LogicGoInfotechSpaces commited on
Commit
0964a65
·
1 Parent(s): dac0915

fix(mask): improve mask detection - handle RGB paint directly, detect magenta and white areas; match reference model behavior

Browse files
Files changed (2) hide show
  1. api/main.py +9 -3
  2. src/core.py +35 -8
api/main.py CHANGED
@@ -285,10 +285,16 @@ def inpaint_multipart(
285
  nonzero = int((binmask > 0).sum())
286
  log.info("fallback detection: %d pixels", nonzero)
287
 
288
- # Build RGBA mask where painted area has alpha=0 (to be removed)
 
289
  mask_rgba = np.zeros((binmask.shape[0], binmask.shape[1], 4), dtype=np.uint8)
290
- mask_rgba[:, :, 3] = np.where(binmask > 0, 0, 255).astype(np.uint8)
291
- log.info("Final mask: %d pixels marked for removal (alpha=0)", int((mask_rgba[:,:,3] == 0).sum()))
 
 
 
 
 
292
  else:
293
  mask_rgba = _load_rgba_mask_from_image(m)
294
 
 
285
  nonzero = int((binmask > 0).sum())
286
  log.info("fallback detection: %d pixels", nonzero)
287
 
288
+ # Build RGBA mask: painted areas should be white in RGB for direct detection
289
+ # Use RGB channels with white=remove, black=keep, then set alpha appropriately
290
  mask_rgba = np.zeros((binmask.shape[0], binmask.shape[1], 4), dtype=np.uint8)
291
+ # Paint detected areas as white in RGB (will be detected in process_inpaint)
292
+ mask_rgba[:, :, 0] = binmask # R
293
+ mask_rgba[:, :, 1] = binmask # G
294
+ mask_rgba[:, :, 2] = binmask # B
295
+ # Set alpha to opaque so RGB channels are used
296
+ mask_rgba[:, :, 3] = 255
297
+ log.info("Final mask: %d pixels marked for removal (white in RGB)", int((binmask > 0).sum()))
298
  else:
299
  mask_rgba = _load_rgba_mask_from_image(m)
300
 
src/core.py CHANGED
@@ -460,17 +460,40 @@ def process_inpaint(image, mask, invert_mask=True):
460
 
461
  # Convert RGBA mask to single-channel mask.
462
  # Standard LaMa convention: 1 = remove, 0 = keep
463
- # User draws with alpha=0 (transparent), we want those to become 1 (remove)
 
 
 
464
  alpha_channel = mask[:,:,3]
 
 
 
 
465
 
466
- # When invert_mask=True: alpha=0 (painted/transparent) → 255 → 1 (remove)
467
- # When invert_mask=False: alpha=255 (opaque) 255 1 (remove)
468
- if invert_mask:
469
- # Inverted: transparent (0) means remove, opaque (255) means keep
470
- mask = 255 - alpha_channel
 
 
 
 
 
 
 
 
 
 
 
 
 
471
  else:
472
- # Normal: opaque (255) means remove, transparent (0) means keep
473
- mask = alpha_channel
 
 
 
474
 
475
  mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
476
 
@@ -482,6 +505,10 @@ def process_inpaint(image, mask, invert_mask=True):
482
  # Normalize: values > 0 become 1.0, 0 stays 0
483
  # After this, 1.0 = remove, 0.0 = keep (LaMa expects this)
484
  mask = norm_img(mask)
 
 
 
 
485
 
486
  res_np_img = run(image, mask)
487
 
 
460
 
461
  # Convert RGBA mask to single-channel mask.
462
  # Standard LaMa convention: 1 = remove, 0 = keep
463
+ # The mask can come in different formats:
464
+ # - RGBA with alpha channel encoding (alpha=0 means remove when invert_mask=True)
465
+ # - RGBA with RGB encoding (white/colored areas mean remove)
466
+
467
  alpha_channel = mask[:,:,3]
468
+ rgb_channels = mask[:,:,:3]
469
+
470
+ # Check if alpha channel is meaningful (not all 255)
471
+ alpha_mean = alpha_channel.mean()
472
 
473
+ if alpha_mean < 50:
474
+ # Alpha channel is mostly transparent - use alpha directly
475
+ # Transparent (0) = remove, Opaque (255) = keep
476
+ if invert_mask:
477
+ mask = 255 - alpha_channel # transparent → white (remove)
478
+ else:
479
+ mask = alpha_channel # opaque → white (remove)
480
+ elif alpha_mean > 200:
481
+ # Alpha channel is mostly opaque - check RGB channels for paint colors
482
+ # Detect magenta (255, 0, 255) or any bright colored paint
483
+ gray = cv2.cvtColor(rgb_channels, cv2.COLOR_RGB2GRAY)
484
+ # White or bright colors (>200) in RGB = remove
485
+ mask_rgb = (gray > 200).astype(np.uint8) * 255
486
+ # Also detect magenta specifically
487
+ magenta = np.all(rgb_channels == [255, 0, 255], axis=2).astype(np.uint8) * 255
488
+ mask = np.maximum(mask_rgb, magenta)
489
+ if not invert_mask:
490
+ mask = 255 - mask # invert if needed
491
  else:
492
+ # Mixed alpha - use alpha channel with inversion logic
493
+ if invert_mask:
494
+ mask = 255 - alpha_channel
495
+ else:
496
+ mask = alpha_channel
497
 
498
  mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
499
 
 
505
  # Normalize: values > 0 become 1.0, 0 stays 0
506
  # After this, 1.0 = remove, 0.0 = keep (LaMa expects this)
507
  mask = norm_img(mask)
508
+
509
+ # Final check: ensure we have some pixels to remove
510
+ mask_final_pixels = int((mask > 0.5).sum())
511
+ print(f"After normalization: {mask_final_pixels} pixels marked for removal (value > 0.5)")
512
 
513
  res_np_img = run(image, mask)
514