File size: 30,473 Bytes
66003a2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn.functional as F
from dataclasses import dataclass
from vggt.utils.pose_enc import extri_intri_to_pose_encoding
from train_utils.general import check_and_fix_inf_nan
from math import ceil, floor
@dataclass(eq=False)
class MultitaskLoss(torch.nn.Module):
"""
Multi-task loss module that combines different loss types for VGGT.
Supports:
- Camera loss
- Depth loss
- Point loss
- Tracking loss (not cleaned yet, dirty code is at the bottom of this file)
"""
def __init__(self, camera=None, depth=None, point=None, track=None, **kwargs):
super().__init__()
# Loss configuration dictionaries for each task
self.camera = camera
self.depth = depth
self.point = point
self.track = track
def forward(self, predictions, batch) -> torch.Tensor:
"""
Compute the total multi-task loss.
Args:
predictions: Dict containing model predictions for different tasks
batch: Dict containing ground truth data and masks
Returns:
Dict containing individual losses and total objective
"""
total_loss = 0
loss_dict = {}
# Camera pose loss - if pose encodings are predicted
if "pose_enc_list" in predictions:
camera_loss_dict = compute_camera_loss(predictions, batch, **self.camera)
camera_loss = camera_loss_dict["loss_camera"] * self.camera["weight"]
total_loss = total_loss + camera_loss
loss_dict.update(camera_loss_dict)
# Depth estimation loss - if depth maps are predicted
if "depth" in predictions:
depth_loss_dict = compute_depth_loss(predictions, batch, **self.depth)
depth_loss = depth_loss_dict["loss_conf_depth"] + depth_loss_dict["loss_reg_depth"] + depth_loss_dict["loss_grad_depth"]
depth_loss = depth_loss * self.depth["weight"]
total_loss = total_loss + depth_loss
loss_dict.update(depth_loss_dict)
# 3D point reconstruction loss - if world points are predicted
if "world_points" in predictions:
point_loss_dict = compute_point_loss(predictions, batch, **self.point)
point_loss = point_loss_dict["loss_conf_point"] + point_loss_dict["loss_reg_point"] + point_loss_dict["loss_grad_point"]
point_loss = point_loss * self.point["weight"]
total_loss = total_loss + point_loss
loss_dict.update(point_loss_dict)
# Tracking loss - not cleaned yet, dirty code is at the bottom of this file
if "track" in predictions:
raise NotImplementedError("Track loss is not cleaned up yet")
loss_dict["objective"] = total_loss
return loss_dict
def compute_camera_loss(
pred_dict, # predictions dict, contains pose encodings
batch_data, # ground truth and mask batch dict
loss_type="l1", # "l1" or "l2" loss
gamma=0.6, # temporal decay weight for multi-stage training
pose_encoding_type="absT_quaR_FoV",
weight_trans=1.0, # weight for translation loss
weight_rot=1.0, # weight for rotation loss
weight_focal=0.5, # weight for focal length loss
**kwargs
):
# List of predicted pose encodings per stage
pred_pose_encodings = pred_dict['pose_enc_list']
# Binary mask for valid points per frame (B, N, H, W)
point_masks = batch_data['point_masks']
# Only consider frames with enough valid points (>100)
valid_frame_mask = point_masks[:, 0].sum(dim=[-1, -2]) > 100
# Number of prediction stages
n_stages = len(pred_pose_encodings)
# Get ground truth camera extrinsics and intrinsics
gt_extrinsics = batch_data['extrinsics']
gt_intrinsics = batch_data['intrinsics']
image_hw = batch_data['images'].shape[-2:]
# Encode ground truth pose to match predicted encoding format
gt_pose_encoding = extri_intri_to_pose_encoding(
gt_extrinsics, gt_intrinsics, image_hw, pose_encoding_type=pose_encoding_type
)
# Initialize loss accumulators for translation, rotation, focal length
total_loss_T = total_loss_R = total_loss_FL = 0
# Compute loss for each prediction stage with temporal weighting
for stage_idx in range(n_stages):
# Later stages get higher weight (gamma^0 = 1.0 for final stage)
stage_weight = gamma ** (n_stages - stage_idx - 1)
pred_pose_stage = pred_pose_encodings[stage_idx]
if valid_frame_mask.sum() == 0:
# If no valid frames, set losses to zero to avoid gradient issues
loss_T_stage = (pred_pose_stage * 0).mean()
loss_R_stage = (pred_pose_stage * 0).mean()
loss_FL_stage = (pred_pose_stage * 0).mean()
else:
# Only consider valid frames for loss computation
loss_T_stage, loss_R_stage, loss_FL_stage = camera_loss_single(
pred_pose_stage[valid_frame_mask].clone(),
gt_pose_encoding[valid_frame_mask].clone(),
loss_type=loss_type
)
# Accumulate weighted losses across stages
total_loss_T += loss_T_stage * stage_weight
total_loss_R += loss_R_stage * stage_weight
total_loss_FL += loss_FL_stage * stage_weight
# Average over all stages
avg_loss_T = total_loss_T / n_stages
avg_loss_R = total_loss_R / n_stages
avg_loss_FL = total_loss_FL / n_stages
# Compute total weighted camera loss
total_camera_loss = (
avg_loss_T * weight_trans +
avg_loss_R * weight_rot +
avg_loss_FL * weight_focal
)
# Return loss dictionary with individual components
return {
"loss_camera": total_camera_loss,
"loss_T": avg_loss_T,
"loss_R": avg_loss_R,
"loss_FL": avg_loss_FL
}
def camera_loss_single(pred_pose_enc, gt_pose_enc, loss_type="l1"):
"""
Computes translation, rotation, and focal loss for a batch of pose encodings.
Args:
pred_pose_enc: (N, D) predicted pose encoding
gt_pose_enc: (N, D) ground truth pose encoding
loss_type: "l1" (abs error) or "l2" (euclidean error)
Returns:
loss_T: translation loss (mean)
loss_R: rotation loss (mean)
loss_FL: focal length/intrinsics loss (mean)
NOTE: The paper uses smooth l1 loss, but we found l1 loss is more stable than smooth l1 and l2 loss.
So here we use l1 loss.
"""
if loss_type == "l1":
# Translation: first 3 dims; Rotation: next 4 (quaternion); Focal/Intrinsics: last dims
loss_T = (pred_pose_enc[..., :3] - gt_pose_enc[..., :3]).abs()
loss_R = (pred_pose_enc[..., 3:7] - gt_pose_enc[..., 3:7]).abs()
loss_FL = (pred_pose_enc[..., 7:] - gt_pose_enc[..., 7:]).abs()
elif loss_type == "l2":
# L2 norm for each component
loss_T = (pred_pose_enc[..., :3] - gt_pose_enc[..., :3]).norm(dim=-1, keepdim=True)
loss_R = (pred_pose_enc[..., 3:7] - gt_pose_enc[..., 3:7]).norm(dim=-1)
loss_FL = (pred_pose_enc[..., 7:] - gt_pose_enc[..., 7:]).norm(dim=-1)
else:
raise ValueError(f"Unknown loss type: {loss_type}")
# Check/fix numerical issues (nan/inf) for each loss component
loss_T = check_and_fix_inf_nan(loss_T, "loss_T")
loss_R = check_and_fix_inf_nan(loss_R, "loss_R")
loss_FL = check_and_fix_inf_nan(loss_FL, "loss_FL")
# Clamp outlier translation loss to prevent instability, then average
loss_T = loss_T.clamp(max=100).mean()
loss_R = loss_R.mean()
loss_FL = loss_FL.mean()
return loss_T, loss_R, loss_FL
def compute_point_loss(predictions, batch, gamma=1.0, alpha=0.2, gradient_loss_fn = None, valid_range=-1, **kwargs):
"""
Compute point loss.
Args:
predictions: Dict containing 'world_points' and 'world_points_conf'
batch: Dict containing ground truth 'world_points' and 'point_masks'
gamma: Weight for confidence loss
alpha: Weight for confidence regularization
gradient_loss_fn: Type of gradient loss to apply
valid_range: Quantile range for outlier filtering
"""
pred_points = predictions['world_points']
pred_points_conf = predictions['world_points_conf']
gt_points = batch['world_points']
gt_points_mask = batch['point_masks']
gt_points = check_and_fix_inf_nan(gt_points, "gt_points")
if gt_points_mask.sum() < 100:
# If there are less than 100 valid points, skip this batch
dummy_loss = (0.0 * pred_points).mean()
loss_dict = {f"loss_conf_point": dummy_loss,
f"loss_reg_point": dummy_loss,
f"loss_grad_point": dummy_loss,}
return loss_dict
# Compute confidence-weighted regression loss with optional gradient loss
loss_conf, loss_grad, loss_reg = regression_loss(pred_points, gt_points, gt_points_mask, conf=pred_points_conf,
gradient_loss_fn=gradient_loss_fn, gamma=gamma, alpha=alpha, valid_range=valid_range)
loss_dict = {
f"loss_conf_point": loss_conf,
f"loss_reg_point": loss_reg,
f"loss_grad_point": loss_grad,
}
return loss_dict
def compute_depth_loss(predictions, batch, gamma=1.0, alpha=0.2, gradient_loss_fn = None, valid_range=-1, **kwargs):
"""
Compute depth loss.
Args:
predictions: Dict containing 'depth' and 'depth_conf'
batch: Dict containing ground truth 'depths' and 'point_masks'
gamma: Weight for confidence loss
alpha: Weight for confidence regularization
gradient_loss_fn: Type of gradient loss to apply
valid_range: Quantile range for outlier filtering
"""
pred_depth = predictions['depth']
pred_depth_conf = predictions['depth_conf']
gt_depth = batch['depths']
gt_depth = check_and_fix_inf_nan(gt_depth, "gt_depth")
gt_depth = gt_depth[..., None] # (B, H, W, 1)
gt_depth_mask = batch['point_masks'].clone() # 3D points derived from depth map, so we use the same mask
if gt_depth_mask.sum() < 100:
# If there are less than 100 valid points, skip this batch
dummy_loss = (0.0 * pred_depth).mean()
loss_dict = {f"loss_conf_depth": dummy_loss,
f"loss_reg_depth": dummy_loss,
f"loss_grad_depth": dummy_loss,}
return loss_dict
# NOTE: we put conf inside regression_loss so that we can also apply conf loss to the gradient loss in a multi-scale manner
# this is hacky, but very easier to implement
loss_conf, loss_grad, loss_reg = regression_loss(pred_depth, gt_depth, gt_depth_mask, conf=pred_depth_conf,
gradient_loss_fn=gradient_loss_fn, gamma=gamma, alpha=alpha, valid_range=valid_range)
loss_dict = {
f"loss_conf_depth": loss_conf,
f"loss_reg_depth": loss_reg,
f"loss_grad_depth": loss_grad,
}
return loss_dict
def regression_loss(pred, gt, mask, conf=None, gradient_loss_fn=None, gamma=1.0, alpha=0.2, valid_range=-1):
"""
Core regression loss function with confidence weighting and optional gradient loss.
Computes:
1. gamma * ||pred - gt||^2 * conf - alpha * log(conf)
2. Optional gradient loss
Args:
pred: (B, S, H, W, C) predicted values
gt: (B, S, H, W, C) ground truth values
mask: (B, S, H, W) valid pixel mask
conf: (B, S, H, W) confidence weights (optional)
gradient_loss_fn: Type of gradient loss ("normal", "grad", etc.)
gamma: Weight for confidence loss
alpha: Weight for confidence regularization
valid_range: Quantile range for outlier filtering
Returns:
loss_conf: Confidence-weighted loss
loss_grad: Gradient loss (0 if not specified)
loss_reg: Regular L2 loss
"""
bb, ss, hh, ww, nc = pred.shape
# Compute L2 distance between predicted and ground truth points
loss_reg = torch.norm(gt[mask] - pred[mask], dim=-1)
loss_reg = check_and_fix_inf_nan(loss_reg, "loss_reg")
# Confidence-weighted loss: gamma * loss * conf - alpha * log(conf)
# This encourages the model to be confident on easy examples and less confident on hard ones
loss_conf = gamma * loss_reg * conf[mask] - alpha * torch.log(conf[mask])
loss_conf = check_and_fix_inf_nan(loss_conf, "loss_conf")
# Initialize gradient loss
loss_grad = 0
# Prepare confidence for gradient loss if needed
if "conf" in gradient_loss_fn:
to_feed_conf = conf.reshape(bb*ss, hh, ww)
else:
to_feed_conf = None
# Compute gradient loss if specified for spatial smoothness
if "normal" in gradient_loss_fn:
# Surface normal-based gradient loss
loss_grad = gradient_loss_multi_scale_wrapper(
pred.reshape(bb*ss, hh, ww, nc),
gt.reshape(bb*ss, hh, ww, nc),
mask.reshape(bb*ss, hh, ww),
gradient_loss_fn=normal_loss,
scales=3,
conf=to_feed_conf,
)
elif "grad" in gradient_loss_fn:
# Standard gradient-based loss
loss_grad = gradient_loss_multi_scale_wrapper(
pred.reshape(bb*ss, hh, ww, nc),
gt.reshape(bb*ss, hh, ww, nc),
mask.reshape(bb*ss, hh, ww),
gradient_loss_fn=gradient_loss,
conf=to_feed_conf,
)
# Process confidence-weighted loss
if loss_conf.numel() > 0:
# Filter out outliers using quantile-based thresholding
if valid_range>0:
loss_conf = filter_by_quantile(loss_conf, valid_range)
loss_conf = check_and_fix_inf_nan(loss_conf, f"loss_conf_depth")
loss_conf = loss_conf.mean()
else:
loss_conf = (0.0 * pred).mean()
# Process regular regression loss
if loss_reg.numel() > 0:
# Filter out outliers using quantile-based thresholding
if valid_range>0:
loss_reg = filter_by_quantile(loss_reg, valid_range)
loss_reg = check_and_fix_inf_nan(loss_reg, f"loss_reg_depth")
loss_reg = loss_reg.mean()
else:
loss_reg = (0.0 * pred).mean()
return loss_conf, loss_grad, loss_reg
def gradient_loss_multi_scale_wrapper(prediction, target, mask, scales=4, gradient_loss_fn = None, conf=None):
"""
Multi-scale gradient loss wrapper. Applies gradient loss at multiple scales by subsampling the input.
This helps capture both fine and coarse spatial structures.
Args:
prediction: (B, H, W, C) predicted values
target: (B, H, W, C) ground truth values
mask: (B, H, W) valid pixel mask
scales: Number of scales to use
gradient_loss_fn: Gradient loss function to apply
conf: (B, H, W) confidence weights (optional)
"""
total = 0
for scale in range(scales):
step = pow(2, scale) # Subsample by 2^scale
total += gradient_loss_fn(
prediction[:, ::step, ::step],
target[:, ::step, ::step],
mask[:, ::step, ::step],
conf=conf[:, ::step, ::step] if conf is not None else None
)
total = total / scales
return total
def normal_loss(prediction, target, mask, cos_eps=1e-8, conf=None, gamma=1.0, alpha=0.2):
"""
Surface normal-based loss for geometric consistency.
Computes surface normals from 3D point maps using cross products of neighboring points,
then measures the angle between predicted and ground truth normals.
Args:
prediction: (B, H, W, 3) predicted 3D coordinates/points
target: (B, H, W, 3) ground-truth 3D coordinates/points
mask: (B, H, W) valid pixel mask
cos_eps: Epsilon for numerical stability in cosine computation
conf: (B, H, W) confidence weights (optional)
gamma: Weight for confidence loss
alpha: Weight for confidence regularization
"""
# Convert point maps to surface normals using cross products
pred_normals, pred_valids = point_map_to_normal(prediction, mask, eps=cos_eps)
gt_normals, gt_valids = point_map_to_normal(target, mask, eps=cos_eps)
# Only consider regions where both predicted and GT normals are valid
all_valid = pred_valids & gt_valids # shape: (4, B, H, W)
# Early return if not enough valid points
divisor = torch.sum(all_valid)
if divisor < 10:
return 0
# Extract valid normals
pred_normals = pred_normals[all_valid].clone()
gt_normals = gt_normals[all_valid].clone()
# Compute cosine similarity between corresponding normals
dot = torch.sum(pred_normals * gt_normals, dim=-1)
# Clamp dot product to [-1, 1] for numerical stability
dot = torch.clamp(dot, -1 + cos_eps, 1 - cos_eps)
# Compute loss as 1 - cos(theta), instead of arccos(dot) for numerical stability
loss = 1 - dot
# Return mean loss if we have enough valid points
if loss.numel() < 10:
return 0
else:
loss = check_and_fix_inf_nan(loss, "normal_loss")
if conf is not None:
# Apply confidence weighting
conf = conf[None, ...].expand(4, -1, -1, -1)
conf = conf[all_valid].clone()
loss = gamma * loss * conf - alpha * torch.log(conf)
return loss.mean()
else:
return loss.mean()
def gradient_loss(prediction, target, mask, conf=None, gamma=1.0, alpha=0.2):
"""
Gradient-based loss. Computes the L1 difference between adjacent pixels in x and y directions.
Args:
prediction: (B, H, W, C) predicted values
target: (B, H, W, C) ground truth values
mask: (B, H, W) valid pixel mask
conf: (B, H, W) confidence weights (optional)
gamma: Weight for confidence loss
alpha: Weight for confidence regularization
"""
# Expand mask to match prediction channels
mask = mask[..., None].expand(-1, -1, -1, prediction.shape[-1])
M = torch.sum(mask, (1, 2, 3))
# Compute difference between prediction and target
diff = prediction - target
diff = torch.mul(mask, diff)
# Compute gradients in x direction (horizontal)
grad_x = torch.abs(diff[:, :, 1:] - diff[:, :, :-1])
mask_x = torch.mul(mask[:, :, 1:], mask[:, :, :-1])
grad_x = torch.mul(mask_x, grad_x)
# Compute gradients in y direction (vertical)
grad_y = torch.abs(diff[:, 1:, :] - diff[:, :-1, :])
mask_y = torch.mul(mask[:, 1:, :], mask[:, :-1, :])
grad_y = torch.mul(mask_y, grad_y)
# Clamp gradients to prevent outliers
grad_x = grad_x.clamp(max=100)
grad_y = grad_y.clamp(max=100)
# Apply confidence weighting if provided
if conf is not None:
conf = conf[..., None].expand(-1, -1, -1, prediction.shape[-1])
conf_x = conf[:, :, 1:]
conf_y = conf[:, 1:, :]
grad_x = gamma * grad_x * conf_x - alpha * torch.log(conf_x)
grad_y = gamma * grad_y * conf_y - alpha * torch.log(conf_y)
# Sum gradients and normalize by number of valid pixels
grad_loss = torch.sum(grad_x, (1, 2, 3)) + torch.sum(grad_y, (1, 2, 3))
divisor = torch.sum(M)
if divisor == 0:
return 0
else:
grad_loss = torch.sum(grad_loss) / divisor
return grad_loss
def point_map_to_normal(point_map, mask, eps=1e-6):
"""
Convert 3D point map to surface normal vectors using cross products.
Computes normals by taking cross products of neighboring point differences.
Uses 4 different cross-product directions for robustness.
Args:
point_map: (B, H, W, 3) 3D points laid out in a 2D grid
mask: (B, H, W) valid pixels (bool)
eps: Epsilon for numerical stability in normalization
Returns:
normals: (4, B, H, W, 3) normal vectors for each of the 4 cross-product directions
valids: (4, B, H, W) corresponding valid masks
"""
with torch.cuda.amp.autocast(enabled=False):
# Pad inputs to avoid boundary issues
padded_mask = F.pad(mask, (1, 1, 1, 1), mode='constant', value=0)
pts = F.pad(point_map.permute(0, 3, 1, 2), (1,1,1,1), mode='constant', value=0).permute(0, 2, 3, 1)
# Get neighboring points for each pixel
center = pts[:, 1:-1, 1:-1, :] # B,H,W,3
up = pts[:, :-2, 1:-1, :]
left = pts[:, 1:-1, :-2 , :]
down = pts[:, 2:, 1:-1, :]
right = pts[:, 1:-1, 2:, :]
# Compute direction vectors from center to neighbors
up_dir = up - center
left_dir = left - center
down_dir = down - center
right_dir = right - center
# Compute four cross products for different normal directions
n1 = torch.cross(up_dir, left_dir, dim=-1) # up x left
n2 = torch.cross(left_dir, down_dir, dim=-1) # left x down
n3 = torch.cross(down_dir, right_dir, dim=-1) # down x right
n4 = torch.cross(right_dir,up_dir, dim=-1) # right x up
# Validity masks - require both direction pixels to be valid
v1 = padded_mask[:, :-2, 1:-1] & padded_mask[:, 1:-1, 1:-1] & padded_mask[:, 1:-1, :-2]
v2 = padded_mask[:, 1:-1, :-2 ] & padded_mask[:, 1:-1, 1:-1] & padded_mask[:, 2:, 1:-1]
v3 = padded_mask[:, 2:, 1:-1] & padded_mask[:, 1:-1, 1:-1] & padded_mask[:, 1:-1, 2:]
v4 = padded_mask[:, 1:-1, 2: ] & padded_mask[:, 1:-1, 1:-1] & padded_mask[:, :-2, 1:-1]
# Stack normals and validity masks
normals = torch.stack([n1, n2, n3, n4], dim=0) # shape [4, B, H, W, 3]
valids = torch.stack([v1, v2, v3, v4], dim=0) # shape [4, B, H, W]
# Normalize normal vectors
normals = F.normalize(normals, p=2, dim=-1, eps=eps)
return normals, valids
def filter_by_quantile(loss_tensor, valid_range, min_elements=1000, hard_max=100):
"""
Filter loss tensor by keeping only values below a certain quantile threshold.
This helps remove outliers that could destabilize training.
Args:
loss_tensor: Tensor containing loss values
valid_range: Float between 0 and 1 indicating the quantile threshold
min_elements: Minimum number of elements required to apply filtering
hard_max: Maximum allowed value for any individual loss
Returns:
Filtered and clamped loss tensor
"""
if loss_tensor.numel() <= min_elements:
# Too few elements, just return as-is
return loss_tensor
# Randomly sample if tensor is too large to avoid memory issues
if loss_tensor.numel() > 100000000:
# Flatten and randomly select 1M elements
indices = torch.randperm(loss_tensor.numel(), device=loss_tensor.device)[:1_000_000]
loss_tensor = loss_tensor.view(-1)[indices]
# First clamp individual values to prevent extreme outliers
loss_tensor = loss_tensor.clamp(max=hard_max)
# Compute quantile threshold
quantile_thresh = torch_quantile(loss_tensor.detach(), valid_range)
quantile_thresh = min(quantile_thresh, hard_max)
# Apply quantile filtering if enough elements remain
quantile_mask = loss_tensor < quantile_thresh
if quantile_mask.sum() > min_elements:
return loss_tensor[quantile_mask]
return loss_tensor
def torch_quantile(
input,
q,
dim = None,
keepdim: bool = False,
*,
interpolation: str = "nearest",
out: torch.Tensor = None,
) -> torch.Tensor:
"""Better torch.quantile for one SCALAR quantile.
Using torch.kthvalue. Better than torch.quantile because:
- No 2**24 input size limit (pytorch/issues/67592),
- Much faster, at least on big input sizes.
Arguments:
input (torch.Tensor): See torch.quantile.
q (float): See torch.quantile. Supports only scalar input
currently.
dim (int | None): See torch.quantile.
keepdim (bool): See torch.quantile. Supports only False
currently.
interpolation: {"nearest", "lower", "higher"}
See torch.quantile.
out (torch.Tensor | None): See torch.quantile. Supports only
None currently.
"""
# https://github.com/pytorch/pytorch/issues/64947
# Sanitization: q
try:
q = float(q)
assert 0 <= q <= 1
except Exception:
raise ValueError(f"Only scalar input 0<=q<=1 is currently supported (got {q})!")
# Handle dim=None case
if dim_was_none := dim is None:
dim = 0
input = input.reshape((-1,) + (1,) * (input.ndim - 1))
# Set interpolation method
if interpolation == "nearest":
inter = round
elif interpolation == "lower":
inter = floor
elif interpolation == "higher":
inter = ceil
else:
raise ValueError(
"Supported interpolations currently are {'nearest', 'lower', 'higher'} "
f"(got '{interpolation}')!"
)
# Validate out parameter
if out is not None:
raise ValueError(f"Only None value is currently supported for out (got {out})!")
# Compute k-th value
k = inter(q * (input.shape[dim] - 1)) + 1
out = torch.kthvalue(input, k, dim, keepdim=True, out=out)[0]
# Handle keepdim and dim=None cases
if keepdim:
return out
if dim_was_none:
return out.squeeze()
else:
return out.squeeze(dim)
return out
########################################################################################
########################################################################################
# Dirty code for tracking loss:
########################################################################################
########################################################################################
'''
def _compute_losses(self, coord_preds, vis_scores, conf_scores, batch):
"""Compute tracking losses using sequence_loss"""
gt_tracks = batch["tracks"] # B, S, N, 2
gt_track_vis_mask = batch["track_vis_mask"] # B, S, N
# if self.training and hasattr(self, "train_query_points"):
train_query_points = coord_preds[-1].shape[2]
gt_tracks = gt_tracks[:, :, :train_query_points]
gt_tracks = check_and_fix_inf_nan(gt_tracks, "gt_tracks", hard_max=None)
gt_track_vis_mask = gt_track_vis_mask[:, :, :train_query_points]
# Create validity mask that filters out tracks not visible in first frame
valids = torch.ones_like(gt_track_vis_mask)
mask = gt_track_vis_mask[:, 0, :] == True
valids = valids * mask.unsqueeze(1)
if not valids.any():
print("No valid tracks found in first frame")
print("seq_name: ", batch["seq_name"])
print("ids: ", batch["ids"])
print("time: ", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
dummy_coord = coord_preds[0].mean() * 0 # keeps graph & grads
dummy_vis = vis_scores.mean() * 0
if conf_scores is not None:
dummy_conf = conf_scores.mean() * 0
else:
dummy_conf = 0
return dummy_coord, dummy_vis, dummy_conf # three scalar zeros
# Compute tracking loss using sequence_loss
track_loss = sequence_loss(
flow_preds=coord_preds,
flow_gt=gt_tracks,
vis=gt_track_vis_mask,
valids=valids,
**self.loss_kwargs
)
vis_loss = F.binary_cross_entropy_with_logits(vis_scores[valids], gt_track_vis_mask[valids].float())
vis_loss = check_and_fix_inf_nan(vis_loss, "vis_loss", hard_max=None)
# within 3 pixels
if conf_scores is not None:
gt_conf_mask = (gt_tracks - coord_preds[-1]).norm(dim=-1) < 3
conf_loss = F.binary_cross_entropy_with_logits(conf_scores[valids], gt_conf_mask[valids].float())
conf_loss = check_and_fix_inf_nan(conf_loss, "conf_loss", hard_max=None)
else:
conf_loss = 0
return track_loss, vis_loss, conf_loss
def reduce_masked_mean(x, mask, dim=None, keepdim=False):
for a, b in zip(x.size(), mask.size()):
assert a == b
prod = x * mask
if dim is None:
numer = torch.sum(prod)
denom = torch.sum(mask)
else:
numer = torch.sum(prod, dim=dim, keepdim=keepdim)
denom = torch.sum(mask, dim=dim, keepdim=keepdim)
mean = numer / denom.clamp(min=1)
mean = torch.where(denom > 0,
mean,
torch.zeros_like(mean))
return mean
def sequence_loss(flow_preds, flow_gt, vis, valids, gamma=0.8, vis_aware=False, huber=False, delta=10, vis_aware_w=0.1, **kwargs):
"""Loss function defined over sequence of flow predictions"""
B, S, N, D = flow_gt.shape
assert D == 2
B, S1, N = vis.shape
B, S2, N = valids.shape
assert S == S1
assert S == S2
n_predictions = len(flow_preds)
flow_loss = 0.0
for i in range(n_predictions):
i_weight = gamma ** (n_predictions - i - 1)
flow_pred = flow_preds[i]
i_loss = (flow_pred - flow_gt).abs() # B, S, N, 2
i_loss = check_and_fix_inf_nan(i_loss, f"i_loss_iter_{i}", hard_max=None)
i_loss = torch.mean(i_loss, dim=3) # B, S, N
# Combine valids and vis for per-frame valid masking.
combined_mask = torch.logical_and(valids, vis)
num_valid_points = combined_mask.sum()
if vis_aware:
combined_mask = combined_mask.float() * (1.0 + vis_aware_w) # Add, don't add to the mask itself.
flow_loss += i_weight * reduce_masked_mean(i_loss, combined_mask)
else:
if num_valid_points > 2:
i_loss = i_loss[combined_mask]
flow_loss += i_weight * i_loss.mean()
else:
i_loss = check_and_fix_inf_nan(i_loss, f"i_loss_iter_safe_check_{i}", hard_max=None)
flow_loss += 0 * i_loss.mean()
# Avoid division by zero if n_predictions is 0 (though it shouldn't be).
if n_predictions > 0:
flow_loss = flow_loss / n_predictions
return flow_loss
'''
|