opsiclear-admin commited on
Commit
3e5d851
·
verified ·
1 Parent(s): a90a10a

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +34 -106
app.py CHANGED
@@ -1,8 +1,3 @@
1
- import warnings
2
- warnings.filterwarnings("ignore", message=".*torch.distributed.reduce_op.*")
3
- warnings.filterwarnings("ignore", message=".*torch.cuda.amp.autocast.*")
4
- warnings.filterwarnings("ignore", message=".*Default grid_sample and affine_grid behavior.*")
5
-
6
  import gradio as gr
7
  from gradio_client import Client, handle_file
8
  import spaces
@@ -32,9 +27,9 @@ import o_voxel
32
 
33
  # Patch postprocess module with local fix for cumesh.fill_holes() bug
34
  import importlib.util
35
- import sys
36
  _local_postprocess = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'o-voxel', 'o_voxel', 'postprocess.py')
37
  if os.path.exists(_local_postprocess):
 
38
  _spec = importlib.util.spec_from_file_location('o_voxel.postprocess', _local_postprocess)
39
  _mod = importlib.util.module_from_spec(_spec)
40
  _spec.loader.exec_module(_mod)
@@ -341,7 +336,7 @@ def remove_background(input: Image.Image) -> Image.Image:
341
  return output
342
 
343
 
344
- def preprocess_image(input: Image.Image, histogram_normalize: bool = False) -> Image.Image:
345
  """
346
  Preprocess the input image.
347
  """
@@ -375,63 +370,24 @@ def preprocess_image(input: Image.Image, histogram_normalize: bool = False) -> I
375
  size = int(size * 1)
376
  bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2
377
  output = output.crop(bbox) # type: ignore
378
- output = np.array(output).astype(np.float32) / 255
379
- rgb = output[:, :, :3]
380
- alpha = output[:, :, 3:4]
381
- rgb = rgb * alpha # premultiply alpha
382
-
383
- if histogram_normalize:
384
- fg_mask = (alpha[:, :, 0] > 0.05)
385
- if fg_mask.any():
386
- for c in range(3):
387
- ch = rgb[:, :, c]
388
- fg_vals = ch[fg_mask]
389
- if fg_vals.max() > fg_vals.min():
390
- lo, hi = np.percentile(fg_vals, [1, 99])
391
- if hi > lo:
392
- ch_norm = np.clip((ch - lo) / (hi - lo), 0, 1)
393
- rgb[:, :, c] = np.where(fg_mask, ch_norm, 0)
394
-
395
- output = Image.fromarray((rgb * 255).astype(np.uint8))
396
  return output
397
 
398
 
399
- def apply_histogram_normalization(img: Image.Image) -> Image.Image:
400
- """Apply histogram normalization to a preprocessed RGB image (black background)."""
401
- arr = np.array(img).astype(np.float32) / 255
402
- fg_mask = arr.sum(axis=2) > 0.05
403
- if not fg_mask.any():
404
- return img
405
- for c in range(3):
406
- ch = arr[:, :, c]
407
- fg_vals = ch[fg_mask]
408
- if fg_vals.max() > fg_vals.min():
409
- lo, hi = np.percentile(fg_vals, [1, 99])
410
- if hi > lo:
411
- ch_norm = np.clip((ch - lo) / (hi - lo), 0, 1)
412
- arr[:, :, c] = np.where(fg_mask, ch_norm, 0)
413
- return Image.fromarray((arr * 255).astype(np.uint8))
414
-
415
-
416
- def preprocess_images(images: List[Tuple[Image.Image, str]]) -> Tuple[List[Image.Image], List[Image.Image]]:
417
  """
418
  Preprocess a list of input images for multi-image conditioning.
419
- Uses parallel processing for faster background removal.
420
- Returns (gallery_images, base_images_for_state).
421
  """
422
  images = [image[0] for image in images]
423
- with ThreadPoolExecutor(max_workers=min(4, len(images))) as executor:
424
- processed_images = list(executor.map(preprocess_image, images))
425
- return processed_images, processed_images
426
-
427
-
428
- def toggle_normalize(base_images: list, histogram_normalize: bool) -> List[Image.Image]:
429
- """Toggle histogram normalization on stored base images."""
430
- if not base_images:
431
- return []
432
- if histogram_normalize:
433
- return [apply_histogram_normalization(img) for img in base_images]
434
- return list(base_images)
435
 
