Add feature: support OpenAI dall-e-3 image generation
Browse files- README.md +14 -9
- main.py +21 -10
- models.py +7 -0
- request.py +22 -1
- response.py +33 -39
- utils.py +27 -1
README.md
CHANGED
|
@@ -12,20 +12,23 @@
|
|
| 12 |
|
| 13 |
## Introduction
|
| 14 |
|
| 15 |
-
|
| 16 |
|
| 17 |
## Features
|
| 18 |
|
| 19 |
-
-
|
| 20 |
-
-
|
| 21 |
-
-
|
| 22 |
-
-
|
| 23 |
-
-
|
| 24 |
-
-
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
## Configuration
|
| 27 |
|
| 28 |
-
使用api.yaml配置文件,可以配置多个模型,每个模型可以配置多个后端服务,支持负载均衡。下面是 api.yaml 配置文件的示例:
|
| 29 |
|
| 30 |
```yaml
|
| 31 |
providers:
|
|
@@ -35,6 +38,7 @@ providers:
|
|
| 35 |
model: # 至少填一个模型
|
| 36 |
- gpt-4o # 可以使用的模型名称,必填
|
| 37 |
- claude-3-5-sonnet-20240620: claude-3-5-sonnet # 重命名模型,claude-3-5-sonnet-20240620 是服务商的模型名称,claude-3-5-sonnet 是重命名后的名字,可以使用简洁的名字代替原来复杂的名称,选填
|
|
|
|
| 38 |
|
| 39 |
- provider: anthropic
|
| 40 |
base_url: https://api.anthropic.com/v1/messages
|
|
@@ -86,7 +90,7 @@ api_keys:
|
|
| 86 |
model:
|
| 87 |
- anthropic/claude-3-5-sonnet # 可以使用的模型名称,仅可以使用名为 anthropic 提供商提供的 claude-3-5-sonnet 模型。其他提供商的 claude-3-5-sonnet 模型不可以使用。
|
| 88 |
preferences:
|
| 89 |
-
USE_ROUND_ROBIN: true # 是否使用轮询负载均衡,true 为使用,false 为不使用,默认为 true
|
| 90 |
AUTO_RETRY: true # 是否自动重试,自动重试下一个提供商,true 为自动重试,false 为不自动重试,默认为 true
|
| 91 |
```
|
| 92 |
|
|
@@ -152,6 +156,7 @@ curl -X POST http://127.0.0.1:8000/v1/chat/completions \
|
|
| 152 |
-d '{"model": "gpt-4o","messages": [{"role": "user", "content": "Hello"}],"stream": true}'
|
| 153 |
```
|
| 154 |
|
|
|
|
| 155 |
## Star History
|
| 156 |
|
| 157 |
<a href="https://github.com/yym68686/uni-api/stargazers">
|
|
|
|
| 12 |
|
| 13 |
## Introduction
|
| 14 |
|
| 15 |
+
如果个人使用的话,one/new-api 过于复杂,有很多个人不需要使用的商用功能,如果你不想要复杂的前端界面,有想要支持的模型多一点,可以试试 uni-api。这是一个统一管理大模型API的项目,可以通过一个统一的API接口调用多个后端服务,统一转换为 OpenAI 格式,支持负载均衡。目前支持的后端服务有:OpenAI、Anthropic、Gemini、Vertex、DeepBricks、OpenRouter 等。
|
| 16 |
|
| 17 |
## Features
|
| 18 |
|
| 19 |
+
- 无前端,纯配置文件配置 API 渠道。只要写一个文件就能运行起一个属于自己的 API 站,文档有详细的配置指南,小白友好。
|
| 20 |
+
- 统一管理多个后端服务,支持 OpenAI、Deepseek、DeepBricks、OpenRouter 等其他API 是 OpenAI 格式的提供商。支持 OpenAI Dalle-3 图像生成。
|
| 21 |
+
- 同时支持 Anthropic、Gemini、Vertex API。Vertex 同时支持 Claude 和 Gemini API。
|
| 22 |
+
- 支持 OpenAI、 Anthropic、Gemini、Vertex 原生 tool use 函数调用。
|
| 23 |
+
- 支持 OpenAI、Anthropic、Gemini、Vertex 原生识图 API。
|
| 24 |
+
- 支持负载均衡,支持 Vertex 区域负载均衡,支持 Vertex 高并发,最高可将 Gemini,Claude 并发提高 (API数量 * 区域数量) 倍。除了 Vertex 区域负载均衡,所有 API 均支持渠道级负载均衡,提高沉浸式翻译体验。
|
| 25 |
+
- 支持自动重试,当一个 API 渠道响应失败时,自动重试下一个 API 渠道。
|
| 26 |
+
- 支持细粒度的权限控制。支持使用通配符设置 API key 可用渠道的特定模型。
|
| 27 |
+
- 支持多个 API Key。
|
| 28 |
|
| 29 |
## Configuration
|
| 30 |
|
| 31 |
+
使用 api.yaml 配置文件,可以配置多个模型,每个模型可以配置多个后端服务,支持负载均衡。下面是 api.yaml 配置文件的示例:
|
| 32 |
|
| 33 |
```yaml
|
| 34 |
providers:
|
|
|
|
| 38 |
model: # 至少填一个模型
|
| 39 |
- gpt-4o # 可以使用的模型名称,必填
|
| 40 |
- claude-3-5-sonnet-20240620: claude-3-5-sonnet # 重命名模型,claude-3-5-sonnet-20240620 是服务商的模型名称,claude-3-5-sonnet 是重命名后的名字,可以使用简洁的名字代替原来复杂的名称,选填
|
| 41 |
+
- dall-e-3
|
| 42 |
|
| 43 |
- provider: anthropic
|
| 44 |
base_url: https://api.anthropic.com/v1/messages
|
|
|
|
| 90 |
model:
|
| 91 |
- anthropic/claude-3-5-sonnet # 可以使用的模型名称,仅可以使用名为 anthropic 提供商提供的 claude-3-5-sonnet 模型。其他提供商的 claude-3-5-sonnet 模型不可以使用。
|
| 92 |
preferences:
|
| 93 |
+
USE_ROUND_ROBIN: true # 是否使用轮询负载均衡,true 为使用,false 为不使用,默认为 true。开启轮训后每次请求模型按照 model 配置的顺序依次请求。与 providers 里面原始的渠道顺序无关。因此你可以设置每个 API key 请求顺序不一样。
|
| 94 |
AUTO_RETRY: true # 是否自动重试,自动重试下一个提供商,true 为自动重试,false 为不自动重试,默认为 true
|
| 95 |
```
|
| 96 |
|
|
|
|
| 156 |
-d '{"model": "gpt-4o","messages": [{"role": "user", "content": "Hello"}],"stream": true}'
|
| 157 |
```
|
| 158 |
|
| 159 |
+
|
| 160 |
## Star History
|
| 161 |
|
| 162 |
<a href="https://github.com/yym68686/uni-api/stargazers">
|
main.py
CHANGED
|
@@ -5,16 +5,16 @@ import secrets
|
|
| 5 |
from contextlib import asynccontextmanager
|
| 6 |
|
| 7 |
from fastapi.middleware.cors import CORSMiddleware
|
| 8 |
-
from fastapi import FastAPI, HTTPException, Depends
|
| 9 |
from fastapi.responses import StreamingResponse, JSONResponse
|
| 10 |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
| 11 |
|
| 12 |
-
from models import RequestModel
|
| 13 |
from utils import error_handling_wrapper, get_all_models, post_all_models, load_config
|
| 14 |
from request import get_payload
|
| 15 |
from response import fetch_response, fetch_response_stream
|
| 16 |
|
| 17 |
-
from typing import List, Dict
|
| 18 |
from urllib.parse import urlparse
|
| 19 |
|
| 20 |
@asynccontextmanager
|
|
@@ -80,7 +80,7 @@ app.add_middleware(
|
|
| 80 |
allow_headers=["*"], # 允许所有头部字段
|
| 81 |
)
|
| 82 |
|
| 83 |
-
async def process_request(request: RequestModel, provider: Dict):
|
| 84 |
url = provider['base_url']
|
| 85 |
parsed_url = urlparse(url)
|
| 86 |
# print(parsed_url)
|
|
@@ -101,6 +101,10 @@ async def process_request(request: RequestModel, provider: Dict):
|
|
| 101 |
and "gemini" not in provider['model'][request.model]:
|
| 102 |
engine = "openrouter"
|
| 103 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
if provider.get("engine"):
|
| 105 |
engine = provider["engine"]
|
| 106 |
|
|
@@ -122,7 +126,7 @@ async def process_request(request: RequestModel, provider: Dict):
|
|
| 122 |
wrapped_generator = await error_handling_wrapper(generator, status_code=500)
|
| 123 |
return StreamingResponse(wrapped_generator, media_type="text/event-stream")
|
| 124 |
else:
|
| 125 |
-
return await fetch_response(app.state.client, url, headers, payload)
|
| 126 |
|
| 127 |
import asyncio
|
| 128 |
class ModelRequestHandler:
|
|
@@ -171,7 +175,7 @@ class ModelRequestHandler:
|
|
| 171 |
# print(json.dumps(provider, indent=4, ensure_ascii=False))
|
| 172 |
return provider_list
|
| 173 |
|
| 174 |
-
async def request_model(self, request: RequestModel, token: str):
|
| 175 |
config = app.state.config
|
| 176 |
# api_keys_db = app.state.api_keys_db
|
| 177 |
api_list = app.state.api_list
|
|
@@ -193,9 +197,9 @@ class ModelRequestHandler:
|
|
| 193 |
if config['api_keys'][api_index]["preferences"].get("AUTO_RETRY") == False:
|
| 194 |
auto_retry = False
|
| 195 |
|
| 196 |
-
return await self.try_all_providers(request, matching_providers, use_round_robin, auto_retry)
|
| 197 |
|
| 198 |
-
async def try_all_providers(self, request: RequestModel, providers: List[Dict], use_round_robin: bool, auto_retry: bool):
|
| 199 |
num_providers = len(providers)
|
| 200 |
start_index = self.last_provider_index + 1 if use_round_robin else 0
|
| 201 |
|
|
@@ -203,7 +207,7 @@ class ModelRequestHandler:
|
|
| 203 |
self.last_provider_index = (start_index + i) % num_providers
|
| 204 |
provider = providers[self.last_provider_index]
|
| 205 |
try:
|
| 206 |
-
response = await process_request(request, provider)
|
| 207 |
return response
|
| 208 |
except (Exception, HTTPException, asyncio.CancelledError, httpx.ReadError) as e:
|
| 209 |
logger.error(f"Error with provider {provider['provider']}: {str(e)}")
|
|
@@ -228,7 +232,7 @@ def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)
|
|
| 228 |
return token
|
| 229 |
|
| 230 |
@app.post("/v1/chat/completions")
|
| 231 |
-
async def request_model(request: RequestModel, token: str = Depends(verify_api_key)):
|
| 232 |
return await model_handler.request_model(request, token)
|
| 233 |
|
| 234 |
@app.options("/v1/chat/completions")
|
|
@@ -251,6 +255,13 @@ async def list_models():
|
|
| 251 |
"data": models
|
| 252 |
})
|
| 253 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
@app.get("/generate-api-key")
|
| 255 |
def generate_api_key():
|
| 256 |
api_key = "sk-" + secrets.token_urlsafe(32)
|
|
|
|
| 5 |
from contextlib import asynccontextmanager
|
| 6 |
|
| 7 |
from fastapi.middleware.cors import CORSMiddleware
|
| 8 |
+
from fastapi import FastAPI, HTTPException, Depends
|
| 9 |
from fastapi.responses import StreamingResponse, JSONResponse
|
| 10 |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
| 11 |
|
| 12 |
+
from models import RequestModel, ImageGenerationRequest
|
| 13 |
from utils import error_handling_wrapper, get_all_models, post_all_models, load_config
|
| 14 |
from request import get_payload
|
| 15 |
from response import fetch_response, fetch_response_stream
|
| 16 |
|
| 17 |
+
from typing import List, Dict, Union
|
| 18 |
from urllib.parse import urlparse
|
| 19 |
|
| 20 |
@asynccontextmanager
|
|
|
|
| 80 |
allow_headers=["*"], # 允许所有头部字段
|
| 81 |
)
|
| 82 |
|
| 83 |
+
async def process_request(request: Union[RequestModel, ImageGenerationRequest], provider: Dict, endpoint=None):
|
| 84 |
url = provider['base_url']
|
| 85 |
parsed_url = urlparse(url)
|
| 86 |
# print(parsed_url)
|
|
|
|
| 101 |
and "gemini" not in provider['model'][request.model]:
|
| 102 |
engine = "openrouter"
|
| 103 |
|
| 104 |
+
if endpoint == "/v1/images/generations":
|
| 105 |
+
engine = "dalle"
|
| 106 |
+
request.stream = False
|
| 107 |
+
|
| 108 |
if provider.get("engine"):
|
| 109 |
engine = provider["engine"]
|
| 110 |
|
|
|
|
| 126 |
wrapped_generator = await error_handling_wrapper(generator, status_code=500)
|
| 127 |
return StreamingResponse(wrapped_generator, media_type="text/event-stream")
|
| 128 |
else:
|
| 129 |
+
return await anext(fetch_response(app.state.client, url, headers, payload))
|
| 130 |
|
| 131 |
import asyncio
|
| 132 |
class ModelRequestHandler:
|
|
|
|
| 175 |
# print(json.dumps(provider, indent=4, ensure_ascii=False))
|
| 176 |
return provider_list
|
| 177 |
|
| 178 |
+
async def request_model(self, request: Union[RequestModel, ImageGenerationRequest], token: str, endpoint=None):
|
| 179 |
config = app.state.config
|
| 180 |
# api_keys_db = app.state.api_keys_db
|
| 181 |
api_list = app.state.api_list
|
|
|
|
| 197 |
if config['api_keys'][api_index]["preferences"].get("AUTO_RETRY") == False:
|
| 198 |
auto_retry = False
|
| 199 |
|
| 200 |
+
return await self.try_all_providers(request, matching_providers, use_round_robin, auto_retry, endpoint)
|
| 201 |
|
| 202 |
+
async def try_all_providers(self, request: Union[RequestModel, ImageGenerationRequest], providers: List[Dict], use_round_robin: bool, auto_retry: bool, endpoint: str = None):
|
| 203 |
num_providers = len(providers)
|
| 204 |
start_index = self.last_provider_index + 1 if use_round_robin else 0
|
| 205 |
|
|
|
|
| 207 |
self.last_provider_index = (start_index + i) % num_providers
|
| 208 |
provider = providers[self.last_provider_index]
|
| 209 |
try:
|
| 210 |
+
response = await process_request(request, provider, endpoint)
|
| 211 |
return response
|
| 212 |
except (Exception, HTTPException, asyncio.CancelledError, httpx.ReadError) as e:
|
| 213 |
logger.error(f"Error with provider {provider['provider']}: {str(e)}")
|
|
|
|
| 232 |
return token
|
| 233 |
|
| 234 |
@app.post("/v1/chat/completions")
|
| 235 |
+
async def request_model(request: Union[RequestModel, ImageGenerationRequest], token: str = Depends(verify_api_key)):
|
| 236 |
return await model_handler.request_model(request, token)
|
| 237 |
|
| 238 |
@app.options("/v1/chat/completions")
|
|
|
|
| 255 |
"data": models
|
| 256 |
})
|
| 257 |
|
| 258 |
+
@app.post("/v1/images/generations")
|
| 259 |
+
async def images_generations(
|
| 260 |
+
request: ImageGenerationRequest,
|
| 261 |
+
token: str = Depends(verify_api_key)
|
| 262 |
+
):
|
| 263 |
+
return await model_handler.request_model(request, token, endpoint="/v1/images/generations")
|
| 264 |
+
|
| 265 |
@app.get("/generate-api-key")
|
| 266 |
def generate_api_key():
|
| 267 |
api_key = "sk-" + secrets.token_urlsafe(32)
|
models.py
CHANGED
|
@@ -1,6 +1,13 @@
|
|
| 1 |
from pydantic import BaseModel, Field
|
| 2 |
from typing import List, Dict, Optional, Union
|
| 3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
class FunctionParameter(BaseModel):
|
| 5 |
type: str
|
| 6 |
properties: Dict[str, Dict[str, str]]
|
|
|
|
| 1 |
from pydantic import BaseModel, Field
|
| 2 |
from typing import List, Dict, Optional, Union
|
| 3 |
|
| 4 |
+
class ImageGenerationRequest(BaseModel):
|
| 5 |
+
model: str
|
| 6 |
+
prompt: str
|
| 7 |
+
n: int
|
| 8 |
+
size: str
|
| 9 |
+
stream: bool = False
|
| 10 |
+
|
| 11 |
class FunctionParameter(BaseModel):
|
| 12 |
type: str
|
| 13 |
properties: Dict[str, Dict[str, str]]
|
request.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
import json
|
| 2 |
from models import RequestModel
|
| 3 |
-
from utils import c35s, c3s, c3o, c3h, gem,
|
| 4 |
|
| 5 |
async def get_image_message(base64_image, engine = None):
|
| 6 |
if "gpt" == engine:
|
|
@@ -748,6 +748,25 @@ async def get_claude_payload(request, engine, provider):
|
|
| 748 |
|
| 749 |
return url, headers, payload
|
| 750 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 751 |
async def get_payload(request: RequestModel, engine, provider):
|
| 752 |
if engine == "gemini":
|
| 753 |
return await get_gemini_payload(request, engine, provider)
|
|
@@ -761,5 +780,7 @@ async def get_payload(request: RequestModel, engine, provider):
|
|
| 761 |
return await get_gpt_payload(request, engine, provider)
|
| 762 |
elif engine == "openrouter":
|
| 763 |
return await get_openrouter_payload(request, engine, provider)
|
|
|
|
|
|
|
| 764 |
else:
|
| 765 |
raise ValueError("Unknown payload")
|
|
|
|
| 1 |
import json
|
| 2 |
from models import RequestModel
|
| 3 |
+
from utils import c35s, c3s, c3o, c3h, gem, BaseAPI
|
| 4 |
|
| 5 |
async def get_image_message(base64_image, engine = None):
|
| 6 |
if "gpt" == engine:
|
|
|
|
| 748 |
|
| 749 |
return url, headers, payload
|
| 750 |
|
| 751 |
+
async def get_dalle_payload(request, engine, provider):
|
| 752 |
+
model = provider['model'][request.model]
|
| 753 |
+
headers = {
|
| 754 |
+
"Content-Type": "application/json",
|
| 755 |
+
}
|
| 756 |
+
if provider.get("api"):
|
| 757 |
+
headers['Authorization'] = f"Bearer {provider['api']}"
|
| 758 |
+
url = provider['base_url']
|
| 759 |
+
url = BaseAPI(url).image_url
|
| 760 |
+
|
| 761 |
+
payload = {
|
| 762 |
+
"model": model,
|
| 763 |
+
"prompt": request.prompt,
|
| 764 |
+
"n": request.n,
|
| 765 |
+
"size": request.size
|
| 766 |
+
}
|
| 767 |
+
|
| 768 |
+
return url, headers, payload
|
| 769 |
+
|
| 770 |
async def get_payload(request: RequestModel, engine, provider):
|
| 771 |
if engine == "gemini":
|
| 772 |
return await get_gemini_payload(request, engine, provider)
|
|
|
|
| 780 |
return await get_gpt_payload(request, engine, provider)
|
| 781 |
elif engine == "openrouter":
|
| 782 |
return await get_openrouter_payload(request, engine, provider)
|
| 783 |
+
elif engine == "dalle":
|
| 784 |
+
return await get_dalle_payload(request, engine, provider)
|
| 785 |
else:
|
| 786 |
raise ValueError("Unknown payload")
|
response.py
CHANGED
|
@@ -36,17 +36,24 @@ async def generate_sse_response(timestamp, model, content=None, tools_id=None, f
|
|
| 36 |
|
| 37 |
return sse_response
|
| 38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
async def fetch_gemini_response_stream(client, url, headers, payload, model):
|
| 40 |
timestamp = datetime.timestamp(datetime.now())
|
| 41 |
async with client.stream('POST', url, headers=headers, json=payload) as response:
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
error_json = json.loads(error_str)
|
| 47 |
-
except json.JSONDecodeError:
|
| 48 |
-
error_json = error_str
|
| 49 |
-
yield {"error": f"fetch_gpt_response_stream HTTP Error {response.status_code}", "details": error_json}
|
| 50 |
buffer = ""
|
| 51 |
revicing_function_call = False
|
| 52 |
function_full_response = "{"
|
|
@@ -87,14 +94,11 @@ async def fetch_gemini_response_stream(client, url, headers, payload, model):
|
|
| 87 |
async def fetch_vertex_claude_response_stream(client, url, headers, payload, model):
|
| 88 |
timestamp = datetime.timestamp(datetime.now())
|
| 89 |
async with client.stream('POST', url, headers=headers, json=payload) as response:
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
except json.JSONDecodeError:
|
| 96 |
-
error_json = error_str
|
| 97 |
-
yield {"error": f"fetch_gpt_response_stream HTTP Error {response.status_code}", "details": error_json}
|
| 98 |
buffer = ""
|
| 99 |
revicing_function_call = False
|
| 100 |
function_full_response = "{"
|
|
@@ -138,14 +142,9 @@ async def fetch_gpt_response_stream(client, url, headers, payload, max_redirects
|
|
| 138 |
while redirect_count < max_redirects:
|
| 139 |
# logger.info(f"fetch_gpt_response_stream: {url}")
|
| 140 |
async with client.stream('POST', url, headers=headers, json=payload) as response:
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
try:
|
| 145 |
-
error_json = json.loads(error_str)
|
| 146 |
-
except json.JSONDecodeError:
|
| 147 |
-
error_json = error_str
|
| 148 |
-
yield {"error": f"fetch_gpt_response_stream HTTP Error {response.status_code}", "details": error_json}
|
| 149 |
return
|
| 150 |
|
| 151 |
buffer = ""
|
|
@@ -185,14 +184,10 @@ async def fetch_gpt_response_stream(client, url, headers, payload, max_redirects
|
|
| 185 |
async def fetch_claude_response_stream(client, url, headers, payload, model):
|
| 186 |
timestamp = datetime.timestamp(datetime.now())
|
| 187 |
async with client.stream('POST', url, headers=headers, json=payload) as response:
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
error_json = json.loads(error_str)
|
| 193 |
-
except json.JSONDecodeError:
|
| 194 |
-
error_json = error_str
|
| 195 |
-
yield {"error": f"fetch_claude_response_stream HTTP Error {response.status_code}", "details": error_json}
|
| 196 |
buffer = ""
|
| 197 |
async for chunk in response.aiter_text():
|
| 198 |
# logger.info(f"chunk: {repr(chunk)}")
|
|
@@ -241,13 +236,12 @@ async def fetch_claude_response_stream(client, url, headers, payload, model):
|
|
| 241 |
yield sse_string
|
| 242 |
|
| 243 |
async def fetch_response(client, url, headers, payload):
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
return
|
| 249 |
-
|
| 250 |
-
return {"error": f"500", "details": "fetch_response Read Response Timeout"}
|
| 251 |
|
| 252 |
async def fetch_response_stream(client, url, headers, payload, engine, model):
|
| 253 |
try:
|
|
|
|
| 36 |
|
| 37 |
return sse_response
|
| 38 |
|
| 39 |
+
async def check_response(response, error_log):
|
| 40 |
+
if response.status_code != 200:
|
| 41 |
+
error_message = await response.aread()
|
| 42 |
+
error_str = error_message.decode('utf-8', errors='replace')
|
| 43 |
+
try:
|
| 44 |
+
error_json = json.loads(error_str)
|
| 45 |
+
except json.JSONDecodeError:
|
| 46 |
+
error_json = error_str
|
| 47 |
+
return {"error": f"{error_log} HTTP Error {response.status_code}", "details": error_json}
|
| 48 |
+
return None
|
| 49 |
+
|
| 50 |
async def fetch_gemini_response_stream(client, url, headers, payload, model):
|
| 51 |
timestamp = datetime.timestamp(datetime.now())
|
| 52 |
async with client.stream('POST', url, headers=headers, json=payload) as response:
|
| 53 |
+
error_message = await check_response(response, "fetch_gemini_response_stream")
|
| 54 |
+
if error_message:
|
| 55 |
+
yield error_message
|
| 56 |
+
return
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
buffer = ""
|
| 58 |
revicing_function_call = False
|
| 59 |
function_full_response = "{"
|
|
|
|
| 94 |
async def fetch_vertex_claude_response_stream(client, url, headers, payload, model):
|
| 95 |
timestamp = datetime.timestamp(datetime.now())
|
| 96 |
async with client.stream('POST', url, headers=headers, json=payload) as response:
|
| 97 |
+
error_message = await check_response(response, "fetch_vertex_claude_response_stream")
|
| 98 |
+
if error_message:
|
| 99 |
+
yield error_message
|
| 100 |
+
return
|
| 101 |
+
|
|
|
|
|
|
|
|
|
|
| 102 |
buffer = ""
|
| 103 |
revicing_function_call = False
|
| 104 |
function_full_response = "{"
|
|
|
|
| 142 |
while redirect_count < max_redirects:
|
| 143 |
# logger.info(f"fetch_gpt_response_stream: {url}")
|
| 144 |
async with client.stream('POST', url, headers=headers, json=payload) as response:
|
| 145 |
+
error_message = await check_response(response, "fetch_gpt_response_stream")
|
| 146 |
+
if error_message:
|
| 147 |
+
yield error_message
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
return
|
| 149 |
|
| 150 |
buffer = ""
|
|
|
|
| 184 |
async def fetch_claude_response_stream(client, url, headers, payload, model):
|
| 185 |
timestamp = datetime.timestamp(datetime.now())
|
| 186 |
async with client.stream('POST', url, headers=headers, json=payload) as response:
|
| 187 |
+
error_message = await check_response(response, "fetch_claude_response_stream")
|
| 188 |
+
if error_message:
|
| 189 |
+
yield error_message
|
| 190 |
+
return
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
buffer = ""
|
| 192 |
async for chunk in response.aiter_text():
|
| 193 |
# logger.info(f"chunk: {repr(chunk)}")
|
|
|
|
| 236 |
yield sse_string
|
| 237 |
|
| 238 |
async def fetch_response(client, url, headers, payload):
|
| 239 |
+
response = await client.post(url, headers=headers, json=payload)
|
| 240 |
+
error_message = await check_response(response, "fetch_response")
|
| 241 |
+
if error_message:
|
| 242 |
+
yield error_message
|
| 243 |
+
return
|
| 244 |
+
yield response.json()
|
|
|
|
| 245 |
|
| 246 |
async def fetch_response_stream(client, url, headers, payload, engine, model):
|
| 247 |
try:
|
utils.py
CHANGED
|
@@ -222,4 +222,30 @@ c35s = CircularList(["us-east5", "europe-west1"])
|
|
| 222 |
c3s = CircularList(["us-east5", "us-central1", "asia-southeast1"])
|
| 223 |
c3o = CircularList(["us-east5"])
|
| 224 |
c3h = CircularList(["us-east5", "us-central1", "europe-west1", "europe-west4"])
|
| 225 |
-
gem = CircularList(["us-central1", "us-east4", "us-west1", "us-west4", "europe-west1", "europe-west2"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
c3s = CircularList(["us-east5", "us-central1", "asia-southeast1"])
|
| 223 |
c3o = CircularList(["us-east5"])
|
| 224 |
c3h = CircularList(["us-east5", "us-central1", "europe-west1", "europe-west4"])
|
| 225 |
+
gem = CircularList(["us-central1", "us-east4", "us-west1", "us-west4", "europe-west1", "europe-west2"])
|
| 226 |
+
|
| 227 |
+
class BaseAPI:
|
| 228 |
+
def __init__(
|
| 229 |
+
self,
|
| 230 |
+
api_url: str = "https://api.openai.com/v1/chat/completions",
|
| 231 |
+
):
|
| 232 |
+
if api_url == "":
|
| 233 |
+
api_url = "https://api.openai.com/v1/chat/completions"
|
| 234 |
+
self.source_api_url: str = api_url
|
| 235 |
+
from urllib.parse import urlparse, urlunparse
|
| 236 |
+
parsed_url = urlparse(self.source_api_url)
|
| 237 |
+
if parsed_url.scheme == "":
|
| 238 |
+
raise Exception("Error: API_URL is not set")
|
| 239 |
+
if parsed_url.path != '/':
|
| 240 |
+
before_v1 = parsed_url.path.split("/v1")[0]
|
| 241 |
+
else:
|
| 242 |
+
before_v1 = ""
|
| 243 |
+
self.base_url: str = urlunparse(parsed_url[:2] + (before_v1,) + ("",) * 3)
|
| 244 |
+
self.v1_url: str = urlunparse(parsed_url[:2]+ (before_v1 + "/v1",) + ("",) * 3)
|
| 245 |
+
self.v1_models: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/models",) + ("",) * 3)
|
| 246 |
+
if parsed_url.netloc == "api.deepseek.com":
|
| 247 |
+
self.chat_url: str = urlunparse(parsed_url[:2] + ("/chat/completions",) + ("",) * 3)
|
| 248 |
+
else:
|
| 249 |
+
self.chat_url: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/chat/completions",) + ("",) * 3)
|
| 250 |
+
self.image_url: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/images/generations",) + ("",) * 3)
|
| 251 |
+
self.audio_transcriptions: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/audio/transcriptions",) + ("",) * 3)
|