Spaces:
Paused
Paused
| """ | |
| Watermark Remover — FastAPI backend. | |
| Run: uvicorn app:app --reload --port 8000 | |
| """ | |
| import base64 | |
| import shutil | |
| import threading | |
| import uuid | |
| from pathlib import Path | |
| import cv2 | |
| import numpy as np | |
| from fastapi import FastAPI, File, Form, HTTPException, UploadFile | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import FileResponse | |
| from fastapi.staticfiles import StaticFiles | |
| # ── Job registry ────────────────────────────────────────── | |
| # job_id -> {status, progress, total, message, download_url, error} | |
| _jobs: dict = {} | |
| def _job_update(job_id: str, **kwargs) -> None: | |
| if job_id in _jobs: | |
| _jobs[job_id].update(kwargs) | |
| BASE_DIR = Path(__file__).parent | |
| STATIC_DIR = BASE_DIR / "static" | |
| UPLOAD_DIR = Path("/tmp/wm_tool") | |
| UPLOAD_DIR.mkdir(parents=True, exist_ok=True) | |
| VIDEO_EXTS = {".mp4", ".avi", ".mov", ".mkv", ".webm", ".flv"} | |
| IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".webp", ".bmp", ".tiff", ".tif"} | |
| app = FastAPI(title="Watermark Remover") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static") | |
| # --------------------------------------------------------------------------- | |
| # Helpers | |
| # --------------------------------------------------------------------------- | |
| def _decode_mask(mask_data: str, target_shape=None) -> np.ndarray: | |
| """Decode a base64 data-URL PNG into a grayscale numpy mask.""" | |
| b64 = mask_data.split(",")[-1] | |
| raw = base64.b64decode(b64) | |
| arr = np.frombuffer(raw, dtype=np.uint8) | |
| mask = cv2.imdecode(arr, cv2.IMREAD_GRAYSCALE) | |
| if mask is None: | |
| raise HTTPException(400, "Invalid mask image") | |
| if target_shape is not None and mask.shape[:2] != target_shape[:2]: | |
| mask = cv2.resize(mask, (target_shape[1], target_shape[0]), interpolation=cv2.INTER_NEAREST) | |
| return mask | |
| def _safe_path(filename: str) -> Path: | |
| """Resolve path and ensure it stays within UPLOAD_DIR.""" | |
| p = (UPLOAD_DIR / filename).resolve() | |
| try: | |
| p.relative_to(UPLOAD_DIR.resolve()) | |
| except ValueError: | |
| raise HTTPException(403, "Access denied") | |
| return p | |
| # --------------------------------------------------------------------------- | |
| # Routes | |
| # --------------------------------------------------------------------------- | |
| async def index(): | |
| return FileResponse(str(STATIC_DIR / "index.html")) | |
| async def capabilities(): | |
| from inpaint import LAMA_AVAILABLE | |
| from sam_segment import SAM_AVAILABLE | |
| from sd_inpaint import SD_AVAILABLE | |
| try: | |
| from detect import EASYOCR_AVAILABLE | |
| except Exception: | |
| EASYOCR_AVAILABLE = False | |
| return {"lama": LAMA_AVAILABLE, "easyocr": EASYOCR_AVAILABLE, | |
| "opencv": True, "sam": SAM_AVAILABLE, "sd": SD_AVAILABLE} | |
| async def upload_file(file: UploadFile = File(...)): | |
| ext = Path(file.filename).suffix.lower() | |
| if ext not in VIDEO_EXTS | IMAGE_EXTS: | |
| raise HTTPException(400, f"Unsupported file type: {ext}") | |
| file_id = str(uuid.uuid4()) | |
| save_path = UPLOAD_DIR / f"{file_id}{ext}" | |
| with open(save_path, "wb") as f: | |
| shutil.copyfileobj(file.file, f) | |
| is_video = ext in VIDEO_EXTS | |
| if is_video: | |
| import subprocess | |
| preview_path = UPLOAD_DIR / f"{file_id}_preview.jpg" | |
| subprocess.run( | |
| ["ffmpeg", "-i", str(save_path), "-vframes", "1", "-q:v", "2", str(preview_path), "-y"], | |
| capture_output=True, | |
| ) | |
| return {"file_id": file_id, "ext": ext, "is_video": is_video, "filename": file.filename} | |
| async def get_source(filename: str): | |
| """Serve the original uploaded file (used by the compare video player).""" | |
| file_path = _safe_path(filename) | |
| if not file_path.exists(): | |
| raise HTTPException(404, "Source file not found") | |
| ext = file_path.suffix.lower() | |
| media_types = { | |
| ".mp4": "video/mp4", ".mov": "video/quicktime", ".avi": "video/x-msvideo", | |
| ".mkv": "video/x-matroska", ".webm": "video/webm", ".flv": "video/x-flv", | |
| } | |
| return FileResponse(str(file_path), media_type=media_types.get(ext, "application/octet-stream")) | |
| async def get_preview(file_id: str): | |
| # Video preview (first frame) | |
| p = UPLOAD_DIR / f"{file_id}_preview.jpg" | |
| if p.exists(): | |
| return FileResponse(str(p), media_type="image/jpeg") | |
| # Direct image | |
| for ext in IMAGE_EXTS: | |
| p = UPLOAD_DIR / f"{file_id}{ext}" | |
| if p.exists(): | |
| return FileResponse(str(p)) | |
| raise HTTPException(404, "Preview not found") | |
| async def detect_watermark(file_id: str = Form(...), ext: str = Form(...)): | |
| from detect import detect_watermarks | |
| if ext in VIDEO_EXTS: | |
| img_path = UPLOAD_DIR / f"{file_id}_preview.jpg" | |
| else: | |
| img_path = UPLOAD_DIR / f"{file_id}{ext}" | |
| if not img_path.exists(): | |
| raise HTTPException(404, "File not found") | |
| img = cv2.imread(str(img_path)) | |
| if img is None: | |
| raise HTTPException(400, "Cannot read image") | |
| h, w = img.shape[:2] | |
| regions = detect_watermarks(str(img_path)) | |
| return {"regions": regions, "image_width": w, "image_height": h} | |
| async def job_status(job_id: str): | |
| job = _jobs.get(job_id) | |
| if not job: | |
| raise HTTPException(404, "Job not found") | |
| return job | |
| async def process_image( | |
| file_id: str = Form(...), | |
| ext: str = Form(...), | |
| mask_data: str = Form(...), | |
| method: str = Form("opencv"), | |
| ): | |
| from inpaint import LAMA_AVAILABLE, inpaint_image | |
| img_path = UPLOAD_DIR / f"{file_id}{ext}" | |
| if not img_path.exists(): | |
| raise HTTPException(404, "File not found") | |
| if method == "lama" and not LAMA_AVAILABLE: | |
| method = "opencv" | |
| img = cv2.imread(str(img_path)) | |
| if img is None: | |
| raise HTTPException(400, "Cannot read image") | |
| mask = _decode_mask(mask_data, target_shape=img.shape) | |
| out_ext = ext if ext in {".jpg", ".jpeg", ".png", ".webp"} else ".png" | |
| result_path = UPLOAD_DIR / f"{file_id}_result{out_ext}" | |
| download_url = f"/download/{file_id}_result{out_ext}" | |
| job_id = str(uuid.uuid4()) | |
| _jobs[job_id] = {"status": "processing", "progress": 0, "total": 1, | |
| "message": "Running LaMa inpainting…" if method == "lama" else "Inpainting…"} | |
| def _run(): | |
| try: | |
| result = inpaint_image(img, mask, method=method) | |
| cv2.imwrite(str(result_path), result) | |
| _job_update(job_id, status="done", progress=1, total=1, | |
| message="Done.", download_url=download_url) | |
| except Exception as e: | |
| _job_update(job_id, status="error", message=str(e)) | |
| threading.Thread(target=_run, daemon=True).start() | |
| return {"job_id": job_id} | |
| async def process_video( | |
| file_id: str = Form(...), | |
| ext: str = Form(...), | |
| mask_data: str = Form(...), | |
| mode: str = Form("fast"), | |
| method: str = Form("opencv"), | |
| ): | |
| from inpaint import LAMA_AVAILABLE | |
| from video import process_video_file | |
| video_path = UPLOAD_DIR / f"{file_id}{ext}" | |
| if not video_path.exists(): | |
| raise HTTPException(404, "File not found") | |
| if method == "lama" and not LAMA_AVAILABLE: | |
| method = "opencv" | |
| cap = cv2.VideoCapture(str(video_path)) | |
| w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| cap.release() | |
| mask = _decode_mask(mask_data, target_shape=(h, w)) | |
| result_path = UPLOAD_DIR / f"{file_id}_result.mp4" | |
| download_url = f"/download/{file_id}_result.mp4" | |
| job_id = str(uuid.uuid4()) | |
| frames_to_process = 1 if mode == "fast" else total_frames | |
| _jobs[job_id] = {"status": "processing", "progress": 0, | |
| "total": frames_to_process, "message": "Starting…"} | |
| def _progress(frame_num: int, total: int): | |
| pct = round(frame_num / total * 100) if total else 0 | |
| _job_update(job_id, progress=frame_num, total=total, | |
| message=f"Frame {frame_num} / {total} ({pct}%)") | |
| def _run(): | |
| try: | |
| process_video_file(str(video_path), mask, str(result_path), | |
| mode=mode, method=method, progress_callback=_progress) | |
| _job_update(job_id, status="done", progress=frames_to_process, | |
| total=frames_to_process, message="Done.", download_url=download_url) | |
| except Exception as e: | |
| _job_update(job_id, status="error", message=str(e)) | |
| threading.Thread(target=_run, daemon=True).start() | |
| return {"job_id": job_id} | |
| async def segment_point( | |
| file_id: str = Form(...), | |
| ext: str = Form(...), | |
| x: float = Form(...), | |
| y: float = Form(...), | |
| canvas_w: int = Form(...), | |
| canvas_h: int = Form(...), | |
| ): | |
| """SAM click-to-segment: returns a base64 PNG mask.""" | |
| from sam_segment import SAM_AVAILABLE, segment_at_point | |
| if not SAM_AVAILABLE: | |
| raise HTTPException(503, "SAM model not available") | |
| if ext in VIDEO_EXTS: | |
| img_path = UPLOAD_DIR / f"{file_id}_preview.jpg" | |
| else: | |
| img_path = UPLOAD_DIR / f"{file_id}{ext}" | |
| if not img_path.exists(): | |
| raise HTTPException(404, "File not found") | |
| img = cv2.imread(str(img_path)) | |
| if img is None: | |
| raise HTTPException(400, "Cannot read image") | |
| mask = segment_at_point(img, x, y, canvas_w, canvas_h) | |
| _, buf = cv2.imencode(".png", mask) | |
| b64 = base64.b64encode(buf).decode() | |
| return {"mask": f"data:image/png;base64,{b64}"} | |
| async def download_file(filename: str): | |
| file_path = _safe_path(filename) | |
| if not file_path.exists(): | |
| raise HTTPException(404, "File not found") | |
| ext = file_path.suffix.lower() | |
| name_map = { | |
| ".mp4": "watermark_removed.mp4", | |
| ".jpg": "watermark_removed.jpg", | |
| ".jpeg": "watermark_removed.jpg", | |
| ".png": "watermark_removed.png", | |
| ".webp": "watermark_removed.webp", | |
| } | |
| download_name = name_map.get(ext, f"watermark_removed{ext}") | |
| return FileResponse(str(file_path), filename=download_name) | |
| if __name__ == "__main__": | |
| import os | |
| import uvicorn | |
| port = int(os.environ.get("PORT", 8000)) | |
| uvicorn.run(app, host="0.0.0.0", port=port) | |