|
|
import binascii |
|
|
import codecs |
|
|
import marshal |
|
|
import os |
|
|
import types as python_types |
|
|
|
|
|
|
|
|
def default(method): |
|
|
"""Decorates a method to detect overrides in subclasses.""" |
|
|
method._is_default = True |
|
|
return method |
|
|
|
|
|
|
|
|
def is_default(method): |
|
|
"""Check if a method is decorated with the `default` wrapper.""" |
|
|
return getattr(method, "_is_default", False) |
|
|
|
|
|
|
|
|
def func_dump(func): |
|
|
"""Serializes a user-defined function. |
|
|
|
|
|
Args: |
|
|
func: the function to serialize. |
|
|
|
|
|
Returns: |
|
|
A tuple `(code, defaults, closure)`. |
|
|
""" |
|
|
if os.name == "nt": |
|
|
raw_code = marshal.dumps(func.__code__).replace(b"\\", b"/") |
|
|
code = codecs.encode(raw_code, "base64").decode("ascii") |
|
|
else: |
|
|
raw_code = marshal.dumps(func.__code__) |
|
|
code = codecs.encode(raw_code, "base64").decode("ascii") |
|
|
defaults = func.__defaults__ |
|
|
if func.__closure__: |
|
|
closure = tuple(c.cell_contents for c in func.__closure__) |
|
|
else: |
|
|
closure = None |
|
|
return code, defaults, closure |
|
|
|
|
|
|
|
|
def func_load(code, defaults=None, closure=None, globs=None): |
|
|
"""Deserializes a user defined function. |
|
|
|
|
|
Args: |
|
|
code: bytecode of the function. |
|
|
defaults: defaults of the function. |
|
|
closure: closure of the function. |
|
|
globs: dictionary of global objects. |
|
|
|
|
|
Returns: |
|
|
A function object. |
|
|
""" |
|
|
if isinstance(code, (tuple, list)): |
|
|
code, defaults, closure = code |
|
|
if isinstance(defaults, list): |
|
|
defaults = tuple(defaults) |
|
|
|
|
|
def ensure_value_to_cell(value): |
|
|
"""Ensures that a value is converted to a python cell object. |
|
|
|
|
|
Args: |
|
|
value: Any value that needs to be casted to the cell type |
|
|
|
|
|
Returns: |
|
|
A value wrapped as a cell object (see function "func_load") |
|
|
""" |
|
|
|
|
|
def dummy_fn(): |
|
|
value |
|
|
|
|
|
cell_value = dummy_fn.__closure__[0] |
|
|
if not isinstance(value, type(cell_value)): |
|
|
return cell_value |
|
|
return value |
|
|
|
|
|
if closure is not None: |
|
|
closure = tuple(ensure_value_to_cell(_) for _ in closure) |
|
|
try: |
|
|
raw_code = codecs.decode(code.encode("ascii"), "base64") |
|
|
except (UnicodeEncodeError, binascii.Error): |
|
|
raw_code = code.encode("raw_unicode_escape") |
|
|
code = marshal.loads(raw_code) |
|
|
if globs is None: |
|
|
globs = globals() |
|
|
return python_types.FunctionType( |
|
|
code, globs, name=code.co_name, argdefs=defaults, closure=closure |
|
|
) |
|
|
|
|
|
|
|
|
def to_list(x): |
|
|
"""Normalizes a list/tensor into a list. |
|
|
|
|
|
If a tensor is passed, we return |
|
|
a list of size 1 containing the tensor. |
|
|
|
|
|
Args: |
|
|
x: target object to be normalized. |
|
|
|
|
|
Returns: |
|
|
A list. |
|
|
""" |
|
|
if isinstance(x, list): |
|
|
return x |
|
|
return [x] |
|
|
|
|
|
|
|
|
def remove_long_seq(maxlen, seq, label): |
|
|
"""Removes sequences that exceed the maximum length. |
|
|
|
|
|
Args: |
|
|
maxlen: Int, maximum length of the output sequences. |
|
|
seq: List of lists, where each sublist is a sequence. |
|
|
label: List where each element is an integer. |
|
|
|
|
|
Returns: |
|
|
new_seq, new_label: shortened lists for `seq` and `label`. |
|
|
""" |
|
|
new_seq, new_label = [], [] |
|
|
for x, y in zip(seq, label): |
|
|
if len(x) < maxlen: |
|
|
new_seq.append(x) |
|
|
new_label.append(y) |
|
|
return new_seq, new_label |
|
|
|
|
|
|
|
|
def removeprefix(x, prefix): |
|
|
"""Backport of `removeprefix` from PEP-616 (Python 3.9+)""" |
|
|
|
|
|
if len(prefix) > 0 and x.startswith(prefix): |
|
|
return x[len(prefix) :] |
|
|
else: |
|
|
return x |
|
|
|
|
|
|
|
|
def removesuffix(x, suffix): |
|
|
"""Backport of `removesuffix` from PEP-616 (Python 3.9+)""" |
|
|
|
|
|
if len(suffix) > 0 and x.endswith(suffix): |
|
|
return x[: -len(suffix)] |
|
|
else: |
|
|
return x |
|
|
|
|
|
|
|
|
def remove_by_id(lst, value): |
|
|
"""Remove a value from a list by id.""" |
|
|
for i, v in enumerate(lst): |
|
|
if id(v) == id(value): |
|
|
del lst[i] |
|
|
return |
|
|
|
|
|
|
|
|
def pythonify_logs(logs): |
|
|
"""Flatten and convert log values to Python-native types. |
|
|
|
|
|
This function attempts to convert dict value by `float(value)` and skips |
|
|
the conversion if it fails. |
|
|
|
|
|
Args: |
|
|
logs: A dict containing log values. |
|
|
|
|
|
Returns: |
|
|
A flattened dict with values converted to Python-native types if |
|
|
possible. |
|
|
""" |
|
|
logs = logs or {} |
|
|
result = {} |
|
|
for key, value in sorted(logs.items()): |
|
|
if isinstance(value, dict): |
|
|
result.update(pythonify_logs(value)) |
|
|
else: |
|
|
try: |
|
|
value = float(value) |
|
|
except: |
|
|
pass |
|
|
result[key] = value |
|
|
return result |
|
|
|