436
 
437
  def pack_state(latents: Tuple[SparseTensor, SparseTensor, int]) -> dict:
@@ -474,10 +430,10 @@ def prepare_multi_example() -> List[str]:
474
  return examples
475
 
476
 
477
- def load_multi_example(image) -> Tuple[List[Image.Image], List[Image.Image]]:
478
- """Load all views for a multi-image case, preprocess, and store base images."""
479
  if image is None:
480
- return [], []
481
 
482
  # Convert to PIL Image if needed
483
  if isinstance(image, np.ndarray):
@@ -490,7 +446,6 @@ def load_multi_example(image) -> Tuple[List[Image.Image], List[Image.Image]]:
490
  example_dir = "assets/example_multi_image"
491
  case_names = sorted(set([f.rsplit('_', 1)[0] for f in os.listdir(example_dir) if f.endswith('.png')]))
492
 
493
- raw_images = None
494
  for case_name in case_names:
495
  first_img_path = f'{example_dir}/{case_name}_1.png'
496
  if os.path.exists(first_img_path):
@@ -499,18 +454,18 @@ def load_multi_example(image) -> Tuple[List[Image.Image], List[Image.Image]]:
499
 
500
  # Compare images (check if same shape and content)
501
  if input_rgb.shape == first_rgb.shape and np.array_equal(input_rgb, first_rgb):
502
- raw_images = []
 
503
  for i in range(1, 7):
504
  img_path = f'{example_dir}/{case_name}_{i}.png'
505
  if os.path.exists(img_path):
506
- raw_images.append(Image.open(img_path).convert('RGBA'))
507
- break
508
-
509
- if not raw_images:
510
- raw_images = [image.convert('RGBA') if image.mode != 'RGBA' else image]
511
 
512
- processed = [preprocess_image(img) for img in raw_images]
513
- return processed, processed
514
 
515
 
516
  def split_image(image: Image.Image) -> List[Image.Image]:
@@ -553,16 +508,9 @@ def image_to_3d(
553
  if not multiimages:
554
  raise gr.Error("Please upload images or select an example first.")
555
 
556
- # Use gallery images directly (already preprocessed with optional normalization)
557
- processed_images = [image[0] for image in multiimages]
558
-
559
- # Debug: save preprocessed images and log stats
560
- for i, img in enumerate(processed_images):
561
- arr = np.array(img)
562
- print(f"[DEBUG] Preprocessed image {i}: mode={img.mode}, size={img.size}, "
563
- f"dtype={arr.dtype}, min={arr.min()}, max={arr.max()}, mean={arr.mean():.1f}")
564
- img.save(os.path.join(TMP_DIR, f'debug_preprocessed_{i}.png'))
565
- print(f"[DEBUG] Pipeline params: mode={multiimage_algo}, tex_mode={tex_multiimage_algo}")
566
 
567
  # --- Sampling ---
568
  outputs, latents = pipeline.run_multi_image(
@@ -599,15 +547,6 @@ def image_to_3d(
599
  mesh = outputs[0]
600
  mesh.simplify(16777216) # nvdiffrast limit
601
  images = render_utils.render_snapshot(mesh, resolution=1024, r=2, fov=36, nviews=STEPS, envmap=envmap)
602
-
603
- # Debug: save base_color render and log stats for all render modes
604
- for key in images:
605
- arr = images[key][0] # first view
606
- print(f"[DEBUG] Render '{key}': shape={arr.shape}, min={arr.min()}, max={arr.max()}, mean={arr.mean():.1f}")
607
- # Save base_color and shaded_forest for inspection
608
- Image.fromarray(images['base_color'][0]).save(os.path.join(TMP_DIR, 'debug_base_color.png'))
609
- Image.fromarray(images['shaded_forest'][0]).save(os.path.join(TMP_DIR, 'debug_shaded_forest.png'))
610
-
611
  state = pack_state(latents)
612
  torch.cuda.empty_cache()
613
 
@@ -747,6 +686,7 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="orange", neutral_hue="slate"))
747
  <li>Click an example below to load a pre-made multi-view set, or upload your own images.</li>
748
  <li>Click <b>Generate</b> to create the 3D model, then <b>Extract GLB</b> to export.</li>
749
  <li style="color: #e67300;"><b>⚠️ Note:</b> Generation quality is highly sensitive to parameters. Adjust settings in Advanced Settings if results are unsatisfactory.</li>
 
750
  </ul>
751
  </div>
752
  </div>
@@ -782,25 +722,23 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="orange", neutral_hue="slate"))
782
  tex_slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
