Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,12 +1,12 @@
|
|
| 1 |
import numpy as np
|
| 2 |
import torch
|
| 3 |
import torch.nn.functional as F
|
| 4 |
-
from torchvision.transforms.functional import normalize
|
| 5 |
from huggingface_hub import hf_hub_download
|
| 6 |
from briarmbg import BriaRMBG
|
| 7 |
from PIL import Image
|
| 8 |
from fastapi import FastAPI, File, UploadFile
|
| 9 |
-
from fastapi.responses import FileResponse
|
| 10 |
import os
|
| 11 |
|
| 12 |
app = FastAPI()
|
|
@@ -15,10 +15,10 @@ app = FastAPI()
|
|
| 15 |
net = BriaRMBG()
|
| 16 |
model_path = hf_hub_download("briaai/RMBG-1.4", 'model.pth')
|
| 17 |
if torch.cuda.is_available():
|
| 18 |
-
net.load_state_dict(torch.load(model_path))
|
| 19 |
net = net.cuda()
|
| 20 |
else:
|
| 21 |
-
net.load_state_dict(torch.load(model_path, map_location="cpu"))
|
| 22 |
net.eval()
|
| 23 |
|
| 24 |
def resize_image(image):
|
|
@@ -58,6 +58,10 @@ def process_image(image: Image.Image):
|
|
| 58 |
|
| 59 |
return output_path
|
| 60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
@app.post("/remove-background/")
|
| 62 |
async def remove_background(file: UploadFile = File(...)):
|
| 63 |
image = Image.open(file.file)
|
|
|
|
| 1 |
import numpy as np
|
| 2 |
import torch
|
| 3 |
import torch.nn.functional as F
|
| 4 |
+
from torchvision.transforms.functional import normalize
|
| 5 |
from huggingface_hub import hf_hub_download
|
| 6 |
from briarmbg import BriaRMBG
|
| 7 |
from PIL import Image
|
| 8 |
from fastapi import FastAPI, File, UploadFile
|
| 9 |
+
from fastapi.responses import FileResponse, JSONResponse
|
| 10 |
import os
|
| 11 |
|
| 12 |
app = FastAPI()
|
|
|
|
| 15 |
net = BriaRMBG()
|
| 16 |
model_path = hf_hub_download("briaai/RMBG-1.4", 'model.pth')
|
| 17 |
if torch.cuda.is_available():
|
| 18 |
+
net.load_state_dict(torch.load(model_path, map_location="cuda", weights_only=True))
|
| 19 |
net = net.cuda()
|
| 20 |
else:
|
| 21 |
+
net.load_state_dict(torch.load(model_path, map_location="cpu", weights_only=True))
|
| 22 |
net.eval()
|
| 23 |
|
| 24 |
def resize_image(image):
|
|
|
|
| 58 |
|
| 59 |
return output_path
|
| 60 |
|
| 61 |
+
@app.get("/")
|
| 62 |
+
def read_root():
|
| 63 |
+
return {"message": "Welcome to the Background Removal API"}
|
| 64 |
+
|
| 65 |
@app.post("/remove-background/")
|
| 66 |
async def remove_background(file: UploadFile = File(...)):
|
| 67 |
image = Image.open(file.file)
|