LogicGoInfotechSpaces commited on
Commit
ed7d157
·
1 Parent(s): c96c733

fix: match reference model exactly - use simple 255-alpha inversion like aryadytm/remove-photo-object

Browse files
Files changed (1) hide show
  1. src/core.py +32 -44
src/core.py CHANGED
@@ -443,81 +443,69 @@ def get_args_parser():
443
 
444
 
445
  def process_inpaint(image, mask, invert_mask=True):
 
 
 
 
446
  image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
447
  original_shape = image.shape
448
  interpolation = cv2.INTER_CUBIC
449
 
450
- #size_limit: Union[int, str] = request.form.get("sizeLimit", "1080")
451
- #if size_limit == "Original":
452
  size_limit = max(image.shape)
453
- #else:
454
- # size_limit = int(size_limit)
455
 
456
  print(f"Origin image shape: {original_shape}")
457
  image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
458
  print(f"Resized image shape: {image.shape}")
459
  image = norm_img(image)
460
 
461
- # Convert RGBA mask to single-channel mask.
462
- # Standard LaMa convention: 1 = remove, 0 = keep
463
- # Simple approach: white pixels in RGB = remove, black = keep
464
- # This matches the reference model behavior
465
 
 
466
  alpha_channel = mask[:,:,3]
467
  rgb_channels = mask[:,:,:3]
468
-
469
- # Convert RGB to grayscale to detect white/black
470
- gray = cv2.cvtColor(rgb_channels, cv2.COLOR_RGB2GRAY)
471
-
472
- # Standard: white (255) = remove, black (0) = keep
473
- # Detect white pixels (>128) as removal areas
474
- mask = (gray > 128).astype(np.uint8) * 255
475
-
476
- # Also explicitly detect magenta (255, 0, 255) which is commonly used for painting
477
- magenta = np.all(rgb_channels == [255, 0, 255], axis=2).astype(np.uint8) * 255
478
- mask = np.maximum(mask, magenta)
479
-
480
- # If alpha channel is mostly transparent (<50 mean), use it as mask source
481
  alpha_mean = alpha_channel.mean()
482
- if alpha_mean < 50:
483
- # Transparent areas (alpha=0) should be removed
484
- if invert_mask:
485
- mask = np.maximum(mask, (255 - alpha_channel)) # transparent → white
486
- else:
487
- mask = np.maximum(mask, alpha_channel) # opaque → white
488
 
489
- # Apply invert_mask if needed
490
- # When invert_mask=False: black pixels (0) should become white (255) to remove
491
- # When invert_mask=True (default): white pixels (255) stay white to remove (standard)
492
- if not invert_mask:
493
- mask = 255 - mask
494
- print(f"Applied invert_mask=False: inverted mask - {int((mask > 128).sum())} pixels now marked for removal")
 
 
 
 
 
 
495
  else:
496
- print(f"Using invert_mask=True: {int((mask > 128).sum())} white pixels will be removed (standard)")
 
 
 
 
 
 
497
 
498
  mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
499
 
500
- # Debug: log mask statistics BEFORE normalization
501
  mask_nonzero = int((mask > 128).sum())
502
  mask_total = mask.shape[0] * mask.shape[1]
503
  print(f"Mask shape: {mask.shape}, pixels to remove (>128): {mask_nonzero}/{mask_total} ({100*mask_nonzero/mask_total:.1f}%)")
504
 
505
- # Normalize: values > 0 become 1.0, 0 stays 0
506
- # After this, 1.0 = remove, 0.0 = keep (LaMa expects this)
507
  mask = norm_img(mask)
508
 
509
- # Final check: ensure we have some pixels to remove
510
  mask_final_pixels = int((mask > 0.5).sum())
511
  print(f"After normalization: {mask_final_pixels} pixels marked for removal (value > 0.5)")
512
 
513
  if mask_final_pixels < 10:
