brian4dwell's picture
initi worker working
1c5aca1
"""Runtime registry for shared worker resources."""
from __future__ import annotations
import logging
from contextlib import contextmanager
from threading import Lock
from typing import Any, Dict, Mapping
import redis
import torch
from stream3r.models.stream3r import STream3R
from .config import WorkerSettings
from .db import BaseDatabaseClient, create_database_client
from .storage import StorageClient, create_storage_client
logger = logging.getLogger(__name__)
class WorkerRuntime:
"""Holds shared state reused across RQ jobs."""
def __init__(self, settings: WorkerSettings):
self.settings = settings
self._redis = redis.Redis.from_url(
settings.redis_url,
decode_responses=False,
health_check_interval=settings.redis_healthcheck_interval,
)
self.storage: StorageClient = create_storage_client(settings)
self.db: BaseDatabaseClient = create_database_client(settings)
try:
self.db.ensure_schema()
except Exception as exc: # pragma: no cover - depends on external DB
logger.warning("Failed to ensure job schema: %s", exc)
self._model: STream3R | None = None
self._model_lock = Lock()
self._device: torch.device | None = None
self._autocast_dtype: torch.dtype | None = None
# -----------------------------------------------------------------
# Redis helpers
# -----------------------------------------------------------------
@property
def redis(self) -> redis.Redis:
return self._redis
def emit_event(self, payload: Mapping[str, Any]) -> None:
try:
stream = self.settings.redis_events_stream
data = {k: str(v) for k, v in payload.items() if v is not None}
maxlen = self.settings.redis_stream_maxlen
self._redis.xadd(stream, data, maxlen=maxlen, approximate=True)
except redis.RedisError as exc: # pragma: no cover - depends on Redis
logger.warning("Failed to emit event to Redis: %s", exc)
@contextmanager
def gpu_lock(self) -> Any:
lock = self._redis.lock(
self.settings.gpu_lock_key,
timeout=self.settings.gpu_lock_timeout,
blocking_timeout=self.settings.gpu_lock_blocking_timeout,
)
acquired = False
try:
acquired = lock.acquire(blocking=True)
if not acquired:
raise TimeoutError("Timed out waiting for GPU lock")
yield
finally:
if acquired:
try:
lock.release()
except redis.RedisError: # pragma: no cover - depends on Redis
logger.debug("GPU lock already released")
# -----------------------------------------------------------------
# Model helpers
# -----------------------------------------------------------------
def _resolve_device(self) -> torch.device:
if self._device is not None:
return self._device
preference = self.settings.model_device_preference
if preference:
try:
device = torch.device(preference)
except (ValueError, RuntimeError):
logger.warning("Unknown device preference '%s', falling back to auto", preference)
device = None
else:
device = None
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self._device = device
return device
def _resolve_autocast_dtype(self) -> torch.dtype:
if self._autocast_dtype is not None:
return self._autocast_dtype
dtype_name = self.settings.model_dtype
if dtype_name:
try:
self._autocast_dtype = getattr(torch, dtype_name)
return self._autocast_dtype
except AttributeError:
logger.warning("Unsupported dtype '%s', using default", dtype_name)
if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
self._autocast_dtype = torch.bfloat16
else:
self._autocast_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
return self._autocast_dtype
def get_model(self) -> STream3R:
with self._model_lock:
if self._model is None:
logger.info("Loading STream3R model '%s'", self.settings.model_id)
model = STream3R.from_pretrained(
self.settings.model_id,
revision=self.settings.model_revision,
)
device = self._resolve_device()
model.to(device)
model.eval()
self._model = model
return self._model
def model_device(self) -> torch.device:
return self._resolve_device()
def autocast_dtype(self) -> torch.dtype:
return self._resolve_autocast_dtype()
# -----------------------------------------------------------------
def close(self) -> None:
try:
self.db.close()
except AttributeError:
pass
_RUNTIME: WorkerRuntime | None = None
def get_runtime() -> WorkerRuntime:
global _RUNTIME
if _RUNTIME is None:
settings = WorkerSettings.from_env()
_RUNTIME = WorkerRuntime(settings)
return _RUNTIME