import os import threading import torch import numpy as np from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request from fastapi.responses import StreamingResponse, HTMLResponse, RedirectResponse, JSONResponse from PIL import Image from io import BytesIO import requests from transformers import AutoModelForImageSegmentation import uvicorn # --------------------------------------------------------- # Optional HEIC/HEIF # --------------------------------------------------------- try: import pillow_heif pillow_heif.register_heif_opener() except ImportError: pass # --------------------------------------------------------- # Performance settings for HF CPU # --------------------------------------------------------- os.environ["OMP_NUM_THREADS"] = "1" os.environ["MKL_NUM_THREADS"] = "1" torch.set_num_threads(1) # --------------------------------------------------------- # Constants # --------------------------------------------------------- TARGET_SIZE = (512, 512) # Faster inference MAX_SIDE = 3000 # Auto-downscale for huge uploads # --------------------------------------------------------- # Load model # --------------------------------------------------------- MODEL_DIR = "models/BiRefNet" os.makedirs(MODEL_DIR, exist_ok=True) device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.float16 if torch.cuda.is_available() else torch.float32 print("Loading BiRefNet…") birefnet = AutoModelForImageSegmentation.from_pretrained( "ZhengPeng7/BiRefNet", cache_dir=MODEL_DIR, trust_remote_code=True, revision="main", ) birefnet.to(device, dtype=dtype).eval() print("Model ready.") lock = threading.Lock() # --------------------------------------------------------- # Helpers # --------------------------------------------------------- def load_image_from_url(url: str) -> Image.Image: try: r = requests.get(url, timeout=10) r.raise_for_status() return Image.open(BytesIO(r.content)).convert("RGB") except Exception: raise HTTPException(status_code=400, detail="Invalid image URL") def auto_downscale(img: Image.Image) -> Image.Image: w, h = img.size if max(w, h) <= MAX_SIDE: return img scale = MAX_SIDE / max(w, h) new_w = int(w * scale) new_h = int(h * scale) print(f"[INFO] Downscaling {w}×{h} → {new_w}×{new_h}") return img.resize((new_w, new_h), Image.LANCZOS) def transform(img: Image.Image) -> torch.Tensor: img = img.resize(TARGET_SIZE) arr = np.array(img).astype(np.float32) / 255.0 mean = np.array([0.485, 0.456, 0.406]) std = np.array([0.229, 0.224, 0.225]) arr = (arr - mean) / std arr = np.transpose(arr, (2, 0, 1)) t = torch.from_numpy(arr).unsqueeze(0).to(device=device, dtype=dtype) return t def run_inference(img: Image.Image) -> Image.Image: orig_size = img.size tensor = transform(img) with lock: with torch.no_grad(): pred = birefnet(tensor)[-1].sigmoid().cpu()[0, 0] mask = Image.fromarray((pred.numpy() * 255).astype(np.uint8)).resize(orig_size) img = img.convert("RGBA") img.putalpha(mask) return img # --------------------------------------------------------- # FastAPI app # --------------------------------------------------------- app = FastAPI(title="Background Remover API") # --------------------------------------------------------- # Redirect GET → POST logic # --------------------------------------------------------- @app.get("/remove-background") async def redirect_to_post(): return JSONResponse( {"detail": "This endpoint only supports POST. Use POST /remove-background"}, status_code=405 ) # --------------------------------------------------------- # Main POST endpoint # --------------------------------------------------------- @app.post("/remove-background") async def remove_bg(file: UploadFile = File(None), image_url: str = Form(None)): try: if file: raw = await file.read() img = Image.open(BytesIO(raw)).convert("RGB") elif image_url: img = load_image_from_url(image_url) else: raise HTTPException(status_code=400, detail="Upload file or image_url required") img = auto_downscale(img) result = run_inference(img) buf = BytesIO() result.save(buf, format="PNG") buf.seek(0) return StreamingResponse(buf, media_type="image/png") except Exception as e: raise HTTPException(status_code=500, detail=str(e)) # --------------------------------------------------------- # UI: Show INPUT + OUTPUT (big preview) # --------------------------------------------------------- @app.get("/", response_class=HTMLResponse) async def ui(): return """ Background Remover – Test UI

API Test Panel (POST Only)

Input Image
Output Image

Upload Test


URL Test

""" # --------------------------------------------------------- # Run app # --------------------------------------------------------- if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)