opsiclear-admin commited on
Commit
e3cd099
·
verified ·
1 Parent(s): 9c5ed9a

Fix GLB export: add DownloadButton and return tuple

Browse files
Files changed (1) hide show
  1. app.py +53 -97
app.py CHANGED
@@ -1,18 +1,17 @@
1
- import warnings
2
- warnings.filterwarnings("ignore", message=".*torch.distributed.reduce_op.*")
3
- warnings.filterwarnings("ignore", message=".*torch.cuda.amp.autocast.*")
4
- warnings.filterwarnings("ignore", message=".*Default grid_sample and affine_grid behavior.*")
5
-
6
  import gradio as gr
7
  from gradio_client import Client, handle_file
8
  import spaces
9
  from concurrent.futures import ThreadPoolExecutor
10
 
11
  import os
 
 
 
 
12
  os.environ["OPENCV_IO_ENABLE_OPENEXR"] = '1'
13
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
14
  os.environ["ATTN_BACKEND"] = "flash_attn_3"
15
- os.environ["FLEX_GEMM_AUTOTUNE_CACHE_PATH"] = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'autotune_cache.json')
16
  os.environ["FLEX_GEMM_AUTOTUNER_VERBOSE"] = '1'
17
  from datetime import datetime
18
  import shutil
@@ -32,8 +31,7 @@ import o_voxel
32
 
33
  # Patch postprocess module with local fix for cumesh.fill_holes() bug
34
  import importlib.util
35
- import sys
36
- _local_postprocess = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'o-voxel', 'o_voxel', 'postprocess.py')
37
  if os.path.exists(_local_postprocess):
38
  _spec = importlib.util.spec_from_file_location('o_voxel.postprocess', _local_postprocess)
39
  _mod = importlib.util.module_from_spec(_spec)
@@ -328,8 +326,7 @@ def start_session(req: gr.Request):
328
 
329
  def end_session(req: gr.Request):
330
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
331
- if os.path.exists(user_dir):
332
- shutil.rmtree(user_dir)
333
 
334
 
335
  def remove_background(input: Image.Image) -> Image.Image:
@@ -375,10 +372,9 @@ def preprocess_image(input: Image.Image) -> Image.Image:
375
  size = int(size * 1)
376
  bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2
377
  output = output.crop(bbox) # type: ignore
378
- output_np = np.array(output)
379
- alpha = output_np[:, :, 3]
380
- output_np[:, :, :3][alpha < 0.5 * 255] = [0, 0, 0]
381
- output = Image.fromarray(output_np[:, :, :3])
382
  return output
383
 
384
 
@@ -435,40 +431,34 @@ def prepare_multi_example() -> List[str]:
435
 
436
  def load_multi_example(image) -> List[Image.Image]:
437
  """Load all views for a multi-image case by matching the input image."""
438
- if image is None:
439
- return []
440
 
441
- # Convert to PIL Image if needed
442
  if isinstance(image, np.ndarray):
443
  image = Image.fromarray(image)
444
 
445
- # Convert to RGB for consistent comparison
446
- input_rgb = np.array(image.convert('RGB'))
447
 
448
  # Find matching case by comparing with first images
449
- example_dir = "assets/example_multi_image"
450
- case_names = sorted(set([f.rsplit('_', 1)[0] for f in os.listdir(example_dir) if f.endswith('.png')]))
451
-
452
- for case_name in case_names:
453
- first_img_path = f'{example_dir}/{case_name}_1.png'
454
  if os.path.exists(first_img_path):
455
- first_img = Image.open(first_img_path).convert('RGB')
456
- first_rgb = np.array(first_img)
457
-
458
- # Compare images (check if same shape and content)
459
- if input_rgb.shape == first_rgb.shape and np.array_equal(input_rgb, first_rgb):
460
- # Found match, load all views (without preprocessing - will be done on Generate)
461
  images = []
462
  for i in range(1, 7):
463
- img_path = f'{example_dir}/{case_name}_{i}.png'
464
  if os.path.exists(img_path):
465
- img = Image.open(img_path).convert('RGBA')
466
- images.append(img)
467
- if images:
468
- return images
469
 
