| | import torch |
| | import torch.nn |
| | import torch.nn.functional as F |
| | import math |
| | from typing import Optional, Callable, List, Union, Tuple, Dict, Any |
| | from dataclasses import dataclass |
| | from forgetting_transformer.ops.layer_with_visualization import LayerWithVisualization |
| | import forgetting_transformer.ops.framework_mock as framework |
| |
|
| |
|
| | @dataclass |
| | class AttentionMask: |
| | src_length_mask: Optional[torch.Tensor] |
| | position_mask: Optional[torch.Tensor] |
| |
|
| |
|
| | class MultiHeadAttentionBase(LayerWithVisualization): |
| | def __init__(self, state_size: int, n_heads: int, dropout: float=0.1, projection_size: Optional[int] = None): |
| | assert state_size % n_heads == 0 |
| | super().__init__() |
| | self.attention_to_visualize = [] |
| |
|
| | self.state_size = state_size |
| | self.projection_size = projection_size or (state_size // n_heads) |
| | self.n_heads = n_heads |
| | self.scale = 1.0 / math.sqrt(self.projection_size) |
| |
|
| | self.dropout = torch.nn.Dropout(dropout) |
| |
|
| | @staticmethod |
| | def apply_logit_masks(logits: torch.Tensor, mask: Optional[AttentionMask], val: float = float("-inf")) -> torch.Tensor: |
| | if mask.position_mask is not None: |
| | |
| | logits = logits.masked_fill(mask.position_mask, val) |
| |
|
| | if mask.src_length_mask is not None: |
| | |
| | b, i = mask.src_length_mask.shape |
| | pad_dims = logits.ndim - 2 |
| | logits = logits.masked_fill(mask.src_length_mask.view([b] + [1] * pad_dims + [i]), val) |
| |
|
| | return logits |
| |
|
| | def _masked_softmax(self, logits: torch.Tensor, mask: Optional[AttentionMask]) -> torch.Tensor: |
| | if mask is None or (mask.src_length_mask is None and mask.position_mask is None): |
| | return F.softmax(logits, -1) |
| |
|
| | |
| | bb, n_time_dest, n_time_src = logits.shape |
| |
|
| | logits = logits.view(bb // self.n_heads, self.n_heads, n_time_dest, n_time_src) |
| | logits = self.apply_logit_masks(logits, mask) |
| |
|
| | logits = F.softmax(logits, -1) |
| | return logits.view(bb, n_time_dest, n_time_src) |
| |
|
| | def _attention_read(self, mask: Optional[AttentionMask], scores: torch.Tensor, v: torch.Tensor) -> \ |
| | Tuple[torch.Tensor, torch.Tensor]: |
| | |
| | |
| | |
| | |
| | s_reshape = scores.view(-1, self.n_heads, *scores.shape[1:]) |
| | |
| | if self.visualization_enabled: |
| | self.attention_to_visualize.append(s_reshape[0]) |
| | return torch.bmm(scores, v), s_reshape |
| |
|
| | def transform_data(self, input: torch.Tensor, proj: Callable[[torch.Tensor], torch.Tensor], |
| | n_projs: int) -> List[torch.Tensor]: |
| | |
| | |
| | n_batch, n_steps, _ = input.shape |
| | transformed = proj(input).view(n_batch, n_steps, self.n_heads, n_projs, -1). \ |
| | permute(0, 2, 1, 3, 4).contiguous().view(n_batch * self.n_heads, n_steps, n_projs, -1) |
| | return transformed.unbind(dim=2) |
| |
|
| | def plot(self, options: Dict[str, Any]) -> Dict[str, Any]: |
| | r = {} |
| | marks = options.get("steplabel") |
| | if options.get("mha.plot_head_details") and self.attention_to_visualize[0].shape[0] > 1: |
| | for head in range(self.attention_to_visualize[0].shape[0]): |
| | r[f"head_{head}"] = framework.visualize.plot.AnimatedHeatmap( |
| | torch.stack([layer[head] for _, layer in enumerate(self.attention_to_visualize)], 0), |
| | ylabel="dest", xlabel="src", textval=False, x_marks=marks, y_marks=marks, ignore_wrong_marks=True) |
| |
|
| | r["attention_max"] = framework.visualize.plot.AnimatedHeatmap( |
| | torch.stack([layer.max(0)[0] for _, layer in enumerate(self.attention_to_visualize)], 0), |
| | ylabel="dest", xlabel="src", textval=False, x_marks=marks, y_marks=marks, ignore_wrong_marks=True) |
| | self.attention_to_visualize = [] |
| | return r |
| |
|
| |
|
| | class AttentionMergeMixin: |
| | def __init__(self, out_size: Optional[int]) -> None: |
| | self.multi_head_merge = torch.nn.Linear(self.n_heads * self.projection_size, out_size or self.state_size, |
| | bias=False) |
| |
|
| | def merged_attention(self, n_batch: int, n_out_steps: int, *args, need_weights: bool = False, **kwargs) -> \ |
| | Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: |
| |
|
| | data, scores = self._attention(*args, **kwargs) |
| |
|
| | data = data.view(n_batch, self.n_heads, n_out_steps, -1).permute(0, 2, 1, 3).contiguous().\ |
| | view(n_batch, n_out_steps, -1) |
| |
|
| | return self.multi_head_merge(data), scores |
| |
|
| |
|
| | class AbsPosAttentionBase(MultiHeadAttentionBase): |
| | def get_attention_scores(self, mask: Optional[torch.Tensor], q: torch.Tensor, k: torch.Tensor) -> torch.Tensor: |
| | logits = torch.bmm(q, k.transpose(1, 2)) |
| | return self._masked_softmax(logits * self.scale, mask) |
| |
|
| | def _attention(self, mask: Optional[torch.Tensor], q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> \ |
| | torch.Tensor: |
| | |
| | |
| | scores = self.get_attention_scores(mask, q, k) |
| | return self._attention_read(mask, scores, v) |
| |
|
| |
|
| | class MultiHeadAttention(AttentionMergeMixin, AbsPosAttentionBase): |
| | def __init__(self, state_size: int, n_heads: int, dropout: float = 0.1, input_size: Optional[int] = None, |
| | out_size: Optional[int] = None): |
| | super(AttentionMergeMixin, self).__init__(state_size, n_heads, dropout) |
| |
|
| | self.data_to_kv = torch.nn.Linear(state_size, 2 * n_heads * self.projection_size, bias=False) |
| | self.data_to_q = torch.nn.Linear(input_size or state_size, n_heads * self.projection_size, bias=False) |
| |
|
| | super(MultiHeadAttention, self).__init__(out_size) |
| | self.reset_parameters() |
| |
|
| | def forward(self, curr_state: torch.Tensor, attend_to: torch.Tensor, mask: Optional[AttentionMask], |
| | need_weights: bool = False): |
| | |
| | k, v = self.transform_data(attend_to, self.data_to_kv, 2) |
| | q, = self.transform_data(curr_state, self.data_to_q, 1) |
| |
|
| | data, scores = self.merged_attention(curr_state.shape[0], q.shape[1], mask, q, k, v) |
| | if need_weights: |
| | return data, scores |
| | else: |
| | return data |
| |
|
| | def reset_parameters(self): |
| | torch.nn.init.xavier_uniform_(self.data_to_q.weight) |
| | torch.nn.init.xavier_uniform_(self.data_to_kv.weight) |
| | torch.nn.init.xavier_uniform_(self.data_to_kv.weight) |
| |
|