File size: 7,144 Bytes
9a5e266
 
 
 
 
 
4b22c57
 
9a5e266
 
4b22c57
9a5e266
 
 
 
 
 
 
 
 
 
 
 
 
 
3072d42
 
9a5e266
 
 
 
 
 
4b22c57
 
 
 
 
9a5e266
 
 
 
 
 
 
3072d42
 
9a5e266
 
 
 
 
 
 
 
4b22c57
9a5e266
 
 
 
 
 
 
 
 
 
 
 
 
 
4b22c57
9a5e266
 
 
 
4b22c57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a5e266
 
 
 
 
 
 
 
 
 
 
 
 
4b22c57
9a5e266
 
 
 
 
4b22c57
 
 
 
 
 
 
 
9a5e266
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
import os
import shutil
import tempfile
import subprocess
import cv2 # type: ignore
import numpy as np
import asyncio
from fastapi import FastAPI, UploadFile, File, Response, BackgroundTasks, Query, HTTPException
from fastapi.responses import FileResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.concurrency import run_in_threadpool
from rembg import new_session, remove
from enum import Enum

app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

class ModelName(str, Enum):
    birefnet_general = "birefnet-general"
    birefnet_general_lite = "birefnet-general-lite"
    isnet_anime = "isnet-anime"
    u2net = "u2net"

# Cache sessions to avoid reloading models on every request
sessions = {}

# Global semaphore to limit concurrent processing.
# Free tiers have limited CPU/RAM. We limit to 1 concurrent heavy task to prevent OOM/Crashes.
MAX_CONCURRENT_PROCESSING = 1
processing_semaphore = asyncio.Semaphore(MAX_CONCURRENT_PROCESSING)

def get_session(model_name: str):
    if model_name not in sessions:
        print(f"Loading model: {model_name}...")
        sessions[model_name] = new_session(model_name)
    return sessions[model_name]

# Pre-load the default model suitable for Mascots/Cartoons
# 'birefnet-general' offers superior edge detection and quality for mascots
DEFAULT_MODEL = ModelName.birefnet_general

@app.on_event("startup")
async def startup_event():
    # Trigger download/load of default model on startup
    get_session(DEFAULT_MODEL.value)

@app.get("/")
def read_root():
    return {"message": "Background Removal API is running", "concurrent_limit": MAX_CONCURRENT_PROCESSING}

@app.post("/image-bg-removal")
async def image_bg_removal(
    file: UploadFile = File(...),
    model: ModelName = Query(DEFAULT_MODEL, description="Model to use for background removal"),
    alpha_matting: bool = Query(False, description="Enable alpha matting for softer edges"),
    alpha_matting_foreground_threshold: int = Query(240, description="Trimap foreground threshold"),
    alpha_matting_background_threshold: int = Query(10, description="Trimap background threshold"),
    alpha_matting_erode_size: int = Query(10, description="Erode size for alpha matting")
):
    """
    Removes background from an image.
    Returns the image with transparent background (PNG).
    """
    # Read file content first (IO bound, doesn't need semaphore)
    input_image = await file.read()
    
    session = get_session(model.value)
    
    # Acquire semaphore before heavy processing
    if processing_semaphore.locked():
        print("Waiting for processing slot...")
        
    async with processing_semaphore:
        try:
            # Run blocking 'remove' function in a separate thread to avoid blocking the event loop
            output_image = await run_in_threadpool(
                remove,
                input_image, 
                session=session,
                alpha_matting=alpha_matting,
                alpha_matting_foreground_threshold=alpha_matting_foreground_threshold,
                alpha_matting_background_threshold=alpha_matting_background_threshold,
                alpha_matting_erode_size=alpha_matting_erode_size
            )
        except Exception as e:
            print(f"Error with alpha matting: {e}")
            if alpha_matting:
                print("Falling back to standard background removal (alpha_matting=False)...")
                # Fallback also runs in thread pool
                output_image = await run_in_threadpool(remove, input_image, session=session, alpha_matting=False)
            else:
                raise e
    
    return Response(content=output_image, media_type="image/png")

@app.post("/video-bg-removal")
async def video_bg_removal(
    background_tasks: BackgroundTasks, 
    file: UploadFile = File(...),
    model: ModelName = Query(DEFAULT_MODEL, description="Model to use for background removal")
):
    """
    Removes background from a video.
    Returns WebM with Alpha.
    """
    # Create temp file for input (IO bound)
    with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmp_input:
        shutil.copyfileobj(file.file, tmp_input)
        tmp_input_path = tmp_input.name

    try:
        # Acquire semaphore for the heavy video processing
        if processing_semaphore.locked():
             print("Waiting for video processing slot...")

        async with processing_semaphore:
            # Pass model name to processing function, run in thread pool
            output_path = await run_in_threadpool(process_video, tmp_input_path, model.value)
            
    except Exception as e:
        if os.path.exists(tmp_input_path):
            os.remove(tmp_input_path)
        return {"error": str(e)}

    background_tasks.add_task(os.remove, tmp_input_path)
    background_tasks.add_task(os.remove, output_path)

    return FileResponse(output_path, media_type="video/webm", filename="output_bg_removed.webm")

def process_video(input_path: str, model_name: str) -> str:
    cap = cv2.VideoCapture(input_path)
    
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = cap.get(cv2.CAP_PROP_FPS)
    if fps <= 0: fps = 30.0

    output_path = tempfile.mktemp(suffix=".webm")
    
    # FFmpeg command to read raw RGBA video from stdin and output WebM with Alpha
    command = [
        'ffmpeg',
        '-y', # Overwrite output file
        '-f', 'rawvideo',
        '-vcodec', 'rawvideo',
        '-s', f'{width}x{height}',
        '-pix_fmt', 'rgba',
        '-r', str(fps),
        '-i', '-', # Input from stdin
        '-c:v', 'libvpx-vp9',
        '-b:v', '2M', # Reasonable bitrate
        '-pix_fmt', 'yuva420p', # Important for alpha transparency in WebM
        output_path
    ]

    # Open ffmpeg process
    process = subprocess.Popen(command, stdin=subprocess.PIPE)
    
    session = get_session(model_name)

    try:
        while True:
            ret, frame = cap.read()
            if not ret:
                break
            
            frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            
            try:
                # Attempt with alpha matting enabled for quality
                result_rgba = remove(frame_rgb, session=session, alpha_matting=False)
            except Exception as e:
                # Fallback per frame if matting fails
                print(f"Frame processing error (matting): {e}. Fallback to standard.")
                result_rgba = remove(frame_rgb, session=session, alpha_matting=False)
            
            # rembg returns RGBA
            process.stdin.write(result_rgba.tobytes())
            
    except Exception as e:
        print(f"Error during video processing: {e}")
        raise e
    finally:
        cap.release()
        if process.stdin:
            process.stdin.close()
        process.wait()
    
    if process.returncode != 0:
        raise Exception(f"FFmpeg exited with error code {process.returncode}")
    
    return output_path