514
- print("WARNING: Very few pixels marked for removal! The mask might be empty or inverted.")
515
- print("Check your mask format: white pixels (255) should indicate areas to remove when invert_mask=True")
516
 
517
  res_np_img = run(image, mask)
518
-
519
- # Debug: verify output changed
520
- diff_pixels = int(np.sum(np.abs(res_np_img.astype(np.float32) - cv2.cvtColor(image, cv2.COLOR_RGBA2RGB).astype(np.float32)) > 5))
521
- print(f"Output check: {diff_pixels} pixels differ from input (should be > 0 if removal worked)")
522
 
523
  return cv2.cvtColor(res_np_img, cv2.COLOR_BGR2RGB)
 
443
 
444
 
445
  def process_inpaint(image, mask, invert_mask=True):
446
+ """
447
+ Process inpainting - matches reference model implementation exactly.
448
+ Reference: https://huggingface.co/spaces/aryadytm/remove-photo-object
449
+ """
450
  image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
451
  original_shape = image.shape
452
  interpolation = cv2.INTER_CUBIC
453
 
 
 
454
  size_limit = max(image.shape)
 
 
455
 
456
  print(f"Origin image shape: {original_shape}")
457
  image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
458
  print(f"Resized image shape: {image.shape}")
459
  image = norm_img(image)
460
 
461
+ # Match reference model exactly: invert alpha channel
462
+ # Reference line 460: mask = 255-mask[:,:,3]
463
+ # This means: alpha=0 (transparent/drawn) 255 (white/remove)
464
+ # alpha=255 (opaque) 0 (black/keep)
465
 
466
+ # Check if we should use RGB channels (for uploaded black/white masks)
467
  alpha_channel = mask[:,:,3]
468
  rgb_channels = mask[:,:,:3]
 
 
 
 
 
 
 
 
 
 
 
 
 
469
  alpha_mean = alpha_channel.mean()
 
 
 
 
 
 
470
 
471
+ if alpha_mean > 200:
472
+ # Alpha is mostly opaque - use RGB channels (white=remove, black=keep)
473
+ gray = cv2.cvtColor(rgb_channels, cv2.COLOR_RGB2GRAY)
474
+ # White pixels (>128) = remove
475
+ mask = (gray > 128).astype(np.uint8) * 255
476
+ # Also detect magenta specifically
477
+ magenta = np.all(rgb_channels == [255, 0, 255], axis=2).astype(np.uint8) * 255
478
+ mask = np.maximum(mask, magenta)
479
+
480
+ # Apply invert_mask if needed
481
+ if not invert_mask:
482
+ mask = 255 - mask
483
  else:
484
+ # Alpha channel encodes mask - use reference model's exact logic
485
+ # Invert alpha: transparent (0) → white (255), opaque (255) → black (0)
486
+ mask = 255 - alpha_channel
487
+
488
+ # Apply invert_mask if user wants opposite
489
+ if not invert_mask:
490
+ mask = 255 - mask # double invert back to original
491
 
492
  mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
493
 
494
+ # Debug: log mask statistics
495
  mask_nonzero = int((mask > 128).sum())
496
  mask_total = mask.shape[0] * mask.shape[1]
497
  print(f"Mask shape: {mask.shape}, pixels to remove (>128): {mask_nonzero}/{mask_total} ({100*mask_nonzero/mask_total:.1f}%)")
498
 
499
+ # Normalize: values > 0 become 1.0, 0 stays 0 (LaMa expects this)
 
500
  mask = norm_img(mask)
501
 
502
+ # Final check
503
  mask_final_pixels = int((mask > 0.5).sum())
504
  print(f"After normalization: {mask_final_pixels} pixels marked for removal (value > 0.5)")
505
 
506
  if mask_final_pixels < 10:
507
+ print("WARNING: Very few pixels marked for removal! Check mask format.")
 
508
 
509
  res_np_img = run(image, mask)
 
 
 
 
510
 
511
  return cv2.cvtColor(res_np_img, cv2.COLOR_BGR2RGB)