minchul commited on
Commit
40e2e98
·
verified ·
1 Parent(s): ec8ff86

Upload directory

Browse files
Files changed (1) hide show
  1. models/vit/__init__.py +68 -0
models/vit/__init__.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..base import BaseModel
2
+ from .vit import VisionTransformer
3
+ from torchvision import transforms
4
+
5
+
6
+ class ViTModel(BaseModel):
7
+
8
+ """
9
+ A class representing a Vision Transformer (ViT) model that inherits from the BaseModel class.
10
+
11
+ This model applies the transformer architecture to image analysis, utilizing patches of images as input sequences,
12
+ allowing for attention-based processing of visual elements.
13
+ https://arxiv.org/abs/2010.11929
14
+ ```
15
+ @article{dosovitskiy2020image,
16
+ title={An image is worth 16x16 words: Transformers for image recognition at scale},
17
+ author={Dosovitskiy, Alexey and Beyer, Lucas and Kolesnikov, Alexander and Weissenborn, Dirk and Zhai, Xiaohua and Unterthiner, Thomas and Dehghani, Mostafa and Minderer, Matthias and Heigold, Georg and Gelly, Sylvain and others},
18
+ journal={arXiv preprint arXiv:2010.11929},
19
+ year={2020}
20
+ }
21
+ ```
22
+ """
23
+
24
+ def __init__(self, net, config):
25
+ super(ViTModel, self).__init__(config)
26
+ self.net = net
27
+
28
+
29
+ @classmethod
30
+ def from_config(cls, config):
31
+
32
+ if config.name == 'small':
33
+ net = VisionTransformer(img_size=112, patch_size=8, num_classes=config.output_dim, embed_dim=512, depth=12,
34
+ mlp_ratio=5, num_heads=8, drop_path_rate=0.1, norm_layer="ln",
35
+ mask_ratio=config.mask_ratio)
36
+ elif config.name == 'base':
37
+ net = VisionTransformer(img_size=112, patch_size=8, num_classes=config.output_dim, embed_dim=512, depth=24,
38
+ mlp_ratio=3, num_heads=16, drop_path_rate=0.1, norm_layer="ln",
39
+ mask_ratio=config.mask_ratio)
40
+ else:
41
+ raise NotImplementedError
42
+
43
+ model = cls(net, config)
44
+ model.eval()
45
+ return model
46
+
47
+ def forward(self, x):
48
+ if self.input_color_flip:
49
+ x = x.flip(1)
50
+ return self.net(x)
51
+
52
+ def make_train_transform(self):
53
+ transform = transforms.Compose([
54
+ transforms.ToTensor(),
55
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
56
+ ])
57
+ return transform
58
+
59
+ def make_test_transform(self):
60
+ transform = transforms.Compose([
61
+ transforms.ToTensor(),
62
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
63
+ ])
64
+ return transform
65
+
66
+ def load_model(model_config):
67
+ model = ViTModel.from_config(model_config)
68
+ return model