ImageCaptionner / app /utils /model_cache.py
AOUNZakaria's picture
Update app/utils/model_cache.py
832b141 verified
"""
Model Caching Module for Production
Loads models once at startup and reuses them for all requests.
This eliminates the overhead of loading models per-request.
"""
import torch
import os
import logging
from pathlib import Path
logger = logging.getLogger(__name__)
# Get base directory (project root)
BASE_DIR = Path(__file__).resolve().parent.parent.parent
MODELS_DIR = BASE_DIR / "models"
class ModelCache:
"""Singleton class to cache loaded models in memory."""
def __init__(self):
self._resnet_encoder = None
self._resnet_decoder = None
self._resnet_vocab = None
self._efficientnet_model = None
self._efficientnet_tokenizer = None
self._device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self._models_loaded = False
logger.info(f"ModelCache initialized on device: {self._device}")
def load_all_models(self,
resnet_path=None,
efficientnet_path=None,
use_optimized=True):
"""
Load all models at startup.
Args:
resnet_path: Path to ResNet checkpoint (default: models/resnet_best_model.pth)
efficientnet_path: Path to EfficientNet checkpoint (default: models/efficient_best_model.pth)
use_optimized: If True, try to load optimized models first
"""
if self._models_loaded:
logger.warning("Models already loaded, skipping")
return
# Set default paths
if resnet_path is None:
resnet_path = str(MODELS_DIR / "resnet_best_model.pth")
if efficientnet_path is None:
efficientnet_path = str(MODELS_DIR / "efficient_best_model.pth")
# Try optimized models first if requested
if use_optimized:
# Check multiple possible locations for optimized models
optimized_resnet_paths = [
str(MODELS_DIR / "optimized_models" / "resnet_resnet_best_model_quantized.pth"),
str(BASE_DIR / "optimized_models" / "resnet_resnet_best_model_quantized.pth"),
resnet_path.replace('.pth', '_quantized.pth'),
resnet_path.replace('resnet_best_model.pth', 'resnet_resnet_best_model_quantized.pth'),
]
optimized_efficient_paths = [
str(MODELS_DIR / "optimized_models" / "efficientnet_efficient_best_model_quantized.pth"),
str(BASE_DIR / "optimized_models" / "efficientnet_efficient_best_model_quantized.pth"),
efficientnet_path.replace('.pth', '_quantized.pth'),
efficientnet_path.replace('efficient_best_model.pth', 'efficientnet_efficient_best_model_quantized.pth'),
]
# Find optimized ResNet model
for opt_path in optimized_resnet_paths:
if os.path.exists(opt_path):
resnet_path = opt_path
logger.info(f"Using optimized ResNet model: {resnet_path}")
break
# Find optimized EfficientNet model
for opt_path in optimized_efficient_paths:
if os.path.exists(opt_path):
efficientnet_path = opt_path
logger.info(f"Using optimized EfficientNet model: {efficientnet_path}")
break
# Load EfficientNet only (ResNet skipped)
try:
self.load_efficientnet_model(efficientnet_path)
logger.info("EfficientNet model loaded successfully")
except Exception as e:
logger.error(f"Failed to load EfficientNet model: {e}", exc_info=True)
self._models_loaded = True
def load_efficientnet_model_only(self, use_optimized=True):
"""
Load only EfficientNet model (skip ResNet).
Useful when only EfficientNet is needed.
"""
if self._models_loaded:
logger.warning("Models already loaded, skipping")
return
efficientnet_path = str(MODELS_DIR / "efficient_best_model.pth")
# Try optimized model first if requested
if use_optimized:
optimized_efficient_paths = [
str(MODELS_DIR / "optimized_models" / "efficientnet_efficient_best_model_quantized.pth"),
str(BASE_DIR / "optimized_models" / "efficientnet_efficient_best_model_quantized.pth"),
efficientnet_path.replace('.pth', '_quantized.pth'),
efficientnet_path.replace('efficient_best_model.pth', 'efficientnet_efficient_best_model_quantized.pth'),
]
# Also search in HF Hub cache directories (nested structure)
hf_cache_base = MODELS_DIR / "optimized_models"
if hf_cache_base.exists():
# Search for HF Hub cache structure: models--*--*/snapshots/*/model.pth
for cache_dir in hf_cache_base.glob("models--*"):
if cache_dir.is_dir():
snapshots_dir = cache_dir / "snapshots"
if snapshots_dir.exists():
for snapshot_dir in snapshots_dir.glob("*"):
if snapshot_dir.is_dir():
hf_model_path = snapshot_dir / "efficientnet_efficient_best_model_quantized.pth"
if hf_model_path.exists():
optimized_efficient_paths.insert(0, str(hf_model_path))
logger.info(f"Found model in HF Hub cache: {hf_model_path}")
# Find optimized EfficientNet model
for opt_path in optimized_efficient_paths:
if os.path.exists(opt_path):
efficientnet_path = opt_path
logger.info(f"Using optimized EfficientNet model: {efficientnet_path}")
break
# Load EfficientNet
try:
self.load_efficientnet_model(efficientnet_path)
logger.info("EfficientNet model loaded successfully")
except Exception as e:
logger.error(f"Failed to load EfficientNet model: {e}", exc_info=True)
self._models_loaded = True
def load_resnet_models(self, checkpoint_path=None):
"""Load ResNet encoder and decoder models."""
if self._resnet_encoder is not None:
return self._resnet_encoder, self._resnet_decoder, self._resnet_vocab
if checkpoint_path is None:
checkpoint_path = str(MODELS_DIR / "resnet_best_model.pth")
# Resolve path - try multiple locations
checkpoint_path = self._resolve_model_path(checkpoint_path)
logger.info(f"Loading ResNet models from {checkpoint_path}")
# Import from training module (handles both old and new locations)
# Need to do this BEFORE loading checkpoint to avoid pickle issues
try:
from training.resnet_train import EncoderCNN, DecoderRNN
# Add to sys.modules to help with pickle loading
import sys
if 'resnet_train' not in sys.modules:
sys.modules['resnet_train'] = sys.modules['training.resnet_train']
except ImportError:
try:
# Fallback for backward compatibility
import sys
sys.path.insert(0, str(BASE_DIR))
from resnet_train import EncoderCNN, DecoderRNN
except ImportError:
logger.error("Could not import ResNet model classes. Make sure resnet_train.py exists in training/ or root.")
raise
# Load checkpoint with proper module mapping
import sys
import importlib.util
# Map old module names for pickle compatibility
if 'resnet_train' not in sys.modules:
try:
spec = importlib.util.spec_from_file_location("resnet_train", str(BASE_DIR / "training" / "resnet_train.py"))
if spec and spec.loader:
resnet_module = importlib.util.module_from_spec(spec)
sys.modules['resnet_train'] = resnet_module
spec.loader.exec_module(resnet_module)
except Exception:
pass
checkpoint = torch.load(checkpoint_path, map_location=self._device, weights_only=False)
# Initialize models
self._resnet_encoder = EncoderCNN().to(self._device)
self._resnet_decoder = DecoderRNN().to(self._device)
# Load weights
self._resnet_encoder.load_state_dict(checkpoint['encoder'])
self._resnet_decoder.load_state_dict(checkpoint['decoder'])
# Set to eval mode
self._resnet_encoder.eval()
self._resnet_decoder.eval()
# Store vocabulary
self._resnet_vocab = checkpoint.get('vocab')
# Warm up models (first inference is slower)
logger.info("Warming up ResNet models...")
dummy_input = torch.randn(1, 3, 224, 224).to(self._device)
with torch.no_grad():
_ = self._resnet_encoder(dummy_input)
logger.info("ResNet models warmed up")
return self._resnet_encoder, self._resnet_decoder, self._resnet_vocab
def load_efficientnet_model(self, checkpoint_path=None):
"""Load EfficientNet model."""
if self._efficientnet_model is not None:
return self._efficientnet_model, self._efficientnet_tokenizer
if checkpoint_path is None:
checkpoint_path = str(MODELS_DIR / "efficient_best_model.pth")
# Resolve path - try multiple locations
checkpoint_path = self._resolve_model_path(checkpoint_path)
logger.info(f"Loading EfficientNet model from {checkpoint_path}")
# Import from training module (handles both old and new locations)
try:
from training.efficient_train import Encoder, Decoder, ImageCaptioningModel
except ImportError:
try:
# Fallback for backward compatibility
import sys
sys.path.insert(0, str(BASE_DIR))
from efficient_train import Encoder, Decoder, ImageCaptioningModel
except ImportError:
logger.error("Could not import EfficientNet model classes. Make sure efficient_train.py exists in training/ or root.")
raise
from transformers import AutoTokenizer
# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
special_tokens = {'additional_special_tokens': ['<start>', '<end>']}
tokenizer.add_special_tokens(special_tokens)
self._efficientnet_tokenizer = tokenizer
# Initialize model
encoder = Encoder(model_name='efficientnet_b3', embed_dim=512)
decoder = Decoder(
vocab_size=len(tokenizer),
embed_dim=512,
num_layers=8,
num_heads=8,
max_seq_length=64
)
self._efficientnet_model = ImageCaptioningModel(encoder, decoder).to(self._device)
# Load weights
checkpoint = torch.load(checkpoint_path, map_location=self._device, weights_only=False)
# Check if this is a quantized model (has _packed_params keys)
is_quantized = any('_packed_params' in key for key in checkpoint.get('model_state', checkpoint).keys())
if is_quantized:
# For quantized models, we need to prepare the model for quantization first
logger.info("Detected quantized model, preparing model for quantization...")
try:
# Prepare model for quantization
import torch.quantization as quant
self._efficientnet_model = quant.quantize_dynamic(
self._efficientnet_model, {torch.nn.Linear}, dtype=torch.qint8
)
logger.info("Model prepared for quantization")
except Exception as e:
logger.warning(f"Could not prepare model for quantization: {e}. Trying to load anyway...")
if 'model_state' in checkpoint:
try:
self._efficientnet_model.load_state_dict(checkpoint['model_state'], strict=False)
except Exception as e:
logger.warning(f"Could not load quantized state dict: {e}. Trying regular model...")
# Try loading non-quantized model instead
regular_path = checkpoint_path.replace('_quantized.pth', '.pth').replace('efficientnet_efficient_best_model', 'efficient_best_model')
if os.path.exists(regular_path) and regular_path != checkpoint_path:
logger.info(f"Trying regular model: {regular_path}")
checkpoint = torch.load(regular_path, map_location=self._device, weights_only=False)
if 'model_state' in checkpoint:
self._efficientnet_model.load_state_dict(checkpoint['model_state'])
else:
self._efficientnet_model.load_state_dict(checkpoint)
else:
# Fallback: try loading directly
try:
self._efficientnet_model.load_state_dict(checkpoint, strict=False)
except Exception:
logger.warning("Could not load state dict. Model may not work correctly.")
self._efficientnet_model.eval()
# Warm up
logger.info("Warming up EfficientNet model...")
dummy_input = torch.randn(1, 3, 224, 224).to(self._device)
with torch.no_grad():
_ = self._efficientnet_model.encoder(dummy_input)
logger.info("EfficientNet model warmed up")
return self._efficientnet_model, self._efficientnet_tokenizer
def _resolve_model_path(self, checkpoint_path):
"""Resolve model path, trying multiple locations."""
# If path exists, use it
if os.path.exists(checkpoint_path):
return checkpoint_path
# Try in models directory
alt_path = str(MODELS_DIR / os.path.basename(checkpoint_path))
if os.path.exists(alt_path):
logger.info(f"Found model at: {alt_path}")
return alt_path
# Try in optimized_models directory
alt_path = str(MODELS_DIR / "optimized_models" / os.path.basename(checkpoint_path))
if os.path.exists(alt_path):
logger.info(f"Found model at: {alt_path}")
return alt_path
# Try in root directory (backward compatibility)
alt_path = str(BASE_DIR / os.path.basename(checkpoint_path))
if os.path.exists(alt_path):
logger.info(f"Found model at: {alt_path}")
return alt_path
# Search in HF Hub cache directories (nested structure)
hf_cache_base = MODELS_DIR / "optimized_models"
if hf_cache_base.exists():
model_filename = os.path.basename(checkpoint_path)
# Also try quantized variant if looking for base model
search_filenames = [model_filename]
if 'efficient_best_model.pth' in model_filename and 'quantized' not in model_filename:
search_filenames.append('efficientnet_efficient_best_model_quantized.pth')
search_filenames.append(model_filename.replace('efficient_best_model.pth', 'efficientnet_efficient_best_model_quantized.pth'))
# Search for HF Hub cache structure: models--*--*/snapshots/*/model.pth
for cache_dir in hf_cache_base.glob("models--*"):
if cache_dir.is_dir():
snapshots_dir = cache_dir / "snapshots"
if snapshots_dir.exists():
for snapshot_dir in snapshots_dir.glob("*"):
if snapshot_dir.is_dir():
for search_filename in search_filenames:
hf_model_path = snapshot_dir / search_filename
if hf_model_path.exists():
logger.info(f"Found model in HF Hub cache: {hf_model_path}")
return str(hf_model_path)
# Return original path (will fail with clear error)
return checkpoint_path
def get_resnet_models(self):
"""Get cached ResNet models."""
if self._resnet_encoder is None:
raise RuntimeError("ResNet models not loaded. Call load_resnet_models() first.")
return self._resnet_encoder, self._resnet_decoder, self._resnet_vocab
def get_efficientnet_model(self):
"""Get cached EfficientNet model."""
if self._efficientnet_model is None:
raise RuntimeError("EfficientNet model not loaded. Call load_efficientnet_model() first.")
return self._efficientnet_model, self._efficientnet_tokenizer
def is_resnet_loaded(self):
"""Check if ResNet models are loaded."""
return self._resnet_encoder is not None
def is_efficientnet_loaded(self):
"""Check if EfficientNet model is loaded."""
return self._efficientnet_model is not None
# Singleton instance
model_cache = ModelCache()