Spaces:
Running on Zero
Running on Zero
| import numpy as np | |
| import torch | |
| import trimesh | |
| import io | |
| from PIL import Image | |
| _triposr_cache = None | |
| def get_triposr(): | |
| global _triposr_cache | |
| if _triposr_cache is None: | |
| from transformers import TripoSRForImageTo3D, TripoSRImageProcessor | |
| processor = TripoSRImageProcessor.from_pretrained("stabilityai/TripoSR") | |
| model = TripoSRForImageTo3D.from_pretrained("stabilityai/TripoSR") | |
| model.eval() | |
| _triposr_cache = (model, processor) | |
| return _triposr_cache | |
| def reconstruct_region(image: Image.Image, mask: list[list[bool]], bbox: list[int]) -> bytes: | |
| """ | |
| Crop the masked region from the image, run TripoSR, return GLB bytes. | |
| """ | |
| model, processor = get_triposr() | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = model.to(device) | |
| # Crop to bounding box with 20% padding | |
| x, y, w, h = bbox | |
| pad_x = int(w * 0.20) | |
| pad_y = int(h * 0.20) | |
| W, H = image.size | |
| x0 = max(0, x - pad_x) | |
| y0 = max(0, y - pad_y) | |
| x1 = min(W, x + w + pad_x) | |
| y1 = min(H, y + h + pad_y) | |
| cropped = image.crop((x0, y0, x1, y1)).resize((512, 512), Image.LANCZOS) | |
| # Apply mask as alpha channel so TripoSR focuses on the region | |
| mask_arr = np.array(mask, dtype=np.uint8)[y0:y1, x0:x1] | |
| mask_resized = np.array( | |
| Image.fromarray(mask_arr * 255).resize((512, 512), Image.NEAREST) | |
| ) | |
| rgba = np.array(cropped.convert("RGBA")) | |
| rgba[:, :, 3] = mask_resized | |
| input_img = Image.fromarray(rgba) | |
| inputs = processor(images=input_img, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| # Export as GLB via trimesh | |
| mesh_data = outputs.mesh # TripoSR returns a trimesh-compatible object | |
| if hasattr(mesh_data, "export"): | |
| glb_bytes = mesh_data.export(file_type="glb") | |
| else: | |
| # Fallback: build trimesh from vertices/faces tensors | |
| verts = mesh_data.verts_list()[0].cpu().numpy() | |
| faces = mesh_data.faces_list()[0].cpu().numpy() | |
| mesh = trimesh.Trimesh(vertices=verts, faces=faces, process=False) | |
| buf = io.BytesIO() | |
| mesh.export(buf, file_type="glb") | |
| glb_bytes = buf.getvalue() | |
| return glb_bytes | |
| def depth_to_mesh(depth: list[list[float]], mask: list[list[bool]], image: Image.Image) -> bytes: | |
| """ | |
| Fallback when TripoSR isn't available: lift depth map into a 3D mesh | |
| constrained to the masked region, textured with the source image. | |
| """ | |
| depth_arr = np.array(depth, dtype=np.float32) | |
| mask_arr = np.array(mask, dtype=bool) | |
| H, W = depth_arr.shape | |
| # Normalize depth to [0, 1] then scale to reasonable Z range | |
| dmin, dmax = depth_arr.min(), depth_arr.max() | |
| if dmax > dmin: | |
| depth_norm = (depth_arr - dmin) / (dmax - dmin) | |
| else: | |
| depth_norm = np.zeros_like(depth_arr) | |
| depth_scaled = depth_norm * 0.5 # 0.5 units of Z range | |
| # Build vertex grid only for masked pixels | |
| ys, xs = np.where(mask_arr) | |
| if len(xs) == 0: | |
| # Empty mask — return a flat quad | |
| verts = np.array([[0, 0, 0], [1, 0, 0], [1, 1, 0], [0, 1, 0]], dtype=np.float32) | |
| faces = np.array([[0, 1, 2], [0, 2, 3]]) | |
| mesh = trimesh.Trimesh(vertices=verts, faces=faces) | |
| buf = io.BytesIO() | |
| mesh.export(buf, file_type="glb") | |
| return buf.getvalue() | |
| # Normalize to [-0.5, 0.5] XY space | |
| x_norm = (xs / W) - 0.5 | |
| y_norm = 0.5 - (ys / H) | |
| z_vals = depth_scaled[ys, xs] | |
| vertices = np.stack([x_norm, y_norm, z_vals], axis=1).astype(np.float32) | |
| # UV = source pixel position | |
| uvs = np.stack([xs / W, 1.0 - ys / H], axis=1).astype(np.float32) | |
| # Triangulate the masked grid using Delaunay | |
| from scipy.spatial import Delaunay | |
| points_2d = np.stack([x_norm, y_norm], axis=1) | |
| tri = Delaunay(points_2d) | |
| faces = tri.simplices.astype(np.int32) | |
| # Build mesh with texture | |
| img_arr = np.array(image.convert("RGB")) | |
| texture = trimesh.visual.texture.TextureVisuals( | |
| uv=uvs, | |
| image=Image.fromarray(img_arr), | |
| ) | |
| mesh = trimesh.Trimesh(vertices=vertices, faces=faces, visual=texture, process=False) | |
| buf = io.BytesIO() | |
| mesh.export(buf, file_type="glb") | |
| return buf.getvalue() | |