|
|
import torch |
|
|
import torch.nn |
|
|
import torch.nn.functional as F |
|
|
from typing import Optional, Dict, Any |
|
|
from forgetting_transformer.ops.multi_head_attention import AttentionMask, MultiHeadAttentionBase, AttentionMergeMixin |
|
|
import forgetting_transformer.ops.framework_mock as framework |
|
|
import math |
|
|
from matplotlib import cm |
|
|
|
|
|
def shift(posmat: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
|
|
|
|
|
|
p = F.pad(posmat, (0, 1, 0, 1)).flatten(-2) |
|
|
p = p.narrow(-1, posmat.shape[-1] // 2, posmat.shape[-1] * posmat.shape[-2]).view_as(posmat) |
|
|
|
|
|
return p.narrow(-1, 0, (posmat.shape[-1] + 1) // 2) |
|
|
|
|
|
|
|
|
class RelativeAttentionBase(MultiHeadAttentionBase): |
|
|
def __init__(self, state_size: int, n_heads: int, dropout: float, projection_size: Optional[int] = None): |
|
|
super().__init__(state_size, n_heads, dropout=dropout, projection_size=projection_size) |
|
|
self.scale = torch.nn.Parameter(torch.tensor([self.scale])) |
|
|
self.s_bias = torch.nn.Parameter(torch.tensor([0.0])) |
|
|
self.vis_pos_vs_content = [] |
|
|
|
|
|
def get_attention_scores(self, mask: Optional[torch.Tensor], |
|
|
q_content: torch.Tensor, k_content: torch.Tensor, |
|
|
q_pos: torch.Tensor, k_pos: torch.Tensor, |
|
|
pos_offset: int, ar_gate: Optional[torch.Tensor] = None) -> torch.Tensor: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
n_batch = q_content.shape[0] // self.n_heads |
|
|
n_out_steps = q_content.shape[1] |
|
|
|
|
|
|
|
|
content = torch.bmm(q_content, self.dropout(k_content).transpose(1, 2)) |
|
|
|
|
|
|
|
|
pos = torch.matmul(q_pos.view(n_batch, self.n_heads, n_out_steps, -1), self.dropout(k_pos).transpose(-1, -2)) |
|
|
fpos = shift(pos).flatten(0, 1) |
|
|
if ar_gate is not None: |
|
|
fpos = fpos * ar_gate + pos.flatten(0, 1)[..., fpos.shape[-1] - 1:] * (1 - ar_gate) |
|
|
|
|
|
|
|
|
if self.visualization_enabled: |
|
|
self.vis_pos_vs_content.append((content.view(n_batch, self.n_heads, *content.shape[1:])[0] * self.scale, |
|
|
fpos.view(n_batch, self.n_heads, *fpos.shape[1:])[0] * self.scale)) |
|
|
|
|
|
return self._masked_softmax((content + fpos) * self.scale, mask) |
|
|
|
|
|
def _attention(self, mask: Optional[torch.Tensor], |
|
|
q_content: torch.Tensor, k_content: torch.Tensor, |
|
|
q_pos: torch.Tensor, k_pos: torch.Tensor, |
|
|
v: torch.Tensor, pos_offset: int, |
|
|
ar_gate: Optional[torch.Tensor] = None) -> [torch.Tensor, torch.Tensor]: |
|
|
|
|
|
scores = self.get_attention_scores(mask, q_content, k_content, q_pos, k_pos, pos_offset, ar_gate) |
|
|
|
|
|
|
|
|
return self._attention_read(mask, scores, v) |
|
|
|
|
|
def _get_pos_subset(self, pos_encoding: torch.Tensor, length: int, offset: int) -> torch.Tensor: |
|
|
l_slice = 2 * length - 1 |
|
|
assert pos_encoding.shape[0] > l_slice |
|
|
return pos_encoding.narrow(0, pos_encoding.shape[0] // 2 - length + 1 - offset, 2 * length - 1) |
|
|
|
|
|
def plot(self, options: Dict[str, Any]) -> Dict[str, Any]: |
|
|
r = {} |
|
|
marks = options.get("steplabel") |
|
|
if options.get("mha.plot_head_details") and self.vis_pos_vs_content: |
|
|
for head in range(self.vis_pos_vs_content[0][0].shape[0]): |
|
|
cont = torch.stack([layer[0][head] for _, layer in enumerate(self.vis_pos_vs_content)], 0) |
|
|
pos = torch.stack([layer[1][head] for _, layer in enumerate(self.vis_pos_vs_content)], 0) |
|
|
i = torch.stack([layer[head] for _, layer in enumerate(self.attention_to_visualize)], 0) |
|
|
content = torch.stack([cont, pos], -1).softmax(-1)[..., 0] |
|
|
|
|
|
color = cm.get_cmap("brg")(content.cpu().numpy()) |
|
|
color[..., -1] = (i * 0.95 + 0.05).cpu().numpy() |
|
|
|
|
|
r[f"content_vs_pos_{head}"] = framework.visualize.plot.AnimatedHeatmap(color, ylabel="dest", |
|
|
xlabel="src", textval=False, x_marks=marks, y_marks=marks, cmap="brg", colorbar=True, |
|
|
colorbar_ticks=[0, 0.99], colorbar_labels=["pos", "con"], ignore_wrong_marks=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.vis_pos_vs_content = [] |
|
|
|
|
|
r.update(super().plot(options)) |
|
|
return r |
|
|
|
|
|
|
|
|
|
|
|
class FixedRelativeMultiheadAttentionBase(RelativeAttentionBase): |
|
|
def __init__(self, state_size: int, n_heads: int, dropout: float = 0.0, input_size: Optional[int] = None, |
|
|
projection_size: Optional[int] = None): |
|
|
super().__init__(state_size, n_heads, dropout, projection_size) |
|
|
|
|
|
self.input_size = state_size if input_size is None else input_size |
|
|
|
|
|
self.pos_to_pq = torch.nn.Linear(state_size, self.n_heads * self.projection_size, bias=False) |
|
|
self.register_buffer("pos_encoding", self._create_buffer(1000)) |
|
|
|
|
|
def _create_buffer(self, max_len: int): |
|
|
return framework.layers.sinusoidal_pos_embedding(self.state_size, 2 * max_len - 1, -max_len + 1, |
|
|
device=self.pos_to_pq.weight.device) |
|
|
|
|
|
def get_pos(self, l: int, offset: int) -> torch.Tensor: |
|
|
if self.pos_encoding.shape[0] < 2 * (l + offset) - 1: |
|
|
self.pos_encoding = self._create_buffer(int(2**math.ceil(math.log2(2 * (l + offset) - 1)))) |
|
|
|
|
|
return self.pos_to_pq(self._get_pos_subset(self.pos_encoding, l, offset)) |
|
|
|
|
|
|
|
|
class FixedRelativeMultiheadAttention(AttentionMergeMixin, FixedRelativeMultiheadAttentionBase): |
|
|
def __init__(self, state_size: int, n_heads: int, dropout: float = 0.0, global_pos_bias: bool = True, |
|
|
global_content_bias: bool = True, input_size: Optional[int] = None, absolute_gate: bool = False, |
|
|
projection_size: Optional[int] = None, output_size: Optional[int] = None): |
|
|
super(AttentionMergeMixin, self).__init__(state_size, n_heads, dropout, input_size, projection_size=projection_size) |
|
|
|
|
|
self.data_to_kv = torch.nn.Linear(state_size, 2 * n_heads * self.projection_size, bias=False) |
|
|
self.data_to_q = torch.nn.Linear(self.input_size, n_heads * self.projection_size, bias=False) |
|
|
self.data_to_absgate = torch.nn.Linear(self.input_size, n_heads) \ |
|
|
if absolute_gate else None |
|
|
|
|
|
self.global_content_bias = torch.nn.Parameter(torch.zeros([n_heads, self.projection_size])) \ |
|
|
if global_content_bias else None |
|
|
self.global_pos_bias = torch.nn.Parameter(torch.zeros([n_heads, self.projection_size])) \ |
|
|
if global_pos_bias else None |
|
|
|
|
|
super(FixedRelativeMultiheadAttention, self).__init__(output_size) |
|
|
self.reset_parameters() |
|
|
|
|
|
def add_head_specific_bias(self, data: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: |
|
|
|
|
|
|
|
|
return (data.view(-1, bias.shape[0], *data.shape[1:]) + bias.unsqueeze(1).type_as(data)).view_as(data) \ |
|
|
if bias is not None else data |
|
|
|
|
|
def forward(self, curr_state: torch.Tensor, attend_to: torch.Tensor, mask: Optional[AttentionMask], |
|
|
pos_offset: int = 0, need_weights: bool = False): |
|
|
|
|
|
|
|
|
batch_size, in_len = attend_to.shape[0:2] |
|
|
out_len = curr_state.shape[1] |
|
|
|
|
|
k_content, v = self.transform_data(attend_to, self.data_to_kv, 2) |
|
|
q, = self.transform_data(curr_state, self.data_to_q, 1) |
|
|
|
|
|
k_pos = self.get_pos(in_len, pos_offset).view(-1, self.n_heads, self.projection_size).\ |
|
|
transpose(0, 1) |
|
|
|
|
|
q_content = self.add_head_specific_bias(q, self.global_content_bias) |
|
|
q_pos = self.add_head_specific_bias(q, self.global_pos_bias) |
|
|
|
|
|
|
|
|
absgate = torch.sigmoid(self.transform_data(curr_state, self.data_to_absgate, 1)[0]) \ |
|
|
if self.data_to_absgate is not None else None |
|
|
|
|
|
data, scores = self.merged_attention(batch_size, out_len, mask, q_content, k_content, q_pos, k_pos, v, |
|
|
pos_offset, ar_gate=absgate, need_weights=need_weights) |
|
|
|
|
|
if need_weights: |
|
|
return data, scores |
|
|
else: |
|
|
return data |
|
|
|
|
|
def reset_parameters(self): |
|
|
|
|
|
|
|
|
torch.nn.init.xavier_uniform_(self.data_to_q.weight) |
|
|
torch.nn.init.xavier_uniform_(self.pos_to_pq.weight) |
|
|
torch.nn.init.xavier_uniform_(self.data_to_kv.weight) |
|
|
torch.nn.init.xavier_uniform_(self.data_to_kv.weight) |
|
|
|
|
|
if self.global_content_bias is not None: |
|
|
self.global_content_bias.data.fill_(0) |
|
|
|
|
|
if self.global_pos_bias is not None: |
|
|
self.global_pos_bias.data.fill_(0) |
|
|
|