bobocup commited on
Commit
3a461ba
·
verified ·
1 Parent(s): ddd1134

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -252
app.py CHANGED
@@ -1,16 +1,14 @@
1
  from fastapi import FastAPI, HTTPException, Request, Response
2
  from fastapi.middleware.cors import CORSMiddleware
3
- from fastapi.responses import StreamingResponse, HTMLResponse
4
- from fastapi.staticfiles import StaticFiles
5
  import httpx
6
  import os
7
  import json
8
- from typing import List, Optional, Dict
9
  import requests
10
  from itertools import cycle
11
  import asyncio
12
  import time
13
- from datetime import datetime, timedelta
14
 
15
  # 创建FastAPI应用
16
  app = FastAPI()
@@ -29,107 +27,59 @@ class Config:
29
  OPENAI_API_BASE = "https://api.x.ai/v1"
30
  KEYS_URL = os.getenv("KEYS_URL", "")
31
  WHITELIST_IPS = os.getenv("WHITELIST_IPS", "").split(",")
32
- ADMIN_PASSWORD = os.getenv("ADMIN_PASSWORD", "")
33
- CHUNK_SIZE = 1
34
-
35
- # Key信息类
36
- class KeyInfo:
37
- def __init__(self, key: str):
38
- self.key = key
39
- self.is_valid = True
40
- self.cooling_until = None
41
- self.last_used = None
42
- self.error_count = 0
43
-
44
- def to_dict(self):
45
- return {
46
- "key": self.key,
47
- "is_valid": self.is_valid,
48
- "cooling_until": self.cooling_until.isoformat() if self.cooling_until else None,
49
- "last_used": self.last_used.isoformat() if self.last_used else None,
50
- "error_count": self.error_count,
51
- "status": self.get_status()
52
- }
53
-
54
- def get_status(self):
55
- if self.cooling_until and self.cooling_until > datetime.now():
56
- return "cooling"
57
- if not self.is_valid:
58
- return "invalid"
59
- return "valid"
60
 
61
  # 全局变量
62
- keys_info: Dict[str, KeyInfo] = {}
63
- chat_keys = []
64
  first_key = None
 
65
  # 初始化keys
66
  def init_keys():
67
- global keys_info, chat_keys, first_key
68
  try:
69
  if Config.KEYS_URL:
70
  response = requests.get(Config.KEYS_URL)
71
- raw_keys = [k.strip() for k in response.text.splitlines() if k.strip()]
72
  else:
73
  with open("key.txt", "r") as f:
74
- raw_keys = [k.strip() for k in f.readlines() if k.strip()]
75
 
76
- keys_info = {k: KeyInfo(k) for k in raw_keys}
77
- chat_keys = list(raw_keys)
78
- first_key = raw_keys[0] if raw_keys else None
79
- print(f"Loaded {len(raw_keys)} API keys")
80
  except Exception as e:
81
  print(f"Error loading keys: {e}")
82
- keys_info = {}
83
- chat_keys = []
84
  first_key = None
85
 
86
- # 获取可用的chat key
87
  def get_chat_key():
88
- valid_keys = [k for k in chat_keys if is_key_available(k)]
89
- if not valid_keys:
90
- raise HTTPException(status_code=500, detail="No available API keys")
91
-
92
- # 简单轮询
93
- key = valid_keys[0]
94
- chat_keys.append(chat_keys.pop(0))
95
- return key
96
-
97
- # 检查key是否可用
98
- def is_key_available(key: str) -> bool:
99
- info = keys_info.get(key)
100
- if not info or not info.is_valid:
101
- return False
102
- if info.cooling_until and info.cooling_until > datetime.now():
103
- return False
104
- return True
105
-
106
- # 设置key冷却
107
- def set_key_cooling(key: str, days: int = 30):
108
- if key in keys_info:
109
- keys_info[key].cooling_until = datetime.now() + timedelta(days=days)
110
-
111
  # 获取真实IP地址
112
  def get_client_ip(request: Request) -> str:
 
113
  forwarded_for = request.headers.get("x-forwarded-for")
114
  if forwarded_for:
115
  return forwarded_for.split(",")[0].strip()
 
 
 
 
 
116
  return request.client.host
117
 
118
  # IP白名单中间件
119
  @app.middleware("http")
120
- async def access_control(request: Request, call_next):
121
- path = request.url.path
122
-
123
- # 管理后台相关路径不检查IP
124
- if path.startswith("/admin") or path.startswith("/api/admin") or path.startswith("/api/keys"):
125
- return await call_next(request)
126
-
127
- # API调用检查白名单
128
  if Config.WHITELIST_IPS and Config.WHITELIST_IPS[0]:
