""" Enterprise model lifecycle manager. Verifies local models, triggers automatic download if missing, and ensures thread/async-safe initialization across the library. Author: IntelliDeep Labs Team License: BSL 1.1 """ from __future__ import annotations import asyncio import hashlib import logging import os import sys import threading import zipfile from dataclasses import dataclass, field from pathlib import Path from typing import Dict, List, Optional logger = logging.getLogger(__name__) @dataclass(frozen=True) class ModelConfig: """Immutable configuration for a required model.""" name: str expected_dir: str zip_filename: str = "nlproxy_models.zip" min_files: int = 2 # Minimum files expected after extraction type: str = "unknown" # embedding, nli, perplexity, spacy, etc. class ModelManager: """ Singleton manager for local model verification and on-demand download. Thread-safe and async-compatible. Designed for enterprise library initialization. """ _instance: Optional[ModelManager] = None _lock = threading.Lock() _async_lock: Optional[asyncio.Lock] = None _initialized: bool = False _models_dir: Optional[Path] = None # Registry of required models REQUIRED_MODELS: Dict[str, ModelConfig] = { "all-MiniLM-L6-v2": ModelConfig(name="all-MiniLM-L6-v2", expected_dir="all-MiniLM-L6-v2", type="embedding"), "nli-distilroberta-base": ModelConfig(name="nli-distilroberta-base", expected_dir="nli-distilroberta-base", type="nli"), "distilgpt2": ModelConfig(name="distilgpt2", expected_dir="distilgpt2", type="perplexity"), } def __new__(cls) -> ModelManager: if cls._instance is None: with cls._lock: if cls._instance is None: cls._instance = super().__new__(cls) cls._async_lock = asyncio.Lock() return cls._instance @classmethod def get_instance(cls, models_dir: Optional[str] = None) -> ModelManager: """Get singleton instance and optionally set models directory.""" instance = cls() if models_dir: instance._models_dir = Path(models_dir).resolve() elif instance._models_dir is None: instance._models_dir = Path(os.getenv("NLPROXY_MODELS_DIR", "nlproxy/models")).resolve() return instance def _is_model_extracted(self, config: ModelConfig) -> bool: """Check if model directory exists and contains expected files.""" model_path = self._models_dir / config.expected_dir if not model_path.is_dir(): return False return len(list(model_path.iterdir())) >= config.min_files def _has_zip_archive(self, config: ModelConfig) -> bool: """Fallback: check if the source zip exists in models_dir.""" zip_path = self._models_dir / config.zip_filename return zip_path.exists() and zip_path.stat().st_size > 0 def _compute_sha256(self, path: Path) -> str: """Compute SHA256 checksum for a file in a streaming manner.""" hash_obj = hashlib.sha256() with path.open("rb") as file_obj: for chunk in iter(lambda: file_obj.read(8192), b""): hash_obj.update(chunk) return hash_obj.hexdigest() def verify_zip_checksum(self, zip_path: Path, expected_hash: Optional[str] = None) -> bool: """Verify the ZIP archive checksum against an expected SHA256 hash.""" expected_hash = expected_hash or os.getenv("NLPROXY_MODELS_SHA256") if not expected_hash: logger.debug("No expected SHA256 checksum provided for ZIP validation.") return True actual_hash = self._compute_sha256(zip_path) if actual_hash.lower() != expected_hash.lower(): raise ValueError( f"ZIP checksum mismatch: expected {expected_hash.lower()}, got {actual_hash.lower()}" ) logger.info("ZIP checksum verification passed.") return True def verify_all(self) -> bool: """Synchronous verification of all required models.""" return all(self._is_model_extracted(cfg) for cfg in self.REQUIRED_MODELS.values()) async def ensure_ready(self) -> None: """ Async-safe verification & download trigger. Idempotent: safe to call multiple times from different coroutines. """ if self._initialized: return async with self._async_lock: if self._initialized: return missing = [cfg.name for cfg in self.REQUIRED_MODELS.values() if not self._is_model_extracted(cfg)] if missing: logger.info(f"Missing models: {', '.join(missing)}. Triggering automatic download...") await self._trigger_download() # Final verification if not self.verify_all(): raise RuntimeError("Model verification failed after download. Check NLPROXY_MODELS_URL or permissions.") self._initialized = True logger.info("✅ All required models verified and ready for use.") async def _trigger_download(self) -> None: """Execute download_models.py in a thread pool to avoid blocking the event loop.""" # Dynamic import to prevent circular dependencies from nlproxy.cli.download_models import main as download_main loop = asyncio.get_running_loop() # Pass --models-dir explicitly to avoid env var race conditions args = ["--models-dir", str(self._models_dir), "--verbose"] exit_code = await loop.run_in_executor(None, download_main, args) if exit_code != 0: raise RuntimeError(f"Model download failed with exit code {exit_code}") def sync_ensure_ready(self) -> None: """Synchronous wrapper for legacy or CLI initialization contexts.""" if self._initialized: return asyncio.run(self.ensure_ready())