Alexvatti commited on
Commit
b2d94fd
·
verified ·
1 Parent(s): 2327efa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +171 -1
app.py CHANGED
@@ -3,7 +3,10 @@ import os
3
  import numpy as np
4
  import matplotlib.pyplot as plt
5
  from tensorflow.keras.models import load_model
6
- from your_dataset_module import readDataset # Replace with your actual dataset module
 
 
 
7
 
8
  # Configuration
9
  HEIGHT = WIDTH = 256
@@ -18,6 +21,173 @@ CLASS_COLORS = [
18
  [0, 0, 0]
19
  ]
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  # Streamlit App Title
22
  st.title("Satellite Mining Segmentation: SAR + Optic Image Inference")
23
 
 
3
  import numpy as np
4
  import matplotlib.pyplot as plt
5
  from tensorflow.keras.models import load_model
6
+ import pathlib
7
+ import natsort
8
+ import datetime
9
+ import shutil
10
 
11
  # Configuration
12
  HEIGHT = WIDTH = 256
 
21
  [0, 0, 0]
22
  ]
23
 
24
+ class readDataset:
25
+ def __init__(self, sarPathes, opticPathes, masksPathes):
26
+ self.sarPathes = sarPathes
27
+ self.opticPathes = opticPathes
28
+ self.masksPathes = masksPathes
29
+ self.sarImages = None
30
+ self.opticImages = None
31
+ self.masks = None
32
+ self.testSarImages = None
33
+ self.testopticImages = None
34
+ self.testMasks = None
35
+
36
+ def readPathes(self):
37
+ # Get all file paths
38
+ all_sar_images = natsort.natsorted(list(pathlib.Path(self.sarPathes).glob('*.*')))
39
+ all_optic_images = natsort.natsorted(list(pathlib.Path(self.opticPathes).glob('*.*')))
40
+ all_mask_images = natsort.natsorted(list(pathlib.Path(self.masksPathes).glob('*.*')))
41
+
42
+ # Clean up .ipynb_checkpoints
43
+ for directory in [self.sarPathes, self.opticPathes, self.masksPathes]:
44
+ try:
45
+ shutil.rmtree(os.path.join(directory, ".ipynb_checkpoints"))
46
+ print(f".ipynb_checkpoints directory deleted successfully from {directory}.")
47
+ except Exception:
48
+ pass
49
+
50
+ # Extract image IDs - just getting the filename without extension
51
+ def extract_id(filepath):
52
+ return pathlib.Path(str(filepath)).stem
53
+
54
+ # Create dictionaries mapping IDs to paths for efficient lookup
55
+ sar_dict = {extract_id(f): f for f in all_sar_images}
56
+ optic_dict = {extract_id(f): f for f in all_optic_images}
57
+ mask_dict = {extract_id(f): f for f in all_mask_images}
58
+
59
+ # Find common IDs across all three datasets
60
+ common_ids = set(sar_dict.keys()) & set(optic_dict.keys()) & set(mask_dict.keys())
61
+
62
+ # Create matched file lists using sorted common IDs
63
+ sorted_common_ids = natsort.natsorted(list(common_ids))
64
+ self.sarImages = [sar_dict[id] for id in sorted_common_ids]
65
+ self.opticImages = [optic_dict[id] for id in sorted_common_ids]
66
+ self.masks = [mask_dict[id] for id in sorted_common_ids]
67
+
68
+ print(f"(INFO..) Found {len(all_sar_images)} SAR, {len(all_optic_images)} optical, {len(all_mask_images)} mask images")
69
+ print(f"(INFO..) Complete triplets: {len(common_ids)}")
70
+
71
+ def convertColorToLabel(self, img):
72
+ color_to_label = {
73
+ (115, 178, 115): 0, # non_mining_land (green)
74
+ (255, 0, 0): 1, # illegal_mining_land (red)
75
+ (0, 0, 0): 2, # beach (black)
76
+ }
77
+ # Create empty label array
78
+ label_img = np.zeros((img.shape[0], img.shape[1]), dtype=np.uint8)
79
+
80
+ # Map colors to labels
81
+ for color, label in color_to_label.items():
82
+ mask = np.all(img == color, axis=2)
83
+ label_img[mask] = label
84
+
85
+ # One-hot encode labels
86
+ num_classes = len(color_to_label)
87
+ one_hot = np.zeros((img.shape[0], img.shape[1], num_classes), dtype=np.uint8)
88
+ for c in range(num_classes):
89
+ one_hot[:, :, c] = (label_img == c).astype(np.uint8)
90
+
91
+ return one_hot
92
+
93
+ def readImages(self, data, typeData, width, height):
94
+ images = []
95
+ for img in data:
96
+ if typeData == 's': # SAR image
97
+ with rasterio.open(str(img)) as src:
98
+ sar_bands = [src.read(i) for i in range(1, src.count + 1)]
99
+ sar_image = np.stack(sar_bands, axis=-1)
100
+ # Stretching
101
+ p2, p98 = np.percentile(sar_image, (2, 98))
102
+ sar_image = np.clip(sar_image, p2, p98)
103
+ sar_image = ((sar_image - p2) / (p98 - p2) * 255).astype(np.uint8)
104
+ # Resize
105
+ sar_image = cv2.resize(sar_image, (width, height), interpolation=cv2.INTER_AREA)
106
+ images.append(np.expand_dims(sar_image, axis=-1))
107
+ elif typeData == 'm': # Mask image
108
+ img = cv2.imread(str(img), cv2.IMREAD_COLOR)
109
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
110
+ img = cv2.resize(img, (width, height), interpolation=cv2.INTER_NEAREST)
111
+ images.append(self.convertColorToLabel(img))
112
+ elif typeData == 'o': # Optic image
113
+ img = cv2.imread(str(img), cv2.IMREAD_COLOR)
114
+ img = cv2.resize(cv2.cvtColor(img, cv2.COLOR_BGR2RGB), (width, height), interpolation=cv2.INTER_AREA)
115
+ images.append(img)
116
+ print(f"(INFO..) Read {len(images)} {typeData} images")
117
+ return np.array(images)
118
+
119
+ def normalizeImages(self, images, typeData):
120
+ normalized_images = []
121
+ for img in images:
122
+ img = img.astype(np.uint8)
123
+ if typeData == 's':
124
+ img = img / 255.
125
+ if typeData == 'o':
126
+ img = img / 255.
127
+ normalized_images.append(img)
128
+ print("(INFO..) Normalization Image Done")
129
+ return np.array(normalized_images)
130
+
131
+ def dataAugmentation(self, sar_images, optic_images, masks, n_augments, size=WIDTH):
132
+ # Define augmentation pipeline once
133
+ augmentation = A.ReplayCompose([
134
+ A.RandomResizedCrop(size=(size, size), scale=(0.2, 0.9), ratio=(1, 1),
135
+ interpolation=cv2.INTER_AREA, mask_interpolation=cv2.INTER_NEAREST, p=0.5),
136
+ A.HorizontalFlip(p=0.5),
137
+ A.ShiftScaleRotate(scale_limit=(0.0, 0.15), rotate_limit=(-90, 90),
138
+ interpolation=cv2.INTER_AREA, mask_interpolation=cv2.INTER_NEAREST,
139
+ border_mode=cv2.BORDER_REFLECT, p=0.5),
140
+ A.RandomGamma(p=0.5),
141
+ A.RandomBrightnessContrast(brightness_limit=(-0.25, 0.25), contrast_limit=(-0.25, 0.25), p=0.5)
142
+ ], additional_targets={'sar': 'image'})
143
+
144
+ if not (len(sar_images) == len(optic_images) == len(masks)):
145
+ raise ValueError("Number of SAR images, optic images, and masks must be the same.")
146
+
147
+ # Initialize lists with original data
148
+ augmented_sar = list(sar_images)
149
+ augmented_optic = list(optic_images)
150
+ augmented_masks = list(masks)
151
+
152
+ # Perform augmentations
153
+ for i, (sar, optic, mask) in enumerate(zip(sar_images, optic_images, masks)):
154
+ for _ in range(n_augments):
155
+ augmented = augmentation(image=optic.astype(np.uint8),
156
+ mask=mask.astype(np.uint8),
157
+ sar=sar.astype(np.uint8))
158
+ augmented_sar.append(augmented['sar'])
159
+ augmented_optic.append(augmented['image'])
160
+ augmented_masks.append(augmented['mask'])
161
+
162
+ # Print statistics
163
+ total_original = len(optic_images)
164
+ total_augmented = len(augmented_optic)
165
+ print(f"(INFO..) Original Train Optic Images: {total_original}")
166
+ print(f"(INFO..) Total Augmented Train Optic Images: {total_augmented}")
167
+ print(f"(INFO..) Augmentation Multiplier: {total_augmented / total_original:.2f}x")
168
+ print("(INFO..) Augmentation Image Done \n")
169
+
170
+ return (np.array(augmented_sar), np.array(augmented_optic), np.array(augmented_masks))
171
+
172
+ def splitDataset(self, sar_images, optic_images, masks, test_size=0.1, n_augments=10):
173
+ data = list(zip(sar_images, optic_images, masks))
174
+ train_data, test_data = train_test_split(data, test_size=test_size, random_state=42)
175
+
176
+ # Unpack the training and test data
177
+ train_sar, train_optic, train_masks = zip(*train_data)
178
+ test_sar, test_optic, test_masks = zip(*test_data)
179
+
180
+ # Augment train data
181
+ train_sar_aug, train_optic_aug, train_masks_aug = self.dataAugmentation(
182
+ np.array(train_sar), np.array(train_optic), np.array(train_masks),
183
+ n_augments=n_augments
184
+ )
185
+
186
+ print("(INFO..) Splitting and Saving Data Done \n")
187
+ return (
188
+ np.array(train_sar_aug), np.array(train_optic_aug), np.array(train_masks_aug),
189
+ np.array(test_sar), np.array(test_optic), np.array(test_masks)
190
+ )
191
  # Streamlit App Title
192
  st.title("Satellite Mining Segmentation: SAR + Optic Image Inference")
193