LogicGoInfotechSpaces commited on
Commit
2eaee03
·
1 Parent(s): 7f02de5

feat(api): allow painted-on mask via mask_is_painted flag (auto-diff + Otsu threshold)

Browse files
Files changed (1) hide show
  1. api/main.py +20 -2
api/main.py CHANGED
@@ -10,6 +10,7 @@ from fastapi import FastAPI, UploadFile, File, HTTPException, Depends, Header, R
10
  from fastapi.responses import FileResponse, JSONResponse
11
  from pydantic import BaseModel
12
  from PIL import Image
 
13
 
14
  from src.core import process_inpaint
15
 
@@ -179,12 +180,29 @@ def inpaint_multipart(
179
  mask: UploadFile = File(...),
180
  request: Request = None,
181
  invert_mask: bool = True,
 
182
  _: None = Depends(bearer_auth),
183
  ) -> Dict[str, str]:
184
  # Load in-memory
185
  img = Image.open(image.file).convert("RGBA")
186
- m = Image.open(mask.file)
187
- mask_rgba = _load_rgba_mask_from_image(m)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
 
189
  result = process_inpaint(np.array(img), mask_rgba, invert_mask=invert_mask)
190
  result_name = f"output_{uuid.uuid4().hex}.png"
 
10
  from fastapi.responses import FileResponse, JSONResponse
11
  from pydantic import BaseModel
12
  from PIL import Image
13
+ import cv2
14
 
15
  from src.core import process_inpaint
16
 
 
180
  mask: UploadFile = File(...),
181
  request: Request = None,
182
  invert_mask: bool = True,
183
+ mask_is_painted: bool = False, # if True, mask file is the painted-on image (e.g., black strokes on original)
184
  _: None = Depends(bearer_auth),
185
  ) -> Dict[str, str]:
186
  # Load in-memory
187
  img = Image.open(image.file).convert("RGBA")
188
+ m = Image.open(mask.file).convert("RGBA")
189
+
190
+ if mask_is_painted:
191
+ # Derive mask by differencing painted image vs original
192
+ img_rgb = cv2.cvtColor(np.array(img), cv2.COLOR_RGBA2RGB)
193
+ m_rgb = cv2.cvtColor(np.array(m), cv2.COLOR_RGBA2RGB)
194
+ diff = cv2.absdiff(img_rgb, m_rgb)
195
+ gray = cv2.cvtColor(diff, cv2.COLOR_RGB2GRAY)
196
+ # Otsu threshold for robustness; fallback threshold if Otsu fails
197
+ try:
198
+ _, binmask = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
199
+ except Exception:
200
+ _, binmask = cv2.threshold(gray, 40, 255, cv2.THRESH_BINARY)
201
+ # Build RGBA mask where selected area has alpha=0
202
+ mask_rgba = np.zeros((binmask.shape[0], binmask.shape[1], 4), dtype=np.uint8)
203
+ mask_rgba[:, :, 3] = np.where(binmask > 0, 0, 255).astype(np.uint8)
204
+ else:
205
+ mask_rgba = _load_rgba_mask_from_image(m)
206
 
207
  result = process_inpaint(np.array(img), mask_rgba, invert_mask=invert_mask)
208
  result_name = f"output_{uuid.uuid4().hex}.png"