|
|
|
|
|
import os |
|
|
import uuid |
|
|
import shutil |
|
|
from datetime import datetime |
|
|
from typing import Dict, List, Optional |
|
|
|
|
|
import numpy as np |
|
|
from fastapi import FastAPI, UploadFile, File, HTTPException, Depends, Header, Request |
|
|
from fastapi.responses import FileResponse, JSONResponse |
|
|
from pydantic import BaseModel |
|
|
from PIL import Image |
|
|
import cv2 |
|
|
import logging |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
log = logging.getLogger("api") |
|
|
|
|
|
from src.core import process_inpaint |
|
|
|
|
|
|
|
|
BASE_DIR = os.environ.get("DATA_DIR", "/data") |
|
|
if not os.path.isdir(BASE_DIR): |
|
|
|
|
|
BASE_DIR = "/tmp" |
|
|
|
|
|
UPLOAD_DIR = os.path.join(BASE_DIR, "uploads") |
|
|
OUTPUT_DIR = os.path.join(BASE_DIR, "outputs") |
|
|
|
|
|
os.makedirs(UPLOAD_DIR, exist_ok=True) |
|
|
os.makedirs(OUTPUT_DIR, exist_ok=True) |
|
|
|
|
|
|
|
|
ENV_TOKEN = os.environ.get("API_TOKEN") |
|
|
|
|
|
app = FastAPI(title="Photo Object Removal API", version="1.0.0") |
|
|
|
|
|
|
|
|
file_store: Dict[str, Dict[str, str]] = {} |
|
|
logs: List[Dict[str, str]] = [] |
|
|
|
|
|
|
|
|
def bearer_auth(authorization: Optional[str] = Header(default=None)) -> None: |
|
|
if not ENV_TOKEN: |
|
|
return |
|
|
if authorization is None or not authorization.lower().startswith("bearer "): |
|
|
raise HTTPException(status_code=401, detail="Unauthorized") |
|
|
token = authorization.split(" ", 1)[1] |
|
|
if token != ENV_TOKEN: |
|
|
raise HTTPException(status_code=403, detail="Forbidden") |
|
|
|
|
|
|
|
|
class InpaintRequest(BaseModel): |
|
|
image_id: str |
|
|
mask_id: str |
|
|
invert_mask: bool = True |
|
|
passthrough: bool = False |
|
|
|
|
|
|
|
|
class SimpleRemoveRequest(BaseModel): |
|
|
image_id: str |
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
def root() -> Dict[str, object]: |
|
|
return { |
|
|
"name": "Photo Object Removal API", |
|
|
"status": "ok", |
|
|
"endpoints": { |
|
|
"GET /health": "health check", |
|
|
"POST /upload-image": "form-data: image=file", |
|
|
"POST /upload-mask": "form-data: mask=file", |
|
|
"POST /inpaint": "JSON: {image_id, mask_id}", |
|
|
"POST /inpaint-multipart": "form-data: image=file, mask=file", |
|
|
"POST /remove-pink": "form-data: image=file (auto-detects pink segments and removes them)", |
|
|
"GET /download/{filename}": "download result image", |
|
|
"GET /result/{filename}": "view result image in browser", |
|
|
"GET /logs": "recent uploads/results", |
|
|
}, |
|
|
"auth": "set API_TOKEN env var to require Authorization: Bearer <token> (except /health)", |
|
|
} |
|
|
|
|
|
|
|
|
@app.get("/health") |
|
|
def health() -> Dict[str, str]: |
|
|
return {"status": "healthy"} |
|
|
|
|
|
|
|
|
@app.post("/upload-image") |
|
|
def upload_image(image: UploadFile = File(...), _: None = Depends(bearer_auth)) -> Dict[str, str]: |
|
|
ext = os.path.splitext(image.filename)[1] or ".png" |
|
|
file_id = str(uuid.uuid4()) |
|
|
stored_name = f"{file_id}{ext}" |
|
|
stored_path = os.path.join(UPLOAD_DIR, stored_name) |
|
|
with open(stored_path, "wb") as f: |
|
|
shutil.copyfileobj(image.file, f) |
|
|
file_store[file_id] = { |
|
|
"type": "image", |
|
|
"filename": image.filename, |
|
|
"stored_name": stored_name, |
|
|
"path": stored_path, |
|
|
"timestamp": datetime.utcnow().isoformat(), |
|
|
} |
|
|
logs.append({"id": file_id, "filename": image.filename, "type": "image", "timestamp": datetime.utcnow().isoformat()}) |
|
|
return {"id": file_id, "filename": image.filename} |
|
|
|
|
|
|
|
|
@app.post("/upload-mask") |
|
|
def upload_mask(mask: UploadFile = File(...), _: None = Depends(bearer_auth)) -> Dict[str, str]: |
|
|
ext = os.path.splitext(mask.filename)[1] or ".png" |
|
|
file_id = str(uuid.uuid4()) |
|
|
stored_name = f"{file_id}{ext}" |
|
|
stored_path = os.path.join(UPLOAD_DIR, stored_name) |
|
|
with open(stored_path, "wb") as f: |
|
|
shutil.copyfileobj(mask.file, f) |
|
|
file_store[file_id] = { |
|
|
"type": "mask", |
|
|
"filename": mask.filename, |
|
|
"stored_name": stored_name, |
|
|
"path": stored_path, |
|
|
"timestamp": datetime.utcnow().isoformat(), |
|
|
} |
|
|
logs.append({"id": file_id, "filename": mask.filename, "type": "mask", "timestamp": datetime.utcnow().isoformat()}) |
|
|
return {"id": file_id, "filename": mask.filename} |
|
|
|
|
|
|
|
|
def _load_rgba_image(path: str) -> Image.Image: |
|
|
img = Image.open(path) |
|
|
return img.convert("RGBA") |
|
|
|
|
|
|
|
|
def _load_rgba_mask_from_image(img: Image.Image) -> np.ndarray: |
|
|
""" |
|
|
Convert mask image to RGBA format (black/white mask). |
|
|
Standard convention: white (255) = area to remove, black (0) = area to keep |
|
|
Returns RGBA with white in RGB channels where removal is needed, alpha=255 |
|
|
""" |
|
|
if img.mode != "RGBA": |
|
|
|
|
|
gray = img.convert("L") |
|
|
arr = np.array(gray) |
|
|
|
|
|
mask_bw = np.where(arr > 128, 255, 0).astype(np.uint8) |
|
|
|
|
|
rgba = np.zeros((img.height, img.width, 4), dtype=np.uint8) |
|
|
rgba[:, :, 0] = mask_bw |
|
|
rgba[:, :, 1] = mask_bw |
|
|
rgba[:, :, 2] = mask_bw |
|
|
rgba[:, :, 3] = 255 |
|
|
log.info(f"Loaded {img.mode} mask: {int((mask_bw > 0).sum())} white pixels (to remove)") |
|
|
return rgba |
|
|
|
|
|
|
|
|
arr = np.array(img) |
|
|
alpha = arr[:, :, 3] |
|
|
rgb = arr[:, :, :3] |
|
|
|
|
|
|
|
|
if alpha.mean() > 200: |
|
|
|
|
|
gray = cv2.cvtColor(rgb, cv2.COLOR_RGB2GRAY) |
|
|
|
|
|
magenta = np.all(rgb == [255, 0, 255], axis=2).astype(np.uint8) * 255 |
|
|
mask_bw = np.maximum(np.where(gray > 128, 255, 0).astype(np.uint8), magenta) |
|
|
|
|
|
rgba = arr.copy() |
|
|
rgba[:, :, 0] = mask_bw |
|
|
rgba[:, :, 1] = mask_bw |
|
|
rgba[:, :, 2] = mask_bw |
|
|
rgba[:, :, 3] = 255 |
|
|
log.info(f"Loaded RGBA mask (RGB-based): {int((mask_bw > 0).sum())} white pixels (to remove)") |
|
|
return rgba |
|
|
|
|
|
|
|
|
|
|
|
mask_bw = np.where(alpha < 128, 255, 0).astype(np.uint8) |
|
|
rgba = arr.copy() |
|
|
rgba[:, :, 0] = mask_bw |
|
|
rgba[:, :, 1] = mask_bw |
|
|
rgba[:, :, 2] = mask_bw |
|
|
rgba[:, :, 3] = 255 |
|
|
log.info(f"Loaded RGBA mask (alpha-based): {int((mask_bw > 0).sum())} white pixels (to remove)") |
|
|
return rgba |
|
|
|
|
|
|
|
|
@app.post("/inpaint") |
|
|
def inpaint(req: InpaintRequest, _: None = Depends(bearer_auth)) -> Dict[str, str]: |
|
|
if req.image_id not in file_store or file_store[req.image_id]["type"] != "image": |
|
|
raise HTTPException(status_code=404, detail="image_id not found") |
|
|
if req.mask_id not in file_store or file_store[req.mask_id]["type"] != "mask": |
|
|
raise HTTPException(status_code=404, detail="mask_id not found") |
|
|
|
|
|
img_rgba = _load_rgba_image(file_store[req.image_id]["path"]) |
|
|
mask_img = Image.open(file_store[req.mask_id]["path"]) |
|
|
mask_rgba = _load_rgba_mask_from_image(mask_img) |
|
|
|
|
|
if req.passthrough: |
|
|
result = np.array(img_rgba.convert("RGB")) |
|
|
else: |
|
|
result = process_inpaint(np.array(img_rgba), mask_rgba, invert_mask=req.invert_mask) |
|
|
result_name = f"output_{uuid.uuid4().hex}.png" |
|
|
result_path = os.path.join(OUTPUT_DIR, result_name) |
|
|
Image.fromarray(result).save(result_path) |
|
|
|
|
|
logs.append({"result": result_name, "timestamp": datetime.utcnow().isoformat()}) |
|
|
return {"result": result_name} |
|
|
|
|
|
|
|
|
@app.post("/inpaint-url") |
|
|
def inpaint_url(req: InpaintRequest, request: Request, _: None = Depends(bearer_auth)) -> Dict[str, str]: |
|
|
"""Same as /inpaint but returns a JSON with a public download URL instead of image bytes.""" |
|
|
if req.image_id not in file_store or file_store[req.image_id]["type"] != "image": |
|
|
raise HTTPException(status_code=404, detail="image_id not found") |
|
|
if req.mask_id not in file_store or file_store[req.mask_id]["type"] != "mask": |
|
|
raise HTTPException(status_code=404, detail="mask_id not found") |
|
|
|
|
|
img_rgba = _load_rgba_image(file_store[req.image_id]["path"]) |
|
|
mask_img = Image.open(file_store[req.mask_id]["path"]) |
|
|
mask_rgba = _load_rgba_mask_from_image(mask_img) |
|
|
|
|
|
if req.passthrough: |
|
|
result = np.array(img_rgba.convert("RGB")) |
|
|
else: |
|
|
result = process_inpaint(np.array(img_rgba), mask_rgba, invert_mask=req.invert_mask) |
|
|
result_name = f"output_{uuid.uuid4().hex}.png" |
|
|
result_path = os.path.join(OUTPUT_DIR, result_name) |
|
|
Image.fromarray(result).save(result_path) |
|
|
|
|
|
url = str(request.url_for("download_file", filename=result_name)) |
|
|
logs.append({"result": result_name, "url": url, "timestamp": datetime.utcnow().isoformat()}) |
|
|
return {"result": result_name, "url": url} |
|
|
|
|
|
|
|
|
@app.post("/inpaint-multipart") |
|
|
def inpaint_multipart( |
|
|
image: UploadFile = File(...), |
|
|
mask: UploadFile = File(...), |
|
|
request: Request = None, |
|
|
invert_mask: bool = True, |
|
|
mask_is_painted: bool = False, |
|
|
passthrough: bool = False, |
|
|
_: None = Depends(bearer_auth), |
|
|
) -> Dict[str, str]: |
|
|
|
|
|
img = Image.open(image.file).convert("RGBA") |
|
|
m = Image.open(mask.file).convert("RGBA") |
|
|
|
|
|
if passthrough: |
|
|
|
|
|
result = np.array(img.convert("RGB")) |
|
|
result_name = f"output_{uuid.uuid4().hex}.png" |
|
|
result_path = os.path.join(OUTPUT_DIR, result_name) |
|
|
Image.fromarray(result).save(result_path) |
|
|
|
|
|
url: Optional[str] = None |
|
|
try: |
|
|
if request is not None: |
|
|
url = str(request.url_for("download_file", filename=result_name)) |
|
|
except Exception: |
|
|
url = None |
|
|
|
|
|
entry: Dict[str, str] = {"result": result_name, "timestamp": datetime.utcnow().isoformat()} |
|
|
if url: |
|
|
entry["url"] = url |
|
|
logs.append(entry) |
|
|
resp: Dict[str, str] = {"result": result_name} |
|
|
if url: |
|
|
resp["url"] = url |
|
|
return resp |
|
|
|
|
|
if mask_is_painted: |
|
|
|
|
|
|
|
|
log.info("Auto-detecting pink/magenta paint from uploaded image...") |
|
|
|
|
|
m_rgb = cv2.cvtColor(np.array(m), cv2.COLOR_RGBA2RGB) |
|
|
|
|
|
|
|
|
|
|
|
magenta_detected = ( |
|
|
(m_rgb[:, :, 0] > 240) & |
|
|
(m_rgb[:, :, 1] < 30) & |
|
|
(m_rgb[:, :, 2] > 240) |
|
|
).astype(np.uint8) * 255 |
|
|
|
|
|
|
|
|
if img is not None: |
|
|
img_rgb = cv2.cvtColor(np.array(img), cv2.COLOR_RGBA2RGB) |
|
|
if img_rgb.shape == m_rgb.shape: |
|
|
diff = cv2.absdiff(img_rgb, m_rgb) |
|
|
gray_diff = cv2.cvtColor(diff, cv2.COLOR_RGB2GRAY) |
|
|
|
|
|
diff_mask = (gray_diff > 50).astype(np.uint8) * 255 |
|
|
|
|
|
binmask = cv2.bitwise_or(magenta_detected, diff_mask) |
|
|
else: |
|
|
binmask = magenta_detected |
|
|
else: |
|
|
|
|
|
binmask = magenta_detected |
|
|
|
|
|
|
|
|
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) |
|
|
|
|
|
binmask = cv2.morphologyEx(binmask, cv2.MORPH_CLOSE, kernel, iterations=2) |
|
|
|
|
|
binmask = cv2.morphologyEx(binmask, cv2.MORPH_OPEN, kernel, iterations=1) |
|
|
|
|
|
nonzero = int((binmask > 0).sum()) |
|
|
log.info("Pink/magenta paint detected: %d pixels marked for removal (white)", nonzero) |
|
|
|
|
|
if nonzero < 50: |
|
|
log.warning("Very few pixels detected! Trying stricter magenta detection...") |
|
|
|
|
|
magenta_strict = np.all(m_rgb == [255, 0, 255], axis=2).astype(np.uint8) * 255 |
|
|
binmask = cv2.morphologyEx(magenta_strict, cv2.MORPH_CLOSE, kernel, iterations=3) |
|
|
nonzero = int((binmask > 0).sum()) |
|
|
log.info("Strict magenta detection: %d pixels", nonzero) |
|
|
|
|
|
if nonzero < 50: |
|
|
log.error("CRITICAL: Could not detect pink/magenta paint! Returning original image.") |
|
|
result = np.array(img.convert("RGB")) if img else np.array(m.convert("RGB")) |
|
|
result_name = f"output_{uuid.uuid4().hex}.png" |
|
|
result_path = os.path.join(OUTPUT_DIR, result_name) |
|
|
Image.fromarray(result).save(result_path) |
|
|
return {"result": result_name, "error": "pink/magenta paint detection failed - very few pixels detected"} |
|
|
|
|
|
|
|
|
mask_rgba = np.zeros((binmask.shape[0], binmask.shape[1], 4), dtype=np.uint8) |
|
|
mask_rgba[:, :, 0] = binmask |
|
|
mask_rgba[:, :, 1] = binmask |
|
|
mask_rgba[:, :, 2] = binmask |
|
|
mask_rgba[:, :, 3] = 255 |
|
|
|
|
|
log.info("Successfully created black/white mask: %d white pixels (to remove), %d black pixels (to keep)", |
|
|
nonzero, binmask.shape[0] * binmask.shape[1] - nonzero) |
|
|
else: |
|
|
mask_rgba = _load_rgba_mask_from_image(m) |
|
|
|
|
|
|
|
|
|
|
|
actual_invert = invert_mask if not mask_is_painted else False |
|
|
log.info("Using invert_mask=%s (mask_is_painted=%s)", actual_invert, mask_is_painted) |
|
|
|
|
|
result = process_inpaint(np.array(img), mask_rgba, invert_mask=actual_invert) |
|
|
result_name = f"output_{uuid.uuid4().hex}.png" |
|
|
result_path = os.path.join(OUTPUT_DIR, result_name) |
|
|
Image.fromarray(result).save(result_path) |
|
|
|
|
|
url: Optional[str] = None |
|
|
try: |
|
|
if request is not None: |
|
|
url = str(request.url_for("download_file", filename=result_name)) |
|
|
except Exception: |
|
|
url = None |
|
|
|
|
|
entry: Dict[str, str] = {"result": result_name, "timestamp": datetime.utcnow().isoformat()} |
|
|
if url: |
|
|
entry["url"] = url |
|
|
logs.append(entry) |
|
|
resp: Dict[str, str] = {"result": result_name} |
|
|
if url: |
|
|
resp["url"] = url |
|
|
return resp |
|
|
|
|
|
|
|
|
@app.post("/remove-pink") |
|
|
def remove_pink_segments( |
|
|
image: UploadFile = File(...), |
|
|
request: Request = None, |
|
|
_: None = Depends(bearer_auth), |
|
|
) -> Dict[str, str]: |
|
|
""" |
|
|
Simple endpoint: upload an image with pink/magenta segments to remove. |
|
|
- Pink/Magenta segments → automatically removed (white in mask) |
|
|
- Everything else → automatically kept (black in mask) |
|
|
Just paint pink/magenta on areas you want to remove, upload the image, and it works! |
|
|
""" |
|
|
log.info(f"Simple remove-pink: processing image {image.filename}") |
|
|
|
|
|
|
|
|
img = Image.open(image.file).convert("RGBA") |
|
|
img_rgb = cv2.cvtColor(np.array(img), cv2.COLOR_RGBA2RGB) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
magenta_detected = ( |
|
|
(img_rgb[:, :, 0] > 240) & |
|
|
(img_rgb[:, :, 1] < 30) & |
|
|
(img_rgb[:, :, 2] > 240) |
|
|
).astype(np.uint8) * 255 |
|
|
|
|
|
|
|
|
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) |
|
|
binmask = cv2.morphologyEx(magenta_detected, cv2.MORPH_CLOSE, kernel, iterations=2) |
|
|
binmask = cv2.morphologyEx(binmask, cv2.MORPH_OPEN, kernel, iterations=1) |
|
|
|
|
|
nonzero = int((binmask > 0).sum()) |
|
|
log.info(f"Detected {nonzero} pink pixels to remove") |
|
|
|
|
|
if nonzero < 50: |
|
|
log.warning("Very few pink pixels detected! Trying strict magenta detection...") |
|
|
magenta_strict = np.all(img_rgb == [255, 0, 255], axis=2).astype(np.uint8) * 255 |
|
|
binmask = cv2.morphologyEx(magenta_strict, cv2.MORPH_CLOSE, kernel, iterations=3) |
|
|
nonzero = int((binmask > 0).sum()) |
|
|
log.info(f"Strict detection: {nonzero} pink pixels") |
|
|
|
|
|
if nonzero < 50: |
|
|
log.error("No pink segments detected! Returning original image.") |
|
|
result = np.array(img.convert("RGB")) |
|
|
result_name = f"output_{uuid.uuid4().hex}.png" |
|
|
result_path = os.path.join(OUTPUT_DIR, result_name) |
|
|
Image.fromarray(result).save(result_path) |
|
|
return { |
|
|
"result": result_name, |
|
|
"error": "No pink/magenta segments detected. Please paint areas to remove with magenta/pink color (RGB 255,0,255)." |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mask_rgba = np.zeros((binmask.shape[0], binmask.shape[1], 4), dtype=np.uint8) |
|
|
mask_rgba[:, :, 0] = binmask |
|
|
mask_rgba[:, :, 1] = binmask |
|
|
mask_rgba[:, :, 2] = binmask |
|
|
mask_rgba[:, :, 3] = 255 |
|
|
|
|
|
total_pixels = binmask.shape[0] * binmask.shape[1] |
|
|
log.info(f"Created mask: {nonzero} white pixels (remove/pink), {total_pixels - nonzero} black pixels (keep - everything else automatically)") |
|
|
|
|
|
|
|
|
result = process_inpaint(np.array(img), mask_rgba, invert_mask=False) |
|
|
result_name = f"output_{uuid.uuid4().hex}.png" |
|
|
result_path = os.path.join(OUTPUT_DIR, result_name) |
|
|
Image.fromarray(result).save(result_path) |
|
|
|
|
|
url: Optional[str] = None |
|
|
try: |
|
|
if request is not None: |
|
|
url = str(request.url_for("download_file", filename=result_name)) |
|
|
except Exception: |
|
|
url = None |
|
|
|
|
|
logs.append({ |
|
|
"result": result_name, |
|
|
"filename": image.filename, |
|
|
"pink_pixels": nonzero, |
|
|
"timestamp": datetime.utcnow().isoformat() |
|
|
}) |
|
|
|
|
|
resp: Dict[str, str] = {"result": result_name, "pink_segments_detected": str(nonzero)} |
|
|
if url: |
|
|
resp["url"] = url |
|
|
return resp |
|
|
|
|
|
|
|
|
@app.get("/download/{filename}") |
|
|
def download_file(filename: str): |
|
|
path = os.path.join(OUTPUT_DIR, filename) |
|
|
if not os.path.isfile(path): |
|
|
raise HTTPException(status_code=404, detail="file not found") |
|
|
return FileResponse(path) |
|
|
|
|
|
|
|
|
@app.get("/result/{filename}") |
|
|
def view_result(filename: str): |
|
|
"""View result image directly in browser (same as download but with proper content-type for viewing)""" |
|
|
path = os.path.join(OUTPUT_DIR, filename) |
|
|
if not os.path.isfile(path): |
|
|
raise HTTPException(status_code=404, detail="file not found") |
|
|
return FileResponse(path, media_type="image/png") |
|
|
|
|
|
|
|
|
@app.get("/logs") |
|
|
def get_logs(_: None = Depends(bearer_auth)) -> JSONResponse: |
|
|
return JSONResponse(content=logs) |