PRIMA-demo / prima /models /prima.py
HF Space deploy
Deploy snapshot (LFS for demo images per .gitattributes)
c3a4f1d
"""
PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
Official implementation of the paper:
"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
Licensed under a modified MIT license
"""
import torch
import pickle
import pytorch_lightning as pl
from typing import Any, Dict
from yacs.config import CfgNode
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torchvision.utils import make_grid
from ..utils.geometry import perspective_projection, aa_to_rotmat
from ..utils.pylogger import get_pylogger
from .backbones import create_backbone
from .heads import build_smal_head
from ..utils import MeshRenderer
from ..utils import renderer
from prima.models.smal_wrapper import SMAL
from .discriminator import Discriminator
from .bioclip_embedding import BioClipEmbedding
import sys
from transformers import AutoModel, AutoFeatureExtractor
import einops
import open_clip
from .losses import Keypoint3DLoss, Keypoint2DLoss, ParameterLoss, ShapePriorLoss, PosePriorLoss, SupConLoss
log = get_pylogger(__name__)
class PRIMA(pl.LightningModule):
def __init__(self, cfg: CfgNode, init_renderer: bool = True):
"""
Setup PRIMA model
Args:
cfg (CfgNode): Config file as a yacs CfgNode
"""
super().__init__()
# Save hyperparameters
self.save_hyperparameters(logger=False, ignore=['init_renderer'])
self.cfg = cfg
# Create backbone feature extractor
if cfg.MODEL.BACKBONE.TYPE =='vith':
self.backbone = create_backbone(cfg) # create vit backbone anyway, for inference, no config loading, just load ckpt weights
if cfg.MODEL.BACKBONE.get('PRETRAINED_WEIGHTS', None): # pretrained exists and not none, then true
log.info(f'Loading backbone weights from {cfg.MODEL.BACKBONE.PRETRAINED_WEIGHTS}')
state_dict = torch.load(cfg.MODEL.BACKBONE.PRETRAINED_WEIGHTS, map_location='cpu', weights_only=True)['state_dict']
state_dict = {k.replace('backbone.', ''): v for k, v in state_dict.items()}
missing_keys, unexpected_keys = self.backbone.load_state_dict(state_dict, strict=False)
# freeze backbones
if cfg.MODEL.BACKBONE.get('FREEZE', False) and cfg.MODEL.BACKBONE.TYPE == 'vith':
log.info(f'Freezing first 2/3 blocks of vit backbone')
# Freeze patch embedding
if hasattr(self.backbone, 'patch_embed'):
for p in self.backbone.patch_embed.parameters():
p.requires_grad = False
# Freeze first 2/3 of transformer blocks
if hasattr(self.backbone, 'blocks'):
total_blocks = len(self.backbone.blocks)
freeze_blocks = int(total_blocks * 2 / 3)
log.info(f'Freezing {freeze_blocks} out of {total_blocks} blocks')
for i in range(freeze_blocks):
for p in self.backbone.blocks[i].parameters():
p.requires_grad = False
# Create SMAL head (predicts SMAL params + perspective camera)
self.smal_head = build_smal_head(cfg)
# Instantiate SMAL model
smal_model_path = cfg.SMAL.MODEL_PATH
with open(smal_model_path, 'rb') as f:
smal_cfg = pickle.load(f, encoding="latin1")
self.smal = SMAL(**smal_cfg)
# create bioclip model for species classification token extraction
use_bioclip_embedding = cfg.MODEL.get('USE_BIOCLIP_EMBEDDING', False)
if use_bioclip_embedding:
bioclip_config = cfg.MODEL.get('BIOCLIP_EMBEDDING', {})
embed_dim = bioclip_config.get('EMBED_DIM', 1280)
self.bioclip_embedding = BioClipEmbedding(cfg, embed_dim=embed_dim)
# Freeze BioClip model by default
for param in self.bioclip_embedding.species_model.parameters():
param.requires_grad = False
else:
self.bioclip_embedding = None
# Create discriminator
self.discriminator = Discriminator()
# Define loss functions
self.keypoint_3d_loss = Keypoint3DLoss(loss_type='l1')
self.keypoint_2d_loss = Keypoint2DLoss(loss_type='l1')
if self.cfg.LOSS_WEIGHTS.get('INTERMEDIATE_KP2D', 0) > 0:
self.intermediate_kp2d_loss = Keypoint2DLoss(loss_type='l1')
if self.cfg.LOSS_WEIGHTS.get('INTERMEDIATE_KP3D', 0) > 0:
self.intermediate_kp3d_loss = Keypoint3DLoss(loss_type='l1')
self.smal_parameter_loss = ParameterLoss()
self.shape_prior_loss = ShapePriorLoss(path_prior=cfg.SMAL.SHAPE_PRIOR_PATH)
self.pose_prior_loss = PosePriorLoss(path_prior=cfg.SMAL.POSE_PRIOR_PATH)
self.supcon_loss = SupConLoss()
self.register_buffer('initialized', torch.tensor(False))
# init depth renderer for supervised training
# Setup renderer for visualization
if init_renderer:
self.mesh_renderer = MeshRenderer(self.cfg, faces=self.smal.faces.numpy())
else:
self.mesh_renderer = None
# Disable automatic optimization since we use adversarial training
self.automatic_optimization = False
def get_parameters(self):
all_params = list(self.smal_head.parameters())
if self.cfg.MODEL.BACKBONE.TYPE in ['vith', 'dinov2', 'dinov3']:
all_params += list(self.backbone.parameters())
if hasattr(self, 'keypoint_projection') and self.keypoint_projection is not None:
all_params += list(self.keypoint_projection.parameters())
if hasattr(self, 'bioclip_embedding') and self.bioclip_embedding is not None:
# Only add projection parameters as the model itself is frozen
all_params += list(self.bioclip_embedding.projection.parameters())
return all_params
def configure_optimizers(self):
"""
Setup model and discriminator Optimizers
Returns:
Tuple[torch.optim.Optimizer, torch.optim.Optimizer]: Model and discriminator optimizers
"""
# Use separate learning rates only for vith backbone
if self.cfg.MODEL.BACKBONE.TYPE == 'vith':
# Separate backbone parameters and other parameters
backbone_params = []
other_params = []
# Collect backbone parameters
if hasattr(self, 'backbone'):
backbone_params = list(filter(lambda p: p.requires_grad, self.backbone.parameters()))
# Collect other parameters
other_params += list(self.smal_head.parameters())
if hasattr(self, 'keypoint_projection') and self.keypoint_projection is not None:
other_params += list(self.keypoint_projection.parameters())
if hasattr(self, 'bioclip_embedding') and self.bioclip_embedding is not None:
other_params += list(self.bioclip_embedding.projection.parameters())
# Filter only trainable parameters
other_params = list(filter(lambda p: p.requires_grad, other_params))
# Create parameter groups with different learning rates
param_groups = [
{'params': backbone_params, 'lr': self.cfg.TRAIN.LR / 10.0}, # Backbone: 1/10 lr
{'params': other_params, 'lr': self.cfg.TRAIN.LR} # Other modules: normal lr
]
log.info(f'Using separate LR for vith backbone')
log.info(f'Backbone parameters: {len(backbone_params)}, lr={self.cfg.TRAIN.LR / 10.0}')
log.info(f'Other parameters: {len(other_params)}, lr={self.cfg.TRAIN.LR}')
else:
# Use same learning rate for all parameters
all_params = list(filter(lambda p: p.requires_grad, self.get_parameters()))
param_groups = [{'params': all_params, 'lr': self.cfg.TRAIN.LR}]
log.info(f'Using same LR for all parameters: {len(all_params)}, lr={self.cfg.TRAIN.LR}')
optimizer = torch.optim.AdamW(params=param_groups,
weight_decay=self.cfg.TRAIN.WEIGHT_DECAY)
if self.cfg.LOSS_WEIGHTS.get("ADVERSARIAL", 0) > 0:
optimizer_disc = torch.optim.AdamW(params=self.discriminator.parameters(),
lr=self.cfg.TRAIN.LR,
weight_decay=self.cfg.TRAIN.WEIGHT_DECAY)
else:
return optimizer,
return optimizer, optimizer_disc
def forward_step(self, batch: Dict, train: bool = False) -> Dict:
"""
Run a forward step of the network
Args:
batch (Dict): Dictionary containing batch data
train (bool): Flag indicating whether it is training or validation mode
Returns:
Dict: Dictionary containing the regression output
"""
# Use RGB image as input
x = batch['img'] # [B, 3, H, W]
batch_size = x.shape[0]
# Compute conditioning features using the backbone
if self.cfg.MODEL.BACKBONE.TYPE =='vith': # vit backbone return [1, 1280, 12, 16]
conditioning_feats, cls = self.backbone(x[:, :, :, 32:-32]) # reshape the input into [256, 192]
# return shape shape [B, D, Hp, Wp], [B, D]
if conditioning_feats.ndim == 4:
# Flatten spatial dimensions into sequence dimension: [B, D, Hp, Wp] -> [B, Hp*Wp, D]
B, D, Hp, Wp = conditioning_feats.shape
conditioning_feats = conditioning_feats.permute(0, 2, 3, 1).reshape(B, Hp * Wp, D) # [B, Hp*Wp, D]
# add bioclip embedding if enabled
if self.bioclip_embedding is not None:
species_feature = self.bioclip_embedding(batch['img']) # [B, embed_dim]
# concatenate species feature to conditioning_feats along token dimension
if len(conditioning_feats.shape) == 3:
# Token-wise concatenation: add species_feature as a single token
# (B, embed_dim) -> (B, 1, embed_dim)
species_token = species_feature.unsqueeze(1) # (B, 1, embed_dim)
# Concatenate along token dimension: (B, num_tokens, C) + (B, 1, embed_dim) -> (B, num_tokens + 1, C or embed_dim)
# Note: This requires C == embed_dim for consistent feature dimensions
conditioning_feats = torch.cat([conditioning_feats, species_token], dim=1) # (B, num_tokens + 1, C)
else:
# If conditioning_feats is 2D (B, C), concat directly along feature dimension
conditioning_feats = torch.cat([conditioning_feats, species_feature], dim=-1)
# Predict SMAL parameters and camera
pred_smal_params, pred_cam, extra_outputs = self.smal_head(conditioning_feats)
# Store useful regression outputs to the output dict
output = {}
if 'shape_feat' in extra_outputs:
output['shape_feat'] = extra_outputs['shape_feat']
if 'init_betas' in extra_outputs:
output['init_betas'] = extra_outputs['init_betas'].reshape(batch_size, -1)
output['pred_cam'] = pred_cam # [B, 3]
output['pred_smal_params'] = {k: v.clone() for k, v in pred_smal_params.items()}
# Compute camera translation
focal_length = batch['focal_length']
pred_cam_t = torch.stack([
pred_cam[:, 1],
pred_cam[:, 2],
2 * focal_length[:, 0] / (self.cfg.MODEL.IMAGE_SIZE * pred_cam[:, 0] + 1e-9)
], dim=-1) # [B, 3]
output['pred_cam_t'] = pred_cam_t # [B, 3]
output['focal_length'] = focal_length # [B, 2]
# Compute model vertices, joints and the projected joints
pred_smal_params['global_orient'] = pred_smal_params['global_orient'].reshape(batch_size, -1, 3, 3)
pred_smal_params['pose'] = pred_smal_params['pose'].reshape(batch_size, -1, 3, 3)
pred_smal_params['betas'] = pred_smal_params['betas'].reshape(batch_size, -1)
smal_output = self.smal(**pred_smal_params, pose2rot=False)
pred_keypoints_3d = smal_output.joints
pred_vertices = smal_output.vertices
output['pred_keypoints_3d'] = pred_keypoints_3d.reshape(batch_size, -1, 3)
output['pred_vertices'] = pred_vertices.reshape(batch_size, -1, 3)
# project 3D keypoints to 2D
pred_keypoints_2d = perspective_projection(
pred_keypoints_3d,
translation=pred_cam_t,
focal_length=focal_length / self.cfg.MODEL.IMAGE_SIZE
) # [B, num_joints, 2]
output['pred_keypoints_2d'] = pred_keypoints_2d
# get intermediate keypoint predictions if available
if 'keypoints_3d' in pred_smal_params and pred_smal_params['keypoints_3d'] is not None:
inter_keypoints_3d = pred_smal_params['keypoints_3d']
output['inter_keypoints_3d'] = inter_keypoints_3d.reshape(batch_size, -1, 3)
# output['use_intermediate_kp3d_loss'] = True
if 'keypoints_2d' in pred_smal_params and pred_smal_params['keypoints_2d'] is not None:
inter_keypoints_2d = pred_smal_params['keypoints_2d']
output['inter_keypoints_2d'] = inter_keypoints_2d.reshape(batch_size, -1, 2)
# output['use_intermediate_kp2d_loss'] = True
return output
def compute_loss(self, batch: Dict, output: Dict, train: bool = True) -> torch.Tensor:
"""
Compute losses given the input batch and the regression output
Args:
batch (Dict): Dictionary containing batch data
output (Dict): Dictionary containing the regression output
train (bool): Flag indicating whether it is training or validation mode
Returns:
torch.Tensor : Total loss for current batch
"""
pred_smal_params = output['pred_smal_params']
pred_keypoints_2d = output['pred_keypoints_2d']
pred_keypoints_3d = output['pred_keypoints_3d']
if 'inter_keypoints_2d' in output:
inter_keypoints_2d = output['inter_keypoints_2d']
if 'inter_keypoints_3d' in output:
inter_keypoints_3d = output['inter_keypoints_3d']
batch_size = pred_smal_params['pose'].shape[0]
device = pred_smal_params['pose'].device
dtype = pred_smal_params['pose'].dtype
# Get annotations
gt_keypoints_2d = batch['keypoints_2d']
gt_keypoints_3d = batch['keypoints_3d']
gt_smal_params = batch['smal_params']
gt_mask = batch['mask']
has_smal_params = batch['has_smal_params']
is_axis_angle = batch['smal_params_is_axis_angle']
has_mask = batch['has_mask']
# Compute 2D keypoint loss
loss_keypoints_2d = self.keypoint_2d_loss(pred_keypoints_2d, gt_keypoints_2d)
# Compute 3D keypoint loss
loss_keypoints_3d = self.keypoint_3d_loss(pred_keypoints_3d, gt_keypoints_3d, pelvis_id=0)
# Compute intermediate 2D keypoint loss if available
loss_intermediate_kp2d = torch.tensor(0., device=device, dtype=dtype)
if 'inter_keypoints_2d' in output:
loss_intermediate_kp2d = self.intermediate_kp2d_loss(inter_keypoints_2d, gt_keypoints_2d)
# loss_keypoints_2d = loss_keypoints_2d + loss_intermediate_kp2d
# Compute intermediate 3D keypoint loss if available
loss_intermediate_kp3d = torch.tensor(0., device=device, dtype=dtype)
if 'inter_keypoints_3d' in output:
loss_intermediate_kp3d = self.intermediate_kp3d_loss(inter_keypoints_3d, gt_keypoints_3d, pelvis_id=0)
# loss_keypoints_3d = loss_keypoints_3d + loss_intermediate_kp3d
# add intermediate keypoint losses if available
# Compute loss on SMAL parameters
loss_smal_params = {}
for k, pred in pred_smal_params.items():
# Skip keypoint predictions - they're handled separately
if k in ['keypoints_2d', 'keypoints_3d']:
continue
gt = gt_smal_params[k].view(batch_size, -1)
if is_axis_angle[k].all():
gt = aa_to_rotmat(gt.reshape(-1, 3)).view(batch_size, -1, 3, 3)
has_gt = has_smal_params[k]
# Only compute parameter loss if ANY sample has GT
param_loss = self.smal_parameter_loss(pred.reshape(batch_size, -1),
gt.reshape(batch_size, -1),
has_gt)
if k == "betas":
# Only add shape prior loss if NOT all samples have GT (prior is regularization for samples without GT)
# But the shape_prior_loss already handles this check internally
loss_smal_params[k] = param_loss + self.shape_prior_loss(pred, batch["category"], has_gt)
if 'init_betas' in output:
init_betas = output['init_betas']
loss_smal_params[k] = loss_smal_params[k] + self.shape_prior_loss(init_betas, batch["category"], has_gt) / 2.
else:
# Only add pose prior loss if NOT all samples have GT
# The pose_prior_loss already handles this check internally
loss_smal_params[k] = param_loss + \
self.pose_prior_loss(torch.cat((pred_smal_params["global_orient"],
pred_smal_params["pose"]),
dim=1), has_gt) / 2.
if 'shape_feat' in output:
loss_supcon = self.supcon_loss(output['shape_feat'], labels=batch['category'])
else:
loss_supcon = torch.tensor(0., device=device, dtype=dtype)
loss = self.cfg.LOSS_WEIGHTS['KEYPOINTS_3D'] * loss_keypoints_3d + \
self.cfg.LOSS_WEIGHTS['KEYPOINTS_2D'] * loss_keypoints_2d + \
sum([loss_smal_params[k] * self.cfg.LOSS_WEIGHTS[k.upper()] for k in loss_smal_params]) + \
self.cfg.LOSS_WEIGHTS['SUPCON'] * loss_supcon
if 'inter_keypoints_2d' in output:
loss = loss + self.cfg.LOSS_WEIGHTS.get('INTERMEDIATE_KP2D', 0) * loss_intermediate_kp2d
if 'inter_keypoints_3d' in output:
loss = loss + self.cfg.LOSS_WEIGHTS.get('INTERMEDIATE_KP3D', 0) * loss_intermediate_kp3d
losses = dict(loss=loss.detach(),
loss_keypoints_2d=loss_keypoints_2d.detach(),
loss_keypoints_3d=loss_keypoints_3d.detach(),
loss_supcon=loss_supcon.detach(),
)
for k, v in loss_smal_params.items():
losses['loss_' + k] = v.detach()
# attach intermediate keypoint losses if computed
if 'inter_keypoints_2d' in output:
losses['loss_inter_keypoints_2d'] = loss_intermediate_kp2d.detach()
if 'inter_keypoints_3d' in output:
losses['loss_inter_keypoints_3d'] = loss_intermediate_kp3d.detach()
output['losses'] = losses
return loss
def forward(self, batch: Dict) -> Dict:
"""
Run a forward step of the network in val mode
Args:
batch (Dict): Dictionary containing batch data
Returns:
Dict: Dictionary containing the regression output
"""
return self.forward_step(batch, train=False)
def training_step_discriminator(self, batch: Dict,
pose: torch.Tensor,
betas: torch.Tensor,
optimizer: torch.optim.Optimizer) -> torch.Tensor:
"""
Run a discriminator training step
Args:
batch (Dict): Dictionary containing mocap batch data
pose (torch.Tensor): Regressed pose from current step
betas (torch.Tensor): Regressed betas from current step
optimizer (torch.optim.Optimizer): Discriminator optimizer
Returns:
torch.Tensor: Discriminator loss
"""
batch_size = pose.shape[0]
gt_pose = batch['pose']
gt_betas = batch['betas']
gt_rotmat = aa_to_rotmat(gt_pose.view(-1, 3)).view(batch_size, -1, 3, 3)
disc_fake_out = self.discriminator(pose.detach(), betas.detach())
loss_fake = ((disc_fake_out - 0.0) ** 2).sum() / batch_size
disc_real_out = self.discriminator(gt_rotmat.detach(), gt_betas.detach())
loss_real = ((disc_real_out - 1.0) ** 2).sum() / batch_size
loss_disc = loss_fake + loss_real
loss = self.cfg.LOSS_WEIGHTS.ADVERSARIAL * loss_disc
optimizer.zero_grad()
self.manual_backward(loss)
optimizer.step()
return loss_disc.detach()
# Tensoroboard logging should run from first rank only
@pl.utilities.rank_zero.rank_zero_only
def tensorboard_logging(self, batch: Dict, output: Dict, step_count: int, train: bool = True,
write_to_summary_writer: bool = True) -> None:
"""
Log results to Tensorboard
Args:
batch (Dict): Dictionary containing batch data
output (Dict): Dictionary containing the regression output
step_count (int): Global training step count
train (bool): Flag indicating whether it is training or validation mode
"""
mode = 'train' if train else 'val'
images = batch['img']
gt_keypoints_2d = batch['keypoints_2d']
batch_size = images.shape[0]
# mul std then add mean
images = (images) * (torch.tensor([0.229, 0.224, 0.225], device=images.device).reshape(1, 3, 1, 1))
images = (images + torch.tensor([0.485, 0.456, 0.406], device=images.device).reshape(1, 3, 1, 1))
pred_vertices = output['pred_vertices'].detach().reshape(batch_size, -1, 3)
losses = output['losses']
pred_cam_t = output['pred_cam_t'].detach().reshape(batch_size, 3)
pred_keypoints_2d = output['pred_keypoints_2d'].detach().reshape(batch_size, -1, 2)
if write_to_summary_writer:
summary_writer = self.logger.experiment
for loss_name, val in losses.items():
summary_writer.add_scalar(mode + '/' + loss_name, val.detach().item(), step_count)
# if train is False:
# for metric_name, val in output['metric'].items():
# summary_writer.add_scalar(mode + '/' + metric_name, val, step_count)
num_images = min(batch_size, self.cfg.EXTRA.NUM_LOG_IMAGES)
predictions = self.mesh_renderer.visualize_tensorboard(pred_vertices[:num_images].cpu().numpy(),
pred_cam_t[:num_images].cpu().numpy(),
images[:num_images].cpu().numpy(),
self.cfg.SMAL.get("FOCAL_LENGTH", 1000),
pred_keypoints_2d[:num_images].cpu().numpy(),
gt_keypoints_2d[:num_images].cpu().numpy(),
pred_masks=output.get('pred_masks', None)[:num_images] if output.get('pred_masks', None) is not None else None,
gt_masks=output.get('gt_masks', None)[:num_images] if output.get('gt_masks', None) is not None else None,
)
predictions = make_grid(predictions, nrow=5, padding=2)
if write_to_summary_writer:
summary_writer.add_image('%s/predictions' % mode, predictions, step_count)
return predictions
def training_step(self, batch: Dict) -> Dict:
"""
Run a full training step
Args:
batch (Dict): Dictionary containing {'img', 'mask', 'keypoints_2d', 'keypoints_3d', 'orig_keypoints_2d',
'box_center', 'box_size', 'img_size', 'smal_params',
'smal_params_is_axis_angle', '_trans', 'imgname', 'focal_length'}
Returns:
Dict: Dictionary containing regression output.
"""
batch = batch['img']
optimizer = self.optimizers(use_pl_optimizer=True)
if self.cfg.LOSS_WEIGHTS.get("ADVERSARIAL", 0) > 0:
optimizer, optimizer_disc = optimizer
batch_size = batch['img'].shape[0]
output = self.forward_step(batch, train=True)
pred_smal_params = output['pred_smal_params']
loss = self.compute_loss(batch, output, train=True)
if self.cfg.LOSS_WEIGHTS.get("ADVERSARIAL", 0) > 0:
disc_out = self.discriminator(pred_smal_params['pose'].reshape(batch_size, -1),
pred_smal_params['betas'].reshape(batch_size, -1))
loss_adv = ((disc_out - 1.0) ** 2).sum() / batch_size
loss = loss + self.cfg.LOSS_WEIGHTS.ADVERSARIAL * loss_adv
# Error if Nan
if torch.isnan(loss):
raise ValueError('Loss is NaN')
optimizer.zero_grad()
self.manual_backward(loss)
# Clip gradient
if self.cfg.TRAIN.get('GRAD_CLIP_VAL', 0) > 0:
gn = torch.nn.utils.clip_grad_norm_(self.get_parameters(), self.cfg.TRAIN.GRAD_CLIP_VAL,
error_if_nonfinite=True)
self.log('train/grad_norm', gn, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
# For compatibility
# if self.cfg.LOSS_WEIGHTS.ADVERSARIAL == 0:
# optimizer.param_groups[0]['capturable'] = True
optimizer.step()
if self.cfg.LOSS_WEIGHTS.get("ADVERSARIAL", 0) > 0:
loss_disc = self.training_step_discriminator(batch['smal_params'],
pred_smal_params['pose'].reshape(batch_size, -1),
pred_smal_params['betas'].reshape(batch_size, -1),
optimizer_disc)
output['losses']['loss_gen'] = loss_adv
output['losses']['loss_disc'] = loss_disc
if self.global_step > 0 and self.global_step % self.cfg.GENERAL.LOG_STEPS == 0:
self.tensorboard_logging(batch, output, self.global_step, train=True)
# Log training loss to the logger so checkpoint callback can monitor it.
self.log('train/loss', output['losses']['loss'], on_step=True, on_epoch=True, prog_bar=True,
logger=True, batch_size=batch_size, sync_dist=True)
return output
def validation_step(self, batch: Dict, batch_idx: int, dataloader_idx=0) -> Dict:
"""
Run a validation step and log to Tensorboard
Args:
batch (Dict): Dictionary containing batch data
batch_idx (int): Unused.
Returns:
Dict: Dictionary containing regression output.
"""
# The validation dataloader yields the inner batch dict directly (not wrapped as {'img': loader}).
# Run forward, compute loss and log aggregated validation metrics so ModelCheckpoint can monitor them.
output = self.forward_step(batch, train=False)
# compute_loss will populate output['losses'] and return the scalar loss
loss = self.compute_loss(batch, output, train=False)
# Ensure losses dict is available
losses = output.get('losses', {})
# Log all validation losses to logger with on_epoch=True so checkpoint monitors epoch-level metric
for loss_name, val in losses.items():
# use prog_bar only for the main loss
prog = True if loss_name == 'loss' else False
# Log as 'val/<loss_name>' e.g. 'val/loss'
self.log(f'val/{loss_name}', val, on_step=False, on_epoch=True, prog_bar=prog, logger=True,
sync_dist=True)
# Periodically write images/other visuals to tensorboard
# Log visualizations on the first batch of each validation epoch
if batch_idx == 0:
# Use global_step for step count when logging validation visuals
self.tensorboard_logging(batch, output, self.global_step, train=False)
return output