File size: 2,136 Bytes
3d79eb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
#


from dataclasses import fields
from typing import Any, List, Mapping

from fairseq2.typing import DataClass, is_dataclass_instance


def update_dataclass(
    obj: DataClass,
    overrides: Mapping[str, Any],
) -> List[str]:
    """Update ``obj`` with the data contained in ``overrides`` Return the unknown fields.
    Copied from an old version of fairseq2 with simplification.

    :param obj:
        The data class instance to update.
    :param overrides:
        The dictionary containing the data to set in ``obj``.
    """

    unknown_fields: List[str] = []

    field_path: List[str] = []

    # The dataset config has a special attribute `silent_freeze` that does not allow hard update
    forbidden_fields_ = ["silent_freeze"]

    def update(obj_: DataClass, overrides_: Mapping[str, Any]) -> None:
        overrides_copy = {**overrides_}

        for field in fields(obj_):
            if field.name in forbidden_fields_:
                continue
            value = getattr(obj_, field.name)

            try:
                override = overrides_copy.pop(field.name)
            except KeyError:
                continue

            # Recursively traverse child dataclasses.
            if override is not None and is_dataclass_instance(value):
                if not isinstance(override, Mapping):
                    pathname = ".".join(field_path + [field.name])

                    raise RuntimeError(
                        pathname,
                        f"The field '{pathname}' is expected to be of type `{type(value)}`, but is of type `{type(override)}` instead.",  # fmt: skip
                    )

                field_path.append(field.name)

                update(value, override)

                field_path.pop()
            else:
                setattr(obj_, field.name, override)

        if overrides_copy:
            unknown_fields.extend(
                ".".join(field_path + [name]) for name in overrides_copy
            )

    update(obj, overrides)

    unknown_fields.sort()

    return unknown_fields