Spaces:
Running
Running
| import os | |
| from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Query | |
| from fastapi.responses import StreamingResponse, HTMLResponse | |
| from PIL import Image | |
| import torch | |
| import numpy as np | |
| from transformers import AutoModelForImageSegmentation | |
| from io import BytesIO | |
| import requests | |
| import uvicorn | |
| # ------------------------- | |
| # Optional HEIC/HEIF Support | |
| # ------------------------- | |
| try: | |
| import pillow_heif | |
| pillow_heif.register_heif_opener() | |
| print("✅ HEIC/HEIF format supported.") | |
| except ImportError: | |
| print("⚠️ Install pillow-heif for HEIC support: pip install pillow-heif") | |
| # ------------------------- | |
| # Model Setup | |
| # ------------------------- | |
| MODEL_DIR = "models/BiRefNet" | |
| os.makedirs(MODEL_DIR, exist_ok=True) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| print("Loading BiRefNet model...") | |
| birefnet = AutoModelForImageSegmentation.from_pretrained( | |
| "ZhengPeng7/BiRefNet", | |
| cache_dir=MODEL_DIR, | |
| trust_remote_code=True, | |
| revision="main" | |
| ) | |
| birefnet.to(device, dtype=dtype).eval() | |
| print("Model loaded successfully.") | |
| # ------------------------- | |
| # FastAPI App | |
| # ------------------------- | |
| app = FastAPI(title="Background Remover API") | |
| # ------------------------- | |
| # Utility Functions | |
| # ------------------------- | |
| def load_image_from_url(url: str) -> Image.Image: | |
| try: | |
| response = requests.get(url, timeout=10) | |
| response.raise_for_status() | |
| return Image.open(BytesIO(response.content)).convert("RGB") | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail=f"Error loading image from URL: {str(e)}") | |
| def transform_image(image: Image.Image, resolution: int = 512) -> torch.Tensor: | |
| image = image.resize((resolution, resolution)) | |
| arr = np.array(image).astype(np.float32) / 255.0 | |
| mean = np.array([0.485, 0.456, 0.406], dtype=np.float32) | |
| std = np.array([0.229, 0.224, 0.225], dtype=np.float32) | |
| arr = (arr - mean) / std | |
| arr = np.transpose(arr, (2, 0, 1)) # HWC -> CHW | |
| tensor = torch.from_numpy(arr).unsqueeze(0).to(dtype).to(device) | |
| return tensor | |
| def process_image(image: Image.Image, resolution: int = 512) -> Image.Image: | |
| orig_size = image.size | |
| input_tensor = transform_image(image, resolution) | |
| with torch.no_grad(): | |
| preds = birefnet(input_tensor)[-1].sigmoid().cpu() | |
| pred = preds[0, 0] | |
| mask = Image.fromarray((pred.numpy() * 255).astype(np.uint8)).resize(orig_size) | |
| image = image.convert("RGBA") | |
| image.putalpha(mask) | |
| return image | |
| # ------------------------- | |
| # /remove-background Endpoint | |
| # ------------------------- | |
| async def remove_background( | |
| file: UploadFile = File(None), | |
| image_url: str = Form(None), | |
| resolution: int = Form(512) | |
| ): | |
| """ | |
| Remove background from an image. | |
| Accepts a file upload or image URL. | |
| Optional resolution (default 512) for faster inference. | |
| Returns PNG with transparent background. | |
| """ | |
| try: | |
| if file: | |
| image = Image.open(BytesIO(await file.read())).convert("RGB") | |
| elif image_url: | |
| image = load_image_from_url(image_url) | |
| else: | |
| raise HTTPException(status_code=400, detail="Provide either 'file' or 'image_url'.") | |
| result = process_image(image, resolution) | |
| buf = BytesIO() | |
| result.save(buf, format="PNG") | |
| buf.seek(0) | |
| return StreamingResponse(buf, media_type="image/png") | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # ------------------------- | |
| # Developer Test Page (Bootstrap) | |
| # ------------------------- | |
| async def index(): | |
| html = """ | |
| <!DOCTYPE html> | |
| <html lang="en"> | |
| <head> | |
| <meta charset="UTF-8"> | |
| <meta name="viewport" content="width=device-width, initial-scale=1"> | |
| <title>Background Remover API Test</title> | |
| <link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.2/dist/css/bootstrap.min.css" rel="stylesheet"> | |
| <style> | |
| body { background-color: #f8f9fa; padding-top: 40px; } | |
| .container { max-width: 700px; } | |
| img { max-width: 100%; margin-top: 20px; border-radius: 10px; } | |
| </style> | |
| </head> | |
| <body> | |
| <div class="container text-center"> | |
| <h2 class="mb-4">Background Remover API Tester</h2> | |
| <form id="uploadForm" class="mb-4" enctype="multipart/form-data"> | |
| <div class="mb-3"> | |
| <label for="fileInput" class="form-label">Upload Image (any format, e.g. JPG, PNG, HEIC):</label> | |
| <input class="form-control" type="file" id="fileInput" name="file" accept="image/*"> | |
| </div> | |
| <div class="mb-3"> | |
| <label for="resInput" class="form-label">Resolution (default 512):</label> | |
| <input class="form-control" type="number" id="resInput" name="resolution" value="512" min="64" max="2048"> | |
| </div> | |
| <button class="btn btn-primary" type="submit">Remove Background</button> | |
| </form> | |
| <div class="mb-4">OR</div> | |
| <form id="urlForm" class="mb-4"> | |
| <div class="mb-3"> | |
| <label for="urlInput" class="form-label">Enter Image URL:</label> | |
| <input class="form-control" type="text" id="urlInput" placeholder="https://example.com/image.jpg"> | |
| </div> | |
| <div class="mb-3"> | |
| <label for="urlResInput" class="form-label">Resolution (default 512):</label> | |
| <input class="form-control" type="number" id="urlResInput" name="resolution" value="512" min="64" max="2048"> | |
| </div> | |
| <button class="btn btn-success" type="submit">Remove Background</button> | |
| </form> | |
| <div id="resultContainer" class="mt-4"> | |
| <h5>Result:</h5> | |
| <img id="resultImg" src="" alt=""> | |
| </div> | |
| </div> | |
| <script> | |
| const uploadForm = document.getElementById("uploadForm"); | |
| const urlForm = document.getElementById("urlForm"); | |
| const resultImg = document.getElementById("resultImg"); | |
| uploadForm.addEventListener("submit", async e => { | |
| e.preventDefault(); | |
| const fileInput = document.getElementById("fileInput"); | |
| const res = document.getElementById("resInput").value || 512; | |
| if (!fileInput.files.length) return alert("Please select a file!"); | |
| const formData = new FormData(); | |
| formData.append("file", fileInput.files[0]); | |
| formData.append("resolution", res); | |
| const response = await fetch("/remove-background", { method: "POST", body: formData }); | |
| const blob = await response.blob(); | |
| resultImg.src = URL.createObjectURL(blob); | |
| }); | |
| urlForm.addEventListener("submit", async e => { | |
| e.preventDefault(); | |
| const url = document.getElementById("urlInput").value.trim(); | |
| const res = document.getElementById("urlResInput").value || 512; | |
| if (!url) return alert("Please enter an image URL!"); | |
| const formData = new FormData(); | |
| formData.append("image_url", url); | |
| formData.append("resolution", res); | |
| const response = await fetch("/remove-background", { method: "POST", body: formData }); | |
| const blob = await response.blob(); | |
| resultImg.src = URL.createObjectURL(blob); | |
| }); | |
| </script> | |
| </body> | |
| </html> | |
| """ | |
| return HTMLResponse(html) | |
| # ------------------------- | |
| # Run App | |
| # ------------------------- | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |