theYiran commited on
Commit
79b41c0
·
verified ·
1 Parent(s): ae6e06d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +339 -76
app.py CHANGED
@@ -1,23 +1,13 @@
1
  import os
2
  import sys
3
- import subprocess
4
  import importlib
5
  import site
6
  import time
7
- import uuid
8
- import shutil
9
- import glob
10
- from types import ModuleType
11
-
12
- # ========================================================
13
- # 1. 核心修复:路径环境变量与内存级伪造 diso
14
- # ========================================================
15
- # 解决 KeyError: 'PARTCRAFTER_PROCESSED'
16
- os.environ["PARTCRAFTER_PROCESSED"] = os.environ.get("PARTCRAFTER_PROCESSED", "outputs")
17
- os.makedirs(os.environ["PARTCRAFTER_PROCESSED"], exist_ok=True)
18
- os.environ['PYOPENGL_PLATFORM'] = 'egl'
19
 
 
20
  def mock_diso():
 
21
  print("🧪 Creating emergency mock for diso...")
22
  diso = ModuleType("diso")
23
  class FakeDiffDMC:
@@ -27,51 +17,155 @@ def mock_diso():
27
  sys.modules["diso"] = diso
28
  sys.modules["diso._C"] = ModuleType("diso._C")
29
  sys.modules["diso.diso_native"] = ModuleType("diso.diso_native")
30
- print("✅ diso mocked.")
31
 
32
  mock_diso()
33
 
34
- # ========================================================
35
- # 2. 极速环境安装 (避开编译,解决超时)
36
- # ========================================================
37
- def setup_environment():
38
- print("🚀 Initializing Optimized Environment...")
39
- # 强制降级 torch 确保与预编译包匹配
40
- try:
41
- import torch
42
- if "2.9" in torch.__version__:
43
- print("🔄 Downgrading torch to 2.4.0...")
44
- subprocess.run([sys.executable, "-m", "pip", "install", "torch==2.4.0+cu121", "torchvision==0.19.0+cu121", "--extra-index-url", "https://download.pytorch.org/whl/cu121"], check=True)
45
- importlib.invalidate_caches()
46
- os.execv(sys.executable, ['python'] + sys.argv)
47
- except: pass
48
-
49
- # 极速安装 PyG 扩展和渲染工具
50
  subprocess.run([
51
  sys.executable, "-m", "pip", "install",
52
- "torch-scatter", "torch-sparse", "torch-cluster", "ninja", "pyrender", "pyopengl==3.1.0", "trimesh", "accelerate",
53
- "-f", "https://data.pyg.org/whl/torch-2.4.0+cu121.html", "--no-cache-dir"
 
 
 
 
 
 
 
54
  ])
55
 
56
  importlib.invalidate_caches()
57
  site.main()
58
- print("🎉 Environment ready!")
 
 
 
 
 
 
 
 
 
 
59
 
60
- # 只有环境未就绪时运行一次
61
- if "torch-scatter" not in str(subprocess.check_output([sys.executable, "-m", "pip", "freeze"])):
62
- setup_environment()
63
 
64
- # ========================================================
65
- # 3. 业务代码导入 (保留你原本的所有引用)
66
- # ========================================================
67
  import spaces
68
  import gradio as gr
69
  import numpy as np
70
  import torch
 
 
71
  from huggingface_hub import snapshot_download
72
  from PIL import Image
73
  from accelerate.utils import set_seed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  import trimesh
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
  from src.utils.data_utils import get_colored_mesh_composition, scene_to_parts, load_surfaces
77
  from src.utils.render_utils import render_views_around_mesh, render_normal_views_around_mesh, make_grid_for_images_or_videos, export_renderings, explode_mesh
@@ -79,13 +173,12 @@ from src.pipelines.pipeline_partcrafter import PartCrafterPipeline
79
  from src.utils.image_utils import prepare_image
80
  from src.models.briarmbg import BriaRMBG
81
 
82
- # ========================================================
83
- # 4. 业务逻辑 (100% 保留你代码中的参数与函数)
84
- # ========================================================
85
  MAX_NUM_PARTS = 16
86
  DEVICE = "cuda"
87
  DTYPE = torch.float16
88
 
 
89
  partcrafter_weights_dir = "pretrained_weights/PartCrafter"
90
  rmbg_weights_dir = "pretrained_weights/RMBG-1.4"
