Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -7,8 +7,6 @@ from matplotlib import pyplot as plt
|
|
| 7 |
import random
|
| 8 |
from keras.utils import get_custom_objects
|
| 9 |
import os
|
| 10 |
-
import zipfile
|
| 11 |
-
from io import BytesIO
|
| 12 |
|
| 13 |
os.environ['SM_FRAMEWORK'] = 'tf.keras'
|
| 14 |
|
|
@@ -51,6 +49,7 @@ def label_to_rgb(mask):
|
|
| 51 |
return rgb_mask
|
| 52 |
|
| 53 |
|
|
|
|
| 54 |
def jaccard_coef(y_true, y_pred):
|
| 55 |
y_true_flatten = K.flatten(y_true)
|
| 56 |
y_pred_flatten = K.flatten(y_pred)
|
|
@@ -65,41 +64,16 @@ total_loss = dice_loss + (1 * focal_loss)
|
|
| 65 |
|
| 66 |
satellite_model = load_model('model/satellite_segmentation_full.h5', custom_objects=({'dice_loss_plus_1focal_loss' : total_loss, 'jaccard_coef': jaccard_coef}))
|
| 67 |
|
| 68 |
-
def process_input_image(image_source
|
| 69 |
image = np.expand_dims(image_source, 0)
|
| 70 |
prediction = satellite_model.predict(image)
|
| 71 |
predicted_image = np.argmax(prediction, axis=3)
|
| 72 |
predicted_image = predicted_image[0, :, :]
|
| 73 |
rgb_image = label_to_rgb(predicted_image)
|
| 74 |
-
|
| 75 |
-
mask_zip, output_name = create_mask_zip(predicted_image)
|
| 76 |
-
if download_masks:
|
| 77 |
-
return {"img_output": rgb_image, "masks_zip": mask_zip.getvalue()}
|
| 78 |
-
else:
|
| 79 |
-
return {"img_output": rgb_image}
|
| 80 |
-
|
| 81 |
|
| 82 |
my_app = gr.Blocks()
|
| 83 |
|
| 84 |
-
|
| 85 |
-
def create_mask_zip(mask, output_name='masks.zip'):
|
| 86 |
-
buffer = BytesIO()
|
| 87 |
-
|
| 88 |
-
with zipfile.ZipFile(buffer, mode='w', compression=zipfile.ZIP_DEFLATED) as zf:
|
| 89 |
-
for idx, class_name in enumerate(['building', 'land', 'road', 'vegetation', 'water', 'unlabeled']):
|
| 90 |
-
mask_class = (mask == idx) * 255
|
| 91 |
-
mask_class = mask_class.astype(np.uint8)
|
| 92 |
-
mask_image = Image.fromarray(mask_class)
|
| 93 |
-
|
| 94 |
-
mask_buffer = BytesIO()
|
| 95 |
-
mask_image.save(mask_buffer, format='PNG')
|
| 96 |
-
mask_buffer.seek(0)
|
| 97 |
-
|
| 98 |
-
zf.writestr(f'{class_name}_mask.png', mask_buffer.getvalue())
|
| 99 |
-
|
| 100 |
-
buffer.seek(0)
|
| 101 |
-
return buffer, output_name
|
| 102 |
-
|
| 103 |
# Define the custom legend HTML
|
| 104 |
legend_html = '''
|
| 105 |
<div style="font-size: 14px; font-weight: bold; display: flex; flex-wrap: wrap; margin-top: 10px;">
|
|
@@ -125,29 +99,25 @@ legend_html = '''
|
|
| 125 |
'''
|
| 126 |
|
| 127 |
with my_app:
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
masks_zip,
|
| 151 |
-
]
|
| 152 |
-
)
|
| 153 |
-
my_app.launch(debug=True)
|
|
|
|
| 7 |
import random
|
| 8 |
from keras.utils import get_custom_objects
|
| 9 |
import os
|
|
|
|
|
|
|
| 10 |
|
| 11 |
os.environ['SM_FRAMEWORK'] = 'tf.keras'
|
| 12 |
|
|
|
|
| 49 |
return rgb_mask
|
| 50 |
|
| 51 |
|
| 52 |
+
|
| 53 |
def jaccard_coef(y_true, y_pred):
|
| 54 |
y_true_flatten = K.flatten(y_true)
|
| 55 |
y_pred_flatten = K.flatten(y_pred)
|
|
|
|
| 64 |
|
| 65 |
satellite_model = load_model('model/satellite_segmentation_full.h5', custom_objects=({'dice_loss_plus_1focal_loss' : total_loss, 'jaccard_coef': jaccard_coef}))
|
| 66 |
|
| 67 |
+
def process_input_image(image_source):
|
| 68 |
image = np.expand_dims(image_source, 0)
|
| 69 |
prediction = satellite_model.predict(image)
|
| 70 |
predicted_image = np.argmax(prediction, axis=3)
|
| 71 |
predicted_image = predicted_image[0, :, :]
|
| 72 |
rgb_image = label_to_rgb(predicted_image)
|
| 73 |
+
return "Predicted Masked Image", rgb_image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
my_app = gr.Blocks()
|
| 76 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
# Define the custom legend HTML
|
| 78 |
legend_html = '''
|
| 79 |
<div style="font-size: 14px; font-weight: bold; display: flex; flex-wrap: wrap; margin-top: 10px;">
|
|
|
|
| 99 |
'''
|
| 100 |
|
| 101 |
with my_app:
|
| 102 |
+
gr.Markdown("Image Processing Application UI with Gradio")
|
| 103 |
+
with gr.Tabs():
|
| 104 |
+
with gr.TabItem("Select your image"):
|
| 105 |
+
with gr.Row():
|
| 106 |
+
with gr.Column():
|
| 107 |
+
img_source = gr.Image(label="Please select source Image", shape=(256,256))
|
| 108 |
+
source_image_loader = gr.Button("Load above Image")
|
| 109 |
+
with gr.Column():
|
| 110 |
+
output_label = gr.Label(label="Image Info")
|
| 111 |
+
img_output = gr.Image(label="Image Output")
|
| 112 |
+
legend = gr.HTML(legend_html) # Add the custom legend component here
|
| 113 |
+
source_image_loader.click(
|
| 114 |
+
process_input_image,
|
| 115 |
+
[
|
| 116 |
+
img_source
|
| 117 |
+
],
|
| 118 |
+
[
|
| 119 |
+
output_label,
|
| 120 |
+
img_output
|
| 121 |
+
]
|
| 122 |
+
)
|
| 123 |
+
my_app.launch(debug=True)
|
|
|
|
|
|
|
|
|
|
|
|