bobocup commited on
Commit
098232e
·
verified ·
1 Parent(s): 3a461ba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +277 -42
app.py CHANGED
@@ -1,14 +1,16 @@
1
  from fastapi import FastAPI, HTTPException, Request, Response
2
  from fastapi.middleware.cors import CORSMiddleware
3
- from fastapi.responses import StreamingResponse
 
4
  import httpx
5
  import os
6
  import json
7
- from typing import List, Optional
8
  import requests
9
  from itertools import cycle
10
  import asyncio
11
  import time
 
12
 
13
  # 创建FastAPI应用
14
  app = FastAPI()
@@ -27,15 +29,56 @@ class Config:
27
  OPENAI_API_BASE = "https://api.x.ai/v1"
28
  KEYS_URL = os.getenv("KEYS_URL", "")
29
  WHITELIST_IPS = os.getenv("WHITELIST_IPS", "").split(",")
 
 
 
 
 
 
 
 
 
30
 
31
  # 全局变量
32
  keys = []
33
  key_cycle = None
34
  first_key = None
 
 
 
 
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  # 初始化keys
37
  def init_keys():
38
- global keys, key_cycle, first_key
39
  try:
40
  if Config.KEYS_URL:
41
  response = requests.get(Config.KEYS_URL)
@@ -47,40 +90,69 @@ def init_keys():
47
  if keys:
48
  first_key = keys[0]
49
  key_cycle = cycle(keys)
 
 
50
  print(f"Loaded {len(keys)} API keys")
51
  except Exception as e:
52
  print(f"Error loading keys: {e}")
53
  keys = []
54
  key_cycle = None
55
  first_key = None
 
56
 
57
- # 获取聊天key
58
- def get_chat_key():
59
  global key_cycle
60
  if not key_cycle:
61
  raise HTTPException(status_code=500, detail="No API keys available")
62
- return next(key_cycle)
63
- # 获取真实IP地址
64
- def get_client_ip(request: Request) -> str:
65
- # 尝试从各种头部获取IP
66
- forwarded_for = request.headers.get("x-forwarded-for")
67
- if forwarded_for:
68
- return forwarded_for.split(",")[0].strip()
69
-
70
- real_ip = request.headers.get("x-real-ip")
71
- if real_ip:
72
- return real_ip
73
 
74
- return request.client.host
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
- # IP白名单中间件
77
- @app.middleware("http")
78
- async def ip_whitelist(request: Request, call_next):
79
- if Config.WHITELIST_IPS and Config.WHITELIST_IPS[0]:
80
- client_ip = get_client_ip(request)
81
- if client_ip not in Config.WHITELIST_IPS:
82
- raise HTTPException(status_code=403, detail="IP not allowed")
83
- return await call_next(request)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
  # 流式响应生成器
86
  async def stream_generator(response):
@@ -90,7 +162,6 @@ async def stream_generator(response):
90
  chunk_str = chunk.decode('utf-8')
91
  buffer += chunk_str
92
 
93
- # 处理buffer中的完整事件
94
  while '\n\n' in buffer:
95
  event, buffer = buffer.split('\n\n', 1)
96
  if event.startswith('data: '):
@@ -99,35 +170,165 @@ async def stream_generator(response):
99
  yield f"data: [DONE]\n\n"
100
  else:
101
  try:
102
- # 解析JSON并重新格式化
103
  json_data = json.loads(data)
104
  yield f"data: {json.dumps(json_data)}\n\n"
105
- # 添加小延迟使流更平滑
106
- await asyncio.sleep(0.01)
107
  except json.JSONDecodeError:
108
  print(f"JSON decode error for data: {data}")
109
  continue
110
  except Exception as e:
111
  print(f"Stream Error: {str(e)}")
112
  yield f"data: {json.dumps({'error': str(e)})}\n\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  # 模型列表路由
114
  @app.get("/api/v1/models")
115
  async def list_models():
116
- headers = {
117
- "Authorization": f"Bearer {get_chat_key()}",
118
- "Content-Type": "application/json"
119
- }
120
-
121
- async with httpx.AsyncClient() as client:
122
- try:
 
 
 
 
 
