Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import gradio as gr | |
| import cv2 | |
| from net import Net | |
| soft_layer = nn.Softmax() | |
| model_dir = "./model/mnist.pth" | |
| checkpoint = torch.load(model_dir, map_location="cpu") | |
| state_dict = checkpoint.get("state_dict", checkpoint) | |
| model = Net() | |
| model.load_state_dict(state_dict ) | |
| def classify(input): | |
| gray = cv2.cvtColor(input, cv2.COLOR_BGR2GRAY) | |
| resized = cv2.resize(gray, (28,28), interpolation = cv2.INTER_AREA) | |
| resized_t = torch.tensor(resized) | |
| resized_t = torch.unsqueeze(resized_t, 0) | |
| resized_t = torch.unsqueeze(resized_t, 0) | |
| resized_t = (resized_t -128 )/128. | |
| return 'label: {}'.format(torch.argmax( soft_layer( model(resized_t))).numpy()) | |
| # return "Hello " + name + "!" | |
| demo = gr.Interface(fn=classify, inputs="image", outputs="text") | |
| # demo.launch(debug=True) | |
| demo.launch() |