Spaces:
Running
Running
File size: 3,769 Bytes
0df353e 353cbab 839de5d 353cbab 0df353e 353cbab 0df353e 353cbab 0df353e 353cbab 0df353e 353cbab 0df353e 353cbab 0df353e 353cbab 0b06727 0df353e 353cbab 0b06727 353cbab 0df353e 353cbab 0df353e 353cbab 0df353e 353cbab b423626 0df353e b423626 0b06727 0df353e 353cbab 0df353e b423626 0df353e b423626 353cbab 0df353e 353cbab 0df353e 353cbab 0df353e 353cbab 0df353e 353cbab 0df353e 353cbab 0df353e 353cbab 0df353e b423626 0df353e b423626 0df353e 353cbab 0df353e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import os
import io
import torch
import numpy as np
from PIL import Image
from safetensors.torch import load_file
from fastapi import FastAPI, File, UploadFile, HTTPException, Request, Depends
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from fastapi.security import APIKeyHeader
import uvicorn
from birefnet import BiRefNet
from BiRefNet_config import BiRefNetConfig
# =========================
# HUGGING FACE SECRET
# =========================
API_KEY = os.getenv("BIREFNET_API_KEY")
if not API_KEY:
raise RuntimeError("❌ BIREFNET_API_KEY not found in HF Space Secrets")
DEVICE = "cpu"
# =========================
# LOAD MODEL
# =========================
config = BiRefNetConfig()
model = BiRefNet(config)
state_dict = load_file("model.safetensors")
model.load_state_dict(state_dict, strict=False)
model.to(DEVICE)
model.eval()
print("✅ BiRefNet Lite loaded")
# =========================
# API KEY AUTH
# =========================
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
def verify_api_key(api_key: str = Depends(api_key_header)):
if api_key != API_KEY:
raise HTTPException(
status_code=401,
detail="Invalid or missing API key"
)
# =========================
# IMAGE PIPELINE
# =========================
def preprocess(img: Image.Image):
img = img.convert("RGB").resize((1024, 1024))
arr = np.array(img).astype(np.float32) / 255.0
arr = arr.transpose(2, 0, 1)
return torch.from_numpy(arr).unsqueeze(0)
@torch.no_grad()
def remove_bg(image: Image.Image) -> Image.Image:
x = preprocess(image).to(DEVICE)
pred = model(x)[0]
pred = torch.sigmoid(pred)
mask = pred.squeeze().cpu().numpy()
mask = (mask * 255).astype(np.uint8)
mask = Image.fromarray(mask).resize(image.size)
out = image.convert("RGBA")
out.putalpha(mask)
return out
# =========================
# FASTAPI APP
# =========================
app = FastAPI(title="BiRefNet Background Remover API")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Restrict later
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# =========================
# ROUTES
# =========================
@app.get("/")
async def root():
return {
"status": "ok",
"secured": True,
"endpoint": "/remove-bg"
}
@app.post("/remove-bg")
async def remove_background(
request: Request,
file: UploadFile = File(None),
_: str = Depends(verify_api_key)
):
try:
if file is None:
body = await request.body()
if not body:
raise HTTPException(400, "No image data received")
image = Image.open(io.BytesIO(body))
else:
contents = await file.read()
if not contents:
raise HTTPException(400, "Empty file")
image = Image.open(io.BytesIO(contents))
if image.format not in ["JPEG", "JPG", "PNG"]:
raise HTTPException(400, "Invalid image format")
result = remove_bg(image)
img_bytes = io.BytesIO()
result.save(img_bytes, format="PNG")
img_bytes.seek(0)
return StreamingResponse(
img_bytes,
media_type="image/png",
headers={
"Content-Disposition": "inline; filename=removed-bg.png"
}
)
except Exception as e:
print("❌ Error:", e)
raise HTTPException(500, "Processing failed")
# =========================
# HF DOCKER ENTRYPOINT
# =========================
if __name__ == "__main__":
uvicorn.run(
app,
host="0.0.0.0",
port=int(os.environ.get("PORT", 7860))
)
|