Alexvatti commited on
Commit
cd33515
·
verified ·
1 Parent(s): 473fe9e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -192
app.py CHANGED
@@ -24,174 +24,6 @@ CLASS_COLORS = [
24
  [0, 0, 0]
25
  ]
26
 
27
- class readDataset:
28
- def __init__(self, sarPathes, opticPathes, masksPathes):
29
- self.sarPathes = sarPathes
30
- self.opticPathes = opticPathes
31
- self.masksPathes = masksPathes
32
- self.sarImages = None
33
- self.opticImages = None
34
- self.masks = None
35
- self.testSarImages = None
36
- self.testopticImages = None
37
- self.testMasks = None
38
-
39
- def readPathes(self):
40
- # Get all file paths
41
- all_sar_images = natsort.natsorted(list(pathlib.Path(self.sarPathes).glob('*.*')))
42
- all_optic_images = natsort.natsorted(list(pathlib.Path(self.opticPathes).glob('*.*')))
43
- all_mask_images = natsort.natsorted(list(pathlib.Path(self.masksPathes).glob('*.*')))
44
-
45
- # Clean up .ipynb_checkpoints
46
- for directory in [self.sarPathes, self.opticPathes, self.masksPathes]:
47
- try:
48
- shutil.rmtree(os.path.join(directory, ".ipynb_checkpoints"))
49
- print(f".ipynb_checkpoints directory deleted successfully from {directory}.")
50
- except Exception:
51
- pass
52
-
53
- # Extract image IDs - just getting the filename without extension
54
- def extract_id(filepath):
55
- return pathlib.Path(str(filepath)).stem
56
-
57
- # Create dictionaries mapping IDs to paths for efficient lookup
58
- sar_dict = {extract_id(f): f for f in all_sar_images}
59
- optic_dict = {extract_id(f): f for f in all_optic_images}
60
- mask_dict = {extract_id(f): f for f in all_mask_images}
61
-
62
- # Find common IDs across all three datasets
63
- common_ids = set(sar_dict.keys()) & set(optic_dict.keys()) & set(mask_dict.keys())
64
-
65
- # Create matched file lists using sorted common IDs
66
- sorted_common_ids = natsort.natsorted(list(common_ids))
67
- self.sarImages = [sar_dict[id] for id in sorted_common_ids]
68
- self.opticImages = [optic_dict[id] for id in sorted_common_ids]
69
- self.masks = [mask_dict[id] for id in sorted_common_ids]
70
-
71
- print(f"(INFO..) Found {len(all_sar_images)} SAR, {len(all_optic_images)} optical, {len(all_mask_images)} mask images")
72
- print(f"(INFO..) Complete triplets: {len(common_ids)}")
73
-
74
- def convertColorToLabel(self, img):
75
- color_to_label = {
76
- (115, 178, 115): 0, # non_mining_land (green)
77
- (255, 0, 0): 1, # illegal_mining_land (red)
78
- (0, 0, 0): 2, # beach (black)
79
- }
80
- # Create empty label array
81
- label_img = np.zeros((img.shape[0], img.shape[1]), dtype=np.uint8)
82
-
83
- # Map colors to labels
84
- for color, label in color_to_label.items():
85
- mask = np.all(img == color, axis=2)
86
- label_img[mask] = label
87
-
88
- # One-hot encode labels
89
- num_classes = len(color_to_label)
90
- one_hot = np.zeros((img.shape[0], img.shape[1], num_classes), dtype=np.uint8)
91
- for c in range(num_classes):
92
- one_hot[:, :, c] = (label_img == c).astype(np.uint8)
93
-
94
- return one_hot
95
-
96
- def readImages(self, data, typeData, width, height):
97
- images = []
98
- for img in data:
99
- if typeData == 's': # SAR image
100
- with rasterio.open(str(img)) as src:
101
- sar_bands = [src.read(i) for i in range(1, src.count + 1)]
102
- sar_image = np.stack(sar_bands, axis=-1)
103
- # Stretching
104
- p2, p98 = np.percentile(sar_image, (2, 98))
105
- sar_image = np.clip(sar_image, p2, p98)
106
- sar_image = ((sar_image - p2) / (p98 - p2) * 255).astype(np.uint8)
107
- # Resize
108
- sar_image = cv2.resize(sar_image, (width, height), interpolation=cv2.INTER_AREA)
109
- images.append(np.expand_dims(sar_image, axis=-1))
110
- elif typeData == 'm': # Mask image
111
- img = cv2.imread(str(img), cv2.IMREAD_COLOR)
112
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
113
- img = cv2.resize(img, (width, height), interpolation=cv2.INTER_NEAREST)
114
- images.append(self.convertColorToLabel(img))
115
- elif typeData == 'o': # Optic image
116
- img = cv2.imread(str(img), cv2.IMREAD_COLOR)
117
- img = cv2.resize(cv2.cvtColor(img, cv2.COLOR_BGR2RGB), (width, height), interpolation=cv2.INTER_AREA)
118
- images.append(img)
119
- print(f"(INFO..) Read {len(images)} {typeData} images")
120
- return np.array(images)
121
-
122
- def normalizeImages(self, images, typeData):
123
- normalized_images = []
124
- for img in images:
125
- img = img.astype(np.uint8)
126
- if typeData == 's':
127
- img = img / 255.
128
- if typeData == 'o':
129
- img = img / 255.
130
- normalized_images.append(img)
131
- print("(INFO..) Normalization Image Done")
132
- return np.array(normalized_images)
133
-
134
- def dataAugmentation(self, sar_images, optic_images, masks, n_augments, size=WIDTH):
135
- # Define augmentation pipeline once
136
- augmentation = A.ReplayCompose([
137
- A.RandomResizedCrop(size=(size, size), scale=(0.2, 0.9), ratio=(1, 1),
138
- interpolation=cv2.INTER_AREA, mask_interpolation=cv2.INTER_NEAREST, p=0.5),
139
- A.HorizontalFlip(p=0.5),
140
- A.ShiftScaleRotate(scale_limit=(0.0, 0.15), rotate_limit=(-90, 90),
141
- interpolation=cv2.INTER_AREA, mask_interpolation=cv2.INTER_NEAREST,
142
- border_mode=cv2.BORDER_REFLECT, p=0.5),
143
- A.RandomGamma(p=0.5),
144
- A.RandomBrightnessContrast(brightness_limit=(-0.25, 0.25), contrast_limit=(-0.25, 0.25), p=0.5)
145
- ], additional_targets={'sar': 'image'})
146
-
147
- if not (len(sar_images) == len(optic_images) == len(masks)):
148
- raise ValueError("Number of SAR images, optic images, and masks must be the same.")
149
-
150
- # Initialize lists with original data
151
- augmented_sar = list(sar_images)
152
- augmented_optic = list(optic_images)
153
- augmented_masks = list(masks)
154
-
155
- # Perform augmentations
156
- for i, (sar, optic, mask) in enumerate(zip(sar_images, optic_images, masks)):
157
- for _ in range(n_augments):
158
- augmented = augmentation(image=optic.astype(np.uint8),
159
- mask=mask.astype(np.uint8),
160
- sar=sar.astype(np.uint8))
161
- augmented_sar.append(augmented['sar'])
162
- augmented_optic.append(augmented['image'])
163
- augmented_masks.append(augmented['mask'])
164
-
165
- # Print statistics
166
- total_original = len(optic_images)
167
- total_augmented = len(augmented_optic)
168
- print(f"(INFO..) Original Train Optic Images: {total_original}")
169
- print(f"(INFO..) Total Augmented Train Optic Images: {total_augmented}")
170
- print(f"(INFO..) Augmentation Multiplier: {total_augmented / total_original:.2f}x")
171
- print("(INFO..) Augmentation Image Done \n")
172
-
173
- return (np.array(augmented_sar), np.array(augmented_optic), np.array(augmented_masks))
174
-
175
- def splitDataset(self, sar_images, optic_images, masks, test_size=0.1, n_augments=10):
176
- data = list(zip(sar_images, optic_images, masks))
177
- train_data, test_data = train_test_split(data, test_size=test_size, random_state=42)
178
-
179
- # Unpack the training and test data
180
- train_sar, train_optic, train_masks = zip(*train_data)
181
- test_sar, test_optic, test_masks = zip(*test_data)
182
-
183
- # Augment train data
184
- train_sar_aug, train_optic_aug, train_masks_aug = self.dataAugmentation(
185
- np.array(train_sar), np.array(train_optic), np.array(train_masks),
186
- n_augments=n_augments
187
- )
188
-
189
- print("(INFO..) Splitting and Saving Data Done \n")
190
- return (
191
- np.array(train_sar_aug), np.array(train_optic_aug), np.array(train_masks_aug),
192
- np.array(test_sar), np.array(test_optic), np.array(test_masks)
193
- )
194
-
195
  @tf.keras.saving.register_keras_serializable()