129
  client_ip = get_client_ip(request)
130
  if client_ip not in Config.WHITELIST_IPS:
131
  raise HTTPException(status_code=403, detail="IP not allowed")
132
-
133
  return await call_next(request)
134
 
135
  # 流式响应生成器
@@ -160,91 +110,11 @@ async def stream_generator(response):
160
  except Exception as e:
161
  print(f"Stream Error: {str(e)}")
162
  yield f"data: {json.dumps({'error': str(e)})}\n\n"
163
-
164
-
165
- # ���改 handle_api_request 函数中的超时设置
166
- async def handle_api_request(url: str, headers: dict, method: str = "GET", body: dict = None, for_chat: bool = False):
167
- max_retries = 3
168
- current_try = 0
169
-
170
- while current_try < max_retries:
171
- try:
172
- if for_chat:
173
- key = get_chat_key()
174
- else:
175
- key = first_key
176
-
177
- if not key:
178
- raise HTTPException(status_code=500, detail="No API keys available")
179
-
180
- headers["Authorization"] = f"Bearer {key}"
181
-
182
- # 减少连接超时,增加读取超时
183
- timeout = httpx.Timeout(
184
- connect=5.0, # 连接超时5秒
185
- read=60.0, # 读取超时60秒
186
- write=5.0, # 写入超时5秒
187
- pool=5.0 # 连接池超时5秒
188
- )
189
-
190
- async with httpx.AsyncClient(timeout=timeout) as client:
191
- response = await client.request(
192
- method=method,
193
- url=url,
194
- headers=headers,
195
- json=body
196
- )
197
-
198
- # 检查配额不足
199
- if response.status_code == 429 or "insufficient_quota" in response.text.lower():
200
- print(f"Key {key} quota exceeded, trying next key...")
201
- set_key_cooling(key)
202
- current_try += 1
203
- continue
204
-
205
- return response
206
-
207
- except Exception as e:
208
- print(f"Request error: {str(e)}")
209
- current_try += 1
210
- if current_try == max_retries:
211
- raise HTTPException(status_code=500, detail=f"API request failed after {max_retries} retries")
212
- # 添加静态文件支持
213
- app.mount("/static", StaticFiles(directory="static"), name="static")
214
-
215
- # 管理页面路由
216
- @app.get("/admin", response_class=HTMLResponse)
217
- async def admin_page():
218
- try:
219
- with open("static/admin.html", "r", encoding="utf-8") as f:
220
- return f.read()
221
- except Exception as e:
222
- print(f"Error loading admin page: {e}")
223
- raise HTTPException(status_code=500, detail="Error loading admin page")
224
-
225
- # 验证管理密码
226
- @app.post("/api/admin/login")
227
- async def admin_login(request: Request):
228
- try:
229
- body = await request.json()
230
- password = body.get("password")
231
- if password == Config.ADMIN_PASSWORD:
232
- return {"status": "success"}
233
- raise HTTPException(status_code=401, detail="Invalid password")
234
- except json.JSONDecodeError:
235
- raise HTTPException(status_code=400, detail="Invalid JSON")
236
-
237
- # 获取所有keys状态
238
- @app.get("/api/keys")
239
- async def get_keys(password: str):
240
- if password != Config.ADMIN_PASSWORD:
241
- raise HTTPException(status_code=401, detail="Invalid password")
242
- return {"keys": [info.to_dict() for info in keys_info.values()]}
243
-
244
- # 检查key是否有效
245
- async def check_key_valid(key: str) -> bool:
246
  headers = {
247
- "Authorization": f"Bearer {key}",
248
  "Content-Type": "application/json"
249
  }
250
 
@@ -254,77 +124,14 @@ async def check_key_valid(key: str) -> bool:
254
  f"{Config.OPENAI_API_BASE}/models",
255
  headers=headers
256
  )
