|
|
import os |
|
|
import shutil |
|
|
import tempfile |
|
|
import subprocess |
|
|
import cv2 |
|
|
import numpy as np |
|
|
import asyncio |
|
|
from fastapi import FastAPI, UploadFile, File, Response, BackgroundTasks, Query, HTTPException |
|
|
from fastapi.responses import FileResponse |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from fastapi.concurrency import run_in_threadpool |
|
|
from rembg import new_session, remove |
|
|
from enum import Enum |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
class ModelName(str, Enum): |
|
|
birefnet_general = "birefnet-general" |
|
|
birefnet_general_lite = "birefnet-general-lite" |
|
|
isnet_anime = "isnet-anime" |
|
|
u2net = "u2net" |
|
|
|
|
|
|
|
|
sessions = {} |
|
|
|
|
|
|
|
|
|
|
|
MAX_CONCURRENT_PROCESSING = 1 |
|
|
processing_semaphore = asyncio.Semaphore(MAX_CONCURRENT_PROCESSING) |
|
|
|
|
|
def get_session(model_name: str): |
|
|
if model_name not in sessions: |
|
|
print(f"Loading model: {model_name}...") |
|
|
sessions[model_name] = new_session(model_name) |
|
|
return sessions[model_name] |
|
|
|
|
|
|
|
|
|
|
|
DEFAULT_MODEL = ModelName.birefnet_general |
|
|
|
|
|
@app.on_event("startup") |
|
|
async def startup_event(): |
|
|
|
|
|
get_session(DEFAULT_MODEL.value) |
|
|
|
|
|
@app.get("/") |
|
|
def read_root(): |
|
|
return {"message": "Background Removal API is running", "concurrent_limit": MAX_CONCURRENT_PROCESSING} |
|
|
|
|
|
@app.post("/image-bg-removal") |
|
|
async def image_bg_removal( |
|
|
file: UploadFile = File(...), |
|
|
model: ModelName = Query(DEFAULT_MODEL, description="Model to use for background removal"), |
|
|
alpha_matting: bool = Query(False, description="Enable alpha matting for softer edges"), |
|
|
alpha_matting_foreground_threshold: int = Query(240, description="Trimap foreground threshold"), |
|
|
alpha_matting_background_threshold: int = Query(10, description="Trimap background threshold"), |
|
|
alpha_matting_erode_size: int = Query(10, description="Erode size for alpha matting") |
|
|
): |
|
|
""" |
|
|
Removes background from an image. |
|
|
Returns the image with transparent background (PNG). |
|
|
""" |
|
|
|
|
|
input_image = await file.read() |
|
|
|
|
|
session = get_session(model.value) |
|
|
|
|
|
|
|
|
if processing_semaphore.locked(): |
|
|
print("Waiting for processing slot...") |
|
|
|
|
|
async with processing_semaphore: |
|
|
try: |
|
|
|
|
|
output_image = await run_in_threadpool( |
|
|
remove, |
|
|
input_image, |
|
|
session=session, |
|
|
alpha_matting=alpha_matting, |
|
|
alpha_matting_foreground_threshold=alpha_matting_foreground_threshold, |
|
|
alpha_matting_background_threshold=alpha_matting_background_threshold, |
|
|
alpha_matting_erode_size=alpha_matting_erode_size |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"Error with alpha matting: {e}") |
|
|
if alpha_matting: |
|
|
print("Falling back to standard background removal (alpha_matting=False)...") |
|
|
|
|
|
output_image = await run_in_threadpool(remove, input_image, session=session, alpha_matting=False) |
|
|
else: |
|
|
raise e |
|
|
|
|
|
return Response(content=output_image, media_type="image/png") |
|
|
|
|
|
@app.post("/video-bg-removal") |
|
|
async def video_bg_removal( |
|
|
background_tasks: BackgroundTasks, |
|
|
file: UploadFile = File(...), |
|
|
model: ModelName = Query(DEFAULT_MODEL, description="Model to use for background removal") |
|
|
): |
|
|
""" |
|
|
Removes background from a video. |
|
|
Returns WebM with Alpha. |
|
|
""" |
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmp_input: |
|
|
shutil.copyfileobj(file.file, tmp_input) |
|
|
tmp_input_path = tmp_input.name |
|
|
|
|
|
try: |
|
|
|
|
|
if processing_semaphore.locked(): |
|
|
print("Waiting for video processing slot...") |
|
|
|
|
|
async with processing_semaphore: |
|
|
|
|
|
output_path = await run_in_threadpool(process_video, tmp_input_path, model.value) |
|
|
|
|
|
except Exception as e: |
|
|
if os.path.exists(tmp_input_path): |
|
|
os.remove(tmp_input_path) |
|
|
return {"error": str(e)} |
|
|
|
|
|
background_tasks.add_task(os.remove, tmp_input_path) |
|
|
background_tasks.add_task(os.remove, output_path) |
|
|
|
|
|
return FileResponse(output_path, media_type="video/webm", filename="output_bg_removed.webm") |
|
|
|
|
|
def process_video(input_path: str, model_name: str) -> str: |
|
|
cap = cv2.VideoCapture(input_path) |
|
|
|
|
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
|
fps = cap.get(cv2.CAP_PROP_FPS) |
|
|
if fps <= 0: fps = 30.0 |
|
|
|
|
|
output_path = tempfile.mktemp(suffix=".webm") |
|
|
|
|
|
|
|
|
command = [ |
|
|
'ffmpeg', |
|
|
'-y', |
|
|
'-f', 'rawvideo', |
|
|
'-vcodec', 'rawvideo', |
|
|
'-s', f'{width}x{height}', |
|
|
'-pix_fmt', 'rgba', |
|
|
'-r', str(fps), |
|
|
'-i', '-', |
|
|
'-c:v', 'libvpx-vp9', |
|
|
'-b:v', '2M', |
|
|
'-pix_fmt', 'yuva420p', |
|
|
output_path |
|
|
] |
|
|
|
|
|
|
|
|
process = subprocess.Popen(command, stdin=subprocess.PIPE) |
|
|
|
|
|
session = get_session(model_name) |
|
|
|
|
|
try: |
|
|
while True: |
|
|
ret, frame = cap.read() |
|
|
if not ret: |
|
|
break |
|
|
|
|
|
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
try: |
|
|
|
|
|
result_rgba = remove(frame_rgb, session=session, alpha_matting=False) |
|
|
except Exception as e: |
|
|
|
|
|
print(f"Frame processing error (matting): {e}. Fallback to standard.") |
|
|
result_rgba = remove(frame_rgb, session=session, alpha_matting=False) |
|
|
|
|
|
|
|
|
process.stdin.write(result_rgba.tobytes()) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error during video processing: {e}") |
|
|
raise e |
|
|
finally: |
|
|
cap.release() |
|
|
if process.stdin: |
|
|
process.stdin.close() |
|
|
process.wait() |
|
|
|
|
|
if process.returncode != 0: |
|
|
raise Exception(f"FFmpeg exited with error code {process.returncode}") |
|
|
|
|
|
return output_path |
|
|
|