91
  snapshot_download(repo_id="wgsxm/PartCrafter", local_dir=partcrafter_weights_dir)
@@ -99,34 +192,119 @@ def first_file_from_dir(directory, ext):
99
  files = glob.glob(os.path.join(directory, f"*.{ext}"))
100
  return sorted(files)[0] if files else None
101
 
102
- def get_duration(image_path, num_parts, seed, num_tokens, num_steps, guidance, flash, rmbg, session, progress):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  duration_seconds = 75
104
- if num_parts > 10: duration_seconds = 120
105
- elif num_parts > 5: duration_seconds = 90
 
 
 
 
106
  return int(duration_seconds)
 
107
 
108
  @spaces.GPU(duration=140)
109
- def gen_model_n_video(image_path, num_parts, progress=gr.Progress(track_tqdm=True)):
 
 
 
110
  model_path = run_partcrafter(image_path, num_parts=num_parts, progress=progress)
111
  video_path = gen_video(model_path)
 
112
  return model_path, video_path
113
 
114
  @spaces.GPU()
115
  def gen_video(model_path):
 
116
  if model_path is None:
117
  gr.Info("You must craft the 3d parts first")
 
118
  return None
 
119
  export_dir = os.path.dirname(model_path)
 
120
  merged = trimesh.load(model_path)
 
121
  preview_path = os.path.join(export_dir, "rendering.gif")
122
- rendered_images = render_views_around_mesh(merged, num_views=36, radius=4)
123
- export_renderings(rendered_images, preview_path, fps=7)
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  return preview_path
125
 
126
  @spaces.GPU(duration=get_duration)
127
  @torch.no_grad()
128
- def run_partcrafter(image_path, num_parts=1, seed=0, num_tokens=1024, num_inference_steps=50, guidance_scale=7.0, use_flash_decoder=False, rmbg=True, session_id=None, progress=gr.Progress(track_tqdm=True)):
129
- if session_id is None: session_id = uuid.uuid4().hex
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  if rmbg:
131
  img_pil = prepare_image(image_path, bg_color=np.array([1.0, 1.0, 1.0]), rmbg_net=rmbg_net)
132
  else:
@@ -141,47 +319,92 @@ def run_partcrafter(image_path, num_parts=1, seed=0, num_tokens=1024, num_infere
141
  generator=torch.Generator(device=pipe.device).manual_seed(seed),
142
  num_inference_steps=num_inference_steps,
143
  guidance_scale=guidance_scale,
144
- max_num_expanded_coords=1e9,
145
  use_flash_decoder=use_flash_decoder,
146
  ).meshes
147
- print(f"Generation time: {time.time() - start_time:.2f}s")
 
148
 
 
149
  for i, mesh in enumerate(outputs):
150
- if mesh is None: outputs[i] = trimesh.Trimesh(vertices=[[0,0,0]], faces=[[0,0,0]])
 
 
151
 
152
  export_dir = os.path.join(os.environ["PARTCRAFTER_PROCESSED"], session_id)
153
- if os.path.exists(export_dir): shutil.rmtree(export_dir)
 
 
 
 
154
  os.makedirs(export_dir, exist_ok=True)
155
 
 
 
156
  for idx, mesh in enumerate(outputs):
157
- mesh.export(os.path.join(export_dir, f"part_{idx:02}.glb"))
 
 
158
 
 
159
  merged = get_colored_mesh_composition(outputs)
 
 
160
  merged_path = os.path.join(export_dir, "object.glb")
161
  merged.export(merged_path)
 
162
  return merged_path
163
 
164
- # ========================================================
165
- # 5. UI 界面逻辑 (完全保留你原来的 CSS 和 Examples)
166
- # ========================================================
167
  def cleanup(request: gr.Request):
 
168
  sid = request.session_hash
169
  if sid:
170
- shutil.rmtree(os.path.join(os.environ["PARTCRAFTER_PROCESSED"], sid), ignore_errors=True)
 
 
 
171
 
 
 
172
  def build_demo():
173
- css = "#col-container { margin: 0 auto; max-width: 1560px; }"
174
- with gr.Blocks(css=css, theme=gr.themes.Ocean()) as demo:
 
 
 
 
 
 
 
175
  session_state = gr.State()
176
- demo.load(lambda r: r.session_hash, outputs=[session_state])
 
177
  with gr.Column(elem_id="col-container"):
