File size: 2,434 Bytes
5f0db1e
 
 
333c083
706520f
5f0db1e
333c083
706520f
 
 
 
5f0db1e
88bdcff
f3ebc82
 
88bdcff
 
 
 
f3ebc82
 
88bdcff
 
 
 
 
 
 
 
5f0db1e
 
333c083
706520f
5f0db1e
f3ebc82
 
88bdcff
f3ebc82
88bdcff
 
f3ebc82
 
 
 
88bdcff
333c083
706520f
f3ebc82
 
88bdcff
 
333c083
 
f3ebc82
333c083
f3ebc82
88bdcff
 
 
5f0db1e
 
333c083
5f0db1e
88bdcff
 
f3ebc82
88bdcff
f3ebc82
 
88bdcff
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
"""Model loading with mock/real switching based on environment.

Supports two loading modes:
- MOCK_MODELS=true: Loads mock models (fast, for local dev on RTX 4090)
- MOCK_MODELS=false: Loads all real models at startup (~38-43GB total)

Memory Strategy (Simultaneous Loading for 4xL4 GPUs with 88GB total):
- Vision 30B-A3B FP8 via vLLM: ~30-35GB
- Embedding 2B: ~4GB
- Reranker 2B: ~4GB
- Total: ~38-43GB, leaving ~45GB+ headroom
"""

import logging
import time
from typing import Union

from config.settings import settings

logger = logging.getLogger(__name__)

# Type alias for model stack
ModelStack = Union["MockModelStack", "RealModelStack"]  # noqa: F821

# Lazy singleton
_model_stack: ModelStack | None = None


def get_model_stack() -> ModelStack:
    """Get model stack based on environment configuration.

    For mock models: Loads mock models immediately (fast, for local dev).
    For real models: Loads all 3 models at startup (~38-43GB total).
    """
    start_time = time.time()

    if settings.mock_models:
        logger.info("Loading MOCK model stack (development mode)")
        from models.mock import MockModelStack

        stack = MockModelStack().load_all()
        elapsed = time.time() - start_time
        logger.info(f"Mock model stack loaded in {elapsed:.2f}s")
        return stack
    else:
        logger.info("Loading REAL model stack (production mode)")
        logger.info(f"Vision model: {settings.vision_model} (FP8 via vLLM)")
        logger.info(f"Embedding model: {settings.embedding_model}")
        logger.info(f"Reranker model: {settings.reranker_model}")
        from models.real import RealModelStack

        # Load all models at startup (simultaneous loading)
        stack = RealModelStack().load_all()
        elapsed = time.time() - start_time
        logger.info(f"Real model stack loaded in {elapsed:.2f}s")
        return stack


def get_models() -> ModelStack:
    """Get or create the singleton model stack.

    Returns fully loaded model stack (all models ready for inference).
    """
    global _model_stack
    if _model_stack is None:
        logger.debug("Model stack not initialized, creating new stack")
        _model_stack = get_model_stack()
    else:
        logger.debug("Returning cached model stack")
    return _model_stack


def reset_models() -> None:
    """Reset the model stack (useful for testing)."""
    global _model_stack
    _model_stack = None