| import base64 |
| import gzip |
| from contextlib import contextmanager |
| from dataclasses import dataclass |
| from typing import Dict, Iterable, Optional, Tuple |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from torch import Tensor, nn |
|
|
| try: |
| from torch.nn.functional import scaled_dot_product_attention |
|
|
| SDPA_AVAILABLE = True |
| except (ImportError, RuntimeError, OSError): |
| scaled_dot_product_attention = None |
| SDPA_AVAILABLE = False |
|
|
|
|
| @dataclass |
| class ModelDimensions: |
| n_mels: int |
| n_audio_ctx: int |
| n_audio_state: int |
| n_audio_head: int |
| n_audio_layer: int |
| n_vocab: int |
| n_text_ctx: int |
| n_text_state: int |
| n_text_head: int |
| n_text_layer: int |
|
|
|
|
| class LayerNorm(nn.LayerNorm): |
| def forward(self, x: Tensor) -> Tensor: |
| return super().forward(x.float()).type(x.dtype) |
|
|
|
|
| class Linear(nn.Linear): |
| def forward(self, x: Tensor) -> Tensor: |
| return F.linear( |
| x, |
| self.weight.to(x.dtype), |
| None if self.bias is None else self.bias.to(x.dtype), |
| ) |
|
|
|
|
| class Conv1d(nn.Conv1d): |
| def _conv_forward( |
| self, x: Tensor, weight: Tensor, bias: Optional[Tensor] |
| ) -> Tensor: |
| return super()._conv_forward( |
| x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype) |
| ) |
|
|
|
|
| def sinusoids(length, channels, max_timescale=10000): |
| """Returns sinusoids for positional embedding""" |
| assert channels % 2 == 0 |
| log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) |
| inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2)) |
| scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] |
| return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) |
|
|
|
|
| @contextmanager |
| def disable_sdpa(): |
| prev_state = MultiHeadAttention.use_sdpa |
| try: |
| MultiHeadAttention.use_sdpa = False |
| yield |
| finally: |
| MultiHeadAttention.use_sdpa = prev_state |
|
|
|
|
| class MultiHeadAttention(nn.Module): |
| use_sdpa = True |
|
|
| def __init__(self, n_state: int, n_head: int): |
| super().__init__() |
| self.n_head = n_head |
| self.query = Linear(n_state, n_state) |
| self.key = Linear(n_state, n_state, bias=False) |
| self.value = Linear(n_state, n_state) |
| self.out = Linear(n_state, n_state) |
|
|
| def forward( |
| self, |
| x: Tensor, |
| xa: Optional[Tensor] = None, |
| mask: Optional[Tensor] = None, |
| kv_cache: Optional[dict] = None, |
| ): |
| q = self.query(x) |
|
|
| if kv_cache is None or xa is None or self.key not in kv_cache: |
| |
| |
| k = self.key(x if xa is None else xa) |
| v = self.value(x if xa is None else xa) |
| else: |
| |
| k = kv_cache[self.key] |
| v = kv_cache[self.value] |
|
|
| wv, qk = self.qkv_attention(q, k, v, mask) |
| return self.out(wv), qk |
|
|
| def qkv_attention( |
| self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
| n_batch, n_ctx, n_state = q.shape |
| scale = (n_state // self.n_head) ** -0.25 |
| q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) |
| k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) |
| v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) |
|
|
| if SDPA_AVAILABLE and MultiHeadAttention.use_sdpa: |
| a = scaled_dot_product_attention( |
| q, k, v, is_causal=mask is not None and n_ctx > 1 |
| ) |
| out = a.permute(0, 2, 1, 3).flatten(start_dim=2) |
| qk = None |
| else: |
| qk = (q * scale) @ (k * scale).transpose(-1, -2) |
| if mask is not None: |
| qk = qk + mask[:n_ctx, :n_ctx] |
| qk = qk.float() |
|
|
| w = F.softmax(qk, dim=-1).to(q.dtype) |
| out = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2) |
| qk = qk.detach() |
|
|
| return out, qk |
|
|
|
|
| class ResidualAttentionBlock(nn.Module): |
| def __init__(self, n_state: int, n_head: int, cross_attention: bool = False): |
| super().__init__() |
|
|
| self.attn = MultiHeadAttention(n_state, n_head) |
| self.attn_ln = LayerNorm(n_state) |
|
|
| self.cross_attn = ( |
| MultiHeadAttention(n_state, n_head) if cross_attention else None |
| ) |
| self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None |
|
|
| n_mlp = n_state * 4 |
| self.mlp = nn.Sequential( |
| Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state) |
| ) |
| self.mlp_ln = LayerNorm(n_state) |
|
|
| def forward( |
| self, |
| x: Tensor, |
| xa: Optional[Tensor] = None, |
| mask: Optional[Tensor] = None, |
| kv_cache: Optional[dict] = None, |
| ): |
| x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0] |
| if self.cross_attn: |
| x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0] |
| x = x + self.mlp(self.mlp_ln(x)) |
| return x |
|
|
|
|
| class AudioEncoder(nn.Module): |
| def __init__( |
| self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int |
| ): |
| super().__init__() |
| self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1) |
| self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1) |
| self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state)) |
|
|
| self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList( |
| [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)] |
| ) |
| self.ln_post = LayerNorm(n_state) |
|
|
| def forward(self, x: Tensor): |
| """ |
| x : torch.Tensor, shape = (batch_size, n_mels, n_ctx) |
| the mel spectrogram of the audio |
| """ |
| x = F.gelu(self.conv1(x)) |
| x = F.gelu(self.conv2(x)) |
| x = x.permute(0, 2, 1) |
|
|
| assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape" |
| x = (x + self.positional_embedding).to(x.dtype) |
|
|
| for block in self.blocks: |
| x = block(x) |
|
|
| x = self.ln_post(x) |
| return x |