han145 commited on
Commit
a7af7f9
·
verified ·
1 Parent(s): 888b613

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -11
app.py CHANGED
@@ -2,7 +2,8 @@ import os
2
  import time
3
  import json
4
  import logging
5
- from fastapi import FastAPI, Request, HTTPException
 
6
  from fastapi.responses import JSONResponse
7
  from transformers import AutoTokenizer, AutoModelForCausalLM
8
  import torch
@@ -21,6 +22,40 @@ MODEL_NAME = "Qwen/Qwen1.5-0.5B-Chat"
21
  MAX_TOKENS = 256
22
  DEVICE = "cpu" # 强制使用CPU
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  def load_model():
25
  """极简模型加载"""
26
  global model, tokenizer
@@ -81,6 +116,14 @@ def generate_response(messages):
81
  add_generation_prompt=True
82
  )
83
 
 
 
 
 
 
 
 
 
84
  # 编码输入
85
  inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=1024)
86
  inputs = inputs.to(DEVICE)
@@ -112,25 +155,46 @@ def generate_response(messages):
112
  return {"error": f"生成失败: {str(e)}"}
113
 
114
  # 创建极简FastAPI应用
115
- app = FastAPI(title="Qwen1.5-0.5B API", version="1.0")
 
 
 
 
116
 
117
  # 启动时加载模型
118
  @app.on_event("startup")
119
  async def startup_event():
120
  load_model()
 
 
 
121
 
122
- # 健康检查端点(OpenClaw可能
123
  @app.get("/health")
124
  async def health_check():
125
  return {
126
  "status": "healthy" if model is not None else "loading",
127
  "model_loaded": model is not None,
 
128
  "timestamp": int(time.time())
129
  }
130
 
131
- # OpenAI兼容的聊天端点
 
 
 
 
 
 
 
 
 
 
132
  @app.post("/v1/chat/completions")
133
- async def chat_completion(request: Request):
 
 
 
134
  """极简版OpenAI兼容端点"""
135
  try:
136
  # 解析请求
@@ -185,13 +249,14 @@ async def chat_completion(request: Request):
185
  }
186
  )
187
 
188
- # 端点
189
- @app.get("/")
190
- async def root():
 
191
  return {
192
- "message": "Qwen1.5-0.5B-Chat API服务运行中",
193
- "model_loaded": model is not None,
194
- "endpoint": "/v1/chat/completions"
195
  }
196
 
197
  if __name__ == "__main__":
 
2
  import time
3
  import json
4
  import logging
5
+ from fastapi import FastAPI, Request, HTTPException, Depends, status
6
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
7
  from fastapi.responses import JSONResponse
8
  from transformers import AutoTokenizer, AutoModelForCausalLM
9
  import torch
 
22
  MAX_TOKENS = 256
23
  DEVICE = "cpu" # 强制使用CPU
24
 
25
+ # API密钥配置
26
+ # 从环境变量获取API密钥,如果没有设置则使用默认值
27
+ API_KEYS = os.getenv("API_KEYS", "your-secret-key-1,your-secret-key-2").split(",")
28
+ # 是否启用API密钥验证
29
+ API_AUTH_ENABLED = os.getenv("API_AUTH_ENABLED", "true").lower() == "true"
30
+
31
+ # 创建Bearer认证方案
32
+ security = HTTPBearer()
33
+
34
+ def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
35
+ """验证API密钥"""
36
+ # 如果未启用认证,则跳过验证
37
+ if not API_AUTH_ENABLED:
38
+ return True
39
+
40
+ # 检查Bearer令牌格式
41
+ if not credentials.scheme == "Bearer":
42
+ raise HTTPException(
43
+ status_code=status.HTTP_401_UNAUTHORIZED,
44
+ detail="Invalid authentication scheme. Use 'Bearer' token",
45
+ headers={"WWW-Authenticate": "Bearer"},
46
+ )
47
+
48
+ # 检查API密钥是否有效
49
+ api_key = credentials.credentials
50
+ if api_key not in API_KEYS:
51
+ raise HTTPException(
52
+ status_code=status.HTTP_401_UNAUTHORIZED,
53
+ detail="Invalid API key",
54
+ headers={"WWW-Authenticate": "Bearer"},
55
+ )
56
+
57
+ return True
58
+
59
  def load_model():
60
  """极简模型加载"""
61
  global model, tokenizer
 
116
  add_generation_prompt=True
117
  )
118
 
119
+ # 确保text是字符串
120
+ if not isinstance(text, str):
121
+ # 如果返回的是列表,则连接成字符串
122
+ if isinstance(text, list):
123
+ text = "".join(text)
124
+ else:
125
+ text = str(text)
126
+
127
  # 编码输入
128
  inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=1024)
129
  inputs = inputs.to(DEVICE)
 
155
  return {"error": f"生成失败: {str(e)}"}
156
 
157
  # 创建极简FastAPI应用
158
+ app = FastAPI(
159
+ title="Qwen1.5-0.5B API",
160
+ version="1.0",
161
+ description="带有API密钥验证的Qwen1.5-0.5B-Chat API服务"
162
+ )
163
 
164
  # 启动时加载模型
165
  @app.on_event("startup")
166
  async def startup_event():
167
  load_model()
168
+ logger.info(f"API认证状态: {'已启用' if API_AUTH_ENABLED else '已禁用'}")
169
+ if API_AUTH_ENABLED:
170
+ logger.info(f"有效的API密钥数量: {len(API_KEYS)}")
171
 
172
+ # 健康检查端点(认证
173
  @app.get("/health")
174
  async def health_check():
175
  return {
176
  "status": "healthy" if model is not None else "loading",
177
  "model_loaded": model is not None,
178
+ "api_auth_enabled": API_AUTH_ENABLED,
179
  "timestamp": int(time.time())
180
  }
181
 
182
+ # 端点(无需认证)
183
+ @app.get("/")
184
+ async def root():
185
+ return {
186
+ "message": "Qwen1.5-0.5B-Chat API服务运行中",
187
+ "model_loaded": model is not None,
188
+ "api_auth_enabled": API_AUTH_ENABLED,
189
+ "endpoint": "/v1/chat/completions"
190
+ }
191
+
192
+ # OpenAI兼容的聊天端点(需要认证)
193
  @app.post("/v1/chat/completions")
194
+ async def chat_completion(
195
+ request: Request,
196
+ auth_valid: bool = Depends(verify_api_key)
197
+ ):
198
  """极简版OpenAI兼容端点"""
199
  try:
200
  # 解析请求
 
249
  }
250
  )
251
 
252
+ # 添加一个简单的测试端点(需要认证)
253
+ @app.post("/v1/test")
254
+ async def test_endpoint(auth_valid: bool = Depends(verify_api_key)):
255
+ """测试端点,验证API密钥是否有效"""
256
  return {
257
+ "status": "success",
258
+ "message": "API密钥验证通过",
259
+ "timestamp": int(time.time())
260
  }
261
 
262
  if __name__ == "__main__":