Alexvatti commited on
Commit
b90fed9
·
verified ·
1 Parent(s): c07cd59

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -0
app.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ from tensorflow.keras.models import load_model
6
+ from your_dataset_module import readDataset # Replace with your actual dataset module
7
+
8
+ # Configuration
9
+ HEIGHT = WIDTH = 256
10
+ SAR_SHAPE = (HEIGHT, WIDTH, 1)
11
+ OPTIC_SHAPE = (HEIGHT, WIDTH, 3)
12
+ MASK_SHAPE = (HEIGHT, WIDTH, 4) # One-hot encoded masks with 4 classes
13
+
14
+ # Class colors: non-mining (green), illegal mining (red), beach (black)
15
+ CLASS_COLORS = [
16
+ [115/255, 178/255, 115/255],
17
+ [1, 0, 0],
18
+ [0, 0, 0]
19
+ ]
20
+
21
+ # Streamlit App Title
22
+ st.title("Satellite Mining Segmentation: SAR + Optic Image Inference")
23
+
24
+ # Sidebar inputs
25
+ dataset_path = st.text_input("Enter dataset path", "path/to/your/dataset")
26
+ model_path = st.text_input("Enter path to trained model (.h5)", "model.h5")
27
+ num_samples = st.slider("Number of test samples to visualize", 1, 10, 3)
28
+
29
+ if st.button("Run Inference"):
30
+ with st.spinner("Loading data and model..."):
31
+
32
+ # Prepare paths
33
+ sar_path = os.path.join(dataset_path, 'sar_images')
34
+ optic_path = os.path.join(dataset_path, 'optic_images')
35
+ mask_path = os.path.join(dataset_path, 'masks')
36
+
37
+ # Read dataset
38
+ dataset = readDataset(
39
+ sarPathes=sar_path,
40
+ opticPathes=optic_path,
41
+ masksPathes=mask_path
42
+ )
43
+ dataset.readPathes()
44
+
45
+ sar_images = dataset.readImages(dataset.sarImages, typeData='s', width=WIDTH, height=HEIGHT)
46
+ optic_images = dataset.readImages(dataset.opticImages, typeData='o', width=WIDTH, height=HEIGHT)
47
+ masks = dataset.readImages(dataset.masks, typeData='m', width=WIDTH, height=HEIGHT)
48
+
49
+ sar_images = dataset.normalizeImages(sar_images, 's')
50
+ optic_images = dataset.normalizeImages(optic_images, 'i')
51
+
52
+ # Load model
53
+ model = load_model(model_path)
54
+
55
+ # Predict
56
+ pred_masks = model.predict([optic_images, sar_images], verbose=0)
57
+ is_multiclass = pred_masks.shape[-1] > 1
58
+
59
+ num_samples = min(num_samples, len(sar_images))
60
+
61
+ # Plotting
62
+ fig, axes = plt.subplots(num_samples, 4, figsize=(21, 6 * num_samples))
63
+
64
+ for i in range(num_samples):
65
+ ax = axes[i] if num_samples > 1 else axes
66
+
67
+ ax[0].imshow(sar_images[i].squeeze(), cmap='gray')
68
+ ax[0].set_title(f"SAR Image {i+1}")
69
+ ax[0].axis('off')
70
+
71
+ ax[1].imshow(optic_images[i])
72
+ ax[1].set_title(f"Optic Image {i+1}")
73
+ ax[1].axis('off')
74
+
75
+ if is_multiclass:
76
+ gt_color_mask = np.zeros((*masks[i].shape[:2], 3))
77
+ for j, color in enumerate(CLASS_COLORS):
78
+ gt_color_mask += masks[i][:,:,j][:,:,np.newaxis] * np.array(color)
79
+ ax[2].imshow(gt_color_mask)
80
+ else:
81
+ ax[2].imshow(masks[i], cmap='gray')
82
+ ax[2].set_title(f"Ground Truth Mask {i+1}")
83
+ ax[2].axis('off')
84
+
85
+ if is_multiclass:
86
+ pred_color_mask = np.zeros((*pred_masks[i].shape[:2], 3))
87
+ for j, color in enumerate(CLASS_COLORS):
88
+ pred_color_mask += pred_masks[i][:,:,j][:,:,np.newaxis] * np.array(color)
89
+ ax[3].imshow(pred_color_mask)
90
+ else:
91
+ ax[3].imshow(pred_masks[i], cmap='gray')
92
+ ax[3].set_title(f"Predicted Mask {i+1}")
93
+ ax[3].axis('off')
94
+
95
+ st.pyplot(fig)