|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Contain small python utility functions |
|
|
""" |
|
|
|
|
|
import importlib.util |
|
|
import re |
|
|
from functools import lru_cache |
|
|
from typing import Any, Dict, List, Union |
|
|
|
|
|
import numpy as np |
|
|
import yaml |
|
|
from yaml import Dumper |
|
|
|
|
|
|
|
|
def is_sci_notation(number: float) -> bool: |
|
|
pattern = re.compile(r"^[+-]?\d+(\.\d*)?[eE][+-]?\d+$") |
|
|
return bool(pattern.match(str(number))) |
|
|
|
|
|
|
|
|
def float_representer(dumper: Dumper, number: Union[float, np.float32, np.float64]): |
|
|
if is_sci_notation(number): |
|
|
value = str(number) |
|
|
if "." not in value and "e" in value: |
|
|
value = value.replace("e", ".0e", 1) |
|
|
else: |
|
|
value = str(round(number, 3)) |
|
|
|
|
|
return dumper.represent_scalar("tag:yaml.org,2002:float", value) |
|
|
|
|
|
|
|
|
yaml.add_representer(float, float_representer) |
|
|
yaml.add_representer(np.float32, float_representer) |
|
|
yaml.add_representer(np.float64, float_representer) |
|
|
|
|
|
|
|
|
@lru_cache |
|
|
def is_package_available(name: str) -> bool: |
|
|
return importlib.util.find_spec(name) is not None |
|
|
|
|
|
|
|
|
def union_two_dict(dict1: Dict[str, Any], dict2: Dict[str, Any]) -> Dict[str, Any]: |
|
|
"""Union two dict. Will throw an error if there is an item not the same object with the same key.""" |
|
|
for key in dict2.keys(): |
|
|
if key in dict1: |
|
|
assert dict1[key] == dict2[key], f"{key} in dict1 and dict2 are not the same object" |
|
|
|
|
|
dict1[key] = dict2[key] |
|
|
|
|
|
return dict1 |
|
|
|
|
|
|
|
|
def append_to_dict(data: Dict[str, List[Any]], new_data: Dict[str, Any]) -> None: |
|
|
"""Append dict to a dict of list.""" |
|
|
for key, val in new_data.items(): |
|
|
if key not in data: |
|
|
data[key] = [] |
|
|
|
|
|
data[key].append(val) |
|
|
|
|
|
|
|
|
def unflatten_dict(data: Dict[str, Any], sep: str = "/") -> Dict[str, Any]: |
|
|
unflattened = {} |
|
|
for key, value in data.items(): |
|
|
pieces = key.split(sep) |
|
|
pointer = unflattened |
|
|
for piece in pieces[:-1]: |
|
|
if piece not in pointer: |
|
|
pointer[piece] = {} |
|
|
|
|
|
pointer = pointer[piece] |
|
|
|
|
|
pointer[pieces[-1]] = value |
|
|
|
|
|
return unflattened |
|
|
|
|
|
|
|
|
def flatten_dict(data: Dict[str, Any], parent_key: str = "", sep: str = "/") -> Dict[str, Any]: |
|
|
flattened = {} |
|
|
for key, value in data.items(): |
|
|
new_key = parent_key + sep + key if parent_key else key |
|
|
if isinstance(value, dict): |
|
|
flattened.update(flatten_dict(value, new_key, sep=sep)) |
|
|
else: |
|
|
flattened[new_key] = value |
|
|
|
|
|
return flattened |
|
|
|
|
|
|
|
|
import torch |
|
|
def filter_config_for_hparams(config: Dict[str, Any]) -> Dict[str, Any]: |
|
|
""" |
|
|
Filter the config dictionary to only include values that are supported by add_hparams. |
|
|
""" |
|
|
filtered_config = {} |
|
|
for key, value in config.items(): |
|
|
print(f"key: {key}, value: {value}") |
|
|
if isinstance(value, (int, float, str, bool, torch.Tensor)): |
|
|
filtered_config[key] = value |
|
|
elif isinstance(value, dict): |
|
|
|
|
|
filtered_config[key] = filter_config_for_hparams(value) |
|
|
return filtered_config |
|
|
|
|
|
|
|
|
def convert_dict_to_str(data: Dict[str, Any]) -> str: |
|
|
return yaml.dump(data, indent=2) |
|
|
|