from typing import * import numpy as np import torch import utils3d import nvdiffrast.torch as dr from tqdm import tqdm import trimesh import trimesh.visual import xatlas import pyvista as pv from pymeshfix import _meshfix import igraph import cv2 from PIL import Image from .random_utils import sphere_hammersley_sequence from .render_utils import render_multiview from ..renderers import GaussianRenderer from ..representations import Strivec, Gaussian, MeshExtractResult @torch.no_grad() def _fill_holes( verts, faces, max_hole_size=0.04, max_hole_nbe=32, resolution=128, num_views=500, debug=False, verbose=False ): """ Rasterize a mesh from multiple views and remove invisible faces. Also includes postprocessing to: 1. Remove connected components that are have low visibility. 2. Mincut to remove faces at the inner side of the mesh connected to the outer side with a small hole. Args: verts (torch.Tensor): Vertices of the mesh. Shape (V, 3). faces (torch.Tensor): Faces of the mesh. Shape (F, 3). max_hole_size (float): Maximum area of a hole to fill. resolution (int): Resolution of the rasterization. num_views (int): Number of views to rasterize the mesh. verbose (bool): Whether to print progress. """ # Construct cameras yaws = [] pitchs = [] for i in range(num_views): y, p = sphere_hammersley_sequence(i, num_views) yaws.append(y) pitchs.append(p) yaws = torch.tensor(yaws).cuda() pitchs = torch.tensor(pitchs).cuda() radius = 2.0 fov = torch.deg2rad(torch.tensor(40)).cuda() projection = utils3d.torch.perspective_from_fov_xy(fov, fov, 1, 3) views = [] for (yaw, pitch) in zip(yaws, pitchs): orig = torch.tensor([ torch.sin(yaw) * torch.cos(pitch), torch.cos(yaw) * torch.cos(pitch), torch.sin(pitch), ]).cuda().float() * radius view = utils3d.torch.view_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda()) views.append(view) views = torch.stack(views, dim=0) # Rasterize visblity = torch.zeros(faces.shape[0], dtype=torch.int32, device=verts.device) rastctx = utils3d.torch.RastContext(backend='cuda') for i in tqdm(range(views.shape[0]), total=views.shape[0], disable=not verbose, desc='Rasterizing'): view = views[i] buffers = utils3d.torch.rasterize_triangle_faces( rastctx, verts[None], faces, resolution, resolution, view=view, projection=projection ) face_id = buffers['face_id'][0][buffers['mask'][0] > 0.95] - 1 face_id = torch.unique(face_id).long() visblity[face_id] += 1 visblity = visblity.float() / num_views # Mincut ## construct outer faces edges, face2edge, edge_degrees = utils3d.torch.compute_edges(faces) boundary_edge_indices = torch.nonzero(edge_degrees == 1).reshape(-1) connected_components = utils3d.torch.compute_connected_components(faces, edges, face2edge) outer_face_indices = torch.zeros(faces.shape[0], dtype=torch.bool, device=faces.device) for i in range(len(connected_components)): outer_face_indices[connected_components[i]] = visblity[connected_components[i]] > min(max(visblity[connected_components[i]].quantile(0.75).item(), 0.25), 0.5) outer_face_indices = outer_face_indices.nonzero().reshape(-1) ## construct inner faces inner_face_indices = torch.nonzero(visblity == 0).reshape(-1) if verbose: tqdm.write(f'Found {inner_face_indices.shape[0]} invisible faces') if inner_face_indices.shape[0] == 0: return verts, faces ## Construct dual graph (faces as nodes, edges as edges) dual_edges, dual_edge2edge = utils3d.torch.compute_dual_graph(face2edge) dual_edge2edge = edges[dual_edge2edge] dual_edges_weights = torch.norm(verts[dual_edge2edge[:, 0]] - verts[dual_edge2edge[:, 1]], dim=1) if verbose: tqdm.write(f'Dual graph: {dual_edges.shape[0]} edges') ## solve mincut problem ### construct main graph g = igraph.Graph() g.add_vertices(faces.shape[0]) g.add_edges(dual_edges.cpu().numpy()) g.es['weight'] = dual_edges_weights.cpu().numpy() ### source and target g.add_vertex('s') g.add_vertex('t') ### connect invisible faces to source g.add_edges([(f, 's') for f in inner_face_indices], attributes={'weight': torch.ones(inner_face_indices.shape[0], dtype=torch.float32).cpu().numpy()}) ### connect outer faces to target g.add_edges([(f, 't') for f in outer_face_indices], attributes={'weight': torch.ones(outer_face_indices.shape[0], dtype=torch.float32).cpu().numpy()}) ### solve mincut cut = g.mincut('s', 't', (np.array(g.es['weight']) * 1000).tolist()) remove_face_indices = torch.tensor([v for v in cut.partition[0] if v < faces.shape[0]], dtype=torch.long, device=faces.device) if verbose: tqdm.write(f'Mincut solved, start checking the cut') ### check if the cut is valid with each connected component to_remove_cc = utils3d.torch.compute_connected_components(faces[remove_face_indices]) if debug: tqdm.write(f'Number of connected components of the cut: {len(to_remove_cc)}') valid_remove_cc = [] cutting_edges = [] for cc in to_remove_cc: #### check if the connected component has low visibility visblity_median = visblity[remove_face_indices[cc]].median() if debug: tqdm.write(f'visblity_median: {visblity_median}') if visblity_median > 0.25: continue #### check if the cuting loop is small enough cc_edge_indices, cc_edges_degree = torch.unique(face2edge[remove_face_indices[cc]], return_counts=True) cc_boundary_edge_indices = cc_edge_indices[cc_edges_degree == 1] cc_new_boundary_edge_indices = cc_boundary_edge_indices[~torch.isin(cc_boundary_edge_indices, boundary_edge_indices)] if len(cc_new_boundary_edge_indices) > 0: cc_new_boundary_edge_cc = utils3d.torch.compute_edge_connected_components(edges[cc_new_boundary_edge_indices]) cc_new_boundary_edges_cc_center = [verts[edges[cc_new_boundary_edge_indices[edge_cc]]].mean(dim=1).mean(dim=0) for edge_cc in cc_new_boundary_edge_cc] cc_new_boundary_edges_cc_area = [] for i, edge_cc in enumerate(cc_new_boundary_edge_cc): _e1 = verts[edges[cc_new_boundary_edge_indices[edge_cc]][:, 0]] - cc_new_boundary_edges_cc_center[i] _e2 = verts[edges[cc_new_boundary_edge_indices[edge_cc]][:, 1]] - cc_new_boundary_edges_cc_center[i] cc_new_boundary_edges_cc_area.append(torch.norm(torch.cross(_e1, _e2, dim=-1), dim=1).sum() * 0.5) if debug: cutting_edges.append(cc_new_boundary_edge_indices) tqdm.write(f'Area of the cutting loop: {cc_new_boundary_edges_cc_area}') if any([l > max_hole_size for l in cc_new_boundary_edges_cc_area]): continue valid_remove_cc.append(cc) if debug: face_v = verts[faces].mean(dim=1).cpu().numpy() vis_dual_edges = dual_edges.cpu().numpy() vis_colors = np.zeros((faces.shape[0], 3), dtype=np.uint8) vis_colors[inner_face_indices.cpu().numpy()] = [0, 0, 255] vis_colors[outer_face_indices.cpu().numpy()] = [0, 255, 0] vis_colors[remove_face_indices.cpu().numpy()] = [255, 0, 255] if len(valid_remove_cc) > 0: vis_colors[remove_face_indices[torch.cat(valid_remove_cc)].cpu().numpy()] = [255, 0, 0] utils3d.io.write_ply('dbg_dual.ply', face_v, edges=vis_dual_edges, vertex_colors=vis_colors) vis_verts = verts.cpu().numpy() vis_edges = edges[torch.cat(cutting_edges)].cpu().numpy() utils3d.io.write_ply('dbg_cut.ply', vis_verts, edges=vis_edges) if len(valid_remove_cc) > 0: remove_face_indices = remove_face_indices[torch.cat(valid_remove_cc)] mask = torch.ones(faces.shape[0], dtype=torch.bool, device=faces.device) mask[remove_face_indices] = 0 faces = faces[mask] faces, verts = utils3d.torch.remove_unreferenced_vertices(faces, verts) if verbose: tqdm.write(f'Removed {(~mask).sum()} faces by mincut') else: if verbose: tqdm.write(f'Removed 0 faces by mincut') mesh = _meshfix.PyTMesh() mesh.load_array(verts.cpu().numpy(), faces.cpu().numpy()) mesh.fill_small_boundaries(nbe=max_hole_nbe, refine=True) verts, faces = mesh.return_arrays() verts, faces = torch.tensor(verts, device='cuda', dtype=torch.float32), torch.tensor(faces, device='cuda', dtype=torch.int32) return verts, faces def postprocess_mesh( vertices: np.array, faces: np.array, simplify: bool = True, simplify_ratio: float = 0.9, fill_holes: bool = True, fill_holes_max_hole_size: float = 0.04, fill_holes_max_hole_nbe: int = 32, fill_holes_resolution: int = 1024, fill_holes_num_views: int = 1000, debug: bool = False, verbose: bool = False, ): """ Postprocess a mesh by simplifying, removing invisible faces, and removing isolated pieces. Args: vertices (np.array): Vertices of the mesh. Shape (V, 3). faces (np.array): Faces of the mesh. Shape (F, 3). simplify (bool): Whether to simplify the mesh, using quadric edge collapse. simplify_ratio (float): Ratio of faces to keep after simplification. fill_holes (bool): Whether to fill holes in the mesh. fill_holes_max_hole_size (float): Maximum area of a hole to fill. fill_holes_max_hole_nbe (int): Maximum number of boundary edges of a hole to fill. fill_holes_resolution (int): Resolution of the rasterization. fill_holes_num_views (int): Number of views to rasterize the mesh. verbose (bool): Whether to print progress. """ if verbose: tqdm.write(f'Before postprocess: {vertices.shape[0]} vertices, {faces.shape[0]} faces') # Simplify if simplify and simplify_ratio > 0: mesh = pv.PolyData(vertices, np.concatenate([np.full((faces.shape[0], 1), 3), faces], axis=1)) mesh = mesh.decimate(simplify_ratio, progress_bar=verbose) vertices, faces = mesh.points, mesh.faces.reshape(-1, 4)[:, 1:] if verbose: tqdm.write(f'After decimate: {vertices.shape[0]} vertices, {faces.shape[0]} faces') # Remove invisible faces if fill_holes: vertices, faces = torch.tensor(vertices).cuda(), torch.tensor(faces.astype(np.int32)).cuda() vertices, faces = _fill_holes( vertices, faces, max_hole_size=fill_holes_max_hole_size, max_hole_nbe=fill_holes_max_hole_nbe, resolution=fill_holes_resolution, num_views=fill_holes_num_views, debug=debug, verbose=verbose, ) vertices, faces = vertices.cpu().numpy(), faces.cpu().numpy() if verbose: tqdm.write(f'After remove invisible faces: {vertices.shape[0]} vertices, {faces.shape[0]} faces') return vertices, faces def barycentric_transfer_attributes( src_mesh: trimesh.Trimesh, src_attrs: np.ndarray, dst_vertices: np.ndarray, ) -> np.ndarray: """ Transfer per-vertex attributes from a source mesh to new vertices via barycentric interpolation on the closest triangle. Args: src_mesh (trimesh.Trimesh): Source mesh (must have faces). src_attrs (np.ndarray): Per-vertex attributes on the source mesh. Shape (V_src, C). dst_vertices (np.ndarray): Destination vertex positions. Shape (V_dst, 3). Returns: np.ndarray: Interpolated attributes for each destination vertex. Shape (V_dst, C). """ src_attrs = np.asarray(src_attrs, dtype=np.float64) dst_vertices = np.asarray(dst_vertices, dtype=np.float64) closest_points, _, triangle_ids = trimesh.proximity.closest_point(src_mesh, dst_vertices) face_indices = src_mesh.faces[triangle_ids] # (N, 3) v0 = src_mesh.vertices[face_indices[:, 0]].astype(np.float64) v1 = src_mesh.vertices[face_indices[:, 1]].astype(np.float64) v2 = src_mesh.vertices[face_indices[:, 2]].astype(np.float64) # Barycentric coordinates via dot-product method e0 = v1 - v0 e1 = v2 - v0 w = closest_points.astype(np.float64) - v0 d00 = np.sum(e0 * e0, axis=1) d01 = np.sum(e0 * e1, axis=1) d11 = np.sum(e1 * e1, axis=1) d20 = np.sum(w * e0, axis=1) d21 = np.sum(w * e1, axis=1) denom = d00 * d11 - d01 * d01 denom = np.where(np.abs(denom) < 1e-12, 1e-12, denom) b1 = (d11 * d20 - d01 * d21) / denom b2 = (d00 * d21 - d01 * d20) / denom b0 = 1.0 - b1 - b2 bary = np.stack([b0, b1, b2], axis=1) # (N, 3) np.clip(bary, 0.0, None, out=bary) bary_sum = bary.sum(axis=1, keepdims=True) bary_sum = np.maximum(bary_sum, 1e-12) bary /= bary_sum a0 = src_attrs[face_indices[:, 0]] a1 = src_attrs[face_indices[:, 1]] a2 = src_attrs[face_indices[:, 2]] result = bary[:, 0:1] * a0 + bary[:, 1:2] * a1 + bary[:, 2:3] * a2 return result.astype(np.float32) def parametrize_mesh(vertices: np.array, faces: np.array): """ Parametrize a mesh to a texture space, using xatlas. Args: vertices (np.array): Vertices of the mesh. Shape (V, 3). faces (np.array): Faces of the mesh. Shape (F, 3). Returns: vertices, faces, uvs, vmapping vmapping maps new vertex indices back to original vertex indices (new vertices may be duplicated at UV seams). """ vmapping, indices, uvs = xatlas.parametrize(vertices, faces) vertices = vertices[vmapping] faces = indices return vertices, faces, uvs, vmapping @torch.no_grad() def bake_vertex_colors_to_texture( dense_vertices: np.ndarray, dense_faces: np.ndarray, dense_vertex_colors: np.ndarray, simp_vertices: np.ndarray, simp_faces: np.ndarray, simp_uvs: np.ndarray, texture_size: int = 1024, ) -> np.ndarray: """ Bake per-vertex colors from a dense mesh into a UV-mapped texture on a simplified mesh. For each texel covered by the simplified mesh in UV space, the 3D position is computed via nvdiffrast interpolation, then the closest point on the dense mesh is queried and its vertex color is barycentric-interpolated. Args: dense_vertices (np.ndarray): Dense mesh vertices. Shape (Vd, 3). dense_faces (np.ndarray): Dense mesh faces. Shape (Fd, 3). dense_vertex_colors (np.ndarray): Per-vertex RGB in [0,1]. Shape (Vd, 3). simp_vertices (np.ndarray): Simplified (UV-split) mesh vertices. Shape (Vs, 3). simp_faces (np.ndarray): Simplified mesh faces. Shape (Fs, 3). simp_uvs (np.ndarray): UV coordinates for simplified mesh. Shape (Vs, 2). texture_size (int): Output texture resolution (square). Returns: np.ndarray: Baked texture image, shape (texture_size, texture_size, 3), uint8. """ device = 'cuda' verts_t = torch.tensor(simp_vertices, dtype=torch.float32, device=device) faces_t = torch.tensor(simp_faces.astype(np.int32), dtype=torch.int32, device=device) uvs_t = torch.tensor(simp_uvs, dtype=torch.float32, device=device) # Map UVs to clip space for nvdiffrast: [0,1] -> [-1,1], z=0, w=1 uv_clip = torch.zeros(uvs_t.shape[0], 4, dtype=torch.float32, device=device) uv_clip[:, 0] = uvs_t[:, 0] * 2.0 - 1.0 uv_clip[:, 1] = uvs_t[:, 1] * 2.0 - 1.0 uv_clip[:, 2] = 0.0 uv_clip[:, 3] = 1.0 glctx = dr.RasterizeCudaContext() rast_out, _ = dr.rasterize(glctx, uv_clip[None], faces_t, resolution=[texture_size, texture_size]) # Interpolate 3D positions at each texel pos_map, _ = dr.interpolate(verts_t[None].contiguous(), rast_out, faces_t) # pos_map: (1, H, W, 3) mask = (rast_out[0, :, :, 3] > 0) # (H, W) positions = pos_map[0][mask].cpu().numpy() # (N, 3) # Query dense mesh for closest-point colors dense_mesh = trimesh.Trimesh(vertices=dense_vertices, faces=dense_faces, process=False) closest_pts, _, tri_ids = trimesh.proximity.closest_point(dense_mesh, positions) face_verts = dense_faces[tri_ids] # (N, 3) v0 = dense_vertices[face_verts[:, 0]] v1 = dense_vertices[face_verts[:, 1]] v2 = dense_vertices[face_verts[:, 2]] e0 = v1 - v0 e1 = v2 - v0 w = closest_pts - v0 d00 = np.sum(e0 * e0, axis=1) d01 = np.sum(e0 * e1, axis=1) d11 = np.sum(e1 * e1, axis=1) d20 = np.sum(w * e0, axis=1) d21 = np.sum(w * e1, axis=1) denom = d00 * d11 - d01 * d01 denom = np.where(np.abs(denom) < 1e-12, 1e-12, denom) b1 = (d11 * d20 - d01 * d21) / denom b2 = (d00 * d21 - d01 * d20) / denom b0 = 1.0 - b1 - b2 bary = np.stack([b0, b1, b2], axis=1) np.clip(bary, 0.0, None, out=bary) bary /= np.maximum(bary.sum(axis=1, keepdims=True), 1e-12) c0 = dense_vertex_colors[face_verts[:, 0]] c1 = dense_vertex_colors[face_verts[:, 1]] c2 = dense_vertex_colors[face_verts[:, 2]] colors = bary[:, 0:1] * c0 + bary[:, 1:2] * c1 + bary[:, 2:3] * c2 # Write to texture (flip vertically to match image convention) texture = np.zeros((texture_size, texture_size, 3), dtype=np.float32) mask_np = mask.cpu().numpy() texture[mask_np] = colors.astype(np.float32) texture = np.flipud(texture) mask_np = np.flipud(mask_np) texture = np.clip(texture * 255, 0, 255).astype(np.uint8) inpaint_mask = (~mask_np).astype(np.uint8) texture = cv2.inpaint(texture, inpaint_mask, 3, cv2.INPAINT_TELEA) return texture @torch.no_grad() def render_multiview_mesh_colors( vertices: np.ndarray, faces: np.ndarray, vertex_colors: np.ndarray, resolution: int = 1024, nviews: int = 100, near: float = 0.1, far: float = 10.0, verbose: bool = True, ): """ Render multiview color images from a mesh with per-vertex colors. Uses ``utils3d.torch.rasterize_triangle_faces`` — the exact same rasterisation path that :func:`bake_texture` uses internally — so the observations are guaranteed to be projection-aligned with the bake-texture rasterisation. Returns: observations: list of (H, W, 3) uint8 images in standard top-left origin extrinsics: list of numpy (4, 4) intrinsics: list of numpy (3, 3) """ from .render_utils import yaw_pitch_r_fov_to_extrinsics_intrinsics r, fov = 2, 40 cams = [sphere_hammersley_sequence(i, nviews) for i in range(nviews)] extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics( [c[0] for c in cams], [c[1] for c in cams], r, fov, ) verts_t = torch.tensor(vertices, dtype=torch.float32, device='cuda') faces_t = torch.tensor(faces.astype(np.int32), dtype=torch.int32, device='cuda') colors_t = torch.tensor(vertex_colors, dtype=torch.float32, device='cuda').clamp(0, 1) rastctx = utils3d.torch.RastContext(backend='cuda') observations = [] for extr, intr in tqdm( zip(extrinsics, intrinsics), total=nviews, disable=not verbose, desc='Rendering multiview', ): view = utils3d.torch.extrinsics_to_view(extr) proj = utils3d.torch.intrinsics_to_perspective(intr, near, far) rast = utils3d.torch.rasterize_triangle_faces( rastctx, verts_t[None], faces_t, resolution, resolution, uv=colors_t[None], view=view, projection=proj, ) color_img = rast['uv'][0] # (H, W, 3) interpolated vertex colors mask = rast['mask'][0] > 0.5 # rasterisation is in OpenGL bottom-left origin; flip to top-left color_img = color_img.flip(0).clamp(0, 1) mask = mask.flip(0) # zero out background so mask-based workflows stay correct color_img[~mask] = 0 observations.append( np.clip(color_img.cpu().numpy() * 255, 0, 255).astype(np.uint8) ) extrinsics_np = [e.cpu().numpy() for e in extrinsics] intrinsics_np = [i.cpu().numpy() for i in intrinsics] return observations, extrinsics_np, intrinsics_np def bake_texture( vertices: np.array, faces: np.array, uvs: np.array, observations: List[np.array], masks: List[np.array], extrinsics: List[np.array], intrinsics: List[np.array], texture_size: int = 2048, near: float = 0.1, far: float = 10.0, mode: Literal['fast', 'opt'] = 'opt', lambda_tv: float = 1e-2, verbose: bool = False, ): """ Bake texture to a mesh from multiple observations. Args: vertices (np.array): Vertices of the mesh. Shape (V, 3). faces (np.array): Faces of the mesh. Shape (F, 3). uvs (np.array): UV coordinates of the mesh. Shape (V, 2). observations (List[np.array]): List of observations. Each observation is a 2D image. Shape (H, W, 3). masks (List[np.array]): List of masks. Each mask is a 2D image. Shape (H, W). extrinsics (List[np.array]): List of extrinsics. Shape (4, 4). intrinsics (List[np.array]): List of intrinsics. Shape (3, 3). texture_size (int): Size of the texture. near (float): Near plane of the camera. far (float): Far plane of the camera. mode (Literal['fast', 'opt']): Mode of texture baking. lambda_tv (float): Weight of total variation loss in optimization. verbose (bool): Whether to print progress. """ vertices = torch.tensor(vertices).cuda() faces = torch.tensor(faces.astype(np.int32)).cuda() uvs = torch.tensor(uvs).cuda() observations = [torch.tensor(obs / 255.0).float().cuda() for obs in observations] masks = [torch.tensor(m>0).bool().cuda() for m in masks] views = [utils3d.torch.extrinsics_to_view(torch.tensor(extr).cuda()) for extr in extrinsics] projections = [utils3d.torch.intrinsics_to_perspective(torch.tensor(intr).cuda(), near, far) for intr in intrinsics] if mode == 'fast': texture = torch.zeros((texture_size * texture_size, 3), dtype=torch.float32).cuda() texture_weights = torch.zeros((texture_size * texture_size), dtype=torch.float32).cuda() rastctx = utils3d.torch.RastContext(backend='cuda') for observation, view, projection in tqdm(zip(observations, views, projections), total=len(observations), disable=not verbose, desc='Texture baking (fast)'): with torch.no_grad(): rast = utils3d.torch.rasterize_triangle_faces( rastctx, vertices[None], faces, observation.shape[1], observation.shape[0], uv=uvs[None], view=view, projection=projection ) uv_map = rast['uv'][0].detach().flip(0) mask = rast['mask'][0].detach().bool() & masks[0] # nearest neighbor interpolation uv_map = (uv_map * texture_size).floor().long() obs = observation[mask] uv_map = uv_map[mask] idx = uv_map[:, 0] + (texture_size - uv_map[:, 1] - 1) * texture_size texture = texture.scatter_add(0, idx.view(-1, 1).expand(-1, 3), obs) texture_weights = texture_weights.scatter_add(0, idx, torch.ones((obs.shape[0]), dtype=torch.float32, device=texture.device)) mask = texture_weights > 0 texture[mask] /= texture_weights[mask][:, None] texture = np.clip(texture.reshape(texture_size, texture_size, 3).cpu().numpy() * 255, 0, 255).astype(np.uint8) # inpaint mask = (texture_weights == 0).cpu().numpy().astype(np.uint8).reshape(texture_size, texture_size) texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA) elif mode == 'opt': rastctx = utils3d.torch.RastContext(backend='cuda') observations = [observations.flip(0) for observations in observations] masks = [m.flip(0) for m in masks] _uv = [] _uv_dr = [] for observation, view, projection in tqdm( zip(observations, views, projections), total=len(views), disable=not verbose, desc='Texture baking (opt): UV', ): with torch.no_grad(): rast = utils3d.torch.rasterize_triangle_faces( rastctx, vertices[None], faces, observation.shape[1], observation.shape[0], uv=uvs[None], view=view, projection=projection, ) _uv.append(rast['uv'].detach()) _uv_dr.append(rast['uv_dr'].detach()) texture = torch.nn.Parameter( torch.zeros((1, texture_size, texture_size, 3), dtype=torch.float32).cuda() ) optimizer = torch.optim.Adam([texture], betas=(0.5, 0.9), lr=1e-2) def exp_anealing(optimizer, step, total_steps, start_lr, end_lr): return start_lr * (end_lr / start_lr) ** (step / total_steps) def cosine_anealing(optimizer, step, total_steps, start_lr, end_lr): return end_lr + 0.5 * (start_lr - end_lr) * (1 + np.cos(np.pi * step / total_steps)) def tv_loss(texture): return torch.nn.functional.l1_loss( texture[:, :-1, :, :], texture[:, 1:, :, :] ) + torch.nn.functional.l1_loss( texture[:, :, :-1, :], texture[:, :, 1:, :] ) total_steps = 500 with tqdm( total=total_steps, disable=not verbose, desc='Texture baking (opt): optimizing', ) as pbar: for step in range(total_steps): optimizer.zero_grad() selected = np.random.randint(0, len(views)) uv, uv_dr, observation, mask = ( _uv[selected], _uv_dr[selected], observations[selected], masks[selected], ) render = dr.texture(texture, uv, uv_dr)[0] loss = torch.nn.functional.l1_loss(render[mask], observation[mask]) if lambda_tv > 0: loss += lambda_tv * tv_loss(texture) loss.backward() optimizer.step() # annealing optimizer.param_groups[0]['lr'] = cosine_anealing( optimizer, step, total_steps, 1e-2, 1e-5 ) pbar.set_postfix({'loss': loss.item()}) pbar.update() texture = np.clip( texture[0].flip(0).detach().cpu().numpy() * 255, 0, 255 ).astype(np.uint8) mask = 1 - utils3d.torch.rasterize_triangle_faces( rastctx, (uvs * 2 - 1)[None], faces, texture_size, texture_size )['mask'][0].detach().cpu().numpy().astype(np.uint8) texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA) else: raise ValueError(f'Unknown mode: {mode}') return texture def to_glb( app_rep: Union[Strivec, Gaussian], mesh: MeshExtractResult, simplify: float = 0.95, fill_holes: bool = True, fill_holes_max_size: float = 0.04, texture_size: int = 1024, debug: bool = False, verbose: bool = True, ) -> trimesh.Trimesh: """ Convert a generated asset to a glb file. Args: app_rep (Union[Strivec, Gaussian]): Appearance representation. mesh (MeshExtractResult): Extracted mesh. simplify (float): Ratio of faces to remove in simplification. fill_holes (bool): Whether to fill holes in the mesh. fill_holes_max_size (float): Maximum area of a hole to fill. texture_size (int): Size of the texture. debug (bool): Whether to print debug information. verbose (bool): Whether to print progress. """ vertices = mesh.vertices.cpu().numpy() faces = mesh.faces.cpu().numpy() # mesh postprocess vertices, faces = postprocess_mesh( vertices, faces, simplify=simplify > 0, simplify_ratio=simplify, fill_holes=fill_holes, fill_holes_max_hole_size=fill_holes_max_size, fill_holes_max_hole_nbe=int(250 * np.sqrt(1-simplify)), fill_holes_resolution=1024, fill_holes_num_views=1000, debug=debug, verbose=verbose, ) # parametrize mesh vertices, faces, uvs, _vmapping = parametrize_mesh(vertices, faces) # bake texture observations, extrinsics, intrinsics = render_multiview(app_rep, resolution=1024, nviews=100) masks = [np.any(observation > 0, axis=-1) for observation in observations] extrinsics = [extrinsics[i].cpu().numpy() for i in range(len(extrinsics))] intrinsics = [intrinsics[i].cpu().numpy() for i in range(len(intrinsics))] texture = bake_texture( vertices, faces, uvs, observations, masks, extrinsics, intrinsics, texture_size=texture_size, mode='opt', lambda_tv=0.01, verbose=verbose ) texture = Image.fromarray(texture) # rotate mesh (from z-up to y-up) vertices = vertices @ np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]]) material = trimesh.visual.material.PBRMaterial( roughnessFactor=1.0, baseColorTexture=texture, baseColorFactor=np.array([255, 255, 255, 255], dtype=np.uint8) ) mesh = trimesh.Trimesh(vertices, faces, visual=trimesh.visual.TextureVisuals(uv=uvs, material=material)) return mesh def simplify_gs( gs: Gaussian, simplify: float = 0.95, verbose: bool = True, ): """ Simplify 3D Gaussians NOTE: this function is not used in the current implementation for the unsatisfactory performance. Args: gs (Gaussian): 3D Gaussian. simplify (float): Ratio of Gaussians to remove in simplification. """ if simplify <= 0: return gs # simplify observations, extrinsics, intrinsics = render_multiview(gs, resolution=1024, nviews=100) observations = [torch.tensor(obs / 255.0).float().cuda().permute(2, 0, 1) for obs in observations] # Following https://arxiv.org/pdf/2411.06019 renderer = GaussianRenderer({ "resolution": 1024, "near": 0.8, "far": 1.6, "ssaa": 1, "bg_color": (0,0,0), }) new_gs = Gaussian(**gs.init_params) new_gs._features_dc = gs._features_dc.clone() new_gs._features_rest = gs._features_rest.clone() if gs._features_rest is not None else None new_gs._opacity = torch.nn.Parameter(gs._opacity.clone()) new_gs._rotation = torch.nn.Parameter(gs._rotation.clone()) new_gs._scaling = torch.nn.Parameter(gs._scaling.clone()) new_gs._xyz = torch.nn.Parameter(gs._xyz.clone()) start_lr = [1e-4, 1e-3, 5e-3, 0.025] end_lr = [1e-6, 1e-5, 5e-5, 0.00025] optimizer = torch.optim.Adam([ {"params": new_gs._xyz, "lr": start_lr[0]}, {"params": new_gs._rotation, "lr": start_lr[1]}, {"params": new_gs._scaling, "lr": start_lr[2]}, {"params": new_gs._opacity, "lr": start_lr[3]}, ], lr=start_lr[0]) def exp_anealing(optimizer, step, total_steps, start_lr, end_lr): return start_lr * (end_lr / start_lr) ** (step / total_steps) def cosine_anealing(optimizer, step, total_steps, start_lr, end_lr): return end_lr + 0.5 * (start_lr - end_lr) * (1 + np.cos(np.pi * step / total_steps)) _zeta = new_gs.get_opacity.clone().detach().squeeze() _lambda = torch.zeros_like(_zeta) _delta = 1e-7 _interval = 10 num_target = int((1 - simplify) * _zeta.shape[0]) with tqdm(total=2500, disable=not verbose, desc='Simplifying Gaussian') as pbar: for i in range(2500): # prune if i % 100 == 0: mask = new_gs.get_opacity.squeeze() > 0.05 mask = torch.nonzero(mask).squeeze() new_gs._xyz = torch.nn.Parameter(new_gs._xyz[mask]) new_gs._rotation = torch.nn.Parameter(new_gs._rotation[mask]) new_gs._scaling = torch.nn.Parameter(new_gs._scaling[mask]) new_gs._opacity = torch.nn.Parameter(new_gs._opacity[mask]) new_gs._features_dc = new_gs._features_dc[mask] new_gs._features_rest = new_gs._features_rest[mask] if new_gs._features_rest is not None else None _zeta = _zeta[mask] _lambda = _lambda[mask] # update optimizer state for param_group, new_param in zip(optimizer.param_groups, [new_gs._xyz, new_gs._rotation, new_gs._scaling, new_gs._opacity]): stored_state = optimizer.state[param_group['params'][0]] if 'exp_avg' in stored_state: stored_state['exp_avg'] = stored_state['exp_avg'][mask] stored_state['exp_avg_sq'] = stored_state['exp_avg_sq'][mask] del optimizer.state[param_group['params'][0]] param_group['params'][0] = new_param optimizer.state[param_group['params'][0]] = stored_state opacity = new_gs.get_opacity.squeeze() # sparisfy if i % _interval == 0: _zeta = _lambda + opacity.detach() if opacity.shape[0] > num_target: index = _zeta.topk(num_target)[1] _m = torch.ones_like(_zeta, dtype=torch.bool) _m[index] = 0 _zeta[_m] = 0 _lambda = _lambda + opacity.detach() - _zeta # sample a random view view_idx = np.random.randint(len(observations)) observation = observations[view_idx] extrinsic = extrinsics[view_idx] intrinsic = intrinsics[view_idx] color = renderer.render(new_gs, extrinsic, intrinsic)['color'] rgb_loss = torch.nn.functional.l1_loss(color, observation) loss = rgb_loss + \ _delta * torch.sum(torch.pow(_lambda + opacity - _zeta, 2)) optimizer.zero_grad() loss.backward() optimizer.step() # update lr for j in range(len(optimizer.param_groups)): optimizer.param_groups[j]['lr'] = cosine_anealing(optimizer, i, 2500, start_lr[j], end_lr[j]) pbar.set_postfix({'loss': rgb_loss.item(), 'num': opacity.shape[0], 'lambda': _lambda.mean().item()}) pbar.update() new_gs._xyz = new_gs._xyz.data new_gs._rotation = new_gs._rotation.data new_gs._scaling = new_gs._scaling.data new_gs._opacity = new_gs._opacity.data return new_gs