Test-Prompt / backend /config /config_manager.py
abhiman181025's picture
First commit
1314bf5
import yaml
import os
from pathlib import Path
from typing import Dict, List, Any, Optional
class ConfigManager:
"""Manages configuration loading and access for the application."""
def __init__(self, config_path: Optional[str] = None):
"""
Initialize the configuration manager.
Args:
config_path: Path to the configuration file. If None, uses default path.
"""
if config_path is None:
# Default to config/models.yaml relative to project root
project_root = Path(__file__).parent.parent.parent
config_path = project_root / "config" / "models.yaml"
self.config_path = Path(config_path)
self._config = None
self.load_config()
def load_config(self) -> None:
"""Load configuration from YAML file."""
try:
with open(self.config_path, 'r', encoding='utf-8') as file:
self._config = yaml.safe_load(file)
print(f"βœ… Configuration loaded from {self.config_path}")
except FileNotFoundError:
raise FileNotFoundError(f"Configuration file not found: {self.config_path}")
except yaml.YAMLError as e:
raise ValueError(f"Invalid YAML in configuration file: {e}")
def reload_config(self) -> None:
"""Reload configuration from file."""
self.load_config()
@property
def config(self) -> Dict[str, Any]:
"""Get the full configuration dictionary."""
if self._config is None:
self.load_config()
return self._config
def get_available_models(self) -> Dict[str, str]:
"""Get a dictionary of available model names and their IDs."""
models = self.config.get('models', {})
return {name: model_config['model_id'] for name, model_config in models.items()}
def get_model_config(self, model_name: str) -> Dict[str, Any]:
"""
Get configuration for a specific model.
Args:
model_name: Name of the model (e.g., 'InternVL3-8B')
Returns:
Model configuration dictionary
Raises:
KeyError: If model name is not found
"""
models = self.config.get('models', {})
if model_name not in models:
available = list(models.keys())
raise KeyError(f"Model '{model_name}' not found. Available models: {available}")
return models[model_name]
def get_supported_quantizations(self, model_name: str) -> List[str]:
"""Get supported quantization methods for a model."""
model_config = self.get_model_config(model_name)
return model_config.get('supported_quantizations', [])
def get_default_quantization(self, model_name: str) -> str:
"""Get the default quantization method for a model."""
model_config = self.get_model_config(model_name)
return model_config.get('default_quantization', 'non-quantized(fp16)')
def get_default_model(self) -> str:
"""Get the default model name."""
return self.config.get('default_model', 'InternVL3-8B')
def validate_model_and_quantization(self, model_name: str, quantization: str) -> bool:
"""
Validate if a quantization method is supported for a model.
Args:
model_name: Name of the model
quantization: Quantization method
Returns:
True if valid, False otherwise
"""
try:
supported = self.get_supported_quantizations(model_name)
return quantization in supported
except KeyError:
return False
def apply_environment_settings(self) -> None:
"""Apply environment settings to the current process."""
# Set CUDA memory allocation
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
def get_model_description(self, model_name: str) -> str:
"""Get description for a model."""
model_config = self.get_model_config(model_name)
return model_config.get('description', 'No description available')
def __str__(self) -> str:
"""String representation of the configuration manager."""
return f"ConfigManager(config_path={self.config_path})"
def __repr__(self) -> str:
"""Detailed string representation."""
models = list(self.get_available_models().keys())
return f"ConfigManager(config_path={self.config_path}, models={models})"