Spaces:
Runtime error
Runtime error
| # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/2A. Whisper quantization dataset preparation.ipynb. | |
| # %% auto 0 | |
| __all__ = [] | |
| # %% ../nbs/2A. Whisper quantization dataset preparation.ipynb 3 | |
| import os | |
| import io | |
| import time | |
| import torch | |
| import torchaudio | |
| # %% ../nbs/2A. Whisper quantization dataset preparation.ipynb 4 | |
| from pathlib import Path | |
| import json | |
| from fastprogress import progress_bar, master_bar | |
| import numpy as np | |
| import random | |
| import whisper | |
| from torch import nn | |
| import torch.nn.functional as F | |
| from torch.utils.data.dataloader import DataLoader | |
| from fastcore.script import * | |
| from . import vad | |
| import webdataset as wds | |
| # %% ../nbs/2A. Whisper quantization dataset preparation.ipynb 9 | |
| # let's make it a bit more conservative | |
| # with full 30 second chunks it sometimes misses a small part of the transcript | |
| def random_cutter(dur): | |
| if random.random() < 0.5: | |
| return dur > 28 * (random.random()*0.95+0.05) | |
| else: | |
| return dur > 28 | |
| def chunk_merger(segments, should_cut=lambda x: x > 28): | |
| if len(segments) == 0: return segments | |
| curr_start = segments[0][0] | |
| curr_end = 0 | |
| merged = [] | |
| for ts,te in segments: | |
| if should_cut(te - curr_start) and curr_end - curr_start > 0: | |
| merged.append((curr_start, curr_end)) | |
| curr_start = ts | |
| curr_end = te | |
| merged.append((curr_start, curr_end)) | |
| return merged | |
| # %% ../nbs/2A. Whisper quantization dataset preparation.ipynb 18 | |
| def merge_in(*datasets): | |
| """Merge multiple datasets into the current one returning samples with the union of keys. | |
| It requires (and validates) all datasets to have the same ordering of keys so you have | |
| to use it before any sample shuffling. Shard shuffling is ok. | |
| """ | |
| def merge_loop(main_samples): | |
| for samples in zip(*[main_samples]+[iter(x) for x in datasets]): | |
| key = samples[0]['__key__'] | |
| news = {} | |
| for s in samples: | |
| assert s['__key__'] == key | |
| news.update(s) | |
| yield news | |
| return merge_loop | |
| # %% ../nbs/2A. Whisper quantization dataset preparation.ipynb 19 | |
| import copy | |
| # %% ../nbs/2A. Whisper quantization dataset preparation.ipynb 20 | |
| # a workaround for https://github.com/webdataset/webdataset/issues/297 | |
| # should be possible to use ds.compose here | |
| def wds_compose(ds, *args): | |
| ds = copy.copy(ds) | |
| ds.pipeline = copy.copy(ds.pipeline) | |
| for f in args: | |
| ds.append(f) | |
| return ds | |
| # %% ../nbs/2A. Whisper quantization dataset preparation.ipynb 24 | |
| def split_to_chunks(stream, pad_to_seconds=30, random_shift=False): | |
| for s in stream: | |
| audio, sr = s.get('flac', s.get('wav', (None, None))) | |
| if audio is None: | |
| print(f"warning: '{s['__key__']}' does not contain an audio file") | |
| continue | |
| imax = len(s['vad.npy']) - 1 | |
| for i,(ts,te) in enumerate(s['vad.npy']): | |
| samples = audio[0,int(ts*sr):int(te*sr)] | |
| if pad_to_seconds is not None: | |
| padding = pad_to_seconds*sr-samples.shape[-1] | |
| lpad = random.randint(0, padding) if random_shift else 0 | |
| samples = F.pad(samples, (lpad, padding-lpad)) | |
| yield {"__key__": s['__key__'] + f"_{i:03d}", | |
| "__url__": s['__url__'], | |
| "i": i, "imax": imax, | |
| "tstart": ts, "tend": te, "total_seconds": audio.shape[-1]/sr, | |
| "lpad": lpad, "rpad": padding-lpad, | |
| "lpad_s": lpad/sr, "rpad_s": (padding-lpad)/sr, | |
| "samples": samples, "sample_rate": sr} | |
| # %% ../nbs/2A. Whisper quantization dataset preparation.ipynb 38 | |
| def flac_to_txt_name(input, model_size): | |
| return input.rsplit("/", 1)[1].replace('flac', f'{model_size}-txt') + ".gz" | |
| def process_shard( | |
| input:str, # input shard URL/path | |
| output:str=None, # output shard URL/path | |
| bs:int=None, # batch size (16 uses around 11GB of VRAM) | |
| n_samples:int=None, # limit the number of samples (useful for quick benchmarking) | |
| whisper_model:str="base.en" # Whisper model size | |
| ): | |
| if output is None: output = flac_to_txt_name(input, whisper_model) | |
| if bs is None: bs = 16 | |
| if n_samples is None: n_samples = 'noinfer' | |
| else: n_samples = n_samples // bs | |
| ds = wds_compose(vad.load_dataset(input), | |
| merge_in(wds.WebDataset(vad.flac_to_vad_name(input)).decode()), | |
| wds.map_dict(**{"vad.npy":chunk_merger}), | |
| split_to_chunks, | |
| wds.to_tuple('__key__', 'samples'), | |
| wds.batched(bs), | |
| ) | |
| dl = DataLoader(ds, num_workers=2, batch_size=None) | |
| whmodel = whisper.load_model(whisper_model) | |
| decoding_options = whisper.DecodingOptions(language='en') | |
| tmp = output+".tmp" | |
| with wds.TarWriter(tmp) as sink: | |
| for keys, samples in progress_bar(dl, total=n_samples): | |
| with torch.no_grad(): | |
| embs = whmodel.encoder(whisper.log_mel_spectrogram(samples).cuda()) | |
| decs = whmodel.decode(embs, decoding_options) | |
| for key, dec in zip(keys, decs): | |
| sink.write({ | |
| "__key__": key, | |
| "txt": dec.text, | |
| }) | |
| os.rename(tmp, output) | |