| |
| import torch |
| import torch.nn as nn |
| import torchaudio |
| import torchaudio.transforms as T |
| from torchaudio.models import Conformer |
| from allosaurus.audio import Audio |
|
|
| class ConformerAcousticModel(nn.Module): |
| def __init__(self, input_dim: int, num_phonemes: int, d_model: int, ffn_dim: int = 2560, num_heads: int = 4, num_layers: int = 8, depthwise_conv_kernel_size: int = 31, dropout: float = 0.1): |
| super().__init__() |
| self.input_projection = nn.Sequential( |
| nn.Linear(input_dim, d_model), |
| nn.LayerNorm(d_model), |
| nn.Dropout(dropout) |
| ) |
| self.conformer = Conformer( |
| input_dim=d_model, |
| num_heads=num_heads, |
| ffn_dim=ffn_dim, |
| num_layers=num_layers, |
| depthwise_conv_kernel_size=depthwise_conv_kernel_size, |
| dropout=dropout |
| ) |
| self.output_projection = nn.Linear(d_model, num_phonemes) |
|
|
| def forward(self, features: torch.Tensor, lengths: torch.Tensor) -> torch.Tensor: |
| x = self.input_projection(features) |
| x, _ = self.conformer(x, lengths) |
| logits = self.output_projection(x) |
| return logits |
|
|
| class UpgradedRecognizer: |
| def __init__(self, pm_module, am_module, lm_module, device): |
| self.pm = pm_module |
| self.am = am_module |
| self.lm = lm_module |
| self.device = device |
| self.am.to(self.device) |
| self.am.eval() |
|
|
| def recognize(self, audio_path: str) -> str: |
| waveform, sr = torchaudio.load(audio_path) |
| if sr != 16000: |
| resampler = T.Resample(sr, 16000).to(waveform.device) |
| waveform = resampler(waveform) |
| if waveform.shape[0] > 1: |
| waveform = torch.mean(waveform, dim=0, keepdim=True) |
|
|
| audio_object = Audio(waveform.squeeze().cpu().numpy(), 16000) |
|
|
| features = self.pm.compute(audio_object) |
| features_tensor = torch.tensor(features).unsqueeze(0).to(self.device) |
| lengths_tensor = torch.tensor([features_tensor.shape[1]], device=self.device) |
|
|
| with torch.no_grad(): |
| logits = self.am(features_tensor, lengths_tensor) |
|
|
| logits_numpy = logits.squeeze(0).cpu().numpy() |
| phoneme_list = self.lm.compute(logits_numpy, lang_id='ipa', topk=1) |
|
|
| return " ".join(phoneme_list) |