simonpick commited on
Commit
11eb381
Β·
verified Β·
1 Parent(s): be21a32

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -37
app.py CHANGED
@@ -565,16 +565,25 @@ def unpack_state(state: dict) -> Tuple[SparseTensor, SparseTensor, int]:
565
  return shape_slat, tex_slat, state['res']
566
 
567
 
568
- @spaces.GPU(duration=120)
569
- def image_to_3d(
570
  image: Image.Image,
571
  req: gr.Request,
572
  progress=gr.Progress(track_tqdm=True),
573
- ) -> str:
 
 
 
 
 
 
 
574
  # Hardcoded values
575
  seed = np.random.randint(0, MAX_SEED)
576
- resolution = "1024"
 
577
 
 
578
  outputs, latents = pipeline.run(
579
  image,
580
  seed=seed,
@@ -602,11 +611,11 @@ def image_to_3d(
602
  )
603
  mesh = outputs[0]
604
  mesh.simplify(16777216)
 
 
605
  images = render_utils.render_snapshot(mesh, resolution=1024, r=2, fov=36, nviews=STEPS, envmap=envmap)
606
- state = pack_state(latents)
607
- torch.cuda.empty_cache()
608
 
609
- # Build HTML
610
  images_html = ""
611
  for m_idx, mode in enumerate(MODES):
612
  for s_idx in range(STEPS):
@@ -621,7 +630,7 @@ def image_to_3d(
621
  active_class = "active" if idx == DEFAULT_MODE else ""
622
  btns_html += f'<img src="{mode["icon_base64"]}" class="mode-btn {active_class}" onclick="selectMode({idx})" title="{mode["name"]}">'
623
 
624
- full_html = f"""
625
  <div class="previewer-container">
626
  <div class="display-row">{images_html}</div>
627
  <div class="mode-row">{btns_html}</div>
@@ -631,23 +640,11 @@ def image_to_3d(
631
  </div>
632
  """
633
 
634
- return state, full_html
635
-
636
-
637
- @spaces.GPU(duration=120)
638
- def extract_glb(
639
- state: dict,
640
- req: gr.Request,
641
- progress=gr.Progress(track_tqdm=True),
642
- ) -> Tuple[str, str]:
643
- # Hardcoded values
644
- decimation_target = 300000
645
- texture_size = 4096
646
-
647
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
648
- shape_slat, tex_slat, res = unpack_state(state)
649
  mesh = pipeline.decode_latent(shape_slat, tex_slat, res)[0]
650
  mesh.simplify(16777216)
 
651
  glb = o_voxel.postprocess.to_glb(
652
  vertices=mesh.vertices,
653
  faces=mesh.faces,
@@ -663,13 +660,16 @@ def extract_glb(
663
  remesh_project=0,
664
  use_tqdm=True,
665
  )
 
666
  now = datetime.now()
667
  timestamp = now.strftime("%Y-%m-%dT%H%M%S") + f".{now.microsecond // 1000:03d}"
668
- os.makedirs(user_dir, exist_ok=True)
669
  glb_path = os.path.join(user_dir, f'sample_{timestamp}.glb')
670
  glb.export(glb_path, extension_webp=True)
 
671
  torch.cuda.empty_cache()
672
- return glb_path, glb_path
 
 
673
 
674
 
675
  # ═══════════════════════════════════════════════════════════════
@@ -720,9 +720,7 @@ with gr.Blocks(theme=gr.themes.Base(primary_hue="indigo"), delete_cache=(600, 60
720
  download_btn = gr.DownloadButton("Download GLB", elem_classes=["primary-btn"], size="lg")
721
 
722
  # Footer
723
- gr.HTML('<div class="footer-note">Generation includes automatic GLB extraction. This may take 60+ seconds total.</div>')
724
-
725
- output_buf = gr.State()
726
 
727
  # Event Handlers
728
  demo.load(start_session)
@@ -734,19 +732,13 @@ with gr.Blocks(theme=gr.themes.Base(primary_hue="indigo"), delete_cache=(600, 60
734
  outputs=[image_prompt],
735
  )
736
 
737
- # Generate 3D now automatically chains to Extract GLB
738
  generate_btn.click(
739
- lambda: gr.Walkthrough(selected=0), outputs=walkthrough
740
- ).then(
741
- image_to_3d,
742
  inputs=[image_prompt],
743
- outputs=[output_buf, preview_output],
744
  ).then(
745
  lambda: gr.Walkthrough(selected=1), outputs=walkthrough
746
- ).then(
747
- extract_glb,
748
- inputs=[output_buf],
749
- outputs=[glb_output, download_btn],
750
  )
751
 
752
 
 
565
  return shape_slat, tex_slat, state['res']
566
 
567
 
568
+ @spaces.GPU(duration=180)
569
+ def generate_and_extract(
570
  image: Image.Image,
571
  req: gr.Request,
572
  progress=gr.Progress(track_tqdm=True),
573
+ ) -> Tuple[str, str, str]:
574
+ """
575
+ Combined function: Generate 3D from image AND extract GLB in one GPU session.
576
+ This avoids issues with chaining multiple @spaces.GPU functions.
577
+ """
578
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
579
+ os.makedirs(user_dir, exist_ok=True)
580
+
581
  # Hardcoded values
582
  seed = np.random.randint(0, MAX_SEED)
583
+ decimation_target = 300000
584
+ texture_size = 4096
585
 
586
+ # === STAGE 1: Generate 3D ===
587
  outputs, latents = pipeline.run(
588
  image,
589
  seed=seed,
 
611
  )
612
  mesh = outputs[0]
613
  mesh.simplify(16777216)
614
+
615
+ # Render preview images
616
  images = render_utils.render_snapshot(mesh, resolution=1024, r=2, fov=36, nviews=STEPS, envmap=envmap)
 
 
617
 
618
+ # Build preview HTML
619
  images_html = ""
620
  for m_idx, mode in enumerate(MODES):
621
  for s_idx in range(STEPS):
 
630
  active_class = "active" if idx == DEFAULT_MODE else ""
631
  btns_html += f'<img src="{mode["icon_base64"]}" class="mode-btn {active_class}" onclick="selectMode({idx})" title="{mode["name"]}">'
632
 
633
+ preview_html = f"""
634
  <div class="previewer-container">
635
  <div class="display-row">{images_html}</div>
636
  <div class="mode-row">{btns_html}</div>
 
640
  </div>
641
  """
642
 
643
+ # === STAGE 2: Extract GLB ===
644
+ shape_slat, tex_slat, res = latents
 
 
 
 
 
 
 
 
 
 
 
 
 
645
  mesh = pipeline.decode_latent(shape_slat, tex_slat, res)[0]
646
  mesh.simplify(16777216)
647
+
648
  glb = o_voxel.postprocess.to_glb(
649
  vertices=mesh.vertices,
650
  faces=mesh.faces,
 
660
  remesh_project=0,
661
  use_tqdm=True,
662
  )
663
+
664
  now = datetime.now()
665
  timestamp = now.strftime("%Y-%m-%dT%H%M%S") + f".{now.microsecond // 1000:03d}"
 
666
  glb_path = os.path.join(user_dir, f'sample_{timestamp}.glb')
667
  glb.export(glb_path, extension_webp=True)
668
+
669
  torch.cuda.empty_cache()
670
+
671
+ # Return: preview_html, glb_path (for viewer), glb_path (for download)
672
+ return preview_html, glb_path, glb_path
673
 
674
 
675
  # ═══════════════════════════════════════════════════════════════
 
720
  download_btn = gr.DownloadButton("Download GLB", elem_classes=["primary-btn"], size="lg")
721
 
722
  # Footer
723
+ gr.HTML('<div class="footer-note">Generation includes automatic GLB extraction. This may take 90+ seconds total.</div>')
 
 
724
 
725
  # Event Handlers
726
  demo.load(start_session)
 
732
  outputs=[image_prompt],
733
  )
734
 
735
+ # Single GPU call: Generate 3D + Extract GLB
736
  generate_btn.click(
737
+ generate_and_extract,
 
 
738
  inputs=[image_prompt],
739
+ outputs=[preview_output, glb_output, download_btn],
740
  ).then(
741
  lambda: gr.Walkthrough(selected=1), outputs=walkthrough
 
 
 
 
742
  )
743
 
744