MNIST-Realtime / app.py
liquidaudit
added css
24e9f54
import gradio as gr
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
from torchvision import datasets, transforms
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = "cpu"
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(784, 512)
self.fc2 = nn.Linear(512, 256)
self.fc3 = nn.Linear(256, 10)
def forward(self, x):
x = x.view(-1, 784)
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x
# Load the saved model
model = torch.load('./model.pth', map_location=torch.device('cpu'))
model.eval()
# Define the transformation for preprocessing
transform = transforms.Compose([
transforms.Grayscale(),
transforms.Resize((28, 28)),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
# Function to preprocess the image
def preprocess_image(image):
if isinstance(image, Image.Image):
image = image.convert("L") # Convert to grayscale
image = np.array(image) # Convert to numpy array
elif isinstance(image, np.ndarray):
if len(image.shape) == 3 and image.shape[2] == 3:
image = np.dot(image[..., :3], [0.2989, 0.5870, 0.1140]) # Convert RGB to grayscale
image = image.astype(np.uint8) # Convert to uint8
else:
raise ValueError("Invalid image type. Only PIL images and numpy arrays are supported.")
image = Image.fromarray(image)
image = image.resize((28, 28)) # Resize to 28x28
image = np.array(image) # Convert to numpy array
image = image / 255.0 # Normalize the pixel values
image = torch.FloatTensor(image) # Convert to torch tensor
image = image.unsqueeze(0) # Add a batch dimension
return image
# Function to make predictions
def predict_sketch(image):
# Preprocess the image
image = preprocess_image(image)
# Make the prediction
with torch.no_grad():
output = model(image)
probabilities = torch.nn.functional.softmax(output[0], dim=0)
labels = [str(i) for i in range(10)]
probabilities = probabilities.cpu().numpy()
result = {label: float(prob) for label, prob in zip(labels, probabilities)}
return result
# Create the interface
interface = gr.Interface(fn=predict_sketch, inputs="sketchpad", outputs=gr.Label(num_top_classes=3),
title='MNIST Realtime Recognition', live=True, css=".footer {display:none !important}",
description="Draw a number 0 through 9 on the sketchpad, and see predictions in real time. Model accuracy is 96%.")
# Run the interface
interface.launch(enable_queue=True)