|
|
import os |
|
|
import uuid |
|
|
import shutil |
|
|
from datetime import datetime |
|
|
from typing import Dict, List, Optional |
|
|
|
|
|
import numpy as np |
|
|
from fastapi import FastAPI, UploadFile, File, HTTPException, Depends, Header |
|
|
from fastapi.responses import FileResponse, JSONResponse |
|
|
from pydantic import BaseModel |
|
|
from PIL import Image |
|
|
|
|
|
from src.core import process_inpaint |
|
|
|
|
|
|
|
|
BASE_DIR = os.environ.get("DATA_DIR", "/data") |
|
|
if not os.path.isdir(BASE_DIR): |
|
|
|
|
|
BASE_DIR = "/tmp" |
|
|
|
|
|
UPLOAD_DIR = os.path.join(BASE_DIR, "uploads") |
|
|
OUTPUT_DIR = os.path.join(BASE_DIR, "outputs") |
|
|
|
|
|
os.makedirs(UPLOAD_DIR, exist_ok=True) |
|
|
os.makedirs(OUTPUT_DIR, exist_ok=True) |
|
|
|
|
|
|
|
|
ENV_TOKEN = os.environ.get("API_TOKEN") |
|
|
|
|
|
app = FastAPI(title="Photo Object Removal API", version="1.0.0") |
|
|
|
|
|
|
|
|
file_store: Dict[str, Dict[str, str]] = {} |
|
|
logs: List[Dict[str, str]] = [] |
|
|
|
|
|
|
|
|
def bearer_auth(authorization: Optional[str] = Header(default=None)) -> None: |
|
|
if not ENV_TOKEN: |
|
|
return |
|
|
if authorization is None or not authorization.lower().startswith("bearer "): |
|
|
raise HTTPException(status_code=401, detail="Unauthorized") |
|
|
token = authorization.split(" ", 1)[1] |
|
|
if token != ENV_TOKEN: |
|
|
raise HTTPException(status_code=403, detail="Forbidden") |
|
|
|
|
|
|
|
|
class InpaintRequest(BaseModel): |
|
|
image_id: str |
|
|
mask_id: str |
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
def root() -> Dict[str, object]: |
|
|
return { |
|
|
"name": "Photo Object Removal API", |
|
|
"status": "ok", |
|
|
"endpoints": { |
|
|
"GET /health": "health check", |
|
|
"POST /upload-image": "form-data: image=file", |
|
|
"POST /upload-mask": "form-data: mask=file", |
|
|
"POST /inpaint": "JSON: {image_id, mask_id}", |
|
|
"POST /inpaint-multipart": "form-data: image=file, mask=file", |
|
|
"GET /download/{filename}": "download result image", |
|
|
"GET /logs": "recent uploads/results", |
|
|
}, |
|
|
"auth": "set API_TOKEN env var to require Authorization: Bearer <token> (except /health)", |
|
|
} |
|
|
|
|
|
|
|
|
@app.get("/health") |
|
|
def health() -> Dict[str, str]: |
|
|
return {"status": "healthy"} |
|
|
|
|
|
|
|
|
@app.post("/upload-image") |
|
|
def upload_image(image: UploadFile = File(...), _: None = Depends(bearer_auth)) -> Dict[str, str]: |
|
|
ext = os.path.splitext(image.filename)[1] or ".png" |
|
|
file_id = str(uuid.uuid4()) |
|
|
stored_name = f"{file_id}{ext}" |
|
|
stored_path = os.path.join(UPLOAD_DIR, stored_name) |
|
|
with open(stored_path, "wb") as f: |
|
|
shutil.copyfileobj(image.file, f) |
|
|
file_store[file_id] = { |
|
|
"type": "image", |
|
|
"filename": image.filename, |
|
|
"stored_name": stored_name, |
|
|
"path": stored_path, |
|
|
"timestamp": datetime.utcnow().isoformat(), |
|
|
} |
|
|
logs.append({"id": file_id, "filename": image.filename, "type": "image", "timestamp": datetime.utcnow().isoformat()}) |
|
|
return {"id": file_id, "filename": image.filename} |
|
|
|
|
|
|
|
|
@app.post("/upload-mask") |
|
|
def upload_mask(mask: UploadFile = File(...), _: None = Depends(bearer_auth)) -> Dict[str, str]: |
|
|
ext = os.path.splitext(mask.filename)[1] or ".png" |
|
|
file_id = str(uuid.uuid4()) |
|
|
stored_name = f"{file_id}{ext}" |
|
|
stored_path = os.path.join(UPLOAD_DIR, stored_name) |
|
|
with open(stored_path, "wb") as f: |
|
|
shutil.copyfileobj(mask.file, f) |
|
|
file_store[file_id] = { |
|
|
"type": "mask", |
|
|
"filename": mask.filename, |
|
|
"stored_name": stored_name, |
|
|
"path": stored_path, |
|
|
"timestamp": datetime.utcnow().isoformat(), |
|
|
} |
|
|
logs.append({"id": file_id, "filename": mask.filename, "type": "mask", "timestamp": datetime.utcnow().isoformat()}) |
|
|
return {"id": file_id, "filename": mask.filename} |
|
|
|
|
|
|
|
|
def _load_rgba_image(path: str) -> Image.Image: |
|
|
img = Image.open(path) |
|
|
return img.convert("RGBA") |
|
|
|
|
|
|
|
|
def _load_rgba_mask_from_image(img: Image.Image) -> np.ndarray: |
|
|
|
|
|
if img.mode != "RGBA": |
|
|
|
|
|
gray = img.convert("L") |
|
|
arr = np.array(gray) |
|
|
alpha = np.where(arr > 0, 0, 255).astype(np.uint8) |
|
|
rgba = np.zeros((img.height, img.width, 4), dtype=np.uint8) |
|
|
rgba[:, :, 3] = alpha |
|
|
return rgba |
|
|
return np.array(img) |
|
|
|
|
|
|
|
|
@app.post("/inpaint") |
|
|
def inpaint(req: InpaintRequest, _: None = Depends(bearer_auth)): |
|
|
if req.image_id not in file_store or file_store[req.image_id]["type"] != "image": |
|
|
raise HTTPException(status_code=404, detail="image_id not found") |
|
|
if req.mask_id not in file_store or file_store[req.mask_id]["type"] != "mask": |
|
|
raise HTTPException(status_code=404, detail="mask_id not found") |
|
|
|
|
|
img_rgba = _load_rgba_image(file_store[req.image_id]["path"]) |
|
|
mask_img = Image.open(file_store[req.mask_id]["path"]) |
|
|
mask_rgba = _load_rgba_mask_from_image(mask_img) |
|
|
|
|
|
result = process_inpaint(np.array(img_rgba), mask_rgba) |
|
|
result_name = f"output_{uuid.uuid4().hex}.png" |
|
|
result_path = os.path.join(OUTPUT_DIR, result_name) |
|
|
Image.fromarray(result).save(result_path) |
|
|
|
|
|
logs.append({"result": result_name, "timestamp": datetime.utcnow().isoformat()}) |
|
|
return FileResponse(result_path, media_type="image/png", filename=result_name) |
|
|
|
|
|
|
|
|
@app.post("/inpaint-multipart") |
|
|
def inpaint_multipart( |
|
|
image: UploadFile = File(...), |
|
|
mask: UploadFile = File(...), |
|
|
_: None = Depends(bearer_auth), |
|
|
) -> Dict[str, str]: |
|
|
|
|
|
img = Image.open(image.file).convert("RGBA") |
|
|
m = Image.open(mask.file) |
|
|
mask_rgba = _load_rgba_mask_from_image(m) |
|
|
|
|
|
result = process_inpaint(np.array(img), mask_rgba) |
|
|
result_name = f"output_{uuid.uuid4().hex}.png" |
|
|
result_path = os.path.join(OUTPUT_DIR, result_name) |
|
|
Image.fromarray(result).save(result_path) |
|
|
|
|
|
logs.append({"result": result_name, "timestamp": datetime.utcnow().isoformat()}) |
|
|
return {"result": result_name} |
|
|
|
|
|
|
|
|
@app.get("/download/{filename}") |
|
|
def download_file(filename: str): |
|
|
path = os.path.join(OUTPUT_DIR, filename) |
|
|
if not os.path.isfile(path): |
|
|
raise HTTPException(status_code=404, detail="file not found") |
|
|
return FileResponse(path) |
|
|
|
|
|
|
|
|
@app.get("/logs") |
|
|
def get_logs(_: None = Depends(bearer_auth)) -> JSONResponse: |
|
|
return JSONResponse(content=logs) |
|
|
|