Spaces:
Running
Running
| import os | |
| import asyncio | |
| from concurrent.futures import ThreadPoolExecutor | |
| from fastapi import FastAPI, UploadFile, File, Form, HTTPException | |
| from fastapi.responses import StreamingResponse, HTMLResponse | |
| from PIL import Image, ImageSequence | |
| import pillow_heif # HEIC/HEIF support | |
| import numpy as np | |
| import torch | |
| from transformers import AutoModelForImageSegmentation | |
| from io import BytesIO | |
| from loadimg import load_img | |
| import uvicorn | |
| # ------------------------- | |
| # Enable HEIC/HEIF Support | |
| # ------------------------- | |
| pillow_heif.register_heif_opener() | |
| # ------------------------- | |
| # Thread Pool for Concurrency | |
| # ------------------------- | |
| executor = ThreadPoolExecutor(max_workers=os.cpu_count() or 4) | |
| # ------------------------- | |
| # Model Setup (Load Once) | |
| # ------------------------- | |
| MODEL_DIR = "models/BiRefNet" | |
| os.makedirs(MODEL_DIR, exist_ok=True) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {device}") | |
| print("Loading BiRefNet model (first run may take a while)...") | |
| birefnet = AutoModelForImageSegmentation.from_pretrained( | |
| "ZhengPeng7/BiRefNet", | |
| cache_dir=MODEL_DIR, | |
| trust_remote_code=True | |
| ) | |
| birefnet.to(device) | |
| birefnet.eval() | |
| print(f"Model loaded successfully on {device}.") | |
| # ------------------------- | |
| # Image Preprocessing | |
| # ------------------------- | |
| TARGET_SIZE = (512, 512) # Lower resolution for faster inference | |
| def transform_image(image: Image.Image) -> torch.Tensor: | |
| image = image.resize(TARGET_SIZE) | |
| arr = np.array(image).astype(np.float32) / 255.0 | |
| mean = np.array([0.485, 0.456, 0.406], dtype=np.float32) | |
| std = np.array([0.229, 0.224, 0.225], dtype=np.float32) | |
| arr = (arr - mean) / std | |
| arr = np.transpose(arr, (2, 0, 1)) # HWC -> CHW | |
| tensor = torch.from_numpy(arr).unsqueeze(0).to(torch.float32).to(device) | |
| return tensor | |
| def process_image_sync(image: Image.Image) -> BytesIO: | |
| """Process image synchronously and return PNG bytes (in-memory).""" | |
| image_size = image.size | |
| input_tensor = transform_image(image) | |
| with torch.no_grad(): | |
| if device == "cuda": | |
| # Mixed precision for GPU | |
| with torch.cuda.amp.autocast(): | |
| preds = birefnet(input_tensor)[-1].sigmoid().cpu() | |
| else: | |
| # CPU fallback | |
| preds = birefnet(input_tensor)[-1].sigmoid().cpu() | |
| pred = preds[0, 0].numpy() | |
| mask = Image.fromarray((pred * 255).astype(np.uint8)).resize(image_size) | |
| image = image.copy() | |
| image.putalpha(mask) | |
| output_buffer = BytesIO() | |
| image.save(output_buffer, format="PNG") | |
| output_buffer.seek(0) | |
| return output_buffer | |
| async def process_image_async(image: Image.Image) -> BytesIO: | |
| """Run processing asynchronously in thread pool (no disk I/O).""" | |
| loop = asyncio.get_event_loop() | |
| return await loop.run_in_executor(executor, process_image_sync, image) | |
| # ------------------------- | |
| # Safe Image Loader | |
| # ------------------------- | |
| def open_image_safely(file_bytes: bytes) -> Image.Image: | |
| """Open image safely (HEIC, HEIF, PDF, SVG, GIF, PNG, JPG, etc).""" | |
| try: | |
| img = Image.open(BytesIO(file_bytes)) | |
| fmt = (img.format or "").lower() | |
| # Handle PDF: first page | |
| if fmt == "pdf": | |
| from pdf2image import convert_from_bytes | |
| pdf_images = convert_from_bytes(file_bytes, first_page=1, last_page=1) | |
| return pdf_images[0].convert("RGB") | |
| # Handle GIF: first frame | |
| if fmt == "gif" and getattr(img, "is_animated", False): | |
| img.seek(0) | |
| return img.convert("RGB") | |
| # Handle SVG | |
| if fmt == "svg": | |
| import cairosvg | |
| png_bytes = cairosvg.svg2png(bytestring=file_bytes) | |
| return Image.open(BytesIO(png_bytes)).convert("RGB") | |
| # Other formats (HEIC, HEIF, JPG, PNG) | |
| return img.convert("RGB") | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail=f"Unsupported or corrupted image: {e}") | |
| # ------------------------- | |
| # FastAPI App | |
| # ------------------------- | |
| app = FastAPI(title="Background Removal API", description="Removes image backgrounds in-memory") | |
| # ------------------------- | |
| # API Endpoints | |
| # ------------------------- | |
| async def remove_bg_file(file: UploadFile = File(...)): | |
| try: | |
| contents = await file.read() | |
| image = open_image_safely(contents) | |
| output_buffer = await process_image_async(image) | |
| return StreamingResponse(output_buffer, media_type="image/png") | |
| except HTTPException as e: | |
| raise e | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error processing image: {e}") | |
| async def remove_bg_url(image_url: str = Form(...)): | |
| try: | |
| image = load_img(image_url, output_type="pil").convert("RGB") | |
| output_buffer = await process_image_async(image) | |
| return StreamingResponse(output_buffer, media_type="image/png") | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error processing URL: {e}") | |
| # ------------------------- | |
| # Web Interface | |
| # ------------------------- | |
| async def index(): | |
| html = """ | |
| <!DOCTYPE html> | |
| <html lang="en"> | |
| <head> | |
| <meta charset="UTF-8"> | |
| <title>Background Removal Tool</title> | |
| <link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0/dist/css/bootstrap.min.css" rel="stylesheet"> | |
| <style> | |
| body { padding: 30px; background-color: #f8f9fa; } | |
| .container { max-width: 700px; background: #fff; padding: 20px; border-radius: 10px; | |
| box-shadow: 0 0 10px rgba(0,0,0,0.1);} | |
| img { max-width: 100%; border-radius: 8px; } | |
| .preview-grid { display: flex; gap: 15px; justify-content: space-between; margin-top: 15px; } | |
| .preview-item { flex: 1; text-align: center; } | |
| .preview-item img { width: 100%; border: 1px solid #ddd; padding: 5px; background: #fff; } | |
| </style> | |
| </head> | |
| <body> | |
| <div class="container"> | |
| <h2 class="mb-4">Background Removal Tool</h2> | |
| <form id="fileForm" enctype="multipart/form-data"> | |
| <div class="mb-3"> | |
| <label for="fileInput" class="form-label">Upload Image</label> | |
| <input class="form-control" type="file" id="fileInput" name="file" | |
| accept="image/*,application/pdf,.heic,.heif,.svg"> | |
| </div> | |
| <button class="btn btn-primary" type="submit">Remove Background</button> | |
| </form> | |
| <hr> | |
| <form id="urlForm"> | |
| <div class="mb-3"> | |
| <label for="urlInput" class="form-label">Image URL</label> | |
| <input class="form-control" type="text" id="urlInput" placeholder="Enter image URL"> | |
| </div> | |
| <button class="btn btn-success" type="submit">Remove Background</button> | |
| </form> | |
| <hr> | |
| <h5>Preview:</h5> | |
| <div class="preview-grid"> | |
| <div class="preview-item"> | |
| <strong>Before</strong> | |
| <img id="beforeImg" src="" alt="Original Image"> | |
| </div> | |
| <div class="preview-item"> | |
| <strong>After</strong> | |
| <img id="afterImg" src="" alt="Processed Image"> | |
| </div> | |
| </div> | |
| </div> | |
| <script> | |
| const fileForm = document.getElementById('fileForm'); | |
| const urlForm = document.getElementById('urlForm'); | |
| const beforeImg = document.getElementById('beforeImg'); | |
| const afterImg = document.getElementById('afterImg'); | |
| fileForm.addEventListener('submit', async (e) => { | |
| e.preventDefault(); | |
| const fileInput = document.getElementById('fileInput'); | |
| if (!fileInput.files.length) return alert("Select a file!"); | |
| const file = fileInput.files[0]; | |
| beforeImg.src = URL.createObjectURL(file); | |
| const formData = new FormData(); | |
| formData.append("file", file); | |
| const res = await fetch('/remove_bg_file', { method: 'POST', body: formData }); | |
| if (!res.ok) { | |
| const err = await res.json(); | |
| alert(err.detail || "Failed to process image"); | |
| return; | |
| } | |
| const blob = await res.blob(); | |
| afterImg.src = URL.createObjectURL(blob); | |
| }); | |
| urlForm.addEventListener('submit', async (e) => { | |
| e.preventDefault(); | |
| const urlInput = document.getElementById('urlInput').value; | |
| if (!urlInput) return alert("Enter an image URL"); | |
| beforeImg.src = urlInput; | |
| const formData = new FormData(); | |
| formData.append("image_url", urlInput); | |
| const res = await fetch('/remove_bg_url', { method: 'POST', body: formData }); | |
| if (!res.ok) { | |
| const err = await res.json(); | |
| alert(err.detail || "Failed to process image URL"); | |
| return; | |
| } | |
| const blob = await res.blob(); | |
| afterImg.src = URL.createObjectURL(blob); | |
| }); | |
| </script> | |
| </body> | |
| </html> | |
| """ | |
| return HTMLResponse(content=html) | |
| # ------------------------- | |
| # Run Server (Auto-detect filename) | |
| # ------------------------- | |
| if __name__ == "__main__": | |
| module_name = os.path.splitext(os.path.basename(__file__))[0] | |
| uvicorn.run(f"{module_name}:app", host="0.0.0.0", port=7860, workers=2) | |