"""Model loading with mock/real switching based on environment. Supports two loading modes: - MOCK_MODELS=true: Loads mock models (fast, for local dev on RTX 4090) - MOCK_MODELS=false: Loads all real models at startup (~38-43GB total) Memory Strategy (Simultaneous Loading for 4xL4 GPUs with 88GB total): - Vision 30B-A3B FP8 via vLLM: ~30-35GB - Embedding 2B: ~4GB - Reranker 2B: ~4GB - Total: ~38-43GB, leaving ~45GB+ headroom """ import logging import time from typing import Union from config.settings import settings logger = logging.getLogger(__name__) # Type alias for model stack ModelStack = Union["MockModelStack", "RealModelStack"] # noqa: F821 # Lazy singleton _model_stack: ModelStack | None = None def get_model_stack() -> ModelStack: """Get model stack based on environment configuration. For mock models: Loads mock models immediately (fast, for local dev). For real models: Loads all 3 models at startup (~38-43GB total). """ start_time = time.time() if settings.mock_models: logger.info("Loading MOCK model stack (development mode)") from models.mock import MockModelStack stack = MockModelStack().load_all() elapsed = time.time() - start_time logger.info(f"Mock model stack loaded in {elapsed:.2f}s") return stack else: logger.info("Loading REAL model stack (production mode)") logger.info(f"Vision model: {settings.vision_model} (FP8 via vLLM)") logger.info(f"Embedding model: {settings.embedding_model}") logger.info(f"Reranker model: {settings.reranker_model}") from models.real import RealModelStack # Load all models at startup (simultaneous loading) stack = RealModelStack().load_all() elapsed = time.time() - start_time logger.info(f"Real model stack loaded in {elapsed:.2f}s") return stack def get_models() -> ModelStack: """Get or create the singleton model stack. Returns fully loaded model stack (all models ready for inference). """ global _model_stack if _model_stack is None: logger.debug("Model stack not initialized, creating new stack") _model_stack = get_model_stack() else: logger.debug("Returning cached model stack") return _model_stack def reset_models() -> None: """Reset the model stack (useful for testing).""" global _model_stack _model_stack = None