178
- gr.HTML("<div style='text-align:center;'><strong>PartCrafter</strong> – Structured 3D Mesh Generation</div>")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  with gr.Row():
180
  with gr.Column(scale=1):
 
181
  input_image = gr.Image(type="filepath", label="Input Image", height=256)
182
  num_parts = gr.Slider(1, MAX_NUM_PARTS, value=4, step=1, label="Number of Parts")
183
  run_button = gr.Button("Step 1 - 🧩 Craft 3D Parts", variant="primary")
184
  video_button = gr.Button("Step 2 - 🎥 Generate Split Preview Gif (Optional)")
 
185
  with gr.Accordion("Advanced Settings", open=False):
186
  seed = gr.Number(value=0, label="Random Seed", precision=0)
187
  num_tokens = gr.Slider(256, 2048, value=1024, step=64, label="Num Tokens")
@@ -189,19 +412,59 @@ def build_demo():
189
  guidance = gr.Slider(1.0, 20.0, value=7.0, step=0.1, label="Guidance Scale")
190
  flash_decoder = gr.Checkbox(value=False, label="Use Flash Decoder")
191
  remove_bg = gr.Checkbox(value=True, label="Remove Background (RMBG)")
 
192
  with gr.Column(scale=2):
193
- output_model = gr.Model3D(label="Merged 3D Object", height=512)
194
- video_output = gr.Image(label="Split Preview", height=512)
195
- gr.Examples(
196
- examples=[["assets/images/np5_b81f29e567ea4db48014f89c9079e403.png", 5]],
197
- inputs=[input_image, num_parts],
198
- outputs=[output_model, video_output],
199
- fn=gen_model_n_video,
200
- cache_examples=True
201
- )
202
- run_button.click(fn=run_partcrafter, inputs=[input_image, num_parts, seed, num_tokens, num_steps, guidance, flash_decoder, remove_bg, session_state], outputs=[output_model])
203
- video_button.click(fn=gen_video, inputs=[output_model], outputs=[video_output])
204
- return demo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
 
206
  if __name__ == "__main__":
207
- build_demo().launch()
 
 
 
 
1
  import os
2
  import sys
3
+ import subprocess # <--- 确保这行在这里!
4
  import importlib
5
  import site
6
  import time
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
+ # --- 🧪 1. 内存级伪造 diso (必须在任何业务 import 之前) ---
9
  def mock_diso():
10
+ from types import ModuleType
11
  print("🧪 Creating emergency mock for diso...")
12
  diso = ModuleType("diso")
13
  class FakeDiffDMC:
 
17
  sys.modules["diso"] = diso
18
  sys.modules["diso._C"] = ModuleType("diso._C")
19
  sys.modules["diso.diso_native"] = ModuleType("diso.diso_native")
20
+ print("✅ diso has been mocked successfully!")
21
 
22
  mock_diso()
23
 
24
+ # --- 🚀 2. 极速环境安装 (已经成功的 scatter/sparse) ---
25
+ def install_essential_packages():
26
+ print("📦 Checking core dependencies...")
27
+ # 确保基础环境正确
28
+ subprocess.run([sys.executable, "-m", "pip", "install", "ninja", "setuptools", "wheel", "-q"])
29
+
30
+ # 极速安装 PyG 扩展
 
 
 
 
 
 
 
 
 
31
  subprocess.run([
32
  sys.executable, "-m", "pip", "install",
33
+ "torch-scatter", "torch-sparse", "torch-cluster",
34
+ "-f", "https://data.pyg.org/whl/torch-2.4.0+cu121.html",
35
+ "--no-cache-dir"
36
+ ])
37
+
38
+ # 安装剩下的渲染工具
39
+ subprocess.run([
40
+ sys.executable, "-m", "pip", "install",
41
+ "pyrender", "pyopengl==3.1.0", "pyyaml", "trimesh", "accelerate", "-q"
42
  ])
43
 
44
  importlib.invalidate_caches()
45
  site.main()
46
+ print("🎉 Environment Installation Phase Finished.")
47
+
48
+ install_essential_packages()
49
+
50
+
51
+ # ... 之前的 mock_diso 和安装逻辑 ...
52
+
53
+
54
+ # 1. 核心路径保护
55
+ os.environ["PARTCRAFTER_PROCESSED"] = os.environ.get("PARTCRAFTER_PROCESSED", "outputs")
56
+ os.makedirs(os.environ["PARTCRAFTER_PROCESSED"], exist_ok=True)
57
 
