merve HF Staff commited on
Commit
7a6faa9
·
verified ·
1 Parent(s): 53a50aa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -7
app.py CHANGED
@@ -7,11 +7,10 @@ from torchvision import transforms
7
  from typing import Union, Tuple
8
  from PIL import Image
9
 
10
- torch.set_float32_matmul_precision(["high", "highest"][0])
11
-
12
  birefnet = AutoModelForImageSegmentation.from_pretrained(
13
- "ZhengPeng7/BiRefNet", trust_remote_code=True
14
  )
 
15
  #birefnet.to("cuda")
16
 
17
  transform_image = transforms.Compose(
@@ -59,10 +58,9 @@ def process(image: Image.Image) -> Image.Image:
59
  PIL.Image: The image with the background removed, using the segmentation mask as transparency.
60
  """
61
  image_size = image.size
62
- input_images = transform_image(image).unsqueeze(0)#.to("cuda")
63
- # Prediction
64
- with torch.no_grad():
65
- preds = birefnet(input_images)[-1].sigmoid().cpu()
66
  pred = preds[0].squeeze()
67
  pred_pil = transforms.ToPILImage()(pred)
68
  mask = pred_pil.resize(image_size)
 
7
  from typing import Union, Tuple
8
  from PIL import Image
9
 
 
 
10
  birefnet = AutoModelForImageSegmentation.from_pretrained(
11
+ "ZhengPeng7/BiRefNet", low_cpu_mem_usage=False, trust_remote_code=True, torch_dtype=torch.float32, device_map=None
12
  )
13
+ birefnet = birefnet.eval()
14
  #birefnet.to("cuda")
15
 
16
  transform_image = transforms.Compose(
 
58
  PIL.Image: The image with the background removed, using the segmentation mask as transparency.
59
  """
60
  image_size = image.size
61
+ input_images = transform_image(image).unsqueeze(0)
62
+ with torch.inference_mode():
63
+ preds = birefnet(input_images)[-1].sigmoid().detach().cpu()
 
64
  pred = preds[0].squeeze()
65
  pred_pil = transforms.ToPILImage()(pred)
66
  mask = pred_pil.resize(image_size)