|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| class WhisperConfig: |
| def __init__(self): |
| |
| self.sampling_rate = 16000 |
| self.n_fft = 400 |
| self.hop_length = 160 |
| self.n_mels = 80 |
| self.d_model = 384 |
| self.n_heads = 6 |
| self.n_layers = 4 |
| self.vocab_size = 1000 |
|
|
| class SimpleTokenizer: |
| def __init__(self): |
| self.token_to_id = {} |
| self.id_to_token = {} |
| self.special_tokens = { |
| "<pad>": 0, |
| "<s>": 1, |
| "</s>": 2, |
| "<unk>": 3, |
| } |
| |
| |
| for token, idx in self.special_tokens.items(): |
| self.token_to_id[token] = idx |
| self.id_to_token[idx] = token |
| |
| self.next_id = len(self.special_tokens) |
| |
| def load_vocab(self, vocab_file): |
| import json |
| with open(vocab_file, 'r', encoding='utf-8') as f: |
| self.token_to_id = json.load(f) |
| |
| |
| self.id_to_token = {int(v): k for k, v in self.token_to_id.items()} |
| self.next_id = max(map(int, self.id_to_token.keys())) + 1 |
| |
| def encode(self, text): |
| if not isinstance(text, str): |
| text = str(text) |
| |
| ids = [self.special_tokens["<s>"]] |
| for char in text: |
| if char in self.token_to_id: |
| ids.append(self.token_to_id[char]) |
| else: |
| ids.append(self.special_tokens["<unk>"]) |
| ids.append(self.special_tokens["</s>"]) |
| return ids |
| |
| def decode(self, ids): |
| text = "" |
| for id in ids: |
| |
| if id in [self.special_tokens["<pad>"], self.special_tokens["<s>"], self.special_tokens["</s>"]]: |
| continue |
| |
| id_int = int(id) if not isinstance(id, int) else id |
| if id_int in self.id_to_token: |
| text += self.id_to_token[id_int] |
| else: |
| text += self.id_to_token[self.special_tokens["<unk>"]] |
| |
| return text |
|
|
| class PositionalEncoding(nn.Module): |
| def __init__(self, d_model, max_len=5000): |
| super().__init__() |
| import math |
| pe = torch.zeros(max_len, d_model) |
| position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) |
| div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) |
| |
| pe[:, 0::2] = torch.sin(position * div_term) |
| pe[:, 1::2] = torch.cos(position * div_term) |
| pe = pe.unsqueeze(0) |
| |
| self.register_buffer('pe', pe) |
| |
| def forward(self, x): |
| return x + self.pe[:, :x.size(1)] |
|
|
| class EncoderBlock(nn.Module): |
| def __init__(self, d_model, n_heads, d_ff=2048, dropout=0.1): |
| super().__init__() |
| self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True) |
| self.norm1 = nn.LayerNorm(d_model) |
| self.norm2 = nn.LayerNorm(d_model) |
| self.ff = nn.Sequential( |
| nn.Linear(d_model, d_ff), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(d_ff, d_model) |
| ) |
| self.dropout = nn.Dropout(dropout) |
| |
| def forward(self, x, mask=None): |
| attn_output, _ = self.self_attn(x, x, x, key_padding_mask=mask) |
| x = x + self.dropout(attn_output) |
| x = self.norm1(x) |
| |
| ff_output = self.ff(x) |
| x = x + self.dropout(ff_output) |
| x = self.norm2(x) |
| |
| return x |
|
|
| class DecoderBlock(nn.Module): |
| def __init__(self, d_model, n_heads, d_ff=2048, dropout=0.1): |
| super().__init__() |
| self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True) |
| self.cross_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True) |
| self.norm1 = nn.LayerNorm(d_model) |
| self.norm2 = nn.LayerNorm(d_model) |
| self.norm3 = nn.LayerNorm(d_model) |
| self.ff = nn.Sequential( |
| nn.Linear(d_model, d_ff), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(d_ff, d_model) |
| ) |
| self.dropout = nn.Dropout(dropout) |
| |
| def forward(self, x, enc_output, tgt_mask=None, src_mask=None): |
| |
| attn_output, _ = self.self_attn(x, x, x, attn_mask=tgt_mask) |
| x = x + self.dropout(attn_output) |
| x = self.norm1(x) |
| |
| |
| attn_output, _ = self.cross_attn(x, enc_output, enc_output, key_padding_mask=src_mask) |
| x = x + self.dropout(attn_output) |
| x = self.norm2(x) |
| |
| |
| ff_output = self.ff(x) |
| x = x + self.dropout(ff_output) |
| x = self.norm3(x) |
| |
| return x |
|
|
| class AudioEncoder(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| d_model = config.d_model |
| |
| |
| self.conv1 = nn.Conv1d(config.n_mels, d_model, kernel_size=3, stride=1, padding=1) |
| self.conv2 = nn.Conv1d(d_model, d_model, kernel_size=3, stride=2, padding=1) |
| self.conv3 = nn.Conv1d(d_model, d_model, kernel_size=3, stride=2, padding=1) |
| self.conv4 = nn.Conv1d(d_model, d_model, kernel_size=3, stride=2, padding=1) |
| |
| self.norm = nn.LayerNorm(d_model) |
| self.pos_encoder = PositionalEncoding(d_model) |
| |
| self.layers = nn.ModuleList([ |
| EncoderBlock(d_model, config.n_heads, d_model * 4) |
| for _ in range(config.n_layers) |
| ]) |
| |
| self.dropout = nn.Dropout(0.1) |
| |
| def forward(self, x): |
| |
| x = F.gelu(self.conv1(x)) |
| x = F.gelu(self.conv2(x)) |
| x = F.gelu(self.conv3(x)) |
| x = F.gelu(self.conv4(x)) |
| |
| x = x.transpose(1, 2) |
| x = self.norm(x) |
| x = self.pos_encoder(x) |
| |
| for layer in self.layers: |
| x = layer(x) |
| |
| return x |
|
|
| class TextDecoder(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| d_model = config.d_model |
| vocab_size = config.vocab_size |
| |
| self.token_embedding = nn.Embedding(vocab_size, d_model) |
| self.pos_encoder = PositionalEncoding(d_model) |
| |
| self.layers = nn.ModuleList([ |
| DecoderBlock(d_model, config.n_heads, d_model * 4) |
| for _ in range(config.n_layers) |
| ]) |
| |
| self.output_projection = nn.Linear(d_model, vocab_size) |
| self.dropout = nn.Dropout(0.1) |
| |
| def forward(self, x, encoder_output, tgt_mask=None): |
| x = self.token_embedding(x) |
| x = self.pos_encoder(x) |
| |
| for layer in self.layers: |
| x = layer(x, encoder_output, tgt_mask=tgt_mask) |
| |
| x = self.output_projection(x) |
| return x |
|
|
| class WhisperModel(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.encoder = AudioEncoder(config) |
| self.decoder = TextDecoder(config) |
| self.config = config |
| |
| def _create_causal_mask(self, size): |
| mask = torch.triu(torch.ones(size, size), diagonal=1).bool() |
| return mask.to(next(self.parameters()).device) |
| |
| def forward(self, audio_features, token_ids, attention_mask=None): |
| |
| encoder_output = self.encoder(audio_features) |
| |
| |
| seq_len = token_ids.size(1) |
| causal_mask = self._create_causal_mask(seq_len) |
| |
| |
| output = self.decoder(token_ids, encoder_output, tgt_mask=causal_mask) |
| |
| return output |
| |
| def generate(self, audio_features, tokenizer, max_len=100): |
| batch_size = audio_features.size(0) |
| |
| |
| encoder_output = self.encoder(audio_features) |
| |
| |
| curr_tokens = torch.ones(batch_size, 1).fill_(tokenizer.special_tokens["<s>"]).long().to(next(self.parameters()).device) |
| |
| |
| for i in range(max_len - 1): |
| |
| causal_mask = self._create_causal_mask(curr_tokens.size(1)) |
| |
| |
| with torch.no_grad(): |
| output = self.decoder(curr_tokens, encoder_output, tgt_mask=causal_mask) |
| next_token_logits = output[:, -1, :] |
| next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) |
| |
| |
| curr_tokens = torch.cat([curr_tokens, next_token], dim=1) |
| |
| |
| if (next_token == tokenizer.special_tokens["</s>"]).all(): |
| break |
| |
| return curr_tokens |
| |
| |
| def transcribe(self, audio, beam_size=5): |
| import numpy as np |
| import torch |
| |
| |
| if isinstance(audio, str): |
| try: |
| from pydub import AudioSegment |
| audio_seg = AudioSegment.from_file(audio) |
| audio_seg = audio_seg.set_channels(1).set_frame_rate(16000) |
| audio = np.array(audio_seg.get_array_of_samples()).astype(np.float32) / 32768.0 |
| except: |
| print("Error loading audio file. Using dummy audio.") |
| audio = np.zeros(16000, dtype=np.float32) |
| |
| |
| if not isinstance(audio, np.ndarray): |
| audio = np.array(audio, dtype=np.float32) |
| |
| |
| if len(audio.shape) == 1: |
| audio = audio.reshape(1, -1) |
| |
| |
| try: |
| import torchaudio |
| |
| |
| if not isinstance(audio, torch.Tensor): |
| audio = torch.from_numpy(audio) |
| |
| |
| mel_spec = torchaudio.transforms.MelSpectrogram( |
| sample_rate=self.config.sampling_rate, |
| n_fft=self.config.n_fft, |
| hop_length=self.config.hop_length, |
| n_mels=self.config.n_mels |
| )(audio) |
| |
| log_mel_spec = torch.log(mel_spec + 1e-9) |
| |
| |
| mean = log_mel_spec.mean() |
| std = log_mel_spec.std() |
| log_mel_spec = (log_mel_spec - mean) / (std + 1e-9) |
| |
| except ImportError: |
| |
| print("torchaudio not available. Using dummy features.") |
| log_mel_spec = torch.zeros(1, self.config.n_mels, 100) |
| |
| |
| if log_mel_spec.dim() == 3: |
| |
| pass |
| elif log_mel_spec.dim() == 2: |
| |
| log_mel_spec = log_mel_spec.unsqueeze(0) |
| elif log_mel_spec.dim() == 4: |
| |
| log_mel_spec = log_mel_spec.squeeze(0) |
| |
| |
| log_mel_spec = log_mel_spec.to(next(self.parameters()).device) |
| |
| |
| with torch.no_grad(): |
| generated = self.generate(log_mel_spec, self.config.tokenizer) |
| |
| |
| transcription = self.config.tokenizer.decode(generated[0].cpu().numpy()) |
| |
| |
| class Segment: |
| def __init__(self, text): |
| self.text = text |
| |
| segments = [Segment(transcription)] |
| info = {"language": "mn"} |
| |
| return segments, info |
|
|