File size: 4,498 Bytes
12f0afd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68e67db
 
 
12f0afd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
#!/usr/bin/env python3
"""
Model Cache Manager

Provides global caching for HuggingFace models to prevent re-downloads
across multiple instances and sessions.
"""

import logging
from typing import Optional
from pathlib import Path
from langchain_huggingface import HuggingFaceEmbeddings
from sentence_transformers import CrossEncoder

from app.core.logging import logger

# Optional accelerate import
try:
    from accelerate import Accelerator
    ACCELERATE_AVAILABLE = True
except ImportError:
    ACCELERATE_AVAILABLE = False
    Accelerator = None

# Global model cache
_EMBEDDINGS_CACHE = {}
_CROSS_ENCODER_CACHE = {}

# Local models directory - support worktrees via environment variable
import os
_MODELS_DIR = Path(os.getenv('MODELS_DIR', 'models')).resolve()

def _get_local_model_path(model_name: str) -> Optional[Path]:
    """
    Get local path for a model if it exists.

    Args:
        model_name: HuggingFace model name

    Returns:
        Path to local model directory or None if not found
    """
    if "/" in model_name:
        # Handle different model name formats
        if model_name.startswith("sentence-transformers/"):
            # For sentence transformers: sentence-transformers/all-mpnet-base-v2
            model_short_name = model_name.split("/")[-1]
            local_path = _MODELS_DIR / "sentence_transformers" / model_short_name
        elif model_name.startswith("cross-encoder/"):
            # For cross encoders: cross-encoder/ms-marco-MiniLM-L-6-v2
            model_short_name = model_name.split("/")[-1]
            local_path = _MODELS_DIR / "cross_encoder" / model_short_name
        else:
            # Fallback for other models
            model_short_name = model_name.split("/")[-1]
            local_path = _MODELS_DIR / model_short_name

        if local_path.exists():
            return local_path

    return None

def get_cached_embeddings(model_name: str = "sentence-transformers/all-mpnet-base-v2") -> HuggingFaceEmbeddings:
    """
    Get cached HuggingFace embeddings model with accelerate optimization.

    Creates the model only once and reuses it across all instances.
    Uses local models directory if available, otherwise downloads from HuggingFace.
    Automatically uses GPU if available via accelerate.
    """
    if model_name not in _EMBEDDINGS_CACHE:
        # Check for local model first
        local_path = _get_local_model_path(model_name)
        if local_path:
            logger.info(f"Using local embeddings model: {local_path}")
            embeddings = HuggingFaceEmbeddings(model_name=str(local_path))
        else:
            logger.info(f"Downloading embeddings model: {model_name}")
            embeddings = HuggingFaceEmbeddings(model_name=model_name)

        # Optimize device placement with accelerate if available
        if ACCELERATE_AVAILABLE:
            try:
                accelerator = Accelerator()
                logger.info(f"Embeddings model optimized for device: {accelerator.device}")
                # Accelerate will automatically handle device placement
            except Exception as e:
                logger.warning(f"Failed to optimize embeddings with accelerate: {e}")

        _EMBEDDINGS_CACHE[model_name] = embeddings
    else:
        logger.debug(f"Using cached embeddings model: {model_name}")

    return _EMBEDDINGS_CACHE[model_name]

def get_cached_cross_encoder(model_name: str = 'cross-encoder/ms-marco-MiniLM-L-6-v2') -> CrossEncoder:
    """
    Get cached cross-encoder model.

    Creates the model only once and reuses it across all instances.
    Uses local models directory if available, otherwise downloads from HuggingFace.
    """
    if model_name not in _CROSS_ENCODER_CACHE:
        # Check for local model first
        local_path = _get_local_model_path(model_name)
        if local_path:
            logger.info(f"Using local cross-encoder model: {local_path}")
            _CROSS_ENCODER_CACHE[model_name] = CrossEncoder(str(local_path))
        else:
            logger.info(f"Downloading cross-encoder model: {model_name}")
            _CROSS_ENCODER_CACHE[model_name] = CrossEncoder(model_name)
    else:
        logger.debug(f"Using cached cross-encoder model: {model_name}")

    return _CROSS_ENCODER_CACHE[model_name]

def clear_model_cache():
    """
    Clear all cached models.

    Useful for memory management or testing.
    """
    _EMBEDDINGS_CACHE.clear()
    _CROSS_ENCODER_CACHE.clear()
    logger.info("Model cache cleared")