| | |
| | |
| |
|
| | |
| | |
| |
|
| | import logging |
| | import math |
| | import os |
| | import random |
| | import re |
| | from datetime import timedelta |
| | from typing import Optional |
| |
|
| | import hydra |
| |
|
| | import numpy as np |
| | import omegaconf |
| | import torch |
| | import torch.distributed as dist |
| | from iopath.common.file_io import g_pathmgr |
| | from omegaconf import OmegaConf |
| |
|
| |
|
| | def multiply_all(*args): |
| | return np.prod(np.array(args)).item() |
| |
|
| |
|
| | def collect_dict_keys(config): |
| | """This function recursively iterates through a dataset configuration, and collect all the dict_key that are defined""" |
| | val_keys = [] |
| | |
| | if "_target_" in config and re.match(r".*collate_fn.*", config["_target_"]): |
| | val_keys.append(config["dict_key"]) |
| | else: |
| | |
| | for v in config.values(): |
| | if isinstance(v, type(config)): |
| | val_keys.extend(collect_dict_keys(v)) |
| | elif isinstance(v, omegaconf.listconfig.ListConfig): |
| | for item in v: |
| | if isinstance(item, type(config)): |
| | val_keys.extend(collect_dict_keys(item)) |
| | return val_keys |
| |
|
| |
|
| | class Phase: |
| | TRAIN = "train" |
| | VAL = "val" |
| |
|
| |
|
| | def register_omegaconf_resolvers(): |
| | OmegaConf.register_new_resolver("get_method", hydra.utils.get_method) |
| | OmegaConf.register_new_resolver("get_class", hydra.utils.get_class) |
| | OmegaConf.register_new_resolver("add", lambda x, y: x + y) |
| | OmegaConf.register_new_resolver("times", multiply_all) |
| | OmegaConf.register_new_resolver("divide", lambda x, y: x / y) |
| | OmegaConf.register_new_resolver("pow", lambda x, y: x**y) |
| | OmegaConf.register_new_resolver("subtract", lambda x, y: x - y) |
| | OmegaConf.register_new_resolver("range", lambda x: list(range(x))) |
| | OmegaConf.register_new_resolver("int", lambda x: int(x)) |
| | OmegaConf.register_new_resolver("ceil_int", lambda x: int(math.ceil(x))) |
| | OmegaConf.register_new_resolver("merge", lambda *x: OmegaConf.merge(*x)) |
| |
|
| |
|
| | def setup_distributed_backend(backend, timeout_mins): |
| | """ |
| | Initialize torch.distributed and set the CUDA device. |
| | Expects environment variables to be set as per |
| | https://pytorch.org/docs/stable/distributed.html#environment-variable-initialization |
| | along with the environ variable "LOCAL_RANK" which is used to set the CUDA device. |
| | """ |
| | |
| | |
| | os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1" |
| | logging.info(f"Setting up torch.distributed with a timeout of {timeout_mins} mins") |
| | dist.init_process_group(backend=backend, timeout=timedelta(minutes=timeout_mins)) |
| | return dist.get_rank() |
| |
|
| |
|
| | def get_machine_local_and_dist_rank(): |
| | """ |
| | Get the distributed and local rank of the current gpu. |
| | """ |
| | local_rank = int(os.environ.get("LOCAL_RANK", None)) |
| | distributed_rank = int(os.environ.get("RANK", None)) |
| | assert ( |
| | local_rank is not None and distributed_rank is not None |
| | ), "Please the set the RANK and LOCAL_RANK environment variables." |
| | return local_rank, distributed_rank |
| |
|
| |
|
| | def print_cfg(cfg): |
| | """ |
| | Supports printing both Hydra DictConfig and also the AttrDict config |
| | """ |
| | logging.info("Training with config:") |
| | logging.info(OmegaConf.to_yaml(cfg)) |
| |
|
| |
|
| | def set_seeds(seed_value, max_epochs, dist_rank): |
| | """ |
| | Set the python random, numpy and torch seed for each gpu. Also set the CUDA |
| | seeds if the CUDA is available. This ensures deterministic nature of the training. |
| | """ |
| | |
| | seed_value = (seed_value + dist_rank) * max_epochs |
| | logging.info(f"MACHINE SEED: {seed_value}") |
| | random.seed(seed_value) |
| | np.random.seed(seed_value) |
| | torch.manual_seed(seed_value) |
| | if torch.cuda.is_available(): |
| | torch.cuda.manual_seed_all(seed_value) |
| |
|
| |
|
| | def makedir(dir_path): |
| | """ |
| | Create the directory if it does not exist. |
| | """ |
| | is_success = False |
| | try: |
| | if not g_pathmgr.exists(dir_path): |
| | g_pathmgr.mkdirs(dir_path) |
| | is_success = True |
| | except BaseException: |
| | logging.info(f"Error creating directory: {dir_path}") |
| | return is_success |
| |
|
| |
|
| | def is_dist_avail_and_initialized(): |
| | if not dist.is_available(): |
| | return False |
| | if not dist.is_initialized(): |
| | return False |
| | return True |
| |
|
| |
|
| | def get_amp_type(amp_type: Optional[str] = None): |
| | if amp_type is None: |
| | return None |
| | assert amp_type in ["bfloat16", "float16"], "Invalid Amp type." |
| | if amp_type == "bfloat16": |
| | return torch.bfloat16 |
| | else: |
| | return torch.float16 |
| |
|
| |
|
| | def log_env_variables(): |
| | env_keys = sorted(list(os.environ.keys())) |
| | st = "" |
| | for k in env_keys: |
| | v = os.environ[k] |
| | st += f"{k}={v}\n" |
| | logging.info("Logging ENV_VARIABLES") |
| | logging.info(st) |
| |
|
| |
|
| | class AverageMeter: |
| | """Computes and stores the average and current value""" |
| |
|
| | def __init__(self, name, device, fmt=":f"): |
| | self.name = name |
| | self.fmt = fmt |
| | self.device = device |
| | self.reset() |
| |
|
| | def reset(self): |
| | self.val = 0 |
| | self.avg = 0 |
| | self.sum = 0 |
| | self.count = 0 |
| | self._allow_updates = True |
| |
|
| | def update(self, val, n=1): |
| | self.val = val |
| | self.sum += val * n |
| | self.count += n |
| | self.avg = self.sum / self.count |
| |
|
| | def __str__(self): |
| | fmtstr = "{name}: {val" + self.fmt + "} ({avg" + self.fmt + "})" |
| | return fmtstr.format(**self.__dict__) |
| |
|
| |
|
| | class MemMeter: |
| | """Computes and stores the current, avg, and max of peak Mem usage per iteration""" |
| |
|
| | def __init__(self, name, device, fmt=":f"): |
| | self.name = name |
| | self.fmt = fmt |
| | self.device = device |
| | self.reset() |
| |
|
| | def reset(self): |
| | self.val = 0 |
| | self.avg = 0 |
| | self.peak = 0 |
| | self.sum = 0 |
| | self.count = 0 |
| | self._allow_updates = True |
| |
|
| | def update(self, n=1, reset_peak_usage=True): |
| | self.val = torch.cuda.max_memory_allocated() // 1e9 |
| | self.sum += self.val * n |
| | self.count += n |
| | self.avg = self.sum / self.count |
| | self.peak = max(self.peak, self.val) |
| | if reset_peak_usage: |
| | torch.cuda.reset_peak_memory_stats() |
| |
|
| | def __str__(self): |
| | fmtstr = ( |
| | "{name}: {val" |
| | + self.fmt |
| | + "} ({avg" |
| | + self.fmt |
| | + "}/{peak" |
| | + self.fmt |
| | + "})" |
| | ) |
| | return fmtstr.format(**self.__dict__) |
| |
|
| |
|
| | def human_readable_time(time_seconds): |
| | time = int(time_seconds) |
| | minutes, seconds = divmod(time, 60) |
| | hours, minutes = divmod(minutes, 60) |
| | days, hours = divmod(hours, 24) |
| | return f"{days:02}d {hours:02}h {minutes:02}m" |
| |
|
| |
|
| | class DurationMeter: |
| | def __init__(self, name, device, fmt=":f"): |
| | self.name = name |
| | self.device = device |
| | self.fmt = fmt |
| | self.val = 0 |
| |
|
| | def reset(self): |
| | self.val = 0 |
| |
|
| | def update(self, val): |
| | self.val = val |
| |
|
| | def add(self, val): |
| | self.val += val |
| |
|
| | def __str__(self): |
| | return f"{self.name}: {human_readable_time(self.val)}" |
| |
|
| |
|
| | class ProgressMeter: |
| | def __init__(self, num_batches, meters, real_meters, prefix=""): |
| | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) |
| | self.meters = meters |
| | self.real_meters = real_meters |
| | self.prefix = prefix |
| |
|
| | def display(self, batch, enable_print=False): |
| | entries = [self.prefix + self.batch_fmtstr.format(batch)] |
| | entries += [str(meter) for meter in self.meters] |
| | entries += [ |
| | " | ".join( |
| | [ |
| | f"{os.path.join(name, subname)}: {val:.4f}" |
| | for subname, val in meter.compute().items() |
| | ] |
| | ) |
| | for name, meter in self.real_meters.items() |
| | ] |
| | logging.info(" | ".join(entries)) |
| | if enable_print: |
| | print(" | ".join(entries)) |
| |
|
| | def _get_batch_fmtstr(self, num_batches): |
| | num_digits = len(str(num_batches // 1)) |
| | fmt = "{:" + str(num_digits) + "d}" |
| | return "[" + fmt + "/" + fmt.format(num_batches) + "]" |
| |
|
| |
|
| | def get_resume_checkpoint(checkpoint_save_dir): |
| | if not g_pathmgr.isdir(checkpoint_save_dir): |
| | return None |
| | ckpt_file = os.path.join(checkpoint_save_dir, "checkpoint.pt") |
| | if not g_pathmgr.isfile(ckpt_file): |
| | return None |
| |
|
| | return ckpt_file |
| |
|