|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import einops |
|
|
|
|
|
def resize_images_lpips(images, img_size, img_size_min): |
|
|
image_out_size = (min(img_size[0], img_size_min), min(img_size[1], int(img_size_min/img_size[0]*img_size[1]))) |
|
|
return F.interpolate( |
|
|
images.view(-1, 3, img_size[0], img_size[1]) * 2 - 1, image_out_size, |
|
|
mode='bilinear', |
|
|
align_corners=False |
|
|
) |
|
|
|
|
|
def normalize_depth(depth: torch.Tensor, valid_mask: torch.Tensor): |
|
|
depth = einops.rearrange(depth, 'b t 1 h w -> b t (h w)') |
|
|
valid_mask = einops.rearrange(valid_mask, 'b t 1 h w -> b t (h w)').float() |
|
|
|
|
|
|
|
|
valid_count = valid_mask.sum(dim=-1, keepdim=True).clamp(min=1) |
|
|
|
|
|
|
|
|
depth_valid = depth * valid_mask |
|
|
|
|
|
|
|
|
depth_median = torch.median(depth_valid, dim=-1, keepdim=True)[0] |
|
|
|
|
|
|
|
|
depth_centered = (depth_valid - depth_median) * valid_mask |
|
|
|
|
|
|
|
|
depth_var = depth_centered.abs().sum(dim=-1, keepdim=True) / valid_count |
|
|
|
|
|
|
|
|
depth_var = torch.clamp(depth_var, min=1e-3, max=1e3) |
|
|
|
|
|
|
|
|
depth_normalized = depth_centered / depth_var |
|
|
|
|
|
return depth_normalized |
|
|
|
|
|
def compute_depth_loss(pred_depths: torch.Tensor, gt_depths: torch.Tensor): |
|
|
|
|
|
valid_mask = (gt_depths > 0) & torch.isfinite(gt_depths) |
|
|
|
|
|
|
|
|
pred_depths_norm = normalize_depth(pred_depths, valid_mask) |
|
|
gt_depths_norm = normalize_depth(gt_depths, valid_mask) |
|
|
|
|
|
|
|
|
valid_mask_float = einops.rearrange(valid_mask.float(), 'b t 1 h w -> b t (h w)') |
|
|
|
|
|
|
|
|
loss_depth = F.smooth_l1_loss( |
|
|
pred_depths_norm * valid_mask_float, |
|
|
gt_depths_norm * valid_mask_float |
|
|
) |
|
|
return loss_depth |
|
|
|
|
|
def compute_lpips_loss_in_chunks(lpips_loss_module, gt_images, pred_images, lpips_img_size, lpips_img_size_min, chunk_size=64): |
|
|
""" |
|
|
Computes LPIPS loss with chunking along the V dimension and uses gradient checkpointing. |
|
|
|
|
|
Args: |
|
|
lpips_loss_module: A callable LPIPS loss module. |
|
|
gt_images (Tensor): Ground truth images of shape (B, V, C, H, W). |
|
|
pred_images (Tensor): Predicted images of shape (B, V, C, H, W). |
|
|
lpips_img_size (int): Target image size for LPIPS. |
|
|
lpips_img_size_min (int): Minimum image size for LPIPS. |
|
|
chunk_size (int): Number of V elements to process at once. Default is 64. |
|
|
|
|
|
Returns: |
|
|
Tensor: Scalar LPIPS loss averaged over all (B * V) image pairs. |
|
|
""" |
|
|
B, V, C, H, W = gt_images.shape |
|
|
total_loss = [] |
|
|
num_chunks = (V + chunk_size - 1) // chunk_size |
|
|
|
|
|
for i in range(num_chunks): |
|
|
start = i * chunk_size |
|
|
end = min((i + 1) * chunk_size, V) |
|
|
|
|
|
gt_chunk = gt_images[:, start:end].reshape(-1, C, H, W) |
|
|
pred_chunk = pred_images[:, start:end].reshape(-1, C, H, W) |
|
|
|
|
|
gt_chunk = resize_images_lpips(gt_chunk, lpips_img_size, lpips_img_size_min) |
|
|
pred_chunk = resize_images_lpips(pred_chunk, lpips_img_size, lpips_img_size_min) |
|
|
|
|
|
loss_chunk = torch.utils.checkpoint.checkpoint( |
|
|
lpips_loss_module, |
|
|
gt_chunk, |
|
|
pred_chunk, |
|
|
use_reentrant=False |
|
|
) |
|
|
total_loss.append(loss_chunk) |
|
|
total_loss = torch.cat(total_loss, 0) |
|
|
total_loss = total_loss.mean((2, 3)) |
|
|
return total_loss |
|
|
|
|
|
def compute_loss(accelerator, train_loss, pred_images, gt_images, pred_depths, gt_depths, pred_opacity, config, lpips_loss_module=None, lpips_img_size=None): |
|
|
|
|
|
loss = F.mse_loss(pred_images, gt_images) |
|
|
|
|
|
if config.get('lambda_lpips', 0) > 0: |
|
|
if config.lpips_chunk_size is not None: |
|
|
loss_lpips = compute_lpips_loss_in_chunks(lpips_loss_module, gt_images, pred_images, lpips_img_size, config.lpips_img_size_min, config.lpips_chunk_size) |
|
|
else: |
|
|
loss_lpips = lpips_loss_module( |
|
|
resize_images_lpips(gt_images, lpips_img_size, lpips_img_size_min), |
|
|
resize_images_lpips(pred_images, lpips_img_size, lpips_img_size_min), |
|
|
) |
|
|
loss_lpips = loss_lpips.mean() |
|
|
loss = loss + config.lambda_lpips * loss_lpips |
|
|
|
|
|
if config.get('lambda_ssim', 0) > 0: |
|
|
ssim_img_size = config.img_size |
|
|
loss_ssim = fused_ssim( |
|
|
pred_images.view(-1, 3, ssim_img_size[0], ssim_img_size[1]).float(), |
|
|
gt_images.view(-1, 3, ssim_img_size[0], ssim_img_size[1]).float() |
|
|
) |
|
|
loss_ssim = (1 - loss_ssim) / 2 |
|
|
loss = loss + config.lambda_ssim * loss_ssim |
|
|
|
|
|
|
|
|
if config.get('lambda_depth', 0) > 0: |
|
|
loss_depth = compute_depth_loss(pred_depths, gt_depths) |
|
|
loss = loss + config.lambda_depth * loss_depth |
|
|
|
|
|
if config.get('lambda_opacity', 0) > 0: |
|
|
loss_opacity = pred_opacity.to(pred_images.dtype).sigmoid().mean() |
|
|
loss = loss + config.lambda_opacity * loss_opacity |
|
|
|
|
|
|
|
|
loss = loss.mean() |
|
|
|
|
|
|
|
|
avg_loss = accelerator.gather(loss.repeat(config.batch_size)).mean() |
|
|
train_loss += avg_loss.item() / config.gradient_accumulation_steps |
|
|
return train_loss, loss |