Spaces:
Sleeping
Sleeping
| """ | |
| Model manager for handling multiple models | |
| """ | |
| from typing import Dict, Optional | |
| import logging | |
| from config import MODELS | |
| from .musicgen_model import MusicGenModel | |
| from .audioldm_model import AudioLDMModel | |
| logger = logging.getLogger(__name__) | |
| class ModelManager: | |
| def __init__(self): | |
| self.models = {} | |
| self.current_model = None | |
| def load_model(self, model_key: str) -> bool: | |
| """Load a specific model""" | |
| try: | |
| if model_key not in MODELS: | |
| raise ValueError(f"Unknown model: {model_key}") | |
| config = MODELS[model_key] | |
| logger.info(f"Loading model: {config.name}") | |
| if "musicgen" in model_key: | |
| self.models[model_key] = MusicGenModel(config.repo_id) | |
| elif "audioldm" in model_key: | |
| self.models[model_key] = AudioLDMModel(config.repo_id) | |
| self.current_model = model_key | |
| logger.info(f"Model {config.name} loaded successfully") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Failed to load model {model_key}: {str(e)}") | |
| return False | |
| def get_model(self, model_key: Optional[str] = None): | |
| """Get a loaded model""" | |
| key = model_key or self.current_model | |
| if key not in self.models: | |
| self.load_model(key) | |
| return self.models.get(key) | |
| def list_models(self) -> Dict[str, str]: | |
| """List available models""" | |
| return {key: config.name for key, config in MODELS.items()} | |
| def get_model_info(self, model_key: str) -> Dict: | |
| """Get model information""" | |
| if model_key in MODELS: | |
| config = MODELS[model_key] | |
| return { | |
| "name": config.name, | |
| "max_duration": config.max_duration, | |
| "description": config.description, | |
| "default_params": config.default_params | |
| } | |
| return {} |