aadex commited on
Commit
49c9a09
·
verified ·
1 Parent(s): 519bf7f

Add model card

Browse files
Files changed (1) hide show
  1. README.md +108 -0
README.md ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - vision-transformer
4
+ - image-classification
5
+ - rope
6
+ - imagenet100
7
+ - pytorch
8
+ license: apache-2.0
9
+ datasets:
10
+ - imagenet100
11
+ metrics:
12
+ - accuracy
13
+ ---
14
+
15
+ # Rope Vit - IMAGENET100
16
+
17
+ This model was trained using the [vit-analysis](https://github.com/your-repo/vit-analysis) framework for analyzing Vision Transformer positional encoding methods.
18
+
19
+ ## Model Details
20
+
21
+ | Property | Value |
22
+ |----------|-------|
23
+ | **Model Type** | ROPE Vision Transformer |
24
+ | **Dataset** | imagenet100 |
25
+ | **Best Accuracy** | 77.30% |
26
+ | **Image Size** | 224 |
27
+ | **Patch Size** | 16 |
28
+ | **Hidden Dim** | 192 |
29
+ | **Depth** | 12 |
30
+ | **Num Heads** | 3 |
31
+ | **MLP Dim** | 768 |
32
+ | **Num Classes** | 100 |
33
+
34
+ ## Model Description
35
+
36
+ This is a Vision Transformer with **Rotary Position Embeddings (RoPE)**.
37
+ RoPE encodes position information directly into the attention mechanism, enabling better
38
+ generalization to different sequence lengths and improved extrapolation capabilities.
39
+
40
+ - **RoPE Theta:** 10.0
41
+
42
+ ## Usage
43
+
44
+ ```python
45
+ import torch
46
+ from models import RoPESimpleVisionTransformer
47
+
48
+ # Initialize model
49
+ model = RoPESimpleVisionTransformer(
50
+ image_size=224,
51
+ patch_size=16,
52
+ num_layers=12,
53
+ num_heads=3,
54
+ hidden_dim=192,
55
+ mlp_dim=768,
56
+ num_classes=100,
57
+ )
58
+
59
+ # Load checkpoint
60
+ checkpoint = torch.load('rope_vit_imagenet100_best.pth', map_location='cpu')
61
+ state_dict = checkpoint['state_dict']
62
+
63
+ # Remove 'module.' prefix if present (from DDP training)
64
+ state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
65
+ model.load_state_dict(state_dict)
66
+ model.eval()
67
+
68
+ # Inference
69
+ from torchvision import transforms
70
+ from PIL import Image
71
+
72
+ transform = transforms.Compose([
73
+ transforms.Resize(256),
74
+ transforms.CenterCrop(224),
75
+ transforms.ToTensor(),
76
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
77
+ ])
78
+
79
+ image = Image.open('your_image.jpg').convert('RGB')
80
+ input_tensor = transform(image).unsqueeze(0)
81
+
82
+ with torch.no_grad():
83
+ output = model(input_tensor)
84
+ prediction = output.argmax(dim=1)
85
+ ```
86
+
87
+ ## Training
88
+
89
+ This model was trained with:
90
+ - **Framework:** PyTorch
91
+ - **Optimizer:** AdamW
92
+ - **Mixed Precision:** Enabled
93
+
94
+ ## Citation
95
+
96
+ If you use this model, please cite:
97
+
98
+ ```bibtex
99
+ @misc{vit-analysis,
100
+ title={Vision Transformer Position Encoding Analysis},
101
+ year={2024},
102
+ url={https://github.com/your-repo/vit-analysis}
103
+ }
104
+ ```
105
+
106
+ ## License
107
+
108
+ Apache 2.0