Spaces:
Runtime error
Runtime error
| import json | |
| from websockets.exceptions import ConnectionClosedOK | |
| from websockets.sync.client import connect | |
| from chatglm2_6b.modelClient import ChatGLM2 | |
| import abc | |
| class ChatClient(abc.ABC): | |
| def simple_chat(self, query, history, temperature, top_p): | |
| pass | |
| def instruct_chat(self, message, chat_history, instructions, temperature, top_p): | |
| pass | |
| def format_chat_prompt(message: str, chat_history, instructions: str) -> str: | |
| instructions = instructions.strip(" ").strip("\n") | |
| prompt = f"对话背景设定:{instructions}" | |
| for i, (user_message, bot_message) in enumerate(chat_history): | |
| prompt = f"{prompt}\n\n[Round {i + 1}]\n\n问:{user_message}\n\n答:{bot_message}" | |
| prompt = f"{prompt}\n\n[Round {len(chat_history)+1}]\n\n问:{message}\n\n答:" | |
| return prompt | |
| class ChatGLM2APIClient(ChatClient): | |
| def __init__(self, ws_url=None): | |
| self.ws_url = "ws://localhost:10001" | |
| if ws_url: | |
| self.ws_url = ws_url | |
| def simple_chat(self, query, history, temperature, top_p): | |
| """chatglm2-6b 模型定义的对话方法""" | |
| url = f"{self.ws_url}/streamChat" | |
| with connect(url) as websocket: | |
| msg = json.dumps({ | |
| "query": query, "history": history, | |
| "temperature": temperature, "top_p": top_p, | |
| }) | |
| websocket.send(msg) | |
| data = None | |
| try: | |
| while True: | |
| data = websocket.recv() | |
| data = json.loads(data) | |
| yield data['resp'], data['history'] | |
| except ConnectionClosedOK: | |
| print("generation is finished") | |
| def instruct_chat(self, message, chat_history, instructions, temperature, top_p): | |
| """基于chatglm2-6b text_generate 实现的基于预设指令的对话""" | |
| url = f"{self.ws_url}/streamGenerate" | |
| prompt = format_chat_prompt(message, chat_history, instructions) | |
| chat_history = chat_history + [[message, ""]] | |
| params = json.dumps({"prompt": prompt, "temperature": temperature, "top_p": top_p}) | |
| with connect(url) as websocket: | |
| websocket.send(params) | |
| data = None | |
| try: | |
| while True: | |
| data = websocket.recv() | |
| data = json.loads(data) | |
| resp = data['text'] | |
| last_turn = list(chat_history.pop(-1)) | |
| last_turn[-1] = resp | |
| chat_history = chat_history + [last_turn] | |
| yield resp, chat_history | |
| except ConnectionClosedOK: | |
| print("generation is finished") | |
| class ChatGLM2ModelClient(ChatClient): | |
| def __init__(self, model_path=None): | |
| self.model = ChatGLM2(model_path) | |
| def simple_chat(self, query, history, temperature, top_p): | |
| kwargs = { | |
| "query": query, "history": history, | |
| "temperature": temperature, "top_p": top_p, | |
| } | |
| for resp, history in self.model.stream_chat(**kwargs): | |
| yield resp, history | |
| def instruct_chat(self, message, chat_history, instructions, temperature, top_p): | |
| prompt = format_chat_prompt(message, chat_history, instructions) | |
| chat_history = chat_history + [[message, ""]] | |
| kwargs = {"prompt": prompt, "temperature": temperature, "top_p": top_p} | |
| for resp in self.model.stream_generate(**kwargs): | |
| last_turn = list(chat_history.pop(-1)) | |
| last_turn[-1] = resp | |
| chat_history = chat_history + [last_turn] | |
| yield resp, chat_history | |