Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| from torch import nn | |
| from torchvision import transforms | |
| import matplotlib.pyplot as plt | |
| from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas | |
| import PIL.Image | |
| # Load the model | |
| class MyMnist_ModelV0(nn.Module): | |
| def __init__(self, input_shape: int, hidden_units: int, hidden_units2: int, output_shape: int): | |
| super().__init__() | |
| self.layer_stack = nn.Sequential( | |
| nn.Flatten(), | |
| nn.Linear(in_features=input_shape, out_features=hidden_units), | |
| nn.ReLU(), | |
| nn.Linear(in_features=hidden_units, out_features=hidden_units2), | |
| nn.ReLU(), | |
| nn.Linear(in_features=hidden_units2, out_features=output_shape) | |
| ) | |
| def forward(self, x): | |
| return self.layer_stack(x) | |
| # Load the pre-trained model | |
| load_model = MyMnist_ModelV0(input_shape=784, | |
| hidden_units=256, | |
| hidden_units2=128, | |
| output_shape=10) | |
| PATH = "state_dict_model.pth" # Path to the trained model | |
| load_model.load_state_dict(torch.load(PATH)) | |
| load_model.eval() | |
| # Function to recognize digit | |
| def recognize_digit(image): | |
| if image is not None: | |
| # Convert image to grayscale | |
| image = np.array(image.convert("L")) | |
| # Resize image to 28x28 | |
| image = torch.tensor(image / 255.0, dtype=torch.float32) | |
| # Perform inference | |
| with torch.inference_mode(): | |
| prediction = load_model(image) | |
| prediction = torch.softmax(prediction, dim=1) | |
| return {str(i): float(prediction[0][i]) for i in range(10)} | |
| else: | |
| return "" | |
| # Function to create a canvas for drawing | |
| def create_canvas(): | |
| fig, ax = plt.subplots() | |
| ax.set_title("Draw your digit") | |
| ax.set_xticks([]) | |
| ax.set_yticks([]) | |
| canvas = FigureCanvas(fig) | |
| return canvas | |
| # Create canvas | |
| canvas = create_canvas() | |
| # Define Gradio interface | |
| demo = gr.Interface(fn=recognize_digit, | |
| inputs=gr.inputs.Image(canvas=canvas), | |
| outputs=gr.outputs.Label(num_top_classes=1)) | |
| # Launch the interface | |
| demo.launch(share=True) | |