Spaces:
Running
Running
| import os | |
| import threading | |
| import torch | |
| import numpy as np | |
| from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request | |
| from fastapi.responses import StreamingResponse, HTMLResponse, RedirectResponse, JSONResponse | |
| from PIL import Image | |
| from io import BytesIO | |
| import requests | |
| from transformers import AutoModelForImageSegmentation | |
| import uvicorn | |
| # --------------------------------------------------------- | |
| # Optional HEIC/HEIF | |
| # --------------------------------------------------------- | |
| try: | |
| import pillow_heif | |
| pillow_heif.register_heif_opener() | |
| except ImportError: | |
| pass | |
| # --------------------------------------------------------- | |
| # Performance settings for HF CPU | |
| # --------------------------------------------------------- | |
| os.environ["OMP_NUM_THREADS"] = "1" | |
| os.environ["MKL_NUM_THREADS"] = "1" | |
| torch.set_num_threads(1) | |
| # --------------------------------------------------------- | |
| # Constants | |
| # --------------------------------------------------------- | |
| TARGET_SIZE = (512, 512) # Faster inference | |
| MAX_SIDE = 3000 # Auto-downscale for huge uploads | |
| # --------------------------------------------------------- | |
| # Load model | |
| # --------------------------------------------------------- | |
| MODEL_DIR = "models/BiRefNet" | |
| os.makedirs(MODEL_DIR, exist_ok=True) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| print("Loading BiRefNet…") | |
| birefnet = AutoModelForImageSegmentation.from_pretrained( | |
| "ZhengPeng7/BiRefNet", | |
| cache_dir=MODEL_DIR, | |
| trust_remote_code=True, | |
| revision="main", | |
| ) | |
| birefnet.to(device, dtype=dtype).eval() | |
| print("Model ready.") | |
| lock = threading.Lock() | |
| # --------------------------------------------------------- | |
| # Helpers | |
| # --------------------------------------------------------- | |
| def load_image_from_url(url: str) -> Image.Image: | |
| try: | |
| r = requests.get(url, timeout=10) | |
| r.raise_for_status() | |
| return Image.open(BytesIO(r.content)).convert("RGB") | |
| except Exception: | |
| raise HTTPException(status_code=400, detail="Invalid image URL") | |
| def auto_downscale(img: Image.Image) -> Image.Image: | |
| w, h = img.size | |
| if max(w, h) <= MAX_SIDE: | |
| return img | |
| scale = MAX_SIDE / max(w, h) | |
| new_w = int(w * scale) | |
| new_h = int(h * scale) | |
| print(f"[INFO] Downscaling {w}×{h} → {new_w}×{new_h}") | |
| return img.resize((new_w, new_h), Image.LANCZOS) | |
| def transform(img: Image.Image) -> torch.Tensor: | |
| img = img.resize(TARGET_SIZE) | |
| arr = np.array(img).astype(np.float32) / 255.0 | |
| mean = np.array([0.485, 0.456, 0.406]) | |
| std = np.array([0.229, 0.224, 0.225]) | |
| arr = (arr - mean) / std | |
| arr = np.transpose(arr, (2, 0, 1)) | |
| t = torch.from_numpy(arr).unsqueeze(0).to(device=device, dtype=dtype) | |
| return t | |
| def run_inference(img: Image.Image) -> Image.Image: | |
| orig_size = img.size | |
| tensor = transform(img) | |
| with lock: | |
| with torch.no_grad(): | |
| pred = birefnet(tensor)[-1].sigmoid().cpu()[0, 0] | |
| mask = Image.fromarray((pred.numpy() * 255).astype(np.uint8)).resize(orig_size) | |
| img = img.convert("RGBA") | |
| img.putalpha(mask) | |
| return img | |
| # --------------------------------------------------------- | |
| # FastAPI app | |
| # --------------------------------------------------------- | |
| app = FastAPI(title="Background Remover API") | |
| # --------------------------------------------------------- | |
| # Redirect GET → POST logic | |
| # --------------------------------------------------------- | |
| async def redirect_to_post(): | |
| return JSONResponse( | |
| {"detail": "This endpoint only supports POST. Use POST /remove-background"}, | |
| status_code=405 | |
| ) | |
| # --------------------------------------------------------- | |
| # Main POST endpoint | |
| # --------------------------------------------------------- | |
| async def remove_bg(file: UploadFile = File(None), image_url: str = Form(None)): | |
| try: | |
| if file: | |
| raw = await file.read() | |
| img = Image.open(BytesIO(raw)).convert("RGB") | |
| elif image_url: | |
| img = load_image_from_url(image_url) | |
| else: | |
| raise HTTPException(status_code=400, detail="Upload file or image_url required") | |
| img = auto_downscale(img) | |
| result = run_inference(img) | |
| buf = BytesIO() | |
| result.save(buf, format="PNG") | |
| buf.seek(0) | |
| return StreamingResponse(buf, media_type="image/png") | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # --------------------------------------------------------- | |
| # UI: Show INPUT + OUTPUT (big preview) | |
| # --------------------------------------------------------- | |
| async def ui(): | |
| return """ | |
| <html> | |
| <head> | |
| <title>Background Remover – Test UI</title> | |
| <link rel='stylesheet' | |
| href='https://cdn.jsdelivr.net/npm/bootstrap@5.3.2/dist/css/bootstrap.min.css'> | |
| </head> | |
| <body class='bg-light'> | |
| <div class='container py-4 text-center'> | |
| <h2 class='mb-4'>API Test Panel (POST Only)</h2> | |
| <div class='row'> | |
| <div class='col-md-6'> | |
| <h5>Input Image</h5> | |
| <img id='inputImg' style='max-width:100%; border-radius:10px;'> | |
| </div> | |
| <div class='col-md-6'> | |
| <h5>Output Image</h5> | |
| <img id='outputImg' style='max-width:100%; border-radius:10px;'> | |
| </div> | |
| </div> | |
| <hr> | |
| <h4>Upload Test</h4> | |
| <form id="uploadForm" enctype='multipart/form-data'> | |
| <input type='file' id='fileInput' class='form-control mb-3'> | |
| <button class='btn btn-primary'>Send POST</button> | |
| </form> | |
| <hr> | |
| <h4>URL Test</h4> | |
| <form id='urlForm'> | |
| <input id='urlInput' class='form-control mb-3' placeholder='https://example.com/image.jpg'> | |
| <button class='btn btn-success'>Send POST</button> | |
| </form> | |
| </div> | |
| <script> | |
| const inputImg = document.getElementById("inputImg"); | |
| const outputImg = document.getElementById("outputImg"); | |
| // FILE TEST | |
| document.getElementById("uploadForm").addEventListener("submit", async e => { | |
| e.preventDefault(); | |
| const file = document.getElementById("fileInput").files[0]; | |
| if (!file) return alert("Select a file first."); | |
| inputImg.src = URL.createObjectURL(file); | |
| const fd = new FormData(); | |
| fd.append("file", file); | |
| const r = await fetch("/remove-background", { method:"POST", body:fd }); | |
| outputImg.src = URL.createObjectURL(await r.blob()); | |
| }); | |
| // URL TEST | |
| document.getElementById("urlForm").addEventListener("submit", async e => { | |
| e.preventDefault(); | |
| const url = document.getElementById("urlInput").value.trim(); | |
| if (!url) return alert("Enter an image URL first."); | |
| inputImg.src = url; | |
| const fd = new FormData(); | |
| fd.append("image_url", url); | |
| const r = await fetch("/remove-background", { method:"POST", body:fd }); | |
| outputImg.src = URL.createObjectURL(await r.blob()); | |
| }); | |
| </script> | |
| </body> | |
| </html> | |
| """ | |
| # --------------------------------------------------------- | |
| # Run app | |
| # --------------------------------------------------------- | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |