Fixed the bug where the Claude role could not be obtained and the SSE format was incorrect.
Browse files- json_str/gpt/mess_sse.json +12 -0
- main.py +38 -11
- request.py +6 -4
- response.py +7 -1
json_str/gpt/mess_sse.json
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data: {"id":"chatcmpl-9j98vC0GPtmMAdOsSgh1TGhFQAsZC","object":"chat.completion.chunk","created":1720546933,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_ce0793330f","choices":[{"index":0,"delta":{"role":"assistant","content":""},"logprobs":null,"finish_reason":null}],"usage":null}
|
| 2 |
+
data: {"id":"chatcmpl-9j98vC0GPtmMAdOsSgh1TGhFQAsZC","object":"chat.completion.chunk","created":1720546933,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_ce0793330f","choices":[{"index":0,"delta":{"content":"Hello"},"logprobs":null,"finish_reason":null}],"usage":null}
|
| 3 |
+
data: {"id":"chatcmpl-9j98vC0GPtmMAdOsSgh1TGhFQAsZC","object":"chat.completion.chunk","created":1720546933,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_ce0793330f","choices":[{"index":0,"delta":{"content":"!"},"logprobs":null,"finish_reason":null}],"usage":null}
|
| 4 |
+
data: {"id":"chatcmpl-9j98vC0GPtmMAdOsSgh1TGhFQAsZC","object":"chat.completion.chunk","created":1720546933,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_ce0793330f","choices":[{"index":0,"delta":{"content":" How"},"logprobs":null,"finish_reason":null}],"usage":null}
|
| 5 |
+
data: {"id":"chatcmpl-9j98vC0GPtmMAdOsSgh1TGhFQAsZC","object":"chat.completion.chunk","created":1720546933,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_ce0793330f","choices":[{"index":0,"delta":{"content":" can"},"logprobs":null,"finish_reason":null}],"usage":null}
|
| 6 |
+
data: {"id":"chatcmpl-9j98vC0GPtmMAdOsSgh1TGhFQAsZC","object":"chat.completion.chunk","created":1720546933,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_ce0793330f","choices":[{"index":0,"delta":{"content":" I"},"logprobs":null,"finish_reason":null}],"usage":null}
|
| 7 |
+
data: {"id":"chatcmpl-9j98vC0GPtmMAdOsSgh1TGhFQAsZC","object":"chat.completion.chunk","created":1720546933,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_ce0793330f","choices":[{"index":0,"delta":{"content":" assist"},"logprobs":null,"finish_reason":null}],"usage":null}
|
| 8 |
+
data: {"id":"chatcmpl-9j98vC0GPtmMAdOsSgh1TGhFQAsZC","object":"chat.completion.chunk","created":1720546933,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_ce0793330f","choices":[{"index":0,"delta":{"content":" you"},"logprobs":null,"finish_reason":null}],"usage":null}
|
| 9 |
+
data: {"id":"chatcmpl-9j98vC0GPtmMAdOsSgh1TGhFQAsZC","object":"chat.completion.chunk","created":1720546933,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_ce0793330f","choices":[{"index":0,"delta":{"content":" today"},"logprobs":null,"finish_reason":null}],"usage":null}
|
| 10 |
+
data: {"id":"chatcmpl-9j98vC0GPtmMAdOsSgh1TGhFQAsZC","object":"chat.completion.chunk","created":1720546933,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_ce0793330f","choices":[{"index":0,"delta":{"content":"?"},"logprobs":null,"finish_reason":null}],"usage":null}
|
| 11 |
+
data: {"id":"chatcmpl-9j98vC0GPtmMAdOsSgh1TGhFQAsZC","object":"chat.completion.chunk","created":1720546933,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_ce0793330f","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}],"usage":null}
|
| 12 |
+
data: {"id":"chatcmpl-9j98vC0GPtmMAdOsSgh1TGhFQAsZC","object":"chat.completion.chunk","created":1720546933,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_ce0793330f","choices":[],"usage":{"prompt_tokens":178,"completion_tokens":10,"total_tokens":188}}
|
main.py
CHANGED
|
@@ -1,11 +1,12 @@
|
|
| 1 |
import os
|
| 2 |
import json
|
| 3 |
import httpx
|
|
|
|
| 4 |
import yaml
|
| 5 |
import traceback
|
| 6 |
from contextlib import asynccontextmanager
|
| 7 |
|
| 8 |
-
from fastapi import FastAPI, HTTPException, Depends
|
| 9 |
from fastapi.responses import StreamingResponse
|
| 10 |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
| 11 |
|
|
@@ -48,7 +49,16 @@ def load_config():
|
|
| 48 |
return []
|
| 49 |
|
| 50 |
config = load_config()
|
| 51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
async def process_request(request: RequestModel, provider: Dict):
|
| 54 |
print("provider: ", provider['provider'])
|
|
@@ -64,15 +74,16 @@ async def process_request(request: RequestModel, provider: Dict):
|
|
| 64 |
|
| 65 |
url, headers, payload = await get_payload(request, engine, provider)
|
| 66 |
|
| 67 |
-
request_info = {
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
}
|
| 72 |
-
print(f"Request details: {json.dumps(request_info, indent=4, ensure_ascii=False)}")
|
| 73 |
|
| 74 |
if request.stream:
|
| 75 |
-
|
|
|
|
| 76 |
else:
|
| 77 |
return await fetch_response(app.state.client, url, headers, payload)
|
| 78 |
|
|
@@ -81,7 +92,11 @@ class ModelRequestHandler:
|
|
| 81 |
self.last_provider_index = -1
|
| 82 |
|
| 83 |
def get_matching_providers(self, model_name):
|
| 84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
|
| 86 |
async def request_model(self, request: RequestModel, token: str):
|
| 87 |
model_name = request.model
|
|
@@ -122,6 +137,18 @@ class ModelRequestHandler:
|
|
| 122 |
|
| 123 |
model_handler = ModelRequestHandler()
|
| 124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
|
| 126 |
token = credentials.credentials
|
| 127 |
if token not in api_keys_db:
|
|
@@ -137,7 +164,7 @@ def get_all_models():
|
|
| 137 |
unique_models = set()
|
| 138 |
|
| 139 |
for provider in config:
|
| 140 |
-
for model in provider['model']:
|
| 141 |
if model not in unique_models:
|
| 142 |
unique_models.add(model)
|
| 143 |
model_info = {
|
|
|
|
| 1 |
import os
|
| 2 |
import json
|
| 3 |
import httpx
|
| 4 |
+
import logging
|
| 5 |
import yaml
|
| 6 |
import traceback
|
| 7 |
from contextlib import asynccontextmanager
|
| 8 |
|
| 9 |
+
from fastapi import FastAPI, Request, HTTPException, Depends
|
| 10 |
from fastapi.responses import StreamingResponse
|
| 11 |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
| 12 |
|
|
|
|
| 49 |
return []
|
| 50 |
|
| 51 |
config = load_config()
|
| 52 |
+
for index, provider in enumerate(config):
|
| 53 |
+
model_dict = {}
|
| 54 |
+
for model in provider['model']:
|
| 55 |
+
if type(model) == str:
|
| 56 |
+
model_dict[model] = model
|
| 57 |
+
if type(model) == dict:
|
| 58 |
+
model_dict.update({value: key for key, value in model.items()})
|
| 59 |
+
provider['model'] = model_dict
|
| 60 |
+
config[index] = provider
|
| 61 |
+
# print(json.dumps(config, indent=4, ensure_ascii=False))
|
| 62 |
|
| 63 |
async def process_request(request: RequestModel, provider: Dict):
|
| 64 |
print("provider: ", provider['provider'])
|
|
|
|
| 74 |
|
| 75 |
url, headers, payload = await get_payload(request, engine, provider)
|
| 76 |
|
| 77 |
+
# request_info = {
|
| 78 |
+
# "url": url,
|
| 79 |
+
# "headers": headers,
|
| 80 |
+
# "payload": payload
|
| 81 |
+
# }
|
| 82 |
+
# print(f"Request details: {json.dumps(request_info, indent=4, ensure_ascii=False)}")
|
| 83 |
|
| 84 |
if request.stream:
|
| 85 |
+
model = provider['model'][request.model]
|
| 86 |
+
return StreamingResponse(fetch_response_stream(app.state.client, url, headers, payload, engine, model), media_type="text/event-stream")
|
| 87 |
else:
|
| 88 |
return await fetch_response(app.state.client, url, headers, payload)
|
| 89 |
|
|
|
|
| 92 |
self.last_provider_index = -1
|
| 93 |
|
| 94 |
def get_matching_providers(self, model_name):
|
| 95 |
+
# for provider in config:
|
| 96 |
+
# print("provider", model_name, list(provider['model'].keys()))
|
| 97 |
+
# if model_name in provider['model'].keys():
|
| 98 |
+
# print("provider", provider)
|
| 99 |
+
return [provider for provider in config if model_name in provider['model'].keys()]
|
| 100 |
|
| 101 |
async def request_model(self, request: RequestModel, token: str):
|
| 102 |
model_name = request.model
|
|
|
|
| 137 |
|
| 138 |
model_handler = ModelRequestHandler()
|
| 139 |
|
| 140 |
+
@app.middleware("http")
|
| 141 |
+
async def log_requests(request: Request, call_next):
|
| 142 |
+
# 打印请求信息
|
| 143 |
+
logging.info(f"Request: {request.method} {request.url}")
|
| 144 |
+
# 打印请求体(如果有)
|
| 145 |
+
if request.method in ["POST", "PUT", "PATCH"]:
|
| 146 |
+
body = await request.body()
|
| 147 |
+
logging.info(f"Request Body: {body.decode('utf-8')}")
|
| 148 |
+
|
| 149 |
+
response = await call_next(request)
|
| 150 |
+
return response
|
| 151 |
+
|
| 152 |
def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
|
| 153 |
token = credentials.credentials
|
| 154 |
if token not in api_keys_db:
|
|
|
|
| 164 |
unique_models = set()
|
| 165 |
|
| 166 |
for provider in config:
|
| 167 |
+
for model in provider['model'].keys():
|
| 168 |
if model not in unique_models:
|
| 169 |
unique_models.add(model)
|
| 170 |
model_info = {
|
request.py
CHANGED
|
@@ -38,9 +38,10 @@ async def get_gemini_payload(request, engine, provider):
|
|
| 38 |
'Content-Type': 'application/json'
|
| 39 |
}
|
| 40 |
url = provider['base_url']
|
|
|
|
| 41 |
if request.stream:
|
| 42 |
gemini_stream = "streamGenerateContent"
|
| 43 |
-
url = url.format(model=
|
| 44 |
|
| 45 |
messages = []
|
| 46 |
for msg in request.messages:
|
|
@@ -112,7 +113,6 @@ async def get_gpt_payload(request, engine, provider):
|
|
| 112 |
'Content-Type': 'application/json'
|
| 113 |
}
|
| 114 |
url = provider['base_url']
|
| 115 |
-
url = url.format(model=request.model, stream=request.stream, api_key=provider['api'])
|
| 116 |
|
| 117 |
messages = []
|
| 118 |
for msg in request.messages:
|
|
@@ -133,8 +133,9 @@ async def get_gpt_payload(request, engine, provider):
|
|
| 133 |
else:
|
| 134 |
messages.append({"role": msg.role, "content": content})
|
| 135 |
|
|
|
|
| 136 |
payload = {
|
| 137 |
-
"model":
|
| 138 |
"messages": messages,
|
| 139 |
}
|
| 140 |
|
|
@@ -222,8 +223,9 @@ async def get_claude_payload(request, engine, provider):
|
|
| 222 |
elif msg.role == "system":
|
| 223 |
system_prompt = content
|
| 224 |
|
|
|
|
| 225 |
payload = {
|
| 226 |
-
"model":
|
| 227 |
"messages": messages,
|
| 228 |
"system": system_prompt,
|
| 229 |
}
|
|
|
|
| 38 |
'Content-Type': 'application/json'
|
| 39 |
}
|
| 40 |
url = provider['base_url']
|
| 41 |
+
model = provider['model'][request.model]
|
| 42 |
if request.stream:
|
| 43 |
gemini_stream = "streamGenerateContent"
|
| 44 |
+
url = url.format(model=model, stream=gemini_stream, api_key=provider['api'])
|
| 45 |
|
| 46 |
messages = []
|
| 47 |
for msg in request.messages:
|
|
|
|
| 113 |
'Content-Type': 'application/json'
|
| 114 |
}
|
| 115 |
url = provider['base_url']
|
|
|
|
| 116 |
|
| 117 |
messages = []
|
| 118 |
for msg in request.messages:
|
|
|
|
| 133 |
else:
|
| 134 |
messages.append({"role": msg.role, "content": content})
|
| 135 |
|
| 136 |
+
model = provider['model'][request.model]
|
| 137 |
payload = {
|
| 138 |
+
"model": model,
|
| 139 |
"messages": messages,
|
| 140 |
}
|
| 141 |
|
|
|
|
| 223 |
elif msg.role == "system":
|
| 224 |
system_prompt = content
|
| 225 |
|
| 226 |
+
model = provider['model'][request.model]
|
| 227 |
payload = {
|
| 228 |
+
"model": model,
|
| 229 |
"messages": messages,
|
| 230 |
"system": system_prompt,
|
| 231 |
}
|
response.py
CHANGED
|
@@ -2,7 +2,7 @@ from datetime import datetime
|
|
| 2 |
import json
|
| 3 |
import httpx
|
| 4 |
|
| 5 |
-
async def generate_sse_response(timestamp, model, content=None, tools_id=None, function_call_name=None, function_call_content=None):
|
| 6 |
sample_data = {
|
| 7 |
"id": "chatcmpl-9ijPeRHa0wtyA2G8wq5z8FC3wGMzc",
|
| 8 |
"object": "chat.completion.chunk",
|
|
@@ -24,6 +24,8 @@ async def generate_sse_response(timestamp, model, content=None, tools_id=None, f
|
|
| 24 |
if tools_id and function_call_name:
|
| 25 |
sample_data["choices"][0]["delta"] = {"tool_calls":[{"index":0,"id":tools_id,"type":"function","function":{"name":function_call_name,"arguments":""}}]}
|
| 26 |
# sample_data["choices"][0]["delta"] = {"tool_calls":[{"index":0,"function":{"id": tools_id, "name": function_call_name}}]}
|
|
|
|
|
|
|
| 27 |
json_data = json.dumps(sample_data, ensure_ascii=False)
|
| 28 |
|
| 29 |
# 构建SSE响应
|
|
@@ -91,6 +93,10 @@ async def fetch_claude_response_stream(client, url, headers, payload, model):
|
|
| 91 |
message = resp.get("message")
|
| 92 |
if message:
|
| 93 |
tokens_use = resp.get("usage")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
if tokens_use:
|
| 95 |
total_tokens = tokens_use["input_tokens"] + tokens_use["output_tokens"]
|
| 96 |
# print("\n\rtotal_tokens", total_tokens)
|
|
|
|
| 2 |
import json
|
| 3 |
import httpx
|
| 4 |
|
| 5 |
+
async def generate_sse_response(timestamp, model, content=None, tools_id=None, function_call_name=None, function_call_content=None, role=None, tokens_use=None, total_tokens=None):
|
| 6 |
sample_data = {
|
| 7 |
"id": "chatcmpl-9ijPeRHa0wtyA2G8wq5z8FC3wGMzc",
|
| 8 |
"object": "chat.completion.chunk",
|
|
|
|
| 24 |
if tools_id and function_call_name:
|
| 25 |
sample_data["choices"][0]["delta"] = {"tool_calls":[{"index":0,"id":tools_id,"type":"function","function":{"name":function_call_name,"arguments":""}}]}
|
| 26 |
# sample_data["choices"][0]["delta"] = {"tool_calls":[{"index":0,"function":{"id": tools_id, "name": function_call_name}}]}
|
| 27 |
+
if role:
|
| 28 |
+
sample_data["choices"][0]["delta"] = {"role": role, "content": ""}
|
| 29 |
json_data = json.dumps(sample_data, ensure_ascii=False)
|
| 30 |
|
| 31 |
# 构建SSE响应
|
|
|
|
| 93 |
message = resp.get("message")
|
| 94 |
if message:
|
| 95 |
tokens_use = resp.get("usage")
|
| 96 |
+
role = message.get("role")
|
| 97 |
+
if role:
|
| 98 |
+
sse_string = await generate_sse_response(timestamp, model, None, None, None, None, role)
|
| 99 |
+
yield sse_string
|
| 100 |
if tokens_use:
|
| 101 |
total_tokens = tokens_use["input_tokens"] + tokens_use["output_tokens"]
|
| 102 |
# print("\n\rtotal_tokens", total_tokens)
|