|
|
import torch |
|
|
import torch.nn |
|
|
import torch.nn.functional as F |
|
|
import math |
|
|
from typing import Optional, Callable, List, Union, Tuple, Dict, Any |
|
|
from dataclasses import dataclass |
|
|
from forgetting_transformer.ops.layer_with_visualization import LayerWithVisualization |
|
|
import forgetting_transformer.ops.framework_mock as framework |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class AttentionMask: |
|
|
src_length_mask: Optional[torch.Tensor] |
|
|
position_mask: Optional[torch.Tensor] |
|
|
|
|
|
|
|
|
class MultiHeadAttentionBase(LayerWithVisualization): |
|
|
def __init__(self, state_size: int, n_heads: int, dropout: float=0.1, projection_size: Optional[int] = None): |
|
|
assert state_size % n_heads == 0 |
|
|
super().__init__() |
|
|
self.attention_to_visualize = [] |
|
|
|
|
|
self.state_size = state_size |
|
|
self.projection_size = projection_size or (state_size // n_heads) |
|
|
self.n_heads = n_heads |
|
|
self.scale = 1.0 / math.sqrt(self.projection_size) |
|
|
|
|
|
self.dropout = torch.nn.Dropout(dropout) |
|
|
|
|
|
@staticmethod |
|
|
def apply_logit_masks(logits: torch.Tensor, mask: Optional[AttentionMask], val: float = float("-inf")) -> torch.Tensor: |
|
|
if mask.position_mask is not None: |
|
|
|
|
|
logits = logits.masked_fill(mask.position_mask, val) |
|
|
|
|
|
if mask.src_length_mask is not None: |
|
|
|
|
|
b, i = mask.src_length_mask.shape |
|
|
pad_dims = logits.ndim - 2 |
|
|
logits = logits.masked_fill(mask.src_length_mask.view([b] + [1] * pad_dims + [i]), val) |
|
|
|
|
|
return logits |
|
|
|
|
|
def _masked_softmax(self, logits: torch.Tensor, mask: Optional[AttentionMask]) -> torch.Tensor: |
|
|
if mask is None or (mask.src_length_mask is None and mask.position_mask is None): |
|
|
return F.softmax(logits, -1) |
|
|
|
|
|
|
|
|
bb, n_time_dest, n_time_src = logits.shape |
|
|
|
|
|
logits = logits.view(bb // self.n_heads, self.n_heads, n_time_dest, n_time_src) |
|
|
logits = self.apply_logit_masks(logits, mask) |
|
|
|
|
|
logits = F.softmax(logits, -1) |
|
|
return logits.view(bb, n_time_dest, n_time_src) |
|
|
|
|
|
def _attention_read(self, mask: Optional[AttentionMask], scores: torch.Tensor, v: torch.Tensor) -> \ |
|
|
Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
s_reshape = scores.view(-1, self.n_heads, *scores.shape[1:]) |
|
|
|
|
|
if self.visualization_enabled: |
|
|
self.attention_to_visualize.append(s_reshape[0]) |
|
|
return torch.bmm(scores, v), s_reshape |
|
|
|
|
|
def transform_data(self, input: torch.Tensor, proj: Callable[[torch.Tensor], torch.Tensor], |
|
|
n_projs: int) -> List[torch.Tensor]: |
|
|
|
|
|
|
|
|
n_batch, n_steps, _ = input.shape |
|
|
transformed = proj(input).view(n_batch, n_steps, self.n_heads, n_projs, -1). \ |
|
|
permute(0, 2, 1, 3, 4).contiguous().view(n_batch * self.n_heads, n_steps, n_projs, -1) |
|
|
return transformed.unbind(dim=2) |
|
|
|
|
|
def plot(self, options: Dict[str, Any]) -> Dict[str, Any]: |
|
|
r = {} |
|
|
marks = options.get("steplabel") |
|
|
if options.get("mha.plot_head_details") and self.attention_to_visualize[0].shape[0] > 1: |
|
|
for head in range(self.attention_to_visualize[0].shape[0]): |
|
|
r[f"head_{head}"] = framework.visualize.plot.AnimatedHeatmap( |
|
|
torch.stack([layer[head] for _, layer in enumerate(self.attention_to_visualize)], 0), |
|
|
ylabel="dest", xlabel="src", textval=False, x_marks=marks, y_marks=marks, ignore_wrong_marks=True) |
|
|
|
|
|
r["attention_max"] = framework.visualize.plot.AnimatedHeatmap( |
|
|
torch.stack([layer.max(0)[0] for _, layer in enumerate(self.attention_to_visualize)], 0), |
|
|
ylabel="dest", xlabel="src", textval=False, x_marks=marks, y_marks=marks, ignore_wrong_marks=True) |
|
|
self.attention_to_visualize = [] |
|
|
return r |
|
|
|
|
|
|
|
|
class AttentionMergeMixin: |
|
|
def __init__(self, out_size: Optional[int]) -> None: |
|
|
self.multi_head_merge = torch.nn.Linear(self.n_heads * self.projection_size, out_size or self.state_size, |
|
|
bias=False) |
|
|
|
|
|
def merged_attention(self, n_batch: int, n_out_steps: int, *args, need_weights: bool = False, **kwargs) -> \ |
|
|
Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: |
|
|
|
|
|
data, scores = self._attention(*args, **kwargs) |
|
|
|
|
|
data = data.view(n_batch, self.n_heads, n_out_steps, -1).permute(0, 2, 1, 3).contiguous().\ |
|
|
view(n_batch, n_out_steps, -1) |
|
|
|
|
|
return self.multi_head_merge(data), scores |
|
|
|
|
|
|
|
|
class AbsPosAttentionBase(MultiHeadAttentionBase): |
|
|
def get_attention_scores(self, mask: Optional[torch.Tensor], q: torch.Tensor, k: torch.Tensor) -> torch.Tensor: |
|
|
logits = torch.bmm(q, k.transpose(1, 2)) |
|
|
return self._masked_softmax(logits * self.scale, mask) |
|
|
|
|
|
def _attention(self, mask: Optional[torch.Tensor], q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> \ |
|
|
torch.Tensor: |
|
|
|
|
|
|
|
|
scores = self.get_attention_scores(mask, q, k) |
|
|
return self._attention_read(mask, scores, v) |
|
|
|
|
|
|
|
|
class MultiHeadAttention(AttentionMergeMixin, AbsPosAttentionBase): |
|
|
def __init__(self, state_size: int, n_heads: int, dropout: float = 0.1, input_size: Optional[int] = None, |
|
|
out_size: Optional[int] = None): |
|
|
super(AttentionMergeMixin, self).__init__(state_size, n_heads, dropout) |
|
|
|
|
|
self.data_to_kv = torch.nn.Linear(state_size, 2 * n_heads * self.projection_size, bias=False) |
|
|
self.data_to_q = torch.nn.Linear(input_size or state_size, n_heads * self.projection_size, bias=False) |
|
|
|
|
|
super(MultiHeadAttention, self).__init__(out_size) |
|
|
self.reset_parameters() |
|
|
|
|
|
def forward(self, curr_state: torch.Tensor, attend_to: torch.Tensor, mask: Optional[AttentionMask], |
|
|
need_weights: bool = False): |
|
|
|
|
|
k, v = self.transform_data(attend_to, self.data_to_kv, 2) |
|
|
q, = self.transform_data(curr_state, self.data_to_q, 1) |
|
|
|
|
|
data, scores = self.merged_attention(curr_state.shape[0], q.shape[1], mask, q, k, v) |
|
|
if need_weights: |
|
|
return data, scores |
|
|
else: |
|
|
return data |
|
|
|
|
|
def reset_parameters(self): |
|
|
torch.nn.init.xavier_uniform_(self.data_to_q.weight) |
|
|
torch.nn.init.xavier_uniform_(self.data_to_kv.weight) |
|
|
torch.nn.init.xavier_uniform_(self.data_to_kv.weight) |
|
|
|