Spaces:
Runtime error
Runtime error
| import scipy | |
| import gradio as gr | |
| import numpy as np | |
| import tensorflow as tf | |
| def digit_KNN_prediction(test_image, K=5): | |
| test_image_flatten = test_image.reshape((-1, 28*28)) | |
| from sklearn.neighbors import KNeighborsClassifier | |
| (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data() | |
| img_shape = x_train.shape | |
| n_samples = img_shape[0] | |
| width = img_shape[1] | |
| height = img_shape[2] | |
| x_train_flatten = x_train.reshape(n_samples, width*height) | |
| KNN_classifier = KNeighborsClassifier(n_neighbors=K) | |
| KNN_classifier.fit(x_train_flatten, y_train) | |
| ans = KNN_classifier.predict(test_image_flatten) | |
| return ans[0] | |
| input_image = gr.inputs.Image(shape=(28, 28), image_mode='L') | |
| input_K = gr.inputs.Slider(1, 13, step = 2, default = 5) | |
| output_label = gr.outputs.Textbox(label="Predicted Digit") | |
| gr.Interface(fn=digit_KNN_prediction, | |
| inputs = [input_image, input_K], | |
| outputs = [output_label], | |
| title = "Digit classification using KNN algorithm", | |
| ).launch(debug=True) |