Commit
·
ed7d157
1
Parent(s):
c96c733
fix: match reference model exactly - use simple 255-alpha inversion like aryadytm/remove-photo-object
Browse files- src/core.py +32 -44
src/core.py
CHANGED
|
@@ -443,81 +443,69 @@ def get_args_parser():
|
|
| 443 |
|
| 444 |
|
| 445 |
def process_inpaint(image, mask, invert_mask=True):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 446 |
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
|
| 447 |
original_shape = image.shape
|
| 448 |
interpolation = cv2.INTER_CUBIC
|
| 449 |
|
| 450 |
-
#size_limit: Union[int, str] = request.form.get("sizeLimit", "1080")
|
| 451 |
-
#if size_limit == "Original":
|
| 452 |
size_limit = max(image.shape)
|
| 453 |
-
#else:
|
| 454 |
-
# size_limit = int(size_limit)
|
| 455 |
|
| 456 |
print(f"Origin image shape: {original_shape}")
|
| 457 |
image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
|
| 458 |
print(f"Resized image shape: {image.shape}")
|
| 459 |
image = norm_img(image)
|
| 460 |
|
| 461 |
-
#
|
| 462 |
-
#
|
| 463 |
-
#
|
| 464 |
-
#
|
| 465 |
|
|
|
|
| 466 |
alpha_channel = mask[:,:,3]
|
| 467 |
rgb_channels = mask[:,:,:3]
|
| 468 |
-
|
| 469 |
-
# Convert RGB to grayscale to detect white/black
|
| 470 |
-
gray = cv2.cvtColor(rgb_channels, cv2.COLOR_RGB2GRAY)
|
| 471 |
-
|
| 472 |
-
# Standard: white (255) = remove, black (0) = keep
|
| 473 |
-
# Detect white pixels (>128) as removal areas
|
| 474 |
-
mask = (gray > 128).astype(np.uint8) * 255
|
| 475 |
-
|
| 476 |
-
# Also explicitly detect magenta (255, 0, 255) which is commonly used for painting
|
| 477 |
-
magenta = np.all(rgb_channels == [255, 0, 255], axis=2).astype(np.uint8) * 255
|
| 478 |
-
mask = np.maximum(mask, magenta)
|
| 479 |
-
|
| 480 |
-
# If alpha channel is mostly transparent (<50 mean), use it as mask source
|
| 481 |
alpha_mean = alpha_channel.mean()
|
| 482 |
-
if alpha_mean < 50:
|
| 483 |
-
# Transparent areas (alpha=0) should be removed
|
| 484 |
-
if invert_mask:
|
| 485 |
-
mask = np.maximum(mask, (255 - alpha_channel)) # transparent → white
|
| 486 |
-
else:
|
| 487 |
-
mask = np.maximum(mask, alpha_channel) # opaque → white
|
| 488 |
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
mask =
|
| 494 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 495 |
else:
|
| 496 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 497 |
|
| 498 |
mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
|
| 499 |
|
| 500 |
-
# Debug: log mask statistics
|
| 501 |
mask_nonzero = int((mask > 128).sum())
|
| 502 |
mask_total = mask.shape[0] * mask.shape[1]
|
| 503 |
print(f"Mask shape: {mask.shape}, pixels to remove (>128): {mask_nonzero}/{mask_total} ({100*mask_nonzero/mask_total:.1f}%)")
|
| 504 |
|
| 505 |
-
# Normalize: values > 0 become 1.0, 0 stays 0
|
| 506 |
-
# After this, 1.0 = remove, 0.0 = keep (LaMa expects this)
|
| 507 |
mask = norm_img(mask)
|
| 508 |
|
| 509 |
-
# Final check
|
| 510 |
mask_final_pixels = int((mask > 0.5).sum())
|
| 511 |
print(f"After normalization: {mask_final_pixels} pixels marked for removal (value > 0.5)")
|
| 512 |
|
| 513 |
if mask_final_pixels < 10:
|
| 514 |
-
print("WARNING: Very few pixels marked for removal!
|
| 515 |
-
print("Check your mask format: white pixels (255) should indicate areas to remove when invert_mask=True")
|
| 516 |
|
| 517 |
res_np_img = run(image, mask)
|
| 518 |
-
|
| 519 |
-
# Debug: verify output changed
|
| 520 |
-
diff_pixels = int(np.sum(np.abs(res_np_img.astype(np.float32) - cv2.cvtColor(image, cv2.COLOR_RGBA2RGB).astype(np.float32)) > 5))
|
| 521 |
-
print(f"Output check: {diff_pixels} pixels differ from input (should be > 0 if removal worked)")
|
| 522 |
|
| 523 |
return cv2.cvtColor(res_np_img, cv2.COLOR_BGR2RGB)
|
|
|
|
| 443 |
|
| 444 |
|
| 445 |
def process_inpaint(image, mask, invert_mask=True):
|
| 446 |
+
"""
|
| 447 |
+
Process inpainting - matches reference model implementation exactly.
|
| 448 |
+
Reference: https://huggingface.co/spaces/aryadytm/remove-photo-object
|
| 449 |
+
"""
|
| 450 |
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
|
| 451 |
original_shape = image.shape
|
| 452 |
interpolation = cv2.INTER_CUBIC
|
| 453 |
|
|
|
|
|
|
|
| 454 |
size_limit = max(image.shape)
|
|
|
|
|
|
|
| 455 |
|
| 456 |
print(f"Origin image shape: {original_shape}")
|
| 457 |
image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
|
| 458 |
print(f"Resized image shape: {image.shape}")
|
| 459 |
image = norm_img(image)
|
| 460 |
|
| 461 |
+
# Match reference model exactly: invert alpha channel
|
| 462 |
+
# Reference line 460: mask = 255-mask[:,:,3]
|
| 463 |
+
# This means: alpha=0 (transparent/drawn) → 255 (white/remove)
|
| 464 |
+
# alpha=255 (opaque) → 0 (black/keep)
|
| 465 |
|
| 466 |
+
# Check if we should use RGB channels (for uploaded black/white masks)
|
| 467 |
alpha_channel = mask[:,:,3]
|
| 468 |
rgb_channels = mask[:,:,:3]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 469 |
alpha_mean = alpha_channel.mean()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 470 |
|
| 471 |
+
if alpha_mean > 200:
|
| 472 |
+
# Alpha is mostly opaque - use RGB channels (white=remove, black=keep)
|
| 473 |
+
gray = cv2.cvtColor(rgb_channels, cv2.COLOR_RGB2GRAY)
|
| 474 |
+
# White pixels (>128) = remove
|
| 475 |
+
mask = (gray > 128).astype(np.uint8) * 255
|
| 476 |
+
# Also detect magenta specifically
|
| 477 |
+
magenta = np.all(rgb_channels == [255, 0, 255], axis=2).astype(np.uint8) * 255
|
| 478 |
+
mask = np.maximum(mask, magenta)
|
| 479 |
+
|
| 480 |
+
# Apply invert_mask if needed
|
| 481 |
+
if not invert_mask:
|
| 482 |
+
mask = 255 - mask
|
| 483 |
else:
|
| 484 |
+
# Alpha channel encodes mask - use reference model's exact logic
|
| 485 |
+
# Invert alpha: transparent (0) → white (255), opaque (255) → black (0)
|
| 486 |
+
mask = 255 - alpha_channel
|
| 487 |
+
|
| 488 |
+
# Apply invert_mask if user wants opposite
|
| 489 |
+
if not invert_mask:
|
| 490 |
+
mask = 255 - mask # double invert back to original
|
| 491 |
|
| 492 |
mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
|
| 493 |
|
| 494 |
+
# Debug: log mask statistics
|
| 495 |
mask_nonzero = int((mask > 128).sum())
|
| 496 |
mask_total = mask.shape[0] * mask.shape[1]
|
| 497 |
print(f"Mask shape: {mask.shape}, pixels to remove (>128): {mask_nonzero}/{mask_total} ({100*mask_nonzero/mask_total:.1f}%)")
|
| 498 |
|
| 499 |
+
# Normalize: values > 0 become 1.0, 0 stays 0 (LaMa expects this)
|
|
|
|
| 500 |
mask = norm_img(mask)
|
| 501 |
|
| 502 |
+
# Final check
|
| 503 |
mask_final_pixels = int((mask > 0.5).sum())
|
| 504 |
print(f"After normalization: {mask_final_pixels} pixels marked for removal (value > 0.5)")
|
| 505 |
|
| 506 |
if mask_final_pixels < 10:
|
| 507 |
+
print("WARNING: Very few pixels marked for removal! Check mask format.")
|
|
|
|
| 508 |
|
| 509 |
res_np_img = run(image, mask)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 510 |
|
| 511 |
return cv2.cvtColor(res_np_img, cv2.COLOR_BGR2RGB)
|