Upload directory
Browse files- models/vit/__init__.py +68 -0
models/vit/__init__.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ..base import BaseModel
|
| 2 |
+
from .vit import VisionTransformer
|
| 3 |
+
from torchvision import transforms
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class ViTModel(BaseModel):
|
| 7 |
+
|
| 8 |
+
"""
|
| 9 |
+
A class representing a Vision Transformer (ViT) model that inherits from the BaseModel class.
|
| 10 |
+
|
| 11 |
+
This model applies the transformer architecture to image analysis, utilizing patches of images as input sequences,
|
| 12 |
+
allowing for attention-based processing of visual elements.
|
| 13 |
+
https://arxiv.org/abs/2010.11929
|
| 14 |
+
```
|
| 15 |
+
@article{dosovitskiy2020image,
|
| 16 |
+
title={An image is worth 16x16 words: Transformers for image recognition at scale},
|
| 17 |
+
author={Dosovitskiy, Alexey and Beyer, Lucas and Kolesnikov, Alexander and Weissenborn, Dirk and Zhai, Xiaohua and Unterthiner, Thomas and Dehghani, Mostafa and Minderer, Matthias and Heigold, Georg and Gelly, Sylvain and others},
|
| 18 |
+
journal={arXiv preprint arXiv:2010.11929},
|
| 19 |
+
year={2020}
|
| 20 |
+
}
|
| 21 |
+
```
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, net, config):
|
| 25 |
+
super(ViTModel, self).__init__(config)
|
| 26 |
+
self.net = net
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@classmethod
|
| 30 |
+
def from_config(cls, config):
|
| 31 |
+
|
| 32 |
+
if config.name == 'small':
|
| 33 |
+
net = VisionTransformer(img_size=112, patch_size=8, num_classes=config.output_dim, embed_dim=512, depth=12,
|
| 34 |
+
mlp_ratio=5, num_heads=8, drop_path_rate=0.1, norm_layer="ln",
|
| 35 |
+
mask_ratio=config.mask_ratio)
|
| 36 |
+
elif config.name == 'base':
|
| 37 |
+
net = VisionTransformer(img_size=112, patch_size=8, num_classes=config.output_dim, embed_dim=512, depth=24,
|
| 38 |
+
mlp_ratio=3, num_heads=16, drop_path_rate=0.1, norm_layer="ln",
|
| 39 |
+
mask_ratio=config.mask_ratio)
|
| 40 |
+
else:
|
| 41 |
+
raise NotImplementedError
|
| 42 |
+
|
| 43 |
+
model = cls(net, config)
|
| 44 |
+
model.eval()
|
| 45 |
+
return model
|
| 46 |
+
|
| 47 |
+
def forward(self, x):
|
| 48 |
+
if self.input_color_flip:
|
| 49 |
+
x = x.flip(1)
|
| 50 |
+
return self.net(x)
|
| 51 |
+
|
| 52 |
+
def make_train_transform(self):
|
| 53 |
+
transform = transforms.Compose([
|
| 54 |
+
transforms.ToTensor(),
|
| 55 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
| 56 |
+
])
|
| 57 |
+
return transform
|
| 58 |
+
|
| 59 |
+
def make_test_transform(self):
|
| 60 |
+
transform = transforms.Compose([
|
| 61 |
+
transforms.ToTensor(),
|
| 62 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
| 63 |
+
])
|
| 64 |
+
return transform
|
| 65 |
+
|
| 66 |
+
def load_model(model_config):
|
| 67 |
+
model = ViTModel.from_config(model_config)
|
| 68 |
+
return model
|