mouttaki commited on
Commit
7632107
·
1 Parent(s): 0a24ac3

fix: simplify app.py for Hugging Face deployment

Browse files
Files changed (1) hide show
  1. app.py +242 -50
app.py CHANGED
@@ -1,30 +1,24 @@
1
- # tout en haut (cache persistant pour éviter de re-télécharger à chaque boot)
2
  import os
3
- os.environ.setdefault("HF_HOME", "/data/.cache/huggingface")
4
- os.environ.setdefault("HUGGINGFACE_HUB_CACHE", "/data/.cache/huggingface")
5
- os.environ.setdefault("TRANSFORMERS_CACHE", "/data/.cache/huggingface")
6
- os.environ.setdefault("XDG_CACHE_HOME", "/data/.cache")
7
- os.environ.setdefault("U2NET_HOME", "/data/.u2net")
8
-
 
9
  import argparse
10
  import torch
 
11
 
12
- # -------- args --------
 
 
13
  parser = argparse.ArgumentParser()
14
  parser.add_argument("--model_path", type=str, default="tencent/Hunyuan3D-2mini")
15
  parser.add_argument("--subfolder", type=str, default="hunyuan3d-dit-v2-mini-turbo")
16
  parser.add_argument("--texgen_model_path", type=str, default="tencent/Hunyuan3D-2")
17
  parser.add_argument("--port", type=int, default=7860)
18
  parser.add_argument("--host", type=str, default="0.0.0.0")
19
-
20
- # auto-device: cuda > mps > cpu
21
- _auto_device = (
22
- "cuda" if torch.cuda.is_available()
23
- else ("mps" if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available()
24
- else "cpu")
25
- )
26
- parser.add_argument("--device", type=str, default=_auto_device)
27
- parser.add_argument("--mc_algo", type=str, default="mc")
28
  parser.add_argument("--cache_path", type=str, default="gradio_cache")
29
  parser.add_argument("--enable_t23d", action="store_true")
30
  parser.add_argument("--disable_tex", action="store_true")
@@ -32,44 +26,242 @@ parser.add_argument("--enable_flashvdm", action="store_true")
32
  parser.add_argument("--compile", action="store_true")
33
  parser.add_argument("--low_vram_mode", action="store_true")
34
  args = parser.parse_args()
 
35
 
36
- # fp16 sur cuda, sinon fp32
37
- DTYPE = torch.float16 if args.device == "cuda" else torch.float32
 
 
 
 
 
38
 
39
- # Si flashvdm activé, ne l’utilise pas en CPU/MPS
40
- args.enable_flashvdm = args.enable_flashvdm or True # si tu veux garder True par défaut
41
- if args.device in ["cpu", "mps"]:
42
- args.enable_flashvdm = False # évite les features GPU-only
43
 
44
- # -------- lazy load des workers --------
45
- from functools import lru_cache
46
 
47
- @lru_cache(maxsize=1)
48
- def get_workers():
49
- from hy3dgen.shapegen import (
50
- FaceReducer, FloaterRemover, DegenerateFaceRemover, Hunyuan3DDiTFlowMatchingPipeline
51
- )
52
- from hy3dgen.rembg import BackgroundRemover
53
-
54
- rmbg_worker = BackgroundRemover()
55
- i23d_worker = Hunyuan3DDiTFlowMatchingPipeline.from_pretrained(
56
- args.model_path,
57
- subfolder=args.subfolder,
58
- use_safetensors=True,
59
- device=args.device, # la pipeline sait gérer cpu/mps/cuda
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  )
61
 
62
- # cast correct pour le device courant
63
- try:
64
- i23d_worker.to(args.device, DTYPE)
65
- except Exception:
66
- # dernier filet de secours pour CPU
67
- i23d_worker.to("cpu", torch.float32)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
- if args.enable_flashvdm and args.device == "cuda":
70
- i23d_worker.enable_flashvdm(mc_algo=args.mc_algo)
 
 
71
 
72
- floater_remove_worker = FloaterRemover()
73
- degenerate_face_remove_worker = DegenerateFaceRemover()
74
- face_reduce_worker = FaceReducer()
75
- return rmbg_worker, i23d_worker, floater_remove_worker, degenerate_face_remove_worker, face_reduce_worker
 
 
1
  import os
