Spaces:
Runtime error
Runtime error
| import threading | |
| from typing import Dict, Any, Optional, Callable | |
| from .base_model import BaseModel | |
| from .internvl import InternVLModel | |
| from .qwen import QwenModel | |
| from ..config.config_manager import ConfigManager | |
| class ModelManager: | |
| """Manager class for handling multiple vision-language models.""" | |
| def __init__(self, config_manager: ConfigManager): | |
| """ | |
| Initialize the model manager. | |
| Args: | |
| config_manager: Configuration manager instance | |
| """ | |
| self.config_manager = config_manager | |
| self.models: Dict[str, BaseModel] = {} | |
| self.current_model: Optional[BaseModel] = None | |
| self.current_model_name: Optional[str] = None | |
| self.loading_lock = threading.Lock() | |
| # Apply environment settings | |
| self.config_manager.apply_environment_settings() | |
| # Initialize models but don't load them yet | |
| self._initialize_models() | |
| def _get_model_class(self, model_config: Dict[str, Any]) -> type: | |
| """ | |
| Determine the appropriate model class based on model configuration. | |
| Args: | |
| model_config: Model configuration dictionary | |
| Returns: | |
| Model class to instantiate | |
| """ | |
| model_type = model_config.get('model_type', 'internvl').lower() | |
| model_id = model_config.get('model_id', '').lower() | |
| # Determine model type based on model_id or explicit model_type | |
| if 'qwen' in model_id or model_type == 'qwen': | |
| return QwenModel | |
| elif 'internvl' in model_id or model_type == 'internvl': | |
| return InternVLModel | |
| else: | |
| # Default to InternVL for backward compatibility | |
| print(f"β οΈ Unknown model type for {model_config.get('name', 'unknown')}, defaulting to InternVL") | |
| return InternVLModel | |
| def _initialize_models(self) -> None: | |
| """Initialize model instances without loading them.""" | |
| available_models = self.config_manager.get_available_models() | |
| for model_name, model_id in available_models.items(): | |
| model_config = self.config_manager.get_model_config(model_name) | |
| # Determine the appropriate model class | |
| model_class = self._get_model_class(model_config) | |
| # Create model instance | |
| self.models[model_name] = model_class( | |
| model_name=model_name, | |
| model_config=model_config, | |
| config_manager=self.config_manager | |
| ) | |
| print(f"β Initialized {model_class.__name__}: {model_name}") | |
| def get_available_models(self) -> list[str]: | |
| """Get list of available model names.""" | |
| return list(self.models.keys()) | |
| def get_model_info(self, model_name: str) -> Dict[str, Any]: | |
| """ | |
| Get information about a specific model. | |
| Args: | |
| model_name: Name of the model | |
| Returns: | |
| Model information dictionary | |
| """ | |
| if model_name not in self.models: | |
| raise KeyError(f"Model '{model_name}' not available") | |
| return self.models[model_name].get_model_info() | |
| def get_all_models_info(self) -> Dict[str, Dict[str, Any]]: | |
| """Get information about all available models.""" | |
| return {name: model.get_model_info() for name, model in self.models.items()} | |
| def load_model( | |
| self, | |
| model_name: str, | |
| quantization_type: str, | |
| progress_callback: Optional[Callable] = None | |
| ) -> bool: | |
| """ | |
| Load a specific model with given quantization. | |
| Args: | |
| model_name: Name of the model to load | |
| quantization_type: Type of quantization to use | |
| progress_callback: Callback function for progress updates | |
| Returns: | |
| True if successful, False otherwise | |
| """ | |
| with self.loading_lock: | |
| if model_name not in self.models: | |
| raise KeyError(f"Model '{model_name}' not available") | |
| model = self.models[model_name] | |
| # Check if this model is already loaded with the same quantization | |
| if (self.current_model == model and | |
| model.is_model_loaded() and | |
| model.current_quantization == quantization_type): | |
| if progress_callback: | |
| progress_callback(f"β {model_name} already loaded with {quantization_type}!") | |
| return True | |
| # Unload current model if different | |
| if (self.current_model and | |
| self.current_model != model and | |
| self.current_model.is_model_loaded()): | |
| if progress_callback: | |
| progress_callback(f"π Unloading {self.current_model_name}...") | |
| self.current_model.unload_model() | |
| # Load the requested model | |
| try: | |
| success = model.load_model(quantization_type, progress_callback) | |
| if success: | |
| self.current_model = model | |
| self.current_model_name = model_name | |
| print(f"β Successfully loaded {model_name} with {quantization_type}") | |
| return True | |
| else: | |
| if progress_callback: | |
| progress_callback(f"β Failed to load {model_name}") | |
| return False | |
| except Exception as e: | |
| error_msg = f"Error loading {model_name}: {str(e)}" | |
| print(error_msg) | |
| if progress_callback: | |
| progress_callback(f"β {error_msg}") | |
| return False | |
| def unload_current_model(self) -> None: | |
| """Unload the currently loaded model.""" | |
| with self.loading_lock: | |
| if self.current_model and self.current_model.is_model_loaded(): | |
| print(f"π Unloading {self.current_model_name}...") | |
| self.current_model.unload_model() | |
| self.current_model = None | |
| self.current_model_name = None | |
| print("β Model unloaded successfully") | |
| else: | |
| print("βΉοΈ No model currently loaded") | |
| def inference(self, image_path: str, prompt: str, **kwargs) -> str: | |
| """ | |
| Perform inference using the currently loaded model. | |
| Args: | |
| image_path: Path to the image file | |
| prompt: Text prompt for the model | |
| **kwargs: Additional inference parameters | |
| Returns: | |
| Model's text response | |
| """ | |
| if not self.current_model or not self.current_model.is_model_loaded(): | |
| raise RuntimeError("No model is currently loaded. Load a model first.") | |
| return self.current_model.inference(image_path, prompt, **kwargs) | |
| def get_current_model_status(self) -> str: | |
| """Get status string for the currently loaded model.""" | |
| if not self.current_model or not self.current_model.is_model_loaded(): | |
| return "β No model loaded" | |
| quantization = self.current_model.current_quantization or "Unknown" | |
| model_class = self.current_model.__class__.__name__ | |
| return f"β {self.current_model_name} ({model_class}) loaded with {quantization}" | |
| def get_supported_quantizations(self, model_name: str) -> list[str]: | |
| """Get supported quantization methods for a model.""" | |
| if model_name not in self.models: | |
| raise KeyError(f"Model '{model_name}' not available") | |
| return self.models[model_name].get_supported_quantizations() | |
| def validate_model_and_quantization(self, model_name: str, quantization_type: str) -> bool: | |
| """ | |
| Validate if a model and quantization combination is valid. | |
| Args: | |
| model_name: Name of the model | |
| quantization_type: Type of quantization | |
| Returns: | |
| True if valid, False otherwise | |
| """ | |
| if model_name not in self.models: | |
| return False | |
| return self.models[model_name].validate_quantization(quantization_type) | |
| def get_model_memory_requirements(self, model_name: str) -> Dict[str, int]: | |
| """Get memory requirements for a specific model.""" | |
| if model_name not in self.models: | |
| raise KeyError(f"Model '{model_name}' not available") | |
| return self.models[model_name].get_memory_requirements() | |
| def preload_default_model(self) -> bool: | |
| """ | |
| Preload the default model specified in configuration. | |
| Returns: | |
| True if successful, False otherwise | |
| """ | |
| default_model = self.config_manager.get_default_model() | |
| default_quantization = self.config_manager.get_default_quantization(default_model) | |
| print(f"π Preloading default model: {default_model} with {default_quantization}") | |
| try: | |
| return self.load_model(default_model, default_quantization) | |
| except Exception as e: | |
| print(f"β οΈ Failed to preload default model: {str(e)}") | |
| return False | |
| def __str__(self) -> str: | |
| """String representation of the model manager.""" | |
| loaded_info = f"Current: {self.current_model_name}" if self.current_model_name else "None loaded" | |
| return f"ModelManager({len(self.models)} models available, {loaded_info})" | |
| def __repr__(self) -> str: | |
| """Detailed string representation.""" | |
| models_list = list(self.models.keys()) | |
| return f"ModelManager(models={models_list}, current={self.current_model_name})" |