Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from PIL import Image | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import datasets, transforms | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| device = "cpu" | |
| class Net(nn.Module): | |
| def __init__(self): | |
| super(Net, self).__init__() | |
| self.fc1 = nn.Linear(784, 512) | |
| self.fc2 = nn.Linear(512, 256) | |
| self.fc3 = nn.Linear(256, 10) | |
| def forward(self, x): | |
| x = x.view(-1, 784) | |
| x = torch.relu(self.fc1(x)) | |
| x = torch.relu(self.fc2(x)) | |
| x = self.fc3(x) | |
| return x | |
| # Load the saved model | |
| model = torch.load('./model.pth', map_location=torch.device('cpu')) | |
| model.eval() | |
| # Define the transformation for preprocessing | |
| transform = transforms.Compose([ | |
| transforms.Grayscale(), | |
| transforms.Resize((28, 28)), | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.1307,), (0.3081,)) | |
| ]) | |
| # Function to preprocess the image | |
| def preprocess_image(image): | |
| if isinstance(image, Image.Image): | |
| image = image.convert("L") # Convert to grayscale | |
| image = np.array(image) # Convert to numpy array | |
| elif isinstance(image, np.ndarray): | |
| if len(image.shape) == 3 and image.shape[2] == 3: | |
| image = np.dot(image[..., :3], [0.2989, 0.5870, 0.1140]) # Convert RGB to grayscale | |
| image = image.astype(np.uint8) # Convert to uint8 | |
| else: | |
| raise ValueError("Invalid image type. Only PIL images and numpy arrays are supported.") | |
| image = Image.fromarray(image) | |
| image = image.resize((28, 28)) # Resize to 28x28 | |
| image = np.array(image) # Convert to numpy array | |
| image = image / 255.0 # Normalize the pixel values | |
| image = torch.FloatTensor(image) # Convert to torch tensor | |
| image = image.unsqueeze(0) # Add a batch dimension | |
| return image | |
| # Function to make predictions | |
| def predict_sketch(image): | |
| # Preprocess the image | |
| image = preprocess_image(image) | |
| # Make the prediction | |
| with torch.no_grad(): | |
| output = model(image) | |
| probabilities = torch.nn.functional.softmax(output[0], dim=0) | |
| labels = [str(i) for i in range(10)] | |
| probabilities = probabilities.cpu().numpy() | |
| result = {label: float(prob) for label, prob in zip(labels, probabilities)} | |
| return result | |
| # Create the interface | |
| interface = gr.Interface(fn=predict_sketch, inputs="sketchpad", outputs=gr.Label(num_top_classes=3), | |
| title='MNIST Realtime Recognition', live=True, css=".footer {display:none !important}", | |
| description="Draw a number 0 through 9 on the sketchpad, and see predictions in real time. Model accuracy is 96%.") | |
| # Run the interface | |
| interface.launch(enable_queue=True) |