| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | """ |
| | Generic utilities |
| | """ |
| |
|
| | from collections import OrderedDict |
| | from dataclasses import fields, is_dataclass |
| | from typing import Any, Tuple |
| |
|
| | import numpy as np |
| |
|
| | from .import_utils import is_torch_available, is_torch_version |
| |
|
| |
|
| | def is_tensor(x) -> bool: |
| | """ |
| | Tests if `x` is a `torch.Tensor` or `np.ndarray`. |
| | """ |
| | if is_torch_available(): |
| | import torch |
| |
|
| | if isinstance(x, torch.Tensor): |
| | return True |
| |
|
| | return isinstance(x, np.ndarray) |
| |
|
| |
|
| | class BaseOutput(OrderedDict): |
| | """ |
| | Base class for all model outputs as dataclass. Has a `__getitem__` that allows indexing by integer or slice (like a |
| | tuple) or strings (like a dictionary) that will ignore the `None` attributes. Otherwise behaves like a regular |
| | Python dictionary. |
| | |
| | <Tip warning={true}> |
| | |
| | You can't unpack a [`BaseOutput`] directly. Use the [`~utils.BaseOutput.to_tuple`] method to convert it to a tuple |
| | first. |
| | |
| | </Tip> |
| | """ |
| |
|
| | def __init_subclass__(cls) -> None: |
| | """Register subclasses as pytree nodes. |
| | |
| | This is necessary to synchronize gradients when using `torch.nn.parallel.DistributedDataParallel` with |
| | `static_graph=True` with modules that output `ModelOutput` subclasses. |
| | """ |
| | if is_torch_available(): |
| | import torch.utils._pytree |
| |
|
| | if is_torch_version("<", "2.2"): |
| | torch.utils._pytree._register_pytree_node( |
| | cls, |
| | torch.utils._pytree._dict_flatten, |
| | lambda values, context: cls(**torch.utils._pytree._dict_unflatten(values, context)), |
| | ) |
| | else: |
| | torch.utils._pytree.register_pytree_node( |
| | cls, |
| | torch.utils._pytree._dict_flatten, |
| | lambda values, context: cls(**torch.utils._pytree._dict_unflatten(values, context)), |
| | ) |
| |
|
| | def __post_init__(self) -> None: |
| | class_fields = fields(self) |
| |
|
| | |
| | if not len(class_fields): |
| | raise ValueError(f"{self.__class__.__name__} has no fields.") |
| |
|
| | first_field = getattr(self, class_fields[0].name) |
| | other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:]) |
| |
|
| | if other_fields_are_none and isinstance(first_field, dict): |
| | for key, value in first_field.items(): |
| | self[key] = value |
| | else: |
| | for field in class_fields: |
| | v = getattr(self, field.name) |
| | if v is not None: |
| | self[field.name] = v |
| |
|
| | def __delitem__(self, *args, **kwargs): |
| | raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.") |
| |
|
| | def setdefault(self, *args, **kwargs): |
| | raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.") |
| |
|
| | def pop(self, *args, **kwargs): |
| | raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.") |
| |
|
| | def update(self, *args, **kwargs): |
| | raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.") |
| |
|
| | def __getitem__(self, k: Any) -> Any: |
| | if isinstance(k, str): |
| | inner_dict = dict(self.items()) |
| | return inner_dict[k] |
| | else: |
| | return self.to_tuple()[k] |
| |
|
| | def __setattr__(self, name: Any, value: Any) -> None: |
| | if name in self.keys() and value is not None: |
| | |
| | super().__setitem__(name, value) |
| | super().__setattr__(name, value) |
| |
|
| | def __setitem__(self, key, value): |
| | |
| | super().__setitem__(key, value) |
| | |
| | super().__setattr__(key, value) |
| |
|
| | def __reduce__(self): |
| | if not is_dataclass(self): |
| | return super().__reduce__() |
| | callable, _args, *remaining = super().__reduce__() |
| | args = tuple(getattr(self, field.name) for field in fields(self)) |
| | return callable, args, *remaining |
| |
|
| | def to_tuple(self) -> Tuple[Any, ...]: |
| | """ |
| | Convert self to a tuple containing all the attributes/keys that are not `None`. |
| | """ |
| | return tuple(self[k] for k in self.keys()) |
| |
|