Spaces:
Paused
Paused
| """Model loading with mock/real switching based on environment. | |
| Supports two loading modes: | |
| - MOCK_MODELS=true: Loads mock models (fast, for local dev on RTX 4090) | |
| - MOCK_MODELS=false: Loads all real models at startup (~38-43GB total) | |
| Memory Strategy (Simultaneous Loading for 4xL4 GPUs with 88GB total): | |
| - Vision 30B-A3B FP8 via vLLM: ~30-35GB | |
| - Embedding 2B: ~4GB | |
| - Reranker 2B: ~4GB | |
| - Total: ~38-43GB, leaving ~45GB+ headroom | |
| """ | |
| import logging | |
| import time | |
| from typing import Union | |
| from config.settings import settings | |
| logger = logging.getLogger(__name__) | |
| # Type alias for model stack | |
| ModelStack = Union["MockModelStack", "RealModelStack"] # noqa: F821 | |
| # Lazy singleton | |
| _model_stack: ModelStack | None = None | |
| def get_model_stack() -> ModelStack: | |
| """Get model stack based on environment configuration. | |
| For mock models: Loads mock models immediately (fast, for local dev). | |
| For real models: Loads all 3 models at startup (~38-43GB total). | |
| """ | |
| start_time = time.time() | |
| if settings.mock_models: | |
| logger.info("Loading MOCK model stack (development mode)") | |
| from models.mock import MockModelStack | |
| stack = MockModelStack().load_all() | |
| elapsed = time.time() - start_time | |
| logger.info(f"Mock model stack loaded in {elapsed:.2f}s") | |
| return stack | |
| else: | |
| logger.info("Loading REAL model stack (production mode)") | |
| logger.info(f"Vision model: {settings.vision_model} (FP8 via vLLM)") | |
| logger.info(f"Embedding model: {settings.embedding_model}") | |
| logger.info(f"Reranker model: {settings.reranker_model}") | |
| from models.real import RealModelStack | |
| # Load all models at startup (simultaneous loading) | |
| stack = RealModelStack().load_all() | |
| elapsed = time.time() - start_time | |
| logger.info(f"Real model stack loaded in {elapsed:.2f}s") | |
| return stack | |
| def get_models() -> ModelStack: | |
| """Get or create the singleton model stack. | |
| Returns fully loaded model stack (all models ready for inference). | |
| """ | |
| global _model_stack | |
| if _model_stack is None: | |
| logger.debug("Model stack not initialized, creating new stack") | |
| _model_stack = get_model_stack() | |
| else: | |
| logger.debug("Returning cached model stack") | |
| return _model_stack | |
| def reset_models() -> None: | |
| """Reset the model stack (useful for testing).""" | |
| global _model_stack | |
| _model_stack = None | |