Lancelot53 commited on
Commit
2a30c76
·
1 Parent(s): 09d4108

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +71 -0
README.md CHANGED
@@ -1,3 +1,74 @@
1
  ---
2
  license: cc-by-4.0
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: cc-by-4.0
3
  ---
4
+
5
+ # This model doesn't inherit huggingface/transformers so it needs to be downloaded
6
+ ```
7
+ wget https://huggingface.co/Lancelot53/icon_classifier_maxvit/blob/main/id_2_class_89.json
8
+ wget https://huggingface.co/Lancelot53/icon_classifier_maxvit/blob/main/best_model_89.pth
9
+ ```
10
+
11
+ # Inference Code
12
+ ```
13
+ import torch
14
+ import torch.nn as nn
15
+ from torchvision import transforms, models
16
+ from PIL import Image
17
+ import torch.nn.functional as F
18
+
19
+ #load id_2_class.json
20
+ import json
21
+
22
+ with open('id_2_class_89.json') as json_file:
23
+ id_2_class = json.load(json_file)
24
+
25
+ #make class_2_id dict
26
+
27
+ class_2_id = {}
28
+ for key, value in id_2_class.items():
29
+ class_2_id[value] = key
30
+
31
+ test_transform = transforms.Compose([
32
+ transforms.Resize((224, 224)),
33
+ transforms.ToTensor(),
34
+ transforms.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5])
35
+ ])
36
+
37
+ class MaxViT(nn.Module):
38
+ def __init__(self):
39
+ super(MaxViT, self).__init__()
40
+ model = models.maxvit_t(weights="DEFAULT")
41
+ num_ftrs = model.classifier[5].in_features
42
+ model.classifier[5] = nn.Linear(num_ftrs, len(class_2_id))
43
+ self.model = model
44
+ def forward(self, x):
45
+ return self.model(x)
46
+
47
+ # Instantiate the model
48
+ model = MaxViT()
49
+ model.load_state_dict(torch.load('best_model_89.pth'))
50
+ model.eval()
51
+
52
+ def inference(image_path, CONFIDENT_THRESHOLD=None):
53
+ img = Image.open(image_path).convert("L").convert("RGB")
54
+ img = test_transform(img)
55
+ img = img.unsqueeze(0)
56
+
57
+ with torch.no_grad():
58
+ output = F.softmax(model(img), dim=1)
59
+ confidence, predicted = torch.max(output.data, 1)
60
+
61
+ if CONFIDENT_THRESHOLD is not None and confidence.item() < CONFIDENT_THRESHOLD:
62
+ return "UNKNOWN_CLASS", confidence.item()
63
+
64
+ return id_2_class[str(predicted.item())], confidence.item()
65
+
66
+ inference("images/7820.jpg", 0.9) #0.9 should be good enough
67
+ ```
68
+
69
+
70
+ # Training
71
+ Check the repo
72
+
73
+ # Dataset
74
+ Trained on 8K icons in 43 classes. The dataset is proprietary for now (Email me if you want it).