Chancee12 commited on
Commit
7e6d971
·
1 Parent(s): 71b4ab7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -105
app.py CHANGED
@@ -1,156 +1,109 @@
1
  import gradio as gr
2
  import os
3
  import cv2
4
- from PIL import Image
5
- import numpy as np
6
  from matplotlib import pyplot as plt
7
  import random
8
  from keras.utils import get_custom_objects
9
  import os
10
- import tensorflow as tf
11
- from huggingface_hub import hf_hub_download
12
 
13
  os.environ['SM_FRAMEWORK'] = 'tf.keras'
 
14
  import segmentation_models as sm
 
15
  from keras import backend as K
16
  from keras.models import load_model
17
 
18
- # Replace 'your_username' and 'your_model_name' with the appropriate values
19
- import zipfile
20
-
21
- from huggingface_hub import HfApi, HfFolder
22
- import requests
23
-
24
- api = HfApi()
25
- model_identifier = "Chancee12/satellite_segmentation_v3_full"
26
- filename = "satellite_segmentation_v3_full.zip"
27
-
28
- # Get the URL to download the model
29
- url = api.model_download_url(model_identifier, filename)
30
-
31
- # Download the model
32
- response = requests.get(url)
33
- open(filename, "wb").write(response.content)
34
-
35
- # Extract the zip file
36
- with zipfile.ZipFile(filename, 'r') as zip_ref:
37
- zip_ref.extractall("satellite_segmentation_v3_full")
38
-
39
-
40
- # The rest of code remains the same
41
-
42
- class_building = '#2A2A2A'
43
  class_building = class_building.lstrip('#')
44
  class_building = np.array(tuple(int(class_building[i:i+2], 16) for i in (0,2,4)))
45
 
46
- class_land = '#996515'
47
  class_land = class_land.lstrip('#')
48
  class_land = np.array(tuple(int(class_land[i:i+2], 16) for i in (0,2,4)))
49
 
50
- class_road = '#FFC107'
51
  class_road = class_road.lstrip('#')
52
  class_road = np.array(tuple(int(class_road[i:i+2], 16) for i in (0,2,4)))
53
 
54
- class_vegetation = '#4CAF50'
55
  class_vegetation = class_vegetation.lstrip('#')
56
  class_vegetation = np.array(tuple(int(class_vegetation[i:i+2], 16) for i in (0,2,4)))
57
 
58
- class_water = '#03A9F4'
59
  class_water = class_water.lstrip('#')
60
  class_water = np.array(tuple(int(class_water[i:i+2], 16) for i in (0,2,4)))
61
 
62
- class_unlabeled = '#BDBDBD'
63
  class_unlabeled = class_unlabeled.lstrip('#')
64
  class_unlabeled = np.array(tuple(int(class_unlabeled[i:i+2], 16) for i in (0,2,4)))
65
 
66
- def label_to_rgb(mask):
67
- rgb_mask = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
68
- rgb_mask[mask == 0] = class_building
69
- rgb_mask[mask == 1] = class_land
70
- rgb_mask[mask == 2] = class_road
71
- rgb_mask[mask == 3] = class_vegetation
72
- rgb_mask[mask == 4] = class_water
73
- rgb_mask[mask == 5] = class_unlabeled
74
- return rgb_mask
75
-
76
-
77
-
78
  def jaccard_coef(y_true, y_pred):
79
- y_true_flatten = K.flatten(y_true)
80
- y_pred_flatten = K.flatten(y_pred)
81
- intersection = K.sum(y_true_flatten * y_pred_flatten)
82
- final_coef_value = (intersection + 1.0) / (K.sum(y_true_flatten) + K.sum(y_pred_flatten) - intersection + 1.0)
83
- return final_coef_value
 
 
84
 
 
85
  weights = [0.1666, 0.1666, 0.1666, 0.1666, 0.1666, 0.1666]
86
- dice_loss = sm.losses.DiceLoss(class_weights = weights)
87
  focal_loss = sm.losses.CategoricalFocalLoss()
