LogicGoInfotechSpaces commited on
Commit
f854294
·
1 Parent(s): 0964a65

feat(api): auto-convert painted images to black/white masks (white=remove, black=keep) for better compatibility

Browse files
Files changed (1) hide show
  1. api/main.py +48 -21
api/main.py CHANGED
@@ -126,40 +126,56 @@ def _load_rgba_image(path: str) -> Image.Image:
126
 
127
  def _load_rgba_mask_from_image(img: Image.Image) -> np.ndarray:
128
  """
129
- Convert mask image to RGBA format.
130
  Standard convention: white (255) = area to remove, black (0) = area to keep
131
- Returns RGBA where alpha=0 means "to remove", alpha=255 means "keep"
132
- (This will be inverted in process_inpaint if invert_mask=True)
133
  """
134
  if img.mode != "RGBA":
135
  # For RGB/Grayscale masks: white (value>128) = remove, black (value<=128) = keep
136
  gray = img.convert("L")
137
  arr = np.array(gray)
138
- # White pixels (>128) should have alpha=0 (to remove after inversion)
139
- # Black pixels (<=128) should have alpha=255 (to keep after inversion)
140
- alpha = np.where(arr > 128, 0, 255).astype(np.uint8)
141
  rgba = np.zeros((img.height, img.width, 4), dtype=np.uint8)
142
- rgba[:, :, 3] = alpha
143
- log.info(f"Loaded {img.mode} mask: {int((alpha == 0).sum())} pixels marked for removal (alpha=0)")
 
 
 
144
  return rgba
145
 
146
  # For RGBA: check if alpha channel is meaningful
147
  arr = np.array(img)
148
  alpha = arr[:, :, 3]
 
149
 
150
  # If alpha is mostly opaque everywhere (mean > 200), treat RGB channels as mask values
151
  if alpha.mean() > 200:
152
- # Use RGB to determine mask: white in RGB = remove
153
- gray = cv2.cvtColor(arr[:, :, :3], cv2.COLOR_RGB2GRAY)
154
- alpha = np.where(gray > 128, 0, 255).astype(np.uint8)
 
 
 
155
  rgba = arr.copy()
156
- rgba[:, :, 3] = alpha
157
- log.info(f"Loaded RGBA mask (RGB-based): {int((alpha == 0).sum())} pixels marked for removal (alpha=0)")
 
 
 
158
  return rgba
159
 
160
- # Alpha channel already encodes the mask
161
- log.info(f"Loaded RGBA mask (alpha-based): {int((alpha < 128).sum())} pixels marked for removal (alpha<128)")
162
- return arr
 
 
 
 
 
 
 
163
 
164
 
165
  @app.post("/inpaint")
