Normal1919 commited on
Commit
bd7df87
·
verified ·
1 Parent(s): 2885311

how to use

Browse files
Files changed (1) hide show
  1. README.md +58 -1
README.md CHANGED
@@ -23,4 +23,61 @@ should probably proofread and complete it, then remove this comment. -->
23
 
24
  # swinv2-large-patch4-window12to24-192to384-22kto1k-ft-microbes-merged
25
 
26
- This model is a fine-tuned version of [microsoft/swinv2-large-patch4-window12to24-192to384-22kto1k-ft](https://huggingface.co/microsoft/swinv2-large-patch4-window12to24-192to384-22kto1k-ft) on the private dataset.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  # swinv2-large-patch4-window12to24-192to384-22kto1k-ft-microbes-merged
25
 
26
+ This model is a fine-tuned version of [microsoft/swinv2-large-patch4-window12to24-192to384-22kto1k-ft](https://huggingface.co/microsoft/swinv2-large-patch4-window12to24-192to384-22kto1k-ft) on the private dataset.
27
+
28
+ # How to use
29
+
30
+ ```python
31
+ import torch
32
+ import torch.nn.functional as F
33
+ import torchvision
34
+ import torchvision.transforms as transforms
35
+ from transformers import AutoModelForImageClassification
36
+ from matplotlib import pyplot as plt
37
+
38
+ model_name = "THW_02"
39
+
40
+ model = AutoModelForImageClassification.from_pretrained(model_name)
41
+ model.eval()
42
+ # model = torch.compile(model)
43
+
44
+ image_transform = transforms.Compose([
45
+ transforms.ToPILImage(),
46
+ transforms.Resize((256, 256)),
47
+ transforms.ToTensor(),
48
+ transforms.Normalize(mean=[0.697, 0.633, 0.635], std=[0.3135, 0.320, 0.315])
49
+ ])
50
+
51
+ with torch.no_grad():
52
+ image_raw = torchvision.io.read_image("test_img/c9f00dbb7e8fe20538fcc71b1dc0fbb913029959.png")
53
+ if image_raw.size()[0] == 1:
54
+ image_raw = torch.cat([image_raw]*3, 0)
55
+ if image_raw.size()[0] == 4:
56
+ image_raw = image_raw[:3]
57
+ edit_image_tensor: torch.Tensor = image_transform(image_raw)
58
+ edit_image_tensor = edit_image_tensor.unsqueeze(0)
59
+
60
+ outputs = model(pixel_values=edit_image_tensor)
61
+ logits = F.sigmoid(outputs.logits)[0]
62
+ ind = logits.argmax().item()
63
+ print(model.config.id2label[ind])
64
+
65
+ cha_names = [model.config.id2label[i] for i in range(146)]
66
+ cha_probs = logits.numpy()
67
+ names_probs = list(zip(cha_names, cha_probs))
68
+ names_probs = sorted(names_probs, key=lambda x: x[1], reverse=True)
69
+
70
+ print(names_probs)
71
+
72
+ top_k = 10
73
+ names_show = []
74
+ probs_show = []
75
+ for i in range(top_k):
76
+ names_show.append(names_probs[i][0])
77
+ probs_show.append(names_probs[i][1])
78
+
79
+ plt.rcParams['font.sans-serif'] = ['SimHei']
80
+ plt.figure(figsize=(12, 8))
81
+ plt.bar(names_show, probs_show)
82
+ plt.show()
83
+ ```