Prerak51 commited on
Commit
143d762
·
verified ·
1 Parent(s): 2ff606d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +224 -74
app.py CHANGED
@@ -1,23 +1,30 @@
1
-
2
-
3
- import os
4
- import cv2
5
- from PIL import Image
6
- import numpy as np
7
- from matplotlib import pyplot as plt
8
- import random
9
  import gradio as gr
10
- from keras import backend as K
11
  from keras.models import load_model
12
- def jaccard_coef(y_true, y_pred):
13
- y_true_flatten = K.flatten(y_true)
14
- y_pred_flatten = K.flatten(y_pred)
15
- intersection = K.sum(y_true_flatten * y_pred_flatten)
16
- final_coef_value = (intersection + 1.0) / (K.sum(y_true_flatten) + K.sum(y_pred_flatten) - intersection + 1.0)
17
- return final_coef_value
 
 
 
 
 
 
 
 
 
 
18
 
 
 
 
 
 
 
 
19
 
20
- # Define Dice Loss
21
  def dice_loss(y_true, y_pred):
22
  smooth = 1e-12
23
  intersection = K.sum(y_true * y_pred, axis=[1,2,3])
@@ -25,7 +32,6 @@ def dice_loss(y_true, y_pred):
25
  dice = K.mean((2.0 * intersection + smooth) / (union + smooth), axis=0)
26
  return 1.0 - dice
27
 
28
- # Define Focal Loss
29
  def focal_loss(y_true, y_pred, alpha=0.25, gamma=2.0):
30
  y_pred = K.clip(y_pred, K.epsilon(), 1.0 - K.epsilon())
31
  ce_loss = -y_true * K.log(y_pred)
@@ -33,77 +39,221 @@ def focal_loss(y_true, y_pred, alpha=0.25, gamma=2.0):
33
  fl_loss = ce_loss * weight
34
  return K.mean(K.sum(fl_loss, axis=-1))
35
 
36
- # Define Total Loss
37
  def total_loss(y_true, y_pred):
38
  return dice_loss(y_true, y_pred) + (1 * focal_loss(y_true, y_pred))
39
 
40
- weights = [0.1666, 0.1666, 0.1666, 0.1666, 0.1666, 0.1666]
 
 
41
 
 
 
42
 
43
- from keras.models import load_model
44
- import numpy as np
45
- from PIL import Image
46
- import matplotlib.pyplot as plt
47
- saved_model=load_model('satmodel.h5', custom_objects={'total_loss': total_loss, 'dice_loss': dice_loss, 'focal_loss': focal_loss, 'jaccard_coef': jaccard_coef})
48
- # def process_input_image(image_source):
49
- # image = np.expand_dims(image_source, 0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
- # prediction = saved_model.predict(image)
52
- # predicted_image = np.argmax(prediction, axis=3)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
- # predicted_image = predicted_image[0,:,:]
55
- # predicted_image = predicted_image * 50
56
- # return 'Predicted Masked Image', predicted_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
- import matplotlib.pyplot as plt
59
- import matplotlib.colors as mcolors
60
 
61
- # # Define the image processing function
62
 
63
- # Define the image processing function
64
- def process_input_image(image):
65
- image = Image.fromarray(image)
66
- image = image.convert('RGB') # Convert the image to RGB
67
- image = image.resize((256, 256))
68
- image = np.array(image)
69
- image = np.expand_dims(image, 0)
70
 
71
- prediction = saved_model.predict(image)
72
- predicted_image = np.argmax(prediction, axis=3)
73
 
74
- predicted_image = predicted_image[0,:,:]
75
- predicted_image = predicted_image * 50
76
 
77
 
78
- # Apply a colormap to the predicted image
79
- cmap = plt.get_cmap('viridis') # You can choose any colormap you prefer
80
- colored_image = cmap(predicted_image / predicted_image.max()) # Normalize to [0, 1]
81
- colored_image = (colored_image[:, :, :3] * 255).astype(np.uint8) # Convert to RGB and scale to [0, 255]
82
 
83
- return 'Predicted Masked Image', colored_image
84
- # return 'Predicted Masked Image', predicted_image
85
 
