opsiclear-admin commited on
Commit
c458082
·
verified ·
1 Parent(s): 83e9fff

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +56 -19
app.py CHANGED
@@ -396,15 +396,42 @@ def preprocess_image(input: Image.Image, histogram_normalize: bool = False) -> I
396
  return output
397
 
398
 
399
- def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
400
  """
401
  Preprocess a list of input images for multi-image conditioning.
402
  Uses parallel processing for faster background removal.
 
403
  """
404
  images = [image[0] for image in images]
405
  with ThreadPoolExecutor(max_workers=min(4, len(images))) as executor:
406
  processed_images = list(executor.map(preprocess_image, images))
407
- return processed_images
 
 
 
 
 
 
 
 
 
408
 
409
 
410
  def pack_state(latents: Tuple[SparseTensor, SparseTensor, int]) -> dict:
@@ -447,10 +474,10 @@ def prepare_multi_example() -> List[str]:
447
  return examples
448
 
449
 
450
- def load_multi_example(image) -> List[Image.Image]:
451
- """Load all views for a multi-image case by matching the input image."""
452
  if image is None:
453
- return []
454
 
455
  # Convert to PIL Image if needed
456
  if isinstance(image, np.ndarray):
@@ -463,6 +490,7 @@ def load_multi_example(image) -> List[Image.Image]:
463
  example_dir = "assets/example_multi_image"
464
  case_names = sorted(set([f.rsplit('_', 1)[0] for f in os.listdir(example_dir) if f.endswith('.png')]))
465
 
 
466
  for case_name in case_names:
467
  first_img_path = f'{example_dir}/{case_name}_1.png'
468
  if os.path.exists(first_img_path):
@@ -471,18 +499,18 @@ def load_multi_example(image) -> List[Image.Image]:
471
 
472
  # Compare images (check if same shape and content)
473
  if input_rgb.shape == first_rgb.shape and np.array_equal(input_rgb, first_rgb):
474
- # Found match, load all views (without preprocessing - will be done on Generate)
475
- images = []
476
  for i in range(1, 7):
477
  img_path = f'{example_dir}/{case_name}_{i}.png'
478
  if os.path.exists(img_path):
479
- img = Image.open(img_path).convert('RGBA')
480
- images.append(img)
481
- if images:
482
- return images
483
 
484
- # No match found, return the single image
485
- return [image.convert('RGBA') if image.mode != 'RGBA' else image]
 
 
 
486
 
487
 
488
  def split_image(image: Image.Image) -> List[Image.Image]:
@@ -519,16 +547,14 @@ def image_to_3d(
519
  multiimages: List[Tuple[Image.Image, str]],
520
  multiimage_algo: Literal["multidiffusion", "stochastic"],
521
  tex_multiimage_algo: Literal["multidiffusion", "stochastic"],
522
- histogram_normalize: bool,
523
  req: gr.Request,
524
  progress=gr.Progress(track_tqdm=True),
525
  ) -> str:
526
  if not multiimages:
527
  raise gr.Error("Please upload images or select an example first.")
528
 
529
- # Preprocess images (background removal, cropping, etc.)
530
- images = [image[0] for image in multiimages]
531
- processed_images = [preprocess_image(img, histogram_normalize=histogram_normalize) for img in images]
532
 
533
  # Debug: save preprocessed images and log stats
534
  for i, img in enumerate(processed_images):
@@ -773,13 +799,14 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="orange", neutral_hue="slate"))
773
  examples=prepare_multi_example(),
774
  inputs=[example_image],
775
  fn=load_multi_example,
776
- outputs=[multiimage_prompt],
777
  run_on_click=True,
778
  cache_examples=False,
779
  examples_per_page=50,
780
  )
781
 
782
  output_buf = gr.State()
 
783
 
784
 
785
  # Handlers
@@ -788,6 +815,16 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="orange", neutral_hue="slate"))
788
  multiimage_prompt.upload(
789
  preprocess_images,
790
  inputs=[multiimage_prompt],
 
 
 
 
 
 
 
 
 
 
791
  outputs=[multiimage_prompt],
792
  )
793
 
@@ -802,7 +839,7 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="orange", neutral_hue="slate"))
802
  ss_guidance_strength, ss_guidance_rescale, ss_sampling_steps, ss_rescale_t,
803
  shape_slat_guidance_strength, shape_slat_guidance_rescale, shape_slat_sampling_steps, shape_slat_rescale_t,
804
  tex_slat_guidance_strength, tex_slat_guidance_rescale, tex_slat_sampling_steps, tex_slat_rescale_t,
805
- multiimage_prompt, multiimage_algo, tex_multiimage_algo, histogram_normalize
806
  ],
807
  outputs=[output_buf, preview_output],
808
  )
 
396
  return output
397
 
398
 
