| |
|
|
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| from contextlib import contextmanager |
| from typing import Dict, Iterable, Optional, Tuple |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from torch import Tensor, nn |
|
|
| from torch.nn.functional import scaled_dot_product_attention |
| SDPA_AVAILABLE = True |
|
|
|
|
| class LayerNorm(nn.LayerNorm): |
| def forward(self, x: Tensor) -> Tensor: |
| return super().forward(x.float()).type(x.dtype) |
|
|
|
|
| class Linear(nn.Linear): |
| def forward(self, x: Tensor) -> Tensor: |
| return F.linear( |
| x, |
| self.weight.to(x.dtype), |
| None if self.bias is None else self.bias.to(x.dtype), |
| ) |
|
|
|
|
| class Conv1d(nn.Conv1d): |
| def _conv_forward( |
| self, x: Tensor, weight: Tensor, bias: Optional[Tensor] |
| ) -> Tensor: |
| return super()._conv_forward( |
| x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype) |
| ) |
|
|
|
|
| def sinusoids(length, channels, max_timescale=10000): |
| """Returns sinusoids for positional embedding""" |
| assert channels % 2 == 0 |
| log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) |
| inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2)) |
| scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] |
| return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) |
|
|
|
|
| @contextmanager |
| def disable_sdpa(): |
| prev_state = MultiHeadAttention.use_sdpa |
| try: |
| MultiHeadAttention.use_sdpa = False |
| yield |
| finally: |
| MultiHeadAttention.use_sdpa = prev_state |
|
|
|
|
| class MultiHeadAttention(nn.Module): |
| use_sdpa = True |
|
|
| def __init__(self, n_state: int, n_head: int): |
| super().__init__() |
| self.n_head = n_head |
| self.query = Linear(n_state, n_state) |
| self.key = Linear(n_state, n_state, bias=False) |
| self.value = Linear(n_state, n_state) |
| self.out = Linear(n_state, n_state) |
|
|
| def forward( |
| self, |
| x: Tensor, |
| xa: Optional[Tensor] = None, |
| mask: Optional[Tensor] = None, |
| kv_cache: Optional[dict] = None, |
| casual: Optional[bool] = None |
| ): |
| q = self.query(x) |
|
|
| if kv_cache is None or xa is None or self.key not in kv_cache: |
| |
| |
| k = self.key(x if xa is None else xa) |
| v = self.value(x if xa is None else xa) |
| else: |
| |
| k = kv_cache[self.key] |
| v = kv_cache[self.value] |
|
|
| wv = self.qkv_attention(q, k, v, mask, casual) |
| return self.out(wv) |
|
|
| def qkv_attention( |
| self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None, casual: Optional[bool] = None |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
| n_batch, n_ctx, n_state = q.shape |
| scale = (n_state // self.n_head) ** -0.25 |
| q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) |
| k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) |
| v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) |
|
|
| a = scaled_dot_product_attention( |
| q, k, v, is_causal=casual and n_ctx > 1, attn_mask=mask[:, None, None, :] if mask is not None else None |
| ) |
| out = a.permute(0, 2, 1, 3).flatten(start_dim=2) |
| return out |
|
|
|
|
| class ResidualAttentionBlock(nn.Module): |
| def __init__(self, n_state: int, n_head: int, cross_attention: bool = False): |
| super().__init__() |
|
|
| self.attn = MultiHeadAttention(n_state, n_head) |
| self.attn_ln = LayerNorm(n_state) |
|
|
| self.cross_attn = ( |
| MultiHeadAttention(n_state, n_head) if cross_attention else None |
| ) |
| self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None |
|
|
| n_mlp = n_state * 4 |
| self.mlp = nn.Sequential( |
| Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state) |
| ) |
| self.mlp_ln = LayerNorm(n_state) |
|
|
| def forward( |
| self, |
| x: Tensor, |
| xa: Optional[Tensor] = None, |
| mask: Optional[Tensor] = None, |
| kv_cache: Optional[dict] = None, |
| casual: Optional[bool] = None, |
| ): |
| x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache, casual=casual) |
| if self.cross_attn: |
| |
| x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache, casual=False) |
| x = x + self.mlp(self.mlp_ln(x)) |
| return x |
|
|
|
|
| class AudioEncoder(nn.Module): |
| def __init__( |
| self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int |
| ): |
| super().__init__() |
| self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1) |
| self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1) |
| self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state)) |
|
|
| self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList( |
| [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)] |
| ) |
| self.ln_post = LayerNorm(n_state) |
|
|
| def forward(self, x: Tensor, attn_mask: Tensor): |
| """ |
| x : torch.Tensor, shape = (batch_size, n_mels, n_ctx) |
| the mel spectrogram of the audio |
| """ |
| x = F.gelu(self.conv1(x)) |
| x = F.gelu(self.conv2(x)) |
| x = x.permute(0, 2, 1) |
|
|
| |
| x = (x + self.positional_embedding[:x.size(1)]).to(x.dtype) |
|
|
| for block in self.blocks: |
| x = block(x, mask=attn_mask, casual=False) |
|
|
| x = self.ln_post(x) |
| return x |
|
|
|
|
| class TextDecoder(nn.Module): |
| def __init__( |
| self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int |
| ): |
| super().__init__() |
|
|
| self.token_embedding = nn.Embedding(n_vocab, n_state) |
| self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state)) |
|
|
| self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList( |
| [ |
| ResidualAttentionBlock(n_state, n_head, cross_attention=True) |
| for _ in range(n_layer) |
| ] |
| ) |
| self.ln = LayerNorm(n_state) |
|
|
| self.out_proj = nn.Linear(n_state, n_vocab) |
|
|
| def forward(self, x: Tensor, attn_mask: Tensor, xa: Tensor, kv_cache: Optional[dict] = None): |
| """ |
| x : torch.LongTensor, shape = (batch_size, <= n_ctx) |
| the text tokens |
| xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state) |
| the encoded audio features to be attended on |
| """ |
| offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0 |
| x = ( |
| self.token_embedding(x) |
| + self.positional_embedding[offset : offset + x.shape[-1]] |
| ) |
| x = x.to(xa.dtype) |
|
|
| for block in self.blocks: |
| x = block(x, xa, mask=attn_mask, kv_cache=kv_cache, casual=True) |
|
|
| x = self.ln(x) |
| |
| |
| |
| logits = self.out_proj(x) |
|
|
| return logits |
|
|
|
|
| class Whisper(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.n_vocab = 6800 |
| self.n_text_layer = 6 |
| self.n_text_head = 8 |
| self.n_text_ctx = 2048 |
|
|
| self.encoder = AudioEncoder( |
| n_mels=80, n_ctx=3000, n_state=512, n_head=8, n_layer=6, |
| ) |
| self.decoder = TextDecoder( |
| n_vocab=6800, n_ctx=2048, n_state=512, n_head=8, n_layer=6, |
| ) |
|
|
| def embed_audio(self, mel: torch.Tensor): |
| return self.encoder(mel, None) |
|
|
| def logits(self, tokens, audio_features, kv_cache=None): |
| return self.decoder(tokens, None, audio_features, kv_cache=kv_cache) |
|
|
| def forward( |
| self, mel, mel_len, token, token_len |
| ) -> Dict[str, torch.Tensor]: |
| attn_mask_enc = self.sequence_mask(mel_len//2, device=mel.device) > 0 |
| attn_mask_dec = self.sequence_mask(token_len, device=mel.device) > 0 |
| return self.decoder(token, attn_mask_dec, self.encoder(mel, attn_mask_enc)) |
|
|
| @property |
| def device(self): |
| return next(self.parameters()).device |
|
|
| def install_kv_cache_hooks(self, cache: Optional[dict] = None): |
| """ |
| The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value |
| tensors calculated for the previous positions. This method returns a dictionary that stores |
| all caches, and the necessary hooks for the key and value projection modules that save the |
| intermediate tensors to be reused during later calculations. |
| |
| Returns |
| ------- |
| cache : Dict[nn.Module, torch.Tensor] |
| A dictionary object mapping the key/value projection modules to its cache |
| hooks : List[RemovableHandle] |
| List of PyTorch RemovableHandle objects to stop the hooks to be called |
| """ |
| cache = {**cache} if cache is not None else {} |
| hooks = [] |
|
|
| def save_to_cache(module, _, output): |
| if module not in cache or output.shape[1] > self.n_text_ctx: |
| |
| cache[module] = output |
| else: |
| cache[module] = torch.cat([cache[module], output], dim=1).detach() |
| return cache[module] |
|
|
| def install_hooks(layer: nn.Module): |
| if isinstance(layer, MultiHeadAttention): |
| hooks.append(layer.key.register_forward_hook(save_to_cache)) |
| hooks.append(layer.value.register_forward_hook(save_to_cache)) |
|
|
| self.decoder.apply(install_hooks) |
| return cache, hooks |
| |
| def sequence_mask(self, seq_lens, max_len=None, device='cpu'): |
| b = seq_lens.shape[0] |
| if max_len is None: |
| max_len = seq_lens.max() |
| mask = torch.arange(max_len).unsqueeze(0).to(device) |
| mask = mask < (seq_lens.unsqueeze(1)) |
| mask = mask.float() |
| return mask |
|
|