88
  total_loss = dice_loss + (1 * focal_loss)
89
 
90
- # ... (previous code)
91
- # ... (previous code)
92
- satellite_model = tf.keras.models.load_model("satellite_segmentation_v3_full", custom_objects={'dice_loss_plus_1focal_loss': total_loss, 'jaccard_coef': jaccard_coef, 'DiceLoss': sm.losses.DiceLoss, 'CategoricalFocalLoss': sm.losses.CategoricalFocalLoss})
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
  def process_input_image(image_source):
95
  image = np.expand_dims(image_source, 0)
96
  prediction = satellite_model.predict(image)
97
  predicted_image = np.argmax(prediction, axis=3)
98
  predicted_image = predicted_image[0, :, :]
99
- rgb_image = label_to_rgb(predicted_image)
100
- return "Predicted Masked Image", rgb_image
101
 
102
- # ... (rest of the code)
 
103
 
104
-
105
- # ... (rest of the code)
106
 
107
 
108
  my_app = gr.Blocks()
109
 
110
- # Define the custom legend HTML
111
- legend_html = '''
112
- <div style="font-size: 14px; font-weight: bold; display: flex; flex-wrap: wrap; margin-top: 10px;">
113
- <div style="display: flex; align-items: center; margin-right: 10px;">
114
- <div style="width: 20px; height: 20px; background-color: #2A2A2A; margin-right: 5px;"></div>Building
115
- </div>
116
- <div style="display: flex; align-items: center; margin-right: 10px;">
117
- <div style="width: 20px; height: 20px; background-color: #996515; margin-right: 5px;"></div>Land
118
- </div>
119
- <div style="display: flex; align-items: center; margin-right: 10px;">
120
- <div style="width: 20px; height: 20px; background-color: #FFC107; margin-right: 5px;"></div>Road
121
- </div>
122
- <div style="display: flex; align-items: center; margin-right: 10px;">
123
- <div style="width: 20px; height: 20px; background-color: #4CAF50; margin-right: 5px;"></div>Vegetation
124
- </div>
125
- <div style="display: flex; align-items: center; margin-right: 10px;">
126
- <div style="width: 20px; height: 20px; background-color: #03A9F4; margin-right: 5px;"></div>Water
127
- </div>
128
- <div style="display: flex; align-items: center;">
129
- <div style="width: 20px; height: 20px; background-color: #BDBDBD; margin-right: 5px;"></div>Other
130
- </div>
131
- </div>
132
- '''
133
-
134
  with my_app:
135
- gr.Markdown("Image Processing Application UI with Gradio")
136
- with gr.Tabs():
137
- with gr.TabItem("Select your image"):
138
- with gr.Row():
139
- with gr.Column():
140
- img_source = gr.Image(label="Please select source Image", shape=(256,256))
141
- source_image_loader = gr.Button("Load above Image")
142
- with gr.Column():
143
- output_label = gr.Label(label="Image Info")
144
- img_output = gr.Image(label="Image Output")
145
- legend = gr.HTML(legend_html) # Add the custom legend component here
146
- source_image_loader.click(
147
- process_input_image,
148
- [
149
- img_source
150
- ],
151
- [
152
- output_label,
153
- img_output
154
- ]
155
- )
156
  my_app.launch(debug=True)
 
1
  import gradio as gr
2
  import os
3
  import cv2
4
+ from PIL import Image
5
+ import numpy as np
6
  from matplotlib import pyplot as plt
7
  import random
8
  from keras.utils import get_custom_objects
9
  import os
 
 
10
 
11
  os.environ['SM_FRAMEWORK'] = 'tf.keras'
12
+
13
  import segmentation_models as sm
14
+
15
  from keras import backend as K
16
  from keras.models import load_model
17
 
18
+ class_building = '#3C1098'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  class_building = class_building.lstrip('#')
20
  class_building = np.array(tuple(int(class_building[i:i+2], 16) for i in (0,2,4)))
21
 
22
+ class_land = '#8429F6'
23
  class_land = class_land.lstrip('#')
