keeendaaa commited on
Commit
2b1896f
·
1 Parent(s): 98202ab

Initial TripoSG Space app

Browse files
Files changed (5) hide show
  1. .gitignore +5 -0
  2. README.md +41 -2
  3. app.py +225 -0
  4. requirements.txt +23 -0
  5. utils.py +37 -0
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ checkpoints/
2
+ triposg/
3
+ tmp/
4
+ __pycache__/
5
+ *.glb
README.md CHANGED
@@ -1,12 +1,51 @@
1
  ---
2
- title: Trip W Oblaka
3
  emoji: 😻
4
  colorFrom: pink
5
  colorTo: blue
6
  sdk: gradio
7
  sdk_version: 6.5.0
8
  app_file: app.py
 
9
  pinned: false
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: TripoSG Image-to-3D API
3
  emoji: 😻
4
  colorFrom: pink
5
  colorTo: blue
6
  sdk: gradio
7
  sdk_version: 6.5.0
8
  app_file: app.py
9
+ python_version: 3.10
10
  pinned: false
11
  ---
12
 
13
+ # TripoSG Image-to-3D API
14
+
15
+ This Space wraps the official TripoSG pipeline and exposes a `/predict` API endpoint for programmatic generation of GLB meshes.
16
+
17
+ ## API usage
18
+
19
+ Python:
20
+
21
+ ```python
22
+ from gradio_client import Client
23
+
24
+ client = Client("your-username/your-space")
25
+ result = client.predict(
26
+ image_path="input.png",
27
+ seed=0,
28
+ num_inference_steps=50,
29
+ guidance_scale=7.5,
30
+ simplify=True,
31
+ target_face_num=100000,
32
+ api_name="/predict",
33
+ )
34
+ print(result)
35
+ ```
36
+
37
+ Raw HTTP (example):
38
+
39
+ ```bash
40
+ curl -X POST \
41
+ -H "Content-Type: application/json" \
42
+ -d '{"data": ["data:image/png;base64,......", 0, 50, 7.5, true, 100000]}' \
43
+ https://your-username-your-space.hf.space/api/predict
44
+ ```
45
+
46
+ The response contains the generated GLB file path and URL.
47
+
48
+ ## Notes
49
+
50
+ - The Space will clone `VAST-AI-Research/TripoSG` at runtime and download weights from `VAST-AI/TripoSG` and `briaai/RMBG-1.4`.
51
+ - `requirements.txt` targets the default Hugging Face Spaces GPU runtime (Linux). For local runs, adjust Torch/CUDA and the `diso` wheel as needed.
app.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import uuid
4
+ import shutil
5
+
6
+ import gradio as gr
7
+ import numpy as np
8
+ import torch
9
+ from huggingface_hub import snapshot_download
10
+ import trimesh
11
+
12
+ try:
13
+ import spaces
14
+
15
+ gpu = spaces.GPU
16
+ except Exception:
17
+
18
+ def gpu(*_args, **_kwargs):
19
+ def _wrap(fn):
20
+ return fn
21
+
22
+ return _wrap
23
+
24
+
25
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
26
+ DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
27
+
28
+ TRIPOSG_REPO_URL = "https://github.com/VAST-AI-Research/TripoSG.git"
29
+ TRIPOSG_CODE_DIR = "./triposg"
30
+
31
+ CHECKPOINT_DIR = "checkpoints"
32
+ RMBG_PRETRAINED_MODEL = os.path.join(CHECKPOINT_DIR, "RMBG-1.4")
33
+ TRIPOSG_PRETRAINED_MODEL = os.path.join(CHECKPOINT_DIR, "TripoSG")
34
+
35
+ TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp")
36
+ os.makedirs(TMP_DIR, exist_ok=True)
37
+
38
+
39
+ if not os.path.exists(TRIPOSG_CODE_DIR):
40
+ os.system(f"git clone {TRIPOSG_REPO_URL} {TRIPOSG_CODE_DIR}")
41
+
42
+ sys.path.append(TRIPOSG_CODE_DIR)
43
+ sys.path.append(os.path.join(TRIPOSG_CODE_DIR, "scripts"))
44
+
45
+
46
+ from image_process import prepare_image
47
+ from briarmbg import BriaRMBG
48
+ from triposg.pipelines.pipeline_triposg import TripoSGPipeline
49
+ from utils import simplify_mesh
50
+
51
+
52
+ snapshot_download("briaai/RMBG-1.4", local_dir=RMBG_PRETRAINED_MODEL)
53
+ rmbg_net = BriaRMBG.from_pretrained(RMBG_PRETRAINED_MODEL).to(DEVICE)
54
+ rmbg_net.eval()
55
+
56
+ snapshot_download("VAST-AI/TripoSG", local_dir=TRIPOSG_PRETRAINED_MODEL)
57
+ triposg_pipe = TripoSGPipeline.from_pretrained(TRIPOSG_PRETRAINED_MODEL).to(
58
+ DEVICE, DTYPE
59
+ )
60
+
61
+
62
+ def _session_dir(req: gr.Request | None) -> str:
63
+ if req is None:
64
+ return TMP_DIR
65
+ save_dir = os.path.join(TMP_DIR, str(req.session_hash))
66
+ os.makedirs(save_dir, exist_ok=True)
67
+ return save_dir
68
+
69
+
70
+ def _unique_glb_path(save_dir: str) -> str:
71
+ return os.path.join(save_dir, f"triposg_{uuid.uuid4().hex}.glb")
72
+
73
+
74
+ def _run_triposg(
75
+ image_path: str,
76
+ seed: int,
77
+ num_inference_steps: int,
78
+ guidance_scale: float,
79
+ simplify: bool,
80
+ target_face_num: int,
81
+ req: gr.Request | None = None,
82
+ ):
83
+ if not image_path:
84
+ raise gr.Error("Upload an image first.")
85
+
86
+ image_seg = prepare_image(
87
+ image_path, bg_color=np.array([1.0, 1.0, 1.0]), rmbg_net=rmbg_net
88
+ )
89
+
90
+ generator = torch.Generator(device=triposg_pipe.device).manual_seed(seed)
91
+ outputs = triposg_pipe(
92
+ image=image_seg,
93
+ generator=generator,
94
+ num_inference_steps=num_inference_steps,
95
+ guidance_scale=guidance_scale,
96
+ ).samples[0]
97
+
98
+ mesh = trimesh.Trimesh(outputs[0].astype(np.float32), np.ascontiguousarray(outputs[1]))
99
+
100
+ if simplify:
101
+ mesh = simplify_mesh(mesh, target_face_num)
102
+
103
+ save_dir = _session_dir(req)
104
+ mesh_path = _unique_glb_path(save_dir)
105
+ mesh.export(mesh_path)
106
+
107
+ return image_seg, mesh_path
108
+
109
+
110
+ @gpu(duration=180)
111
+ @torch.no_grad()
112
+ def generate_mesh(
113
+ image_path: str,
114
+ seed: int,
115
+ num_inference_steps: int,
116
+ guidance_scale: float,
117
+ simplify: bool,
118
+ target_face_num: int,
119
+ req: gr.Request | None = None,
120
+ ):
121
+ image_seg, mesh_path = _run_triposg(
122
+ image_path,
123
+ seed,
124
+ num_inference_steps,
125
+ guidance_scale,
126
+ simplify,
127
+ target_face_num,
128
+ req,
129
+ )
130
+ if torch.cuda.is_available():
131
+ torch.cuda.empty_cache()
132
+ return image_seg, mesh_path
133
+
134
+
135
+ @gpu(duration=180)
136
+ @torch.no_grad()
137
+ def api_generate(
138
+ image_path: str,
139
+ seed: int,
140
+ num_inference_steps: int,
141
+ guidance_scale: float,
142
+ simplify: bool,
143
+ target_face_num: int,
144
+ req: gr.Request | None = None,
145
+ ):
146
+ _, mesh_path = _run_triposg(
147
+ image_path,
148
+ seed,
149
+ num_inference_steps,
150
+ guidance_scale,
151
+ simplify,
152
+ target_face_num,
153
+ req,
154
+ )
155
+ if torch.cuda.is_available():
156
+ torch.cuda.empty_cache()
157
+ return mesh_path
158
+
159
+
160
+ def _cleanup_session(req: gr.Request):
161
+ save_dir = os.path.join(TMP_DIR, str(req.session_hash))
162
+ if os.path.exists(save_dir):
163
+ shutil.rmtree(save_dir)
164
+
165
+
166
+ TITLE = "TripoSG Image-to-3D API"
167
+ DESCRIPTION = (
168
+ "Upload a single-object image to generate a 3D mesh (GLB). "
169
+ "This demo exposes a /predict API endpoint."
170
+ )
171
+
172
+
173
+ with gr.Blocks(title=TITLE) as demo:
174
+ gr.Markdown(f"# {TITLE}\n\n{DESCRIPTION}")
175
+
176
+ with gr.Row():
177
+ with gr.Column():
178
+ image_input = gr.Image(label="Input Image", type="filepath")
179
+ seg_output = gr.Image(
180
+ label="Segmentation Preview", type="pil", format="png"
181
+ )
182
+
183
+ with gr.Accordion("Generation Settings", open=True):
184
+ seed = gr.Slider(
185
+ label="Seed", minimum=0, maximum=2**31 - 1, step=1, value=0
186
+ )
187
+ steps = gr.Slider(
188
+ label="Inference Steps", minimum=8, maximum=50, step=1, value=50
189
+ )
190
+ guidance = gr.Slider(
191
+ label="CFG Scale", minimum=0.0, maximum=20.0, step=0.1, value=7.5
192
+ )
193
+ simplify = gr.Checkbox(label="Simplify Mesh", value=True)
194
+ face_count = gr.Slider(
195
+ label="Target Face Count",
196
+ minimum=10000,
197
+ maximum=1000000,
198
+ step=1000,
199
+ value=100000,
200
+ )
201
+
202
+ generate_btn = gr.Button("Generate 3D", variant="primary")
203
+
204
+ with gr.Column():
205
+ model_output = gr.Model3D(label="Generated GLB", interactive=False)
206
+ file_output = gr.File(label="Download GLB", interactive=False)
207
+
208
+ generate_btn.click(
209
+ generate_mesh,
210
+ inputs=[image_input, seed, steps, guidance, simplify, face_count],
211
+ outputs=[seg_output, model_output],
212
+ ).then(lambda path: path, inputs=model_output, outputs=file_output)
213
+
214
+ api_btn = gr.Button(visible=False)
215
+ api_btn.click(
216
+ api_generate,
217
+ inputs=[image_input, seed, steps, guidance, simplify, face_count],
218
+ outputs=[file_output],
219
+ api_name="/predict",
220
+ )
221
+
222
+ demo.unload(_cleanup_session)
223
+
224
+
225
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diffusers==0.32.2
2
+ trimesh
3
+ pillow
4
+ spandrel==0.4.0
5
+ plyfile==1.1
6
+ xformers
7
+ pymcubes==0.1.4
8
+ shapely
9
+ mkl==2022.0.2
10
+ nvdiffrast
11
+ cvcuda_cu12==0.6.0.16
12
+ triton==3.1.0
13
+ imageio==2.36.0
14
+ numpy==1.26.4
15
+ scipy==1.13.1
16
+ tqdm==4.67.1
17
+ opencv-python
18
+ open3d==0.18.0
19
+ pymeshlab
20
+ ninja==1.11.1.3
21
+ matplotlib
22
+
23
+ diso @ https://github.com/Chumbyte/DiSO/releases/download/v0.1.4/diso-0.1.4-cp310-cp310-linux_x86_64.whl
utils.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import open3d as o3d
3
+ import pymeshlab as pml
4
+
5
+
6
+ def simplify_mesh(mesh, target_face_num: int = 100000):
7
+ if mesh.faces.shape[0] <= target_face_num:
8
+ return mesh
9
+
10
+ vertices = mesh.vertices
11
+ faces = mesh.faces
12
+
13
+ ms = pml.MeshSet()
14
+ ms.add_mesh(pml.Mesh(vertices, faces))
15
+ ms.meshing_decimation_quadric_edge_collapse(
16
+ targetfacenum=int(target_face_num), preserveboundary=True
17
+ )
18
+
19
+ new_mesh = ms.current_mesh()
20
+ new_vertices = new_mesh.vertex_matrix()
21
+ new_faces = new_mesh.face_matrix()
22
+
23
+ o3d_mesh = o3d.geometry.TriangleMesh(
24
+ o3d.utility.Vector3dVector(new_vertices),
25
+ o3d.utility.Vector3iVector(new_faces),
26
+ )
27
+ o3d_mesh = o3d_mesh.remove_duplicated_vertices()
28
+ o3d_mesh = o3d_mesh.remove_degenerate_triangles()
29
+ o3d_mesh = o3d_mesh.remove_non_manifold_edges()
30
+ o3d_mesh = o3d_mesh.remove_unreferenced_vertices()
31
+
32
+ return mesh.__class__(
33
+ vertices=np.asarray(o3d_mesh.vertices),
34
+ faces=np.asarray(o3d_mesh.triangles),
35
+ vertex_normals=np.asarray(o3d_mesh.vertex_normals),
36
+ process=False,
37
+ )