File size: 5,934 Bytes
4c1ba5a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
from io import BytesIO
from urllib.request import urlopen
import soundfile
import torch
from datasets import load_dataset, Audio
import numpy as np
from transformers import AutoModel, AutoProcessor, BatchFeature,Gemma3ForCausalLM,Gemma3Processor
from tqdm import tqdm
import json
import os
import time
from datetime import datetime
from whisper_normalizer.english import EnglishTextNormalizer
from whisper_normalizer.basic import BasicTextNormalizer
import sacrebleu
from jiwer import cer, wer
from torch.utils.data import Dataset, DataLoader
import soundfile as sf
import re
from pathlib import Path
import opencc
from ASRDataset import *
# converter = opencc.OpenCC('s2tw.json')
model_id = "./"
revision = "main" #"v1.0"
processor = AutoProcessor.from_pretrained(
model_id, revision = revision, trust_remote_code=True
)
results_dir = f"evaluation_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
# os.makedirs(results_dir, exist_ok=True)
def eval_text(model,dataloader,with_input_mode=False,save_path="",start_idx=0):
res = {'label':[],"pred":[],'cer':[]}
func_error = 0
total_func_call = 0
total_error = 0
all_output_text = []
remove_sign = lambda x:x.replace('User transcribe is','').replace('GPT output is','').replace('\n','').\
replace(' ','').replace('?','').replace('?','').replace('!','').replace('。','').\
replace('.','').replace('!','')
for batch_idx, batch in enumerate(tqdm(dataloader)):
if batch_idx<=start_idx:continue
batch = {k: v.to("cuda") for k, v in batch.items() if type(v)!=type(None)}
try:
with torch.inference_mode():
if not with_input_mode: batch.pop('input_modes')
generate_ids = model.generate(**batch,
max_new_tokens=256,
temperature = 0.001, top_p = 0.95, top_k = 64, do_sample=True
)
batch_inputs = processor.batch_decode(
generate_ids[:, :batch['input_ids'].shape[1]], skip_special_tokens=True,
clean_up_tokenization_spaces=False
)
batch_predictions = processor.batch_decode(
generate_ids[:, batch['input_ids'].shape[1]:], skip_special_tokens=True,
clean_up_tokenization_spaces=False
)
batch_references = processor.batch_decode(
batch['labels'], skip_special_tokens=True, clean_up_tokenization_spaces=False
)
for inp,label,output in zip(batch_inputs,batch_references,batch_predictions):
cer_o = min(100,round(cer(re.sub(r"\s+", "", label), re.sub(r"\s+", "", output)) * 100, 2))
res['label'].append(batch_references)
res['pred'].append(batch_predictions)
res['cer'].append(cer_o)
all_output_text.append({
'input':inp,
'label':label,
'output':output,
'cer':cer_o,
})
if 'Action:' in label:
func_error+=(remove_sign(label)!=remove_sign(output))
total_func_call+=1
if batch_idx%100==0:
with open(save_path,'w', encoding='utf-8') as f:
json.dump(all_output_text,f, ensure_ascii=False, indent=4)
avg_cer = sum(a['cer'] for a in all_output_text)/len(all_output_text)
total_error = sum(a['cer']!=0 for a in all_output_text)
print('total',len(all_output_text))
print('total_error & rate',total_error,total_error/len(all_output_text))
print('avg_cer',avg_cer)
print('total_func_call',total_func_call)
print('func_error & rate',func_error,',',func_error/total_func_call)
except:
print("error at ",batch_idx)
time.sleep(2)
avg_cer = sum(a['cer'] for a in all_output_text)/len(all_output_text)
total_error = sum(a['cer']!=0 for a in all_output_text)
print('total',len(all_output_text))
print('total_error & rate',total_error,total_error/len(all_output_text))
print('avg_cer',avg_cer)
print('total_func_call',total_func_call)
print('func_error & rate',func_error,',',func_error/total_func_call)
with open(save_path,'w', encoding='utf-8') as f:
json.dump(all_output_text,f, ensure_ascii=False, indent=4)
return res,all_output_text
nav_data = MultiturnAudioDataset(split='eval',text_only=True,processor=processor,json_path='/mnt/data-2t/jeff/codes/LLaMA-Factory/data/nav_toolcall_train.json')
ctrl_data = MultiturnAudioDataset(split='eval',text_only=True,processor=processor,json_path='/mnt/data-2t/jeff/codes/LLaMA-Factory/data/ctrl_toolcall_train.json')
ctrl_dataloader = DataLoader(ctrl_data, batch_size=1, shuffle=False, collate_fn=covost_collate_fn)
nav_dataloader = DataLoader(nav_data, batch_size=1, shuffle=False, collate_fn=covost_collate_fn)
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
from PIL import Image
import requests
import torch
model_id_org = "google/gemma-3-4b-it"
model_org = Gemma3ForConditionalGeneration.from_pretrained(
model_id_org, device_map="auto",attn_implementation="eager"
).eval()
from peft import PeftModel
model_org = PeftModel.from_pretrained(model_org, '/mnt/data-2t/jeff/codes/LLaMA-Factory/saves/Gemma-3-4B-Instruct/lora/train_123/checkpoint-3270')
res_org_nav,output_org_nav = eval_text(model_org,nav_dataloader,save_path='./output_org_nav_{}.json'.format(str(datetime.now())[:16]))
res_org_ctrl,output_org_ctrl = eval_text(model_org,ctrl_dataloader,save_path='./output_org_ctrl_{}.json'.format(str(datetime.now())[:16]))
|