|
|
import os |
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = '0' |
|
|
|
|
|
from swift.llm import PtEngine, RequestConfig, safe_snapshot_download, get_model_tokenizer, get_template, InferRequest |
|
|
import json |
|
|
from transformers import AutoProcessor |
|
|
from swift.tuners import Swift |
|
|
last_model_checkpoint = '/root/autodl-tmp/output_7B_Lora_cotSFT/v2-20250613-111902/checkpoint-3' |
|
|
|
|
|
|
|
|
model_id_or_path = '/root/autodl-tmp/output_7B_Lora_allmission/v2-20250610-190504/checkpoint-1000-merged' |
|
|
system = 'You are a helpful assistant.' |
|
|
infer_backend = 'pt' |
|
|
|
|
|
|
|
|
max_new_tokens = 2048 |
|
|
temperature = 0 |
|
|
stream = False |
|
|
template_type = None |
|
|
default_system = system |
|
|
|
|
|
model, tokenizer = get_model_tokenizer(model_id_or_path) |
|
|
|
|
|
model = Swift.from_pretrained(model, last_model_checkpoint) |
|
|
template_type = template_type or model.model_meta.template |
|
|
template = get_template(template_type, tokenizer, default_system=default_system) |
|
|
engine = PtEngine.from_model_template(model, template, max_batch_size=2) |
|
|
request_config = RequestConfig(max_tokens=8192, temperature=0) |
|
|
|
|
|
def load_test_data(json_file): |
|
|
test_requests = [] |
|
|
with open(json_file, 'r', encoding='utf-8') as f: |
|
|
for line in f: |
|
|
data = json.loads(line.strip()) |
|
|
test_requests.append(InferRequest( |
|
|
messages=data['messages'], |
|
|
audios=data['audios'] |
|
|
)) |
|
|
return test_requests |
|
|
|
|
|
def main(): |
|
|
|
|
|
test_file = 'dataset_allmissiontest.json' |
|
|
infer_requests = load_test_data(test_file) |
|
|
results = [] |
|
|
resp_list = engine.infer(infer_requests, request_config) |
|
|
for i, resp in enumerate(resp_list): |
|
|
assistant_content = next((msg['content'] for msg in infer_requests[i].messages if msg['role'] == 'assistant'), None) |
|
|
result = { |
|
|
"index": i, |
|
|
"truth": assistant_content, |
|
|
"response": resp.choices[0].message.content, |
|
|
} |
|
|
results.append(result) |
|
|
print(f'truth{i}: {assistant_content}') |
|
|
print(f'response{i}: {resp.choices[0].message.content}') |
|
|
output_file = 'inference_results.json' |
|
|
with open(output_file, 'w', encoding='utf-8') as f: |
|
|
json.dump(results, f, ensure_ascii=False, indent=2) |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |