| from typing import Dict, List, Optional, Tuple, Union |
|
|
| import matplotlib.pyplot as plt |
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from pytorch3d.loss import ( |
| chamfer_distance, |
| mesh_edge_loss, |
| mesh_laplacian_smoothing, |
| mesh_normal_consistency, |
| ) |
| from pytorch3d.structures import Meshes |
|
|
| from engine.MVSRender.camera_utils import MiniCam, OrbitCamera, orbit_camera |
|
|
|
|
| def compute_chamfer_loss( |
| query_verts: torch.Tensor, gt_pts: torch.Tensor |
| ) -> torch.Tensor: |
| """Compute chamfer distance between two point sets.""" |
| return chamfer_distance(query_verts[None], gt_pts[None])[0] |
|
|
|
|
| def update_mesh_shape_prior_losses( |
| mesh: Union[Meshes, "Mesh"], loss: Dict[str, torch.Tensor] |
| ) -> None: |
| """Update mesh regularization losses.""" |
| if isinstance(mesh, Meshes): |
| loss.update( |
| { |
| "edge": mesh_edge_loss(mesh), |
| "normal": mesh_normal_consistency(mesh), |
| "laplacian": mesh_laplacian_smoothing(mesh, method="uniform"), |
| } |
| ) |
| else: |
| loss.update( |
| {"normal": mesh.normal_consistency(), "laplacian": mesh.laplacian()} |
| ) |
|
|
|
|
| def plot_losses(losses: Dict[str, Dict[str, List[float]]]) -> None: |
| """Plot training losses over iterations.""" |
| plt.figure(figsize=(13, 5)) |
| ax = plt.gca() |
| for k, l in losses.items(): |
| ax.plot(l["values"], label=f"{k} loss") |
| ax.legend(fontsize=16) |
| ax.set_xlabel("Iteration", fontsize=16) |
| ax.set_ylabel("Loss", fontsize=16) |
| ax.set_title("Loss vs iterations", fontsize=16) |
|
|
|
|
| def inverse_sigmoid(x: Union[float, torch.Tensor]) -> torch.Tensor: |
| """Inverse sigmoid function.""" |
| x = torch.tensor(x).float() if isinstance(x, float) else x |
| return torch.log(x / (1 - x)) |
|
|
|
|
| def image_grid( |
| images: np.ndarray, |
| rows: Optional[int] = None, |
| cols: Optional[int] = None, |
| fill: bool = True, |
| show_axes: bool = False, |
| rgb: bool = True, |
| normal: bool = False, |
| ) -> None: |
| """Plot grid of images.""" |
| if (rows is None) != (cols is None): |
| raise ValueError("Specify either both rows and cols or neither") |
|
|
| rows = rows or len(images) |
| cols = cols or 1 |
|
|
| fig, axes = plt.subplots( |
| rows, |
| cols, |
| gridspec_kw={"wspace": 0.0, "hspace": 0.0} if fill else {}, |
| figsize=(15, 9), |
| ) |
| fig.subplots_adjust(left=0, bottom=0, right=1, top=1) |
|
|
| for ax, im in zip(np.array(axes).ravel(), images): |
| channel_slice = slice(3) if rgb else slice(3, 4) |
| channel_slice = slice(4, 7) if normal else channel_slice |
| ax.imshow(im[..., channel_slice]) |
| ax.set_axis_off() if not show_axes else None |
|
|
|
|
| def visualize_prediction( |
| predicted: torch.Tensor, |
| target: torch.Tensor, |
| vis_normal: bool = False, |
| gt_normal: Optional[torch.Tensor] = None, |
| pred_normal: Optional[torch.Tensor] = None, |
| title: str = "", |
| ) -> None: |
| """Visualize prediction vs target comparison.""" |
| figsize = (20, 10) if not vis_normal else (20, 20) |
| plt.figure(figsize=figsize) |
|
|
| if not vis_normal: |
| plt.subplot(1, 2, 1).imshow(predicted.detach().cpu().numpy()) |
| plt.subplot(1, 2, 2).imshow(target.detach().cpu().numpy()) |
| else: |
| plt.subplot(2, 2, 1).imshow(predicted.detach().cpu().numpy()) |
| plt.subplot(2, 2, 2).imshow(target.detach().cpu().numpy()) |
| plt.subplot(2, 2, 3).imshow(pred_normal.detach().cpu().numpy()) |
| plt.subplot(2, 2, 4).imshow(gt_normal.detach().cpu().numpy()) |
|
|
| plt.suptitle(title) |
| plt.tight_layout() |
|
|
|
|
| def camera_traj(cam, ref_size=1024, views=30, radius=2.0): |
|
|
| azimuth_bins = 360 // views |
|
|
| cameras = [] |
|
|
| for rotate_i in range(30): |
| pose = orbit_camera( |
| 0, |
| 0 + rotate_i * azimuth_bins, |
| radius, |
| ) |
|
|
| cur_cam = MiniCam( |
| pose, |
| ref_size, |
| ref_size, |
| cam.fovy, |
| cam.fovx, |
| cam.near, |
| cam.far, |
| ) |
|
|
| cameras.append(cur_cam) |
|
|
| return cameras |
|
|