| import datetime |
| import glob |
| import importlib |
| import itertools |
| import os |
| import os.path as osp |
| import subprocess |
| import time |
|
|
| import numpy as np |
|
|
|
|
| class AverageMeter(object): |
| def __init__(self, avg=None, count=1): |
| self.reset() |
| if avg is not None: |
| self.val = avg |
| self.avg = avg |
| self.count = count |
| self.sum = avg * count |
|
|
| def __repr__(self) -> str: |
| return f"{self.avg: .4f}" |
|
|
| def reset(self): |
| self.val = 0 |
| self.avg = 0 |
| self.sum = 0 |
| self.count = 0 |
|
|
| def update(self, val, n=1): |
| if n > 0: |
| self.val = val |
| self.sum += val * n |
| self.count += n |
| self.avg = self.sum / self.count |
|
|
|
|
| def worker_init_fn(worker_id): |
| os.environ["worker_id"] = str(worker_id) |
| np.random.seed(np.random.get_state()[1][0] + worker_id * 7) |
|
|
|
|
| def find_last_version(folder, prefix="version_", cp="last"): |
| version_folders = glob.glob(f"{folder}/{prefix}*") |
| if cp is not None: |
| if cp == "last": |
| suffix = "last.ckpt" |
| elif cp == "best": |
| suffix = "*best*.ckpt" |
| elif cp.isdigit(): |
| suffix = f"*{int(cp):07d}.ckpt" |
| else: |
| suffix = f"{cp}.ckpt" |
| version_folders = [ |
| x for x in version_folders if len(glob.glob(f"{x}/**/{suffix}")) > 0 |
| ] |
| version_numbers = sorted( |
| [int(osp.basename(x)[len(prefix) :]) for x in version_folders] |
| ) |
| if len(version_numbers) == 0: |
| return None |
| last_version = version_numbers[-1] |
| return last_version |
|
|
|
|
| def get_eta_str(cur_iter, total_iter, time_per_iter): |
| eta = time_per_iter * (total_iter - cur_iter - 1) |
| return convert_sec_to_time(eta) |
|
|
|
|
| def convert_sec_to_time(secs): |
| return str(datetime.timedelta(seconds=round(secs))) |
|
|
|
|
| def concat_lists(list_of_lists): |
| return list(itertools.chain.from_iterable(list_of_lists)) |
|
|
|
|
| def find_consecutive_runs(x, min_len=1): |
| """Find runs of consecutive items in an array.""" |
|
|
| |
| x = np.asanyarray(x) |
| if x.ndim != 1: |
| raise ValueError("only 1D array supported") |
| n = x.shape[0] |
|
|
| |
| if n == 0: |
| return np.array([]), np.array([]), np.array([]) |
|
|
| else: |
| |
| loc_run_start = np.empty(n, dtype=bool) |
| loc_run_start[0] = True |
| np.not_equal(x[:-1], x[1:] - 1, out=loc_run_start[1:]) |
| run_starts = np.nonzero(loc_run_start)[0] |
|
|
| |
| run_lengths = np.diff(np.append(run_starts, n)) |
| ind = run_lengths >= min_len |
| run_starts = run_starts[ind] |
| run_lengths = run_lengths[ind] |
|
|
| |
| run_values = [ |
| x[start : start + length] for start, length in zip(run_starts, run_lengths) |
| ] |
| |
|
|
| return run_values, run_starts, run_lengths |
|
|
|
|
| def get_checkpoint_path(checkpoint_dir, cp, return_name=False): |
| if cp == "last": |
| cp_name = "last.ckpt" |
| elif cp == "best": |
| cp_name = osp.basename(sorted(glob.glob(f"{checkpoint_dir}/*best*.ckpt"))[-1]) |
| else: |
| cp_name = osp.basename(sorted(glob.glob(f"{checkpoint_dir}/{cp}.ckpt"))[-1]) |
| cp_path = f"{checkpoint_dir}/{cp_name}" |
| if return_name: |
| return cp_path, cp_name |
| return cp_path |
|
|
|
|
| def subprocess_run(cmd, ignore_err=False, **kwargs): |
| try: |
| result = subprocess.run(cmd, **kwargs) |
| except subprocess.CalledProcessError as err: |
| print("####### subprocess-run error message ######") |
| print(f"{err} {err.stderr.decode('utf8')}") |
| if result.returncode != 0: |
| if not ignore_err: |
| raise Exception("error in subprocess_run!") |
| return result |
|
|
|
|
| def import_type_from_str(s): |
| module_name, type_name = s.rsplit(".", 1) |
| module = importlib.import_module(module_name) |
| type_to_import = getattr(module, type_name) |
| return type_to_import |
|
|
|
|
| def build_object_from_dict(d, type_field="type", **add_kwargs): |
| d = d.copy() |
| _type = import_type_from_str(d.pop(type_field)) |
| return _type(**d, **add_kwargs) |
|
|
|
|
| def write_list_to_file(filename, string_list): |
| with open(filename, "w") as file: |
| for item in string_list: |
| file.write(item + "\n") |
|
|
|
|
| def are_arrays_equal(array1, array2, sort=False): |
| if array1 is None or array2 is None: |
| return False |
| |
| |
| if len(array1) != len(array2): |
| return False |
|
|
| |
| if sort: |
| array1 = sorted(array1) |
| array2 = sorted(array2) |
|
|
| |
| for i in range(len(array1)): |
| if array1[i] != array2[i]: |
| return False |
|
|
| return True |
|
|
|
|
| def load_ema_weights_from_checkpoint(model, checkpoint): |
| ema_params = checkpoint["optimizer_states"][0]["ema"] |
| for param, ema_param in zip(model.parameters(), ema_params): |
| param.data.copy_(ema_param.data) |
| return |
|
|
|
|
| def rsync_file_from_remote(fname, remote_dir, local_dir, hostname): |
| remote_fname = fname.replace(local_dir, f"{remote_dir}/./") |
| cmd = f"rsync -avzP -m --relative {hostname}:{remote_fname} {local_dir}/" |
| subprocess_run(cmd, shell=True) |
| return |
|
|
|
|
| |
| timer_indent_level = 0 |
|
|
|
|
| |
| class Timer: |
| def __init__(self, name="", enabled=True, show_rank=False, rank_zero_only=True): |
| self.name = name |
| self.start_time = None |
| self.enabled = enabled |
| if "LOCAL_RANK" in os.environ: |
| self.rank = int(os.environ["LOCAL_RANK"]) |
| else: |
| self.rank = 0 |
| self.show_rank = show_rank |
| self.rank_zero_only = rank_zero_only |
|
|
| def __enter__(self): |
| if (not self.enabled) or (self.rank_zero_only and self.rank != 0): |
| return self |
| global timer_indent_level |
| self.start_time = time.perf_counter() |
| self.current_indent = timer_indent_level |
| timer_indent_level += 1 |
| return self |
|
|
| def __exit__(self, exc_type, exc_val, exc_tb): |
| if exc_type: |
| return False |
| if (not self.enabled) or (self.rank_zero_only and self.rank != 0): |
| return self |
| global timer_indent_level |
| elapsed_time = time.perf_counter() - self.start_time |
| indent = " " * self.current_indent |
| rank_str = f"[rank{self.rank}] " if self.show_rank else "" |
| print(f"{indent}{rank_str}[{self.name}] time: {elapsed_time:.4f} seconds") |
| timer_indent_level -= 1 |
|
|