470
- # No match found, return the single image
471
- return [image.convert('RGBA') if image.mode != 'RGBA' else image]
472
 
473
 
474
  def split_image(image: Image.Image) -> List[Image.Image]:
@@ -486,7 +476,7 @@ def split_image(image: Image.Image) -> List[Image.Image]:
486
  return [preprocess_image(image) for image in images]
487
 
488
 
489
- @spaces.GPU(duration=120)
490
  def image_to_3d(
491
  seed: int,
492
  resolution: str,
@@ -503,29 +493,14 @@ def image_to_3d(
503
  tex_slat_sampling_steps: int,
504
  tex_slat_rescale_t: float,
505
  multiimages: List[Tuple[Image.Image, str]],
506
- multiimage_algo: Literal["multidiffusion", "stochastic"],
507
- tex_multiimage_algo: Literal["multidiffusion", "stochastic"],
508
  req: gr.Request,
509
  progress=gr.Progress(track_tqdm=True),
510
  ) -> str:
511
- if not multiimages:
512
- raise gr.Error("Please upload images or select an example first.")
513
-
514
- # Preprocess images (background removal, cropping, etc.)
515
- images = [image[0] for image in multiimages]
516
- processed_images = [preprocess_image(img) for img in images]
517
-
518
- # Debug: save preprocessed images and log stats
519
- for i, img in enumerate(processed_images):
520
- arr = np.array(img)
521
- print(f"[DEBUG] Preprocessed image {i}: mode={img.mode}, size={img.size}, "
522
- f"dtype={arr.dtype}, min={arr.min()}, max={arr.max()}, mean={arr.mean():.1f}")
523
- img.save(os.path.join(TMP_DIR, f'debug_preprocessed_{i}.png'))
524
- print(f"[DEBUG] Pipeline params: mode={multiimage_algo}, tex_mode={tex_multiimage_algo}")
525
-
526
  # --- Sampling ---
527
  outputs, latents = pipeline.run_multi_image(
528
- processed_images,
529
  seed=seed,
530
  preprocess_image=False,
531
  sparse_structure_sampler_params={
@@ -558,16 +533,8 @@ def image_to_3d(
558
  mesh = outputs[0]
559
  mesh.simplify(16777216) # nvdiffrast limit
560
  images = render_utils.render_snapshot(mesh, resolution=1024, r=2, fov=36, nviews=STEPS, envmap=envmap)
561
-
562
- # Debug: save base_color render and log stats for all render modes
563
- for key in images:
564
- arr = images[key][0] # first view
565
- print(f"[DEBUG] Render '{key}': shape={arr.shape}, min={arr.min()}, max={arr.max()}, mean={arr.mean():.1f}")
566
- # Save base_color and shaded_forest for inspection
567
- Image.fromarray(images['base_color'][0]).save(os.path.join(TMP_DIR, 'debug_base_color.png'))
568
- Image.fromarray(images['shaded_forest'][0]).save(os.path.join(TMP_DIR, 'debug_shaded_forest.png'))
569
-
570
  state = pack_state(latents)
 
571
  torch.cuda.empty_cache()
572
 
573
  # --- HTML Construction ---
@@ -653,7 +620,7 @@ def extract_glb(
653
  texture_size: int,
654
  req: gr.Request,
655
  progress=gr.Progress(track_tqdm=True),
656
- ) -> str:
657
  """
658
  Extract a GLB file from the 3D model.
659
 
@@ -663,7 +630,7 @@ def extract_glb(
663
  texture_size (int): The texture resolution.
664
 
665
  Returns:
666
- str: The path to the extracted GLB file.
667
  """
668
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
669
  shape_slat, tex_slat, res = unpack_state(state)
@@ -690,25 +657,14 @@ def extract_glb(
690
  glb_path = os.path.join(user_dir, f'sample_{timestamp}.glb')
691
  glb.export(glb_path, extension_webp=True)
692
  torch.cuda.empty_cache()
693
- return glb_path
694
-
695
-
696
- with gr.Blocks(theme=gr.themes.Soft(primary_hue="orange", neutral_hue="slate")) as demo:
697
- gr.HTML("""
698
- <div style="display: flex; align-items: center; gap: 20px; margin-bottom: 10px;">
699
- <a href="https://www.opsiclear.com" target="_blank">
700
- <img src="https://www.opsiclear.com/assets/logos/Logo_v2_compact_name.svg" alt="OpsiClear" style="height: 80px;">
701
- </a>
702
- <div>
703
- <h2 style="margin: 0;">Multi-View to 3D with <a href="https://microsoft.github.io/TRELLIS.2" target="_blank">TRELLIS.2</a></h2>
704
- <ul style="margin: 5px 0; padding-left: 20px;">
705
- <li>Upload multiple images from different viewpoints to create a 3D asset with multi-image conditioning.</li>
706
- <li>Click an example below to load a pre-made multi-view set, or upload your own images.</li>
707
- <li>Click <b>Generate</b> to create the 3D model, then <b>Extract GLB</b> to export.</li>
708
- <li style="color: #e67300;"><b>⚠️ Note:</b> Generation quality is highly sensitive to parameters. Adjust settings in Advanced Settings if results are unsatisfactory.</li>
709
- </ul>
710
- </div>
711
- </div>
712
  """)
713
 
714
  with gr.Row():
@@ -721,6 +677,10 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="orange", neutral_hue="slate"))
721
  decimation_target = gr.Slider(100000, 500000, label="Decimation Target", value=300000, step=10000)
722
  texture_size = gr.Slider(1024, 4096, label="Texture Size", value=2048, step=1024)
723
 
 
 
 
 
724
  with gr.Accordion(label="Advanced Settings", open=False):
725
  gr.Markdown("Stage 1: Sparse Structure Generation")
726
  with gr.Row():
@@ -741,15 +701,12 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="orange", neutral_hue="slate"))
741
  tex_slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
742
  tex_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1)
