Spaces:
Runtime error
Runtime error
| import yaml | |
| import os | |
| from pathlib import Path | |
| from typing import Dict, List, Any, Optional | |
| class ConfigManager: | |
| """Manages configuration loading and access for the application.""" | |
| def __init__(self, config_path: Optional[str] = None): | |
| """ | |
| Initialize the configuration manager. | |
| Args: | |
| config_path: Path to the configuration file. If None, uses default path. | |
| """ | |
| if config_path is None: | |
| # Default to config/models.yaml relative to project root | |
| project_root = Path(__file__).parent.parent.parent | |
| config_path = project_root / "config" / "models.yaml" | |
| self.config_path = Path(config_path) | |
| self._config = None | |
| self.load_config() | |
| def load_config(self) -> None: | |
| """Load configuration from YAML file.""" | |
| try: | |
| with open(self.config_path, 'r', encoding='utf-8') as file: | |
| self._config = yaml.safe_load(file) | |
| print(f"β Configuration loaded from {self.config_path}") | |
| except FileNotFoundError: | |
| raise FileNotFoundError(f"Configuration file not found: {self.config_path}") | |
| except yaml.YAMLError as e: | |
| raise ValueError(f"Invalid YAML in configuration file: {e}") | |
| def reload_config(self) -> None: | |
| """Reload configuration from file.""" | |
| self.load_config() | |
| def config(self) -> Dict[str, Any]: | |
| """Get the full configuration dictionary.""" | |
| if self._config is None: | |
| self.load_config() | |
| return self._config | |
| def get_available_models(self) -> Dict[str, str]: | |
| """Get a dictionary of available model names and their IDs.""" | |
| models = self.config.get('models', {}) | |
| return {name: model_config['model_id'] for name, model_config in models.items()} | |
| def get_model_config(self, model_name: str) -> Dict[str, Any]: | |
| """ | |
| Get configuration for a specific model. | |
| Args: | |
| model_name: Name of the model (e.g., 'InternVL3-8B') | |
| Returns: | |
| Model configuration dictionary | |
| Raises: | |
| KeyError: If model name is not found | |
| """ | |
| models = self.config.get('models', {}) | |
| if model_name not in models: | |
| available = list(models.keys()) | |
| raise KeyError(f"Model '{model_name}' not found. Available models: {available}") | |
| return models[model_name] | |
| def get_supported_quantizations(self, model_name: str) -> List[str]: | |
| """Get supported quantization methods for a model.""" | |
| model_config = self.get_model_config(model_name) | |
| return model_config.get('supported_quantizations', []) | |
| def get_default_quantization(self, model_name: str) -> str: | |
| """Get the default quantization method for a model.""" | |
| model_config = self.get_model_config(model_name) | |
| return model_config.get('default_quantization', 'non-quantized(fp16)') | |
| def get_default_model(self) -> str: | |
| """Get the default model name.""" | |
| return self.config.get('default_model', 'InternVL3-8B') | |
| def validate_model_and_quantization(self, model_name: str, quantization: str) -> bool: | |
| """ | |
| Validate if a quantization method is supported for a model. | |
| Args: | |
| model_name: Name of the model | |
| quantization: Quantization method | |
| Returns: | |
| True if valid, False otherwise | |
| """ | |
| try: | |
| supported = self.get_supported_quantizations(model_name) | |
| return quantization in supported | |
| except KeyError: | |
| return False | |
| def apply_environment_settings(self) -> None: | |
| """Apply environment settings to the current process.""" | |
| # Set CUDA memory allocation | |
| os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' | |
| def get_model_description(self, model_name: str) -> str: | |
| """Get description for a model.""" | |
| model_config = self.get_model_config(model_name) | |
| return model_config.get('description', 'No description available') | |
| def __str__(self) -> str: | |
| """String representation of the configuration manager.""" | |
| return f"ConfigManager(config_path={self.config_path})" | |
| def __repr__(self) -> str: | |
| """Detailed string representation.""" | |
| models = list(self.get_available_models().keys()) | |
| return f"ConfigManager(config_path={self.config_path}, models={models})" |