import gradio as gr from PIL import Image import numpy as np import torch import torch.nn as nn from torchvision import datasets, transforms device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = "cpu" class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.fc1 = nn.Linear(784, 512) self.fc2 = nn.Linear(512, 256) self.fc3 = nn.Linear(256, 10) def forward(self, x): x = x.view(-1, 784) x = torch.relu(self.fc1(x)) x = torch.relu(self.fc2(x)) x = self.fc3(x) return x # Load the saved model model = torch.load('./model.pth', map_location=torch.device('cpu')) model.eval() # Define the transformation for preprocessing transform = transforms.Compose([ transforms.Grayscale(), transforms.Resize((28, 28)), transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) # Function to preprocess the image def preprocess_image(image): if isinstance(image, Image.Image): image = image.convert("L") # Convert to grayscale image = np.array(image) # Convert to numpy array elif isinstance(image, np.ndarray): if len(image.shape) == 3 and image.shape[2] == 3: image = np.dot(image[..., :3], [0.2989, 0.5870, 0.1140]) # Convert RGB to grayscale image = image.astype(np.uint8) # Convert to uint8 else: raise ValueError("Invalid image type. Only PIL images and numpy arrays are supported.") image = Image.fromarray(image) image = image.resize((28, 28)) # Resize to 28x28 image = np.array(image) # Convert to numpy array image = image / 255.0 # Normalize the pixel values image = torch.FloatTensor(image) # Convert to torch tensor image = image.unsqueeze(0) # Add a batch dimension return image # Function to make predictions def predict_sketch(image): # Preprocess the image image = preprocess_image(image) # Make the prediction with torch.no_grad(): output = model(image) probabilities = torch.nn.functional.softmax(output[0], dim=0) labels = [str(i) for i in range(10)] probabilities = probabilities.cpu().numpy() result = {label: float(prob) for label, prob in zip(labels, probabilities)} return result # Create the interface interface = gr.Interface(fn=predict_sketch, inputs="sketchpad", outputs=gr.Label(num_top_classes=3), title='MNIST Realtime Recognition', live=True, css=".footer {display:none !important}", description="Draw a number 0 through 9 on the sketchpad, and see predictions in real time. Model accuracy is 96%.") # Run the interface interface.launch(enable_queue=True)