743
  multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Structure Algorithm", value="stochastic")
744
- tex_multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Texture Algorithm", value="stochastic")
745
 
746
  with gr.Column(scale=10):
747
  preview_output = gr.HTML(empty_html, label="3D Asset Preview", show_label=True, container=True)
748
- glb_output = gr.Model3D(label="Extracted GLB", height=400, show_label=True, display_mode="solid", clear_color=(0.25, 0.25, 0.25, 1.0), visible=False)
749
-
750
- with gr.Row():
751
- generate_btn = gr.Button("Generate", variant="primary")
752
- extract_btn = gr.Button("Extract GLB")
753
 
754
  example_image = gr.Image(visible=False) # Hidden component for examples
755
  examples_multi = gr.Examples(
@@ -758,8 +715,7 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="orange", neutral_hue="slate"))
758
  fn=load_multi_example,
759
  outputs=[multiimage_prompt],
760
  run_on_click=True,
761
- cache_examples=False,
762
- examples_per_page=50,
763
  )
764
 
765
  output_buf = gr.State()
@@ -793,7 +749,7 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="orange", neutral_hue="slate"))
793
  extract_btn.click(
794
  extract_glb,
795
  inputs=[output_buf, decimation_target, texture_size],
796
- outputs=[glb_output],
797
  )
798
 
799
 
@@ -810,7 +766,7 @@ if __name__ == "__main__":
810
  rmbg_client = Client("briaai/BRIA-RMBG-2.0")
811
  pipeline = Trellis2ImageTo3DPipeline.from_pretrained('microsoft/TRELLIS.2-4B')
812
  pipeline.rembg_model = None
813
- pipeline.low_vram = False
814
  pipeline.cuda()
815
 
816
  envmap = {
@@ -828,4 +784,4 @@ if __name__ == "__main__":
828
  )),
829
  }
830
 
831
- demo.launch(css=css, head=head)
 
 
 
 
 
 
1
  import gradio as gr
2
  from gradio_client import Client, handle_file
3
  import spaces
4
  from concurrent.futures import ThreadPoolExecutor
5
 
6
  import os
