Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn.functional as F | |
| import torchvision.transforms as T | |
| import json | |
| from CNN import CNN | |
| # def greet(name): | |
| # return "Hello " + name + "!!" | |
| # demo = gr.Interface(fn=greet, inputs="text", outputs="text") | |
| # demo.launch() | |
| # Load the model | |
| n_classes = 345 | |
| params = { | |
| 'n_filters': 30, | |
| 'hidden_dim': 100, | |
| 'n_layers': 2, | |
| 'n_classes': n_classes | |
| } | |
| print('testesesesf') | |
| model = CNN(**params) | |
| model.load_state_dict(torch.load('model_weights.pth', map_location='cpu')) | |
| model.eval() | |
| # utils | |
| labels_path = 'labels.json' | |
| with open(labels_path, 'r') as f: | |
| names = json.load(f) | |
| transform = T.Compose([ | |
| T.ToTensor(), # (1, H, W), values in [0, 1], white=1 black=0 | |
| T.Lambda(lambda x: 1.0 - x), # invert -> white=0, black=1 | |
| T.Resize((28, 28), interpolation=T.InterpolationMode.BILINEAR), | |
| # T.Normalize((0.5,), (0.5,)) # optional if your model expects [-1, 1] | |
| ]) | |
| def predict(input_image): | |
| img = input_image['composite'] | |
| if img is None: | |
| return {"No drawing detected": 1.0} | |
| img = transform(img) | |
| img = img.unsqueeze(0).to(torch.float32) # add batch dimension | |
| # torch.save(img, ) | |
| with torch.no_grad(): | |
| out = model(img) | |
| # idx = torch.argmax(out).item() | |
| probs = F.softmax(out, dim=1).squeeze(0) | |
| res = {names[i]:proba.item() for i, proba in enumerate(probs)} | |
| return res | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Sketchpad( | |
| label="Draw a sketch", | |
| image_mode='L', | |
| brush=gr.Brush(default_size=15, default_color='black', colors=['black'], color_mode='fixed') | |
| ), | |
| outputs=gr.Label(num_top_classes=5), | |
| title="Sketch Recognition model", | |
| clear_btn=gr.ClearButton(), | |
| live=True | |
| ) | |
| print('test') | |
| demo.launch() |