import threading from typing import Dict, Any, Optional, Callable from .base_model import BaseModel from .internvl import InternVLModel from .qwen import QwenModel from ..config.config_manager import ConfigManager class ModelManager: """Manager class for handling multiple vision-language models.""" def __init__(self, config_manager: ConfigManager): """ Initialize the model manager. Args: config_manager: Configuration manager instance """ self.config_manager = config_manager self.models: Dict[str, BaseModel] = {} self.current_model: Optional[BaseModel] = None self.current_model_name: Optional[str] = None self.loading_lock = threading.Lock() # Apply environment settings self.config_manager.apply_environment_settings() # Initialize models but don't load them yet self._initialize_models() def _get_model_class(self, model_config: Dict[str, Any]) -> type: """ Determine the appropriate model class based on model configuration. Args: model_config: Model configuration dictionary Returns: Model class to instantiate """ model_type = model_config.get('model_type', 'internvl').lower() model_id = model_config.get('model_id', '').lower() # Determine model type based on model_id or explicit model_type if 'qwen' in model_id or model_type == 'qwen': return QwenModel elif 'internvl' in model_id or model_type == 'internvl': return InternVLModel else: # Default to InternVL for backward compatibility print(f"âš ī¸ Unknown model type for {model_config.get('name', 'unknown')}, defaulting to InternVL") return InternVLModel def _initialize_models(self) -> None: """Initialize model instances without loading them.""" available_models = self.config_manager.get_available_models() for model_name, model_id in available_models.items(): model_config = self.config_manager.get_model_config(model_name) # Determine the appropriate model class model_class = self._get_model_class(model_config) # Create model instance self.models[model_name] = model_class( model_name=model_name, model_config=model_config, config_manager=self.config_manager ) print(f"✅ Initialized {model_class.__name__}: {model_name}") def get_available_models(self) -> list[str]: """Get list of available model names.""" return list(self.models.keys()) def get_model_info(self, model_name: str) -> Dict[str, Any]: """ Get information about a specific model. Args: model_name: Name of the model Returns: Model information dictionary """ if model_name not in self.models: raise KeyError(f"Model '{model_name}' not available") return self.models[model_name].get_model_info() def get_all_models_info(self) -> Dict[str, Dict[str, Any]]: """Get information about all available models.""" return {name: model.get_model_info() for name, model in self.models.items()} def load_model( self, model_name: str, quantization_type: str, progress_callback: Optional[Callable] = None ) -> bool: """ Load a specific model with given quantization. Args: model_name: Name of the model to load quantization_type: Type of quantization to use progress_callback: Callback function for progress updates Returns: True if successful, False otherwise """ with self.loading_lock: if model_name not in self.models: raise KeyError(f"Model '{model_name}' not available") model = self.models[model_name] # Check if this model is already loaded with the same quantization if (self.current_model == model and model.is_model_loaded() and model.current_quantization == quantization_type): if progress_callback: progress_callback(f"✅ {model_name} already loaded with {quantization_type}!") return True # Unload current model if different if (self.current_model and self.current_model != model and self.current_model.is_model_loaded()): if progress_callback: progress_callback(f"🔄 Unloading {self.current_model_name}...") self.current_model.unload_model() # Load the requested model try: success = model.load_model(quantization_type, progress_callback) if success: self.current_model = model self.current_model_name = model_name print(f"✅ Successfully loaded {model_name} with {quantization_type}") return True else: if progress_callback: progress_callback(f"❌ Failed to load {model_name}") return False except Exception as e: error_msg = f"Error loading {model_name}: {str(e)}" print(error_msg) if progress_callback: progress_callback(f"❌ {error_msg}") return False def unload_current_model(self) -> None: """Unload the currently loaded model.""" with self.loading_lock: if self.current_model and self.current_model.is_model_loaded(): print(f"🔄 Unloading {self.current_model_name}...") self.current_model.unload_model() self.current_model = None self.current_model_name = None print("✅ Model unloaded successfully") else: print("â„šī¸ No model currently loaded") def inference(self, image_path: str, prompt: str, **kwargs) -> str: """ Perform inference using the currently loaded model. Args: image_path: Path to the image file prompt: Text prompt for the model **kwargs: Additional inference parameters Returns: Model's text response """ if not self.current_model or not self.current_model.is_model_loaded(): raise RuntimeError("No model is currently loaded. Load a model first.") return self.current_model.inference(image_path, prompt, **kwargs) def get_current_model_status(self) -> str: """Get status string for the currently loaded model.""" if not self.current_model or not self.current_model.is_model_loaded(): return "❌ No model loaded" quantization = self.current_model.current_quantization or "Unknown" model_class = self.current_model.__class__.__name__ return f"✅ {self.current_model_name} ({model_class}) loaded with {quantization}" def get_supported_quantizations(self, model_name: str) -> list[str]: """Get supported quantization methods for a model.""" if model_name not in self.models: raise KeyError(f"Model '{model_name}' not available") return self.models[model_name].get_supported_quantizations() def validate_model_and_quantization(self, model_name: str, quantization_type: str) -> bool: """ Validate if a model and quantization combination is valid. Args: model_name: Name of the model quantization_type: Type of quantization Returns: True if valid, False otherwise """ if model_name not in self.models: return False return self.models[model_name].validate_quantization(quantization_type) def get_model_memory_requirements(self, model_name: str) -> Dict[str, int]: """Get memory requirements for a specific model.""" if model_name not in self.models: raise KeyError(f"Model '{model_name}' not available") return self.models[model_name].get_memory_requirements() def preload_default_model(self) -> bool: """ Preload the default model specified in configuration. Returns: True if successful, False otherwise """ default_model = self.config_manager.get_default_model() default_quantization = self.config_manager.get_default_quantization(default_model) print(f"🚀 Preloading default model: {default_model} with {default_quantization}") try: return self.load_model(default_model, default_quantization) except Exception as e: print(f"âš ī¸ Failed to preload default model: {str(e)}") return False def __str__(self) -> str: """String representation of the model manager.""" loaded_info = f"Current: {self.current_model_name}" if self.current_model_name else "None loaded" return f"ModelManager({len(self.models)} models available, {loaded_info})" def __repr__(self) -> str: """Detailed string representation.""" models_list = list(self.models.keys()) return f"ModelManager(models={models_list}, current={self.current_model_name})"