LNO / eval.py
jva96160's picture
Upload 22 files
3284d90 verified
from io import BytesIO
from urllib.request import urlopen
import soundfile
import torch
from datasets import load_dataset, Audio
import numpy as np
from transformers import AutoModel, AutoProcessor, BatchFeature
from tqdm import tqdm
import json
import os
import time
from datetime import datetime
from whisper_normalizer.english import EnglishTextNormalizer
from whisper_normalizer.basic import BasicTextNormalizer
import sacrebleu
from jiwer import cer, wer
from torch.utils.data import Dataset, DataLoader
import soundfile as sf
import re
from pathlib import Path
import opencc
from ASRDataset import *
converter = opencc.OpenCC('s2tw.json')
normalizer = {
"en_us" : EnglishTextNormalizer(),
"other" : BasicTextNormalizer()
}
model_id = "/mnt/jeff/InCar/LlamaNemotronOmni/test_nemotron_omni"
revision = "main" #"v1.0"
model = AutoModel.from_pretrained(
model_id, device_map="cuda", revision = revision, trust_remote_code=True
).eval()
processor = AutoProcessor.from_pretrained(
model_id, revision = revision, trust_remote_code=True
)
if 'LlamaNemotronOmni' in model_id:
processor.tokenizer.pad_token_id = processor.tokenizer.eos_token_id
results_dir = f"evaluation_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
os.makedirs(results_dir, exist_ok=True)
INSTRUCTION = {
"ast": "Translate the audio to {0}.",
"asr": "Transcribe the audio clip into text.",
}
def covost_collate_fn_test(batch):
input_ids_list = []
input_audio_embeds_list = []
audio_embed_sizes_list = []
audio_attention_mask_list = []
input_modes_list = []
answer_list = []
for inputs in batch:
input_ids_list.append(inputs['input_ids'][0])
input_audio_embeds_list.append(inputs['input_audio_embeds'])
audio_embed_sizes_list.append(inputs['audio_embed_sizes'])
audio_attention_mask_list.append(
inputs['input_audio_embeds'].new_full((inputs['input_audio_embeds'].size(1),), True, dtype=torch.bool)
)
input_modes_list.append(inputs['input_modes'])
answer_list.append(inputs['answer'])
try:
input_ids = pad_sequence(input_ids_list, padding_side='left', padding_value=0)
audio_attention_mask = (
pad_sequence(audio_attention_mask_list, padding_side='right', padding_value=False)
if len(audio_attention_mask_list) > 1
else None
)
except Exception as e:
print(e)
print(input_ids_list)
print(audio_attention_mask)
raise
attention_mask = (input_ids != 0).long()
input_audio_embeds = cat_with_pad(input_audio_embeds_list, dim=0)
audio_embed_sizes = torch.cat(audio_embed_sizes_list)
input_modes = torch.cat(input_modes_list)
return BatchFeature(
{
'input_ids': input_ids,
'attention_mask': attention_mask,
'input_audio_embeds': input_audio_embeds,
'audio_embed_sizes': audio_embed_sizes,
'audio_attention_mask': audio_attention_mask,
'input_modes': input_modes,
'answer': answer_list,
}
)
def save_results(results, dataset_name, task, source_lang, target_lang=None, sample_idx=None):
filename = f"{task}_{dataset_name}_{source_lang}"
if target_lang:
filename += f"_to_{target_lang}"
if sample_idx is not None:
filename += f"_sample_{sample_idx}"
filepath = os.path.join(results_dir, f"{filename}.json")
results["timestamp"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(results, f, ensure_ascii=False, indent=2)
return filepath
def evaluate_task(dataset, source_lang, target_lang, num_samples=-1, batch_size = 4, is_asr=True):
task_type = "asr" if is_asr else "translation"
eval_lang = source_lang if is_asr else target_lang
if eval_lang in normalizer:
eval_normalizer = normalizer[eval_lang]
else:
eval_normalizer = normalizer['other']
sample_results = []
if num_samples > 0 and num_samples < len(dataset):
indices = np.random.choice(len(dataset), num_samples, replace=False)
dataset = dataset.select(indices)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=covost_collate_fn_test)
evaluated_samples = {}
for batch_idx, batch in enumerate(tqdm(dataloader)):
batch_references = batch.pop("answer")
if torch.cuda.is_available():
try:
batch = {k: v.to("cuda") for k, v in batch.items()}
except:
print('error')
break
with torch.inference_mode():
generate_ids = model.generate(**batch,
max_new_tokens=256,
#temperature = 1.0, top_p = 0.95, top_k = 64, do_sample=True
)
input_lengths = batch['input_ids'].shape[1]
generate_ids = generate_ids[:, input_lengths:]
batch_predictions = processor.batch_decode(
generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
for i, (reference, prediction) in enumerate(zip(batch_references, batch_predictions)):
idx = batch_idx * batch_size + i
sample_result = {
"id": idx,
"reference": reference,
"prediction": converter.convert(prediction)
}
sample_results.append(sample_result)
if (batch_idx + 1) % 10 == 0:
temp_results = []
for item in sample_results:
sample_id = item["id"]
if sample_id in evaluated_samples:
temp_item = item.copy()
temp_item.update(evaluated_samples[sample_id])
temp_results.append(temp_item)
else:
temp_item = item.copy()
try:
ref = eval_normalizer(item["reference"])
pred = eval_normalizer(item["prediction"])
# BLEU, WER/CER
utt_bleu = sacrebleu.sentence_bleu(pred, [ref]).score
utt_cer = round(cer(re.sub(r"\s+", "", ref), re.sub(r"\s+", "", pred)) * 100, 2)
utt_wer = round(wer(ref, pred) * 100, 2)
metrics = {
"bleu": utt_bleu,
"cer": min(100,utt_cer),
"wer": utt_wer
}
evaluated_samples[sample_id] = metrics
temp_item.update(metrics)
except Exception as e:
print(f"Error evaluating sample {sample_id}: {e}")
metrics = {
"bleu": 0,
"cer": 100,
"wer": 100,
"error": str(e)
}
evaluated_samples[sample_id] = metrics
temp_item.update(metrics)
temp_results.append(temp_item)
partial_results = {
"task": task_type,
"source_lang": source_lang,
"target_lang": target_lang,
"num_samples": len(temp_results),
"sample_results": temp_results
}
save_results(partial_results, dataset.name, task_type, source_lang, target_lang)
for item in sample_results:
ref = eval_normalizer(item["reference"])
pred = eval_normalizer(item["prediction"])
utt_bleu = sacrebleu.sentence_bleu(pred, [ref]).score
utt_cer = round(cer(re.sub(r"\s+", "", ref), re.sub(r"\s+", "", pred)) * 100, 2)
utt_wer = round(wer(ref, pred) * 100, 2)
item.update({
"bleu": utt_bleu,
"cer": min(100,utt_cer),
"wer": utt_wer
})
avg_bleu = sum(item["bleu"] for item in sample_results) / len(sample_results)
avg_cer = sum(item["cer"] for item in sample_results) / len(sample_results)
avg_wer = sum(item["wer"] for item in sample_results) / len(sample_results)
results = {
"dataset": dataset.name,
"task": task_type,
"source_lang": source_lang,
"target_lang": target_lang,
"num_samples": len(sample_results),
"metrics": {
"bleu": avg_bleu,
"cer": avg_cer,
"wer": avg_wer
},
"sample_results": sample_results
}
save_results(results, dataset.name, task_type, source_lang, target_lang)
return results
if __name__ == "__main__":
source_languages = [
("en_us", "English"),
]
target_languages = [
("zh-TW", "zh-TW"),
]
num_samples = -1
batch_size = 32
for source_lang, target_lang in zip(source_languages, target_languages):
print(f"\n===== {source_lang[0]} ASR =====")
split = "test"
datasets = []
commonvoice_speech_tw = CommonVoiceDataset(
processor=processor,
source_lang="zh-TW",
split=split
)
datasets.append(commonvoice_speech_tw)
fleurs = FleursDataset(
processor=processor,
split=split,
source_lang="en_us", # English
mode="asr"
)
datasets.append(fleurs)
# Libri Speech Clean ASR mode (English -> English text)
# libri_speech_clean = LibriSpeechDataset(
# processor=processor,
# subset="clean",
# split=split
# )
# datasets.append(libri_speech_clean)
# # Libri Speech Other ASR mode (English -> English text)
# libri_speech_other = LibriSpeechDataset(
# processor=processor,
# subset="other",
# split=split
# )
# datasets.append(libri_speech_other)
# Fleurs ASR mode (English -> English text)
for dataset in datasets:
# ASR
asr_results = evaluate_task(dataset, source_lang[0], target_lang[0], num_samples, batch_size=batch_size, is_asr = True)
print(f"\n=== {asr_results.get('dataset', 'Dataset')} | {source_lang[0]} ASR===")
print(f"BLEU: {asr_results.get('metrics', {}).get('bleu', 'N/A')}")
print(f"WER: {asr_results.get('metrics', {}).get('wer', 'N/A')}")
print(f"CER: {asr_results.get('metrics', {}).get('cer', 'N/A')}")