from io import BytesIO from urllib.request import urlopen import soundfile import torch from datasets import load_dataset, Audio import numpy as np from transformers import AutoModel, AutoProcessor, BatchFeature from tqdm import tqdm import json import os import time from datetime import datetime from whisper_normalizer.english import EnglishTextNormalizer from whisper_normalizer.basic import BasicTextNormalizer import sacrebleu from jiwer import cer, wer from torch.utils.data import Dataset, DataLoader import soundfile as sf import re from pathlib import Path import opencc from ASRDataset import * converter = opencc.OpenCC('s2tw.json') normalizer = { "en_us" : EnglishTextNormalizer(), "other" : BasicTextNormalizer() } model_id = "/mnt/jeff/gemma_test" revision = "main" #"v1.0" model = AutoModel.from_pretrained( model_id, device_map="cuda", revision = revision, trust_remote_code=True ).eval() processor = AutoProcessor.from_pretrained( model_id, revision = revision, trust_remote_code=True ) results_dir = f"evaluation_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}" os.makedirs(results_dir, exist_ok=True) INSTRUCTION = { "ast": "Translate the audio to {0}.", "asr": "Transcribe the audio clip into text.", } def save_results(results, dataset_name, task, source_lang, target_lang=None, sample_idx=None): filename = f"{task}_{dataset_name}_{source_lang}" if target_lang: filename += f"_to_{target_lang}" if sample_idx is not None: filename += f"_sample_{sample_idx}" filepath = os.path.join(results_dir, f"{filename}.json") results["timestamp"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S") with open(filepath, 'w', encoding='utf-8') as f: json.dump(results, f, ensure_ascii=False, indent=2) return filepath def evaluate_task(dataset): sample_results = [] dataloader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=covost_collate_fn) evaluated_samples = {} for batch_idx, batch in enumerate(tqdm(dataloader)): if torch.cuda.is_available(): try: batch = {k: v.to("cuda") for k, v in batch.items()} except: print('error') break with torch.inference_mode(): generate_ids = model.generate(**batch, max_new_tokens=256, #temperature = 1.0, top_p = 0.95, top_k = 64, do_sample=True ) input_lengths = batch['input_ids'].shape[1] generate_ids = generate_ids[:, input_lengths:] batch_predictions = processor.batch_decode( generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False ) input_lengths = batch['input_ids'].shape[1] label_ids = generate_ids[:, input_lengths:] batch_references = processor.batch_decode( label_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False ) for i, (reference, prediction) in enumerate(zip(batch_references, batch_predictions)): idx = batch_idx + i sample_result = { "id": idx, "reference": reference, "prediction": converter.convert(prediction) } sample_results.append(sample_result) if (batch_idx + 1) % 10 == 0: temp_results = [] for item in sample_results: sample_id = item["id"] if sample_id in evaluated_samples: temp_item = item.copy() temp_item.update(evaluated_samples[sample_id]) temp_results.append(temp_item) else: temp_item = item.copy() try: ref = eval_normalizer(item["reference"]) pred = eval_normalizer(item["prediction"]) # BLEU, WER/CER utt_bleu = sacrebleu.sentence_bleu(pred, [ref]).score utt_cer = round(cer(re.sub(r"\s+", "", ref), re.sub(r"\s+", "", pred)) * 100, 2) utt_wer = round(wer(ref, pred) * 100, 2) metrics = { "bleu": utt_bleu, "cer": min(100,utt_cer), "wer": utt_wer } evaluated_samples[sample_id] = metrics temp_item.update(metrics) except Exception as e: print(f"Error evaluating sample {sample_id}: {e}") metrics = { "bleu": 0, "cer": 100, "wer": 100, "error": str(e) } evaluated_samples[sample_id] = metrics temp_item.update(metrics) temp_results.append(temp_item) partial_results = { "task": task_type, "source_lang": source_lang, "target_lang": target_lang, "num_samples": len(temp_results), "sample_results": temp_results } save_results(partial_results, dataset.name, task_type, source_lang, target_lang) for item in sample_results: ref = eval_normalizer(item["reference"]) pred = eval_normalizer(item["prediction"]) utt_bleu = sacrebleu.sentence_bleu(pred, [ref]).score utt_cer = round(cer(re.sub(r"\s+", "", ref), re.sub(r"\s+", "", pred)) * 100, 2) utt_wer = round(wer(ref, pred) * 100, 2) item.update({ "bleu": utt_bleu, "cer": min(100,utt_cer), "wer": utt_wer }) avg_bleu = sum(item["bleu"] for item in sample_results) / len(sample_results) avg_cer = sum(item["cer"] for item in sample_results) / len(sample_results) avg_wer = sum(item["wer"] for item in sample_results) / len(sample_results) results = { "dataset": dataset.name, "task": task_type, "source_lang": source_lang, "target_lang": target_lang, "num_samples": len(sample_results), "metrics": { "bleu": avg_bleu, "cer": avg_cer, "wer": avg_wer }, "sample_results": sample_results } save_results(results, dataset.name, task_type, source_lang, target_lang) return results if __name__ == "__main__": datasets = [] pickup_dataset = MultiturnAudioDataset(split='eval',processor=processor,json_path='/mnt/jeff/InCar/data/multiturn_data/pickup_processed.json') datasets.append(pickup_dataset) for dataset in datasets: # ASR asr_results = evaluate_task(dataset) print(f"\n=== {asr_results.get('dataset', 'Dataset')}") print(f"BLEU: {asr_results.get('metrics', {}).get('bleu', 'N/A')}") print(f"WER: {asr_results.get('metrics', {}).get('wer', 'N/A')}") print(f"CER: {asr_results.get('metrics', {}).get('cer', 'N/A')}")