File size: 4,615 Bytes
1314bf5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import yaml
import os
from pathlib import Path
from typing import Dict, List, Any, Optional

class ConfigManager:
    """Manages configuration loading and access for the application."""
    
    def __init__(self, config_path: Optional[str] = None):
        """
        Initialize the configuration manager.
        
        Args:
            config_path: Path to the configuration file. If None, uses default path.
        """
        if config_path is None:
            # Default to config/models.yaml relative to project root
            project_root = Path(__file__).parent.parent.parent
            config_path = project_root / "config" / "models.yaml"
        
        self.config_path = Path(config_path)
        self._config = None
        self.load_config()
    
    def load_config(self) -> None:
        """Load configuration from YAML file."""
        try:
            with open(self.config_path, 'r', encoding='utf-8') as file:
                self._config = yaml.safe_load(file)
            print(f"βœ… Configuration loaded from {self.config_path}")
        except FileNotFoundError:
            raise FileNotFoundError(f"Configuration file not found: {self.config_path}")
        except yaml.YAMLError as e:
            raise ValueError(f"Invalid YAML in configuration file: {e}")
    
    def reload_config(self) -> None:
        """Reload configuration from file."""
        self.load_config()
    
    @property
    def config(self) -> Dict[str, Any]:
        """Get the full configuration dictionary."""
        if self._config is None:
            self.load_config()
        return self._config
    
    def get_available_models(self) -> Dict[str, str]:
        """Get a dictionary of available model names and their IDs."""
        models = self.config.get('models', {})
        return {name: model_config['model_id'] for name, model_config in models.items()}
    
    def get_model_config(self, model_name: str) -> Dict[str, Any]:
        """
        Get configuration for a specific model.
        
        Args:
            model_name: Name of the model (e.g., 'InternVL3-8B')
            
        Returns:
            Model configuration dictionary
            
        Raises:
            KeyError: If model name is not found
        """
        models = self.config.get('models', {})
        if model_name not in models:
            available = list(models.keys())
            raise KeyError(f"Model '{model_name}' not found. Available models: {available}")
        
        return models[model_name]
    
    def get_supported_quantizations(self, model_name: str) -> List[str]:
        """Get supported quantization methods for a model."""
        model_config = self.get_model_config(model_name)
        return model_config.get('supported_quantizations', [])
    
    def get_default_quantization(self, model_name: str) -> str:
        """Get the default quantization method for a model."""
        model_config = self.get_model_config(model_name)
        return model_config.get('default_quantization', 'non-quantized(fp16)')
    
    def get_default_model(self) -> str:
        """Get the default model name."""
        return self.config.get('default_model', 'InternVL3-8B')
    
    def validate_model_and_quantization(self, model_name: str, quantization: str) -> bool:
        """
        Validate if a quantization method is supported for a model.
        
        Args:
            model_name: Name of the model
            quantization: Quantization method
            
        Returns:
            True if valid, False otherwise
        """
        try:
            supported = self.get_supported_quantizations(model_name)
            return quantization in supported
        except KeyError:
            return False
    
    def apply_environment_settings(self) -> None:
        """Apply environment settings to the current process."""
        # Set CUDA memory allocation
        os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
    
    def get_model_description(self, model_name: str) -> str:
        """Get description for a model."""
        model_config = self.get_model_config(model_name)
        return model_config.get('description', 'No description available')
    
    def __str__(self) -> str:
        """String representation of the configuration manager."""
        return f"ConfigManager(config_path={self.config_path})"
    
    def __repr__(self) -> str:
        """Detailed string representation."""
        models = list(self.get_available_models().keys())
        return f"ConfigManager(config_path={self.config_path}, models={models})"