File size: 2,257 Bytes
22d0b50
 
 
 
 
 
 
 
 
 
 
e875c80
22d0b50
 
e875c80
22d0b50
 
 
 
 
e875c80
22d0b50
 
 
 
 
e875c80
22d0b50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e875c80
22d0b50
e875c80
22d0b50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e875c80
22d0b50
 
 
 
e875c80
22d0b50
e875c80
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import Response
from PIL import Image
import torch
import torchvision.transforms as transforms
from io import BytesIO
from transformers import AutoModelForImageSegmentation
import uvicorn
from fastapi.middleware.cors import CORSMiddleware

app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # For production, specify your frontend domain
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Load model directly from Hugging Face with trust_remote_code=True
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForImageSegmentation.from_pretrained('ZhengPeng7/BiRefNet', trust_remote_code=True)
model.to(device)
model.eval()

# Define transforms based on the other implementation
transform = transforms.Compose([
    transforms.Resize((1024, 1024)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

@app.post("/remove_bg")
async def remove_bg(file: UploadFile = File(...)):
    # Open the image
    image = Image.open(BytesIO(await file.read())).convert("RGB")
    original_size = image.size
    
    # Transform the image
    input_tensor = transform(image).unsqueeze(0).to(device)
    
    # Perform inference - notice the key differences here
    with torch.no_grad():
        # Take the last element [-1] from the model output and apply sigmoid
        pred = model(input_tensor)[-1].sigmoid().cpu()[0].squeeze()
    
    # Convert to PIL image and resize back to original size
    mask = transforms.ToPILImage()(pred)
    mask = mask.resize(original_size)
    
    # Apply mask to original image
    result_image = image.copy().convert("RGBA")
    result_image.putalpha(mask)
    
    # Save image to bytes
    img_io = BytesIO()
    result_image.save(img_io, format="PNG")
    img_io.seek(0)
    
    return Response(content=img_io.getvalue(), media_type="image/png")

# Add a simple root route for health check
@app.get("/")
def read_root():
    return {"status": "ok", "message": "Background removal API is running"}

# Make sure this is included for Hugging Face Spaces
if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)