24
  class_land = np.array(tuple(int(class_land[i:i+2], 16) for i in (0,2,4)))
25
 
26
+ class_road = '#6EC1E4'
27
  class_road = class_road.lstrip('#')
28
  class_road = np.array(tuple(int(class_road[i:i+2], 16) for i in (0,2,4)))
29
 
30
+ class_vegetation = '#FEDD3A'
31
  class_vegetation = class_vegetation.lstrip('#')
32
  class_vegetation = np.array(tuple(int(class_vegetation[i:i+2], 16) for i in (0,2,4)))
33
 
34
+ class_water = '#E2A929'
35
  class_water = class_water.lstrip('#')
36
  class_water = np.array(tuple(int(class_water[i:i+2], 16) for i in (0,2,4)))
37
 
38
+ class_unlabeled = '#9B9B9B'
39
  class_unlabeled = class_unlabeled.lstrip('#')
40
  class_unlabeled = np.array(tuple(int(class_unlabeled[i:i+2], 16) for i in (0,2,4)))
41
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  def jaccard_coef(y_true, y_pred):
43
+ y_true_flatten = K.flatten(y_true)
44
+ y_pred_flatten = K.flatten(y_pred)
45
+ intersection = K.sum(y_true_flatten * y_pred_flatten)
46
+ final_coef_value = (intersection + 1.0) / (
47
+ K.sum(y_true_flatten) + K.sum(y_pred_flatten) - intersection + 1.0)
48
+ return final_coef_value
49
+
50
 
51
+ # six class for six weights
52
  weights = [0.1666, 0.1666, 0.1666, 0.1666, 0.1666, 0.1666]
53
+ dice_loss = sm.losses.DiceLoss(class_weights=weights)
54
  focal_loss = sm.losses.CategoricalFocalLoss()
55
  total_loss = dice_loss + (1 * focal_loss)
56
 
57
+ satellite_model = load_model('model/satellite_segmentation_full.h5',
58
+ custom_objects=({'dice_loss_plus_1focal_loss': total_loss, 'jaccard_coef': jaccard_coef}))
59
+
60
+
61
+ def label_to_rgb(label_segment):
62
+ rgb_image = np.zeros((label_segment.shape[0], label_segment.shape[1], 3), dtype=np.uint8)
63
+
64
+ rgb_image[label_segment == 0] = class_water
65
+ rgb_image[label_segment == 1] = class_land
66
+ rgb_image[label_segment == 2] = class_road
67
+ rgb_image[label_segment == 3] = class_building
68
+ rgb_image[label_segment == 4] = class_vegetation
69
+ rgb_image[label_segment == 5] = class_unlabeled
70
+
71
+ return rgb_image
72
+
73
 
74
  def process_input_image(image_source):
75
  image = np.expand_dims(image_source, 0)
76
  prediction = satellite_model.predict(image)
77
  predicted_image = np.argmax(prediction, axis=3)
78
  predicted_image = predicted_image[0, :, :]
 
 
79
 
80
+ # Convert the predicted image labels to RGB
81
+ colored_predicted_image = label_to_rgb(predicted_image)
82
 
83
+ return "Predicted Masked Image", colored_predicted_image
 
84
 
85
 
86
  my_app = gr.Blocks()
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  with my_app:
89
+ gr.Markdown("Image Processing Application UI with Gradio")
90
+ with gr.Tabs():
91
+ with gr.TabItem("Select your image"):
92
+ with gr.Row():
93
+ with gr.Column():
94
+ img_source = gr.Image(label="Please select source Image", shape=(256, 256))
95
+ source_image_loader = gr.Button("Load above Image")
96
+ with gr.Column():
97
+ output_label = gr.Label(label="Image Info")
98
+ img_output = gr.Image(label="Image Output")
99
+ source_image_loader.click(
100
+ process_input_image,
101
+ [
102
+ img_source
103
+ ],
104
+ [
105
+ output_label,
106
+ img_output
107
+ ]
108
+ )
 
109
  my_app.launch(debug=True)