2
+ import spaces
3
+ import random
4
+ import shutil
5
+ import gradio as gr
6
+ from glob import glob
7
+ from pathlib import Path
8
+ import uuid
9
  import argparse
10
  import torch
11
+ import trimesh
12
 
13
+ # -------------------
14
+ # Arguments & Config
15
+ # -------------------
16
  parser = argparse.ArgumentParser()
17
  parser.add_argument("--model_path", type=str, default="tencent/Hunyuan3D-2mini")
18
  parser.add_argument("--subfolder", type=str, default="hunyuan3d-dit-v2-mini-turbo")
19
  parser.add_argument("--texgen_model_path", type=str, default="tencent/Hunyuan3D-2")
20
  parser.add_argument("--port", type=int, default=7860)
21
  parser.add_argument("--host", type=str, default="0.0.0.0")
 
 
 
 
 
 
 
 
 
22
  parser.add_argument("--cache_path", type=str, default="gradio_cache")
23
  parser.add_argument("--enable_t23d", action="store_true")
24
  parser.add_argument("--disable_tex", action="store_true")
 
26
  parser.add_argument("--compile", action="store_true")
27
  parser.add_argument("--low_vram_mode", action="store_true")
28
  args = parser.parse_args()
29
+ args.enable_flashvdm = True
30
 
31
+ # Choix auto du device
32
+ if torch.cuda.is_available():
33
+ args.device = "cuda"
34
+ elif torch.backends.mps.is_available():
35
+ args.device = "mps"
36
+ else:
37
+ args.device = "cpu"
38
 
39
+ SAVE_DIR = args.cache_path
40
+ os.makedirs(SAVE_DIR, exist_ok=True)
 
 
41
 
42
+ CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
 
43
 
