"""CLI entrypoint to launch an RQ worker for STream3R jobs.""" from __future__ import annotations import argparse import logging import os from typing import Sequence from rq import Queue, Worker from .config import WorkerSettings from .runtime import get_runtime logger = logging.getLogger(__name__) def _parse_args(default_queues: Sequence[str]) -> argparse.Namespace: parser = argparse.ArgumentParser(description="Run the STream3R RQ worker") parser.add_argument( "--queue", "--queues", dest="queues", action="append", help="Queue names to listen to (can be repeated)", ) parser.add_argument( "--burst", action="store_true", help="Run in burst mode (exit when queues are empty)", ) parser.add_argument( "--log-level", default="INFO", help="Logging level", ) args = parser.parse_args() if not args.queues: args.queues = list(default_queues) return args class Stream3RWorker(Worker): """RQ worker that enforces configured default timeouts before execution.""" def __init__(self, *args, default_timeout: int | None = None, **kwargs) -> None: super().__init__(*args, **kwargs) self._default_timeout = default_timeout if default_timeout and default_timeout > 0 else None def _normalize_timeout(self, timeout_value: object) -> int | None: if timeout_value is None: return None try: return int(timeout_value) except (TypeError, ValueError): return None def execute_job(self, job, queue): # type: ignore[override] if self._default_timeout is not None: current = self._normalize_timeout(getattr(job, "timeout", None)) if current is None or current < self._default_timeout: job.timeout = self._default_timeout runtime = get_runtime() lock_ctx = runtime.gpu_lock() self.log.debug("Worker %s acquiring GPU lock for job %s", self.name, job.id) with lock_ctx: os.environ["STREAM3R_GPU_LOCK_HELD"] = "1" try: return super().execute_job(job, queue) finally: os.environ.pop("STREAM3R_GPU_LOCK_HELD", None) def main() -> None: settings = WorkerSettings.from_env() if settings.default_job_timeout and settings.default_job_timeout > 0: Queue.DEFAULT_TIMEOUT = settings.default_job_timeout args = _parse_args([settings.pose_queue, settings.model_queue, settings.keyframe_queue]) logging.basicConfig(level=getattr(logging, str(args.log_level).upper(), logging.INFO)) runtime = get_runtime() queues = [Queue(name, connection=runtime.redis) for name in args.queues] if not queues: raise ValueError("No queues configured for worker") for queue in queues: logger.info("Listening on queue '%s'", queue.name) worker = Stream3RWorker( queues, connection=runtime.redis, name=settings.worker_name, default_timeout=settings.default_job_timeout, ) worker.work(burst=args.burst) if __name__ == "__main__": # pragma: no cover main()