123
  response = await client.get(
124
  f"{Config.OPENAI_API_BASE}/models",
125
  headers=headers
126
  )
 
 
 
 
 
 
 
 
 
 
 
127
  return response.json()
128
- except Exception as e:
129
- print(f"Models Error: {str(e)}")
130
- raise HTTPException(status_code=500, detail=str(e))
 
131
 
132
  # 聊天完成路由
133
  @app.post("/api/v1/chat/completions")
@@ -137,9 +338,12 @@ async def chat_completions(request: Request):
137
  body = await request.body()
138
  body_json = json.loads(body)
139
 
 
 
 
140
  # 获取headers
141
  headers = {
142
- "Authorization": f"Bearer {get_chat_key()}",
143
  "Content-Type": "application/json",
144
  "Accept": "text/event-stream" if body_json.get("stream") else "application/json"
145
  }
@@ -154,6 +358,18 @@ async def chat_completions(request: Request):
154
  json=body_json
155
  )
156
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  # 检查响应状态
158
  if response.status_code != 200:
159
  return Response(
@@ -183,6 +399,7 @@ async def chat_completions(request: Request):
183
  except Exception as e:
184
  print(f"Chat Error: {str(e)}")
185
  raise HTTPException(status_code=500, detail=str(e))
 
186
  # 代理其他请求
187
  @app.api_route("/api/v1/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
188
  async def proxy(path: str, request: Request):
@@ -192,8 +409,14 @@ async def proxy(path: str, request: Request):
192
  try:
193
  method = request.method
194
  body = await request.body() if method in ["POST", "PUT"] else None
 
 
 
 
 
 
195
  headers = {
196
- "Authorization": f"Bearer {get_chat_key()}",
197
  "Content-Type": "application/json"
198
  }
199
 
@@ -205,12 +428,24 @@ async def proxy(path: str, request: Request):
205
  content=body
206
  )
207
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  return Response(
209
  content=response.text,
210
  status_code=response.status_code,
211
  media_type=response.headers.get("content-type", "application/json")
212
  )
213
-
214
  except Exception as e:
215
  print(f"Proxy Error: {str(e)}")
216
  raise HTTPException(status_code=500, detail=str(e))
 
1
  from fastapi import FastAPI, HTTPException, Request, Response
2
  from fastapi.middleware.cors import CORSMiddleware
3
+ from fastapi.responses import StreamingResponse, FileResponse
4
+ from fastapi.staticfiles import StaticFiles
5
  import httpx
6
  import os
7
  import json
8
+ from typing import List, Optional, Dict
9
  import requests
10
  from itertools import cycle
11
  import asyncio
12
  import time
13
+ from datetime import datetime, timedelta
14
 
15
  # 创建FastAPI应用
16
  app = FastAPI()
 
29
  OPENAI_API_BASE = "https://api.x.ai/v1"
30
  KEYS_URL = os.getenv("KEYS_URL", "")
31
  WHITELIST_IPS = os.getenv("WHITELIST_IPS", "").split(",")
32
+ ADMIN_PASSWORD = os.getenv("ADMIN_PASSWORD", "admin")
33
+
34
+ # 密钥状态类
35
+ class KeyStatus:
36
+ def __init__(self, key: str):
37
+ self.key = key
38
+ self.status = "valid" # valid, invalid, cooling
39
+ self.cooling_until = None
40
+ self.last_check = None
41
 
42
  # 全局变量
43
  keys = []
44
  key_cycle = None
45
  first_key = None
46
+ key_status_map: Dict[str, KeyStatus] = {}
47
+
48
+ # 静态文件挂载
49
+ app.mount("/static", StaticFiles(directory="static"), name="static")
50
 
51
+ # 管理页面路由
52
+ @app.get("/admin")
53
+ async def admin():
54
+ return FileResponse("static/admin.html")
55
+ # 获取真实IP地址
56
+ def get_client_ip(request: Request) -> str:
57
+ # 尝试从各种头部获取IP
58
+ forwarded_for = request.headers.get("x-forwarded-for")
59
+ if forwarded_for:
60
+ return forwarded_for.split(",")[0].strip()
61
+
62
+ real_ip = request.headers.get("x-real-ip")
63
+ if real_ip:
64
+ return real_ip
65
+
66
+ return request.client.host
67
+
68
+ # IP白名单中间件
69
+ @app.middleware("http")
70
+ async def ip_whitelist(request: Request, call_next):
71
+ # 只对API接口启用白名单,排除管理后台
72
+ if "/api/" in request.url.path and "/api/admin/" not in request.url.path and "/api/keys" not in request.url.path:
73
+ if Config.WHITELIST_IPS and Config.WHITELIST_IPS[0]:
74
+ client_ip = get_client_ip(request)
75
+ if client_ip not in Config.WHITELIST_IPS:
76
+ raise HTTPException(status_code=403, detail="IP not allowed")
77
+ return await call_next(request)
78
+
79
  # 初始化keys
80
  def init_keys():
81
+ global keys, key_cycle, first_key, key_status_map
82
  try:
83
  if Config.KEYS_URL:
84
  response = requests.get(Config.KEYS_URL)
 
90
  if keys:
91
  first_key = keys[0]
92
  key_cycle = cycle(keys)
93
+ # 初始化所有key的状态
94
+ key_status_map = {key: KeyStatus(key) for key in keys}
95
  print(f"Loaded {len(keys)} API keys")
96
  except Exception as e:
97
  print(f"Error loading keys: {e}")
98
  keys = []
99
  key_cycle = None
100
  first_key = None
101
+ key_status_map = {}
102
 
103
+ # 获取有效的key
104
+ def get_valid_key():
105
  global key_cycle
106
  if not key_cycle:
107
  raise HTTPException(status_code=500, detail="No API keys available")
 
 
 
 
 
 
 
 
 
 
 
108
 
109
+ # 尝试最多len(keys)次
110
+ for _ in range(len(keys)):
111
+ key = next(key_cycle)
112
+ key_info = key_status_map.get(key)
113
+ if not key_info:
114
+ key_info = KeyStatus(key)
115
+ key_status_map[key] = key_info
116
+
117
+ # 检查key是否可用
118
+ if key_info.status == "valid":
119
+ return key
120
+ elif key_info.status == "cooling":
121
+ if key_info.cooling_until and datetime.now() > key_info.cooling_until:
122
+ key_info.status = "valid"
123
+ key_info.cooling_until = None
124
+ return key
125
+
126
+ raise HTTPException(status_code=500, detail="No valid API keys available")
127
 
128
+ # 标记key为冷却状态
129
+ def mark_key_cooling(key: str):
130
+ if key in key_status_map:
131
+ key_status_map[key].status = "cooling"
132
+ key_status_map[key].cooling_until = datetime.now() + timedelta(days=30)
133
+
134
+ # 检查key状态
135
+ async def check_key_status(key: str) -> bool:
136
+ try:
137
+ async with httpx.AsyncClient() as client:
138
+ response = await client.get(
139
+ f"{Config.OPENAI_API_BASE}/models",
140
+ headers={"Authorization": f"Bearer {key}"}
141
+ )
142
+ is_valid = response.status_code == 200
143
+ key_info = key_status_map.get(key)
144
+ if key_info:
145
+ if not is_valid:
146
+ key_info.status = "cooling"
147
+ key_info.cooling_until = datetime.now() + timedelta(days=30)
148
+ else:
149
+ key_info.status = "valid"
150
+ key_info.cooling_until = None
151
+ key_info.last_check = datetime.now()
152
+ return is_valid
153
+ except Exception as e:
154
+ print(f"Error checking key {key}: {e}")
155
+ return False
156
 
157
  # 流式响应生成器
158
  async def stream_generator(response):
 
162
  chunk_str = chunk.decode('utf-8')
163
  buffer += chunk_str
164
 
 
165
  while '\n\n' in buffer:
166
  event, buffer = buffer.split('\n\n', 1)
167
  if event.startswith('data: '):
 
170
  yield f"data: [DONE]\n\n"
171
  else:
172
  try:
 
173
  json_data = json.loads(data)
174
  yield f"data: {json.dumps(json_data)}\n\n"
 
 
175
  except json.JSONDecodeError:
176
  print(f"JSON decode error for data: {data}")
177
  continue
178
  except Exception as e:
179
  print(f"Stream Error: {str(e)}")
180
  yield f"data: {json.dumps({'error': str(e)})}\n\n"
181
+ # 验证管理员密码
182
+ def verify_admin(password: str):
183
+ if not password or password != Config.ADMIN_PASSWORD:
184
+ raise HTTPException(status_code=403, detail="Invalid admin password")
185
+
186
+ # 管理员登录
187
+ @app.post("/api/admin/login")
188
+ async def admin_login(request: Request):
189
+ data = await request.json()
190
+ password = data.get("password")
191
+ verify_admin(password)
192
+ return {"status": "success"}
193
+
194
+ # 获取所有密钥状态
195
+ @app.get("/api/keys")
196
+ async def list_keys(password: str):
197
+ verify_admin(password)
198
+ return {
199
+ "keys": [
200
+ {
201
+ "key": k,
202
+ "status": key_status_map[k].status if k in key_status_map else "valid",
203
+ "cooling_until": key_status_map[k].cooling_until.isoformat() if k in key_status_map and key_status_map[k].cooling_until else None,
204
+ "last_check": key_status_map[k].last_check.isoformat() if k in key_status_map and key_status_map[k].last_check else None
205
+ }
206
+ for k in keys
207
+ ]
208
+ }
209
+
210
+ # 添加新密钥
211
+ @app.post("/api/keys/add")
212
+ async def add_key(request: Request):
213
+ data = await request.json()
214
+ verify_admin(data.get("password"))
215
+ new_key = data.get("key", "").strip()
216
+
217
+ if not new_key:
218
+ raise HTTPException(status_code=400, detail="Key is required")
219
+ if new_key in keys:
220
+ raise HTTPException(status_code=400, detail="Key already exists")
221
+
222
+ keys.append(new_key)
223
+ key_status_map[new_key] = KeyStatus(new_key)
224
+
225
+ # 重新初始化key_cycle
226
+ global key_cycle
227
+ key_cycle = cycle(keys)
228
+
229
+ # 保存到文件
230
+ if not Config.KEYS_URL:
231
+ with open("key.txt", "w") as f:
232
+ f.write("\n".join(keys))
233
+
234
+ return {"status": "success"}
235
+
236
+ # 删除密钥
237
+ @app.delete("/api/keys/{key}")
238
+ async def delete_key(key: str, password: str):
239
+ verify_admin(password)
240
+ if key in keys:
241
+ keys.remove(key)
242
+ if key in key_status_map:
243
+ del key_status_map[key]
244
+
245
+ # 重新初始化key_cycle
246
+ global key_cycle
247
+ key_cycle = cycle(keys)
248
+
249
+ # 保存到文件
250
+ if not Config.KEYS_URL:
251
+ with open("key.txt", "w") as f:
252
+ f.write("\n".join(keys))
253
+
254
+ return {"status": "success"}
255
+
256
+ # 批量删除密钥
257
+ @app.post("/api/keys/delete-batch")
258
+ async def delete_keys_batch(request: Request):
259
+ data = await request.json()
260
+ verify_admin(data.get("password"))
261
+ keys_to_delete = data.get("keys", [])
262
+
263
+ for key in keys_to_delete:
264
+ if key in keys:
265
+ keys.remove(key)
266
+ if key in key_status_map:
267
+ del key_status_map[key]
268
+
269
+ # 重新初始化key_cycle
270
+ global key_cycle
271
+ key_cycle = cycle(keys)
272
+
273
+ # 保存到文件
274
+ if not Config.KEYS_URL:
275
+ with open("key.txt", "w") as f:
276
+ f.write("\n".join(keys))
277
+
278
+ return {"status": "success"}
279
+
280
+ # 检查单个密钥
281
+ @app.get("/api/keys/check/{key}")
282
+ async def check_single_key(key: str, password: str):
283
+ verify_admin(password)
284
+ if key not in keys:
285
+ raise HTTPException(status_code=404, detail="Key not found")
286
+
287
+ is_valid = await check_key_status(key)
288
+ return {"status": "success", "valid": is_valid}
289
+
290
+ # 检查所有密钥
291
+ @app.post("/api/keys/check-all")
292
+ async def check_all_keys(password: str):
293
+ verify_admin(password)
294
+ for key in keys:
295
+ await check_key_status(key)
296
+ return {"status": "success"}
297
  # 模型列表路由
298
  @app.get("/api/v1/models")
299
  async def list_models():
300
+ # 对于模型列表,使用第一个可用的key
301
+ try:
302
+ key = first_key
303
+ if key in key_status_map and key_status_map[key].status == "cooling":
304
+ key = get_valid_key() # 如果第一个key在冷却,则使用轮询获取
305
+
306
+ headers = {
307
+ "Authorization": f"Bearer {key}",
308
+ "Content-Type": "application/json"
309
+ }
310
+
311
+ async with httpx.AsyncClient() as client:
312
  response = await client.get(
313
  f"{Config.OPENAI_API_BASE}/models",
314
  headers=headers
315
  )
316
+
317
+ if response.status_code == 429: # 如果遇到限流
318
+ mark_key_cooling(key)
319
+ # 重试一次,使用轮询的key
320
+ key = get_valid_key()
321
+ headers["Authorization"] = f"Bearer {key}"
322
+ response = await client.get(
323
+ f"{Config.OPENAI_API_BASE}/models",
324
+ headers=headers
325
+ )
326
+
327
  return response.json()
328
+
329
+ except Exception as e:
330
+ print(f"Models Error: {str(e)}")
331
+ raise HTTPException(status_code=500, detail=str(e))
332
 
333
  # 聊天完成路由
334
  @app.post("/api/v1/chat/completions")
 
338
  body = await request.body()
339
  body_json = json.loads(body)
340
 
341
+ # 使用轮询获取key
342
+ key = get_valid_key()
343
+
344
  # 获取headers
345
  headers = {
346
+ "Authorization": f"Bearer {key}",
347
  "Content-Type": "application/json",
348
  "Accept": "text/event-stream" if body_json.get("stream") else "application/json"
349
  }
 
358
  json=body_json
359
  )
