# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch import torch.nn.functional as F from dataclasses import dataclass from vggt.utils.pose_enc import extri_intri_to_pose_encoding from train_utils.general import check_and_fix_inf_nan from math import ceil, floor @dataclass(eq=False) class MultitaskLoss(torch.nn.Module): """ Multi-task loss module that combines different loss types for VGGT. Supports: - Camera loss - Depth loss - Point loss - Tracking loss (not cleaned yet, dirty code is at the bottom of this file) """ def __init__(self, camera=None, depth=None, point=None, track=None, **kwargs): super().__init__() # Loss configuration dictionaries for each task self.camera = camera self.depth = depth self.point = point self.track = track def forward(self, predictions, batch) -> torch.Tensor: """ Compute the total multi-task loss. Args: predictions: Dict containing model predictions for different tasks batch: Dict containing ground truth data and masks Returns: Dict containing individual losses and total objective """ total_loss = 0 loss_dict = {} # Camera pose loss - if pose encodings are predicted if "pose_enc_list" in predictions: camera_loss_dict = compute_camera_loss(predictions, batch, **self.camera) camera_loss = camera_loss_dict["loss_camera"] * self.camera["weight"] total_loss = total_loss + camera_loss loss_dict.update(camera_loss_dict) # Depth estimation loss - if depth maps are predicted if "depth" in predictions: depth_loss_dict = compute_depth_loss(predictions, batch, **self.depth) depth_loss = depth_loss_dict["loss_conf_depth"] + depth_loss_dict["loss_reg_depth"] + depth_loss_dict["loss_grad_depth"] depth_loss = depth_loss * self.depth["weight"] total_loss = total_loss + depth_loss loss_dict.update(depth_loss_dict) # 3D point reconstruction loss - if world points are predicted if "world_points" in predictions: point_loss_dict = compute_point_loss(predictions, batch, **self.point) point_loss = point_loss_dict["loss_conf_point"] + point_loss_dict["loss_reg_point"] + point_loss_dict["loss_grad_point"] point_loss = point_loss * self.point["weight"] total_loss = total_loss + point_loss loss_dict.update(point_loss_dict) # Tracking loss - not cleaned yet, dirty code is at the bottom of this file if "track" in predictions: raise NotImplementedError("Track loss is not cleaned up yet") loss_dict["objective"] = total_loss return loss_dict def compute_camera_loss( pred_dict, # predictions dict, contains pose encodings batch_data, # ground truth and mask batch dict loss_type="l1", # "l1" or "l2" loss gamma=0.6, # temporal decay weight for multi-stage training pose_encoding_type="absT_quaR_FoV", weight_trans=1.0, # weight for translation loss weight_rot=1.0, # weight for rotation loss weight_focal=0.5, # weight for focal length loss **kwargs ): # List of predicted pose encodings per stage pred_pose_encodings = pred_dict['pose_enc_list'] # Binary mask for valid points per frame (B, N, H, W) point_masks = batch_data['point_masks'] # Only consider frames with enough valid points (>100) valid_frame_mask = point_masks[:, 0].sum(dim=[-1, -2]) > 100 # Number of prediction stages n_stages = len(pred_pose_encodings) # Get ground truth camera extrinsics and intrinsics gt_extrinsics = batch_data['extrinsics'] gt_intrinsics = batch_data['intrinsics'] image_hw = batch_data['images'].shape[-2:] # Encode ground truth pose to match predicted encoding format gt_pose_encoding = extri_intri_to_pose_encoding( gt_extrinsics, gt_intrinsics, image_hw, pose_encoding_type=pose_encoding_type ) # Initialize loss accumulators for translation, rotation, focal length total_loss_T = total_loss_R = total_loss_FL = 0 # Compute loss for each prediction stage with temporal weighting for stage_idx in range(n_stages): # Later stages get higher weight (gamma^0 = 1.0 for final stage) stage_weight = gamma ** (n_stages - stage_idx - 1) pred_pose_stage = pred_pose_encodings[stage_idx] if valid_frame_mask.sum() == 0: # If no valid frames, set losses to zero to avoid gradient issues loss_T_stage = (pred_pose_stage * 0).mean() loss_R_stage = (pred_pose_stage * 0).mean() loss_FL_stage = (pred_pose_stage * 0).mean() else: # Only consider valid frames for loss computation loss_T_stage, loss_R_stage, loss_FL_stage = camera_loss_single( pred_pose_stage[valid_frame_mask].clone(), gt_pose_encoding[valid_frame_mask].clone(), loss_type=loss_type ) # Accumulate weighted losses across stages total_loss_T += loss_T_stage * stage_weight total_loss_R += loss_R_stage * stage_weight total_loss_FL += loss_FL_stage * stage_weight # Average over all stages avg_loss_T = total_loss_T / n_stages avg_loss_R = total_loss_R / n_stages avg_loss_FL = total_loss_FL / n_stages # Compute total weighted camera loss total_camera_loss = ( avg_loss_T * weight_trans + avg_loss_R * weight_rot + avg_loss_FL * weight_focal ) # Return loss dictionary with individual components return { "loss_camera": total_camera_loss, "loss_T": avg_loss_T, "loss_R": avg_loss_R, "loss_FL": avg_loss_FL } def camera_loss_single(pred_pose_enc, gt_pose_enc, loss_type="l1"): """ Computes translation, rotation, and focal loss for a batch of pose encodings. Args: pred_pose_enc: (N, D) predicted pose encoding gt_pose_enc: (N, D) ground truth pose encoding loss_type: "l1" (abs error) or "l2" (euclidean error) Returns: loss_T: translation loss (mean) loss_R: rotation loss (mean) loss_FL: focal length/intrinsics loss (mean) NOTE: The paper uses smooth l1 loss, but we found l1 loss is more stable than smooth l1 and l2 loss. So here we use l1 loss. """ if loss_type == "l1": # Translation: first 3 dims; Rotation: next 4 (quaternion); Focal/Intrinsics: last dims loss_T = (pred_pose_enc[..., :3] - gt_pose_enc[..., :3]).abs() loss_R = (pred_pose_enc[..., 3:7] - gt_pose_enc[..., 3:7]).abs() loss_FL = (pred_pose_enc[..., 7:] - gt_pose_enc[..., 7:]).abs() elif loss_type == "l2": # L2 norm for each component loss_T = (pred_pose_enc[..., :3] - gt_pose_enc[..., :3]).norm(dim=-1, keepdim=True) loss_R = (pred_pose_enc[..., 3:7] - gt_pose_enc[..., 3:7]).norm(dim=-1) loss_FL = (pred_pose_enc[..., 7:] - gt_pose_enc[..., 7:]).norm(dim=-1) else: raise ValueError(f"Unknown loss type: {loss_type}") # Check/fix numerical issues (nan/inf) for each loss component loss_T = check_and_fix_inf_nan(loss_T, "loss_T") loss_R = check_and_fix_inf_nan(loss_R, "loss_R") loss_FL = check_and_fix_inf_nan(loss_FL, "loss_FL") # Clamp outlier translation loss to prevent instability, then average loss_T = loss_T.clamp(max=100).mean() loss_R = loss_R.mean() loss_FL = loss_FL.mean() return loss_T, loss_R, loss_FL def compute_point_loss(predictions, batch, gamma=1.0, alpha=0.2, gradient_loss_fn = None, valid_range=-1, **kwargs): """ Compute point loss. Args: predictions: Dict containing 'world_points' and 'world_points_conf' batch: Dict containing ground truth 'world_points' and 'point_masks' gamma: Weight for confidence loss alpha: Weight for confidence regularization gradient_loss_fn: Type of gradient loss to apply valid_range: Quantile range for outlier filtering """ pred_points = predictions['world_points'] pred_points_conf = predictions['world_points_conf'] gt_points = batch['world_points'] gt_points_mask = batch['point_masks'] gt_points = check_and_fix_inf_nan(gt_points, "gt_points") if gt_points_mask.sum() < 100: # If there are less than 100 valid points, skip this batch dummy_loss = (0.0 * pred_points).mean() loss_dict = {f"loss_conf_point": dummy_loss, f"loss_reg_point": dummy_loss, f"loss_grad_point": dummy_loss,} return loss_dict # Compute confidence-weighted regression loss with optional gradient loss loss_conf, loss_grad, loss_reg = regression_loss(pred_points, gt_points, gt_points_mask, conf=pred_points_conf, gradient_loss_fn=gradient_loss_fn, gamma=gamma, alpha=alpha, valid_range=valid_range) loss_dict = { f"loss_conf_point": loss_conf, f"loss_reg_point": loss_reg, f"loss_grad_point": loss_grad, } return loss_dict def compute_depth_loss(predictions, batch, gamma=1.0, alpha=0.2, gradient_loss_fn = None, valid_range=-1, **kwargs): """ Compute depth loss. Args: predictions: Dict containing 'depth' and 'depth_conf' batch: Dict containing ground truth 'depths' and 'point_masks' gamma: Weight for confidence loss alpha: Weight for confidence regularization gradient_loss_fn: Type of gradient loss to apply valid_range: Quantile range for outlier filtering """ pred_depth = predictions['depth'] pred_depth_conf = predictions['depth_conf'] gt_depth = batch['depths'] gt_depth = check_and_fix_inf_nan(gt_depth, "gt_depth") gt_depth = gt_depth[..., None] # (B, H, W, 1) gt_depth_mask = batch['point_masks'].clone() # 3D points derived from depth map, so we use the same mask if gt_depth_mask.sum() < 100: # If there are less than 100 valid points, skip this batch dummy_loss = (0.0 * pred_depth).mean() loss_dict = {f"loss_conf_depth": dummy_loss, f"loss_reg_depth": dummy_loss, f"loss_grad_depth": dummy_loss,} return loss_dict # NOTE: we put conf inside regression_loss so that we can also apply conf loss to the gradient loss in a multi-scale manner # this is hacky, but very easier to implement loss_conf, loss_grad, loss_reg = regression_loss(pred_depth, gt_depth, gt_depth_mask, conf=pred_depth_conf, gradient_loss_fn=gradient_loss_fn, gamma=gamma, alpha=alpha, valid_range=valid_range) loss_dict = { f"loss_conf_depth": loss_conf, f"loss_reg_depth": loss_reg, f"loss_grad_depth": loss_grad, } return loss_dict def regression_loss(pred, gt, mask, conf=None, gradient_loss_fn=None, gamma=1.0, alpha=0.2, valid_range=-1): """ Core regression loss function with confidence weighting and optional gradient loss. Computes: 1. gamma * ||pred - gt||^2 * conf - alpha * log(conf) 2. Optional gradient loss Args: pred: (B, S, H, W, C) predicted values gt: (B, S, H, W, C) ground truth values mask: (B, S, H, W) valid pixel mask conf: (B, S, H, W) confidence weights (optional) gradient_loss_fn: Type of gradient loss ("normal", "grad", etc.) gamma: Weight for confidence loss alpha: Weight for confidence regularization valid_range: Quantile range for outlier filtering Returns: loss_conf: Confidence-weighted loss loss_grad: Gradient loss (0 if not specified) loss_reg: Regular L2 loss """ bb, ss, hh, ww, nc = pred.shape # Compute L2 distance between predicted and ground truth points loss_reg = torch.norm(gt[mask] - pred[mask], dim=-1) loss_reg = check_and_fix_inf_nan(loss_reg, "loss_reg") # Confidence-weighted loss: gamma * loss * conf - alpha * log(conf) # This encourages the model to be confident on easy examples and less confident on hard ones loss_conf = gamma * loss_reg * conf[mask] - alpha * torch.log(conf[mask]) loss_conf = check_and_fix_inf_nan(loss_conf, "loss_conf") # Initialize gradient loss loss_grad = 0 # Prepare confidence for gradient loss if needed if "conf" in gradient_loss_fn: to_feed_conf = conf.reshape(bb*ss, hh, ww) else: to_feed_conf = None # Compute gradient loss if specified for spatial smoothness if "normal" in gradient_loss_fn: # Surface normal-based gradient loss loss_grad = gradient_loss_multi_scale_wrapper( pred.reshape(bb*ss, hh, ww, nc), gt.reshape(bb*ss, hh, ww, nc), mask.reshape(bb*ss, hh, ww), gradient_loss_fn=normal_loss, scales=3, conf=to_feed_conf, ) elif "grad" in gradient_loss_fn: # Standard gradient-based loss loss_grad = gradient_loss_multi_scale_wrapper( pred.reshape(bb*ss, hh, ww, nc), gt.reshape(bb*ss, hh, ww, nc), mask.reshape(bb*ss, hh, ww), gradient_loss_fn=gradient_loss, conf=to_feed_conf, ) # Process confidence-weighted loss if loss_conf.numel() > 0: # Filter out outliers using quantile-based thresholding if valid_range>0: loss_conf = filter_by_quantile(loss_conf, valid_range) loss_conf = check_and_fix_inf_nan(loss_conf, f"loss_conf_depth") loss_conf = loss_conf.mean() else: loss_conf = (0.0 * pred).mean() # Process regular regression loss if loss_reg.numel() > 0: # Filter out outliers using quantile-based thresholding if valid_range>0: loss_reg = filter_by_quantile(loss_reg, valid_range) loss_reg = check_and_fix_inf_nan(loss_reg, f"loss_reg_depth") loss_reg = loss_reg.mean() else: loss_reg = (0.0 * pred).mean() return loss_conf, loss_grad, loss_reg def gradient_loss_multi_scale_wrapper(prediction, target, mask, scales=4, gradient_loss_fn = None, conf=None): """ Multi-scale gradient loss wrapper. Applies gradient loss at multiple scales by subsampling the input. This helps capture both fine and coarse spatial structures. Args: prediction: (B, H, W, C) predicted values target: (B, H, W, C) ground truth values mask: (B, H, W) valid pixel mask scales: Number of scales to use gradient_loss_fn: Gradient loss function to apply conf: (B, H, W) confidence weights (optional) """ total = 0 for scale in range(scales): step = pow(2, scale) # Subsample by 2^scale total += gradient_loss_fn( prediction[:, ::step, ::step], target[:, ::step, ::step], mask[:, ::step, ::step], conf=conf[:, ::step, ::step] if conf is not None else None ) total = total / scales return total def normal_loss(prediction, target, mask, cos_eps=1e-8, conf=None, gamma=1.0, alpha=0.2): """ Surface normal-based loss for geometric consistency. Computes surface normals from 3D point maps using cross products of neighboring points, then measures the angle between predicted and ground truth normals. Args: prediction: (B, H, W, 3) predicted 3D coordinates/points target: (B, H, W, 3) ground-truth 3D coordinates/points mask: (B, H, W) valid pixel mask cos_eps: Epsilon for numerical stability in cosine computation conf: (B, H, W) confidence weights (optional) gamma: Weight for confidence loss alpha: Weight for confidence regularization """ # Convert point maps to surface normals using cross products pred_normals, pred_valids = point_map_to_normal(prediction, mask, eps=cos_eps) gt_normals, gt_valids = point_map_to_normal(target, mask, eps=cos_eps) # Only consider regions where both predicted and GT normals are valid all_valid = pred_valids & gt_valids # shape: (4, B, H, W) # Early return if not enough valid points divisor = torch.sum(all_valid) if divisor < 10: return 0 # Extract valid normals pred_normals = pred_normals[all_valid].clone() gt_normals = gt_normals[all_valid].clone() # Compute cosine similarity between corresponding normals dot = torch.sum(pred_normals * gt_normals, dim=-1) # Clamp dot product to [-1, 1] for numerical stability dot = torch.clamp(dot, -1 + cos_eps, 1 - cos_eps) # Compute loss as 1 - cos(theta), instead of arccos(dot) for numerical stability loss = 1 - dot # Return mean loss if we have enough valid points if loss.numel() < 10: return 0 else: loss = check_and_fix_inf_nan(loss, "normal_loss") if conf is not None: # Apply confidence weighting conf = conf[None, ...].expand(4, -1, -1, -1) conf = conf[all_valid].clone() loss = gamma * loss * conf - alpha * torch.log(conf) return loss.mean() else: return loss.mean() def gradient_loss(prediction, target, mask, conf=None, gamma=1.0, alpha=0.2): """ Gradient-based loss. Computes the L1 difference between adjacent pixels in x and y directions. Args: prediction: (B, H, W, C) predicted values target: (B, H, W, C) ground truth values mask: (B, H, W) valid pixel mask conf: (B, H, W) confidence weights (optional) gamma: Weight for confidence loss alpha: Weight for confidence regularization """ # Expand mask to match prediction channels mask = mask[..., None].expand(-1, -1, -1, prediction.shape[-1]) M = torch.sum(mask, (1, 2, 3)) # Compute difference between prediction and target diff = prediction - target diff = torch.mul(mask, diff) # Compute gradients in x direction (horizontal) grad_x = torch.abs(diff[:, :, 1:] - diff[:, :, :-1]) mask_x = torch.mul(mask[:, :, 1:], mask[:, :, :-1]) grad_x = torch.mul(mask_x, grad_x) # Compute gradients in y direction (vertical) grad_y = torch.abs(diff[:, 1:, :] - diff[:, :-1, :]) mask_y = torch.mul(mask[:, 1:, :], mask[:, :-1, :]) grad_y = torch.mul(mask_y, grad_y) # Clamp gradients to prevent outliers grad_x = grad_x.clamp(max=100) grad_y = grad_y.clamp(max=100) # Apply confidence weighting if provided if conf is not None: conf = conf[..., None].expand(-1, -1, -1, prediction.shape[-1]) conf_x = conf[:, :, 1:] conf_y = conf[:, 1:, :] grad_x = gamma * grad_x * conf_x - alpha * torch.log(conf_x) grad_y = gamma * grad_y * conf_y - alpha * torch.log(conf_y) # Sum gradients and normalize by number of valid pixels grad_loss = torch.sum(grad_x, (1, 2, 3)) + torch.sum(grad_y, (1, 2, 3)) divisor = torch.sum(M) if divisor == 0: return 0 else: grad_loss = torch.sum(grad_loss) / divisor return grad_loss def point_map_to_normal(point_map, mask, eps=1e-6): """ Convert 3D point map to surface normal vectors using cross products. Computes normals by taking cross products of neighboring point differences. Uses 4 different cross-product directions for robustness. Args: point_map: (B, H, W, 3) 3D points laid out in a 2D grid mask: (B, H, W) valid pixels (bool) eps: Epsilon for numerical stability in normalization Returns: normals: (4, B, H, W, 3) normal vectors for each of the 4 cross-product directions valids: (4, B, H, W) corresponding valid masks """ with torch.cuda.amp.autocast(enabled=False): # Pad inputs to avoid boundary issues padded_mask = F.pad(mask, (1, 1, 1, 1), mode='constant', value=0) pts = F.pad(point_map.permute(0, 3, 1, 2), (1,1,1,1), mode='constant', value=0).permute(0, 2, 3, 1) # Get neighboring points for each pixel center = pts[:, 1:-1, 1:-1, :] # B,H,W,3 up = pts[:, :-2, 1:-1, :] left = pts[:, 1:-1, :-2 , :] down = pts[:, 2:, 1:-1, :] right = pts[:, 1:-1, 2:, :] # Compute direction vectors from center to neighbors up_dir = up - center left_dir = left - center down_dir = down - center right_dir = right - center # Compute four cross products for different normal directions n1 = torch.cross(up_dir, left_dir, dim=-1) # up x left n2 = torch.cross(left_dir, down_dir, dim=-1) # left x down n3 = torch.cross(down_dir, right_dir, dim=-1) # down x right n4 = torch.cross(right_dir,up_dir, dim=-1) # right x up # Validity masks - require both direction pixels to be valid v1 = padded_mask[:, :-2, 1:-1] & padded_mask[:, 1:-1, 1:-1] & padded_mask[:, 1:-1, :-2] v2 = padded_mask[:, 1:-1, :-2 ] & padded_mask[:, 1:-1, 1:-1] & padded_mask[:, 2:, 1:-1] v3 = padded_mask[:, 2:, 1:-1] & padded_mask[:, 1:-1, 1:-1] & padded_mask[:, 1:-1, 2:] v4 = padded_mask[:, 1:-1, 2: ] & padded_mask[:, 1:-1, 1:-1] & padded_mask[:, :-2, 1:-1] # Stack normals and validity masks normals = torch.stack([n1, n2, n3, n4], dim=0) # shape [4, B, H, W, 3] valids = torch.stack([v1, v2, v3, v4], dim=0) # shape [4, B, H, W] # Normalize normal vectors normals = F.normalize(normals, p=2, dim=-1, eps=eps) return normals, valids def filter_by_quantile(loss_tensor, valid_range, min_elements=1000, hard_max=100): """ Filter loss tensor by keeping only values below a certain quantile threshold. This helps remove outliers that could destabilize training. Args: loss_tensor: Tensor containing loss values valid_range: Float between 0 and 1 indicating the quantile threshold min_elements: Minimum number of elements required to apply filtering hard_max: Maximum allowed value for any individual loss Returns: Filtered and clamped loss tensor """ if loss_tensor.numel() <= min_elements: # Too few elements, just return as-is return loss_tensor # Randomly sample if tensor is too large to avoid memory issues if loss_tensor.numel() > 100000000: # Flatten and randomly select 1M elements indices = torch.randperm(loss_tensor.numel(), device=loss_tensor.device)[:1_000_000] loss_tensor = loss_tensor.view(-1)[indices] # First clamp individual values to prevent extreme outliers loss_tensor = loss_tensor.clamp(max=hard_max) # Compute quantile threshold quantile_thresh = torch_quantile(loss_tensor.detach(), valid_range) quantile_thresh = min(quantile_thresh, hard_max) # Apply quantile filtering if enough elements remain quantile_mask = loss_tensor < quantile_thresh if quantile_mask.sum() > min_elements: return loss_tensor[quantile_mask] return loss_tensor def torch_quantile( input, q, dim = None, keepdim: bool = False, *, interpolation: str = "nearest", out: torch.Tensor = None, ) -> torch.Tensor: """Better torch.quantile for one SCALAR quantile. Using torch.kthvalue. Better than torch.quantile because: - No 2**24 input size limit (pytorch/issues/67592), - Much faster, at least on big input sizes. Arguments: input (torch.Tensor): See torch.quantile. q (float): See torch.quantile. Supports only scalar input currently. dim (int | None): See torch.quantile. keepdim (bool): See torch.quantile. Supports only False currently. interpolation: {"nearest", "lower", "higher"} See torch.quantile. out (torch.Tensor | None): See torch.quantile. Supports only None currently. """ # https://github.com/pytorch/pytorch/issues/64947 # Sanitization: q try: q = float(q) assert 0 <= q <= 1 except Exception: raise ValueError(f"Only scalar input 0<=q<=1 is currently supported (got {q})!") # Handle dim=None case if dim_was_none := dim is None: dim = 0 input = input.reshape((-1,) + (1,) * (input.ndim - 1)) # Set interpolation method if interpolation == "nearest": inter = round elif interpolation == "lower": inter = floor elif interpolation == "higher": inter = ceil else: raise ValueError( "Supported interpolations currently are {'nearest', 'lower', 'higher'} " f"(got '{interpolation}')!" ) # Validate out parameter if out is not None: raise ValueError(f"Only None value is currently supported for out (got {out})!") # Compute k-th value k = inter(q * (input.shape[dim] - 1)) + 1 out = torch.kthvalue(input, k, dim, keepdim=True, out=out)[0] # Handle keepdim and dim=None cases if keepdim: return out if dim_was_none: return out.squeeze() else: return out.squeeze(dim) return out ######################################################################################## ######################################################################################## # Dirty code for tracking loss: ######################################################################################## ######################################################################################## ''' def _compute_losses(self, coord_preds, vis_scores, conf_scores, batch): """Compute tracking losses using sequence_loss""" gt_tracks = batch["tracks"] # B, S, N, 2 gt_track_vis_mask = batch["track_vis_mask"] # B, S, N # if self.training and hasattr(self, "train_query_points"): train_query_points = coord_preds[-1].shape[2] gt_tracks = gt_tracks[:, :, :train_query_points] gt_tracks = check_and_fix_inf_nan(gt_tracks, "gt_tracks", hard_max=None) gt_track_vis_mask = gt_track_vis_mask[:, :, :train_query_points] # Create validity mask that filters out tracks not visible in first frame valids = torch.ones_like(gt_track_vis_mask) mask = gt_track_vis_mask[:, 0, :] == True valids = valids * mask.unsqueeze(1) if not valids.any(): print("No valid tracks found in first frame") print("seq_name: ", batch["seq_name"]) print("ids: ", batch["ids"]) print("time: ", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) dummy_coord = coord_preds[0].mean() * 0 # keeps graph & grads dummy_vis = vis_scores.mean() * 0 if conf_scores is not None: dummy_conf = conf_scores.mean() * 0 else: dummy_conf = 0 return dummy_coord, dummy_vis, dummy_conf # three scalar zeros # Compute tracking loss using sequence_loss track_loss = sequence_loss( flow_preds=coord_preds, flow_gt=gt_tracks, vis=gt_track_vis_mask, valids=valids, **self.loss_kwargs ) vis_loss = F.binary_cross_entropy_with_logits(vis_scores[valids], gt_track_vis_mask[valids].float()) vis_loss = check_and_fix_inf_nan(vis_loss, "vis_loss", hard_max=None) # within 3 pixels if conf_scores is not None: gt_conf_mask = (gt_tracks - coord_preds[-1]).norm(dim=-1) < 3 conf_loss = F.binary_cross_entropy_with_logits(conf_scores[valids], gt_conf_mask[valids].float()) conf_loss = check_and_fix_inf_nan(conf_loss, "conf_loss", hard_max=None) else: conf_loss = 0 return track_loss, vis_loss, conf_loss def reduce_masked_mean(x, mask, dim=None, keepdim=False): for a, b in zip(x.size(), mask.size()): assert a == b prod = x * mask if dim is None: numer = torch.sum(prod) denom = torch.sum(mask) else: numer = torch.sum(prod, dim=dim, keepdim=keepdim) denom = torch.sum(mask, dim=dim, keepdim=keepdim) mean = numer / denom.clamp(min=1) mean = torch.where(denom > 0, mean, torch.zeros_like(mean)) return mean def sequence_loss(flow_preds, flow_gt, vis, valids, gamma=0.8, vis_aware=False, huber=False, delta=10, vis_aware_w=0.1, **kwargs): """Loss function defined over sequence of flow predictions""" B, S, N, D = flow_gt.shape assert D == 2 B, S1, N = vis.shape B, S2, N = valids.shape assert S == S1 assert S == S2 n_predictions = len(flow_preds) flow_loss = 0.0 for i in range(n_predictions): i_weight = gamma ** (n_predictions - i - 1) flow_pred = flow_preds[i] i_loss = (flow_pred - flow_gt).abs() # B, S, N, 2 i_loss = check_and_fix_inf_nan(i_loss, f"i_loss_iter_{i}", hard_max=None) i_loss = torch.mean(i_loss, dim=3) # B, S, N # Combine valids and vis for per-frame valid masking. combined_mask = torch.logical_and(valids, vis) num_valid_points = combined_mask.sum() if vis_aware: combined_mask = combined_mask.float() * (1.0 + vis_aware_w) # Add, don't add to the mask itself. flow_loss += i_weight * reduce_masked_mean(i_loss, combined_mask) else: if num_valid_points > 2: i_loss = i_loss[combined_mask] flow_loss += i_weight * i_loss.mean() else: i_loss = check_and_fix_inf_nan(i_loss, f"i_loss_iter_safe_check_{i}", hard_max=None) flow_loss += 0 * i_loss.mean() # Avoid division by zero if n_predictions is 0 (though it shouldn't be). if n_predictions > 0: flow_loss = flow_loss / n_predictions return flow_loss '''