Spaces:
Sleeping
Sleeping
| import math | |
| from lightning.pytorch.utilities.types import EVAL_DATALOADERS | |
| import torch | |
| from typing import Dict,Optional,Tuple,Union | |
| from dataclasses import dataclass | |
| import lightning as pl | |
| from torchmetrics import Accuracy | |
| # @dataclass | |
| # class ViTCfg: | |
| # image_size: int | |
| # patch_size: int | |
| # num_channels: int | |
| # model_dim: int | |
| # num_attn_heads:int | |
| # attn_dropout: int | |
| # d_ff: int | |
| # number_encoders:int | |
| # classification_heads:int | |
| class PatchEmbedding(torch.nn.Module): | |
| def __init__(self, cfg:Dict) -> None: | |
| super().__init__() | |
| for k,v in cfg.items(): setattr(self,k,v) | |
| assert self.image_size % self.patch_size==0,"patch size is not divide image_size properly" | |
| self.num_patchs = (self.image_size // self.patch_size)**2 | |
| self.img2flattn:torch.nn.Conv2d = torch.nn.Conv2d ( | |
| in_channels = self.num_channels, | |
| out_channels=self.model_dim, | |
| kernel_size = self.patch_size, | |
| stride = self.patch_size, | |
| bias=False | |
| ) | |
| def forward(self,x:torch.Tensor)->torch.Tensor: | |
| # (bs, 3, 32, 32 ) >> (bs, model_dim, img_size//patch_size, img_size//patch_size ) >> ( 1. model_dim, img_size**2 ) >> ( 1, img_size**2, model_dim ) | |
| return self.img2flattn(x).flatten(2).transpose(1,2) | |
| class Embedding(torch.nn.Module): | |
| def __init__(self,cfg:Dict ) -> None: | |
| super().__init__() | |
| for k,v in cfg.items(): setattr(self,k,v) | |
| self.patch_embedding:PatchEmbedding = PatchEmbedding(cfg=cfg) | |
| # single [CLS] token | |
| self.cls_token:torch.nn.Parameter = torch.nn.Parameter( torch.randn(1,1, self.model_dim ) ) | |
| self.position_embd:torch.nn.Parameter = torch.nn.Parameter( | |
| torch.randn( 1, int( (self.image_size // self.patch_size)**2 + 1), self.model_dim ) | |
| ) | |
| def forward(self,x:torch.Tensor)->torch.Tensor: | |
| x = self.patch_embedding(x) | |
| cls_token = self.cls_token.expand( x.shape[0], -1, -1 ) | |
| x = torch.cat( (cls_token,x) , dim=1) | |
| x = x + self.position_embd | |
| return x | |
| class AttentionBlock(torch.nn.Module): | |
| def __init__(self,cfg:Dict ) -> None: | |
| super().__init__() | |
| for k,v in cfg.items(): self.__setattr__(k,v) | |
| assert self.model_dim % self.num_attn_heads ==0, "model dim is not divisible by n heads" | |
| self.attn_layer:torch.nn.Linear = torch.nn.Linear(self.model_dim, 3*self.model_dim, bias=False) | |
| self.out :torch.nn.Linear = torch.nn.Linear(self.model_dim,self.model_dim,bias=False) | |
| self.attn_dropout:torch.nn.Dropout = torch.nn.Dropout() | |
| self.resid_dropout:torch.nn.Dropout= torch.nn.Dropout() | |
| # casual mask to ensure that attention is only applied to the left in the input seq | |
| # self.register_buffer('bias',tensor= torch.tril(torch.ones(self.block_size,self.block_size)).view(1, 1, self.block_size, self.block_size) ) | |
| ''' | |
| block_size=10 | |
| [[[[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.], | |
| [1., 1., 0., 0., 0., 0., 0., 0., 0., 0.], | |
| [1., 1., 1., 0., 0., 0., 0., 0., 0., 0.], | |
| [1., 1., 1., 1., 0., 0., 0., 0., 0., 0.], | |
| [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.], | |
| [1., 1., 1., 1., 1., 1., 0., 0., 0., 0.], | |
| [1., 1., 1., 1., 1., 1., 1., 0., 0., 0.], | |
| [1., 1., 1., 1., 1., 1., 1., 1., 0., 0.], | |
| [1., 1., 1., 1., 1., 1., 1., 1., 1., 0.], | |
| [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]]] | |
| # Batch-1, Seq-1, Mask-(10,10) | |
| ''' | |
| def forward(self,x:torch.Tensor, attention_outputs:bool)->Tuple[torch.Tensor, Union[torch.Tensor,None]]: | |
| ''' | |
| input (bs,seq_len,embedding_dim) >> output (bs,seq_len,embedding_dim) | |
| x :: (bs,seq_len,embedding_dim) | |
| attn :: (bs, seq_len, 3*embedding_dim) | |
| .split:: (bs, seq_len, 3*embedding_dim).split(embedding_dim,dim=2) | |
| # Each chunk (bs,seq_len,embedding) is a view of the original tensor, split across embeddin_dim so, 3 will get | |
| k,q,v >> (bs,seql_len, n_heads, embedding_dim//n_heads) >> (bs,head, seql_len, embedding_dim//n_heads) | |
| # Each Heads are responsible for different context of seq_len | |
| ''' | |
| B,T,C = x.size() #(bs, seq_len ,embedding_dim) | |
| # calc q,k,v | |
| q:torch.Tensor; | |
| k:torch.Tensor; | |
| v:torch.Tensor; | |
| q,k,v = self.attn_layer(x).split(split_size=self.model_dim,dim=2) | |
| q = q.view(B,T,self.num_attn_heads, C//self.num_attn_heads).transpose(1,2) | |
| k = k.view(B,T,self.num_attn_heads, C//self.num_attn_heads).transpose(1,2) | |
| v = v.view(B,T,self.num_attn_heads, C//self.num_attn_heads).transpose(1,2) | |
| attn = (q @ k.transpose(-2,-1)) * (1/math.sqrt(k.size(-1))) | |
| # attn = attn.masked_fill(self.bias[:,:,:T,:T]==0,float('-inf')) | |
| attn = torch.nn.functional.softmax(attn,dim=-1) | |
| attn = self.attn_dropout(attn) | |
| y:torch.Tensor = attn @ v # (bs, n_heads, T,T) @ (bs, n_heads, T, embding_dm/n_heads ) >> (bs,n_heads, seq_len, embedding_dim/n_heads ) | |
| y:torch.Tensor = y.transpose(1,2).contiguous().view(B,T,C) | |
| return self.resid_dropout(self.out(y)), attn if attention_outputs else None | |
| class MLP(torch.nn.Module): | |
| def __init__(self,cfg:Dict ) -> None: | |
| super().__init__() | |
| for k,v in cfg.items(): self.__setattr__(k,v) | |
| super().__init__() | |
| self.dense_1 = torch.nn.Linear(self.model_dim, self.d_ff) | |
| self.activation = torch.nn.ReLU() | |
| self.layernorm = torch.nn.LayerNorm(self.d_ff) | |
| self.dense_2 = torch.nn.Linear(self.d_ff, self.model_dim) | |
| self.dropout = torch.nn.Dropout(0.2) | |
| def forward(self,x:torch.Tensor)->torch.Tensor: | |
| return self.dropout( self.dense_2( self.layernorm(self.activation( self.dense_1(x) )) ) ) | |
| class EncoderBlock(torch.nn.Module): | |
| def __init__(self,cfg:Dict ) -> None: | |
| super().__init__() | |
| for k,v in cfg.items(): self.__setattr__(k,v) | |
| self.attn_block = AttentionBlock(cfg) | |
| self.layernorm_1 = torch.nn.LayerNorm(self.model_dim) | |
| self.mlp = MLP(cfg) | |
| self.layernorm_2 = torch.nn.LayerNorm(self.model_dim) | |
| def forward(self,x:torch.Tensor, attention_outputs:bool)->Tuple[torch.Tensor, Union[torch.Tensor,None]]: | |
| # self-attention | |
| attention_op, attn = self.attn_block(self.layernorm_1(x), attention_outputs=attention_outputs ) | |
| x = x + attention_op | |
| # FC | |
| mlp_output = self.mlp( self.layernorm_2(x) ) | |
| x = x + mlp_output | |
| return x, attn if attention_outputs==True else None # Return the transformer block's output and the attention probabilities (optional) | |
| class Encoder(torch.nn.Module): | |
| """ | |
| The transformer encoder module. | |
| """ | |
| def __init__(self,cfg:Dict ) -> None: | |
| super().__init__() | |
| for k,v in cfg.items(): self.__setattr__(k,v) | |
| # Create a list of transformer blocks | |
| self.blocks = torch.nn.ModuleList([]) | |
| for _ in range(self.number_encoders): | |
| block = EncoderBlock(cfg) | |
| self.blocks.append(block) | |
| def forward(self,x:torch.Tensor,attention_outputs:bool): | |
| # Calculate the transformer block's output for each block | |
| all_attn = [] | |
| for block in self.blocks: | |
| x,attn = block(x,attention_outputs=attention_outputs) | |
| all_attn.append(attn) | |
| # Return the encoder's output and the attention probabilities (optional) | |
| return x,all_attn if attention_outputs==True else None | |
| class ViTClassifier(torch.nn.Module): | |
| def __init__(self, cfg:Dict ) -> None: | |
| super().__init__() | |
| for k,v in cfg.items(): self.__setattr__(k,v) | |
| self.embed:Embedding = Embedding(cfg) | |
| self.encoders:Encoder = Encoder(cfg=cfg) | |
| self.classifier:torch.nn.Linear = torch.nn.Linear(self.model_dim ,self.classification_heads,bias=False) | |
| def forward(self,x:torch.Tensor,attention_outputs=False): | |
| x = self.embed(x) | |
| x,attn = self.encoders(x,attention_outputs=attention_outputs) | |
| return self.classifier(x[:,0]), attn if attention_outputs else None | |