Spaces:
Runtime error
Runtime error
| # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/4A. S2A dataset preparation.ipynb. | |
| # %% auto 0 | |
| __all__ = ['flac_to_s2a_name'] | |
| # %% ../nbs/4A. S2A dataset preparation.ipynb 2 | |
| import sys | |
| import os | |
| import itertools | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| import torchaudio | |
| import torch.nn.functional as F | |
| from torch.profiler import profile, record_function, ProfilerActivity | |
| from fastprogress import progress_bar | |
| from fastcore.script import * | |
| import whisper | |
| from . import vad, wh_transcribe, vq_stoks, extract_acoustic | |
| import webdataset as wds | |
| # %% ../nbs/4A. S2A dataset preparation.ipynb 4 | |
| def flac_to_s2a_name(input): | |
| if '-flac-' in input: | |
| return input.rsplit("/", 1)[1].replace('flac', 's2a') + ".gz" | |
| else: | |
| return input.rsplit("/", 1)[1].replace('raw', 's2a') + ".gz" | |
| # %% ../nbs/4A. S2A dataset preparation.ipynb 6 | |
| def resampler(newsr = 24000, key = 'samples_24k'): | |
| _last_sr = None | |
| tform = None | |
| def _resample(samples): | |
| for s in samples: | |
| sr = s['sample_rate'] | |
| if sr != newsr: | |
| if sr != _last_sr: tform = torchaudio.transforms.Resample(sr, newsr) | |
| s[key] = tform(s['samples']) | |
| else: | |
| s[key] = s['samples'] | |
| yield s | |
| return _resample | |
| # %% ../nbs/4A. S2A dataset preparation.ipynb 9 | |
| def prepare_s2a( | |
| input:str, # FLAC webdataset file path (or - to read the names from stdin) | |
| proc_dataset_path:Path, # processed VAD files path | |
| output:str=None, # output file name | |
| vq_model:str="collabora/spear-tts-pytorch:whisper-vq-stoks.model", # the model path (use repo_id:filename to download it from hugginface) | |
| n_samples:int=None, # process a limited amount of samples | |
| batch_size:int=1, # process several segments at once | |
| fix_dots:bool=False, # fix dots in file names | |
| ): | |
| if ":" in vq_model: | |
| repo, fname = vq_model.split(":", 1) | |
| vq_model = vq_stoks.RQBottleneckTransformer.load_model(repo, fname).cuda() | |
| else: | |
| vq_model = vq_stoks.RQBottleneckTransformer.load_model(local_filename=vq_model).cuda() | |
| amodel = extract_acoustic.load_model() | |
| amodel.set_target_bandwidth(3) | |
| if input == "-": | |
| input = [f.strip() for f in sys.stdin.readlines()] | |
| assert output, "please provide the output shard name" | |
| else: | |
| if output is None: output = flac_to_s2a_name(input) | |
| input = [input] | |
| total = n_samples//batch_size if n_samples else 'noinfer' | |
| ds = wds.WebDataset(input, shardshuffle=True, rename_files=vad.fix_dots_in_names if fix_dots else None).compose( | |
| wds.decode(wds.torch_audio), | |
| wds.select(lambda x: 'wav' in x or 'flac' in x), | |
| vq_stoks.merge_in(vq_stoks.derived_dataset(proc_dataset_path, 'vad')), | |
| wds.map_dict(**{"vad.npy":wh_transcribe.chunk_merger}), | |
| lambda x: wh_transcribe.split_to_chunks(x), | |
| resampler(), | |
| resampler(16000, 'samples_16k'), | |
| wds.to_tuple('__key__', 'rpad_s', 'samples_16k', 'samples_24k'), | |
| wds.batched(64), | |
| ) | |
| dl = wds.WebLoader(ds, num_workers=4, batch_size=None).unbatched().shuffle(2000).batched(batch_size) | |
| speakers = set() | |
| tmp = output+".tmp" | |
| with wds.TarWriter(tmp) as sink: | |
| for keys, rpad_ss, samples, samples24k in progress_bar(dl, total=total): | |
| with record_function('to_cuda'): | |
| samples, samples24k = samples.cuda(), samples24k.unsqueeze(1).cuda() | |
| with record_function('encodec'): | |
| atoks = amodel.encode(samples24k)[0][0] | |
| with record_function('vq_stoks'): | |
| stoks = vq_model.encode_audio(samples) | |
| with record_function('from_cuda'): | |
| atoks, stoks = atoks.cpu().numpy().astype(np.int16), stoks.cpu().numpy().astype(np.int16) | |
| for key, rpad_s, _atoks, _stoks in zip(keys, rpad_ss, atoks, stoks): | |
| speakers.add(key.split('/')[1]) | |
| sink.write({ | |
| "__key__": key, | |
| "atoks.npy": _atoks[:,:int(-rpad_s * 75)], | |
| "stoks.npy": _stoks[:int(-rpad_s * 25)], | |
| }) | |
| with open(output+".speakers.txt", "w") as f: f.write("\n".join(speakers)) | |
| if not n_samples: | |
| os.rename(tmp, output) | |