from io import BytesIO from urllib.request import urlopen import soundfile import torch from datasets import load_dataset, Audio import numpy as np from transformers import AutoModel, AutoProcessor, BatchFeature from tqdm import tqdm import json import os import time from datetime import datetime from whisper_normalizer.english import EnglishTextNormalizer from whisper_normalizer.basic import BasicTextNormalizer import sacrebleu from jiwer import cer, wer from torch.utils.data import Dataset, DataLoader import soundfile as sf import re from pathlib import Path import opencc from ASRDataset import * converter = opencc.OpenCC('s2tw.json') normalizer = { "en_us" : EnglishTextNormalizer(), "other" : BasicTextNormalizer() } model_id = "/mnt/jeff/InCar/LlamaNemotronOmni/test_nemotron_omni" revision = "main" #"v1.0" model = AutoModel.from_pretrained( model_id, device_map="cuda", revision = revision, trust_remote_code=True ).eval() processor = AutoProcessor.from_pretrained( model_id, revision = revision, trust_remote_code=True ) if 'LlamaNemotronOmni' in model_id: processor.tokenizer.pad_token_id = processor.tokenizer.eos_token_id results_dir = f"evaluation_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}" os.makedirs(results_dir, exist_ok=True) INSTRUCTION = { "ast": "Translate the audio to {0}.", "asr": "Transcribe the audio clip into text.", } def covost_collate_fn_test(batch): input_ids_list = [] input_audio_embeds_list = [] audio_embed_sizes_list = [] audio_attention_mask_list = [] input_modes_list = [] answer_list = [] for inputs in batch: input_ids_list.append(inputs['input_ids'][0]) input_audio_embeds_list.append(inputs['input_audio_embeds']) audio_embed_sizes_list.append(inputs['audio_embed_sizes']) audio_attention_mask_list.append( inputs['input_audio_embeds'].new_full((inputs['input_audio_embeds'].size(1),), True, dtype=torch.bool) ) input_modes_list.append(inputs['input_modes']) answer_list.append(inputs['answer']) try: input_ids = pad_sequence(input_ids_list, padding_side='left', padding_value=0) audio_attention_mask = ( pad_sequence(audio_attention_mask_list, padding_side='right', padding_value=False) if len(audio_attention_mask_list) > 1 else None ) except Exception as e: print(e) print(input_ids_list) print(audio_attention_mask) raise attention_mask = (input_ids != 0).long() input_audio_embeds = cat_with_pad(input_audio_embeds_list, dim=0) audio_embed_sizes = torch.cat(audio_embed_sizes_list) input_modes = torch.cat(input_modes_list) return BatchFeature( { 'input_ids': input_ids, 'attention_mask': attention_mask, 'input_audio_embeds': input_audio_embeds, 'audio_embed_sizes': audio_embed_sizes, 'audio_attention_mask': audio_attention_mask, 'input_modes': input_modes, 'answer': answer_list, } ) def save_results(results, dataset_name, task, source_lang, target_lang=None, sample_idx=None): filename = f"{task}_{dataset_name}_{source_lang}" if target_lang: filename += f"_to_{target_lang}" if sample_idx is not None: filename += f"_sample_{sample_idx}" filepath = os.path.join(results_dir, f"{filename}.json") results["timestamp"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S") with open(filepath, 'w', encoding='utf-8') as f: json.dump(results, f, ensure_ascii=False, indent=2) return filepath def evaluate_task(dataset, source_lang, target_lang, num_samples=-1, batch_size = 4, is_asr=True): task_type = "asr" if is_asr else "translation" eval_lang = source_lang if is_asr else target_lang if eval_lang in normalizer: eval_normalizer = normalizer[eval_lang] else: eval_normalizer = normalizer['other'] sample_results = [] if num_samples > 0 and num_samples < len(dataset): indices = np.random.choice(len(dataset), num_samples, replace=False) dataset = dataset.select(indices) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=covost_collate_fn_test) evaluated_samples = {} for batch_idx, batch in enumerate(tqdm(dataloader)): batch_references = batch.pop("answer") if torch.cuda.is_available(): try: batch = {k: v.to("cuda") for k, v in batch.items()} except: print('error') break with torch.inference_mode(): generate_ids = model.generate(**batch, max_new_tokens=256, #temperature = 1.0, top_p = 0.95, top_k = 64, do_sample=True ) input_lengths = batch['input_ids'].shape[1] generate_ids = generate_ids[:, input_lengths:] batch_predictions = processor.batch_decode( generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False ) for i, (reference, prediction) in enumerate(zip(batch_references, batch_predictions)): idx = batch_idx * batch_size + i sample_result = { "id": idx, "reference": reference, "prediction": converter.convert(prediction) } sample_results.append(sample_result) if (batch_idx + 1) % 10 == 0: temp_results = [] for item in sample_results: sample_id = item["id"] if sample_id in evaluated_samples: temp_item = item.copy() temp_item.update(evaluated_samples[sample_id]) temp_results.append(temp_item) else: temp_item = item.copy() try: ref = eval_normalizer(item["reference"]) pred = eval_normalizer(item["prediction"]) # BLEU, WER/CER utt_bleu = sacrebleu.sentence_bleu(pred, [ref]).score utt_cer = round(cer(re.sub(r"\s+", "", ref), re.sub(r"\s+", "", pred)) * 100, 2) utt_wer = round(wer(ref, pred) * 100, 2) metrics = { "bleu": utt_bleu, "cer": min(100,utt_cer), "wer": utt_wer } evaluated_samples[sample_id] = metrics temp_item.update(metrics) except Exception as e: print(f"Error evaluating sample {sample_id}: {e}") metrics = { "bleu": 0, "cer": 100, "wer": 100, "error": str(e) } evaluated_samples[sample_id] = metrics temp_item.update(metrics) temp_results.append(temp_item) partial_results = { "task": task_type, "source_lang": source_lang, "target_lang": target_lang, "num_samples": len(temp_results), "sample_results": temp_results } save_results(partial_results, dataset.name, task_type, source_lang, target_lang) for item in sample_results: ref = eval_normalizer(item["reference"]) pred = eval_normalizer(item["prediction"]) utt_bleu = sacrebleu.sentence_bleu(pred, [ref]).score utt_cer = round(cer(re.sub(r"\s+", "", ref), re.sub(r"\s+", "", pred)) * 100, 2) utt_wer = round(wer(ref, pred) * 100, 2) item.update({ "bleu": utt_bleu, "cer": min(100,utt_cer), "wer": utt_wer }) avg_bleu = sum(item["bleu"] for item in sample_results) / len(sample_results) avg_cer = sum(item["cer"] for item in sample_results) / len(sample_results) avg_wer = sum(item["wer"] for item in sample_results) / len(sample_results) results = { "dataset": dataset.name, "task": task_type, "source_lang": source_lang, "target_lang": target_lang, "num_samples": len(sample_results), "metrics": { "bleu": avg_bleu, "cer": avg_cer, "wer": avg_wer }, "sample_results": sample_results } save_results(results, dataset.name, task_type, source_lang, target_lang) return results if __name__ == "__main__": source_languages = [ ("en_us", "English"), ] target_languages = [ ("zh-TW", "zh-TW"), ] num_samples = -1 batch_size = 32 for source_lang, target_lang in zip(source_languages, target_languages): print(f"\n===== {source_lang[0]} ASR =====") split = "test" datasets = [] commonvoice_speech_tw = CommonVoiceDataset( processor=processor, source_lang="zh-TW", split=split ) datasets.append(commonvoice_speech_tw) fleurs = FleursDataset( processor=processor, split=split, source_lang="en_us", # English mode="asr" ) datasets.append(fleurs) # Libri Speech Clean ASR mode (English -> English text) # libri_speech_clean = LibriSpeechDataset( # processor=processor, # subset="clean", # split=split # ) # datasets.append(libri_speech_clean) # # Libri Speech Other ASR mode (English -> English text) # libri_speech_other = LibriSpeechDataset( # processor=processor, # subset="other", # split=split # ) # datasets.append(libri_speech_other) # Fleurs ASR mode (English -> English text) for dataset in datasets: # ASR asr_results = evaluate_task(dataset, source_lang[0], target_lang[0], num_samples, batch_size=batch_size, is_asr = True) print(f"\n=== {asr_results.get('dataset', 'Dataset')} | {source_lang[0]} ASR===") print(f"BLEU: {asr_results.get('metrics', {}).get('bleu', 'N/A')}") print(f"WER: {asr_results.get('metrics', {}).get('wer', 'N/A')}") print(f"CER: {asr_results.get('metrics', {}).get('cer', 'N/A')}")