File size: 7,091 Bytes
731dcab | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 | import torch
import torch.nn
import torch.nn.functional as F
import math
from typing import Optional, Callable, List, Union, Tuple, Dict, Any
from dataclasses import dataclass
from forgetting_transformer.ops.layer_with_visualization import LayerWithVisualization
import forgetting_transformer.ops.framework_mock as framework
@dataclass
class AttentionMask:
src_length_mask: Optional[torch.Tensor]
position_mask: Optional[torch.Tensor]
class MultiHeadAttentionBase(LayerWithVisualization):
def __init__(self, state_size: int, n_heads: int, dropout: float=0.1, projection_size: Optional[int] = None):
assert state_size % n_heads == 0
super().__init__()
self.attention_to_visualize = []
self.state_size = state_size
self.projection_size = projection_size or (state_size // n_heads)
self.n_heads = n_heads
self.scale = 1.0 / math.sqrt(self.projection_size)
self.dropout = torch.nn.Dropout(dropout)
@staticmethod
def apply_logit_masks(logits: torch.Tensor, mask: Optional[AttentionMask], val: float = float("-inf")) -> torch.Tensor:
if mask.position_mask is not None:
# [..., N_out, N_in], broadcast works
logits = logits.masked_fill(mask.position_mask, val)
if mask.src_length_mask is not None:
# [B, ...., N_in], needs manual shaping
b, i = mask.src_length_mask.shape
pad_dims = logits.ndim - 2
logits = logits.masked_fill(mask.src_length_mask.view([b] + [1] * pad_dims + [i]), val)
return logits
def _masked_softmax(self, logits: torch.Tensor, mask: Optional[AttentionMask]) -> torch.Tensor:
if mask is None or (mask.src_length_mask is None and mask.position_mask is None):
return F.softmax(logits, -1)
# Output shape: [n_batch * n_heads, n_time_dest, n_time_src]
bb, n_time_dest, n_time_src = logits.shape
logits = logits.view(bb // self.n_heads, self.n_heads, n_time_dest, n_time_src)
logits = self.apply_logit_masks(logits, mask)
logits = F.softmax(logits, -1)
return logits.view(bb, n_time_dest, n_time_src)
def _attention_read(self, mask: Optional[AttentionMask], scores: torch.Tensor, v: torch.Tensor) -> \
Tuple[torch.Tensor, torch.Tensor]:
# scores: [n_batch * n_heads, n_out, n_in]
# v: [n_nbatch * n_heads, n_in]
# Output data shape [n_batch * n_heads, n_time_dest, data_size]
# Out attention score shape: [n_batch, n_heads, n_time_dest, n_time_src]
s_reshape = scores.view(-1, self.n_heads, *scores.shape[1:])
# scores = self.dropout(scores)
if self.visualization_enabled:
self.attention_to_visualize.append(s_reshape[0])
return torch.bmm(scores, v), s_reshape
def transform_data(self, input: torch.Tensor, proj: Callable[[torch.Tensor], torch.Tensor],
n_projs: int) -> List[torch.Tensor]:
# Input shape: [n_batch, n_steps, n_channels]
# Output: Tuple of n_projs tensors of dimension: [n_batch * n_heads, n_steps, projection_size]
n_batch, n_steps, _ = input.shape
transformed = proj(input).view(n_batch, n_steps, self.n_heads, n_projs, -1). \
permute(0, 2, 1, 3, 4).contiguous().view(n_batch * self.n_heads, n_steps, n_projs, -1)
return transformed.unbind(dim=2)
def plot(self, options: Dict[str, Any]) -> Dict[str, Any]:
r = {}
marks = options.get("steplabel")
if options.get("mha.plot_head_details") and self.attention_to_visualize[0].shape[0] > 1:
for head in range(self.attention_to_visualize[0].shape[0]):
r[f"head_{head}"] = framework.visualize.plot.AnimatedHeatmap(
torch.stack([layer[head] for _, layer in enumerate(self.attention_to_visualize)], 0),
ylabel="dest", xlabel="src", textval=False, x_marks=marks, y_marks=marks, ignore_wrong_marks=True)
r["attention_max"] = framework.visualize.plot.AnimatedHeatmap(
torch.stack([layer.max(0)[0] for _, layer in enumerate(self.attention_to_visualize)], 0),
ylabel="dest", xlabel="src", textval=False, x_marks=marks, y_marks=marks, ignore_wrong_marks=True)
self.attention_to_visualize = []
return r
class AttentionMergeMixin:
def __init__(self, out_size: Optional[int]) -> None:
self.multi_head_merge = torch.nn.Linear(self.n_heads * self.projection_size, out_size or self.state_size,
bias=False)
def merged_attention(self, n_batch: int, n_out_steps: int, *args, need_weights: bool = False, **kwargs) -> \
Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
data, scores = self._attention(*args, **kwargs)
data = data.view(n_batch, self.n_heads, n_out_steps, -1).permute(0, 2, 1, 3).contiguous().\
view(n_batch, n_out_steps, -1)
return self.multi_head_merge(data), scores
class AbsPosAttentionBase(MultiHeadAttentionBase):
def get_attention_scores(self, mask: Optional[torch.Tensor], q: torch.Tensor, k: torch.Tensor) -> torch.Tensor:
logits = torch.bmm(q, k.transpose(1, 2))
return self._masked_softmax(logits * self.scale, mask)
def _attention(self, mask: Optional[torch.Tensor], q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> \
torch.Tensor:
# all inputs should have a shape of [n_batch, n_steps, data_size]
# Output shape [n_batch * n_heads, n_time_dest, data_size]
scores = self.get_attention_scores(mask, q, k)
return self._attention_read(mask, scores, v)
class MultiHeadAttention(AttentionMergeMixin, AbsPosAttentionBase):
def __init__(self, state_size: int, n_heads: int, dropout: float = 0.1, input_size: Optional[int] = None,
out_size: Optional[int] = None):
super(AttentionMergeMixin, self).__init__(state_size, n_heads, dropout)
self.data_to_kv = torch.nn.Linear(state_size, 2 * n_heads * self.projection_size, bias=False)
self.data_to_q = torch.nn.Linear(input_size or state_size, n_heads * self.projection_size, bias=False)
super(MultiHeadAttention, self).__init__(out_size)
self.reset_parameters()
def forward(self, curr_state: torch.Tensor, attend_to: torch.Tensor, mask: Optional[AttentionMask],
need_weights: bool = False):
# Input and output shape: [n_batch, n_steps, data_size]
k, v = self.transform_data(attend_to, self.data_to_kv, 2)
q, = self.transform_data(curr_state, self.data_to_q, 1)
data, scores = self.merged_attention(curr_state.shape[0], q.shape[1], mask, q, k, v)
if need_weights:
return data, scores
else:
return data
def reset_parameters(self):
torch.nn.init.xavier_uniform_(self.data_to_q.weight)
torch.nn.init.xavier_uniform_(self.data_to_kv.weight)
torch.nn.init.xavier_uniform_(self.data_to_kv.weight)
|