fantaxy commited on
Commit
cc6ce4e
·
verified ·
1 Parent(s): b73ce25

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -4
app.py CHANGED
@@ -1,12 +1,12 @@
1
  import numpy as np
2
  import torch
3
  import torch.nn.functional as F
4
- from torchvision.transforms.functional import normalize # 수정된 부분
5
  from huggingface_hub import hf_hub_download
6
  from briarmbg import BriaRMBG
7
  from PIL import Image
8
  from fastapi import FastAPI, File, UploadFile
9
- from fastapi.responses import FileResponse
10
  import os
11
 
12
  app = FastAPI()
@@ -15,10 +15,10 @@ app = FastAPI()
15
  net = BriaRMBG()
16
  model_path = hf_hub_download("briaai/RMBG-1.4", 'model.pth')
17
  if torch.cuda.is_available():
18
- net.load_state_dict(torch.load(model_path))
19
  net = net.cuda()
20
  else:
21
- net.load_state_dict(torch.load(model_path, map_location="cpu"))
22
  net.eval()
23
 
24
  def resize_image(image):
@@ -58,6 +58,10 @@ def process_image(image: Image.Image):
58
 
59
  return output_path
60
 
 
 
 
 
61
  @app.post("/remove-background/")
62
  async def remove_background(file: UploadFile = File(...)):
63
  image = Image.open(file.file)
 
1
  import numpy as np
2
  import torch
3
  import torch.nn.functional as F
4
+ from torchvision.transforms.functional import normalize
5
  from huggingface_hub import hf_hub_download
6
  from briarmbg import BriaRMBG
7
  from PIL import Image
8
  from fastapi import FastAPI, File, UploadFile
9
+ from fastapi.responses import FileResponse, JSONResponse
10
  import os
11
 
12
  app = FastAPI()
 
15
  net = BriaRMBG()
16
  model_path = hf_hub_download("briaai/RMBG-1.4", 'model.pth')
17
  if torch.cuda.is_available():
18
+ net.load_state_dict(torch.load(model_path, map_location="cuda", weights_only=True))
19
  net = net.cuda()
20
  else:
21
+ net.load_state_dict(torch.load(model_path, map_location="cpu", weights_only=True))
22
  net.eval()
23
 
24
  def resize_image(image):
 
58
 
59
  return output_path
60
 
61
+ @app.get("/")
62
+ def read_root():
63
+ return {"message": "Welcome to the Background Removal API"}
64
+
65
  @app.post("/remove-background/")
66
  async def remove_background(file: UploadFile = File(...)):
67
  image = Image.open(file.file)