Spaces:
Running
Running
| import inspect | |
| from dataclasses import _MISSING_TYPE, MISSING, Field, field, fields | |
| from functools import wraps | |
| from typing import ( | |
| Any, | |
| Callable, | |
| Dict, | |
| ForwardRef, | |
| List, | |
| Literal, | |
| Optional, | |
| Tuple, | |
| Type, | |
| TypeVar, | |
| Union, | |
| get_args, | |
| get_origin, | |
| overload, | |
| ) | |
| from .errors import ( | |
| StrictDataclassClassValidationError, | |
| StrictDataclassDefinitionError, | |
| StrictDataclassFieldValidationError, | |
| ) | |
| Validator_T = Callable[[Any], None] | |
| T = TypeVar("T") | |
| # The overload decorator helps type checkers understand the different return types | |
| def strict(cls: Type[T]) -> Type[T]: ... | |
| def strict(*, accept_kwargs: bool = False) -> Callable[[Type[T]], Type[T]]: ... | |
| def strict( | |
| cls: Optional[Type[T]] = None, *, accept_kwargs: bool = False | |
| ) -> Union[Type[T], Callable[[Type[T]], Type[T]]]: | |
| """ | |
| Decorator to add strict validation to a dataclass. | |
| This decorator must be used on top of `@dataclass` to ensure IDEs and static typing tools | |
| recognize the class as a dataclass. | |
| Can be used with or without arguments: | |
| - `@strict` | |
| - `@strict(accept_kwargs=True)` | |
| Args: | |
| cls: | |
| The class to convert to a strict dataclass. | |
| accept_kwargs (`bool`, *optional*): | |
| If True, allows arbitrary keyword arguments in `__init__`. Defaults to False. | |
| Returns: | |
| The enhanced dataclass with strict validation on field assignment. | |
| Example: | |
| ```py | |
| >>> from dataclasses import dataclass | |
| >>> from huggingface_hub.dataclasses import as_validated_field, strict, validated_field | |
| >>> @as_validated_field | |
| >>> def positive_int(value: int): | |
| ... if not value >= 0: | |
| ... raise ValueError(f"Value must be positive, got {value}") | |
| >>> @strict(accept_kwargs=True) | |
| ... @dataclass | |
| ... class User: | |
| ... name: str | |
| ... age: int = positive_int(default=10) | |
| # Initialize | |
| >>> User(name="John") | |
| User(name='John', age=10) | |
| # Extra kwargs are accepted | |
| >>> User(name="John", age=30, lastname="Doe") | |
| User(name='John', age=30, *lastname='Doe') | |
| # Invalid type => raises | |
| >>> User(name="John", age="30") | |
| huggingface_hub.errors.StrictDataclassFieldValidationError: Validation error for field 'age': | |
| TypeError: Field 'age' expected int, got str (value: '30') | |
| # Invalid value => raises | |
| >>> User(name="John", age=-1) | |
| huggingface_hub.errors.StrictDataclassFieldValidationError: Validation error for field 'age': | |
| ValueError: Value must be positive, got -1 | |
| ``` | |
| """ | |
| def wrap(cls: Type[T]) -> Type[T]: | |
| if not hasattr(cls, "__dataclass_fields__"): | |
| raise StrictDataclassDefinitionError( | |
| f"Class '{cls.__name__}' must be a dataclass before applying @strict." | |
| ) | |
| # List and store validators | |
| field_validators: Dict[str, List[Validator_T]] = {} | |
| for f in fields(cls): # type: ignore [arg-type] | |
| validators = [] | |
| validators.append(_create_type_validator(f)) | |
| custom_validator = f.metadata.get("validator") | |
| if custom_validator is not None: | |
| if not isinstance(custom_validator, list): | |
| custom_validator = [custom_validator] | |
| for validator in custom_validator: | |
| if not _is_validator(validator): | |
| raise StrictDataclassDefinitionError( | |
| f"Invalid validator for field '{f.name}': {validator}. Must be a callable taking a single argument." | |
| ) | |
| validators.extend(custom_validator) | |
| field_validators[f.name] = validators | |
| cls.__validators__ = field_validators # type: ignore | |
| # Override __setattr__ to validate fields on assignment | |
| original_setattr = cls.__setattr__ | |
| def __strict_setattr__(self: Any, name: str, value: Any) -> None: | |
| """Custom __setattr__ method for strict dataclasses.""" | |
| # Run all validators | |
| for validator in self.__validators__.get(name, []): | |
| try: | |
| validator(value) | |
| except (ValueError, TypeError) as e: | |
| raise StrictDataclassFieldValidationError(field=name, cause=e) from e | |
| # If validation passed, set the attribute | |
| original_setattr(self, name, value) | |
| cls.__setattr__ = __strict_setattr__ # type: ignore[method-assign] | |
| if accept_kwargs: | |
| # (optional) Override __init__ to accept arbitrary keyword arguments | |
| original_init = cls.__init__ | |
| def __init__(self, **kwargs: Any) -> None: | |
| # Extract only the fields that are part of the dataclass | |
| dataclass_fields = {f.name for f in fields(cls)} # type: ignore [arg-type] | |
| standard_kwargs = {k: v for k, v in kwargs.items() if k in dataclass_fields} | |
| # Call the original __init__ with standard fields | |
| original_init(self, **standard_kwargs) | |
| # Add any additional kwargs as attributes | |
| for name, value in kwargs.items(): | |
| if name not in dataclass_fields: | |
| self.__setattr__(name, value) | |
| cls.__init__ = __init__ # type: ignore[method-assign] | |
| # (optional) Override __repr__ to include additional kwargs | |
| original_repr = cls.__repr__ | |
| def __repr__(self) -> str: | |
| # Call the original __repr__ to get the standard fields | |
| standard_repr = original_repr(self) | |
| # Get additional kwargs | |
| additional_kwargs = [ | |
| # add a '*' in front of additional kwargs to let the user know they are not part of the dataclass | |
| f"*{k}={v!r}" | |
| for k, v in self.__dict__.items() | |
| if k not in cls.__dataclass_fields__ # type: ignore [attr-defined] | |
| ] | |
| additional_repr = ", ".join(additional_kwargs) | |
| # Combine both representations | |
| return f"{standard_repr[:-1]}, {additional_repr})" if additional_kwargs else standard_repr | |
| cls.__repr__ = __repr__ # type: ignore [method-assign] | |
| # List all public methods starting with `validate_` => class validators. | |
| class_validators = [] | |
| for name in dir(cls): | |
| if not name.startswith("validate_"): | |
| continue | |
| method = getattr(cls, name) | |
| if not callable(method): | |
| continue | |
| if len(inspect.signature(method).parameters) != 1: | |
| raise StrictDataclassDefinitionError( | |
| f"Class '{cls.__name__}' has a class validator '{name}' that takes more than one argument." | |
| " Class validators must take only 'self' as an argument. Methods starting with 'validate_'" | |
| " are considered to be class validators." | |
| ) | |
| class_validators.append(method) | |
| cls.__class_validators__ = class_validators # type: ignore [attr-defined] | |
| # Add `validate` method to the class, but first check if it already exists | |
| def validate(self: T) -> None: | |
| """Run class validators on the instance.""" | |
| for validator in cls.__class_validators__: # type: ignore [attr-defined] | |
| try: | |
| validator(self) | |
| except (ValueError, TypeError) as e: | |
| raise StrictDataclassClassValidationError(validator=validator.__name__, cause=e) from e | |
| # Hack to be able to raise if `.validate()` already exists except if it was created by this decorator on a parent class | |
| # (in which case we just override it) | |
| validate.__is_defined_by_strict_decorator__ = True # type: ignore [attr-defined] | |
| if hasattr(cls, "validate"): | |
| if not getattr(cls.validate, "__is_defined_by_strict_decorator__", False): # type: ignore [attr-defined] | |
| raise StrictDataclassDefinitionError( | |
| f"Class '{cls.__name__}' already implements a method called 'validate'." | |
| " This method name is reserved when using the @strict decorator on a dataclass." | |
| " If you want to keep your own method, please rename it." | |
| ) | |
| cls.validate = validate # type: ignore | |
| # Run class validators after initialization | |
| initial_init = cls.__init__ | |
| def init_with_validate(self, *args, **kwargs) -> None: | |
| """Run class validators after initialization.""" | |
| initial_init(self, *args, **kwargs) # type: ignore [call-arg] | |
| cls.validate(self) # type: ignore [attr-defined] | |
| setattr(cls, "__init__", init_with_validate) | |
| return cls | |
| # Return wrapped class or the decorator itself | |
| return wrap(cls) if cls is not None else wrap | |
| def validated_field( | |
| validator: Union[List[Validator_T], Validator_T], | |
| default: Union[Any, _MISSING_TYPE] = MISSING, | |
| default_factory: Union[Callable[[], Any], _MISSING_TYPE] = MISSING, | |
| init: bool = True, | |
| repr: bool = True, | |
| hash: Optional[bool] = None, | |
| compare: bool = True, | |
| metadata: Optional[Dict] = None, | |
| **kwargs: Any, | |
| ) -> Any: | |
| """ | |
| Create a dataclass field with a custom validator. | |
| Useful to apply several checks to a field. If only applying one rule, check out the [`as_validated_field`] decorator. | |
| Args: | |
| validator (`Callable` or `List[Callable]`): | |
| A method that takes a value as input and raises ValueError/TypeError if the value is invalid. | |
| Can be a list of validators to apply multiple checks. | |
| **kwargs: | |
| Additional arguments to pass to `dataclasses.field()`. | |
| Returns: | |
| A field with the validator attached in metadata | |
| """ | |
| if not isinstance(validator, list): | |
| validator = [validator] | |
| if metadata is None: | |
| metadata = {} | |
| metadata["validator"] = validator | |
| return field( # type: ignore | |
| default=default, # type: ignore [arg-type] | |
| default_factory=default_factory, # type: ignore [arg-type] | |
| init=init, | |
| repr=repr, | |
| hash=hash, | |
| compare=compare, | |
| metadata=metadata, | |
| **kwargs, | |
| ) | |
| def as_validated_field(validator: Validator_T): | |
| """ | |
| Decorates a validator function as a [`validated_field`] (i.e. a dataclass field with a custom validator). | |
| Args: | |
| validator (`Callable`): | |
| A method that takes a value as input and raises ValueError/TypeError if the value is invalid. | |
| """ | |
| def _inner( | |
| default: Union[Any, _MISSING_TYPE] = MISSING, | |
| default_factory: Union[Callable[[], Any], _MISSING_TYPE] = MISSING, | |
| init: bool = True, | |
| repr: bool = True, | |
| hash: Optional[bool] = None, | |
| compare: bool = True, | |
| metadata: Optional[Dict] = None, | |
| **kwargs: Any, | |
| ): | |
| return validated_field( | |
| validator, | |
| default=default, | |
| default_factory=default_factory, | |
| init=init, | |
| repr=repr, | |
| hash=hash, | |
| compare=compare, | |
| metadata=metadata, | |
| **kwargs, | |
| ) | |
| return _inner | |
| def type_validator(name: str, value: Any, expected_type: Any) -> None: | |
| """Validate that 'value' matches 'expected_type'.""" | |
| origin = get_origin(expected_type) | |
| args = get_args(expected_type) | |
| if expected_type is Any: | |
| return | |
| elif validator := _BASIC_TYPE_VALIDATORS.get(origin): | |
| validator(name, value, args) | |
| elif isinstance(expected_type, type): # simple types | |
| _validate_simple_type(name, value, expected_type) | |
| elif isinstance(expected_type, ForwardRef) or isinstance(expected_type, str): | |
| return | |
| else: | |
| raise TypeError(f"Unsupported type for field '{name}': {expected_type}") | |
| def _validate_union(name: str, value: Any, args: Tuple[Any, ...]) -> None: | |
| """Validate that value matches one of the types in a Union.""" | |
| errors = [] | |
| for t in args: | |
| try: | |
| type_validator(name, value, t) | |
| return # Valid if any type matches | |
| except TypeError as e: | |
| errors.append(str(e)) | |
| raise TypeError( | |
| f"Field '{name}' with value {repr(value)} doesn't match any type in {args}. Errors: {'; '.join(errors)}" | |
| ) | |
| def _validate_literal(name: str, value: Any, args: Tuple[Any, ...]) -> None: | |
| """Validate Literal type.""" | |
| if value not in args: | |
| raise TypeError(f"Field '{name}' expected one of {args}, got {value}") | |
| def _validate_list(name: str, value: Any, args: Tuple[Any, ...]) -> None: | |
| """Validate List[T] type.""" | |
| if not isinstance(value, list): | |
| raise TypeError(f"Field '{name}' expected a list, got {type(value).__name__}") | |
| # Validate each item in the list | |
| item_type = args[0] | |
| for i, item in enumerate(value): | |
| try: | |
| type_validator(f"{name}[{i}]", item, item_type) | |
| except TypeError as e: | |
| raise TypeError(f"Invalid item at index {i} in list '{name}'") from e | |
| def _validate_dict(name: str, value: Any, args: Tuple[Any, ...]) -> None: | |
| """Validate Dict[K, V] type.""" | |
| if not isinstance(value, dict): | |
| raise TypeError(f"Field '{name}' expected a dict, got {type(value).__name__}") | |
| # Validate keys and values | |
| key_type, value_type = args | |
| for k, v in value.items(): | |
| try: | |
| type_validator(f"{name}.key", k, key_type) | |
| type_validator(f"{name}[{k!r}]", v, value_type) | |
| except TypeError as e: | |
| raise TypeError(f"Invalid key or value in dict '{name}'") from e | |
| def _validate_tuple(name: str, value: Any, args: Tuple[Any, ...]) -> None: | |
| """Validate Tuple type.""" | |
| if not isinstance(value, tuple): | |
| raise TypeError(f"Field '{name}' expected a tuple, got {type(value).__name__}") | |
| # Handle variable-length tuples: Tuple[T, ...] | |
| if len(args) == 2 and args[1] is Ellipsis: | |
| for i, item in enumerate(value): | |
| try: | |
| type_validator(f"{name}[{i}]", item, args[0]) | |
| except TypeError as e: | |
| raise TypeError(f"Invalid item at index {i} in tuple '{name}'") from e | |
| # Handle fixed-length tuples: Tuple[T1, T2, ...] | |
| elif len(args) != len(value): | |
| raise TypeError(f"Field '{name}' expected a tuple of length {len(args)}, got {len(value)}") | |
| else: | |
| for i, (item, expected) in enumerate(zip(value, args)): | |
| try: | |
| type_validator(f"{name}[{i}]", item, expected) | |
| except TypeError as e: | |
| raise TypeError(f"Invalid item at index {i} in tuple '{name}'") from e | |
| def _validate_set(name: str, value: Any, args: Tuple[Any, ...]) -> None: | |
| """Validate Set[T] type.""" | |
| if not isinstance(value, set): | |
| raise TypeError(f"Field '{name}' expected a set, got {type(value).__name__}") | |
| # Validate each item in the set | |
| item_type = args[0] | |
| for i, item in enumerate(value): | |
| try: | |
| type_validator(f"{name} item", item, item_type) | |
| except TypeError as e: | |
| raise TypeError(f"Invalid item in set '{name}'") from e | |
| def _validate_simple_type(name: str, value: Any, expected_type: type) -> None: | |
| """Validate simple type (int, str, etc.).""" | |
| if not isinstance(value, expected_type): | |
| raise TypeError( | |
| f"Field '{name}' expected {expected_type.__name__}, got {type(value).__name__} (value: {repr(value)})" | |
| ) | |
| def _create_type_validator(field: Field) -> Validator_T: | |
| """Create a type validator function for a field.""" | |
| # Hacky: we cannot use a lambda here because of reference issues | |
| def validator(value: Any) -> None: | |
| type_validator(field.name, value, field.type) | |
| return validator | |
| def _is_validator(validator: Any) -> bool: | |
| """Check if a function is a validator. | |
| A validator is a Callable that can be called with a single positional argument. | |
| The validator can have more arguments with default values. | |
| Basically, returns True if `validator(value)` is possible. | |
| """ | |
| if not callable(validator): | |
| return False | |
| signature = inspect.signature(validator) | |
| parameters = list(signature.parameters.values()) | |
| if len(parameters) == 0: | |
| return False | |
| if parameters[0].kind not in ( | |
| inspect.Parameter.POSITIONAL_OR_KEYWORD, | |
| inspect.Parameter.POSITIONAL_ONLY, | |
| inspect.Parameter.VAR_POSITIONAL, | |
| ): | |
| return False | |
| for parameter in parameters[1:]: | |
| if parameter.default == inspect.Parameter.empty: | |
| return False | |
| return True | |
| _BASIC_TYPE_VALIDATORS = { | |
| Union: _validate_union, | |
| Literal: _validate_literal, | |
| list: _validate_list, | |
| dict: _validate_dict, | |
| tuple: _validate_tuple, | |
| set: _validate_set, | |
| } | |
| __all__ = [ | |
| "strict", | |
| "validated_field", | |
| "Validator_T", | |
| "StrictDataclassClassValidationError", | |
| "StrictDataclassDefinitionError", | |
| "StrictDataclassFieldValidationError", | |
| ] | |