File size: 3,602 Bytes
5403e87 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 | import torch
import numpy as np
import re
import sentencepiece as spm
import kenlm
from data_loading import test_dataloader
from squeezeformer import MySqueezeformer
from torchmetrics.functional import word_error_rate, char_error_rate
from torchaudio.models.decoder import ctc_decoder
# -------------------------
# Device
# -------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# -------------------------
# Tokenizers & LM
# -------------------------
sp = spm.SentencePieceProcessor()
sp.Load("ressources/tokenizer/128_v7.model")
sp_lm = spm.SentencePieceProcessor()
sp_lm.Load("ressources/tokenizer/5K.model")
lm = kenlm.Model("./ressources/kenLM_model/kab_5k_trigram.bin")
# -------------------------
# Model
# -------------------------
MODEL = MySqueezeformer().to(device)
MODEL.load_state_dict(
torch.load("ressources/e2e_model/squeezeformer", map_location=device), strict=False
)
MODEL.eval()
# -------------------------
# Decoder
# -------------------------
decoder = ctc_decoder(
tokens="ressources/tokenizer/128_v7.txt",
lexicon=None,
beam_size=1,
beam_threshold=1,
beam_size_token=1,
nbest=1,
log_add=True,
blank_token="_",
sil_token="|",
unk_word="<unk>",
)
# -------------------------
# Helpers
# -------------------------
def clean_text(tokens):
text = "".join(tokens)
text = text.replace("_", "")
text = text.replace("|", "")
text = text.replace("▁", " ")
text = " ".join(text.split())
text = re.sub(r"-{2,}", "-", text)
return text.strip()
@torch.no_grad()
def evaluate():
all_transcriptions = []
all_targets = []
for batch in test_dataloader:
if batch is None:
continue
inputs, targets, input_lengths, target_lengths = batch
inputs = inputs.to(device)
input_lengths = input_lengths.to(device)
# ---- Forward ----
outputs, _ = MODEL.forward(inputs, input_lengths)
# decoder expects CPU
outputs = outputs.cpu()
# ---- Decode batch directly (faster) ----
batch_results = decoder(outputs)
# ---- Targets ----
for i in range(len(targets)):
tgt = targets[i][: target_lengths[i]].tolist()
target_sentence = sp.Decode(tgt)
all_targets.append(target_sentence)
# ---- Predictions ----
for results_array in batch_results:
transcriptions = []
scores = []
for result in results_array:
tokens = decoder.idxs_to_tokens(result.tokens)
transcription = clean_text(tokens)
transcriptions.append(transcription)
# ---- LM scoring ----
lm_input = " ".join(sp_lm.Encode(transcription, out_type=str))
lm_input = lm_input.replace("- ", "-").replace(" -", "-")
lm_score = lm.score(lm_input)
score = lm_score * 0.25 + result.score * 0.75
scores.append(score)
best_idx = int(np.argmax(scores))
best_transcription = transcriptions[best_idx]
print(best_transcription)
all_transcriptions.append(best_transcription)
# -------------------------
# Metrics
# -------------------------
wer = word_error_rate(all_transcriptions, all_targets)
cer = char_error_rate(all_transcriptions, all_targets)
print(f"Average Word Error Rate: {wer * 100:.2f}%")
print(f"Average Character Error Rate: {cer * 100:.2f}%")
if __name__ == "__main__":
evaluate()
|