Commit
·
0964a65
1
Parent(s):
dac0915
fix(mask): improve mask detection - handle RGB paint directly, detect magenta and white areas; match reference model behavior
Browse files- api/main.py +9 -3
- src/core.py +35 -8
api/main.py
CHANGED
|
@@ -285,10 +285,16 @@ def inpaint_multipart(
|
|
| 285 |
nonzero = int((binmask > 0).sum())
|
| 286 |
log.info("fallback detection: %d pixels", nonzero)
|
| 287 |
|
| 288 |
-
# Build RGBA mask
|
|
|
|
| 289 |
mask_rgba = np.zeros((binmask.shape[0], binmask.shape[1], 4), dtype=np.uint8)
|
| 290 |
-
|
| 291 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 292 |
else:
|
| 293 |
mask_rgba = _load_rgba_mask_from_image(m)
|
| 294 |
|
|
|
|
| 285 |
nonzero = int((binmask > 0).sum())
|
| 286 |
log.info("fallback detection: %d pixels", nonzero)
|
| 287 |
|
| 288 |
+
# Build RGBA mask: painted areas should be white in RGB for direct detection
|
| 289 |
+
# Use RGB channels with white=remove, black=keep, then set alpha appropriately
|
| 290 |
mask_rgba = np.zeros((binmask.shape[0], binmask.shape[1], 4), dtype=np.uint8)
|
| 291 |
+
# Paint detected areas as white in RGB (will be detected in process_inpaint)
|
| 292 |
+
mask_rgba[:, :, 0] = binmask # R
|
| 293 |
+
mask_rgba[:, :, 1] = binmask # G
|
| 294 |
+
mask_rgba[:, :, 2] = binmask # B
|
| 295 |
+
# Set alpha to opaque so RGB channels are used
|
| 296 |
+
mask_rgba[:, :, 3] = 255
|
| 297 |
+
log.info("Final mask: %d pixels marked for removal (white in RGB)", int((binmask > 0).sum()))
|
| 298 |
else:
|
| 299 |
mask_rgba = _load_rgba_mask_from_image(m)
|
| 300 |
|
src/core.py
CHANGED
|
@@ -460,17 +460,40 @@ def process_inpaint(image, mask, invert_mask=True):
|
|
| 460 |
|
| 461 |
# Convert RGBA mask to single-channel mask.
|
| 462 |
# Standard LaMa convention: 1 = remove, 0 = keep
|
| 463 |
-
#
|
|
|
|
|
|
|
|
|
|
| 464 |
alpha_channel = mask[:,:,3]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 465 |
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 471 |
else:
|
| 472 |
-
#
|
| 473 |
-
|
|
|
|
|
|
|
|
|
|
| 474 |
|
| 475 |
mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
|
| 476 |
|
|
@@ -482,6 +505,10 @@ def process_inpaint(image, mask, invert_mask=True):
|
|
| 482 |
# Normalize: values > 0 become 1.0, 0 stays 0
|
| 483 |
# After this, 1.0 = remove, 0.0 = keep (LaMa expects this)
|
| 484 |
mask = norm_img(mask)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 485 |
|
| 486 |
res_np_img = run(image, mask)
|
| 487 |
|
|
|
|
| 460 |
|
| 461 |
# Convert RGBA mask to single-channel mask.
|
| 462 |
# Standard LaMa convention: 1 = remove, 0 = keep
|
| 463 |
+
# The mask can come in different formats:
|
| 464 |
+
# - RGBA with alpha channel encoding (alpha=0 means remove when invert_mask=True)
|
| 465 |
+
# - RGBA with RGB encoding (white/colored areas mean remove)
|
| 466 |
+
|
| 467 |
alpha_channel = mask[:,:,3]
|
| 468 |
+
rgb_channels = mask[:,:,:3]
|
| 469 |
+
|
| 470 |
+
# Check if alpha channel is meaningful (not all 255)
|
| 471 |
+
alpha_mean = alpha_channel.mean()
|
| 472 |
|
| 473 |
+
if alpha_mean < 50:
|
| 474 |
+
# Alpha channel is mostly transparent - use alpha directly
|
| 475 |
+
# Transparent (0) = remove, Opaque (255) = keep
|
| 476 |
+
if invert_mask:
|
| 477 |
+
mask = 255 - alpha_channel # transparent → white (remove)
|
| 478 |
+
else:
|
| 479 |
+
mask = alpha_channel # opaque → white (remove)
|
| 480 |
+
elif alpha_mean > 200:
|
| 481 |
+
# Alpha channel is mostly opaque - check RGB channels for paint colors
|
| 482 |
+
# Detect magenta (255, 0, 255) or any bright colored paint
|
| 483 |
+
gray = cv2.cvtColor(rgb_channels, cv2.COLOR_RGB2GRAY)
|
| 484 |
+
# White or bright colors (>200) in RGB = remove
|
| 485 |
+
mask_rgb = (gray > 200).astype(np.uint8) * 255
|
| 486 |
+
# Also detect magenta specifically
|
| 487 |
+
magenta = np.all(rgb_channels == [255, 0, 255], axis=2).astype(np.uint8) * 255
|
| 488 |
+
mask = np.maximum(mask_rgb, magenta)
|
| 489 |
+
if not invert_mask:
|
| 490 |
+
mask = 255 - mask # invert if needed
|
| 491 |
else:
|
| 492 |
+
# Mixed alpha - use alpha channel with inversion logic
|
| 493 |
+
if invert_mask:
|
| 494 |
+
mask = 255 - alpha_channel
|
| 495 |
+
else:
|
| 496 |
+
mask = alpha_channel
|
| 497 |
|
| 498 |
mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
|
| 499 |
|
|
|
|
| 505 |
# Normalize: values > 0 become 1.0, 0 stays 0
|
| 506 |
# After this, 1.0 = remove, 0.0 = keep (LaMa expects this)
|
| 507 |
mask = norm_img(mask)
|
| 508 |
+
|
| 509 |
+
# Final check: ensure we have some pixels to remove
|
| 510 |
+
mask_final_pixels = int((mask > 0.5).sum())
|
| 511 |
+
print(f"After normalization: {mask_final_pixels} pixels marked for removal (value > 0.5)")
|
| 512 |
|
| 513 |
res_np_img = run(image, mask)
|
| 514 |
|