File size: 8,627 Bytes
63f0b06 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 | from pathlib import Path
from typing import Dict, List, Any, Optional
import json
import os
from .logger import get_logger
from .exceptions import ConfigurationError
logger = get_logger(__name__)
class ConfigValidator:
def __init__(self, config):
self.config = config
self.validation_errors = []
self.validation_warnings = []
def validate_all(self) -> Dict[str, Any]:
logger.info("Starting configuration validation...")
self.validation_errors.clear()
self.validation_warnings.clear()
self._validate_paths()
self._validate_models()
self._validate_environment()
self._validate_dependencies()
self._validate_permissions()
results = {
"valid": len(self.validation_errors) == 0,
"errors": self.validation_errors,
"warnings": self.validation_warnings,
"total_errors": len(self.validation_errors),
"total_warnings": len(self.validation_warnings)
}
if results["valid"]:
logger.info(f"Configuration validation passed ({len(self.validation_warnings)} warnings)")
else:
logger.error(f"Configuration validation failed ({len(self.validation_errors)} errors, {len(self.validation_warnings)} warnings)")
return results
def _validate_paths(self):
logger.debug("Validating paths...")
critical_paths = [
("project_root", "Project root directory"),
("models", "Models directory"),
("models_config", "Model configuration directory"),
("backend", "Backend directory"),
("frontend", "Frontend directory")
]
for path_name, description in critical_paths:
try:
path = self.config.get_path(path_name)
if not path.exists():
self._add_error(f"{description} does not exist: {path}")
elif not path.is_dir():
self._add_error(f"{description} is not a directory: {path}")
else:
logger.debug(f"{description}: {path}")
except Exception as e:
self._add_error(f"Failed to validate {description}: {e}")
optional_paths = [
("models_pretrained", "Pretrained models directory"),
("models_fine_tuned", "Fine-tuned models directory"),
("data_raw", "Raw data directory"),
("data_processed", "Processed data directory")
]
for path_name, description in optional_paths:
try:
path = self.config.get_path(path_name)
if not path.exists():
self._add_warning(f"{description} will be created: {path}")
else:
logger.debug(f"{description}: {path}")
except Exception as e:
self._add_warning(f"Could not check {description}: {e}")
def _validate_models(self):
logger.debug("Validating model configurations...")
try:
model_configs = self.config.model_configs
for model_name, model_config in model_configs.items():
config_file = Path(model_config.get("config", ""))
if not config_file.exists():
self._add_warning(f"Model config file not found for {model_name}: {config_file}")
else:
try:
with open(config_file, 'r') as f:
json.load(f)
logger.debug(f"Model config valid: {model_name}")
except json.JSONDecodeError as e:
self._add_error(f"Invalid JSON in model config {model_name}: {e}")
ckpt_file = Path(model_config.get("ckpt", ""))
if not ckpt_file.exists():
self._add_warning(f"Model checkpoint not found for {model_name}: {ckpt_file}")
else:
logger.debug(f"Model checkpoint exists: {model_name}")
except Exception as e:
self._add_error(f"Failed to validate model configurations: {e}")
def _validate_environment(self):
logger.debug("Validating environment...")
import sys
python_version = sys.version_info
if python_version < (3, 8):
self._add_error(f"Python 3.8+ required, found {python_version.major}.{python_version.minor}")
else:
logger.debug(f"Python version: {python_version.major}.{python_version.minor}.{python_version.micro}")
try:
import torch
if torch.cuda.is_available():
device_count = torch.cuda.device_count()
device_name = torch.cuda.get_device_name(0) if device_count > 0 else "Unknown"
logger.debug(f"CUDA available: {device_count} device(s), {device_name}")
else:
self._add_warning("CUDA not available, will use CPU (slower)")
except ImportError:
self._add_error("PyTorch not installed or not accessible")
env_vars = [
("HOME", "User home directory"),
("PATH", "System PATH")
]
for var_name, description in env_vars:
if not os.environ.get(var_name):
self._add_warning(f"Environment variable not set: {var_name} ({description})")
def _validate_dependencies(self):
logger.debug("Validating dependencies...")
required_packages = [
("torch", "PyTorch"),
("torchaudio", "TorchAudio"),
("flask", "Flask"),
("transformers", "Transformers"),
("diffusers", "Diffusers"),
("librosa", "Librosa"),
("soundfile", "SoundFile"),
("numpy", "NumPy"),
("scipy", "SciPy")
]
for package_name, description in required_packages:
try:
__import__(package_name)
logger.debug(f"{description} available")
except ImportError:
self._add_error(f"Required package not installed: {package_name} ({description})")
optional_packages = [
("wandb", "Weights & Biases"),
("gradio", "Gradio"),
("matplotlib", "Matplotlib")
]
for package_name, description in optional_packages:
try:
__import__(package_name)
logger.debug(f"{description} available")
except ImportError:
self._add_warning(f"Optional package not installed: {package_name} ({description})")
def _validate_permissions(self):
logger.debug("Validating permissions...")
write_dirs = [
("models", "Models directory"),
("data_raw", "Raw data directory"),
("data_processed", "Processed data directory")
]
for path_name, description in write_dirs:
try:
path = self.config.get_path(path_name)
path.mkdir(exist_ok=True, parents=True)
test_file = path / ".permission_test"
try:
test_file.write_text("test")
test_file.unlink()
logger.debug(f"Write permission: {description}")
except PermissionError:
self._add_error(f"No write permission for {description}: {path}")
except Exception as e:
self._add_error(f"Failed to check permissions for {description}: {e}")
def _add_error(self, message: str):
self.validation_errors.append(message)
logger.error(f"Validation Error: {message}")
def _add_warning(self, message: str):
self.validation_warnings.append(message)
logger.warning(f"Validation Warning: {message}")
def validate_config(config) -> Dict[str, Any]:
validator = ConfigValidator(config)
return validator.validate_all()
def ensure_config_valid(config) -> bool:
results = validate_config(config)
if not results["valid"]:
error_messages = "\n".join(results["errors"])
raise ConfigurationError(
"configuration_validation",
"valid configuration",
f"{results['total_errors']} validation errors"
)
return True |