LHMPP / engine /MVSRecon /mvs_recon.py
Lingteng Qiu (邱陵腾)
rm assets & wheels
434b0b0
# -*- coding: utf-8 -*-
# @Organization : Tongyi Lab, Alibaba
# @Author : Lingteng Qiu
# @Email : 220019047@link.cuhk.edu.cn
# @Time : 2025-07-09 19:20:42
# @Function : MVS Reconstruction through continuous remeshing
import sys
sys.path.append("./")
import glob
import math
import os
import pdb
from typing import Dict, List, Optional, Tuple, Union
import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
import tqdm
import trimesh
from engine.continuous_remeshing.core.opt import MeshOptimizer
from engine.MVSRecon.mvs_utils import (
camera_traj,
compute_chamfer_loss,
plot_losses,
update_mesh_shape_prior_losses,
visualize_prediction,
)
from engine.MVSRender.camera_utils import MiniCam, OrbitCamera, orbit_camera
from engine.MVSRender.mvs_render import GaussianModel
from core.models.rendering.mesh_utils import Mesh, safe_normalize
from core.models.rendering.utils.sh_utils import RGB2SH, SH2RGB
from core.models.rendering.utils.typing import *
from PIL import Image
from plyfile import PlyData, PlyElement
from pytorch3d.loss import (
chamfer_distance,
mesh_edge_loss,
mesh_laplacian_smoothing,
mesh_normal_consistency,
)
from pytorch3d.structures import Meshes
import nvdiffrast.torch as dr
def is_format(f: str, format: Sequence[str]):
"""if a file's extension is in a set of format
Args:
f (str): file name.
format (Sequence[str]): set of extensions (both '.jpg' or 'jpg' is ok).
Returns:
bool: if the file's extension is in the set.
"""
ext = os.path.splitext(f)[1].lower() # include the dot
return ext in format or ext[1:] in format
def next_files(path, format=None):
files = os.listdir(path)
files = [os.path.join(path, file) for file in files]
if format is not None:
files = list(filter(lambda x: is_format(x, format), files))
return sorted(files)
def avaliable_device():
if torch.cuda.is_available():
current_device_id = torch.cuda.current_device()
device = f"cuda:{current_device_id}"
else:
device = "cpu"
return device
DEFAULT_DEVICE = avaliable_device()
def read_gt_data(
data_root: str,
tgt_size: int,
mask_data_root: Optional[str] = None,
normal_data_root: Optional[str] = None,
read_normal: bool = True,
device: Optional[torch.device] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
"""
Read and preprocess ground truth images, masks, and optionally normals.
Args:
data_root: Directory containing input images
tgt_size: Target size for resizing images
mask_data_root: Directory containing mask images
normal_data_root: Directory containing normal maps
read_normal: Whether to read normal maps
device: Target device for torch tensors
Returns:
Tensor containing concatenated images and masks (and normals if read_normal=True)
"""
# Get sorted list of image files
img_paths = sorted(
glob.glob(os.path.join(data_root, "*.png"))
+ glob.glob(os.path.join(data_root, "*.jpg"))
)
# Early exit if no images found
if not img_paths:
raise FileNotFoundError(f"No images found in {data_root}")
# Create mask directory if needed
if mask_data_root and not os.path.exists(mask_data_root):
os.makedirs(mask_data_root, exist_ok=True)
# Get mask paths (assuming same directory structure)
mask_dir = mask_data_root
mask_paths = next_files(mask_dir)
if not mask_paths:
raise FileNotFoundError(f"No masks found in {mask_dir}")
# Initialize containers
images, masks, normals = [], [], []
# Process each image-mask pair
for img_path, mask_path in zip(img_paths, mask_paths):
# Read and resize image
img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (tgt_size, tgt_size), interpolation=cv2.INTER_LINEAR)
images.append(img)
# Read and resize mask
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
mask = cv2.resize(mask, (tgt_size, tgt_size), interpolation=cv2.INTER_LINEAR)
masks.append(mask[..., None]) # Add channel dimension
# Optionally read normals
if read_normal and normal_data_root:
base_name = os.path.splitext(os.path.basename(img_path))[0]
normal_path = os.path.join(normal_data_root, f"{base_name}.png")
if not os.path.exists(normal_path):
normal_path = os.path.join(normal_data_root, f"normal_{base_name}.png")
if os.path.exists(normal_path):
normal = cv2.cvtColor(cv2.imread(normal_path), cv2.COLOR_BGR2RGB)
normal = cv2.resize(
normal, (tgt_size, tgt_size), interpolation=cv2.INTER_NEAREST
)
normals.append(normal)
else:
raise FileNotFoundError(f"Normal map not found: {normal_path}")
# Convert to numpy arrays
images = np.stack(images)
masks = np.stack(masks)
# Normalize and convert to tensor
def to_tensor(arr):
return torch.from_numpy(arr).float().to(device=device) / 255.0
if not read_normal or not normals:
return to_tensor(np.concatenate([images, masks], axis=-1))
normals = np.stack(normals)
return to_tensor(np.concatenate([images, masks, normals], axis=-1))
class DrRender:
def __init__(
self,
mesh: "Mesh",
width: int,
height: int,
radius: float,
fovy: float,
num_azimuth: int = 30,
):
self.W = width
self.H = height
self.cam = OrbitCamera(width, height, r=radius, fovy=fovy)
self.bg_color = torch.zeros(3, dtype=torch.float32, device=DEFAULT_DEVICE)
self.glctx = dr.RasterizeCudaContext()
self.mesh = mesh
self.mesh.auto_normal()
self.need_update = True
self.render_normal = True
self.render_depth = False
def step(self, mesh: Union[Meshes, "Mesh"]) -> Dict[str, torch.Tensor]:
"""Render mesh from current camera view."""
if not self.need_update:
return self.render_dict
pose = torch.from_numpy(self.cam.pose.astype(np.float32)).to(DEFAULT_DEVICE)
proj = torch.from_numpy(self.cam.perspective.astype(np.float32)).to(
DEFAULT_DEVICE
)
# Transform vertices
v_cam = (
torch.matmul(F.pad(mesh.v, pad=(0, 1), value=1.0), torch.inverse(pose).T)
.float()
.unsqueeze(0)
)
v_clip = v_cam @ proj.T
# Rasterize
rast, _ = dr.rasterize(self.glctx, v_clip, mesh.f.int(), (self.H, self.W))
alpha = (rast[..., 3:] > 0).float()
alpha = dr.antialias(alpha, rast, v_clip, mesh.f).squeeze(0).clamp(0, 1)
render_buff = {"mask": alpha}
# Render normals if enabled
if self.render_normal:
vn = mesh.vn @ torch.from_numpy(self.cam.pose[:3, :3]).to(mesh.vn)
normal, _ = dr.interpolate(vn.unsqueeze(0).contiguous(), rast, mesh.f)
normal = safe_normalize(normal)
normal = dr.antialias(normal, rast, v_clip, mesh.f)
# coordinate system change
normal[0][:, :, 0] = -normal[0][:, :, 0]
normal[0] = -normal[0]
normal_image = (normal[0] + 1) / 2
normal_image = torch.where(rast[..., 3:] > 0, normal_image, 0)
render_buff["normal"] = normal_image[0]
self.need_update = False
self.render_dict = render_buff
return render_buff
class MVSRec:
def __init__(
self,
gs_file,
img_dir,
mask_dir,
normal_dir,
mesh_path,
fovy=50,
ref_size=1024,
views=30,
radius=2.0,
sh_degree=0,
):
self.cam = OrbitCamera(ref_size, ref_size, r=radius, fovy=fovy)
self.fovy = fovy
self.ref_size = ref_size
self.views = views
self.radius = radius
input_gs = GaussianModel(use_rgb=True)
input_gs.load_ply(gs_file)
self.gs = input_gs
self.autosize()
self.camera_views = camera_traj(self.cam, ref_size, views, radius)
self.device = avaliable_device()
self.sh_degree = sh_degree
self.data_root = img_dir
self.mask_data_root = mask_dir
self.normal_data_root = normal_dir
cano_mesh = mesh_path
cano_mesh = trimesh.load_mesh(cano_mesh, process=False)
vertices = torch.from_numpy(cano_mesh.vertices).float()
faces = torch.from_numpy(cano_mesh.faces).long()
vertices = vertices - self.offset
mesh = Mesh(vertices, faces)
mesh.auto_normal()
self.dr = DrRender(
mesh, ref_size, ref_size, radius=radius, fovy=fovy, num_azimuth=views
)
self.num_azimuth = views
self.device = avaliable_device()
def autosize(self):
xyz = self.gs.xyz
min_xyz = xyz.min(dim=0).values
max_xyz = xyz.max(dim=0).values
middle_offset = (min_xyz + max_xyz) / 2
xyz -= middle_offset
self.gs.xyz = xyz
self.offset = middle_offset
def fitting(
self,
save_dir,
Niter=100,
vis_data_root="./debug/mv_rec/",
save_name="rec_mesh.obj",
):
device = self.device
num_azimuth = self.num_azimuth
use_normal = True
plot_period = 200
vis = vis_data_root is not None
if vis:
os.makedirs(vis_data_root, exist_ok=True)
imgs_w_masks = read_gt_data(
self.data_root,
self.ref_size,
self.mask_data_root,
self.normal_data_root,
read_normal=True,
device=self.device,
)
imgs_gt = imgs_w_masks[:, :, :, :3]
alpha_gt = imgs_w_masks[:, :, :, 3:4]
if use_normal:
normal_gt = imgs_w_masks[:, :, :, 4:7]
# render from fixed views and save all images
elevation = [0]
azimuth = np.linspace(0, 360, num_azimuth, dtype=np.int32, endpoint=False)
view_list = []
for ele in tqdm.tqdm(elevation):
for azi in tqdm.tqdm(azimuth):
view_list.append((ele, azi))
assert len(view_list) == imgs_w_masks.shape[0]
losses = {
"silhouette": {"weight": 1.0, "values": []},
"normal_sup": {"weight": 1.0, "values": []},
"edge": {"weight": 0.0, "values": []},
"normal": {"weight": 0.00, "values": []},
"laplacian": {"weight": 0.0, "values": []},
"chamfer_loss": {"weight": 100, "values": []},
}
# Convert to Pytorch3d Mesh
src_mesh = Meshes(
verts=[self.dr.mesh.v],
faces=[self.dr.mesh.f],
textures=None,
)
loop = tqdm.tqdm(range(Niter))
start_edge_len = 0.02
end_edge_len = 0.005
optimizer = MeshOptimizer(
src_mesh.verts_packed().cuda(),
src_mesh.faces_packed().cuda(),
ramp=5,
edge_len_lims=(end_edge_len, start_edge_len),
local_edgelen=False,
)
vertices = optimizer.vertices
faces = optimizer.faces
for i in loop:
new_mesh = Meshes(
verts=[vertices],
faces=[faces],
textures=None,
)
new_mesh.v = new_mesh.verts_packed()
new_mesh.f = new_mesh.faces_packed().int()
new_mesh.vn = -new_mesh.verts_normals_packed()
# Losses to smooth /regularize the mesh shape
loss = {k: torch.tensor(0.0, device=device) for k in losses}
update_mesh_shape_prior_losses(new_mesh, loss)
num_views_per_iteration = len(view_list)
for idx in range(len(view_list)):
ele, azi = view_list[idx]
self.dr.cam.from_angle(ele, azi)
self.dr.need_update = True
# render
render_buff = self.dr.step(mesh=new_mesh)
loss_silhouette = ((render_buff["mask"] - alpha_gt[idx]) ** 2).mean()
loss["silhouette"] += loss_silhouette / num_views_per_iteration
if use_normal:
mask = alpha_gt[idx]
normal_gt_tmp = normal_gt[idx] * mask
pred_normal_tmp = render_buff["normal"] * mask
loss_normal_sup = ((pred_normal_tmp - normal_gt_tmp) ** 2).mean()
loss["normal_sup"] += loss_normal_sup / num_views_per_iteration
else:
normal_gt_tmp = None
pred_normal_tmp = None
# Weighted sum of the losses
sum_loss = torch.tensor(0.0, device=device)
for k, l in loss.items():
sum_loss += l * losses[k]["weight"]
losses[k]["values"].append(float(l.detach().cpu()))
# Print the losses
loop.set_description(
f"total_loss = {sum_loss:.05f}, alpha = {loss['silhouette']:.05f}, n_sup = {loss['normal_sup']:.05f}, champ = {loss['chamfer_loss']:.05f}, n = {loss['normal']:.05f}, l={loss['laplacian']:.05f}, e={loss['edge']:.05f}"
)
# Plot mesh
if vis and (i % plot_period == 0):
save_path = os.path.join(vis_data_root, f"it{i:05d}.jpg")
visualize_prediction(
self.dr.render_dict["mask"],
alpha_gt[idx],
vis_normal=use_normal,
gt_normal=normal_gt_tmp,
pred_normal=pred_normal_tmp,
title="iter: %d" % i,
)
plt.savefig(save_path)
# Optimization step
sum_loss.backward()
optimizer.step()
vertices, faces = optimizer.remesh()
if vis:
save_path = os.path.join(vis_data_root, f"it{i:05d}.jpg")
visualize_prediction(
self.dr.render_dict["mask"],
alpha_gt[idx],
vis_normal=use_normal,
gt_normal=normal_gt_tmp,
pred_normal=pred_normal_tmp,
title="iter: %d" % i,
)
plt.savefig(save_path)
save_path = os.path.join(vis_data_root, f"loss.jpg")
plot_losses(losses)
plt.savefig(save_path)
# final_obj = os.path.join(vis_data_root, f"final_model_a{losses['silhouette']['weight']}_ns{losses['normal_sup']['weight']}_n{losses['normal']['weight']}_l{losses['laplacian']['weight']}_e{losses['edge']['weight']}.obj")
os.makedirs(save_dir, exist_ok=True)
final_obj = os.path.join(save_dir, save_name)
if not hasattr(new_mesh, "write_obj"):
# back to original coordinate
offset = self.offset.to(new_mesh.verts_packed())
new_mesh = Mesh(
v=(new_mesh.verts_packed() + offset), f=new_mesh.faces_packed()
)
new_mesh.write_obj(final_obj)
def main():
name = "extract_frame_taobao_man9000_445363-23368393-1280376597"
gs_ply = f"./exps/output_gs/{name}.ply"
data_root = f"./exps/mvs_render/{name}"
mask_root = f"./exps/mvs_sam/{name}"
normal_root = f"./exps/mvs_normal/{name}"
mesh_root = f"./exps/smplx_output/{name}.obj"
stereo = MVSRec(gs_ply, data_root, mask_root, normal_root, mesh_root, radius=2.5)
stereo.fitting("./debug/meshs_recon", save_name=f"{name}" + ".obj", Niter=100)
if __name__ == "__main__":
main()