File size: 5,608 Bytes
d670799
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
# Copyright (c) OpenMMLab. All rights reserved.
from collections import OrderedDict
from typing import Optional, Tuple, Union

import torch
import torch.nn as nn
from mmengine.model import BaseModule

from mmaction.registry import MODELS


class LayerNorm(nn.LayerNorm):
    """Subclass torch's LayerNorm to handle fp16."""

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Defines the computation performed at every call."""
        orig_type = x.dtype
        ret = super().forward(x.type(torch.float32))
        return ret.type(orig_type)


class QuickGELU(nn.Module):

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Perform quick gelu."""
        return x * torch.sigmoid(1.702 * x)


class ResidualAttentionBlock(nn.Module):
    """"ResidualAttentionBlock.



    Args:

        d_model (int): The dimension of the model.

        n_head (int): The number of heads.

        attn_mask (torch.Tensor, optional): The attention mask.

            Defaults to None.

    """

    def __init__(self,

                 d_model: int,

                 n_head: int,

                 attn_mask: Optional[torch.Tensor] = None) -> None:
        super().__init__()

        self.attn = nn.MultiheadAttention(d_model, n_head)
        self.ln_1 = LayerNorm(d_model)
        self.mlp = nn.Sequential(
            OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)),
                         ('gelu', QuickGELU()),
                         ('c_proj', nn.Linear(d_model * 4, d_model))]))
        self.ln_2 = LayerNorm(d_model)
        self.attn_mask = attn_mask

    def attention(self, x: torch.Tensor) -> torch.Tensor:
        """Perform attention."""
        self.attn_mask = self.attn_mask.to(
            dtype=x.dtype,
            device=x.device) if self.attn_mask is not None else None
        return self.attn(
            x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Defines the computation performed at every call."""
        x = x + self.attention(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x


class Transformer(nn.Module):
    """"ResidualAttentionBlock.



    Args:

        width (int): The width of transformer.

        heads (int): The number of heads of transformer.

        layers (int): The number of layers of transformer.

        attn_mask (torch.Tensor, optional): The attention mask.

            Defaults to None.

    """

    def __init__(self,

                 width: int,

                 layers: int,

                 heads: int,

                 attn_mask: Optional[torch.Tensor] = None):
        super().__init__()
        self.width = width
        self.layers = layers
        self.resblocks = nn.Sequential(*[
            ResidualAttentionBlock(width, heads, attn_mask)
            for _ in range(layers)
        ])

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Defines the computation performed at every call."""
        return self.resblocks(x)


@MODELS.register_module()
class TransformerAdapter(BaseModule):
    """"Transformer adapter, modified from github.com/openai/CLIP.



    Args:

        num_segs (int): The number of segments.

        transformer_width (int): The width of transformer.

        transformer_heads (int): The number of heads of transformer.

        transformer_layers (int): The number of layers of transformer.

    """

    def __init__(self, num_segs: int, transformer_width: int,

                 transformer_heads: int, transformer_layers: int) -> None:
        super(TransformerAdapter, self).__init__()
        self.num_segs = num_segs

        self.positional_embedding = nn.Parameter(
            torch.empty(num_segs, transformer_width))
        self.transformer = Transformer(
            width=transformer_width,
            layers=transformer_layers,
            heads=transformer_heads)

    def init_weights(self) -> None:
        """Initialize the weights."""

        nn.init.normal_(self.positional_embedding, std=0.01)

        proj_std = (self.transformer.width**-0.5) * (
            (2 * self.transformer.layers)**-0.5)
        attn_std = self.transformer.width**-0.5
        fc_std = (2 * self.transformer.width)**-0.5
        for block in self.transformer.resblocks:
            nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
            nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
            nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
            nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Defines the computation performed at every call."""
        b, seq_length, c = x.size()

        x_original = x
        x = x + self.positional_embedding
        x = x.transpose(0, 1)  # NLD -> LND
        x = self.transformer(x)
        x = x.transpose(0, 1)  # LND -> NLD
        x = x.type(x_original.dtype) + x_original
        return x.mean(dim=1)


@MODELS.register_module()
class SimpleMeanAdapter(BaseModule):
    """Average features adapter.



    Args:

        dim (int): The dimension to perform averaging. Defaults to 1.

    """

    def __init__(self, dim: Union[int, Tuple[int]] = 1) -> None:
        super().__init__()
        self.dim = dim

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Defines the computation performed at every call."""
        return x.mean(dim=self.dim)