interactSpeech / .ipynb_checkpoints /infer-checkpoint.py
Student0809's picture
Add files using upload-large-folder tool
ee3af03 verified
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
from swift.llm import PtEngine, RequestConfig, safe_snapshot_download, get_model_tokenizer, get_template, InferRequest
import json
from transformers import AutoProcessor
from swift.tuners import Swift
last_model_checkpoint = '/root/autodl-tmp/output_7B_Lora_cotSFT/v2-20250613-111902/checkpoint-3'
# 模型
model_id_or_path = '/root/autodl-tmp/output_7B_Lora_allmission/v2-20250610-190504/checkpoint-1000-merged' # model_id or model_path
system = 'You are a helpful assistant.'
infer_backend = 'pt'
# 生成参数
max_new_tokens = 2048
temperature = 0
stream = False
template_type = None
default_system = system # None: 使用对应模型默认的default_system
# 初始化音频处理器
model, tokenizer = get_model_tokenizer(model_id_or_path)
# 初始化引擎
model = Swift.from_pretrained(model, last_model_checkpoint)
template_type = template_type or model.model_meta.template
template = get_template(template_type, tokenizer, default_system=default_system)
engine = PtEngine.from_model_template(model, template, max_batch_size=2)
request_config = RequestConfig(max_tokens=8192, temperature=0)
def load_test_data(json_file):
test_requests = []
with open(json_file, 'r', encoding='utf-8') as f:
for line in f:
data = json.loads(line.strip())
test_requests.append(InferRequest(
messages=data['messages'],
audios=data['audios']
))
return test_requests
def main():
# 加载测试数据
test_file = 'dataset_allmissiontest.json'
infer_requests = load_test_data(test_file)
results = []
resp_list = engine.infer(infer_requests, request_config)
for i, resp in enumerate(resp_list):
assistant_content = next((msg['content'] for msg in infer_requests[i].messages if msg['role'] == 'assistant'), None)
result = {
"index": i,
"truth": assistant_content,
"response": resp.choices[0].message.content,
}
results.append(result)
print(f'truth{i}: {assistant_content}')
print(f'response{i}: {resp.choices[0].message.content}')
output_file = 'inference_results.json'
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(results, f, ensure_ascii=False, indent=2)
if __name__ == '__main__':
main()