Phillnet-2 / memory_optimization /model_cache.py
ayjays132's picture
Upload 478 files
101858b verified
"""
Model Cache Module
Shared model caching system with shared Qwen model integration.
"""
import torch
import logging
from typing import Dict, Any, Optional
import sys
import os
logger = logging.getLogger(__name__)
# Try to import shared model
SHARED_MODEL_AVAILABLE = False
get_shared_model_func = None
get_shared_tokenizer_func = None
try:
import importlib.util
shared_model_path = os.path.join(os.path.dirname(__file__), '..', 'Shared Model', 'shared_model.py')
if os.path.exists(shared_model_path):
spec = importlib.util.spec_from_file_location("shared_model", shared_model_path)
if spec and spec.loader:
shared_model_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(shared_model_module)
SharedModel = shared_model_module.SharedModel
SharedModelConfig = shared_model_module.SharedModelConfig
get_shared_model_func = shared_model_module.get_shared_model
get_shared_tokenizer_func = shared_model_module.get_shared_tokenizer
SHARED_MODEL_AVAILABLE = True
except Exception as e:
logger.debug(f"Shared model not available: {e}")
class ModelCache:
"""
Shared model cache to prevent memory duplication.
Integrates with shared Qwen model for zero memory overhead.
"""
def __init__(self, use_shared_model: bool = True, shared_model_name: str = "Qwen/Qwen3-0.6B"):
self.use_shared_model = use_shared_model
self.shared_model_name = shared_model_name
self.shared_models = {}
self.shared_tokenizers = {}
logger.debug("ModelCache initialized")
def get_shared_model(self, model_name: str, model_type: str = "transformer",
device: Optional[str] = None, **kwargs) -> Any:
"""
Get or create a shared model instance.
Uses shared Qwen model if available for zero memory overhead.
Args:
model_name: Name of the model to load
model_type: Type of model (transformer, tokenizer, etc.)
device: Device to load model on
**kwargs: Additional model loading parameters
Returns:
Shared model instance
"""
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
# Try to use shared Qwen model first
if (self.use_shared_model and model_type == "transformer" and
SHARED_MODEL_AVAILABLE and get_shared_model_func is not None):
try:
shared_model = get_shared_model_func()
if shared_model is not None:
logger.info(f"[CACHE] Using shared Qwen model (zero memory overhead)")
return shared_model
except Exception as e:
logger.debug(f"[CACHE] Shared model not available: {e}")
# Try to use shared tokenizer
if (self.use_shared_model and model_type == "tokenizer" and
SHARED_MODEL_AVAILABLE and get_shared_tokenizer_func is not None):
try:
shared_tokenizer = get_shared_tokenizer_func()
if shared_tokenizer is not None:
logger.info(f"[CACHE] Using shared Qwen tokenizer (zero memory overhead)")
return shared_tokenizer
except Exception as e:
logger.debug(f"[CACHE] Shared tokenizer not available: {e}")
# Fallback to cached models
cache_key = f"{model_name}_{model_type}_{device}_{hash(str(sorted(kwargs.items())))}"
if model_type == "tokenizer":
cache_dict = self.shared_tokenizers
else:
cache_dict = self.shared_models
if cache_key not in cache_dict:
logger.info(f"[CACHE] Loading {model_type} model: {model_name}")
try:
if model_type == "transformer":
model = self._load_transformer_model(model_name, device, **kwargs)
elif model_type == "tokenizer":
model = self._load_tokenizer_model(model_name, device, **kwargs)
else:
raise ValueError(f"Unknown model type: {model_type}")
cache_dict[cache_key] = model
logger.info(f"[CACHE] {model_type} model cached: {cache_key}")
except Exception as e:
logger.error(f"[CACHE] Failed to load {model_type} model {model_name}: {e}")
raise
else:
logger.debug(f"[CACHE] Using cached {model_type} model: {cache_key}")
return cache_dict[cache_key]
def _load_transformer_model(self, model_name: str, device: str, **kwargs) -> Any:
"""Load transformer model with memory optimizations."""
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
# Memory-optimized loading configuration
load_config = {
"torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32,
"device_map": "auto" if torch.cuda.is_available() else None,
"trust_remote_code": True,
"attn_implementation": "eager", # More memory efficient
}
# Add quantization if available and beneficial
if kwargs.get('use_4bit_quantization', True) and torch.cuda.is_available():
try:
# Check if bitsandbytes is properly installed with CUDA support
import bitsandbytes as bnb
if hasattr(bnb, 'libbitsandbytes_cuda'):
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
load_config["quantization_config"] = quantization_config
else:
logger.debug("[CACHE] BitsAndBytes CUDA support not available, skipping quantization")
except (ImportError, AttributeError, Exception) as e:
logger.debug(f"[CACHE] 4-bit quantization not available: {e}")
# Remove problematic kwargs
filtered_kwargs = {k: v for k, v in kwargs.items()
if k not in ['use_4bit', 'dtype', 'use_4bit_quantization']}
# Merge with user-provided kwargs
load_config.update(filtered_kwargs)
return AutoModelForCausalLM.from_pretrained(model_name, **load_config)
def _load_tokenizer_model(self, model_name: str, device: str, **kwargs) -> Any:
"""Load tokenizer with memory optimizations."""
from transformers import AutoTokenizer
load_config = {
"trust_remote_code": True,
}
load_config.update({k: v for k, v in kwargs.items() if k != 'use_4bit_quantization'})
return AutoTokenizer.from_pretrained(model_name, **load_config)
def clear_cache(self) -> None:
"""Clear all cached models."""
self.shared_models.clear()
self.shared_tokenizers.clear()
logger.info("[CACHE] Model cache cleared")
def get_stats(self) -> Dict:
"""Get cache statistics."""
return {
'shared_models': list(self.shared_models.keys()),
'shared_tokenizers': list(self.shared_tokenizers.keys()),
'use_shared_model': self.use_shared_model
}