from typing import Dict, List, Optional, Tuple, Union import matplotlib.pyplot as plt import numpy as np import torch import torch.nn.functional as F from pytorch3d.loss import ( chamfer_distance, mesh_edge_loss, mesh_laplacian_smoothing, mesh_normal_consistency, ) from pytorch3d.structures import Meshes from engine.MVSRender.camera_utils import MiniCam, OrbitCamera, orbit_camera def compute_chamfer_loss( query_verts: torch.Tensor, gt_pts: torch.Tensor ) -> torch.Tensor: """Compute chamfer distance between two point sets.""" return chamfer_distance(query_verts[None], gt_pts[None])[0] def update_mesh_shape_prior_losses( mesh: Union[Meshes, "Mesh"], loss: Dict[str, torch.Tensor] ) -> None: """Update mesh regularization losses.""" if isinstance(mesh, Meshes): loss.update( { "edge": mesh_edge_loss(mesh), "normal": mesh_normal_consistency(mesh), "laplacian": mesh_laplacian_smoothing(mesh, method="uniform"), } ) else: loss.update( {"normal": mesh.normal_consistency(), "laplacian": mesh.laplacian()} ) def plot_losses(losses: Dict[str, Dict[str, List[float]]]) -> None: """Plot training losses over iterations.""" plt.figure(figsize=(13, 5)) ax = plt.gca() for k, l in losses.items(): ax.plot(l["values"], label=f"{k} loss") ax.legend(fontsize=16) ax.set_xlabel("Iteration", fontsize=16) ax.set_ylabel("Loss", fontsize=16) ax.set_title("Loss vs iterations", fontsize=16) def inverse_sigmoid(x: Union[float, torch.Tensor]) -> torch.Tensor: """Inverse sigmoid function.""" x = torch.tensor(x).float() if isinstance(x, float) else x return torch.log(x / (1 - x)) def image_grid( images: np.ndarray, rows: Optional[int] = None, cols: Optional[int] = None, fill: bool = True, show_axes: bool = False, rgb: bool = True, normal: bool = False, ) -> None: """Plot grid of images.""" if (rows is None) != (cols is None): raise ValueError("Specify either both rows and cols or neither") rows = rows or len(images) cols = cols or 1 fig, axes = plt.subplots( rows, cols, gridspec_kw={"wspace": 0.0, "hspace": 0.0} if fill else {}, figsize=(15, 9), ) fig.subplots_adjust(left=0, bottom=0, right=1, top=1) for ax, im in zip(np.array(axes).ravel(), images): channel_slice = slice(3) if rgb else slice(3, 4) channel_slice = slice(4, 7) if normal else channel_slice ax.imshow(im[..., channel_slice]) ax.set_axis_off() if not show_axes else None def visualize_prediction( predicted: torch.Tensor, target: torch.Tensor, vis_normal: bool = False, gt_normal: Optional[torch.Tensor] = None, pred_normal: Optional[torch.Tensor] = None, title: str = "", ) -> None: """Visualize prediction vs target comparison.""" figsize = (20, 10) if not vis_normal else (20, 20) plt.figure(figsize=figsize) if not vis_normal: plt.subplot(1, 2, 1).imshow(predicted.detach().cpu().numpy()) plt.subplot(1, 2, 2).imshow(target.detach().cpu().numpy()) else: plt.subplot(2, 2, 1).imshow(predicted.detach().cpu().numpy()) plt.subplot(2, 2, 2).imshow(target.detach().cpu().numpy()) plt.subplot(2, 2, 3).imshow(pred_normal.detach().cpu().numpy()) plt.subplot(2, 2, 4).imshow(gt_normal.detach().cpu().numpy()) plt.suptitle(title) plt.tight_layout() def camera_traj(cam, ref_size=1024, views=30, radius=2.0): azimuth_bins = 360 // views cameras = [] for rotate_i in range(30): pose = orbit_camera( 0, 0 + rotate_i * azimuth_bins, radius, ) cur_cam = MiniCam( pose, ref_size, ref_size, cam.fovy, cam.fovx, cam.near, cam.far, ) cameras.append(cur_cam) return cameras