whitepeacock's picture
Update app.py
7437b84 verified
raw
history blame
9.84 kB
import os
import asyncio
from concurrent.futures import ThreadPoolExecutor
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.responses import StreamingResponse, HTMLResponse
from PIL import Image, ImageSequence
import pillow_heif # HEIC/HEIF support
import numpy as np
import torch
from transformers import AutoModelForImageSegmentation
from io import BytesIO
from loadimg import load_img
import uvicorn
# -------------------------
# Enable HEIC/HEIF Support
# -------------------------
pillow_heif.register_heif_opener()
# -------------------------
# Thread Pool for Concurrency
# -------------------------
executor = ThreadPoolExecutor(max_workers=os.cpu_count() or 4)
# -------------------------
# Model Setup (Load Once)
# -------------------------
MODEL_DIR = "models/BiRefNet"
os.makedirs(MODEL_DIR, exist_ok=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
print("Loading BiRefNet model (first run may take a while)...")
birefnet = AutoModelForImageSegmentation.from_pretrained(
"ZhengPeng7/BiRefNet",
cache_dir=MODEL_DIR,
trust_remote_code=True
)
birefnet.to(device)
birefnet.eval()
print(f"Model loaded successfully on {device}.")
# -------------------------
# Image Preprocessing
# -------------------------
TARGET_SIZE = (512, 512) # Lower resolution for faster inference
def transform_image(image: Image.Image) -> torch.Tensor:
image = image.resize(TARGET_SIZE)
arr = np.array(image).astype(np.float32) / 255.0
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
arr = (arr - mean) / std
arr = np.transpose(arr, (2, 0, 1)) # HWC -> CHW
tensor = torch.from_numpy(arr).unsqueeze(0).to(torch.float32).to(device)
return tensor
def process_image_sync(image: Image.Image) -> BytesIO:
"""Process image synchronously and return PNG bytes (in-memory)."""
image_size = image.size
input_tensor = transform_image(image)
with torch.no_grad():
if device == "cuda":
# Mixed precision for GPU
with torch.cuda.amp.autocast():
preds = birefnet(input_tensor)[-1].sigmoid().cpu()
else:
# CPU fallback
preds = birefnet(input_tensor)[-1].sigmoid().cpu()
pred = preds[0, 0].numpy()
mask = Image.fromarray((pred * 255).astype(np.uint8)).resize(image_size)
image = image.copy()
image.putalpha(mask)
output_buffer = BytesIO()
image.save(output_buffer, format="PNG")
output_buffer.seek(0)
return output_buffer
async def process_image_async(image: Image.Image) -> BytesIO:
"""Run processing asynchronously in thread pool (no disk I/O)."""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(executor, process_image_sync, image)
# -------------------------
# Safe Image Loader
# -------------------------
def open_image_safely(file_bytes: bytes) -> Image.Image:
"""Open image safely (HEIC, HEIF, PDF, SVG, GIF, PNG, JPG, etc)."""
try:
img = Image.open(BytesIO(file_bytes))
fmt = (img.format or "").lower()
# Handle PDF: first page
if fmt == "pdf":
from pdf2image import convert_from_bytes
pdf_images = convert_from_bytes(file_bytes, first_page=1, last_page=1)
return pdf_images[0].convert("RGB")
# Handle GIF: first frame
if fmt == "gif" and getattr(img, "is_animated", False):
img.seek(0)
return img.convert("RGB")
# Handle SVG
if fmt == "svg":
import cairosvg
png_bytes = cairosvg.svg2png(bytestring=file_bytes)
return Image.open(BytesIO(png_bytes)).convert("RGB")
# Other formats (HEIC, HEIF, JPG, PNG)
return img.convert("RGB")
except Exception as e:
raise HTTPException(status_code=400, detail=f"Unsupported or corrupted image: {e}")
# -------------------------
# FastAPI App
# -------------------------
app = FastAPI(title="Background Removal API", description="Removes image backgrounds in-memory")
# -------------------------
# API Endpoints
# -------------------------
@app.post("/remove_bg_file")
async def remove_bg_file(file: UploadFile = File(...)):
try:
contents = await file.read()
image = open_image_safely(contents)
output_buffer = await process_image_async(image)
return StreamingResponse(output_buffer, media_type="image/png")
except HTTPException as e:
raise e
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing image: {e}")
@app.post("/remove_bg_url")
async def remove_bg_url(image_url: str = Form(...)):
try:
image = load_img(image_url, output_type="pil").convert("RGB")
output_buffer = await process_image_async(image)
return StreamingResponse(output_buffer, media_type="image/png")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing URL: {e}")
# -------------------------
# Web Interface
# -------------------------
@app.get("/", response_class=HTMLResponse)
async def index():
html = """
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>Background Removal Tool</title>
<link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0/dist/css/bootstrap.min.css" rel="stylesheet">
<style>
body { padding: 30px; background-color: #f8f9fa; }
.container { max-width: 700px; background: #fff; padding: 20px; border-radius: 10px;
box-shadow: 0 0 10px rgba(0,0,0,0.1);}
img { max-width: 100%; border-radius: 8px; }
.preview-grid { display: flex; gap: 15px; justify-content: space-between; margin-top: 15px; }
.preview-item { flex: 1; text-align: center; }
.preview-item img { width: 100%; border: 1px solid #ddd; padding: 5px; background: #fff; }
</style>
</head>
<body>
<div class="container">
<h2 class="mb-4">Background Removal Tool</h2>
<form id="fileForm" enctype="multipart/form-data">
<div class="mb-3">
<label for="fileInput" class="form-label">Upload Image</label>
<input class="form-control" type="file" id="fileInput" name="file"
accept="image/*,application/pdf,.heic,.heif,.svg">
</div>
<button class="btn btn-primary" type="submit">Remove Background</button>
</form>
<hr>
<form id="urlForm">
<div class="mb-3">
<label for="urlInput" class="form-label">Image URL</label>
<input class="form-control" type="text" id="urlInput" placeholder="Enter image URL">
</div>
<button class="btn btn-success" type="submit">Remove Background</button>
</form>
<hr>
<h5>Preview:</h5>
<div class="preview-grid">
<div class="preview-item">
<strong>Before</strong>
<img id="beforeImg" src="" alt="Original Image">
</div>
<div class="preview-item">
<strong>After</strong>
<img id="afterImg" src="" alt="Processed Image">
</div>
</div>
</div>
<script>
const fileForm = document.getElementById('fileForm');
const urlForm = document.getElementById('urlForm');
const beforeImg = document.getElementById('beforeImg');
const afterImg = document.getElementById('afterImg');
fileForm.addEventListener('submit', async (e) => {
e.preventDefault();
const fileInput = document.getElementById('fileInput');
if (!fileInput.files.length) return alert("Select a file!");
const file = fileInput.files[0];
beforeImg.src = URL.createObjectURL(file);
const formData = new FormData();
formData.append("file", file);
const res = await fetch('/remove_bg_file', { method: 'POST', body: formData });
if (!res.ok) {
const err = await res.json();
alert(err.detail || "Failed to process image");
return;
}
const blob = await res.blob();
afterImg.src = URL.createObjectURL(blob);
});
urlForm.addEventListener('submit', async (e) => {
e.preventDefault();
const urlInput = document.getElementById('urlInput').value;
if (!urlInput) return alert("Enter an image URL");
beforeImg.src = urlInput;
const formData = new FormData();
formData.append("image_url", urlInput);
const res = await fetch('/remove_bg_url', { method: 'POST', body: formData });
if (!res.ok) {
const err = await res.json();
alert(err.detail || "Failed to process image URL");
return;
}
const blob = await res.blob();
afterImg.src = URL.createObjectURL(blob);
});
</script>
</body>
</html>
"""
return HTMLResponse(content=html)
# -------------------------
# Run Server (Auto-detect filename)
# -------------------------
if __name__ == "__main__":
module_name = os.path.splitext(os.path.basename(__file__))[0]
uvicorn.run(f"{module_name}:app", host="0.0.0.0", port=7860, workers=2)