import os import uuid import shutil from datetime import datetime from typing import Dict, List, Optional import numpy as np from fastapi import FastAPI, UploadFile, File, HTTPException, Depends, Header, Request from fastapi.responses import FileResponse, JSONResponse from pydantic import BaseModel from PIL import Image import cv2 import logging logging.basicConfig(level=logging.INFO) log = logging.getLogger("api") from src.core import process_inpaint # Directories (use writable space on HF Spaces) BASE_DIR = os.environ.get("DATA_DIR", "/data") if not os.path.isdir(BASE_DIR): # Fallback to /tmp if /data not available BASE_DIR = "/tmp" UPLOAD_DIR = os.path.join(BASE_DIR, "uploads") OUTPUT_DIR = os.path.join(BASE_DIR, "outputs") os.makedirs(UPLOAD_DIR, exist_ok=True) os.makedirs(OUTPUT_DIR, exist_ok=True) # Optional Bearer token: set env API_TOKEN to require auth; if not set, endpoints are open ENV_TOKEN = os.environ.get("API_TOKEN") app = FastAPI(title="Photo Object Removal API", version="1.0.0") # In-memory stores file_store: Dict[str, Dict[str, str]] = {} logs: List[Dict[str, str]] = [] def bearer_auth(authorization: Optional[str] = Header(default=None)) -> None: if not ENV_TOKEN: return if authorization is None or not authorization.lower().startswith("bearer "): raise HTTPException(status_code=401, detail="Unauthorized") token = authorization.split(" ", 1)[1] if token != ENV_TOKEN: raise HTTPException(status_code=403, detail="Forbidden") class InpaintRequest(BaseModel): image_id: str mask_id: str invert_mask: bool = True # True => selected/painted area is removed passthrough: bool = False # If True, return the original image unchanged class SimpleRemoveRequest(BaseModel): image_id: str # Image with pink/magenta segments to remove @app.get("/") def root() -> Dict[str, object]: return { "name": "Photo Object Removal API", "status": "ok", "endpoints": { "GET /health": "health check", "POST /upload-image": "form-data: image=file", "POST /upload-mask": "form-data: mask=file", "POST /inpaint": "JSON: {image_id, mask_id}", "POST /inpaint-multipart": "form-data: image=file, mask=file", "POST /remove-pink": "form-data: image=file (auto-detects pink segments and removes them)", "GET /download/{filename}": "download result image", "GET /result/{filename}": "view result image in browser", "GET /logs": "recent uploads/results", }, "auth": "set API_TOKEN env var to require Authorization: Bearer (except /health)", } @app.get("/health") def health() -> Dict[str, str]: return {"status": "healthy"} @app.post("/upload-image") def upload_image(image: UploadFile = File(...), _: None = Depends(bearer_auth)) -> Dict[str, str]: ext = os.path.splitext(image.filename)[1] or ".png" file_id = str(uuid.uuid4()) stored_name = f"{file_id}{ext}" stored_path = os.path.join(UPLOAD_DIR, stored_name) with open(stored_path, "wb") as f: shutil.copyfileobj(image.file, f) file_store[file_id] = { "type": "image", "filename": image.filename, "stored_name": stored_name, "path": stored_path, "timestamp": datetime.utcnow().isoformat(), } logs.append({"id": file_id, "filename": image.filename, "type": "image", "timestamp": datetime.utcnow().isoformat()}) return {"id": file_id, "filename": image.filename} @app.post("/upload-mask") def upload_mask(mask: UploadFile = File(...), _: None = Depends(bearer_auth)) -> Dict[str, str]: ext = os.path.splitext(mask.filename)[1] or ".png" file_id = str(uuid.uuid4()) stored_name = f"{file_id}{ext}" stored_path = os.path.join(UPLOAD_DIR, stored_name) with open(stored_path, "wb") as f: shutil.copyfileobj(mask.file, f) file_store[file_id] = { "type": "mask", "filename": mask.filename, "stored_name": stored_name, "path": stored_path, "timestamp": datetime.utcnow().isoformat(), } logs.append({"id": file_id, "filename": mask.filename, "type": "mask", "timestamp": datetime.utcnow().isoformat()}) return {"id": file_id, "filename": mask.filename} def _load_rgba_image(path: str) -> Image.Image: img = Image.open(path) return img.convert("RGBA") def _load_rgba_mask_from_image(img: Image.Image) -> np.ndarray: """ Convert mask image to RGBA format (black/white mask). Standard convention: white (255) = area to remove, black (0) = area to keep Returns RGBA with white in RGB channels where removal is needed, alpha=255 """ if img.mode != "RGBA": # For RGB/Grayscale masks: white (value>128) = remove, black (value<=128) = keep gray = img.convert("L") arr = np.array(gray) # Create proper black/white mask: white pixels (>128) = remove, black (<=128) = keep mask_bw = np.where(arr > 128, 255, 0).astype(np.uint8) rgba = np.zeros((img.height, img.width, 4), dtype=np.uint8) rgba[:, :, 0] = mask_bw # R rgba[:, :, 1] = mask_bw # G rgba[:, :, 2] = mask_bw # B rgba[:, :, 3] = 255 # Fully opaque log.info(f"Loaded {img.mode} mask: {int((mask_bw > 0).sum())} white pixels (to remove)") return rgba # For RGBA: check if alpha channel is meaningful arr = np.array(img) alpha = arr[:, :, 3] rgb = arr[:, :, :3] # If alpha is mostly opaque everywhere (mean > 200), treat RGB channels as mask values if alpha.mean() > 200: # Use RGB to determine mask: white/bright in RGB = remove gray = cv2.cvtColor(rgb, cv2.COLOR_RGB2GRAY) # Also detect magenta specifically magenta = np.all(rgb == [255, 0, 255], axis=2).astype(np.uint8) * 255 mask_bw = np.maximum(np.where(gray > 128, 255, 0).astype(np.uint8), magenta) rgba = arr.copy() rgba[:, :, 0] = mask_bw # R rgba[:, :, 1] = mask_bw # G rgba[:, :, 2] = mask_bw # B rgba[:, :, 3] = 255 # Fully opaque log.info(f"Loaded RGBA mask (RGB-based): {int((mask_bw > 0).sum())} white pixels (to remove)") return rgba # Alpha channel encodes the mask - convert to RGB-based # Transparent areas (alpha < 128) = remove, Opaque areas = keep mask_bw = np.where(alpha < 128, 255, 0).astype(np.uint8) rgba = arr.copy() rgba[:, :, 0] = mask_bw rgba[:, :, 1] = mask_bw rgba[:, :, 2] = mask_bw rgba[:, :, 3] = 255 log.info(f"Loaded RGBA mask (alpha-based): {int((mask_bw > 0).sum())} white pixels (to remove)") return rgba @app.post("/inpaint") def inpaint(req: InpaintRequest, _: None = Depends(bearer_auth)) -> Dict[str, str]: if req.image_id not in file_store or file_store[req.image_id]["type"] != "image": raise HTTPException(status_code=404, detail="image_id not found") if req.mask_id not in file_store or file_store[req.mask_id]["type"] != "mask": raise HTTPException(status_code=404, detail="mask_id not found") img_rgba = _load_rgba_image(file_store[req.image_id]["path"]) mask_img = Image.open(file_store[req.mask_id]["path"]) # may be RGB/gray/RGBA mask_rgba = _load_rgba_mask_from_image(mask_img) if req.passthrough: result = np.array(img_rgba.convert("RGB")) else: result = process_inpaint(np.array(img_rgba), mask_rgba, invert_mask=req.invert_mask) result_name = f"output_{uuid.uuid4().hex}.png" result_path = os.path.join(OUTPUT_DIR, result_name) Image.fromarray(result).save(result_path) logs.append({"result": result_name, "timestamp": datetime.utcnow().isoformat()}) return {"result": result_name} @app.post("/inpaint-url") def inpaint_url(req: InpaintRequest, request: Request, _: None = Depends(bearer_auth)) -> Dict[str, str]: """Same as /inpaint but returns a JSON with a public download URL instead of image bytes.""" if req.image_id not in file_store or file_store[req.image_id]["type"] != "image": raise HTTPException(status_code=404, detail="image_id not found") if req.mask_id not in file_store or file_store[req.mask_id]["type"] != "mask": raise HTTPException(status_code=404, detail="mask_id not found") img_rgba = _load_rgba_image(file_store[req.image_id]["path"]) mask_img = Image.open(file_store[req.mask_id]["path"]) # may be RGB/gray/RGBA mask_rgba = _load_rgba_mask_from_image(mask_img) if req.passthrough: result = np.array(img_rgba.convert("RGB")) else: result = process_inpaint(np.array(img_rgba), mask_rgba, invert_mask=req.invert_mask) result_name = f"output_{uuid.uuid4().hex}.png" result_path = os.path.join(OUTPUT_DIR, result_name) Image.fromarray(result).save(result_path) url = str(request.url_for("download_file", filename=result_name)) logs.append({"result": result_name, "url": url, "timestamp": datetime.utcnow().isoformat()}) return {"result": result_name, "url": url} @app.post("/inpaint-multipart") def inpaint_multipart( image: UploadFile = File(...), mask: UploadFile = File(...), request: Request = None, invert_mask: bool = True, mask_is_painted: bool = False, # if True, mask file is the painted-on image (e.g., black strokes on original) passthrough: bool = False, _: None = Depends(bearer_auth), ) -> Dict[str, str]: # Load in-memory img = Image.open(image.file).convert("RGBA") m = Image.open(mask.file).convert("RGBA") if passthrough: # Just echo the input image, ignore mask result = np.array(img.convert("RGB")) result_name = f"output_{uuid.uuid4().hex}.png" result_path = os.path.join(OUTPUT_DIR, result_name) Image.fromarray(result).save(result_path) url: Optional[str] = None try: if request is not None: url = str(request.url_for("download_file", filename=result_name)) except Exception: url = None entry: Dict[str, str] = {"result": result_name, "timestamp": datetime.utcnow().isoformat()} if url: entry["url"] = url logs.append(entry) resp: Dict[str, str] = {"result": result_name} if url: resp["url"] = url return resp if mask_is_painted: # Auto-detect pink/magenta paint and convert to black/white mask # White pixels = areas to remove, Black pixels = areas to keep log.info("Auto-detecting pink/magenta paint from uploaded image...") m_rgb = cv2.cvtColor(np.array(m), cv2.COLOR_RGBA2RGB) # Method 1: Detect magenta/pink paint directly (RGB: 255, 0, 255) # Allow some tolerance for slight variations (e.g., 250-255 for R/B, 0-10 for G) magenta_detected = ( (m_rgb[:, :, 0] > 240) & # Red channel: high (240-255) (m_rgb[:, :, 1] < 30) & # Green channel: low (0-30) (m_rgb[:, :, 2] > 240) # Blue channel: high (240-255) ).astype(np.uint8) * 255 # Method 2: Also check if original image was provided to find differences if img is not None: img_rgb = cv2.cvtColor(np.array(img), cv2.COLOR_RGBA2RGB) if img_rgb.shape == m_rgb.shape: diff = cv2.absdiff(img_rgb, m_rgb) gray_diff = cv2.cvtColor(diff, cv2.COLOR_RGB2GRAY) # Any significant difference (>50) could be paint diff_mask = (gray_diff > 50).astype(np.uint8) * 255 # Combine with magenta detection binmask = cv2.bitwise_or(magenta_detected, diff_mask) else: binmask = magenta_detected else: # No original image provided, use magenta detection only binmask = magenta_detected # Clean up the mask: remove noise and fill small holes kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) # Close small gaps in the mask binmask = cv2.morphologyEx(binmask, cv2.MORPH_CLOSE, kernel, iterations=2) # Remove small noise binmask = cv2.morphologyEx(binmask, cv2.MORPH_OPEN, kernel, iterations=1) nonzero = int((binmask > 0).sum()) log.info("Pink/magenta paint detected: %d pixels marked for removal (white)", nonzero) if nonzero < 50: log.warning("Very few pixels detected! Trying stricter magenta detection...") # Try more strict magenta detection (exact match) magenta_strict = np.all(m_rgb == [255, 0, 255], axis=2).astype(np.uint8) * 255 binmask = cv2.morphologyEx(magenta_strict, cv2.MORPH_CLOSE, kernel, iterations=3) nonzero = int((binmask > 0).sum()) log.info("Strict magenta detection: %d pixels", nonzero) if nonzero < 50: log.error("CRITICAL: Could not detect pink/magenta paint! Returning original image.") result = np.array(img.convert("RGB")) if img else np.array(m.convert("RGB")) result_name = f"output_{uuid.uuid4().hex}.png" result_path = os.path.join(OUTPUT_DIR, result_name) Image.fromarray(result).save(result_path) return {"result": result_name, "error": "pink/magenta paint detection failed - very few pixels detected"} # Create black/white mask: white = remove (pink areas), black = keep (everything else) mask_rgba = np.zeros((binmask.shape[0], binmask.shape[1], 4), dtype=np.uint8) mask_rgba[:, :, 0] = binmask # R: white where pink detected mask_rgba[:, :, 1] = binmask # G: white where pink detected mask_rgba[:, :, 2] = binmask # B: white where pink detected mask_rgba[:, :, 3] = 255 # Alpha: fully opaque log.info("Successfully created black/white mask: %d white pixels (to remove), %d black pixels (to keep)", nonzero, binmask.shape[0] * binmask.shape[1] - nonzero) else: mask_rgba = _load_rgba_mask_from_image(m) # When mask_is_painted=true, we create white=remove masks, so invert_mask should be False # (white pixels should stay white to indicate removal) actual_invert = invert_mask if not mask_is_painted else False log.info("Using invert_mask=%s (mask_is_painted=%s)", actual_invert, mask_is_painted) result = process_inpaint(np.array(img), mask_rgba, invert_mask=actual_invert) result_name = f"output_{uuid.uuid4().hex}.png" result_path = os.path.join(OUTPUT_DIR, result_name) Image.fromarray(result).save(result_path) url: Optional[str] = None try: if request is not None: url = str(request.url_for("download_file", filename=result_name)) except Exception: url = None entry: Dict[str, str] = {"result": result_name, "timestamp": datetime.utcnow().isoformat()} if url: entry["url"] = url logs.append(entry) resp: Dict[str, str] = {"result": result_name} if url: resp["url"] = url return resp @app.post("/remove-pink") def remove_pink_segments( image: UploadFile = File(...), request: Request = None, _: None = Depends(bearer_auth), ) -> Dict[str, str]: """ Simple endpoint: upload an image with pink/magenta segments to remove. - Pink/Magenta segments → automatically removed (white in mask) - Everything else → automatically kept (black in mask) Just paint pink/magenta on areas you want to remove, upload the image, and it works! """ log.info(f"Simple remove-pink: processing image {image.filename}") # Load the image (with pink paint on it) img = Image.open(image.file).convert("RGBA") img_rgb = cv2.cvtColor(np.array(img), cv2.COLOR_RGBA2RGB) # Auto-detect pink/magenta segments to remove # Pink/Magenta → white in mask (remove) # Everything else (natural image colors, including dark areas) → black in mask (keep) # Detect pink/magenta paint (RGB: R>240, G<30, B>240) magenta_detected = ( (img_rgb[:, :, 0] > 240) & # Red: high (img_rgb[:, :, 1] < 30) & # Green: low (img_rgb[:, :, 2] > 240) # Blue: high ).astype(np.uint8) * 255 # Clean up the pink mask kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) binmask = cv2.morphologyEx(magenta_detected, cv2.MORPH_CLOSE, kernel, iterations=2) binmask = cv2.morphologyEx(binmask, cv2.MORPH_OPEN, kernel, iterations=1) nonzero = int((binmask > 0).sum()) log.info(f"Detected {nonzero} pink pixels to remove") if nonzero < 50: log.warning("Very few pink pixels detected! Trying strict magenta detection...") magenta_strict = np.all(img_rgb == [255, 0, 255], axis=2).astype(np.uint8) * 255 binmask = cv2.morphologyEx(magenta_strict, cv2.MORPH_CLOSE, kernel, iterations=3) nonzero = int((binmask > 0).sum()) log.info(f"Strict detection: {nonzero} pink pixels") if nonzero < 50: log.error("No pink segments detected! Returning original image.") result = np.array(img.convert("RGB")) result_name = f"output_{uuid.uuid4().hex}.png" result_path = os.path.join(OUTPUT_DIR, result_name) Image.fromarray(result).save(result_path) return { "result": result_name, "error": "No pink/magenta segments detected. Please paint areas to remove with magenta/pink color (RGB 255,0,255)." } # Create mask: # - White = pink areas (remove) # - Black = everything else (keep) - automatically, no black painting needed mask_rgba = np.zeros((binmask.shape[0], binmask.shape[1], 4), dtype=np.uint8) mask_rgba[:, :, 0] = binmask # R: white where pink detected mask_rgba[:, :, 1] = binmask # G: white where pink detected mask_rgba[:, :, 2] = binmask # B: white where pink detected mask_rgba[:, :, 3] = 255 # Alpha: fully opaque total_pixels = binmask.shape[0] * binmask.shape[1] log.info(f"Created mask: {nonzero} white pixels (remove/pink), {total_pixels - nonzero} black pixels (keep - everything else automatically)") # Process with invert_mask=False because white pixels = remove (standard) result = process_inpaint(np.array(img), mask_rgba, invert_mask=False) result_name = f"output_{uuid.uuid4().hex}.png" result_path = os.path.join(OUTPUT_DIR, result_name) Image.fromarray(result).save(result_path) url: Optional[str] = None try: if request is not None: url = str(request.url_for("download_file", filename=result_name)) except Exception: url = None logs.append({ "result": result_name, "filename": image.filename, "pink_pixels": nonzero, "timestamp": datetime.utcnow().isoformat() }) resp: Dict[str, str] = {"result": result_name, "pink_segments_detected": str(nonzero)} if url: resp["url"] = url return resp @app.get("/download/{filename}") def download_file(filename: str): path = os.path.join(OUTPUT_DIR, filename) if not os.path.isfile(path): raise HTTPException(status_code=404, detail="file not found") return FileResponse(path) @app.get("/result/{filename}") def view_result(filename: str): """View result image directly in browser (same as download but with proper content-type for viewing)""" path = os.path.join(OUTPUT_DIR, filename) if not os.path.isfile(path): raise HTTPException(status_code=404, detail="file not found") return FileResponse(path, media_type="image/png") @app.get("/logs") def get_logs(_: None = Depends(bearer_auth)) -> JSONResponse: return JSONResponse(content=logs)