44
+ HTML_HEIGHT = 500
45
+ HTML_WIDTH = 500
46
+ MAX_SEED = 1e7
47
+
48
+ # -------------------
49
+ # Utils
50
+ # -------------------
51
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
52
+ if randomize_seed:
53
+ seed = random.randint(0, MAX_SEED)
54
+ return seed
55
+
56
+ def gen_save_folder(max_size=200):
57
+ os.makedirs(SAVE_DIR, exist_ok=True)
58
+ dirs = [f for f in Path(SAVE_DIR).iterdir() if f.is_dir()]
59
+ if len(dirs) >= max_size:
60
+ oldest_dir = min(dirs, key=lambda x: x.stat().st_ctime)
61
+ shutil.rmtree(oldest_dir)
62
+ print(f"Removed oldest folder: {oldest_dir}")
63
+ new_folder = os.path.join(SAVE_DIR, str(uuid.uuid4()))
64
+ os.makedirs(new_folder, exist_ok=True)
65
+ return new_folder
66
+
67
+ def export_mesh(mesh, save_folder, textured=False, type="glb"):
68
+ if textured:
69
+ path = os.path.join(save_folder, f"textured_mesh.{type}")
70
+ else:
71
+ path = os.path.join(save_folder, f"white_mesh.{type}")
72
+ if type not in ["glb", "obj"]:
73
+ mesh.export(path)
74
+ else:
75
+ mesh.export(path, include_normals=textured)
76
+ return path
77
+
78
+ def build_model_viewer_html(save_folder, height=660, width=790, textured=False):
79
+ if textured:
80
+ related_path = f"./textured_mesh.glb"
81
+ template_name = "./assets/modelviewer-textured-template.html"
82
+ output_html_path = os.path.join(save_folder, f"textured_mesh.html")
83
+ else:
84
+ related_path = f"./white_mesh.glb"
85
+ template_name = "./assets/modelviewer-template.html"
86
+ output_html_path = os.path.join(save_folder, f"white_mesh.html")
87
+
88
+ offset = 50 if textured else 10
89
+ with open(os.path.join(CURRENT_DIR, template_name), "r", encoding="utf-8") as f:
90
+ template_html = f.read()
91
+
92
+ with open(output_html_path, "w", encoding="utf-8") as f:
93
+ template_html = template_html.replace("#height#", f"{height - offset}")
94
+ template_html = template_html.replace("#width#", f"{width}")
95
+ template_html = template_html.replace("#src#", f"{related_path}/")
96
+ f.write(template_html)
97
+
98
+ rel_path = os.path.relpath(output_html_path, SAVE_DIR)
99
+ iframe_tag = f'<iframe src="/static/{rel_path}" height="{height}" width="100%" frameborder="0"></iframe>'
100
+ return f"<div style='height: {height}; width: 100%;'>{iframe_tag}</div>"
101
+
102
+ # -------------------
103
+ # Model loaders
104
+ # -------------------
105
+ from hy3dgen.shapegen import (
106
+ FaceReducer,
107
+ FloaterRemover,
108
+ DegenerateFaceRemover,
109
+ MeshSimplifier,
110
+ Hunyuan3DDiTFlowMatchingPipeline,
111
+ )
112
+ from hy3dgen.shapegen.pipelines import export_to_trimesh
113
+ from hy3dgen.rembg import BackgroundRemover
114
+
115
+ rmbg_worker = BackgroundRemover()
116
+ i23d_worker = Hunyuan3DDiTFlowMatchingPipeline.from_pretrained(
117
+ args.model_path,
118
+ subfolder=args.subfolder,
119
+ use_safetensors=True,
120
+ device=args.device,
121
+ )
122
+ if args.enable_flashvdm:
123
+ mc_algo = "mc" if args.device in ["cpu", "mps"] else "mc"
124
+ i23d_worker.enable_flashvdm(mc_algo=mc_algo)
125
+ if args.compile:
126
+ i23d_worker.compile()
127
+
128
+ floater_remove_worker = FloaterRemover()
129
+ degenerate_face_remove_worker = DegenerateFaceRemover()
130
+ face_reduce_worker = FaceReducer()
131
+
132
+ progress = gr.Progress()
133
+
134
+ # -------------------
135
+ # Main function
136
+ # -------------------
137
+ # ⚠️ Désactive si pas de GPU sur ton Space
138
+ # @spaces.GPU(duration=40)
139
+ def gen_shape(
140
+ image=None,
141
+ steps=50,
142
+ guidance_scale=7.5,
143
+ seed=1234,
144
+ octree_resolution=256,
145
+ num_chunks=200000,
146
+ target_face_num=10000,
147
+ randomize_seed: bool = False,
148
+ ):
149
+ progress(0, desc="Starting")
150
+
151
+ def callback(step_idx, timestep, outputs):
152
+ progress_value = ((step_idx + 1.0) / steps) * (0.5 / 1.0)
153
+ progress(progress_value, desc=f"Mesh generating, {step_idx + 1}/{steps} steps")
154
+
155
+ if image is None:
156
+ raise gr.Error("Please provide an image.")
157
+
158
+ seed = int(randomize_seed_fn(seed, randomize_seed))
159
+ save_folder = gen_save_folder()
160
+ image = rmbg_worker(image.convert("RGB"))
161
+
162
+ generator = torch.Generator().manual_seed(seed)
163
+ outputs = i23d_worker(
164
+ image=image,
165
+ num_inference_steps=steps,
166
+ guidance_scale=guidance_scale,
167
+ generator=generator,
168
+ octree_resolution=octree_resolution,
169
+ num_chunks=num_chunks,
170
+ output_type="mesh",
171
+ callback=callback,
172
+ callback_steps=1,
173
  )
174
 
