| |
| |
| |
| |
| |
| |
|
|
| import sys |
|
|
| sys.path.append("./") |
| import glob |
| import math |
| import os |
| import pdb |
| from typing import Dict, List, Optional, Tuple, Union |
|
|
| import cv2 |
| import matplotlib.pyplot as plt |
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| import tqdm |
| import trimesh |
| from engine.continuous_remeshing.core.opt import MeshOptimizer |
| from engine.MVSRecon.mvs_utils import ( |
| camera_traj, |
| compute_chamfer_loss, |
| plot_losses, |
| update_mesh_shape_prior_losses, |
| visualize_prediction, |
| ) |
| from engine.MVSRender.camera_utils import MiniCam, OrbitCamera, orbit_camera |
| from engine.MVSRender.mvs_render import GaussianModel |
| from core.models.rendering.mesh_utils import Mesh, safe_normalize |
| from core.models.rendering.utils.sh_utils import RGB2SH, SH2RGB |
| from core.models.rendering.utils.typing import * |
| from PIL import Image |
| from plyfile import PlyData, PlyElement |
| from pytorch3d.loss import ( |
| chamfer_distance, |
| mesh_edge_loss, |
| mesh_laplacian_smoothing, |
| mesh_normal_consistency, |
| ) |
| from pytorch3d.structures import Meshes |
|
|
| import nvdiffrast.torch as dr |
|
|
|
|
| def is_format(f: str, format: Sequence[str]): |
| """if a file's extension is in a set of format |
| |
| Args: |
| f (str): file name. |
| format (Sequence[str]): set of extensions (both '.jpg' or 'jpg' is ok). |
| |
| Returns: |
| bool: if the file's extension is in the set. |
| """ |
| ext = os.path.splitext(f)[1].lower() |
| return ext in format or ext[1:] in format |
|
|
|
|
| def next_files(path, format=None): |
|
|
| files = os.listdir(path) |
| files = [os.path.join(path, file) for file in files] |
| if format is not None: |
| files = list(filter(lambda x: is_format(x, format), files)) |
|
|
| return sorted(files) |
|
|
|
|
| def avaliable_device(): |
| if torch.cuda.is_available(): |
| current_device_id = torch.cuda.current_device() |
| device = f"cuda:{current_device_id}" |
| else: |
| device = "cpu" |
|
|
| return device |
|
|
|
|
| DEFAULT_DEVICE = avaliable_device() |
|
|
|
|
| def read_gt_data( |
| data_root: str, |
| tgt_size: int, |
| mask_data_root: Optional[str] = None, |
| normal_data_root: Optional[str] = None, |
| read_normal: bool = True, |
| device: Optional[torch.device] = None, |
| ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: |
| """ |
| Read and preprocess ground truth images, masks, and optionally normals. |
| |
| Args: |
| data_root: Directory containing input images |
| tgt_size: Target size for resizing images |
| mask_data_root: Directory containing mask images |
| normal_data_root: Directory containing normal maps |
| read_normal: Whether to read normal maps |
| device: Target device for torch tensors |
| |
| Returns: |
| Tensor containing concatenated images and masks (and normals if read_normal=True) |
| """ |
| |
| img_paths = sorted( |
| glob.glob(os.path.join(data_root, "*.png")) |
| + glob.glob(os.path.join(data_root, "*.jpg")) |
| ) |
|
|
| |
| if not img_paths: |
| raise FileNotFoundError(f"No images found in {data_root}") |
|
|
| |
| if mask_data_root and not os.path.exists(mask_data_root): |
| os.makedirs(mask_data_root, exist_ok=True) |
|
|
| |
| mask_dir = mask_data_root |
| mask_paths = next_files(mask_dir) |
| if not mask_paths: |
| raise FileNotFoundError(f"No masks found in {mask_dir}") |
|
|
| |
| images, masks, normals = [], [], [] |
|
|
| |
| for img_path, mask_path in zip(img_paths, mask_paths): |
| |
| img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB) |
| img = cv2.resize(img, (tgt_size, tgt_size), interpolation=cv2.INTER_LINEAR) |
| images.append(img) |
|
|
| |
| mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) |
| mask = cv2.resize(mask, (tgt_size, tgt_size), interpolation=cv2.INTER_LINEAR) |
| masks.append(mask[..., None]) |
|
|
| |
| if read_normal and normal_data_root: |
| base_name = os.path.splitext(os.path.basename(img_path))[0] |
| normal_path = os.path.join(normal_data_root, f"{base_name}.png") |
|
|
| if not os.path.exists(normal_path): |
| normal_path = os.path.join(normal_data_root, f"normal_{base_name}.png") |
|
|
| if os.path.exists(normal_path): |
| normal = cv2.cvtColor(cv2.imread(normal_path), cv2.COLOR_BGR2RGB) |
| normal = cv2.resize( |
| normal, (tgt_size, tgt_size), interpolation=cv2.INTER_NEAREST |
| ) |
| normals.append(normal) |
| else: |
| raise FileNotFoundError(f"Normal map not found: {normal_path}") |
|
|
| |
| images = np.stack(images) |
| masks = np.stack(masks) |
|
|
| |
| def to_tensor(arr): |
| return torch.from_numpy(arr).float().to(device=device) / 255.0 |
|
|
| if not read_normal or not normals: |
| return to_tensor(np.concatenate([images, masks], axis=-1)) |
|
|
| normals = np.stack(normals) |
|
|
| return to_tensor(np.concatenate([images, masks, normals], axis=-1)) |
|
|
|
|
| class DrRender: |
| def __init__( |
| self, |
| mesh: "Mesh", |
| width: int, |
| height: int, |
| radius: float, |
| fovy: float, |
| num_azimuth: int = 30, |
| ): |
| self.W = width |
| self.H = height |
| self.cam = OrbitCamera(width, height, r=radius, fovy=fovy) |
| self.bg_color = torch.zeros(3, dtype=torch.float32, device=DEFAULT_DEVICE) |
| self.glctx = dr.RasterizeCudaContext() |
| self.mesh = mesh |
| self.mesh.auto_normal() |
| self.need_update = True |
| self.render_normal = True |
| self.render_depth = False |
|
|
| def step(self, mesh: Union[Meshes, "Mesh"]) -> Dict[str, torch.Tensor]: |
| """Render mesh from current camera view.""" |
| if not self.need_update: |
| return self.render_dict |
|
|
| pose = torch.from_numpy(self.cam.pose.astype(np.float32)).to(DEFAULT_DEVICE) |
| proj = torch.from_numpy(self.cam.perspective.astype(np.float32)).to( |
| DEFAULT_DEVICE |
| ) |
|
|
| |
| v_cam = ( |
| torch.matmul(F.pad(mesh.v, pad=(0, 1), value=1.0), torch.inverse(pose).T) |
| .float() |
| .unsqueeze(0) |
| ) |
| v_clip = v_cam @ proj.T |
|
|
| |
| rast, _ = dr.rasterize(self.glctx, v_clip, mesh.f.int(), (self.H, self.W)) |
| alpha = (rast[..., 3:] > 0).float() |
| alpha = dr.antialias(alpha, rast, v_clip, mesh.f).squeeze(0).clamp(0, 1) |
|
|
| render_buff = {"mask": alpha} |
|
|
| |
| if self.render_normal: |
| vn = mesh.vn @ torch.from_numpy(self.cam.pose[:3, :3]).to(mesh.vn) |
| normal, _ = dr.interpolate(vn.unsqueeze(0).contiguous(), rast, mesh.f) |
| normal = safe_normalize(normal) |
| normal = dr.antialias(normal, rast, v_clip, mesh.f) |
|
|
| |
| normal[0][:, :, 0] = -normal[0][:, :, 0] |
| normal[0] = -normal[0] |
|
|
| normal_image = (normal[0] + 1) / 2 |
| normal_image = torch.where(rast[..., 3:] > 0, normal_image, 0) |
| render_buff["normal"] = normal_image[0] |
|
|
| self.need_update = False |
| self.render_dict = render_buff |
| return render_buff |
|
|
|
|
| class MVSRec: |
| def __init__( |
| self, |
| gs_file, |
| img_dir, |
| mask_dir, |
| normal_dir, |
| mesh_path, |
| fovy=50, |
| ref_size=1024, |
| views=30, |
| radius=2.0, |
| sh_degree=0, |
| ): |
|
|
| self.cam = OrbitCamera(ref_size, ref_size, r=radius, fovy=fovy) |
|
|
| self.fovy = fovy |
| self.ref_size = ref_size |
| self.views = views |
| self.radius = radius |
|
|
| input_gs = GaussianModel(use_rgb=True) |
| input_gs.load_ply(gs_file) |
|
|
| self.gs = input_gs |
| self.autosize() |
|
|
| self.camera_views = camera_traj(self.cam, ref_size, views, radius) |
| self.device = avaliable_device() |
| self.sh_degree = sh_degree |
|
|
| self.data_root = img_dir |
| self.mask_data_root = mask_dir |
| self.normal_data_root = normal_dir |
|
|
| cano_mesh = mesh_path |
| cano_mesh = trimesh.load_mesh(cano_mesh, process=False) |
| vertices = torch.from_numpy(cano_mesh.vertices).float() |
| faces = torch.from_numpy(cano_mesh.faces).long() |
|
|
| vertices = vertices - self.offset |
| mesh = Mesh(vertices, faces) |
| mesh.auto_normal() |
|
|
| self.dr = DrRender( |
| mesh, ref_size, ref_size, radius=radius, fovy=fovy, num_azimuth=views |
| ) |
|
|
| self.num_azimuth = views |
|
|
| self.device = avaliable_device() |
|
|
| def autosize(self): |
| xyz = self.gs.xyz |
| min_xyz = xyz.min(dim=0).values |
| max_xyz = xyz.max(dim=0).values |
| middle_offset = (min_xyz + max_xyz) / 2 |
| xyz -= middle_offset |
| self.gs.xyz = xyz |
| self.offset = middle_offset |
|
|
| def fitting( |
| self, |
| save_dir, |
| Niter=100, |
| vis_data_root="./debug/mv_rec/", |
| save_name="rec_mesh.obj", |
| ): |
|
|
| device = self.device |
|
|
| num_azimuth = self.num_azimuth |
|
|
| use_normal = True |
| plot_period = 200 |
| vis = vis_data_root is not None |
| if vis: |
| os.makedirs(vis_data_root, exist_ok=True) |
|
|
| imgs_w_masks = read_gt_data( |
| self.data_root, |
| self.ref_size, |
| self.mask_data_root, |
| self.normal_data_root, |
| read_normal=True, |
| device=self.device, |
| ) |
|
|
| imgs_gt = imgs_w_masks[:, :, :, :3] |
| alpha_gt = imgs_w_masks[:, :, :, 3:4] |
| if use_normal: |
| normal_gt = imgs_w_masks[:, :, :, 4:7] |
|
|
| |
| elevation = [0] |
| azimuth = np.linspace(0, 360, num_azimuth, dtype=np.int32, endpoint=False) |
|
|
| view_list = [] |
| for ele in tqdm.tqdm(elevation): |
| for azi in tqdm.tqdm(azimuth): |
| view_list.append((ele, azi)) |
|
|
| assert len(view_list) == imgs_w_masks.shape[0] |
|
|
| losses = { |
| "silhouette": {"weight": 1.0, "values": []}, |
| "normal_sup": {"weight": 1.0, "values": []}, |
| "edge": {"weight": 0.0, "values": []}, |
| "normal": {"weight": 0.00, "values": []}, |
| "laplacian": {"weight": 0.0, "values": []}, |
| "chamfer_loss": {"weight": 100, "values": []}, |
| } |
|
|
| |
| src_mesh = Meshes( |
| verts=[self.dr.mesh.v], |
| faces=[self.dr.mesh.f], |
| textures=None, |
| ) |
|
|
| loop = tqdm.tqdm(range(Niter)) |
|
|
| start_edge_len = 0.02 |
| end_edge_len = 0.005 |
|
|
| optimizer = MeshOptimizer( |
| src_mesh.verts_packed().cuda(), |
| src_mesh.faces_packed().cuda(), |
| ramp=5, |
| edge_len_lims=(end_edge_len, start_edge_len), |
| local_edgelen=False, |
| ) |
| vertices = optimizer.vertices |
| faces = optimizer.faces |
|
|
| for i in loop: |
| new_mesh = Meshes( |
| verts=[vertices], |
| faces=[faces], |
| textures=None, |
| ) |
| new_mesh.v = new_mesh.verts_packed() |
| new_mesh.f = new_mesh.faces_packed().int() |
| new_mesh.vn = -new_mesh.verts_normals_packed() |
|
|
| |
| loss = {k: torch.tensor(0.0, device=device) for k in losses} |
|
|
| update_mesh_shape_prior_losses(new_mesh, loss) |
|
|
| num_views_per_iteration = len(view_list) |
|
|
| for idx in range(len(view_list)): |
| ele, azi = view_list[idx] |
|
|
| self.dr.cam.from_angle(ele, azi) |
| self.dr.need_update = True |
|
|
| |
| render_buff = self.dr.step(mesh=new_mesh) |
|
|
| loss_silhouette = ((render_buff["mask"] - alpha_gt[idx]) ** 2).mean() |
| loss["silhouette"] += loss_silhouette / num_views_per_iteration |
|
|
| if use_normal: |
| mask = alpha_gt[idx] |
| normal_gt_tmp = normal_gt[idx] * mask |
| pred_normal_tmp = render_buff["normal"] * mask |
|
|
| loss_normal_sup = ((pred_normal_tmp - normal_gt_tmp) ** 2).mean() |
| loss["normal_sup"] += loss_normal_sup / num_views_per_iteration |
| else: |
| normal_gt_tmp = None |
| pred_normal_tmp = None |
|
|
| |
| sum_loss = torch.tensor(0.0, device=device) |
| for k, l in loss.items(): |
| sum_loss += l * losses[k]["weight"] |
| losses[k]["values"].append(float(l.detach().cpu())) |
|
|
| |
| loop.set_description( |
| f"total_loss = {sum_loss:.05f}, alpha = {loss['silhouette']:.05f}, n_sup = {loss['normal_sup']:.05f}, champ = {loss['chamfer_loss']:.05f}, n = {loss['normal']:.05f}, l={loss['laplacian']:.05f}, e={loss['edge']:.05f}" |
| ) |
|
|
| |
|
|
| if vis and (i % plot_period == 0): |
| save_path = os.path.join(vis_data_root, f"it{i:05d}.jpg") |
| visualize_prediction( |
| self.dr.render_dict["mask"], |
| alpha_gt[idx], |
| vis_normal=use_normal, |
| gt_normal=normal_gt_tmp, |
| pred_normal=pred_normal_tmp, |
| title="iter: %d" % i, |
| ) |
| plt.savefig(save_path) |
|
|
| |
| sum_loss.backward() |
| optimizer.step() |
| vertices, faces = optimizer.remesh() |
|
|
| if vis: |
| save_path = os.path.join(vis_data_root, f"it{i:05d}.jpg") |
| visualize_prediction( |
| self.dr.render_dict["mask"], |
| alpha_gt[idx], |
| vis_normal=use_normal, |
| gt_normal=normal_gt_tmp, |
| pred_normal=pred_normal_tmp, |
| title="iter: %d" % i, |
| ) |
| plt.savefig(save_path) |
| save_path = os.path.join(vis_data_root, f"loss.jpg") |
| plot_losses(losses) |
| plt.savefig(save_path) |
|
|
| |
| os.makedirs(save_dir, exist_ok=True) |
| final_obj = os.path.join(save_dir, save_name) |
|
|
| if not hasattr(new_mesh, "write_obj"): |
| |
| offset = self.offset.to(new_mesh.verts_packed()) |
| new_mesh = Mesh( |
| v=(new_mesh.verts_packed() + offset), f=new_mesh.faces_packed() |
| ) |
|
|
| new_mesh.write_obj(final_obj) |
|
|
|
|
| def main(): |
|
|
| name = "extract_frame_taobao_man9000_445363-23368393-1280376597" |
| gs_ply = f"./exps/output_gs/{name}.ply" |
| data_root = f"./exps/mvs_render/{name}" |
| mask_root = f"./exps/mvs_sam/{name}" |
| normal_root = f"./exps/mvs_normal/{name}" |
| mesh_root = f"./exps/smplx_output/{name}.obj" |
|
|
| stereo = MVSRec(gs_ply, data_root, mask_root, normal_root, mesh_root, radius=2.5) |
| stereo.fitting("./debug/meshs_recon", save_name=f"{name}" + ".obj", Niter=100) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|