Spaces:
Build error
Build error
| from pathlib import Path | |
| import gradio as gr | |
| from PIL import Image | |
| import numpy as np | |
| from mnist_model.database import transform | |
| from mnist_model.train import MNISTModel | |
| import torch | |
| current_dir = Path(__file__).parent.absolute() | |
| model_path = current_dir / "data" / "models" / "mnist_resnet18_4epochs.ckpt" | |
| model = MNISTModel.load_from_checkpoint(str(model_path)).model | |
| model.eval() | |
| def classify_image(inp): | |
| img = inp["composite"] | |
| img = np.array(img) | |
| # take onliy the last channel | |
| img = img[:, :, -1] | |
| img = Image.fromarray(img) | |
| img = transform()(img) | |
| img = img.unsqueeze(0) | |
| pred = model(img) | |
| pred = torch.softmax(pred, dim=1) | |
| return {str(i): float(pred[0][i]) for i in range(10)} | |
| inputs = gr.Sketchpad(label="Draw a number", crop_size=(28, 28), type="pil") | |
| outputs = gr.Label(num_top_classes=5) | |
| title = "MNIST" | |
| description = "A simple number recognition example model" | |
| article = "<p style='text-align: center'>Here is a simple implementation of number recognition using MNIST dataset and a Resnet18 backbone.<a href=" | |
| gr.Interface(fn=classify_image, inputs=inputs, outputs=outputs, title=title, description=description, article=article).launch() | |