Spaces:
Running
Running
| import os | |
| import io | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| from safetensors.torch import load_file | |
| from fastapi import FastAPI, File, UploadFile, HTTPException, Request, Depends | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import StreamingResponse | |
| from fastapi.security import APIKeyHeader | |
| import uvicorn | |
| from birefnet import BiRefNet | |
| from BiRefNet_config import BiRefNetConfig | |
| # ========================= | |
| # HUGGING FACE SECRET | |
| # ========================= | |
| API_KEY = os.getenv("BIREFNET_API_KEY") | |
| if not API_KEY: | |
| raise RuntimeError("β BIREFNET_API_KEY not found in HF Space Secrets") | |
| DEVICE = "cpu" | |
| # ========================= | |
| # LOAD MODEL | |
| # ========================= | |
| config = BiRefNetConfig() | |
| model = BiRefNet(config) | |
| state_dict = load_file("model.safetensors") | |
| model.load_state_dict(state_dict, strict=False) | |
| model.to(DEVICE) | |
| model.eval() | |
| print("β BiRefNet Lite loaded") | |
| # ========================= | |
| # API KEY AUTH | |
| # ========================= | |
| api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) | |
| def verify_api_key(api_key: str = Depends(api_key_header)): | |
| if api_key != API_KEY: | |
| raise HTTPException( | |
| status_code=401, | |
| detail="Invalid or missing API key" | |
| ) | |
| # ========================= | |
| # IMAGE PIPELINE | |
| # ========================= | |
| def preprocess(img: Image.Image): | |
| img = img.convert("RGB").resize((1024, 1024)) | |
| arr = np.array(img).astype(np.float32) / 255.0 | |
| arr = arr.transpose(2, 0, 1) | |
| return torch.from_numpy(arr).unsqueeze(0) | |
| def remove_bg(image: Image.Image) -> Image.Image: | |
| x = preprocess(image).to(DEVICE) | |
| pred = model(x)[0] | |
| pred = torch.sigmoid(pred) | |
| mask = pred.squeeze().cpu().numpy() | |
| mask = (mask * 255).astype(np.uint8) | |
| mask = Image.fromarray(mask).resize(image.size) | |
| out = image.convert("RGBA") | |
| out.putalpha(mask) | |
| return out | |
| # ========================= | |
| # FASTAPI APP | |
| # ========================= | |
| app = FastAPI(title="BiRefNet Background Remover API") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Restrict later | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ========================= | |
| # ROUTES | |
| # ========================= | |
| async def root(): | |
| return { | |
| "status": "ok", | |
| "secured": True, | |
| "endpoint": "/remove-bg" | |
| } | |
| async def remove_background( | |
| request: Request, | |
| file: UploadFile = File(None), | |
| _: str = Depends(verify_api_key) | |
| ): | |
| try: | |
| if file is None: | |
| body = await request.body() | |
| if not body: | |
| raise HTTPException(400, "No image data received") | |
| image = Image.open(io.BytesIO(body)) | |
| else: | |
| contents = await file.read() | |
| if not contents: | |
| raise HTTPException(400, "Empty file") | |
| image = Image.open(io.BytesIO(contents)) | |
| if image.format not in ["JPEG", "JPG", "PNG"]: | |
| raise HTTPException(400, "Invalid image format") | |
| result = remove_bg(image) | |
| img_bytes = io.BytesIO() | |
| result.save(img_bytes, format="PNG") | |
| img_bytes.seek(0) | |
| return StreamingResponse( | |
| img_bytes, | |
| media_type="image/png", | |
| headers={ | |
| "Content-Disposition": "inline; filename=removed-bg.png" | |
| } | |
| ) | |
| except Exception as e: | |
| print("β Error:", e) | |
| raise HTTPException(500, "Processing failed") | |
| # ========================= | |
| # HF DOCKER ENTRYPOINT | |
| # ========================= | |
| if __name__ == "__main__": | |
| uvicorn.run( | |
| app, | |
| host="0.0.0.0", | |
| port=int(os.environ.get("PORT", 7860)) | |
| ) | |