Alexvatti commited on
Commit
d98d462
·
verified ·
1 Parent(s): c3f9b72

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -46
app.py CHANGED
@@ -147,123 +147,134 @@ optic_file = st.file_uploader("Upload Optical Image", type=["tiff"])
147
  mask_file = st.file_uploader("Upload Mask Image", type=["tiff"])
148
 
149
  num_samples = st.slider("Number of test samples to visualize", 1, 10, 1)
 
 
 
 
 
150
 
151
- if sar_file is not None and optic_file is not None and mask_file is not None:
152
- st.success("All files uploaded successfully!")
153
- st.write(f"Number of samples selected for visualization: {num_samples}")
154
- sar_path = save_uploaded_file(sar_file, suffix=".tif")
155
- optic_path = save_uploaded_file(optic_file, suffix=".tif")
156
- mask_path = save_uploaded_file(mask_file, suffix=".tif")
157
-
158
- sarImages = [sar_path]
159
- opticImages = [optic_path]
160
- masks = [mask_path]
161
- model_path = "Residual_UNET_Bilinear.keras"
162
-
163
-
164
- if st.button("Run Inference"):
165
- with st.spinner("Loading data and model..."):
166
-
167
  sar_images = readImages(sarImages, typeData='s', width=WIDTH, height=HEIGHT)
168
  optic_images = readImages(opticImages, typeData='o', width=WIDTH, height=HEIGHT)
169
  masks = readImages(masks, typeData='m', width=WIDTH, height=HEIGHT)
170
-
171
  sar_images = normalizeImages(sar_images, 's')
172
  optic_images = normalizeImages(optic_images, 'i')
173
-
174
  # Load model
175
- model = tf.keras.models.load_model(model_path,
176
- custom_objects={"cce_dice_loss": cce_dice_loss, "dice_score": dice_score})
177
-
 
 
 
178
  pred_masks = model.predict([optic_images, sar_images], verbose=0)
179
  is_multiclass = pred_masks.shape[-1] > 1
180
-
181
  num_samples = min(num_samples, len(sar_images))
182
-
183
- # Plotting
184
  fig, axes = plt.subplots(num_samples, 4, figsize=(21, 6 * num_samples))
185
-
186
  for i in range(num_samples):
187
  ax = axes[i] if num_samples > 1 else axes
188
-
 
189
  ax[0].imshow(sar_images[i].squeeze(), cmap='gray')
190
  ax[0].set_title(f"SAR Image {i+1}")
191
  ax[0].axis('off')
192
-
 
193
  ax[1].imshow(optic_images[i])
194
  ax[1].set_title(f"Optic Image {i+1}")
195
  ax[1].axis('off')
196
-
 
197
  if is_multiclass:
198
  gt_color_mask = np.zeros((*masks[i].shape[:2], 3))
199
  for j, color in enumerate(CLASS_COLORS):
200
- gt_color_mask += masks[i][:,:,j][:,:,np.newaxis] * np.array(color)
201
  ax[2].imshow(gt_color_mask)
202
  else:
203
  ax[2].imshow(masks[i], cmap='gray')
204
  ax[2].set_title(f"Ground Truth Mask {i+1}")
205
  ax[2].axis('off')
206
-
 
207
  if is_multiclass:
208
  pred_color_mask = np.zeros((*pred_masks[i].shape[:2], 3))
209
  for j, color in enumerate(CLASS_COLORS):
210
- pred_color_mask += pred_masks[i][:,:,j][:,:,np.newaxis] * np.array(color)
211
  ax[3].imshow(pred_color_mask)
212
  else:
213
  ax[3].imshow(pred_masks[i], cmap='gray')
214
  ax[3].set_title(f"Predicted Mask {i+1}")
215
  ax[3].axis('off')
216
-
217
  st.pyplot(fig)
218
-
219
  # Define color for class 1: illegal mining
220
  red_color = [255, 0, 0]
221
-
222
  # Convert optic_images to uint8 if needed
223
  if optic_images.dtype != np.uint8:
