File size: 2,371 Bytes
ee3af03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

from swift.llm import PtEngine, RequestConfig, safe_snapshot_download, get_model_tokenizer, get_template, InferRequest
import json
from transformers import AutoProcessor
from swift.tuners import Swift
last_model_checkpoint = '/root/autodl-tmp/output_7B_Lora_cotSFT/v2-20250613-111902/checkpoint-3'

# 模型
model_id_or_path = '/root/autodl-tmp/output_7B_Lora_allmission/v2-20250610-190504/checkpoint-1000-merged'  # model_id or model_path
system = 'You are a helpful assistant.'
infer_backend = 'pt'

# 生成参数
max_new_tokens = 2048
temperature = 0
stream = False
template_type = None
default_system = system  # None: 使用对应模型默认的default_system
# 初始化音频处理器
model, tokenizer = get_model_tokenizer(model_id_or_path)
# 初始化引擎
model = Swift.from_pretrained(model, last_model_checkpoint)
template_type = template_type or model.model_meta.template
template = get_template(template_type, tokenizer, default_system=default_system)
engine = PtEngine.from_model_template(model, template, max_batch_size=2)
request_config = RequestConfig(max_tokens=8192, temperature=0)

def load_test_data(json_file):
    test_requests = []
    with open(json_file, 'r', encoding='utf-8') as f:
        for line in f:
            data = json.loads(line.strip())
            test_requests.append(InferRequest(
                messages=data['messages'],
                audios=data['audios']
            ))
    return test_requests

def main():
    # 加载测试数据
    test_file = 'dataset_allmissiontest.json'
    infer_requests = load_test_data(test_file)
    results = []
    resp_list = engine.infer(infer_requests, request_config)
    for i, resp in enumerate(resp_list):
        assistant_content = next((msg['content'] for msg in infer_requests[i].messages if msg['role'] == 'assistant'), None)
        result = {
            "index": i,
            "truth": assistant_content,
            "response": resp.choices[0].message.content,
        }
        results.append(result)
        print(f'truth{i}: {assistant_content}')
        print(f'response{i}: {resp.choices[0].message.content}')
    output_file = 'inference_results.json'
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(results, f, ensure_ascii=False, indent=2)
    

if __name__ == '__main__':
    main()