File size: 2,493 Bytes
d4e3a53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5185713
d4e3a53
bd7df87
 
 
 
 
 
 
 
 
 
 
 
cc6e200
bd7df87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
---
license: apache-2.0
base_model: microsoft/swinv2-large-patch4-window12to24-192to384-22kto1k-ft
model-index:
- name: THW
  results:
  - task:
      name: Image Classification
      type: image-classification
    dataset:
      name: None
      type: None
      config: None
      split: None
      args: None
    metrics:
    - name: None
      type: None
      value: None
---
<!-- This model card has been generated automatically according to the information the Trainer had access to. You
should probably proofread and complete it, then remove this comment. -->

# Normal1919/THW

This model is a fine-tuned version of [microsoft/swinv2-large-patch4-window12to24-192to384-22kto1k-ft](https://huggingface.co/microsoft/swinv2-large-patch4-window12to24-192to384-22kto1k-ft) on the private dataset.

# How to use

```python
import torch
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from transformers import AutoModelForImageClassification
from matplotlib import pyplot as plt

model_name = "Normal1919/THW"

model = AutoModelForImageClassification.from_pretrained(model_name)
model.eval()
# model = torch.compile(model)

image_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.697, 0.633, 0.635], std=[0.3135, 0.320, 0.315])
])

with torch.no_grad():
    image_raw = torchvision.io.read_image("test_img/c9f00dbb7e8fe20538fcc71b1dc0fbb913029959.png")
    if image_raw.size()[0] == 1:
        image_raw = torch.cat([image_raw]*3, 0)
    if image_raw.size()[0] == 4:
        image_raw = image_raw[:3]
    edit_image_tensor: torch.Tensor = image_transform(image_raw)
    edit_image_tensor = edit_image_tensor.unsqueeze(0)

    outputs = model(pixel_values=edit_image_tensor)
    logits = F.sigmoid(outputs.logits)[0]
    ind = logits.argmax().item()
    print(model.config.id2label[ind])

    cha_names = [model.config.id2label[i] for i in range(146)]
    cha_probs = logits.numpy()
    names_probs = list(zip(cha_names, cha_probs))
    names_probs = sorted(names_probs, key=lambda x: x[1], reverse=True)

    print(names_probs)

    top_k = 10
    names_show = []
    probs_show = []
    for i in range(top_k):
        names_show.append(names_probs[i][0])
        probs_show.append(names_probs[i][1])

    plt.rcParams['font.sans-serif'] = ['SimHei']
    plt.figure(figsize=(12, 8))
    plt.bar(names_show, probs_show)
    plt.show()
```