58
+ # 2. 模型权重下载路径确认 (确保这些目录也存)
59
+ os.makedirs("pretrained_weights/PartCrafter", exist_ok=True)
60
+ os.makedirs("pretrained_weights/RMBG-1.4", exist_ok=True)
61
 
62
+ # ... 继续执行 snapshot_download ...
63
+
64
+ # --- 3. 正式导入业务逻辑 (现在开始这几百行代码就不会报错了) ---
65
  import spaces
66
  import gradio as gr
67
  import numpy as np
68
  import torch
69
+ import uuid
70
+ import shutil
71
  from huggingface_hub import snapshot_download
72
  from PIL import Image
73
  from accelerate.utils import set_seed
74
+
75
+ # 从这里往下,粘贴你原本所有的业务逻辑代码 (PartCrafterPipeline 等)
76
+ # ...
77
+
78
+ # --- 🚀 核心修复:强制版本回退以避开编译 ---
79
+ def pre_install_check():
80
+ try:
81
+ import torch
82
+ # 如果是 2.9+ 版本,强制降级到有预编译包的 2.4.0
83
+ if "2.9" in torch.__version__:
84
+ print(f"🔄 Current torch {torch.__version__} is too new. Downgrading to 2.4.0 for speed...")
85
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "ninja", "setuptools", "wheel", "-q"])
86
+ subprocess.check_call([
87
+ sys.executable, "-m", "pip", "install",
88
+ "torch==2.4.0+cu121", "torchvision==0.19.0+cu121",
89
+ "--extra-index-url", "https://download.pytorch.org/whl/cu121"
90
+ ])
91
+ # 刷新路径
92
+ importlib.invalidate_caches()
93
+ os.execv(sys.executable, ['python'] + sys.argv) # 重启进程以加载新版本
94
+ except Exception as e:
95
+ print(f"Pre-install check note: {e}")
96
+
97
+ pre_install_check()
98
+
99
  import trimesh
100
+ import glob
101
+ import importlib, site
102
+
103
+ # Re-discover all .pth/.egg-link files
104
+ for sitedir in site.getsitepackages():
105
+ site.addsitedir(sitedir)
106
+
107
+ importlib.invalidate_caches()
108
+
109
+ # --- 简化的 CUDA 环境配置 ---
110
+ def setup_cuda_env():
111
+ cuda_path = "/usr/local/cuda"
112
+ if os.path.exists(cuda_path):
113
+ os.environ["CUDA_HOME"] = cuda_path
114
+ os.environ["PATH"] = f"{cuda_path}/bin:{os.environ['PATH']}"
115
+ os.environ["LD_LIBRARY_PATH"] = f"{cuda_path}/lib64:{os.environ.get('LD_LIBRARY_PATH', '')}"
116
+ print(f"==> Using system CUDA at {cuda_path}")
117
+
118
+ setup_cuda_env()
119
+
120
+ # --- 🚀 针对 PyTorch 2.9.1 的优化源码编译方案 ---
121
+ # --- 🚀 暴力整合版:攻克 diso 最后的防线 ---
122
+ def install_heavy_packages():
123
+ os.environ['PYOPENGL_PLATFORM'] = 'egl'
124
+
125
+ # 1. PyG 扩展(这部分已经稳了,保持不动)
126
+ print("📦 Installing PyG extensions...")
127
+ subprocess.run([
128
+ sys.executable, "-m", "pip", "install",
129
+ "torch-scatter", "torch-sparse", "torch-cluster",
130
+ "-f", "https://data.pyg.org/whl/torch-2.4.0+cu121.html"
131
+ ], check=True)
132
+
133
+ # 2. 暴力解决 diso:克隆源码 -> 强行导入
134
+ print("🔥 Attempting D-Plan: Manual diso injection...")
135
+ diso_path = os.path.join(os.getcwd(), "diso_source")
136
+ if not os.path.exists(diso_path):
137
+ subprocess.run(["git", "clone", "https://github.com/SarahWeiii/diso.git", diso_path])
138
+
139
+ # 将 diso 的源码路径直接加入系统搜索路径
140
+ # 这样即使没有编译成功 .so 文件,Python 也能找到包结构
141
+ if diso_path not in sys.path:
142
+ sys.path.insert(0, diso_path)
143
+
144
+ # 3. 安装渲染和其他轻量级依赖
145
+ print("📦 Installing rendering tools...")
146
+ subprocess.run([sys.executable, "-m", "pip", "install", "pyrender", "pyopengl==3.1.0", "pyyaml", "-q"], check=True)
147
+
148
+ importlib.invalidate_caches()
149
+ print("🎉 Environment Installation Phase Finished.")
150
+
151
+ # 执行安装
152
+ install_heavy_packages()
153
+
154
+ # --- 🛰️ 关键:diso 导入补丁 ---
155
+ try:
156
+ import diso
157
+ print("✅ diso imported successfully!")
158
+ except ImportError:
159
+ # 如果还是报错,尝试将 diso 内部的包直接暴露出来
160
+ print("⚠️ diso import failed, applying emergency mock...")
161
+ diso_src_path = os.path.join(os.getcwd(), "diso_source")
162
+ sys.path.insert(0, diso_src_path)
163
+ # 强制让 Python 识别 diso 目录
164
+ importlib.invalidate_caches()
165
+
166
+ # ... 后续代码保持不变 ...
167
+
168
+
169
 
