fragmenta / utils /validators.py
MazCodes's picture
Upload folder using huggingface_hub
63f0b06 verified
from typing import Any, Dict, List, Optional, Union, Callable
import re
import os
from pathlib import Path
from .logger import get_logger
from .exceptions import ValidationError
logger = get_logger(__name__)
class Validator:
@staticmethod
def required(value: Any, field_name: str = "field") -> Any:
if value is None:
raise ValidationError(field_name, None, "value is required")
if isinstance(value, str) and not value.strip():
raise ValidationError(field_name, value, "value cannot be empty")
if isinstance(value, (list, dict)) and len(value) == 0:
raise ValidationError(field_name, str(value), "value cannot be empty")
return value
@staticmethod
def string(
value: Any,
field_name: str = "field",
min_length: int = None,
max_length: int = None,
pattern: str = None
) -> str:
if not isinstance(value, str):
raise ValidationError(field_name, str(value), "must be a string")
if min_length is not None and len(value) < min_length:
raise ValidationError(
field_name,
value,
f"must be at least {min_length} characters long"
)
if max_length is not None and len(value) > max_length:
raise ValidationError(
field_name,
value,
f"must be no more than {max_length} characters long"
)
if pattern is not None and not re.match(pattern, value):
raise ValidationError(
field_name,
value,
f"must match pattern: {pattern}"
)
return value
@staticmethod
def number(
value: Any,
field_name: str = "field",
min_value: Union[int, float] = None,
max_value: Union[int, float] = None,
integer_only: bool = False
) -> Union[int, float]:
try:
if integer_only:
num_value = int(value)
else:
num_value = float(value)
except (ValueError, TypeError):
raise ValidationError(
field_name,
str(value),
f"must be a {'integer' if integer_only else 'number'}"
)
if min_value is not None and num_value < min_value:
raise ValidationError(
field_name,
str(value),
f"must be at least {min_value}"
)
if max_value is not None and num_value > max_value:
raise ValidationError(
field_name,
str(value),
f"must be no more than {max_value}"
)
return num_value
@staticmethod
def file_path(
value: Any,
field_name: str = "field",
must_exist: bool = True,
allowed_extensions: List[str] = None
) -> Path:
if not isinstance(value, (str, Path)):
raise ValidationError(field_name, str(value), "must be a valid file path")
path = Path(value)
if must_exist and not path.exists():
raise ValidationError(field_name, str(value), "file does not exist")
if allowed_extensions:
extension = path.suffix.lower()
if extension not in [ext.lower() for ext in allowed_extensions]:
raise ValidationError(
field_name,
str(value),
f"must have one of these extensions: {', '.join(allowed_extensions)}"
)
return path
@staticmethod
def choice(
value: Any,
field_name: str = "field",
choices: List[Any] = None
) -> Any:
if choices is not None and value not in choices:
raise ValidationError(
field_name,
str(value),
f"must be one of: {', '.join(str(c) for c in choices)}"
)
return value
@staticmethod
def email(value: Any, field_name: str = "field") -> str:
email_pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
if not isinstance(value, str):
raise ValidationError(field_name, str(value), "must be a string")
if not re.match(email_pattern, value):
raise ValidationError(field_name, value, "must be a valid email address")
return value.lower()
@staticmethod
def url(value: Any, field_name: str = "field") -> str:
url_pattern = r'^https?://(?:[-\w.])+(?:\:[0-9]+)?(?:/(?:[\w/_.])*(?:\?(?:[\w&=%.])*)?(?:\#(?:\w)*)?)?$'
if not isinstance(value, str):
raise ValidationError(field_name, str(value), "must be a string")
if not re.match(url_pattern, value):
raise ValidationError(field_name, value, "must be a valid URL")
return value
def validate_request_data(schema: Dict[str, Dict[str, Any]]):
def decorator(func: Callable) -> Callable:
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
return decorator
def validate_model_config(config: Dict[str, Any]) -> Dict[str, Any]:
errors = {}
required_fields = ['model_type', 'sample_rate']
for field in required_fields:
try:
Validator.required(config.get(field), field)
except ValidationError as e:
errors[field] = [str(e)]
if 'sample_rate' in config:
try:
Validator.number(
config['sample_rate'],
'sample_rate',
min_value=8000,
max_value=48000,
integer_only=True
)
except ValidationError as e:
errors['sample_rate'] = [str(e)]
if 'model_type' in config:
try:
Validator.choice(
config['model_type'],
'model_type',
choices=['autoencoder', 'diffusion', 'lm']
)
except ValidationError as e:
errors['model_type'] = [str(e)]
if errors:
logger.error(f"Model configuration validation failed: {errors}")
raise ValidationError("model_config", str(config), f"validation failed: {errors}")
return config
def validate_training_config(config: Dict[str, Any]) -> Dict[str, Any]:
errors = {}
if 'modelName' in config:
try:
Validator.string(
config['modelName'],
'modelName',
min_length=1,
max_length=100
)
except ValidationError as e:
errors['modelName'] = [str(e)]
if 'epochs' in config:
try:
Validator.number(
config['epochs'],
'epochs',
min_value=1,
max_value=1000,
integer_only=True
)
except ValidationError as e:
errors['epochs'] = [str(e)]
if 'batchSize' in config:
try:
Validator.number(
config['batchSize'],
'batchSize',
min_value=1,
max_value=64,
integer_only=True
)
except ValidationError as e:
errors['batchSize'] = [str(e)]
if 'learningRate' in config:
try:
Validator.number(
config['learningRate'],
'learningRate',
min_value=1e-6,
max_value=1e-1
)
except ValidationError as e:
errors['learningRate'] = [str(e)]
if errors:
logger.error(f"Training configuration validation failed: {errors}")
raise ValidationError("training_config", str(config), f"validation failed: {errors}")
return config