import os import torch import numpy as np from fastapi import FastAPI, UploadFile, File, Form, HTTPException from fastapi.responses import StreamingResponse, HTMLResponse from PIL import Image from io import BytesIO import requests from transformers import AutoModelForImageSegmentation import uvicorn # --------------------------------------------------------- # CPU optimization (important for HF Spaces) # --------------------------------------------------------- os.environ["OMP_NUM_THREADS"] = "1" os.environ["MKL_NUM_THREADS"] = "1" torch.set_num_threads(1) # --------------------------------------------------------- # Config (speed focused) # --------------------------------------------------------- TARGET_SIZE = (320, 320) # 🔥 faster inference MAX_FILE_SIZE = 5 * 1024 * 1024 # 5MB MAX_COMPRESS_DIM = 1400 # aggressive resize # --------------------------------------------------------- # Load model # --------------------------------------------------------- MODEL_DIR = "models/BiRefNet" os.makedirs(MODEL_DIR, exist_ok=True) device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.float16 if torch.cuda.is_available() else torch.float32 print("Loading model...") model = AutoModelForImageSegmentation.from_pretrained( "ZhengPeng7/BiRefNet", cache_dir=MODEL_DIR, trust_remote_code=True ) model.to(device, dtype=dtype).eval() print("Model ready") # --------------------------------------------------------- # Image helpers # --------------------------------------------------------- def load_image_from_url(url: str): r = requests.get(url, timeout=10) r.raise_for_status() return Image.open(BytesIO(r.content)).convert("RGB") # 🔥 FAST compression (key part) def compress_if_needed(img: Image.Image, raw_bytes: bytes): if len(raw_bytes) <= MAX_FILE_SIZE: return img print("[INFO] Compressing image >5MB") img = img.convert("RGB") # Resize aggressively w, h = img.size scale = min(1.0, MAX_COMPRESS_DIM / max(w, h)) img = img.resize((int(w * scale), int(h * scale)), Image.BILINEAR) # Reduce quality quickly (no loop → faster) buffer = BytesIO() img.save(buffer, format="JPEG", quality=70, optimize=True) buffer.seek(0) return Image.open(buffer).convert("RGB") def transform(img): img = img.resize(TARGET_SIZE, Image.BILINEAR) arr = np.asarray(img, dtype=np.float32) / 255.0 mean = np.array([0.485, 0.456, 0.406]) std = np.array([0.229, 0.224, 0.225]) arr = (arr - mean) / std arr = np.transpose(arr, (2, 0, 1)) return torch.from_numpy(arr).unsqueeze(0).to(device=device, dtype=dtype) # 🔥 FAST inference def remove_background(img: Image.Image): orig_size = img.size tensor = transform(img) with torch.inference_mode(): pred = model(tensor) pred = pred[-1] if isinstance(pred, (list, tuple)) else pred pred = pred.sigmoid()[0, 0].cpu() mask = Image.fromarray((pred.mul(255).byte().numpy())) mask = mask.resize(orig_size, Image.BILINEAR) img = img.convert("RGBA") img.putalpha(mask) return img # --------------------------------------------------------- # FastAPI # --------------------------------------------------------- app = FastAPI() @app.post("/remove-background") async def remove_bg(file: UploadFile = File(None), image_url: str = Form(None)): try: if file: raw = await file.read() img = Image.open(BytesIO(raw)).convert("RGB") # ✅ Step 1: compress if >5MB img = compress_if_needed(img, raw) elif image_url: img = load_image_from_url(image_url) else: raise HTTPException(400, "Provide file or URL") # ✅ Step 2: remove background result = remove_background(img) buf = BytesIO() result.save(buf, format="PNG") buf.seek(0) return StreamingResponse(buf, media_type="image/png") except Exception as e: raise HTTPException(500, str(e)) # --------------------------------------------------------- # Simple UI # --------------------------------------------------------- @app.get("/", response_class=HTMLResponse) async def home(): return """