@@ -285,16 +301,27 @@ def inpaint_multipart(
285
  nonzero = int((binmask > 0).sum())
286
  log.info("fallback detection: %d pixels", nonzero)
287
 
288
- # Build RGBA mask: painted areas should be white in RGB for direct detection
289
- # Use RGB channels with white=remove, black=keep, then set alpha appropriately
290
  mask_rgba = np.zeros((binmask.shape[0], binmask.shape[1], 4), dtype=np.uint8)
291
- # Paint detected areas as white in RGB (will be detected in process_inpaint)
 
292
  mask_rgba[:, :, 0] = binmask # R
293
  mask_rgba[:, :, 1] = binmask # G
294
  mask_rgba[:, :, 2] = binmask # B
295
- # Set alpha to opaque so RGB channels are used
296
  mask_rgba[:, :, 3] = 255
297
- log.info("Final mask: %d pixels marked for removal (white in RGB)", int((binmask > 0).sum()))
 
 
 
 
 
 
 
 
 
 
298
  else:
299
  mask_rgba = _load_rgba_mask_from_image(m)
300
 
 
126
 
127
  def _load_rgba_mask_from_image(img: Image.Image) -> np.ndarray:
128
  """
129
+ Convert mask image to RGBA format (black/white mask).
130
  Standard convention: white (255) = area to remove, black (0) = area to keep
131
+ Returns RGBA with white in RGB channels where removal is needed, alpha=255
 
132
  """
133
  if img.mode != "RGBA":
134
  # For RGB/Grayscale masks: white (value>128) = remove, black (value<=128) = keep
135
  gray = img.convert("L")
136
  arr = np.array(gray)
137
+ # Create proper black/white mask: white pixels (>128) = remove, black (<=128) = keep
138
+ mask_bw = np.where(arr > 128, 255, 0).astype(np.uint8)
139
+
140
  rgba = np.zeros((img.height, img.width, 4), dtype=np.uint8)
141
+ rgba[:, :, 0] = mask_bw # R
142
+ rgba[:, :, 1] = mask_bw # G
143
+ rgba[:, :, 2] = mask_bw # B
144
+ rgba[:, :, 3] = 255 # Fully opaque
145
+ log.info(f"Loaded {img.mode} mask: {int((mask_bw > 0).sum())} white pixels (to remove)")
146
  return rgba
147
 
148
  # For RGBA: check if alpha channel is meaningful
149
  arr = np.array(img)
150
  alpha = arr[:, :, 3]
151
+ rgb = arr[:, :, :3]
152
 
153
  # If alpha is mostly opaque everywhere (mean > 200), treat RGB channels as mask values
154
  if alpha.mean() > 200:
155
+ # Use RGB to determine mask: white/bright in RGB = remove
156
+ gray = cv2.cvtColor(rgb, cv2.COLOR_RGB2GRAY)
157
+ # Also detect magenta specifically
158
+ magenta = np.all(rgb == [255, 0, 255], axis=2).astype(np.uint8) * 255
159
+ mask_bw = np.maximum(np.where(gray > 128, 255, 0).astype(np.uint8), magenta)
160
+
161
  rgba = arr.copy()
162
+ rgba[:, :, 0] = mask_bw # R
163
+ rgba[:, :, 1] = mask_bw # G
164
+ rgba[:, :, 2] = mask_bw # B
165
+ rgba[:, :, 3] = 255 # Fully opaque
166
+ log.info(f"Loaded RGBA mask (RGB-based): {int((mask_bw > 0).sum())} white pixels (to remove)")
167
  return rgba
168
 
169
+ # Alpha channel encodes the mask - convert to RGB-based
170
+ # Transparent areas (alpha < 128) = remove, Opaque areas = keep
171
+ mask_bw = np.where(alpha < 128, 255, 0).astype(np.uint8)
172
+ rgba = arr.copy()
173
+ rgba[:, :, 0] = mask_bw
174
+ rgba[:, :, 1] = mask_bw
175
+ rgba[:, :, 2] = mask_bw
176
+ rgba[:, :, 3] = 255
177
+ log.info(f"Loaded RGBA mask (alpha-based): {int((mask_bw > 0).sum())} white pixels (to remove)")
178
+ return rgba
179
 
180
 
181
  @app.post("/inpaint")
 
301
  nonzero = int((binmask > 0).sum())
302
  log.info("fallback detection: %d pixels", nonzero)
303
 
304
+ # Build RGBA mask: convert to proper black/white mask
305
+ # White (255) = remove, Black (0) = keep (standard convention)
306
  mask_rgba = np.zeros((binmask.shape[0], binmask.shape[1], 4), dtype=np.uint8)
307
+
308
+ # Set RGB channels: white where paint detected, black elsewhere
309
  mask_rgba[:, :, 0] = binmask # R
310
  mask_rgba[:, :, 1] = binmask # G
311
  mask_rgba[:, :, 2] = binmask # B
312
+ # Set alpha to opaque so it's treated as a standard RGB mask
313
  mask_rgba[:, :, 3] = 255
314
+
315
+ # Also create a cleaner version: apply morphological operations to smooth edges
316
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
317
+ binmask_clean = cv2.morphologyEx(binmask, cv2.MORPH_CLOSE, kernel)
318
+ binmask_clean = cv2.morphologyEx(binmask_clean, cv2.MORPH_OPEN, kernel)
319
+ mask_rgba[:, :, 0] = binmask_clean
320
+ mask_rgba[:, :, 1] = binmask_clean
321
+ mask_rgba[:, :, 2] = binmask_clean
322
+
323
+ log.info("Auto-converted painted image to black/white mask: %d white pixels (to remove)",
324
+ int((binmask_clean > 0).sum()))
325
  else:
326
  mask_rgba = _load_rgba_mask_from_image(m)
327