File size: 1,151 Bytes
26225c5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
__all__ = ['listify', 'listify_with_reference']
def listify(obj):
"""Convert `obj` to nested lists.
"""
if obj is None or isinstance(obj, str):
return obj
if not hasattr(obj, '__len__'):
return obj
if hasattr(obj, 'dim') and obj.dim() == 0:
return obj
if len(obj) == 0:
return obj
return [listify(x) for x in obj]
def listify_with_reference(arg_ref, *args):
"""listify `arg_ref` and the `args`, while ensuring that the length
of `args` match the length of `arg_ref`. This is typically needed
for parsing the input arguments of a function from an OmegaConf.
"""
arg_ref = listify(arg_ref)
args_out = [listify(a) for a in args]
if arg_ref is None:
return [], *([] for _ in args)
if not isinstance(arg_ref, list):
return [arg_ref], *[[a] for a in args_out]
if len(arg_ref) == 0:
return [], *([] for _ in args)
for i, a in enumerate(args_out):
if not isinstance(a, list):
a = [a]
if len(a) != len(arg_ref):
a = a * len(arg_ref)
args_out[i] = a
return arg_ref, *args_out
|