hysts HF Staff commited on
Commit
62addc9
·
1 Parent(s): 9feb399
Files changed (1) hide show
  1. app.py +9 -8
app.py CHANGED
@@ -29,6 +29,7 @@ from trellis.representations import Gaussian, MeshExtractResult
29
  from trellis.utils import postprocessing_utils, render_utils
30
 
31
  MAX_SEED = np.iinfo(np.int32).max
 
32
 
33
  pipeline = TrellisImageTo3DPipeline.from_pretrained("microsoft/TRELLIS-image-large")
34
  pipeline.cuda()
@@ -177,8 +178,8 @@ def image_to_3d(
177
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
178
 
179
  with (
180
- tempfile.NamedTemporaryFile(suffix=".pth", delete=False) as state_file,
181
- tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as video_file,
182
  ):
183
  save_state_to_file(outputs["gaussian"][0], outputs["mesh"][0], state_file.name)
184
  torch.cuda.empty_cache()
@@ -205,7 +206,7 @@ def extract_glb(
205
  gs, mesh = load_state_from_file(state_path)
206
  glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
207
  torch.cuda.empty_cache()
208
- with tempfile.NamedTemporaryFile(suffix=".glb", delete=False) as glb_file:
209
  glb.export(glb_file.name)
210
  return glb_file.name
211
 
@@ -221,7 +222,7 @@ def extract_gaussian(state_path: str) -> str:
221
  str: The path to the extracted Gaussian file.
222
  """
223
  gs, _ = load_state_from_file(state_path)
224
- with tempfile.NamedTemporaryFile(suffix=".ply", delete=False) as gaussian_file:
225
  gs.save_ply(gaussian_file.name)
226
  return gaussian_file.name
227
 
@@ -278,7 +279,7 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
278
  video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
279
  model_output = gr.Model3D(label="Extracted GLB/Gaussian", height=300)
280
 
281
- state_file_path = gr.Textbox(visible=False)
282
 
283
  examples = gr.Examples(
284
  examples=sorted(pathlib.Path("assets/example_image").glob("*.png")),
@@ -309,7 +310,7 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
309
  slat_guidance_strength,
310
  slat_sampling_steps,
311
  ],
312
- outputs=[state_file_path, video_output],
313
  ).then(
314
  fn=lambda: (gr.Button(interactive=True), gr.Button(interactive=True)),
315
  outputs=[extract_glb_btn, extract_gs_btn],
@@ -322,8 +323,8 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
322
  api_name=False,
323
  )
324
 
325
- extract_glb_btn.click(fn=extract_glb, inputs=[state_file_path, mesh_simplify, texture_size], outputs=model_output)
326
- extract_gs_btn.click(fn=extract_gaussian, inputs=state_file_path, outputs=model_output)
327
 
328
  if __name__ == "__main__":
329
  demo.launch(mcp_server=True)
 
29
  from trellis.utils import postprocessing_utils, render_utils
30
 
31
  MAX_SEED = np.iinfo(np.int32).max
32
+ TEMP_DIR = gr.utils.get_upload_folder()
33
 
34
  pipeline = TrellisImageTo3DPipeline.from_pretrained("microsoft/TRELLIS-image-large")
35
  pipeline.cuda()
 
178
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
179
 
180
  with (
181
+ tempfile.NamedTemporaryFile(suffix=".pth", dir=TEMP_DIR, delete=False) as state_file,
182
+ tempfile.NamedTemporaryFile(suffix=".mp4", dir=TEMP_DIR, delete=False) as video_file,
183
  ):
184
  save_state_to_file(outputs["gaussian"][0], outputs["mesh"][0], state_file.name)
185
  torch.cuda.empty_cache()
 
206
  gs, mesh = load_state_from_file(state_path)
207
  glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
208
  torch.cuda.empty_cache()
209
+ with tempfile.NamedTemporaryFile(suffix=".glb", dir=TEMP_DIR, delete=False) as glb_file:
210
  glb.export(glb_file.name)
211
  return glb_file.name
212
 
 
222
  str: The path to the extracted Gaussian file.
223
  """
224
  gs, _ = load_state_from_file(state_path)
225
+ with tempfile.NamedTemporaryFile(suffix=".ply", dir=TEMP_DIR, delete=False) as gaussian_file:
226
  gs.save_ply(gaussian_file.name)
227
  return gaussian_file.name
228
 
 
279
  video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
280
  model_output = gr.Model3D(label="Extracted GLB/Gaussian", height=300)
281
 
282
+ state_file = gr.File(visible=False)
283
 
284
  examples = gr.Examples(
285
  examples=sorted(pathlib.Path("assets/example_image").glob("*.png")),
 
310
  slat_guidance_strength,
311
  slat_sampling_steps,
312
  ],
313
+ outputs=[state_file, video_output],
314
  ).then(
315
  fn=lambda: (gr.Button(interactive=True), gr.Button(interactive=True)),
316
  outputs=[extract_glb_btn, extract_gs_btn],
 
323
  api_name=False,
324
  )
325
 
326
+ extract_glb_btn.click(fn=extract_glb, inputs=[state_file, mesh_simplify, texture_size], outputs=model_output)
327
+ extract_gs_btn.click(fn=extract_gaussian, inputs=state_file, outputs=model_output)
328
 
329
  if __name__ == "__main__":
330
  demo.launch(mcp_server=True)