Allofomer_ru / modeling.py
RicardoQi's picture
Update modeling.py
1f29f06 verified
# Auto-generated to contain necessary class definitions for loading the recognizer.
import torch
import torch.nn as nn
import torchaudio
import torchaudio.transforms as T
from torchaudio.models import Conformer
from allosaurus.audio import Audio
class ConformerAcousticModel(nn.Module):
def __init__(self, input_dim: int, num_phonemes: int, d_model: int, ffn_dim: int = 2560, num_heads: int = 4, num_layers: int = 8, depthwise_conv_kernel_size: int = 31, dropout: float = 0.1):
super().__init__()
self.input_projection = nn.Sequential(
nn.Linear(input_dim, d_model),
nn.LayerNorm(d_model),
nn.Dropout(dropout)
)
self.conformer = Conformer(
input_dim=d_model,
num_heads=num_heads,
ffn_dim=ffn_dim,
num_layers=num_layers,
depthwise_conv_kernel_size=depthwise_conv_kernel_size,
dropout=dropout
)
self.output_projection = nn.Linear(d_model, num_phonemes)
def forward(self, features: torch.Tensor, lengths: torch.Tensor) -> torch.Tensor:
x = self.input_projection(features)
x, _ = self.conformer(x, lengths)
logits = self.output_projection(x)
return logits
class UpgradedRecognizer:
def __init__(self, pm_module, am_module, lm_module, device):
self.pm = pm_module
self.am = am_module
self.lm = lm_module
self.device = device
self.am.to(self.device)
self.am.eval()
def recognize(self, audio_path: str) -> str:
waveform, sr = torchaudio.load(audio_path)
if sr != 16000:
resampler = T.Resample(sr, 16000).to(waveform.device)
waveform = resampler(waveform)
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
audio_object = Audio(waveform.squeeze().cpu().numpy(), 16000)
features = self.pm.compute(audio_object)
features_tensor = torch.tensor(features).unsqueeze(0).to(self.device)
lengths_tensor = torch.tensor([features_tensor.shape[1]], device=self.device)
with torch.no_grad():
logits = self.am(features_tensor, lengths_tensor)
logits_numpy = logits.squeeze(0).cpu().numpy()
phoneme_list = self.lm.compute(logits_numpy, lang_id='ipa', topk=1)
return " ".join(phoneme_list)