joebruce1313's picture
Upload 38004 files
1f5470c verified
import binascii
import codecs
import marshal
import os
import types as python_types
def default(method):
"""Decorates a method to detect overrides in subclasses."""
method._is_default = True
return method
def is_default(method):
"""Check if a method is decorated with the `default` wrapper."""
return getattr(method, "_is_default", False)
def func_dump(func):
"""Serializes a user-defined function.
Args:
func: the function to serialize.
Returns:
A tuple `(code, defaults, closure)`.
"""
if os.name == "nt":
raw_code = marshal.dumps(func.__code__).replace(b"\\", b"/")
code = codecs.encode(raw_code, "base64").decode("ascii")
else:
raw_code = marshal.dumps(func.__code__)
code = codecs.encode(raw_code, "base64").decode("ascii")
defaults = func.__defaults__
if func.__closure__:
closure = tuple(c.cell_contents for c in func.__closure__)
else:
closure = None
return code, defaults, closure
def func_load(code, defaults=None, closure=None, globs=None):
"""Deserializes a user defined function.
Args:
code: bytecode of the function.
defaults: defaults of the function.
closure: closure of the function.
globs: dictionary of global objects.
Returns:
A function object.
"""
if isinstance(code, (tuple, list)): # unpack previous dump
code, defaults, closure = code
if isinstance(defaults, list):
defaults = tuple(defaults)
def ensure_value_to_cell(value):
"""Ensures that a value is converted to a python cell object.
Args:
value: Any value that needs to be casted to the cell type
Returns:
A value wrapped as a cell object (see function "func_load")
"""
def dummy_fn():
value # just access it so it gets captured in .__closure__
cell_value = dummy_fn.__closure__[0]
if not isinstance(value, type(cell_value)):
return cell_value
return value
if closure is not None:
closure = tuple(ensure_value_to_cell(_) for _ in closure)
try:
raw_code = codecs.decode(code.encode("ascii"), "base64")
except (UnicodeEncodeError, binascii.Error):
raw_code = code.encode("raw_unicode_escape")
code = marshal.loads(raw_code)
if globs is None:
globs = globals()
return python_types.FunctionType(
code, globs, name=code.co_name, argdefs=defaults, closure=closure
)
def to_list(x):
"""Normalizes a list/tensor into a list.
If a tensor is passed, we return
a list of size 1 containing the tensor.
Args:
x: target object to be normalized.
Returns:
A list.
"""
if isinstance(x, list):
return x
return [x]
def remove_long_seq(maxlen, seq, label):
"""Removes sequences that exceed the maximum length.
Args:
maxlen: Int, maximum length of the output sequences.
seq: List of lists, where each sublist is a sequence.
label: List where each element is an integer.
Returns:
new_seq, new_label: shortened lists for `seq` and `label`.
"""
new_seq, new_label = [], []
for x, y in zip(seq, label):
if len(x) < maxlen:
new_seq.append(x)
new_label.append(y)
return new_seq, new_label
def removeprefix(x, prefix):
"""Backport of `removeprefix` from PEP-616 (Python 3.9+)"""
if len(prefix) > 0 and x.startswith(prefix):
return x[len(prefix) :]
else:
return x
def removesuffix(x, suffix):
"""Backport of `removesuffix` from PEP-616 (Python 3.9+)"""
if len(suffix) > 0 and x.endswith(suffix):
return x[: -len(suffix)]
else:
return x
def remove_by_id(lst, value):
"""Remove a value from a list by id."""
for i, v in enumerate(lst):
if id(v) == id(value):
del lst[i]
return
def pythonify_logs(logs):
"""Flatten and convert log values to Python-native types.
This function attempts to convert dict value by `float(value)` and skips
the conversion if it fails.
Args:
logs: A dict containing log values.
Returns:
A flattened dict with values converted to Python-native types if
possible.
"""
logs = logs or {}
result = {}
for key, value in sorted(logs.items()):
if isinstance(value, dict):
result.update(pythonify_logs(value))
else:
try:
value = float(value)
except:
pass
result[key] = value
return result