brian4dwell commited on
Commit
f1e0138
·
1 Parent(s): 1c5aca1

working for larger batches now

Browse files
notes.md ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+
2
+
3
+ **Manually Clear the GPU Lock from REDIS:
4
+ redis-cli -u "$REDIS_URL" DEL gpu:lock
stream3r/worker/__init__.py CHANGED
@@ -1,6 +1,14 @@
1
  """Worker utilities for running STream3R jobs via RQ."""
2
 
3
- from .tasks import model_build_job, pose_pointmap_job
 
 
 
 
 
 
 
 
4
 
5
  __all__ = [
6
  "pose_pointmap_job",
 
1
  """Worker utilities for running STream3R jobs via RQ."""
2
 
3
+ from rq.queue import Queue
4
+
5
+ from .config import WorkerSettings
6
+
7
+ _settings = WorkerSettings.from_env()
8
+ if _settings.default_job_timeout and _settings.default_job_timeout > 0:
9
+ Queue.DEFAULT_TIMEOUT = _settings.default_job_timeout
10
+
11
+ from .tasks import model_build_job, pose_pointmap_job # noqa: E402
12
 
13
  __all__ = [
14
  "pose_pointmap_job",
stream3r/worker/config.py CHANGED
@@ -114,8 +114,9 @@ class WorkerSettings:
114
  scene_media_api_base_url: str | None = None
115
  scene_media_api_token: str | None = None
116
  scene_media_page_size: int = 200
117
- stream_window_size: int = 10
118
- max_frames_per_job: int = 10
 
119
 
120
  @classmethod
121
  def from_env(cls) -> "WorkerSettings":
@@ -208,6 +209,9 @@ class WorkerSettings:
208
  "max_frames_per_job": _env_int(
209
  "STREAM3R_MAX_FRAMES", base.max_frames_per_job
210
  ),
 
 
 
211
  }
212
 
213
  return cls(**kwargs)
 
114
  scene_media_api_base_url: str | None = None
115
  scene_media_api_token: str | None = None
116
  scene_media_page_size: int = 200
117
+ stream_window_size: int = 20
118
+ max_frames_per_job: int = 0
119
+ default_job_timeout: int = 15 * 60
120
 
121
  @classmethod
122
  def from_env(cls) -> "WorkerSettings":
 
209
  "max_frames_per_job": _env_int(
210
  "STREAM3R_MAX_FRAMES", base.max_frames_per_job
211
  ),
212
+ "default_job_timeout": _env_int(
213
+ "STREAM3R_JOB_TIMEOUT", base.default_job_timeout
214
+ ),
215
  }
216
 
217
  return cls(**kwargs)
stream3r/worker/main.py CHANGED
@@ -4,6 +4,7 @@ from __future__ import annotations
4
 
5
  import argparse
6
  import logging
 
7
  from typing import Sequence
8
 
9
  from rq import Queue, Worker
@@ -40,8 +41,42 @@ def _parse_args(default_queues: Sequence[str]) -> argparse.Namespace:
40
  return args
41
 
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  def main() -> None:
44
  settings = WorkerSettings.from_env()
 
 
 
45
  args = _parse_args([settings.pose_queue, settings.model_queue])
46
  logging.basicConfig(level=getattr(logging, str(args.log_level).upper(), logging.INFO))
47
 
@@ -51,7 +86,11 @@ def main() -> None:
51
  for queue in queues:
52
  logger.info("Listening on queue '%s'", queue.name)
53
 
54
- worker = Worker(queues, name=settings.worker_name)
 
 
 
 
55
  worker.work(burst=args.burst)
56
 
57
 
 
4
 
5
  import argparse
6
  import logging
7
+ import os
8
  from typing import Sequence
9
 
10
  from rq import Queue, Worker
 
41
  return args
42
 
43
 
44
+ class Stream3RWorker(Worker):
45
+ """RQ worker that enforces configured default timeouts before execution."""
46
+
47
+ def __init__(self, *args, default_timeout: int | None = None, **kwargs) -> None:
48
+ super().__init__(*args, **kwargs)
49
+ self._default_timeout = default_timeout if default_timeout and default_timeout > 0 else None
50
+
51
+ def _normalize_timeout(self, timeout_value: object) -> int | None:
52
+ if timeout_value is None:
53
+ return None
54
+ try:
55
+ return int(timeout_value)
56
+ except (TypeError, ValueError):
57
+ return None
58
+
59
+ def execute_job(self, job, queue): # type: ignore[override]
60
+ if self._default_timeout is not None:
61
+ current = self._normalize_timeout(getattr(job, "timeout", None))
62
+ if current is None or current < self._default_timeout:
63
+ job.timeout = self._default_timeout
64
+ runtime = get_runtime()
65
+ lock_ctx = runtime.gpu_lock()
66
+ self.log.debug("Worker %s acquiring GPU lock for job %s", self.name, job.id)
67
+ with lock_ctx:
68
+ os.environ["STREAM3R_GPU_LOCK_HELD"] = "1"
69
+ try:
70
+ return super().execute_job(job, queue)
71
+ finally:
72
+ os.environ.pop("STREAM3R_GPU_LOCK_HELD", None)
73
+
74
+
75
  def main() -> None:
