|
|
|
|
|
|
|
|
|
|
|
|
| import logging
|
| import math
|
| import os
|
| import random
|
| import re
|
| from datetime import timedelta
|
| from typing import Optional
|
|
|
| import hydra
|
|
|
| import numpy as np
|
| import omegaconf
|
| import torch
|
| import torch.distributed as dist
|
| from iopath.common.file_io import g_pathmgr
|
| from omegaconf import OmegaConf
|
|
|
|
|
| def multiply_all(*args):
|
| return np.prod(np.array(args)).item()
|
|
|
|
|
| def collect_dict_keys(config):
|
| """This function recursively iterates through a dataset configuration, and collect all the dict_key that are defined"""
|
| val_keys = []
|
|
|
| if "_target_" in config and re.match(r".*collate_fn.*", config["_target_"]):
|
| val_keys.append(config["dict_key"])
|
| else:
|
|
|
| for v in config.values():
|
| if isinstance(v, type(config)):
|
| val_keys.extend(collect_dict_keys(v))
|
| elif isinstance(v, omegaconf.listconfig.ListConfig):
|
| for item in v:
|
| if isinstance(item, type(config)):
|
| val_keys.extend(collect_dict_keys(item))
|
| return val_keys
|
|
|
|
|
| class Phase:
|
| TRAIN = "train"
|
| VAL = "val"
|
|
|
|
|
| def register_omegaconf_resolvers():
|
| OmegaConf.register_new_resolver("get_method", hydra.utils.get_method)
|
| OmegaConf.register_new_resolver("get_class", hydra.utils.get_class)
|
| OmegaConf.register_new_resolver("add", lambda x, y: x + y)
|
| OmegaConf.register_new_resolver("times", multiply_all)
|
| OmegaConf.register_new_resolver("divide", lambda x, y: x / y)
|
| OmegaConf.register_new_resolver("pow", lambda x, y: x**y)
|
| OmegaConf.register_new_resolver("subtract", lambda x, y: x - y)
|
| OmegaConf.register_new_resolver("range", lambda x: list(range(x)))
|
| OmegaConf.register_new_resolver("int", lambda x: int(x))
|
| OmegaConf.register_new_resolver("ceil_int", lambda x: int(math.ceil(x)))
|
| OmegaConf.register_new_resolver("merge", lambda *x: OmegaConf.merge(*x))
|
|
|
|
|
| def setup_distributed_backend(backend, timeout_mins):
|
| """
|
| Initialize torch.distributed and set the CUDA device.
|
| Expects environment variables to be set as per
|
| https://pytorch.org/docs/stable/distributed.html#environment-variable-initialization
|
| along with the environ variable "LOCAL_RANK" which is used to set the CUDA device.
|
| """
|
|
|
|
|
| os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1"
|
| logging.info(f"Setting up torch.distributed with a timeout of {timeout_mins} mins")
|
| dist.init_process_group(backend=backend, timeout=timedelta(minutes=timeout_mins))
|
| return dist.get_rank()
|
|
|
|
|
| def get_machine_local_and_dist_rank():
|
| """
|
| Get the distributed and local rank of the current gpu.
|
| """
|
| local_rank = int(os.environ.get("LOCAL_RANK", None))
|
| distributed_rank = int(os.environ.get("RANK", None))
|
| assert (
|
| local_rank is not None and distributed_rank is not None
|
| ), "Please the set the RANK and LOCAL_RANK environment variables."
|
| return local_rank, distributed_rank
|
|
|
|
|
| def print_cfg(cfg):
|
| """
|
| Supports printing both Hydra DictConfig and also the AttrDict config
|
| """
|
| logging.info("Training with config:")
|
| logging.info(OmegaConf.to_yaml(cfg))
|
|
|
|
|
| def set_seeds(seed_value, max_epochs, dist_rank):
|
| """
|
| Set the python random, numpy and torch seed for each gpu. Also set the CUDA
|
| seeds if the CUDA is available. This ensures deterministic nature of the training.
|
| """
|
|
|
| seed_value = (seed_value + dist_rank) * max_epochs
|
| logging.info(f"MACHINE SEED: {seed_value}")
|
| random.seed(seed_value)
|
| np.random.seed(seed_value)
|
| torch.manual_seed(seed_value)
|
| if torch.cuda.is_available():
|
| torch.cuda.manual_seed_all(seed_value)
|
|
|
|
|
| def makedir(dir_path):
|
| """
|
| Create the directory if it does not exist.
|
| """
|
| is_success = False
|
| try:
|
| if not g_pathmgr.exists(dir_path):
|
| g_pathmgr.mkdirs(dir_path)
|
| is_success = True
|
| except BaseException:
|
| logging.info(f"Error creating directory: {dir_path}")
|
| return is_success
|
|
|
|
|
| def is_dist_avail_and_initialized():
|
| if not dist.is_available():
|
| return False
|
| if not dist.is_initialized():
|
| return False
|
| return True
|
|
|
|
|
| def get_amp_type(amp_type: Optional[str] = None):
|
| if amp_type is None:
|
| return None
|
| assert amp_type in ["bfloat16", "float16"], "Invalid Amp type."
|
| if amp_type == "bfloat16":
|
| return torch.bfloat16
|
| else:
|
| return torch.float16
|
|
|
|
|
| def log_env_variables():
|
| env_keys = sorted(list(os.environ.keys()))
|
| st = ""
|
| for k in env_keys:
|
| v = os.environ[k]
|
| st += f"{k}={v}\n"
|
| logging.info("Logging ENV_VARIABLES")
|
| logging.info(st)
|
|
|
|
|
| class AverageMeter:
|
| """Computes and stores the average and current value"""
|
|
|
| def __init__(self, name, device, fmt=":f"):
|
| self.name = name
|
| self.fmt = fmt
|
| self.device = device
|
| self.reset()
|
|
|
| def reset(self):
|
| self.val = 0
|
| self.avg = 0
|
| self.sum = 0
|
| self.count = 0
|
| self._allow_updates = True
|
|
|
| def update(self, val, n=1):
|
| self.val = val
|
| self.sum += val * n
|
| self.count += n
|
| self.avg = self.sum / self.count
|
|
|
| def __str__(self):
|
| fmtstr = "{name}: {val" + self.fmt + "} ({avg" + self.fmt + "})"
|
| return fmtstr.format(**self.__dict__)
|
|
|
|
|
| class MemMeter:
|
| """Computes and stores the current, avg, and max of peak Mem usage per iteration"""
|
|
|
| def __init__(self, name, device, fmt=":f"):
|
| self.name = name
|
| self.fmt = fmt
|
| self.device = device
|
| self.reset()
|
|
|
| def reset(self):
|
| self.val = 0
|
| self.avg = 0
|
| self.peak = 0
|
| self.sum = 0
|
| self.count = 0
|
| self._allow_updates = True
|
|
|
| def update(self, n=1, reset_peak_usage=True):
|
| self.val = torch.cuda.max_memory_allocated() // 1e9
|
| self.sum += self.val * n
|
| self.count += n
|
| self.avg = self.sum / self.count
|
| self.peak = max(self.peak, self.val)
|
| if reset_peak_usage:
|
| torch.cuda.reset_peak_memory_stats()
|
|
|
| def __str__(self):
|
| fmtstr = (
|
| "{name}: {val"
|
| + self.fmt
|
| + "} ({avg"
|
| + self.fmt
|
| + "}/{peak"
|
| + self.fmt
|
| + "})"
|
| )
|
| return fmtstr.format(**self.__dict__)
|
|
|
|
|
| def human_readable_time(time_seconds):
|
| time = int(time_seconds)
|
| minutes, seconds = divmod(time, 60)
|
| hours, minutes = divmod(minutes, 60)
|
| days, hours = divmod(hours, 24)
|
| return f"{days:02}d {hours:02}h {minutes:02}m"
|
|
|
|
|
| class DurationMeter:
|
| def __init__(self, name, device, fmt=":f"):
|
| self.name = name
|
| self.device = device
|
| self.fmt = fmt
|
| self.val = 0
|
|
|
| def reset(self):
|
| self.val = 0
|
|
|
| def update(self, val):
|
| self.val = val
|
|
|
| def add(self, val):
|
| self.val += val
|
|
|
| def __str__(self):
|
| return f"{self.name}: {human_readable_time(self.val)}"
|
|
|
|
|
| class ProgressMeter:
|
| def __init__(self, num_batches, meters, real_meters, prefix=""):
|
| self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
|
| self.meters = meters
|
| self.real_meters = real_meters
|
| self.prefix = prefix
|
|
|
| def display(self, batch, enable_print=False):
|
| entries = [self.prefix + self.batch_fmtstr.format(batch)]
|
| entries += [str(meter) for meter in self.meters]
|
| entries += [
|
| " | ".join(
|
| [
|
| f"{os.path.join(name, subname)}: {val:.4f}"
|
| for subname, val in meter.compute().items()
|
| ]
|
| )
|
| for name, meter in self.real_meters.items()
|
| ]
|
| logging.info(" | ".join(entries))
|
| if enable_print:
|
| print(" | ".join(entries))
|
|
|
| def _get_batch_fmtstr(self, num_batches):
|
| num_digits = len(str(num_batches // 1))
|
| fmt = "{:" + str(num_digits) + "d}"
|
| return "[" + fmt + "/" + fmt.format(num_batches) + "]"
|
|
|
|
|
| def get_resume_checkpoint(checkpoint_save_dir):
|
| if not g_pathmgr.isdir(checkpoint_save_dir):
|
| return None
|
| ckpt_file = os.path.join(checkpoint_save_dir, "checkpoint.pt")
|
| if not g_pathmgr.isfile(ckpt_file):
|
| return None
|
|
|
| return ckpt_file
|
|
|