| import torch |
| import torchaudio |
| from datasets import load_dataset, load_metric |
| from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, set_seed |
| import argparse |
| from pyctcdecode import build_ctcdecoder |
| from multiprocessing import Pool |
| import re |
|
|
| """ |
| This is the script to run the prediction on test set of Turkish speech dataset. |
| Usage: |
| python run_evaluation.y -m <wav2vec2 model_name> -d <Zindi dataset directory> -o <output file name> \ |
| -b <optional batch size, default=16> |
| """ |
|
|
|
|
| class KenLM: |
| def __init__(self, tokenizer, model_name, unigrams=None, num_workers=8, beam_width=128): |
| self.num_workers = num_workers |
| self.beam_width = beam_width |
| vocab_dict = tokenizer.get_vocab() |
| self.vocabulary = [x[0] for x in sorted(vocab_dict.items(), key=lambda x: x[1], reverse=False)] |
| self.vocabulary = self.vocabulary[:-1] |
| self.decoder = build_ctcdecoder(self.vocabulary, model_name) |
|
|
| @staticmethod |
| def lm_postprocess(text): |
| return ' '.join([x if len(x) > 1 else "" for x in text.split()]).strip() |
|
|
| def decode(self, logits): |
| probs = logits.cpu().numpy() |
| |
| with Pool(self.num_workers) as pool: |
| text = self.decoder.decode_batch(pool, probs) |
| text = [KenLM.lm_postprocess(x) for x in text] |
| return text |
|
|
| chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"\“\‘\”\'\`…\’»«]' |
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("-m", "--model_name", type=str, required=True, |
| help="The wav2vec2 model name") |
| parser.add_argument("-n", "--name", type=str, required=True, |
| help="The name of dataset") |
| parser.add_argument("-c", "--config_name", type=str, required=True, |
| help="The config name of the dataset") |
| parser.add_argument("-d", "--data_dir", type=str, required=False, default=None, |
| help="The directory contains the dataset") |
| parser.add_argument("-b", "--batch_size", type=int, required=False, default=16, |
| help="Batch size") |
| parser.add_argument("-k", "--kenlm", type=str, required=False, default=False, |
| help="Path to KenLM model") |
| parser.add_argument("-u", "--unigrams", type=str, required=False, default=None, |
| help="Path to unigrams file") |
| parser.add_argument("--num_workers", type=int, required=False, default=8, |
| help="KenLM's number of workers") |
| parser.add_argument("-w", "--beam_width", type=int, required=False, default=128, |
| help="KenLM's beam width") |
| parser.add_argument("--test_pct", type=int, required=False, default=100, |
| help="Percentage of the test set") |
| parser.add_argument("--cpu", required=False, action='store_true', |
| help="Force to use CPU") |
| args = parser.parse_args() |
|
|
| set_seed(42) |
| processor = Wav2Vec2Processor.from_pretrained(args.model_name) |
| model = Wav2Vec2ForCTC.from_pretrained(args.model_name) |
| kenlm = None |
| if args.kenlm: |
| kenlm = KenLM(processor.tokenizer, args.kenlm, args.unigrams) |
|
|
| |
| |
|
|
| def speech_file_to_array_fn(batch): |
| if "audio" in batch: |
| speech_array = torch.tensor(batch["audio"]["array"]) |
| resampler = torchaudio.transforms.Resample(48_000, 16_000) |
| else: |
| speech_array, sampling_rate = torchaudio.load(batch["path"]) |
| resampler = torchaudio.transforms.Resample(sampling_rate, 16_000) |
| batch["speech"] = resampler(speech_array).squeeze().numpy() |
| return batch |
|
|
| def remove_special_characters(batch): |
| batch["norm_text"] = re.sub(chars_to_ignore_regex, "", batch["sentence"]).lower().strip() |
| return batch |
|
|
| lg_test = load_dataset(args.name, args.config_name, data_dir=args.data_dir, |
| split=f"test[:{args.test_pct}%]", use_auth_token=True) |
| if args.cpu: |
| device = "cpu" |
| else: |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| lg_test = lg_test.map(speech_file_to_array_fn) |
| lg_test = lg_test.map(remove_special_characters) |
| model = model.to(device) |
| wer = load_metric("wer") |
|
|
| batch_size = args.batch_size |
| def evaluate(batch): |
| inputs = processor(batch["speech"], sampling_rate=16_000, return_tensors="pt", padding=True) |
|
|
| with torch.no_grad(): |
| logits = model(inputs.input_values.to(device)).logits |
|
|
| if args.kenlm: |
| batch["pred_strings"] = kenlm.decode(logits) |
| else: |
| predicted_ids = torch.argmax(logits, dim=-1) |
| batch["pred_strings"] = processor.batch_decode(predicted_ids) |
| return batch |
|
|
| result = lg_test.map(evaluate, batched=True, batch_size=batch_size) |
| WER = 100 * wer.compute(predictions=result["pred_strings"], references=result["norm_text"]) |
| print(f"WER: {WER:.2f}%") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|