Chancee12 commited on
Commit
7da17b6
·
1 Parent(s): bb82107

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -55
app.py CHANGED
@@ -7,8 +7,6 @@ from matplotlib import pyplot as plt
7
  import random
8
  from keras.utils import get_custom_objects
9
  import os
10
- import zipfile
11
- from io import BytesIO
12
 
13
  os.environ['SM_FRAMEWORK'] = 'tf.keras'
14
 
@@ -51,6 +49,7 @@ def label_to_rgb(mask):
51
  return rgb_mask
52
 
53
 
 
54
  def jaccard_coef(y_true, y_pred):
55
  y_true_flatten = K.flatten(y_true)
56
  y_pred_flatten = K.flatten(y_pred)
@@ -65,41 +64,16 @@ total_loss = dice_loss + (1 * focal_loss)
65
 
66
  satellite_model = load_model('model/satellite_segmentation_full.h5', custom_objects=({'dice_loss_plus_1focal_loss' : total_loss, 'jaccard_coef': jaccard_coef}))
67
 
68
- def process_input_image(image_source, download_masks=False):
69
  image = np.expand_dims(image_source, 0)
70
  prediction = satellite_model.predict(image)
71
  predicted_image = np.argmax(prediction, axis=3)
72
  predicted_image = predicted_image[0, :, :]
73
  rgb_image = label_to_rgb(predicted_image)
74
-
75
- mask_zip, output_name = create_mask_zip(predicted_image)
76
- if download_masks:
77
- return {"img_output": rgb_image, "masks_zip": mask_zip.getvalue()}
78
- else:
79
- return {"img_output": rgb_image}
80
-
81
 
82
  my_app = gr.Blocks()
83
 
84
-
85
- def create_mask_zip(mask, output_name='masks.zip'):
86
- buffer = BytesIO()
87
-
88
- with zipfile.ZipFile(buffer, mode='w', compression=zipfile.ZIP_DEFLATED) as zf:
89
- for idx, class_name in enumerate(['building', 'land', 'road', 'vegetation', 'water', 'unlabeled']):
90
- mask_class = (mask == idx) * 255
91
- mask_class = mask_class.astype(np.uint8)
92
- mask_image = Image.fromarray(mask_class)
93
-
94
- mask_buffer = BytesIO()
95
- mask_image.save(mask_buffer, format='PNG')
96
- mask_buffer.seek(0)
97
-
98
- zf.writestr(f'{class_name}_mask.png', mask_buffer.getvalue())
99
-
100
- buffer.seek(0)
101
- return buffer, output_name
102
-
103
  # Define the custom legend HTML
104
  legend_html = '''
105
  <div style="font-size: 14px; font-weight: bold; display: flex; flex-wrap: wrap; margin-top: 10px;">
@@ -125,29 +99,25 @@ legend_html = '''
125
  '''
126
 
127
  with my_app:
128
- gr.Markdown("Image Processing Application UI with Gradio")
129
- with gr.Tabs():
130
- with gr.TabItem("Select your image"):
131
- with gr.Row():
132
- with gr.Column():
133
- img_source = gr.Image(label="Please select source Image", shape=(256, 256))
134
- source_image_loader = gr.Button("Load above Image")
135
- download_masks_checkbox = gr.Checkbox(label="Download masks as separate files")
136
- with gr.Column():
137
- output_label = gr.Label(label="Image Info")
138
- img_output = gr.Image(label="Image Output")
139
- masks_zip = gr.File(label="Download Masks", filename="masks.zip")
140
- legend = gr.HTML(legend_html)
141
-
142
- source_image_loader.click(
143
- process_input_image,
144
- [
145
- img_source,
146
- download_masks_checkbox
147
- ],
148
- [
149
- img_output,
150
- masks_zip,
151
- ]
152
- )
153
- my_app.launch(debug=True)
 
7
  import random
8
  from keras.utils import get_custom_objects
9
  import os
 
 
10
 
11
  os.environ['SM_FRAMEWORK'] = 'tf.keras'
12
 
 
49
  return rgb_mask
50
 
51
 
52
+
53
  def jaccard_coef(y_true, y_pred):
54
  y_true_flatten = K.flatten(y_true)
55
  y_pred_flatten = K.flatten(y_pred)
 
64
 
65
  satellite_model = load_model('model/satellite_segmentation_full.h5', custom_objects=({'dice_loss_plus_1focal_loss' : total_loss, 'jaccard_coef': jaccard_coef}))
66
 
67
+ def process_input_image(image_source):
68
  image = np.expand_dims(image_source, 0)
69
  prediction = satellite_model.predict(image)
70
  predicted_image = np.argmax(prediction, axis=3)
71
  predicted_image = predicted_image[0, :, :]
72
  rgb_image = label_to_rgb(predicted_image)
73
+ return "Predicted Masked Image", rgb_image
 
 
 
 
 
 
74
 
75
  my_app = gr.Blocks()
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  # Define the custom legend HTML
78
  legend_html = '''
79
  <div style="font-size: 14px; font-weight: bold; display: flex; flex-wrap: wrap; margin-top: 10px;">
 
99
  '''
100
 
101
  with my_app:
102
+ gr.Markdown("Image Processing Application UI with Gradio")
103
+ with gr.Tabs():
104
+ with gr.TabItem("Select your image"):
105
+ with gr.Row():
106
+ with gr.Column():
107
+ img_source = gr.Image(label="Please select source Image", shape=(256,256))
108
+ source_image_loader = gr.Button("Load above Image")
109
+ with gr.Column():
110
+ output_label = gr.Label(label="Image Info")
111
+ img_output = gr.Image(label="Image Output")
112
+ legend = gr.HTML(legend_html) # Add the custom legend component here
113
+ source_image_loader.click(
114
+ process_input_image,
115
+ [
116
+ img_source
117
+ ],
118
+ [
119
+ output_label,
120
+ img_output
121
+ ]
122
+ )
123
+ my_app.launch(debug=True)