File size: 5,458 Bytes
1c5aca1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
"""Runtime registry for shared worker resources."""

from __future__ import annotations

import logging
from contextlib import contextmanager
from threading import Lock
from typing import Any, Dict, Mapping

import redis
import torch

from stream3r.models.stream3r import STream3R

from .config import WorkerSettings
from .db import BaseDatabaseClient, create_database_client
from .storage import StorageClient, create_storage_client


logger = logging.getLogger(__name__)


class WorkerRuntime:
    """Holds shared state reused across RQ jobs."""

    def __init__(self, settings: WorkerSettings):
        self.settings = settings
        self._redis = redis.Redis.from_url(
            settings.redis_url,
            decode_responses=False,
            health_check_interval=settings.redis_healthcheck_interval,
        )
        self.storage: StorageClient = create_storage_client(settings)
        self.db: BaseDatabaseClient = create_database_client(settings)

        try:
            self.db.ensure_schema()
        except Exception as exc:  # pragma: no cover - depends on external DB
            logger.warning("Failed to ensure job schema: %s", exc)

        self._model: STream3R | None = None
        self._model_lock = Lock()
        self._device: torch.device | None = None
        self._autocast_dtype: torch.dtype | None = None

    # -----------------------------------------------------------------
    # Redis helpers
    # -----------------------------------------------------------------
    @property
    def redis(self) -> redis.Redis:
        return self._redis

    def emit_event(self, payload: Mapping[str, Any]) -> None:
        try:
            stream = self.settings.redis_events_stream
            data = {k: str(v) for k, v in payload.items() if v is not None}
            maxlen = self.settings.redis_stream_maxlen
            self._redis.xadd(stream, data, maxlen=maxlen, approximate=True)
        except redis.RedisError as exc:  # pragma: no cover - depends on Redis
            logger.warning("Failed to emit event to Redis: %s", exc)

    @contextmanager
    def gpu_lock(self) -> Any:
        lock = self._redis.lock(
            self.settings.gpu_lock_key,
            timeout=self.settings.gpu_lock_timeout,
            blocking_timeout=self.settings.gpu_lock_blocking_timeout,
        )
        acquired = False
        try:
            acquired = lock.acquire(blocking=True)
            if not acquired:
                raise TimeoutError("Timed out waiting for GPU lock")
            yield
        finally:
            if acquired:
                try:
                    lock.release()
                except redis.RedisError:  # pragma: no cover - depends on Redis
                    logger.debug("GPU lock already released")

    # -----------------------------------------------------------------
    # Model helpers
    # -----------------------------------------------------------------
    def _resolve_device(self) -> torch.device:
        if self._device is not None:
            return self._device

        preference = self.settings.model_device_preference
        if preference:
            try:
                device = torch.device(preference)
            except (ValueError, RuntimeError):
                logger.warning("Unknown device preference '%s', falling back to auto", preference)
                device = None
        else:
            device = None

        if device is None:
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self._device = device
        return device

    def _resolve_autocast_dtype(self) -> torch.dtype:
        if self._autocast_dtype is not None:
            return self._autocast_dtype

        dtype_name = self.settings.model_dtype
        if dtype_name:
            try:
                self._autocast_dtype = getattr(torch, dtype_name)
                return self._autocast_dtype
            except AttributeError:
                logger.warning("Unsupported dtype '%s', using default", dtype_name)

        if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
            self._autocast_dtype = torch.bfloat16
        else:
            self._autocast_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
        return self._autocast_dtype

    def get_model(self) -> STream3R:
        with self._model_lock:
            if self._model is None:
                logger.info("Loading STream3R model '%s'", self.settings.model_id)
                model = STream3R.from_pretrained(
                    self.settings.model_id,
                    revision=self.settings.model_revision,
                )
                device = self._resolve_device()
                model.to(device)
                model.eval()
                self._model = model
            return self._model

    def model_device(self) -> torch.device:
        return self._resolve_device()

    def autocast_dtype(self) -> torch.dtype:
        return self._resolve_autocast_dtype()

    # -----------------------------------------------------------------
    def close(self) -> None:
        try:
            self.db.close()
        except AttributeError:
            pass


_RUNTIME: WorkerRuntime | None = None


def get_runtime() -> WorkerRuntime:
    global _RUNTIME
    if _RUNTIME is None:
        settings = WorkerSettings.from_env()
        _RUNTIME = WorkerRuntime(settings)
    return _RUNTIME