| import os |
| import unittest |
|
|
| from swift.llm import PtEngine, RequestConfig, get_model_tokenizer, get_template |
| from swift.utils import get_logger, seed_everything |
|
|
| |
| os.environ['SWIFT_DEBUG'] = '1' |
|
|
| logger = get_logger() |
|
|
|
|
| def _infer_model(pt_engine, system=None, messages=None): |
| seed_everything(42) |
| request_config = RequestConfig(max_tokens=128, temperature=0) |
| if messages is None: |
| messages = [] |
| if system is not None: |
| messages += [{'role': 'system', 'content': system}] |
| messages += [{'role': 'user', 'content': '你好'}] |
| resp = pt_engine.infer([{'messages': messages}], request_config=request_config) |
| response = resp[0].choices[0].message.content |
| messages += [{'role': 'assistant', 'content': response}, {'role': 'user', 'content': '<image>这是什么'}] |
| resp = pt_engine.infer([{ |
| 'messages': messages, |
| }], request_config=request_config) |
| response = resp[0].choices[0].message.content |
| messages += [{'role': 'assistant', 'content': response}] |
| logger.info(f'model: {pt_engine.model_info.model_name}, messages: {messages}') |
| return response |
|
|
|
|
| class TestTemplate(unittest.TestCase): |
|
|
| def test_template(self): |
| pt_engine = PtEngine('Qwen/Qwen2.5-3B-Instruct-GPTQ-Int4') |
| response = _infer_model(pt_engine) |
| pt_engine.default_template.template_backend = 'jinja' |
| response2 = _infer_model(pt_engine) |
| assert response == response2 |
|
|
| def test_tool_message_join(self): |
| from copy import deepcopy |
|
|
| from swift.plugin import agent_templates |
|
|
| messages = [ |
| |
| { |
| 'role': 'user', |
| 'content': 'user1' |
| }, |
| { |
| 'role': 'assistant', |
| 'content': 'assistant1' |
| }, |
| { |
| 'role': 'assistant', |
| 'content': 'assistant2' |
| }, |
| { |
| 'role': 'tool', |
| 'content': 'tool1' |
| }, |
| |
| { |
| 'role': 'assistant', |
| 'content': 'assistant3' |
| }, |
| { |
| 'role': 'tool', |
| 'content': 'tool2' |
| }, |
| { |
| 'role': 'tool', |
| 'content': 'tool3' |
| }, |
| ] |
|
|
| |
| tokenizer = get_model_tokenizer('Qwen/Qwen2.5-7B-Instruct', load_model=False)[1] |
| template = get_template(tokenizer.model_meta.template, tokenizer) |
| for agent_template_type in ('react_zh', 'qwen_zh'): |
| agent_template = agent_templates[agent_template_type]() |
| template.agent_template = agent_template |
| observation = agent_template.keyword.observation |
| test_messages = deepcopy(messages) |
| test_messages[2]['content'] = 'assistant2' + observation |
| test_messages[4]['content'] = ( |
| agent_template.keyword.action + agent_template.keyword.action_input + 'assistant3' + observation) |
| encoded = template.encode({'messages': test_messages}) |
| res = template.safe_decode(encoded['input_ids']) |
|
|
| ground_truth = ( |
| '<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n' |
| '<|im_start|>user\nuser1<|im_end|>\n' |
| f'<|im_start|>assistant\nassistant1assistant2{observation}tool1' |
| f'{agent_template.keyword.action}{agent_template.keyword.action_input}assistant3' |
| f'{observation}tool2\n{observation}tool3\n') |
| assert res == ground_truth |
|
|
|
|
| if __name__ == '__main__': |
| unittest.main() |
|
|