86
- my_app = gr.Blocks()
87
- with my_app:
88
- gr.Markdown("Statellite Image Segmentation Application UI with Gradio")
89
- with gr.Tabs():
90
- with gr.TabItem("Select your image"):
91
- with gr.Row():
92
- with gr.Column():
93
- img_source = gr.Image(label="Please select source Image")
94
- source_image_loader = gr.Button("Load above Image")
95
- with gr.Column():
96
- output_label = gr.Label(label="Image Info")
97
- img_output = gr.Image(label="Image Output")
98
- source_image_loader.click(
99
- process_input_image,
100
- [
101
- img_source
102
- ],
103
- [
104
- output_label,
105
- img_output
106
- ]
107
- )
108
- my_app.launch(debug=True,share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
2
  from keras.models import load_model
3
+ from patchify import patchify, unpatchify
4
+ import numpy as np
5
+ import cv2
6
+ from sklearn.preprocessing import MinMaxScaler
7
+ import matplotlib.pyplot as plt
8
+
9
+ # Define colors for classes
10
+ class_building = np.array([60, 16, 152])
11
+ class_land = np.array([132, 41, 246])
12
+ class_road = np.array([110, 193, 228])
13
+ class_vegetation = np.array([254, 221, 58])
14
+ class_water = np.array([226, 169, 41])
15
+ class_unlabeled = np.array([155, 155, 155])
16
+
17
+ # Number of classes in your segmentation task
18
+ total_classes = 6 # Update this with your total number of classes
19
 
20
+ # Define custom loss functions
21
+ def jaccard_coef(y_true, y_pred):
22
+ smooth = 1e-12
23
+ intersection = K.sum(K.abs(y_true * y_pred), axis=[1,2,3])
24
+ union = K.sum(y_true,[1,2,3])+K.sum(y_pred,[1,2,3])-intersection
25
+ jac = K.mean((intersection + smooth) / (union + smooth), axis=0)
26
+ return jac
27
 
 
28
  def dice_loss(y_true, y_pred):
29
  smooth = 1e-12
30
  intersection = K.sum(y_true * y_pred, axis=[1,2,3])
 
32
  dice = K.mean((2.0 * intersection + smooth) / (union + smooth), axis=0)
33
  return 1.0 - dice
34
 
 
35
  def focal_loss(y_true, y_pred, alpha=0.25, gamma=2.0):
36
  y_pred = K.clip(y_pred, K.epsilon(), 1.0 - K.epsilon())
37
  ce_loss = -y_true * K.log(y_pred)
 
39
  fl_loss = ce_loss * weight
40
  return K.mean(K.sum(fl_loss, axis=-1))
41
 
 
42
  def total_loss(y_true, y_pred):
43
  return dice_loss(y_true, y_pred) + (1 * focal_loss(y_true, y_pred))
44
 
45
+ # Load the pre-trained model
46
+ model_path = 'satmodel.h5' # Replace with your model path
47
+ model = load_model(model_path, custom_objects={'total_loss': total_loss, 'jaccard_coef': jaccard_coef, 'dice_loss': dice_loss, 'focal_loss': focal_loss})
48
 
49
+ # MinMaxScaler for normalization
50
+ minmaxscaler = MinMaxScaler()
51
 
52
+ # Function to predict the full image
53
+ def predict_full_image(image, patch_size, model):
54
+ original_shape = image.shape
55
+ print(f"Original image shape: {original_shape}")
56
+
57
+ # Pad image to make its dimensions divisible by the patch size
58
+ pad_height = (patch_size - image.shape[0] % patch_size) % patch_size
59
+ pad_width = (patch_size - image.shape[1] % patch_size) % patch_size
60
+ image = np.pad(image, ((0, pad_height), (0, pad_width), (0, 0)), mode='constant', constant_values=0)
61
+ padded_shape = image.shape
62
+ print(f"Padded image shape: {padded_shape}")
63
+
64
+ # Normalize the image
65
+ image = minmaxscaler.fit_transform(image.reshape(-1, image.shape[-1])).reshape(image.shape)
66
+
67
+ # Create patches
68
+ patched_images = patchify(image, (patch_size, patch_size, 3), step=patch_size)
69
+ print(f"Patched image shape: {patched_images.shape}")
70
+
71
+ predicted_patches = []
72
+
73
+ # Predict on each patch
74
+ for i in range(patched_images.shape[0]):
75
+ for j in range(patched_images.shape[1]):
76
+ single_patch = patched_images[i, j, 0]
77
+ single_patch = np.expand_dims(single_patch, axis=0)
78
+ prediction = model.predict(single_patch)
79
+ predicted_patches.append(prediction[0])
80
+
81
+ # Reshape predicted patches
82
+ predicted_patches = np.array(predicted_patches)
83
+ print(f"Predicted patches shape: {predicted_patches.shape}")
84
+
85
+ predicted_patches = predicted_patches.reshape(patched_images.shape[0], patched_images.shape[1], patch_size, patch_size, total_classes)
86
+ print(f"Reshaped predicted patches shape: {predicted_patches.shape}")
87
+
88
+ # Unpatchify the image
89
+ reconstructed_image = np.zeros((padded_shape[0], padded_shape[1], total_classes))
90
+ for i in range(patched_images.shape[0]):
91
+ for j in range(patched_images.shape[1]):
92
+ reconstructed_image[i * patch_size:(i + 1) * patch_size, j * patch_size:(j + 1) * patch_size, :] = predicted_patches[i, j]
93
+ print(f"Reconstructed image shape (with padding): {reconstructed_image.shape}")
94
+
95
+ # Remove padding
96
+ reconstructed_image = reconstructed_image[:original_shape[0], :original_shape[1]]
97
+ print(f"Final reconstructed image shape: {reconstructed_image.shape}")
98
+
99
+ return reconstructed_image
100
 
101
+ # Function to process the input image
102
+ def process_input_image(input_image):
103
+ image_patch_size = 256
104
+ predicted_full_image = predict_full_image(input_image, image_patch_size, model)
105
+
106
+ # Convert the predictions to RGB
107
+ predicted_full_image_rgb = np.zeros_like(input_image)
108
+
109
+ # Map the predicted class labels to RGB colors
110
+ predicted_full_image_rgb[predicted_full_image.argmax(axis=-1) == 0] = class_water
111
+ predicted_full_image_rgb[predicted_full_image.argmax(axis=-1) == 1] = class_land
112
+ predicted_full_image_rgb[predicted_full_image.argmax(axis=-1) == 2] = class_road
113
+ predicted_full_image_rgb[predicted_full_image.argmax(axis=-1) == 3] = class_building
114
+ predicted_full_image_rgb[predicted_full_image.argmax(axis=-1) == 4] = class_vegetation
115
+ predicted_full_image_rgb[predicted_full_image.argmax(axis=-1) == 5] = class_unlabeled
116
+
117
+ return "Image processed", predicted_full_image_rgb
118
 
119
+ # Gradio application
120
+ my_app = gr.Blocks()
121
+ with my_app:
122
+ gr.Markdown("Satellite Image Segmentation Application UI with Gradio")
123
+ with gr.Tabs():
124
+ with gr.TabItem("Select your image"):
125
+ with gr.Row():
126
+ with gr.Column():
127
+ img_source = gr.Image(label="Please select source Image")
128
+ source_image_loader = gr.Button("Load above Image")
129
+ with gr.Column():
130
+ output_label = gr.Label(label="Image Info")
131
+ img_output = gr.Image(label="Image Output")
132
+ source_image_loader.click(
133
+ process_input_image,
134
+ inputs=[img_source],
135
+ outputs=[output_label, img_output]
136
+ )
137
 
138
+ # Launch the app
139
+ my_app.launch()
140
 
 
141
 
 
 
 
 
 
 
 
142
 
 
 
143
 
 
 
144
 
145
 
 
 
 
 
146
 
 
 
147
 
148
+
149
+
150
+
151
+
152
+
153
+ # import os
154
+ # import cv2
155
+ # from PIL import Image
156
+ # import numpy as np
157
+ # from matplotlib import pyplot as plt
158
+ # import random
159
+ # import gradio as gr
160
+ # from keras import backend as K
161
+ # from keras.models import load_model
162
+ # def jaccard_coef(y_true, y_pred):
163
+ # y_true_flatten = K.flatten(y_true)
164
+ # y_pred_flatten = K.flatten(y_pred)
165
+ # intersection = K.sum(y_true_flatten * y_pred_flatten)
166
+ # final_coef_value = (intersection + 1.0) / (K.sum(y_true_flatten) + K.sum(y_pred_flatten) - intersection + 1.0)
167
+ # return final_coef_value
168
+
169
+
170
+ # # Define Dice Loss
171
+ # def dice_loss(y_true, y_pred):
172
+ # smooth = 1e-12
173
+ # intersection = K.sum(y_true * y_pred, axis=[1,2,3])
174
+ # union = K.sum(y_true, axis=[1,2,3]) + K.sum(y_pred, axis=[1,2,3])
175
+ # dice = K.mean((2.0 * intersection + smooth) / (union + smooth), axis=0)
176
+ # return 1.0 - dice
177
+
178
+ # # Define Focal Loss
179
+ # def focal_loss(y_true, y_pred, alpha=0.25, gamma=2.0):
180
+ # y_pred = K.clip(y_pred, K.epsilon(), 1.0 - K.epsilon())
181
+ # ce_loss = -y_true * K.log(y_pred)
182
+ # weight = alpha * y_true * K.pow((1 - y_pred), gamma)
183
+ # fl_loss = ce_loss * weight
184
+ # return K.mean(K.sum(fl_loss, axis=-1))
185
+
186
+ # # Define Total Loss
187
+ # def total_loss(y_true, y_pred):
188
+ # return dice_loss(y_true, y_pred) + (1 * focal_loss(y_true, y_pred))
189
+
190
+ # weights = [0.1666, 0.1666, 0.1666, 0.1666, 0.1666, 0.1666]
191
+
192
+
193
+ # from keras.models import load_model
194
+ # import numpy as np
195
+ # from PIL import Image
196
+ # import matplotlib.pyplot as plt
197
+ # saved_model=load_model('satmodel.h5', custom_objects={'total_loss': total_loss, 'dice_loss': dice_loss, 'focal_loss': focal_loss, 'jaccard_coef': jaccard_coef})
198
+ # # def process_input_image(image_source):
199
+ # # image = np.expand_dims(image_source, 0)
200
+
201
+ # # prediction = saved_model.predict(image)
202
+ # # predicted_image = np.argmax(prediction, axis=3)
203
+
204
+ # # predicted_image = predicted_image[0,:,:]
205
+ # # predicted_image = predicted_image * 50
206
+ # # return 'Predicted Masked Image', predicted_image
207
+
208
+ # import matplotlib.pyplot as plt
209
+ # import matplotlib.colors as mcolors
210
+
211
+ # # # Define the image processing function
212
+
213
+ # # Define the image processing function
214
+ # def process_input_image(image):
215
+ # image = Image.fromarray(image)
216
+ # image = image.convert('RGB') # Convert the image to RGB
217
+ # image = image.resize((256, 256))
218
+ # image = np.array(image)
219
+ # image = np.expand_dims(image, 0)
220
+
221
+ # prediction = saved_model.predict(image)
222
+ # predicted_image = np.argmax(prediction, axis=3)
223
+
224
+ # predicted_image = predicted_image[0,:,:]
225
+ # predicted_image = predicted_image * 50
226
+
227
+
228
+ # # Apply a colormap to the predicted image
229
+ # cmap = plt.get_cmap('viridis') # You can choose any colormap you prefer
230
+ # colored_image = cmap(predicted_image / predicted_image.max()) # Normalize to [0, 1]
231
+ # colored_image = (colored_image[:, :, :3] * 255).astype(np.uint8) # Convert to RGB and scale to [0, 255]
232
+
233
+ # return 'Predicted Masked Image', colored_image
234
+ # # return 'Predicted Masked Image', predicted_image
235
+
236
+ # my_app = gr.Blocks()
237
+ # with my_app:
238
+ # gr.Markdown("Statellite Image Segmentation Application UI with Gradio")
239
+ # with gr.Tabs():
240
+ # with gr.TabItem("Select your image"):
241
+ # with gr.Row():
242
+ # with gr.Column():
243
+ # img_source = gr.Image(label="Please select source Image")
244
+ # source_image_loader = gr.Button("Load above Image")
245
+ # with gr.Column():
246
+ # output_label = gr.Label(label="Image Info")
247
+ # img_output = gr.Image(label="Image Output")
248
+ # source_image_loader.click(
249
+ # process_input_image,
250
+ # [
251
+ # img_source
252
+ # ],
253
+ # [
254
+ # output_label,
255
+ # img_output
256
+ # ]
257
+ # )
258
+ # my_app.launch(debug=True,share=True)
259