LHMPP / engine /MVSRecon /mvs_utils.py
Lingteng Qiu (邱陵腾)
rm assets & wheels
434b0b0
from typing import Dict, List, Optional, Tuple, Union
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from pytorch3d.loss import (
chamfer_distance,
mesh_edge_loss,
mesh_laplacian_smoothing,
mesh_normal_consistency,
)
from pytorch3d.structures import Meshes
from engine.MVSRender.camera_utils import MiniCam, OrbitCamera, orbit_camera
def compute_chamfer_loss(
query_verts: torch.Tensor, gt_pts: torch.Tensor
) -> torch.Tensor:
"""Compute chamfer distance between two point sets."""
return chamfer_distance(query_verts[None], gt_pts[None])[0]
def update_mesh_shape_prior_losses(
mesh: Union[Meshes, "Mesh"], loss: Dict[str, torch.Tensor]
) -> None:
"""Update mesh regularization losses."""
if isinstance(mesh, Meshes):
loss.update(
{
"edge": mesh_edge_loss(mesh),
"normal": mesh_normal_consistency(mesh),
"laplacian": mesh_laplacian_smoothing(mesh, method="uniform"),
}
)
else:
loss.update(
{"normal": mesh.normal_consistency(), "laplacian": mesh.laplacian()}
)
def plot_losses(losses: Dict[str, Dict[str, List[float]]]) -> None:
"""Plot training losses over iterations."""
plt.figure(figsize=(13, 5))
ax = plt.gca()
for k, l in losses.items():
ax.plot(l["values"], label=f"{k} loss")
ax.legend(fontsize=16)
ax.set_xlabel("Iteration", fontsize=16)
ax.set_ylabel("Loss", fontsize=16)
ax.set_title("Loss vs iterations", fontsize=16)
def inverse_sigmoid(x: Union[float, torch.Tensor]) -> torch.Tensor:
"""Inverse sigmoid function."""
x = torch.tensor(x).float() if isinstance(x, float) else x
return torch.log(x / (1 - x))
def image_grid(
images: np.ndarray,
rows: Optional[int] = None,
cols: Optional[int] = None,
fill: bool = True,
show_axes: bool = False,
rgb: bool = True,
normal: bool = False,
) -> None:
"""Plot grid of images."""
if (rows is None) != (cols is None):
raise ValueError("Specify either both rows and cols or neither")
rows = rows or len(images)
cols = cols or 1
fig, axes = plt.subplots(
rows,
cols,
gridspec_kw={"wspace": 0.0, "hspace": 0.0} if fill else {},
figsize=(15, 9),
)
fig.subplots_adjust(left=0, bottom=0, right=1, top=1)
for ax, im in zip(np.array(axes).ravel(), images):
channel_slice = slice(3) if rgb else slice(3, 4)
channel_slice = slice(4, 7) if normal else channel_slice
ax.imshow(im[..., channel_slice])
ax.set_axis_off() if not show_axes else None
def visualize_prediction(
predicted: torch.Tensor,
target: torch.Tensor,
vis_normal: bool = False,
gt_normal: Optional[torch.Tensor] = None,
pred_normal: Optional[torch.Tensor] = None,
title: str = "",
) -> None:
"""Visualize prediction vs target comparison."""
figsize = (20, 10) if not vis_normal else (20, 20)
plt.figure(figsize=figsize)
if not vis_normal:
plt.subplot(1, 2, 1).imshow(predicted.detach().cpu().numpy())
plt.subplot(1, 2, 2).imshow(target.detach().cpu().numpy())
else:
plt.subplot(2, 2, 1).imshow(predicted.detach().cpu().numpy())
plt.subplot(2, 2, 2).imshow(target.detach().cpu().numpy())
plt.subplot(2, 2, 3).imshow(pred_normal.detach().cpu().numpy())
plt.subplot(2, 2, 4).imshow(gt_normal.detach().cpu().numpy())
plt.suptitle(title)
plt.tight_layout()
def camera_traj(cam, ref_size=1024, views=30, radius=2.0):
azimuth_bins = 360 // views
cameras = []
for rotate_i in range(30):
pose = orbit_camera(
0,
0 + rotate_i * azimuth_bins,
radius,
)
cur_cam = MiniCam(
pose,
ref_size,
ref_size,
cam.fovy,
cam.fovx,
cam.near,
cam.far,
)
cameras.append(cur_cam)
return cameras