Alexvatti commited on
Commit
c3f9b72
·
verified ·
1 Parent(s): 662786f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -117
app.py CHANGED
@@ -146,7 +146,7 @@ sar_file = st.file_uploader("Upload SAR Image", type=["tiff"])
146
  optic_file = st.file_uploader("Upload Optical Image", type=["tiff"])
147
  mask_file = st.file_uploader("Upload Mask Image", type=["tiff"])
148
 
149
- num_samples = st.slider("Number of test samples to visualize", 1, 10, 3)
150
 
151
  if sar_file is not None and optic_file is not None and mask_file is not None:
152
  st.success("All files uploaded successfully!")
@@ -154,123 +154,116 @@ if sar_file is not None and optic_file is not None and mask_file is not None:
154
  sar_path = save_uploaded_file(sar_file, suffix=".tif")
155
  optic_path = save_uploaded_file(optic_file, suffix=".tif")
156
  mask_path = save_uploaded_file(mask_file, suffix=".tif")
157
-
158
- st.write("Temporary paths created for inference:")
159
- st.code(f"SAR Path: {sar_path}")
160
- st.code(f"Optic Path: {optic_path}")
161
- st.code(f"Mask Path: {mask_path}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
 
 
 
 
163
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  else:
165
  st.warning("Please upload all three .tiff files to proceed.")
166
-
167
- sarImages = [sar_path]
168
- opticImages = [optic_path]
169
- masks = [mask_path]
170
- model_path = "Residual_UNET_Bilinear.keras"
171
-
172
-
173
- if st.button("Run Inference"):
174
- with st.spinner("Loading data and model..."):
175
-
176
- sar_images = readImages(sarImages, typeData='s', width=WIDTH, height=HEIGHT)
177
- optic_images = readImages(opticImages, typeData='o', width=WIDTH, height=HEIGHT)
178
- masks = readImages(masks, typeData='m', width=WIDTH, height=HEIGHT)
179
-
180
- sar_images = normalizeImages(sar_images, 's')
181
- optic_images = normalizeImages(optic_images, 'i')
182
-
183
- # Load model
184
- model = tf.keras.models.load_model(model_path,
185
- custom_objects={"cce_dice_loss": cce_dice_loss, "dice_score": dice_score})
186
-
187
- pred_masks = model.predict([optic_images, sar_images], verbose=0)
188
- is_multiclass = pred_masks.shape[-1] > 1
189
-
190
- num_samples = min(num_samples, len(sar_images))
191
-
192
- # Plotting
193
- fig, axes = plt.subplots(num_samples, 4, figsize=(21, 6 * num_samples))
194
-
195
- for i in range(num_samples):
196
- ax = axes[i] if num_samples > 1 else axes
197
-
198
- ax[0].imshow(sar_images[i].squeeze(), cmap='gray')
199
- ax[0].set_title(f"SAR Image {i+1}")
200
- ax[0].axis('off')
201
-
202
- ax[1].imshow(optic_images[i])
203
- ax[1].set_title(f"Optic Image {i+1}")
204
- ax[1].axis('off')
205
-
206
- if is_multiclass:
207
- gt_color_mask = np.zeros((*masks[i].shape[:2], 3))
208
- for j, color in enumerate(CLASS_COLORS):
209
- gt_color_mask += masks[i][:,:,j][:,:,np.newaxis] * np.array(color)
210
- ax[2].imshow(gt_color_mask)
211
- else:
212
- ax[2].imshow(masks[i], cmap='gray')
213
- ax[2].set_title(f"Ground Truth Mask {i+1}")
214
- ax[2].axis('off')
215
-
216
- if is_multiclass:
217
- pred_color_mask = np.zeros((*pred_masks[i].shape[:2], 3))
218
- for j, color in enumerate(CLASS_COLORS):
219
- pred_color_mask += pred_masks[i][:,:,j][:,:,np.newaxis] * np.array(color)
220
- ax[3].imshow(pred_color_mask)
221
- else:
222
- ax[3].imshow(pred_masks[i], cmap='gray')
223
- ax[3].set_title(f"Predicted Mask {i+1}")
224
- ax[3].axis('off')
225
-
226
- st.pyplot(fig)
227
-
228
- # Define color for class 1: illegal mining
229
- red_color = [255, 0, 0]
230
-
231
- # Convert optic_images to uint8 if needed
232
- if optic_images.dtype != np.uint8:
233
- optic_images = (optic_images * 255).astype(np.uint8)
234
-
235
- # Create figure with subplots
236
- fig, axes = plt.subplots(num_samples, 4, figsize=(21, 6 * num_samples))
237
-
238
- for i in range(num_samples):
239
- ax = axes[i] if num_samples > 1 else axes
240
-
241
- # SAR image
242
- ax[0].imshow(sar_images[i].squeeze(), cmap='gray')
243
- ax[0].set_title(f"SAR Image {i+1}")
244
- ax[0].axis('off')
245
-
246
- # Optic image
247
- ax[1].imshow(optic_images[i])
248
- ax[1].set_title(f"Optic Image {i+1}")
249
- ax[1].axis('off')
250
-
251
- # Ground truth overlay
252
- gt_overlay = optic_images[i].copy()
253
- if is_multiclass:
254
- gt_overlay[masks[i][:, :, 1] == 1] = red_color
255
- else:
256
- gt_overlay[masks[i].squeeze() == 1] = red_color
257
-
258
- ax[2].imshow(optic_images[i])
259
- ax[2].imshow(gt_overlay, alpha=0.4)
260
- ax[2].set_title(f"Ground Truth Overlay {i+1}")
261
- ax[2].axis('off')
262
-
263
- # Predicted mask overlay
264
- pred_overlay = optic_images[i].copy()
265
- if is_multiclass:
266
- pred_overlay[pred_masks[i][:, :, 1] > 0.5] = red_color
267
- else:
268
- pred_overlay[pred_masks[i].squeeze() > 0.5] = red_color
269
-
270
- ax[3].imshow(optic_images[i])
271
- ax[3].imshow(pred_overlay, alpha=0.4)
272
- ax[3].set_title(f"Predicted Overlay {i+1}")
273
- ax[3].axis('off')
274
-
275
- plt.tight_layout()
276
- st.pyplot(fig)
 
146
  optic_file = st.file_uploader("Upload Optical Image", type=["tiff"])
147
  mask_file = st.file_uploader("Upload Mask Image", type=["tiff"])
148
 
149
+ num_samples = st.slider("Number of test samples to visualize", 1, 10, 1)
150
 
151
  if sar_file is not None and optic_file is not None and mask_file is not None:
152
  st.success("All files uploaded successfully!")
 
154
  sar_path = save_uploaded_file(sar_file, suffix=".tif")
155
  optic_path = save_uploaded_file(optic_file, suffix=".tif")
156
  mask_path = save_uploaded_file(mask_file, suffix=".tif")
157
+
158
+ sarImages = [sar_path]
159
+ opticImages = [optic_path]
160
+ masks = [mask_path]
161
+ model_path = "Residual_UNET_Bilinear.keras"
162
+
163
+
164
+ if st.button("Run Inference"):
165
+ with st.spinner("Loading data and model..."):
166
+
167
+ sar_images = readImages(sarImages, typeData='s', width=WIDTH, height=HEIGHT)
168
+ optic_images = readImages(opticImages, typeData='o', width=WIDTH, height=HEIGHT)
169
+ masks = readImages(masks, typeData='m', width=WIDTH, height=HEIGHT)
170
+
171
+ sar_images = normalizeImages(sar_images, 's')
172
+ optic_images = normalizeImages(optic_images, 'i')
173
+
174
+ # Load model
175
+ model = tf.keras.models.load_model(model_path,
176
+ custom_objects={"cce_dice_loss": cce_dice_loss, "dice_score": dice_score})
177
+
178
+ pred_masks = model.predict([optic_images, sar_images], verbose=0)
179
+ is_multiclass = pred_masks.shape[-1] > 1
180
+
181
+ num_samples = min(num_samples, len(sar_images))
182
+
183
+ # Plotting
184
+ fig, axes = plt.subplots(num_samples, 4, figsize=(21, 6 * num_samples))
185
+
186
+ for i in range(num_samples):
187
+ ax = axes[i] if num_samples > 1 else axes
188
 
189
+ ax[0].imshow(sar_images[i].squeeze(), cmap='gray')
190
+ ax[0].set_title(f"SAR Image {i+1}")
191
+ ax[0].axis('off')
192
 
193
+ ax[1].imshow(optic_images[i])
194
+ ax[1].set_title(f"Optic Image {i+1}")
195
+ ax[1].axis('off')
196
+
197
+ if is_multiclass:
198
+ gt_color_mask = np.zeros((*masks[i].shape[:2], 3))
199
+ for j, color in enumerate(CLASS_COLORS):
200
+ gt_color_mask += masks[i][:,:,j][:,:,np.newaxis] * np.array(color)
201
+ ax[2].imshow(gt_color_mask)
202
+ else:
203
+ ax[2].imshow(masks[i], cmap='gray')
204
+ ax[2].set_title(f"Ground Truth Mask {i+1}")
205
+ ax[2].axis('off')
206
+
207
+ if is_multiclass:
208
+ pred_color_mask = np.zeros((*pred_masks[i].shape[:2], 3))
209
+ for j, color in enumerate(CLASS_COLORS):
210
+ pred_color_mask += pred_masks[i][:,:,j][:,:,np.newaxis] * np.array(color)
211
+ ax[3].imshow(pred_color_mask)
212
+ else:
213
+ ax[3].imshow(pred_masks[i], cmap='gray')
214
+ ax[3].set_title(f"Predicted Mask {i+1}")
215
+ ax[3].axis('off')
216
+
217
+ st.pyplot(fig)
218
+
219
+ # Define color for class 1: illegal mining
220
+ red_color = [255, 0, 0]
221
+
222
+ # Convert optic_images to uint8 if needed
223
+ if optic_images.dtype != np.uint8:
224
+ optic_images = (optic_images * 255).astype(np.uint8)
225
+
226
+ # Create figure with subplots
227
+ fig, axes = plt.subplots(num_samples, 4, figsize=(21, 6 * num_samples))
228
+
229
+ for i in range(num_samples):
230
+ ax = axes[i] if num_samples > 1 else axes
231
+
232
+ # SAR image
233
+ ax[0].imshow(sar_images[i].squeeze(), cmap='gray')
234
+ ax[0].set_title(f"SAR Image {i+1}")
235
+ ax[0].axis('off')
236
+
237
+ # Optic image
238
+ ax[1].imshow(optic_images[i])
239
+ ax[1].set_title(f"Optic Image {i+1}")
240
+ ax[1].axis('off')
241
+
242
+ # Ground truth overlay
243
+ gt_overlay = optic_images[i].copy()
244
+ if is_multiclass:
245
+ gt_overlay[masks[i][:, :, 1] == 1] = red_color
246
+ else:
247
+ gt_overlay[masks[i].squeeze() == 1] = red_color
248
+
249
+ ax[2].imshow(optic_images[i])
250
+ ax[2].imshow(gt_overlay, alpha=0.4)
251
+ ax[2].set_title(f"Ground Truth Overlay {i+1}")
252
+ ax[2].axis('off')
253
+
254
+ # Predicted mask overlay
255
+ pred_overlay = optic_images[i].copy()
256
+ if is_multiclass:
257
+ pred_overlay[pred_masks[i][:, :, 1] > 0.5] = red_color
258
+ else:
259
+ pred_overlay[pred_masks[i].squeeze() > 0.5] = red_color
260
+
261
+ ax[3].imshow(optic_images[i])
262
+ ax[3].imshow(pred_overlay, alpha=0.4)
263
+ ax[3].set_title(f"Predicted Overlay {i+1}")
264
+ ax[3].axis('off')
265
+
266
+ plt.tight_layout()
267
+ st.pyplot(fig)
268
  else:
269
  st.warning("Please upload all three .tiff files to proceed.")