fragmenta / utils /config_validator.py
MazCodes's picture
Upload folder using huggingface_hub
63f0b06 verified
from pathlib import Path
from typing import Dict, List, Any, Optional
import json
import os
from .logger import get_logger
from .exceptions import ConfigurationError
logger = get_logger(__name__)
class ConfigValidator:
def __init__(self, config):
self.config = config
self.validation_errors = []
self.validation_warnings = []
def validate_all(self) -> Dict[str, Any]:
logger.info("Starting configuration validation...")
self.validation_errors.clear()
self.validation_warnings.clear()
self._validate_paths()
self._validate_models()
self._validate_environment()
self._validate_dependencies()
self._validate_permissions()
results = {
"valid": len(self.validation_errors) == 0,
"errors": self.validation_errors,
"warnings": self.validation_warnings,
"total_errors": len(self.validation_errors),
"total_warnings": len(self.validation_warnings)
}
if results["valid"]:
logger.info(f"Configuration validation passed ({len(self.validation_warnings)} warnings)")
else:
logger.error(f"Configuration validation failed ({len(self.validation_errors)} errors, {len(self.validation_warnings)} warnings)")
return results
def _validate_paths(self):
logger.debug("Validating paths...")
critical_paths = [
("project_root", "Project root directory"),
("models", "Models directory"),
("models_config", "Model configuration directory"),
("backend", "Backend directory"),
("frontend", "Frontend directory")
]
for path_name, description in critical_paths:
try:
path = self.config.get_path(path_name)
if not path.exists():
self._add_error(f"{description} does not exist: {path}")
elif not path.is_dir():
self._add_error(f"{description} is not a directory: {path}")
else:
logger.debug(f"{description}: {path}")
except Exception as e:
self._add_error(f"Failed to validate {description}: {e}")
optional_paths = [
("models_pretrained", "Pretrained models directory"),
("models_fine_tuned", "Fine-tuned models directory"),
("data_raw", "Raw data directory"),
("data_processed", "Processed data directory")
]
for path_name, description in optional_paths:
try:
path = self.config.get_path(path_name)
if not path.exists():
self._add_warning(f"{description} will be created: {path}")
else:
logger.debug(f"{description}: {path}")
except Exception as e:
self._add_warning(f"Could not check {description}: {e}")
def _validate_models(self):
logger.debug("Validating model configurations...")
try:
model_configs = self.config.model_configs
for model_name, model_config in model_configs.items():
config_file = Path(model_config.get("config", ""))
if not config_file.exists():
self._add_warning(f"Model config file not found for {model_name}: {config_file}")
else:
try:
with open(config_file, 'r') as f:
json.load(f)
logger.debug(f"Model config valid: {model_name}")
except json.JSONDecodeError as e:
self._add_error(f"Invalid JSON in model config {model_name}: {e}")
ckpt_file = Path(model_config.get("ckpt", ""))
if not ckpt_file.exists():
self._add_warning(f"Model checkpoint not found for {model_name}: {ckpt_file}")
else:
logger.debug(f"Model checkpoint exists: {model_name}")
except Exception as e:
self._add_error(f"Failed to validate model configurations: {e}")
def _validate_environment(self):
logger.debug("Validating environment...")
import sys
python_version = sys.version_info
if python_version < (3, 8):
self._add_error(f"Python 3.8+ required, found {python_version.major}.{python_version.minor}")
else:
logger.debug(f"Python version: {python_version.major}.{python_version.minor}.{python_version.micro}")
try:
import torch
if torch.cuda.is_available():
device_count = torch.cuda.device_count()
device_name = torch.cuda.get_device_name(0) if device_count > 0 else "Unknown"
logger.debug(f"CUDA available: {device_count} device(s), {device_name}")
else:
self._add_warning("CUDA not available, will use CPU (slower)")
except ImportError:
self._add_error("PyTorch not installed or not accessible")
env_vars = [
("HOME", "User home directory"),
("PATH", "System PATH")
]
for var_name, description in env_vars:
if not os.environ.get(var_name):
self._add_warning(f"Environment variable not set: {var_name} ({description})")
def _validate_dependencies(self):
logger.debug("Validating dependencies...")
required_packages = [
("torch", "PyTorch"),
("torchaudio", "TorchAudio"),
("flask", "Flask"),
("transformers", "Transformers"),
("diffusers", "Diffusers"),
("librosa", "Librosa"),
("soundfile", "SoundFile"),
("numpy", "NumPy"),
("scipy", "SciPy")
]
for package_name, description in required_packages:
try:
__import__(package_name)
logger.debug(f"{description} available")
except ImportError:
self._add_error(f"Required package not installed: {package_name} ({description})")
optional_packages = [
("wandb", "Weights & Biases"),
("gradio", "Gradio"),
("matplotlib", "Matplotlib")
]
for package_name, description in optional_packages:
try:
__import__(package_name)
logger.debug(f"{description} available")
except ImportError:
self._add_warning(f"Optional package not installed: {package_name} ({description})")
def _validate_permissions(self):
logger.debug("Validating permissions...")
write_dirs = [
("models", "Models directory"),
("data_raw", "Raw data directory"),
("data_processed", "Processed data directory")
]
for path_name, description in write_dirs:
try:
path = self.config.get_path(path_name)
path.mkdir(exist_ok=True, parents=True)
test_file = path / ".permission_test"
try:
test_file.write_text("test")
test_file.unlink()
logger.debug(f"Write permission: {description}")
except PermissionError:
self._add_error(f"No write permission for {description}: {path}")
except Exception as e:
self._add_error(f"Failed to check permissions for {description}: {e}")
def _add_error(self, message: str):
self.validation_errors.append(message)
logger.error(f"Validation Error: {message}")
def _add_warning(self, message: str):
self.validation_warnings.append(message)
logger.warning(f"Validation Warning: {message}")
def validate_config(config) -> Dict[str, Any]:
validator = ConfigValidator(config)
return validator.validate_all()
def ensure_config_valid(config) -> bool:
results = validate_config(config)
if not results["valid"]:
error_messages = "\n".join(results["errors"])
raise ConfigurationError(
"configuration_validation",
"valid configuration",
f"{results['total_errors']} validation errors"
)
return True