Test-Prompt / backend /models /model_manager.py
abhiman181025's picture
First commit
1314bf5
import threading
from typing import Dict, Any, Optional, Callable
from .base_model import BaseModel
from .internvl import InternVLModel
from .qwen import QwenModel
from ..config.config_manager import ConfigManager
class ModelManager:
"""Manager class for handling multiple vision-language models."""
def __init__(self, config_manager: ConfigManager):
"""
Initialize the model manager.
Args:
config_manager: Configuration manager instance
"""
self.config_manager = config_manager
self.models: Dict[str, BaseModel] = {}
self.current_model: Optional[BaseModel] = None
self.current_model_name: Optional[str] = None
self.loading_lock = threading.Lock()
# Apply environment settings
self.config_manager.apply_environment_settings()
# Initialize models but don't load them yet
self._initialize_models()
def _get_model_class(self, model_config: Dict[str, Any]) -> type:
"""
Determine the appropriate model class based on model configuration.
Args:
model_config: Model configuration dictionary
Returns:
Model class to instantiate
"""
model_type = model_config.get('model_type', 'internvl').lower()
model_id = model_config.get('model_id', '').lower()
# Determine model type based on model_id or explicit model_type
if 'qwen' in model_id or model_type == 'qwen':
return QwenModel
elif 'internvl' in model_id or model_type == 'internvl':
return InternVLModel
else:
# Default to InternVL for backward compatibility
print(f"⚠️ Unknown model type for {model_config.get('name', 'unknown')}, defaulting to InternVL")
return InternVLModel
def _initialize_models(self) -> None:
"""Initialize model instances without loading them."""
available_models = self.config_manager.get_available_models()
for model_name, model_id in available_models.items():
model_config = self.config_manager.get_model_config(model_name)
# Determine the appropriate model class
model_class = self._get_model_class(model_config)
# Create model instance
self.models[model_name] = model_class(
model_name=model_name,
model_config=model_config,
config_manager=self.config_manager
)
print(f"βœ… Initialized {model_class.__name__}: {model_name}")
def get_available_models(self) -> list[str]:
"""Get list of available model names."""
return list(self.models.keys())
def get_model_info(self, model_name: str) -> Dict[str, Any]:
"""
Get information about a specific model.
Args:
model_name: Name of the model
Returns:
Model information dictionary
"""
if model_name not in self.models:
raise KeyError(f"Model '{model_name}' not available")
return self.models[model_name].get_model_info()
def get_all_models_info(self) -> Dict[str, Dict[str, Any]]:
"""Get information about all available models."""
return {name: model.get_model_info() for name, model in self.models.items()}
def load_model(
self,
model_name: str,
quantization_type: str,
progress_callback: Optional[Callable] = None
) -> bool:
"""
Load a specific model with given quantization.
Args:
model_name: Name of the model to load
quantization_type: Type of quantization to use
progress_callback: Callback function for progress updates
Returns:
True if successful, False otherwise
"""
with self.loading_lock:
if model_name not in self.models:
raise KeyError(f"Model '{model_name}' not available")
model = self.models[model_name]
# Check if this model is already loaded with the same quantization
if (self.current_model == model and
model.is_model_loaded() and
model.current_quantization == quantization_type):
if progress_callback:
progress_callback(f"βœ… {model_name} already loaded with {quantization_type}!")
return True
# Unload current model if different
if (self.current_model and
self.current_model != model and
self.current_model.is_model_loaded()):
if progress_callback:
progress_callback(f"πŸ”„ Unloading {self.current_model_name}...")
self.current_model.unload_model()
# Load the requested model
try:
success = model.load_model(quantization_type, progress_callback)
if success:
self.current_model = model
self.current_model_name = model_name
print(f"βœ… Successfully loaded {model_name} with {quantization_type}")
return True
else:
if progress_callback:
progress_callback(f"❌ Failed to load {model_name}")
return False
except Exception as e:
error_msg = f"Error loading {model_name}: {str(e)}"
print(error_msg)
if progress_callback:
progress_callback(f"❌ {error_msg}")
return False
def unload_current_model(self) -> None:
"""Unload the currently loaded model."""
with self.loading_lock:
if self.current_model and self.current_model.is_model_loaded():
print(f"πŸ”„ Unloading {self.current_model_name}...")
self.current_model.unload_model()
self.current_model = None
self.current_model_name = None
print("βœ… Model unloaded successfully")
else:
print("ℹ️ No model currently loaded")
def inference(self, image_path: str, prompt: str, **kwargs) -> str:
"""
Perform inference using the currently loaded model.
Args:
image_path: Path to the image file
prompt: Text prompt for the model
**kwargs: Additional inference parameters
Returns:
Model's text response
"""
if not self.current_model or not self.current_model.is_model_loaded():
raise RuntimeError("No model is currently loaded. Load a model first.")
return self.current_model.inference(image_path, prompt, **kwargs)
def get_current_model_status(self) -> str:
"""Get status string for the currently loaded model."""
if not self.current_model or not self.current_model.is_model_loaded():
return "❌ No model loaded"
quantization = self.current_model.current_quantization or "Unknown"
model_class = self.current_model.__class__.__name__
return f"βœ… {self.current_model_name} ({model_class}) loaded with {quantization}"
def get_supported_quantizations(self, model_name: str) -> list[str]:
"""Get supported quantization methods for a model."""
if model_name not in self.models:
raise KeyError(f"Model '{model_name}' not available")
return self.models[model_name].get_supported_quantizations()
def validate_model_and_quantization(self, model_name: str, quantization_type: str) -> bool:
"""
Validate if a model and quantization combination is valid.
Args:
model_name: Name of the model
quantization_type: Type of quantization
Returns:
True if valid, False otherwise
"""
if model_name not in self.models:
return False
return self.models[model_name].validate_quantization(quantization_type)
def get_model_memory_requirements(self, model_name: str) -> Dict[str, int]:
"""Get memory requirements for a specific model."""
if model_name not in self.models:
raise KeyError(f"Model '{model_name}' not available")
return self.models[model_name].get_memory_requirements()
def preload_default_model(self) -> bool:
"""
Preload the default model specified in configuration.
Returns:
True if successful, False otherwise
"""
default_model = self.config_manager.get_default_model()
default_quantization = self.config_manager.get_default_quantization(default_model)
print(f"πŸš€ Preloading default model: {default_model} with {default_quantization}")
try:
return self.load_model(default_model, default_quantization)
except Exception as e:
print(f"⚠️ Failed to preload default model: {str(e)}")
return False
def __str__(self) -> str:
"""String representation of the model manager."""
loaded_info = f"Current: {self.current_model_name}" if self.current_model_name else "None loaded"
return f"ModelManager({len(self.models)} models available, {loaded_info})"
def __repr__(self) -> str:
"""Detailed string representation."""
models_list = list(self.models.keys())
return f"ModelManager(models={models_list}, current={self.current_model_name})"