import gradio as gr import torch import torch.nn as nn import gradio as gr import cv2 from net import Net soft_layer = nn.Softmax() model_dir = "./model/mnist.pth" checkpoint = torch.load(model_dir, map_location="cpu") state_dict = checkpoint.get("state_dict", checkpoint) model = Net() model.load_state_dict(state_dict ) def classify(input): gray = cv2.cvtColor(input, cv2.COLOR_BGR2GRAY) resized = cv2.resize(gray, (28,28), interpolation = cv2.INTER_AREA) resized_t = torch.tensor(resized) resized_t = torch.unsqueeze(resized_t, 0) resized_t = torch.unsqueeze(resized_t, 0) resized_t = (resized_t -128 )/128. return 'label: {}'.format(torch.argmax( soft_layer( model(resized_t))).numpy()) # return "Hello " + name + "!" demo = gr.Interface(fn=classify, inputs="image", outputs="text") # demo.launch(debug=True) demo.launch()