AkashKumarave commited on
Commit
c1db78e
·
verified ·
1 Parent(s): e7338e2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -118
app.py CHANGED
@@ -1,183 +1,140 @@
1
  import os
2
  import cv2
3
- import gradio as gr
4
  import numpy as np
5
  import torch
6
- import torch.nn as nn
7
  from torchvision import transforms
8
- import torch.nn.functional as F
9
  from PIL import Image
10
- import warnings
11
- warnings.filterwarnings("ignore")
12
 
13
- # Initialize device
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
 
16
- # Clone repository and setup model
17
  if not os.path.exists("DIS"):
18
- os.system("git clone --depth 1 https://github.com/xuebinqin/DIS")
19
  os.system("mv DIS/IS-Net/* .")
20
 
21
  # Import model components
22
- from data_loader_cache import normalize, im_reader, im_preprocess
23
  from models import ISNetDIS
 
24
 
25
- # Setup model directory
26
  os.makedirs("saved_models", exist_ok=True)
27
  if os.path.exists("isnet.pth"):
28
  os.rename("isnet.pth", "saved_models/isnet.pth")
29
 
30
- class GOSNormalize:
 
31
  def __init__(self, mean=[0.5, 0.5, 0.5], std=[1.0, 1.0, 1.0]):
32
  self.mean = mean
33
  self.std = std
34
-
35
- def __call__(self, image):
36
- return normalize(image, self.mean, self.std)
37
-
38
- transform = transforms.Compose([GOSNormalize()])
39
-
40
- def load_image(im_path, hypar):
41
- im = im_reader(im_path)
42
- im, im_shp = im_preprocess(im, hypar["cache_size"])
43
- im = torch.divide(im, 255.0)
44
- shape = torch.from_numpy(np.array(im_shp))
45
- return transform(im).unsqueeze(0), shape.unsqueeze(0)
46
-
47
- def build_model(hypar, device):
48
- net = hypar["model"]
49
- if hypar["model_digit"] == "half":
50
- net.half()
51
- for layer in net.modules():
52
- if isinstance(layer, nn.BatchNorm2d):
53
- layer.float()
54
-
55
- net.to(device)
56
 
57
- model_path = os.path.join(hypar["model_path"], hypar["restore_model"])
58
- if os.path.exists(model_path):
59
- state_dict = torch.load(model_path, map_location=device)
60
- net.load_state_dict(state_dict)
61
-
62
- net.eval()
63
- return net
64
 
65
- def predict(net, inputs_val, shapes_val, hypar, device):
66
- with torch.no_grad():
67
- inputs_val = inputs_val.type(torch.float16 if hypar["model_digit"] == "half" else torch.float32)
68
- inputs_val = inputs_val.to(device)
69
- ds_val = net(inputs_val)[0]
70
- pred_val = ds_val[0][0,:,:,:]
71
-
72
- pred_val = F.interpolate(
73
- pred_val.unsqueeze(0).unsqueeze(0),
74
- size=(shapes_val[0][0], shapes_val[0][1]),
75
- mode='bilinear',
76
- align_corners=False
77
- ).squeeze()
78
-
79
- pred_val = (pred_val - pred_val.min()) / (pred_val.max() - pred_val.min() + 1e-8)
80
-
81
- return (pred_val.cpu().numpy() * 255).astype(np.uint8)
82
 
83
- # Model configuration
84
- hypar = {
85
  "model_path": "saved_models",
86
- "restore_model": "isnet.pth",
87
- "interm_sup": False,
88
- "model_digit": "full",
89
- "seed": 0,
90
- "cache_size": [1024, 1024],
91
  "input_size": [1024, 1024],
92
- "crop_size": [1024, 1024],
93
- "model": ISNetDIS()
94
  }
95
 
96
- # Initialize model
97
- net = build_model(hypar, device)
 
 
 
 
 
 
 
98
 
99
- def process_image(image):
 
100
  try:
101
- image_path = image if isinstance(image, str) else image.name
102
-
103
- # Verify image exists
104
- if not os.path.exists(image_path):
105
- raise FileNotFoundError(f"Image file not found: {image_path}")
106
 
107
- image_tensor, orig_size = load_image(image_path, hypar)
108
- mask = predict(net, image_tensor, orig_size, hypar, device)
 
 
 
109
 
110
- mask_img = Image.fromarray(mask).convert('L')
111
- rgb_img = Image.open(image_path).convert("RGB")
 
 
 
 
112
 
113
- rgba_img = rgb_img.copy()
114
- rgba_img.putalpha(mask_img)
 
 
 
115
 
116
- return rgba_img, mask_img
 
117
  except Exception as e:
118
  raise gr.Error(f"Error processing image: {str(e)}")
119
 
120
- # Interface setup
121
- title = "DIS Image Segmentation"
122
  description = """
123
- Highly Accurate Dichotomous Image Segmentation
124
- <br>GitHub: [xuebinqin/DIS](https://github.com/xuebinqin/DIS)
125
  """
126
 
127
- # Prepare examples
128
  examples = []
129
- for example_file in ["robot.png", "ship.png"]:
130
- if os.path.exists(example_file):
131
- examples.append([example_file])
132
 
133
- # Create Gradio interface
134
- with gr.Blocks(title=title) as app:
135
  gr.Markdown(f"## {title}")
136
  gr.Markdown(description)
137
 
138
  with gr.Row():
139
- with gr.Column():
140
- input_image = gr.Image(
141
- type="filepath",
142
- label="Input Image",
143
- height=400
144
- )
145
- submit_btn = gr.Button("Process", variant="primary")
146
-
147
- with gr.Column():
148
- output_rgba = gr.Image(
149
- label="Transparent Background",
150
- type="pil",
151
- height=400
152
- )
153
- output_mask = gr.Image(
154
- label="Segmentation Mask",
155
- type="pil",
156
- height=400
157
- )
158
 
