LogicGoInfotechSpaces commited on
Commit
e7611db
·
1 Parent(s): 5452c3d

fix(mask): improve mask loading to match standard convention (white=remove); add debug logs

Browse files
Files changed (2) hide show
  1. api/main.py +18 -4
  2. src/core.py +7 -1
api/main.py CHANGED
@@ -125,16 +125,30 @@ def _load_rgba_image(path: str) -> Image.Image:
125
 
126
 
127
  def _load_rgba_mask_from_image(img: Image.Image) -> np.ndarray:
128
- # Expected by process_inpaint: RGBA where alpha=0 for drawn (to remove), 255 elsewhere
 
129
  if img.mode != "RGBA":
130
- # If no alpha, treat non-black/white>0 as masked areas
131
  gray = img.convert("L")
132
  arr = np.array(gray)
133
- alpha = np.where(arr > 0, 0, 255).astype(np.uint8)
 
134
  rgba = np.zeros((img.height, img.width, 4), dtype=np.uint8)
135
  rgba[:, :, 3] = alpha
136
  return rgba
137
- return np.array(img)
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
 
140
  @app.post("/inpaint")
 
125
 
126
 
127
  def _load_rgba_mask_from_image(img: Image.Image) -> np.ndarray:
128
+ # Standard convention: white=remove (255), black=keep (0)
129
+ # Convert to RGBA where alpha=0 means "to remove", alpha=255 means "keep"
130
  if img.mode != "RGBA":
131
+ # For RGB/Grayscale masks: white (value>128) = remove, black (value<=128) = keep
132
  gray = img.convert("L")
133
  arr = np.array(gray)
134
+ # White pixels (>128) should have alpha=0 (to remove), black pixels (<=128) alpha=255 (keep)
135
+ alpha = np.where(arr > 128, 0, 255).astype(np.uint8)
136
  rgba = np.zeros((img.height, img.width, 4), dtype=np.uint8)
137
  rgba[:, :, 3] = alpha
138
  return rgba
139
+ # For RGBA: check if alpha channel is used or RGB channels
140
+ arr = np.array(img)
141
+ alpha = arr[:, :, 3]
142
+ # If alpha is mostly opaque (mean > 200), treat RGB channels as mask values
143
+ if alpha.mean() > 200:
144
+ # Use RGB to determine mask: white in RGB = remove
145
+ gray = cv2.cvtColor(arr[:, :, :3], cv2.COLOR_RGB2GRAY)
146
+ alpha = np.where(gray > 128, 0, 255).astype(np.uint8)
147
+ rgba = arr.copy()
148
+ rgba[:, :, 3] = alpha
149
+ return rgba
150
+ # Alpha channel already encodes the mask
151
+ return arr
152
 
153
 
154
  @app.post("/inpaint")
src/core.py CHANGED
@@ -459,10 +459,16 @@ def process_inpaint(image, mask, invert_mask=True):
459
  image = norm_img(image)
460
 
461
  # Convert RGBA mask to single-channel mask.
462
- # When invert_mask=True (default), areas the user paints (alpha=0) become 255 (to remove).
 
463
  alpha_channel = mask[:,:,3]
464
  mask = (255 - alpha_channel) if invert_mask else alpha_channel
465
  mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
 
 
 
 
 
466
  mask = norm_img(mask)
467
 
468
  res_np_img = run(image, mask)
 
459
  image = norm_img(image)
460
 
461
  # Convert RGBA mask to single-channel mask.
462
+ # Standard: white=remove (255), black=keep (0)
463
+ # When invert_mask=True (default): alpha=0 (transparent/painted) → 255 (remove), alpha=255 → 0 (keep)
464
  alpha_channel = mask[:,:,3]
465
  mask = (255 - alpha_channel) if invert_mask else alpha_channel
466
  mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
467
+
468
+ # Debug: log mask statistics
469
+ mask_nonzero = int((mask > 128).sum())
470
+ print(f"Mask shape: {mask.shape}, non-zero pixels (>128): {mask_nonzero}")
471
+
472
  mask = norm_img(mask)
473
 
474
  res_np_img = run(image, mask)