File size: 3,092 Bytes
03022ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import re
import importlib
from speaker_diarization.local.utils.config import Config


def dynamic_import(import_path):
    module_name, obj_name = import_path.rsplit('.', 1)
    m = importlib.import_module(module_name)
    return getattr(m, obj_name)

def is_ref_type(value: str):
    assert isinstance(value, str), 'Input value is not a str.'
    if re.match('^<[a-zA-Z]\w*>$', value):
        return True
    else:
        return False

def is_built(ins):
    if isinstance(ins, dict):
        if 'obj' in ins and 'args' in ins:
            return False
        for i in ins.values():
            if not is_built(i):
                return False
    elif isinstance(ins, str):
        if '/' in ins:  # reference may exist in a path string.
            inss = ins.split('/')
            return is_built(inss)
        elif is_ref_type(ins):
            return False
    elif isinstance(ins, list):
        for i in ins:
            if not is_built(i):
                return False
    return True

def deep_build(ins, config, build_space: set = None):
    if is_built(ins):
        return ins

    if build_space is None:
        build_space = set()

    if isinstance(ins, list):
        for i in range(len(ins)):
            ins[i] = deep_build(ins[i], config, build_space)
        return ins
    elif isinstance(ins, dict):
        if 'obj' in ins and 'args' in ins: # return a instantiated module.
            obj = ins['obj']
            args = ins['args']
            assert isinstance(args, dict), f"Args for {obj} must be a dict."
            args = deep_build(args, config, build_space)

            module_cls = dynamic_import(obj)
            mm = module_cls(**args)
            return mm
        else:  # return a nomal dict.
            for k in ins:
                ins[k] = deep_build(ins[k], config, build_space)
            return ins
    elif isinstance(ins, str):
        if '/' in ins:  # reference may exist in a path string.
            inss = ins.split('/')
            inss = deep_build(inss, config, build_space)
            ins = '/'.join(inss)
            return ins
        elif is_ref_type(ins):
            ref = ins[1:-1]
            if ref in build_space:
                raise ValueError("Cross referencing is not allowed in config.")
            build_space.add(ref)

            if isinstance(config, dict):
                if ref not in config:
                    raise AssertionError(f"Key name {ins} not found in config.")
                attr = config[ref]
            else:
                if not hasattr(config, ref):
                    raise AssertionError(f"Key name {ins} not found in config.")
                attr = getattr(config, ref)

            attr = deep_build(attr, config, build_space)

            if isinstance(config, dict):
                config[ref] = attr
            else:
                setattr(config, ref, attr)

            build_space.remove(ref)
            return attr
        else:
            return ins
    else:
        return ins

def build(name: str, config: Config):
    return deep_build(f"<{name}>", config)