Lord-Raven commited on
Commit
8adef27
·
1 Parent(s): 9308bbb

Trying to add CPU support.

Browse files
Files changed (1) hide show
  1. app.py +43 -29
app.py CHANGED
@@ -17,10 +17,15 @@ from briarmbg import BriaRMBG
17
  from depth_anything_v2.dpt import DepthAnythingV2
18
 
19
 
20
- net = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
21
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
- net.to(device)
23
- net.eval()
 
 
 
 
 
24
 
25
  def resize_image(image):
26
  image = image.convert('RGB')
@@ -28,36 +33,45 @@ def resize_image(image):
28
  image = image.resize(model_input_size, Image.BILINEAR)
29
  return image
30
 
31
- @spaces.GPU(duration=6)
32
- def process_background(image):
33
-
34
- # prepare input
35
- orig_image = Image.fromarray(image)
36
- w,h = orig_im_size = orig_image.size
37
- image = resize_image(orig_image)
38
- im_np = np.array(image)
39
- im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2,0,1)
40
- im_tensor = torch.unsqueeze(im_tensor,0)
41
- im_tensor = torch.divide(im_tensor,255.0)
42
- im_tensor = normalize(im_tensor,[0.5,0.5,0.5],[1.0,1.0,1.0])
43
- if torch.cuda.is_available():
44
- im_tensor=im_tensor.cuda()
45
-
46
- #inference
47
- result=net(im_tensor)
48
- # post process
49
- result = torch.squeeze(F.interpolate(result[0][0], size=(h,w), mode='bilinear') ,0)
50
- ma = torch.max(result)
51
- mi = torch.min(result)
52
- result = (result-mi)/(ma-mi)
53
- # image to pil
54
- result_array = (result*255).cpu().data.numpy().astype(np.uint8)
55
  pil_mask = Image.fromarray(np.squeeze(result_array))
56
- # add the mask on the original image as alpha channel
57
  new_im = orig_image.copy()
58
  new_im.putalpha(pil_mask)
59
  return new_im
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  css = """
63
  #img-display-container {
 
17
  from depth_anything_v2.dpt import DepthAnythingV2
18
 
19
 
20
+ net_cpu = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
21
+ net_cpu.to('cpu')
22
+ net_cpu.eval()
23
+
24
+ net_gpu = None
25
+ if torch.cuda.is_available():
26
+ net_gpu = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
27
+ net_gpu.to('cuda')
28
+ net_gpu.eval()
29
 
30
  def resize_image(image):
31
  image = image.convert('RGB')
 
33
  image = image.resize(model_input_size, Image.BILINEAR)
34
  return image
35
 
36
+ def _run_rmbg_on_image(image_np, net, device_str):
37
+ """Shared helper: run RMBG net on a numpy image and return a PIL RGBA with alpha mask."""
38
+ orig_image = Image.fromarray(image_np)
39
+ w, h = orig_image.size
40
+ img = resize_image(orig_image)
41
+ im_np = np.array(img)
42
+ im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0) / 255.0
43
+ im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
44
+ if device_str == 'cuda':
45
+ im_tensor = im_tensor.cuda()
46
+ with torch.no_grad():
47
+ result = net(im_tensor)
48
+ result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode='bilinear'), 0)
49
+ ma = torch.max(result); mi = torch.min(result)
50
+ result = (result - mi) / (ma - mi + 1e-8)
51
+ result_array = (result * 255).cpu().numpy().astype(np.uint8)
 
 
 
 
 
 
 
 
52
  pil_mask = Image.fromarray(np.squeeze(result_array))
 
53
  new_im = orig_image.copy()
54
  new_im.putalpha(pil_mask)
55
  return new_im
56
 
57
+ @spaces.GPU(duration=6)
58
+ def process_background_gpu(image):
59
+ if net_gpu is None:
60
+ raise RuntimeError("No GPU instance available")
61
+ return _run_rmbg_on_image(image, net_gpu, 'cuda')
62
+
63
+ def process_background_cpu(image):
64
+ return _run_rmbg_on_image(image, net_cpu, 'cpu')
65
+
66
+ # wrapper used by the UI: try GPU first, fall back to CPU on any exception
67
+ def process_background(image):
68
+ try:
69
+ # attempt GPU call (this can raise if Zero-GPU is unavailable)
70
+ return process_background_gpu(image)
71
+ except Exception:
72
+ # fallback to CPU path
73
+ return process_background_cpu(image)
74
+
75
 
76
  css = """
77
  #img-display-container {