|
|
import gradio as gr |
|
|
import numpy as np |
|
|
import tensorflow as tf |
|
|
from tensorflow.keras.models import load_model |
|
|
import pickle |
|
|
from PIL import Image |
|
|
import pandas as pd |
|
|
|
|
|
|
|
|
model = load_model("./model_final.keras") |
|
|
|
|
|
|
|
|
with open("pipeline_1.pkl", "rb") as f: |
|
|
pipeline_1 = pickle.load(f) |
|
|
|
|
|
def preprocess_and_predict(image): |
|
|
""" |
|
|
Preprocess the input image using the pipeline and make a prediction. |
|
|
""" |
|
|
|
|
|
image = image.resize((28, 28)).convert("L") |
|
|
|
|
|
|
|
|
image_array = np.array(image).reshape(1, -1).astype(np.float32) |
|
|
|
|
|
|
|
|
image_df = pd.DataFrame(image_array, columns=[f"pixel{i}" for i in range(784)]) |
|
|
|
|
|
|
|
|
image_array_transformed = pipeline_1.transform(image_df).reshape(1,-1) |
|
|
|
|
|
|
|
|
predictions = model.predict(image_array_transformed) |
|
|
|
|
|
|
|
|
predicted_digit = np.argmax(predictions, axis=1)[0] |
|
|
|
|
|
return f"Predicted Digit: {predicted_digit}" |
|
|
|
|
|
|
|
|
examples = [ |
|
|
["./examples/0.jpg"], |
|
|
["./examples/1.jpg"], |
|
|
["./examples/2_high_contrast.jpg"], |
|
|
["./examples/4.jpg"], |
|
|
["./examples/6.jpg"], |
|
|
["./examples/7.jpg"], |
|
|
["./examples/8_high_contrast.jpg"], |
|
|
["./examples/8.jpg"] |
|
|
] |
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=preprocess_and_predict, |
|
|
inputs=gr.Image(type="pil"), |
|
|
outputs="text", |
|
|
title="MNIST Digit Classifier", |
|
|
description="Upload an image of a digit (0-9) from the MNIST dataset (https://huggingface.co/datasets/ylecun/mnist) [The model will perform poorly for custom images bcz it has only been trained using \"as is\" images from MNIST i.e\n(i) pretty much centered\n (ii) 28x28 pixels\n (iii) perfectly black background\n (iv) white font color images. A custom image will have to be resized (to be 28x28) and still might not have the above things and thus, the model performs poorly], and the model will predict the digit.", |
|
|
examples=examples |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
demo.launch() |
|
|
|