Spaces:
Runtime error
Runtime error
| import scipy | |
| import gradio as gr | |
| import numpy as np | |
| import tensorflow as tf | |
| from tensorflow import keras | |
| from keras.models import load_model | |
| import pickle | |
| def mnist_prediction(test_image, model='KNN'): | |
| test_image_flatten = test_image.reshape((-1, 28*28)) | |
| if model == 'KNN': | |
| with open('KNN_best_model_final.pkl', 'rb') as file: | |
| knn_loaded = pickle.load(file) | |
| ans = knn_loaded.predict(test_image_flatten) | |
| return ans[0] | |
| elif model == 'SoftMax': | |
| with open('softmax_best_model_final.pkl', 'rb') as file: | |
| softmax_model_loaded = pickle.load(file) | |
| ans = softmax_model_loaded.predict(test_image_flatten) | |
| return ans[0] | |
| elif model == 'Deep Neural Network': | |
| dnn_model = load_model("deep_nn_model_final.h5") | |
| ans_prediction = dnn_model.predict(np.asarray(test_image_flatten)) | |
| ans = np.argmax(ans_prediction) | |
| return ans | |
| elif model == 'CNN': | |
| cnn_model = load_model("cnn_model_final.h5") | |
| ans_prediction = cnn_model.predict(np.asarray([test_image])) | |
| ans = np.argmax(ans_prediction) | |
| return ans | |
| elif model == 'SVM': | |
| with open('svm_best_model_final.pkl', 'rb') as file: | |
| svm_model_loaded = pickle.load(file) | |
| ans = svm_model_loaded.predict(test_image_flatten) | |
| return ans[0] | |
| elif model == 'Decision Tree': | |
| with open('tree_model_final.pkl', 'rb') as file: | |
| tree_model_loaded = pickle.load(file) | |
| ans = tree_model_loaded.predict(test_image_flatten) | |
| return ans[0] | |
| elif model == 'Random Forest': | |
| with open('forest_model_final.pkl', 'rb') as file: | |
| forest_model_loaded = pickle.load(file) | |
| ans = forest_model_loaded.predict(test_image_flatten) | |
| return ans[0] | |
| return "Not found" | |
| input_image = gr.inputs.Image(shape=(28, 28), image_mode='L') | |
| input_model = gr.inputs.Dropdown(['KNN', 'SoftMax', 'Deep Neural Network', 'CNN', 'SVM', 'Decision Tree', 'Random Forest']) | |
| output_label = gr.outputs.Textbox(label="Predicted Digit") | |
| gr.Interface(fn=mnist_prediction, | |
| inputs = [input_image, input_model], | |
| outputs = output_label, | |
| title = "MNIST classification", | |
| ).launch(debug=True) |