File size: 4,763 Bytes
2675a94 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
from transformers import AutoTokenizer, GPT2LMHeadModel
from datasets import load_dataset, Dataset, DatasetDict
import random
import string
import torch
from torchmetrics.text import WordErrorRate, CharErrorRate
wer = WordErrorRate()
cer = CharErrorRate()
def process(text):
# Lower case every letter
text = text.lower()
# Remove punctuation
punctuation_to_remove = string.punctuation.replace("'", "")
translation_table = str.maketrans('', '', punctuation_to_remove)
text = text.translate(translation_table)
# Remove whitespaces from front and behind
while text[0] == ' ' or text[-1] == ' ':
if text[0] == ' ':
text = text[1:]
if text[-1] == ' ':
text = text[:-1]
return text
import jiwer
from edit_distance import SequenceMatcher
def correct_text(text):
transforms = jiwer.Compose(
[
jiwer.ExpandCommonEnglishContractions(),
jiwer.ToLowerCase(),
jiwer.RemoveMultipleSpaces(),
jiwer.Strip(),
jiwer.RemovePunctuation(),
jiwer.ReduceToListOfListOfWords(),
]
)
return transforms(text)
def align_gt_asr(gt, asr):
sm = SequenceMatcher(a=gt, b=asr)
best_path = []
opcodes = sm.get_opcodes()
for tag, i1, i2, j1, j2 in opcodes:
if tag == "delete":
for i in range(i1, i2):
best_path.append([gt[i], ""])
if tag == "replace" or tag == "equal":
for i, j in zip(range(i1, i2), range(j1, j2)):
best_path.append([gt[i], asr[j]])
if tag == "insert":
for j in range(j1, j2):
best_path.append(["", asr[j]])
return best_path
dtype = torch.float16
dataset_name = "./../libripseech_tokenized"
dataset = DatasetDict.load_from_disk(dataset_name)
with open("./../prompting/blist/all_rare_words.txt") as fin:
rarewords = [process(word.strip()) for word in fin]
tokenizer = AutoTokenizer.from_pretrained("./../tokenizer")
tokenizer.pad_token_id = 0
tokenizer.pad_token = "<|padding|>"
tokenizer.padding_side = "left"
# Adding new tokens for introducing prompts
tokenizer.add_tokens(["<|startofprompt|>", "<|sepofprompt|>", "<|endofprompt|>"])
sot_token = tokenizer.encode("<|startoftranscript|>")[0]
eot_token = tokenizer.encode("<|endoftranscript|>")[0]
from math import ceil
from tqdm import tqdm
val_bs = 32
n_bwords = 25
context_length = 2048
def prepare(element):
# Add audio
audio_tkns = element["audio_tokens"]
data = "".join([f"<|audio:{tkn}|>" for tkn in audio_tkns])
# sample context words and mix with the biasing list
b_words = element["b_words"]
if n_bwords > len(b_words):
context = b_words + random.sample(rarewords, n_bwords - len(b_words))
else:
context = random.sample(b_words, n_bwords)
random.shuffle(context)
# add the context words
data += "<|startofprompt|>" + "<|sepofprompt|>".join(context) + "<|endofprompt|>"
# Add text
data += "<|startoftranscript|>"
return {"data": data, "context": context}
@torch.no_grad()
def evaluate_model(model):
transcripts = []
processed_data = dataset["test.clean"].map(prepare)
data = processed_data["data"]
for idx in tqdm(range(ceil(len(data)/val_bs))):
outputs = tokenizer(data[idx * val_bs: (idx + 1) * val_bs], truncation=False, max_length=None, padding=True, return_tensors="pt").to(model.device)
input_ids = outputs["input_ids"]
par = input_ids.shape[-1]
generations = model.generate(
input_ids,
max_new_tokens=context_length - par - 1,
eos_token_id = eot_token
)
transcripts += tokenizer.batch_decode(generations[:, par:], skip_special_tokens=True)
bias_word_cnt = 0
normal_word_cnt = 0
u_wer = 0.0
b_wer = 0.0
pred_list = correct_text(transcripts)
text_list = correct_text(processed_data["text"])
prompt_list = processed_data["context"]
for a, b, c in zip(pred_list, text_list, prompt_list):
aligned_pair = align_gt_asr(b, a)
for gt_word, asr_word in aligned_pair:
if gt_word in c or asr_word in c:
if gt_word != asr_word:
b_wer += 1.0
if gt_word in c:
bias_word_cnt += 1
else:
if gt_word != asr_word:
u_wer += 1.0
if gt_word != "":
normal_word_cnt += 1
u_wer = u_wer / normal_word_cnt * 100
b_wer = b_wer / bias_word_cnt * 100
return wer(transcripts, processed_data["text"]).item() * 100, cer(transcripts, processed_data["text"]).item() * 100, b_wer, u_wer |