the-adrianator's picture
Initial commit: AI watermark remover
b2c1b6b
"""
Watermark Remover — FastAPI backend.
Run: uvicorn app:app --reload --port 8000
"""
import base64
import shutil
import threading
import uuid
from pathlib import Path
import cv2
import numpy as np
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
# ── Job registry ──────────────────────────────────────────
# job_id -> {status, progress, total, message, download_url, error}
_jobs: dict = {}
def _job_update(job_id: str, **kwargs) -> None:
if job_id in _jobs:
_jobs[job_id].update(kwargs)
BASE_DIR = Path(__file__).parent
STATIC_DIR = BASE_DIR / "static"
UPLOAD_DIR = Path("/tmp/wm_tool")
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
VIDEO_EXTS = {".mp4", ".avi", ".mov", ".mkv", ".webm", ".flv"}
IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".webp", ".bmp", ".tiff", ".tif"}
app = FastAPI(title="Watermark Remover")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static")
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _decode_mask(mask_data: str, target_shape=None) -> np.ndarray:
"""Decode a base64 data-URL PNG into a grayscale numpy mask."""
b64 = mask_data.split(",")[-1]
raw = base64.b64decode(b64)
arr = np.frombuffer(raw, dtype=np.uint8)
mask = cv2.imdecode(arr, cv2.IMREAD_GRAYSCALE)
if mask is None:
raise HTTPException(400, "Invalid mask image")
if target_shape is not None and mask.shape[:2] != target_shape[:2]:
mask = cv2.resize(mask, (target_shape[1], target_shape[0]), interpolation=cv2.INTER_NEAREST)
return mask
def _safe_path(filename: str) -> Path:
"""Resolve path and ensure it stays within UPLOAD_DIR."""
p = (UPLOAD_DIR / filename).resolve()
try:
p.relative_to(UPLOAD_DIR.resolve())
except ValueError:
raise HTTPException(403, "Access denied")
return p
# ---------------------------------------------------------------------------
# Routes
# ---------------------------------------------------------------------------
@app.get("/")
async def index():
return FileResponse(str(STATIC_DIR / "index.html"))
@app.get("/capabilities")
async def capabilities():
from inpaint import LAMA_AVAILABLE
from sam_segment import SAM_AVAILABLE
from sd_inpaint import SD_AVAILABLE
try:
from detect import EASYOCR_AVAILABLE
except Exception:
EASYOCR_AVAILABLE = False
return {"lama": LAMA_AVAILABLE, "easyocr": EASYOCR_AVAILABLE,
"opencv": True, "sam": SAM_AVAILABLE, "sd": SD_AVAILABLE}
@app.post("/upload")
async def upload_file(file: UploadFile = File(...)):
ext = Path(file.filename).suffix.lower()
if ext not in VIDEO_EXTS | IMAGE_EXTS:
raise HTTPException(400, f"Unsupported file type: {ext}")
file_id = str(uuid.uuid4())
save_path = UPLOAD_DIR / f"{file_id}{ext}"
with open(save_path, "wb") as f:
shutil.copyfileobj(file.file, f)
is_video = ext in VIDEO_EXTS
if is_video:
import subprocess
preview_path = UPLOAD_DIR / f"{file_id}_preview.jpg"
subprocess.run(
["ffmpeg", "-i", str(save_path), "-vframes", "1", "-q:v", "2", str(preview_path), "-y"],
capture_output=True,
)
return {"file_id": file_id, "ext": ext, "is_video": is_video, "filename": file.filename}
@app.get("/source/{filename:path}")
async def get_source(filename: str):
"""Serve the original uploaded file (used by the compare video player)."""
file_path = _safe_path(filename)
if not file_path.exists():
raise HTTPException(404, "Source file not found")
ext = file_path.suffix.lower()
media_types = {
".mp4": "video/mp4", ".mov": "video/quicktime", ".avi": "video/x-msvideo",
".mkv": "video/x-matroska", ".webm": "video/webm", ".flv": "video/x-flv",
}
return FileResponse(str(file_path), media_type=media_types.get(ext, "application/octet-stream"))
@app.get("/preview/{file_id}")
async def get_preview(file_id: str):
# Video preview (first frame)
p = UPLOAD_DIR / f"{file_id}_preview.jpg"
if p.exists():
return FileResponse(str(p), media_type="image/jpeg")
# Direct image
for ext in IMAGE_EXTS:
p = UPLOAD_DIR / f"{file_id}{ext}"
if p.exists():
return FileResponse(str(p))
raise HTTPException(404, "Preview not found")
@app.post("/detect")
async def detect_watermark(file_id: str = Form(...), ext: str = Form(...)):
from detect import detect_watermarks
if ext in VIDEO_EXTS:
img_path = UPLOAD_DIR / f"{file_id}_preview.jpg"
else:
img_path = UPLOAD_DIR / f"{file_id}{ext}"
if not img_path.exists():
raise HTTPException(404, "File not found")
img = cv2.imread(str(img_path))
if img is None:
raise HTTPException(400, "Cannot read image")
h, w = img.shape[:2]
regions = detect_watermarks(str(img_path))
return {"regions": regions, "image_width": w, "image_height": h}
@app.get("/status/{job_id}")
async def job_status(job_id: str):
job = _jobs.get(job_id)
if not job:
raise HTTPException(404, "Job not found")
return job
@app.post("/process/image")
async def process_image(
file_id: str = Form(...),
ext: str = Form(...),
mask_data: str = Form(...),
method: str = Form("opencv"),
):
from inpaint import LAMA_AVAILABLE, inpaint_image
img_path = UPLOAD_DIR / f"{file_id}{ext}"
if not img_path.exists():
raise HTTPException(404, "File not found")
if method == "lama" and not LAMA_AVAILABLE:
method = "opencv"
img = cv2.imread(str(img_path))
if img is None:
raise HTTPException(400, "Cannot read image")
mask = _decode_mask(mask_data, target_shape=img.shape)
out_ext = ext if ext in {".jpg", ".jpeg", ".png", ".webp"} else ".png"
result_path = UPLOAD_DIR / f"{file_id}_result{out_ext}"
download_url = f"/download/{file_id}_result{out_ext}"
job_id = str(uuid.uuid4())
_jobs[job_id] = {"status": "processing", "progress": 0, "total": 1,
"message": "Running LaMa inpainting…" if method == "lama" else "Inpainting…"}
def _run():
try:
result = inpaint_image(img, mask, method=method)
cv2.imwrite(str(result_path), result)
_job_update(job_id, status="done", progress=1, total=1,
message="Done.", download_url=download_url)
except Exception as e:
_job_update(job_id, status="error", message=str(e))
threading.Thread(target=_run, daemon=True).start()
return {"job_id": job_id}
@app.post("/process/video")
async def process_video(
file_id: str = Form(...),
ext: str = Form(...),
mask_data: str = Form(...),
mode: str = Form("fast"),
method: str = Form("opencv"),
):
from inpaint import LAMA_AVAILABLE
from video import process_video_file
video_path = UPLOAD_DIR / f"{file_id}{ext}"
if not video_path.exists():
raise HTTPException(404, "File not found")
if method == "lama" and not LAMA_AVAILABLE:
method = "opencv"
cap = cv2.VideoCapture(str(video_path))
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
cap.release()
mask = _decode_mask(mask_data, target_shape=(h, w))
result_path = UPLOAD_DIR / f"{file_id}_result.mp4"
download_url = f"/download/{file_id}_result.mp4"
job_id = str(uuid.uuid4())
frames_to_process = 1 if mode == "fast" else total_frames
_jobs[job_id] = {"status": "processing", "progress": 0,
"total": frames_to_process, "message": "Starting…"}
def _progress(frame_num: int, total: int):
pct = round(frame_num / total * 100) if total else 0
_job_update(job_id, progress=frame_num, total=total,
message=f"Frame {frame_num} / {total} ({pct}%)")
def _run():
try:
process_video_file(str(video_path), mask, str(result_path),
mode=mode, method=method, progress_callback=_progress)
_job_update(job_id, status="done", progress=frames_to_process,
total=frames_to_process, message="Done.", download_url=download_url)
except Exception as e:
_job_update(job_id, status="error", message=str(e))
threading.Thread(target=_run, daemon=True).start()
return {"job_id": job_id}
@app.post("/segment")
async def segment_point(
file_id: str = Form(...),
ext: str = Form(...),
x: float = Form(...),
y: float = Form(...),
canvas_w: int = Form(...),
canvas_h: int = Form(...),
):
"""SAM click-to-segment: returns a base64 PNG mask."""
from sam_segment import SAM_AVAILABLE, segment_at_point
if not SAM_AVAILABLE:
raise HTTPException(503, "SAM model not available")
if ext in VIDEO_EXTS:
img_path = UPLOAD_DIR / f"{file_id}_preview.jpg"
else:
img_path = UPLOAD_DIR / f"{file_id}{ext}"
if not img_path.exists():
raise HTTPException(404, "File not found")
img = cv2.imread(str(img_path))
if img is None:
raise HTTPException(400, "Cannot read image")
mask = segment_at_point(img, x, y, canvas_w, canvas_h)
_, buf = cv2.imencode(".png", mask)
b64 = base64.b64encode(buf).decode()
return {"mask": f"data:image/png;base64,{b64}"}
@app.get("/download/{filename:path}")
async def download_file(filename: str):
file_path = _safe_path(filename)
if not file_path.exists():
raise HTTPException(404, "File not found")
ext = file_path.suffix.lower()
name_map = {
".mp4": "watermark_removed.mp4",
".jpg": "watermark_removed.jpg",
".jpeg": "watermark_removed.jpg",
".png": "watermark_removed.png",
".webp": "watermark_removed.webp",
}
download_name = name_map.get(ext, f"watermark_removed{ext}")
return FileResponse(str(file_path), filename=download_name)
if __name__ == "__main__":
import os
import uvicorn
port = int(os.environ.get("PORT", 8000))
uvicorn.run(app, host="0.0.0.0", port=port)