175
+ mesh = export_to_trimesh(outputs)[0]
176
+ path = export_mesh(mesh, save_folder, textured=False)
177
+
178
+ if args.low_vram_mode:
179
+ torch.cuda.empty_cache()
180
+ if path is None:
181
+ raise gr.Error("Mesh generation failed.")
182
+
183
+ mesh = trimesh.load(path)
184
+ progress(0.5, desc="Optimizing mesh")
185
+ mesh = floater_remove_worker(mesh)
186
+ mesh = degenerate_face_remove_worker(mesh)
187
+ progress(0.6, desc="Reducing mesh faces")
188
+ mesh = face_reduce_worker(mesh, target_face_num)
189
+
190
+ save_folder = gen_save_folder()
191
+ progress(0.9, desc="Converting format")
192
+ sourceObjPath = export_mesh(mesh, save_folder, textured=False, type="obj")
193
+
194
+ rel_objPath = os.path.relpath(sourceObjPath, SAVE_DIR)
195
+ objPath = "/static/" + rel_objPath
196
+
197
+ save_folder = gen_save_folder()
198
+ _ = export_mesh(mesh, save_folder, textured=False)
199
+ model_viewer_html = build_model_viewer_html(save_folder, height=HTML_HEIGHT, width=HTML_WIDTH, textured=False)
200
+
201
+ glbPath = os.path.join(save_folder, f"white_mesh.glb")
202
+ rel_glbPath = os.path.relpath(glbPath, SAVE_DIR)
203
+ glbPath = "/static/" + rel_glbPath
204
+
205
+ progress(1, desc="Complete")
206
+ return model_viewer_html, gr.update(value=sourceObjPath, interactive=True), glbPath, objPath
207
+
208
+ # -------------------
209
+ # Gradio UI
210
+ # -------------------
211
+ def get_example_img_list():
212
+ return sorted(glob("./assets/example_images/**/*.png", recursive=True))
213
+ example_imgs = get_example_img_list()
214
+
215
+ HTML_OUTPUT_PLACEHOLDER = """
216
+ <div style='height: 500px; width: 100%; border-radius: 8px; border: 1px solid #e5e7eb; display: flex; justify-content: center; align-items: center;'>
217
+ <div style="text-align: center; font-size: 16px; color: #6b7280;">
218
+ <p style="color: #8d8d8d;">No mesh here.</p>
219
+ </div>
220
+ </div>
221
+ """
222
+
223
+ title = "## AI 3D Model Generator"
224
+ description = "Upload an image and generate a 3D mesh using Hunyuan 3D. Ready for AR/VR, games or 3D printing."
225
+
226
+ with gr.Blocks().queue() as demo:
227
+ gr.Markdown(title)
228
+ gr.Markdown(description)
229
+ with gr.Row():
230
+ with gr.Column(scale=3):
231
+ image = gr.Image(sources=["upload"], label="Image", type="pil", image_mode="RGBA", height=290)
232
+ gen_button = gr.Button(value="Generate Shape", variant="primary")
233
+ with gr.Accordion("Advanced Options", open=False):
234
+ with gr.Column():
235
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=1234)
236
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
237
+ with gr.Column():
238
+ num_steps = gr.Slider(maximum=100, minimum=1, value=5, step=1, label="Inference Steps")
239
+ octree_resolution = gr.Slider(maximum=512, minimum=16, value=256, label="Octree Resolution")
240
+ with gr.Column():
241
+ cfg_scale = gr.Slider(maximum=20.0, minimum=1.0, value=5.5, step=0.1, label="Guidance Scale")
242
+ num_chunks = gr.Slider(maximum=5000000, minimum=1000, value=8000, label="Number of Chunks")
243
+ target_face_num = gr.Slider(maximum=1000000, minimum=100, value=10000, label="Target Face Number")
244
+
245
+ with gr.Column(scale=6):
246
+ html_export_mesh = gr.HTML(HTML_OUTPUT_PLACEHOLDER, label="Output")
247
+ file_export = gr.DownloadButton(label="Download", variant="primary", interactive=False)
248
+ with gr.Row():
249
+ objPath_output = gr.Text(label="Obj Path", interactive=False)
250
+ glbPath_output = gr.Text(label="Glb Path", interactive=False)
251
+
252
+ with gr.Column(scale=3):
253
+ gr.Examples(examples=example_imgs, inputs=[image], examples_per_page=18)
254
+
255
+ gen_button.click(
256
+ fn=gen_shape,
257
+ inputs=[image, num_steps, cfg_scale, seed, octree_resolution, num_chunks, target_face_num, randomize_seed],
258
+ outputs=[html_export_mesh, file_export, glbPath_output, objPath_output],
259
+ )
260
 
261
+ # -------------------
262
+ # Hugging Face Spaces entrypoint
263
+ # -------------------
264
+ app = demo
265
 
266
+ if __name__ == "__main__":
267
+ demo.launch(server_name=args.host, server_port=args.port)