170
  from src.utils.data_utils import get_colored_mesh_composition, scene_to_parts, load_surfaces
171
  from src.utils.render_utils import render_views_around_mesh, render_normal_views_around_mesh, make_grid_for_images_or_videos, export_renderings, explode_mesh
 
173
  from src.utils.image_utils import prepare_image
174
  from src.models.briarmbg import BriaRMBG
175
 
176
+ # Constants
 
 
177
  MAX_NUM_PARTS = 16
178
  DEVICE = "cuda"
179
  DTYPE = torch.float16
180
 
181
+ # Download and initialize models
182
  partcrafter_weights_dir = "pretrained_weights/PartCrafter"
183
  rmbg_weights_dir = "pretrained_weights/RMBG-1.4"
184
  snapshot_download(repo_id="wgsxm/PartCrafter", local_dir=partcrafter_weights_dir)
 
192
  files = glob.glob(os.path.join(directory, f"*.{ext}"))
193
  return sorted(files)[0] if files else None
194
 
195
+
196
+
197
+ def get_duration(
198
+ image_path,
199
+ num_parts,
200
+ seed,
201
+ num_tokens,
202
+ num_inference_steps,
203
+ guidance_scale,
204
+ use_flash_decoder,
205
+ rmbg,
206
+ session_id,
207
+ progress,
208
+ ):
209
+
210
  duration_seconds = 75
211
+
212
+ if num_parts > 10:
213
+ duration_seconds = 120
214
+ elif num_parts > 5:
215
+ duration_seconds = 90
216
+
217
  return int(duration_seconds)
218
+
219
 
220
  @spaces.GPU(duration=140)
221
+ def gen_model_n_video(image_path: str,
222
+ num_parts: int,
223
+ progress=gr.Progress(track_tqdm=True),):
224
+
225
  model_path = run_partcrafter(image_path, num_parts=num_parts, progress=progress)
226
  video_path = gen_video(model_path)
227
+
228
  return model_path, video_path
229
 
230
  @spaces.GPU()
231
  def gen_video(model_path):
232
+
233
  if model_path is None:
234
  gr.Info("You must craft the 3d parts first")
235
+
236
  return None
237
+
238
  export_dir = os.path.dirname(model_path)
239
+
240
  merged = trimesh.load(model_path)
241
+
242
  preview_path = os.path.join(export_dir, "rendering.gif")
243
+
244
+ num_views = 36
245
+ radius = 4
246
+ fps = 7
247
+ rendered_images = render_views_around_mesh(
248
+ merged,
249
+ num_views=num_views,
250
+ radius=radius,
251
+ )
252
+
253
+ export_renderings(
254
+ rendered_images,
255
+ preview_path,
256
+ fps=fps,
257
+ )
258
  return preview_path
259
 
260
  @spaces.GPU(duration=get_duration)
261
  @torch.no_grad()
