Spaces:
Configuration error
Configuration error
| """Runtime registry for shared worker resources.""" | |
| from __future__ import annotations | |
| import logging | |
| from contextlib import contextmanager | |
| from threading import Lock | |
| from typing import Any, Dict, Mapping | |
| import redis | |
| import torch | |
| from stream3r.models.stream3r import STream3R | |
| from .config import WorkerSettings | |
| from .db import BaseDatabaseClient, create_database_client | |
| from .storage import StorageClient, create_storage_client | |
| logger = logging.getLogger(__name__) | |
| class WorkerRuntime: | |
| """Holds shared state reused across RQ jobs.""" | |
| def __init__(self, settings: WorkerSettings): | |
| self.settings = settings | |
| self._redis = redis.Redis.from_url( | |
| settings.redis_url, | |
| decode_responses=False, | |
| health_check_interval=settings.redis_healthcheck_interval, | |
| ) | |
| self.storage: StorageClient = create_storage_client(settings) | |
| self.db: BaseDatabaseClient = create_database_client(settings) | |
| try: | |
| self.db.ensure_schema() | |
| except Exception as exc: # pragma: no cover - depends on external DB | |
| logger.warning("Failed to ensure job schema: %s", exc) | |
| self._model: STream3R | None = None | |
| self._model_lock = Lock() | |
| self._device: torch.device | None = None | |
| self._autocast_dtype: torch.dtype | None = None | |
| # ----------------------------------------------------------------- | |
| # Redis helpers | |
| # ----------------------------------------------------------------- | |
| def redis(self) -> redis.Redis: | |
| return self._redis | |
| def emit_event(self, payload: Mapping[str, Any]) -> None: | |
| try: | |
| stream = self.settings.redis_events_stream | |
| data = {k: str(v) for k, v in payload.items() if v is not None} | |
| maxlen = self.settings.redis_stream_maxlen | |
| self._redis.xadd(stream, data, maxlen=maxlen, approximate=True) | |
| except redis.RedisError as exc: # pragma: no cover - depends on Redis | |
| logger.warning("Failed to emit event to Redis: %s", exc) | |
| def gpu_lock(self) -> Any: | |
| lock = self._redis.lock( | |
| self.settings.gpu_lock_key, | |
| timeout=self.settings.gpu_lock_timeout, | |
| blocking_timeout=self.settings.gpu_lock_blocking_timeout, | |
| ) | |
| acquired = False | |
| try: | |
| acquired = lock.acquire(blocking=True) | |
| if not acquired: | |
| raise TimeoutError("Timed out waiting for GPU lock") | |
| yield | |
| finally: | |
| if acquired: | |
| try: | |
| lock.release() | |
| except redis.RedisError: # pragma: no cover - depends on Redis | |
| logger.debug("GPU lock already released") | |
| # ----------------------------------------------------------------- | |
| # Model helpers | |
| # ----------------------------------------------------------------- | |
| def _resolve_device(self) -> torch.device: | |
| if self._device is not None: | |
| return self._device | |
| preference = self.settings.model_device_preference | |
| if preference: | |
| try: | |
| device = torch.device(preference) | |
| except (ValueError, RuntimeError): | |
| logger.warning("Unknown device preference '%s', falling back to auto", preference) | |
| device = None | |
| else: | |
| device = None | |
| if device is None: | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self._device = device | |
| return device | |
| def _resolve_autocast_dtype(self) -> torch.dtype: | |
| if self._autocast_dtype is not None: | |
| return self._autocast_dtype | |
| dtype_name = self.settings.model_dtype | |
| if dtype_name: | |
| try: | |
| self._autocast_dtype = getattr(torch, dtype_name) | |
| return self._autocast_dtype | |
| except AttributeError: | |
| logger.warning("Unsupported dtype '%s', using default", dtype_name) | |
| if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8: | |
| self._autocast_dtype = torch.bfloat16 | |
| else: | |
| self._autocast_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| return self._autocast_dtype | |
| def get_model(self) -> STream3R: | |
| with self._model_lock: | |
| if self._model is None: | |
| logger.info("Loading STream3R model '%s'", self.settings.model_id) | |
| model = STream3R.from_pretrained( | |
| self.settings.model_id, | |
| revision=self.settings.model_revision, | |
| ) | |
| device = self._resolve_device() | |
| model.to(device) | |
| model.eval() | |
| self._model = model | |
| return self._model | |
| def model_device(self) -> torch.device: | |
| return self._resolve_device() | |
| def autocast_dtype(self) -> torch.dtype: | |
| return self._resolve_autocast_dtype() | |
| # ----------------------------------------------------------------- | |
| def close(self) -> None: | |
| try: | |
| self.db.close() | |
| except AttributeError: | |
| pass | |
| _RUNTIME: WorkerRuntime | None = None | |
| def get_runtime() -> WorkerRuntime: | |
| global _RUNTIME | |
| if _RUNTIME is None: | |
| settings = WorkerSettings.from_env() | |
| _RUNTIME = WorkerRuntime(settings) | |
| return _RUNTIME | |