File size: 12,838 Bytes
7feac49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
import os

os.environ['SWIFT_DEBUG'] = '1'

system = 'You are a helpful assistant.'

tools = [{
    'type': 'function',
    'function': {
        'name': 'get_current_weather',
        'description': 'Get the current weather in a given location',
        'parameters': {
            'type': 'object',
            'properties': {
                'location': {
                    'type': 'string',
                    'description': 'The city and state, e.g. San Francisco, CA'
                },
                'unit': {
                    'type': 'string',
                    'enum': ['celsius', 'fahrenheit']
                }
            },
            'required': ['location']
        }
    }
}, {
    'name_for_model': 'tool2',
    'name_for_human': '工具2',
    'description': 'Tool2的描述',
}]

glm4_tools = [{
    'type': 'function',
    'function': {
        'name': 'realtime_aqi',
        'description': '天气预报。获取实时空气质量。当前空气质量,PM2.5,PM10信息',
        'parameters': {
            'type': 'object',
            'properties': {
                'city': {
                    'description': '城市名'
                }
            },
            'required': ['city']
        }
    }
}]
glm4_tool_messasges = [
    {
        'role': 'tool',
        'content': '{"city": "北京", "aqi": "10", "unit": "celsius"}'
    },
    {
        'role': 'tool',
        'content': '{"city": "上海", "aqi": "72", "unit": "fahrenheit"}'
    },
]
glm4_query = '北京和上海今天的天气情况'


def _infer(engine, num_tools: int = 1, agent_tools=None, tool_messages=None, query=None):
    if agent_tools is None:
        agent_tools = tools
    if tool_messages is None:
        tool_messages = []
        for _ in range(num_tools):
            tool_messages.append({
                'role': 'tool',
                'content': '{"temperature": 32, "condition": "Sunny", "humidity": 50}'
            })
    stop = [engine.default_template.agent_template.keyword.observation]
    query = query or "How's the weather in Beijing today?"
    infer_request = InferRequest([{'role': 'user', 'content': query}], tools=agent_tools)
    request_config = RequestConfig(max_tokens=512, stop=stop, temperature=0)
    resp_list = engine.infer([infer_request], request_config=request_config)
    response = resp_list[0].choices[0].message.content
    toolcall = resp_list[0].choices[0].message.tool_calls[0].function
    print(f'response: {response}')
    print(f'toolcall: {toolcall}')
    assert toolcall is not None
    infer_request.messages.append({'role': 'assistant', 'content': response})
    infer_request.messages += tool_messages
    resp_list = engine.infer([infer_request], request_config=request_config)
    response2 = resp_list[0].choices[0].message.content
    print(f'response2: {response2}')
    infer_request.messages.append({'role': 'assistant', 'content': response2})
    return infer_request.messages


def test_react_en():
    agent_template = agent_templates['react_en']()
    new_system = agent_template._format_tools(tools, system)
    assert len(new_system) == 1144
    engine = PtEngine('Qwen/Qwen2.5-7B-Instruct')
    template = engine.default_template
    template.agent_template = agent_template
    messages = _infer(engine)
    assert messages[-1]['content'] == (
        'Thought: The current temperature in Beijing is 32 degrees Celsius, and the condition is sunny '
        'with a humidity of 50%.\nFinal Answer: The current temperature in Beijing is 32 degrees Celsius,'
        ' and the condition is sunny with a humidity of 50%.')
    template.set_mode('train')
    encoded = template.encode({'messages': messages})
    print(f'input_ids: {template.safe_decode(encoded["input_ids"])}')
    print(f'labels: {template.safe_decode(encoded["labels"])}')

    dataset = load_dataset('AI-ModelScope/function-calling-chatml')[0]
    data = dataset[6]
    data['messages'].insert(1, data['messages'][1])
    data['messages'].insert(3, data['messages'][3])
    template.template_backend = 'swift'
    encoded = template.encode(data)
    print(f'input_ids: {template.safe_decode(encoded["input_ids"])}')
    print(f'labels: {template.safe_decode(encoded["labels"])}')


def test_react_zh():
    agent_template = agent_templates['react_zh']()
    new_system = agent_template._format_tools(tools, system)
    assert len(new_system) == 712
    engine = PtEngine('Qwen/Qwen2.5-7B-Instruct')
    template = engine.default_template
    template.agent_template = agent_template
    _infer(engine)


