chat-api / main.py
lmt
post接口支持上下文
f00d827
import asyncio
from datetime import datetime
import json
from typing import Dict, List, Union
import os
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
import openai
from pydantic import BaseModel
import requests
import uvicorn
api_key = os.environ.get('api_key')
API_URL = "https://api.openai.com/v1/chat/completions"
app = FastAPI()
@app.get("/")
def read_root():
return {"Hello": "World!"}
# 存储每个连接的对话历史
connection_history: Dict[str, List[Dict[str, str]]] = {}
def get_sys_prompt():
return [{
"role": "system",
"content": f"You are a helpful assistant. Current time is {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}."
}]
def test_reset(msg):
return msg == "重置对话" or msg == "开启新对话"
def get_ai_response(messages, stream=True):
'''获取ChatGPT答复,使用流式返回'''
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}",
}
sys_prompt = get_sys_prompt()
payload = json.dumps({
"model": "gpt-3.5-turbo",
"messages": sys_prompt + messages,
"stream": stream,
})
response = requests.post(
API_URL,
headers=headers,
data=payload,
stream=True
)
return response
class Message(BaseModel):
msg: str
class ResponseMessage(BaseModel):
msg: str
finished: bool
@app.websocket("/api/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
connection_id = str(id(websocket))
connection_history[connection_id] = []
try:
while True:
data = await websocket.receive_text()
message = Message(**json.loads(data))
print(message)
# 判断是否重置对话
if test_reset(message.msg):
connection_history[connection_id] = []
print("OK,对话已重置")
continue
user_message = {"role": "user", "content": message.msg}
connection_history[connection_id].append(user_message)
response = get_ai_response(connection_history[connection_id])
whole_message = ""
finished = False
for line in response.iter_lines():
if line:
decoded_line = line.decode("utf-8")
# print(decoded_line)
content = json.loads(decoded_line[6:])
finish_reason = content["choices"][0]["finish_reason"]
if "content" in content['choices'][0]["delta"]:
response_msg = content["choices"][0]["delta"]["content"]
whole_message += response_msg
else:
response_msg = ""
if finish_reason == "stop":
finished = True
print(whole_message)
assistant_message = {
"role": "assistant", "content": whole_message}
connection_history[connection_id].append(
assistant_message)
response_message = ResponseMessage(
msg=response_msg, finished=finished)
await websocket.send_text(response_message.json())
# sleep 1ms 给发送协程让出执行窗口
await asyncio.sleep(0.001)
if finished:
break
except WebSocketDisconnect as e:
# 在这里处理断开连接的情况,例如记录日志、清理资源等
print(f"WebSocket disconnected with code: {e.code}")
except:
print(f"WebSocket disconnected with unknown reason")
class Item(BaseModel):
_msgid: Union[str, None] = None
input: str = ""
history: List[Dict] = []
@app.post("/api/chat")
def chat(item: Item):
print(item)
history = [*item.history]
if item.input:
history.append({"role": "user", "content": f"{item.input}"})
res = get_response(initial_prompt, history)
return res
openai.api_key = api_key
initial_prompt = "You are a helpful assistant."
def get_response(system_prompt, history):
messages = [
{"role": "system", "content": f"{system_prompt}"},
*history
]
payload = {
"model": "gpt-3.5-turbo",
"messages": messages,
}
print(f"payload: -->{payload}")
response = openai.ChatCompletion.create(**payload)
return response
def construct_user(text):
return construct_text("user", text)
def construct_system(text):
return construct_text("system", text)
def construct_assistant(text):
return construct_text("assistant", text)
def construct_text(role, text):
return {"role": role, "content": text}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)