pardi-speech / tts /layers /attention.py
Mehdi Lakbar
Initial demo of Lina-speech (pardi-speech)
56cfa73
import math
import os
import time
from typing import Literal
import torch
import torch.nn.functional as F
from einops import rearrange, reduce, repeat
from fla.models.utils import Cache
from torch import nn
from transformers.cache_utils import Cache
def apply_causal_sliding_window(mask: torch.Tensor, window_size: int) -> torch.Tensor:
B, H, Q, KV = mask.shape
device = mask.device
q_idx = torch.arange(Q, device=device).unsqueeze(1) # (Q, 1)
k_idx = torch.arange(KV, device=device).unsqueeze(0) # (1, KV)
lower_bound = q_idx - (window_size - 1) # (Q, 1), may be negative
allowed_2d = (k_idx <= q_idx) & (k_idx >= lower_bound) # (Q, KV), dtype=torch.bool
allowed_4d = allowed_2d.unsqueeze(0).unsqueeze(0).expand(B, H, Q, KV)
orig_dtype = mask.dtype
if mask.dtype != torch.bool:
mask_bool = mask.to(torch.bool)
else:
mask_bool = mask
new_mask = mask_bool & allowed_4d
if orig_dtype != torch.bool:
return new_mask.to(orig_dtype)
else:
return new_mask
def precompute_freqs_cis_(
t: torch.Tensor,
n_elem: int,
base: float = 10000,
) -> torch.Tensor:
freqs = 1.0 / (
base
** (
torch.arange(0, n_elem, 2, device=t.device)[: (n_elem // 2)].float()
/ n_elem
)
)
freqs = torch.outer(t, freqs)
cache = repeat(freqs, "... d -> ... (d 2)")
return cache
import torch
from einops import repeat
def precompute_freqs_cis(
t: torch.Tensor, # shape: (B, T) or (T,)
n_elem: int,
base: float = 10000,
) -> torch.Tensor:
"""
Batched version of precompute_freqs_cis.
Args:
t: torch.Tensor, shape (B, T) or (T,)
Timesteps to compute frequencies for.
n_elem: int
Embedding dimension (must be even).
base: float
Base for frequency computation (default: 10000).
Returns:
cache: torch.Tensor, shape (B, T, n_elem) if batched,
(T, n_elem) if unbatched.
"""
if t.dim() == 1: # unbatched
t = t.unsqueeze(0) # (1, T)
B, T = t.shape
device = t.device
# frequencies (half dimension, then expand back)
freqs = 1.0 / (
base
** (torch.arange(0, n_elem, 2, device=device)[: (n_elem // 2)].float() / n_elem)
) # shape: (n_elem // 2,)
# outer product for each batch
# (B, T, n_elem//2)
freqs = torch.einsum("bt,d->btd", t, freqs)
# duplicate last dim to interleave sin/cos pairs
# (B, T, n_elem)
cache = repeat(freqs, "... d -> ... (d 2)")
# if cache.shape[0] == 1: # if originally unbatched
# cache = cache.squeeze(0) # (T, n_elem)
return cache
def rotate_half(x):
x = rearrange(x, "... (d r) -> ... d r", r=2)
x1, x2 = x.unbind(dim=-1)
x = torch.stack((-x2, x1), dim=-1)
return rearrange(x, "... d r -> ... (d r)")
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
out = x * freqs_cis.cos() + rotate_half(x) * freqs_cis.sin()
return out
def scaled_dot_product_attention(query, key, value, mask=None):
scale_factor = 1 / math.sqrt(query.size(-1))
attn_weight = query @ key.transpose(-2, -1) * scale_factor
if mask is not None:
attn_weight.masked_fill_(~mask, -torch.finfo(attn_weight.dtype).max)
attn_weight = torch.softmax(attn_weight, dim=-1)
return attn_weight @ value, attn_weight
class SelfAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
layer_idx: int,
is_causal: bool = False,
sliding_window: int | None = None,
):
super().__init__()
self.qkv = nn.Linear(dim, 3 * dim)
assert dim % num_heads == 0
self.heads = num_heads
self.is_causal = is_causal
self.layer_idx = layer_idx
self.output_proj = nn.Linear(dim, dim)
self.sliding_window = sliding_window
if self.sliding_window is not None:
self.is_causal = False
def forward(
self,
x,
freqs: torch.Tensor | None = None,
mask: torch.Tensor | None = None,
cache: Cache | None = None,
):
B, T, D = x.shape
q, k, v = self.qkv(x).chunk(3, dim=-1)
q, k, v = map(
lambda x: rearrange(x, "b n (h d) -> b h n d", h=self.heads), (q, k, v)
)
if freqs is not None:
q = apply_rotary_emb(q, freqs)
k = apply_rotary_emb(k, freqs)
if cache is not None:
cache.update(attn_state=(k, v), layer_idx=self.layer_idx, offset=T)
k, v = cache[self.layer_idx]["attn_state"]
if self.sliding_window is not None:
mask = torch.ones(B, 1, T, T, device=x.device)
mask = apply_causal_sliding_window(mask, self.sliding_window)
y = F.scaled_dot_product_attention(
q, k, v, attn_mask=mask, is_causal=self.is_causal and T > 1
)
y = rearrange(y, "b h n d -> b n (h d)")
y = self.output_proj(y)
return y
class CrossAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
layer_idx: int | None = None,
dropout: float = 0.1,
):
super().__init__()
assert dim % num_heads == 0
self.pre_norm_q = nn.LayerNorm(dim)
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
self.layer_idx = layer_idx
self.heads = num_heads
self.dropout_att = dropout
def _prepare_kv(self, text_hidden_states: torch.Tensor):
v = self.ln_v(self.v(text_hidden_states))
k = self.ln_k(self.k(text_hidden_states))
def _query(self, x):
return self.q(self.pre_norm_q(q))
def forward(
self,
q: torch.Tensor,
k: torch.Tensor | None = None,
v: torch.Tensor | None = None,
mask: torch.Tensor | None = None,
output_attention: bool = False,
cache: Cache | None = None,
**kwargs,
):
if v is None:
v = k
q = self.q(self.pre_norm_q(q))
if cache is not None:
if cache[self.layer_idx] is not None:
ca_state = cache[self.layer_idx]["crossatt_state"]
if ca_state is not None:
k, v = ca_state
else:
v = self.v(v)
k = self.k(k)
cache.update(crossatt_state=(k, v), layer_idx=self.layer_idx)
else:
v = self.v(v)
k = self.k(k)
q, k, v = map(
lambda x: rearrange(x, "b n (h d) -> b h n d", h=self.heads), (q, k, v)
)
if mask is not None:
if mask.ndim == 3:
mask = mask[:, None]
# if not self.training:
if not self.training:
x, att = scaled_dot_product_attention(q, k, v, mask=mask)
else:
x = nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=mask, dropout_p=self.dropout_att
)
att = None
x = rearrange(x, "b h n d -> b n (h d)")
if att is not None:
if cache is not None:
cache.update(crossatt_weights=att, layer_idx=self.layer_idx)
else:
self.att = att
return x
class ConvPos(nn.Module):
def __init__(self, dim, max_seq_len=1000, kernel_size=7, n_parallel_codebook=2):
super().__init__()
self.embed = nn.Embedding(max_seq_len * n_parallel_codebook, dim)
self.dw_conv = nn.Conv1d(dim, dim, kernel_size, groups=dim, padding="same")
self.max_seq_len = max_seq_len
self.n_parallel_codebook = n_parallel_codebook
def forward(self, x, left_shift=0, random_shift=False):
# left_pad = 31 if left_shift > 0 else 0
# x = torch.cat((torch.arange(left_shift - left_pad, left_shift).to(x).unsqueeze(0),x, torch.arange(31).to(x).unsqueeze(0)), dim=1).clamp_min_(0)
if random_shift:
bias = torch.randint(
0,
self.n_parallel_codebook,
(x.shape[0],),
device=x.device,
)
x = x + bias * self.max_seq_len
y = self.embed(x)
y = rearrange(y, "b n c -> b c n")
y = self.dw_conv(y)
y = rearrange(y, "b c n -> b n c") # [:,left_pad:-31]
return y
class SinPos(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
exp = torch.arange(self.dim // 2, device=x.device)
exp = 2 * exp / (self.dim)
exp = rearrange(exp, "e -> 1 1 e")
x = rearrange(x, "b p -> b p 1")
pos = x * torch.pow(10000, -exp)
pos = torch.cat((pos, pos + math.pi / 2), dim=2)
pos = torch.sin(pos)
return pos
class BlindCrossAttention(nn.Module):
def __init__(
self,
q_dim,
k_dim,
att_dim,
pos_net,
dropout=0.1,
pos_dim=64,
pos_type="sinusoidal",
layer_idx: int | None = None,
):
super().__init__()
self.q = nn.Linear(q_dim, att_dim)
self.k = nn.Linear(k_dim, att_dim)
self.v = nn.Linear(k_dim, att_dim)
self.pos_net = pos_net
if pos_type == "sinusoidal":
self.pos_embed = SinPos(pos_dim)
elif pos_type == "convolutional":
self.pos_embed = ConvPos(pos_dim)
self.ln_q = nn.LayerNorm(att_dim)
self.ln_k = nn.LayerNorm(att_dim)
self.ln_v = nn.LayerNorm(att_dim)
self.dropout_att = nn.Dropout(dropout)
self.layer_idx = layer_idx
def _prepare_kv(self, text_hidden_states: torch.Tensor):
v = self.ln_v(self.v(text_hidden_states))
k = self.ln_k(self.k(text_hidden_states))
b, h, j, d = k.shape
pos = torch.arange(j, device=k.device).unsqueeze(0)
pos_emb = self.pos_embed(pos)
return {"k": k, "v": v, "pos_emb": pos_emb}
def _query(self, x):
return self.ln_q(self.q(x))
def forward(
self,
q,
k,
kv_cached=None,
mask=None,
time_step=None,
pos=None,
left_shift=0,
past_key_values=None,
cache=None,
**kwargs,
):
q = self.ln_q(self.q(q))
# if kv_cached is None:
# v = self.ln_v(self.v(k))
# k = self.ln_k(self.k(k))
# else:
# k, v = kv_cached
if mask is not None:
mask = mask.unsqueeze(1)
if cache is not None:
if cache[self.layer_idx] is not None:
ca_state = cache[self.layer_idx]["crossatt_state"]
if ca_state is not None:
k, v, pos_emb = ca_state
else:
# v = self.v(v)
# k = self.k(k)
v = self.ln_v(self.v(k))
k = self.ln_k(self.k(k))
pos = torch.arange(k.shape[-2], device=k.device).unsqueeze(0)
pos_emb = self.pos_embed(pos, left_shift=left_shift)
cache.update(
crossatt_state=(k, v, pos_emb), layer_idx=self.layer_idx
)
else:
v = self.ln_v(self.v(k))
k = self.ln_k(self.k(k))
if pos is None:
pos = torch.arange(k.shape[-2], device=k.device).unsqueeze(0)
pos_emb = self.pos_embed(pos, left_shift=left_shift)
q, k, v = map(lambda x: rearrange(x, "b n d -> b 1 n d"), (q, k, v))
b, h, j, d = k.shape
if self.training:
sdpa = lambda q, k, pos: (
nn.functional.scaled_dot_product_attention(
q, k, pos, attn_mask=mask, dropout_p=self.dropout_att.p
),
None,
)
else:
sdpa = lambda q, k, pos: scaled_dot_product_attention(q, k, pos, mask=mask)
x, att1 = sdpa(q, k, pos_emb.unsqueeze(1))
x = rearrange(x, "b 1 n d -> b n d")
x = self.pos_net(x, cache=cache)
x = rearrange(x, "b n d -> b 1 n d")
pos_emb = rearrange(pos_emb, "b n d -> b 1 n d")
x, att2 = sdpa(x, pos_emb, v)
x = rearrange(x, "b 1 n d -> b n d")
self.att1 = att1
self.att2 = att2
if att2 is not None:
if cache is not None:
cache.update(
crossatt_weights=torch.cat((att1, att2), dim=1),
layer_idx=self.layer_idx,
)
return x
class ListenReadCrossAttention(nn.Module):
def __init__(
self,
q_dim: int,
k_dim: int,
att_dim: int,
crossatt_type: Literal["listen", "read"],
num_heads: int = 1,
dropout: float = 0.1,
layer_idx: int | None = None,
):
super().__init__()
self.q = nn.Linear(q_dim, att_dim)
self.k = nn.Linear(k_dim, att_dim)
self.ln_q = nn.LayerNorm(att_dim)
self.ln_k = nn.LayerNorm(att_dim)
self.dropout_att = nn.Dropout(dropout)
self.crossatt_type = crossatt_type
self.layer_idx = layer_idx
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
text_freqs: torch.Tensor,
mask: torch.Tensor | None = None,
past_key_values=None,
cache=None,
**kwargs,
):
q = self.ln_q(self.q(q))
k = self.ln_k(self.k(k))
if mask is not None:
mask = mask.unsqueeze(1)
q, k = map(lambda x: rearrange(x, "b n d -> b 1 n d"), (q, k))
if self.training:
sdpa = lambda q, k, pos: (
nn.functional.scaled_dot_product_attention(
q, k, pos, attn_mask=mask, dropout_p=self.dropout_att.p
),
None,
)
else:
sdpa = lambda q, k, pos: scaled_dot_product_attention(q, k, pos, mask=mask)
text_freqs = rearrange(text_freqs, "b n d -> b 1 n d")
if self.crossatt_type == "listen":
x, att = sdpa(q, k, text_freqs)
elif self.crossatt_type == "read":
x, att = sdpa(q, text_freqs, k)
else:
raise ValueError
x = rearrange(x, "b 1 n d -> b n d")
if att is not None:
if cache is not None:
cache.update(
crossatt_weights=att,
layer_idx=self.layer_idx,
)
self.att = att
return x