| | |
| | |
| |
|
| | |
| | |
| |
|
| | import gc |
| | import json |
| | import logging |
| | import math |
| | import os |
| | import time |
| | from collections import OrderedDict |
| | from dataclasses import dataclass, field |
| | from typing import Any, Dict, List, Mapping, Optional |
| |
|
| | import numpy as np |
| |
|
| | import torch |
| | import torch.distributed as dist |
| | import torch.nn as nn |
| | from hydra.utils import instantiate |
| | from iopath.common.file_io import g_pathmgr |
| |
|
| | from training.optimizer import construct_optimizer |
| |
|
| | from training.utils.checkpoint_utils import ( |
| | assert_skipped_parameters_are_frozen, |
| | exclude_params_matching_unix_pattern, |
| | load_state_dict_into_model, |
| | with_check_parameter_frozen, |
| | ) |
| | from training.utils.data_utils import BatchedVideoDatapoint |
| | from training.utils.distributed import all_reduce_max, barrier, get_rank |
| |
|
| | from training.utils.logger import Logger, setup_logging |
| |
|
| | from training.utils.train_utils import ( |
| | AverageMeter, |
| | collect_dict_keys, |
| | DurationMeter, |
| | get_amp_type, |
| | get_machine_local_and_dist_rank, |
| | get_resume_checkpoint, |
| | human_readable_time, |
| | is_dist_avail_and_initialized, |
| | log_env_variables, |
| | makedir, |
| | MemMeter, |
| | Phase, |
| | ProgressMeter, |
| | set_seeds, |
| | setup_distributed_backend, |
| | ) |
| |
|
| |
|
| | CORE_LOSS_KEY = "core_loss" |
| |
|
| |
|
| | def unwrap_ddp_if_wrapped(model): |
| | if isinstance(model, torch.nn.parallel.DistributedDataParallel): |
| | return model.module |
| | return model |
| |
|
| |
|
| | @dataclass |
| | class OptimAMPConf: |
| | enabled: bool = False |
| | amp_dtype: str = "float16" |
| |
|
| |
|
| | @dataclass |
| | class OptimConf: |
| | optimizer: torch.optim.Optimizer = None |
| | options: Optional[Dict[str, Any]] = None |
| | param_group_modifiers: Optional[List] = None |
| | amp: Optional[Dict[str, Any]] = None |
| | gradient_clip: Any = None |
| | gradient_logger: Any = None |
| |
|
| | def __post_init__(self): |
| | |
| | if not isinstance(self.amp, OptimAMPConf): |
| | if self.amp is None: |
| | self.amp = {} |
| | assert isinstance(self.amp, Mapping) |
| | self.amp = OptimAMPConf(**self.amp) |
| |
|
| |
|
| | @dataclass |
| | class DistributedConf: |
| | backend: Optional[str] = None |
| | comms_dtype: Optional[str] = None |
| | find_unused_parameters: bool = False |
| | timeout_mins: int = 30 |
| |
|
| |
|
| | @dataclass |
| | class CudaConf: |
| | cudnn_deterministic: bool = False |
| | cudnn_benchmark: bool = True |
| | allow_tf32: bool = False |
| | |
| | matmul_allow_tf32: Optional[bool] = None |
| | |
| | cudnn_allow_tf32: Optional[bool] = None |
| |
|
| |
|
| | @dataclass |
| | class CheckpointConf: |
| | save_dir: str |
| | save_freq: int |
| | save_list: List[int] = field(default_factory=list) |
| | model_weight_initializer: Any = None |
| | save_best_meters: List[str] = None |
| | skip_saving_parameters: List[str] = field(default_factory=list) |
| | initialize_after_preemption: Optional[bool] = None |
| | |
| | resume_from: Optional[str] = None |
| |
|
| | def infer_missing(self): |
| | if self.initialize_after_preemption is None: |
| | with_skip_saving = len(self.skip_saving_parameters) > 0 |
| | self.initialize_after_preemption = with_skip_saving |
| | return self |
| |
|
| |
|
| | @dataclass |
| | class LoggingConf: |
| | log_dir: str |
| | log_freq: int |
| | tensorboard_writer: Any |
| | log_level_primary: str = "INFO" |
| | log_level_secondary: str = "ERROR" |
| | log_scalar_frequency: int = 100 |
| | log_visual_frequency: int = 100 |
| | scalar_keys_to_log: Optional[Dict[str, Any]] = None |
| | log_batch_stats: bool = False |
| |
|
| |
|
| | class Trainer: |
| | """ |
| | Trainer supporting the DDP training strategies. |
| | """ |
| |
|
| | EPSILON = 1e-8 |
| |
|
| | def __init__( |
| | self, |
| | *, |
| | data: Dict[str, Any], |
| | model: Dict[str, Any], |
| | logging: Dict[str, Any], |
| | checkpoint: Dict[str, Any], |
| | max_epochs: int, |
| | mode: str = "train", |
| | accelerator: str = "cuda", |
| | seed_value: int = 123, |
| | val_epoch_freq: int = 1, |
| | distributed: Dict[str, bool] = None, |
| | cuda: Dict[str, bool] = None, |
| | env_variables: Optional[Dict[str, Any]] = None, |
| | optim: Optional[Dict[str, Any]] = None, |
| | optim_overrides: Optional[List[Dict[str, Any]]] = None, |
| | meters: Optional[Dict[str, Any]] = None, |
| | loss: Optional[Dict[str, Any]] = None, |
| | ): |
| |
|
| | self._setup_env_variables(env_variables) |
| | self._setup_timers() |
| |
|
| | self.data_conf = data |
| | self.model_conf = model |
| | self.logging_conf = LoggingConf(**logging) |
| | self.checkpoint_conf = CheckpointConf(**checkpoint).infer_missing() |
| | self.max_epochs = max_epochs |
| | self.mode = mode |
| | self.val_epoch_freq = val_epoch_freq |
| | self.optim_conf = OptimConf(**optim) if optim is not None else None |
| | self.meters_conf = meters |
| | self.loss_conf = loss |
| | distributed = DistributedConf(**distributed or {}) |
| | cuda = CudaConf(**cuda or {}) |
| | self.where = 0.0 |
| |
|
| | self._infer_distributed_backend_if_none(distributed, accelerator) |
| |
|
| | self._setup_device(accelerator) |
| |
|
| | self._setup_torch_dist_and_backend(cuda, distributed) |
| |
|
| | makedir(self.logging_conf.log_dir) |
| | setup_logging( |
| | __name__, |
| | output_dir=self.logging_conf.log_dir, |
| | rank=self.rank, |
| | log_level_primary=self.logging_conf.log_level_primary, |
| | log_level_secondary=self.logging_conf.log_level_secondary, |
| | ) |
| |
|
| | set_seeds(seed_value, self.max_epochs, self.distributed_rank) |
| | log_env_variables() |
| |
|
| | assert ( |
| | is_dist_avail_and_initialized() |
| | ), "Torch distributed needs to be initialized before calling the trainer." |
| |
|
| | self._setup_components() |
| | self._move_to_device() |
| | self._construct_optimizers() |
| | self._setup_dataloaders() |
| |
|
| | self.time_elapsed_meter = DurationMeter("Time Elapsed", self.device, ":.2f") |
| |
|
| | if self.checkpoint_conf.resume_from is not None: |
| | assert os.path.exists( |
| | self.checkpoint_conf.resume_from |
| | ), f"The 'resume_from' checkpoint {self.checkpoint_conf.resume_from} does not exist!" |
| | dst = os.path.join(self.checkpoint_conf.save_dir, "checkpoint.pt") |
| | if self.distributed_rank == 0 and not os.path.exists(dst): |
| | |
| | |
| | makedir(self.checkpoint_conf.save_dir) |
| | g_pathmgr.copy(self.checkpoint_conf.resume_from, dst) |
| | barrier() |
| |
|
| | self.load_checkpoint() |
| | self._setup_ddp_distributed_training(distributed, accelerator) |
| | barrier() |
| |
|
| | def _setup_timers(self): |
| | """ |
| | Initializes counters for elapsed time and eta. |
| | """ |
| | self.start_time = time.time() |
| | self.ckpt_time_elapsed = 0 |
| | self.est_epoch_time = dict.fromkeys([Phase.TRAIN, Phase.VAL], 0) |
| |
|
| | def _get_meters(self, phase_filters=None): |
| | if self.meters is None: |
| | return {} |
| | meters = {} |
| | for phase, phase_meters in self.meters.items(): |
| | if phase_filters is not None and phase not in phase_filters: |
| | continue |
| | for key, key_meters in phase_meters.items(): |
| | if key_meters is None: |
| | continue |
| | for name, meter in key_meters.items(): |
| | meters[f"{phase}_{key}/{name}"] = meter |
| | return meters |
| |
|
| | def _infer_distributed_backend_if_none(self, distributed_conf, accelerator): |
| | if distributed_conf.backend is None: |
| | distributed_conf.backend = "nccl" if accelerator == "cuda" else "gloo" |
| |
|
| | def _setup_env_variables(self, env_variables_conf) -> None: |
| | if env_variables_conf is not None: |
| | for variable_name, value in env_variables_conf.items(): |
| | os.environ[variable_name] = value |
| |
|
| | def _setup_torch_dist_and_backend(self, cuda_conf, distributed_conf) -> None: |
| | if torch.cuda.is_available(): |
| | torch.backends.cudnn.deterministic = cuda_conf.cudnn_deterministic |
| | torch.backends.cudnn.benchmark = cuda_conf.cudnn_benchmark |
| | torch.backends.cuda.matmul.allow_tf32 = ( |
| | cuda_conf.matmul_allow_tf32 |
| | if cuda_conf.matmul_allow_tf32 is not None |
| | else cuda_conf.allow_tf32 |
| | ) |
| | torch.backends.cudnn.allow_tf32 = ( |
| | cuda_conf.cudnn_allow_tf32 |
| | if cuda_conf.cudnn_allow_tf32 is not None |
| | else cuda_conf.allow_tf32 |
| | ) |
| |
|
| | self.rank = setup_distributed_backend( |
| | distributed_conf.backend, distributed_conf.timeout_mins |
| | ) |
| |
|
| | def _setup_device(self, accelerator): |
| | self.local_rank, self.distributed_rank = get_machine_local_and_dist_rank() |
| | if accelerator == "cuda": |
| | self.device = torch.device("cuda", self.local_rank) |
| | torch.cuda.set_device(self.local_rank) |
| | elif accelerator == "cpu": |
| | self.device = torch.device("cpu") |
| | else: |
| | raise ValueError(f"Unsupported accelerator: {accelerator}") |
| |
|
| | def _setup_ddp_distributed_training(self, distributed_conf, accelerator): |
| |
|
| | assert isinstance(self.model, torch.nn.Module) |
| |
|
| | self.model = nn.parallel.DistributedDataParallel( |
| | self.model, |
| | device_ids=[self.local_rank] if accelerator == "cuda" else [], |
| | find_unused_parameters=distributed_conf.find_unused_parameters, |
| | ) |
| | if distributed_conf.comms_dtype is not None: |
| | from torch.distributed.algorithms import ddp_comm_hooks |
| |
|
| | amp_type = get_amp_type(distributed_conf.comms_dtype) |
| | if amp_type == torch.bfloat16: |
| | hook = ddp_comm_hooks.default_hooks.bf16_compress_hook |
| | logging.info("Enabling bfloat16 grad communication") |
| | else: |
| | hook = ddp_comm_hooks.default_hooks.fp16_compress_hook |
| | logging.info("Enabling fp16 grad communication") |
| | process_group = None |
| | self.model.register_comm_hook(process_group, hook) |
| |
|
| | def _move_to_device(self): |
| | logging.info( |
| | f"Moving components to device {self.device} and local rank {self.local_rank}." |
| | ) |
| |
|
| | self.model.to(self.device) |
| |
|
| | logging.info( |
| | f"Done moving components to device {self.device} and local rank {self.local_rank}." |
| | ) |
| |
|
| | def save_checkpoint(self, epoch, checkpoint_names=None): |
| | checkpoint_folder = self.checkpoint_conf.save_dir |
| | makedir(checkpoint_folder) |
| | if checkpoint_names is None: |
| | checkpoint_names = ["checkpoint"] |
| | if ( |
| | self.checkpoint_conf.save_freq > 0 |
| | and (int(epoch) % self.checkpoint_conf.save_freq == 0) |
| | ) or int(epoch) in self.checkpoint_conf.save_list: |
| | checkpoint_names.append(f"checkpoint_{int(epoch)}") |
| |
|
| | checkpoint_paths = [] |
| | for ckpt_name in checkpoint_names: |
| | checkpoint_paths.append(os.path.join(checkpoint_folder, f"{ckpt_name}.pt")) |
| |
|
| | state_dict = unwrap_ddp_if_wrapped(self.model).state_dict() |
| | state_dict = exclude_params_matching_unix_pattern( |
| | patterns=self.checkpoint_conf.skip_saving_parameters, state_dict=state_dict |
| | ) |
| |
|
| | checkpoint = { |
| | "model": state_dict, |
| | "optimizer": self.optim.optimizer.state_dict(), |
| | "epoch": epoch, |
| | "loss": self.loss.state_dict(), |
| | "steps": self.steps, |
| | "time_elapsed": self.time_elapsed_meter.val, |
| | "best_meter_values": self.best_meter_values, |
| | } |
| | if self.optim_conf.amp.enabled: |
| | checkpoint["scaler"] = self.scaler.state_dict() |
| |
|
| | |
| | if self.distributed_rank != 0: |
| | return |
| |
|
| | for checkpoint_path in checkpoint_paths: |
| | self._save_checkpoint(checkpoint, checkpoint_path) |
| |
|
| | def _save_checkpoint(self, checkpoint, checkpoint_path): |
| | """ |
| | Save a checkpoint while guarding against the job being killed in the middle |
| | of checkpoint saving (which corrupts the checkpoint file and ruins the |
| | entire training since usually only the last checkpoint is kept per run). |
| | |
| | We first save the new checkpoint to a temp file (with a '.tmp' suffix), and |
| | and move it to overwrite the old checkpoint_path. |
| | """ |
| | checkpoint_path_tmp = f"{checkpoint_path}.tmp" |
| | with g_pathmgr.open(checkpoint_path_tmp, "wb") as f: |
| | torch.save(checkpoint, f) |
| | |
| | if g_pathmgr.exists(checkpoint_path): |
| | |
| | g_pathmgr.rm(checkpoint_path) |
| | success = g_pathmgr.mv(checkpoint_path_tmp, checkpoint_path) |
| | assert success |
| |
|
| | def load_checkpoint(self): |
| | ckpt_path = get_resume_checkpoint(self.checkpoint_conf.save_dir) |
| | if ckpt_path is None: |
| | self._init_model_state() |
| | else: |
| | if self.checkpoint_conf.initialize_after_preemption: |
| | self._call_model_initializer() |
| | self._load_resuming_checkpoint(ckpt_path) |
| |
|
| | def _init_model_state(self): |
| | |
| | |
| | |
| | assert_skipped_parameters_are_frozen( |
| | patterns=self.checkpoint_conf.skip_saving_parameters, |
| | model=self.model, |
| | ) |
| |
|
| | |
| | |
| | |
| | |
| | allow_init_skip_parameters = self.checkpoint_conf.initialize_after_preemption |
| | with with_check_parameter_frozen( |
| | patterns=self.checkpoint_conf.skip_saving_parameters, |
| | model=self.model, |
| | disabled=allow_init_skip_parameters, |
| | ): |
| | self._call_model_initializer() |
| |
|
| | def _call_model_initializer(self): |
| | model_weight_initializer = instantiate( |
| | self.checkpoint_conf.model_weight_initializer |
| | ) |
| | if model_weight_initializer is not None: |
| | logging.info( |
| | f"Loading pretrained checkpoint from {self.checkpoint_conf.model_weight_initializer}" |
| | ) |
| | self.model = model_weight_initializer(model=self.model) |
| |
|
| | def _load_resuming_checkpoint(self, ckpt_path: str): |
| | logging.info(f"Resuming training from {ckpt_path}") |
| |
|
| | with g_pathmgr.open(ckpt_path, "rb") as f: |
| | checkpoint = torch.load(f, map_location="cpu") |
| | load_state_dict_into_model( |
| | model=self.model, |
| | state_dict=checkpoint["model"], |
| | ignore_missing_keys=self.checkpoint_conf.skip_saving_parameters, |
| | ) |
| |
|
| | self.optim.optimizer.load_state_dict(checkpoint["optimizer"]) |
| | self.loss.load_state_dict(checkpoint["loss"], strict=True) |
| | self.epoch = checkpoint["epoch"] |
| | self.steps = checkpoint["steps"] |
| | self.ckpt_time_elapsed = checkpoint.get("time_elapsed") |
| |
|
| | if self.optim_conf.amp.enabled and "scaler" in checkpoint: |
| | self.scaler.load_state_dict(checkpoint["scaler"]) |
| |
|
| | self.best_meter_values = checkpoint.get("best_meter_values", {}) |
| |
|
| | if "train_dataset" in checkpoint and self.train_dataset is not None: |
| | self.train_dataset.load_checkpoint_state(checkpoint["train_dataset"]) |
| |
|
| | def is_intermediate_val_epoch(self, epoch): |
| | return epoch % self.val_epoch_freq == 0 and epoch < self.max_epochs - 1 |
| |
|
| | def _step( |
| | self, |
| | batch: BatchedVideoDatapoint, |
| | model: nn.Module, |
| | phase: str, |
| | ): |
| |
|
| | outputs = model(batch) |
| | targets = batch.masks |
| | batch_size = len(batch.img_batch) |
| |
|
| | key = batch.dict_key |
| | loss = self.loss[key](outputs, targets) |
| | loss_str = f"Losses/{phase}_{key}_loss" |
| |
|
| | loss_log_str = os.path.join("Step_Losses", loss_str) |
| |
|
| | |
| | step_losses = {} |
| | if isinstance(loss, dict): |
| | step_losses.update( |
| | {f"Losses/{phase}_{key}_{k}": v for k, v in loss.items()} |
| | ) |
| | loss = self._log_loss_detailed_and_return_core_loss( |
| | loss, loss_log_str, self.steps[phase] |
| | ) |
| |
|
| | if self.steps[phase] % self.logging_conf.log_scalar_frequency == 0: |
| | self.logger.log( |
| | loss_log_str, |
| | loss, |
| | self.steps[phase], |
| | ) |
| |
|
| | self.steps[phase] += 1 |
| |
|
| | ret_tuple = {loss_str: loss}, batch_size, step_losses |
| |
|
| | if phase in self.meters and key in self.meters[phase]: |
| | meters_dict = self.meters[phase][key] |
| | if meters_dict is not None: |
| | for _, meter in meters_dict.items(): |
| | meter.update( |
| | find_stages=outputs, |
| | find_metadatas=batch.metadata, |
| | ) |
| |
|
| | return ret_tuple |
| |
|
| | def run(self): |
| | assert self.mode in ["train", "train_only", "val"] |
| | if self.mode == "train": |
| | if self.epoch > 0: |
| | logging.info(f"Resuming training from epoch: {self.epoch}") |
| | |
| | if self.is_intermediate_val_epoch(self.epoch - 1): |
| | logging.info("Running previous val epoch") |
| | self.epoch -= 1 |
| | self.run_val() |
| | self.epoch += 1 |
| | self.run_train() |
| | self.run_val() |
| | elif self.mode == "val": |
| | self.run_val() |
| | elif self.mode == "train_only": |
| | self.run_train() |
| |
|
| | def _setup_dataloaders(self): |
| | self.train_dataset = None |
| | self.val_dataset = None |
| |
|
| | if self.mode in ["train", "val"]: |
| | self.val_dataset = instantiate(self.data_conf.get(Phase.VAL, None)) |
| |
|
| | if self.mode in ["train", "train_only"]: |
| | self.train_dataset = instantiate(self.data_conf.train) |
| |
|
| | def run_train(self): |
| |
|
| | while self.epoch < self.max_epochs: |
| | dataloader = self.train_dataset.get_loader(epoch=int(self.epoch)) |
| | barrier() |
| | outs = self.train_epoch(dataloader) |
| | self.logger.log_dict(outs, self.epoch) |
| |
|
| | |
| | if self.distributed_rank == 0: |
| | with g_pathmgr.open( |
| | os.path.join(self.logging_conf.log_dir, "train_stats.json"), |
| | "a", |
| | ) as f: |
| | f.write(json.dumps(outs) + "\n") |
| |
|
| | |
| | self.save_checkpoint(self.epoch + 1) |
| |
|
| | del dataloader |
| | gc.collect() |
| |
|
| | |
| | |
| | if self.is_intermediate_val_epoch(self.epoch): |
| | self.run_val() |
| |
|
| | if self.distributed_rank == 0: |
| | self.best_meter_values.update(self._get_trainer_state("train")) |
| | with g_pathmgr.open( |
| | os.path.join(self.logging_conf.log_dir, "best_stats.json"), |
| | "a", |
| | ) as f: |
| | f.write(json.dumps(self.best_meter_values) + "\n") |
| |
|
| | self.epoch += 1 |
| | |
| | self.epoch -= 1 |
| |
|
| | def run_val(self): |
| | if not self.val_dataset: |
| | return |
| |
|
| | dataloader = self.val_dataset.get_loader(epoch=int(self.epoch)) |
| | outs = self.val_epoch(dataloader, phase=Phase.VAL) |
| | del dataloader |
| | gc.collect() |
| | self.logger.log_dict(outs, self.epoch) |
| |
|
| | if self.distributed_rank == 0: |
| | with g_pathmgr.open( |
| | os.path.join(self.logging_conf.log_dir, "val_stats.json"), |
| | "a", |
| | ) as f: |
| | f.write(json.dumps(outs) + "\n") |
| |
|
| | def val_epoch(self, val_loader, phase): |
| | batch_time = AverageMeter("Batch Time", self.device, ":.2f") |
| | data_time = AverageMeter("Data Time", self.device, ":.2f") |
| | mem = MemMeter("Mem (GB)", self.device, ":.2f") |
| |
|
| | iters_per_epoch = len(val_loader) |
| |
|
| | curr_phases = [phase] |
| | curr_models = [self.model] |
| |
|
| | loss_names = [] |
| | for p in curr_phases: |
| | for key in self.loss.keys(): |
| | loss_names.append(f"Losses/{p}_{key}_loss") |
| |
|
| | loss_mts = OrderedDict( |
| | [(name, AverageMeter(name, self.device, ":.2e")) for name in loss_names] |
| | ) |
| | extra_loss_mts = {} |
| |
|
| | for model in curr_models: |
| | model.eval() |
| | if hasattr(unwrap_ddp_if_wrapped(model), "on_validation_epoch_start"): |
| | unwrap_ddp_if_wrapped(model).on_validation_epoch_start() |
| |
|
| | progress = ProgressMeter( |
| | iters_per_epoch, |
| | [batch_time, data_time, mem, self.time_elapsed_meter, *loss_mts.values()], |
| | self._get_meters(curr_phases), |
| | prefix="Val Epoch: [{}]".format(self.epoch), |
| | ) |
| |
|
| | end = time.time() |
| |
|
| | for data_iter, batch in enumerate(val_loader): |
| |
|
| | |
| | data_time.update(time.time() - end) |
| |
|
| | batch = batch.to(self.device, non_blocking=True) |
| |
|
| | |
| | with torch.no_grad(): |
| | with torch.cuda.amp.autocast( |
| | enabled=(self.optim_conf.amp.enabled if self.optim_conf else False), |
| | dtype=( |
| | get_amp_type(self.optim_conf.amp.amp_dtype) |
| | if self.optim_conf |
| | else None |
| | ), |
| | ): |
| | for phase, model in zip(curr_phases, curr_models): |
| | loss_dict, batch_size, extra_losses = self._step( |
| | batch, |
| | model, |
| | phase, |
| | ) |
| |
|
| | assert len(loss_dict) == 1 |
| | loss_key, loss = loss_dict.popitem() |
| |
|
| | loss_mts[loss_key].update(loss.item(), batch_size) |
| |
|
| | for k, v in extra_losses.items(): |
| | if k not in extra_loss_mts: |
| | extra_loss_mts[k] = AverageMeter(k, self.device, ":.2e") |
| | extra_loss_mts[k].update(v.item(), batch_size) |
| |
|
| | |
| | batch_time.update(time.time() - end) |
| | end = time.time() |
| |
|
| | self.time_elapsed_meter.update( |
| | time.time() - self.start_time + self.ckpt_time_elapsed |
| | ) |
| |
|
| | if torch.cuda.is_available(): |
| | mem.update(reset_peak_usage=True) |
| |
|
| | if data_iter % self.logging_conf.log_freq == 0: |
| | progress.display(data_iter) |
| |
|
| | if data_iter % self.logging_conf.log_scalar_frequency == 0: |
| | |
| | for progress_meter in progress.meters: |
| | self.logger.log( |
| | os.path.join("Step_Stats", phase, progress_meter.name), |
| | progress_meter.val, |
| | self.steps[Phase.VAL], |
| | ) |
| |
|
| | if data_iter % 10 == 0: |
| | dist.barrier() |
| |
|
| | self.est_epoch_time[phase] = batch_time.avg * iters_per_epoch |
| | self._log_timers(phase) |
| | for model in curr_models: |
| | if hasattr(unwrap_ddp_if_wrapped(model), "on_validation_epoch_end"): |
| | unwrap_ddp_if_wrapped(model).on_validation_epoch_end() |
| |
|
| | out_dict = self._log_meters_and_save_best_ckpts(curr_phases) |
| |
|
| | for k, v in loss_mts.items(): |
| | out_dict[k] = v.avg |
| | for k, v in extra_loss_mts.items(): |
| | out_dict[k] = v.avg |
| |
|
| | for phase in curr_phases: |
| | out_dict.update(self._get_trainer_state(phase)) |
| | self._reset_meters(curr_phases) |
| | logging.info(f"Meters: {out_dict}") |
| | return out_dict |
| |
|
| | def _get_trainer_state(self, phase): |
| | return { |
| | "Trainer/where": self.where, |
| | "Trainer/epoch": self.epoch, |
| | f"Trainer/steps_{phase}": self.steps[phase], |
| | } |
| |
|
| | def train_epoch(self, train_loader): |
| |
|
| | |
| | batch_time_meter = AverageMeter("Batch Time", self.device, ":.2f") |
| | data_time_meter = AverageMeter("Data Time", self.device, ":.2f") |
| | mem_meter = MemMeter("Mem (GB)", self.device, ":.2f") |
| | data_times = [] |
| | phase = Phase.TRAIN |
| |
|
| | iters_per_epoch = len(train_loader) |
| |
|
| | loss_names = [] |
| | for batch_key in self.loss.keys(): |
| | loss_names.append(f"Losses/{phase}_{batch_key}_loss") |
| |
|
| | loss_mts = OrderedDict( |
| | [(name, AverageMeter(name, self.device, ":.2e")) for name in loss_names] |
| | ) |
| | extra_loss_mts = {} |
| |
|
| | progress = ProgressMeter( |
| | iters_per_epoch, |
| | [ |
| | batch_time_meter, |
| | data_time_meter, |
| | mem_meter, |
| | self.time_elapsed_meter, |
| | *loss_mts.values(), |
| | ], |
| | self._get_meters([phase]), |
| | prefix="Train Epoch: [{}]".format(self.epoch), |
| | ) |
| |
|
| | |
| | self.model.train() |
| | end = time.time() |
| |
|
| | for data_iter, batch in enumerate(train_loader): |
| | |
| | data_time_meter.update(time.time() - end) |
| | data_times.append(data_time_meter.val) |
| | batch = batch.to( |
| | self.device, non_blocking=True |
| | ) |
| |
|
| | try: |
| | self._run_step(batch, phase, loss_mts, extra_loss_mts) |
| |
|
| | |
| | exact_epoch = self.epoch + float(data_iter) / iters_per_epoch |
| | self.where = float(exact_epoch) / self.max_epochs |
| | assert self.where <= 1 + self.EPSILON |
| | if self.where < 1.0: |
| | self.optim.step_schedulers( |
| | self.where, step=int(exact_epoch * iters_per_epoch) |
| | ) |
| | else: |
| | logging.warning( |
| | f"Skipping scheduler update since the training is at the end, i.e, {self.where} of [0,1]." |
| | ) |
| |
|
| | |
| | if data_iter % self.logging_conf.log_scalar_frequency == 0: |
| | for j, param_group in enumerate(self.optim.optimizer.param_groups): |
| | for option in self.optim.schedulers[j]: |
| | optim_prefix = ( |
| | "" + f"{j}_" |
| | if len(self.optim.optimizer.param_groups) > 1 |
| | else "" |
| | ) |
| | self.logger.log( |
| | os.path.join("Optim", f"{optim_prefix}", option), |
| | param_group[option], |
| | self.steps[phase], |
| | ) |
| |
|
| | |
| | if self.gradient_clipper is not None: |
| | self.scaler.unscale_(self.optim.optimizer) |
| | self.gradient_clipper(model=self.model) |
| |
|
| | if self.gradient_logger is not None: |
| | self.gradient_logger( |
| | self.model, rank=self.distributed_rank, where=self.where |
| | ) |
| |
|
| | |
| | |
| | self.scaler.step(self.optim.optimizer) |
| | self.scaler.update() |
| |
|
| | |
| | batch_time_meter.update(time.time() - end) |
| | end = time.time() |
| |
|
| | self.time_elapsed_meter.update( |
| | time.time() - self.start_time + self.ckpt_time_elapsed |
| | ) |
| |
|
| | mem_meter.update(reset_peak_usage=True) |
| | if data_iter % self.logging_conf.log_freq == 0: |
| | progress.display(data_iter) |
| |
|
| | if data_iter % self.logging_conf.log_scalar_frequency == 0: |
| | |
| | for progress_meter in progress.meters: |
| | self.logger.log( |
| | os.path.join("Step_Stats", phase, progress_meter.name), |
| | progress_meter.val, |
| | self.steps[phase], |
| | ) |
| |
|
| | |
| | except FloatingPointError as e: |
| | raise e |
| |
|
| | self.est_epoch_time[Phase.TRAIN] = batch_time_meter.avg * iters_per_epoch |
| | self._log_timers(Phase.TRAIN) |
| | self._log_sync_data_times(Phase.TRAIN, data_times) |
| |
|
| | out_dict = self._log_meters_and_save_best_ckpts([Phase.TRAIN]) |
| |
|
| | for k, v in loss_mts.items(): |
| | out_dict[k] = v.avg |
| | for k, v in extra_loss_mts.items(): |
| | out_dict[k] = v.avg |
| | out_dict.update(self._get_trainer_state(phase)) |
| | logging.info(f"Losses and meters: {out_dict}") |
| | self._reset_meters([phase]) |
| | return out_dict |
| |
|
| | def _log_sync_data_times(self, phase, data_times): |
| | data_times = all_reduce_max(torch.tensor(data_times)).tolist() |
| | steps = range(self.steps[phase] - len(data_times), self.steps[phase]) |
| | for step, data_time in zip(steps, data_times): |
| | if step % self.logging_conf.log_scalar_frequency == 0: |
| | self.logger.log( |
| | os.path.join("Step_Stats", phase, "Data Time Synced"), |
| | data_time, |
| | step, |
| | ) |
| |
|
| | def _run_step( |
| | self, |
| | batch: BatchedVideoDatapoint, |
| | phase: str, |
| | loss_mts: Dict[str, AverageMeter], |
| | extra_loss_mts: Dict[str, AverageMeter], |
| | raise_on_error: bool = True, |
| | ): |
| | """ |
| | Run the forward / backward |
| | """ |
| |
|
| | |
| | |
| | |
| | self.optim.zero_grad(set_to_none=True) |
| | with torch.cuda.amp.autocast( |
| | enabled=self.optim_conf.amp.enabled, |
| | dtype=get_amp_type(self.optim_conf.amp.amp_dtype), |
| | ): |
| | loss_dict, batch_size, extra_losses = self._step( |
| | batch, |
| | self.model, |
| | phase, |
| | ) |
| |
|
| | assert len(loss_dict) == 1 |
| | loss_key, loss = loss_dict.popitem() |
| |
|
| | if not math.isfinite(loss.item()): |
| | error_msg = f"Loss is {loss.item()}, attempting to stop training" |
| | logging.error(error_msg) |
| | if raise_on_error: |
| | raise FloatingPointError(error_msg) |
| | else: |
| | return |
| |
|
| | self.scaler.scale(loss).backward() |
| | loss_mts[loss_key].update(loss.item(), batch_size) |
| | for extra_loss_key, extra_loss in extra_losses.items(): |
| | if extra_loss_key not in extra_loss_mts: |
| | extra_loss_mts[extra_loss_key] = AverageMeter( |
| | extra_loss_key, self.device, ":.2e" |
| | ) |
| | extra_loss_mts[extra_loss_key].update(extra_loss.item(), batch_size) |
| |
|
| | def _log_meters_and_save_best_ckpts(self, phases: List[str]): |
| | logging.info("Synchronizing meters") |
| | out_dict = {} |
| | checkpoint_save_keys = [] |
| | for key, meter in self._get_meters(phases).items(): |
| | meter_output = meter.compute_synced() |
| | is_better_check = getattr(meter, "is_better", None) |
| |
|
| | for meter_subkey, meter_value in meter_output.items(): |
| | out_dict[os.path.join("Meters_train", key, meter_subkey)] = meter_value |
| |
|
| | if is_better_check is None: |
| | continue |
| |
|
| | tracked_meter_key = os.path.join(key, meter_subkey) |
| | if tracked_meter_key not in self.best_meter_values or is_better_check( |
| | meter_value, |
| | self.best_meter_values[tracked_meter_key], |
| | ): |
| | self.best_meter_values[tracked_meter_key] = meter_value |
| |
|
| | if ( |
| | self.checkpoint_conf.save_best_meters is not None |
| | and key in self.checkpoint_conf.save_best_meters |
| | ): |
| | checkpoint_save_keys.append(tracked_meter_key.replace("/", "_")) |
| |
|
| | if len(checkpoint_save_keys) > 0: |
| | self.save_checkpoint(self.epoch + 1, checkpoint_save_keys) |
| |
|
| | return out_dict |
| |
|
| | def _log_timers(self, phase): |
| | time_remaining = 0 |
| | epochs_remaining = self.max_epochs - self.epoch - 1 |
| | val_epochs_remaining = sum( |
| | n % self.val_epoch_freq == 0 for n in range(self.epoch, self.max_epochs) |
| | ) |
| |
|
| | |
| | |
| | if (self.max_epochs - 1) % self.val_epoch_freq != 0: |
| | val_epochs_remaining += 1 |
| |
|
| | |
| | if phase == Phase.VAL: |
| | val_epochs_remaining -= 1 |
| |
|
| | time_remaining += ( |
| | epochs_remaining * self.est_epoch_time[Phase.TRAIN] |
| | + val_epochs_remaining * self.est_epoch_time[Phase.VAL] |
| | ) |
| |
|
| | self.logger.log( |
| | os.path.join("Step_Stats", phase, self.time_elapsed_meter.name), |
| | self.time_elapsed_meter.val, |
| | self.steps[phase], |
| | ) |
| |
|
| | logging.info(f"Estimated time remaining: {human_readable_time(time_remaining)}") |
| |
|
| | def _reset_meters(self, phases: str) -> None: |
| | for meter in self._get_meters(phases).values(): |
| | meter.reset() |
| |
|
| | def _check_val_key_match(self, val_keys, phase): |
| | if val_keys is not None: |
| | |
| | assert len(val_keys) == len( |
| | set(val_keys) |
| | ), f"Duplicate keys in val datasets, keys: {val_keys}" |
| |
|
| | |
| | if self.meters_conf is not None and phase in self.meters_conf: |
| | assert set(val_keys) == set(self.meters_conf[phase].keys()), ( |
| | f"Keys in val datasets do not match the keys in meters." |
| | f"\nMissing in meters: {set(val_keys) - set(self.meters_conf[phase].keys())}" |
| | f"\nMissing in val datasets: {set(self.meters_conf[phase].keys()) - set(val_keys)}" |
| | ) |
| |
|
| | if self.loss_conf is not None: |
| | loss_keys = set(self.loss_conf.keys()) - set(["all"]) |
| | assert all([k in loss_keys for k in val_keys]), ( |
| | f"Keys in val datasets do not match the keys in losses." |
| | f"\nMissing in losses: {set(val_keys) - loss_keys}" |
| | f"\nMissing in val datasets: {loss_keys - set(val_keys)}" |
| | ) |
| |
|
| | def _setup_components(self): |
| |
|
| | |
| | val_phase = Phase.VAL |
| | val_keys = None |
| | if self.data_conf.get(val_phase, None) is not None: |
| | val_keys = collect_dict_keys(self.data_conf[val_phase]) |
| | |
| | self._check_val_key_match(val_keys, phase=val_phase) |
| |
|
| | logging.info("Setting up components: Model, loss, optim, meters etc.") |
| | self.epoch = 0 |
| | self.steps = {Phase.TRAIN: 0, Phase.VAL: 0} |
| |
|
| | self.logger = Logger(self.logging_conf) |
| |
|
| | self.model = instantiate(self.model_conf, _convert_="all") |
| | print_model_summary(self.model) |
| |
|
| | self.loss = None |
| | if self.loss_conf: |
| | self.loss = { |
| | key: el |
| | for (key, el) in instantiate(self.loss_conf, _convert_="all").items() |
| | } |
| | self.loss = nn.ModuleDict(self.loss) |
| |
|
| | self.meters = {} |
| | self.best_meter_values = {} |
| | if self.meters_conf: |
| | self.meters = instantiate(self.meters_conf, _convert_="all") |
| |
|
| | self.scaler = torch.amp.GradScaler( |
| | self.device, |
| | enabled=self.optim_conf.amp.enabled if self.optim_conf else False, |
| | ) |
| |
|
| | self.gradient_clipper = ( |
| | instantiate(self.optim_conf.gradient_clip) if self.optim_conf else None |
| | ) |
| | self.gradient_logger = ( |
| | instantiate(self.optim_conf.gradient_logger) if self.optim_conf else None |
| | ) |
| |
|
| | logging.info("Finished setting up components: Model, loss, optim, meters etc.") |
| |
|
| | def _construct_optimizers(self): |
| | self.optim = construct_optimizer( |
| | self.model, |
| | self.optim_conf.optimizer, |
| | self.optim_conf.options, |
| | self.optim_conf.param_group_modifiers, |
| | ) |
| |
|
| | def _log_loss_detailed_and_return_core_loss(self, loss, loss_str, step): |
| | core_loss = loss.pop(CORE_LOSS_KEY) |
| | if step % self.logging_conf.log_scalar_frequency == 0: |
| | for k in loss: |
| | log_str = os.path.join(loss_str, k) |
| | self.logger.log(log_str, loss[k], step) |
| | return core_loss |
| |
|
| |
|
| | def print_model_summary(model: torch.nn.Module, log_dir: str = ""): |
| | """ |
| | Prints the model and the number of parameters in the model. |
| | # Multiple packages provide this info in a nice table format |
| | # However, they need us to provide an `input` (as they also write down the output sizes) |
| | # Our models are complex, and a single input is restrictive. |
| | # https://github.com/sksq96/pytorch-summary |
| | # https://github.com/nmhkahn/torchsummaryX |
| | """ |
| | if get_rank() != 0: |
| | return |
| | param_kwargs = {} |
| | trainable_parameters = sum( |
| | p.numel() for p in model.parameters(**param_kwargs) if p.requires_grad |
| | ) |
| | total_parameters = sum(p.numel() for p in model.parameters(**param_kwargs)) |
| | non_trainable_parameters = total_parameters - trainable_parameters |
| | logging.info("==" * 10) |
| | logging.info(f"Summary for model {type(model)}") |
| | logging.info(f"Model is {model}") |
| | logging.info(f"\tTotal parameters {get_human_readable_count(total_parameters)}") |
| | logging.info( |
| | f"\tTrainable parameters {get_human_readable_count(trainable_parameters)}" |
| | ) |
| | logging.info( |
| | f"\tNon-Trainable parameters {get_human_readable_count(non_trainable_parameters)}" |
| | ) |
| | logging.info("==" * 10) |
| |
|
| | if log_dir: |
| | output_fpath = os.path.join(log_dir, "model.txt") |
| | with g_pathmgr.open(output_fpath, "w") as f: |
| | print(model, file=f) |
| |
|
| |
|
| | PARAMETER_NUM_UNITS = [" ", "K", "M", "B", "T"] |
| |
|
| |
|
| | def get_human_readable_count(number: int) -> str: |
| | """ |
| | Abbreviates an integer number with K, M, B, T for thousands, millions, |
| | billions and trillions, respectively. |
| | Examples: |
| | >>> get_human_readable_count(123) |
| | '123 ' |
| | >>> get_human_readable_count(1234) # (one thousand) |
| | '1.2 K' |
| | >>> get_human_readable_count(2e6) # (two million) |
| | '2.0 M' |
| | >>> get_human_readable_count(3e9) # (three billion) |
| | '3.0 B' |
| | >>> get_human_readable_count(4e14) # (four hundred trillion) |
| | '400 T' |
| | >>> get_human_readable_count(5e15) # (more than trillion) |
| | '5,000 T' |
| | Args: |
| | number: a positive integer number |
| | Return: |
| | A string formatted according to the pattern described above. |
| | """ |
| | assert number >= 0 |
| | labels = PARAMETER_NUM_UNITS |
| | num_digits = int(np.floor(np.log10(number)) + 1 if number > 0 else 1) |
| | num_groups = int(np.ceil(num_digits / 3)) |
| | num_groups = min(num_groups, len(labels)) |
| | shift = -3 * (num_groups - 1) |
| | number = number * (10**shift) |
| | index = num_groups - 1 |
| | if index < 1 or number >= 100: |
| | return f"{int(number):,d} {labels[index]}" |
| | else: |
| | return f"{number:,.1f} {labels[index]}" |
| |
|