783
  tex_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1)
784
  multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Structure Algorithm", value="stochastic")
785
- tex_multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Texture Algorithm", value="stochastic")
786
- histogram_normalize = gr.Checkbox(label="Histogram Normalize", value=False)
787
 
788
  with gr.Column(scale=10):
789
- with gr.Row():
790
- generate_btn = gr.Button("Generate", variant="primary")
791
- extract_btn = gr.Button("Extract GLB")
792
-
793
  preview_output = gr.HTML(empty_html, label="3D Asset Preview", show_label=True, container=True)
794
  glb_output = gr.Model3D(label="Extracted GLB", height=400, show_label=True, display_mode="solid", clear_color=(0.25, 0.25, 0.25, 1.0))
795
  download_btn = gr.DownloadButton(label="Download GLB")
796
 
 
 
 
 
797
  example_image = gr.Image(visible=False) # Hidden component for examples
798
- preprocessed_state = gr.State([])
799
  examples_multi = gr.Examples(
800
  examples=prepare_multi_example(),
801
  inputs=[example_image],
802
  fn=load_multi_example,
803
- outputs=[multiimage_prompt, preprocessed_state],
804
  run_on_click=True,
805
  cache_examples=False,
806
  examples_per_page=50,
@@ -815,16 +753,6 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="orange", neutral_hue="slate"))
815
  multiimage_prompt.upload(
816
  preprocess_images,
817
  inputs=[multiimage_prompt],
818
- outputs=[multiimage_prompt, preprocessed_state],
819
- ).then(
820
- toggle_normalize,
821
- inputs=[preprocessed_state, histogram_normalize],
822
- outputs=[multiimage_prompt],
823
- )
824
-
825
- histogram_normalize.change(
826
- toggle_normalize,
827
- inputs=[preprocessed_state, histogram_normalize],
828
  outputs=[multiimage_prompt],
829
  )
830
 
 
 
 
 
 
 
1
  import gradio as gr
2
  from gradio_client import Client, handle_file
3
  import spaces
 
27
 
28
  # Patch postprocess module with local fix for cumesh.fill_holes() bug
29
  import importlib.util
 
30
  _local_postprocess = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'o-voxel', 'o_voxel', 'postprocess.py')
31
  if os.path.exists(_local_postprocess):
32
+ import sys
33
  _spec = importlib.util.spec_from_file_location('o_voxel.postprocess', _local_postprocess)
34
  _mod = importlib.util.module_from_spec(_spec)
35
  _spec.loader.exec_module(_mod)
 
336
  return output
337
 
338
 
339
+ def preprocess_image(input: Image.Image) -> Image.Image:
340
  """
341
  Preprocess the input image.
342
  """
 
370
  size = int(size * 1)
371
  bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2
372
  output = output.crop(bbox) # type: ignore
373
+ output_np = np.array(output).astype(np.float32)
374
+ rgb = output_np[:, :, :3]
375
+ alpha = output_np[:, :, 3:4] / 255.0
376
+ # Keep full RGB for visible pixels, zero out transparent background
377
+ mask = (alpha > 0.05).astype(np.float32)
378
+ rgb = rgb * mask
379
+ output = Image.fromarray(rgb.astype(np.uint8))
 
 
 
 
 
 
 
 
 
 
 
380
  return output
381
 
382
 
383
+ @spaces.GPU(duration=60)
384
+ def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
385
  """
386
  Preprocess a list of input images for multi-image conditioning.
 
 
387
  """
388
  images = [image[0] for image in images]
389
+ processed_images = [preprocess_image(img) for img in images]
390
+ return processed_images
 
 
 
 
 
 
 
 
 
 
391
 
392
 
393
  def pack_state(latents: Tuple[SparseTensor, SparseTensor, int]) -> dict:
 
430
  return examples
431
 
432
 
