Spaces:
Sleeping
Sleeping
| import torch | |
| from torch import nn | |
| import torchvision.transforms as transforms | |
| import torch.nn.functional as F | |
| from pathlib import Path | |
| import gradio as gr | |
| from huggingface_hub import hf_hub_download | |
| LABELS = Path("classes.txt").read_text().splitlines() | |
| num_classes = len(LABELS) | |
| model = nn.Sequential( | |
| nn.Conv2d(1, 64, 3, padding="same"), | |
| nn.ReLU(), | |
| nn.MaxPool2d(2), | |
| nn.Conv2d(64, 128, 3, padding="same"), | |
| nn.ReLU(), | |
| nn.MaxPool2d(2), | |
| nn.Conv2d(128, 256, 3, padding="same"), | |
| nn.ReLU(), | |
| nn.MaxPool2d(2), | |
| nn.Flatten(), | |
| nn.Linear(2304, 512), | |
| nn.ReLU(), | |
| nn.Linear(512, num_classes), | |
| ) | |
| model_path = hf_hub_download(repo_id="jerilseb/quickdraw-small", filename="model.pth") | |
| state_dict = torch.load(model_path, map_location="cpu") | |
| model.load_state_dict(state_dict) | |
| model.eval() | |
| transform = transforms.Compose( | |
| [ | |
| transforms.Resize((28, 28)), | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.5,), (0.5,)), | |
| ] | |
| ) | |
| def predict(image): | |
| image = image['composite'] | |
| tensor = transform(image).unsqueeze(0) | |
| with torch.no_grad(): | |
| out = model(tensor) | |
| probabilities = F.softmax(out[0], dim=0) | |
| values, indices = torch.topk(probabilities, 5) | |
| return {LABELS[i]: v.item() for i, v in zip(indices, values)} | |
| inputs = gr.ImageEditor( | |
| type="pil", | |
| height=720, | |
| width=720, | |
| layers=False, | |
| image_mode="L", | |
| brush=gr.Brush(default_color="white", default_size=20), | |
| sources=[], | |
| label="Draw a shape", | |
| ) | |
| demo = gr.Interface(predict, inputs=inputs, outputs="label", live=True) | |
| demo.launch(debug=True) | |