| import torch |
| import numpy as np |
| import re |
| import sentencepiece as spm |
| import kenlm |
|
|
| from data_loading import test_dataloader |
| from squeezeformer import MySqueezeformer |
| from torchmetrics.functional import word_error_rate, char_error_rate |
| from torchaudio.models.decoder import ctc_decoder |
|
|
| |
| |
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| |
| |
| |
| sp = spm.SentencePieceProcessor() |
| sp.Load("ressources/tokenizer/128_v7.model") |
|
|
| sp_lm = spm.SentencePieceProcessor() |
| sp_lm.Load("ressources/tokenizer/5K.model") |
|
|
| lm = kenlm.Model("./ressources/kenLM_model/kab_5k_trigram.bin") |
|
|
| |
| |
| |
| MODEL = MySqueezeformer().to(device) |
|
|
| MODEL.load_state_dict( |
| torch.load("ressources/e2e_model/squeezeformer", map_location=device), strict=False |
| ) |
|
|
| MODEL.eval() |
|
|
| |
| |
| |
| decoder = ctc_decoder( |
| tokens="ressources/tokenizer/128_v7.txt", |
| lexicon=None, |
| beam_size=1, |
| beam_threshold=1, |
| beam_size_token=1, |
| nbest=1, |
| log_add=True, |
| blank_token="_", |
| sil_token="|", |
| unk_word="<unk>", |
| ) |
|
|
|
|
| |
| |
| |
| def clean_text(tokens): |
| text = "".join(tokens) |
| text = text.replace("_", "") |
| text = text.replace("|", "") |
| text = text.replace("▁", " ") |
| text = " ".join(text.split()) |
| text = re.sub(r"-{2,}", "-", text) |
| return text.strip() |
|
|
|
|
| @torch.no_grad() |
| def evaluate(): |
| all_transcriptions = [] |
| all_targets = [] |
|
|
| for batch in test_dataloader: |
| if batch is None: |
| continue |
|
|
| inputs, targets, input_lengths, target_lengths = batch |
|
|
| inputs = inputs.to(device) |
| input_lengths = input_lengths.to(device) |
|
|
| |
| outputs, _ = MODEL.forward(inputs, input_lengths) |
|
|
| |
| outputs = outputs.cpu() |
|
|
| |
| batch_results = decoder(outputs) |
|
|
| |
| for i in range(len(targets)): |
| tgt = targets[i][: target_lengths[i]].tolist() |
| target_sentence = sp.Decode(tgt) |
| all_targets.append(target_sentence) |
|
|
| |
| for results_array in batch_results: |
| transcriptions = [] |
| scores = [] |
|
|
| for result in results_array: |
| tokens = decoder.idxs_to_tokens(result.tokens) |
|
|
| transcription = clean_text(tokens) |
| transcriptions.append(transcription) |
|
|
| |
| lm_input = " ".join(sp_lm.Encode(transcription, out_type=str)) |
| lm_input = lm_input.replace("- ", "-").replace(" -", "-") |
|
|
| lm_score = lm.score(lm_input) |
| score = lm_score * 0.25 + result.score * 0.75 |
| scores.append(score) |
|
|
| best_idx = int(np.argmax(scores)) |
| best_transcription = transcriptions[best_idx] |
|
|
| print(best_transcription) |
| all_transcriptions.append(best_transcription) |
|
|
| |
| |
| |
| wer = word_error_rate(all_transcriptions, all_targets) |
| cer = char_error_rate(all_transcriptions, all_targets) |
|
|
| print(f"Average Word Error Rate: {wer * 100:.2f}%") |
| print(f"Average Character Error Rate: {cer * 100:.2f}%") |
|
|
|
|
| if __name__ == "__main__": |
| evaluate() |
|
|