224
  optic_images = (optic_images * 255).astype(np.uint8)
225
-
226
- # Create figure with subplots
227
  fig, axes = plt.subplots(num_samples, 4, figsize=(21, 6 * num_samples))
228
-
229
  for i in range(num_samples):
230
  ax = axes[i] if num_samples > 1 else axes
231
-
232
  # SAR image
233
  ax[0].imshow(sar_images[i].squeeze(), cmap='gray')
234
  ax[0].set_title(f"SAR Image {i+1}")
235
  ax[0].axis('off')
236
-
237
  # Optic image
238
  ax[1].imshow(optic_images[i])
239
  ax[1].set_title(f"Optic Image {i+1}")
240
  ax[1].axis('off')
241
-
242
  # Ground truth overlay
243
  gt_overlay = optic_images[i].copy()
244
  if is_multiclass:
245
  gt_overlay[masks[i][:, :, 1] == 1] = red_color
246
  else:
247
  gt_overlay[masks[i].squeeze() == 1] = red_color
248
-
249
  ax[2].imshow(optic_images[i])
250
  ax[2].imshow(gt_overlay, alpha=0.4)
251
  ax[2].set_title(f"Ground Truth Overlay {i+1}")
252
  ax[2].axis('off')
253
-
254
  # Predicted mask overlay
255
  pred_overlay = optic_images[i].copy()
256
  if is_multiclass:
257
  pred_overlay[pred_masks[i][:, :, 1] > 0.5] = red_color
258
  else:
259
  pred_overlay[pred_masks[i].squeeze() > 0.5] = red_color
260
-
261
  ax[3].imshow(optic_images[i])
262
  ax[3].imshow(pred_overlay, alpha=0.4)
263
  ax[3].set_title(f"Predicted Overlay {i+1}")
264
  ax[3].axis('off')
265
-
266
  plt.tight_layout()
267
  st.pyplot(fig)
268
- else:
269
- st.warning("Please upload all three .tiff files to proceed.")
 
 
147
  mask_file = st.file_uploader("Upload Mask Image", type=["tiff"])
148
 
149
  num_samples = st.slider("Number of test samples to visualize", 1, 10, 1)
150
+ if st.button("Run Inference"):
151
+ with st.spinner("Loading data and model..."):
152
+ if sar_file is not None and optic_file is not None and mask_file is not None:
153
+ st.success("All files uploaded successfully!")
154
+ st.write(f"Number of samples selected for visualization: {num_samples}")
155
 
156
+ # Save uploaded files
157
+ sar_path = save_uploaded_file(sar_file, suffix=".tif")
158
+ optic_path = save_uploaded_file(optic_file, suffix=".tif")
159
+ mask_path = save_uploaded_file(mask_file, suffix=".tif")
160
+
161
+ # Create image lists
162
+ sarImages = [sar_path]
163
+ opticImages = [optic_path]
164
+ masks = [mask_path]
165
+
166
+ # Model path
167
+ model_path = "Residual_UNET_Bilinear.keras"
168
+
169
+ # Read and normalize images
 
 
170
  sar_images = readImages(sarImages, typeData='s', width=WIDTH, height=HEIGHT)
171
  optic_images = readImages(opticImages, typeData='o', width=WIDTH, height=HEIGHT)
172
  masks = readImages(masks, typeData='m', width=WIDTH, height=HEIGHT)
173
+
174
  sar_images = normalizeImages(sar_images, 's')
175
  optic_images = normalizeImages(optic_images, 'i')
176
+
177
  # Load model
178
+ model = tf.keras.models.load_model(
179
+ model_path,
180
+ custom_objects={"cce_dice_loss": cce_dice_loss, "dice_score": dice_score}
181
+ )
182
+
183
+ # Predict masks
184
  pred_masks = model.predict([optic_images, sar_images], verbose=0)
185
  is_multiclass = pred_masks.shape[-1] > 1
186
+
187
  num_samples = min(num_samples, len(sar_images))
