| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torchaudio | |
| from typing import Optional | |
| from comfy.ldm.modules.attention import optimized_attention_masked | |
| import comfy.ops | |
| class WhisperFeatureExtractor(nn.Module): | |
| def __init__(self, n_mels=128, device=None): | |
| super().__init__() | |
| self.sample_rate = 16000 | |
| self.n_fft = 400 | |
| self.hop_length = 160 | |
| self.n_mels = n_mels | |
| self.chunk_length = 30 | |
| self.n_samples = 480000 | |
| self.mel_spectrogram = torchaudio.transforms.MelSpectrogram( | |
| sample_rate=self.sample_rate, | |
| n_fft=self.n_fft, | |
| hop_length=self.hop_length, | |
| n_mels=self.n_mels, | |
| f_min=0, | |
| f_max=8000, | |
| norm="slaney", | |
| mel_scale="slaney", | |
| ).to(device) | |
| def __call__(self, audio): | |
| audio = torch.mean(audio, dim=1) | |
| batch_size = audio.shape[0] | |
| processed_audio = [] | |
| for i in range(batch_size): | |
| aud = audio[i] | |
| if aud.shape[0] > self.n_samples: | |
| aud = aud[:self.n_samples] | |
| elif aud.shape[0] < self.n_samples: | |
| aud = F.pad(aud, (0, self.n_samples - aud.shape[0])) | |
| processed_audio.append(aud) | |
| audio = torch.stack(processed_audio) | |
| mel_spec = self.mel_spectrogram(audio.to(self.mel_spectrogram.spectrogram.window.device))[:, :, :-1].to(audio.device) | |
| log_mel_spec = torch.clamp(mel_spec, min=1e-10).log10() | |
| log_mel_spec = torch.maximum(log_mel_spec, log_mel_spec.max() - 8.0) | |
| log_mel_spec = (log_mel_spec + 4.0) / 4.0 | |
| return log_mel_spec | |
| class MultiHeadAttention(nn.Module): | |
| def __init__(self, d_model: int, n_heads: int, dtype=None, device=None, operations=None): | |
| super().__init__() | |
| assert d_model % n_heads == 0 | |
| self.d_model = d_model | |
| self.n_heads = n_heads | |
| self.d_k = d_model // n_heads | |
| self.q_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device) | |
| self.k_proj = operations.Linear(d_model, d_model, bias=False, dtype=dtype, device=device) | |
| self.v_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device) | |
| self.out_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device) | |
| def forward( | |
| self, | |
| query: torch.Tensor, | |
| key: torch.Tensor, | |
| value: torch.Tensor, | |
| mask: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| batch_size, seq_len, _ = query.shape | |
| q = self.q_proj(query) | |
| k = self.k_proj(key) | |
| v = self.v_proj(value) | |
| attn_output = optimized_attention_masked(q, k, v, self.n_heads, mask) | |
| attn_output = self.out_proj(attn_output) | |
| return attn_output | |
| class EncoderLayer(nn.Module): | |
| def __init__(self, d_model: int, n_heads: int, d_ff: int, dtype=None, device=None, operations=None): | |
| super().__init__() | |
| self.self_attn = MultiHeadAttention(d_model, n_heads, dtype=dtype, device=device, operations=operations) | |
| self.self_attn_layer_norm = operations.LayerNorm(d_model, dtype=dtype, device=device) | |
| self.fc1 = operations.Linear(d_model, d_ff, dtype=dtype, device=device) | |
| self.fc2 = operations.Linear(d_ff, d_model, dtype=dtype, device=device) | |
| self.final_layer_norm = operations.LayerNorm(d_model, dtype=dtype, device=device) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None | |
| ) -> torch.Tensor: | |
| residual = x | |
| x = self.self_attn_layer_norm(x) | |
| x = self.self_attn(x, x, x, attention_mask) | |
| x = residual + x | |
| residual = x | |
| x = self.final_layer_norm(x) | |
| x = self.fc1(x) | |
| x = F.gelu(x) | |
| x = self.fc2(x) | |
| x = residual + x | |
| return x | |
| class AudioEncoder(nn.Module): | |
| def __init__( | |
| self, | |
| n_mels: int = 128, | |
| n_ctx: int = 1500, | |
| n_state: int = 1280, | |
| n_head: int = 20, | |
| n_layer: int = 32, | |
| dtype=None, | |
| device=None, | |
| operations=None | |
| ): | |
| super().__init__() | |
| self.conv1 = operations.Conv1d(n_mels, n_state, kernel_size=3, padding=1, dtype=dtype, device=device) | |
| self.conv2 = operations.Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1, dtype=dtype, device=device) | |
| self.embed_positions = operations.Embedding(n_ctx, n_state, dtype=dtype, device=device) | |
| self.layers = nn.ModuleList([ | |
| EncoderLayer(n_state, n_head, n_state * 4, dtype=dtype, device=device, operations=operations) | |
| for _ in range(n_layer) | |
| ]) | |
| self.layer_norm = operations.LayerNorm(n_state, dtype=dtype, device=device) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = F.gelu(self.conv1(x)) | |
| x = F.gelu(self.conv2(x)) | |
| x = x.transpose(1, 2) | |
| x = x + comfy.ops.cast_to_input(self.embed_positions.weight[:, :x.shape[1]], x) | |
| all_x = () | |
| for layer in self.layers: | |
| all_x += (x,) | |
| x = layer(x) | |
| x = self.layer_norm(x) | |
| all_x += (x,) | |
| return x, all_x | |
| class WhisperLargeV3(nn.Module): | |
| def __init__( | |
| self, | |
| n_mels: int = 128, | |
| n_audio_ctx: int = 1500, | |
| n_audio_state: int = 1280, | |
| n_audio_head: int = 20, | |
| n_audio_layer: int = 32, | |
| dtype=None, | |
| device=None, | |
| operations=None | |
| ): | |
| super().__init__() | |
| self.feature_extractor = WhisperFeatureExtractor(n_mels=n_mels, device=device) | |
| self.encoder = AudioEncoder( | |
| n_mels, n_audio_ctx, n_audio_state, n_audio_head, n_audio_layer, | |
| dtype=dtype, device=device, operations=operations | |
| ) | |
| def forward(self, audio): | |
| mel = self.feature_extractor(audio) | |
| x, all_x = self.encoder(mel) | |
| return x, all_x | |