| from pathlib import Path | |
| from typing import Dict, List, Any, Optional | |
| import json | |
| import os | |
| from .logger import get_logger | |
| from .exceptions import ConfigurationError | |
| logger = get_logger(__name__) | |
| class ConfigValidator: | |
| def __init__(self, config): | |
| self.config = config | |
| self.validation_errors = [] | |
| self.validation_warnings = [] | |
| def validate_all(self) -> Dict[str, Any]: | |
| logger.info("Starting configuration validation...") | |
| self.validation_errors.clear() | |
| self.validation_warnings.clear() | |
| self._validate_paths() | |
| self._validate_models() | |
| self._validate_environment() | |
| self._validate_dependencies() | |
| self._validate_permissions() | |
| results = { | |
| "valid": len(self.validation_errors) == 0, | |
| "errors": self.validation_errors, | |
| "warnings": self.validation_warnings, | |
| "total_errors": len(self.validation_errors), | |
| "total_warnings": len(self.validation_warnings) | |
| } | |
| if results["valid"]: | |
| logger.info(f"Configuration validation passed ({len(self.validation_warnings)} warnings)") | |
| else: | |
| logger.error(f"Configuration validation failed ({len(self.validation_errors)} errors, {len(self.validation_warnings)} warnings)") | |
| return results | |
| def _validate_paths(self): | |
| logger.debug("Validating paths...") | |
| critical_paths = [ | |
| ("project_root", "Project root directory"), | |
| ("models", "Models directory"), | |
| ("models_config", "Model configuration directory"), | |
| ("backend", "Backend directory"), | |
| ("frontend", "Frontend directory") | |
| ] | |
| for path_name, description in critical_paths: | |
| try: | |
| path = self.config.get_path(path_name) | |
| if not path.exists(): | |
| self._add_error(f"{description} does not exist: {path}") | |
| elif not path.is_dir(): | |
| self._add_error(f"{description} is not a directory: {path}") | |
| else: | |
| logger.debug(f"{description}: {path}") | |
| except Exception as e: | |
| self._add_error(f"Failed to validate {description}: {e}") | |
| optional_paths = [ | |
| ("models_pretrained", "Pretrained models directory"), | |
| ("models_fine_tuned", "Fine-tuned models directory"), | |
| ("data_raw", "Raw data directory"), | |
| ("data_processed", "Processed data directory") | |
| ] | |
| for path_name, description in optional_paths: | |
| try: | |
| path = self.config.get_path(path_name) | |
| if not path.exists(): | |
| self._add_warning(f"{description} will be created: {path}") | |
| else: | |
| logger.debug(f"{description}: {path}") | |
| except Exception as e: | |
| self._add_warning(f"Could not check {description}: {e}") | |
| def _validate_models(self): | |
| logger.debug("Validating model configurations...") | |
| try: | |
| model_configs = self.config.model_configs | |
| for model_name, model_config in model_configs.items(): | |
| config_file = Path(model_config.get("config", "")) | |
| if not config_file.exists(): | |
| self._add_warning(f"Model config file not found for {model_name}: {config_file}") | |
| else: | |
| try: | |
| with open(config_file, 'r') as f: | |
| json.load(f) | |
| logger.debug(f"Model config valid: {model_name}") | |
| except json.JSONDecodeError as e: | |
| self._add_error(f"Invalid JSON in model config {model_name}: {e}") | |
| ckpt_file = Path(model_config.get("ckpt", "")) | |
| if not ckpt_file.exists(): | |
| self._add_warning(f"Model checkpoint not found for {model_name}: {ckpt_file}") | |
| else: | |
| logger.debug(f"Model checkpoint exists: {model_name}") | |
| except Exception as e: | |
| self._add_error(f"Failed to validate model configurations: {e}") | |
| def _validate_environment(self): | |
| logger.debug("Validating environment...") | |
| import sys | |
| python_version = sys.version_info | |
| if python_version < (3, 8): | |
| self._add_error(f"Python 3.8+ required, found {python_version.major}.{python_version.minor}") | |
| else: | |
| logger.debug(f"Python version: {python_version.major}.{python_version.minor}.{python_version.micro}") | |
| try: | |
| import torch | |
| if torch.cuda.is_available(): | |
| device_count = torch.cuda.device_count() | |
| device_name = torch.cuda.get_device_name(0) if device_count > 0 else "Unknown" | |
| logger.debug(f"CUDA available: {device_count} device(s), {device_name}") | |
| else: | |
| self._add_warning("CUDA not available, will use CPU (slower)") | |
| except ImportError: | |
| self._add_error("PyTorch not installed or not accessible") | |
| env_vars = [ | |
| ("HOME", "User home directory"), | |
| ("PATH", "System PATH") | |
| ] | |
| for var_name, description in env_vars: | |
| if not os.environ.get(var_name): | |
| self._add_warning(f"Environment variable not set: {var_name} ({description})") | |
| def _validate_dependencies(self): | |
| logger.debug("Validating dependencies...") | |
| required_packages = [ | |
| ("torch", "PyTorch"), | |
| ("torchaudio", "TorchAudio"), | |
| ("flask", "Flask"), | |
| ("transformers", "Transformers"), | |
| ("diffusers", "Diffusers"), | |
| ("librosa", "Librosa"), | |
| ("soundfile", "SoundFile"), | |
| ("numpy", "NumPy"), | |
| ("scipy", "SciPy") | |
| ] | |
| for package_name, description in required_packages: | |
| try: | |
| __import__(package_name) | |
| logger.debug(f"{description} available") | |
| except ImportError: | |
| self._add_error(f"Required package not installed: {package_name} ({description})") | |
| optional_packages = [ | |
| ("wandb", "Weights & Biases"), | |
| ("gradio", "Gradio"), | |
| ("matplotlib", "Matplotlib") | |
| ] | |
| for package_name, description in optional_packages: | |
| try: | |
| __import__(package_name) | |
| logger.debug(f"{description} available") | |
| except ImportError: | |
| self._add_warning(f"Optional package not installed: {package_name} ({description})") | |
| def _validate_permissions(self): | |
| logger.debug("Validating permissions...") | |
| write_dirs = [ | |
| ("models", "Models directory"), | |
| ("data_raw", "Raw data directory"), | |
| ("data_processed", "Processed data directory") | |
| ] | |
| for path_name, description in write_dirs: | |
| try: | |
| path = self.config.get_path(path_name) | |
| path.mkdir(exist_ok=True, parents=True) | |
| test_file = path / ".permission_test" | |
| try: | |
| test_file.write_text("test") | |
| test_file.unlink() | |
| logger.debug(f"Write permission: {description}") | |
| except PermissionError: | |
| self._add_error(f"No write permission for {description}: {path}") | |
| except Exception as e: | |
| self._add_error(f"Failed to check permissions for {description}: {e}") | |
| def _add_error(self, message: str): | |
| self.validation_errors.append(message) | |
| logger.error(f"Validation Error: {message}") | |
| def _add_warning(self, message: str): | |
| self.validation_warnings.append(message) | |
| logger.warning(f"Validation Warning: {message}") | |
| def validate_config(config) -> Dict[str, Any]: | |
| validator = ConfigValidator(config) | |
| return validator.validate_all() | |
| def ensure_config_valid(config) -> bool: | |
| results = validate_config(config) | |
| if not results["valid"]: | |
| error_messages = "\n".join(results["errors"]) | |
| raise ConfigurationError( | |
| "configuration_validation", | |
| "valid configuration", | |
| f"{results['total_errors']} validation errors" | |
| ) | |
| return True |