aadex commited on
Commit
6a8c801
·
verified ·
1 Parent(s): 5538ad8

Add model card

Browse files
Files changed (1) hide show
  1. README.md +106 -0
README.md ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - vision-transformer
4
+ - image-classification
5
+ - fire
6
+ - imagenet100
7
+ - pytorch
8
+ license: apache-2.0
9
+ datasets:
10
+ - imagenet100
11
+ metrics:
12
+ - accuracy
13
+ ---
14
+
15
+ # Fire Vit - IMAGENET100
16
+
17
+ This model was trained using the [vit-analysis](https://github.com/your-repo/vit-analysis) framework for analyzing Vision Transformer positional encoding methods.
18
+
19
+ ## Model Details
20
+
21
+ | Property | Value |
22
+ |----------|-------|
23
+ | **Model Type** | FIRE Vision Transformer |
24
+ | **Dataset** | imagenet100 |
25
+ | **Best Accuracy** | 74.22% |
26
+ | **Image Size** | 224 |
27
+ | **Patch Size** | 16 |
28
+ | **Hidden Dim** | 192 |
29
+ | **Depth** | 12 |
30
+ | **Num Heads** | 3 |
31
+ | **MLP Dim** | 768 |
32
+ | **Num Classes** | 100 |
33
+
34
+ ## Model Description
35
+
36
+ This is a Vision Transformer with **FIRE (Functional Interpolation for Relative Position Encoding)**.
37
+ FIRE uses learnable continuous functions to generate position-dependent bias terms, providing
38
+ flexible and generalizable position encoding.
39
+
40
+ ## Usage
41
+
42
+ ```python
43
+ import torch
44
+ from models import FIRESimpleVisionTransformer
45
+
46
+ # Initialize model
47
+ model = FIRESimpleVisionTransformer(
48
+ image_size=224,
49
+ patch_size=16,
50
+ num_layers=12,
51
+ num_heads=3,
52
+ hidden_dim=192,
53
+ mlp_dim=768,
54
+ num_classes=100,
55
+ )
56
+
57
+ # Load checkpoint
58
+ checkpoint = torch.load('fire_vit_imagenet100_best.pth', map_location='cpu')
59
+ state_dict = checkpoint['state_dict']
60
+
61
+ # Remove 'module.' prefix if present (from DDP training)
62
+ state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
63
+ model.load_state_dict(state_dict)
64
+ model.eval()
65
+
66
+ # Inference
67
+ from torchvision import transforms
68
+ from PIL import Image
69
+
70
+ transform = transforms.Compose([
71
+ transforms.Resize(256),
72
+ transforms.CenterCrop(224),
73
+ transforms.ToTensor(),
74
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
75
+ ])
76
+
77
+ image = Image.open('your_image.jpg').convert('RGB')
78
+ input_tensor = transform(image).unsqueeze(0)
79
+
80
+ with torch.no_grad():
81
+ output = model(input_tensor)
82
+ prediction = output.argmax(dim=1)
83
+ ```
84
+
85
+ ## Training
86
+
87
+ This model was trained with:
88
+ - **Framework:** PyTorch
89
+ - **Optimizer:** AdamW
90
+ - **Mixed Precision:** Enabled
91
+
92
+ ## Citation
93
+
94
+ If you use this model, please cite:
95
+
96
+ ```bibtex
97
+ @misc{vit-analysis,
98
+ title={Vision Transformer Position Encoding Analysis},
99
+ year={2024},
100
+ url={https://github.com/your-repo/vit-analysis}
101
+ }
102
+ ```
103
+
104
+ ## License
105
+
106
+ Apache 2.0