File size: 4,498 Bytes
12f0afd 68e67db 12f0afd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
#!/usr/bin/env python3
"""
Model Cache Manager
Provides global caching for HuggingFace models to prevent re-downloads
across multiple instances and sessions.
"""
import logging
from typing import Optional
from pathlib import Path
from langchain_huggingface import HuggingFaceEmbeddings
from sentence_transformers import CrossEncoder
from app.core.logging import logger
# Optional accelerate import
try:
from accelerate import Accelerator
ACCELERATE_AVAILABLE = True
except ImportError:
ACCELERATE_AVAILABLE = False
Accelerator = None
# Global model cache
_EMBEDDINGS_CACHE = {}
_CROSS_ENCODER_CACHE = {}
# Local models directory - support worktrees via environment variable
import os
_MODELS_DIR = Path(os.getenv('MODELS_DIR', 'models')).resolve()
def _get_local_model_path(model_name: str) -> Optional[Path]:
"""
Get local path for a model if it exists.
Args:
model_name: HuggingFace model name
Returns:
Path to local model directory or None if not found
"""
if "/" in model_name:
# Handle different model name formats
if model_name.startswith("sentence-transformers/"):
# For sentence transformers: sentence-transformers/all-mpnet-base-v2
model_short_name = model_name.split("/")[-1]
local_path = _MODELS_DIR / "sentence_transformers" / model_short_name
elif model_name.startswith("cross-encoder/"):
# For cross encoders: cross-encoder/ms-marco-MiniLM-L-6-v2
model_short_name = model_name.split("/")[-1]
local_path = _MODELS_DIR / "cross_encoder" / model_short_name
else:
# Fallback for other models
model_short_name = model_name.split("/")[-1]
local_path = _MODELS_DIR / model_short_name
if local_path.exists():
return local_path
return None
def get_cached_embeddings(model_name: str = "sentence-transformers/all-mpnet-base-v2") -> HuggingFaceEmbeddings:
"""
Get cached HuggingFace embeddings model with accelerate optimization.
Creates the model only once and reuses it across all instances.
Uses local models directory if available, otherwise downloads from HuggingFace.
Automatically uses GPU if available via accelerate.
"""
if model_name not in _EMBEDDINGS_CACHE:
# Check for local model first
local_path = _get_local_model_path(model_name)
if local_path:
logger.info(f"Using local embeddings model: {local_path}")
embeddings = HuggingFaceEmbeddings(model_name=str(local_path))
else:
logger.info(f"Downloading embeddings model: {model_name}")
embeddings = HuggingFaceEmbeddings(model_name=model_name)
# Optimize device placement with accelerate if available
if ACCELERATE_AVAILABLE:
try:
accelerator = Accelerator()
logger.info(f"Embeddings model optimized for device: {accelerator.device}")
# Accelerate will automatically handle device placement
except Exception as e:
logger.warning(f"Failed to optimize embeddings with accelerate: {e}")
_EMBEDDINGS_CACHE[model_name] = embeddings
else:
logger.debug(f"Using cached embeddings model: {model_name}")
return _EMBEDDINGS_CACHE[model_name]
def get_cached_cross_encoder(model_name: str = 'cross-encoder/ms-marco-MiniLM-L-6-v2') -> CrossEncoder:
"""
Get cached cross-encoder model.
Creates the model only once and reuses it across all instances.
Uses local models directory if available, otherwise downloads from HuggingFace.
"""
if model_name not in _CROSS_ENCODER_CACHE:
# Check for local model first
local_path = _get_local_model_path(model_name)
if local_path:
logger.info(f"Using local cross-encoder model: {local_path}")
_CROSS_ENCODER_CACHE[model_name] = CrossEncoder(str(local_path))
else:
logger.info(f"Downloading cross-encoder model: {model_name}")
_CROSS_ENCODER_CACHE[model_name] = CrossEncoder(model_name)
else:
logger.debug(f"Using cached cross-encoder model: {model_name}")
return _CROSS_ENCODER_CACHE[model_name]
def clear_model_cache():
"""
Clear all cached models.
Useful for memory management or testing.
"""
_EMBEDDINGS_CACHE.clear()
_CROSS_ENCODER_CACHE.clear()
logger.info("Model cache cleared")
|