File size: 4,330 Bytes
714cf46 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 | import torch
from torch import Tensor, nn
from . import vb_layers_initialize as init
class PairWeightedAveraging(nn.Module):
"""Pair weighted averaging layer."""
def __init__(
self,
c_m: int,
c_z: int,
c_h: int,
num_heads: int,
inf: float = 1e6,
) -> None:
"""Initialize the pair weighted averaging layer.
Parameters
----------
c_m: int
The dimension of the input sequence.
c_z: int
The dimension of the input pairwise tensor.
c_h: int
The dimension of the hidden.
num_heads: int
The number of heads.
inf: float
The value to use for masking, default 1e6.
"""
super().__init__()
self.c_m = c_m
self.c_z = c_z
self.c_h = c_h
self.num_heads = num_heads
self.inf = inf
self.norm_m = nn.LayerNorm(c_m)
self.norm_z = nn.LayerNorm(c_z)
self.proj_m = nn.Linear(c_m, c_h * num_heads, bias=False)
self.proj_g = nn.Linear(c_m, c_h * num_heads, bias=False)
self.proj_z = nn.Linear(c_z, num_heads, bias=False)
self.proj_o = nn.Linear(c_h * num_heads, c_m, bias=False)
init.final_init_(self.proj_o.weight)
def forward(
self, m: Tensor, z: Tensor, mask: Tensor, chunk_heads: False = bool
) -> Tensor:
"""Forward pass.
Parameters
----------
m : torch.Tensor
The input sequence tensor (B, S, N, D)
z : torch.Tensor
The input pairwise tensor (B, N, N, D)
mask : torch.Tensor
The pairwise mask tensor (B, N, N)
Returns
-------
torch.Tensor
The output sequence tensor (B, S, N, D)
"""
# Compute layer norms
m = self.norm_m(m)
z = self.norm_z(z)
if chunk_heads and not self.training:
# Compute heads sequentially
o_chunks = []
for head_idx in range(self.num_heads):
sliced_weight_proj_m = self.proj_m.weight[
head_idx * self.c_h : (head_idx + 1) * self.c_h, :
]
sliced_weight_proj_g = self.proj_g.weight[
head_idx * self.c_h : (head_idx + 1) * self.c_h, :
]
sliced_weight_proj_z = self.proj_z.weight[head_idx : (head_idx + 1), :]
sliced_weight_proj_o = self.proj_o.weight[
:, head_idx * self.c_h : (head_idx + 1) * self.c_h
]
# Project input tensors
v: Tensor = m @ sliced_weight_proj_m.T
v = v.reshape(*v.shape[:3], 1, self.c_h)
v = v.permute(0, 3, 1, 2, 4)
# Compute weights
b: Tensor = z @ sliced_weight_proj_z.T
b = b.permute(0, 3, 1, 2)
b = b + (1 - mask[:, None]) * -self.inf
w = torch.softmax(b, dim=-1)
# Compute gating
g: Tensor = m @ sliced_weight_proj_g.T
g = g.sigmoid()
# Compute output
o = torch.einsum("bhij,bhsjd->bhsid", w, v)
o = o.permute(0, 2, 3, 1, 4)
o = o.reshape(*o.shape[:3], 1 * self.c_h)
o_chunks = g * o
if head_idx == 0:
o_out = o_chunks @ sliced_weight_proj_o.T
else:
o_out += o_chunks @ sliced_weight_proj_o.T
return o_out
else:
# Project input tensors
v: Tensor = self.proj_m(m)
v = v.reshape(*v.shape[:3], self.num_heads, self.c_h)
v = v.permute(0, 3, 1, 2, 4)
# Compute weights
b: Tensor = self.proj_z(z)
b = b.permute(0, 3, 1, 2)
b = b + (1 - mask[:, None]) * -self.inf
w = torch.softmax(b, dim=-1)
# Compute gating
g: Tensor = self.proj_g(m)
g = g.sigmoid()
# Compute output
o = torch.einsum("bhij,bhsjd->bhsid", w, v)
o = o.permute(0, 2, 3, 1, 4)
o = o.reshape(*o.shape[:3], self.num_heads * self.c_h)
o = self.proj_o(g * o)
return o
|