nathan ayers
Update app.py
3b750f3 verified
raw
history blame contribute delete
767 Bytes
import pickle
import numpy as np
from PIL import Image
import gradio as gr
# load your pickled RandomForest (make sure mnist_model.pkl lives in /app)
model = pickle.load(open("mnist_model.pkl", "rb"))
def classify_digit(img: Image.Image) -> str:
# convert to 28×28 grayscale array
gray = img.convert("L").resize((28, 28))
arr = np.array(gray).reshape(1, -1)
pred = model.predict(arr)[0]
return f"Predicted digit: {pred}"
demo = gr.Interface(
fn=classify_digit,
inputs=gr.Image(type="pil", label="Upload a 28×28 digit"),
outputs=gr.Textbox(label="Prediction"),
title="MNIST Digit Classifier",
description="Upload a handwritten digit image (28×28) to get a live prediction!"
)
if __name__ == "__main__":
demo.launch()