Simple Vit - IMAGENET100
This model was trained using the vit-analysis framework for analyzing Vision Transformer positional encoding methods.
Model Details
| Property | Value |
|---|---|
| Model Type | SIMPLE Vision Transformer |
| Dataset | imagenet100 |
| Best Accuracy | 71.94% |
| Image Size | 224 |
| Patch Size | 16 |
| Hidden Dim | 192 |
| Depth | 12 |
| Num Heads | 3 |
| MLP Dim | 768 |
| Num Classes | 100 |
Model Description
This is a Vision Transformer with learnable positional embeddings. The model uses standard absolute positional embeddings that are learned during training.
Usage
import torch
from models import SimpleVisionTransformer
# Initialize model
model = SimpleVisionTransformer(
image_size=224,
patch_size=16,
num_layers=12,
num_heads=3,
hidden_dim=192,
mlp_dim=768,
num_classes=100,
)
# Load checkpoint
checkpoint = torch.load('simple_vit_imagenet100_best.pth', map_location='cpu')
state_dict = checkpoint['state_dict']
# Remove 'module.' prefix if present (from DDP training)
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
model.eval()
# Inference
from torchvision import transforms
from PIL import Image
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
image = Image.open('your_image.jpg').convert('RGB')
input_tensor = transform(image).unsqueeze(0)
with torch.no_grad():
output = model(input_tensor)
prediction = output.argmax(dim=1)
Training
This model was trained with:
- Framework: PyTorch
- Optimizer: AdamW
- Mixed Precision: Enabled
Citation
If you use this model, please cite:
@misc{vit-analysis,
title={Vision Transformer Position Encoding Analysis},
year={2024},
url={https://github.com/your-repo/vit-analysis}
}
License
Apache 2.0