159
  if examples:
160
  gr.Examples(
161
  examples=examples,
162
- inputs=input_image,
163
- outputs=[output_rgba, output_mask],
164
  fn=process_image,
165
  cache_examples=True,
166
- label="Example Images"
167
  )
168
 
169
  submit_btn.click(
170
  fn=process_image,
171
- inputs=input_image,
172
- outputs=[output_rgba, output_mask],
173
- api_name="predict"
174
  )
175
 
176
- # Launch application
177
  if __name__ == "__main__":
178
  app.launch(
179
  server_name="0.0.0.0",
180
  server_port=7860,
181
- show_error=True,
182
  share=False
183
  )
 
1
  import os
2
  import cv2
 
3
  import numpy as np
4
  import torch
 
5
  from torchvision import transforms
 
6
  from PIL import Image
7
+ import gradio as gr
 
8
 
9
+ # Set up device
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
 
12
+ # Clone model repository if needed
13
  if not os.path.exists("DIS"):
14
+ os.system("git clone https://github.com/xuebinqin/DIS")
15
  os.system("mv DIS/IS-Net/* .")
16
 
17
  # Import model components
 
18
  from models import ISNetDIS
19
+ from data_loader_cache import normalize
20
 
21
+ # Create model directory
22
  os.makedirs("saved_models", exist_ok=True)
23
  if os.path.exists("isnet.pth"):
24
  os.rename("isnet.pth", "saved_models/isnet.pth")
25
 
26
+ # Define image preprocessing
27
+ class ImageNormalizer:
28
  def __init__(self, mean=[0.5, 0.5, 0.5], std=[1.0, 1.0, 1.0]):
29
  self.mean = mean
30
  self.std = std
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
+ def __call__(self, img):
33
+ return normalize(img, self.mean, self.std)
 
 
 
 
 
34
 
35
+ transform = transforms.Compose([ImageNormalizer()])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
+ # Load and configure model
38
+ model_config = {
39
  "model_path": "saved_models",
40
+ "model_file": "isnet.pth",
 
 
 
 
41
  "input_size": [1024, 1024],
42
+ "device": device
 
43
  }
44
 
45
+ model = ISNetDIS().to(device)
46
+ if os.path.exists(f"{model_config['model_path']}/{model_config['model_file']}"):
47
+ model.load_state_dict(
48
+ torch.load(
49
+ f"{model_config['model_path']}/{model_config['model_file']}",
50
+ map_location=device
51
+ )
52
+ )
53
+ model.eval()
54
 
55
+ def process_image(input_image):
56
+ """Process an image through the segmentation model"""
57
  try:
58
+ # Convert Gradio input to usable image path
59
+ if hasattr(input_image, 'name'):
60
+ image_path = input_image.name
61
+ else:
62
+ image_path = input_image
63
 
64
+ # Read and preprocess image
65
+ img = cv2.imread(image_path)
66
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
67
+ img = torch.from_numpy(img).float().permute(2, 0, 1) / 255.0
68
+ img = transform(img).unsqueeze(0).to(device)
69
 
70
+ # Get prediction
71
+ with torch.no_grad():
72
+ pred = model(img)[0][0]
73
+ pred = torch.sigmoid(pred[0])
74
+ pred = (pred - pred.min()) / (pred.max() - pred.min() + 1e-8)
75
+ mask = (pred.cpu().numpy() * 255).astype(np.uint8)
76
 
77
+ # Create output images
78
+ original_img = Image.open(image_path).convert("RGB")
79
+ mask_img = Image.fromarray(mask).convert("L")
80
+ transparent_img = original_img.copy()
81
+ transparent_img.putalpha(mask_img)
82
 
83
+ return transparent_img, mask_img
84
+
85
  except Exception as e:
86
  raise gr.Error(f"Error processing image: {str(e)}")
87
 
88
+ # Gradio interface setup
89
+ title = "Image Background Removal"
90
  description = """
91
+ Upload an image to automatically remove the background using DIS (Dichotomous Image Segmentation).
92
+ <br>Model from: <a href="https://github.com/xuebinqin/DIS">xuebinqin/DIS</a>
93
  """
94
 
95
+ # Check for example images
96
  examples = []
97
+ for img_file in ["robot.png", "ship.png"]:
98
+ if os.path.exists(img_file):
99
+ examples.append([img_file])
100
 
101
+ # Create interface
102
+ with gr.Blocks() as app:
103
  gr.Markdown(f"## {title}")
104
  gr.Markdown(description)
105
 
106
  with gr.Row():
107
+ input_col = gr.Column()
108
+ output_col = gr.Column()
109
+
110
+ with input_col:
111
+ image_input = gr.Image(type="filepath", label="Upload Image")
112
+ submit_btn = gr.Button("Remove Background", variant="primary")
113
+
114
+ with output_col:
115
+ transparent_output = gr.Image(label="Transparent Result", type="pil")
116
+ mask_output = gr.Image(label="Segmentation Mask", type="pil")
 
 
 
 
 
 
 
 
 
117
 
118
  if examples:
119
  gr.Examples(
120
  examples=examples,
121
+ inputs=image_input,
122
+ outputs=[transparent_output, mask_output],
123
  fn=process_image,
124
  cache_examples=True,
125
+ label="Try Example Images"
126
  )
127
 
128
  submit_btn.click(
129
  fn=process_image,
130
+ inputs=image_input,
131
+ outputs=[transparent_output, mask_output]
 
132
  )
133
 
134
+ # Launch the app
135
  if __name__ == "__main__":
136
  app.launch(
137
  server_name="0.0.0.0",
138
  server_port=7860,
 
139
  share=False
140
  )