File size: 4,949 Bytes
1f29eb4
8aefd0f
5d9cc33
 
4bf9eff
c3ba971
5572682
9c1b67f
5d9cc33
 
4bf9eff
5d9cc33
 
131d29f
 
 
4bf9eff
4f71b16
131d29f
4bf9eff
 
 
337828e
 
 
 
8aefd0f
 
 
566ac72
8aefd0f
 
 
566ac72
 
 
 
8aefd0f
337828e
 
 
 
 
 
8aefd0f
 
337828e
 
8aefd0f
 
337828e
 
 
 
 
 
 
 
 
 
 
5d9cc33
 
 
 
 
 
 
 
 
 
 
 
337828e
8aefd0f
337828e
c3ba971
 
 
 
337828e
c3ba971
566ac72
 
 
 
 
 
337828e
 
c3ba971
337828e
c3ba971
337828e
c3ba971
 
 
 
 
 
 
 
 
 
337828e
c3ba971
 
 
 
 
337828e
 
 
 
 
c3ba971
 
 
 
337828e
1f29eb4
c3ba971
 
 
 
 
 
566ac72
 
5d9cc33
 
8aefd0f
 
f00d827
8aefd0f
 
 
 
 
 
f00d827
 
 
 
8aefd0f
 
 
 
566ac72
 
 
 
4bf9eff
5d9cc33
f00d827
 
 
 
4bf9eff
 
 
f00d827
4bf9eff
 
5572682
 
 
4bf9eff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d9cc33
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import asyncio
from datetime import datetime
import json
from typing import Dict, List, Union
import os
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
import openai
from pydantic import BaseModel
import requests
import uvicorn

api_key = os.environ.get('api_key')
API_URL = "https://api.openai.com/v1/chat/completions"

app = FastAPI()


@app.get("/")
def read_root():
    return {"Hello": "World!"}


# 存储每个连接的对话历史
connection_history: Dict[str, List[Dict[str, str]]] = {}


def get_sys_prompt():
    return [{
        "role": "system",
        "content": f"You are a helpful assistant. Current time is {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}."
    }]


def test_reset(msg):
    return msg == "重置对话" or msg == "开启新对话"


def get_ai_response(messages, stream=True):
    '''获取ChatGPT答复,使用流式返回'''
    headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {api_key}",
    }

    sys_prompt = get_sys_prompt()

    payload = json.dumps({
        "model": "gpt-3.5-turbo",
        "messages": sys_prompt + messages,
        "stream": stream,
    })

    response = requests.post(
        API_URL,
        headers=headers,
        data=payload,
        stream=True
    )
    return response


class Message(BaseModel):
    msg: str


class ResponseMessage(BaseModel):
    msg: str
    finished: bool


@app.websocket("/api/ws")
async def websocket_endpoint(websocket: WebSocket):
    await websocket.accept()
    connection_id = str(id(websocket))
    connection_history[connection_id] = []

    try:
        while True:
            data = await websocket.receive_text()
            message = Message(**json.loads(data))
            print(message)

            # 判断是否重置对话
            if test_reset(message.msg):
                connection_history[connection_id] = []
                print("OK,对话已重置")
                continue

            user_message = {"role": "user", "content": message.msg}
            connection_history[connection_id].append(user_message)

            response = get_ai_response(connection_history[connection_id])

            whole_message = ""
            finished = False
            for line in response.iter_lines():
                if line:
                    decoded_line = line.decode("utf-8")
                    # print(decoded_line)
                    content = json.loads(decoded_line[6:])
                    finish_reason = content["choices"][0]["finish_reason"]

                    if "content" in content['choices'][0]["delta"]:
                        response_msg = content["choices"][0]["delta"]["content"]
                        whole_message += response_msg
                    else:
                        response_msg = ""

                    if finish_reason == "stop":
                        finished = True
                        print(whole_message)
                        assistant_message = {
                            "role": "assistant", "content": whole_message}
                        connection_history[connection_id].append(
                            assistant_message)

                    response_message = ResponseMessage(
                        msg=response_msg, finished=finished)
                    await websocket.send_text(response_message.json())
                    # sleep 1ms 给发送协程让出执行窗口
                    await asyncio.sleep(0.001)

                    if finished:
                        break
    except WebSocketDisconnect as e:
        # 在这里处理断开连接的情况,例如记录日志、清理资源等
        print(f"WebSocket disconnected with code: {e.code}")
    except:
        print(f"WebSocket disconnected with unknown reason")


class Item(BaseModel):
    _msgid: Union[str, None] = None
    input: str = ""
    history: List[Dict] = []


@app.post("/api/chat")
def chat(item: Item):
    print(item)
    history = [*item.history]

    if item.input:
        history.append({"role": "user", "content": f"{item.input}"})
    res = get_response(initial_prompt, history)
    return res


openai.api_key = api_key
initial_prompt = "You are a helpful assistant."


def get_response(system_prompt, history):

    messages = [
        {"role": "system", "content": f"{system_prompt}"},
        *history
    ]

    payload = {
        "model": "gpt-3.5-turbo",
        "messages": messages,
    }

    print(f"payload: -->{payload}")

    response = openai.ChatCompletion.create(**payload)
    return response


def construct_user(text):
    return construct_text("user", text)


def construct_system(text):
    return construct_text("system", text)


def construct_assistant(text):
    return construct_text("assistant", text)


def construct_text(role, text):
    return {"role": role, "content": text}


if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)