brian4dwell's picture
workflows
4b27cfa
"""CLI entrypoint to launch an RQ worker for STream3R jobs."""
from __future__ import annotations
import argparse
import logging
import os
from typing import Sequence
from rq import Queue, Worker
from .config import WorkerSettings
from .runtime import get_runtime
logger = logging.getLogger(__name__)
def _parse_args(default_queues: Sequence[str]) -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Run the STream3R RQ worker")
parser.add_argument(
"--queue",
"--queues",
dest="queues",
action="append",
help="Queue names to listen to (can be repeated)",
)
parser.add_argument(
"--burst",
action="store_true",
help="Run in burst mode (exit when queues are empty)",
)
parser.add_argument(
"--log-level",
default="INFO",
help="Logging level",
)
args = parser.parse_args()
if not args.queues:
args.queues = list(default_queues)
return args
class Stream3RWorker(Worker):
"""RQ worker that enforces configured default timeouts before execution."""
def __init__(self, *args, default_timeout: int | None = None, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._default_timeout = default_timeout if default_timeout and default_timeout > 0 else None
def _normalize_timeout(self, timeout_value: object) -> int | None:
if timeout_value is None:
return None
try:
return int(timeout_value)
except (TypeError, ValueError):
return None
def execute_job(self, job, queue): # type: ignore[override]
if self._default_timeout is not None:
current = self._normalize_timeout(getattr(job, "timeout", None))
if current is None or current < self._default_timeout:
job.timeout = self._default_timeout
runtime = get_runtime()
lock_ctx = runtime.gpu_lock()
self.log.debug("Worker %s acquiring GPU lock for job %s", self.name, job.id)
with lock_ctx:
os.environ["STREAM3R_GPU_LOCK_HELD"] = "1"
try:
return super().execute_job(job, queue)
finally:
os.environ.pop("STREAM3R_GPU_LOCK_HELD", None)
def main() -> None:
settings = WorkerSettings.from_env()
if settings.default_job_timeout and settings.default_job_timeout > 0:
Queue.DEFAULT_TIMEOUT = settings.default_job_timeout
args = _parse_args([settings.pose_queue, settings.model_queue, settings.keyframe_queue])
logging.basicConfig(level=getattr(logging, str(args.log_level).upper(), logging.INFO))
runtime = get_runtime()
queues = [Queue(name, connection=runtime.redis) for name in args.queues]
if not queues:
raise ValueError("No queues configured for worker")
for queue in queues:
logger.info("Listening on queue '%s'", queue.name)
worker = Stream3RWorker(
queues,
connection=runtime.redis,
name=settings.worker_name,
default_timeout=settings.default_job_timeout,
)
worker.work(burst=args.burst)
if __name__ == "__main__": # pragma: no cover
main()