""" Model Cache Module Shared model caching system with shared Qwen model integration. """ import torch import logging from typing import Dict, Any, Optional import sys import os logger = logging.getLogger(__name__) # Try to import shared model SHARED_MODEL_AVAILABLE = False get_shared_model_func = None get_shared_tokenizer_func = None try: import importlib.util shared_model_path = os.path.join(os.path.dirname(__file__), '..', 'Shared Model', 'shared_model.py') if os.path.exists(shared_model_path): spec = importlib.util.spec_from_file_location("shared_model", shared_model_path) if spec and spec.loader: shared_model_module = importlib.util.module_from_spec(spec) spec.loader.exec_module(shared_model_module) SharedModel = shared_model_module.SharedModel SharedModelConfig = shared_model_module.SharedModelConfig get_shared_model_func = shared_model_module.get_shared_model get_shared_tokenizer_func = shared_model_module.get_shared_tokenizer SHARED_MODEL_AVAILABLE = True except Exception as e: logger.debug(f"Shared model not available: {e}") class ModelCache: """ Shared model cache to prevent memory duplication. Integrates with shared Qwen model for zero memory overhead. """ def __init__(self, use_shared_model: bool = True, shared_model_name: str = "Qwen/Qwen3-0.6B"): self.use_shared_model = use_shared_model self.shared_model_name = shared_model_name self.shared_models = {} self.shared_tokenizers = {} logger.debug("ModelCache initialized") def get_shared_model(self, model_name: str, model_type: str = "transformer", device: Optional[str] = None, **kwargs) -> Any: """ Get or create a shared model instance. Uses shared Qwen model if available for zero memory overhead. Args: model_name: Name of the model to load model_type: Type of model (transformer, tokenizer, etc.) device: Device to load model on **kwargs: Additional model loading parameters Returns: Shared model instance """ if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" # Try to use shared Qwen model first if (self.use_shared_model and model_type == "transformer" and SHARED_MODEL_AVAILABLE and get_shared_model_func is not None): try: shared_model = get_shared_model_func() if shared_model is not None: logger.info(f"[CACHE] Using shared Qwen model (zero memory overhead)") return shared_model except Exception as e: logger.debug(f"[CACHE] Shared model not available: {e}") # Try to use shared tokenizer if (self.use_shared_model and model_type == "tokenizer" and SHARED_MODEL_AVAILABLE and get_shared_tokenizer_func is not None): try: shared_tokenizer = get_shared_tokenizer_func() if shared_tokenizer is not None: logger.info(f"[CACHE] Using shared Qwen tokenizer (zero memory overhead)") return shared_tokenizer except Exception as e: logger.debug(f"[CACHE] Shared tokenizer not available: {e}") # Fallback to cached models cache_key = f"{model_name}_{model_type}_{device}_{hash(str(sorted(kwargs.items())))}" if model_type == "tokenizer": cache_dict = self.shared_tokenizers else: cache_dict = self.shared_models if cache_key not in cache_dict: logger.info(f"[CACHE] Loading {model_type} model: {model_name}") try: if model_type == "transformer": model = self._load_transformer_model(model_name, device, **kwargs) elif model_type == "tokenizer": model = self._load_tokenizer_model(model_name, device, **kwargs) else: raise ValueError(f"Unknown model type: {model_type}") cache_dict[cache_key] = model logger.info(f"[CACHE] {model_type} model cached: {cache_key}") except Exception as e: logger.error(f"[CACHE] Failed to load {model_type} model {model_name}: {e}") raise else: logger.debug(f"[CACHE] Using cached {model_type} model: {cache_key}") return cache_dict[cache_key] def _load_transformer_model(self, model_name: str, device: str, **kwargs) -> Any: """Load transformer model with memory optimizations.""" from transformers import AutoModelForCausalLM, BitsAndBytesConfig # Memory-optimized loading configuration load_config = { "torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32, "device_map": "auto" if torch.cuda.is_available() else None, "trust_remote_code": True, "attn_implementation": "eager", # More memory efficient } # Add quantization if available and beneficial if kwargs.get('use_4bit_quantization', True) and torch.cuda.is_available(): try: # Check if bitsandbytes is properly installed with CUDA support import bitsandbytes as bnb if hasattr(bnb, 'libbitsandbytes_cuda'): quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4" ) load_config["quantization_config"] = quantization_config else: logger.debug("[CACHE] BitsAndBytes CUDA support not available, skipping quantization") except (ImportError, AttributeError, Exception) as e: logger.debug(f"[CACHE] 4-bit quantization not available: {e}") # Remove problematic kwargs filtered_kwargs = {k: v for k, v in kwargs.items() if k not in ['use_4bit', 'dtype', 'use_4bit_quantization']} # Merge with user-provided kwargs load_config.update(filtered_kwargs) return AutoModelForCausalLM.from_pretrained(model_name, **load_config) def _load_tokenizer_model(self, model_name: str, device: str, **kwargs) -> Any: """Load tokenizer with memory optimizations.""" from transformers import AutoTokenizer load_config = { "trust_remote_code": True, } load_config.update({k: v for k, v in kwargs.items() if k != 'use_4bit_quantization'}) return AutoTokenizer.from_pretrained(model_name, **load_config) def clear_cache(self) -> None: """Clear all cached models.""" self.shared_models.clear() self.shared_tokenizers.clear() logger.info("[CACHE] Model cache cleared") def get_stats(self) -> Dict: """Get cache statistics.""" return { 'shared_models': list(self.shared_models.keys()), 'shared_tokenizers': list(self.shared_tokenizers.keys()), 'use_shared_model': self.use_shared_model }