import yaml import os from pathlib import Path from typing import Dict, List, Any, Optional class ConfigManager: """Manages configuration loading and access for the application.""" def __init__(self, config_path: Optional[str] = None): """ Initialize the configuration manager. Args: config_path: Path to the configuration file. If None, uses default path. """ if config_path is None: # Default to config/models.yaml relative to project root project_root = Path(__file__).parent.parent.parent config_path = project_root / "config" / "models.yaml" self.config_path = Path(config_path) self._config = None self.load_config() def load_config(self) -> None: """Load configuration from YAML file.""" try: with open(self.config_path, 'r', encoding='utf-8') as file: self._config = yaml.safe_load(file) print(f"✅ Configuration loaded from {self.config_path}") except FileNotFoundError: raise FileNotFoundError(f"Configuration file not found: {self.config_path}") except yaml.YAMLError as e: raise ValueError(f"Invalid YAML in configuration file: {e}") def reload_config(self) -> None: """Reload configuration from file.""" self.load_config() @property def config(self) -> Dict[str, Any]: """Get the full configuration dictionary.""" if self._config is None: self.load_config() return self._config def get_available_models(self) -> Dict[str, str]: """Get a dictionary of available model names and their IDs.""" models = self.config.get('models', {}) return {name: model_config['model_id'] for name, model_config in models.items()} def get_model_config(self, model_name: str) -> Dict[str, Any]: """ Get configuration for a specific model. Args: model_name: Name of the model (e.g., 'InternVL3-8B') Returns: Model configuration dictionary Raises: KeyError: If model name is not found """ models = self.config.get('models', {}) if model_name not in models: available = list(models.keys()) raise KeyError(f"Model '{model_name}' not found. Available models: {available}") return models[model_name] def get_supported_quantizations(self, model_name: str) -> List[str]: """Get supported quantization methods for a model.""" model_config = self.get_model_config(model_name) return model_config.get('supported_quantizations', []) def get_default_quantization(self, model_name: str) -> str: """Get the default quantization method for a model.""" model_config = self.get_model_config(model_name) return model_config.get('default_quantization', 'non-quantized(fp16)') def get_default_model(self) -> str: """Get the default model name.""" return self.config.get('default_model', 'InternVL3-8B') def validate_model_and_quantization(self, model_name: str, quantization: str) -> bool: """ Validate if a quantization method is supported for a model. Args: model_name: Name of the model quantization: Quantization method Returns: True if valid, False otherwise """ try: supported = self.get_supported_quantizations(model_name) return quantization in supported except KeyError: return False def apply_environment_settings(self) -> None: """Apply environment settings to the current process.""" # Set CUDA memory allocation os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' def get_model_description(self, model_name: str) -> str: """Get description for a model.""" model_config = self.get_model_config(model_name) return model_config.get('description', 'No description available') def __str__(self) -> str: """String representation of the configuration manager.""" return f"ConfigManager(config_path={self.config_path})" def __repr__(self) -> str: """Detailed string representation.""" models = list(self.get_available_models().keys()) return f"ConfigManager(config_path={self.config_path}, models={models})"