| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import io |
| import json |
| from subprocess import PIPE, Popen |
| from urllib.parse import urlparse |
| from langid.langid import LanguageIdentifier, model |
| import logging |
| import librosa |
| import random |
|
|
| import torch |
| from torch.nn.utils.rnn import pad_sequence |
| import torchaudio |
| import torchaudio.compliance.kaldi as kaldi |
| import torch.nn.functional as F |
| from wenet.text.base_tokenizer import BaseTokenizer |
|
|
| torchaudio.utils.sox_utils.set_buffer_size(16500) |
|
|
| AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma']) |
|
|
| lid = LanguageIdentifier.from_modelstring(model, norm_probs=True) |
|
|
| logging.getLogger('langid').setLevel(logging.INFO) |
|
|
| import os |
| try: |
| cpu_info = os.popen("lscpu | grep 'Vendor ID'").read() |
| |
| if (cpu_info.rstrip().split(" ")[-1] == "0x48"): |
| |
| |
| |
| |
| torch.set_num_threads(1) |
| except Exception as ex: |
| logging.warning('Failed to set number of thread in Kunpeng, \ |
| this may cause segmentfault while dataloading, \ |
| ignore this warning if you are not using Kunpeng') |
|
|
|
|
| class UrlOpenError(Exception): |
|
|
| def __init__(self, msg: str, *args: object) -> None: |
| super().__init__(*args) |
| self.err_msg = msg |
|
|
| def __str__(self) -> str: |
| return self.err_msg |
|
|
|
|
| def parse_json(elem): |
| line = elem['line'] |
| obj = json.loads(line) |
| obj['file_name'] = elem['file_name'] |
| return dict(obj) |
|
|
|
|
| def parse_url(elem): |
| assert 'file_name' in elem |
| assert 'line' in elem |
| assert isinstance(elem, dict) |
| url = elem['line'] |
| try: |
| pr = urlparse(url) |
| |
| if pr.scheme == '' or pr.scheme == 'file': |
| stream = open(url, 'rb') |
| |
| else: |
| cmd = f'wget -q -O - {url}' |
| process = Popen(cmd, shell=True, stdout=PIPE) |
| elem.update(process=process) |
| stream = process.stdout |
| elem.update(stream=stream) |
| return elem |
| except Exception as ex: |
| err_msg = 'Failed to open {}'.format(url) |
| raise UrlOpenError(err_msg) from ex |
|
|
|
|
| def parse_speaker(sample, speaker_dict): |
| assert 'speaker' in sample |
| speaker = sample['speaker'] |
| sample['speaker'] = speaker_dict.get(speaker, 0) |
| return sample |
|
|
|
|
| def detect_language(sample, limited_langs): |
| assert 'txt' in sample |
| |
| |
| |
| |
| |
| lid.set_languages(limited_langs) |
| |
| sample['lang'] = lid.classify(sample['txt'])[0] |
| return sample |
|
|
|
|
| def detect_task(sample): |
| |
| |
| |
| |
| sample['task'] = "transcribe" |
| return sample |
|
|
|
|
| def decode_wav(sample): |
| """ Parse key/wav/txt from json line |
| |
| Args: |
| sample: str, str is a json line has key/wav |
| |
| Returns: |
| {key, wav, sample_rate, ...} |
| """ |
| assert 'key' in sample |
| assert 'wav' in sample |
| wav_file = sample['wav'] |
| if isinstance(wav_file, str): |
| with open(wav_file, 'rb') as f: |
| wav_file = f.read() |
| if 'start' in sample: |
| assert 'end' in sample |
| sample_rate = torchaudio.info(wav_file).sample_rate |
| start_frame = int(sample['start'] * sample_rate) |
| end_frame = int(sample['end'] * sample_rate) |
| with io.BytesIO(wav_file) as file_obj: |
| waveform, _ = torchaudio.load(file_obj, |
| num_frames=end_frame - start_frame, |
| frame_offset=start_frame) |
| else: |
| with io.BytesIO(wav_file) as file_obj: |
| waveform, sample_rate = torchaudio.load(file_obj) |
| |
| del sample['wav'] |
| sample['wav'] = waveform |
| sample['sample_rate'] = sample_rate |
| return sample |
|
|
|
|
| def singal_channel(sample, channel=0): |
| """ Choose a channel of sample. |
| Inplace operation. |
| |
| Args: |
| sample: {key, wav, label, sample_rate} |
| channel: target channel index |
| |
| Returns: |
| {key, wav, label, sample_rate} |
| """ |
| assert 'wav' in sample |
| waveform = sample['wav'] |
| channel_nums = waveform.size(0) |
| assert channel < channel_nums |
| if channel_nums != 1: |
| waveform = waveform[channel, :].unsqueeze(0) |
| sample['wav'] = waveform |
| return sample |
|
|
|
|
| def resample(sample, resample_rate=16000): |
| """ Resample sample. |
| Inplace operation. |
| |
| Args: |
| sample: {key, wav, label, sample_rate} |
| resample_rate: target resample rate |
| |
| Returns: |
| {key, wav, label, sample_rate} |
| """ |
| assert 'sample_rate' in sample |
| assert 'wav' in sample |
| sample_rate = sample['sample_rate'] |
| waveform = sample['wav'] |
| if sample_rate != resample_rate: |
| sample['sample_rate'] = resample_rate |
| sample['wav'] = torchaudio.transforms.Resample( |
| orig_freq=sample_rate, new_freq=resample_rate)(waveform) |
| return sample |
|
|
|
|
| def speed_perturb(sample, speeds=None): |
| """ Apply speed perturb to the sample. |
| Inplace operation. |
| |
| Args: |
| sample: {key, wav, label, sample_rate} |
| speeds(List[float]): optional speed |
| |
| Returns: |
| key, wav, label, sample_rate} |
| """ |
| if speeds is None: |
| speeds = [0.9, 1.0, 1.1] |
| assert 'sample_rate' in sample |
| assert 'wav' in sample |
| sample_rate = sample['sample_rate'] |
| waveform = sample['wav'] |
| speed = random.choice(speeds) |
| if speed != 1.0: |
| wav, _ = torchaudio.sox_effects.apply_effects_tensor( |
| waveform, sample_rate, |
| [['speed', str(speed)], ['rate', str(sample_rate)]]) |
| sample['wav'] = wav |
|
|
| return sample |
|
|
|
|
| def compute_fbank(sample, |
| num_mel_bins=23, |
| frame_length=25, |
| frame_shift=10, |
| dither=0.0, |
| window_type="povey"): |
| """ Extract fbank |
| |
| Args: |
| sample: {key, wav, sample_rate, ...} |
| |
| Returns: |
| {key, feat, wav, sample_rate, ...} |
| """ |
| assert 'sample_rate' in sample |
| assert 'wav' in sample |
| assert 'key' in sample |
| sample_rate = sample['sample_rate'] |
| waveform = sample['wav'] |
| waveform = waveform * (1 << 15) |
| |
| mat = kaldi.fbank(waveform, |
| num_mel_bins=num_mel_bins, |
| frame_length=frame_length, |
| frame_shift=frame_shift, |
| dither=dither, |
| energy_floor=0.0, |
| sample_frequency=sample_rate, |
| window_type=window_type) |
| sample['feat'] = mat |
| return sample |
|
|
|
|
| def compute_w2vbert_fbank(sample, |
| num_mel_bins=23, |
| frame_length=25, |
| frame_shift=10, |
| dither=0.0): |
| """ Extract Pretrain w2vbert(4.5M hours) fbank |
| """ |
| sample = compute_fbank(sample, num_mel_bins, frame_length, frame_shift, |
| dither) |
| mat = sample['feat'] |
| std, mean = torch.std_mean(mat, dim=0) |
| mat = mat.subtract(mean).divide(std) |
| sample['feat'] = mat |
| return sample |
|
|
|
|
| def sort_by_feats(sample): |
| assert 'feat' in sample |
| assert isinstance(sample['feat'], torch.Tensor) |
| return sample['feat'].size(0) |
|
|
|
|
| def feats_length_fn(sample) -> int: |
| assert 'feat' in sample |
| return sample['feat'].size(0) |
|
|
|
|
| def compute_mfcc(sample, |
| num_mel_bins=23, |
| frame_length=25, |
| frame_shift=10, |
| dither=0.0, |
| num_ceps=40, |
| high_freq=0.0, |
| low_freq=20.0): |
| """ Extract mfcc |
| |
| Args: |
| sample: {key, wav, sample_rate, ...} |
| |
| Returns: |
| {key, wav, feat, sample_rate, ...} |
| """ |
| assert 'wav' in sample |
| assert 'key' in sample |
| sample_rate = sample['sample_rate'] |
| waveform = sample['wav'] |
| waveform = waveform * (1 << 15) |
| mat = kaldi.mfcc(waveform, |
| num_mel_bins=num_mel_bins, |
| frame_length=frame_length, |
| frame_shift=frame_shift, |
| dither=dither, |
| num_ceps=num_ceps, |
| high_freq=high_freq, |
| low_freq=low_freq, |
| sample_frequency=sample_rate) |
| sample['feat'] = mat |
| return sample |
|
|
|
|
| def compute_log_mel_spectrogram(sample, |
| n_fft=400, |
| hop_length=160, |
| num_mel_bins=80, |
| padding=0, |
| pad_or_trim: bool = False, |
| max_duration: int = 30): |
| """ Extract log mel spectrogram, modified from openai-whisper, see: |
| - https://github.com/openai/whisper/blob/main/whisper/audio.py |
| - https://github.com/wenet-e2e/wenet/pull/2141#issuecomment-1811765040 |
| |
| Args: |
| sample: {key, wav, sample_rate, ...} |
| max_duration: valid when pad_or_trim is True (orign whisper style) |
| |
| Returns: |
| {key, feat, wav, sample_rate, ...} |
| """ |
| assert 'sample_rate' in sample |
| assert 'wav' in sample |
| assert 'key' in sample |
| sample_rate = sample['sample_rate'] |
| waveform = sample['wav'].squeeze(0) |
| if padding > 0: |
| waveform = F.pad(waveform, (0, padding)) |
| if pad_or_trim: |
| length = max_duration * sample_rate |
| if waveform.size(0) >= length: |
| waveform = waveform[:length] |
| else: |
| waveform = F.pad(waveform, (0, length - waveform.size(0))) |
|
|
| window = torch.hann_window(n_fft) |
| stft = torch.stft(waveform, |
| n_fft, |
| hop_length, |
| window=window, |
| return_complex=True) |
| magnitudes = stft[..., :-1].abs()**2 |
|
|
| filters = torch.from_numpy( |
| librosa.filters.mel(sr=sample_rate, n_fft=n_fft, n_mels=num_mel_bins)) |
| mel_spec = filters @ magnitudes |
|
|
| |
| log_spec = torch.clamp(mel_spec, min=1e-10).log10() |
| log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) |
| log_spec = (log_spec + 4.0) / 4.0 |
| sample['feat'] = log_spec.transpose(0, 1) |
| return sample |
|
|
|
|
| def tokenize(sample, tokenizer: BaseTokenizer): |
| """ Decode text to chars or BPE |
| Inplace operation |
| |
| Args: |
| sample: {key, wav, txt, sample_rate, ...} |
| |
| Returns: |
| {key, wav, txt, tokens, label, sample_rate, ...} |
| """ |
| assert 'txt' in sample |
| tokens, label = tokenizer.tokenize(sample['txt']) |
| sample['tokens'] = tokens |
| sample['label'] = label |
| return sample |
|
|
|
|
| def filter(sample, |
| max_length=10240, |
| min_length=10, |
| token_max_length=200, |
| token_min_length=1, |
| min_output_input_ratio=0.0005, |
| max_output_input_ratio=1): |
| """ Filter sample according to feature and label length |
| Inplace operation. |
| |
| Args:: |
| sample: {key, wav, label, sample_rate, ...}] |
| max_length: drop utterance which is greater than max_length(10ms) |
| min_length: drop utterance which is less than min_length(10ms) |
| token_max_length: drop utterance which is greater than |
| token_max_length, especially when use char unit for |
| english modeling |
| token_min_length: drop utterance which is |
| less than token_max_length |
| min_output_input_ratio: minimal ration of |
| token_length / feats_length(10ms) |
| max_output_input_ratio: maximum ration of |
| token_length / feats_length(10ms) |
| |
| Returns: |
| bool: True to keep, False to filter |
| """ |
| assert 'sample_rate' in sample |
| assert 'wav' in sample |
| |
| num_frames = sample['wav'].size(1) / sample['sample_rate'] * 100 |
| if num_frames < min_length: |
| return False |
| if num_frames > max_length: |
| return False |
|
|
| if 'label' in sample: |
| if len(sample['label']) < token_min_length: |
| return False |
| if len(sample['label']) > token_max_length: |
| return False |
| if num_frames != 0: |
| if len(sample['label']) / num_frames < min_output_input_ratio: |
| return False |
| if len(sample['label']) / num_frames > max_output_input_ratio: |
| return False |
| return True |
|
|
|
|
| def spec_aug(sample, num_t_mask=2, num_f_mask=2, max_t=50, max_f=10, max_w=80): |
| """ Do spec augmentation |
| Inplace operation |
| |
| Args: |
| sample: {key, feat, ...} |
| num_t_mask: number of time mask to apply |
| num_f_mask: number of freq mask to apply |
| max_t: max width of time mask |
| max_f: max width of freq mask |
| max_w: max width of time warp |
| |
| Returns |
| {key, feat, ....} |
| """ |
| assert 'feat' in sample |
| x = sample['feat'] |
| assert isinstance(x, torch.Tensor) |
| y = x.clone().detach() |
| max_frames = y.size(0) |
| max_freq = y.size(1) |
| |
| for i in range(num_t_mask): |
| start = random.randint(0, max_frames - 1) |
| length = random.randint(1, max_t) |
| end = min(max_frames, start + length) |
| y[start:end, :] = 0 |
| |
| for _ in range(num_f_mask): |
| start = random.randint(0, max_freq - 1) |
| length = random.randint(1, max_f) |
| end = min(max_freq, start + length) |
| y[:, start:end] = 0 |
| sample['feat'] = y |
| return sample |
|
|
|
|
| def spec_sub(sample, max_t=20, num_t_sub=3): |
| """ Do spec substitute |
| Inplace operation |
| ref: U2++, section 3.2.3 [https://arxiv.org/abs/2106.05642] |
| |
| Args: |
| sample: Iterable{key, feat, ...} |
| max_t: max width of time substitute |
| num_t_sub: number of time substitute to apply |
| |
| Returns |
| {key, feat, ...} |
| """ |
| assert 'feat' in sample |
| x = sample['feat'] |
| assert isinstance(x, torch.Tensor) |
| y = x.clone().detach() |
| max_frames = y.size(0) |
| for _ in range(num_t_sub): |
| start = random.randint(0, max_frames - 1) |
| length = random.randint(1, max_t) |
| end = min(max_frames, start + length) |
| |
| pos = random.randint(0, start) |
| y[start:end, :] = x[start - pos:end - pos, :] |
| sample['feat'] = y |
| return sample |
|
|
|
|
| def spec_trim(sample, max_t=20): |
| """ Trim tailing frames. Inplace operation. |
| ref: TrimTail [https://arxiv.org/abs/2211.00522] |
| |
| Args: |
| sample: {key, feat, label} |
| max_t: max width of length trimming |
| |
| Returns: |
| {key, feat, label} |
| """ |
| assert 'feat' in sample |
| x = sample['feat'] |
| assert isinstance(x, torch.Tensor) |
| max_frames = x.size(0) |
| length = random.randint(1, max_t) |
| if length < max_frames / 2: |
| y = x.clone().detach()[:max_frames - length] |
| sample['feat'] = y |
| return sample |
|
|
|
|
| def padding(data): |
| """ Padding the data into training data |
| |
| Args: |
| data: List[{key, feat, label} |
| |
| Returns: |
| Tuple(keys, feats, labels, feats lengths, label lengths) |
| """ |
| sample = data |
| assert isinstance(sample, list) |
| feats_length = torch.tensor([x['feat'].size(0) for x in sample], |
| dtype=torch.int32) |
| order = torch.argsort(feats_length, descending=True) |
| feats_lengths = torch.tensor([sample[i]['feat'].size(0) for i in order], |
| dtype=torch.int32) |
| sorted_feats = [sample[i]['feat'] for i in order] |
| sorted_keys = [sample[i]['key'] for i in order] |
| sorted_labels = [ |
| torch.tensor(sample[i]['label'], dtype=torch.int64) for i in order |
| ] |
| sorted_wavs = [sample[i]['wav'].squeeze(0) for i in order] |
| langs = [sample[i]['lang'] for i in order] |
| tasks = [sample[i]['task'] for i in order] |
| label_lengths = torch.tensor([x.size(0) for x in sorted_labels], |
| dtype=torch.int32) |
| wav_lengths = torch.tensor([x.size(0) for x in sorted_wavs], |
| dtype=torch.int32) |
| padded_feats = pad_sequence(sorted_feats, |
| batch_first=True, |
| padding_value=0) |
| padding_labels = pad_sequence(sorted_labels, |
| batch_first=True, |
| padding_value=-1) |
| padded_wavs = pad_sequence(sorted_wavs, batch_first=True, padding_value=0) |
|
|
| batch = { |
| "keys": sorted_keys, |
| "feats": padded_feats, |
| "target": padding_labels, |
| "feats_lengths": feats_lengths, |
| "target_lengths": label_lengths, |
| "pcm": padded_wavs, |
| "pcm_length": wav_lengths, |
| "langs": langs, |
| "tasks": tasks, |
| } |
| if 'speaker' in sample[0]: |
| speaker = torch.tensor([sample[i]['speaker'] for i in order], |
| dtype=torch.int32) |
| batch['speaker'] = speaker |
| return batch |
|
|
|
|
| class DynamicBatchWindow: |
|
|
| def __init__(self, max_frames_in_batch=12000): |
| self.longest_frames = 0 |
| self.max_frames_in_batch = max_frames_in_batch |
|
|
| def __call__(self, sample, buffer_size): |
| assert isinstance(sample, dict) |
| assert 'feat' in sample |
| assert isinstance(sample['feat'], torch.Tensor) |
| new_sample_frames = sample['feat'].size(0) |
| self.longest_frames = max(self.longest_frames, new_sample_frames) |
| frames_after_padding = self.longest_frames * (buffer_size + 1) |
| if frames_after_padding > self.max_frames_in_batch: |
| self.longest_frames = new_sample_frames |
| return True |
| return False |
|
|