262
+ def run_partcrafter(image_path: str,
263
+ num_parts: int = 1,
264
+ seed: int = 0,
265
+ num_tokens: int = 1024,
266
+ num_inference_steps: int = 50,
267
+ guidance_scale: float = 7.0,
268
+ use_flash_decoder: bool = False,
269
+ rmbg: bool = True,
270
+ session_id = None,
271
+ progress=gr.Progress(track_tqdm=True),):
272
+
273
+ """
274
+ Generate structured 3D meshes from a 2D image using the PartCrafter pipeline.
275
+
276
+ This function takes a single 2D image as input and produces a set of part-based 3D meshes,
277
+ using compositional latent diffusion with attention to structure and part separation.
278
+ Optionally removes the background using a pretrained background removal model (RMBG),
279
+ and outputs a merged object mesh.
280
+
281
+ Args:
282
+ image_path (str): Path to the input image file on disk.
283
+ num_parts (int, optional): Number of distinct parts to decompose the object into. Defaults to 1.
284
+ seed (int, optional): Random seed for reproducibility. Defaults to 0.
285
+ num_tokens (int, optional): Number of tokens used during latent encoding. Higher values yield finer detail. Defaults to 1024.
286
+ num_inference_steps (int, optional): Number of diffusion inference steps. More steps improve quality but increase runtime. Defaults to 50.
287
+ guidance_scale (float, optional): Classifier-free guidance scale. Higher values emphasize adherence to conditioning. Defaults to 7.0.
288
+ use_flash_decoder (bool, optional): Whether to use FlashAttention in the decoder for performance. Defaults to False.
289
+ rmbg (bool, optional): Whether to apply background removal before processing. Defaults to True.
290
+ session_id (str, optional): Optional session ID to manage export paths. If not provided, a random UUID is generated.
291
+ progress (gr.Progress, optional): Gradio progress object for visual feedback. Automatically handled by Gradio.
292
+
293
+ Returns:
294
+ Tuple[str, str, str, str]:
295
+ - `merged_path` (str): File path to the merged full object mesh (`object.glb`).
296
+
297
+ Notes:
298
+ - This function utilizes HuggingFace pretrained weights for both part generation and background removal.
299
+ - The final output includes merged model parts to visualize object structure.
300
+ - Generation time depends on the number of parts and inference parameters.
301
+ """
302
+
303
+ max_num_expanded_coords = 1e9
304
+
305
+ if session_id is None:
306
+ session_id = uuid.uuid4().hex
307
+
308
  if rmbg:
309
  img_pil = prepare_image(image_path, bg_color=np.array([1.0, 1.0, 1.0]), rmbg_net=rmbg_net)
310
  else:
 
319
  generator=torch.Generator(device=pipe.device).manual_seed(seed),
320
  num_inference_steps=num_inference_steps,
321
  guidance_scale=guidance_scale,
322
+ max_num_expanded_coords=max_num_expanded_coords,
323
  use_flash_decoder=use_flash_decoder,
324
  ).meshes
325
+ duration = time.time() - start_time
326
+ print(f"Generation time: {duration:.2f}s")
327
 
328
+ # Ensure no None outputs
329
  for i, mesh in enumerate(outputs):
330
+ if mesh is None:
331
+ outputs[i] = trimesh.Trimesh(vertices=[[0,0,0]], faces=[[0,0,0]])
332
+
333
 
334
  export_dir = os.path.join(os.environ["PARTCRAFTER_PROCESSED"], session_id)
335
+
336
+ # If it already exists, delete it (and all its contents)
337
+ if os.path.exists(export_dir):
338
+ shutil.rmtree(export_dir)
339
+
340
  os.makedirs(export_dir, exist_ok=True)
341
 
342
+ parts = []
343
+
344
  for idx, mesh in enumerate(outputs):
345
+ part = os.path.join(export_dir, f"part_{idx:02}.glb")
346
+ mesh.export(part)
347
+ parts.append(part)
348
 
349
+ # Merge and color
350
  merged = get_colored_mesh_composition(outputs)
351
+ split_mesh = explode_mesh(merged)
352
+
353
  merged_path = os.path.join(export_dir, "object.glb")
354
  merged.export(merged_path)
355
+
356
  return merged_path
357
 
 
 
 
358
  def cleanup(request: gr.Request):
359
+
360
  sid = request.session_hash
361
  if sid:
362
+ d1 = os.path.join(os.environ["PARTCRAFTER_PROCESSED"], sid)
363
+ shutil.rmtree(d1, ignore_errors=True)
364
+
365
+ def start_session(request: gr.Request):
366
 
367
+ return request.session_hash
368
+
369
  def build_demo():
