|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from dataclasses import dataclass |
|
|
from vggt.utils.pose_enc import extri_intri_to_pose_encoding |
|
|
from train_utils.general import check_and_fix_inf_nan |
|
|
from math import ceil, floor |
|
|
|
|
|
|
|
|
@dataclass(eq=False) |
|
|
class MultitaskLoss(torch.nn.Module): |
|
|
""" |
|
|
Multi-task loss module that combines different loss types for VGGT. |
|
|
|
|
|
Supports: |
|
|
- Camera loss |
|
|
- Depth loss |
|
|
- Point loss |
|
|
- Tracking loss (not cleaned yet, dirty code is at the bottom of this file) |
|
|
""" |
|
|
def __init__(self, camera=None, depth=None, point=None, track=None, **kwargs): |
|
|
super().__init__() |
|
|
|
|
|
self.camera = camera |
|
|
self.depth = depth |
|
|
self.point = point |
|
|
self.track = track |
|
|
|
|
|
def forward(self, predictions, batch) -> torch.Tensor: |
|
|
""" |
|
|
Compute the total multi-task loss. |
|
|
|
|
|
Args: |
|
|
predictions: Dict containing model predictions for different tasks |
|
|
batch: Dict containing ground truth data and masks |
|
|
|
|
|
Returns: |
|
|
Dict containing individual losses and total objective |
|
|
""" |
|
|
total_loss = 0 |
|
|
loss_dict = {} |
|
|
|
|
|
|
|
|
if "pose_enc_list" in predictions: |
|
|
camera_loss_dict = compute_camera_loss(predictions, batch, **self.camera) |
|
|
camera_loss = camera_loss_dict["loss_camera"] * self.camera["weight"] |
|
|
total_loss = total_loss + camera_loss |
|
|
loss_dict.update(camera_loss_dict) |
|
|
|
|
|
|
|
|
if "depth" in predictions: |
|
|
depth_loss_dict = compute_depth_loss(predictions, batch, **self.depth) |
|
|
depth_loss = depth_loss_dict["loss_conf_depth"] + depth_loss_dict["loss_reg_depth"] + depth_loss_dict["loss_grad_depth"] |
|
|
depth_loss = depth_loss * self.depth["weight"] |
|
|
total_loss = total_loss + depth_loss |
|
|
loss_dict.update(depth_loss_dict) |
|
|
|
|
|
|
|
|
if "world_points" in predictions: |
|
|
point_loss_dict = compute_point_loss(predictions, batch, **self.point) |
|
|
point_loss = point_loss_dict["loss_conf_point"] + point_loss_dict["loss_reg_point"] + point_loss_dict["loss_grad_point"] |
|
|
point_loss = point_loss * self.point["weight"] |
|
|
total_loss = total_loss + point_loss |
|
|
loss_dict.update(point_loss_dict) |
|
|
|
|
|
|
|
|
if "track" in predictions: |
|
|
raise NotImplementedError("Track loss is not cleaned up yet") |
|
|
|
|
|
loss_dict["objective"] = total_loss |
|
|
|
|
|
return loss_dict |
|
|
|
|
|
|
|
|
def compute_camera_loss( |
|
|
pred_dict, |
|
|
batch_data, |
|
|
loss_type="l1", |
|
|
gamma=0.6, |
|
|
pose_encoding_type="absT_quaR_FoV", |
|
|
weight_trans=1.0, |
|
|
weight_rot=1.0, |
|
|
weight_focal=0.5, |
|
|
**kwargs |
|
|
): |
|
|
|
|
|
pred_pose_encodings = pred_dict['pose_enc_list'] |
|
|
|
|
|
point_masks = batch_data['point_masks'] |
|
|
|
|
|
valid_frame_mask = point_masks[:, 0].sum(dim=[-1, -2]) > 100 |
|
|
|
|
|
n_stages = len(pred_pose_encodings) |
|
|
|
|
|
|
|
|
gt_extrinsics = batch_data['extrinsics'] |
|
|
gt_intrinsics = batch_data['intrinsics'] |
|
|
image_hw = batch_data['images'].shape[-2:] |
|
|
|
|
|
|
|
|
gt_pose_encoding = extri_intri_to_pose_encoding( |
|
|
gt_extrinsics, gt_intrinsics, image_hw, pose_encoding_type=pose_encoding_type |
|
|
) |
|
|
|
|
|
|
|
|
total_loss_T = total_loss_R = total_loss_FL = 0 |
|
|
|
|
|
|
|
|
for stage_idx in range(n_stages): |
|
|
|
|
|
stage_weight = gamma ** (n_stages - stage_idx - 1) |
|
|
pred_pose_stage = pred_pose_encodings[stage_idx] |
|
|
|
|
|
if valid_frame_mask.sum() == 0: |
|
|
|
|
|
loss_T_stage = (pred_pose_stage * 0).mean() |
|
|
loss_R_stage = (pred_pose_stage * 0).mean() |
|
|
loss_FL_stage = (pred_pose_stage * 0).mean() |
|
|
else: |
|
|
|
|
|
loss_T_stage, loss_R_stage, loss_FL_stage = camera_loss_single( |
|
|
pred_pose_stage[valid_frame_mask].clone(), |
|
|
gt_pose_encoding[valid_frame_mask].clone(), |
|
|
loss_type=loss_type |
|
|
) |
|
|
|
|
|
total_loss_T += loss_T_stage * stage_weight |
|
|
total_loss_R += loss_R_stage * stage_weight |
|
|
total_loss_FL += loss_FL_stage * stage_weight |
|
|
|
|
|
|
|
|
avg_loss_T = total_loss_T / n_stages |
|
|
avg_loss_R = total_loss_R / n_stages |
|
|
avg_loss_FL = total_loss_FL / n_stages |
|
|
|
|
|
|
|
|
total_camera_loss = ( |
|
|
avg_loss_T * weight_trans + |
|
|
avg_loss_R * weight_rot + |
|
|
avg_loss_FL * weight_focal |
|
|
) |
|
|
|
|
|
|
|
|
return { |
|
|
"loss_camera": total_camera_loss, |
|
|
"loss_T": avg_loss_T, |
|
|
"loss_R": avg_loss_R, |
|
|
"loss_FL": avg_loss_FL |
|
|
} |
|
|
|
|
|
def camera_loss_single(pred_pose_enc, gt_pose_enc, loss_type="l1"): |
|
|
""" |
|
|
Computes translation, rotation, and focal loss for a batch of pose encodings. |
|
|
|
|
|
Args: |
|
|
pred_pose_enc: (N, D) predicted pose encoding |
|
|
gt_pose_enc: (N, D) ground truth pose encoding |
|
|
loss_type: "l1" (abs error) or "l2" (euclidean error) |
|
|
Returns: |
|
|
loss_T: translation loss (mean) |
|
|
loss_R: rotation loss (mean) |
|
|
loss_FL: focal length/intrinsics loss (mean) |
|
|
|
|
|
NOTE: The paper uses smooth l1 loss, but we found l1 loss is more stable than smooth l1 and l2 loss. |
|
|
So here we use l1 loss. |
|
|
""" |
|
|
if loss_type == "l1": |
|
|
|
|
|
loss_T = (pred_pose_enc[..., :3] - gt_pose_enc[..., :3]).abs() |
|
|
loss_R = (pred_pose_enc[..., 3:7] - gt_pose_enc[..., 3:7]).abs() |
|
|
loss_FL = (pred_pose_enc[..., 7:] - gt_pose_enc[..., 7:]).abs() |
|
|
elif loss_type == "l2": |
|
|
|
|
|
loss_T = (pred_pose_enc[..., :3] - gt_pose_enc[..., :3]).norm(dim=-1, keepdim=True) |
|
|
loss_R = (pred_pose_enc[..., 3:7] - gt_pose_enc[..., 3:7]).norm(dim=-1) |
|
|
loss_FL = (pred_pose_enc[..., 7:] - gt_pose_enc[..., 7:]).norm(dim=-1) |
|
|
else: |
|
|
raise ValueError(f"Unknown loss type: {loss_type}") |
|
|
|
|
|
|
|
|
loss_T = check_and_fix_inf_nan(loss_T, "loss_T") |
|
|
loss_R = check_and_fix_inf_nan(loss_R, "loss_R") |
|
|
loss_FL = check_and_fix_inf_nan(loss_FL, "loss_FL") |
|
|
|
|
|
|
|
|
loss_T = loss_T.clamp(max=100).mean() |
|
|
loss_R = loss_R.mean() |
|
|
loss_FL = loss_FL.mean() |
|
|
|
|
|
return loss_T, loss_R, loss_FL |
|
|
|
|
|
|
|
|
def compute_point_loss(predictions, batch, gamma=1.0, alpha=0.2, gradient_loss_fn = None, valid_range=-1, **kwargs): |
|
|
""" |
|
|
Compute point loss. |
|
|
|
|
|
Args: |
|
|
predictions: Dict containing 'world_points' and 'world_points_conf' |
|
|
batch: Dict containing ground truth 'world_points' and 'point_masks' |
|
|
gamma: Weight for confidence loss |
|
|
alpha: Weight for confidence regularization |
|
|
gradient_loss_fn: Type of gradient loss to apply |
|
|
valid_range: Quantile range for outlier filtering |
|
|
""" |
|
|
pred_points = predictions['world_points'] |
|
|
pred_points_conf = predictions['world_points_conf'] |
|
|
gt_points = batch['world_points'] |
|
|
gt_points_mask = batch['point_masks'] |
|
|
|
|
|
gt_points = check_and_fix_inf_nan(gt_points, "gt_points") |
|
|
|
|
|
if gt_points_mask.sum() < 100: |
|
|
|
|
|
dummy_loss = (0.0 * pred_points).mean() |
|
|
loss_dict = {f"loss_conf_point": dummy_loss, |
|
|
f"loss_reg_point": dummy_loss, |
|
|
f"loss_grad_point": dummy_loss,} |
|
|
return loss_dict |
|
|
|
|
|
|
|
|
loss_conf, loss_grad, loss_reg = regression_loss(pred_points, gt_points, gt_points_mask, conf=pred_points_conf, |
|
|
gradient_loss_fn=gradient_loss_fn, gamma=gamma, alpha=alpha, valid_range=valid_range) |
|
|
|
|
|
loss_dict = { |
|
|
f"loss_conf_point": loss_conf, |
|
|
f"loss_reg_point": loss_reg, |
|
|
f"loss_grad_point": loss_grad, |
|
|
} |
|
|
|
|
|
return loss_dict |
|
|
|
|
|
|
|
|
def compute_depth_loss(predictions, batch, gamma=1.0, alpha=0.2, gradient_loss_fn = None, valid_range=-1, **kwargs): |
|
|
""" |
|
|
Compute depth loss. |
|
|
|
|
|
Args: |
|
|
predictions: Dict containing 'depth' and 'depth_conf' |
|
|
batch: Dict containing ground truth 'depths' and 'point_masks' |
|
|
gamma: Weight for confidence loss |
|
|
alpha: Weight for confidence regularization |
|
|
gradient_loss_fn: Type of gradient loss to apply |
|
|
valid_range: Quantile range for outlier filtering |
|
|
""" |
|
|
pred_depth = predictions['depth'] |
|
|
pred_depth_conf = predictions['depth_conf'] |
|
|
|
|
|
gt_depth = batch['depths'] |
|
|
gt_depth = check_and_fix_inf_nan(gt_depth, "gt_depth") |
|
|
gt_depth = gt_depth[..., None] |
|
|
gt_depth_mask = batch['point_masks'].clone() |
|
|
|
|
|
if gt_depth_mask.sum() < 100: |
|
|
|
|
|
dummy_loss = (0.0 * pred_depth).mean() |
|
|
loss_dict = {f"loss_conf_depth": dummy_loss, |
|
|
f"loss_reg_depth": dummy_loss, |
|
|
f"loss_grad_depth": dummy_loss,} |
|
|
return loss_dict |
|
|
|
|
|
|
|
|
|
|
|
loss_conf, loss_grad, loss_reg = regression_loss(pred_depth, gt_depth, gt_depth_mask, conf=pred_depth_conf, |
|
|
gradient_loss_fn=gradient_loss_fn, gamma=gamma, alpha=alpha, valid_range=valid_range) |
|
|
|
|
|
loss_dict = { |
|
|
f"loss_conf_depth": loss_conf, |
|
|
f"loss_reg_depth": loss_reg, |
|
|
f"loss_grad_depth": loss_grad, |
|
|
} |
|
|
|
|
|
return loss_dict |
|
|
|
|
|
|
|
|
def regression_loss(pred, gt, mask, conf=None, gradient_loss_fn=None, gamma=1.0, alpha=0.2, valid_range=-1): |
|
|
""" |
|
|
Core regression loss function with confidence weighting and optional gradient loss. |
|
|
|
|
|
Computes: |
|
|
1. gamma * ||pred - gt||^2 * conf - alpha * log(conf) |
|
|
2. Optional gradient loss |
|
|
|
|
|
Args: |
|
|
pred: (B, S, H, W, C) predicted values |
|
|
gt: (B, S, H, W, C) ground truth values |
|
|
mask: (B, S, H, W) valid pixel mask |
|
|
conf: (B, S, H, W) confidence weights (optional) |
|
|
gradient_loss_fn: Type of gradient loss ("normal", "grad", etc.) |
|
|
gamma: Weight for confidence loss |
|
|
alpha: Weight for confidence regularization |
|
|
valid_range: Quantile range for outlier filtering |
|
|
|
|
|
Returns: |
|
|
loss_conf: Confidence-weighted loss |
|
|
loss_grad: Gradient loss (0 if not specified) |
|
|
loss_reg: Regular L2 loss |
|
|
""" |
|
|
bb, ss, hh, ww, nc = pred.shape |
|
|
|
|
|
|
|
|
loss_reg = torch.norm(gt[mask] - pred[mask], dim=-1) |
|
|
loss_reg = check_and_fix_inf_nan(loss_reg, "loss_reg") |
|
|
|
|
|
|
|
|
|
|
|
loss_conf = gamma * loss_reg * conf[mask] - alpha * torch.log(conf[mask]) |
|
|
loss_conf = check_and_fix_inf_nan(loss_conf, "loss_conf") |
|
|
|
|
|
|
|
|
loss_grad = 0 |
|
|
|
|
|
|
|
|
if "conf" in gradient_loss_fn: |
|
|
to_feed_conf = conf.reshape(bb*ss, hh, ww) |
|
|
else: |
|
|
to_feed_conf = None |
|
|
|
|
|
|
|
|
if "normal" in gradient_loss_fn: |
|
|
|
|
|
loss_grad = gradient_loss_multi_scale_wrapper( |
|
|
pred.reshape(bb*ss, hh, ww, nc), |
|
|
gt.reshape(bb*ss, hh, ww, nc), |
|
|
mask.reshape(bb*ss, hh, ww), |
|
|
gradient_loss_fn=normal_loss, |
|
|
scales=3, |
|
|
conf=to_feed_conf, |
|
|
) |
|
|
elif "grad" in gradient_loss_fn: |
|
|
|
|
|
loss_grad = gradient_loss_multi_scale_wrapper( |
|
|
pred.reshape(bb*ss, hh, ww, nc), |
|
|
gt.reshape(bb*ss, hh, ww, nc), |
|
|
mask.reshape(bb*ss, hh, ww), |
|
|
gradient_loss_fn=gradient_loss, |
|
|
conf=to_feed_conf, |
|
|
) |
|
|
|
|
|
|
|
|
if loss_conf.numel() > 0: |
|
|
|
|
|
if valid_range>0: |
|
|
loss_conf = filter_by_quantile(loss_conf, valid_range) |
|
|
|
|
|
loss_conf = check_and_fix_inf_nan(loss_conf, f"loss_conf_depth") |
|
|
loss_conf = loss_conf.mean() |
|
|
else: |
|
|
loss_conf = (0.0 * pred).mean() |
|
|
|
|
|
|
|
|
if loss_reg.numel() > 0: |
|
|
|
|
|
if valid_range>0: |
|
|
loss_reg = filter_by_quantile(loss_reg, valid_range) |
|
|
|
|
|
loss_reg = check_and_fix_inf_nan(loss_reg, f"loss_reg_depth") |
|
|
loss_reg = loss_reg.mean() |
|
|
else: |
|
|
loss_reg = (0.0 * pred).mean() |
|
|
|
|
|
return loss_conf, loss_grad, loss_reg |
|
|
|
|
|
|
|
|
def gradient_loss_multi_scale_wrapper(prediction, target, mask, scales=4, gradient_loss_fn = None, conf=None): |
|
|
""" |
|
|
Multi-scale gradient loss wrapper. Applies gradient loss at multiple scales by subsampling the input. |
|
|
This helps capture both fine and coarse spatial structures. |
|
|
|
|
|
Args: |
|
|
prediction: (B, H, W, C) predicted values |
|
|
target: (B, H, W, C) ground truth values |
|
|
mask: (B, H, W) valid pixel mask |
|
|
scales: Number of scales to use |
|
|
gradient_loss_fn: Gradient loss function to apply |
|
|
conf: (B, H, W) confidence weights (optional) |
|
|
""" |
|
|
total = 0 |
|
|
for scale in range(scales): |
|
|
step = pow(2, scale) |
|
|
|
|
|
total += gradient_loss_fn( |
|
|
prediction[:, ::step, ::step], |
|
|
target[:, ::step, ::step], |
|
|
mask[:, ::step, ::step], |
|
|
conf=conf[:, ::step, ::step] if conf is not None else None |
|
|
) |
|
|
|
|
|
total = total / scales |
|
|
return total |
|
|
|
|
|
|
|
|
def normal_loss(prediction, target, mask, cos_eps=1e-8, conf=None, gamma=1.0, alpha=0.2): |
|
|
""" |
|
|
Surface normal-based loss for geometric consistency. |
|
|
|
|
|
Computes surface normals from 3D point maps using cross products of neighboring points, |
|
|
then measures the angle between predicted and ground truth normals. |
|
|
|
|
|
Args: |
|
|
prediction: (B, H, W, 3) predicted 3D coordinates/points |
|
|
target: (B, H, W, 3) ground-truth 3D coordinates/points |
|
|
mask: (B, H, W) valid pixel mask |
|
|
cos_eps: Epsilon for numerical stability in cosine computation |
|
|
conf: (B, H, W) confidence weights (optional) |
|
|
gamma: Weight for confidence loss |
|
|
alpha: Weight for confidence regularization |
|
|
""" |
|
|
|
|
|
pred_normals, pred_valids = point_map_to_normal(prediction, mask, eps=cos_eps) |
|
|
gt_normals, gt_valids = point_map_to_normal(target, mask, eps=cos_eps) |
|
|
|
|
|
|
|
|
all_valid = pred_valids & gt_valids |
|
|
|
|
|
|
|
|
divisor = torch.sum(all_valid) |
|
|
if divisor < 10: |
|
|
return 0 |
|
|
|
|
|
|
|
|
pred_normals = pred_normals[all_valid].clone() |
|
|
gt_normals = gt_normals[all_valid].clone() |
|
|
|
|
|
|
|
|
dot = torch.sum(pred_normals * gt_normals, dim=-1) |
|
|
|
|
|
|
|
|
dot = torch.clamp(dot, -1 + cos_eps, 1 - cos_eps) |
|
|
|
|
|
|
|
|
loss = 1 - dot |
|
|
|
|
|
|
|
|
if loss.numel() < 10: |
|
|
return 0 |
|
|
else: |
|
|
loss = check_and_fix_inf_nan(loss, "normal_loss") |
|
|
|
|
|
if conf is not None: |
|
|
|
|
|
conf = conf[None, ...].expand(4, -1, -1, -1) |
|
|
conf = conf[all_valid].clone() |
|
|
|
|
|
loss = gamma * loss * conf - alpha * torch.log(conf) |
|
|
return loss.mean() |
|
|
else: |
|
|
return loss.mean() |
|
|
|
|
|
|
|
|
def gradient_loss(prediction, target, mask, conf=None, gamma=1.0, alpha=0.2): |
|
|
""" |
|
|
Gradient-based loss. Computes the L1 difference between adjacent pixels in x and y directions. |
|
|
|
|
|
Args: |
|
|
prediction: (B, H, W, C) predicted values |
|
|
target: (B, H, W, C) ground truth values |
|
|
mask: (B, H, W) valid pixel mask |
|
|
conf: (B, H, W) confidence weights (optional) |
|
|
gamma: Weight for confidence loss |
|
|
alpha: Weight for confidence regularization |
|
|
""" |
|
|
|
|
|
mask = mask[..., None].expand(-1, -1, -1, prediction.shape[-1]) |
|
|
M = torch.sum(mask, (1, 2, 3)) |
|
|
|
|
|
|
|
|
diff = prediction - target |
|
|
diff = torch.mul(mask, diff) |
|
|
|
|
|
|
|
|
grad_x = torch.abs(diff[:, :, 1:] - diff[:, :, :-1]) |
|
|
mask_x = torch.mul(mask[:, :, 1:], mask[:, :, :-1]) |
|
|
grad_x = torch.mul(mask_x, grad_x) |
|
|
|
|
|
|
|
|
grad_y = torch.abs(diff[:, 1:, :] - diff[:, :-1, :]) |
|
|
mask_y = torch.mul(mask[:, 1:, :], mask[:, :-1, :]) |
|
|
grad_y = torch.mul(mask_y, grad_y) |
|
|
|
|
|
|
|
|
grad_x = grad_x.clamp(max=100) |
|
|
grad_y = grad_y.clamp(max=100) |
|
|
|
|
|
|
|
|
if conf is not None: |
|
|
conf = conf[..., None].expand(-1, -1, -1, prediction.shape[-1]) |
|
|
conf_x = conf[:, :, 1:] |
|
|
conf_y = conf[:, 1:, :] |
|
|
|
|
|
grad_x = gamma * grad_x * conf_x - alpha * torch.log(conf_x) |
|
|
grad_y = gamma * grad_y * conf_y - alpha * torch.log(conf_y) |
|
|
|
|
|
|
|
|
grad_loss = torch.sum(grad_x, (1, 2, 3)) + torch.sum(grad_y, (1, 2, 3)) |
|
|
divisor = torch.sum(M) |
|
|
|
|
|
if divisor == 0: |
|
|
return 0 |
|
|
else: |
|
|
grad_loss = torch.sum(grad_loss) / divisor |
|
|
|
|
|
return grad_loss |
|
|
|
|
|
|
|
|
def point_map_to_normal(point_map, mask, eps=1e-6): |
|
|
""" |
|
|
Convert 3D point map to surface normal vectors using cross products. |
|
|
|
|
|
Computes normals by taking cross products of neighboring point differences. |
|
|
Uses 4 different cross-product directions for robustness. |
|
|
|
|
|
Args: |
|
|
point_map: (B, H, W, 3) 3D points laid out in a 2D grid |
|
|
mask: (B, H, W) valid pixels (bool) |
|
|
eps: Epsilon for numerical stability in normalization |
|
|
|
|
|
Returns: |
|
|
normals: (4, B, H, W, 3) normal vectors for each of the 4 cross-product directions |
|
|
valids: (4, B, H, W) corresponding valid masks |
|
|
""" |
|
|
with torch.cuda.amp.autocast(enabled=False): |
|
|
|
|
|
padded_mask = F.pad(mask, (1, 1, 1, 1), mode='constant', value=0) |
|
|
pts = F.pad(point_map.permute(0, 3, 1, 2), (1,1,1,1), mode='constant', value=0).permute(0, 2, 3, 1) |
|
|
|
|
|
|
|
|
center = pts[:, 1:-1, 1:-1, :] |
|
|
up = pts[:, :-2, 1:-1, :] |
|
|
left = pts[:, 1:-1, :-2 , :] |
|
|
down = pts[:, 2:, 1:-1, :] |
|
|
right = pts[:, 1:-1, 2:, :] |
|
|
|
|
|
|
|
|
up_dir = up - center |
|
|
left_dir = left - center |
|
|
down_dir = down - center |
|
|
right_dir = right - center |
|
|
|
|
|
|
|
|
n1 = torch.cross(up_dir, left_dir, dim=-1) |
|
|
n2 = torch.cross(left_dir, down_dir, dim=-1) |
|
|
n3 = torch.cross(down_dir, right_dir, dim=-1) |
|
|
n4 = torch.cross(right_dir,up_dir, dim=-1) |
|
|
|
|
|
|
|
|
v1 = padded_mask[:, :-2, 1:-1] & padded_mask[:, 1:-1, 1:-1] & padded_mask[:, 1:-1, :-2] |
|
|
v2 = padded_mask[:, 1:-1, :-2 ] & padded_mask[:, 1:-1, 1:-1] & padded_mask[:, 2:, 1:-1] |
|
|
v3 = padded_mask[:, 2:, 1:-1] & padded_mask[:, 1:-1, 1:-1] & padded_mask[:, 1:-1, 2:] |
|
|
v4 = padded_mask[:, 1:-1, 2: ] & padded_mask[:, 1:-1, 1:-1] & padded_mask[:, :-2, 1:-1] |
|
|
|
|
|
|
|
|
normals = torch.stack([n1, n2, n3, n4], dim=0) |
|
|
valids = torch.stack([v1, v2, v3, v4], dim=0) |
|
|
|
|
|
|
|
|
normals = F.normalize(normals, p=2, dim=-1, eps=eps) |
|
|
|
|
|
return normals, valids |
|
|
|
|
|
|
|
|
def filter_by_quantile(loss_tensor, valid_range, min_elements=1000, hard_max=100): |
|
|
""" |
|
|
Filter loss tensor by keeping only values below a certain quantile threshold. |
|
|
|
|
|
This helps remove outliers that could destabilize training. |
|
|
|
|
|
Args: |
|
|
loss_tensor: Tensor containing loss values |
|
|
valid_range: Float between 0 and 1 indicating the quantile threshold |
|
|
min_elements: Minimum number of elements required to apply filtering |
|
|
hard_max: Maximum allowed value for any individual loss |
|
|
|
|
|
Returns: |
|
|
Filtered and clamped loss tensor |
|
|
""" |
|
|
if loss_tensor.numel() <= min_elements: |
|
|
|
|
|
return loss_tensor |
|
|
|
|
|
|
|
|
if loss_tensor.numel() > 100000000: |
|
|
|
|
|
indices = torch.randperm(loss_tensor.numel(), device=loss_tensor.device)[:1_000_000] |
|
|
loss_tensor = loss_tensor.view(-1)[indices] |
|
|
|
|
|
|
|
|
loss_tensor = loss_tensor.clamp(max=hard_max) |
|
|
|
|
|
|
|
|
quantile_thresh = torch_quantile(loss_tensor.detach(), valid_range) |
|
|
quantile_thresh = min(quantile_thresh, hard_max) |
|
|
|
|
|
|
|
|
quantile_mask = loss_tensor < quantile_thresh |
|
|
if quantile_mask.sum() > min_elements: |
|
|
return loss_tensor[quantile_mask] |
|
|
return loss_tensor |
|
|
|
|
|
|
|
|
def torch_quantile( |
|
|
input, |
|
|
q, |
|
|
dim = None, |
|
|
keepdim: bool = False, |
|
|
*, |
|
|
interpolation: str = "nearest", |
|
|
out: torch.Tensor = None, |
|
|
) -> torch.Tensor: |
|
|
"""Better torch.quantile for one SCALAR quantile. |
|
|
|
|
|
Using torch.kthvalue. Better than torch.quantile because: |
|
|
- No 2**24 input size limit (pytorch/issues/67592), |
|
|
- Much faster, at least on big input sizes. |
|
|
|
|
|
Arguments: |
|
|
input (torch.Tensor): See torch.quantile. |
|
|
q (float): See torch.quantile. Supports only scalar input |
|
|
currently. |
|
|
dim (int | None): See torch.quantile. |
|
|
keepdim (bool): See torch.quantile. Supports only False |
|
|
currently. |
|
|
interpolation: {"nearest", "lower", "higher"} |
|
|
See torch.quantile. |
|
|
out (torch.Tensor | None): See torch.quantile. Supports only |
|
|
None currently. |
|
|
""" |
|
|
|
|
|
|
|
|
try: |
|
|
q = float(q) |
|
|
assert 0 <= q <= 1 |
|
|
except Exception: |
|
|
raise ValueError(f"Only scalar input 0<=q<=1 is currently supported (got {q})!") |
|
|
|
|
|
|
|
|
if dim_was_none := dim is None: |
|
|
dim = 0 |
|
|
input = input.reshape((-1,) + (1,) * (input.ndim - 1)) |
|
|
|
|
|
|
|
|
if interpolation == "nearest": |
|
|
inter = round |
|
|
elif interpolation == "lower": |
|
|
inter = floor |
|
|
elif interpolation == "higher": |
|
|
inter = ceil |
|
|
else: |
|
|
raise ValueError( |
|
|
"Supported interpolations currently are {'nearest', 'lower', 'higher'} " |
|
|
f"(got '{interpolation}')!" |
|
|
) |
|
|
|
|
|
|
|
|
if out is not None: |
|
|
raise ValueError(f"Only None value is currently supported for out (got {out})!") |
|
|
|
|
|
|
|
|
k = inter(q * (input.shape[dim] - 1)) + 1 |
|
|
out = torch.kthvalue(input, k, dim, keepdim=True, out=out)[0] |
|
|
|
|
|
|
|
|
if keepdim: |
|
|
return out |
|
|
if dim_was_none: |
|
|
return out.squeeze() |
|
|
else: |
|
|
return out.squeeze(dim) |
|
|
|
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
''' |
|
|
def _compute_losses(self, coord_preds, vis_scores, conf_scores, batch): |
|
|
"""Compute tracking losses using sequence_loss""" |
|
|
gt_tracks = batch["tracks"] # B, S, N, 2 |
|
|
gt_track_vis_mask = batch["track_vis_mask"] # B, S, N |
|
|
|
|
|
# if self.training and hasattr(self, "train_query_points"): |
|
|
train_query_points = coord_preds[-1].shape[2] |
|
|
gt_tracks = gt_tracks[:, :, :train_query_points] |
|
|
gt_tracks = check_and_fix_inf_nan(gt_tracks, "gt_tracks", hard_max=None) |
|
|
|
|
|
gt_track_vis_mask = gt_track_vis_mask[:, :, :train_query_points] |
|
|
|
|
|
# Create validity mask that filters out tracks not visible in first frame |
|
|
valids = torch.ones_like(gt_track_vis_mask) |
|
|
mask = gt_track_vis_mask[:, 0, :] == True |
|
|
valids = valids * mask.unsqueeze(1) |
|
|
|
|
|
|
|
|
|
|
|
if not valids.any(): |
|
|
print("No valid tracks found in first frame") |
|
|
print("seq_name: ", batch["seq_name"]) |
|
|
print("ids: ", batch["ids"]) |
|
|
print("time: ", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) |
|
|
|
|
|
dummy_coord = coord_preds[0].mean() * 0 # keeps graph & grads |
|
|
dummy_vis = vis_scores.mean() * 0 |
|
|
if conf_scores is not None: |
|
|
dummy_conf = conf_scores.mean() * 0 |
|
|
else: |
|
|
dummy_conf = 0 |
|
|
return dummy_coord, dummy_vis, dummy_conf # three scalar zeros |
|
|
|
|
|
|
|
|
# Compute tracking loss using sequence_loss |
|
|
track_loss = sequence_loss( |
|
|
flow_preds=coord_preds, |
|
|
flow_gt=gt_tracks, |
|
|
vis=gt_track_vis_mask, |
|
|
valids=valids, |
|
|
**self.loss_kwargs |
|
|
) |
|
|
|
|
|
vis_loss = F.binary_cross_entropy_with_logits(vis_scores[valids], gt_track_vis_mask[valids].float()) |
|
|
|
|
|
vis_loss = check_and_fix_inf_nan(vis_loss, "vis_loss", hard_max=None) |
|
|
|
|
|
|
|
|
# within 3 pixels |
|
|
if conf_scores is not None: |
|
|
gt_conf_mask = (gt_tracks - coord_preds[-1]).norm(dim=-1) < 3 |
|
|
conf_loss = F.binary_cross_entropy_with_logits(conf_scores[valids], gt_conf_mask[valids].float()) |
|
|
conf_loss = check_and_fix_inf_nan(conf_loss, "conf_loss", hard_max=None) |
|
|
else: |
|
|
conf_loss = 0 |
|
|
|
|
|
return track_loss, vis_loss, conf_loss |
|
|
|
|
|
|
|
|
|
|
|
def reduce_masked_mean(x, mask, dim=None, keepdim=False): |
|
|
for a, b in zip(x.size(), mask.size()): |
|
|
assert a == b |
|
|
prod = x * mask |
|
|
|
|
|
if dim is None: |
|
|
numer = torch.sum(prod) |
|
|
denom = torch.sum(mask) |
|
|
else: |
|
|
numer = torch.sum(prod, dim=dim, keepdim=keepdim) |
|
|
denom = torch.sum(mask, dim=dim, keepdim=keepdim) |
|
|
|
|
|
mean = numer / denom.clamp(min=1) |
|
|
mean = torch.where(denom > 0, |
|
|
mean, |
|
|
torch.zeros_like(mean)) |
|
|
return mean |
|
|
|
|
|
|
|
|
def sequence_loss(flow_preds, flow_gt, vis, valids, gamma=0.8, vis_aware=False, huber=False, delta=10, vis_aware_w=0.1, **kwargs): |
|
|
"""Loss function defined over sequence of flow predictions""" |
|
|
B, S, N, D = flow_gt.shape |
|
|
assert D == 2 |
|
|
B, S1, N = vis.shape |
|
|
B, S2, N = valids.shape |
|
|
assert S == S1 |
|
|
assert S == S2 |
|
|
n_predictions = len(flow_preds) |
|
|
flow_loss = 0.0 |
|
|
|
|
|
for i in range(n_predictions): |
|
|
i_weight = gamma ** (n_predictions - i - 1) |
|
|
flow_pred = flow_preds[i] |
|
|
|
|
|
i_loss = (flow_pred - flow_gt).abs() # B, S, N, 2 |
|
|
i_loss = check_and_fix_inf_nan(i_loss, f"i_loss_iter_{i}", hard_max=None) |
|
|
|
|
|
i_loss = torch.mean(i_loss, dim=3) # B, S, N |
|
|
|
|
|
# Combine valids and vis for per-frame valid masking. |
|
|
combined_mask = torch.logical_and(valids, vis) |
|
|
|
|
|
num_valid_points = combined_mask.sum() |
|
|
|
|
|
if vis_aware: |
|
|
combined_mask = combined_mask.float() * (1.0 + vis_aware_w) # Add, don't add to the mask itself. |
|
|
flow_loss += i_weight * reduce_masked_mean(i_loss, combined_mask) |
|
|
else: |
|
|
if num_valid_points > 2: |
|
|
i_loss = i_loss[combined_mask] |
|
|
flow_loss += i_weight * i_loss.mean() |
|
|
else: |
|
|
i_loss = check_and_fix_inf_nan(i_loss, f"i_loss_iter_safe_check_{i}", hard_max=None) |
|
|
flow_loss += 0 * i_loss.mean() |
|
|
|
|
|
# Avoid division by zero if n_predictions is 0 (though it shouldn't be). |
|
|
if n_predictions > 0: |
|
|
flow_loss = flow_loss / n_predictions |
|
|
|
|
|
return flow_loss |
|
|
''' |
|
|
|
|
|
|
|
|
|