Janeka commited on
Commit
e50da0f
·
verified ·
1 Parent(s): d770436

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -22
app.py CHANGED
@@ -4,20 +4,21 @@ from PIL import Image
4
  import numpy as np
5
  import torch
6
  import os
 
7
  import requests
8
  from tqdm import tqdm
9
  import subprocess
10
 
11
- # Clone InSPyReNet repository if not present
12
  if not os.path.exists('InSPyReNet'):
13
  print("Cloning InSPyReNet repository...")
14
  subprocess.run(['git', 'clone', 'https://github.com/plemeri/InSPyReNet.git'])
15
 
16
- # Add to Python path
17
- import sys
18
- sys.path.insert(0, 'InSPyReNet/lib')
19
 
20
- # Download model weights if not present
21
  def download_file(url, filename):
22
  response = requests.get(url, stream=True)
23
  total_size = int(response.headers.get('content-length', 0))
@@ -40,7 +41,7 @@ if not os.path.exists('InSPyReNet.pth'):
40
  "InSPyReNet.pth"
41
  )
42
 
43
- # Import after setup
44
  from InSPyReNet import InSPyReNet
45
  from modules.layers import load_model
46
  from utils.misc import load_config
@@ -53,32 +54,36 @@ model = InSPyReNet(cfg)
53
  model = load_model(model, 'InSPyReNet.pth', device)
54
  model.eval()
55
 
56
- def process_with_inspyrenet(image):
57
- # Convert to numpy and normalize
58
  image = np.array(image).astype(np.float32)
59
  image -= np.array([104.00699, 116.66877, 122.67892])
60
  image = image.transpose((2, 0, 1))
61
- image = torch.from_numpy(image).unsqueeze(0).to(device)
62
-
63
- # Predict
 
64
  with torch.no_grad():
65
- pred = model(image)
66
-
67
- # Create mask
68
  mask = (pred.squeeze().cpu().numpy() > 0.5).astype(np.uint8) * 255
69
  return mask
70
 
71
- def remove_background(input_image):
72
  try:
73
  if isinstance(input_image, np.ndarray):
74
  input_image = Image.fromarray(input_image)
75
 
76
- # Process with InSPyReNet
77
- mask = process_with_inspyrenet(input_image)
78
- output = input_image.copy()
79
- output.putalpha(Image.fromarray(mask))
 
 
 
 
 
 
80
 
81
- return output, Image.fromarray(mask)
82
 
83
  except Exception as e:
84
  print(f"Error: {str(e)}")
@@ -86,13 +91,20 @@ def remove_background(input_image):
86
 
87
  iface = gr.Interface(
88
  fn=remove_background,
89
- inputs=gr.Image(type="pil", label="Input Image"),
 
 
 
 
 
 
 
90
  outputs=[
91
  gr.Image(type="pil", label="Result"),
92
  gr.Image(type="pil", label="Mask")
93
  ],
94
  title="Professional Background Remover",
95
- description="Using InSPyReNet for high-quality background removal"
96
  )
97
 
98
  if __name__ == "__main__":
 
4
  import numpy as np
5
  import torch
6
  import os
7
+ import sys
8
  import requests
9
  from tqdm import tqdm
10
  import subprocess
11
 
12
+ # Clone and set up repository
13
  if not os.path.exists('InSPyReNet'):
14
  print("Cloning InSPyReNet repository...")
15
  subprocess.run(['git', 'clone', 'https://github.com/plemeri/InSPyReNet.git'])
16
 
17
+ # Add correct paths to system
18
+ sys.path.insert(0, os.path.abspath('InSPyReNet'))
19
+ sys.path.insert(0, os.path.abspath('InSPyReNet/lib'))
20
 
21
+ # Download model weights
22
  def download_file(url, filename):
23
  response = requests.get(url, stream=True)
24
  total_size = int(response.headers.get('content-length', 0))
 
41
  "InSPyReNet.pth"
42
  )
43
 
44
+ # Import after setting paths
45
  from InSPyReNet import InSPyReNet
46
  from modules.layers import load_model
47
  from utils.misc import load_config
 
54
  model = load_model(model, 'InSPyReNet.pth', device)
55
  model.eval()
56
 
57
+ def preprocess(image):
 
58
  image = np.array(image).astype(np.float32)
59
  image -= np.array([104.00699, 116.66877, 122.67892])
60
  image = image.transpose((2, 0, 1))
61
+ return torch.from_numpy(image).unsqueeze(0)
62
+
63
+ def process_with_inspyrenet(image):
64
+ image_tensor = preprocess(image).to(device)
65
  with torch.no_grad():
66
+ pred = model(image_tensor)
 
 
67
  mask = (pred.squeeze().cpu().numpy() > 0.5).astype(np.uint8) * 255
68
  return mask
69
 
70
+ def remove_background(input_image, model_choice="Rembg (U²-Net)"):
71
  try:
72
  if isinstance(input_image, np.ndarray):
73
  input_image = Image.fromarray(input_image)
74
 
75
+ if model_choice == "InSPyReNet":
76
+ mask = process_with_inspyrenet(input_image)
77
+ output = input_image.copy()
78
+ output.putalpha(Image.fromarray(mask))
79
+ else:
80
+ output = remove(input_image)
81
+ if output.mode == 'RGBA':
82
+ mask = output.split()[-1]
83
+ else:
84
+ mask = Image.new('L', output.size, 255)
85
 
86
+ return output, mask
87
 
88
  except Exception as e:
89
  print(f"Error: {str(e)}")
 
91
 
92
  iface = gr.Interface(
93
  fn=remove_background,
94
+ inputs=[
95
+ gr.Image(type="pil", label="Input Image"),
96
+ gr.Radio(
97
+ choices=["Rembg (U²-Net)", "InSPyReNet"],
98
+ value="Rembg (U²-Net)",
99
+ label="Model Selection"
100
+ )
101
+ ],
102
  outputs=[
103
  gr.Image(type="pil", label="Result"),
104
  gr.Image(type="pil", label="Mask")
105
  ],
106
  title="Professional Background Remover",
107
+ description="Choose between Rembg (faster) or InSPyReNet (higher quality)"
108
  )
109
 
110
  if __name__ == "__main__":