Spaces:
Running
Running
| """ | |
| Enterprise model lifecycle manager. | |
| Verifies local models, triggers automatic download if missing, | |
| and ensures thread/async-safe initialization across the library. | |
| Author: IntelliDeep Labs Team | |
| License: BSL 1.1 | |
| """ | |
| from __future__ import annotations | |
| import asyncio | |
| import hashlib | |
| import logging | |
| import os | |
| import sys | |
| import threading | |
| import zipfile | |
| from dataclasses import dataclass, field | |
| from pathlib import Path | |
| from typing import Dict, List, Optional | |
| logger = logging.getLogger(__name__) | |
| class ModelConfig: | |
| """Immutable configuration for a required model.""" | |
| name: str | |
| expected_dir: str | |
| zip_filename: str = "nlproxy_models.zip" | |
| min_files: int = 2 # Minimum files expected after extraction | |
| type: str = "unknown" # embedding, nli, perplexity, spacy, etc. | |
| class ModelManager: | |
| """ | |
| Singleton manager for local model verification and on-demand download. | |
| Thread-safe and async-compatible. Designed for enterprise library initialization. | |
| """ | |
| _instance: Optional[ModelManager] = None | |
| _lock = threading.Lock() | |
| _async_lock: Optional[asyncio.Lock] = None | |
| _initialized: bool = False | |
| _models_dir: Optional[Path] = None | |
| # Registry of required models | |
| REQUIRED_MODELS: Dict[str, ModelConfig] = { | |
| "all-MiniLM-L6-v2": ModelConfig(name="all-MiniLM-L6-v2", expected_dir="all-MiniLM-L6-v2", type="embedding"), | |
| "nli-distilroberta-base": ModelConfig(name="nli-distilroberta-base", expected_dir="nli-distilroberta-base", type="nli"), | |
| "distilgpt2": ModelConfig(name="distilgpt2", expected_dir="distilgpt2", type="perplexity"), | |
| } | |
| def __new__(cls) -> ModelManager: | |
| if cls._instance is None: | |
| with cls._lock: | |
| if cls._instance is None: | |
| cls._instance = super().__new__(cls) | |
| cls._async_lock = asyncio.Lock() | |
| return cls._instance | |
| def get_instance(cls, models_dir: Optional[str] = None) -> ModelManager: | |
| """Get singleton instance and optionally set models directory.""" | |
| instance = cls() | |
| if models_dir: | |
| instance._models_dir = Path(models_dir).resolve() | |
| elif instance._models_dir is None: | |
| instance._models_dir = Path(os.getenv("NLPROXY_MODELS_DIR", "nlproxy/models")).resolve() | |
| return instance | |
| def _is_model_extracted(self, config: ModelConfig) -> bool: | |
| """Check if model directory exists and contains expected files.""" | |
| model_path = self._models_dir / config.expected_dir | |
| if not model_path.is_dir(): | |
| return False | |
| return len(list(model_path.iterdir())) >= config.min_files | |
| def _has_zip_archive(self, config: ModelConfig) -> bool: | |
| """Fallback: check if the source zip exists in models_dir.""" | |
| zip_path = self._models_dir / config.zip_filename | |
| return zip_path.exists() and zip_path.stat().st_size > 0 | |
| def _compute_sha256(self, path: Path) -> str: | |
| """Compute SHA256 checksum for a file in a streaming manner.""" | |
| hash_obj = hashlib.sha256() | |
| with path.open("rb") as file_obj: | |
| for chunk in iter(lambda: file_obj.read(8192), b""): | |
| hash_obj.update(chunk) | |
| return hash_obj.hexdigest() | |
| def verify_zip_checksum(self, zip_path: Path, expected_hash: Optional[str] = None) -> bool: | |
| """Verify the ZIP archive checksum against an expected SHA256 hash.""" | |
| expected_hash = expected_hash or os.getenv("NLPROXY_MODELS_SHA256") | |
| if not expected_hash: | |
| logger.debug("No expected SHA256 checksum provided for ZIP validation.") | |
| return True | |
| actual_hash = self._compute_sha256(zip_path) | |
| if actual_hash.lower() != expected_hash.lower(): | |
| raise ValueError( | |
| f"ZIP checksum mismatch: expected {expected_hash.lower()}, got {actual_hash.lower()}" | |
| ) | |
| logger.info("ZIP checksum verification passed.") | |
| return True | |
| def verify_all(self) -> bool: | |
| """Synchronous verification of all required models.""" | |
| return all(self._is_model_extracted(cfg) for cfg in self.REQUIRED_MODELS.values()) | |
| async def ensure_ready(self) -> None: | |
| """ | |
| Async-safe verification & download trigger. | |
| Idempotent: safe to call multiple times from different coroutines. | |
| """ | |
| if self._initialized: | |
| return | |
| async with self._async_lock: | |
| if self._initialized: | |
| return | |
| missing = [cfg.name for cfg in self.REQUIRED_MODELS.values() if not self._is_model_extracted(cfg)] | |
| if missing: | |
| logger.info(f"Missing models: {', '.join(missing)}. Triggering automatic download...") | |
| await self._trigger_download() | |
| # Final verification | |
| if not self.verify_all(): | |
| raise RuntimeError("Model verification failed after download. Check NLPROXY_MODELS_URL or permissions.") | |
| self._initialized = True | |
| logger.info("✅ All required models verified and ready for use.") | |
| async def _trigger_download(self) -> None: | |
| """Execute download_models.py in a thread pool to avoid blocking the event loop.""" | |
| # Dynamic import to prevent circular dependencies | |
| from nlproxy.cli.download_models import main as download_main | |
| loop = asyncio.get_running_loop() | |
| # Pass --models-dir explicitly to avoid env var race conditions | |
| args = ["--models-dir", str(self._models_dir), "--verbose"] | |
| exit_code = await loop.run_in_executor(None, download_main, args) | |
| if exit_code != 0: | |
| raise RuntimeError(f"Model download failed with exit code {exit_code}") | |
| def sync_ensure_ready(self) -> None: | |
| """Synchronous wrapper for legacy or CLI initialization contexts.""" | |
| if self._initialized: | |
| return | |
| asyncio.run(self.ensure_ready()) | |