def test_qwen_en():
    agent_template = agent_templates['qwen_en']()
    new_system = agent_template._format_tools(tools, system)
    assert len(new_system) == 879
    engine = PtEngine('Qwen/Qwen2.5-7B-Instruct')
    template = engine.default_template
    template.agent_template = agent_template
    messages = _infer(engine)
    assert messages[-1]['content'] == (
        '✿RETURN✿: Today in Beijing, the temperature is 32°C with sunny conditions and the humidity '
        'is at 50%. Enjoy the nice weather!')
    template.set_mode('train')
    encoded = template.encode({'messages': messages})
    print(f'input_ids: {template.safe_decode(encoded["input_ids"])}')
    print(f'labels: {template.safe_decode(encoded["labels"])}')

    dataset = load_dataset('AI-ModelScope/function-calling-chatml')[0]
    data = dataset[6]
    data['messages'].insert(1, data['messages'][1])
    data['messages'].insert(3, data['messages'][3])
    template.template_backend = 'swift'
    encoded = template.encode(data)
    print(f'input_ids: {template.safe_decode(encoded["input_ids"])}')
    print(f'labels: {template.safe_decode(encoded["labels"])}')


def test_qwen_zh():
    agent_template = agent_templates['qwen_zh']()
    new_system = agent_template._format_tools(tools, system)
    assert len(new_system) == 577
    engine = PtEngine('Qwen/Qwen2.5-7B-Instruct')
    template = engine.default_template
    template.agent_template = agent_template
    _infer(engine)


def test_qwen_en_parallel():
    agent_template = agent_templates['qwen_en_parallel']()
    new_system = agent_template._format_tools(tools, system)
    assert len(new_system) == 1012
    engine = PtEngine('Qwen/Qwen2.5-7B-Instruct')
    template = engine.default_template
    template.agent_template = agent_template
    messages = _infer(engine, num_tools=2)
    assert messages[-1]['content'] == (
        '✿RETURN✿: Today in Beijing, the temperature is 32 degrees Celsius with sunny conditions '
        'and the humidity is at 50%. Enjoy the nice weather!')
    template.set_mode('train')
    encoded = template.encode({'messages': messages})
    print(f'input_ids: {template.safe_decode(encoded["input_ids"])}')
    print(f'labels: {template.safe_decode(encoded["labels"])}')

    dataset = load_dataset('AI-ModelScope/function-calling-chatml')[0]
    data = dataset[6]
    data['messages'].insert(1, data['messages'][1])
    data['messages'].insert(3, data['messages'][3])
    template.template_backend = 'swift'
    encoded = template.encode(data)
    print(f'input_ids: {template.safe_decode(encoded["input_ids"])}')
    print(f'labels: {template.safe_decode(encoded["labels"])}')


def test_qwen_zh_parallel():
    agent_template = agent_templates['qwen_zh_parallel']()
    new_system = agent_template._format_tools(tools, system)
    assert len(new_system) == 688
    engine = PtEngine('Qwen/Qwen2.5-7B-Instruct')
    template = engine.default_template
    template.agent_template = agent_template
    _infer(engine, num_tools=2)


def test_hermes():
    agent_template = agent_templates['hermes']()
    new_system = agent_template._format_tools(tools, system)
    assert len(new_system) == 875
    engine = PtEngine('Qwen/Qwen2.5-7B-Instruct')
    template = engine.default_template
    template.agent_template = agent_template
    messages = _infer(engine, num_tools=2)
    template.template_backend = 'jinja'
    messages2 = _infer(engine, num_tools=2)
    assert messages[-1]['content'] == messages2[-1]['content'] == (
        'Today in Beijing, the temperature is 32 degrees Celsius with sunny conditions '
        'and the humidity is at 50%. Enjoy the nice weather!')
    template.set_mode('train')
    encoded = template.encode({'messages': messages})
    print(f'input_ids: {template.safe_decode(encoded["input_ids"])}')
    print(f'labels: {template.safe_decode(encoded["labels"])}')

    dataset = load_dataset('AI-ModelScope/function-calling-chatml')[0]
    data = dataset[6]
    data['messages'].insert(1, data['messages'][1])
    data['messages'].insert(3, data['messages'][3])
    template.template_backend = 'swift'
    encoded = template.encode(data)
    print(f'input_ids: {template.safe_decode(encoded["input_ids"])}')
    print(f'labels: {template.safe_decode(encoded["labels"])}')
    template.template_backend = 'jinja'
    encoded2 = template.encode(data)
    print(f'input_ids: {template.safe_decode(encoded2["input_ids"])}')
    print(f'labels: {template.safe_decode(encoded2["labels"])}')
    assert encoded['input_ids'] == encoded2['input_ids'][:-1]


