Update app.py
Browse files
app.py
CHANGED
|
@@ -2,7 +2,8 @@ import os
|
|
| 2 |
import time
|
| 3 |
import json
|
| 4 |
import logging
|
| 5 |
-
from fastapi import FastAPI, Request, HTTPException
|
|
|
|
| 6 |
from fastapi.responses import JSONResponse
|
| 7 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 8 |
import torch
|
|
@@ -21,6 +22,40 @@ MODEL_NAME = "Qwen/Qwen1.5-0.5B-Chat"
|
|
| 21 |
MAX_TOKENS = 256
|
| 22 |
DEVICE = "cpu" # 强制使用CPU
|
| 23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
def load_model():
|
| 25 |
"""极简模型加载"""
|
| 26 |
global model, tokenizer
|
|
@@ -81,6 +116,14 @@ def generate_response(messages):
|
|
| 81 |
add_generation_prompt=True
|
| 82 |
)
|
| 83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
# 编码输入
|
| 85 |
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=1024)
|
| 86 |
inputs = inputs.to(DEVICE)
|
|
@@ -112,25 +155,46 @@ def generate_response(messages):
|
|
| 112 |
return {"error": f"生成失败: {str(e)}"}
|
| 113 |
|
| 114 |
# 创建极简FastAPI应用
|
| 115 |
-
app = FastAPI(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
|
| 117 |
# 启动时加载模型
|
| 118 |
@app.on_event("startup")
|
| 119 |
async def startup_event():
|
| 120 |
load_model()
|
|
|
|
|
|
|
|
|
|
| 121 |
|
| 122 |
-
# 健康检查端点(
|
| 123 |
@app.get("/health")
|
| 124 |
async def health_check():
|
| 125 |
return {
|
| 126 |
"status": "healthy" if model is not None else "loading",
|
| 127 |
"model_loaded": model is not None,
|
|
|
|
| 128 |
"timestamp": int(time.time())
|
| 129 |
}
|
| 130 |
|
| 131 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
@app.post("/v1/chat/completions")
|
| 133 |
-
async def chat_completion(
|
|
|
|
|
|
|
|
|
|
| 134 |
"""极简版OpenAI兼容端点"""
|
| 135 |
try:
|
| 136 |
# 解析请求
|
|
@@ -185,13 +249,14 @@ async def chat_completion(request: Request):
|
|
| 185 |
}
|
| 186 |
)
|
| 187 |
|
| 188 |
-
#
|
| 189 |
-
@app.
|
| 190 |
-
async def
|
|
|
|
| 191 |
return {
|
| 192 |
-
"
|
| 193 |
-
"
|
| 194 |
-
"
|
| 195 |
}
|
| 196 |
|
| 197 |
if __name__ == "__main__":
|
|
|
|
| 2 |
import time
|
| 3 |
import json
|
| 4 |
import logging
|
| 5 |
+
from fastapi import FastAPI, Request, HTTPException, Depends, status
|
| 6 |
+
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
| 7 |
from fastapi.responses import JSONResponse
|
| 8 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 9 |
import torch
|
|
|
|
| 22 |
MAX_TOKENS = 256
|
| 23 |
DEVICE = "cpu" # 强制使用CPU
|
| 24 |
|
| 25 |
+
# API密钥配置
|
| 26 |
+
# 从环境变量获取API密钥,如果没有设置则使用默认值
|
| 27 |
+
API_KEYS = os.getenv("API_KEYS", "your-secret-key-1,your-secret-key-2").split(",")
|
| 28 |
+
# 是否启用API密钥验证
|
| 29 |
+
API_AUTH_ENABLED = os.getenv("API_AUTH_ENABLED", "true").lower() == "true"
|
| 30 |
+
|
| 31 |
+
# 创建Bearer认证方案
|
| 32 |
+
security = HTTPBearer()
|
| 33 |
+
|
| 34 |
+
def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
|
| 35 |
+
"""验证API密钥"""
|
| 36 |
+
# 如果未启用认证,则跳过验证
|
| 37 |
+
if not API_AUTH_ENABLED:
|
| 38 |
+
return True
|
| 39 |
+
|
| 40 |
+
# 检查Bearer令牌格式
|
| 41 |
+
if not credentials.scheme == "Bearer":
|
| 42 |
+
raise HTTPException(
|
| 43 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 44 |
+
detail="Invalid authentication scheme. Use 'Bearer' token",
|
| 45 |
+
headers={"WWW-Authenticate": "Bearer"},
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
# 检查API密钥是否有效
|
| 49 |
+
api_key = credentials.credentials
|
| 50 |
+
if api_key not in API_KEYS:
|
| 51 |
+
raise HTTPException(
|
| 52 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 53 |
+
detail="Invalid API key",
|
| 54 |
+
headers={"WWW-Authenticate": "Bearer"},
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
return True
|
| 58 |
+
|
| 59 |
def load_model():
|
| 60 |
"""极简模型加载"""
|
| 61 |
global model, tokenizer
|
|
|
|
| 116 |
add_generation_prompt=True
|
| 117 |
)
|
| 118 |
|
| 119 |
+
# 确保text是字符串
|
| 120 |
+
if not isinstance(text, str):
|
| 121 |
+
# 如果返回的是列表,则连接成字符串
|
| 122 |
+
if isinstance(text, list):
|
| 123 |
+
text = "".join(text)
|
| 124 |
+
else:
|
| 125 |
+
text = str(text)
|
| 126 |
+
|
| 127 |
# 编码输入
|
| 128 |
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=1024)
|
| 129 |
inputs = inputs.to(DEVICE)
|
|
|
|
| 155 |
return {"error": f"生成失败: {str(e)}"}
|
| 156 |
|
| 157 |
# 创建极简FastAPI应用
|
| 158 |
+
app = FastAPI(
|
| 159 |
+
title="Qwen1.5-0.5B API",
|
| 160 |
+
version="1.0",
|
| 161 |
+
description="带有API密钥验证的Qwen1.5-0.5B-Chat API服务"
|
| 162 |
+
)
|
| 163 |
|
| 164 |
# 启动时加载模型
|
| 165 |
@app.on_event("startup")
|
| 166 |
async def startup_event():
|
| 167 |
load_model()
|
| 168 |
+
logger.info(f"API认证状态: {'已启用' if API_AUTH_ENABLED else '已禁用'}")
|
| 169 |
+
if API_AUTH_ENABLED:
|
| 170 |
+
logger.info(f"有效的API密钥数量: {len(API_KEYS)}")
|
| 171 |
|
| 172 |
+
# 健康检查端点(无需认证)
|
| 173 |
@app.get("/health")
|
| 174 |
async def health_check():
|
| 175 |
return {
|
| 176 |
"status": "healthy" if model is not None else "loading",
|
| 177 |
"model_loaded": model is not None,
|
| 178 |
+
"api_auth_enabled": API_AUTH_ENABLED,
|
| 179 |
"timestamp": int(time.time())
|
| 180 |
}
|
| 181 |
|
| 182 |
+
# 根端点(无需认证)
|
| 183 |
+
@app.get("/")
|
| 184 |
+
async def root():
|
| 185 |
+
return {
|
| 186 |
+
"message": "Qwen1.5-0.5B-Chat API服务运行中",
|
| 187 |
+
"model_loaded": model is not None,
|
| 188 |
+
"api_auth_enabled": API_AUTH_ENABLED,
|
| 189 |
+
"endpoint": "/v1/chat/completions"
|
| 190 |
+
}
|
| 191 |
+
|
| 192 |
+
# OpenAI兼容的聊天端点(需要认证)
|
| 193 |
@app.post("/v1/chat/completions")
|
| 194 |
+
async def chat_completion(
|
| 195 |
+
request: Request,
|
| 196 |
+
auth_valid: bool = Depends(verify_api_key)
|
| 197 |
+
):
|
| 198 |
"""极简版OpenAI兼容端点"""
|
| 199 |
try:
|
| 200 |
# 解析请求
|
|
|
|
| 249 |
}
|
| 250 |
)
|
| 251 |
|
| 252 |
+
# 添加一个简单的测试端点(需要认证)
|
| 253 |
+
@app.post("/v1/test")
|
| 254 |
+
async def test_endpoint(auth_valid: bool = Depends(verify_api_key)):
|
| 255 |
+
"""测试端点,验证API密钥是否有效"""
|
| 256 |
return {
|
| 257 |
+
"status": "success",
|
| 258 |
+
"message": "API密钥验证通过",
|
| 259 |
+
"timestamp": int(time.time())
|
| 260 |
}
|
| 261 |
|
| 262 |
if __name__ == "__main__":
|