import os import shutil import tempfile import subprocess import cv2 # type: ignore import numpy as np import asyncio from fastapi import FastAPI, UploadFile, File, Response, BackgroundTasks, Query, HTTPException from fastapi.responses import FileResponse from fastapi.middleware.cors import CORSMiddleware from fastapi.concurrency import run_in_threadpool from rembg import new_session, remove from enum import Enum app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) class ModelName(str, Enum): birefnet_general = "birefnet-general" birefnet_general_lite = "birefnet-general-lite" isnet_anime = "isnet-anime" u2net = "u2net" # Cache sessions to avoid reloading models on every request sessions = {} # Global semaphore to limit concurrent processing. # Free tiers have limited CPU/RAM. We limit to 1 concurrent heavy task to prevent OOM/Crashes. MAX_CONCURRENT_PROCESSING = 1 processing_semaphore = asyncio.Semaphore(MAX_CONCURRENT_PROCESSING) def get_session(model_name: str): if model_name not in sessions: print(f"Loading model: {model_name}...") sessions[model_name] = new_session(model_name) return sessions[model_name] # Pre-load the default model suitable for Mascots/Cartoons # 'birefnet-general' offers superior edge detection and quality for mascots DEFAULT_MODEL = ModelName.birefnet_general @app.on_event("startup") async def startup_event(): # Trigger download/load of default model on startup get_session(DEFAULT_MODEL.value) @app.get("/") def read_root(): return {"message": "Background Removal API is running", "concurrent_limit": MAX_CONCURRENT_PROCESSING} @app.post("/image-bg-removal") async def image_bg_removal( file: UploadFile = File(...), model: ModelName = Query(DEFAULT_MODEL, description="Model to use for background removal"), alpha_matting: bool = Query(False, description="Enable alpha matting for softer edges"), alpha_matting_foreground_threshold: int = Query(240, description="Trimap foreground threshold"), alpha_matting_background_threshold: int = Query(10, description="Trimap background threshold"), alpha_matting_erode_size: int = Query(10, description="Erode size for alpha matting") ): """ Removes background from an image. Returns the image with transparent background (PNG). """ # Read file content first (IO bound, doesn't need semaphore) input_image = await file.read() session = get_session(model.value) # Acquire semaphore before heavy processing if processing_semaphore.locked(): print("Waiting for processing slot...") async with processing_semaphore: try: # Run blocking 'remove' function in a separate thread to avoid blocking the event loop output_image = await run_in_threadpool( remove, input_image, session=session, alpha_matting=alpha_matting, alpha_matting_foreground_threshold=alpha_matting_foreground_threshold, alpha_matting_background_threshold=alpha_matting_background_threshold, alpha_matting_erode_size=alpha_matting_erode_size ) except Exception as e: print(f"Error with alpha matting: {e}") if alpha_matting: print("Falling back to standard background removal (alpha_matting=False)...") # Fallback also runs in thread pool output_image = await run_in_threadpool(remove, input_image, session=session, alpha_matting=False) else: raise e return Response(content=output_image, media_type="image/png") @app.post("/video-bg-removal") async def video_bg_removal( background_tasks: BackgroundTasks, file: UploadFile = File(...), model: ModelName = Query(DEFAULT_MODEL, description="Model to use for background removal") ): """ Removes background from a video. Returns WebM with Alpha. """ # Create temp file for input (IO bound) with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmp_input: shutil.copyfileobj(file.file, tmp_input) tmp_input_path = tmp_input.name try: # Acquire semaphore for the heavy video processing if processing_semaphore.locked(): print("Waiting for video processing slot...") async with processing_semaphore: # Pass model name to processing function, run in thread pool output_path = await run_in_threadpool(process_video, tmp_input_path, model.value) except Exception as e: if os.path.exists(tmp_input_path): os.remove(tmp_input_path) return {"error": str(e)} background_tasks.add_task(os.remove, tmp_input_path) background_tasks.add_task(os.remove, output_path) return FileResponse(output_path, media_type="video/webm", filename="output_bg_removed.webm") def process_video(input_path: str, model_name: str) -> str: cap = cv2.VideoCapture(input_path) width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) fps = cap.get(cv2.CAP_PROP_FPS) if fps <= 0: fps = 30.0 output_path = tempfile.mktemp(suffix=".webm") # FFmpeg command to read raw RGBA video from stdin and output WebM with Alpha command = [ 'ffmpeg', '-y', # Overwrite output file '-f', 'rawvideo', '-vcodec', 'rawvideo', '-s', f'{width}x{height}', '-pix_fmt', 'rgba', '-r', str(fps), '-i', '-', # Input from stdin '-c:v', 'libvpx-vp9', '-b:v', '2M', # Reasonable bitrate '-pix_fmt', 'yuva420p', # Important for alpha transparency in WebM output_path ] # Open ffmpeg process process = subprocess.Popen(command, stdin=subprocess.PIPE) session = get_session(model_name) try: while True: ret, frame = cap.read() if not ret: break frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) try: # Attempt with alpha matting enabled for quality result_rgba = remove(frame_rgb, session=session, alpha_matting=False) except Exception as e: # Fallback per frame if matting fails print(f"Frame processing error (matting): {e}. Fallback to standard.") result_rgba = remove(frame_rgb, session=session, alpha_matting=False) # rembg returns RGBA process.stdin.write(result_rgba.tobytes()) except Exception as e: print(f"Error during video processing: {e}") raise e finally: cap.release() if process.stdin: process.stdin.close() process.wait() if process.returncode != 0: raise Exception(f"FFmpeg exited with error code {process.returncode}") return output_path