Classifier / app.py
Hamidreza-Hashemp's picture
Update app.py
ced629a
raw
history blame contribute delete
867 Bytes
import gradio as gr
import torch
import torch.nn as nn
import gradio as gr
import cv2
from net import Net
soft_layer = nn.Softmax()
model_dir = "./model/mnist.pth"
checkpoint = torch.load(model_dir, map_location="cpu")
state_dict = checkpoint.get("state_dict", checkpoint)
model = Net()
model.load_state_dict(state_dict )
def classify(input):
gray = cv2.cvtColor(input, cv2.COLOR_BGR2GRAY)
resized = cv2.resize(gray, (28,28), interpolation = cv2.INTER_AREA)
resized_t = torch.tensor(resized)
resized_t = torch.unsqueeze(resized_t, 0)
resized_t = torch.unsqueeze(resized_t, 0)
resized_t = (resized_t -128 )/128.
return 'label: {}'.format(torch.argmax( soft_layer( model(resized_t))).numpy())
# return "Hello " + name + "!"
demo = gr.Interface(fn=classify, inputs="image", outputs="text")
# demo.launch(debug=True)
demo.launch()