| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| |
|
| | import copy |
| | import fnmatch |
| | import logging |
| | from functools import partial |
| | from typing import Callable, List |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.utils.checkpoint as checkpoint |
| |
|
| | from timm.models.layers import DropPath, trunc_normal_ |
| |
|
| |
|
| | class Attention(nn.Module): |
| | def __init__( |
| | self, |
| | dim, |
| | num_heads=8, |
| | qkv_bias=False, |
| | qk_scale=None, |
| | attn_drop=0.0, |
| | proj_drop=0.0, |
| | ): |
| | super().__init__() |
| | self.num_heads = num_heads |
| | head_dim = dim // num_heads |
| | |
| | |
| | self.scale = qk_scale or head_dim**-0.5 |
| |
|
| | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |
| | self.attn_drop = nn.Dropout(attn_drop) |
| | self.proj = nn.Linear(dim, dim) |
| | self.proj_drop = nn.Dropout(proj_drop) |
| |
|
| | def forward(self, x): |
| | B, N, C = x.shape |
| | qkv = ( |
| | self.qkv(x) |
| | .reshape(B, N, 3, self.num_heads, C // self.num_heads) |
| | .permute(2, 0, 3, 1, 4) |
| | ) |
| | q, k, v = ( |
| | qkv[0], |
| | qkv[1], |
| | qkv[2], |
| | ) |
| |
|
| | attn = (q @ k.transpose(-2, -1)) * self.scale |
| | attn = attn.softmax(dim=-1) |
| | attn = self.attn_drop(attn) |
| |
|
| | x = (attn @ v).transpose(1, 2).reshape(B, N, C) |
| | x = self.proj(x) |
| | x = self.proj_drop(x) |
| | return x |
| |
|
| |
|
| | class Mlp(nn.Module): |
| | def __init__( |
| | self, |
| | in_features, |
| | hidden_features=None, |
| | out_features=None, |
| | act_layer=nn.GELU, |
| | drop=0.0, |
| | ): |
| | super().__init__() |
| | out_features = out_features or in_features |
| | hidden_features = hidden_features or in_features |
| | self.fc1 = nn.Linear(in_features, hidden_features) |
| | self.act = act_layer() |
| | self.fc2 = nn.Linear(hidden_features, out_features) |
| | self.drop = nn.Dropout(drop) |
| |
|
| | def forward(self, x): |
| | x = self.fc1(x) |
| | x = self.act(x) |
| | x = self.drop(x) |
| | x = self.fc2(x) |
| | x = self.drop(x) |
| | return x |
| |
|
| |
|
| | class MultiheadAttention(nn.MultiheadAttention): |
| | def forward(self, x: torch.Tensor, attn_mask: torch.Tensor): |
| | return super().forward(x, x, x, need_weights=False, attn_mask=attn_mask)[0] |
| |
|
| |
|
| | class ViTAttention(Attention): |
| | def forward(self, x: torch.Tensor, attn_mask: torch.Tensor): |
| | assert attn_mask is None |
| | return super().forward(x) |
| |
|
| |
|
| | class BlockWithMasking(nn.Module): |
| | def __init__( |
| | self, |
| | dim: int, |
| | attn_target: Callable, |
| | mlp_ratio: int = 4, |
| | act_layer: Callable = nn.GELU, |
| | norm_layer: Callable = nn.LayerNorm, |
| | ffn_dropout_rate: float = 0.0, |
| | drop_path: float = 0.0, |
| | layer_scale_type: str = None, |
| | layer_scale_init_value: float = 1e-4, |
| | ): |
| | super().__init__() |
| |
|
| | assert not isinstance( |
| | attn_target, nn.Module |
| | ), "attn_target should be a Callable. Otherwise attn_target is shared across blocks!" |
| | self.attn = attn_target() |
| | if drop_path > 0.0: |
| | self.drop_path = DropPath(drop_path) |
| | else: |
| | self.drop_path = nn.Identity() |
| | self.norm_1 = norm_layer(dim) |
| | mlp_hidden_dim = int(mlp_ratio * dim) |
| | self.mlp = Mlp( |
| | in_features=dim, |
| | hidden_features=mlp_hidden_dim, |
| | act_layer=act_layer, |
| | drop=ffn_dropout_rate, |
| | ) |
| | self.norm_2 = norm_layer(dim) |
| | self.layer_scale_type = layer_scale_type |
| | if self.layer_scale_type is not None: |
| | assert self.layer_scale_type in [ |
| | "per_channel", |
| | "scalar", |
| | ], f"Found Layer scale type {self.layer_scale_type}" |
| | if self.layer_scale_type == "per_channel": |
| | |
| | gamma_shape = [1, 1, dim] |
| | elif self.layer_scale_type == "scalar": |
| | |
| | gamma_shape = [1, 1, 1] |
| | |
| | self.layer_scale_gamma1 = nn.Parameter( |
| | torch.ones(size=gamma_shape) * layer_scale_init_value, |
| | requires_grad=True, |
| | ) |
| | self.layer_scale_gamma2 = nn.Parameter( |
| | torch.ones(size=gamma_shape) * layer_scale_init_value, |
| | requires_grad=True, |
| | ) |
| |
|
| | def forward(self, x: torch.Tensor, attn_mask: torch.Tensor): |
| | if self.layer_scale_type is None: |
| | x = x + self.drop_path(self.attn(self.norm_1(x), attn_mask)) |
| | x = x + self.drop_path(self.mlp(self.norm_2(x))) |
| | else: |
| | x = ( |
| | x |
| | + self.drop_path(self.attn(self.norm_1(x), attn_mask)) |
| | * self.layer_scale_gamma1 |
| | ) |
| | x = x + self.drop_path(self.mlp(self.norm_2(x))) * self.layer_scale_gamma2 |
| | return x |
| |
|
| |
|
| | _LAYER_NORM = partial(nn.LayerNorm, eps=1e-6) |
| |
|
| |
|
| | class SimpleTransformer(nn.Module): |
| | def __init__( |
| | self, |
| | attn_target: Callable, |
| | embed_dim: int, |
| | num_blocks: int, |
| | block: Callable = BlockWithMasking, |
| | pre_transformer_layer: Callable = None, |
| | post_transformer_layer: Callable = None, |
| | drop_path_rate: float = 0.0, |
| | drop_path_type: str = "progressive", |
| | norm_layer: Callable = _LAYER_NORM, |
| | mlp_ratio: int = 4, |
| | ffn_dropout_rate: float = 0.0, |
| | layer_scale_type: str = None, |
| | layer_scale_init_value: float = 1e-4, |
| | weight_init_style: str = "jax", |
| | ): |
| | """ |
| | Simple Transformer with the following features |
| | 1. Supports masked attention |
| | 2. Supports DropPath |
| | 3. Supports LayerScale |
| | 4. Supports Dropout in Attention and FFN |
| | 5. Makes few assumptions about the input except that it is a Tensor |
| | """ |
| | super().__init__() |
| | self.pre_transformer_layer = pre_transformer_layer |
| | if drop_path_type == "progressive": |
| | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_blocks)] |
| | elif drop_path_type == "uniform": |
| | dpr = [drop_path_rate for i in range(num_blocks)] |
| | else: |
| | raise ValueError(f"Unknown drop_path_type: {drop_path_type}") |
| |
|
| | self.blocks = nn.Sequential( |
| | *[ |
| | block( |
| | dim=embed_dim, |
| | attn_target=attn_target, |
| | mlp_ratio=mlp_ratio, |
| | ffn_dropout_rate=ffn_dropout_rate, |
| | drop_path=dpr[i], |
| | norm_layer=norm_layer, |
| | layer_scale_type=layer_scale_type, |
| | layer_scale_init_value=layer_scale_init_value, |
| | ) |
| | for i in range(num_blocks) |
| | ] |
| | ) |
| | self.post_transformer_layer = post_transformer_layer |
| | self.weight_init_style = weight_init_style |
| | self.apply(self._init_weights) |
| |
|
| | def _init_weights(self, m): |
| | if isinstance(m, nn.Linear): |
| | if self.weight_init_style == "jax": |
| | |
| | torch.nn.init.xavier_uniform_(m.weight) |
| | elif self.weight_init_style == "pytorch": |
| | |
| | trunc_normal_(m.weight, std=0.02) |
| |
|
| | if m.bias is not None: |
| | nn.init.constant_(m.bias, 0) |
| | elif isinstance(m, (nn.LayerNorm)): |
| | nn.init.constant_(m.bias, 0) |
| | nn.init.constant_(m.weight, 1.0) |
| |
|
| | def forward( |
| | self, |
| | tokens: torch.Tensor, |
| | attn_mask: torch.Tensor = None, |
| | use_checkpoint: bool = False, |
| | checkpoint_every_n: int = 1, |
| | checkpoint_blk_ids: List[int] = None, |
| | ): |
| | """ |
| | Inputs |
| | - tokens: data of shape N x L x D (or L x N x D depending on the attention implementation) |
| | - attn: mask of shape L x L |
| | |
| | Output |
| | - x: data of shape N x L x D (or L x N x D depending on the attention implementation) |
| | """ |
| | if self.pre_transformer_layer: |
| | tokens = self.pre_transformer_layer(tokens) |
| | if use_checkpoint and checkpoint_blk_ids is None: |
| | checkpoint_blk_ids = [ |
| | blk_id |
| | for blk_id in range(len(self.blocks)) |
| | if blk_id % checkpoint_every_n == 0 |
| | ] |
| | if checkpoint_blk_ids: |
| | checkpoint_blk_ids = set(checkpoint_blk_ids) |
| | for blk_id, blk in enumerate(self.blocks): |
| | if use_checkpoint and blk_id in checkpoint_blk_ids: |
| | tokens = checkpoint.checkpoint( |
| | blk, tokens, attn_mask, use_reentrant=False |
| | ) |
| | else: |
| | tokens = blk(tokens, attn_mask=attn_mask) |
| | if self.post_transformer_layer: |
| | tokens = self.post_transformer_layer(tokens) |
| | return tokens |
| |
|