| | --- |
| | license: mit |
| | --- |
| | |
| | ## Demo |
| | https://huggingface.co/spaces/jerilseb/quickdraw-small |
| |
|
| | ## Usage |
| |
|
| | ```python |
| | import torch |
| | from torch import nn |
| | import torchvision.transforms as transforms |
| | import torch.nn.functional as F |
| | from pathlib import Path |
| | |
| | LABELS = Path("classes.txt").read_text().splitlines() |
| | num_classes = len(LABELS) |
| | |
| | model = nn.Sequential( |
| | nn.Conv2d(1, 64, 3, padding="same"), |
| | nn.ReLU(), |
| | nn.MaxPool2d(2), |
| | nn.Conv2d(64, 128, 3, padding="same"), |
| | nn.ReLU(), |
| | nn.MaxPool2d(2), |
| | nn.Conv2d(128, 256, 3, padding="same"), |
| | nn.ReLU(), |
| | nn.MaxPool2d(2), |
| | nn.Flatten(), |
| | nn.Linear(2304, 512), |
| | nn.ReLU(), |
| | nn.Linear(512, num_classes), |
| | ) |
| | |
| | state_dict = torch.load("model.pth", map_location="cpu") |
| | model.load_state_dict(state_dict) |
| | model.eval() |
| | |
| | transform = transforms.Compose( |
| | [ |
| | transforms.Resize((28, 28)), |
| | transforms.ToTensor(), |
| | transforms.Normalize((0.5,), (0.5,)), |
| | ] |
| | ) |
| | |
| | def predict(image): |
| | image = image['composite'] |
| | tensor = transform(image).unsqueeze(0) |
| | with torch.no_grad(): |
| | out = model(tensor) |
| | |
| | probabilities = F.softmax(out[0], dim=0) |
| | values, indices = torch.topk(probabilities, 5) |
| | print(values, indices) |
| | ``` |