from fastapi import FastAPI, UploadFile, File from fastapi.responses import Response from PIL import Image import torch import torchvision.transforms as transforms from io import BytesIO from transformers import AutoModelForImageSegmentation import uvicorn from fastapi.middleware.cors import CORSMiddleware app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], # For production, specify your frontend domain allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Load model directly from Hugging Face with trust_remote_code=True device = "cuda" if torch.cuda.is_available() else "cpu" model = AutoModelForImageSegmentation.from_pretrained('ZhengPeng7/BiRefNet', trust_remote_code=True) model.to(device) model.eval() # Define transforms based on the other implementation transform = transforms.Compose([ transforms.Resize((1024, 1024)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) @app.post("/remove_bg") async def remove_bg(file: UploadFile = File(...)): # Open the image image = Image.open(BytesIO(await file.read())).convert("RGB") original_size = image.size # Transform the image input_tensor = transform(image).unsqueeze(0).to(device) # Perform inference - notice the key differences here with torch.no_grad(): # Take the last element [-1] from the model output and apply sigmoid pred = model(input_tensor)[-1].sigmoid().cpu()[0].squeeze() # Convert to PIL image and resize back to original size mask = transforms.ToPILImage()(pred) mask = mask.resize(original_size) # Apply mask to original image result_image = image.copy().convert("RGBA") result_image.putalpha(mask) # Save image to bytes img_io = BytesIO() result_image.save(img_io, format="PNG") img_io.seek(0) return Response(content=img_io.getvalue(), media_type="image/png") # Add a simple root route for health check @app.get("/") def read_root(): return {"status": "ok", "message": "Background removal API is running"} # Make sure this is included for Hugging Face Spaces if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)