lmt commited on
Commit
5d9cc33
·
1 Parent(s): 5572682

增加ws接口

Browse files
Files changed (1) hide show
  1. main.py +76 -3
main.py CHANGED
@@ -1,14 +1,20 @@
 
 
1
  import os
 
2
  import openai
3
- from fastapi import FastAPI
4
  from pydantic import BaseModel
5
- from typing import Dict, List
 
6
 
7
- openai.api_key = os.environ.get('api_key')
8
  initial_prompt = "You are a helpful assistant."
 
9
 
10
  app = FastAPI()
11
 
 
 
12
 
13
  @app.get("/")
14
  def read_root():
@@ -16,18 +22,81 @@ def read_root():
16
 
17
 
18
  class Item(BaseModel):
 
19
  input: str
20
  history: List[Dict] = []
21
 
22
 
23
  @app.post("/api/chat")
24
  def chat(item: Item):
 
25
  history = [construct_user(item.input)]
26
  res = get_response(initial_prompt, history)
27
  return res
28
 
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  def get_response(system_prompt, history):
 
31
  history = [construct_system(system_prompt), *history]
32
 
33
  payload = {
@@ -55,3 +124,7 @@ def construct_assistant(text):
55
 
56
  def construct_text(role, text):
57
  return {"role": role, "content": text}
 
 
 
 
 
1
+ import json
2
+ from typing import Dict, List, Union
3
  import os
4
+ from fastapi import FastAPI, WebSocket
5
  import openai
 
6
  from pydantic import BaseModel
7
+ import requests
8
+ import uvicorn
9
 
10
+ api_key = os.environ.get('api_key')
11
  initial_prompt = "You are a helpful assistant."
12
+ API_URL = "https://api.openai.com/v1/chat/completions"
13
 
14
  app = FastAPI()
15
 
16
+ openai.api_key = api_key
17
+
18
 
19
  @app.get("/")
20
  def read_root():
 
22
 
23
 
24
  class Item(BaseModel):
25
+ _msgid: Union[str, None] = None
26
  input: str
27
  history: List[Dict] = []
28
 
29
 
30
  @app.post("/api/chat")
31
  def chat(item: Item):
32
+ print(item)
33
  history = [construct_user(item.input)]
34
  res = get_response(initial_prompt, history)
35
  return res
36
 
37
 
38
+ class Message(BaseModel):
39
+ msg: str
40
+
41
+
42
+ class ResponseMessage(BaseModel):
43
+ msg: str
44
+ finished: bool
45
+
46
+
47
+ @app.websocket("/api/ws")
48
+ async def websocket_endpoint(websocket: WebSocket):
49
+ await websocket.accept()
50
+ while True:
51
+ data = await websocket.receive_text()
52
+ message = Message(**json.loads(data))
53
+
54
+ headers = {
55
+ "Content-Type": "application/json",
56
+ "Authorization": f"Bearer {api_key}",
57
+ }
58
+
59
+ payload = json.dumps({
60
+ "model": "gpt-3.5-turbo",
61
+ "messages": [{"role": "system", "content": "start"}, {"role": "user", "content": message.msg}],
62
+ "stream": True,
63
+ })
64
+
65
+ response = requests.post(
66
+ API_URL,
67
+ headers=headers,
68
+ data=payload,
69
+ stream=True
70
+ )
71
+
72
+ print(message)
73
+
74
+ finished = False
75
+ for line in response.iter_lines():
76
+ if line:
77
+ decoded_line = line.decode("utf-8")
78
+ # print(decoded_line)
79
+ content = json.loads(decoded_line[6:])
80
+ finish_reason = content["choices"][0]["finish_reason"]
81
+
82
+ if "content" in content['choices'][0]["delta"]:
83
+ response_msg = content["choices"][0]["delta"]["content"]
84
+ else:
85
+ response_msg = ""
86
+
87
+ if finish_reason == "stop":
88
+ finished = True
89
+
90
+ response_message = ResponseMessage(
91
+ msg=response_msg, finished=finished)
92
+ await websocket.send_text(response_message.json())
93
+
94
+ if finished:
95
+ break
96
+
97
+
98
  def get_response(system_prompt, history):
99
+
100
  history = [construct_system(system_prompt), *history]
101
 
102
  payload = {
 
124
 
125
  def construct_text(role, text):
126
  return {"role": role, "content": text}
127
+
128
+
129
+ if __name__ == "__main__":
130
+ uvicorn.run(app, host="0.0.0.0", port=7860)