fix(api): remove black paint detection - just auto-detect pink segments, everything else kept automatically
87c6da8
| import os | |
| import uuid | |
| import shutil | |
| from datetime import datetime | |
| from typing import Dict, List, Optional | |
| import numpy as np | |
| from fastapi import FastAPI, UploadFile, File, HTTPException, Depends, Header, Request | |
| from fastapi.responses import FileResponse, JSONResponse | |
| from pydantic import BaseModel | |
| from PIL import Image | |
| import cv2 | |
| import logging | |
| logging.basicConfig(level=logging.INFO) | |
| log = logging.getLogger("api") | |
| from src.core import process_inpaint | |
| # Directories (use writable space on HF Spaces) | |
| BASE_DIR = os.environ.get("DATA_DIR", "/data") | |
| if not os.path.isdir(BASE_DIR): | |
| # Fallback to /tmp if /data not available | |
| BASE_DIR = "/tmp" | |
| UPLOAD_DIR = os.path.join(BASE_DIR, "uploads") | |
| OUTPUT_DIR = os.path.join(BASE_DIR, "outputs") | |
| os.makedirs(UPLOAD_DIR, exist_ok=True) | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| # Optional Bearer token: set env API_TOKEN to require auth; if not set, endpoints are open | |
| ENV_TOKEN = os.environ.get("API_TOKEN") | |
| app = FastAPI(title="Photo Object Removal API", version="1.0.0") | |
| # In-memory stores | |
| file_store: Dict[str, Dict[str, str]] = {} | |
| logs: List[Dict[str, str]] = [] | |
| def bearer_auth(authorization: Optional[str] = Header(default=None)) -> None: | |
| if not ENV_TOKEN: | |
| return | |
| if authorization is None or not authorization.lower().startswith("bearer "): | |
| raise HTTPException(status_code=401, detail="Unauthorized") | |
| token = authorization.split(" ", 1)[1] | |
| if token != ENV_TOKEN: | |
| raise HTTPException(status_code=403, detail="Forbidden") | |
| class InpaintRequest(BaseModel): | |
| image_id: str | |
| mask_id: str | |
| invert_mask: bool = True # True => selected/painted area is removed | |
| passthrough: bool = False # If True, return the original image unchanged | |
| class SimpleRemoveRequest(BaseModel): | |
| image_id: str # Image with pink/magenta segments to remove | |
| def root() -> Dict[str, object]: | |
| return { | |
| "name": "Photo Object Removal API", | |
| "status": "ok", | |
| "endpoints": { | |
| "GET /health": "health check", | |
| "POST /upload-image": "form-data: image=file", | |
| "POST /upload-mask": "form-data: mask=file", | |
| "POST /inpaint": "JSON: {image_id, mask_id}", | |
| "POST /inpaint-multipart": "form-data: image=file, mask=file", | |
| "POST /remove-pink": "form-data: image=file (auto-detects pink segments and removes them)", | |
| "GET /download/{filename}": "download result image", | |
| "GET /result/{filename}": "view result image in browser", | |
| "GET /logs": "recent uploads/results", | |
| }, | |
| "auth": "set API_TOKEN env var to require Authorization: Bearer <token> (except /health)", | |
| } | |
| def health() -> Dict[str, str]: | |
| return {"status": "healthy"} | |
| def upload_image(image: UploadFile = File(...), _: None = Depends(bearer_auth)) -> Dict[str, str]: | |
| ext = os.path.splitext(image.filename)[1] or ".png" | |
| file_id = str(uuid.uuid4()) | |
| stored_name = f"{file_id}{ext}" | |
| stored_path = os.path.join(UPLOAD_DIR, stored_name) | |
| with open(stored_path, "wb") as f: | |
| shutil.copyfileobj(image.file, f) | |
| file_store[file_id] = { | |
| "type": "image", | |
| "filename": image.filename, | |
| "stored_name": stored_name, | |
| "path": stored_path, | |
| "timestamp": datetime.utcnow().isoformat(), | |
| } | |
| logs.append({"id": file_id, "filename": image.filename, "type": "image", "timestamp": datetime.utcnow().isoformat()}) | |
| return {"id": file_id, "filename": image.filename} | |
| def upload_mask(mask: UploadFile = File(...), _: None = Depends(bearer_auth)) -> Dict[str, str]: | |
| ext = os.path.splitext(mask.filename)[1] or ".png" | |
| file_id = str(uuid.uuid4()) | |
| stored_name = f"{file_id}{ext}" | |
| stored_path = os.path.join(UPLOAD_DIR, stored_name) | |
| with open(stored_path, "wb") as f: | |
| shutil.copyfileobj(mask.file, f) | |
| file_store[file_id] = { | |
| "type": "mask", | |
| "filename": mask.filename, | |
| "stored_name": stored_name, | |
| "path": stored_path, | |
| "timestamp": datetime.utcnow().isoformat(), | |
| } | |
| logs.append({"id": file_id, "filename": mask.filename, "type": "mask", "timestamp": datetime.utcnow().isoformat()}) | |
| return {"id": file_id, "filename": mask.filename} | |
| def _load_rgba_image(path: str) -> Image.Image: | |
| img = Image.open(path) | |
| return img.convert("RGBA") | |
| def _load_rgba_mask_from_image(img: Image.Image) -> np.ndarray: | |
| """ | |
| Convert mask image to RGBA format (black/white mask). | |
| Standard convention: white (255) = area to remove, black (0) = area to keep | |
| Returns RGBA with white in RGB channels where removal is needed, alpha=255 | |
| """ | |
| if img.mode != "RGBA": | |
| # For RGB/Grayscale masks: white (value>128) = remove, black (value<=128) = keep | |
| gray = img.convert("L") | |
| arr = np.array(gray) | |
| # Create proper black/white mask: white pixels (>128) = remove, black (<=128) = keep | |
| mask_bw = np.where(arr > 128, 255, 0).astype(np.uint8) | |
| rgba = np.zeros((img.height, img.width, 4), dtype=np.uint8) | |
| rgba[:, :, 0] = mask_bw # R | |
| rgba[:, :, 1] = mask_bw # G | |
| rgba[:, :, 2] = mask_bw # B | |
| rgba[:, :, 3] = 255 # Fully opaque | |
| log.info(f"Loaded {img.mode} mask: {int((mask_bw > 0).sum())} white pixels (to remove)") | |
| return rgba | |
| # For RGBA: check if alpha channel is meaningful | |
| arr = np.array(img) | |
| alpha = arr[:, :, 3] | |
| rgb = arr[:, :, :3] | |
| # If alpha is mostly opaque everywhere (mean > 200), treat RGB channels as mask values | |
| if alpha.mean() > 200: | |
| # Use RGB to determine mask: white/bright in RGB = remove | |
| gray = cv2.cvtColor(rgb, cv2.COLOR_RGB2GRAY) | |
| # Also detect magenta specifically | |
| magenta = np.all(rgb == [255, 0, 255], axis=2).astype(np.uint8) * 255 | |
| mask_bw = np.maximum(np.where(gray > 128, 255, 0).astype(np.uint8), magenta) | |
| rgba = arr.copy() | |
| rgba[:, :, 0] = mask_bw # R | |
| rgba[:, :, 1] = mask_bw # G | |
| rgba[:, :, 2] = mask_bw # B | |
| rgba[:, :, 3] = 255 # Fully opaque | |
| log.info(f"Loaded RGBA mask (RGB-based): {int((mask_bw > 0).sum())} white pixels (to remove)") | |
| return rgba | |
| # Alpha channel encodes the mask - convert to RGB-based | |
| # Transparent areas (alpha < 128) = remove, Opaque areas = keep | |
| mask_bw = np.where(alpha < 128, 255, 0).astype(np.uint8) | |
| rgba = arr.copy() | |
| rgba[:, :, 0] = mask_bw | |
| rgba[:, :, 1] = mask_bw | |
| rgba[:, :, 2] = mask_bw | |
| rgba[:, :, 3] = 255 | |
| log.info(f"Loaded RGBA mask (alpha-based): {int((mask_bw > 0).sum())} white pixels (to remove)") | |
| return rgba | |
| def inpaint(req: InpaintRequest, _: None = Depends(bearer_auth)) -> Dict[str, str]: | |
| if req.image_id not in file_store or file_store[req.image_id]["type"] != "image": | |
| raise HTTPException(status_code=404, detail="image_id not found") | |
| if req.mask_id not in file_store or file_store[req.mask_id]["type"] != "mask": | |
| raise HTTPException(status_code=404, detail="mask_id not found") | |
| img_rgba = _load_rgba_image(file_store[req.image_id]["path"]) | |
| mask_img = Image.open(file_store[req.mask_id]["path"]) # may be RGB/gray/RGBA | |
| mask_rgba = _load_rgba_mask_from_image(mask_img) | |
| if req.passthrough: | |
| result = np.array(img_rgba.convert("RGB")) | |
| else: | |
| result = process_inpaint(np.array(img_rgba), mask_rgba, invert_mask=req.invert_mask) | |
| result_name = f"output_{uuid.uuid4().hex}.png" | |
| result_path = os.path.join(OUTPUT_DIR, result_name) | |
| Image.fromarray(result).save(result_path) | |
| logs.append({"result": result_name, "timestamp": datetime.utcnow().isoformat()}) | |
| return {"result": result_name} | |
| def inpaint_url(req: InpaintRequest, request: Request, _: None = Depends(bearer_auth)) -> Dict[str, str]: | |
| """Same as /inpaint but returns a JSON with a public download URL instead of image bytes.""" | |
| if req.image_id not in file_store or file_store[req.image_id]["type"] != "image": | |
| raise HTTPException(status_code=404, detail="image_id not found") | |
| if req.mask_id not in file_store or file_store[req.mask_id]["type"] != "mask": | |
| raise HTTPException(status_code=404, detail="mask_id not found") | |
| img_rgba = _load_rgba_image(file_store[req.image_id]["path"]) | |
| mask_img = Image.open(file_store[req.mask_id]["path"]) # may be RGB/gray/RGBA | |
| mask_rgba = _load_rgba_mask_from_image(mask_img) | |
| if req.passthrough: | |
| result = np.array(img_rgba.convert("RGB")) | |
| else: | |
| result = process_inpaint(np.array(img_rgba), mask_rgba, invert_mask=req.invert_mask) | |
| result_name = f"output_{uuid.uuid4().hex}.png" | |
| result_path = os.path.join(OUTPUT_DIR, result_name) | |
| Image.fromarray(result).save(result_path) | |
| url = str(request.url_for("download_file", filename=result_name)) | |
| logs.append({"result": result_name, "url": url, "timestamp": datetime.utcnow().isoformat()}) | |
| return {"result": result_name, "url": url} | |
| def inpaint_multipart( | |
| image: UploadFile = File(...), | |
| mask: UploadFile = File(...), | |
| request: Request = None, | |
| invert_mask: bool = True, | |
| mask_is_painted: bool = False, # if True, mask file is the painted-on image (e.g., black strokes on original) | |
| passthrough: bool = False, | |
| _: None = Depends(bearer_auth), | |
| ) -> Dict[str, str]: | |
| # Load in-memory | |
| img = Image.open(image.file).convert("RGBA") | |
| m = Image.open(mask.file).convert("RGBA") | |
| if passthrough: | |
| # Just echo the input image, ignore mask | |
| result = np.array(img.convert("RGB")) | |
| result_name = f"output_{uuid.uuid4().hex}.png" | |
| result_path = os.path.join(OUTPUT_DIR, result_name) | |
| Image.fromarray(result).save(result_path) | |
| url: Optional[str] = None | |
| try: | |
| if request is not None: | |
| url = str(request.url_for("download_file", filename=result_name)) | |
| except Exception: | |
| url = None | |
| entry: Dict[str, str] = {"result": result_name, "timestamp": datetime.utcnow().isoformat()} | |
| if url: | |
| entry["url"] = url | |
| logs.append(entry) | |
| resp: Dict[str, str] = {"result": result_name} | |
| if url: | |
| resp["url"] = url | |
| return resp | |
| if mask_is_painted: | |
| # Auto-detect pink/magenta paint and convert to black/white mask | |
| # White pixels = areas to remove, Black pixels = areas to keep | |
| log.info("Auto-detecting pink/magenta paint from uploaded image...") | |
| m_rgb = cv2.cvtColor(np.array(m), cv2.COLOR_RGBA2RGB) | |
| # Method 1: Detect magenta/pink paint directly (RGB: 255, 0, 255) | |
| # Allow some tolerance for slight variations (e.g., 250-255 for R/B, 0-10 for G) | |
| magenta_detected = ( | |
| (m_rgb[:, :, 0] > 240) & # Red channel: high (240-255) | |
| (m_rgb[:, :, 1] < 30) & # Green channel: low (0-30) | |
| (m_rgb[:, :, 2] > 240) # Blue channel: high (240-255) | |
| ).astype(np.uint8) * 255 | |
| # Method 2: Also check if original image was provided to find differences | |
| if img is not None: | |
| img_rgb = cv2.cvtColor(np.array(img), cv2.COLOR_RGBA2RGB) | |
| if img_rgb.shape == m_rgb.shape: | |
| diff = cv2.absdiff(img_rgb, m_rgb) | |
| gray_diff = cv2.cvtColor(diff, cv2.COLOR_RGB2GRAY) | |
| # Any significant difference (>50) could be paint | |
| diff_mask = (gray_diff > 50).astype(np.uint8) * 255 | |
| # Combine with magenta detection | |
| binmask = cv2.bitwise_or(magenta_detected, diff_mask) | |
| else: | |
| binmask = magenta_detected | |
| else: | |
| # No original image provided, use magenta detection only | |
| binmask = magenta_detected | |
| # Clean up the mask: remove noise and fill small holes | |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) | |
| # Close small gaps in the mask | |
| binmask = cv2.morphologyEx(binmask, cv2.MORPH_CLOSE, kernel, iterations=2) | |
| # Remove small noise | |
| binmask = cv2.morphologyEx(binmask, cv2.MORPH_OPEN, kernel, iterations=1) | |
| nonzero = int((binmask > 0).sum()) | |
| log.info("Pink/magenta paint detected: %d pixels marked for removal (white)", nonzero) | |
| if nonzero < 50: | |
| log.warning("Very few pixels detected! Trying stricter magenta detection...") | |
| # Try more strict magenta detection (exact match) | |
| magenta_strict = np.all(m_rgb == [255, 0, 255], axis=2).astype(np.uint8) * 255 | |
| binmask = cv2.morphologyEx(magenta_strict, cv2.MORPH_CLOSE, kernel, iterations=3) | |
| nonzero = int((binmask > 0).sum()) | |
| log.info("Strict magenta detection: %d pixels", nonzero) | |
| if nonzero < 50: | |
| log.error("CRITICAL: Could not detect pink/magenta paint! Returning original image.") | |
| result = np.array(img.convert("RGB")) if img else np.array(m.convert("RGB")) | |
| result_name = f"output_{uuid.uuid4().hex}.png" | |
| result_path = os.path.join(OUTPUT_DIR, result_name) | |
| Image.fromarray(result).save(result_path) | |
| return {"result": result_name, "error": "pink/magenta paint detection failed - very few pixels detected"} | |
| # Create black/white mask: white = remove (pink areas), black = keep (everything else) | |
| mask_rgba = np.zeros((binmask.shape[0], binmask.shape[1], 4), dtype=np.uint8) | |
| mask_rgba[:, :, 0] = binmask # R: white where pink detected | |
| mask_rgba[:, :, 1] = binmask # G: white where pink detected | |
| mask_rgba[:, :, 2] = binmask # B: white where pink detected | |
| mask_rgba[:, :, 3] = 255 # Alpha: fully opaque | |
| log.info("Successfully created black/white mask: %d white pixels (to remove), %d black pixels (to keep)", | |
| nonzero, binmask.shape[0] * binmask.shape[1] - nonzero) | |
| else: | |
| mask_rgba = _load_rgba_mask_from_image(m) | |
| # When mask_is_painted=true, we create white=remove masks, so invert_mask should be False | |
| # (white pixels should stay white to indicate removal) | |
| actual_invert = invert_mask if not mask_is_painted else False | |
| log.info("Using invert_mask=%s (mask_is_painted=%s)", actual_invert, mask_is_painted) | |
| result = process_inpaint(np.array(img), mask_rgba, invert_mask=actual_invert) | |
| result_name = f"output_{uuid.uuid4().hex}.png" | |
| result_path = os.path.join(OUTPUT_DIR, result_name) | |
| Image.fromarray(result).save(result_path) | |
| url: Optional[str] = None | |
| try: | |
| if request is not None: | |
| url = str(request.url_for("download_file", filename=result_name)) | |
| except Exception: | |
| url = None | |
| entry: Dict[str, str] = {"result": result_name, "timestamp": datetime.utcnow().isoformat()} | |
| if url: | |
| entry["url"] = url | |
| logs.append(entry) | |
| resp: Dict[str, str] = {"result": result_name} | |
| if url: | |
| resp["url"] = url | |
| return resp | |
| def remove_pink_segments( | |
| image: UploadFile = File(...), | |
| request: Request = None, | |
| _: None = Depends(bearer_auth), | |
| ) -> Dict[str, str]: | |
| """ | |
| Simple endpoint: upload an image with pink/magenta segments to remove. | |
| - Pink/Magenta segments → automatically removed (white in mask) | |
| - Everything else → automatically kept (black in mask) | |
| Just paint pink/magenta on areas you want to remove, upload the image, and it works! | |
| """ | |
| log.info(f"Simple remove-pink: processing image {image.filename}") | |
| # Load the image (with pink paint on it) | |
| img = Image.open(image.file).convert("RGBA") | |
| img_rgb = cv2.cvtColor(np.array(img), cv2.COLOR_RGBA2RGB) | |
| # Auto-detect pink/magenta segments to remove | |
| # Pink/Magenta → white in mask (remove) | |
| # Everything else (natural image colors, including dark areas) → black in mask (keep) | |
| # Detect pink/magenta paint (RGB: R>240, G<30, B>240) | |
| magenta_detected = ( | |
| (img_rgb[:, :, 0] > 240) & # Red: high | |
| (img_rgb[:, :, 1] < 30) & # Green: low | |
| (img_rgb[:, :, 2] > 240) # Blue: high | |
| ).astype(np.uint8) * 255 | |
| # Clean up the pink mask | |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) | |
| binmask = cv2.morphologyEx(magenta_detected, cv2.MORPH_CLOSE, kernel, iterations=2) | |
| binmask = cv2.morphologyEx(binmask, cv2.MORPH_OPEN, kernel, iterations=1) | |
| nonzero = int((binmask > 0).sum()) | |
| log.info(f"Detected {nonzero} pink pixels to remove") | |
| if nonzero < 50: | |
| log.warning("Very few pink pixels detected! Trying strict magenta detection...") | |
| magenta_strict = np.all(img_rgb == [255, 0, 255], axis=2).astype(np.uint8) * 255 | |
| binmask = cv2.morphologyEx(magenta_strict, cv2.MORPH_CLOSE, kernel, iterations=3) | |
| nonzero = int((binmask > 0).sum()) | |
| log.info(f"Strict detection: {nonzero} pink pixels") | |
| if nonzero < 50: | |
| log.error("No pink segments detected! Returning original image.") | |
| result = np.array(img.convert("RGB")) | |
| result_name = f"output_{uuid.uuid4().hex}.png" | |
| result_path = os.path.join(OUTPUT_DIR, result_name) | |
| Image.fromarray(result).save(result_path) | |
| return { | |
| "result": result_name, | |
| "error": "No pink/magenta segments detected. Please paint areas to remove with magenta/pink color (RGB 255,0,255)." | |
| } | |
| # Create mask: | |
| # - White = pink areas (remove) | |
| # - Black = everything else (keep) - automatically, no black painting needed | |
| mask_rgba = np.zeros((binmask.shape[0], binmask.shape[1], 4), dtype=np.uint8) | |
| mask_rgba[:, :, 0] = binmask # R: white where pink detected | |
| mask_rgba[:, :, 1] = binmask # G: white where pink detected | |
| mask_rgba[:, :, 2] = binmask # B: white where pink detected | |
| mask_rgba[:, :, 3] = 255 # Alpha: fully opaque | |
| total_pixels = binmask.shape[0] * binmask.shape[1] | |
| log.info(f"Created mask: {nonzero} white pixels (remove/pink), {total_pixels - nonzero} black pixels (keep - everything else automatically)") | |
| # Process with invert_mask=False because white pixels = remove (standard) | |
| result = process_inpaint(np.array(img), mask_rgba, invert_mask=False) | |
| result_name = f"output_{uuid.uuid4().hex}.png" | |
| result_path = os.path.join(OUTPUT_DIR, result_name) | |
| Image.fromarray(result).save(result_path) | |
| url: Optional[str] = None | |
| try: | |
| if request is not None: | |
| url = str(request.url_for("download_file", filename=result_name)) | |
| except Exception: | |
| url = None | |
| logs.append({ | |
| "result": result_name, | |
| "filename": image.filename, | |
| "pink_pixels": nonzero, | |
| "timestamp": datetime.utcnow().isoformat() | |
| }) | |
| resp: Dict[str, str] = {"result": result_name, "pink_segments_detected": str(nonzero)} | |
| if url: | |
| resp["url"] = url | |
| return resp | |
| def download_file(filename: str): | |
| path = os.path.join(OUTPUT_DIR, filename) | |
| if not os.path.isfile(path): | |
| raise HTTPException(status_code=404, detail="file not found") | |
| return FileResponse(path) | |
| def view_result(filename: str): | |
| """View result image directly in browser (same as download but with proper content-type for viewing)""" | |
| path = os.path.join(OUTPUT_DIR, filename) | |
| if not os.path.isfile(path): | |
| raise HTTPException(status_code=404, detail="file not found") | |
| return FileResponse(path, media_type="image/png") | |
| def get_logs(_: None = Depends(bearer_auth)) -> JSONResponse: | |
| return JSONResponse(content=logs) |