Spaces:
Paused
Paused
Update main.py
Browse files
main.py
CHANGED
|
@@ -25,6 +25,9 @@ from fastapi.staticfiles import StaticFiles
|
|
| 25 |
|
| 26 |
from bearer_token import BearerTokenGenerator
|
| 27 |
|
|
|
|
|
|
|
|
|
|
| 28 |
# 模型列表
|
| 29 |
MODELS = ["gpt-4o", "gpt-4o-mini", "claude-3-5-sonnet", "claude"]
|
| 30 |
|
|
@@ -121,6 +124,18 @@ def is_base64_image(url: str) -> bool:
|
|
| 121 |
"""
|
| 122 |
return url.startswith("data:image/")
|
| 123 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
# 根路径GET请求处理
|
| 125 |
@app.get("/", response_class=HTMLResponse)
|
| 126 |
async def read_root():
|
|
@@ -140,7 +155,7 @@ async def read_root():
|
|
| 140 |
|
| 141 |
# 聊天完成处理
|
| 142 |
@app.post("/ai/v1/chat/completions")
|
| 143 |
-
async def chat_completions(request: Request, background_tasks: BackgroundTasks):
|
| 144 |
"""
|
| 145 |
处理聊天完成请求
|
| 146 |
"""
|
|
@@ -380,7 +395,7 @@ async def chat_completions(request: Request, background_tasks: BackgroundTasks):
|
|
| 380 |
|
| 381 |
# 图像生成处理
|
| 382 |
@app.post("/ai/v1/images/generations")
|
| 383 |
-
async def images_generations(request: Request):
|
| 384 |
"""
|
| 385 |
处理图像生成请求
|
| 386 |
"""
|
|
@@ -544,10 +559,13 @@ def main():
|
|
| 544 |
parser.add_argument('--base_url', type=str, default='http://localhost', help='Base URL for accessing images')
|
| 545 |
parser.add_argument('--port', type=int, default=INITIAL_PORT, help='服务器监听端口')
|
| 546 |
args = parser.parse_args()
|
| 547 |
-
|
| 548 |
base_url = args.base_url
|
| 549 |
port = args.port
|
| 550 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 551 |
# 确保 images 目录存在
|
| 552 |
if not os.path.exists("images"):
|
| 553 |
os.makedirs("images")
|
|
|
|
| 25 |
|
| 26 |
from bearer_token import BearerTokenGenerator
|
| 27 |
|
| 28 |
+
from fastapi import Depends, HTTPException, Security
|
| 29 |
+
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
| 30 |
+
|
| 31 |
# 模型列表
|
| 32 |
MODELS = ["gpt-4o", "gpt-4o-mini", "claude-3-5-sonnet", "claude"]
|
| 33 |
|
|
|
|
| 124 |
"""
|
| 125 |
return url.startswith("data:image/")
|
| 126 |
|
| 127 |
+
# 添加 HTTPBearer 实例
|
| 128 |
+
security = HTTPBearer()
|
| 129 |
+
|
| 130 |
+
# 添加 API_KEY 验证函数
|
| 131 |
+
def verify_api_key(credentials: HTTPAuthorizationCredentials = Security(security)):
|
| 132 |
+
api_key = os.environ.get("API_KEY")
|
| 133 |
+
if api_key is None:
|
| 134 |
+
raise HTTPException(status_code=500, detail="API_KEY not set in environment variables")
|
| 135 |
+
if credentials.credentials != api_key:
|
| 136 |
+
raise HTTPException(status_code=401, detail="Invalid API key")
|
| 137 |
+
return credentials.credentials
|
| 138 |
+
|
| 139 |
# 根路径GET请求处理
|
| 140 |
@app.get("/", response_class=HTMLResponse)
|
| 141 |
async def read_root():
|
|
|
|
| 155 |
|
| 156 |
# 聊天完成处理
|
| 157 |
@app.post("/ai/v1/chat/completions")
|
| 158 |
+
async def chat_completions(request: Request, background_tasks: BackgroundTasks, api_key: str = Depends(verify_api_key)):
|
| 159 |
"""
|
| 160 |
处理聊天完成请求
|
| 161 |
"""
|
|
|
|
| 395 |
|
| 396 |
# 图像生成处理
|
| 397 |
@app.post("/ai/v1/images/generations")
|
| 398 |
+
async def images_generations(request: Request, api_key: str = Depends(verify_api_key)):
|
| 399 |
"""
|
| 400 |
处理图像生成请求
|
| 401 |
"""
|
|
|
|
| 559 |
parser.add_argument('--base_url', type=str, default='http://localhost', help='Base URL for accessing images')
|
| 560 |
parser.add_argument('--port', type=int, default=INITIAL_PORT, help='服务器监听端口')
|
| 561 |
args = parser.parse_args()
|
|
|
|
| 562 |
base_url = args.base_url
|
| 563 |
port = args.port
|
| 564 |
|
| 565 |
+
# 检查 API_KEY 是否设置
|
| 566 |
+
if not os.environ.get("API_KEY"):
|
| 567 |
+
print("警告: API_KEY 环境变量未设置。客户端验证将无法正常工作。")
|
| 568 |
+
|
| 569 |
# 确保 images 目录存在
|
| 570 |
if not os.path.exists("images"):
|
| 571 |
os.makedirs("images")
|