File size: 7,091 Bytes
731dcab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import torch
import torch.nn
import torch.nn.functional as F
import math
from typing import Optional, Callable, List, Union, Tuple, Dict, Any
from dataclasses import dataclass
from forgetting_transformer.ops.layer_with_visualization import LayerWithVisualization
import forgetting_transformer.ops.framework_mock as framework


@dataclass
class AttentionMask:
    src_length_mask: Optional[torch.Tensor]
    position_mask: Optional[torch.Tensor]


class MultiHeadAttentionBase(LayerWithVisualization):
    def __init__(self, state_size: int, n_heads: int, dropout: float=0.1, projection_size: Optional[int] = None):
        assert state_size % n_heads == 0
        super().__init__()
        self.attention_to_visualize = []

        self.state_size = state_size
        self.projection_size = projection_size or (state_size // n_heads)
        self.n_heads = n_heads
        self.scale = 1.0 / math.sqrt(self.projection_size)

        self.dropout = torch.nn.Dropout(dropout)

    @staticmethod
    def apply_logit_masks(logits: torch.Tensor, mask: Optional[AttentionMask], val: float = float("-inf")) -> torch.Tensor:
        if mask.position_mask is not None:
            # [..., N_out, N_in], broadcast works
            logits = logits.masked_fill(mask.position_mask, val)

        if mask.src_length_mask is not None:
            # [B, ...., N_in], needs manual shaping
            b, i = mask.src_length_mask.shape
            pad_dims = logits.ndim - 2
            logits = logits.masked_fill(mask.src_length_mask.view([b] + [1] * pad_dims + [i]), val)

        return logits

    def _masked_softmax(self, logits: torch.Tensor, mask: Optional[AttentionMask]) -> torch.Tensor:
        if mask is None or (mask.src_length_mask is None and mask.position_mask is None):
            return F.softmax(logits, -1)

        # Output shape: [n_batch * n_heads, n_time_dest, n_time_src]
        bb, n_time_dest, n_time_src = logits.shape

        logits = logits.view(bb // self.n_heads, self.n_heads, n_time_dest, n_time_src)
        logits = self.apply_logit_masks(logits, mask)

        logits = F.softmax(logits, -1)
        return logits.view(bb, n_time_dest, n_time_src)

    def _attention_read(self, mask: Optional[AttentionMask], scores: torch.Tensor, v: torch.Tensor) -> \
            Tuple[torch.Tensor, torch.Tensor]:
        # scores: [n_batch * n_heads, n_out, n_in]
        # v: [n_nbatch * n_heads, n_in]
        # Output data shape [n_batch * n_heads, n_time_dest, data_size]
        # Out attention score shape: [n_batch, n_heads, n_time_dest, n_time_src]
        s_reshape = scores.view(-1, self.n_heads, *scores.shape[1:])
        # scores = self.dropout(scores)
        if self.visualization_enabled:
            self.attention_to_visualize.append(s_reshape[0])
        return torch.bmm(scores, v), s_reshape

    def transform_data(self, input: torch.Tensor, proj: Callable[[torch.Tensor], torch.Tensor],
                       n_projs: int) -> List[torch.Tensor]:
        # Input shape: [n_batch, n_steps, n_channels]
        # Output: Tuple of n_projs tensors of dimension: [n_batch * n_heads, n_steps, projection_size]
        n_batch, n_steps, _ = input.shape
        transformed = proj(input).view(n_batch, n_steps, self.n_heads, n_projs, -1). \
            permute(0, 2, 1, 3, 4).contiguous().view(n_batch * self.n_heads, n_steps, n_projs, -1)
        return transformed.unbind(dim=2)

    def plot(self, options: Dict[str, Any]) -> Dict[str, Any]:
        r = {}
        marks = options.get("steplabel")
        if options.get("mha.plot_head_details") and self.attention_to_visualize[0].shape[0] > 1:
            for head in range(self.attention_to_visualize[0].shape[0]):
                r[f"head_{head}"] = framework.visualize.plot.AnimatedHeatmap(
                        torch.stack([layer[head] for _, layer in enumerate(self.attention_to_visualize)], 0),
                        ylabel="dest", xlabel="src", textval=False, x_marks=marks, y_marks=marks, ignore_wrong_marks=True)

        r["attention_max"] = framework.visualize.plot.AnimatedHeatmap(
            torch.stack([layer.max(0)[0] for _, layer in enumerate(self.attention_to_visualize)], 0),
            ylabel="dest", xlabel="src", textval=False, x_marks=marks, y_marks=marks, ignore_wrong_marks=True)
        self.attention_to_visualize = []
        return r


class AttentionMergeMixin:
    def __init__(self, out_size: Optional[int]) -> None:
        self.multi_head_merge = torch.nn.Linear(self.n_heads * self.projection_size, out_size or self.state_size, 
                                       bias=False)

    def merged_attention(self, n_batch: int, n_out_steps: int, *args, need_weights: bool = False, **kwargs) -> \
            Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:

        data, scores = self._attention(*args, **kwargs)

        data = data.view(n_batch, self.n_heads, n_out_steps, -1).permute(0, 2, 1, 3).contiguous().\
                    view(n_batch, n_out_steps, -1)

        return self.multi_head_merge(data), scores


class AbsPosAttentionBase(MultiHeadAttentionBase):
    def get_attention_scores(self, mask: Optional[torch.Tensor], q: torch.Tensor, k: torch.Tensor) -> torch.Tensor:
        logits = torch.bmm(q, k.transpose(1, 2))
        return self._masked_softmax(logits * self.scale, mask)

    def _attention(self, mask: Optional[torch.Tensor], q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> \
                   torch.Tensor:
        # all inputs should have a shape of [n_batch, n_steps, data_size]
        # Output shape [n_batch * n_heads, n_time_dest, data_size]
        scores = self.get_attention_scores(mask, q, k)
        return self._attention_read(mask, scores, v)


class MultiHeadAttention(AttentionMergeMixin, AbsPosAttentionBase):
    def __init__(self, state_size: int, n_heads: int, dropout: float = 0.1, input_size: Optional[int] = None,
                 out_size: Optional[int] = None):
        super(AttentionMergeMixin, self).__init__(state_size, n_heads, dropout)

        self.data_to_kv = torch.nn.Linear(state_size, 2 * n_heads * self.projection_size, bias=False)
        self.data_to_q = torch.nn.Linear(input_size or state_size, n_heads * self.projection_size, bias=False)

        super(MultiHeadAttention, self).__init__(out_size)
        self.reset_parameters()

    def forward(self, curr_state: torch.Tensor, attend_to: torch.Tensor, mask: Optional[AttentionMask],
                need_weights: bool = False):
        # Input and output shape: [n_batch, n_steps, data_size]
        k, v = self.transform_data(attend_to, self.data_to_kv, 2)
        q, = self.transform_data(curr_state, self.data_to_q, 1)

        data, scores = self.merged_attention(curr_state.shape[0], q.shape[1], mask, q, k, v)
        if need_weights:
            return data, scores
        else:
            return data

    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.data_to_q.weight)
        torch.nn.init.xavier_uniform_(self.data_to_kv.weight)
        torch.nn.init.xavier_uniform_(self.data_to_kv.weight)