Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| import yaml | |
| import json | |
| import torch | |
| import random | |
| import warnings | |
| import importlib | |
| import numpy as np | |
| def load_yaml_config(path): | |
| with open(path) as f: | |
| config = yaml.full_load(f) | |
| return config | |
| def save_config_to_yaml(config, path): | |
| assert path.endswith(".yaml") | |
| with open(path, "w") as f: | |
| f.write(yaml.dump(config)) | |
| f.close() | |
| def save_dict_to_json(d, path, indent=None): | |
| json.dump(d, open(path, "w"), indent=indent) | |
| def load_dict_from_json(path): | |
| return json.load(open(path, "r")) | |
| def write_args(args, path): | |
| args_dict = dict( | |
| (name, getattr(args, name)) for name in dir(args) if not name.startswith("_") | |
| ) | |
| with open(path, "a") as args_file: | |
| args_file.write("==> torch version: {}\n".format(torch.__version__)) | |
| args_file.write( | |
| "==> cudnn version: {}\n".format(torch.backends.cudnn.version()) | |
| ) | |
| args_file.write("==> Cmd:\n") | |
| args_file.write(str(sys.argv)) | |
| args_file.write("\n==> args:\n") | |
| for k, v in sorted(args_dict.items()): | |
| args_file.write(" %s: %s\n" % (str(k), str(v))) | |
| args_file.close() | |
| def seed_everything(seed, cudnn_deterministic=False): | |
| """ | |
| Function that sets seed for pseudo-random number generators in: | |
| pytorch, numpy, python.random | |
| Args: | |
| seed: the integer value seed for global random state | |
| """ | |
| if seed is not None: | |
| print(f"Global seed set to {seed}") | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| torch.backends.cudnn.deterministic = False | |
| if cudnn_deterministic: | |
| torch.backends.cudnn.deterministic = True | |
| warnings.warn( | |
| "You have chosen to seed training. " | |
| "This will turn on the CUDNN deterministic setting, " | |
| "which can slow down your training considerably! " | |
| "You may see unexpected behavior when restarting " | |
| "from checkpoints." | |
| ) | |
| def merge_opts_to_config(config, opts): | |
| def modify_dict(c, nl, v): | |
| if len(nl) == 1: | |
| c[nl[0]] = type(c[nl[0]])(v) | |
| else: | |
| # print(nl) | |
| c[nl[0]] = modify_dict(c[nl[0]], nl[1:], v) | |
| return c | |
| if opts is not None and len(opts) > 0: | |
| assert ( | |
| len(opts) % 2 == 0 | |
| ), "each opts should be given by the name and values! The length shall be even number!" | |
| for i in range(len(opts) // 2): | |
| name = opts[2 * i] | |
| value = opts[2 * i + 1] | |
| config = modify_dict(config, name.split("."), value) | |
| return config | |
| def modify_config_for_debug(config): | |
| config["dataloader"]["num_workers"] = 0 | |
| config["dataloader"]["batch_size"] = 1 | |
| return config | |
| def get_model_parameters_info(model): | |
| # for mn, m in model.named_modules(): | |
| parameters = {"overall": {"trainable": 0, "non_trainable": 0, "total": 0}} | |
| for child_name, child_module in model.named_children(): | |
| parameters[child_name] = {"trainable": 0, "non_trainable": 0} | |
| for pn, p in child_module.named_parameters(): | |
| if p.requires_grad: | |
| parameters[child_name]["trainable"] += p.numel() | |
| else: | |
| parameters[child_name]["non_trainable"] += p.numel() | |
| parameters[child_name]["total"] = ( | |
| parameters[child_name]["trainable"] | |
| + parameters[child_name]["non_trainable"] | |
| ) | |
| parameters["overall"]["trainable"] += parameters[child_name]["trainable"] | |
| parameters["overall"]["non_trainable"] += parameters[child_name][ | |
| "non_trainable" | |
| ] | |
| parameters["overall"]["total"] += parameters[child_name]["total"] | |
| # format the numbers | |
| def format_number(num): | |
| K = 2**10 | |
| M = 2**20 | |
| G = 2**30 | |
| if num > G: # K | |
| uint = "G" | |
| num = round(float(num) / G, 2) | |
| elif num > M: | |
| uint = "M" | |
| num = round(float(num) / M, 2) | |
| elif num > K: | |
| uint = "K" | |
| num = round(float(num) / K, 2) | |
| else: | |
| uint = "" | |
| return "{}{}".format(num, uint) | |
| def format_dict(d): | |
| for k, v in d.items(): | |
| if isinstance(v, dict): | |
| format_dict(v) | |
| else: | |
| d[k] = format_number(v) | |
| format_dict(parameters) | |
| return parameters | |
| def format_seconds(seconds): | |
| h = int(seconds // 3600) | |
| m = int(seconds // 60 - h * 60) | |
| s = int(seconds % 60) | |
| d = int(h // 24) | |
| h = h - d * 24 | |
| if d == 0: | |
| if h == 0: | |
| if m == 0: | |
| ft = "{:02d}s".format(s) | |
| else: | |
| ft = "{:02d}m:{:02d}s".format(m, s) | |
| else: | |
| ft = "{:02d}h:{:02d}m:{:02d}s".format(h, m, s) | |
| else: | |
| ft = "{:d}d:{:02d}h:{:02d}m:{:02d}s".format(d, h, m, s) | |
| return ft | |
| def instantiate_from_config(config): | |
| if config is None: | |
| return None | |
| if not "target" in config: | |
| raise KeyError("Expected key `target` to instantiate.") | |
| module, cls = config["target"].rsplit(".", 1) | |
| cls = getattr(importlib.import_module(module, package=None), cls) | |
| return cls(**config.get("params", dict())) | |
| def class_from_string(class_name): | |
| module, cls = class_name.rsplit(".", 1) | |
| cls = getattr(importlib.import_module(module, package=None), cls) | |
| return cls | |
| def get_all_file(dir, end_with=".h5"): | |
| if isinstance(end_with, str): | |
| end_with = [end_with] | |
| filenames = [] | |
| for root, dirs, files in os.walk(dir): | |
| for f in files: | |
| for ew in end_with: | |
| if f.endswith(ew): | |
| filenames.append(os.path.join(root, f)) | |
| break | |
| return filenames | |
| def get_sub_dirs(dir, abs=True): | |
| sub_dirs = os.listdir(dir) | |
| if abs: | |
| sub_dirs = [os.path.join(dir, s) for s in sub_dirs] | |
| return sub_dirs | |
| def get_model_buffer(model): | |
| state_dict = model.state_dict() | |
| buffers_ = {} | |
| params_ = {n: p for n, p in model.named_parameters()} | |
| for k in state_dict: | |
| if k not in params_: | |
| buffers_[k] = state_dict[k] | |
| return buffers_ | |