| |
| |
| |
| |
| |
| |
|
|
| import math |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from src.models.utils.modules import ( |
| Block, |
| CrossAttention, |
| CrossAttentionBlock |
| ) |
| from src.utils.tensors import trunc_normal_ |
|
|
|
|
| class AttentivePooler(nn.Module): |
| """ Attentive Pooler """ |
|
|
| def __init__( |
| self, |
| num_queries=1, |
| embed_dim=768, |
| num_heads=12, |
| mlp_ratio=4.0, |
| depth=1, |
| norm_layer=nn.LayerNorm, |
| init_std=0.02, |
| qkv_bias=True, |
| complete_block=True |
| ): |
| super().__init__() |
| self.query_tokens = nn.Parameter(torch.zeros(1, num_queries, embed_dim)) |
|
|
| self.complete_block = complete_block |
| if complete_block: |
| self.cross_attention_block = CrossAttentionBlock( |
| dim=embed_dim, |
| num_heads=num_heads, |
| mlp_ratio=mlp_ratio, |
| qkv_bias=qkv_bias, |
| norm_layer=norm_layer) |
| else: |
| self.cross_attention_block = CrossAttention( |
| dim=embed_dim, |
| num_heads=num_heads, |
| qkv_bias=qkv_bias) |
|
|
| self.blocks = None |
| if depth > 1: |
| self.blocks = nn.ModuleList([ |
| Block( |
| dim=embed_dim, |
| num_heads=num_heads, |
| mlp_ratio=mlp_ratio, |
| qkv_bias=qkv_bias, |
| qk_scale=False, |
| norm_layer=norm_layer) |
| for i in range(depth - 1)]) |
|
|
| self.init_std = init_std |
| trunc_normal_(self.query_tokens, std=self.init_std) |
| self.apply(self._init_weights) |
| self._rescale_blocks() |
|
|
| def _rescale_blocks(self): |
| def rescale(param, layer_id): |
| param.div_(math.sqrt(2.0 * layer_id)) |
|
|
| if self.complete_block: |
| rescale(self.cross_attention_block.xattn.proj.weight.data, 1) |
| rescale(self.cross_attention_block.mlp.fc2.weight.data, 1) |
| else: |
| rescale(self.cross_attention_block.proj.weight.data, 1) |
| if self.blocks is not None: |
| for layer_id, layer in enumerate(self.blocks, 1): |
| rescale(layer.attn.proj.weight.data, layer_id + 1) |
| rescale(layer.mlp.fc2.weight.data, layer_id + 1) |
|
|
| def _init_weights(self, m): |
| if isinstance(m, nn.Linear): |
| trunc_normal_(m.weight, std=self.init_std) |
| if isinstance(m, nn.Linear) and m.bias is not None: |
| nn.init.constant_(m.bias, 0) |
| elif isinstance(m, nn.LayerNorm): |
| nn.init.constant_(m.bias, 0) |
| nn.init.constant_(m.weight, 1.0) |
| elif isinstance(m, nn.Conv2d): |
| trunc_normal_(m.weight, std=self.init_std) |
| if m.bias is not None: |
| nn.init.constant_(m.bias, 0) |
|
|
| def forward(self, x): |
| q = self.query_tokens.repeat(len(x), 1, 1) |
| q = self.cross_attention_block(q, x) |
| if self.blocks is not None: |
| for blk in self.blocks: |
| q = blk(q) |
| return q |
|
|
|
|
| class AttentiveClassifier(nn.Module): |
| """ Attentive Classifier """ |
|
|
| def __init__( |
| self, |
| embed_dim=768, |
| num_heads=12, |
| mlp_ratio=4.0, |
| depth=1, |
| norm_layer=nn.LayerNorm, |
| init_std=0.02, |
| qkv_bias=True, |
| num_classes=1000, |
| complete_block=True, |
| ): |
| super().__init__() |
|
|
| self.pooler = AttentivePooler( |
| num_queries=1, |
| embed_dim=embed_dim, |
| num_heads=num_heads, |
| mlp_ratio=mlp_ratio, |
| depth=depth, |
| norm_layer=norm_layer, |
| init_std=init_std, |
| qkv_bias=qkv_bias, |
| complete_block=complete_block, |
| ) |
| self.target_layer = self.pooler.cross_attention_block.xattn |
| self.linear = nn.Linear(embed_dim, num_classes, bias=True) |
|
|
| def forward(self, x): |
| x = self.pooler(x).squeeze(1) |
| x = self.linear(x) |
| return x |