object_remover / api /main.py
LogicGoInfotechSpaces's picture
fix(api): remove black paint detection - just auto-detect pink segments, everything else kept automatically
87c6da8
raw
history blame
19.9 kB
import os
import uuid
import shutil
from datetime import datetime
from typing import Dict, List, Optional
import numpy as np
from fastapi import FastAPI, UploadFile, File, HTTPException, Depends, Header, Request
from fastapi.responses import FileResponse, JSONResponse
from pydantic import BaseModel
from PIL import Image
import cv2
import logging
logging.basicConfig(level=logging.INFO)
log = logging.getLogger("api")
from src.core import process_inpaint
# Directories (use writable space on HF Spaces)
BASE_DIR = os.environ.get("DATA_DIR", "/data")
if not os.path.isdir(BASE_DIR):
# Fallback to /tmp if /data not available
BASE_DIR = "/tmp"
UPLOAD_DIR = os.path.join(BASE_DIR, "uploads")
OUTPUT_DIR = os.path.join(BASE_DIR, "outputs")
os.makedirs(UPLOAD_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)
# Optional Bearer token: set env API_TOKEN to require auth; if not set, endpoints are open
ENV_TOKEN = os.environ.get("API_TOKEN")
app = FastAPI(title="Photo Object Removal API", version="1.0.0")
# In-memory stores
file_store: Dict[str, Dict[str, str]] = {}
logs: List[Dict[str, str]] = []
def bearer_auth(authorization: Optional[str] = Header(default=None)) -> None:
if not ENV_TOKEN:
return
if authorization is None or not authorization.lower().startswith("bearer "):
raise HTTPException(status_code=401, detail="Unauthorized")
token = authorization.split(" ", 1)[1]
if token != ENV_TOKEN:
raise HTTPException(status_code=403, detail="Forbidden")
class InpaintRequest(BaseModel):
image_id: str
mask_id: str
invert_mask: bool = True # True => selected/painted area is removed
passthrough: bool = False # If True, return the original image unchanged
class SimpleRemoveRequest(BaseModel):
image_id: str # Image with pink/magenta segments to remove
@app.get("/")
def root() -> Dict[str, object]:
return {
"name": "Photo Object Removal API",
"status": "ok",
"endpoints": {
"GET /health": "health check",
"POST /upload-image": "form-data: image=file",
"POST /upload-mask": "form-data: mask=file",
"POST /inpaint": "JSON: {image_id, mask_id}",
"POST /inpaint-multipart": "form-data: image=file, mask=file",
"POST /remove-pink": "form-data: image=file (auto-detects pink segments and removes them)",
"GET /download/{filename}": "download result image",
"GET /result/{filename}": "view result image in browser",
"GET /logs": "recent uploads/results",
},
"auth": "set API_TOKEN env var to require Authorization: Bearer <token> (except /health)",
}
@app.get("/health")
def health() -> Dict[str, str]:
return {"status": "healthy"}
@app.post("/upload-image")
def upload_image(image: UploadFile = File(...), _: None = Depends(bearer_auth)) -> Dict[str, str]:
ext = os.path.splitext(image.filename)[1] or ".png"
file_id = str(uuid.uuid4())
stored_name = f"{file_id}{ext}"
stored_path = os.path.join(UPLOAD_DIR, stored_name)
with open(stored_path, "wb") as f:
shutil.copyfileobj(image.file, f)
file_store[file_id] = {
"type": "image",
"filename": image.filename,
"stored_name": stored_name,
"path": stored_path,
"timestamp": datetime.utcnow().isoformat(),
}
logs.append({"id": file_id, "filename": image.filename, "type": "image", "timestamp": datetime.utcnow().isoformat()})
return {"id": file_id, "filename": image.filename}
@app.post("/upload-mask")
def upload_mask(mask: UploadFile = File(...), _: None = Depends(bearer_auth)) -> Dict[str, str]:
ext = os.path.splitext(mask.filename)[1] or ".png"
file_id = str(uuid.uuid4())
stored_name = f"{file_id}{ext}"
stored_path = os.path.join(UPLOAD_DIR, stored_name)
with open(stored_path, "wb") as f:
shutil.copyfileobj(mask.file, f)
file_store[file_id] = {
"type": "mask",
"filename": mask.filename,
"stored_name": stored_name,
"path": stored_path,
"timestamp": datetime.utcnow().isoformat(),
}
logs.append({"id": file_id, "filename": mask.filename, "type": "mask", "timestamp": datetime.utcnow().isoformat()})
return {"id": file_id, "filename": mask.filename}
def _load_rgba_image(path: str) -> Image.Image:
img = Image.open(path)
return img.convert("RGBA")
def _load_rgba_mask_from_image(img: Image.Image) -> np.ndarray:
"""
Convert mask image to RGBA format (black/white mask).
Standard convention: white (255) = area to remove, black (0) = area to keep
Returns RGBA with white in RGB channels where removal is needed, alpha=255
"""
if img.mode != "RGBA":
# For RGB/Grayscale masks: white (value>128) = remove, black (value<=128) = keep
gray = img.convert("L")
arr = np.array(gray)
# Create proper black/white mask: white pixels (>128) = remove, black (<=128) = keep
mask_bw = np.where(arr > 128, 255, 0).astype(np.uint8)
rgba = np.zeros((img.height, img.width, 4), dtype=np.uint8)
rgba[:, :, 0] = mask_bw # R
rgba[:, :, 1] = mask_bw # G
rgba[:, :, 2] = mask_bw # B
rgba[:, :, 3] = 255 # Fully opaque
log.info(f"Loaded {img.mode} mask: {int((mask_bw > 0).sum())} white pixels (to remove)")
return rgba
# For RGBA: check if alpha channel is meaningful
arr = np.array(img)
alpha = arr[:, :, 3]
rgb = arr[:, :, :3]
# If alpha is mostly opaque everywhere (mean > 200), treat RGB channels as mask values
if alpha.mean() > 200:
# Use RGB to determine mask: white/bright in RGB = remove
gray = cv2.cvtColor(rgb, cv2.COLOR_RGB2GRAY)
# Also detect magenta specifically
magenta = np.all(rgb == [255, 0, 255], axis=2).astype(np.uint8) * 255
mask_bw = np.maximum(np.where(gray > 128, 255, 0).astype(np.uint8), magenta)
rgba = arr.copy()
rgba[:, :, 0] = mask_bw # R
rgba[:, :, 1] = mask_bw # G
rgba[:, :, 2] = mask_bw # B
rgba[:, :, 3] = 255 # Fully opaque
log.info(f"Loaded RGBA mask (RGB-based): {int((mask_bw > 0).sum())} white pixels (to remove)")
return rgba
# Alpha channel encodes the mask - convert to RGB-based
# Transparent areas (alpha < 128) = remove, Opaque areas = keep
mask_bw = np.where(alpha < 128, 255, 0).astype(np.uint8)
rgba = arr.copy()
rgba[:, :, 0] = mask_bw
rgba[:, :, 1] = mask_bw
rgba[:, :, 2] = mask_bw
rgba[:, :, 3] = 255
log.info(f"Loaded RGBA mask (alpha-based): {int((mask_bw > 0).sum())} white pixels (to remove)")
return rgba
@app.post("/inpaint")
def inpaint(req: InpaintRequest, _: None = Depends(bearer_auth)) -> Dict[str, str]:
if req.image_id not in file_store or file_store[req.image_id]["type"] != "image":
raise HTTPException(status_code=404, detail="image_id not found")
if req.mask_id not in file_store or file_store[req.mask_id]["type"] != "mask":
raise HTTPException(status_code=404, detail="mask_id not found")
img_rgba = _load_rgba_image(file_store[req.image_id]["path"])
mask_img = Image.open(file_store[req.mask_id]["path"]) # may be RGB/gray/RGBA
mask_rgba = _load_rgba_mask_from_image(mask_img)
if req.passthrough:
result = np.array(img_rgba.convert("RGB"))
else:
result = process_inpaint(np.array(img_rgba), mask_rgba, invert_mask=req.invert_mask)
result_name = f"output_{uuid.uuid4().hex}.png"
result_path = os.path.join(OUTPUT_DIR, result_name)
Image.fromarray(result).save(result_path)
logs.append({"result": result_name, "timestamp": datetime.utcnow().isoformat()})
return {"result": result_name}
@app.post("/inpaint-url")
def inpaint_url(req: InpaintRequest, request: Request, _: None = Depends(bearer_auth)) -> Dict[str, str]:
"""Same as /inpaint but returns a JSON with a public download URL instead of image bytes."""
if req.image_id not in file_store or file_store[req.image_id]["type"] != "image":
raise HTTPException(status_code=404, detail="image_id not found")
if req.mask_id not in file_store or file_store[req.mask_id]["type"] != "mask":
raise HTTPException(status_code=404, detail="mask_id not found")
img_rgba = _load_rgba_image(file_store[req.image_id]["path"])
mask_img = Image.open(file_store[req.mask_id]["path"]) # may be RGB/gray/RGBA
mask_rgba = _load_rgba_mask_from_image(mask_img)
if req.passthrough:
result = np.array(img_rgba.convert("RGB"))
else:
result = process_inpaint(np.array(img_rgba), mask_rgba, invert_mask=req.invert_mask)
result_name = f"output_{uuid.uuid4().hex}.png"
result_path = os.path.join(OUTPUT_DIR, result_name)
Image.fromarray(result).save(result_path)
url = str(request.url_for("download_file", filename=result_name))
logs.append({"result": result_name, "url": url, "timestamp": datetime.utcnow().isoformat()})
return {"result": result_name, "url": url}
@app.post("/inpaint-multipart")
def inpaint_multipart(
image: UploadFile = File(...),
mask: UploadFile = File(...),
request: Request = None,
invert_mask: bool = True,
mask_is_painted: bool = False, # if True, mask file is the painted-on image (e.g., black strokes on original)
passthrough: bool = False,
_: None = Depends(bearer_auth),
) -> Dict[str, str]:
# Load in-memory
img = Image.open(image.file).convert("RGBA")
m = Image.open(mask.file).convert("RGBA")
if passthrough:
# Just echo the input image, ignore mask
result = np.array(img.convert("RGB"))
result_name = f"output_{uuid.uuid4().hex}.png"
result_path = os.path.join(OUTPUT_DIR, result_name)
Image.fromarray(result).save(result_path)
url: Optional[str] = None
try:
if request is not None:
url = str(request.url_for("download_file", filename=result_name))
except Exception:
url = None
entry: Dict[str, str] = {"result": result_name, "timestamp": datetime.utcnow().isoformat()}
if url:
entry["url"] = url
logs.append(entry)
resp: Dict[str, str] = {"result": result_name}
if url:
resp["url"] = url
return resp
if mask_is_painted:
# Auto-detect pink/magenta paint and convert to black/white mask
# White pixels = areas to remove, Black pixels = areas to keep
log.info("Auto-detecting pink/magenta paint from uploaded image...")
m_rgb = cv2.cvtColor(np.array(m), cv2.COLOR_RGBA2RGB)
# Method 1: Detect magenta/pink paint directly (RGB: 255, 0, 255)
# Allow some tolerance for slight variations (e.g., 250-255 for R/B, 0-10 for G)
magenta_detected = (
(m_rgb[:, :, 0] > 240) & # Red channel: high (240-255)
(m_rgb[:, :, 1] < 30) & # Green channel: low (0-30)
(m_rgb[:, :, 2] > 240) # Blue channel: high (240-255)
).astype(np.uint8) * 255
# Method 2: Also check if original image was provided to find differences
if img is not None:
img_rgb = cv2.cvtColor(np.array(img), cv2.COLOR_RGBA2RGB)
if img_rgb.shape == m_rgb.shape:
diff = cv2.absdiff(img_rgb, m_rgb)
gray_diff = cv2.cvtColor(diff, cv2.COLOR_RGB2GRAY)
# Any significant difference (>50) could be paint
diff_mask = (gray_diff > 50).astype(np.uint8) * 255
# Combine with magenta detection
binmask = cv2.bitwise_or(magenta_detected, diff_mask)
else:
binmask = magenta_detected
else:
# No original image provided, use magenta detection only
binmask = magenta_detected
# Clean up the mask: remove noise and fill small holes
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
# Close small gaps in the mask
binmask = cv2.morphologyEx(binmask, cv2.MORPH_CLOSE, kernel, iterations=2)
# Remove small noise
binmask = cv2.morphologyEx(binmask, cv2.MORPH_OPEN, kernel, iterations=1)
nonzero = int((binmask > 0).sum())
log.info("Pink/magenta paint detected: %d pixels marked for removal (white)", nonzero)
if nonzero < 50:
log.warning("Very few pixels detected! Trying stricter magenta detection...")
# Try more strict magenta detection (exact match)
magenta_strict = np.all(m_rgb == [255, 0, 255], axis=2).astype(np.uint8) * 255
binmask = cv2.morphologyEx(magenta_strict, cv2.MORPH_CLOSE, kernel, iterations=3)
nonzero = int((binmask > 0).sum())
log.info("Strict magenta detection: %d pixels", nonzero)
if nonzero < 50:
log.error("CRITICAL: Could not detect pink/magenta paint! Returning original image.")
result = np.array(img.convert("RGB")) if img else np.array(m.convert("RGB"))
result_name = f"output_{uuid.uuid4().hex}.png"
result_path = os.path.join(OUTPUT_DIR, result_name)
Image.fromarray(result).save(result_path)
return {"result": result_name, "error": "pink/magenta paint detection failed - very few pixels detected"}
# Create black/white mask: white = remove (pink areas), black = keep (everything else)
mask_rgba = np.zeros((binmask.shape[0], binmask.shape[1], 4), dtype=np.uint8)
mask_rgba[:, :, 0] = binmask # R: white where pink detected
mask_rgba[:, :, 1] = binmask # G: white where pink detected
mask_rgba[:, :, 2] = binmask # B: white where pink detected
mask_rgba[:, :, 3] = 255 # Alpha: fully opaque
log.info("Successfully created black/white mask: %d white pixels (to remove), %d black pixels (to keep)",
nonzero, binmask.shape[0] * binmask.shape[1] - nonzero)
else:
mask_rgba = _load_rgba_mask_from_image(m)
# When mask_is_painted=true, we create white=remove masks, so invert_mask should be False
# (white pixels should stay white to indicate removal)
actual_invert = invert_mask if not mask_is_painted else False
log.info("Using invert_mask=%s (mask_is_painted=%s)", actual_invert, mask_is_painted)
result = process_inpaint(np.array(img), mask_rgba, invert_mask=actual_invert)
result_name = f"output_{uuid.uuid4().hex}.png"
result_path = os.path.join(OUTPUT_DIR, result_name)
Image.fromarray(result).save(result_path)
url: Optional[str] = None
try:
if request is not None:
url = str(request.url_for("download_file", filename=result_name))
except Exception:
url = None
entry: Dict[str, str] = {"result": result_name, "timestamp": datetime.utcnow().isoformat()}
if url:
entry["url"] = url
logs.append(entry)
resp: Dict[str, str] = {"result": result_name}
if url:
resp["url"] = url
return resp
@app.post("/remove-pink")
def remove_pink_segments(
image: UploadFile = File(...),
request: Request = None,
_: None = Depends(bearer_auth),
) -> Dict[str, str]:
"""
Simple endpoint: upload an image with pink/magenta segments to remove.
- Pink/Magenta segments → automatically removed (white in mask)
- Everything else → automatically kept (black in mask)
Just paint pink/magenta on areas you want to remove, upload the image, and it works!
"""
log.info(f"Simple remove-pink: processing image {image.filename}")
# Load the image (with pink paint on it)
img = Image.open(image.file).convert("RGBA")
img_rgb = cv2.cvtColor(np.array(img), cv2.COLOR_RGBA2RGB)
# Auto-detect pink/magenta segments to remove
# Pink/Magenta → white in mask (remove)
# Everything else (natural image colors, including dark areas) → black in mask (keep)
# Detect pink/magenta paint (RGB: R>240, G<30, B>240)
magenta_detected = (
(img_rgb[:, :, 0] > 240) & # Red: high
(img_rgb[:, :, 1] < 30) & # Green: low
(img_rgb[:, :, 2] > 240) # Blue: high
).astype(np.uint8) * 255
# Clean up the pink mask
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
binmask = cv2.morphologyEx(magenta_detected, cv2.MORPH_CLOSE, kernel, iterations=2)
binmask = cv2.morphologyEx(binmask, cv2.MORPH_OPEN, kernel, iterations=1)
nonzero = int((binmask > 0).sum())
log.info(f"Detected {nonzero} pink pixels to remove")
if nonzero < 50:
log.warning("Very few pink pixels detected! Trying strict magenta detection...")
magenta_strict = np.all(img_rgb == [255, 0, 255], axis=2).astype(np.uint8) * 255
binmask = cv2.morphologyEx(magenta_strict, cv2.MORPH_CLOSE, kernel, iterations=3)
nonzero = int((binmask > 0).sum())
log.info(f"Strict detection: {nonzero} pink pixels")
if nonzero < 50:
log.error("No pink segments detected! Returning original image.")
result = np.array(img.convert("RGB"))
result_name = f"output_{uuid.uuid4().hex}.png"
result_path = os.path.join(OUTPUT_DIR, result_name)
Image.fromarray(result).save(result_path)
return {
"result": result_name,
"error": "No pink/magenta segments detected. Please paint areas to remove with magenta/pink color (RGB 255,0,255)."
}
# Create mask:
# - White = pink areas (remove)
# - Black = everything else (keep) - automatically, no black painting needed
mask_rgba = np.zeros((binmask.shape[0], binmask.shape[1], 4), dtype=np.uint8)
mask_rgba[:, :, 0] = binmask # R: white where pink detected
mask_rgba[:, :, 1] = binmask # G: white where pink detected
mask_rgba[:, :, 2] = binmask # B: white where pink detected
mask_rgba[:, :, 3] = 255 # Alpha: fully opaque
total_pixels = binmask.shape[0] * binmask.shape[1]
log.info(f"Created mask: {nonzero} white pixels (remove/pink), {total_pixels - nonzero} black pixels (keep - everything else automatically)")
# Process with invert_mask=False because white pixels = remove (standard)
result = process_inpaint(np.array(img), mask_rgba, invert_mask=False)
result_name = f"output_{uuid.uuid4().hex}.png"
result_path = os.path.join(OUTPUT_DIR, result_name)
Image.fromarray(result).save(result_path)
url: Optional[str] = None
try:
if request is not None:
url = str(request.url_for("download_file", filename=result_name))
except Exception:
url = None
logs.append({
"result": result_name,
"filename": image.filename,
"pink_pixels": nonzero,
"timestamp": datetime.utcnow().isoformat()
})
resp: Dict[str, str] = {"result": result_name, "pink_segments_detected": str(nonzero)}
if url:
resp["url"] = url
return resp
@app.get("/download/{filename}")
def download_file(filename: str):
path = os.path.join(OUTPUT_DIR, filename)
if not os.path.isfile(path):
raise HTTPException(status_code=404, detail="file not found")
return FileResponse(path)
@app.get("/result/{filename}")
def view_result(filename: str):
"""View result image directly in browser (same as download but with proper content-type for viewing)"""
path = os.path.join(OUTPUT_DIR, filename)
if not os.path.isfile(path):
raise HTTPException(status_code=404, detail="file not found")
return FileResponse(path, media_type="image/png")
@app.get("/logs")
def get_logs(_: None = Depends(bearer_auth)) -> JSONResponse:
return JSONResponse(content=logs)