MNIST / app.py
UdayPrasad's picture
Update app.py
a62cbb0
import scipy
import gradio as gr
import numpy as np
import tensorflow as tf
def digit_KNN_prediction(test_image, K=5):
test_image_flatten = test_image.reshape((-1, 28*28))
from sklearn.neighbors import KNeighborsClassifier
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
img_shape = x_train.shape
n_samples = img_shape[0]
width = img_shape[1]
height = img_shape[2]
x_train_flatten = x_train.reshape(n_samples, width*height)
KNN_classifier = KNeighborsClassifier(n_neighbors=K)
KNN_classifier.fit(x_train_flatten, y_train)
ans = KNN_classifier.predict(test_image_flatten)
return ans[0]
input_image = gr.inputs.Image(shape=(28, 28), image_mode='L')
input_K = gr.inputs.Slider(1, 13, step = 2, default = 5)
output_label = gr.outputs.Textbox(label="Predicted Digit")
gr.Interface(fn=digit_KNN_prediction,
inputs = [input_image, input_K],
outputs = [output_label],
title = "Digit classification using KNN algorithm",
).launch(debug=True)