|
|
--- |
|
|
license: cc-by-4.0 |
|
|
--- |
|
|
|
|
|
# This model doesn't inherit huggingface/transformers so it needs to be downloaded |
|
|
``` |
|
|
wget https://huggingface.co/Lancelot53/icon_classifier_maxvit/blob/main/best_model_89.pth |
|
|
``` |
|
|
|
|
|
# Inference Code |
|
|
``` |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torchvision import transforms, models |
|
|
from PIL import Image |
|
|
import torch.nn.functional as F |
|
|
|
|
|
#load id_2_class.json |
|
|
import json |
|
|
|
|
|
id_2_class = {"0": "back", "1": "Briefcase", "2": "Call", "3": "Camera", "4": "Circle", "5": "Cloud", "6": "delete", "7": "Down", "8": "edit", "9": "Export", "10": "Face", "11": "Folder", "12": "Globe", "13": "Google", "14": "Heart", "15": "Home", "16": "Image", "17": "Import", "18": "Info", "19": "Link", "20": "Location", "21": "Mail", "22": "menu", "23": "Merge", "24": "Message", "25": "Microphone", "26": "more", "27": "Music", "28": "Mute", "29": "Person", "30": "Phone", "31": "plus", "32": "QRCODE", "33": "Refresh", "34": "search", "35": "settings", "36": "share", "37": "Star", "38": "Tick", "39": "Up", "40": "vidCam", "41": "Video", "42": "Volume"} |
|
|
#make class_2_id dict |
|
|
|
|
|
class_2_id = {} |
|
|
for key, value in id_2_class.items(): |
|
|
class_2_id[value] = key |
|
|
|
|
|
test_transform = transforms.Compose([ |
|
|
transforms.Resize((224, 224)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5]) |
|
|
]) |
|
|
|
|
|
class MaxViT(nn.Module): |
|
|
def __init__(self): |
|
|
super(MaxViT, self).__init__() |
|
|
model = models.maxvit_t(weights="DEFAULT") |
|
|
num_ftrs = model.classifier[5].in_features |
|
|
model.classifier[5] = nn.Linear(num_ftrs, len(class_2_id)) |
|
|
self.model = model |
|
|
def forward(self, x): |
|
|
return self.model(x) |
|
|
|
|
|
# Instantiate the model |
|
|
model = MaxViT() |
|
|
model.load_state_dict(torch.load('best_model_89.pth')) |
|
|
model.eval() |
|
|
|
|
|
def inference(image_path, CONFIDENT_THRESHOLD=None): |
|
|
img = Image.open(image_path).convert("L").convert("RGB") |
|
|
img = test_transform(img) |
|
|
img = img.unsqueeze(0) |
|
|
|
|
|
with torch.no_grad(): |
|
|
output = F.softmax(model(img), dim=1) |
|
|
confidence, predicted = torch.max(output.data, 1) |
|
|
|
|
|
if CONFIDENT_THRESHOLD is not None and confidence.item() < CONFIDENT_THRESHOLD: |
|
|
return "UNKNOWN_CLASS", confidence.item() |
|
|
|
|
|
return id_2_class[str(predicted.item())], confidence.item() |
|
|
|
|
|
inference("images/7820.jpg", 0.9) #0.9 should be good enough |
|
|
``` |
|
|
|
|
|
|
|
|
# Training |
|
|
Check the repo |
|
|
|
|
|
# Dataset |
|
|
Trained on Lancelot53/android_icon_dataset |