File size: 3,653 Bytes
b9e7a79
 
 
7e6d971
 
b9e7a79
 
 
 
 
7886c83
7e6d971
b9e7a79
7e6d971
b9e7a79
 
5eb4896
7e6d971
b5fb4c3
 
 
7e6d971
b5fb4c3
 
 
7e6d971
b5fb4c3
 
 
7e6d971
b5fb4c3
 
 
7e6d971
b5fb4c3
 
 
7e6d971
b5fb4c3
 
 
0e93c39
7e6d971
 
 
 
 
 
 
cfd9318
7e6d971
cfd9318
7e6d971
cfd9318
 
 
964f4ce
7e6d971
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b5fb4c3
7da17b6
b07dd12
71b4ab7
b07dd12
 
0455b57
7e6d971
 
083cfec
7e6d971
71b4ab7
 
c17b1d3
 
 
7e6d971
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7da17b6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import gradio as gr
import os
import cv2
from PIL import Image
import numpy as np
from matplotlib import pyplot as plt
import random
from keras.utils import get_custom_objects
import os

os.environ['SM_FRAMEWORK'] = 'tf.keras'

import segmentation_models as sm

from keras import backend as K
from keras.models import load_model

class_building = '#3C1098'
class_building = class_building.lstrip('#')
class_building = np.array(tuple(int(class_building[i:i+2], 16) for i in (0,2,4)))

class_land = '#8429F6'
class_land = class_land.lstrip('#')
class_land = np.array(tuple(int(class_land[i:i+2], 16) for i in (0,2,4)))

class_road = '#6EC1E4'
class_road = class_road.lstrip('#')
class_road = np.array(tuple(int(class_road[i:i+2], 16) for i in (0,2,4)))

class_vegetation = '#FEDD3A'
class_vegetation = class_vegetation.lstrip('#')
class_vegetation = np.array(tuple(int(class_vegetation[i:i+2], 16) for i in (0,2,4)))

class_water = '#E2A929'
class_water = class_water.lstrip('#')
class_water = np.array(tuple(int(class_water[i:i+2], 16) for i in (0,2,4)))

class_unlabeled = '#9B9B9B'
class_unlabeled = class_unlabeled.lstrip('#')
class_unlabeled = np.array(tuple(int(class_unlabeled[i:i+2], 16) for i in (0,2,4)))

def jaccard_coef(y_true, y_pred):
    y_true_flatten = K.flatten(y_true)
    y_pred_flatten = K.flatten(y_pred)
    intersection = K.sum(y_true_flatten * y_pred_flatten)
    final_coef_value = (intersection + 1.0) / (
            K.sum(y_true_flatten) + K.sum(y_pred_flatten) - intersection + 1.0)
    return final_coef_value


# six class for six weights
weights = [0.1666, 0.1666, 0.1666, 0.1666, 0.1666, 0.1666]
dice_loss = sm.losses.DiceLoss(class_weights=weights)
focal_loss = sm.losses.CategoricalFocalLoss()
total_loss = dice_loss + (1 * focal_loss)

satellite_model = load_model('satellite_segmentation_full_v2.h5',
                             custom_objects=({'dice_loss_plus_1focal_loss': total_loss, 'jaccard_coef': jaccard_coef}))


def label_to_rgb(label_segment):
    rgb_image = np.zeros((label_segment.shape[0], label_segment.shape[1], 3), dtype=np.uint8)

    rgb_image[label_segment == 0] = class_water
    rgb_image[label_segment == 1] = class_land
    rgb_image[label_segment == 2] = class_road
    rgb_image[label_segment == 3] = class_building
    rgb_image[label_segment == 4] = class_vegetation
    rgb_image[label_segment == 5] = class_unlabeled

    return rgb_image


def process_input_image(image_source):
    image = np.expand_dims(image_source, 0)
    prediction = satellite_model.predict(image)
    predicted_image = np.argmax(prediction, axis=3)
    predicted_image = predicted_image[0, :, :]

    # Convert the predicted image labels to RGB
    colored_predicted_image = label_to_rgb(predicted_image)

    return "Predicted Masked Image", colored_predicted_image


my_app = gr.Blocks()

with my_app:
    gr.Markdown("Image Processing Application UI with Gradio")
    with gr.Tabs():
        with gr.TabItem("Select your image"):
            with gr.Row():
                with gr.Column():
                    img_source = gr.Image(label="Please select source Image", shape=(256, 256))
                    source_image_loader = gr.Button("Load above Image")
                with gr.Column():
                    output_label = gr.Label(label="Image Info")
                    img_output = gr.Image(label="Image Output")
            source_image_loader.click(
                process_input_image,
                [
                    img_source
                ],
                [
                    output_label,
                    img_output
                ]
            )
my_app.launch(debug=True)