Spaces:
Paused
Paused
File size: 2,434 Bytes
5f0db1e 333c083 706520f 5f0db1e 333c083 706520f 5f0db1e 88bdcff f3ebc82 88bdcff f3ebc82 88bdcff 5f0db1e 333c083 706520f 5f0db1e f3ebc82 88bdcff f3ebc82 88bdcff f3ebc82 88bdcff 333c083 706520f f3ebc82 88bdcff 333c083 f3ebc82 333c083 f3ebc82 88bdcff 5f0db1e 333c083 5f0db1e 88bdcff f3ebc82 88bdcff f3ebc82 88bdcff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
"""Model loading with mock/real switching based on environment.
Supports two loading modes:
- MOCK_MODELS=true: Loads mock models (fast, for local dev on RTX 4090)
- MOCK_MODELS=false: Loads all real models at startup (~38-43GB total)
Memory Strategy (Simultaneous Loading for 4xL4 GPUs with 88GB total):
- Vision 30B-A3B FP8 via vLLM: ~30-35GB
- Embedding 2B: ~4GB
- Reranker 2B: ~4GB
- Total: ~38-43GB, leaving ~45GB+ headroom
"""
import logging
import time
from typing import Union
from config.settings import settings
logger = logging.getLogger(__name__)
# Type alias for model stack
ModelStack = Union["MockModelStack", "RealModelStack"] # noqa: F821
# Lazy singleton
_model_stack: ModelStack | None = None
def get_model_stack() -> ModelStack:
"""Get model stack based on environment configuration.
For mock models: Loads mock models immediately (fast, for local dev).
For real models: Loads all 3 models at startup (~38-43GB total).
"""
start_time = time.time()
if settings.mock_models:
logger.info("Loading MOCK model stack (development mode)")
from models.mock import MockModelStack
stack = MockModelStack().load_all()
elapsed = time.time() - start_time
logger.info(f"Mock model stack loaded in {elapsed:.2f}s")
return stack
else:
logger.info("Loading REAL model stack (production mode)")
logger.info(f"Vision model: {settings.vision_model} (FP8 via vLLM)")
logger.info(f"Embedding model: {settings.embedding_model}")
logger.info(f"Reranker model: {settings.reranker_model}")
from models.real import RealModelStack
# Load all models at startup (simultaneous loading)
stack = RealModelStack().load_all()
elapsed = time.time() - start_time
logger.info(f"Real model stack loaded in {elapsed:.2f}s")
return stack
def get_models() -> ModelStack:
"""Get or create the singleton model stack.
Returns fully loaded model stack (all models ready for inference).
"""
global _model_stack
if _model_stack is None:
logger.debug("Model stack not initialized, creating new stack")
_model_stack = get_model_stack()
else:
logger.debug("Returning cached model stack")
return _model_stack
def reset_models() -> None:
"""Reset the model stack (useful for testing)."""
global _model_stack
_model_stack = None
|