196
  def dice_score(y_true, y_pred, threshold=0.5, smooth=1.0):
197
  #determine binary or multiclass segmentation
@@ -230,39 +62,83 @@ def cce_dice_loss(y_true, y_pred):
230
  dice = dice_loss(y_true, y_pred)
231
  return tf.cast(cce, dtype=tf.float32) + dice
232
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
 
 
234
  # Streamlit App Title
235
  st.title("Satellite Mining Segmentation: SAR + Optic Image Inference")
236
 
237
- # Sidebar inputs
238
- #dataset_path = st.text_input("Enter dataset path", "path/to/your/dataset")
239
- #model_path = st.text_input("Enter path to trained model (.h5)", "model.h5")
240
- dataset_path = "bangka_testing_data"
241
- model_path = "Residual_UNET_Bilinear.keras"
242
  num_samples = st.slider("Number of test samples to visualize", 1, 10, 3)
243
 
 
 
 
 
 
 
 
 
 
 
 
 
244
  if st.button("Run Inference"):
245
  with st.spinner("Loading data and model..."):
246
 
247
- # Prepare paths
248
- sar_path = os.path.join(dataset_path, 'sar_images')
249
- optic_path = os.path.join(dataset_path, 'optic_images')
250
- mask_path = os.path.join(dataset_path, 'masks')
251
-
252
- # Read dataset
253
- dataset = readDataset(
254
- sarPathes=sar_path,
255
- opticPathes=optic_path,
256
- masksPathes=mask_path
257
- )
258
- dataset.readPathes()
259
-
260
- sar_images = dataset.readImages(dataset.sarImages, typeData='s', width=WIDTH, height=HEIGHT)
261
- optic_images = dataset.readImages(dataset.opticImages, typeData='o', width=WIDTH, height=HEIGHT)
262
- masks = dataset.readImages(dataset.masks, typeData='m', width=WIDTH, height=HEIGHT)
263
-
264
- sar_images = dataset.normalizeImages(sar_images, 's')
265
- optic_images = dataset.normalizeImages(optic_images, 'i')
266
 
