Geetansh
Model_demo_done
1c15775
import gradio as gr
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import load_model
import pickle
from PIL import Image
import pandas as pd
# Load the trained model
model = load_model("./model_final.keras")
# Load the fitted pipeline
with open("pipeline_1.pkl", "rb") as f:
pipeline_1 = pickle.load(f)
def preprocess_and_predict(image):
"""
Preprocess the input image using the pipeline and make a prediction.
"""
# Resize and convert the input image to grayscale (28x28)
image = image.resize((28, 28)).convert("L")
# Flatten the image to a 784-length vector
image_array = np.array(image).reshape(1, -1).astype(np.float32)
# Convert the flattened array to a DataFrame (with appropriate column names)
image_df = pd.DataFrame(image_array, columns=[f"pixel{i}" for i in range(784)])
# Transform the input using the fitted pipeline
image_array_transformed = pipeline_1.transform(image_df).reshape(1,-1) #reshape to [[]] because tensorflow accepts matrices
# Make predictions with the model
predictions = model.predict(image_array_transformed)
# Get the predicted digit (the class with the highest probability)
predicted_digit = np.argmax(predictions, axis=1)[0]
return f"Predicted Digit: {predicted_digit}"
# Define sample examples with paths to example images
examples = [
["./examples/0.jpg"],
["./examples/1.jpg"],
["./examples/2_high_contrast.jpg"],
["./examples/4.jpg"],
["./examples/6.jpg"],
["./examples/7.jpg"],
["./examples/8_high_contrast.jpg"],
["./examples/8.jpg"]
]
# Define Gradio interface
demo = gr.Interface(
fn=preprocess_and_predict, # Function to be called
inputs=gr.Image(type="pil"), # Input type: Image
outputs="text", # Output type: Text
title="MNIST Digit Classifier", # Title
description="Upload an image of a digit (0-9) from the MNIST dataset (https://huggingface.co/datasets/ylecun/mnist) [The model will perform poorly for custom images bcz it has only been trained using \"as is\" images from MNIST i.e\n(i) pretty much centered\n (ii) 28x28 pixels\n (iii) perfectly black background\n (iv) white font color images. A custom image will have to be resized (to be 28x28) and still might not have the above things and thus, the model performs poorly], and the model will predict the digit.",
examples=examples # Add sample examples
)
# Launch the app
demo.launch()