257
- return response.status_code == 200
258
- except:
259
- return False
260
-
261
- # 批量检查keys
262
- @app.post("/api/keys/check-all")
263
- async def check_all_keys(password: str):
264
- if password != Config.ADMIN_PASSWORD:
265
- raise HTTPException(status_code=401, detail="Invalid password")
266
-
267
- results = []
268
- for key in keys_info:
269
- is_valid = await check_key_valid(key)
270
- keys_info[key].is_valid = is_valid
271
- results.append({"key": key, "valid": is_valid})
272
-
273
- return {"results": results}
274
-
275
- # 批量删除keys
276
- @app.post("/api/keys/delete-batch")
277
- async def delete_batch_keys(request: Request):
278
- try:
279
- body = await request.json()
280
- password = body.get("password")
281
- keys_to_delete = body.get("keys", [])
282
-
283
- if password != Config.ADMIN_PASSWORD:
284
- raise HTTPException(status_code=401, detail="Invalid password")
285
-
286
- for key in keys_to_delete:
287
- if key in keys_info:
288
- del keys_info[key]
289
- if key in chat_keys:
290
- chat_keys.remove(key)
291
-
292
- return {"status": "success", "deleted_count": len(keys_to_delete)}
293
- except Exception as e:
294
- raise HTTPException(status_code=400, detail=str(e))
295
 
296
- # 添加新key
297
- @app.post("/api/keys/add")
298
- async def add_key(request: Request):
299
- try:
300
- body = await request.json()
301
- password = body.get("password")
302
- key = body.get("key")
303
-
304
- if password != Config.ADMIN_PASSWORD:
305
- raise HTTPException(status_code=401, detail="Invalid password")
306
-
307
- if key in keys_info:
308
- raise HTTPException(status_code=400, detail="Key already exists")
309
-
310
- # 检查key有效性
311
- is_valid = await check_key_valid(key)
312
- if not is_valid:
313
- raise HTTPException(status_code=400, detail="Invalid key")
314
-
315
- keys_info[key] = KeyInfo(key)
316
- chat_keys.append(key)
317
-
318
- return {"status": "success", "message": "Key added successfully"}
319
- except HTTPException:
320
- raise
321
- except Exception as e:
322
- raise HTTPException(status_code=400, detail=str(e))
323
  # 聊天完成路由
324
  @app.post("/api/v1/chat/completions")
325
  async def chat_completions(request: Request):
326
- global key_cycle # 声明使用全局变量
327
-
328
  try:
329
  # 获取请求体
330
  body = await request.body()
@@ -332,7 +139,7 @@ async def chat_completions(request: Request):
332
 
333
  # 获取headers
334
  headers = {
335
- "Authorization": f"Bearer {next(key_cycle)}",
336
  "Content-Type": "application/json",
337
  "Accept": "text/event-stream" if body_json.get("stream") else "application/json"
338
  }
@@ -376,7 +183,6 @@ async def chat_completions(request: Request):
376
  except Exception as e:
377
  print(f"Chat Error: {str(e)}")
378
  raise HTTPException(status_code=500, detail=str(e))
379
-
380
  # 代理其他请求
