inoryQwQ's picture
Optimize RTF and cer
d56de90
import math
import os
import kaldiio
import kaldi_native_fbank as knf
import numpy as np
import torch
class ASRFeatExtractor:
def __init__(self, kaldi_cmvn_file):
self.cmvn = CMVN(kaldi_cmvn_file) if kaldi_cmvn_file != "" else None
self.fbank = KaldifeatFbank(num_mel_bins=80, frame_length=25,
frame_shift=10, dither=0.0)
def __call__(self, wav_paths):
feats = []
durs = []
for wav_path in wav_paths:
sample_rate, wav_np = kaldiio.load_mat(wav_path)
dur = wav_np.shape[0] / sample_rate
fbank = self.fbank((sample_rate, wav_np))
if self.cmvn is not None:
fbank = self.cmvn(fbank)
fbank = torch.from_numpy(fbank).float()
feats.append(fbank)
durs.append(dur)
lengths = torch.tensor([feat.size(0) for feat in feats]).long()
feats_pad = self.pad_feat(feats, 0.0)
return feats_pad, lengths, durs
def run_chunk(self, wav_np, sample_rate):
feats = []
dur = wav_np.shape[0] / sample_rate
fbank = self.fbank((sample_rate, wav_np))
if self.cmvn is not None:
fbank = self.cmvn(fbank)
fbank = torch.from_numpy(fbank).float()
feats.append(fbank)
lengths = torch.tensor([feat.size(0) for feat in feats]).long()
feats_pad = self.pad_feat(feats, 0.0)
return feats_pad.numpy(), lengths, dur
def pad_feat(self, xs, pad_value):
# type: (List[Tensor], int) -> Tensor
n_batch = len(xs)
max_len = max([xs[i].size(0) for i in range(n_batch)])
pad = torch.ones(n_batch, max_len, *xs[0].size()[1:]).to(xs[0].device).to(xs[0].dtype).fill_(pad_value)
for i in range(n_batch):
pad[i, :xs[i].size(0)] = xs[i]
return pad
class CMVN:
def __init__(self, kaldi_cmvn_file):
self.dim, self.means, self.inverse_std_variences = \
self.read_kaldi_cmvn(kaldi_cmvn_file)
def __call__(self, x, is_train=False):
assert x.shape[-1] == self.dim, "CMVN dim mismatch"
out = x - self.means
out = out * self.inverse_std_variences
return out
def read_kaldi_cmvn(self, kaldi_cmvn_file):
assert os.path.exists(kaldi_cmvn_file)
stats = kaldiio.load_mat(kaldi_cmvn_file)
assert stats.shape[0] == 2
dim = stats.shape[-1] - 1
count = stats[0, dim]
assert count >= 1
floor = 1e-20
means = []
inverse_std_variences = []
for d in range(dim):
mean = stats[0, d] / count
means.append(mean.item())
varience = (stats[1, d] / count) - mean*mean
if varience < floor:
varience = floor
istd = 1.0 / math.sqrt(varience)
inverse_std_variences.append(istd)
return dim, np.array(means), np.array(inverse_std_variences)
class KaldifeatFbank:
def __init__(self, num_mel_bins=80, frame_length=25, frame_shift=10,
dither=1.0):
self.dither = dither
opts = knf.FbankOptions()
opts.frame_opts.dither = dither
opts.mel_opts.num_bins = num_mel_bins
opts.frame_opts.snip_edges = True
opts.mel_opts.debug_mel = False
self.opts = opts
def __call__(self, wav, is_train=False):
if type(wav) is str:
sample_rate, wav_np = kaldiio.load_mat(wav)
elif type(wav) in [tuple, list] and len(wav) == 2:
sample_rate, wav_np = wav
assert len(wav_np.shape) == 1
dither = self.dither if is_train else 0.0
self.opts.frame_opts.dither = dither
fbank = knf.OnlineFbank(self.opts)
fbank.accept_waveform(sample_rate, wav_np.tolist())
feat = []
for i in range(fbank.num_frames_ready):
feat.append(fbank.get_frame(i))
if len(feat) == 0:
print("Check data, len(feat) == 0", wav, flush=True)
return np.zeros((0, self.opts.mel_opts.num_bins))
feat = np.vstack(feat)
return feat