76
  settings = WorkerSettings.from_env()
77
+ if settings.default_job_timeout and settings.default_job_timeout > 0:
78
+ Queue.DEFAULT_TIMEOUT = settings.default_job_timeout
79
+
80
  args = _parse_args([settings.pose_queue, settings.model_queue])
81
  logging.basicConfig(level=getattr(logging, str(args.log_level).upper(), logging.INFO))
82
 
 
86
  for queue in queues:
87
  logger.info("Listening on queue '%s'", queue.name)
88
 
89
+ worker = Stream3RWorker(
90
+ queues,
91
+ name=settings.worker_name,
92
+ default_timeout=settings.default_job_timeout,
93
+ )
94
  worker.work(burst=args.burst)
95
 
96
 
stream3r/worker/pipeline.py CHANGED
@@ -86,6 +86,7 @@ def run_stream3r_inference(
86
  frame = images[idx : idx + 1].to(device)
87
  with autocast_ctx:
88
  session.forward_stream(frame)
 
89
  if progress_cb is not None:
90
  progress_cb(idx + 1, total_frames)
91
 
 
86
  frame = images[idx : idx + 1].to(device)
87
  with autocast_ctx:
88
  session.forward_stream(frame)
89
+ print(f"Processed frame {idx + 1}/{total_frames}")
90
  if progress_cb is not None:
91
  progress_cb(idx + 1, total_frames)
92
 
stream3r/worker/tasks.py CHANGED
@@ -5,6 +5,7 @@ from __future__ import annotations
5
  import base64
6
  import json
7
  import logging
 
8
  import re
9
  import shutil
10
  import tempfile
@@ -13,6 +14,7 @@ import uuid
13
  from dataclasses import dataclass, field
14
  from datetime import datetime, timezone
15
  from pathlib import Path
 
16
  from typing import Any, Callable, Mapping
17
 
18
  import numpy as np
@@ -797,16 +799,35 @@ def _execute_job(job_type: str, payload: Mapping[str, Any], handler: JobHandler)
797
  window_size = None
798
 
799
  payload["mode"] = mode
 
800
  timeout_override = payload.get("timeout")
 
 
801
  if timeout_override is not None:
802
  try:
803
- job.timeout = int(timeout_override)
 
 
804
  except (TypeError, ValueError):
805
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
806
 
807
- # Default to 15 minutes if no timeout already applied
808
- if job.timeout is None:
809
- job.timeout = 15 * 60
810
 
811
  sanitized_payload = _sanitize_payload(payload)
812
 
@@ -834,8 +855,10 @@ def _execute_job(job_type: str, payload: Mapping[str, Any], handler: JobHandler)
834
  },
835
  )
836
 
 
 
837
  try:
838
- with runtime.gpu_lock():
839
  with tempfile.TemporaryDirectory(prefix=f"stream3r_{job_id}_") as tmp_dir:
840
  temp_path = Path(tmp_dir)
841
  frame_records = _collect_frames(runtime, scene_id, payload, temp_path)
 
5
  import base64
6
  import json
7
  import logging
8
+ import os
9
  import re
10
  import shutil
11
  import tempfile
 
14
  from dataclasses import dataclass, field
15
  from datetime import datetime, timezone
16
  from pathlib import Path
17
+ from contextlib import nullcontext
18
  from typing import Any, Callable, Mapping
19
 
20
  import numpy as np
 
799
  window_size = None
800
 
801
  payload["mode"] = mode
802
+ desired_timeout = runtime.settings.default_job_timeout
803
  timeout_override = payload.get("timeout")
804
+ applied_timeout: int | None = None
805
+
806
  if timeout_override is not None:
807
  try:
808
+ applied_timeout = int(timeout_override)
809
+ if job is not None:
810
+ job.timeout = applied_timeout
811
  except (TypeError, ValueError):
812
+ applied_timeout = None
813
+
814
+ if applied_timeout is None and desired_timeout and desired_timeout > 0:
815
+ if job is not None:
816
+ current_timeout = getattr(job, "timeout", None)
817
+ try:
818
+ current_timeout_value = int(current_timeout) if current_timeout is not None else None
819
+ except (TypeError, ValueError):
820
+ current_timeout_value = None
821
+ if current_timeout_value is None or current_timeout_value < desired_timeout:
822
+ job.timeout = desired_timeout
823
+ applied_timeout = desired_timeout
824
+ else:
825
+ applied_timeout = current_timeout_value
826
+ else:
827
+ applied_timeout = desired_timeout
828
 
829
+ if applied_timeout is not None:
830
+ payload["timeout"] = applied_timeout
 
831
 
832
  sanitized_payload = _sanitize_payload(payload)
833
 
 
855
  },
856
  )
857
 
858
+ lock_ctx = nullcontext() if os.getenv("STREAM3R_GPU_LOCK_HELD") == "1" else runtime.gpu_lock()
859
+
860
  try:
861
+ with lock_ctx:
862
  with tempfile.TemporaryDirectory(prefix=f"stream3r_{job_id}_") as tmp_dir:
863
  temp_path = Path(tmp_dir)
864
  frame_records = _collect_frames(runtime, scene_id, payload, temp_path)