tut-pictionary / app.py
aassal's picture
model and clases update
1a1b187
from pathlib import Path
import os
import torch
from torch import nn
import gradio as gr
LABELS = Path("class_names.txt").read_text().splitlines()
model = nn.Sequential(
nn.Conv2d(1, 32, 3, padding='same'),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding='same'),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(64, 128, 3, padding='same'),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Flatten(),
nn.Linear(1152, 256),
nn.ReLU(),
nn.Linear(256, len(LABELS)),
)
state_dict = torch.load('pytorch_model.bin', map_location='cpu')
model.load_state_dict(state_dict, strict=False)
model.eval()
model = nn.Sequential(
nn.Conv2d(1, 32, 3, padding='same'),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding='same'),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(64, 128, 3, padding='same'),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Flatten(),
nn.Linear(1152, 256),
nn.ReLU(),
nn.Linear(256, len(LABELS)),
)
state_dict = torch.load('pytorch_model.bin', map_location='cpu')
model.load_state_dict(state_dict, strict=False)
model.eval()
def predict(inp):
x=torch.tensor(inp).unsqueeze(0).unsqueeze(0) / 255.
with torch.no_grad():
out = model(x)
probs = torch.nn.functional.softmax(out[0], dim=0)
values, indices = probs.topk(probs, k=5)
return {Labels[i]: v.item() for i, v in zip(indices, values)}
interface = gr.Interface(predict, inputs="sketchpad", outputs="label", live=True)
interface.launch()