433
+ def load_multi_example(image) -> List[Image.Image]:
434
+ """Load all views for a multi-image case by matching the input image."""
435
  if image is None:
436
+ return []
437
 
438
  # Convert to PIL Image if needed
439
  if isinstance(image, np.ndarray):
 
446
  example_dir = "assets/example_multi_image"
447
  case_names = sorted(set([f.rsplit('_', 1)[0] for f in os.listdir(example_dir) if f.endswith('.png')]))
448
 
 
449
  for case_name in case_names:
450
  first_img_path = f'{example_dir}/{case_name}_1.png'
451
  if os.path.exists(first_img_path):
 
454
 
455
  # Compare images (check if same shape and content)
456
  if input_rgb.shape == first_rgb.shape and np.array_equal(input_rgb, first_rgb):
457
+ # Found match, load all views (without preprocessing - will be done on Generate)
458
+ images = []
459
  for i in range(1, 7):
460
  img_path = f'{example_dir}/{case_name}_{i}.png'
461
  if os.path.exists(img_path):
462
+ img = Image.open(img_path).convert('RGBA')
463
+ images.append(img)
464
+ if images:
465
+ return images
 
466
 
467
+ # No match found, return the single image
468
+ return [image.convert('RGBA') if image.mode != 'RGBA' else image]
469
 
470
 
471
  def split_image(image: Image.Image) -> List[Image.Image]:
 
508
  if not multiimages:
509
  raise gr.Error("Please upload images or select an example first.")
510
 
511
+ # Preprocess images (background removal, cropping, etc.)
512
+ images = [image[0] for image in multiimages]
513
+ processed_images = [preprocess_image(img) for img in images]
 
 
 
 
 
 
 
514
 
515
  # --- Sampling ---
516
  outputs, latents = pipeline.run_multi_image(
 
547
  mesh = outputs[0]
548
  mesh.simplify(16777216) # nvdiffrast limit
549
  images = render_utils.render_snapshot(mesh, resolution=1024, r=2, fov=36, nviews=STEPS, envmap=envmap)
 
 
 
 
 
 
 
 
 
550
  state = pack_state(latents)
551
  torch.cuda.empty_cache()
552
 
 
686
  <li>Click an example below to load a pre-made multi-view set, or upload your own images.</li>
687
  <li>Click <b>Generate</b> to create the 3D model, then <b>Extract GLB</b> to export.</li>
688
  <li style="color: #e67300;"><b>⚠️ Note:</b> Generation quality is highly sensitive to parameters. Adjust settings in Advanced Settings if results are unsatisfactory.</li>
689
+ <li style="color: #cc3333;"><b>⚠️ Non-Commercial:</b> This space uses models with licenses that <b>forbid commercial use</b> (BRIA RMBG-2.0: CC BY-NC 4.0, nvdiffrast/nvdiffrec: NVIDIA Source Code License).</li>
690
  </ul>
691
  </div>
692
  </div>
 
722
  tex_slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
723
  tex_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1)
724
  multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Structure Algorithm", value="stochastic")
725
+ tex_multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Texture Algorithm", value="multidiffusion")
 
726
 
727
  with gr.Column(scale=10):
 
 
 
 
728
  preview_output = gr.HTML(empty_html, label="3D Asset Preview", show_label=True, container=True)
729
  glb_output = gr.Model3D(label="Extracted GLB", height=400, show_label=True, display_mode="solid", clear_color=(0.25, 0.25, 0.25, 1.0))
730
  download_btn = gr.DownloadButton(label="Download GLB")
731
 
732
+ with gr.Row():
733
+ generate_btn = gr.Button("Generate", variant="primary")
734
+ extract_btn = gr.Button("Extract GLB")
735
+
736
  example_image = gr.Image(visible=False) # Hidden component for examples
 
737
  examples_multi = gr.Examples(
738
  examples=prepare_multi_example(),
739
  inputs=[example_image],
740
  fn=load_multi_example,
741
+ outputs=[multiimage_prompt],
742
  run_on_click=True,
743
  cache_examples=False,
744
  examples_per_page=50,
 
753
  multiimage_prompt.upload(
754
  preprocess_images,
755
  inputs=[multiimage_prompt],
 
 
 
 
 
 
 
 
 
 
756
  outputs=[multiimage_prompt],
757
  )
758