| import os |
| import numpy as np |
| import re |
| import argparse |
|
|
| os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["CUDA_VISIBLE_DEVICES"].replace("CUDA", "") |
|
|
| from transformers import pipeline |
| from transformers.models.whisper.english_normalizer import BasicTextNormalizer |
| from datasets import load_dataset, Audio |
|
|
| whisper_norm = BasicTextNormalizer() |
|
|
| def simple_norm(utt): |
| norm_utt = re.sub(r'[^\w\s]', '', utt) |
| norm_utt = " ".join(norm_utt.split()) |
| norm_utt = norm_utt.lower() |
| return norm_utt |
|
|
| def data(dataset): |
| for i, item in enumerate(dataset): |
| yield {**item["audio"], "reference": item["text"], "utt_id": item["id"]} |
|
|
| def get_ckpt(path, ckpt_id): |
| if ckpt_id != 0: |
| model = os.path.join(path, "checkpoint-%i" % ckpt) |
| else: |
| dirs = [d for d in os.listdir(path) if d.startswith("checkpoint-")] |
| ckpts = [int(d.split('-')[-1]) for d in dirs] |
| last_ckpt = sorted(ckpts)[-1] |
| model = os.path.join(path, "checkpoint-%s" % last_ckpt) |
| return model |
|
|
| def main(args): |
| batch_size = args.batch_size |
| |
| if args.device == "cpu": |
| device_id = -1 |
| elif args.device == "gpu": |
| device_id = 0 |
| else: |
| raise NotImplementedError("unknown device %s, should be cpu/gpu" % args.device) |
|
|
| model_dir = os.path.join(args.expdir, args.model_size) |
| |
| |
| |
| model = model_dir |
| |
| |
| whisper_asr = pipeline( |
| "automatic-speech-recognition", model=model, device=device_id |
| ) |
|
|
| whisper_asr.model.config.forced_decoder_ids = ( |
| whisper_asr.tokenizer.get_decoder_prompt_ids( |
| language=args.language, task="transcribe" |
| ) |
| ) |
|
|
| if args.dataset == 'cgn-dev': |
| dataset_path = "./cgn-dev/cgn-dev.py" |
| elif args.dataset == 'subs-annot': |
| dataset_path = "./subs-annot/subs-annot.py" |
| else: |
| raise NotImplementedError('unknown dataset %s' % args.dataset) |
|
|
| cache_dir = "/esat/audioslave/jponcele/hf_cache" |
| dataset = load_dataset(dataset_path, name="raw", split="test", cache_dir=cache_dir, streaming=True) |
| dataset = dataset.cast_column("audio", Audio(sampling_rate=16000)) |
|
|
| utterances = [] |
| predictions = [] |
| references = [] |
|
|
| |
| for out in whisper_asr(data(dataset), batch_size=batch_size): |
| predictions.append(out["text"]) |
| utterances.append(out["utt_id"][0]) |
| references.append(out["reference"][0]) |
| |
|
|
| result_dir = os.path.join(args.expdir, "results", args.dataset) |
| os.makedirs(result_dir, exist_ok=True) |
|
|
| with open(os.path.join(result_dir, "whisper_%s.txt" % args.model_size), "w") as pd: |
| for i, utt in enumerate(utterances): |
| pred = predictions[i] |
| pd.write(utt + ' ' + pred + '\n') |
|
|
| with open(os.path.join(result_dir, "whisper_%s_normW.txt" % args.model_size), "w") as pd: |
| for i, utt in enumerate(utterances): |
| pred = whisper_norm(predictions[i]) |
| pd.write(utt + ' ' + pred + '\n') |
|
|
| with open(os.path.join(result_dir, "whisper_%s_normS.txt" % args.model_size), "w") as pd: |
| for i, utt in enumerate(utterances): |
| pred = simple_norm(predictions[i]) |
| pd.write(utt + ' ' + pred + '\n') |
|
|
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
|
|
| parser.add_argument( |
| "--expdir", |
| type=str, |
| default="/esat/audioslave/jponcele/whisper/finetuning_event/CGN", |
| help="Directory with finetuned models", |
| ) |
| parser.add_argument( |
| "--model_size", |
| type=str, |
| default="tiny", |
| help="Model size", |
| ) |
| parser.add_argument( |
| "--checkpoint", |
| type=int, |
| default=0, |
| help="Load specific checkpoint. 0 means latest", |
| ) |
| parser.add_argument( |
| "--dataset", |
| type=str, |
| default="cgn-dev", |
| help="Dataset name to evaluate the `model_id`. Should be loadable with 🤗 Datasets", |
| ) |
| parser.add_argument( |
| "--device", |
| type=str, |
| default="cpu", |
| help="cpu/gpu", |
| ) |
| parser.add_argument( |
| "--batch_size", |
| type=int, |
| default=16, |
| help="Number of samples to go through each streamed batch.", |
| ) |
| parser.add_argument( |
| "--language", |
| type=str, |
| default="dutch", |
| help="Two letter language code for the transcription language, e.g. use 'en' for English.", |
| ) |
|
|
| args = parser.parse_args() |
| main(args) |
|
|
|
|