File size: 6,222 Bytes
af758d1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn.functional as F
import einops
def resize_images_lpips(images, img_size, img_size_min):
image_out_size = (min(img_size[0], img_size_min), min(img_size[1], int(img_size_min/img_size[0]*img_size[1])))
return F.interpolate(
images.view(-1, 3, img_size[0], img_size[1]) * 2 - 1, image_out_size,
mode='bilinear',
align_corners=False
)
def normalize_depth(depth: torch.Tensor, valid_mask: torch.Tensor):
depth = einops.rearrange(depth, 'b t 1 h w -> b t (h w)')
valid_mask = einops.rearrange(valid_mask, 'b t 1 h w -> b t (h w)').float()
# Count valid pixels (avoid zero count)
valid_count = valid_mask.sum(dim=-1, keepdim=True).clamp(min=1)
# Mask invalid pixels
depth_valid = depth * valid_mask
# Median over valid pixels
depth_median = torch.median(depth_valid, dim=-1, keepdim=True)[0]
# Centered depth (subtract median), mask invalids
depth_centered = (depth_valid - depth_median) * valid_mask
# Mean absolute deviation (only valid pixels)
depth_var = depth_centered.abs().sum(dim=-1, keepdim=True) / valid_count
# Clamp variance (avoid zero or inf)
depth_var = torch.clamp(depth_var, min=1e-3, max=1e3)
# Normalize centered depth
depth_normalized = depth_centered / depth_var
return depth_normalized
def compute_depth_loss(pred_depths: torch.Tensor, gt_depths: torch.Tensor):
# Valid mask computed once: depth > 0, finite, not nan
valid_mask = (gt_depths > 0) & torch.isfinite(gt_depths)
# Normalize using the same valid mask
pred_depths_norm = normalize_depth(pred_depths, valid_mask)
gt_depths_norm = normalize_depth(gt_depths, valid_mask)
# Flatten valid_mask to (b, t, h*w) to match normal[ized tensors
valid_mask_float = einops.rearrange(valid_mask.float(), 'b t 1 h w -> b t (h w)')
# Apply mask before loss
loss_depth = F.smooth_l1_loss(
pred_depths_norm * valid_mask_float,
gt_depths_norm * valid_mask_float
)
return loss_depth
def compute_lpips_loss_in_chunks(lpips_loss_module, gt_images, pred_images, lpips_img_size, lpips_img_size_min, chunk_size=64):
"""
Computes LPIPS loss with chunking along the V dimension and uses gradient checkpointing.
Args:
lpips_loss_module: A callable LPIPS loss module.
gt_images (Tensor): Ground truth images of shape (B, V, C, H, W).
pred_images (Tensor): Predicted images of shape (B, V, C, H, W).
lpips_img_size (int): Target image size for LPIPS.
lpips_img_size_min (int): Minimum image size for LPIPS.
chunk_size (int): Number of V elements to process at once. Default is 64.
Returns:
Tensor: Scalar LPIPS loss averaged over all (B * V) image pairs.
"""
B, V, C, H, W = gt_images.shape
total_loss = []
num_chunks = (V + chunk_size - 1) // chunk_size
for i in range(num_chunks):
start = i * chunk_size
end = min((i + 1) * chunk_size, V)
gt_chunk = gt_images[:, start:end].reshape(-1, C, H, W)
pred_chunk = pred_images[:, start:end].reshape(-1, C, H, W)
gt_chunk = resize_images_lpips(gt_chunk, lpips_img_size, lpips_img_size_min)
pred_chunk = resize_images_lpips(pred_chunk, lpips_img_size, lpips_img_size_min)
loss_chunk = torch.utils.checkpoint.checkpoint(
lpips_loss_module,
gt_chunk,
pred_chunk,
use_reentrant=False
)
total_loss.append(loss_chunk)
total_loss = torch.cat(total_loss, 0)
total_loss = total_loss.mean((2, 3))
return total_loss
def compute_loss(accelerator, train_loss, pred_images, gt_images, pred_depths, gt_depths, pred_opacity, config, lpips_loss_module=None, lpips_img_size=None):
# MSE loss
loss = F.mse_loss(pred_images, gt_images)
# LPIPS loss
if config.get('lambda_lpips', 0) > 0:
if config.lpips_chunk_size is not None:
loss_lpips = compute_lpips_loss_in_chunks(lpips_loss_module, gt_images, pred_images, lpips_img_size, config.lpips_img_size_min, config.lpips_chunk_size)
else:
loss_lpips = lpips_loss_module(
resize_images_lpips(gt_images, lpips_img_size, lpips_img_size_min),
resize_images_lpips(pred_images, lpips_img_size, lpips_img_size_min),
)
loss_lpips = loss_lpips.mean()
loss = loss + config.lambda_lpips * loss_lpips
# SSIM Loss
if config.get('lambda_ssim', 0) > 0:
ssim_img_size = config.img_size
loss_ssim = fused_ssim(
pred_images.view(-1, 3, ssim_img_size[0], ssim_img_size[1]).float(),
gt_images.view(-1, 3, ssim_img_size[0], ssim_img_size[1]).float()
)
loss_ssim = (1 - loss_ssim) / 2
loss = loss + config.lambda_ssim * loss_ssim
# Depth loss
if config.get('lambda_depth', 0) > 0:
loss_depth = compute_depth_loss(pred_depths, gt_depths)
loss = loss + config.lambda_depth * loss_depth
# Opacity loss
if config.get('lambda_opacity', 0) > 0:
loss_opacity = pred_opacity.to(pred_images.dtype).sigmoid().mean()
loss = loss + config.lambda_opacity * loss_opacity
# Average loss
loss = loss.mean()
# Gather the losses across all processes for logging (if we use distributed training).
avg_loss = accelerator.gather(loss.repeat(config.batch_size)).mean()
train_loss += avg_loss.item() / config.gradient_accumulation_steps
return train_loss, loss |