Docker-BiRefNet / app.py
um41r's picture
Update app.py
0df353e verified
import os
import io
import torch
import numpy as np
from PIL import Image
from safetensors.torch import load_file
from fastapi import FastAPI, File, UploadFile, HTTPException, Request, Depends
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from fastapi.security import APIKeyHeader
import uvicorn
from birefnet import BiRefNet
from BiRefNet_config import BiRefNetConfig
# =========================
# HUGGING FACE SECRET
# =========================
API_KEY = os.getenv("BIREFNET_API_KEY")
if not API_KEY:
raise RuntimeError("❌ BIREFNET_API_KEY not found in HF Space Secrets")
DEVICE = "cpu"
# =========================
# LOAD MODEL
# =========================
config = BiRefNetConfig()
model = BiRefNet(config)
state_dict = load_file("model.safetensors")
model.load_state_dict(state_dict, strict=False)
model.to(DEVICE)
model.eval()
print("βœ… BiRefNet Lite loaded")
# =========================
# API KEY AUTH
# =========================
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
def verify_api_key(api_key: str = Depends(api_key_header)):
if api_key != API_KEY:
raise HTTPException(
status_code=401,
detail="Invalid or missing API key"
)
# =========================
# IMAGE PIPELINE
# =========================
def preprocess(img: Image.Image):
img = img.convert("RGB").resize((1024, 1024))
arr = np.array(img).astype(np.float32) / 255.0
arr = arr.transpose(2, 0, 1)
return torch.from_numpy(arr).unsqueeze(0)
@torch.no_grad()
def remove_bg(image: Image.Image) -> Image.Image:
x = preprocess(image).to(DEVICE)
pred = model(x)[0]
pred = torch.sigmoid(pred)
mask = pred.squeeze().cpu().numpy()
mask = (mask * 255).astype(np.uint8)
mask = Image.fromarray(mask).resize(image.size)
out = image.convert("RGBA")
out.putalpha(mask)
return out
# =========================
# FASTAPI APP
# =========================
app = FastAPI(title="BiRefNet Background Remover API")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Restrict later
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# =========================
# ROUTES
# =========================
@app.get("/")
async def root():
return {
"status": "ok",
"secured": True,
"endpoint": "/remove-bg"
}
@app.post("/remove-bg")
async def remove_background(
request: Request,
file: UploadFile = File(None),
_: str = Depends(verify_api_key)
):
try:
if file is None:
body = await request.body()
if not body:
raise HTTPException(400, "No image data received")
image = Image.open(io.BytesIO(body))
else:
contents = await file.read()
if not contents:
raise HTTPException(400, "Empty file")
image = Image.open(io.BytesIO(contents))
if image.format not in ["JPEG", "JPG", "PNG"]:
raise HTTPException(400, "Invalid image format")
result = remove_bg(image)
img_bytes = io.BytesIO()
result.save(img_bytes, format="PNG")
img_bytes.seek(0)
return StreamingResponse(
img_bytes,
media_type="image/png",
headers={
"Content-Disposition": "inline; filename=removed-bg.png"
}
)
except Exception as e:
print("❌ Error:", e)
raise HTTPException(500, "Processing failed")
# =========================
# HF DOCKER ENTRYPOINT
# =========================
if __name__ == "__main__":
uvicorn.run(
app,
host="0.0.0.0",
port=int(os.environ.get("PORT", 7860))
)