File size: 5,934 Bytes
4c1ba5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from io import BytesIO
from urllib.request import urlopen
import soundfile
import torch
from datasets import load_dataset, Audio
import numpy as np
from transformers import AutoModel, AutoProcessor, BatchFeature,Gemma3ForCausalLM,Gemma3Processor
from tqdm import tqdm
import json
import os
import time
from datetime import datetime
from whisper_normalizer.english import EnglishTextNormalizer
from whisper_normalizer.basic import BasicTextNormalizer
import sacrebleu
from jiwer import cer, wer
from torch.utils.data import Dataset, DataLoader
import soundfile as sf
import re
from pathlib import Path
import opencc
from ASRDataset import *

# converter = opencc.OpenCC('s2tw.json')

model_id = "./"
revision = "main" #"v1.0"

processor = AutoProcessor.from_pretrained(
    model_id, revision = revision, trust_remote_code=True
)

results_dir = f"evaluation_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
# os.makedirs(results_dir, exist_ok=True)

def eval_text(model,dataloader,with_input_mode=False,save_path="",start_idx=0):
    res = {'label':[],"pred":[],'cer':[]}
    func_error = 0
    total_func_call = 0
    total_error = 0
    all_output_text = []
    remove_sign = lambda x:x.replace('User transcribe is','').replace('GPT output is','').replace('\n','').\
                            replace(' ','').replace('?','').replace('?','').replace('!','').replace('。','').\
                            replace('.','').replace('!','')
    for batch_idx, batch in enumerate(tqdm(dataloader)):
        if batch_idx<=start_idx:continue
        batch = {k: v.to("cuda") for k, v in batch.items() if type(v)!=type(None)}
        try:
            with torch.inference_mode():
                if not with_input_mode: batch.pop('input_modes')
                generate_ids = model.generate(**batch, 
                max_new_tokens=256,
                temperature = 0.001, top_p = 0.95, top_k = 64, do_sample=True
                )
                batch_inputs = processor.batch_decode(
                    generate_ids[:, :batch['input_ids'].shape[1]], skip_special_tokens=True, 
                    clean_up_tokenization_spaces=False
                )
                batch_predictions = processor.batch_decode(
                    generate_ids[:, batch['input_ids'].shape[1]:], skip_special_tokens=True, 
                    clean_up_tokenization_spaces=False
                )
                batch_references = processor.batch_decode(
                    batch['labels'], skip_special_tokens=True, clean_up_tokenization_spaces=False
                )
                for inp,label,output in zip(batch_inputs,batch_references,batch_predictions):
                
                    cer_o = min(100,round(cer(re.sub(r"\s+", "", label), re.sub(r"\s+", "", output)) * 100, 2))
                    res['label'].append(batch_references)
                    res['pred'].append(batch_predictions)
                    res['cer'].append(cer_o)
                    all_output_text.append({
                        'input':inp,
                        'label':label,
                        'output':output,
                        'cer':cer_o,
                    })
                    if 'Action:' in label:
                        func_error+=(remove_sign(label)!=remove_sign(output))
                        total_func_call+=1
                if batch_idx%100==0:
                    with open(save_path,'w', encoding='utf-8') as f:
                        json.dump(all_output_text,f, ensure_ascii=False, indent=4)
                    avg_cer = sum(a['cer'] for a in all_output_text)/len(all_output_text)
                    total_error = sum(a['cer']!=0 for a in all_output_text)
                    print('total',len(all_output_text))
                    print('total_error & rate',total_error,total_error/len(all_output_text))
                    print('avg_cer',avg_cer)
                    print('total_func_call',total_func_call)
                    print('func_error & rate',func_error,',',func_error/total_func_call)
        except:
            print("error at ",batch_idx)
            time.sleep(2)
    avg_cer = sum(a['cer'] for a in all_output_text)/len(all_output_text)
    total_error = sum(a['cer']!=0 for a in all_output_text)
    print('total',len(all_output_text))
    print('total_error & rate',total_error,total_error/len(all_output_text))
    print('avg_cer',avg_cer)
    print('total_func_call',total_func_call)
    print('func_error & rate',func_error,',',func_error/total_func_call)
    with open(save_path,'w', encoding='utf-8') as f:
        json.dump(all_output_text,f, ensure_ascii=False, indent=4) 
    return res,all_output_text


nav_data = MultiturnAudioDataset(split='eval',text_only=True,processor=processor,json_path='/mnt/data-2t/jeff/codes/LLaMA-Factory/data/nav_toolcall_train.json')
ctrl_data = MultiturnAudioDataset(split='eval',text_only=True,processor=processor,json_path='/mnt/data-2t/jeff/codes/LLaMA-Factory/data/ctrl_toolcall_train.json')
ctrl_dataloader = DataLoader(ctrl_data, batch_size=1, shuffle=False, collate_fn=covost_collate_fn)
nav_dataloader = DataLoader(nav_data, batch_size=1, shuffle=False, collate_fn=covost_collate_fn)



from transformers import AutoProcessor, Gemma3ForConditionalGeneration
from PIL import Image
import requests
import torch

model_id_org = "google/gemma-3-4b-it"

model_org = Gemma3ForConditionalGeneration.from_pretrained(
    model_id_org, device_map="auto",attn_implementation="eager"
).eval()

from peft import PeftModel
model_org = PeftModel.from_pretrained(model_org, '/mnt/data-2t/jeff/codes/LLaMA-Factory/saves/Gemma-3-4B-Instruct/lora/train_123/checkpoint-3270')



res_org_nav,output_org_nav = eval_text(model_org,nav_dataloader,save_path='./output_org_nav_{}.json'.format(str(datetime.now())[:16]))
res_org_ctrl,output_org_ctrl = eval_text(model_org,ctrl_dataloader,save_path='./output_org_ctrl_{}.json'.format(str(datetime.now())[:16]))