lmt commited on
Commit
337828e
·
1 Parent(s): 1f29eb4

ws链接支持上下文

Browse files
Files changed (1) hide show
  1. main.py +42 -18
main.py CHANGED
@@ -36,6 +36,32 @@ def chat(item: Item):
36
  return res
37
 
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  class Message(BaseModel):
40
  msg: str
41
 
@@ -48,31 +74,22 @@ class ResponseMessage(BaseModel):
48
  @app.websocket("/api/ws")
49
  async def websocket_endpoint(websocket: WebSocket):
50
  await websocket.accept()
 
 
 
 
51
  try:
52
  while True:
53
  data = await websocket.receive_text()
54
  message = Message(**json.loads(data))
 
55
 
56
- headers = {
57
- "Content-Type": "application/json",
58
- "Authorization": f"Bearer {api_key}",
59
- }
60
-
61
- payload = json.dumps({
62
- "model": "gpt-3.5-turbo",
63
- "messages": [{"role": "system", "content": "start"}, {"role": "user", "content": message.msg}],
64
- "stream": True,
65
- })
66
-
67
- response = requests.post(
68
- API_URL,
69
- headers=headers,
70
- data=payload,
71
- stream=True
72
- )
73
 
74
- print(message)
75
 
 
76
  finished = False
77
  for line in response.iter_lines():
78
  if line:
@@ -83,15 +100,22 @@ async def websocket_endpoint(websocket: WebSocket):
83
 
84
  if "content" in content['choices'][0]["delta"]:
85
  response_msg = content["choices"][0]["delta"]["content"]
 
86
  else:
87
  response_msg = ""
88
 
89
  if finish_reason == "stop":
90
  finished = True
 
 
 
 
 
91
 
92
  response_message = ResponseMessage(
93
  msg=response_msg, finished=finished)
94
  await websocket.send_text(response_message.json())
 
95
  await asyncio.sleep(0.001)
96
 
97
  if finished:
 
36
  return res
37
 
38
 
39
+ # 存储每个连接的对话历史
40
+ connection_history: Dict[str, List[Dict[str, str]]] = {}
41
+
42
+
43
+ def get_ai_response(messages):
44
+ '''获取ChatGPT答复,使用流式返回'''
45
+ headers = {
46
+ "Content-Type": "application/json",
47
+ "Authorization": f"Bearer {api_key}",
48
+ }
49
+
50
+ payload = json.dumps({
51
+ "model": "gpt-3.5-turbo",
52
+ "messages": messages,
53
+ "stream": True,
54
+ })
55
+
56
+ response = requests.post(
57
+ API_URL,
58
+ headers=headers,
59
+ data=payload,
60
+ stream=True
61
+ )
62
+ return response
63
+
64
+
65
  class Message(BaseModel):
66
  msg: str
67
 
 
74
  @app.websocket("/api/ws")
75
  async def websocket_endpoint(websocket: WebSocket):
76
  await websocket.accept()
77
+ connection_id = str(id(websocket))
78
+ connection_history[connection_id] = [
79
+ {"role": "system", "content": "You are a helpful assistant."}]
80
+
81
  try:
82
  while True:
83
  data = await websocket.receive_text()
84
  message = Message(**json.loads(data))
85
+ print(message)
86
 
87
+ user_message = {"role": "user", "content": message.msg}
88
+ connection_history[connection_id].append(user_message)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
+ response = get_ai_response(connection_history[connection_id])
91
 
92
+ whole_message = ""
93
  finished = False
94
  for line in response.iter_lines():
95
  if line:
 
100
 
101
  if "content" in content['choices'][0]["delta"]:
102
  response_msg = content["choices"][0]["delta"]["content"]
103
+ whole_message += response_msg
104
  else:
105
  response_msg = ""
106
 
107
  if finish_reason == "stop":
108
  finished = True
109
+ print(whole_message)
110
+ assistant_message = {
111
+ "role": "assistant", "content": whole_message}
112
+ connection_history[connection_id].append(
113
+ assistant_message)
114
 
115
  response_message = ResponseMessage(
116
  msg=response_msg, finished=finished)
117
  await websocket.send_text(response_message.json())
118
+ # sleep 1ms 给发送协程让出执行窗口
119
  await asyncio.sleep(0.001)
120
 
121
  if finished: