DavidFernandes's picture
Update app logic and deploy to Hugging Face
4b22c57
import os
import shutil
import tempfile
import subprocess
import cv2 # type: ignore
import numpy as np
import asyncio
from fastapi import FastAPI, UploadFile, File, Response, BackgroundTasks, Query, HTTPException
from fastapi.responses import FileResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.concurrency import run_in_threadpool
from rembg import new_session, remove
from enum import Enum
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class ModelName(str, Enum):
birefnet_general = "birefnet-general"
birefnet_general_lite = "birefnet-general-lite"
isnet_anime = "isnet-anime"
u2net = "u2net"
# Cache sessions to avoid reloading models on every request
sessions = {}
# Global semaphore to limit concurrent processing.
# Free tiers have limited CPU/RAM. We limit to 1 concurrent heavy task to prevent OOM/Crashes.
MAX_CONCURRENT_PROCESSING = 1
processing_semaphore = asyncio.Semaphore(MAX_CONCURRENT_PROCESSING)
def get_session(model_name: str):
if model_name not in sessions:
print(f"Loading model: {model_name}...")
sessions[model_name] = new_session(model_name)
return sessions[model_name]
# Pre-load the default model suitable for Mascots/Cartoons
# 'birefnet-general' offers superior edge detection and quality for mascots
DEFAULT_MODEL = ModelName.birefnet_general
@app.on_event("startup")
async def startup_event():
# Trigger download/load of default model on startup
get_session(DEFAULT_MODEL.value)
@app.get("/")
def read_root():
return {"message": "Background Removal API is running", "concurrent_limit": MAX_CONCURRENT_PROCESSING}
@app.post("/image-bg-removal")
async def image_bg_removal(
file: UploadFile = File(...),
model: ModelName = Query(DEFAULT_MODEL, description="Model to use for background removal"),
alpha_matting: bool = Query(False, description="Enable alpha matting for softer edges"),
alpha_matting_foreground_threshold: int = Query(240, description="Trimap foreground threshold"),
alpha_matting_background_threshold: int = Query(10, description="Trimap background threshold"),
alpha_matting_erode_size: int = Query(10, description="Erode size for alpha matting")
):
"""
Removes background from an image.
Returns the image with transparent background (PNG).
"""
# Read file content first (IO bound, doesn't need semaphore)
input_image = await file.read()
session = get_session(model.value)
# Acquire semaphore before heavy processing
if processing_semaphore.locked():
print("Waiting for processing slot...")
async with processing_semaphore:
try:
# Run blocking 'remove' function in a separate thread to avoid blocking the event loop
output_image = await run_in_threadpool(
remove,
input_image,
session=session,
alpha_matting=alpha_matting,
alpha_matting_foreground_threshold=alpha_matting_foreground_threshold,
alpha_matting_background_threshold=alpha_matting_background_threshold,
alpha_matting_erode_size=alpha_matting_erode_size
)
except Exception as e:
print(f"Error with alpha matting: {e}")
if alpha_matting:
print("Falling back to standard background removal (alpha_matting=False)...")
# Fallback also runs in thread pool
output_image = await run_in_threadpool(remove, input_image, session=session, alpha_matting=False)
else:
raise e
return Response(content=output_image, media_type="image/png")
@app.post("/video-bg-removal")
async def video_bg_removal(
background_tasks: BackgroundTasks,
file: UploadFile = File(...),
model: ModelName = Query(DEFAULT_MODEL, description="Model to use for background removal")
):
"""
Removes background from a video.
Returns WebM with Alpha.
"""
# Create temp file for input (IO bound)
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmp_input:
shutil.copyfileobj(file.file, tmp_input)
tmp_input_path = tmp_input.name
try:
# Acquire semaphore for the heavy video processing
if processing_semaphore.locked():
print("Waiting for video processing slot...")
async with processing_semaphore:
# Pass model name to processing function, run in thread pool
output_path = await run_in_threadpool(process_video, tmp_input_path, model.value)
except Exception as e:
if os.path.exists(tmp_input_path):
os.remove(tmp_input_path)
return {"error": str(e)}
background_tasks.add_task(os.remove, tmp_input_path)
background_tasks.add_task(os.remove, output_path)
return FileResponse(output_path, media_type="video/webm", filename="output_bg_removed.webm")
def process_video(input_path: str, model_name: str) -> str:
cap = cv2.VideoCapture(input_path)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS)
if fps <= 0: fps = 30.0
output_path = tempfile.mktemp(suffix=".webm")
# FFmpeg command to read raw RGBA video from stdin and output WebM with Alpha
command = [
'ffmpeg',
'-y', # Overwrite output file
'-f', 'rawvideo',
'-vcodec', 'rawvideo',
'-s', f'{width}x{height}',
'-pix_fmt', 'rgba',
'-r', str(fps),
'-i', '-', # Input from stdin
'-c:v', 'libvpx-vp9',
'-b:v', '2M', # Reasonable bitrate
'-pix_fmt', 'yuva420p', # Important for alpha transparency in WebM
output_path
]
# Open ffmpeg process
process = subprocess.Popen(command, stdin=subprocess.PIPE)
session = get_session(model_name)
try:
while True:
ret, frame = cap.read()
if not ret:
break
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
try:
# Attempt with alpha matting enabled for quality
result_rgba = remove(frame_rgb, session=session, alpha_matting=False)
except Exception as e:
# Fallback per frame if matting fails
print(f"Frame processing error (matting): {e}. Fallback to standard.")
result_rgba = remove(frame_rgb, session=session, alpha_matting=False)
# rembg returns RGBA
process.stdin.write(result_rgba.tobytes())
except Exception as e:
print(f"Error during video processing: {e}")
raise e
finally:
cap.release()
if process.stdin:
process.stdin.close()
process.wait()
if process.returncode != 0:
raise Exception(f"FFmpeg exited with error code {process.returncode}")
return output_path