_vggt / training /trainer.py
CgvKodai's picture
Upload folder using huggingface_hub
66003a2 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
# --- Environment Variable Setup for Performance and Debugging ---
# Helps with memory fragmentation in PyTorch's memory allocator.
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
# Specifies the threading layer for MKL, can prevent hangs in some environments.
os.environ["MKL_THREADING_LAYER"] = "GNU"
# Provides full Hydra stack traces on error for easier debugging.
os.environ["HYDRA_FULL_ERROR"] = "1"
# Enables asynchronous error handling for NCCL, which can prevent hangs.
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "1"
import contextlib
import gc
import json
import logging
import math
import time
from datetime import timedelta
from typing import Any, Dict, List, Mapping, Optional, Sequence
import torch
import torch.distributed as dist
import torch.nn as nn
import torchvision
from hydra.utils import instantiate
from iopath.common.file_io import g_pathmgr
from train_utils.checkpoint import DDPCheckpointSaver
from train_utils.distributed import get_machine_local_and_dist_rank
from train_utils.freeze import freeze_modules
from train_utils.general import *
from train_utils.logging import setup_logging
from train_utils.normalization import normalize_camera_extrinsics_and_points_batch
from train_utils.optimizer import construct_optimizers
class Trainer:
"""
A generic trainer for DDP training. This should naturally support multi-node training.
This class orchestrates the entire training and validation process, including:
- Setting up the distributed environment (DDP).
- Initializing the model, optimizers, loss functions, and data loaders.
- Handling checkpointing for resuming training.
- Executing the main training and validation loops.
- Logging metrics and visualizations to TensorBoard.
"""
EPSILON = 1e-8
def __init__(
self,
*,
data: Dict[str, Any],
model: Dict[str, Any],
logging: Dict[str, Any],
checkpoint: Dict[str, Any],
max_epochs: int,
mode: str = "train",
device: str = "cuda",
seed_value: int = 123,
val_epoch_freq: int = 1,
distributed: Dict[str, bool] = None,
cuda: Dict[str, bool] = None,
limit_train_batches: Optional[int] = None,
limit_val_batches: Optional[int] = None,
optim: Optional[Dict[str, Any]] = None,
loss: Optional[Dict[str, Any]] = None,
env_variables: Optional[Dict[str, Any]] = None,
accum_steps: int = 1,
**kwargs,
):
"""
Initializes the Trainer.
Args:
data: Hydra config for datasets and dataloaders.
model: Hydra config for the model.
logging: Hydra config for logging (TensorBoard, log frequencies).
checkpoint: Hydra config for checkpointing.
max_epochs: Total number of epochs to train.
mode: "train" for training and validation, "val" for validation only.
device: "cuda" or "cpu".
seed_value: A random seed for reproducibility.
val_epoch_freq: Frequency (in epochs) to run validation.
distributed: Hydra config for DDP settings.
cuda: Hydra config for CUDA-specific settings (e.g., cuDNN).
limit_train_batches: Limit the number of training batches per epoch (for debugging).
limit_val_batches: Limit the number of validation batches per epoch (for debugging).
optim: Hydra config for optimizers and schedulers.
loss: Hydra config for the loss function.
env_variables: Dictionary of environment variables to set.
accum_steps: Number of steps to accumulate gradients before an optimizer step.
"""
self._setup_env_variables(env_variables)
self._setup_timers()
# Store Hydra configurations
self.data_conf = data
self.model_conf = model
self.loss_conf = loss
self.logging_conf = logging
self.checkpoint_conf = checkpoint
self.optim_conf = optim
# Store hyperparameters
self.accum_steps = accum_steps
self.max_epochs = max_epochs
self.mode = mode
self.val_epoch_freq = val_epoch_freq
self.limit_train_batches = limit_train_batches
self.limit_val_batches = limit_val_batches
self.seed_value = seed_value
# 'where' tracks training progress from 0.0 to 1.0 for schedulers
self.where = 0.0
self._setup_device(device)
self._setup_torch_dist_and_backend(cuda, distributed)
# Setup logging directory and configure logger
safe_makedirs(self.logging_conf.log_dir)
setup_logging(
__name__,
output_dir=self.logging_conf.log_dir,
rank=self.rank,
log_level_primary=self.logging_conf.log_level_primary,
log_level_secondary=self.logging_conf.log_level_secondary,
all_ranks=self.logging_conf.all_ranks,
)
set_seeds(seed_value, self.max_epochs, self.distributed_rank)
assert is_dist_avail_and_initialized(), "Torch distributed needs to be initialized before calling the trainer."
# Instantiate components (model, loss, etc.)
self._setup_components()
self._setup_dataloaders()
# Move model to the correct device
self.model.to(self.device)
self.time_elapsed_meter = DurationMeter("Time Elapsed", self.device, ":.4f")
# Construct optimizers (after moving model to device)
if self.mode != "val":
self.optims = construct_optimizers(self.model, self.optim_conf)
# Load checkpoint if available or specified
if self.checkpoint_conf.resume_checkpoint_path is not None:
self._load_resuming_checkpoint(self.checkpoint_conf.resume_checkpoint_path)
else:
ckpt_path = get_resume_checkpoint(self.checkpoint_conf.save_dir)
if ckpt_path is not None:
self._load_resuming_checkpoint(ckpt_path)
# Wrap the model with DDP
self._setup_ddp_distributed_training(distributed, device)
# Barrier to ensure all processes are synchronized before starting
dist.barrier()
def _setup_timers(self):
"""Initializes timers for tracking total elapsed time."""
self.start_time = time.time()
self.ckpt_time_elapsed = 0
def _setup_env_variables(self, env_variables_conf: Optional[Dict[str, Any]]) -> None:
"""Sets environment variables from the configuration."""
if env_variables_conf:
for variable_name, value in env_variables_conf.items():
os.environ[variable_name] = value
logging.info(f"Environment:\n{json.dumps(dict(os.environ), sort_keys=True, indent=2)}")
def _setup_torch_dist_and_backend(self, cuda_conf: Dict, distributed_conf: Dict) -> None:
"""Initializes the distributed process group and configures PyTorch backends."""
if torch.cuda.is_available():
# Configure CUDA backend settings for performance
torch.backends.cudnn.deterministic = cuda_conf.cudnn_deterministic
torch.backends.cudnn.benchmark = cuda_conf.cudnn_benchmark
torch.backends.cuda.matmul.allow_tf32 = cuda_conf.allow_tf32
torch.backends.cudnn.allow_tf32 = cuda_conf.allow_tf32
# Initialize the DDP process group
dist.init_process_group(
backend=distributed_conf.backend,
timeout=timedelta(minutes=distributed_conf.timeout_mins)
)
self.rank = dist.get_rank()
def _load_resuming_checkpoint(self, ckpt_path: str):
"""Loads a checkpoint from the given path to resume training."""
logging.info(f"Resuming training from {ckpt_path} (rank {self.rank})")
with g_pathmgr.open(ckpt_path, "rb") as f:
checkpoint = torch.load(f, map_location="cpu")
# Load model state
model_state_dict = checkpoint["model"] if "model" in checkpoint else checkpoint
missing, unexpected = self.model.load_state_dict(
model_state_dict, strict=self.checkpoint_conf.strict
)
if self.rank == 0:
logging.info(f"Model state loaded. Missing keys: {missing or 'None'}. Unexpected keys: {unexpected or 'None'}.")
# Load optimizer state if available and in training mode
if "optimizer" in checkpoint:
logging.info(f"Loading optimizer state dict (rank {self.rank})")
self.optims.optimizer.load_state_dict(checkpoint["optimizer"])
# Load training progress
if "epoch" in checkpoint:
self.epoch = checkpoint["epoch"]
self.steps = checkpoint["steps"] if "steps" in checkpoint else {"train": 0, "val": 0}
self.ckpt_time_elapsed = checkpoint.get("time_elapsed", 0)
# Load AMP scaler state if available
if self.optim_conf.amp.enabled and "scaler" in checkpoint:
self.scaler.load_state_dict(checkpoint["scaler"])
def _setup_device(self, device: str):
"""Sets up the device for training (CPU or CUDA)."""
self.local_rank, self.distributed_rank = get_machine_local_and_dist_rank()
if device == "cuda":
self.device = torch.device("cuda", self.local_rank)
torch.cuda.set_device(self.local_rank)
elif device == "cpu":
self.device = torch.device("cpu")
else:
raise ValueError(f"Unsupported device: {device}")
def _setup_components(self):
"""Initializes all core training components using Hydra configs."""
logging.info("Setting up components: Model, Loss, Logger, etc.")
self.epoch = 0
self.steps = {'train': 0, 'val': 0}
# Instantiate components from configs
self.tb_writer = instantiate(self.logging_conf.tensorboard_writer, _recursive_=False)
self.model = instantiate(self.model_conf, _recursive_=False)
self.loss = instantiate(self.loss_conf, _recursive_=False)
self.gradient_clipper = instantiate(self.optim_conf.gradient_clip)
self.scaler = torch.cuda.amp.GradScaler(enabled=self.optim_conf.amp.enabled)
# Freeze specified model parameters if any
if getattr(self.optim_conf, "frozen_module_names", None):
logging.info(
f"[Start] Freezing modules: {self.optim_conf.frozen_module_names} on rank {self.distributed_rank}"
)
self.model = freeze_modules(
self.model,
patterns=self.optim_conf.frozen_module_names,
)
logging.info(
f"[Done] Freezing modules: {self.optim_conf.frozen_module_names} on rank {self.distributed_rank}"
)
# Log model summary on rank 0
if self.rank == 0:
model_summary_path = os.path.join(self.logging_conf.log_dir, "model.txt")
model_summary(self.model, log_file=model_summary_path)
logging.info(f"Model summary saved to {model_summary_path}")
logging.info("Successfully initialized training components.")
def _setup_dataloaders(self):
"""Initializes train and validation datasets and dataloaders."""
self.train_dataset = None
self.val_dataset = None
if self.mode in ["train", "val"]:
self.val_dataset = instantiate(
self.data_conf.get('val', None), _recursive_=False
)
if self.val_dataset is not None:
self.val_dataset.seed = self.seed_value
if self.mode in ["train"]:
self.train_dataset = instantiate(self.data_conf.train, _recursive_=False)
self.train_dataset.seed = self.seed_value
def _setup_ddp_distributed_training(self, distributed_conf: Dict, device: str):
"""Wraps the model with DistributedDataParallel (DDP)."""
assert isinstance(self.model, torch.nn.Module)
ddp_options = dict(
find_unused_parameters=distributed_conf.find_unused_parameters,
gradient_as_bucket_view=distributed_conf.gradient_as_bucket_view,
bucket_cap_mb=distributed_conf.bucket_cap_mb,
broadcast_buffers=distributed_conf.broadcast_buffers,
)
self.model = nn.parallel.DistributedDataParallel(
self.model,
device_ids=[self.local_rank] if device == "cuda" else [],
**ddp_options,
)
def save_checkpoint(self, epoch: int, checkpoint_names: Optional[List[str]] = None):
"""
Saves a training checkpoint.
Args:
epoch: The current epoch number.
checkpoint_names: A list of names for the checkpoint file (e.g., "checkpoint_latest").
If None, saves "checkpoint" and "checkpoint_{epoch}" on frequency.
"""
checkpoint_folder = self.checkpoint_conf.save_dir
safe_makedirs(checkpoint_folder)
if checkpoint_names is None:
checkpoint_names = ["checkpoint"]
if (
self.checkpoint_conf.save_freq > 0
and int(epoch) % self.checkpoint_conf.save_freq == 0
and (int(epoch) > 0 or self.checkpoint_conf.save_freq == 1)
):
checkpoint_names.append(f"checkpoint_{int(epoch)}")
checkpoint_content = {
"prev_epoch": epoch,
"steps": self.steps,
"time_elapsed": self.time_elapsed_meter.val,
"optimizer": [optim.optimizer.state_dict() for optim in self.optims],
}
if len(self.optims) == 1:
checkpoint_content["optimizer"] = checkpoint_content["optimizer"][0]
if self.optim_conf.amp.enabled:
checkpoint_content["scaler"] = self.scaler.state_dict()
# Save the checkpoint for DDP only
saver = DDPCheckpointSaver(
checkpoint_folder,
checkpoint_names=checkpoint_names,
rank=self.distributed_rank,
epoch=epoch,
)
if isinstance(self.model, torch.nn.parallel.DistributedDataParallel):
model = self.model.module
saver.save_checkpoint(
model=model,
ema_models = None,
skip_saving_parameters=[],
**checkpoint_content,
)
def _get_scalar_log_keys(self, phase: str) -> List[str]:
"""Retrieves keys for scalar values to be logged for a given phase."""
if self.logging_conf.scalar_keys_to_log:
return self.logging_conf.scalar_keys_to_log[phase].keys_to_log
return []
def run(self):
"""Main entry point to start the training or validation process."""
assert self.mode in ["train", "val"], f"Invalid mode: {self.mode}"
if self.mode == "train":
self.run_train()
# Optionally run a final validation after all training is done
self.run_val()
elif self.mode == "val":
self.run_val()
else:
raise ValueError(f"Invalid mode: {self.mode}")
def run_train(self):
"""Runs the main training loop over all epochs."""
while self.epoch < self.max_epochs:
set_seeds(self.seed_value + self.epoch * 100, self.max_epochs, self.distributed_rank)
dataloader = self.train_dataset.get_loader(epoch=int(self.epoch + self.distributed_rank))
self.train_epoch(dataloader)
# Save checkpoint after each training epoch
self.save_checkpoint(self.epoch)
# Clean up memory
del dataloader
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
# Run validation at the specified frequency
# Skips validation after the last training epoch, as it can be run separately.
if self.epoch % self.val_epoch_freq == 0 and self.epoch < self.max_epochs - 1:
self.run_val()
self.epoch += 1
self.epoch -= 1
def run_val(self):
"""Runs a full validation epoch if a validation dataset is available."""
if not self.val_dataset:
logging.info("No validation dataset configured. Skipping validation.")
return
dataloader = self.val_dataset.get_loader(epoch=int(self.epoch + self.distributed_rank))
self.val_epoch(dataloader)
del dataloader
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
@torch.no_grad()
def val_epoch(self, val_loader):
batch_time = AverageMeter("Batch Time", self.device, ":.4f")
data_time = AverageMeter("Data Time", self.device, ":.4f")
mem = AverageMeter("Mem (GB)", self.device, ":.4f")
data_times = []
phase = 'val'
loss_names = self._get_scalar_log_keys(phase)
loss_names = [f"Loss/{phase}_{name}" for name in loss_names]
loss_meters = {
name: AverageMeter(name, self.device, ":.4f") for name in loss_names
}
progress = ProgressMeter(
num_batches=len(val_loader),
meters=[
batch_time,
data_time,
mem,
self.time_elapsed_meter,
*loss_meters.values(),
],
real_meters={},
prefix="Val Epoch: [{}]".format(self.epoch),
)
self.model.eval()
end = time.time()
iters_per_epoch = len(val_loader)
limit_val_batches = (
iters_per_epoch
if self.limit_val_batches is None
else self.limit_val_batches
)
for data_iter, batch in enumerate(val_loader):
if data_iter > limit_val_batches:
break
# measure data loading time
data_time.update(time.time() - end)
data_times.append(data_time.val)
with torch.cuda.amp.autocast(enabled=False):
batch = self._process_batch(batch)
batch = copy_data_to_device(batch, self.device, non_blocking=True)
amp_type = self.optim_conf.amp.amp_dtype
assert amp_type in ["bfloat16", "float16"], f"Invalid Amp type: {amp_type}"
if amp_type == "bfloat16":
amp_type = torch.bfloat16
else:
amp_type = torch.float16
# compute output
with torch.no_grad():
with torch.cuda.amp.autocast(
enabled=self.optim_conf.amp.enabled,
dtype=amp_type,
):
val_loss_dict = self._step(
batch, self.model, phase, loss_meters
)
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
self.time_elapsed_meter.update(
time.time() - self.start_time + self.ckpt_time_elapsed
)
if torch.cuda.is_available():
mem.update(torch.cuda.max_memory_allocated() // 1e9)
if data_iter % self.logging_conf.log_freq == 0:
progress.display(data_iter)
return True
def train_epoch(self, train_loader):
batch_time = AverageMeter("Batch Time", self.device, ":.4f")
data_time = AverageMeter("Data Time", self.device, ":.4f")
mem = AverageMeter("Mem (GB)", self.device, ":.4f")
data_times = []
phase = 'train'
loss_names = self._get_scalar_log_keys(phase)
loss_names = [f"Loss/{phase}_{name}" for name in loss_names]
loss_meters = {
name: AverageMeter(name, self.device, ":.4f") for name in loss_names
}
for config in self.gradient_clipper.configs:
param_names = ",".join(config['module_names'])
loss_meters[f"Grad/{param_names}"] = AverageMeter(f"Grad/{param_names}", self.device, ":.4f")
progress = ProgressMeter(
num_batches=len(train_loader),
meters=[
batch_time,
data_time,
mem,
self.time_elapsed_meter,
*loss_meters.values(),
],
real_meters={},
prefix="Train Epoch: [{}]".format(self.epoch),
)
self.model.train()
end = time.time()
iters_per_epoch = len(train_loader)
limit_train_batches = (
iters_per_epoch
if self.limit_train_batches is None
else self.limit_train_batches
)
if self.gradient_clipper is not None:
# setup gradient clipping at the beginning of training
self.gradient_clipper.setup_clipping(self.model)
for data_iter, batch in enumerate(train_loader):
if data_iter > limit_train_batches:
break
# measure data loading time
data_time.update(time.time() - end)
data_times.append(data_time.val)
with torch.cuda.amp.autocast(enabled=False):
batch = self._process_batch(batch)
batch = copy_data_to_device(batch, self.device, non_blocking=True)
accum_steps = self.accum_steps
if accum_steps==1:
chunked_batches = [batch]
else:
chunked_batches = chunk_batch_for_accum_steps(batch, accum_steps)
self._run_steps_on_batch_chunks(
chunked_batches, phase, loss_meters
)
# compute gradient and do SGD step
assert data_iter <= limit_train_batches # allow for off by one errors
exact_epoch = self.epoch + float(data_iter) / limit_train_batches
self.where = float(exact_epoch) / self.max_epochs
assert self.where <= 1 + self.EPSILON
if self.where < 1.0:
for optim in self.optims:
optim.step_schedulers(self.where)
else:
logging.warning(
f"Skipping scheduler update since the training is at the end, i.e, {self.where} of [0,1]."
)
# Log schedulers
if self.steps[phase] % self.logging_conf.log_freq == 0:
for i, optim in enumerate(self.optims):
for j, param_group in enumerate(optim.optimizer.param_groups):
for option in optim.schedulers[j]:
optim_prefix = (
f"{i}_"
if len(self.optims) > 1
else (
"" + f"{j}_"
if len(optim.optimizer.param_groups) > 1
else ""
)
)
self.tb_writer.log(
os.path.join("Optim", f"{optim_prefix}", option),
param_group[option],
self.steps[phase],
)
self.tb_writer.log(
os.path.join("Optim", "where"),
self.where,
self.steps[phase],
)
# Clipping gradients and detecting diverging gradients
if self.gradient_clipper is not None:
for optim in self.optims:
self.scaler.unscale_(optim.optimizer)
grad_norm_dict = self.gradient_clipper(model=self.model)
for key, grad_norm in grad_norm_dict.items():
loss_meters[f"Grad/{key}"].update(grad_norm)
# Optimizer step
for optim in self.optims:
self.scaler.step(optim.optimizer)
self.scaler.update()
# Measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
self.time_elapsed_meter.update(
time.time() - self.start_time + self.ckpt_time_elapsed
)
mem.update(torch.cuda.max_memory_allocated() // 1e9)
if data_iter % self.logging_conf.log_freq == 0:
progress.display(data_iter)
return True
def _run_steps_on_batch_chunks(
self,
chunked_batches: List[Any],
phase: str,
loss_meters: Dict[str, AverageMeter],
):
"""
Run the forward / backward as many times as there are chunks in the batch,
accumulating the gradients on each backward
"""
for optim in self.optims:
optim.zero_grad(set_to_none=True)
accum_steps = len(chunked_batches)
amp_type = self.optim_conf.amp.amp_dtype
assert amp_type in ["bfloat16", "float16"], f"Invalid Amp type: {amp_type}"
if amp_type == "bfloat16":
amp_type = torch.bfloat16
else:
amp_type = torch.float16
for i, chunked_batch in enumerate(chunked_batches):
ddp_context = (
self.model.no_sync()
if i < accum_steps - 1
else contextlib.nullcontext()
)
with ddp_context:
with torch.cuda.amp.autocast(
enabled=self.optim_conf.amp.enabled,
dtype=amp_type,
):
loss_dict = self._step(
chunked_batch, self.model, phase, loss_meters
)
loss = loss_dict["objective"]
loss_key = f"Loss/{phase}_loss_objective"
batch_size = chunked_batch["images"].shape[0]
if not math.isfinite(loss.item()):
error_msg = f"Loss is {loss.item()}, attempting to stop training"
logging.error(error_msg)
return
loss /= accum_steps
self.scaler.scale(loss).backward()
loss_meters[loss_key].update(loss.item(), batch_size)
def _apply_batch_repetition(self, batch: Mapping) -> Mapping:
"""
Applies a data augmentation by concatenating the original batch with a
flipped version of itself.
"""
tensor_keys = [
"images", "depths", "extrinsics", "intrinsics",
"cam_points", "world_points", "point_masks",
]
string_keys = ["seq_name"]
for key in tensor_keys:
if key in batch:
original_tensor = batch[key]
batch[key] = torch.concatenate([original_tensor,
torch.flip(original_tensor, dims=[1])],
dim=0)
for key in string_keys:
if key in batch:
batch[key] = batch[key] * 2
return batch
def _process_batch(self, batch: Mapping):
if self.data_conf.train.common_config.repeat_batch:
batch = self._apply_batch_repetition(batch)
# Normalize camera extrinsics and points. The function returns new tensors.
normalized_extrinsics, normalized_cam_points, normalized_world_points, normalized_depths = \
normalize_camera_extrinsics_and_points_batch(
extrinsics=batch["extrinsics"],
cam_points=batch["cam_points"],
world_points=batch["world_points"],
depths=batch["depths"],
point_masks=batch["point_masks"],
)
# Replace the original values in the batch with the normalized ones.
batch["extrinsics"] = normalized_extrinsics
batch["cam_points"] = normalized_cam_points
batch["world_points"] = normalized_world_points
batch["depths"] = normalized_depths
return batch
def _step(self, batch, model: nn.Module, phase: str, loss_meters: dict):
"""
Performs a single forward pass, computes loss, and logs results.
Returns:
A dictionary containing the computed losses.
"""
# Forward pass
y_hat = model(images=batch["images"])
# Loss computation
loss_dict = self.loss(y_hat, batch)
# Combine all data for logging
log_data = {**y_hat, **loss_dict, **batch}
self._update_and_log_scalars(log_data, phase, self.steps[phase], loss_meters)
self._log_tb_visuals(log_data, phase, self.steps[phase])
self.steps[phase] += 1
return loss_dict
def _update_and_log_scalars(self, data: Mapping, phase: str, step: int, loss_meters: dict):
"""Updates average meters and logs scalar values to TensorBoard."""
keys_to_log = self._get_scalar_log_keys(phase)
batch_size = data['extrinsics'].shape[0]
for key in keys_to_log:
if key in data:
value = data[key].item() if torch.is_tensor(data[key]) else data[key]
loss_meters[f"Loss/{phase}_{key}"].update(value, batch_size)
if step % self.logging_conf.log_freq == 0 and self.rank == 0:
self.tb_writer.log(f"Values/{phase}/{key}", value, step)
def _log_tb_visuals(self, batch: Mapping, phase: str, step: int) -> None:
"""Logs image or video visualizations to TensorBoard."""
if not (
self.logging_conf.log_visuals
and (phase in self.logging_conf.log_visual_frequency)
and self.logging_conf.log_visual_frequency[phase] > 0
and (step % self.logging_conf.log_visual_frequency[phase] == 0)
and (self.logging_conf.visuals_keys_to_log is not None)
):
return
if phase in self.logging_conf.visuals_keys_to_log:
keys_to_log = self.logging_conf.visuals_keys_to_log[phase][
"keys_to_log"
]
assert (
len(keys_to_log) > 0
), "Need to include some visual keys to log"
modality = self.logging_conf.visuals_keys_to_log[phase][
"modality"
]
assert modality in [
"image",
"video",
], "Currently only support video or image logging"
name = f"Visuals/{phase}"
visuals_to_log = torchvision.utils.make_grid(
[
torchvision.utils.make_grid(
batch[key][0], # Ensure batch[key][0] is tensor and has at least 3 dimensions
nrow=self.logging_conf.visuals_per_batch_to_log,
)
for key in keys_to_log if key in batch and batch[key][0].dim() >= 3
],
nrow=1,
).clamp(-1, 1)
visuals_to_log = visuals_to_log.cpu()
if visuals_to_log.dtype == torch.bfloat16:
visuals_to_log = visuals_to_log.to(torch.float16)
visuals_to_log = visuals_to_log.numpy()
self.tb_writer.log_visuals(
name, visuals_to_log, step, self.logging_conf.video_logging_fps
)
def chunk_batch_for_accum_steps(batch: Mapping, accum_steps: int) -> List[Mapping]:
"""Splits a batch into smaller chunks for gradient accumulation."""
if accum_steps == 1:
return [batch]
return [get_chunk_from_data(batch, i, accum_steps) for i in range(accum_steps)]
def is_sequence_of_primitives(data: Any) -> bool:
"""Checks if data is a sequence of primitive types (str, int, float, bool)."""
return (
isinstance(data, Sequence)
and not isinstance(data, str)
and len(data) > 0
and isinstance(data[0], (str, int, float, bool))
)
def get_chunk_from_data(data: Any, chunk_id: int, num_chunks: int) -> Any:
"""
Recursively splits tensors and sequences within a data structure into chunks.
Args:
data: The data structure to split (e.g., a dictionary of tensors).
chunk_id: The index of the chunk to retrieve.
num_chunks: The total number of chunks to split the data into.
Returns:
A chunk of the original data structure.
"""
if isinstance(data, torch.Tensor) or is_sequence_of_primitives(data):
# either a tensor or a list of primitive objects
# assert len(data) % num_chunks == 0
start = (len(data) // num_chunks) * chunk_id
end = (len(data) // num_chunks) * (chunk_id + 1)
return data[start:end]
elif isinstance(data, Mapping):
return {
key: get_chunk_from_data(value, chunk_id, num_chunks)
for key, value in data.items()
}
elif isinstance(data, str):
# NOTE: this is a hack to support string keys in the batch
return data
elif isinstance(data, Sequence):
return [get_chunk_from_data(value, chunk_id, num_chunks) for value in data]
else:
return data