|
|
import math |
|
|
import os |
|
|
|
|
|
import kaldiio |
|
|
import kaldi_native_fbank as knf |
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
|
|
|
class ASRFeatExtractor: |
|
|
def __init__(self, kaldi_cmvn_file): |
|
|
self.cmvn = CMVN(kaldi_cmvn_file) if kaldi_cmvn_file != "" else None |
|
|
self.fbank = KaldifeatFbank(num_mel_bins=80, frame_length=25, |
|
|
frame_shift=10, dither=0.0) |
|
|
|
|
|
def __call__(self, wav_paths): |
|
|
feats = [] |
|
|
durs = [] |
|
|
for wav_path in wav_paths: |
|
|
sample_rate, wav_np = kaldiio.load_mat(wav_path) |
|
|
|
|
|
dur = wav_np.shape[0] / sample_rate |
|
|
fbank = self.fbank((sample_rate, wav_np)) |
|
|
if self.cmvn is not None: |
|
|
fbank = self.cmvn(fbank) |
|
|
fbank = torch.from_numpy(fbank).float() |
|
|
feats.append(fbank) |
|
|
durs.append(dur) |
|
|
lengths = torch.tensor([feat.size(0) for feat in feats]).long() |
|
|
feats_pad = self.pad_feat(feats, 0.0) |
|
|
return feats_pad, lengths, durs |
|
|
|
|
|
def run_chunk(self, wav_np, sample_rate): |
|
|
feats = [] |
|
|
|
|
|
dur = wav_np.shape[0] / sample_rate |
|
|
fbank = self.fbank((sample_rate, wav_np)) |
|
|
if self.cmvn is not None: |
|
|
fbank = self.cmvn(fbank) |
|
|
fbank = torch.from_numpy(fbank).float() |
|
|
feats.append(fbank) |
|
|
|
|
|
lengths = torch.tensor([feat.size(0) for feat in feats]).long() |
|
|
feats_pad = self.pad_feat(feats, 0.0) |
|
|
return feats_pad.numpy(), lengths, dur |
|
|
|
|
|
def pad_feat(self, xs, pad_value): |
|
|
|
|
|
n_batch = len(xs) |
|
|
max_len = max([xs[i].size(0) for i in range(n_batch)]) |
|
|
pad = torch.ones(n_batch, max_len, *xs[0].size()[1:]).to(xs[0].device).to(xs[0].dtype).fill_(pad_value) |
|
|
for i in range(n_batch): |
|
|
pad[i, :xs[i].size(0)] = xs[i] |
|
|
return pad |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CMVN: |
|
|
def __init__(self, kaldi_cmvn_file): |
|
|
self.dim, self.means, self.inverse_std_variences = \ |
|
|
self.read_kaldi_cmvn(kaldi_cmvn_file) |
|
|
|
|
|
def __call__(self, x, is_train=False): |
|
|
assert x.shape[-1] == self.dim, "CMVN dim mismatch" |
|
|
out = x - self.means |
|
|
out = out * self.inverse_std_variences |
|
|
return out |
|
|
|
|
|
def read_kaldi_cmvn(self, kaldi_cmvn_file): |
|
|
assert os.path.exists(kaldi_cmvn_file) |
|
|
stats = kaldiio.load_mat(kaldi_cmvn_file) |
|
|
assert stats.shape[0] == 2 |
|
|
dim = stats.shape[-1] - 1 |
|
|
count = stats[0, dim] |
|
|
assert count >= 1 |
|
|
floor = 1e-20 |
|
|
means = [] |
|
|
inverse_std_variences = [] |
|
|
for d in range(dim): |
|
|
mean = stats[0, d] / count |
|
|
means.append(mean.item()) |
|
|
varience = (stats[1, d] / count) - mean*mean |
|
|
if varience < floor: |
|
|
varience = floor |
|
|
istd = 1.0 / math.sqrt(varience) |
|
|
inverse_std_variences.append(istd) |
|
|
return dim, np.array(means), np.array(inverse_std_variences) |
|
|
|
|
|
|
|
|
|
|
|
class KaldifeatFbank: |
|
|
def __init__(self, num_mel_bins=80, frame_length=25, frame_shift=10, |
|
|
dither=1.0): |
|
|
self.dither = dither |
|
|
opts = knf.FbankOptions() |
|
|
opts.frame_opts.dither = dither |
|
|
opts.mel_opts.num_bins = num_mel_bins |
|
|
opts.frame_opts.snip_edges = True |
|
|
opts.mel_opts.debug_mel = False |
|
|
self.opts = opts |
|
|
|
|
|
def __call__(self, wav, is_train=False): |
|
|
if type(wav) is str: |
|
|
sample_rate, wav_np = kaldiio.load_mat(wav) |
|
|
elif type(wav) in [tuple, list] and len(wav) == 2: |
|
|
sample_rate, wav_np = wav |
|
|
assert len(wav_np.shape) == 1 |
|
|
|
|
|
dither = self.dither if is_train else 0.0 |
|
|
self.opts.frame_opts.dither = dither |
|
|
fbank = knf.OnlineFbank(self.opts) |
|
|
|
|
|
fbank.accept_waveform(sample_rate, wav_np.tolist()) |
|
|
feat = [] |
|
|
for i in range(fbank.num_frames_ready): |
|
|
feat.append(fbank.get_frame(i)) |
|
|
if len(feat) == 0: |
|
|
print("Check data, len(feat) == 0", wav, flush=True) |
|
|
return np.zeros((0, self.opts.mel_opts.num_bins)) |
|
|
feat = np.vstack(feat) |
|
|
return feat |
|
|
|