| import spaces |
| import os |
| import gradio as gr |
| import numpy as np |
| import torch |
| from PIL import Image |
| import trimesh |
| import random |
| from transformers import AutoModelForImageSegmentation |
| from torchvision import transforms |
| from huggingface_hub import hf_hub_download, snapshot_download, login |
| import subprocess |
| import shutil |
|
|
|
|
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| DTYPE = torch.float16 |
|
|
| print("DEVICE: ", DEVICE) |
|
|
| DEFAULT_PART_FACE_NUMBER = 10000 |
| MAX_SEED = np.iinfo(np.int32).max |
| HOLOPART_REPO_URL = "https://github.com/VAST-AI-Research/HoloPart.git" |
| HOLOPART_PRETRAINED_MODEL = "checkpoints/HoloPart" |
|
|
| TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp") |
| os.makedirs(TMP_DIR, exist_ok=True) |
|
|
| HOLOPART_CODE_DIR = "./holopart" |
| if not os.path.exists(HOLOPART_REPO_URL): |
| os.system(f"GIT_LFS_SKIP_SMUDGE=1 git clone {HOLOPART_REPO_URL} {HOLOPART_CODE_DIR}") |
|
|
| import sys |
| sys.path.append(HOLOPART_CODE_DIR) |
| sys.path.append(os.path.join(HOLOPART_CODE_DIR, "scripts")) |
|
|
|
|
| EXAMPLES = [ |
| ["./assets/example_data/000.glb", "./assets/example_data/000.PNG"], |
| ["./assets/example_data/001.glb", "./assets/example_data/001.PNG"], |
| ["./assets/example_data/002.glb", "./assets/example_data/002.PNG"], |
| ["./assets/example_data/003.glb", "./assets/example_data/003.PNG"], |
| ["./assets/example_data/004.glb", "./assets/example_data/004.PNG"], |
| ["./assets/example_data/005.glb", "./assets/example_data/005.PNG"] |
| ] |
|
|
|
|
|
|
| HEADER = """ |
| # 🔮 Decompose a 3D shape into complete parts with [HoloPart](https://github.com/VAST-AI-Research/HoloPart). |
| ### Step 1: Prepare Your Segmented Mesh |
| Upload a mesh with part segmentation. We recommend using these segmentation tools: |
| - [SAMPart3D](https://github.com/Pointcept/SAMPart3D) |
| - [SAMesh](https://github.com/gtangg12/samesh) |
| |
| For a mesh file `mesh.glb` and corresponding face mask `mask.npy`, prepare your input using this Python code: |
| ```python |
| import trimesh |
| import numpy as np |
| mesh = trimesh.load("mesh.glb", force="mesh") |
| mask_npy = np.load("mask.npy") |
| mesh_parts = [] |
| for part_id in np.unique(mask_npy): |
| mesh_part = mesh.submesh([mask_npy == part_id], append=True) |
| mesh_parts.append(mesh_part) |
| mesh_parts = trimesh.Scene(mesh_parts).export("input_mesh.glb") |
| ``` |
| The resulting **input_mesh.glb** is your prepared input for HoloPart. |
| ### Step 2: Click the Decompose Parts button to begin the decomposition process. |
| """ |
|
|
|
|
| from inference_holopart import prepare_data, run_holopart |
| from holopart.pipelines.pipeline_holopart import HoloPartPipeline |
|
|
| snapshot_download("VAST-AI/HoloPart", local_dir=HOLOPART_PRETRAINED_MODEL) |
| holopart_pipe = HoloPartPipeline.from_pretrained(HOLOPART_PRETRAINED_MODEL).to(DEVICE, DTYPE) |
|
|
| def start_session(req: gr.Request): |
| save_dir = os.path.join(TMP_DIR, str(req.session_hash)) |
| os.makedirs(save_dir, exist_ok=True) |
| print("start session, mkdir", save_dir) |
|
|
| def end_session(req: gr.Request): |
| save_dir = os.path.join(TMP_DIR, str(req.session_hash)) |
| shutil.rmtree(save_dir) |
|
|
| def get_random_hex(): |
| random_bytes = os.urandom(8) |
| random_hex = random_bytes.hex() |
| return random_hex |
|
|
| def get_random_seed(randomize_seed, seed): |
| if randomize_seed: |
| seed = random.randint(0, MAX_SEED) |
| return seed |
|
|
| def explode_mesh(mesh: trimesh.Scene, explode_factor: float = 0.5): |
| center = mesh.centroid |
| exploded_mesh = trimesh.Scene() |
| for geometry_name, geometry in mesh.geometry.items(): |
| transform = mesh.graph[geometry_name][0] |
| vertices_global = trimesh.transformations.transform_points( |
| geometry.vertices, transform) |
| part_center = np.mean(vertices_global, axis=0) |
| direction = part_center - center |
| direction_length = np.linalg.norm(direction) |
| if direction_length > 0: |
| direction = direction / direction_length |
| displacement = direction * explode_factor |
| new_transform = np.copy(transform) |
| new_transform[:3, 3] += displacement |
| exploded_mesh.add_geometry(geometry, transform=new_transform, geom_name=geometry_name) |
| return exploded_mesh |
|
|
|
|
| @spaces.GPU(duration=600) |
| def run_full(data_path: str, example_image=None, seed=42, num_inference_steps=25, guidance_scale=3.5, progress=gr.Progress(track_tqdm=True)): |
|
|
| batch_size = 30 |
| parts_data = prepare_data(data_path) |
|
|
| part_scene = run_holopart( |
| holopart_pipe, |
| batch=parts_data, |
| batch_size=batch_size, |
| seed=seed, |
| num_inference_steps=num_inference_steps, |
| guidance_scale=guidance_scale, |
| num_chunks=1000000, |
| ) |
| print("mesh extraction done") |
| |
| save_dir = os.path.join(TMP_DIR, "examples") |
| os.makedirs(save_dir, exist_ok=True) |
| mesh_path = os.path.join(save_dir, f"holorpart_{get_random_hex()}.glb") |
| part_scene.export(mesh_path) |
| print("save to ", mesh_path) |
| exploded_mesh = explode_mesh(part_scene, 0.7) |
| exploded_mesh_path = os.path.join(save_dir, f"holorpart_exploded_{get_random_hex()}.glb") |
| exploded_mesh.export(exploded_mesh_path) |
|
|
| torch.cuda.empty_cache() |
|
|
| return mesh_path, exploded_mesh_path |
|
|
| def visualize_input(mesh_in): |
| return mesh_in |
|
|
|
|
| with gr.Blocks(title="HoloPart") as demo: |
| gr.Markdown(HEADER) |
|
|
| with gr.Row(): |
| with gr.Column(): |
| with gr.Row(): |
| with gr.Column(): |
| input_mesh = gr.File(label="Input Mesh", file_types=[".glb"]) |
| input_visualizer = gr.Model3D(label="Input visualizer", interactive=False) |
| example_image = gr.Image(label="Example Image", type="filepath", interactive=False, visible=False) |
| |
| |
| |
| |
| with gr.Accordion("Generation Settings", open=True): |
| seed = gr.Slider( |
| label="Seed", |
| minimum=0, |
| maximum=MAX_SEED, |
| step=0, |
| value=0 |
| ) |
| |
| num_inference_steps = gr.Slider( |
| label="Number of inference steps", |
| minimum=8, |
| maximum=50, |
| step=1, |
| value=25, |
| ) |
| guidance_scale = gr.Slider( |
| label="CFG scale", |
| minimum=0.0, |
| maximum=20.0, |
| step=0.1, |
| value=3.5, |
| ) |
|
|
| with gr.Row(): |
| reduce_face = gr.Checkbox(label="Simplify Mesh", value=True, interactive=False) |
| |
|
|
| gen_button = gr.Button("Decompose Parts", variant="primary") |
|
|
| with gr.Column(): |
| model_output = gr.Model3D(label="Decomposed GLB", interactive=False) |
| exploded_parts_output = gr.Model3D(label="Exploded Parts", interactive=False) |
|
|
| |
| with gr.Row(): |
| examples = gr.Examples( |
| examples=EXAMPLES, |
| fn=run_full, |
| inputs=[input_mesh, example_image], |
| outputs=[model_output, exploded_parts_output], |
| cache_examples=True, |
| cache_mode="lazy" |
| ) |
| |
|
|
|
|
| input_mesh.upload( |
| fn = visualize_input, |
| inputs = [input_mesh], |
| outputs = [input_visualizer], |
| queue=False |
| ) |
|
|
| input_mesh.change( |
| fn = visualize_input, |
| inputs = [input_mesh], |
| outputs = [input_visualizer], |
| queue=False |
| ) |
|
|
| gen_button.click( |
| run_full, |
| inputs=[ |
| input_mesh, |
| example_image, |
| seed, |
| num_inference_steps, |
| guidance_scale |
| ], |
| outputs=[model_output, exploded_parts_output], |
| ) |
|
|
| demo.load(start_session) |
| demo.unload(end_session) |
|
|
| demo.launch(ssr_mode=False) |
|
|