File size: 4,503 Bytes
2807ff7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
"""
This demo script is designed for interacting with the ChatGLM3-6B in Function, to show Function Call capabilities.
"""
import os
import platform
import torch
from transformers import AutoTokenizer, AutoModel
MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/chatglm3-6b')
TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH)
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True)
model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True, device_map="auto").eval()
os_name = platform.system()
clear_command = 'cls' if os_name == 'Windows' else 'clear'
stop_stream = False
def build_prompt(history):
prompt = "欢迎使用 ChatGLM3-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序"
for query, response in history:
prompt += f"\n\n用户:{query}"
prompt += f"\n\nChatGLM3-6B:{response}"
return prompt
tools = [
{'name': 'track', 'description': '追踪指定股票的实时价格',
'parameters':
{
'type': 'object', 'properties':
{'symbol':
{
'description': '需要追踪的股票代码'
}
},
'required': []
}
}, {
'name': '/text-to-speech', 'description': '将文本转换为语音',
'parameters':
{
'type': 'object', 'properties':
{
'text':
{
'description': '需要转换成语音的文本'
},
'voice':
{
'description': '要使用的语音类型(男声、女声等)'
},
'speed': {
'description': '语音的速度(快、中等、慢等)'
}
}, 'required': []
}
},
{
'name': '/image_resizer', 'description': '调整图片的大小和尺寸',
'parameters': {'type': 'object',
'properties':
{
'image_file':
{
'description': '需要调整大小的图片文件'
},
'width':
{
'description': '需要调整的宽度值'
},
'height':
{
'description': '需要调整的高度值'
}
},
'required': []
}
},
{
'name': '/foodimg', 'description': '通过给定的食品名称生成该食品的图片',
'parameters': {
'type': 'object', 'properties':
{
'food_name':
{
'description': '需要生成图片的食品名称'
}
},
'required': []
}
}
]
system_item = {
"role": "system",
"content": "Answer the following questions as best as you can. You have access to the following tools:",
"tools": tools
}
def main():
past_key_values, history = None, [system_item]
role = "user"
global stop_stream
print("欢迎使用 ChatGLM3-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序")
while True:
query = input("\n用户:") if role == "user" else input("\n结果:")
if query.strip() == "stop":
break
if query.strip() == "clear":
past_key_values, history = None, [system_item]
role = "user"
os.system(clear_command)
print("欢迎使用 ChatGLM3-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序")
continue
print("\nChatGLM:", end="")
response, history = model.chat(tokenizer, query, history=history, role=role)
print(response, end="", flush=True)
print("")
if isinstance(response, dict):
role = "observation"
else:
role = "user"
if __name__ == "__main__":
main()
|