amosfang commited on
Commit
c4eb93a
·
verified ·
1 Parent(s): e490bc5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -3
app.py CHANGED
@@ -74,7 +74,42 @@ def get_predictions(y_prediction_encoded):
74
 
75
  return predicted_label_indices
76
 
77
- def predict(image):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  # Steps to get prediction
79
  sample_image_resized = resize_image(image)
80
  y_pred = ensemble_predict(sample_image_resized)
@@ -122,7 +157,7 @@ sample_images = get_sample_images('example_images')
122
  # ).launch(debug=True, share=True)
123
 
124
  tab1 = gr.Interface(
125
- fn=predict,
126
  inputs=gr.Image(label='', type="pil"),
127
  outputs=[gr.Image(type="pil"), gr.Image(type="pil")],
128
  title='Images with Ground Truth',
@@ -131,7 +166,7 @@ tab1 = gr.Interface(
131
 
132
  # Create the video processing interface
133
  tab2 = gr.Interface(
134
- fn=predict,
135
  inputs=gr.File(label=""),
136
  outputs=gr.File(label=""),
137
  title='Images with Ground Truth',
 
74
 
75
  return predicted_label_indices
76
 
77
+ def predict_on_train(image):
78
+ # Steps to get prediction
79
+ sample_image_resized = resize_image(image)
80
+ y_pred = ensemble_predict(sample_image_resized)
81
+ y_pred = get_predictions(y_pred).squeeze()
82
+
83
+ # Define your custom colors for each label
84
+ colors = ['cyan', 'yellow', 'magenta', 'green', 'blue', 'black', 'white']
85
+ # Create a ListedColormap
86
+ cmap = ListedColormap(colors)
87
+
88
+ # Create a figure
89
+ fig, ax = plt.subplots()
90
+
91
+ # Display the image
92
+ ax.imshow(sample_image_resized)
93
+
94
+ # Display the predictions using the specified colormap
95
+ cax = ax.imshow(y_pred, cmap=cmap, vmin=1, vmax=7, alpha=0.5)
96
+
97
+ # Create colorbar and set ticks and ticklabels
98
+ cbar = plt.colorbar(cax, ticks=np.arange(1, 8))
99
+ cbar.set_ticklabels(['Urban', 'Agriculture', 'Range Land', 'Forest', 'Water', 'Barren', 'Unknown'])
100
+
101
+ # Convert the figure to a PIL Image
102
+ image_buffer = io.BytesIO()
103
+ plt.savefig(image_buffer, format='png')
104
+ image_buffer.seek(0)
105
+ image_pil = Image.open(image_buffer)
106
+
107
+ # Close the figure to release resources
108
+ plt.close(fig)
109
+
110
+ return image_pil, image_pil
111
+
112
+ def predict_on_test(image):
113
  # Steps to get prediction
114
  sample_image_resized = resize_image(image)
115
  y_pred = ensemble_predict(sample_image_resized)
 
157
  # ).launch(debug=True, share=True)
158
 
159
  tab1 = gr.Interface(
160
+ fn=predict_on_train,
161
  inputs=gr.Image(label='', type="pil"),
162
  outputs=[gr.Image(type="pil"), gr.Image(type="pil")],
163
  title='Images with Ground Truth',
 
166
 
167
  # Create the video processing interface
168
  tab2 = gr.Interface(
169
+ fn=predict_on_test,
170
  inputs=gr.File(label=""),
171
  outputs=gr.File(label=""),
172
  title='Images with Ground Truth',