File size: 8,627 Bytes
63f0b06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
from pathlib import Path
from typing import Dict, List, Any, Optional
import json
import os

from .logger import get_logger
from .exceptions import ConfigurationError

logger = get_logger(__name__)

class ConfigValidator:
    
    def __init__(self, config):
        self.config = config
        self.validation_errors = []
        self.validation_warnings = []
    
    def validate_all(self) -> Dict[str, Any]:
        logger.info("Starting configuration validation...")
        
        self.validation_errors.clear()
        self.validation_warnings.clear()
        
        self._validate_paths()
        self._validate_models()
        self._validate_environment()
        self._validate_dependencies()
        self._validate_permissions()
        
        results = {
            "valid": len(self.validation_errors) == 0,
            "errors": self.validation_errors,
            "warnings": self.validation_warnings,
            "total_errors": len(self.validation_errors),
            "total_warnings": len(self.validation_warnings)
        }
        
        if results["valid"]:
            logger.info(f"Configuration validation passed ({len(self.validation_warnings)} warnings)")
        else:
            logger.error(f"Configuration validation failed ({len(self.validation_errors)} errors, {len(self.validation_warnings)} warnings)")
        
        return results
    
    def _validate_paths(self):
        logger.debug("Validating paths...")
        
        critical_paths = [
            ("project_root", "Project root directory"),
            ("models", "Models directory"),
            ("models_config", "Model configuration directory"),
            ("backend", "Backend directory"),
            ("frontend", "Frontend directory")
        ]
        
        for path_name, description in critical_paths:
            try:
                path = self.config.get_path(path_name)
                if not path.exists():
                    self._add_error(f"{description} does not exist: {path}")
                elif not path.is_dir():
                    self._add_error(f"{description} is not a directory: {path}")
                else:
                    logger.debug(f"{description}: {path}")
            except Exception as e:
                self._add_error(f"Failed to validate {description}: {e}")
        
        optional_paths = [
            ("models_pretrained", "Pretrained models directory"),
            ("models_fine_tuned", "Fine-tuned models directory"),
            ("data_raw", "Raw data directory"),
            ("data_processed", "Processed data directory")
        ]
        
        for path_name, description in optional_paths:
            try:
                path = self.config.get_path(path_name)
                if not path.exists():
                    self._add_warning(f"{description} will be created: {path}")
                else:
                    logger.debug(f"{description}: {path}")
            except Exception as e:
                self._add_warning(f"Could not check {description}: {e}")
    
    def _validate_models(self):
        logger.debug("Validating model configurations...")
        
        try:
            model_configs = self.config.model_configs
            
            for model_name, model_config in model_configs.items():
                config_file = Path(model_config.get("config", ""))
                if not config_file.exists():
                    self._add_warning(f"Model config file not found for {model_name}: {config_file}")
                else:
                    try:
                        with open(config_file, 'r') as f:
                            json.load(f)
                        logger.debug(f"Model config valid: {model_name}")
                    except json.JSONDecodeError as e:
                        self._add_error(f"Invalid JSON in model config {model_name}: {e}")
                
                ckpt_file = Path(model_config.get("ckpt", ""))
                if not ckpt_file.exists():
                    self._add_warning(f"Model checkpoint not found for {model_name}: {ckpt_file}")
                else:
                    logger.debug(f"Model checkpoint exists: {model_name}")
        
        except Exception as e:
            self._add_error(f"Failed to validate model configurations: {e}")
    
    def _validate_environment(self):
        logger.debug("Validating environment...")
        
        import sys
        python_version = sys.version_info
        if python_version < (3, 8):
            self._add_error(f"Python 3.8+ required, found {python_version.major}.{python_version.minor}")
        else:
            logger.debug(f"Python version: {python_version.major}.{python_version.minor}.{python_version.micro}")
        
        try:
            import torch
            if torch.cuda.is_available():
                device_count = torch.cuda.device_count()
                device_name = torch.cuda.get_device_name(0) if device_count > 0 else "Unknown"
                logger.debug(f"CUDA available: {device_count} device(s), {device_name}")
            else:
                self._add_warning("CUDA not available, will use CPU (slower)")
        except ImportError:
            self._add_error("PyTorch not installed or not accessible")
        
        env_vars = [
            ("HOME", "User home directory"),
            ("PATH", "System PATH")
        ]
        
        for var_name, description in env_vars:
            if not os.environ.get(var_name):
                self._add_warning(f"Environment variable not set: {var_name} ({description})")
    
    def _validate_dependencies(self):
        logger.debug("Validating dependencies...")
        
        required_packages = [
            ("torch", "PyTorch"),
            ("torchaudio", "TorchAudio"),
            ("flask", "Flask"),
            ("transformers", "Transformers"),
            ("diffusers", "Diffusers"),
            ("librosa", "Librosa"),
            ("soundfile", "SoundFile"),
            ("numpy", "NumPy"),
            ("scipy", "SciPy")
        ]
        
        for package_name, description in required_packages:
            try:
                __import__(package_name)
                logger.debug(f"{description} available")
            except ImportError:
                self._add_error(f"Required package not installed: {package_name} ({description})")
        
        optional_packages = [
            ("wandb", "Weights & Biases"),
            ("gradio", "Gradio"),
            ("matplotlib", "Matplotlib")
        ]
        
        for package_name, description in optional_packages:
            try:
                __import__(package_name)
                logger.debug(f"{description} available")
            except ImportError:
                self._add_warning(f"Optional package not installed: {package_name} ({description})")
    
    def _validate_permissions(self):
        logger.debug("Validating permissions...")
        
        write_dirs = [
            ("models", "Models directory"),
            ("data_raw", "Raw data directory"),
            ("data_processed", "Processed data directory")
        ]
        
        for path_name, description in write_dirs:
            try:
                path = self.config.get_path(path_name)
                path.mkdir(exist_ok=True, parents=True)
                
                test_file = path / ".permission_test"
                try:
                    test_file.write_text("test")
                    test_file.unlink()
                    logger.debug(f"Write permission: {description}")
                except PermissionError:
                    self._add_error(f"No write permission for {description}: {path}")
            except Exception as e:
                self._add_error(f"Failed to check permissions for {description}: {e}")
    
    def _add_error(self, message: str):
        self.validation_errors.append(message)
        logger.error(f"Validation Error: {message}")
    
    def _add_warning(self, message: str):
        self.validation_warnings.append(message)
        logger.warning(f"Validation Warning: {message}")

def validate_config(config) -> Dict[str, Any]:
    validator = ConfigValidator(config)
    return validator.validate_all()

def ensure_config_valid(config) -> bool:
    results = validate_config(config)
    
    if not results["valid"]:
        error_messages = "\n".join(results["errors"])
        raise ConfigurationError(
            "configuration_validation",
            "valid configuration",
            f"{results['total_errors']} validation errors"
        )
    
    return True