RicardoQi commited on
Commit
1befda4
·
verified ·
1 Parent(s): 9b320f8

Upload distilled conformer recognizer object and loader script

Browse files
Files changed (3) hide show
  1. .gitattributes +1 -0
  2. modeling.py +62 -0
  3. recognizer.dill +3 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ recognizer.dill filter=lfs diff=lfs merge=lfs -text
modeling.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Auto-generated to contain necessary class definitions for loading the recognizer.
2
+ import torch
3
+ import torch.nn as nn
4
+ import torchaudio
5
+ import torchaudio.transforms as T
6
+ from torchaudio.models import Conformer
7
+ from allosaurus.audio import Audio
8
+
9
+ class ConformerAcousticModel(nn.Module):
10
+ def __init__(self, input_dim: int, num_phonemes: int, d_model: int, ffn_dim: int = 2560, num_heads: int = 4, num_layers: int = 8, depthwise_conv_kernel_size: int = 31, dropout: float = 0.1):
11
+ super().__init__()
12
+ self.input_projection = nn.Sequential(
13
+ nn.Linear(input_dim, d_model),
14
+ nn.LayerNorm(d_model),
15
+ nn.Dropout(dropout)
16
+ )
17
+ self.conformer = Conformer(
18
+ input_dim=d_model,
19
+ num_heads=num_heads,
20
+ ffn_dim=ffn_dim,
21
+ num_layers=num_layers,
22
+ depthwise_conv_kernel_size=depthwise_conv_kernel_size,
23
+ dropout=dropout
24
+ )
25
+ self.output_projection = nn.Linear(d_model, num_phonemes)
26
+
27
+ def forward(self, features: torch.Tensor, lengths: torch.Tensor) -> torch.Tensor:
28
+ x = self.input_projection(features)
29
+ x, _ = self.conformer(x, lengths)
30
+ logits = self.output_projection(x)
31
+ return logits
32
+
33
+ class UpgradedRecognizer:
34
+ def __init__(self, pm_module, am_module, lm_module, device):
35
+ self.pm = pm_module
36
+ self.am = am_module
37
+ self.lm = lm_module
38
+ self.device = device
39
+ self.am.to(self.device)
40
+ self.am.eval()
41
+
42
+ def recognize(self, audio_path: str) -> str:
43
+ waveform, sr = torchaudio.load(audio_path)
44
+ if sr != 16000:
45
+ resampler = T.Resample(sr, 16000).to(waveform.device)
46
+ waveform = resampler(waveform)
47
+ if waveform.shape[0] > 1:
48
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
49
+
50
+ audio_object = Audio(waveform.squeeze().cpu().numpy(), 16000)
51
+
52
+ features = self.pm.compute(audio_object)
53
+ features_tensor = torch.tensor(features).unsqueeze(0).to(self.device)
54
+ lengths_tensor = torch.tensor([features_tensor.shape[1]], device=self.device)
55
+
56
+ with torch.no_grad():
57
+ logits = self.am(features_tensor, lengths_tensor)
58
+
59
+ logits_numpy = logits.squeeze(0).cpu().numpy()
60
+ phoneme_list = self.lm.compute(logits_numpy, lang_id='ipa', topk=1)
61
+
62
+ return " ".join(phoneme_list)
recognizer.dill ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8b7fe6beee958af085db2ceb6fe6c30b1666a48ca8b710ef6bcdbea2e20faf5a
3
+ size 190360911