| """ |
| Model readiness tracker untuk CV Pipeline. |
| |
| Track state setiap model (LOADING / READY / ERROR) supaya: |
| 1. Frontend bisa polling endpoint /ready dan tau kapan boleh kirim request |
| 2. Endpoint inference bisa nunggu (atau reject cepat) kalau model belum ready |
| 3. Internal client (RAG OCR fallback) tau apakah CV API aman dipanggil |
| |
| Thread-safe karena dipakai di background thread (prewarmer) + request handler. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import threading |
| import time |
| from enum import Enum |
| from dataclasses import dataclass, field |
| from typing import Dict, Optional |
|
|
|
|
| class ModelState(str, Enum): |
| NOT_LOADED = "not_loaded" |
| LOADING = "loading" |
| READY = "ready" |
| ERROR = "error" |
|
|
|
|
| @dataclass |
| class ModelStatus: |
| state: ModelState = ModelState.NOT_LOADED |
| error_message: str = "" |
| started_at: float = 0.0 |
| ready_at: float = 0.0 |
|
|
| @property |
| def load_seconds(self) -> Optional[float]: |
| if self.ready_at and self.started_at: |
| return round(self.ready_at - self.started_at, 2) |
| return None |
|
|
| def to_dict(self) -> dict: |
| d = {"state": self.state.value} |
| if self.error_message: |
| d["error"] = self.error_message |
| if self.load_seconds is not None: |
| d["load_seconds"] = self.load_seconds |
| return d |
|
|
|
|
| class ReadinessTracker: |
| """Singleton state tracker — dipakai dari mana aja di CV API.""" |
|
|
| _instance: Optional["ReadinessTracker"] = None |
| _lock = threading.Lock() |
|
|
| def __new__(cls): |
| with cls._lock: |
| if cls._instance is None: |
| cls._instance = super().__new__(cls) |
| cls._instance._init() |
| return cls._instance |
|
|
| def _init(self): |
| self._models: Dict[str, ModelStatus] = { |
| "captioner": ModelStatus(), |
| "yolo": ModelStatus(), |
| "clip": ModelStatus(), |
| "ocr": ModelStatus(), |
| } |
| self._ready_event = threading.Event() |
| self._state_lock = threading.Lock() |
|
|
| def mark_loading(self, model_name: str): |
| with self._state_lock: |
| status = self._models.setdefault(model_name, ModelStatus()) |
| status.state = ModelState.LOADING |
| status.started_at = time.time() |
| status.error_message = "" |
|
|
| def mark_ready(self, model_name: str): |
| with self._state_lock: |
| status = self._models.setdefault(model_name, ModelStatus()) |
| status.state = ModelState.READY |
| status.ready_at = time.time() |
| if self._all_ready_unlocked(): |
| self._ready_event.set() |
|
|
| def mark_error(self, model_name: str, message: str): |
| with self._state_lock: |
| status = self._models.setdefault(model_name, ModelStatus()) |
| status.state = ModelState.ERROR |
| status.error_message = str(message)[:500] |
|
|
| def get_status(self, model_name: str) -> ModelStatus: |
| with self._state_lock: |
| return self._models.get(model_name, ModelStatus()) |
|
|
| def is_ready(self, model_name: str) -> bool: |
| with self._state_lock: |
| return self._models.get(model_name, ModelStatus()).state == ModelState.READY |
|
|
| def all_ready(self) -> bool: |
| with self._state_lock: |
| return self._all_ready_unlocked() |
|
|
| def _all_ready_unlocked(self) -> bool: |
| |
| |
| required = {k: v for k, v in self._models.items() if k in self._REQUIRED_MODELS} |
| return all(s.state == ModelState.READY for s in required.values()) |
|
|
| |
| |
| _REQUIRED_MODELS = {"captioner", "yolo", "ocr"} |
|
|
| def overall_state(self) -> str: |
| with self._state_lock: |
| required = {k: v for k, v in self._models.items() if k in self._REQUIRED_MODELS} |
| req_states = [s.state for s in required.values()] |
| |
| if all(s == ModelState.READY for s in req_states): |
| return "ready" |
| if any(s == ModelState.ERROR for s in req_states): |
| if any(s == ModelState.LOADING for s in req_states): |
| return "loading" |
| return "degraded" |
| if any(s == ModelState.LOADING for s in req_states): |
| return "loading" |
| return "not_started" |
|
|
| def wait_until_ready(self, timeout: float = 120.0) -> bool: |
| """Block sampai semua model READY atau timeout. Returns True kalau ready.""" |
| return self._ready_event.wait(timeout=timeout) |
|
|
| def wait_for(self, model_name: str, timeout: float = 120.0) -> bool: |
| """Block sampai model spesifik READY. Polling-based supaya per-model.""" |
| deadline = time.time() + timeout |
| while time.time() < deadline: |
| status = self.get_status(model_name) |
| if status.state == ModelState.READY: |
| return True |
| if status.state == ModelState.ERROR: |
| return False |
| time.sleep(0.2) |
| return False |
|
|
| def snapshot(self) -> dict: |
| with self._state_lock: |
| return { |
| "overall": self.overall_state(), |
| "models": {name: status.to_dict() for name, status in self._models.items()}, |
| "all_ready": self._all_ready_unlocked(), |
| } |
|
|
|
|
| def get_readiness() -> ReadinessTracker: |
| """Akses singleton readiness tracker.""" |
| return ReadinessTracker() |
|
|