vanishly / main.py
Samir87699's picture
Fix quality loss by avoiding patch downscaling and improving text-erase crop
e94e2ed
import os
import io
import time
import logging
import asyncio
from contextlib import asynccontextmanager
from concurrent.futures import ThreadPoolExecutor
import cv2
import torch
import numpy as np
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import Response, JSONResponse
from PIL import Image, ImageFilter
from simple_lama_inpainting import SimpleLama
from ultralytics import FastSAM
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
from rembg import remove as rembg_remove, new_session as rembg_new_session
# ─────────────────────────────────────────────
# Logging
# ────────────────────────────────────────────
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ─────────────────────────────────────────────
# Global model cache + thread pool
# ─────────────────────────────────────────────
lama_model: SimpleLama | None = None
fastsam_model: FastSAM | None = None
clipseg_processor = None
clipseg_model = None
# Real-ESRGAN upsampler instances
realesrgan_x2: RealESRGANer | None = None
realesrgan_x4: RealESRGANer | None = None
# ─── FREE TIER CPU OPTIMIZATIONS ─────────────────────────────────────────────
# Real-ESRGAN is massively heavy. On a 2-vCPU Hugging Face space, processing multiple
# images at once will thrash the OS context switcher, causing a massive performance collapse.
# By forcing 1 internally, and queueing 1 at a time, we ensure maximum throughput.
import torch
torch.set_num_threads(1)
# Dedicated thread pool keeps them OFF the async event loop so FastAPI remains responsive.
_executor = ThreadPoolExecutor(max_workers=1)
# ─────────────────────────────────────────────
# Model weight paths (baked into Docker image)
# ─────────────────────────────────────────────
WEIGHTS_DIR = os.path.join(os.path.dirname(__file__), "weights")
REALESRGAN_X2_PATH = os.path.join(WEIGHTS_DIR, "RealESRGAN_x2plus.pth")
REALESRGAN_X4_PATH = os.path.join(WEIGHTS_DIR, "RealESRGAN_x4plus.pth")
FASTSAM_MODEL_PATH = os.path.join(WEIGHTS_DIR, "FastSAM-s.pt")
# ─────────────────────────────────────────────
# Inpainting resolution caps
# ─────────────────────────────────────────────
FREE_MAX_PX = 720 # free users — fast inference
PREMIUM_MAX_PX = 1280 # premium — higher quality inpainting
def _load_realesrgan() -> None:
"""Load Real-ESRGAN models for ×2 and ×4 super-resolution."""
global realesrgan_x2, realesrgan_x4
device = torch.device("cpu") # Hugging Face free tier is CPU-only
for scale, path, var in [(2, REALESRGAN_X2_PATH, "realesrgan_x2"),
(4, REALESRGAN_X4_PATH, "realesrgan_x4")]:
if not os.path.exists(path):
logger.warning(f"Real-ESRGAN x{scale} weights not found at {path} — upscaling at x{scale} disabled.")
continue
try:
# Check if weights need wrapping (for models like 4x-UltraSharp that are bare state dicts)
try:
sd = torch.load(path, map_location="cpu")
if "params_ema" not in sd and "params" not in sd:
logger.info(f"Wrapping state dict for {path} to be compatible with RealESRGANer...")
torch.save({"params_ema": sd}, path)
except Exception as e:
logger.error(f"Error checking/wrapping weights {path}: {e}")
# RRDBNet is the generator backbone of Real-ESRGAN
model = RRDBNet(
num_in_ch=3, num_out_ch=3,
num_feat=64, num_block=23,
num_grow_ch=32, scale=scale
)
upsampler = RealESRGANer(
scale=scale,
model_path=path,
model=model,
tile=128, # ✨ CPU OPTIMIZATION: 128px tiles fit beautifully in CPU cache, avoiding RAM bottleneck
tile_pad=10,
pre_pad=0,
half=False, # CPU requires float32
device=device,
)
globals()[var] = upsampler
logger.info(f"✅ Real-ESRGAN x{scale} loaded.")
except Exception as e:
logger.error(f"Failed to load Real-ESRGAN x{scale}: {e}", exc_info=True)
# ─────────────────────────────────────────────
# Startup: pre-load all models once
# ────────────────────────────────────────────
@asynccontextmanager
async def lifespan(app: FastAPI):
global lama_model, fastsam_model, clipseg_processor, clipseg_model
logger.info("Loading FastSAM model…")
try:
# Try local weight path first, then fallback to download
if os.path.exists(FASTSAM_MODEL_PATH):
logger.info(f"Loading FastSAM from local path: {FASTSAM_MODEL_PATH}")
fastsam_model = FastSAM(FASTSAM_MODEL_PATH)
else:
logger.warning(f"FastSAM weights not found at {FASTSAM_MODEL_PATH}, downloading...")
fastsam_model = FastSAM("FastSAM-s.pt")
except Exception as e:
logger.error(f"FastSAM load failed: {e}")
logger.info("Loading CLIPSeg model for text-based removal…")
try:
clipseg_processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
clipseg_model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
except Exception as e:
logger.error(f"CLIPSeg load failed: {e}")
logger.info("Loading LaMa model…")
try:
lama_model = SimpleLama()
except Exception as e:
logger.error(f"LaMa load failed: {e}")
logger.info("Loading Real-ESRGAN upscaling models…")
_load_realesrgan()
logger.info("✅ All models ready.")
yield
logger.info("Shutting down.")
# ─────────────────────────────────────────────
# App
# ─────────────────────────────────────────────
app = FastAPI(
title="Vanishly AI Backend",
description="High-performance AI inpainting + super-resolution powered by LaMa & Real-ESRGAN.",
version="4.0.0",
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["GET", "POST"],
allow_headers=["*"],
)
# ─────────────────────────────────────────────
# Helpers
# ─────────────────────────────────────────────
def _fit(img: Image.Image, max_px: int) -> Image.Image:
"""Downscale so the longest edge ≤ max_px. Never upscales."""
w, h = img.size
if max(w, h) <= max_px:
return img
ratio = max_px / max(w, h)
nw = max(2, round(w * ratio) & ~1)
nh = max(2, round(h * ratio) & ~1)
return img.resize((nw, nh), Image.LANCZOS)
def _encode_jpeg(img: Image.Image, quality: int = 88) -> bytes:
buf = io.BytesIO()
img.convert("RGB").save(buf, format="JPEG", quality=quality, optimize=False, subsampling=2)
return buf.getvalue()
def _encode_png_fast(img: Image.Image) -> bytes:
buf = io.BytesIO()
img.save(buf, format="PNG", compress_level=1, optimize=False)
return buf.getvalue()
def _run_lama(src: Image.Image, msk: Image.Image) -> Image.Image:
return lama_model(src, msk)
def _run_fastsam(src: Image.Image, px: int, py: int) -> Image.Image | None:
results = fastsam_model(
src,
imgsz=512,
conf=0.3,
iou=0.7,
points=[[px, py]],
labels=[1],
retina_masks=False,
)
if not results or results[0].masks is None:
return None
mask_array = results[0].masks.data[0].cpu().numpy()
mask_img = Image.fromarray((mask_array * 255).astype(np.uint8), mode="L")
if mask_img.size != src.size:
mask_img = mask_img.resize(src.size, Image.NEAREST)
return mask_img
def _run_clipseg(src: Image.Image, prompt: str) -> Image.Image | None:
if clipseg_processor is None or clipseg_model is None:
return None
inputs = clipseg_processor(text=[prompt], images=[src], padding="max_length", return_tensors="pt")
with torch.no_grad():
outputs = clipseg_model(**inputs)
preds = outputs.logits.squeeze()
preds = torch.sigmoid(preds)
# Lowered threshold from 0.4 to 0.25 for better object detection
# Adjust based on results: lower = more aggressive, higher = more conservative
mask_np = (preds > 0.25).cpu().numpy().astype(np.uint8) * 255
kernel = np.ones((9, 9), np.uint8)
mask_np = cv2.dilate(mask_np, kernel, iterations=2)
mask_img = Image.fromarray(mask_np, mode="L")
if mask_img.size != src.size:
mask_img = mask_img.resize(src.size, Image.NEAREST)
return mask_img
def _run_superres(src: Image.Image, scale: int) -> Image.Image:
"""
AI Super-Resolution via Real-ESRGAN.
Real-ESRGAN is a GAN-based model trained on millions of real-world photos.
It genuinely reconstructs texture detail — hair strands, skin pores, fabric
weave, foliage — rather than just interpolating pixels like FSRCNN.
tile=256 means large images are processed in overlapping 256px tiles,
keeping CPU RAM usage manageable regardless of input size.
"""
upsampler = realesrgan_x2 if scale == 2 else realesrgan_x4
if upsampler is None:
# Fallback if model failed to load
logger.warning(f"Real-ESRGAN x{scale} not available, falling back to LANCZOS+sharpen")
w, h = src.size
upscaled = src.resize((w * scale, h * scale), Image.LANCZOS)
return upscaled.filter(ImageFilter.UnsharpMask(radius=2.0, percent=200, threshold=2))
logger.info(f" Real-ESRGAN x{scale}: input={src.size}")
# Real-ESRGAN expects a uint8 numpy BGR array (OpenCV format)
img_np = np.array(src.convert("RGB"))
img_bgr = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
# Run the GAN-based upscale — produces genuinely sharp, detailed output
output_bgr, _ = upsampler.enhance(img_bgr, outscale=scale)
# Convert back to PIL
output_rgb = cv2.cvtColor(output_bgr, cv2.COLOR_BGR2RGB)
result = Image.fromarray(output_rgb)
logger.info(f" Real-ESRGAN x{scale}: output={result.size}")
return result
# ─────────────────────────────────────────────
# Routes
# ─────────────────────────────────────────────
def _run_remove_bg(src: Image.Image) -> Image.Image:
"""
Remove background using rembg (u2net ONNX model).
Returns an RGBA image with the background made transparent.
"""
# rembg expects / returns PIL Images when using the PIL interface
result = rembg_remove(src)
return result
@app.get("/", tags=["Health"])
async def root():
return JSONResponse({
"status": "ok",
"model": "lama+realesrgan",
"ready": lama_model is not None,
"upscale_x2": realesrgan_x2 is not None,
"upscale_x4": realesrgan_x4 is not None,
})
@app.get("/health", tags=["Health"])
async def health():
if lama_model is None or fastsam_model is None:
raise HTTPException(status_code=503, detail="Models not yet loaded")
return JSONResponse({
"status": "healthy",
"upscale_x2": realesrgan_x2 is not None,
"upscale_x4": realesrgan_x4 is not None,
})
@app.post("/segment", tags=["Segmentation"])
async def segment_image(
image: UploadFile = File(...),
normX: float = Form(0.5),
normY: float = Form(0.5),
):
"""
Point-tap segmentation via FastSAM.
Returns a grayscale PNG mask (white = selected object).
"""
if fastsam_model is None:
raise HTTPException(status_code=503, detail="FastSAM not loaded")
if image.content_type not in ("image/png", "image/jpeg", "image/jpg", "application/octet-stream"):
raise HTTPException(status_code=415, detail="Unsupported image format")
t0 = time.perf_counter()
try:
img_data = await image.read()
src = Image.open(io.BytesIO(img_data)).convert("RGB")
src_small = _fit(src, 512)
px = int(normX * src_small.width)
py = int(normY * src_small.height)
logger.info(f"Segmenting @ ({px},{py}) on {src_small.size}")
loop = asyncio.get_event_loop()
mask_img = await loop.run_in_executor(_executor, _run_fastsam, src_small, px, py)
if mask_img is None:
mask_img = Image.new("L", src.size, 0)
elif mask_img.size != src.size:
mask_img = mask_img.resize(src.size, Image.NEAREST)
output_bytes = _encode_png_fast(mask_img)
logger.info(f"Segmentation done in {time.perf_counter()-t0:.2f}s, {len(output_bytes)//1024}KB")
return Response(content=output_bytes, media_type="image/png", headers={"X-Model": "fastsam"})
except Exception as e:
logger.error(f"Segmentation failed: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Segmentation error: {str(e)}")
@app.post("/segment_text", tags=["Segmentation"])
async def segment_text(
image: UploadFile = File(...),
prompt: str = Form(...),
):
"""
Text-based segmentation via CLIPSeg.
Returns a grayscale PNG mask (white = selected object).
"""
if clipseg_model is None:
raise HTTPException(status_code=503, detail="CLIPSeg not loaded")
if image.content_type not in ("image/png", "image/jpeg", "image/jpg", "application/octet-stream"):
raise HTTPException(status_code=415, detail="Unsupported image format")
t0 = time.perf_counter()
try:
img_data = await image.read()
src = Image.open(io.BytesIO(img_data)).convert("RGB")
src_small = _fit(src, 512)
logger.info(f"Segmenting text '{prompt}' on {src_small.size}")
loop = asyncio.get_event_loop()
mask_img = await loop.run_in_executor(_executor, _run_clipseg, src_small, prompt)
if mask_img is None:
mask_img = Image.new("L", src.size, 0)
elif mask_img.size != src.size:
mask_img = mask_img.resize(src.size, Image.NEAREST)
output_bytes = _encode_png_fast(mask_img)
logger.info(f"Text Segmentation done in {time.perf_counter()-t0:.2f}s, {len(output_bytes)//1024}KB")
return Response(content=output_bytes, media_type="image/png", headers={"X-Model": "clipseg"})
except Exception as e:
logger.error(f"Text Segmentation failed: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Text Segmentation error: {str(e)}")
@app.post("/text_erase", tags=["AI Edit"])
async def text_erase(
image: UploadFile = File(...),
prompt: str = Form(...),
premium: str = Form("false"),
):
"""
AI Text-Based Object Removal — single endpoint.
Pipeline: Image + "remove the person" → CLIPSeg mask → LaMa inpaint → Result.
Returns the edited image directly — no separate mask needed.
premium=true → up to 1280px, JPEG 92.
premium=false → capped at 720px, JPEG 85.
"""
if clipseg_model is None or lama_model is None:
raise HTTPException(status_code=503, detail="Models not loaded")
if image.content_type not in ("image/png", "image/jpeg", "image/jpg", "application/octet-stream"):
raise HTTPException(status_code=415, detail="Unsupported image format")
is_premium = premium.lower() in ("true", "1", "yes")
max_px = PREMIUM_MAX_PX if is_premium else FREE_MAX_PX
resp_q = 92 if is_premium else 85
t0 = time.perf_counter()
try:
img_data = await image.read()
src = Image.open(io.BytesIO(img_data)).convert("RGB")
orig_size = src.size
# Step 1: Generate mask from text prompt via CLIPSeg
src_small = _fit(src, 512)
logger.info(f"Text erase '{prompt}' — generating mask on {src_small.size}")
loop = asyncio.get_event_loop()
mask_img = await loop.run_in_executor(_executor, _run_clipseg, src_small, prompt)
if mask_img is None:
logger.warning("CLIPSeg returned None — nothing detected for prompt")
mask_img = Image.new("L", orig_size, 0)
elif mask_img.size != orig_size:
mask_img = mask_img.resize(orig_size, Image.NEAREST)
# Step 2: High-Resolution Cropped Inpainting
msk_np = np.array(mask_img)
kernel = np.ones((9, 9), np.uint8)
msk_dilated_np = cv2.dilate(msk_np, kernel, iterations=2)
msk_dilated = Image.fromarray(msk_dilated_np)
coords = cv2.findNonZero(msk_dilated_np)
if coords is not None:
x, y, w, h = cv2.boundingRect(coords)
pad_x = max(128, int(w * 0.5))
pad_y = max(128, int(h * 0.5))
x1 = max(0, x - pad_x)
y1 = max(0, y - pad_y)
x2 = min(orig_size[0], x + w + pad_x)
y2 = min(orig_size[1], y + h + pad_y)
crop_src = src.crop((x1, y1, x2, y2))
crop_msk = msk_dilated.crop((x1, y1, x2, y2))
# Process crop at native resolution (cap at 2560 to prevent OOM)
inpaint_src = _fit(crop_src, 2560)
if crop_msk.size != inpaint_src.size:
inpaint_msk = crop_msk.resize(inpaint_src.size, Image.NEAREST)
else:
inpaint_msk = crop_msk
logger.info(f"Text erase inpainting crop {crop_src.size}{inpaint_src.size} (premium={is_premium})")
inpainted_small = await loop.run_in_executor(_executor, _run_lama, inpaint_src, inpaint_msk)
if inpainted_small.size != crop_src.size:
inpainted_large = inpainted_small.resize(crop_src.size, Image.LANCZOS)
else:
inpainted_large = inpainted_small
pasted_result = src.copy()
pasted_result.paste(inpainted_large, (x1, y1))
else:
pasted_result = src.copy()
# Feather blend for seamless compositing
blend_mask = msk_dilated.filter(ImageFilter.GaussianBlur(radius=25))
if blend_mask.size != orig_size:
blend_mask = blend_mask.resize(orig_size, Image.LANCZOS)
result = Image.composite(pasted_result, src, blend_mask)
output_bytes = _encode_png_fast(result)
elapsed = time.perf_counter() - t0
logger.info(f"Text erase done in {elapsed:.2f}s, {len(output_bytes)//1024}KB")
return Response(
content=output_bytes,
media_type="image/png",
headers={"X-Model": "clipseg+lama", "X-Time": f"{elapsed:.2f}"},
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Text erase failed: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Text erase error: {str(e)}")
@app.post("/inpaint", tags=["Inpaint"])
async def inpaint(
image: UploadFile = File(...),
mask: UploadFile = File(...),
premium: str = Form("false"),
):
"""
Generative object removal with LaMa.
premium=true → up to 1280px, JPEG 92 response.
premium=false → capped at 720px, JPEG 85 response.
"""
if lama_model is None:
raise HTTPException(status_code=503, detail="Model not loaded, retry in a moment")
for upload, name in [(image, "image"), (mask, "mask")]:
if upload.content_type not in ("image/png", "image/jpeg", "image/jpg", "application/octet-stream"):
raise HTTPException(status_code=415, detail=f"Unsupported type for '{name}'")
is_premium = premium.lower() in ("true", "1", "yes")
max_px = PREMIUM_MAX_PX if is_premium else FREE_MAX_PX
resp_q = 92 if is_premium else 85
t0 = time.perf_counter()
try:
img_data, mask_data = await asyncio.gather(image.read(), mask.read())
src = Image.open(io.BytesIO(img_data)).convert("RGB")
msk = Image.open(io.BytesIO(mask_data)).convert("L")
orig_size = src.size
# Dilate mask so LaMa overwrites the exact boundary and a bit of the background
msk_np = np.array(msk)
kernel = np.ones((9, 9), np.uint8)
msk_dilated_np = cv2.dilate(msk_np, kernel, iterations=2)
msk_dilated = Image.fromarray(msk_dilated_np)
# High-Resolution Cropped Inpainting
coords = cv2.findNonZero(msk_dilated_np)
if coords is not None:
x, y, w, h = cv2.boundingRect(coords)
# Add generous padding (50% of dimension or 128px, whichever is larger)
pad_x = max(128, int(w * 0.5))
pad_y = max(128, int(h * 0.5))
x1 = max(0, x - pad_x)
y1 = max(0, y - pad_y)
x2 = min(orig_size[0], x + w + pad_x)
y2 = min(orig_size[1], y + h + pad_y)
crop_src = src.crop((x1, y1, x2, y2))
crop_msk = msk_dilated.crop((x1, y1, x2, y2))
# Process crop at native resolution (cap at 2560 to prevent OOM)
inpaint_src = _fit(crop_src, 2560)
if crop_msk.size != inpaint_src.size:
inpaint_msk = crop_msk.resize(inpaint_src.size, Image.NEAREST)
else:
inpaint_msk = crop_msk
logger.info(f"Inpainting Crop {crop_src.size}{inpaint_src.size} (premium={is_premium})")
loop = asyncio.get_event_loop()
inpainted_small = await loop.run_in_executor(_executor, _run_lama, inpaint_src, inpaint_msk)
# Restore inpainted crop to its original crop size
if inpainted_small.size != crop_src.size:
inpainted_large = inpainted_small.resize(crop_src.size, Image.LANCZOS)
else:
inpainted_large = inpainted_small
# Paste the inpainted crop back into the full-size image
pasted_result = src.copy()
pasted_result.paste(inpainted_large, (x1, y1))
else:
pasted_result = src.copy()
# Feather blend for seamless compositing onto the untouched original high-res image
blend_mask = msk_dilated.filter(ImageFilter.GaussianBlur(radius=25))
if blend_mask.size != orig_size:
blend_mask = blend_mask.resize(orig_size, Image.LANCZOS)
result = Image.composite(pasted_result, src, blend_mask)
# 0% quality loss
output_bytes = _encode_png_fast(result)
elapsed = time.perf_counter() - t0
logger.info(f"Inpainting done in {elapsed:.2f}s, {len(output_bytes)//1024}KB")
return Response(
content=output_bytes,
media_type="image/png",
headers={"X-Model": "lama", "X-Time": f"{elapsed:.2f}"},
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Inpainting failed: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Inpainting error: {str(e)}")
@app.post("/upscale", tags=["Upscale"])
async def upscale_image(
image: UploadFile = File(...),
scale: str = Form("2"),
premium: str = Form("false"),
):
"""
AI Super-Resolution using Real-ESRGAN — PRO feature only.
- scale=2 → 2× super-resolution (High PRO tier)
- scale=4 → 4× super-resolution (Max PRO tier)
Real-ESRGAN reconstructs genuine texture detail (hair, skin, fabric, etc.)
using a GAN trained on millions of real-world photos.
Returns a lossless PNG with authentically enhanced image quality.
"""
is_premium = premium.lower() in ("true", "1", "yes")
if not is_premium:
raise HTTPException(
status_code=403,
detail="AI upscaling is a PRO feature. Upgrade to access it."
)
if image.content_type not in ("image/png", "image/jpeg", "image/jpg", "application/octet-stream"):
raise HTTPException(status_code=415, detail="Unsupported image format")
scale_int = int(scale) if scale in ("2", "4") else 2
t0 = time.perf_counter()
try:
img_data = await image.read()
src = Image.open(io.BytesIO(img_data)).convert("RGB")
logger.info(f"Upscaling {src.size} × x{scale_int} with Real-ESRGAN")
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(
_executor, _run_superres, src, scale_int
)
# Lossless PNG — upscaled PRO image deserves zero compression loss
output_bytes = _encode_png_fast(result)
elapsed = time.perf_counter() - t0
logger.info(
f"Upscaling done in {elapsed:.2f}s | "
f"{src.size}{result.size} | {len(output_bytes)//1024}KB"
)
return Response(
content=output_bytes,
media_type="image/png",
headers={
"X-Model": "realesrgan",
"X-Scale": str(scale_int),
"X-Time": f"{elapsed:.2f}",
"X-Output-Width": str(result.width),
"X-Output-Height": str(result.height),
},
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Upscaling failed: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Upscaling error: {str(e)}")
@app.post("/remove_bg", tags=["AI Edit"])
async def remove_background(
image: UploadFile = File(...),
):
"""
AI Background Removal using rembg (u2net ONNX model).
Accepts any image format (PNG, JPEG) and returns a transparent PNG
with the background removed. CPU-efficient — no GPU required.
"""
if image.content_type not in (
"image/png", "image/jpeg", "image/jpg", "application/octet-stream"
):
raise HTTPException(status_code=415, detail="Unsupported image format")
t0 = time.perf_counter()
try:
img_data = await image.read()
src = Image.open(io.BytesIO(img_data)).convert("RGBA")
logger.info(f"Background removal: input {src.size}")
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(_executor, _run_remove_bg, src)
# Always output transparent PNG — JPEG does not support alpha channel
output_buf = io.BytesIO()
result.save(output_buf, format="PNG", optimize=False)
output_bytes = output_buf.getvalue()
elapsed = time.perf_counter() - t0
logger.info(
f"Background removal done in {elapsed:.2f}s | "
f"{src.size} | {len(output_bytes) // 1024}KB"
)
return Response(
content=output_bytes,
media_type="image/png",
headers={
"X-Model": "rembg-u2net",
"X-Time": f"{elapsed:.2f}",
"X-Output-Width": str(result.width),
"X-Output-Height": str(result.height),
},
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Background removal failed: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Background removal error: {str(e)}")
if __name__ == "__main__":
import uvicorn
port = int(os.environ.get("PORT", 7860))
uvicorn.run("main:app", host="0.0.0.0", port=port, log_level="info")