videopix's picture
Update app.py
bf34fae verified
raw
history blame
7.7 kB
import os
import threading
import torch
import numpy as np
from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request
from fastapi.responses import StreamingResponse, HTMLResponse, RedirectResponse, JSONResponse
from PIL import Image
from io import BytesIO
import requests
from transformers import AutoModelForImageSegmentation
import uvicorn
# ---------------------------------------------------------
# Optional HEIC/HEIF
# ---------------------------------------------------------
try:
import pillow_heif
pillow_heif.register_heif_opener()
except ImportError:
pass
# ---------------------------------------------------------
# Performance settings for HF CPU
# ---------------------------------------------------------
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
torch.set_num_threads(1)
# ---------------------------------------------------------
# Constants
# ---------------------------------------------------------
TARGET_SIZE = (512, 512) # Faster inference
MAX_SIDE = 3000 # Auto-downscale for huge uploads
# ---------------------------------------------------------
# Load model
# ---------------------------------------------------------
MODEL_DIR = "models/BiRefNet"
os.makedirs(MODEL_DIR, exist_ok=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
print("Loading BiRefNet…")
birefnet = AutoModelForImageSegmentation.from_pretrained(
"ZhengPeng7/BiRefNet",
cache_dir=MODEL_DIR,
trust_remote_code=True,
revision="main",
)
birefnet.to(device, dtype=dtype).eval()
print("Model ready.")
lock = threading.Lock()
# ---------------------------------------------------------
# Helpers
# ---------------------------------------------------------
def load_image_from_url(url: str) -> Image.Image:
try:
r = requests.get(url, timeout=10)
r.raise_for_status()
return Image.open(BytesIO(r.content)).convert("RGB")
except Exception:
raise HTTPException(status_code=400, detail="Invalid image URL")
def auto_downscale(img: Image.Image) -> Image.Image:
w, h = img.size
if max(w, h) <= MAX_SIDE:
return img
scale = MAX_SIDE / max(w, h)
new_w = int(w * scale)
new_h = int(h * scale)
print(f"[INFO] Downscaling {w}×{h}{new_w}×{new_h}")
return img.resize((new_w, new_h), Image.LANCZOS)
def transform(img: Image.Image) -> torch.Tensor:
img = img.resize(TARGET_SIZE)
arr = np.array(img).astype(np.float32) / 255.0
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
arr = (arr - mean) / std
arr = np.transpose(arr, (2, 0, 1))
t = torch.from_numpy(arr).unsqueeze(0).to(device=device, dtype=dtype)
return t
def run_inference(img: Image.Image) -> Image.Image:
orig_size = img.size
tensor = transform(img)
with lock:
with torch.no_grad():
pred = birefnet(tensor)[-1].sigmoid().cpu()[0, 0]
mask = Image.fromarray((pred.numpy() * 255).astype(np.uint8)).resize(orig_size)
img = img.convert("RGBA")
img.putalpha(mask)
return img
# ---------------------------------------------------------
# FastAPI app
# ---------------------------------------------------------
app = FastAPI(title="Background Remover API")
# ---------------------------------------------------------
# Redirect GET → POST logic
# ---------------------------------------------------------
@app.get("/remove-background")
async def redirect_to_post():
return JSONResponse(
{"detail": "This endpoint only supports POST. Use POST /remove-background"},
status_code=405
)
# ---------------------------------------------------------
# Main POST endpoint
# ---------------------------------------------------------
@app.post("/remove-background")
async def remove_bg(file: UploadFile = File(None), image_url: str = Form(None)):
try:
if file:
raw = await file.read()
img = Image.open(BytesIO(raw)).convert("RGB")
elif image_url:
img = load_image_from_url(image_url)
else:
raise HTTPException(status_code=400, detail="Upload file or image_url required")
img = auto_downscale(img)
result = run_inference(img)
buf = BytesIO()
result.save(buf, format="PNG")
buf.seek(0)
return StreamingResponse(buf, media_type="image/png")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# ---------------------------------------------------------
# UI: Show INPUT + OUTPUT (big preview)
# ---------------------------------------------------------
@app.get("/", response_class=HTMLResponse)
async def ui():
return """
<html>
<head>
<title>Background Remover – Test UI</title>
<link rel='stylesheet'
href='https://cdn.jsdelivr.net/npm/bootstrap@5.3.2/dist/css/bootstrap.min.css'>
</head>
<body class='bg-light'>
<div class='container py-4 text-center'>
<h2 class='mb-4'>API Test Panel (POST Only)</h2>
<div class='row'>
<div class='col-md-6'>
<h5>Input Image</h5>
<img id='inputImg' style='max-width:100%; border-radius:10px;'>
</div>
<div class='col-md-6'>
<h5>Output Image</h5>
<img id='outputImg' style='max-width:100%; border-radius:10px;'>
</div>
</div>
<hr>
<h4>Upload Test</h4>
<form id="uploadForm" enctype='multipart/form-data'>
<input type='file' id='fileInput' class='form-control mb-3'>
<button class='btn btn-primary'>Send POST</button>
</form>
<hr>
<h4>URL Test</h4>
<form id='urlForm'>
<input id='urlInput' class='form-control mb-3' placeholder='https://example.com/image.jpg'>
<button class='btn btn-success'>Send POST</button>
</form>
</div>
<script>
const inputImg = document.getElementById("inputImg");
const outputImg = document.getElementById("outputImg");
// FILE TEST
document.getElementById("uploadForm").addEventListener("submit", async e => {
e.preventDefault();
const file = document.getElementById("fileInput").files[0];
if (!file) return alert("Select a file first.");
inputImg.src = URL.createObjectURL(file);
const fd = new FormData();
fd.append("file", file);
const r = await fetch("/remove-background", { method:"POST", body:fd });
outputImg.src = URL.createObjectURL(await r.blob());
});
// URL TEST
document.getElementById("urlForm").addEventListener("submit", async e => {
e.preventDefault();
const url = document.getElementById("urlInput").value.trim();
if (!url) return alert("Enter an image URL first.");
inputImg.src = url;
const fd = new FormData();
fd.append("image_url", url);
const r = await fetch("/remove-background", { method:"POST", body:fd });
outputImg.src = URL.createObjectURL(await r.blob());
});
</script>
</body>
</html>
"""
# ---------------------------------------------------------
# Run app
# ---------------------------------------------------------
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)