7
+ import sys
8
+
9
+ _script_dir = os.path.dirname(os.path.abspath(__file__))
10
+
11
  os.environ["OPENCV_IO_ENABLE_OPENEXR"] = '1'
12
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
13
  os.environ["ATTN_BACKEND"] = "flash_attn_3"
14
+ os.environ["FLEX_GEMM_AUTOTUNE_CACHE_PATH"] = os.path.join(_script_dir, 'autotune_cache.json')
15
  os.environ["FLEX_GEMM_AUTOTUNER_VERBOSE"] = '1'
16
  from datetime import datetime
17
  import shutil
 
31
 
32
  # Patch postprocess module with local fix for cumesh.fill_holes() bug
33
  import importlib.util
34
+ _local_postprocess = os.path.join(_script_dir, 'o-voxel', 'o_voxel', 'postprocess.py')
 
35
  if os.path.exists(_local_postprocess):
36
  _spec = importlib.util.spec_from_file_location('o_voxel.postprocess', _local_postprocess)
37
  _mod = importlib.util.module_from_spec(_spec)
 
326
 
327
  def end_session(req: gr.Request):
328
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
329
+ shutil.rmtree(user_dir)
 
330
 
331
 
332
  def remove_background(input: Image.Image) -> Image.Image:
 
372
  size = int(size * 1)
373
  bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2
374
  output = output.crop(bbox) # type: ignore
375
+ output = np.array(output).astype(np.float32) / 255
376
+ output = output[:, :, :3] * output[:, :, 3:4]
377
+ output = Image.fromarray((output * 255).astype(np.uint8))
 
378
  return output
379
 
380
 
 
431
 
432
  def load_multi_example(image) -> List[Image.Image]:
433
  """Load all views for a multi-image case by matching the input image."""
434
+ import hashlib
 
435
 
436
+ # Convert numpy array to PIL Image if needed
437
  if isinstance(image, np.ndarray):
438
  image = Image.fromarray(image)
439
 
440
+ # Get hash of input image for matching
441
+ input_hash = hashlib.md5(np.array(image.convert('RGBA')).tobytes()).hexdigest()
442
 
443
  # Find matching case by comparing with first images
444
+ multi_case = sorted(set([i.split('_')[0] for i in os.listdir("assets/example_multi_image")]))
445
+ for case_name in multi_case:
446
+ first_img_path = f'assets/example_multi_image/{case_name}_1.png'
 
 
447
  if os.path.exists(first_img_path):
448
+ first_img = Image.open(first_img_path).convert('RGBA')
449
+ first_hash = hashlib.md5(np.array(first_img).tobytes()).hexdigest()
450
+ if first_hash == input_hash:
451
+ # Found match, load all views
 
 
452
  images = []
453
  for i in range(1, 7):
454
+ img_path = f'assets/example_multi_image/{case_name}_{i}.png'
455
  if os.path.exists(img_path):
456
+ img = Image.open(img_path)
457
+ images.append(preprocess_image(img))
458
+ return images
 
459
 
460
+ # No match found, return the single image preprocessed
461
+ return [preprocess_image(image)]
462
 
463
 
464
  def split_image(image: Image.Image) -> List[Image.Image]:
 
476
  return [preprocess_image(image) for image in images]
477
 
478
 
479
+ @spaces.GPU(duration=90)
480
  def image_to_3d(
481
  seed: int,
482
  resolution: str,
 
493
  tex_slat_sampling_steps: int,
494
  tex_slat_rescale_t: float,
495
  multiimages: List[Tuple[Image.Image, str]],
496
+ multiimage_algo: Literal["stochastic", "multidiffusion"],
497
+ tex_multiimage_algo: Literal["stochastic", "multidiffusion"],
498
  req: gr.Request,
499
  progress=gr.Progress(track_tqdm=True),
500
  ) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
501
  # --- Sampling ---
502
  outputs, latents = pipeline.run_multi_image(
503
+ [image[0] for image in multiimages],
504
  seed=seed,
505
  preprocess_image=False,
506
  sparse_structure_sampler_params={
 
533
  mesh = outputs[0]
534
  mesh.simplify(16777216) # nvdiffrast limit
535
  images = render_utils.render_snapshot(mesh, resolution=1024, r=2, fov=36, nviews=STEPS, envmap=envmap)
 
 
 
 
 
 
 
 
 
536
  state = pack_state(latents)
537
+ del outputs, mesh, latents # Free memory
538
  torch.cuda.empty_cache()
539
 
540
  # --- HTML Construction ---
 
620
  texture_size: int,
621
  req: gr.Request,
622
  progress=gr.Progress(track_tqdm=True),
623
+ ) -> Tuple[str, str]:
624
  """
625
  Extract a GLB file from the 3D model.
626
 
 
630
  texture_size (int): The texture resolution.
631
 
632
  Returns:
633
+ Tuple[str, str]: The path to the extracted GLB file (for Model3D and DownloadButton).
634
  """
635
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
636
  shape_slat, tex_slat, res = unpack_state(state)
 
657
  glb_path = os.path.join(user_dir, f'sample_{timestamp}.glb')
658
  glb.export(glb_path, extension_webp=True)
659
  torch.cuda.empty_cache()
660
+ return glb_path, glb_path
661
+
662
+
663
+ with gr.Blocks(delete_cache=(600, 600), theme=gr.themes.Soft(primary_hue="orange", neutral_hue="slate")) as demo:
664
+ gr.Markdown("""
665
+ ## Image to 3D Asset with [TRELLIS.2](https://microsoft.github.io/TRELLIS.2)
666
+ * Upload an image and click Generate to create a 3D asset. If the image has alpha channel, it will be used as the mask. Otherwise, background is automatically removed.
667
+ * Click Extract GLB to export the GLB file if you're satisfied with the preview.
 
 
 
 
 
 
 
 
 
 
 
668
  """)
669
 
670
  with gr.Row():
 
677
  decimation_target = gr.Slider(100000, 500000, label="Decimation Target", value=300000, step=10000)
678
  texture_size = gr.Slider(1024, 4096, label="Texture Size", value=2048, step=1024)
679
 
680
+ with gr.Row():
681
+ generate_btn = gr.Button("Generate", variant="primary")
682
+ extract_btn = gr.Button("Extract GLB")
683
+
684
  with gr.Accordion(label="Advanced Settings", open=False):
685
  gr.Markdown("Stage 1: Sparse Structure Generation")
686
  with gr.Row():
 
701
  tex_slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
702
  tex_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1)
703
  multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Structure Algorithm", value="stochastic")
704
+ tex_multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Texture Algorithm", value="multidiffusion")
705
 
706
  with gr.Column(scale=10):
707
  preview_output = gr.HTML(empty_html, label="3D Asset Preview", show_label=True, container=True)
708
+ glb_output = gr.Model3D(label="Extracted GLB", height=400, show_label=True, display_mode="solid", clear_color=(0.25, 0.25, 0.25, 1.0))
709
+ download_btn = gr.DownloadButton(label="Download GLB")
 
 
 
710
 
711
  example_image = gr.Image(visible=False) # Hidden component for examples
712
  examples_multi = gr.Examples(
 
715
  fn=load_multi_example,
716
  outputs=[multiimage_prompt],
717
  run_on_click=True,
718
+ examples_per_page=24,
 
719
  )
720
 
721
  output_buf = gr.State()
 
749
  extract_btn.click(
750
  extract_glb,
751
  inputs=[output_buf, decimation_target, texture_size],
752
+ outputs=[glb_output, download_btn],
753
  )
754
 
755
 
 
766
  rmbg_client = Client("briaai/BRIA-RMBG-2.0")
767
  pipeline = Trellis2ImageTo3DPipeline.from_pretrained('microsoft/TRELLIS.2-4B')
768
  pipeline.rembg_model = None
769
+ pipeline.low_vram = True # Enable low VRAM mode for better memory efficiency
770
  pipeline.cuda()
771
 
772
  envmap = {
 
784
  )),
785
  }
786
 
787
+ demo.queue(max_size=10, default_concurrency_limit=1).launch(css=css, head=head)