|
|
import math |
|
|
from functools import reduce, wraps |
|
|
from inspect import isfunction |
|
|
from operator import mul |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from aml.multimodal_video.utils.einops.lib import rearrange, repeat |
|
|
from aml.multimodal_video.utils.einops.lib.layers.torch import Rearrange |
|
|
|
|
|
from fairseq.modules.local_attention import LocalAttention |
|
|
|
|
|
|
|
|
|
|
|
TOKEN_SELF_ATTN_VALUE = -5e4 |
|
|
KMEAN_INIT_ITERS = 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def exists(val): |
|
|
return val is not None |
|
|
|
|
|
|
|
|
def identity(x, *args, **kwargs): |
|
|
return x |
|
|
|
|
|
|
|
|
def default(x, d): |
|
|
if not exists(x): |
|
|
return d if not isfunction(d) else d() |
|
|
return x |
|
|
|
|
|
|
|
|
def cast_tuple(x): |
|
|
return x if isinstance(x, tuple) else (x,) |
|
|
|
|
|
|
|
|
def cache_fn(f): |
|
|
cache = None |
|
|
|
|
|
@wraps(f) |
|
|
def cached_fn(*args, **kwargs): |
|
|
nonlocal cache |
|
|
if exists(cache): |
|
|
return cache |
|
|
cache = f(*args, **kwargs) |
|
|
return cache |
|
|
|
|
|
return cached_fn |
|
|
|
|
|
|
|
|
def to(t): |
|
|
return {"device": t.device, "dtype": t.dtype} |
|
|
|
|
|
|
|
|
def find_modules(nn_module, type): |
|
|
return [module for module in nn_module.modules() if isinstance(module, type)] |
|
|
|
|
|
|
|
|
def is_empty(t): |
|
|
return t.nelement() == 0 |
|
|
|
|
|
|
|
|
def max_neg_value(tensor): |
|
|
return -torch.finfo(tensor.dtype).max |
|
|
|
|
|
|
|
|
def batched_index_select(values, indices): |
|
|
last_dim = values.shape[-1] |
|
|
return values.gather(2, expand_dim(indices, -1, last_dim)) |
|
|
|
|
|
|
|
|
def merge_dims(ind_from, ind_to, tensor): |
|
|
shape = list(tensor.shape) |
|
|
arr_slice = slice(ind_from, ind_to + 1) |
|
|
shape[arr_slice] = [reduce(mul, shape[arr_slice])] |
|
|
return tensor.reshape(*shape) |
|
|
|
|
|
|
|
|
def expand_dim(t, dim, k): |
|
|
t = t.unsqueeze(dim) |
|
|
expand_shape = [-1] * len(t.shape) |
|
|
expand_shape[dim] = k |
|
|
return t.expand(*expand_shape) |
|
|
|
|
|
|
|
|
def scatter_mean(src, t, index, dim, eps=1e-5): |
|
|
numer = src.scatter_add(dim, index, t) |
|
|
denom = src.scatter_add(dim, index, torch.ones_like(t)) |
|
|
return numer / (denom + eps) |
|
|
|
|
|
|
|
|
def split_at_index(dim, index, t): |
|
|
pre_slices = (slice(None),) * dim |
|
|
l = (*pre_slices, slice(None, index)) |
|
|
r = (*pre_slices, slice(index, None)) |
|
|
return t[l], t[r] |
|
|
|
|
|
|
|
|
def reshape_dim(t, dim, split_dims): |
|
|
shape = list(t.shape) |
|
|
num_dims = len(shape) |
|
|
dim = (dim + num_dims) % num_dims |
|
|
shape[dim : dim + 1] = split_dims |
|
|
return t.reshape(shape) |
|
|
|
|
|
|
|
|
def ema(old, new, decay): |
|
|
if not exists(old): |
|
|
return new |
|
|
return old * decay + new * (1 - decay) |
|
|
|
|
|
|
|
|
def ema_inplace(moving_avg, new, decay): |
|
|
if is_empty(moving_avg): |
|
|
moving_avg.data.copy_(new) |
|
|
return |
|
|
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def map_first_tuple_or_el(x, fn): |
|
|
if isinstance(x, tuple): |
|
|
return (fn(x[0]),) + x[1:] |
|
|
return fn(x) |
|
|
|
|
|
|
|
|
class Chunk(nn.Module): |
|
|
def __init__(self, chunks, fn, along_dim=-1): |
|
|
super().__init__() |
|
|
self.dim = along_dim |
|
|
self.chunks = chunks |
|
|
self.fn = fn |
|
|
|
|
|
def forward(self, x, **kwargs): |
|
|
if self.chunks <= 1: |
|
|
return self.fn(x, **kwargs) |
|
|
chunks = x.chunk(self.chunks, dim=self.dim) |
|
|
return torch.cat([self.fn(c, **kwargs) for c in chunks], dim=self.dim) |
|
|
|
|
|
|
|
|
class PreNorm(nn.ModuleList): |
|
|
def __init__(self, norm_class, dim, fn): |
|
|
super().__init__() |
|
|
self.norm = norm_class(dim) |
|
|
self.fn = fn |
|
|
|
|
|
def forward(self, x, **kwargs): |
|
|
x = self.norm(x) |
|
|
return self.fn(x, **kwargs) |
|
|
|
|
|
|
|
|
class ReZero(nn.Module): |
|
|
def __init__(self, fn): |
|
|
super().__init__() |
|
|
self.residual_weight = nn.Parameter(torch.zeros(1)) |
|
|
self.fn = fn |
|
|
|
|
|
def forward(self, x, **kwargs): |
|
|
x = self.fn(x, **kwargs) |
|
|
return map_first_tuple_or_el(x, lambda t: t * self.residual_weight) |
|
|
|
|
|
|
|
|
class ScaleNorm(nn.Module): |
|
|
def __init__(self, dim, eps=1e-5): |
|
|
super().__init__() |
|
|
self.g = nn.Parameter(torch.ones(1)) |
|
|
self.eps = eps |
|
|
|
|
|
def forward(self, x): |
|
|
def norm(t): |
|
|
n = torch.norm(t, dim=-1, keepdim=True).clamp(min=self.eps) |
|
|
return t / n * self.g |
|
|
|
|
|
return map_first_tuple_or_el(x, norm) |
|
|
|
|
|
|
|
|
class ProjectInOut(nn.Module): |
|
|
def __init__(self, fn, dim_in, dim_out, project_out=True): |
|
|
super().__init__() |
|
|
self.fn = fn |
|
|
self.project_in = nn.Linear(dim_in, dim_out) |
|
|
self.project_out = nn.Linear(dim_out, dim_in) if project_out else identity |
|
|
|
|
|
def forward(self, x, **kwargs): |
|
|
x = self.project_in(x) |
|
|
x, loss = self.fn(x, **kwargs) |
|
|
x = self.project_out(x) |
|
|
return x, loss |
|
|
|
|
|
|
|
|
class MatrixMultiply(nn.Module): |
|
|
def __init__(self, tensor, transpose=False): |
|
|
super().__init__() |
|
|
self.tensor = tensor |
|
|
self.transpose = transpose |
|
|
|
|
|
def forward(self, x): |
|
|
tensor = self.tensor |
|
|
if self.transpose: |
|
|
tensor = tensor.t() |
|
|
return x @ tensor |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DepthWiseConv1d(nn.Module): |
|
|
def __init__(self, dim_in, dim_out, kernel_size, stride=1, bias=True, causal=False): |
|
|
super().__init__() |
|
|
self.padding = ( |
|
|
((kernel_size - 1), 0) if causal else (kernel_size // 2, kernel_size // 2) |
|
|
) |
|
|
|
|
|
self.net = nn.Sequential( |
|
|
nn.Conv1d( |
|
|
dim_in, |
|
|
dim_in, |
|
|
kernel_size=kernel_size, |
|
|
groups=dim_in, |
|
|
stride=stride, |
|
|
bias=bias, |
|
|
), |
|
|
nn.Conv1d(dim_in, dim_out, 1, bias=bias), |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
x = F.pad(x, self.padding, value=0.0) |
|
|
return self.net(x) |
|
|
|
|
|
|
|
|
class FixedPositionalEmbedding(nn.Module): |
|
|
def __init__(self, dim, max_seq_len): |
|
|
super().__init__() |
|
|
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) |
|
|
position = torch.arange(0, max_seq_len, dtype=torch.float) |
|
|
sinusoid_inp = torch.einsum("i,j->ij", position, inv_freq) |
|
|
emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1) |
|
|
self.register_buffer("emb", emb) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.emb[None, : x.shape[1], :].to(x) |
|
|
|
|
|
|
|
|
def rotate_every_two(x): |
|
|
x = rearrange(x, "... (d j) -> ... d j", j=2) |
|
|
x1, x2 = x.unbind(dim=-1) |
|
|
x = torch.stack((-x2, x1), dim=-1) |
|
|
return rearrange(x, "... d j -> ... (d j)") |
|
|
|
|
|
|
|
|
def apply_rotary_pos_emb(q, k, sinu_pos): |
|
|
sinu_pos = rearrange(sinu_pos, "() n (j d) -> n j d", j=2) |
|
|
sin, cos = sinu_pos.unbind(dim=-2) |
|
|
sin, cos = map(lambda t: repeat(t, "b n -> b (n j)", j=2), (sin, cos)) |
|
|
q, k = map(lambda t: (t * cos) + (rotate_every_two(t) * sin), (q, k)) |
|
|
return q, k |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def update_kmeans_on_backwards(module): |
|
|
module.kmean_modules = find_modules(module, Kmeans) |
|
|
|
|
|
def hook(_, grad_in, grad_out): |
|
|
for m in module.kmean_modules: |
|
|
m.update() |
|
|
|
|
|
return module.register_backward_hook(hook) |
|
|
|
|
|
|
|
|
def similarity(x, means): |
|
|
return torch.einsum("bhld,hcd->bhlc", x, means) |
|
|
|
|
|
|
|
|
def dists_and_buckets(x, means): |
|
|
dists = similarity(x, means) |
|
|
_, buckets = torch.max(dists, dim=-1) |
|
|
return dists, buckets |
|
|
|
|
|
|
|
|
def batched_bincount(index, num_classes, dim=-1): |
|
|
shape = list(index.shape) |
|
|
shape[dim] = num_classes |
|
|
out = index.new_zeros(shape) |
|
|
out.scatter_add_(dim, index, torch.ones_like(index, dtype=index.dtype)) |
|
|
return out |
|
|
|
|
|
|
|
|
def kmeans_iter(x, means, buckets=None): |
|
|
b, h, _, d, dtype, num_clusters = *x.shape, x.dtype, means.shape[1] |
|
|
|
|
|
if not exists(buckets): |
|
|
_, buckets = dists_and_buckets(x, means) |
|
|
|
|
|
bins = batched_bincount(buckets, num_clusters).sum(0, keepdim=True) |
|
|
zero_mask = bins.long() == 0 |
|
|
|
|
|
means_ = buckets.new_zeros(b, h, num_clusters, d, dtype=dtype) |
|
|
means_.scatter_add_(-2, expand_dim(buckets, -1, d), x) |
|
|
means_ = F.normalize(means_.sum(0, keepdim=True), dim=-1).type(dtype) |
|
|
|
|
|
means = torch.where(zero_mask.unsqueeze(-1), means, means_) |
|
|
means = means.squeeze(0) |
|
|
return means |
|
|
|
|
|
|
|
|
def distribution(dists, window_size): |
|
|
_, topk_indices = dists.topk(k=window_size, dim=-2) |
|
|
indices = topk_indices.transpose(-2, -1) |
|
|
return indices.reshape(*indices.size()[:2], -1) |
|
|
|
|
|
|
|
|
class Kmeans(nn.Module): |
|
|
def __init__( |
|
|
self, num_heads, head_dim, num_clusters, ema_decay=0.999, commitment=1e-4 |
|
|
): |
|
|
super().__init__() |
|
|
self.commitment = commitment |
|
|
self.ema_decay = ema_decay |
|
|
|
|
|
self.register_buffer("means", torch.randn(num_heads, num_clusters, head_dim)) |
|
|
self.register_buffer("initted", torch.tensor(False)) |
|
|
self.num_new_means = 0 |
|
|
self.new_means = None |
|
|
|
|
|
@torch.no_grad() |
|
|
def init(self, x): |
|
|
if self.initted: |
|
|
return |
|
|
_, h, _, d, device, _ = *x.shape, x.device, x.dtype |
|
|
|
|
|
num_clusters = self.means.shape[1] |
|
|
|
|
|
means = x.transpose(0, 1).contiguous().view(h, -1, d) |
|
|
num_samples = means.shape[1] |
|
|
|
|
|
if num_samples >= num_clusters: |
|
|
indices = torch.randperm(num_samples, device=device)[:num_clusters] |
|
|
else: |
|
|
indices = torch.randint(0, num_samples, (num_clusters,), device=device) |
|
|
|
|
|
means = means[:, indices] |
|
|
|
|
|
for _ in range(KMEAN_INIT_ITERS): |
|
|
means = kmeans_iter(x, means) |
|
|
|
|
|
self.num_new_means = 0 |
|
|
self.means.data.copy_(means) |
|
|
self.initted.data.copy_(torch.tensor(True)) |
|
|
|
|
|
@torch.no_grad() |
|
|
def update(self, new_means=None): |
|
|
new_means = default(new_means, self.new_means) |
|
|
assert exists(new_means), "new kmeans has not been supplied" |
|
|
ema_inplace(self.means, new_means, self.ema_decay) |
|
|
|
|
|
del self.new_means |
|
|
self.new_means = None |
|
|
self.num_new_means = 0 |
|
|
|
|
|
def forward(self, x, update_means=False): |
|
|
self.init(x) |
|
|
|
|
|
b, dtype = x.shape[0], x.dtype |
|
|
means = self.means.type(dtype) |
|
|
x = F.normalize(x, 2, dim=-1).type(dtype) |
|
|
|
|
|
with torch.no_grad(): |
|
|
dists, buckets = dists_and_buckets(x, means) |
|
|
|
|
|
routed_means = batched_index_select(expand_dim(means, 0, b), buckets) |
|
|
loss = F.mse_loss(x, routed_means) * self.commitment |
|
|
|
|
|
if update_means: |
|
|
with torch.no_grad(): |
|
|
means = kmeans_iter(x, means, buckets) |
|
|
self.new_means = ema( |
|
|
self.new_means, means, self.num_new_means / (self.num_new_means + 1) |
|
|
) |
|
|
self.num_new_means += 1 |
|
|
|
|
|
return dists, loss |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class KmeansAttention(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
num_clusters, |
|
|
window_size, |
|
|
num_heads, |
|
|
head_dim, |
|
|
causal=False, |
|
|
dropout=0.0, |
|
|
ema_decay=0.999, |
|
|
commitment=1e-4, |
|
|
context_window_size=None, |
|
|
receives_context=False, |
|
|
num_mem_kv=0, |
|
|
shared_qk=False, |
|
|
): |
|
|
super().__init__() |
|
|
self.num_heads = num_heads |
|
|
self.num_clusters = num_clusters |
|
|
self.head_dim = head_dim |
|
|
|
|
|
self.window_size = window_size |
|
|
self.context_window_size = default(context_window_size, window_size) |
|
|
self.causal = causal |
|
|
|
|
|
self.shared_qk = shared_qk |
|
|
self.receives_context = receives_context |
|
|
self.kmeans = Kmeans(num_heads, head_dim, num_clusters, ema_decay, commitment) |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
|
|
|
self.num_mem_kv = max(num_mem_kv, 1 if causal and not shared_qk else 0) |
|
|
self.mem_key = nn.Parameter( |
|
|
torch.randn(num_heads, num_clusters, self.num_mem_kv, head_dim) |
|
|
) |
|
|
self.mem_value = nn.Parameter( |
|
|
torch.randn(num_heads, num_clusters, self.num_mem_kv, head_dim) |
|
|
) |
|
|
|
|
|
def forward(self, q, k, v, query_mask=None, key_mask=None, **kwargs): |
|
|
b, h, t, d, kv_t, wsz, c_wsz, nc, device, dtype = ( |
|
|
*q.shape, |
|
|
k.shape[2], |
|
|
self.window_size, |
|
|
self.context_window_size, |
|
|
self.num_clusters, |
|
|
q.device, |
|
|
q.dtype, |
|
|
) |
|
|
is_reverse = kwargs.pop("_reverse", False) |
|
|
|
|
|
out = torch.zeros_like(q, dtype=dtype) |
|
|
|
|
|
update_kmeans = self.training and not is_reverse |
|
|
|
|
|
key_mask = ( |
|
|
default(key_mask, query_mask) if not self.receives_context else key_mask |
|
|
) |
|
|
kv_wsz = wsz if not self.receives_context else c_wsz |
|
|
|
|
|
wsz = min(wsz, t) |
|
|
kv_wsz = min(kv_wsz, kv_t) |
|
|
|
|
|
if not self.shared_qk or self.receives_context: |
|
|
dists, aux_loss = self.kmeans(torch.cat((q, k), dim=2), update_kmeans) |
|
|
q_dists, k_dists = split_at_index(2, t, dists) |
|
|
indices = distribution(q_dists, wsz) |
|
|
kv_indices = distribution(k_dists, kv_wsz) |
|
|
else: |
|
|
dists, aux_loss = self.kmeans(q, update_kmeans) |
|
|
k = F.normalize(k, dim=-1).to(q) |
|
|
indices = distribution(dists, wsz) |
|
|
kv_indices = indices |
|
|
|
|
|
q = batched_index_select(q, indices) |
|
|
k = batched_index_select(k, kv_indices) |
|
|
v = batched_index_select(v, kv_indices) |
|
|
|
|
|
reshape_with_window = lambda x: x.reshape(b, h, nc, -1, d) |
|
|
q, k, v = map(reshape_with_window, (q, k, v)) |
|
|
|
|
|
m_k, m_v = map( |
|
|
lambda x: expand_dim(x, 0, b).to(q), (self.mem_key, self.mem_value) |
|
|
) |
|
|
k, v = map(lambda x: torch.cat(x, dim=3), ((m_k, k), (m_v, v))) |
|
|
|
|
|
dots = torch.einsum("bhnid,bhnjd->bhnij", q, k) * (d**-0.5) |
|
|
|
|
|
mask_value = max_neg_value(dots) |
|
|
|
|
|
if exists(query_mask) or exists(key_mask): |
|
|
query_mask = default( |
|
|
query_mask, lambda: torch.ones((b, t), device=device).bool() |
|
|
) |
|
|
key_mask = default( |
|
|
key_mask, lambda: torch.ones((b, kv_t), device=device).bool() |
|
|
) |
|
|
|
|
|
q_mask = expand_dim(query_mask, 1, h).gather(2, indices) |
|
|
kv_mask = expand_dim(key_mask, 1, h).gather(2, kv_indices) |
|
|
q_mask, kv_mask = map(lambda t: t.reshape(b, h, nc, -1), (q_mask, kv_mask)) |
|
|
mask = q_mask[:, :, :, :, None] * kv_mask[:, :, :, None, :] |
|
|
mask = F.pad(mask, (self.num_mem_kv, 0), value=1) |
|
|
dots.masked_fill_(~mask, mask_value) |
|
|
del mask |
|
|
|
|
|
if self.causal: |
|
|
q_mask, kv_mask = map( |
|
|
lambda t: t.reshape(b, h, nc, -1), (indices, kv_indices) |
|
|
) |
|
|
mask = q_mask[:, :, :, :, None] >= kv_mask[:, :, :, None, :] |
|
|
mask = F.pad(mask, (self.num_mem_kv, 0), value=1) |
|
|
dots.masked_fill_(~mask, mask_value) |
|
|
del mask |
|
|
|
|
|
if self.shared_qk: |
|
|
q_mask, kv_mask = map( |
|
|
lambda t: t.reshape(b, h, nc, -1), (indices, kv_indices) |
|
|
) |
|
|
mask = q_mask[:, :, :, :, None] == kv_mask[:, :, :, None, :] |
|
|
mask = F.pad(mask, (self.num_mem_kv, 0), value=0) |
|
|
dots.masked_fill_(mask, TOKEN_SELF_ATTN_VALUE) |
|
|
del mask |
|
|
|
|
|
dots = dots.softmax(dim=-1) |
|
|
dots = self.dropout(dots) |
|
|
|
|
|
bo = torch.einsum("bhcij,bhcjd->bhcid", dots, v) |
|
|
so = torch.reshape(bo, (b, h, -1, bo.shape[-1])).type(dtype) |
|
|
out = scatter_mean(out, so, indices.unsqueeze(-1).expand_as(so), -2) |
|
|
return out, aux_loss |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GELU_(nn.Module): |
|
|
def forward(self, x): |
|
|
return ( |
|
|
0.5 |
|
|
* x |
|
|
* ( |
|
|
1 |
|
|
+ torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))) |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
GELU = nn.GELU if hasattr(nn, "GELU") else GELU_ |
|
|
|
|
|
|
|
|
class FeedForward(nn.Module): |
|
|
def __init__(self, dim, mult=4, dropout=0.0, activation=None, glu=False): |
|
|
super().__init__() |
|
|
activation = default(activation, GELU) |
|
|
|
|
|
self.glu = glu |
|
|
self.w1 = nn.Linear(dim, dim * mult * (2 if glu else 1)) |
|
|
self.act = activation() |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
self.w2 = nn.Linear(dim * mult, dim) |
|
|
|
|
|
def forward(self, x, **kwargs): |
|
|
if not self.glu: |
|
|
x = self.w1(x) |
|
|
x = self.act(x) |
|
|
else: |
|
|
x, v = self.w1(x).chunk(2, dim=-1) |
|
|
x = self.act(x) * v |
|
|
|
|
|
x = self.dropout(x) |
|
|
x = self.w2(x) |
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SelfAttention(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
dim, |
|
|
max_seq_len, |
|
|
heads, |
|
|
local_attn_heads, |
|
|
window_size, |
|
|
dim_head=None, |
|
|
local_attn_window_size=None, |
|
|
local_attn_radius_blocks=1, |
|
|
causal=False, |
|
|
attn_dropout=0.0, |
|
|
dropout=0.0, |
|
|
kmeans_ema_decay=0.999, |
|
|
commitment_factor=1e-4, |
|
|
receives_context=False, |
|
|
context_window_size=None, |
|
|
rel_pos_emb=True, |
|
|
num_mem_kv=0, |
|
|
shared_qk=False, |
|
|
conv_query_kernel=9, |
|
|
): |
|
|
super().__init__() |
|
|
assert ( |
|
|
dim_head or (dim % heads) == 0 |
|
|
), "hidden dimension must be divisible by number of heads" |
|
|
assert ( |
|
|
max_seq_len % window_size |
|
|
) == 0, "maximum sequence length must be divisible by the target window size" |
|
|
assert ( |
|
|
local_attn_heads <= heads |
|
|
), "number of local attention heads must be less than total heads" |
|
|
assert not ( |
|
|
receives_context and local_attn_heads > 0 |
|
|
), "local attention cannot be used for self attention with context" |
|
|
assert not ( |
|
|
receives_context and causal |
|
|
), "contextual attention layer cannot be causal" |
|
|
|
|
|
local_attn_window_size = default(local_attn_window_size, window_size) |
|
|
context_window_size = default(context_window_size, window_size) |
|
|
|
|
|
self.shared_qk = shared_qk |
|
|
self.receives_context = receives_context |
|
|
self.heads = heads |
|
|
self.local_attn_heads = local_attn_heads |
|
|
self.global_attn_heads = heads - local_attn_heads |
|
|
|
|
|
self.causal = causal |
|
|
self.window_size = window_size |
|
|
|
|
|
dim_head = default(dim_head, dim // heads) |
|
|
dim_heads = dim_head * heads |
|
|
self.dim_head = dim_head |
|
|
|
|
|
num_clusters = max_seq_len // window_size |
|
|
|
|
|
|
|
|
|
|
|
local_dim_heads = dim_head * self.local_attn_heads |
|
|
|
|
|
if self.local_attn_heads > 0: |
|
|
rel_pos_emb_config = (dim_head, local_attn_heads) if rel_pos_emb else None |
|
|
self.local_attn = LocalAttention( |
|
|
dim=dim_head, |
|
|
window_size=local_attn_window_size, |
|
|
causal=causal, |
|
|
dropout=attn_dropout, |
|
|
rel_pos_emb_config=rel_pos_emb_config, |
|
|
look_backward=local_attn_radius_blocks, |
|
|
look_forward=0 if causal else local_attn_radius_blocks, |
|
|
) |
|
|
self.local_to_qkv = nn.Linear(dim, 3 * local_dim_heads) |
|
|
|
|
|
|
|
|
|
|
|
global_dim_heads = dim_head * self.global_attn_heads |
|
|
|
|
|
if self.global_attn_heads > 0: |
|
|
self.global_attn = KmeansAttention( |
|
|
num_clusters, |
|
|
window_size, |
|
|
self.global_attn_heads, |
|
|
dim_head, |
|
|
causal=causal, |
|
|
dropout=attn_dropout, |
|
|
ema_decay=kmeans_ema_decay, |
|
|
commitment=commitment_factor, |
|
|
receives_context=receives_context, |
|
|
num_mem_kv=num_mem_kv, |
|
|
shared_qk=shared_qk, |
|
|
) |
|
|
|
|
|
self.to_q = nn.Sequential( |
|
|
Rearrange("b n c -> b c n"), |
|
|
DepthWiseConv1d(dim, global_dim_heads, conv_query_kernel, causal=causal), |
|
|
Rearrange("b c n -> b n c"), |
|
|
) |
|
|
|
|
|
self.to_v = nn.Linear(dim, global_dim_heads, bias=False) |
|
|
|
|
|
if not self.shared_qk: |
|
|
self.to_k = nn.Linear(dim, global_dim_heads, bias=False) |
|
|
|
|
|
|
|
|
|
|
|
self.to_out = nn.Linear(dim_heads, dim, bias=False) |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
query, |
|
|
key, |
|
|
value, |
|
|
context=None, |
|
|
key_padding_mask=None, |
|
|
context_mask=None, |
|
|
pos_emb=None, |
|
|
**kwargs |
|
|
): |
|
|
assert not ( |
|
|
self.receives_context and not exists(context) |
|
|
), "context must be passed if self attention is set to receive context" |
|
|
input_mask = key_padding_mask |
|
|
x = query.transpose(0, 1) |
|
|
b, t, _, h, dh = *x.shape, self.heads, self.dim_head |
|
|
has_local, has_global = map( |
|
|
lambda x: x > 0, (self.local_attn_heads, self.global_attn_heads) |
|
|
) |
|
|
|
|
|
split_heads = ( |
|
|
lambda v: reshape_dim(v, -1, (-1, dh)).transpose(1, 2).contiguous() |
|
|
) |
|
|
|
|
|
if has_local: |
|
|
local_qkv = self.local_to_qkv(x).chunk(3, dim=-1) |
|
|
lq, lk, lv = map(split_heads, local_qkv) |
|
|
|
|
|
if has_global: |
|
|
kv_input = x if not self.receives_context else context |
|
|
|
|
|
q, v = self.to_q(x), self.to_v(kv_input) |
|
|
|
|
|
if not self.shared_qk: |
|
|
k = self.to_k(kv_input) |
|
|
else: |
|
|
k = self.to_q(kv_input) if self.receives_context else q |
|
|
|
|
|
q, k, v = map(split_heads, (q, k, v)) |
|
|
|
|
|
out = [] |
|
|
total_loss = torch.tensor(0.0, requires_grad=True, **to(x)) |
|
|
|
|
|
if has_local: |
|
|
local_out = self.local_attn(lq, lk, lv, input_mask=input_mask) |
|
|
out.append(local_out) |
|
|
|
|
|
if has_global: |
|
|
if not self.receives_context and exists(pos_emb): |
|
|
q, k = apply_rotary_pos_emb(q, k, pos_emb) |
|
|
|
|
|
global_out, loss = self.global_attn( |
|
|
q, k, v, query_mask=input_mask, key_mask=context_mask |
|
|
) |
|
|
total_loss = total_loss + loss |
|
|
|
|
|
out.append(global_out) |
|
|
|
|
|
out = torch.cat(out, dim=1) |
|
|
out = out.reshape(b, h, t, -1).transpose(1, 2).reshape(b, t, -1) |
|
|
out = self.dropout(out.transpose(0, 1)) |
|
|
|
|
|
return out, total_loss |
|
|
|