Chancee12 commited on
Commit
bcae9cb
·
1 Parent(s): b07dd12

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -24
app.py CHANGED
@@ -7,6 +7,8 @@ from matplotlib import pyplot as plt
7
  import random
8
  from keras.utils import get_custom_objects
9
  import os
 
 
10
 
11
  os.environ['SM_FRAMEWORK'] = 'tf.keras'
12
 
@@ -49,7 +51,6 @@ def label_to_rgb(mask):
49
  return rgb_mask
50
 
51
 
52
-
53
  def jaccard_coef(y_true, y_pred):
54
  y_true_flatten = K.flatten(y_true)
55
  y_pred_flatten = K.flatten(y_pred)
@@ -64,16 +65,42 @@ total_loss = dice_loss + (1 * focal_loss)
64
 
65
  satellite_model = load_model('model/satellite_segmentation_full.h5', custom_objects=({'dice_loss_plus_1focal_loss' : total_loss, 'jaccard_coef': jaccard_coef}))
66
 
67
- def process_input_image(image_source):
68
  image = np.expand_dims(image_source, 0)
69
  prediction = satellite_model.predict(image)
70
  predicted_image = np.argmax(prediction, axis=3)
71
  predicted_image = predicted_image[0, :, :]
72
  rgb_image = label_to_rgb(predicted_image)
73
- return "Predicted Masked Image", rgb_image
 
 
 
 
 
 
74
 
75
  my_app = gr.Blocks()
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  # Define the custom legend HTML
78
  legend_html = '''
79
  <div style="font-size: 14px; font-weight: bold; display: flex; flex-wrap: wrap; margin-top: 10px;">
@@ -99,25 +126,30 @@ legend_html = '''
99
  '''
100
 
101
  with my_app:
102
- gr.Markdown("Image Processing Application UI with Gradio")
103
- with gr.Tabs():
104
- with gr.TabItem("Select your image"):
105
- with gr.Row():
106
- with gr.Column():
107
- img_source = gr.Image(label="Please select source Image", shape=(256,256))
108
- source_image_loader = gr.Button("Load above Image")
109
- with gr.Column():
110
- output_label = gr.Label(label="Image Info")
111
- img_output = gr.Image(label="Image Output")
112
- legend = gr.HTML(legend_html) # Add the custom legend component here
113
- source_image_loader.click(
114
- process_input_image,
115
- [
116
- img_source
117
- ],
118
- [
119
- output_label,
120
- img_output
121
- ]
122
- )
 
 
 
 
 
123
  my_app.launch(debug=True)
 
7
  import random
8
  from keras.utils import get_custom_objects
9
  import os
10
+ import zipfile
11
+ from io import BytesIO
12
 
13
  os.environ['SM_FRAMEWORK'] = 'tf.keras'
14
 
 
51
  return rgb_mask
52
 
53
 
 
54
  def jaccard_coef(y_true, y_pred):
55
  y_true_flatten = K.flatten(y_true)
56
  y_pred_flatten = K.flatten(y_pred)
 
65
 
66
  satellite_model = load_model('model/satellite_segmentation_full.h5', custom_objects=({'dice_loss_plus_1focal_loss' : total_loss, 'jaccard_coef': jaccard_coef}))
67
 
68
+ def process_input_image(image_source, download_masks=False):
69
  image = np.expand_dims(image_source, 0)
70
  prediction = satellite_model.predict(image)
71
  predicted_image = np.argmax(prediction, axis=3)
72
  predicted_image = predicted_image[0, :, :]
73
  rgb_image = label_to_rgb(predicted_image)
74
+
75
+ if download_masks:
76
+ mask_zip, output_name = create_mask_zip(predicted_image)
77
+ return "Predicted Masked Image", rgb_image, mask_zip, output_name
78
+ else:
79
+ return "Predicted Masked Image", rgb_image
80
+
81
 
82
  my_app = gr.Blocks()
83
 
84
+
85
+ def create_mask_zip(mask, output_name='masks.zip'):
86
+ buffer = BytesIO()
87
+
88
+ with zipfile.ZipFile(buffer, mode='w', compression=zipfile.ZIP_DEFLATED) as zf:
89
+ for idx, class_name in enumerate(['building', 'land', 'road', 'vegetation', 'water', 'unlabeled']):
90
+ mask_class = (mask == idx) * 255
91
+ mask_class = mask_class.astype(np.uint8)
92
+ mask_image = Image.fromarray(mask_class)
93
+
94
+ mask_buffer = BytesIO()
95
+ mask_image.save(mask_buffer, format='PNG')
96
+ mask_buffer.seek(0)
97
+
98
+ zf.writestr(f'{class_name}_mask.png', mask_buffer.getvalue())
99
+
100
+ buffer.seek(0)
101
+ return buffer, output_name
102
+
103
+
104
  # Define the custom legend HTML
105
  legend_html = '''
106
  <div style="font-size: 14px; font-weight: bold; display: flex; flex-wrap: wrap; margin-top: 10px;">
 
126
  '''
127
 
128
  with my_app:
129
+ gr.Markdown("Image Processing Application UI with Gradio")
130
+ with gr.Tabs():
131
+ with gr.TabItem("Select your image"):
132
+ with gr.Row():
133
+ with gr.Column():
134
+ img_source = gr.Image(label="Please select source Image", shape=(256, 256))
135
+ source_image_loader = gr.Button("Load above Image")
136
+ download_masks_checkbox = gr.Checkbox(label="Download masks as separate files")
137
+ with gr.Column():
138
+ output_label = gr.Label(label="Image Info")
139
+ img_output = gr.Image(label="Image Output")
140
+ masks_zip = gr.File(label="Download Masks", filename="masks.zip")
141
+ legend = gr.HTML(legend_html)
142
+
143
+ source_image_loader.click(
144
+ process_input_image,
145
+ [
146
+ img_source,
147
+ download_masks_checkbox
148
+ ],
149
+ [
150
+ output_label,
151
+ img_output,
152
+ masks_zip,
153
+ ]
154
+ )
155
  my_app.launch(debug=True)