✨ Feature: Add feature: Support that as long as the prefix of the API key exists in the configuration file, the API key is valid.
Browse files
main.py
CHANGED
|
@@ -418,14 +418,20 @@ class StatsMiddleware(BaseHTTPMiddleware):
|
|
| 418 |
)
|
| 419 |
else:
|
| 420 |
token = None
|
|
|
|
|
|
|
| 421 |
if token:
|
| 422 |
try:
|
| 423 |
api_list = app.state.api_list
|
| 424 |
api_index = api_list.index(token)
|
| 425 |
-
enable_moderation = safe_get(config, 'api_keys', api_index, "preferences", "ENABLE_MODERATION", default=False)
|
| 426 |
except ValueError:
|
|
|
|
|
|
|
| 427 |
# token不在api_list中,使用默认值(不开启)
|
| 428 |
pass
|
|
|
|
|
|
|
|
|
|
| 429 |
else:
|
| 430 |
# 如果token为None,检查全局设置
|
| 431 |
enable_moderation = config.get('ENABLE_MODERATION', False)
|
|
@@ -473,7 +479,7 @@ class StatsMiddleware(BaseHTTPMiddleware):
|
|
| 473 |
|
| 474 |
|
| 475 |
if enable_moderation and moderated_content:
|
| 476 |
-
moderation_response = await self.moderate_content(moderated_content,
|
| 477 |
is_flagged = moderation_response.get('results', [{}])[0].get('flagged', False)
|
| 478 |
|
| 479 |
if is_flagged:
|
|
@@ -518,11 +524,11 @@ class StatsMiddleware(BaseHTTPMiddleware):
|
|
| 518 |
# print("current_request_info", current_request_info)
|
| 519 |
request_info.reset(current_request_info)
|
| 520 |
|
| 521 |
-
async def moderate_content(self, content,
|
| 522 |
moderation_request = ModerationRequest(input=content)
|
| 523 |
|
| 524 |
# 直接调用 moderations 函数
|
| 525 |
-
response = await moderations(moderation_request,
|
| 526 |
|
| 527 |
# 读取流式响应的内容
|
| 528 |
moderation_result = b""
|
|
@@ -640,7 +646,7 @@ async def ensure_config(request: Request, call_next):
|
|
| 640 |
return await call_next(request)
|
| 641 |
|
| 642 |
# 在 process_request 函数中更新成功和失败计数
|
| 643 |
-
async def process_request(request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, EmbeddingRequest], provider: Dict, endpoint=None
|
| 644 |
url = provider['base_url']
|
| 645 |
parsed_url = urlparse(url)
|
| 646 |
# print("parsed_url", parsed_url)
|
|
@@ -745,17 +751,14 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest, A
|
|
| 745 |
# response = JSONResponse(first_element)
|
| 746 |
|
| 747 |
# 更新成功计数和首次响应时间
|
| 748 |
-
await update_channel_stats(current_info["request_id"], provider['provider'], request.model,
|
| 749 |
-
# await app.middleware_stack.app.update_channel_stats(current_info["request_id"], provider['provider'], request.model, token, success=True)
|
| 750 |
current_info["first_response_time"] = first_response_time
|
| 751 |
current_info["success"] = True
|
| 752 |
current_info["provider"] = provider['provider']
|
| 753 |
return response
|
| 754 |
|
| 755 |
except (Exception, HTTPException, asyncio.CancelledError, httpx.ReadError, httpx.RemoteProtocolError, httpx.ReadTimeout) as e:
|
| 756 |
-
await update_channel_stats(current_info["request_id"], provider['provider'], request.model,
|
| 757 |
-
# await app.middleware_stack.app.update_channel_stats(current_info["request_id"], provider['provider'], request.model, token, success=False)
|
| 758 |
-
|
| 759 |
raise e
|
| 760 |
|
| 761 |
def weighted_round_robin(weights):
|
|
@@ -950,11 +953,8 @@ class ModelRequestHandler:
|
|
| 950 |
self.last_provider_indices = defaultdict(lambda: -1)
|
| 951 |
self.locks = defaultdict(asyncio.Lock)
|
| 952 |
|
| 953 |
-
async def request_model(self, request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, EmbeddingRequest],
|
| 954 |
config = app.state.config
|
| 955 |
-
api_list = app.state.api_list
|
| 956 |
-
api_index = api_list.index(token)
|
| 957 |
-
|
| 958 |
request_model = request.model
|
| 959 |
if not safe_get(config, 'api_keys', api_index, 'model'):
|
| 960 |
raise HTTPException(status_code=404, detail=f"No matching model found: {request_model}")
|
|
@@ -988,7 +988,7 @@ class ModelRequestHandler:
|
|
| 988 |
index += 1
|
| 989 |
provider = matching_providers[current_index]
|
| 990 |
try:
|
| 991 |
-
response = await process_request(request, provider, endpoint
|
| 992 |
return response
|
| 993 |
except (Exception, HTTPException, asyncio.CancelledError, httpx.ReadError, httpx.RemoteProtocolError, httpx.ReadTimeout) as e:
|
| 994 |
|
|
@@ -1058,9 +1058,12 @@ async def rate_limit_dependency(request: Request, credentials: HTTPAuthorization
|
|
| 1058 |
try:
|
| 1059 |
api_index = api_list.index(token)
|
| 1060 |
except ValueError:
|
| 1061 |
-
|
| 1062 |
-
api_index = None
|
| 1063 |
-
|
|
|
|
|
|
|
|
|
|
| 1064 |
|
| 1065 |
# 使用 IP 地址和 token(如果有)作为限制键
|
| 1066 |
client_ip = request.client.host
|
|
@@ -1073,32 +1076,44 @@ async def rate_limit_dependency(request: Request, credentials: HTTPAuthorization
|
|
| 1073 |
def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
|
| 1074 |
api_list = app.state.api_list
|
| 1075 |
token = credentials.credentials
|
| 1076 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1077 |
raise HTTPException(status_code=403, detail="Invalid or missing API Key")
|
| 1078 |
-
return
|
| 1079 |
|
| 1080 |
def verify_admin_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
|
| 1081 |
api_list = app.state.api_list
|
| 1082 |
token = credentials.credentials
|
| 1083 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1084 |
raise HTTPException(status_code=403, detail="Invalid or missing API Key")
|
| 1085 |
-
for api_key in app.state.api_keys_db:
|
| 1086 |
-
|
| 1087 |
-
|
| 1088 |
-
|
| 1089 |
return token
|
| 1090 |
|
| 1091 |
@app.post("/v1/chat/completions", dependencies=[Depends(rate_limit_dependency)])
|
| 1092 |
-
async def request_model(request: RequestModel,
|
| 1093 |
-
return await model_handler.request_model(request,
|
| 1094 |
|
| 1095 |
@app.options("/v1/chat/completions", dependencies=[Depends(rate_limit_dependency)])
|
| 1096 |
async def options_handler():
|
| 1097 |
return JSONResponse(status_code=200, content={"detail": "OPTIONS allowed"})
|
| 1098 |
|
| 1099 |
@app.get("/v1/models", dependencies=[Depends(rate_limit_dependency)])
|
| 1100 |
-
async def list_models(
|
| 1101 |
-
models = post_all_models(
|
| 1102 |
return JSONResponse(content={
|
| 1103 |
"object": "list",
|
| 1104 |
"data": models
|
|
@@ -1107,23 +1122,23 @@ async def list_models(token: str = Depends(verify_api_key)):
|
|
| 1107 |
@app.post("/v1/images/generations", dependencies=[Depends(rate_limit_dependency)])
|
| 1108 |
async def images_generations(
|
| 1109 |
request: ImageGenerationRequest,
|
| 1110 |
-
|
| 1111 |
):
|
| 1112 |
-
return await model_handler.request_model(request,
|
| 1113 |
|
| 1114 |
@app.post("/v1/embeddings", dependencies=[Depends(rate_limit_dependency)])
|
| 1115 |
async def embeddings(
|
| 1116 |
request: EmbeddingRequest,
|
| 1117 |
-
|
| 1118 |
):
|
| 1119 |
-
return await model_handler.request_model(request,
|
| 1120 |
|
| 1121 |
@app.post("/v1/moderations", dependencies=[Depends(rate_limit_dependency)])
|
| 1122 |
async def moderations(
|
| 1123 |
request: ModerationRequest,
|
| 1124 |
-
|
| 1125 |
):
|
| 1126 |
-
return await model_handler.request_model(request,
|
| 1127 |
|
| 1128 |
from fastapi import UploadFile, File, Form, HTTPException
|
| 1129 |
import io
|
|
@@ -1131,7 +1146,7 @@ import io
|
|
| 1131 |
async def audio_transcriptions(
|
| 1132 |
file: UploadFile = File(...),
|
| 1133 |
model: str = Form(...),
|
| 1134 |
-
|
| 1135 |
):
|
| 1136 |
try:
|
| 1137 |
# 读取上传的文件内容
|
|
@@ -1144,7 +1159,7 @@ async def audio_transcriptions(
|
|
| 1144 |
model=model
|
| 1145 |
)
|
| 1146 |
|
| 1147 |
-
return await model_handler.request_model(request,
|
| 1148 |
except UnicodeDecodeError:
|
| 1149 |
raise HTTPException(status_code=400, detail="Invalid audio file encoding")
|
| 1150 |
except Exception as e:
|
|
|
|
| 418 |
)
|
| 419 |
else:
|
| 420 |
token = None
|
| 421 |
+
|
| 422 |
+
api_index = None
|
| 423 |
if token:
|
| 424 |
try:
|
| 425 |
api_list = app.state.api_list
|
| 426 |
api_index = api_list.index(token)
|
|
|
|
| 427 |
except ValueError:
|
| 428 |
+
# 如果 token 不在 api_list 中,检查是否以 api_list 中的任何一个开头
|
| 429 |
+
api_index = next((i for i, api in enumerate(api_list) if token.startswith(api)), None)
|
| 430 |
# token不在api_list中,使用默认值(不开启)
|
| 431 |
pass
|
| 432 |
+
|
| 433 |
+
if api_index is not None:
|
| 434 |
+
enable_moderation = safe_get(config, 'api_keys', api_index, "preferences", "ENABLE_MODERATION", default=False)
|
| 435 |
else:
|
| 436 |
# 如果token为None,检查全局设置
|
| 437 |
enable_moderation = config.get('ENABLE_MODERATION', False)
|
|
|
|
| 479 |
|
| 480 |
|
| 481 |
if enable_moderation and moderated_content:
|
| 482 |
+
moderation_response = await self.moderate_content(moderated_content, api_index)
|
| 483 |
is_flagged = moderation_response.get('results', [{}])[0].get('flagged', False)
|
| 484 |
|
| 485 |
if is_flagged:
|
|
|
|
| 524 |
# print("current_request_info", current_request_info)
|
| 525 |
request_info.reset(current_request_info)
|
| 526 |
|
| 527 |
+
async def moderate_content(self, content, api_index):
|
| 528 |
moderation_request = ModerationRequest(input=content)
|
| 529 |
|
| 530 |
# 直接调用 moderations 函数
|
| 531 |
+
response = await moderations(moderation_request, api_index)
|
| 532 |
|
| 533 |
# 读取流式响应的内容
|
| 534 |
moderation_result = b""
|
|
|
|
| 646 |
return await call_next(request)
|
| 647 |
|
| 648 |
# 在 process_request 函数中更新成功和失败计数
|
| 649 |
+
async def process_request(request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, EmbeddingRequest], provider: Dict, endpoint=None):
|
| 650 |
url = provider['base_url']
|
| 651 |
parsed_url = urlparse(url)
|
| 652 |
# print("parsed_url", parsed_url)
|
|
|
|
| 751 |
# response = JSONResponse(first_element)
|
| 752 |
|
| 753 |
# 更新成功计数和首次响应时间
|
| 754 |
+
await update_channel_stats(current_info["request_id"], provider['provider'], request.model, current_info["api_key"], success=True)
|
|
|
|
| 755 |
current_info["first_response_time"] = first_response_time
|
| 756 |
current_info["success"] = True
|
| 757 |
current_info["provider"] = provider['provider']
|
| 758 |
return response
|
| 759 |
|
| 760 |
except (Exception, HTTPException, asyncio.CancelledError, httpx.ReadError, httpx.RemoteProtocolError, httpx.ReadTimeout) as e:
|
| 761 |
+
await update_channel_stats(current_info["request_id"], provider['provider'], request.model, current_info["api_key"], success=False)
|
|
|
|
|
|
|
| 762 |
raise e
|
| 763 |
|
| 764 |
def weighted_round_robin(weights):
|
|
|
|
| 953 |
self.last_provider_indices = defaultdict(lambda: -1)
|
| 954 |
self.locks = defaultdict(asyncio.Lock)
|
| 955 |
|
| 956 |
+
async def request_model(self, request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, EmbeddingRequest], api_index: int = None, endpoint=None):
|
| 957 |
config = app.state.config
|
|
|
|
|
|
|
|
|
|
| 958 |
request_model = request.model
|
| 959 |
if not safe_get(config, 'api_keys', api_index, 'model'):
|
| 960 |
raise HTTPException(status_code=404, detail=f"No matching model found: {request_model}")
|
|
|
|
| 988 |
index += 1
|
| 989 |
provider = matching_providers[current_index]
|
| 990 |
try:
|
| 991 |
+
response = await process_request(request, provider, endpoint)
|
| 992 |
return response
|
| 993 |
except (Exception, HTTPException, asyncio.CancelledError, httpx.ReadError, httpx.RemoteProtocolError, httpx.ReadTimeout) as e:
|
| 994 |
|
|
|
|
| 1058 |
try:
|
| 1059 |
api_index = api_list.index(token)
|
| 1060 |
except ValueError:
|
| 1061 |
+
# 如果 token 不在 api_list 中,检查是否以 api_list 中的任何一个开头
|
| 1062 |
+
api_index = next((i for i, api in enumerate(api_list) if token.startswith(api)), None)
|
| 1063 |
+
if api_index is None:
|
| 1064 |
+
print("error: Invalid or missing API Key:", token)
|
| 1065 |
+
api_index = None
|
| 1066 |
+
token = None
|
| 1067 |
|
| 1068 |
# 使用 IP 地址和 token(如果有)作为限制键
|
| 1069 |
client_ip = request.client.host
|
|
|
|
| 1076 |
def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
|
| 1077 |
api_list = app.state.api_list
|
| 1078 |
token = credentials.credentials
|
| 1079 |
+
api_index = None
|
| 1080 |
+
try:
|
| 1081 |
+
api_index = api_list.index(token)
|
| 1082 |
+
except ValueError:
|
| 1083 |
+
# 如果 token 不在 api_list 中,检查是否以 api_list 中的任何一个开头
|
| 1084 |
+
api_index = next((i for i, api in enumerate(api_list) if token.startswith(api)), None)
|
| 1085 |
+
if api_index is None:
|
| 1086 |
raise HTTPException(status_code=403, detail="Invalid or missing API Key")
|
| 1087 |
+
return api_index
|
| 1088 |
|
| 1089 |
def verify_admin_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
|
| 1090 |
api_list = app.state.api_list
|
| 1091 |
token = credentials.credentials
|
| 1092 |
+
api_index = None
|
| 1093 |
+
try:
|
| 1094 |
+
api_index = api_list.index(token)
|
| 1095 |
+
except ValueError:
|
| 1096 |
+
# 如果 token 不在 api_list 中,检查是否以 api_list 中的任何一个开头
|
| 1097 |
+
api_index = next((i for i, api in enumerate(api_list) if token.startswith(api)), None)
|
| 1098 |
+
if api_index is None:
|
| 1099 |
raise HTTPException(status_code=403, detail="Invalid or missing API Key")
|
| 1100 |
+
# for api_key in app.state.api_keys_db:
|
| 1101 |
+
# if token.startswith(api_key['api']):
|
| 1102 |
+
if app.state.api_keys_db[api_index].get('role') != "admin":
|
| 1103 |
+
raise HTTPException(status_code=403, detail="Permission denied")
|
| 1104 |
return token
|
| 1105 |
|
| 1106 |
@app.post("/v1/chat/completions", dependencies=[Depends(rate_limit_dependency)])
|
| 1107 |
+
async def request_model(request: RequestModel, api_index: int = Depends(verify_api_key)):
|
| 1108 |
+
return await model_handler.request_model(request, api_index)
|
| 1109 |
|
| 1110 |
@app.options("/v1/chat/completions", dependencies=[Depends(rate_limit_dependency)])
|
| 1111 |
async def options_handler():
|
| 1112 |
return JSONResponse(status_code=200, content={"detail": "OPTIONS allowed"})
|
| 1113 |
|
| 1114 |
@app.get("/v1/models", dependencies=[Depends(rate_limit_dependency)])
|
| 1115 |
+
async def list_models(api_index: int = Depends(verify_api_key)):
|
| 1116 |
+
models = post_all_models(api_index, app.state.config)
|
| 1117 |
return JSONResponse(content={
|
| 1118 |
"object": "list",
|
| 1119 |
"data": models
|
|
|
|
| 1122 |
@app.post("/v1/images/generations", dependencies=[Depends(rate_limit_dependency)])
|
| 1123 |
async def images_generations(
|
| 1124 |
request: ImageGenerationRequest,
|
| 1125 |
+
api_index: int = Depends(verify_api_key)
|
| 1126 |
):
|
| 1127 |
+
return await model_handler.request_model(request, api_index, endpoint="/v1/images/generations")
|
| 1128 |
|
| 1129 |
@app.post("/v1/embeddings", dependencies=[Depends(rate_limit_dependency)])
|
| 1130 |
async def embeddings(
|
| 1131 |
request: EmbeddingRequest,
|
| 1132 |
+
api_index: int = Depends(verify_api_key)
|
| 1133 |
):
|
| 1134 |
+
return await model_handler.request_model(request, api_index, endpoint="/v1/embeddings")
|
| 1135 |
|
| 1136 |
@app.post("/v1/moderations", dependencies=[Depends(rate_limit_dependency)])
|
| 1137 |
async def moderations(
|
| 1138 |
request: ModerationRequest,
|
| 1139 |
+
api_index: int = Depends(verify_api_key)
|
| 1140 |
):
|
| 1141 |
+
return await model_handler.request_model(request, api_index, endpoint="/v1/moderations")
|
| 1142 |
|
| 1143 |
from fastapi import UploadFile, File, Form, HTTPException
|
| 1144 |
import io
|
|
|
|
| 1146 |
async def audio_transcriptions(
|
| 1147 |
file: UploadFile = File(...),
|
| 1148 |
model: str = Form(...),
|
| 1149 |
+
api_index: int = Depends(verify_api_key)
|
| 1150 |
):
|
| 1151 |
try:
|
| 1152 |
# 读取上传的文件内容
|
|
|
|
| 1159 |
model=model
|
| 1160 |
)
|
| 1161 |
|
| 1162 |
+
return await model_handler.request_model(request, api_index, endpoint="/v1/audio/transcriptions")
|
| 1163 |
except UnicodeDecodeError:
|
| 1164 |
raise HTTPException(status_code=400, detail="Invalid audio file encoding")
|
| 1165 |
except Exception as e:
|
utils.py
CHANGED
|
@@ -63,7 +63,7 @@ class InMemoryRateLimiter:
|
|
| 63 |
|
| 64 |
rate_limiter = InMemoryRateLimiter()
|
| 65 |
|
| 66 |
-
async def get_user_rate_limit(app, api_index:
|
| 67 |
# 这里应该实现根据 token 获取用户速率限制的逻辑
|
| 68 |
# 示例: 返回 (次数, 秒数)
|
| 69 |
config = app.state.config
|
|
@@ -457,13 +457,10 @@ async def error_handling_wrapper(generator):
|
|
| 457 |
except StopAsyncIteration:
|
| 458 |
raise HTTPException(status_code=400, detail="data: {'error': 'No data returned'}")
|
| 459 |
|
| 460 |
-
def post_all_models(
|
| 461 |
all_models = []
|
| 462 |
unique_models = set()
|
| 463 |
|
| 464 |
-
if token not in api_list:
|
| 465 |
-
raise HTTPException(status_code=403, detail="Invalid or missing API Key")
|
| 466 |
-
api_index = api_list.index(token)
|
| 467 |
if config['api_keys'][api_index]['model']:
|
| 468 |
for model in config['api_keys'][api_index]['model']:
|
| 469 |
if model == "all":
|
|
|
|
| 63 |
|
| 64 |
rate_limiter = InMemoryRateLimiter()
|
| 65 |
|
| 66 |
+
async def get_user_rate_limit(app, api_index: int = None):
|
| 67 |
# 这里应该实现根据 token 获取用户速率限制的逻辑
|
| 68 |
# 示例: 返回 (次数, 秒数)
|
| 69 |
config = app.state.config
|
|
|
|
| 457 |
except StopAsyncIteration:
|
| 458 |
raise HTTPException(status_code=400, detail="data: {'error': 'No data returned'}")
|
| 459 |
|
| 460 |
+
def post_all_models(api_index, config):
|
| 461 |
all_models = []
|
| 462 |
unique_models = set()
|
| 463 |
|
|
|
|
|
|
|
|
|
|
| 464 |
if config['api_keys'][api_index]['model']:
|
| 465 |
for model in config['api_keys'][api_index]['model']:
|
| 466 |
if model == "all":
|