| import torch |
| import torch.nn |
| import torch.nn.functional as F |
| import math |
| from typing import Optional, Callable, List, Union, Tuple, Dict, Any |
| from dataclasses import dataclass |
| from forgetting_transformer.ops.layer_with_visualization import LayerWithVisualization |
| import forgetting_transformer.ops.framework_mock as framework |
|
|
|
|
| @dataclass |
| class AttentionMask: |
| src_length_mask: Optional[torch.Tensor] |
| position_mask: Optional[torch.Tensor] |
|
|
|
|
| class MultiHeadAttentionBase(LayerWithVisualization): |
| def __init__(self, state_size: int, n_heads: int, dropout: float=0.1, projection_size: Optional[int] = None): |
| assert state_size % n_heads == 0 |
| super().__init__() |
| self.attention_to_visualize = [] |
|
|
| self.state_size = state_size |
| self.projection_size = projection_size or (state_size // n_heads) |
| self.n_heads = n_heads |
| self.scale = 1.0 / math.sqrt(self.projection_size) |
|
|
| self.dropout = torch.nn.Dropout(dropout) |
|
|
| @staticmethod |
| def apply_logit_masks(logits: torch.Tensor, mask: Optional[AttentionMask], val: float = float("-inf")) -> torch.Tensor: |
| if mask.position_mask is not None: |
| |
| logits = logits.masked_fill(mask.position_mask, val) |
|
|
| if mask.src_length_mask is not None: |
| |
| b, i = mask.src_length_mask.shape |
| pad_dims = logits.ndim - 2 |
| logits = logits.masked_fill(mask.src_length_mask.view([b] + [1] * pad_dims + [i]), val) |
|
|
| return logits |
|
|
| def _masked_softmax(self, logits: torch.Tensor, mask: Optional[AttentionMask]) -> torch.Tensor: |
| if mask is None or (mask.src_length_mask is None and mask.position_mask is None): |
| return F.softmax(logits, -1) |
|
|
| |
| bb, n_time_dest, n_time_src = logits.shape |
|
|
| logits = logits.view(bb // self.n_heads, self.n_heads, n_time_dest, n_time_src) |
| logits = self.apply_logit_masks(logits, mask) |
|
|
| logits = F.softmax(logits, -1) |
| return logits.view(bb, n_time_dest, n_time_src) |
|
|
| def _attention_read(self, mask: Optional[AttentionMask], scores: torch.Tensor, v: torch.Tensor) -> \ |
| Tuple[torch.Tensor, torch.Tensor]: |
| |
| |
| |
| |
| s_reshape = scores.view(-1, self.n_heads, *scores.shape[1:]) |
| |
| if self.visualization_enabled: |
| self.attention_to_visualize.append(s_reshape[0]) |
| return torch.bmm(scores, v), s_reshape |
|
|
| def transform_data(self, input: torch.Tensor, proj: Callable[[torch.Tensor], torch.Tensor], |
| n_projs: int) -> List[torch.Tensor]: |
| |
| |
| n_batch, n_steps, _ = input.shape |
| transformed = proj(input).view(n_batch, n_steps, self.n_heads, n_projs, -1). \ |
| permute(0, 2, 1, 3, 4).contiguous().view(n_batch * self.n_heads, n_steps, n_projs, -1) |
| return transformed.unbind(dim=2) |
|
|
| def plot(self, options: Dict[str, Any]) -> Dict[str, Any]: |
| r = {} |
| marks = options.get("steplabel") |
| if options.get("mha.plot_head_details") and self.attention_to_visualize[0].shape[0] > 1: |
| for head in range(self.attention_to_visualize[0].shape[0]): |
| r[f"head_{head}"] = framework.visualize.plot.AnimatedHeatmap( |
| torch.stack([layer[head] for _, layer in enumerate(self.attention_to_visualize)], 0), |
| ylabel="dest", xlabel="src", textval=False, x_marks=marks, y_marks=marks, ignore_wrong_marks=True) |
|
|
| r["attention_max"] = framework.visualize.plot.AnimatedHeatmap( |
| torch.stack([layer.max(0)[0] for _, layer in enumerate(self.attention_to_visualize)], 0), |
| ylabel="dest", xlabel="src", textval=False, x_marks=marks, y_marks=marks, ignore_wrong_marks=True) |
| self.attention_to_visualize = [] |
| return r |
|
|
|
|
| class AttentionMergeMixin: |
| def __init__(self, out_size: Optional[int]) -> None: |
| self.multi_head_merge = torch.nn.Linear(self.n_heads * self.projection_size, out_size or self.state_size, |
| bias=False) |
|
|
| def merged_attention(self, n_batch: int, n_out_steps: int, *args, need_weights: bool = False, **kwargs) -> \ |
| Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: |
|
|
| data, scores = self._attention(*args, **kwargs) |
|
|
| data = data.view(n_batch, self.n_heads, n_out_steps, -1).permute(0, 2, 1, 3).contiguous().\ |
| view(n_batch, n_out_steps, -1) |
|
|
| return self.multi_head_merge(data), scores |
|
|
|
|
| class AbsPosAttentionBase(MultiHeadAttentionBase): |
| def get_attention_scores(self, mask: Optional[torch.Tensor], q: torch.Tensor, k: torch.Tensor) -> torch.Tensor: |
| logits = torch.bmm(q, k.transpose(1, 2)) |
| return self._masked_softmax(logits * self.scale, mask) |
|
|
| def _attention(self, mask: Optional[torch.Tensor], q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> \ |
| torch.Tensor: |
| |
| |
| scores = self.get_attention_scores(mask, q, k) |
| return self._attention_read(mask, scores, v) |
|
|
|
|
| class MultiHeadAttention(AttentionMergeMixin, AbsPosAttentionBase): |
| def __init__(self, state_size: int, n_heads: int, dropout: float = 0.1, input_size: Optional[int] = None, |
| out_size: Optional[int] = None): |
| super(AttentionMergeMixin, self).__init__(state_size, n_heads, dropout) |
|
|
| self.data_to_kv = torch.nn.Linear(state_size, 2 * n_heads * self.projection_size, bias=False) |
| self.data_to_q = torch.nn.Linear(input_size or state_size, n_heads * self.projection_size, bias=False) |
|
|
| super(MultiHeadAttention, self).__init__(out_size) |
| self.reset_parameters() |
|
|
| def forward(self, curr_state: torch.Tensor, attend_to: torch.Tensor, mask: Optional[AttentionMask], |
| need_weights: bool = False): |
| |
| k, v = self.transform_data(attend_to, self.data_to_kv, 2) |
| q, = self.transform_data(curr_state, self.data_to_q, 1) |
|
|
| data, scores = self.merged_attention(curr_state.shape[0], q.shape[1], mask, q, k, v) |
| if need_weights: |
| return data, scores |
| else: |
| return data |
|
|
| def reset_parameters(self): |
| torch.nn.init.xavier_uniform_(self.data_to_q.weight) |
| torch.nn.init.xavier_uniform_(self.data_to_kv.weight) |
| torch.nn.init.xavier_uniform_(self.data_to_kv.weight) |
|
|