Janeka commited on
Commit
8fbeea7
·
verified ·
1 Parent(s): c9844a5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -60
app.py CHANGED
@@ -3,92 +3,97 @@ from rembg import remove
3
  from PIL import Image
4
  import numpy as np
5
  import torch
6
- import cv2
7
  import os
 
 
 
8
 
9
- # Initialize InSPyReNet if available
10
- try:
11
- from InSPyReNet.models.InSPyReNet import InSPyReNet
12
- from InSPyReNet.utils.dataloader import test_dataset
 
 
 
 
 
 
 
 
 
13
 
14
- # Download InSPyReNet weights
15
- if not os.path.exists('InSPyReNet.pth'):
16
- os.system('wget https://github.com/plemeri/InSPyReNet/releases/download/v1.0/InSPyReNet.pth')
17
-
18
- # Load InSPyReNet model
19
- inspyrenet = InSPyReNet()
20
- inspyrenet.load_state_dict(torch.load('InSPyReNet.pth', map_location='cpu'))
21
- inspyrenet.eval()
22
- HAS_INSPYRE = True
23
- except:
24
- HAS_INSPYRE = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  def process_with_inspyrenet(image):
27
- # Preprocess
28
- image = test_dataset.preprocess(np.array(image))
29
- image = torch.from_numpy(image).unsqueeze(0)
 
 
30
 
31
  # Predict
32
  with torch.no_grad():
33
- pred = inspyrenet(image)
34
 
35
- # Post-process
36
  mask = (pred.squeeze().cpu().numpy() > 0.5).astype(np.uint8) * 255
37
  return mask
38
 
39
- def remove_background(input_image, model_choice="Rembg (U²-Net)"):
40
  try:
41
- # Convert to PIL Image if it's a numpy array
42
  if isinstance(input_image, np.ndarray):
43
  input_image = Image.fromarray(input_image)
44
 
45
- # Process with selected model
46
- if model_choice == "InSPyReNet" and HAS_INSPYRE:
47
- mask = process_with_inspyrenet(input_image)
48
- mask_img = Image.fromarray(mask)
49
-
50
- # Apply mask to original image
51
- output = input_image.copy()
52
- output.putalpha(mask_img)
53
- else:
54
- # Default to Rembg
55
- output = remove(input_image)
56
- if output.mode == 'RGBA':
57
- mask = output.split()[-1]
58
- mask_np = np.array(mask)
59
- else:
60
- mask_np = np.ones(output.size[::-1], dtype=np.uint8) * 255
61
- mask_img = Image.fromarray(mask_np)
62
 
63
- return output, mask_img
64
 
65
  except Exception as e:
66
- print(f"Error processing image: {str(e)}")
67
  return None, None
68
 
69
- # Create interface
70
  iface = gr.Interface(
71
  fn=remove_background,
72
- inputs=[
73
- gr.Image(type="pil", label="Input Image"),
74
- gr.Radio(
75
- choices=["Rembg (U²-Net)", "InSPyReNet"],
76
- value="Rembg (U²-Net)",
77
- label="Model Selection"
78
- )
79
- ],
80
  outputs=[
81
- gr.Image(type="pil", label="Result with Transparent Background"),
82
- gr.Image(type="pil", label="Segmentation Mask")
83
  ],
84
- title="Hybrid Background Remover (CPU)",
85
- description="""
86
- Upload an image to remove the background. Choose between:
87
- - Rembg (U²-Net): Faster (5-15 sec)
88
- - InSPyReNet: More accurate but slower (15-30 sec)
89
- """
90
  )
91
 
92
- # Launch with minimal configuration
93
  if __name__ == "__main__":
94
  iface.launch(server_name="0.0.0.0", server_port=7860)
 
3
  from PIL import Image
4
  import numpy as np
5
  import torch
 
6
  import os
7
+ import requests
8
+ from tqdm import tqdm
9
+ import subprocess
10
 
11
+ # Clone InSPyReNet repository if not present
12
+ if not os.path.exists('InSPyReNet'):
13
+ print("Cloning InSPyReNet repository...")
14
+ subprocess.run(['git', 'clone', 'https://github.com/plemeri/InSPyReNet.git'])
15
+
16
+ # Add to Python path
17
+ import sys
18
+ sys.path.insert(0, 'InSPyReNet/lib')
19
+
20
+ # Download model weights if not present
21
+ def download_file(url, filename):
22
+ response = requests.get(url, stream=True)
23
+ total_size = int(response.headers.get('content-length', 0))
24
 
25
+ with open(filename, 'wb') as f, tqdm(
26
+ desc=filename,
27
+ total=total_size,
28
+ unit='iB',
29
+ unit_scale=True,
30
+ unit_divisor=1024,
31
+ ) as bar:
32
+ for data in response.iter_content(chunk_size=1024):
33
+ size = f.write(data)
34
+ bar.update(size)
35
+
36
+ if not os.path.exists('InSPyReNet.pth'):
37
+ print("Downloading model weights...")
38
+ download_file(
39
+ "https://github.com/plemeri/InSPyReNet/releases/download/v1.0/InSPyReNet.pth",
40
+ "InSPyReNet.pth"
41
+ )
42
+
43
+ # Import after setup
44
+ from InSPyReNet import InSPyReNet
45
+ from modules.layers import load_model
46
+ from utils.misc import load_config
47
+
48
+ # Initialize model
49
+ print("Loading model...")
50
+ cfg = load_config('InSPyReNet/configs/InSPyReNet_SwinB.yaml')
51
+ device = torch.device('cpu')
52
+ model = InSPyReNet(cfg)
53
+ model = load_model(model, 'InSPyReNet.pth', device)
54
+ model.eval()
55
 
56
  def process_with_inspyrenet(image):
57
+ # Convert to numpy and normalize
58
+ image = np.array(image).astype(np.float32)
59
+ image -= np.array([104.00699, 116.66877, 122.67892])
60
+ image = image.transpose((2, 0, 1))
61
+ image = torch.from_numpy(image).unsqueeze(0).to(device)
62
 
63
  # Predict
64
  with torch.no_grad():
65
+ pred = model(image)
66
 
67
+ # Create mask
68
  mask = (pred.squeeze().cpu().numpy() > 0.5).astype(np.uint8) * 255
69
  return mask
70
 
71
+ def remove_background(input_image):
72
  try:
 
73
  if isinstance(input_image, np.ndarray):
74
  input_image = Image.fromarray(input_image)
75
 
76
+ # Process with InSPyReNet
77
+ mask = process_with_inspyrenet(input_image)
78
+ output = input_image.copy()
79
+ output.putalpha(Image.fromarray(mask))
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
+ return output, Image.fromarray(mask)
82
 
83
  except Exception as e:
84
+ print(f"Error: {str(e)}")
85
  return None, None
86
 
 
87
  iface = gr.Interface(
88
  fn=remove_background,
89
+ inputs=gr.Image(type="pil", label="Input Image"),
 
 
 
 
 
 
 
90
  outputs=[
91
+ gr.Image(type="pil", label="Result"),
92
+ gr.Image(type="pil", label="Mask")
93
  ],
94
+ title="Professional Background Remover",
95
+ description="Using InSPyReNet for high-quality background removal"
 
 
 
 
96
  )
97
 
 
98
  if __name__ == "__main__":
99
  iface.launch(server_name="0.0.0.0", server_port=7860)