"""Shared GLB export logic used by both the Gradio app and FastAPI export worker. This module owns the remesh=True / remesh=False branching and the SAFE_NONREMESH_GLB_EXPORT env-flag behaviour so that the two entry-points stay in lock-step. """ from __future__ import annotations import os from typing import Any, Dict import cv2 import numpy as np import torch from PIL import Image import o_voxel # --------------------------------------------------------------------------- # Env helpers # --------------------------------------------------------------------------- def _env_flag(name: str, default: bool) -> bool: value = os.environ.get(name) if value is None: return default return value.strip().lower() in {"1", "true", "yes", "on"} SAFE_NONREMESH_GLB_EXPORT: bool = _env_flag("SAFE_NONREMESH_GLB_EXPORT", True) # --------------------------------------------------------------------------- # Logging helpers # --------------------------------------------------------------------------- def _cumesh_counts(mesh: Any) -> str: num_vertices = getattr(mesh, "num_vertices", "?") num_faces = getattr(mesh, "num_faces", "?") return f"vertices={num_vertices}, faces={num_faces}" def _log_cumesh_counts(label: str, mesh: Any) -> None: print(f"{label}: {_cumesh_counts(mesh)}", flush=True) # --------------------------------------------------------------------------- # Safe non-remesh fallback (extracted verbatim from app.py) # --------------------------------------------------------------------------- def _to_glb_without_risky_nonremesh_cleanup( *, vertices: torch.Tensor, faces: torch.Tensor, attr_volume: torch.Tensor, coords: torch.Tensor, attr_layout: Dict[str, slice], aabb: Any, voxel_size: Any = None, grid_size: Any = None, decimation_target: int = 1000000, texture_size: int = 2048, mesh_cluster_threshold_cone_half_angle_rad=np.radians(90.0), mesh_cluster_refine_iterations=0, mesh_cluster_global_iterations=1, mesh_cluster_smooth_strength=1, verbose: bool = False, use_tqdm: bool = False, ): postprocess = o_voxel.postprocess def _try_unify_face_orientations(current_mesh: Any) -> Any: _log_cumesh_counts("Before face-orientation unification", current_mesh) try: current_mesh.unify_face_orientations() _log_cumesh_counts("After face-orientation unification", current_mesh) return current_mesh except RuntimeError as error: if "[CuMesh] CUDA error" not in str(error): raise print( "Face-orientation unification failed in remesh=False fallback; " f"retrying once from readback. error={error}", flush=True, ) try: retry_vertices, retry_faces = current_mesh.read() retry_mesh = postprocess.cumesh.CuMesh() retry_mesh.init(retry_vertices, retry_faces) retry_mesh.remove_duplicate_faces() retry_mesh.remove_small_connected_components(1e-5) _log_cumesh_counts("Before face-orientation retry", retry_mesh) retry_mesh.unify_face_orientations() _log_cumesh_counts("After face-orientation retry", retry_mesh) return retry_mesh except RuntimeError as retry_error: if "[CuMesh] CUDA error" not in str(retry_error): raise print( "Skipping face-orientation unification in remesh=False fallback after " f"retry failure: {retry_error}", flush=True, ) return current_mesh if isinstance(aabb, (list, tuple)): aabb = np.array(aabb) if isinstance(aabb, np.ndarray): aabb = torch.tensor(aabb, dtype=torch.float32, device=coords.device) assert isinstance(aabb, torch.Tensor) assert aabb.dim() == 2 and aabb.size(0) == 2 and aabb.size(1) == 3 if voxel_size is not None: if isinstance(voxel_size, float): voxel_size = [voxel_size, voxel_size, voxel_size] if isinstance(voxel_size, (list, tuple)): voxel_size = np.array(voxel_size) if isinstance(voxel_size, np.ndarray): voxel_size = torch.tensor( voxel_size, dtype=torch.float32, device=coords.device ) grid_size = ((aabb[1] - aabb[0]) / voxel_size).round().int() else: assert grid_size is not None, "Either voxel_size or grid_size must be provided" if isinstance(grid_size, int): grid_size = [grid_size, grid_size, grid_size] if isinstance(grid_size, (list, tuple)): grid_size = np.array(grid_size) if isinstance(grid_size, np.ndarray): grid_size = torch.tensor(grid_size, dtype=torch.int32, device=coords.device) voxel_size = (aabb[1] - aabb[0]) / grid_size assert isinstance(voxel_size, torch.Tensor) assert voxel_size.dim() == 1 and voxel_size.size(0) == 3 assert isinstance(grid_size, torch.Tensor) assert grid_size.dim() == 1 and grid_size.size(0) == 3 pbar = None if use_tqdm: pbar = postprocess.tqdm(total=6, desc="Extracting GLB") vertices = vertices.cuda() faces = faces.cuda() mesh = postprocess.cumesh.CuMesh() mesh.init(vertices, faces) _log_cumesh_counts("Fallback mesh init", mesh) if pbar is not None: pbar.update(1) if pbar is not None: pbar.set_description("Building BVH") bvh = postprocess.cumesh.cuBVH(vertices, faces) if pbar is not None: pbar.update(1) if pbar is not None: pbar.set_description("Cleaning mesh") mesh.simplify(decimation_target * 3, verbose=verbose) _log_cumesh_counts("After fallback coarse simplification", mesh) mesh.remove_duplicate_faces() mesh.remove_small_connected_components(1e-5) _log_cumesh_counts("After fallback initial cleanup", mesh) mesh.simplify(decimation_target, verbose=verbose) _log_cumesh_counts("After fallback target simplification", mesh) mesh.remove_duplicate_faces() mesh.remove_small_connected_components(1e-5) _log_cumesh_counts("After fallback final cleanup", mesh) mesh = _try_unify_face_orientations(mesh) if pbar is not None: pbar.update(1) if pbar is not None: pbar.set_description("Parameterizing new mesh") out_vertices, out_faces, out_uvs, out_vmaps = mesh.uv_unwrap( compute_charts_kwargs={ "threshold_cone_half_angle_rad": mesh_cluster_threshold_cone_half_angle_rad, "refine_iterations": mesh_cluster_refine_iterations, "global_iterations": mesh_cluster_global_iterations, "smooth_strength": mesh_cluster_smooth_strength, }, return_vmaps=True, verbose=verbose, ) out_vertices = out_vertices.cuda() out_faces = out_faces.cuda() out_uvs = out_uvs.cuda() out_vmaps = out_vmaps.cuda() mesh.compute_vertex_normals() out_normals = mesh.read_vertex_normals()[out_vmaps] if pbar is not None: pbar.update(1) if pbar is not None: pbar.set_description("Sampling attributes") ctx = postprocess.dr.RasterizeCudaContext() uvs_rast = torch.cat( [ out_uvs * 2 - 1, torch.zeros_like(out_uvs[:, :1]), torch.ones_like(out_uvs[:, :1]), ], dim=-1, ).unsqueeze(0) rast = torch.zeros( (1, texture_size, texture_size, 4), device="cuda", dtype=torch.float32 ) for i in range(0, out_faces.shape[0], 100000): rast_chunk, _ = postprocess.dr.rasterize( ctx, uvs_rast, out_faces[i : i + 100000], resolution=[texture_size, texture_size], ) mask_chunk = rast_chunk[..., 3:4] > 0 rast_chunk[..., 3:4] += i rast = torch.where(mask_chunk, rast_chunk, rast) mask = rast[0, ..., 3] > 0 pos = postprocess.dr.interpolate(out_vertices.unsqueeze(0), rast, out_faces)[0][0] valid_pos = pos[mask] _, face_id, uvw = bvh.unsigned_distance(valid_pos, return_uvw=True) orig_tri_verts = vertices[faces[face_id.long()]] valid_pos = (orig_tri_verts * uvw.unsqueeze(-1)).sum(dim=1) attrs = torch.zeros(texture_size, texture_size, attr_volume.shape[1], device="cuda") attrs[mask] = postprocess.grid_sample_3d( attr_volume, torch.cat([torch.zeros_like(coords[:, :1]), coords], dim=-1), shape=torch.Size([1, attr_volume.shape[1], *grid_size.tolist()]), grid=((valid_pos - aabb[0]) / voxel_size).reshape(1, -1, 3), mode="trilinear", ) if pbar is not None: pbar.update(1) if pbar is not None: pbar.set_description("Finalizing mesh") mask = mask.cpu().numpy() base_color = np.clip( attrs[..., attr_layout["base_color"]].cpu().numpy() * 255, 0, 255 ).astype(np.uint8) metallic = np.clip( attrs[..., attr_layout["metallic"]].cpu().numpy() * 255, 0, 255 ).astype(np.uint8) roughness = np.clip( attrs[..., attr_layout["roughness"]].cpu().numpy() * 255, 0, 255 ).astype(np.uint8) alpha = np.clip( attrs[..., attr_layout["alpha"]].cpu().numpy() * 255, 0, 255 ).astype(np.uint8) mask_inv = (~mask).astype(np.uint8) base_color = cv2.inpaint(base_color, mask_inv, 3, cv2.INPAINT_TELEA) metallic = cv2.inpaint(metallic, mask_inv, 1, cv2.INPAINT_TELEA)[..., None] roughness = cv2.inpaint(roughness, mask_inv, 1, cv2.INPAINT_TELEA)[..., None] alpha = cv2.inpaint(alpha, mask_inv, 1, cv2.INPAINT_TELEA)[..., None] material = postprocess.trimesh.visual.material.PBRMaterial( baseColorTexture=Image.fromarray(np.concatenate([base_color, alpha], axis=-1)), baseColorFactor=np.array([255, 255, 255, 255], dtype=np.uint8), metallicRoughnessTexture=Image.fromarray( np.concatenate([np.zeros_like(metallic), roughness, metallic], axis=-1) ), metallicFactor=1.0, roughnessFactor=1.0, alphaMode="OPAQUE", doubleSided=True, ) vertices_np = out_vertices.cpu().numpy() faces_np = out_faces.cpu().numpy() uvs_np = out_uvs.cpu().numpy() normals_np = out_normals.cpu().numpy() vertices_np[:, 1], vertices_np[:, 2] = vertices_np[:, 2], -vertices_np[:, 1] normals_np[:, 1], normals_np[:, 2] = normals_np[:, 2], -normals_np[:, 1] uvs_np[:, 1] = 1 - uvs_np[:, 1] textured_mesh = postprocess.trimesh.Trimesh( vertices=vertices_np, faces=faces_np, vertex_normals=normals_np, process=False, visual=postprocess.trimesh.visual.TextureVisuals(uv=uvs_np, material=material), ) if pbar is not None: pbar.update(1) pbar.close() return textured_mesh # --------------------------------------------------------------------------- # Public entry-point -- mirrors the branching in app.py extract_glb() # --------------------------------------------------------------------------- def export_glb( *, vertices: torch.Tensor, faces: torch.Tensor, attr_volume: torch.Tensor, coords: torch.Tensor, attr_layout: Dict[str, slice], grid_size: Any, aabb: Any, decimation_target: int, texture_size: int, remesh: bool, safe_nonremesh_fallback: bool | None = None, use_tqdm: bool = False, ): """Export a trimesh GLB scene from decoded mesh data. Args: remesh: Whether to rebuild mesh topology during export. safe_nonremesh_fallback: When ``remesh=False``, selects which non-remesh path to use. ``True`` = safe fallback (guarded face-orientation, retry logic). ``False`` = upstream raw ``to_glb(remesh=False)``. ``None`` (default) = fall back to the ``SAFE_NONREMESH_GLB_EXPORT`` env var (which itself defaults to ``True``). Ignored when ``remesh=True``. """ glb_kwargs = dict( vertices=vertices, faces=faces, attr_volume=attr_volume, coords=coords, attr_layout=attr_layout, grid_size=grid_size, aabb=aabb, decimation_target=decimation_target, texture_size=texture_size, use_tqdm=use_tqdm, ) if remesh: return o_voxel.postprocess.to_glb( **glb_kwargs, remesh=True, remesh_band=1, remesh_project=0, ) use_safe = ( safe_nonremesh_fallback if safe_nonremesh_fallback is not None else SAFE_NONREMESH_GLB_EXPORT ) if use_safe: print( "Using remesh=False safe GLB export fallback " f"(safe_nonremesh_fallback={safe_nonremesh_fallback}, " f"SAFE_NONREMESH_GLB_EXPORT={SAFE_NONREMESH_GLB_EXPORT})", flush=True, ) return _to_glb_without_risky_nonremesh_cleanup( vertices=vertices, faces=faces, attr_volume=attr_volume, coords=coords, attr_layout=attr_layout, grid_size=grid_size, aabb=aabb, decimation_target=decimation_target, texture_size=texture_size, use_tqdm=use_tqdm, ) print( "Using upstream remesh=False GLB export path " f"(safe_nonremesh_fallback={safe_nonremesh_fallback}, " f"SAFE_NONREMESH_GLB_EXPORT={SAFE_NONREMESH_GLB_EXPORT})", flush=True, ) return o_voxel.postprocess.to_glb( **glb_kwargs, remesh=False, remesh_band=1, remesh_project=0, )