yuccaaa commited on
Commit
88d804f
·
verified ·
1 Parent(s): 0ccf423

Upload ms-swift/examples/infer/demo_agent.py with huggingface_hub

Browse files
ms-swift/examples/infer/demo_agent.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ import os
3
+
4
+ os.environ['CUDA_VISIBLE_DEVICES'] = '0'
5
+
6
+
7
+ def infer(engine: 'InferEngine', infer_request: 'InferRequest'):
8
+ request_config = RequestConfig(max_tokens=512, temperature=0, stop=['Observation:'])
9
+ resp_list = engine.infer([infer_request], request_config)
10
+ query = infer_request.messages[0]['content']
11
+ response = resp_list[0].choices[0].message.content
12
+ tool = '{"temperature": 72, "condition": "Sunny", "humidity": 50}\n'
13
+ print(f'query: {query}')
14
+ print(f'response: {response}{tool}', end='')
15
+
16
+ infer_request.messages += [{'role': 'assistant', 'content': response}, {'role': 'tool', 'content': tool}]
17
+ resp_list = engine.infer([infer_request], request_config)
18
+ response2 = resp_list[0].choices[0].message.content
19
+ print(response2)
20
+
21
+
22
+ def infer_stream(engine: 'InferEngine', infer_request: 'InferRequest'):
23
+ request_config = RequestConfig(max_tokens=512, temperature=0, stop=['Observation:'], stream=True)
24
+ gen_list = engine.infer([infer_request], request_config)
25
+ query = infer_request.messages[0]['content']
26
+ response = ''
27
+ tool = '{"temperature": 72, "condition": "Sunny", "humidity": 50}'
28
+ print(f'query: {query}')
29
+ for resp in gen_list[0]:
30
+ if resp is None:
31
+ continue
32
+ delta = resp.choices[0].delta.content
33
+ response += delta
34
+ print(delta, end='', flush=True)
35
+ print(tool)
36
+
37
+ infer_request.messages += [{'role': 'assistant', 'content': response}, {'role': 'tool', 'content': tool}]
38
+ gen_list = engine.infer([infer_request], request_config)
39
+ for resp in gen_list[0]:
40
+ if resp is None:
41
+ continue
42
+ print(resp.choices[0].delta.content, end='', flush=True)
43
+ print()
44
+
45
+
46
+ def get_infer_request():
47
+ return InferRequest(
48
+ messages=[{
49
+ 'role': 'user',
50
+ 'content': "How's the weather today?"
51
+ }],
52
+ tools=[{
53
+ 'name': 'get_current_weather',
54
+ 'description': 'Get the current weather in a given location',
55
+ 'parameters': {
56
+ 'type': 'object',
57
+ 'properties': {
58
+ 'location': {
59
+ 'type': 'string',
60
+ 'description': 'The city and state, e.g. San Francisco, CA'
61
+ },
62
+ 'unit': {
63
+ 'type': 'string',
64
+ 'enum': ['celsius', 'fahrenheit']
65
+ }
66
+ },
67
+ 'required': ['location']
68
+ }
69
+ }])
70
+
71
+
72
+ if __name__ == '__main__':
73
+ from swift.llm import InferEngine, InferRequest, PtEngine, RequestConfig
74
+ model = 'Qwen/Qwen2.5-1.5B-Instruct'
75
+ infer_backend = 'pt'
76
+
77
+ if infer_backend == 'pt':
78
+ engine = PtEngine(model, max_batch_size=64)
79
+ elif infer_backend == 'vllm':
80
+ from swift.llm import VllmEngine
81
+ engine = VllmEngine(model, max_model_len=8192)
82
+ elif infer_backend == 'lmdeploy':
83
+ from swift.llm import LmdeployEngine
84
+ engine = LmdeployEngine(model)
85
+
86
+ infer(engine, get_infer_request())
87
+ infer_stream(engine, get_infer_request())