BG_RM_AI / app.py
um41r's picture
Update app.py
e875c80 verified
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import Response
from PIL import Image
import torch
import torchvision.transforms as transforms
from io import BytesIO
from transformers import AutoModelForImageSegmentation
import uvicorn
from fastapi.middleware.cors import CORSMiddleware
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # For production, specify your frontend domain
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load model directly from Hugging Face with trust_remote_code=True
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForImageSegmentation.from_pretrained('ZhengPeng7/BiRefNet', trust_remote_code=True)
model.to(device)
model.eval()
# Define transforms based on the other implementation
transform = transforms.Compose([
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
@app.post("/remove_bg")
async def remove_bg(file: UploadFile = File(...)):
# Open the image
image = Image.open(BytesIO(await file.read())).convert("RGB")
original_size = image.size
# Transform the image
input_tensor = transform(image).unsqueeze(0).to(device)
# Perform inference - notice the key differences here
with torch.no_grad():
# Take the last element [-1] from the model output and apply sigmoid
pred = model(input_tensor)[-1].sigmoid().cpu()[0].squeeze()
# Convert to PIL image and resize back to original size
mask = transforms.ToPILImage()(pred)
mask = mask.resize(original_size)
# Apply mask to original image
result_image = image.copy().convert("RGBA")
result_image.putalpha(mask)
# Save image to bytes
img_io = BytesIO()
result_image.save(img_io, format="PNG")
img_io.seek(0)
return Response(content=img_io.getvalue(), media_type="image/png")
# Add a simple root route for health check
@app.get("/")
def read_root():
return {"status": "ok", "message": "Background removal API is running"}
# Make sure this is included for Hugging Face Spaces
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)