thomasinovic's picture
fix weight file name import in app.py file
28a264f
import gradio as gr
import torch
import torch.nn.functional as F
import torchvision.transforms as T
import json
from CNN import CNN
# def greet(name):
# return "Hello " + name + "!!"
# demo = gr.Interface(fn=greet, inputs="text", outputs="text")
# demo.launch()
# Load the model
n_classes = 345
params = {
'n_filters': 30,
'hidden_dim': 100,
'n_layers': 2,
'n_classes': n_classes
}
print('testesesesf')
model = CNN(**params)
model.load_state_dict(torch.load('model_weights.pth', map_location='cpu'))
model.eval()
# utils
labels_path = 'labels.json'
with open(labels_path, 'r') as f:
names = json.load(f)
transform = T.Compose([
T.ToTensor(), # (1, H, W), values in [0, 1], white=1 black=0
T.Lambda(lambda x: 1.0 - x), # invert -> white=0, black=1
T.Resize((28, 28), interpolation=T.InterpolationMode.BILINEAR),
# T.Normalize((0.5,), (0.5,)) # optional if your model expects [-1, 1]
])
def predict(input_image):
img = input_image['composite']
if img is None:
return {"No drawing detected": 1.0}
img = transform(img)
img = img.unsqueeze(0).to(torch.float32) # add batch dimension
# torch.save(img, )
with torch.no_grad():
out = model(img)
# idx = torch.argmax(out).item()
probs = F.softmax(out, dim=1).squeeze(0)
res = {names[i]:proba.item() for i, proba in enumerate(probs)}
return res
demo = gr.Interface(
fn=predict,
inputs=gr.Sketchpad(
label="Draw a sketch",
image_mode='L',
brush=gr.Brush(default_size=15, default_color='black', colors=['black'], color_mode='fixed')
),
outputs=gr.Label(num_top_classes=5),
title="Sketch Recognition model",
clear_btn=gr.ClearButton(),
live=True
)
print('test')
demo.launch()