import os os.environ['CUDA_VISIBLE_DEVICES'] = '0' from swift.llm import PtEngine, RequestConfig, safe_snapshot_download, get_model_tokenizer, get_template, InferRequest import json from transformers import AutoProcessor from swift.tuners import Swift last_model_checkpoint = '/root/autodl-tmp/output_7B_Lora_cotSFT/v2-20250613-111902/checkpoint-3' # 模型 model_id_or_path = '/root/autodl-tmp/output_7B_Lora_allmission/v2-20250610-190504/checkpoint-1000-merged' # model_id or model_path system = 'You are a helpful assistant.' infer_backend = 'pt' # 生成参数 max_new_tokens = 2048 temperature = 0 stream = False template_type = None default_system = system # None: 使用对应模型默认的default_system # 初始化音频处理器 model, tokenizer = get_model_tokenizer(model_id_or_path) # 初始化引擎 model = Swift.from_pretrained(model, last_model_checkpoint) template_type = template_type or model.model_meta.template template = get_template(template_type, tokenizer, default_system=default_system) engine = PtEngine.from_model_template(model, template, max_batch_size=2) request_config = RequestConfig(max_tokens=8192, temperature=0) def load_test_data(json_file): test_requests = [] with open(json_file, 'r', encoding='utf-8') as f: for line in f: data = json.loads(line.strip()) test_requests.append(InferRequest( messages=data['messages'], audios=data['audios'] )) return test_requests def main(): # 加载测试数据 test_file = 'dataset_allmissiontest.json' infer_requests = load_test_data(test_file) results = [] resp_list = engine.infer(infer_requests, request_config) for i, resp in enumerate(resp_list): assistant_content = next((msg['content'] for msg in infer_requests[i].messages if msg['role'] == 'assistant'), None) result = { "index": i, "truth": assistant_content, "response": resp.choices[0].message.content, } results.append(result) print(f'truth{i}: {assistant_content}') print(f'response{i}: {resp.choices[0].message.content}') output_file = 'inference_results.json' with open(output_file, 'w', encoding='utf-8') as f: json.dump(results, f, ensure_ascii=False, indent=2) if __name__ == '__main__': main()