399
+ def apply_histogram_normalization(img: Image.Image) -> Image.Image:
400
+ """Apply histogram normalization to a preprocessed RGB image (black background)."""
401
+ arr = np.array(img).astype(np.float32) / 255
402
+ fg_mask = arr.sum(axis=2) > 0.05
403
+ if not fg_mask.any():
404
+ return img
405
+ for c in range(3):
406
+ ch = arr[:, :, c]
407
+ fg_vals = ch[fg_mask]
408
+ if fg_vals.max() > fg_vals.min():
409
+ lo, hi = np.percentile(fg_vals, [1, 99])
410
+ if hi > lo:
411
+ ch_norm = np.clip((ch - lo) / (hi - lo), 0, 1)
412
+ arr[:, :, c] = np.where(fg_mask, ch_norm, 0)
413
+ return Image.fromarray((arr * 255).astype(np.uint8))
414
+
415
+
416
+ def preprocess_images(images: List[Tuple[Image.Image, str]]) -> Tuple[List[Image.Image], List[Image.Image]]:
417
  """
418
  Preprocess a list of input images for multi-image conditioning.
419
  Uses parallel processing for faster background removal.
420
+ Returns (gallery_images, base_images_for_state).
421
  """
422
  images = [image[0] for image in images]
423
  with ThreadPoolExecutor(max_workers=min(4, len(images))) as executor:
424
  processed_images = list(executor.map(preprocess_image, images))
425
+ return processed_images, processed_images
426
+
427
+
428
+ def toggle_normalize(base_images: list, histogram_normalize: bool) -> List[Image.Image]:
429
+ """Toggle histogram normalization on stored base images."""
430
+ if not base_images:
431
+ return []
432
+ if histogram_normalize:
433
+ return [apply_histogram_normalization(img) for img in base_images]
434
+ return list(base_images)
435
 
436
 
437
  def pack_state(latents: Tuple[SparseTensor, SparseTensor, int]) -> dict:
 
474
  return examples
475
 
476
 
477
+ def load_multi_example(image) -> Tuple[List[Image.Image], List[Image.Image]]:
478
+ """Load all views for a multi-image case, preprocess, and store base images."""
479
  if image is None:
480
+ return [], []
481
 
482
  # Convert to PIL Image if needed
483
  if isinstance(image, np.ndarray):
 
490
  example_dir = "assets/example_multi_image"
491
  case_names = sorted(set([f.rsplit('_', 1)[0] for f in os.listdir(example_dir) if f.endswith('.png')]))
492
 
493
+ raw_images = None
494
  for case_name in case_names:
495
  first_img_path = f'{example_dir}/{case_name}_1.png'
496
  if os.path.exists(first_img_path):
 
499
 
500
  # Compare images (check if same shape and content)
501
  if input_rgb.shape == first_rgb.shape and np.array_equal(input_rgb, first_rgb):
502
+ raw_images = []
 
503
  for i in range(1, 7):
504
  img_path = f'{example_dir}/{case_name}_{i}.png'
505
  if os.path.exists(img_path):
506
+ raw_images.append(Image.open(img_path).convert('RGBA'))
507
+ break
 
 
508
 
509
+ if not raw_images:
510
+ raw_images = [image.convert('RGBA') if image.mode != 'RGBA' else image]
511
+
512
+ processed = [preprocess_image(img) for img in raw_images]
513
+ return processed, processed
514
 
515
 
516
  def split_image(image: Image.Image) -> List[Image.Image]:
 
547
  multiimages: List[Tuple[Image.Image, str]],
548
  multiimage_algo: Literal["multidiffusion", "stochastic"],
549
  tex_multiimage_algo: Literal["multidiffusion", "stochastic"],
 
550
  req: gr.Request,
551
  progress=gr.Progress(track_tqdm=True),
552
  ) -> str:
553
  if not multiimages:
554
  raise gr.Error("Please upload images or select an example first.")
555
 
556
+ # Use gallery images directly (already preprocessed with optional normalization)
557
+ processed_images = [image[0] for image in multiimages]
 
558
 
559
  # Debug: save preprocessed images and log stats
560
  for i, img in enumerate(processed_images):
 
799
  examples=prepare_multi_example(),
800
  inputs=[example_image],
801
  fn=load_multi_example,
802
+ outputs=[multiimage_prompt, preprocessed_state],
803
  run_on_click=True,
804
  cache_examples=False,
805
  examples_per_page=50,
806
  )
807
 
808
  output_buf = gr.State()
809
+ preprocessed_state = gr.State([])
810
 
811
 
812
  # Handlers
 
815
  multiimage_prompt.upload(
816
  preprocess_images,
817
  inputs=[multiimage_prompt],
818
+ outputs=[multiimage_prompt, preprocessed_state],
819
+ ).then(
820
+ toggle_normalize,
821
+ inputs=[preprocessed_state, histogram_normalize],
822
+ outputs=[multiimage_prompt],
823
+ )
824
+
825
+ histogram_normalize.change(
826
+ toggle_normalize,
827
+ inputs=[preprocessed_state, histogram_normalize],
828
  outputs=[multiimage_prompt],
829
  )
830
 
 
839
  ss_guidance_strength, ss_guidance_rescale, ss_sampling_steps, ss_rescale_t,
840
  shape_slat_guidance_strength, shape_slat_guidance_rescale, shape_slat_sampling_steps, shape_slat_rescale_t,
841
  tex_slat_guidance_strength, tex_slat_guidance_rescale, tex_slat_sampling_steps, tex_slat_rescale_t,
842
+ multiimage_prompt, multiimage_algo, tex_multiimage_algo
843
  ],
844
  outputs=[output_buf, preview_output],
845
  )