381
  @app.api_route("/api/v1/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
382
  async def proxy(path: str, request: Request):
@@ -385,36 +191,34 @@ async def proxy(path: str, request: Request):
385
 
386
  try:
387
  method = request.method
388
- body = await request.json() if method in ["POST", "PUT"] else None
389
- headers = {"Content-Type": "application/json"}
390
-
391
- response = await handle_api_request(
392
- url=f"{Config.OPENAI_API_BASE}/{path}",
393
- headers=headers,
394
- method=method,
395
- body=body,
396
- for_chat=False
397
- )
398
-
399
- return Response(
400
- content=response.text,
401
- status_code=response.status_code,
402
- media_type=response.headers.get("content-type", "application/json")
403
- )
404
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
405
  except Exception as e:
406
- print(f"Proxy error: {str(e)}")
407
  raise HTTPException(status_code=500, detail=str(e))
408
 
409
  # 健康检查路由
410
  @app.get("/api/health")
411
  async def health_check():
412
- available_count = sum(1 for k in keys_info.values() if is_key_available(k.key))
413
- return {
414
- "status": "healthy",
415
- "total_keys": len(keys_info),
416
- "available_keys": available_count
417
- }
418
 
419
  # 启动时初始化
420
  @app.on_event("startup")
 
1
  from fastapi import FastAPI, HTTPException, Request, Response
2
  from fastapi.middleware.cors import CORSMiddleware
3
+ from fastapi.responses import StreamingResponse
 
4
  import httpx
5
  import os
6
  import json
7
+ from typing import List, Optional
8
  import requests
9
  from itertools import cycle
10
  import asyncio
11
  import time
 
12
 
13
  # 创建FastAPI应用
14
  app = FastAPI()
 
27
  OPENAI_API_BASE = "https://api.x.ai/v1"
28
  KEYS_URL = os.getenv("KEYS_URL", "")
29
  WHITELIST_IPS = os.getenv("WHITELIST_IPS", "").split(",")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  # 全局变量
32
+ keys = []
33
+ key_cycle = None
34
  first_key = None
35
+
36
  # 初始化keys
37
  def init_keys():
38
+ global keys, key_cycle, first_key
39
  try:
40
  if Config.KEYS_URL:
41
  response = requests.get(Config.KEYS_URL)
42
+ keys = [k.strip() for k in response.text.splitlines() if k.strip()]
43
  else:
44
  with open("key.txt", "r") as f:
45
+ keys = [k.strip() for k in f.readlines() if k.strip()]
46
 
47
+ if keys:
48
+ first_key = keys[0]
49
+ key_cycle = cycle(keys)
50
+ print(f"Loaded {len(keys)} API keys")
51
  except Exception as e:
52
  print(f"Error loading keys: {e}")
53
+ keys = []
54
+ key_cycle = None
55
  first_key = None
56
 
57
+ # 获取聊天key
58
  def get_chat_key():
59
+ global key_cycle
60
+ if not key_cycle:
61
+ raise HTTPException(status_code=500, detail="No API keys available")
62
+ return next(key_cycle)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  # 获取真实IP地址
64
  def get_client_ip(request: Request) -> str:
65
+ # 尝试从各种头部获取IP
66
  forwarded_for = request.headers.get("x-forwarded-for")
67
  if forwarded_for:
68
  return forwarded_for.split(",")[0].strip()
69
+
70
+ real_ip = request.headers.get("x-real-ip")
71
+ if real_ip:
72
+ return real_ip
73
+
74
  return request.client.host
75
 
76
  # IP白名单中间件
77
  @app.middleware("http")
78
+ async def ip_whitelist(request: Request, call_next):
 
 
 
 
 
 
 
79
  if Config.WHITELIST_IPS and Config.WHITELIST_IPS[0]:
80
  client_ip = get_client_ip(request)
81
  if client_ip not in Config.WHITELIST_IPS:
82
  raise HTTPException(status_code=403, detail="IP not allowed")
 
83
  return await call_next(request)
84
 
85
  # 流式响应生成器
 
110
  except Exception as e:
111
  print(f"Stream Error: {str(e)}")
112
  yield f"data: {json.dumps({'error': str(e)})}\n\n"
113
+ # 模型列表路由
114
+ @app.get("/api/v1/models")
115
+ async def list_models():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  headers = {
117
+ "Authorization": f"Bearer {get_chat_key()}",
118
  "Content-Type": "application/json"
119
  }
120
 
 
124
  f"{Config.OPENAI_API_BASE}/models",
125
  headers=headers
126
  )
127
+ return response.json()
128
+ except Exception as e:
129
+ print(f"Models Error: {str(e)}")
130
+ raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  # 聊天完成路由
133
  @app.post("/api/v1/chat/completions")
134
  async def chat_completions(request: Request):
 
 
135
  try:
136
  # 获取请求体
137
  body = await request.body()
 
139
 
140
  # 获取headers
141
  headers = {
142
+ "Authorization": f"Bearer {get_chat_key()}",
143
  "Content-Type": "application/json",
144
  "Accept": "text/event-stream" if body_json.get("stream") else "application/json"
145
  }
 
183
  except Exception as e:
184
  print(f"Chat Error: {str(e)}")
185
  raise HTTPException(status_code=500, detail=str(e))
 
186
  # 代理其他请求
187
  @app.api_route("/api/v1/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
188
  async def proxy(path: str, request: Request):
 
191
 
192
  try:
193
  method = request.method
194
+ body = await request.body() if method in ["POST", "PUT"] else None
195
+ headers = {
196
+ "Authorization": f"Bearer {get_chat_key()}",
197
+ "Content-Type": "application/json"
198
+ }
 
 
 
 
 
 
 
 
 
 
 
199
 
200
+ async with httpx.AsyncClient() as client:
201
+ response = await client.request(
202
+ method=method,
203
+ url=f"{Config.OPENAI_API_BASE}/{path}",
204
+ headers=headers,
205
+ content=body
206
+ )
207
+
208
+ return Response(
209
+ content=response.text,
210
+ status_code=response.status_code,
211
+ media_type=response.headers.get("content-type", "application/json")
212
+ )
213
+
214
  except Exception as e:
215
+ print(f"Proxy Error: {str(e)}")
216
  raise HTTPException(status_code=500, detail=str(e))
217
 
218
  # 健康检查路由
219
  @app.get("/api/health")
220
  async def health_check():
221
+ return {"status": "healthy", "key_count": len(keys)}
 
 
 
 
 
222
 
223
  # 启动时初始化
224
  @app.on_event("startup")