import os import io import torch import numpy as np from PIL import Image from safetensors.torch import load_file from fastapi import FastAPI, File, UploadFile, HTTPException, Request, Depends from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from fastapi.security import APIKeyHeader import uvicorn from birefnet import BiRefNet from BiRefNet_config import BiRefNetConfig # ========================= # HUGGING FACE SECRET # ========================= API_KEY = os.getenv("BIREFNET_API_KEY") if not API_KEY: raise RuntimeError("❌ BIREFNET_API_KEY not found in HF Space Secrets") DEVICE = "cpu" # ========================= # LOAD MODEL # ========================= config = BiRefNetConfig() model = BiRefNet(config) state_dict = load_file("model.safetensors") model.load_state_dict(state_dict, strict=False) model.to(DEVICE) model.eval() print("✅ BiRefNet Lite loaded") # ========================= # API KEY AUTH # ========================= api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) def verify_api_key(api_key: str = Depends(api_key_header)): if api_key != API_KEY: raise HTTPException( status_code=401, detail="Invalid or missing API key" ) # ========================= # IMAGE PIPELINE # ========================= def preprocess(img: Image.Image): img = img.convert("RGB").resize((1024, 1024)) arr = np.array(img).astype(np.float32) / 255.0 arr = arr.transpose(2, 0, 1) return torch.from_numpy(arr).unsqueeze(0) @torch.no_grad() def remove_bg(image: Image.Image) -> Image.Image: x = preprocess(image).to(DEVICE) pred = model(x)[0] pred = torch.sigmoid(pred) mask = pred.squeeze().cpu().numpy() mask = (mask * 255).astype(np.uint8) mask = Image.fromarray(mask).resize(image.size) out = image.convert("RGBA") out.putalpha(mask) return out # ========================= # FASTAPI APP # ========================= app = FastAPI(title="BiRefNet Background Remover API") app.add_middleware( CORSMiddleware, allow_origins=["*"], # Restrict later allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # ========================= # ROUTES # ========================= @app.get("/") async def root(): return { "status": "ok", "secured": True, "endpoint": "/remove-bg" } @app.post("/remove-bg") async def remove_background( request: Request, file: UploadFile = File(None), _: str = Depends(verify_api_key) ): try: if file is None: body = await request.body() if not body: raise HTTPException(400, "No image data received") image = Image.open(io.BytesIO(body)) else: contents = await file.read() if not contents: raise HTTPException(400, "Empty file") image = Image.open(io.BytesIO(contents)) if image.format not in ["JPEG", "JPG", "PNG"]: raise HTTPException(400, "Invalid image format") result = remove_bg(image) img_bytes = io.BytesIO() result.save(img_bytes, format="PNG") img_bytes.seek(0) return StreamingResponse( img_bytes, media_type="image/png", headers={ "Content-Disposition": "inline; filename=removed-bg.png" } ) except Exception as e: print("❌ Error:", e) raise HTTPException(500, "Processing failed") # ========================= # HF DOCKER ENTRYPOINT # ========================= if __name__ == "__main__": uvicorn.run( app, host="0.0.0.0", port=int(os.environ.get("PORT", 7860)) )