Spaces:
Running
on
Zero
Running
on
Zero
| import math | |
| import os | |
| import time | |
| from typing import Literal | |
| import torch | |
| import torch.nn.functional as F | |
| from einops import rearrange, reduce, repeat | |
| from fla.models.utils import Cache | |
| from torch import nn | |
| from transformers.cache_utils import Cache | |
| def apply_causal_sliding_window(mask: torch.Tensor, window_size: int) -> torch.Tensor: | |
| B, H, Q, KV = mask.shape | |
| device = mask.device | |
| q_idx = torch.arange(Q, device=device).unsqueeze(1) # (Q, 1) | |
| k_idx = torch.arange(KV, device=device).unsqueeze(0) # (1, KV) | |
| lower_bound = q_idx - (window_size - 1) # (Q, 1), may be negative | |
| allowed_2d = (k_idx <= q_idx) & (k_idx >= lower_bound) # (Q, KV), dtype=torch.bool | |
| allowed_4d = allowed_2d.unsqueeze(0).unsqueeze(0).expand(B, H, Q, KV) | |
| orig_dtype = mask.dtype | |
| if mask.dtype != torch.bool: | |
| mask_bool = mask.to(torch.bool) | |
| else: | |
| mask_bool = mask | |
| new_mask = mask_bool & allowed_4d | |
| if orig_dtype != torch.bool: | |
| return new_mask.to(orig_dtype) | |
| else: | |
| return new_mask | |
| def precompute_freqs_cis_( | |
| t: torch.Tensor, | |
| n_elem: int, | |
| base: float = 10000, | |
| ) -> torch.Tensor: | |
| freqs = 1.0 / ( | |
| base | |
| ** ( | |
| torch.arange(0, n_elem, 2, device=t.device)[: (n_elem // 2)].float() | |
| / n_elem | |
| ) | |
| ) | |
| freqs = torch.outer(t, freqs) | |
| cache = repeat(freqs, "... d -> ... (d 2)") | |
| return cache | |
| import torch | |
| from einops import repeat | |
| def precompute_freqs_cis( | |
| t: torch.Tensor, # shape: (B, T) or (T,) | |
| n_elem: int, | |
| base: float = 10000, | |
| ) -> torch.Tensor: | |
| """ | |
| Batched version of precompute_freqs_cis. | |
| Args: | |
| t: torch.Tensor, shape (B, T) or (T,) | |
| Timesteps to compute frequencies for. | |
| n_elem: int | |
| Embedding dimension (must be even). | |
| base: float | |
| Base for frequency computation (default: 10000). | |
| Returns: | |
| cache: torch.Tensor, shape (B, T, n_elem) if batched, | |
| (T, n_elem) if unbatched. | |
| """ | |
| if t.dim() == 1: # unbatched | |
| t = t.unsqueeze(0) # (1, T) | |
| B, T = t.shape | |
| device = t.device | |
| # frequencies (half dimension, then expand back) | |
| freqs = 1.0 / ( | |
| base | |
| ** (torch.arange(0, n_elem, 2, device=device)[: (n_elem // 2)].float() / n_elem) | |
| ) # shape: (n_elem // 2,) | |
| # outer product for each batch | |
| # (B, T, n_elem//2) | |
| freqs = torch.einsum("bt,d->btd", t, freqs) | |
| # duplicate last dim to interleave sin/cos pairs | |
| # (B, T, n_elem) | |
| cache = repeat(freqs, "... d -> ... (d 2)") | |
| # if cache.shape[0] == 1: # if originally unbatched | |
| # cache = cache.squeeze(0) # (T, n_elem) | |
| return cache | |
| def rotate_half(x): | |
| x = rearrange(x, "... (d r) -> ... d r", r=2) | |
| x1, x2 = x.unbind(dim=-1) | |
| x = torch.stack((-x2, x1), dim=-1) | |
| return rearrange(x, "... d r -> ... (d r)") | |
| def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: | |
| out = x * freqs_cis.cos() + rotate_half(x) * freqs_cis.sin() | |
| return out | |
| def scaled_dot_product_attention(query, key, value, mask=None): | |
| scale_factor = 1 / math.sqrt(query.size(-1)) | |
| attn_weight = query @ key.transpose(-2, -1) * scale_factor | |
| if mask is not None: | |
| attn_weight.masked_fill_(~mask, -torch.finfo(attn_weight.dtype).max) | |
| attn_weight = torch.softmax(attn_weight, dim=-1) | |
| return attn_weight @ value, attn_weight | |
| class SelfAttention(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| num_heads: int, | |
| layer_idx: int, | |
| is_causal: bool = False, | |
| sliding_window: int | None = None, | |
| ): | |
| super().__init__() | |
| self.qkv = nn.Linear(dim, 3 * dim) | |
| assert dim % num_heads == 0 | |
| self.heads = num_heads | |
| self.is_causal = is_causal | |
| self.layer_idx = layer_idx | |
| self.output_proj = nn.Linear(dim, dim) | |
| self.sliding_window = sliding_window | |
| if self.sliding_window is not None: | |
| self.is_causal = False | |
| def forward( | |
| self, | |
| x, | |
| freqs: torch.Tensor | None = None, | |
| mask: torch.Tensor | None = None, | |
| cache: Cache | None = None, | |
| ): | |
| B, T, D = x.shape | |
| q, k, v = self.qkv(x).chunk(3, dim=-1) | |
| q, k, v = map( | |
| lambda x: rearrange(x, "b n (h d) -> b h n d", h=self.heads), (q, k, v) | |
| ) | |
| if freqs is not None: | |
| q = apply_rotary_emb(q, freqs) | |
| k = apply_rotary_emb(k, freqs) | |
| if cache is not None: | |
| cache.update(attn_state=(k, v), layer_idx=self.layer_idx, offset=T) | |
| k, v = cache[self.layer_idx]["attn_state"] | |
| if self.sliding_window is not None: | |
| mask = torch.ones(B, 1, T, T, device=x.device) | |
| mask = apply_causal_sliding_window(mask, self.sliding_window) | |
| y = F.scaled_dot_product_attention( | |
| q, k, v, attn_mask=mask, is_causal=self.is_causal and T > 1 | |
| ) | |
| y = rearrange(y, "b h n d -> b n (h d)") | |
| y = self.output_proj(y) | |
| return y | |
| class CrossAttention(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| num_heads: int, | |
| layer_idx: int | None = None, | |
| dropout: float = 0.1, | |
| ): | |
| super().__init__() | |
| assert dim % num_heads == 0 | |
| self.pre_norm_q = nn.LayerNorm(dim) | |
| self.q = nn.Linear(dim, dim) | |
| self.k = nn.Linear(dim, dim) | |
| self.v = nn.Linear(dim, dim) | |
| self.layer_idx = layer_idx | |
| self.heads = num_heads | |
| self.dropout_att = dropout | |
| def _prepare_kv(self, text_hidden_states: torch.Tensor): | |
| v = self.ln_v(self.v(text_hidden_states)) | |
| k = self.ln_k(self.k(text_hidden_states)) | |
| def _query(self, x): | |
| return self.q(self.pre_norm_q(q)) | |
| def forward( | |
| self, | |
| q: torch.Tensor, | |
| k: torch.Tensor | None = None, | |
| v: torch.Tensor | None = None, | |
| mask: torch.Tensor | None = None, | |
| output_attention: bool = False, | |
| cache: Cache | None = None, | |
| **kwargs, | |
| ): | |
| if v is None: | |
| v = k | |
| q = self.q(self.pre_norm_q(q)) | |
| if cache is not None: | |
| if cache[self.layer_idx] is not None: | |
| ca_state = cache[self.layer_idx]["crossatt_state"] | |
| if ca_state is not None: | |
| k, v = ca_state | |
| else: | |
| v = self.v(v) | |
| k = self.k(k) | |
| cache.update(crossatt_state=(k, v), layer_idx=self.layer_idx) | |
| else: | |
| v = self.v(v) | |
| k = self.k(k) | |
| q, k, v = map( | |
| lambda x: rearrange(x, "b n (h d) -> b h n d", h=self.heads), (q, k, v) | |
| ) | |
| if mask is not None: | |
| if mask.ndim == 3: | |
| mask = mask[:, None] | |
| # if not self.training: | |
| if not self.training: | |
| x, att = scaled_dot_product_attention(q, k, v, mask=mask) | |
| else: | |
| x = nn.functional.scaled_dot_product_attention( | |
| q, k, v, attn_mask=mask, dropout_p=self.dropout_att | |
| ) | |
| att = None | |
| x = rearrange(x, "b h n d -> b n (h d)") | |
| if att is not None: | |
| if cache is not None: | |
| cache.update(crossatt_weights=att, layer_idx=self.layer_idx) | |
| else: | |
| self.att = att | |
| return x | |
| class ConvPos(nn.Module): | |
| def __init__(self, dim, max_seq_len=1000, kernel_size=7, n_parallel_codebook=2): | |
| super().__init__() | |
| self.embed = nn.Embedding(max_seq_len * n_parallel_codebook, dim) | |
| self.dw_conv = nn.Conv1d(dim, dim, kernel_size, groups=dim, padding="same") | |
| self.max_seq_len = max_seq_len | |
| self.n_parallel_codebook = n_parallel_codebook | |
| def forward(self, x, left_shift=0, random_shift=False): | |
| # left_pad = 31 if left_shift > 0 else 0 | |
| # x = torch.cat((torch.arange(left_shift - left_pad, left_shift).to(x).unsqueeze(0),x, torch.arange(31).to(x).unsqueeze(0)), dim=1).clamp_min_(0) | |
| if random_shift: | |
| bias = torch.randint( | |
| 0, | |
| self.n_parallel_codebook, | |
| (x.shape[0],), | |
| device=x.device, | |
| ) | |
| x = x + bias * self.max_seq_len | |
| y = self.embed(x) | |
| y = rearrange(y, "b n c -> b c n") | |
| y = self.dw_conv(y) | |
| y = rearrange(y, "b c n -> b n c") # [:,left_pad:-31] | |
| return y | |
| class SinPos(nn.Module): | |
| def __init__(self, dim): | |
| super().__init__() | |
| self.dim = dim | |
| def forward(self, x): | |
| exp = torch.arange(self.dim // 2, device=x.device) | |
| exp = 2 * exp / (self.dim) | |
| exp = rearrange(exp, "e -> 1 1 e") | |
| x = rearrange(x, "b p -> b p 1") | |
| pos = x * torch.pow(10000, -exp) | |
| pos = torch.cat((pos, pos + math.pi / 2), dim=2) | |
| pos = torch.sin(pos) | |
| return pos | |
| class BlindCrossAttention(nn.Module): | |
| def __init__( | |
| self, | |
| q_dim, | |
| k_dim, | |
| att_dim, | |
| pos_net, | |
| dropout=0.1, | |
| pos_dim=64, | |
| pos_type="sinusoidal", | |
| layer_idx: int | None = None, | |
| ): | |
| super().__init__() | |
| self.q = nn.Linear(q_dim, att_dim) | |
| self.k = nn.Linear(k_dim, att_dim) | |
| self.v = nn.Linear(k_dim, att_dim) | |
| self.pos_net = pos_net | |
| if pos_type == "sinusoidal": | |
| self.pos_embed = SinPos(pos_dim) | |
| elif pos_type == "convolutional": | |
| self.pos_embed = ConvPos(pos_dim) | |
| self.ln_q = nn.LayerNorm(att_dim) | |
| self.ln_k = nn.LayerNorm(att_dim) | |
| self.ln_v = nn.LayerNorm(att_dim) | |
| self.dropout_att = nn.Dropout(dropout) | |
| self.layer_idx = layer_idx | |
| def _prepare_kv(self, text_hidden_states: torch.Tensor): | |
| v = self.ln_v(self.v(text_hidden_states)) | |
| k = self.ln_k(self.k(text_hidden_states)) | |
| b, h, j, d = k.shape | |
| pos = torch.arange(j, device=k.device).unsqueeze(0) | |
| pos_emb = self.pos_embed(pos) | |
| return {"k": k, "v": v, "pos_emb": pos_emb} | |
| def _query(self, x): | |
| return self.ln_q(self.q(x)) | |
| def forward( | |
| self, | |
| q, | |
| k, | |
| kv_cached=None, | |
| mask=None, | |
| time_step=None, | |
| pos=None, | |
| left_shift=0, | |
| past_key_values=None, | |
| cache=None, | |
| **kwargs, | |
| ): | |
| q = self.ln_q(self.q(q)) | |
| # if kv_cached is None: | |
| # v = self.ln_v(self.v(k)) | |
| # k = self.ln_k(self.k(k)) | |
| # else: | |
| # k, v = kv_cached | |
| if mask is not None: | |
| mask = mask.unsqueeze(1) | |
| if cache is not None: | |
| if cache[self.layer_idx] is not None: | |
| ca_state = cache[self.layer_idx]["crossatt_state"] | |
| if ca_state is not None: | |
| k, v, pos_emb = ca_state | |
| else: | |
| # v = self.v(v) | |
| # k = self.k(k) | |
| v = self.ln_v(self.v(k)) | |
| k = self.ln_k(self.k(k)) | |
| pos = torch.arange(k.shape[-2], device=k.device).unsqueeze(0) | |
| pos_emb = self.pos_embed(pos, left_shift=left_shift) | |
| cache.update( | |
| crossatt_state=(k, v, pos_emb), layer_idx=self.layer_idx | |
| ) | |
| else: | |
| v = self.ln_v(self.v(k)) | |
| k = self.ln_k(self.k(k)) | |
| if pos is None: | |
| pos = torch.arange(k.shape[-2], device=k.device).unsqueeze(0) | |
| pos_emb = self.pos_embed(pos, left_shift=left_shift) | |
| q, k, v = map(lambda x: rearrange(x, "b n d -> b 1 n d"), (q, k, v)) | |
| b, h, j, d = k.shape | |
| if self.training: | |
| sdpa = lambda q, k, pos: ( | |
| nn.functional.scaled_dot_product_attention( | |
| q, k, pos, attn_mask=mask, dropout_p=self.dropout_att.p | |
| ), | |
| None, | |
| ) | |
| else: | |
| sdpa = lambda q, k, pos: scaled_dot_product_attention(q, k, pos, mask=mask) | |
| x, att1 = sdpa(q, k, pos_emb.unsqueeze(1)) | |
| x = rearrange(x, "b 1 n d -> b n d") | |
| x = self.pos_net(x, cache=cache) | |
| x = rearrange(x, "b n d -> b 1 n d") | |
| pos_emb = rearrange(pos_emb, "b n d -> b 1 n d") | |
| x, att2 = sdpa(x, pos_emb, v) | |
| x = rearrange(x, "b 1 n d -> b n d") | |
| self.att1 = att1 | |
| self.att2 = att2 | |
| if att2 is not None: | |
| if cache is not None: | |
| cache.update( | |
| crossatt_weights=torch.cat((att1, att2), dim=1), | |
| layer_idx=self.layer_idx, | |
| ) | |
| return x | |
| class ListenReadCrossAttention(nn.Module): | |
| def __init__( | |
| self, | |
| q_dim: int, | |
| k_dim: int, | |
| att_dim: int, | |
| crossatt_type: Literal["listen", "read"], | |
| num_heads: int = 1, | |
| dropout: float = 0.1, | |
| layer_idx: int | None = None, | |
| ): | |
| super().__init__() | |
| self.q = nn.Linear(q_dim, att_dim) | |
| self.k = nn.Linear(k_dim, att_dim) | |
| self.ln_q = nn.LayerNorm(att_dim) | |
| self.ln_k = nn.LayerNorm(att_dim) | |
| self.dropout_att = nn.Dropout(dropout) | |
| self.crossatt_type = crossatt_type | |
| self.layer_idx = layer_idx | |
| def forward( | |
| self, | |
| q: torch.Tensor, | |
| k: torch.Tensor, | |
| text_freqs: torch.Tensor, | |
| mask: torch.Tensor | None = None, | |
| past_key_values=None, | |
| cache=None, | |
| **kwargs, | |
| ): | |
| q = self.ln_q(self.q(q)) | |
| k = self.ln_k(self.k(k)) | |
| if mask is not None: | |
| mask = mask.unsqueeze(1) | |
| q, k = map(lambda x: rearrange(x, "b n d -> b 1 n d"), (q, k)) | |
| if self.training: | |
| sdpa = lambda q, k, pos: ( | |
| nn.functional.scaled_dot_product_attention( | |
| q, k, pos, attn_mask=mask, dropout_p=self.dropout_att.p | |
| ), | |
| None, | |
| ) | |
| else: | |
| sdpa = lambda q, k, pos: scaled_dot_product_attention(q, k, pos, mask=mask) | |
| text_freqs = rearrange(text_freqs, "b n d -> b 1 n d") | |
| if self.crossatt_type == "listen": | |
| x, att = sdpa(q, k, text_freqs) | |
| elif self.crossatt_type == "read": | |
| x, att = sdpa(q, text_freqs, k) | |
| else: | |
| raise ValueError | |
| x = rearrange(x, "b 1 n d -> b n d") | |
| if att is not None: | |
| if cache is not None: | |
| cache.update( | |
| crossatt_weights=att, | |
| layer_idx=self.layer_idx, | |
| ) | |
| self.att = att | |
| return x | |