English
File size: 1,151 Bytes
26225c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
__all__ = ['listify', 'listify_with_reference']


def listify(obj):
    """Convert `obj` to nested lists.
    """
    if obj is None or isinstance(obj, str):
        return obj
    if not hasattr(obj, '__len__'):
        return obj
    if hasattr(obj, 'dim') and obj.dim() == 0:
        return obj
    if len(obj) == 0:
        return obj
    return [listify(x) for x in obj]


def listify_with_reference(arg_ref, *args):
    """listify `arg_ref` and the `args`, while ensuring that the length
    of `args` match the length of `arg_ref`. This is typically needed
    for parsing the input arguments of a function from an OmegaConf.
    """
    arg_ref = listify(arg_ref)
    args_out = [listify(a) for a in args]

    if arg_ref is None:
        return [], *([] for _ in args)

    if not isinstance(arg_ref, list):
        return [arg_ref], *[[a] for a in args_out]

    if len(arg_ref) == 0:
        return [], *([] for _ in args)

    for i, a in enumerate(args_out):
        if not isinstance(a, list):
            a = [a]
        if len(a) != len(arg_ref):
            a = a * len(arg_ref)
        args_out[i] = a

    return arg_ref, *args_out