Update app.py
Browse files
app.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
import matplotlib.pyplot as plt
|
| 2 |
import numpy as np
|
| 3 |
from PIL import Image
|
| 4 |
from skimage.transform import resize
|
|
@@ -7,6 +6,9 @@ from tensorflow.keras.models import load_model
|
|
| 7 |
|
| 8 |
from huggingface_hub import snapshot_download
|
| 9 |
|
|
|
|
|
|
|
|
|
|
| 10 |
import gradio as gr
|
| 11 |
import os
|
| 12 |
import io
|
|
@@ -62,13 +64,24 @@ def get_predictions(y_prediction_encoded):
|
|
| 62 |
return predicted_label_indices
|
| 63 |
|
| 64 |
def predict(image):
|
|
|
|
|
|
|
| 65 |
sample_image_resized = resize_image(image)
|
| 66 |
y_pred = ensemble_predict(sample_image_resized)
|
| 67 |
y_pred = get_predictions(y_pred).squeeze()
|
| 68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
# Create a figure without saving it to a file
|
| 70 |
fig, ax = plt.subplots()
|
| 71 |
-
cax = ax.imshow(y_pred, cmap=
|
| 72 |
|
| 73 |
# Convert the figure to a PIL Image
|
| 74 |
image_buffer = io.BytesIO()
|
|
|
|
|
|
|
| 1 |
import numpy as np
|
| 2 |
from PIL import Image
|
| 3 |
from skimage.transform import resize
|
|
|
|
| 6 |
|
| 7 |
from huggingface_hub import snapshot_download
|
| 8 |
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
from matplotlib.colors import ListedColormap
|
| 11 |
+
|
| 12 |
import gradio as gr
|
| 13 |
import os
|
| 14 |
import io
|
|
|
|
| 64 |
return predicted_label_indices
|
| 65 |
|
| 66 |
def predict(image):
|
| 67 |
+
|
| 68 |
+
# Steps to get prediction
|
| 69 |
sample_image_resized = resize_image(image)
|
| 70 |
y_pred = ensemble_predict(sample_image_resized)
|
| 71 |
y_pred = get_predictions(y_pred).squeeze()
|
| 72 |
|
| 73 |
+
# Define your custom colors for each label
|
| 74 |
+
colors = ['cyan', 'yellow', 'magenta', 'green', 'blue', 'black', 'white']
|
| 75 |
+
# Create a ListedColormap
|
| 76 |
+
cmap = ListedColormap(colors)
|
| 77 |
+
# Create colorbar and set ticks and ticklabels
|
| 78 |
+
cbar = plt.colorbar(ticks=np.arange(1, 8))
|
| 79 |
+
cbar.set_ticklabels(['Urban', 'Agriculture', 'Range Land', 'Forest', 'Water', 'Barren', 'Unknown'])
|
| 80 |
+
|
| 81 |
+
|
| 82 |
# Create a figure without saving it to a file
|
| 83 |
fig, ax = plt.subplots()
|
| 84 |
+
cax = ax.imshow(y_pred, cmap=cmap, vmin=1, vmax=7)
|
| 85 |
|
| 86 |
# Convert the figure to a PIL Image
|
| 87 |
image_buffer = io.BytesIO()
|