File size: 7,144 Bytes
9a5e266 4b22c57 9a5e266 4b22c57 9a5e266 3072d42 9a5e266 4b22c57 9a5e266 3072d42 9a5e266 4b22c57 9a5e266 4b22c57 9a5e266 4b22c57 9a5e266 4b22c57 9a5e266 4b22c57 9a5e266 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 | import os
import shutil
import tempfile
import subprocess
import cv2 # type: ignore
import numpy as np
import asyncio
from fastapi import FastAPI, UploadFile, File, Response, BackgroundTasks, Query, HTTPException
from fastapi.responses import FileResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.concurrency import run_in_threadpool
from rembg import new_session, remove
from enum import Enum
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class ModelName(str, Enum):
birefnet_general = "birefnet-general"
birefnet_general_lite = "birefnet-general-lite"
isnet_anime = "isnet-anime"
u2net = "u2net"
# Cache sessions to avoid reloading models on every request
sessions = {}
# Global semaphore to limit concurrent processing.
# Free tiers have limited CPU/RAM. We limit to 1 concurrent heavy task to prevent OOM/Crashes.
MAX_CONCURRENT_PROCESSING = 1
processing_semaphore = asyncio.Semaphore(MAX_CONCURRENT_PROCESSING)
def get_session(model_name: str):
if model_name not in sessions:
print(f"Loading model: {model_name}...")
sessions[model_name] = new_session(model_name)
return sessions[model_name]
# Pre-load the default model suitable for Mascots/Cartoons
# 'birefnet-general' offers superior edge detection and quality for mascots
DEFAULT_MODEL = ModelName.birefnet_general
@app.on_event("startup")
async def startup_event():
# Trigger download/load of default model on startup
get_session(DEFAULT_MODEL.value)
@app.get("/")
def read_root():
return {"message": "Background Removal API is running", "concurrent_limit": MAX_CONCURRENT_PROCESSING}
@app.post("/image-bg-removal")
async def image_bg_removal(
file: UploadFile = File(...),
model: ModelName = Query(DEFAULT_MODEL, description="Model to use for background removal"),
alpha_matting: bool = Query(False, description="Enable alpha matting for softer edges"),
alpha_matting_foreground_threshold: int = Query(240, description="Trimap foreground threshold"),
alpha_matting_background_threshold: int = Query(10, description="Trimap background threshold"),
alpha_matting_erode_size: int = Query(10, description="Erode size for alpha matting")
):
"""
Removes background from an image.
Returns the image with transparent background (PNG).
"""
# Read file content first (IO bound, doesn't need semaphore)
input_image = await file.read()
session = get_session(model.value)
# Acquire semaphore before heavy processing
if processing_semaphore.locked():
print("Waiting for processing slot...")
async with processing_semaphore:
try:
# Run blocking 'remove' function in a separate thread to avoid blocking the event loop
output_image = await run_in_threadpool(
remove,
input_image,
session=session,
alpha_matting=alpha_matting,
alpha_matting_foreground_threshold=alpha_matting_foreground_threshold,
alpha_matting_background_threshold=alpha_matting_background_threshold,
alpha_matting_erode_size=alpha_matting_erode_size
)
except Exception as e:
print(f"Error with alpha matting: {e}")
if alpha_matting:
print("Falling back to standard background removal (alpha_matting=False)...")
# Fallback also runs in thread pool
output_image = await run_in_threadpool(remove, input_image, session=session, alpha_matting=False)
else:
raise e
return Response(content=output_image, media_type="image/png")
@app.post("/video-bg-removal")
async def video_bg_removal(
background_tasks: BackgroundTasks,
file: UploadFile = File(...),
model: ModelName = Query(DEFAULT_MODEL, description="Model to use for background removal")
):
"""
Removes background from a video.
Returns WebM with Alpha.
"""
# Create temp file for input (IO bound)
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmp_input:
shutil.copyfileobj(file.file, tmp_input)
tmp_input_path = tmp_input.name
try:
# Acquire semaphore for the heavy video processing
if processing_semaphore.locked():
print("Waiting for video processing slot...")
async with processing_semaphore:
# Pass model name to processing function, run in thread pool
output_path = await run_in_threadpool(process_video, tmp_input_path, model.value)
except Exception as e:
if os.path.exists(tmp_input_path):
os.remove(tmp_input_path)
return {"error": str(e)}
background_tasks.add_task(os.remove, tmp_input_path)
background_tasks.add_task(os.remove, output_path)
return FileResponse(output_path, media_type="video/webm", filename="output_bg_removed.webm")
def process_video(input_path: str, model_name: str) -> str:
cap = cv2.VideoCapture(input_path)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS)
if fps <= 0: fps = 30.0
output_path = tempfile.mktemp(suffix=".webm")
# FFmpeg command to read raw RGBA video from stdin and output WebM with Alpha
command = [
'ffmpeg',
'-y', # Overwrite output file
'-f', 'rawvideo',
'-vcodec', 'rawvideo',
'-s', f'{width}x{height}',
'-pix_fmt', 'rgba',
'-r', str(fps),
'-i', '-', # Input from stdin
'-c:v', 'libvpx-vp9',
'-b:v', '2M', # Reasonable bitrate
'-pix_fmt', 'yuva420p', # Important for alpha transparency in WebM
output_path
]
# Open ffmpeg process
process = subprocess.Popen(command, stdin=subprocess.PIPE)
session = get_session(model_name)
try:
while True:
ret, frame = cap.read()
if not ret:
break
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
try:
# Attempt with alpha matting enabled for quality
result_rgba = remove(frame_rgb, session=session, alpha_matting=False)
except Exception as e:
# Fallback per frame if matting fails
print(f"Frame processing error (matting): {e}. Fallback to standard.")
result_rgba = remove(frame_rgb, session=session, alpha_matting=False)
# rembg returns RGBA
process.stdin.write(result_rgba.tobytes())
except Exception as e:
print(f"Error during video processing: {e}")
raise e
finally:
cap.release()
if process.stdin:
process.stdin.close()
process.wait()
if process.returncode != 0:
raise Exception(f"FFmpeg exited with error code {process.returncode}")
return output_path
|