import os import asyncio from concurrent.futures import ThreadPoolExecutor from fastapi import FastAPI, UploadFile, File, Form, HTTPException from fastapi.responses import StreamingResponse, HTMLResponse from PIL import Image, ImageSequence import pillow_heif # HEIC/HEIF support import numpy as np import torch from transformers import AutoModelForImageSegmentation from io import BytesIO from loadimg import load_img import uvicorn # ------------------------- # Enable HEIC/HEIF Support # ------------------------- pillow_heif.register_heif_opener() # ------------------------- # Thread Pool for Concurrency # ------------------------- executor = ThreadPoolExecutor(max_workers=os.cpu_count() or 4) # ------------------------- # Model Setup (Load Once) # ------------------------- MODEL_DIR = "models/BiRefNet" os.makedirs(MODEL_DIR, exist_ok=True) device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") print("Loading BiRefNet model (first run may take a while)...") birefnet = AutoModelForImageSegmentation.from_pretrained( "ZhengPeng7/BiRefNet", cache_dir=MODEL_DIR, trust_remote_code=True ) birefnet.to(device) birefnet.eval() print(f"Model loaded successfully on {device}.") # ------------------------- # Image Preprocessing # ------------------------- TARGET_SIZE = (512, 512) # Lower resolution for faster inference def transform_image(image: Image.Image) -> torch.Tensor: image = image.resize(TARGET_SIZE) arr = np.array(image).astype(np.float32) / 255.0 mean = np.array([0.485, 0.456, 0.406], dtype=np.float32) std = np.array([0.229, 0.224, 0.225], dtype=np.float32) arr = (arr - mean) / std arr = np.transpose(arr, (2, 0, 1)) # HWC -> CHW tensor = torch.from_numpy(arr).unsqueeze(0).to(torch.float32).to(device) return tensor def process_image_sync(image: Image.Image) -> BytesIO: """Process image synchronously and return PNG bytes (in-memory).""" image_size = image.size input_tensor = transform_image(image) with torch.no_grad(): if device == "cuda": # Mixed precision for GPU with torch.cuda.amp.autocast(): preds = birefnet(input_tensor)[-1].sigmoid().cpu() else: # CPU fallback preds = birefnet(input_tensor)[-1].sigmoid().cpu() pred = preds[0, 0].numpy() mask = Image.fromarray((pred * 255).astype(np.uint8)).resize(image_size) image = image.copy() image.putalpha(mask) output_buffer = BytesIO() image.save(output_buffer, format="PNG") output_buffer.seek(0) return output_buffer async def process_image_async(image: Image.Image) -> BytesIO: """Run processing asynchronously in thread pool (no disk I/O).""" loop = asyncio.get_event_loop() return await loop.run_in_executor(executor, process_image_sync, image) # ------------------------- # Safe Image Loader # ------------------------- def open_image_safely(file_bytes: bytes) -> Image.Image: """Open image safely (HEIC, HEIF, PDF, SVG, GIF, PNG, JPG, etc).""" try: img = Image.open(BytesIO(file_bytes)) fmt = (img.format or "").lower() # Handle PDF: first page if fmt == "pdf": from pdf2image import convert_from_bytes pdf_images = convert_from_bytes(file_bytes, first_page=1, last_page=1) return pdf_images[0].convert("RGB") # Handle GIF: first frame if fmt == "gif" and getattr(img, "is_animated", False): img.seek(0) return img.convert("RGB") # Handle SVG if fmt == "svg": import cairosvg png_bytes = cairosvg.svg2png(bytestring=file_bytes) return Image.open(BytesIO(png_bytes)).convert("RGB") # Other formats (HEIC, HEIF, JPG, PNG) return img.convert("RGB") except Exception as e: raise HTTPException(status_code=400, detail=f"Unsupported or corrupted image: {e}") # ------------------------- # FastAPI App # ------------------------- app = FastAPI(title="Background Removal API", description="Removes image backgrounds in-memory") # ------------------------- # API Endpoints # ------------------------- @app.post("/remove_bg_file") async def remove_bg_file(file: UploadFile = File(...)): try: contents = await file.read() image = open_image_safely(contents) output_buffer = await process_image_async(image) return StreamingResponse(output_buffer, media_type="image/png") except HTTPException as e: raise e except Exception as e: raise HTTPException(status_code=500, detail=f"Error processing image: {e}") @app.post("/remove_bg_url") async def remove_bg_url(image_url: str = Form(...)): try: image = load_img(image_url, output_type="pil").convert("RGB") output_buffer = await process_image_async(image) return StreamingResponse(output_buffer, media_type="image/png") except Exception as e: raise HTTPException(status_code=500, detail=f"Error processing URL: {e}") # ------------------------- # Web Interface # ------------------------- @app.get("/", response_class=HTMLResponse) async def index(): html = """