188
+
189
+ # Plotting results
190
  fig, axes = plt.subplots(num_samples, 4, figsize=(21, 6 * num_samples))
191
+
192
  for i in range(num_samples):
193
  ax = axes[i] if num_samples > 1 else axes
194
+
195
+ # Plot SAR image
196
  ax[0].imshow(sar_images[i].squeeze(), cmap='gray')
197
  ax[0].set_title(f"SAR Image {i+1}")
198
  ax[0].axis('off')
199
+
200
+ # Plot Optic image
201
  ax[1].imshow(optic_images[i])
202
  ax[1].set_title(f"Optic Image {i+1}")
203
  ax[1].axis('off')
204
+
205
+ # Plot Ground Truth Mask
206
  if is_multiclass:
207
  gt_color_mask = np.zeros((*masks[i].shape[:2], 3))
208
  for j, color in enumerate(CLASS_COLORS):
209
+ gt_color_mask += masks[i][:, :, j][:, :, np.newaxis] * np.array(color)
210
  ax[2].imshow(gt_color_mask)
211
  else:
212
  ax[2].imshow(masks[i], cmap='gray')
213
  ax[2].set_title(f"Ground Truth Mask {i+1}")
214
  ax[2].axis('off')
215
+
216
+ # Plot Predicted Mask
217
  if is_multiclass:
218
  pred_color_mask = np.zeros((*pred_masks[i].shape[:2], 3))
219
  for j, color in enumerate(CLASS_COLORS):
220
+ pred_color_mask += pred_masks[i][:, :, j][:, :, np.newaxis] * np.array(color)
221
  ax[3].imshow(pred_color_mask)
222
  else:
223
  ax[3].imshow(pred_masks[i], cmap='gray')
224
  ax[3].set_title(f"Predicted Mask {i+1}")
225
  ax[3].axis('off')
226
+
227
  st.pyplot(fig)
228
+
229
  # Define color for class 1: illegal mining
230
  red_color = [255, 0, 0]
231
+
232
  # Convert optic_images to uint8 if needed
233
  if optic_images.dtype != np.uint8:
234
  optic_images = (optic_images * 255).astype(np.uint8)
235
+
236
+ # Plot overlays
237
  fig, axes = plt.subplots(num_samples, 4, figsize=(21, 6 * num_samples))
238
+
239
  for i in range(num_samples):
240
  ax = axes[i] if num_samples > 1 else axes
241
+
242
  # SAR image
243
  ax[0].imshow(sar_images[i].squeeze(), cmap='gray')
244
  ax[0].set_title(f"SAR Image {i+1}")
245
  ax[0].axis('off')
246
+
247
  # Optic image
248
  ax[1].imshow(optic_images[i])
249
  ax[1].set_title(f"Optic Image {i+1}")
250
  ax[1].axis('off')
251
+
252
  # Ground truth overlay
253
  gt_overlay = optic_images[i].copy()
254
  if is_multiclass:
255
  gt_overlay[masks[i][:, :, 1] == 1] = red_color
256
  else:
257
  gt_overlay[masks[i].squeeze() == 1] = red_color
258
+
259
  ax[2].imshow(optic_images[i])
260
  ax[2].imshow(gt_overlay, alpha=0.4)
261
  ax[2].set_title(f"Ground Truth Overlay {i+1}")
262
  ax[2].axis('off')
263
+
264
  # Predicted mask overlay
265
  pred_overlay = optic_images[i].copy()
266
  if is_multiclass:
267
  pred_overlay[pred_masks[i][:, :, 1] > 0.5] = red_color
268
  else:
269
  pred_overlay[pred_masks[i].squeeze() > 0.5] = red_color
270
+
271
  ax[3].imshow(optic_images[i])
272
  ax[3].imshow(pred_overlay, alpha=0.4)
273
  ax[3].set_title(f"Predicted Overlay {i+1}")
274
  ax[3].axis('off')
275
+
276
  plt.tight_layout()
277
  st.pyplot(fig)
278
+ else:
279
+ st.warning("Please upload all three .tiff files to proceed.")
280
+