Commit
·
e7611db
1
Parent(s):
5452c3d
fix(mask): improve mask loading to match standard convention (white=remove); add debug logs
Browse files- api/main.py +18 -4
- src/core.py +7 -1
api/main.py
CHANGED
|
@@ -125,16 +125,30 @@ def _load_rgba_image(path: str) -> Image.Image:
|
|
| 125 |
|
| 126 |
|
| 127 |
def _load_rgba_mask_from_image(img: Image.Image) -> np.ndarray:
|
| 128 |
-
#
|
|
|
|
| 129 |
if img.mode != "RGBA":
|
| 130 |
-
#
|
| 131 |
gray = img.convert("L")
|
| 132 |
arr = np.array(gray)
|
| 133 |
-
|
|
|
|
| 134 |
rgba = np.zeros((img.height, img.width, 4), dtype=np.uint8)
|
| 135 |
rgba[:, :, 3] = alpha
|
| 136 |
return rgba
|
| 137 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
|
| 139 |
|
| 140 |
@app.post("/inpaint")
|
|
|
|
| 125 |
|
| 126 |
|
| 127 |
def _load_rgba_mask_from_image(img: Image.Image) -> np.ndarray:
|
| 128 |
+
# Standard convention: white=remove (255), black=keep (0)
|
| 129 |
+
# Convert to RGBA where alpha=0 means "to remove", alpha=255 means "keep"
|
| 130 |
if img.mode != "RGBA":
|
| 131 |
+
# For RGB/Grayscale masks: white (value>128) = remove, black (value<=128) = keep
|
| 132 |
gray = img.convert("L")
|
| 133 |
arr = np.array(gray)
|
| 134 |
+
# White pixels (>128) should have alpha=0 (to remove), black pixels (<=128) alpha=255 (keep)
|
| 135 |
+
alpha = np.where(arr > 128, 0, 255).astype(np.uint8)
|
| 136 |
rgba = np.zeros((img.height, img.width, 4), dtype=np.uint8)
|
| 137 |
rgba[:, :, 3] = alpha
|
| 138 |
return rgba
|
| 139 |
+
# For RGBA: check if alpha channel is used or RGB channels
|
| 140 |
+
arr = np.array(img)
|
| 141 |
+
alpha = arr[:, :, 3]
|
| 142 |
+
# If alpha is mostly opaque (mean > 200), treat RGB channels as mask values
|
| 143 |
+
if alpha.mean() > 200:
|
| 144 |
+
# Use RGB to determine mask: white in RGB = remove
|
| 145 |
+
gray = cv2.cvtColor(arr[:, :, :3], cv2.COLOR_RGB2GRAY)
|
| 146 |
+
alpha = np.where(gray > 128, 0, 255).astype(np.uint8)
|
| 147 |
+
rgba = arr.copy()
|
| 148 |
+
rgba[:, :, 3] = alpha
|
| 149 |
+
return rgba
|
| 150 |
+
# Alpha channel already encodes the mask
|
| 151 |
+
return arr
|
| 152 |
|
| 153 |
|
| 154 |
@app.post("/inpaint")
|
src/core.py
CHANGED
|
@@ -459,10 +459,16 @@ def process_inpaint(image, mask, invert_mask=True):
|
|
| 459 |
image = norm_img(image)
|
| 460 |
|
| 461 |
# Convert RGBA mask to single-channel mask.
|
| 462 |
-
#
|
|
|
|
| 463 |
alpha_channel = mask[:,:,3]
|
| 464 |
mask = (255 - alpha_channel) if invert_mask else alpha_channel
|
| 465 |
mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 466 |
mask = norm_img(mask)
|
| 467 |
|
| 468 |
res_np_img = run(image, mask)
|
|
|
|
| 459 |
image = norm_img(image)
|
| 460 |
|
| 461 |
# Convert RGBA mask to single-channel mask.
|
| 462 |
+
# Standard: white=remove (255), black=keep (0)
|
| 463 |
+
# When invert_mask=True (default): alpha=0 (transparent/painted) → 255 (remove), alpha=255 → 0 (keep)
|
| 464 |
alpha_channel = mask[:,:,3]
|
| 465 |
mask = (255 - alpha_channel) if invert_mask else alpha_channel
|
| 466 |
mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
|
| 467 |
+
|
| 468 |
+
# Debug: log mask statistics
|
| 469 |
+
mask_nonzero = int((mask > 128).sum())
|
| 470 |
+
print(f"Mask shape: {mask.shape}, non-zero pixels (>128): {mask_nonzero}")
|
| 471 |
+
|
| 472 |
mask = norm_img(mask)
|
| 473 |
|
| 474 |
res_np_img = run(image, mask)
|