Alic22 commited on
Commit
536c5aa
·
verified ·
1 Parent(s): 5a2c9a9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -54
app.py CHANGED
@@ -55,7 +55,6 @@ print("Model ready!")
55
  ##################
56
 
57
  to_tensor = transforms.ToTensor()
58
- to_array = transforms.ToPILImage()
59
  resize = transforms.Resize((512,512))
60
  resize_small = transforms.Resize((369,369))
61
  normalize = transforms.Normalize(
@@ -91,19 +90,13 @@ def transparent(fg, bg, alpha_factor):
91
  return background
92
 
93
  def show_img(all_imgs, dropdown, bg, alpha_factor):
94
- if all_imgs is None:
95
- return None
96
 
97
- idx = target_list_all.index(dropdown)
98
- fg = all_imgs[idx]
 
99
 
100
- foreground = Image.open(fg)
101
- background = np.array(bg)
102
-
103
- background = Image.fromarray(bg)
104
- new_alpha_factor = int(255*alpha_factor)
105
- foreground.putalpha(new_alpha_factor)
106
- background.paste(foreground, (0, 0), foreground)
107
 
108
  return background
109
 
@@ -112,30 +105,19 @@ def show_img(all_imgs, dropdown, bg, alpha_factor):
112
 
113
  def inference(img):
114
  background = resize_pil(img)
115
-
116
  img = process_pil(img).unsqueeze(0)
117
 
118
  with torch.no_grad():
119
- mask = model(img)[0]
120
 
121
 
122
  # Get probability values (logits to probs)
123
- mask_probs = torch.sigmoid(mask)
124
- mask_probs = mask_probs.detach().numpy()
125
- mask_probs.shape
126
-
127
- # Make binary mask
128
- THRESHOLD = 0.5
129
- mask_preds = mask_probs > THRESHOLD
130
-
131
- # All combined
132
- mask_all = mask_preds.sum(axis=0)
133
- mask_all = np.expand_dims(mask_all, axis=0)
134
- mask_all.shape
135
 
136
  # Concat all combined with normal preds
137
- mask_preds = np.concatenate((mask_all, mask_preds),axis=0)
138
- labs = ["ALL"] + target_list
139
 
140
  fig, axes = plt.subplots(5, 4, figsize = (10,10))
141
 
@@ -143,34 +125,31 @@ def inference(img):
143
  all_masks = []
144
 
145
  for i, ax in enumerate(axes.flat):
146
- label = labs[i]
147
-
148
- all_masks.append(mask_preds[i])
149
-
150
- ax.imshow(mask_preds[i])
151
- ax.set_title(label)
152
-
153
- plt.tight_layout()
154
 
155
  # plt to PIL
156
- img_buf = io.BytesIO()
157
- fig.savefig(img_buf, format='png')
158
- im = Image.open(img_buf)
 
 
159
 
160
  # Saved all masks combined with unvisible xaxis und yaxis and without a white
161
  # background.
162
- all_images = []
163
- for i in range(len(all_masks)):
164
- plt.figure()
165
- fig = plt.imshow(all_masks[i])
166
  plt.axis('off')
167
- fig.axes.get_xaxis().set_visible(False)
168
- fig.axes.get_yaxis().set_visible(False)
169
- img_buf = io.BytesIO()
170
- plt.savefig(img_buf, bbox_inches='tight', pad_inches = 0, format='png')
171
- all_images.append(Image.open(img_buf))
172
 
173
- return im, all_images, background
174
 
175
 
176
 
@@ -212,16 +191,16 @@ with gr.Blocks(title=title) as app:
212
  dropdown = gr.Dropdown(choices=target_list_all, label="Select Label", value="All")
213
  slider = gr.Slider(minimum=0, maximum=1, value=0.4, label="Alpha Factor")
214
 
215
- all_masks = gr.Gallery(visible=False)
216
- background = gr.Image(visible=False)
217
 
218
  gr.Button("1) Generate Masks").click(fn=inference,
219
  inputs=[input_img],
220
- outputs=[img, all_masks, background])
221
 
222
  gr.Button("2) Generate Transparent Mask (with Alpha Factor)").click(fn=show_img,
223
- inputs=[all_masks, dropdown, background, slider],
224
- outputs=[transparent_img])
225
 
226
 
227
  app.launch()
 
55
  ##################
56
 
57
  to_tensor = transforms.ToTensor()
 
58
  resize = transforms.Resize((512,512))
59
  resize_small = transforms.Resize((369,369))
60
  normalize = transforms.Normalize(
 
90
  return background
91
 
92
  def show_img(all_imgs, dropdown, bg, alpha_factor):
 
 
93
 
94
+ idx = target_list_all.index(label)
95
+ fg = mask_images[idx].copy()
96
+ bg = bg.copy()
97
 
98
+ fg = putalpha(int(255 * alpha))
99
+ bg.paste(fg, (0, 0), fg)
 
 
 
 
 
100
 
101
  return background
102
 
 
105
 
106
  def inference(img):
107
  background = resize_pil(img)
 
108
  img = process_pil(img).unsqueeze(0)
109
 
110
  with torch.no_grad():
111
+ logits = model(img)[0]
112
 
113
 
114
  # Get probability values (logits to probs)
115
+ probs = torch.sigmoid(logits).numpy()
116
+ mask = probs > 0.5
 
 
 
 
 
 
 
 
 
 
117
 
118
  # Concat all combined with normal preds
119
+ mask_all = np.sum(masks, axis=0, keepdims=True)
120
+ masks = np.concatenate([mask_all, masks], axis=0)
121
 
122
  fig, axes = plt.subplots(5, 4, figsize = (10,10))
123
 
 
125
  all_masks = []
126
 
127
  for i, ax in enumerate(axes.flat):
128
+ ax.imshow(masks[i])
129
+ ax.set_title(target_list_all[i])
130
+ ax.axis("off")
131
+
 
 
 
 
132
 
133
  # plt to PIL
134
+ buf = io.BytesIO()
135
+ plt.tight_layout()
136
+ plt.savefig(buf, format='png')
137
+ plt.close()
138
+ preview = Image.open(buf)
139
 
140
  # Saved all masks combined with unvisible xaxis und yaxis and without a white
141
  # background.
142
+ mask_images = []
143
+ for m in masks:
144
+ fig = plt.figure()
145
+ plt.imshow(m)
146
  plt.axis('off')
147
+ buf = io.BytesIO()
148
+ plt.savefig(buf, bbox_inches='tight', pad_inches = 0)
149
+ plt.close()
150
+ mask_images.append(Image.open(buf).convert("RGBA"))
 
151
 
152
+ return preview, mask_images, background, mask_images
153
 
154
 
155
 
 
191
  dropdown = gr.Dropdown(choices=target_list_all, label="Select Label", value="All")
192
  slider = gr.Slider(minimum=0, maximum=1, value=0.4, label="Alpha Factor")
193
 
194
+ mask_state = gr.Sate()
195
+ bg_state = gr.State()
196
 
197
  gr.Button("1) Generate Masks").click(fn=inference,
198
  inputs=[input_img],
199
+ outputs=[preview, gr.Gallery(visible=False), bg_state, mask_state])
200
 
201
  gr.Button("2) Generate Transparent Mask (with Alpha Factor)").click(fn=show_img,
202
+ inputs=[mask_state, dropdown, bg_state, slider],
203
+ outputs=output)
204
 
205
 
206
  app.launch()