""" Watermark Remover — FastAPI backend. Run: uvicorn app:app --reload --port 8000 """ import base64 import shutil import threading import uuid from pathlib import Path import cv2 import numpy as np from fastapi import FastAPI, File, Form, HTTPException, UploadFile from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse from fastapi.staticfiles import StaticFiles # ── Job registry ────────────────────────────────────────── # job_id -> {status, progress, total, message, download_url, error} _jobs: dict = {} def _job_update(job_id: str, **kwargs) -> None: if job_id in _jobs: _jobs[job_id].update(kwargs) BASE_DIR = Path(__file__).parent STATIC_DIR = BASE_DIR / "static" UPLOAD_DIR = Path("/tmp/wm_tool") UPLOAD_DIR.mkdir(parents=True, exist_ok=True) VIDEO_EXTS = {".mp4", ".avi", ".mov", ".mkv", ".webm", ".flv"} IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".webp", ".bmp", ".tiff", ".tif"} app = FastAPI(title="Watermark Remover") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static") # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _decode_mask(mask_data: str, target_shape=None) -> np.ndarray: """Decode a base64 data-URL PNG into a grayscale numpy mask.""" b64 = mask_data.split(",")[-1] raw = base64.b64decode(b64) arr = np.frombuffer(raw, dtype=np.uint8) mask = cv2.imdecode(arr, cv2.IMREAD_GRAYSCALE) if mask is None: raise HTTPException(400, "Invalid mask image") if target_shape is not None and mask.shape[:2] != target_shape[:2]: mask = cv2.resize(mask, (target_shape[1], target_shape[0]), interpolation=cv2.INTER_NEAREST) return mask def _safe_path(filename: str) -> Path: """Resolve path and ensure it stays within UPLOAD_DIR.""" p = (UPLOAD_DIR / filename).resolve() try: p.relative_to(UPLOAD_DIR.resolve()) except ValueError: raise HTTPException(403, "Access denied") return p # --------------------------------------------------------------------------- # Routes # --------------------------------------------------------------------------- @app.get("/") async def index(): return FileResponse(str(STATIC_DIR / "index.html")) @app.get("/capabilities") async def capabilities(): from inpaint import LAMA_AVAILABLE from sam_segment import SAM_AVAILABLE from sd_inpaint import SD_AVAILABLE try: from detect import EASYOCR_AVAILABLE except Exception: EASYOCR_AVAILABLE = False return {"lama": LAMA_AVAILABLE, "easyocr": EASYOCR_AVAILABLE, "opencv": True, "sam": SAM_AVAILABLE, "sd": SD_AVAILABLE} @app.post("/upload") async def upload_file(file: UploadFile = File(...)): ext = Path(file.filename).suffix.lower() if ext not in VIDEO_EXTS | IMAGE_EXTS: raise HTTPException(400, f"Unsupported file type: {ext}") file_id = str(uuid.uuid4()) save_path = UPLOAD_DIR / f"{file_id}{ext}" with open(save_path, "wb") as f: shutil.copyfileobj(file.file, f) is_video = ext in VIDEO_EXTS if is_video: import subprocess preview_path = UPLOAD_DIR / f"{file_id}_preview.jpg" subprocess.run( ["ffmpeg", "-i", str(save_path), "-vframes", "1", "-q:v", "2", str(preview_path), "-y"], capture_output=True, ) return {"file_id": file_id, "ext": ext, "is_video": is_video, "filename": file.filename} @app.get("/source/{filename:path}") async def get_source(filename: str): """Serve the original uploaded file (used by the compare video player).""" file_path = _safe_path(filename) if not file_path.exists(): raise HTTPException(404, "Source file not found") ext = file_path.suffix.lower() media_types = { ".mp4": "video/mp4", ".mov": "video/quicktime", ".avi": "video/x-msvideo", ".mkv": "video/x-matroska", ".webm": "video/webm", ".flv": "video/x-flv", } return FileResponse(str(file_path), media_type=media_types.get(ext, "application/octet-stream")) @app.get("/preview/{file_id}") async def get_preview(file_id: str): # Video preview (first frame) p = UPLOAD_DIR / f"{file_id}_preview.jpg" if p.exists(): return FileResponse(str(p), media_type="image/jpeg") # Direct image for ext in IMAGE_EXTS: p = UPLOAD_DIR / f"{file_id}{ext}" if p.exists(): return FileResponse(str(p)) raise HTTPException(404, "Preview not found") @app.post("/detect") async def detect_watermark(file_id: str = Form(...), ext: str = Form(...)): from detect import detect_watermarks if ext in VIDEO_EXTS: img_path = UPLOAD_DIR / f"{file_id}_preview.jpg" else: img_path = UPLOAD_DIR / f"{file_id}{ext}" if not img_path.exists(): raise HTTPException(404, "File not found") img = cv2.imread(str(img_path)) if img is None: raise HTTPException(400, "Cannot read image") h, w = img.shape[:2] regions = detect_watermarks(str(img_path)) return {"regions": regions, "image_width": w, "image_height": h} @app.get("/status/{job_id}") async def job_status(job_id: str): job = _jobs.get(job_id) if not job: raise HTTPException(404, "Job not found") return job @app.post("/process/image") async def process_image( file_id: str = Form(...), ext: str = Form(...), mask_data: str = Form(...), method: str = Form("opencv"), ): from inpaint import LAMA_AVAILABLE, inpaint_image img_path = UPLOAD_DIR / f"{file_id}{ext}" if not img_path.exists(): raise HTTPException(404, "File not found") if method == "lama" and not LAMA_AVAILABLE: method = "opencv" img = cv2.imread(str(img_path)) if img is None: raise HTTPException(400, "Cannot read image") mask = _decode_mask(mask_data, target_shape=img.shape) out_ext = ext if ext in {".jpg", ".jpeg", ".png", ".webp"} else ".png" result_path = UPLOAD_DIR / f"{file_id}_result{out_ext}" download_url = f"/download/{file_id}_result{out_ext}" job_id = str(uuid.uuid4()) _jobs[job_id] = {"status": "processing", "progress": 0, "total": 1, "message": "Running LaMa inpainting…" if method == "lama" else "Inpainting…"} def _run(): try: result = inpaint_image(img, mask, method=method) cv2.imwrite(str(result_path), result) _job_update(job_id, status="done", progress=1, total=1, message="Done.", download_url=download_url) except Exception as e: _job_update(job_id, status="error", message=str(e)) threading.Thread(target=_run, daemon=True).start() return {"job_id": job_id} @app.post("/process/video") async def process_video( file_id: str = Form(...), ext: str = Form(...), mask_data: str = Form(...), mode: str = Form("fast"), method: str = Form("opencv"), ): from inpaint import LAMA_AVAILABLE from video import process_video_file video_path = UPLOAD_DIR / f"{file_id}{ext}" if not video_path.exists(): raise HTTPException(404, "File not found") if method == "lama" and not LAMA_AVAILABLE: method = "opencv" cap = cv2.VideoCapture(str(video_path)) w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) cap.release() mask = _decode_mask(mask_data, target_shape=(h, w)) result_path = UPLOAD_DIR / f"{file_id}_result.mp4" download_url = f"/download/{file_id}_result.mp4" job_id = str(uuid.uuid4()) frames_to_process = 1 if mode == "fast" else total_frames _jobs[job_id] = {"status": "processing", "progress": 0, "total": frames_to_process, "message": "Starting…"} def _progress(frame_num: int, total: int): pct = round(frame_num / total * 100) if total else 0 _job_update(job_id, progress=frame_num, total=total, message=f"Frame {frame_num} / {total} ({pct}%)") def _run(): try: process_video_file(str(video_path), mask, str(result_path), mode=mode, method=method, progress_callback=_progress) _job_update(job_id, status="done", progress=frames_to_process, total=frames_to_process, message="Done.", download_url=download_url) except Exception as e: _job_update(job_id, status="error", message=str(e)) threading.Thread(target=_run, daemon=True).start() return {"job_id": job_id} @app.post("/segment") async def segment_point( file_id: str = Form(...), ext: str = Form(...), x: float = Form(...), y: float = Form(...), canvas_w: int = Form(...), canvas_h: int = Form(...), ): """SAM click-to-segment: returns a base64 PNG mask.""" from sam_segment import SAM_AVAILABLE, segment_at_point if not SAM_AVAILABLE: raise HTTPException(503, "SAM model not available") if ext in VIDEO_EXTS: img_path = UPLOAD_DIR / f"{file_id}_preview.jpg" else: img_path = UPLOAD_DIR / f"{file_id}{ext}" if not img_path.exists(): raise HTTPException(404, "File not found") img = cv2.imread(str(img_path)) if img is None: raise HTTPException(400, "Cannot read image") mask = segment_at_point(img, x, y, canvas_w, canvas_h) _, buf = cv2.imencode(".png", mask) b64 = base64.b64encode(buf).decode() return {"mask": f"data:image/png;base64,{b64}"} @app.get("/download/{filename:path}") async def download_file(filename: str): file_path = _safe_path(filename) if not file_path.exists(): raise HTTPException(404, "File not found") ext = file_path.suffix.lower() name_map = { ".mp4": "watermark_removed.mp4", ".jpg": "watermark_removed.jpg", ".jpeg": "watermark_removed.jpg", ".png": "watermark_removed.png", ".webp": "watermark_removed.webp", } download_name = name_map.get(ext, f"watermark_removed{ext}") return FileResponse(str(file_path), filename=download_name) if __name__ == "__main__": import os import uvicorn port = int(os.environ.get("PORT", 8000)) uvicorn.run(app, host="0.0.0.0", port=port)