370
+ css = """
371
+ #col-container {
372
+ margin: 0 auto;
373
+ max-width: 1560px;
374
+ }
375
+ """
376
+ theme = gr.themes.Ocean()
377
+
378
+ with gr.Blocks(css=css, theme=theme) as demo:
379
  session_state = gr.State()
380
+ demo.load(start_session, outputs=[session_state])
381
+
382
  with gr.Column(elem_id="col-container"):
383
+ gr.HTML(
384
+ """
385
+ <div style="text-align: center;">
386
+ <p style="font-size:16px; display: inline; margin: 0;">
387
+ <strong>PartCrafter</strong> – Structured 3D Mesh Generation via Compositional Latent Diffusion Transformers
388
+ </p>
389
+ <a href="https://github.com/wgsxm/PartCrafter" style="display: inline-block; vertical-align: middle; margin-left: 0.5em;">
390
+ <img src="https://img.shields.io/badge/GitHub-Repo-blue" alt="GitHub Repo">
391
+ </a>
392
+ </div>
393
+ <div style="text-align: center;">
394
+ HF Space by :<a href="https://twitter.com/alexandernasa/" style="display: inline-block; vertical-align: middle; margin-left: 0.5em;">
395
+ <img src="https://img.shields.io/twitter/url/https/twitter.com/cloudposse.svg?style=social&label=Follow Me" alt="GitHub Repo">
396
+ </a>
397
+ </div>
398
+ """
399
+ )
400
  with gr.Row():
401
  with gr.Column(scale=1):
402
+
403
  input_image = gr.Image(type="filepath", label="Input Image", height=256)
404
  num_parts = gr.Slider(1, MAX_NUM_PARTS, value=4, step=1, label="Number of Parts")
405
  run_button = gr.Button("Step 1 - 🧩 Craft 3D Parts", variant="primary")
406
  video_button = gr.Button("Step 2 - 🎥 Generate Split Preview Gif (Optional)")
407
+
408
  with gr.Accordion("Advanced Settings", open=False):
409
  seed = gr.Number(value=0, label="Random Seed", precision=0)
410
  num_tokens = gr.Slider(256, 2048, value=1024, step=64, label="Num Tokens")
 
412
  guidance = gr.Slider(1.0, 20.0, value=7.0, step=0.1, label="Guidance Scale")
413
  flash_decoder = gr.Checkbox(value=False, label="Use Flash Decoder")
414
  remove_bg = gr.Checkbox(value=True, label="Remove Background (RMBG)")
415
+
416
  with gr.Column(scale=2):
417
+ gr.HTML(
418
+ """
419
+ <p style="opacity: 0.6; font-style: italic;">
420
+ The 3D Preview might take a few seconds to load the 3D model
421
+ </p>
422
+ """
423
+ )
424
+ with gr.Row():
425
+ output_model = gr.Model3D(label="Merged 3D Object", height=512, interactive=False)
426
+ video_output = gr.Image(label="Split Preview", height=512)
427
+ with gr.Row():
428
+ with gr.Column():
429
+ examples = gr.Examples(
430
+
431
+ examples=[
432
+ [
433
+ "assets/images/np5_b81f29e567ea4db48014f89c9079e403.png",
434
+ 5,
435
+ ],
436
+ [
437
+ "assets/images/np7_1c004909dedb4ebe8db69b4d7b077434.png",
438
+ 7,
439
+ ],
440
+ [
441
+ "assets/images/np16_dino.png",
442
+ 16,
443
+ ],
444
+ [
445
+ "assets/images/np13_39c0fa16ed324b54a605dcdbcd80797c.png",
446
+ 13,
447
+ ],
448
+
449
+ ],
450
+ inputs=[input_image, num_parts],
451
+ outputs=[output_model, video_output],
452
+ fn=gen_model_n_video,
453
+ cache_examples=True
454
+ )
455
+
456
+ run_button.click(fn=run_partcrafter,
457
+ inputs=[input_image, num_parts, seed, num_tokens, num_steps,
458
+ guidance, flash_decoder, remove_bg, session_state],
459
+ outputs=[output_model])
460
+ video_button.click(fn=gen_video,
461
+ inputs=[output_model],
462
+ outputs=[video_output])
463
+
464
+ return demo
465
 
466
  if __name__ == "__main__":
467
+ demo = build_demo()
468
+ demo.unload(cleanup)
469
+ demo.queue()
470
+ demo.launch(mcp_server=True, ssr_mode=False)