360
 
361
+ # 如果遇到限流,标记key为冷却状态并重试
362
+ if response.status_code == 429:
363
+ mark_key_cooling(key)
364
+ # 重试一次
365
+ key = get_valid_key()
366
+ headers["Authorization"] = f"Bearer {key}"
367
+ response = await client.post(
368
+ url,
369
+ headers=headers,
370
+ json=body_json
371
+ )
372
+
373
  # 检查响应状态
374
  if response.status_code != 200:
375
  return Response(
 
399
  except Exception as e:
400
  print(f"Chat Error: {str(e)}")
401
  raise HTTPException(status_code=500, detail=str(e))
402
+
403
  # 代理其他请求
404
  @app.api_route("/api/v1/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
405
  async def proxy(path: str, request: Request):
 
409
  try:
410
  method = request.method
411
  body = await request.body() if method in ["POST", "PUT"] else None
412
+
413
+ # 使用第一个可用的key
414
+ key = first_key
415
+ if key in key_status_map and key_status_map[key].status == "cooling":
416
+ key = get_valid_key()
417
+
418
  headers = {
419
+ "Authorization": f"Bearer {key}",
420
  "Content-Type": "application/json"
421
  }
422
 
 
428
  content=body
429
  )
430
 
431
+ if response.status_code == 429: # 如果遇到限流
432
+ mark_key_cooling(key)
433
+ # 重试一次,使用轮询的key
434
+ key = get_valid_key()
435
+ headers["Authorization"] = f"Bearer {key}"
436
+ response = await client.request(
437
+ method=method,
438
+ url=f"{Config.OPENAI_API_BASE}/{path}",
439
+ headers=headers,
440
+ content=body
441
+ )
442
+
443
  return Response(
444
  content=response.text,
445
  status_code=response.status_code,
446
  media_type=response.headers.get("content-type", "application/json")
447
  )
448
+
449
  except Exception as e:
450
  print(f"Proxy Error: {str(e)}")
451
  raise HTTPException(status_code=500, detail=str(e))