Spaces:
Configuration error
Configuration error
| """CLI entrypoint to launch an RQ worker for STream3R jobs.""" | |
| from __future__ import annotations | |
| import argparse | |
| import logging | |
| import os | |
| from typing import Sequence | |
| from rq import Queue, Worker | |
| from .config import WorkerSettings | |
| from .runtime import get_runtime | |
| logger = logging.getLogger(__name__) | |
| def _parse_args(default_queues: Sequence[str]) -> argparse.Namespace: | |
| parser = argparse.ArgumentParser(description="Run the STream3R RQ worker") | |
| parser.add_argument( | |
| "--queue", | |
| "--queues", | |
| dest="queues", | |
| action="append", | |
| help="Queue names to listen to (can be repeated)", | |
| ) | |
| parser.add_argument( | |
| "--burst", | |
| action="store_true", | |
| help="Run in burst mode (exit when queues are empty)", | |
| ) | |
| parser.add_argument( | |
| "--log-level", | |
| default="INFO", | |
| help="Logging level", | |
| ) | |
| args = parser.parse_args() | |
| if not args.queues: | |
| args.queues = list(default_queues) | |
| return args | |
| class Stream3RWorker(Worker): | |
| """RQ worker that enforces configured default timeouts before execution.""" | |
| def __init__(self, *args, default_timeout: int | None = None, **kwargs) -> None: | |
| super().__init__(*args, **kwargs) | |
| self._default_timeout = default_timeout if default_timeout and default_timeout > 0 else None | |
| def _normalize_timeout(self, timeout_value: object) -> int | None: | |
| if timeout_value is None: | |
| return None | |
| try: | |
| return int(timeout_value) | |
| except (TypeError, ValueError): | |
| return None | |
| def execute_job(self, job, queue): # type: ignore[override] | |
| if self._default_timeout is not None: | |
| current = self._normalize_timeout(getattr(job, "timeout", None)) | |
| if current is None or current < self._default_timeout: | |
| job.timeout = self._default_timeout | |
| runtime = get_runtime() | |
| lock_ctx = runtime.gpu_lock() | |
| self.log.debug("Worker %s acquiring GPU lock for job %s", self.name, job.id) | |
| with lock_ctx: | |
| os.environ["STREAM3R_GPU_LOCK_HELD"] = "1" | |
| try: | |
| return super().execute_job(job, queue) | |
| finally: | |
| os.environ.pop("STREAM3R_GPU_LOCK_HELD", None) | |
| def main() -> None: | |
| settings = WorkerSettings.from_env() | |
| if settings.default_job_timeout and settings.default_job_timeout > 0: | |
| Queue.DEFAULT_TIMEOUT = settings.default_job_timeout | |
| args = _parse_args([settings.pose_queue, settings.model_queue, settings.keyframe_queue]) | |
| logging.basicConfig(level=getattr(logging, str(args.log_level).upper(), logging.INFO)) | |
| runtime = get_runtime() | |
| queues = [Queue(name, connection=runtime.redis) for name in args.queues] | |
| if not queues: | |
| raise ValueError("No queues configured for worker") | |
| for queue in queues: | |
| logger.info("Listening on queue '%s'", queue.name) | |
| worker = Stream3RWorker( | |
| queues, | |
| connection=runtime.redis, | |
| name=settings.worker_name, | |
| default_timeout=settings.default_job_timeout, | |
| ) | |
| worker.work(burst=args.burst) | |
| if __name__ == "__main__": # pragma: no cover | |
| main() | |