267
  # Load model
268
  model = tf.keras.models.load_model(model_path,
 
24
  [0, 0, 0]
25
  ]
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  @tf.keras.saving.register_keras_serializable()
28
  def dice_score(y_true, y_pred, threshold=0.5, smooth=1.0):
29
  #determine binary or multiclass segmentation
 
62
  dice = dice_loss(y_true, y_pred)
63
  return tf.cast(cce, dtype=tf.float32) + dice
64
 
65
+ def readImages(data, typeData, width, height):
66
+ images = []
67
+ for img in data:
68
+ if typeData == 's': # SAR image
69
+ with rasterio.open(str(img)) as src:
70
+ sar_bands = [src.read(i) for i in range(1, src.count + 1)]
71
+ sar_image = np.stack(sar_bands, axis=-1)
72
+
73
+ # Contrast stretching
74
+ p2, p98 = np.percentile(sar_image, (2, 98))
75
+ sar_image = np.clip(sar_image, p2, p98)
76
+ sar_image = ((sar_image - p2) / (p98 - p2) * 255).astype(np.uint8)
77
+
78
+ # Resize
79
+ sar_image = cv2.resize(sar_image, (width, height), interpolation=cv2.INTER_AREA)
80
+ images.append(np.expand_dims(sar_image, axis=-1))
81
+
82
+ elif typeData == 'm': # Mask image
83
+ img = cv2.imread(str(img), cv2.IMREAD_COLOR)
84
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
85
+ img = cv2.resize(img, (width, height), interpolation=cv2.INTER_NEAREST)
86
+ images.append(self.convertColorToLabel(img))
87
+
88
+ elif typeData == 'o': # Optic image
89
+ img = cv2.imread(str(img), cv2.IMREAD_COLOR)
90
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
91
+ img = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
92
+ images.append(img)
93
+
94
+ print(f"(INFO..) Read {len(images)} '{typeData}' image(s)")
95
+ return np.array(images)
96
+
97
+
98
+ def normalizeImages(images, typeData):
99
+ normalized_images = []
100
+ for img in images:
101
+ img = img.astype(np.uint8)
102
+ if typeData in ['s', 'o']:
103
+ img = img / 255.
104
+ normalized_images.append(img)
105
+
106
+ print("(INFO..) Normalization Image Done")
107
+ return np.array(normalized_images)
108
 
109
+
110
  # Streamlit App Title
111
  st.title("Satellite Mining Segmentation: SAR + Optic Image Inference")
112
 
113
+
114
+ sar_file = st.file_uploader("Upload SAR Image", type=["tiff"])
115
+ optic_file = st.file_uploader("Upload Optical Image", type=["tiff"])
116
+ mask_file = st.file_uploader("Upload Mask Image", type=["tiff"])
117
+
118
  num_samples = st.slider("Number of test samples to visualize", 1, 10, 3)
119
 
120
+ if sar_file is not None and optic_file is not None and mask_file is not None:
121
+ st.success("All files uploaded successfully!")
122
+ st.write(f"Number of samples selected for visualization: {num_samples}")
123
+ else:
124
+ st.warning("Please upload all three .tiff files to proceed.")
125
+
126
+ sarImages = [sar_file]
127
+ opticImages = [optic_file]
128
+ masks = [mask_file]
129
+ model_path = "Residual_UNET_Bilinear.keras"
130
+
131
+
132
  if st.button("Run Inference"):
133
  with st.spinner("Loading data and model..."):
134
 
135
+
136
+ sar_images = readImages(sarImages, typeData='s', width=WIDTH, height=HEIGHT)
137
+ optic_images = readImages(opticImages, typeData='o', width=WIDTH, height=HEIGHT)
138
+ masks = readImages(masks, typeData='m', width=WIDTH, height=HEIGHT)
139
+
140
+ sar_images = normalizeImages(sar_images, 's')
141
+ optic_images = normalizeImages(optic_images, 'i')
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
  # Load model
144
  model = tf.keras.models.load_model(model_path,