| import torch | |
| from torch import nn | |
| class Patch_Embedding (nn.Module): | |
| def __init__(self,img_size,patch_size,embed_dim) : | |
| super(Patch_Embedding,self).__init__() | |
| self.img_size = img_size | |
| self.patch_size = patch_size | |
| self.embed_dim = embed_dim | |
| self.n_patch = (img_size//patch_size)**2 | |
| self.projection_layers = nn.Conv2d(in_channels=3,out_channels=embed_dim,kernel_size=patch_size,stride=patch_size) | |
| def forward(self,x) : | |
| x = self.projection_layers(x) | |
| B,D,H,W = x.shape | |
| x = x.flatten(2) | |
| x = x.transpose(1,2) | |
| return x | |
| class Positional_Encoding (nn.Module) : | |
| def __init__ (self,n_patch,embedd_dim) : | |
| super(Positional_Encoding,self).__init__() | |
| self.n_patch = n_patch | |
| self.embedd_dim = embedd_dim | |
| self.positional_encoding = nn.Parameter(torch.normal(0,0.02,size=(1,n_patch + 1,embedd_dim))) | |
| self.cls_token = nn.Parameter(torch.normal(0,0.02,size=(1,1,embedd_dim))) | |
| def forward(self,x) : | |
| batch = x.shape[0] | |
| cls_token = torch.broadcast_to(self.cls_token,(batch,1,self.embedd_dim)) | |
| x = torch.cat((cls_token,x),dim=1) | |
| x = x + self.positional_encoding | |
| return x | |
| class BlockTransformers (nn.Module) : | |
| def __init__ (self,d_Model,d_ff,n_head) : | |
| super(BlockTransformers,self).__init__() | |
| self.MHA = nn.MultiheadAttention(embed_dim=d_Model,num_heads=n_head,batch_first=True) | |
| self.FFN = nn.Sequential( | |
| nn.Linear(d_Model,d_ff), | |
| nn.GELU(), | |
| nn.Linear(d_ff,d_Model) | |
| ) | |
| self.drop_out = nn.Dropout(p=0.1) | |
| self.drop_out2 = nn.Dropout(p=0.1) | |
| self.layer_norm = nn.LayerNorm(d_Model) | |
| self.layer_norm2 = nn.LayerNorm(d_Model) | |
| def forward(self,x) : | |
| residural = x | |
| x = self.layer_norm(x) | |
| attention,_ = self.MHA(x,x,x) | |
| attention = self.drop_out(attention) | |
| x = x + attention | |
| residural = x | |
| ffn = self.layer_norm2(x) | |
| ffn = self.FFN(ffn) | |
| ffn = self.drop_out2(ffn) | |
| x = residural + ffn | |
| return x | |
| class MiniVisualTransformers (nn.Module) : | |
| def __init__(self) : | |
| super(MiniVisualTransformers,self).__init__() | |
| self.Patch_Embedding = Patch_Embedding(img_size=144,patch_size=32,embed_dim=64) | |
| self.Positional_Encoding = Positional_Encoding(n_patch=self.Patch_Embedding.n_patch,embedd_dim=self.Patch_Embedding.embed_dim) | |
| self.BT = nn.ModuleList([BlockTransformers(d_Model=64,d_ff=256,n_head=4) for _ in range(4)]) | |
| def forward(self,x) : | |
| x = self.Patch_Embedding(x) | |
| x = self.Positional_Encoding(x) | |
| for block in self.BT : | |
| x = block(x) | |
| return x | |
| class Classifier (nn.Module) : | |
| def __init__ (self,n_class) : | |
| super(Classifier,self).__init__() | |
| self.MiniVIT = MiniVisualTransformers() | |
| self.linear = nn.Linear(64,n_class) | |
| def forward(self,x) : | |
| x = self.MiniVIT(x) | |
| x = x[:,0,:] | |
| x = self.linear(x) | |
| return x |