|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from dataclasses import fields |
|
|
from typing import Any, List, Mapping |
|
|
|
|
|
from fairseq2.typing import DataClass, is_dataclass_instance |
|
|
|
|
|
|
|
|
def update_dataclass( |
|
|
obj: DataClass, |
|
|
overrides: Mapping[str, Any], |
|
|
) -> List[str]: |
|
|
"""Update ``obj`` with the data contained in ``overrides`` Return the unknown fields. |
|
|
Copied from an old version of fairseq2 with simplification. |
|
|
|
|
|
:param obj: |
|
|
The data class instance to update. |
|
|
:param overrides: |
|
|
The dictionary containing the data to set in ``obj``. |
|
|
""" |
|
|
|
|
|
unknown_fields: List[str] = [] |
|
|
|
|
|
field_path: List[str] = [] |
|
|
|
|
|
|
|
|
forbidden_fields_ = ["silent_freeze"] |
|
|
|
|
|
def update(obj_: DataClass, overrides_: Mapping[str, Any]) -> None: |
|
|
overrides_copy = {**overrides_} |
|
|
|
|
|
for field in fields(obj_): |
|
|
if field.name in forbidden_fields_: |
|
|
continue |
|
|
value = getattr(obj_, field.name) |
|
|
|
|
|
try: |
|
|
override = overrides_copy.pop(field.name) |
|
|
except KeyError: |
|
|
continue |
|
|
|
|
|
|
|
|
if override is not None and is_dataclass_instance(value): |
|
|
if not isinstance(override, Mapping): |
|
|
pathname = ".".join(field_path + [field.name]) |
|
|
|
|
|
raise RuntimeError( |
|
|
pathname, |
|
|
f"The field '{pathname}' is expected to be of type `{type(value)}`, but is of type `{type(override)}` instead.", |
|
|
) |
|
|
|
|
|
field_path.append(field.name) |
|
|
|
|
|
update(value, override) |
|
|
|
|
|
field_path.pop() |
|
|
else: |
|
|
setattr(obj_, field.name, override) |
|
|
|
|
|
if overrides_copy: |
|
|
unknown_fields.extend( |
|
|
".".join(field_path + [name]) for name in overrides_copy |
|
|
) |
|
|
|
|
|
update(obj, overrides) |
|
|
|
|
|
unknown_fields.sort() |
|
|
|
|
|
return unknown_fields |
|
|
|