def test_toolbench():
    agent_template = agent_templates['toolbench']()
    new_system = agent_template._format_tools(tools, system)
    assert len(new_system) == 1833
    engine = PtEngine('Qwen/Qwen2.5-7B-Instruct')
    template = engine.default_template
    template.agent_template = agent_template
    _infer(engine)


def test_glm4():
    agent_template = agent_templates['glm4']()
    new_system = agent_template._format_tools(tools, system)
    assert len(new_system) == 846
    engine = PtEngine('ZhipuAI/glm-4-9b-chat')
    template = engine.default_template
    template.agent_template = agent_template
    _infer(engine, agent_tools=glm4_tools, tool_messages=glm4_tool_messasges, query=glm4_query)


def test_glm4_0414():
    agent_template = agent_templates['glm4_0414']()
    new_system = agent_template._format_tools(tools, system)
    assert len(new_system) == 769
    engine = PtEngine('ZhipuAI/GLM-4-9B-0414')
    template = engine.default_template
    template.agent_template = agent_template
    messages = _infer(engine, agent_tools=glm4_tools, tool_messages=glm4_tool_messasges, query=glm4_query)
    assert messages[-1]['content'] == '根据天气预报工具,北京今天的空气质量指数为10,属于良好水平;上海今天的空气质量指数为72,属于轻度污染水平。'
    template.set_mode('train')
    encoded = template.encode({'messages': messages})
    print(f'input_ids: {template.safe_decode(encoded["input_ids"])}')
    print(f'labels: {template.safe_decode(encoded["labels"])}')

    dataset = load_dataset('AI-ModelScope/function-calling-chatml')[0]
    data = dataset[6]
    data['messages'].insert(1, data['messages'][1])
    data['messages'].insert(3, data['messages'][3])
    template.template_backend = 'swift'
    encoded = template.encode(data)
    print(f'input_ids: {template.safe_decode(encoded["input_ids"])}')
    print(f'labels: {template.safe_decode(encoded["labels"])}')


def test_llama3():
    agent_template = agent_templates['llama3']()
    engine = PtEngine('LLM-Research/Llama-3.2-3B-Instruct')
    template = engine.default_template
    template.agent_template = agent_template
    messages = _infer(engine)

    template.set_mode('train')
    encoded = template.encode({'messages': messages})
    print(f'input_ids: {template.safe_decode(encoded["input_ids"])}')
    print(f'labels: {template.safe_decode(encoded["labels"])}')

    dataset = load_dataset('AI-ModelScope/function-calling-chatml')[0]
    data = dataset[6]
    data['messages'].insert(1, data['messages'][1])
    data['messages'].insert(3, data['messages'][3])
    template.template_backend = 'swift'
    encoded = template.encode(data)
    print(f'input_ids: {template.safe_decode(encoded["input_ids"])}')
    print(f'labels: {template.safe_decode(encoded["labels"])}')


def test_llama4():
    agent_template = agent_templates['llama4']()
    engine = PtEngine('LLM-Research/Llama-4-Scout-17B-16E-Instruct')
    template = engine.default_template
    template.agent_template = agent_template
    messages = _infer(engine)
    template.set_mode('train')
    encoded = template.encode({'messages': messages})
    print(f'input_ids: {template.safe_decode(encoded["input_ids"])}')
    print(f'labels: {template.safe_decode(encoded["labels"])}')


if __name__ == '__main__':
    from swift.plugin import agent_templates
    from swift.llm import PtEngine, InferRequest, RequestConfig, load_dataset
    # test_react_en()
    # test_react_zh()
    # test_qwen_en()
    # test_qwen_zh()
    # test_qwen_en_parallel()
    # test_qwen_zh_parallel()
    test_hermes()
    # test_toolbench()
    # test_glm4()
    # test_glm4_0414()
    # test_llama3()
    # test_llama4()