File size: 5,608 Bytes
d670799 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
# Copyright (c) OpenMMLab. All rights reserved.
from collections import OrderedDict
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from mmengine.model import BaseModule
from mmaction.registry import MODELS
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Defines the computation performed at every call."""
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
class QuickGELU(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Perform quick gelu."""
return x * torch.sigmoid(1.702 * x)
class ResidualAttentionBlock(nn.Module):
""""ResidualAttentionBlock.
Args:
d_model (int): The dimension of the model.
n_head (int): The number of heads.
attn_mask (torch.Tensor, optional): The attention mask.
Defaults to None.
"""
def __init__(self,
d_model: int,
n_head: int,
attn_mask: Optional[torch.Tensor] = None) -> None:
super().__init__()
self.attn = nn.MultiheadAttention(d_model, n_head)
self.ln_1 = LayerNorm(d_model)
self.mlp = nn.Sequential(
OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)),
('gelu', QuickGELU()),
('c_proj', nn.Linear(d_model * 4, d_model))]))
self.ln_2 = LayerNorm(d_model)
self.attn_mask = attn_mask
def attention(self, x: torch.Tensor) -> torch.Tensor:
"""Perform attention."""
self.attn_mask = self.attn_mask.to(
dtype=x.dtype,
device=x.device) if self.attn_mask is not None else None
return self.attn(
x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Defines the computation performed at every call."""
x = x + self.attention(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
class Transformer(nn.Module):
""""ResidualAttentionBlock.
Args:
width (int): The width of transformer.
heads (int): The number of heads of transformer.
layers (int): The number of layers of transformer.
attn_mask (torch.Tensor, optional): The attention mask.
Defaults to None.
"""
def __init__(self,
width: int,
layers: int,
heads: int,
attn_mask: Optional[torch.Tensor] = None):
super().__init__()
self.width = width
self.layers = layers
self.resblocks = nn.Sequential(*[
ResidualAttentionBlock(width, heads, attn_mask)
for _ in range(layers)
])
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Defines the computation performed at every call."""
return self.resblocks(x)
@MODELS.register_module()
class TransformerAdapter(BaseModule):
""""Transformer adapter, modified from github.com/openai/CLIP.
Args:
num_segs (int): The number of segments.
transformer_width (int): The width of transformer.
transformer_heads (int): The number of heads of transformer.
transformer_layers (int): The number of layers of transformer.
"""
def __init__(self, num_segs: int, transformer_width: int,
transformer_heads: int, transformer_layers: int) -> None:
super(TransformerAdapter, self).__init__()
self.num_segs = num_segs
self.positional_embedding = nn.Parameter(
torch.empty(num_segs, transformer_width))
self.transformer = Transformer(
width=transformer_width,
layers=transformer_layers,
heads=transformer_heads)
def init_weights(self) -> None:
"""Initialize the weights."""
nn.init.normal_(self.positional_embedding, std=0.01)
proj_std = (self.transformer.width**-0.5) * (
(2 * self.transformer.layers)**-0.5)
attn_std = self.transformer.width**-0.5
fc_std = (2 * self.transformer.width)**-0.5
for block in self.transformer.resblocks:
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Defines the computation performed at every call."""
b, seq_length, c = x.size()
x_original = x
x = x + self.positional_embedding
x = x.transpose(0, 1) # NLD -> LND
x = self.transformer(x)
x = x.transpose(0, 1) # LND -> NLD
x = x.type(x_original.dtype) + x_original
return x.mean(dim=1)
@MODELS.register_module()
class SimpleMeanAdapter(BaseModule):
"""Average features adapter.
Args:
dim (int): The dimension to perform averaging. Defaults to 1.
"""
def __init__(self, dim: Union[int, Tuple[int]] = 1) -> None:
super().__init__()
self.dim = dim
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Defines the computation performed at every call."""
return x.mean(dim=self.dim)
|