CassiopeiaCode commited on
Commit
e00ef83
·
1 Parent(s): 93d4c04

feat: 实现账号统计和自动禁用功能

Browse files

- 添加error_count和success_count字段统计请求结果
- 成功时重置error_count为0
- 失败时error_count+1,超过阈值自动禁用账号
- 添加MAX_ERROR_COUNT环境变量配置(默认100)
- 使用StreamTracker追踪流式和非流式响应是否返回有效字符

Files changed (3) hide show
  1. .env.example +4 -1
  2. app.py +60 -26
  3. replicate.py +15 -4
.env.example CHANGED
@@ -1,4 +1,7 @@
1
  # OpenAI 风格 API Key 白名单(仅用于授权,与账号无关)
2
  # 多个用逗号分隔,例如:
3
  # OPENAI_KEYS="key1,key2,key3"
4
- OPENAI_KEYS=""
 
 
 
 
1
  # OpenAI 风格 API Key 白名单(仅用于授权,与账号无关)
2
  # 多个用逗号分隔,例如:
3
  # OPENAI_KEYS="key1,key2,key3"
4
+ OPENAI_KEYS=""
5
+
6
+ # 出错次数阈值,超过此值自动禁用账号
7
+ MAX_ERROR_COUNT=100
app.py CHANGED
@@ -73,13 +73,16 @@ def _ensure_db():
73
  )
74
  """
75
  )
76
- # add enabled column if missing
77
  try:
78
  cols = [row[1] for row in conn.execute("PRAGMA table_info(accounts)").fetchall()]
79
  if "enabled" not in cols:
80
  conn.execute("ALTER TABLE accounts ADD COLUMN enabled INTEGER DEFAULT 1")
 
 
 
 
81
  except Exception:
82
- # best-effort; ignore if cannot alter (should not happen for SQLite)
83
  pass
84
  conn.commit()
85
 
@@ -122,6 +125,7 @@ def _parse_allowed_keys_env() -> List[str]:
122
  return keys
123
 
124
  ALLOWED_API_KEYS: List[str] = _parse_allowed_keys_env()
 
125
 
126
  def _extract_bearer(token_header: Optional[str]) -> Optional[str]:
127
  if not token_header:
@@ -257,6 +261,23 @@ def get_account(account_id: str) -> Dict[str, Any]:
257
  raise HTTPException(status_code=404, detail="Account not found")
258
  return _row_to_dict(row)
259
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  # ------------------------------------------------------------------------------
261
  # Dependencies
262
  # ------------------------------------------------------------------------------
@@ -307,7 +328,7 @@ def chat_completions(req: ChatCompletionRequest, account: Dict[str, Any] = Depen
307
  model = req.model
308
  do_stream = bool(req.stream)
309
 
310
- def _send_upstream(stream: bool) -> Tuple[Optional[str], Optional[Generator[str, None, None]]]:
311
  access = account.get("accessToken")
312
  if not access:
313
  refreshed = refresh_access_token_in_db(account["id"])
@@ -327,41 +348,54 @@ def chat_completions(req: ChatCompletionRequest, account: Dict[str, Any] = Depen
327
  raise
328
 
329
  if not do_stream:
330
- text, _ = _send_upstream(stream=False)
331
- return JSONResponse(content=_openai_non_streaming_response(text or "", model))
 
 
 
 
 
332
  else:
333
  created = int(time.time())
334
  stream_id = f"chatcmpl-{uuid.uuid4()}"
335
  model_used = model or "unknown"
336
 
337
  def event_gen() -> Generator[str, None, None]:
338
- yield _sse_format({
339
- "id": stream_id,
340
- "object": "chat.completion.chunk",
341
- "created": created,
342
- "model": model_used,
343
- "choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": None}],
344
- })
345
- _, it = _send_upstream(stream=True)
346
- assert it is not None
347
- for piece in it:
348
- if not piece:
349
- continue
 
 
 
 
 
 
 
 
 
350
  yield _sse_format({
351
  "id": stream_id,
352
  "object": "chat.completion.chunk",
353
  "created": created,
354
  "model": model_used,
355
- "choices": [{"index": 0, "delta": {"content": piece}, "finish_reason": None}],
356
  })
357
- yield _sse_format({
358
- "id": stream_id,
359
- "object": "chat.completion.chunk",
360
- "created": created,
361
- "model": model_used,
362
- "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
363
- })
364
- yield "data: [DONE]\n\n"
365
 
366
  return StreamingResponse(event_gen(), media_type="text/event-stream")
367
 
 
73
  )
74
  """
75
  )
76
+ # add columns if missing
77
  try:
78
  cols = [row[1] for row in conn.execute("PRAGMA table_info(accounts)").fetchall()]
79
  if "enabled" not in cols:
80
  conn.execute("ALTER TABLE accounts ADD COLUMN enabled INTEGER DEFAULT 1")
81
+ if "error_count" not in cols:
82
+ conn.execute("ALTER TABLE accounts ADD COLUMN error_count INTEGER DEFAULT 0")
83
+ if "success_count" not in cols:
84
+ conn.execute("ALTER TABLE accounts ADD COLUMN success_count INTEGER DEFAULT 0")
85
  except Exception:
 
86
  pass
87
  conn.commit()
88
 
 
125
  return keys
126
 
127
  ALLOWED_API_KEYS: List[str] = _parse_allowed_keys_env()
128
+ MAX_ERROR_COUNT: int = int(os.getenv("MAX_ERROR_COUNT", "100"))
129
 
130
  def _extract_bearer(token_header: Optional[str]) -> Optional[str]:
131
  if not token_header:
 
261
  raise HTTPException(status_code=404, detail="Account not found")
262
  return _row_to_dict(row)
263
 
264
+ def _update_stats(account_id: str, success: bool) -> None:
265
+ with _conn() as conn:
266
+ if success:
267
+ conn.execute("UPDATE accounts SET success_count=success_count+1, error_count=0, updated_at=? WHERE id=?",
268
+ (time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime()), account_id))
269
+ else:
270
+ row = conn.execute("SELECT error_count FROM accounts WHERE id=?", (account_id,)).fetchone()
271
+ if row:
272
+ new_count = (row[0] or 0) + 1
273
+ if new_count >= MAX_ERROR_COUNT:
274
+ conn.execute("UPDATE accounts SET error_count=?, enabled=0, updated_at=? WHERE id=?",
275
+ (new_count, time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime()), account_id))
276
+ else:
277
+ conn.execute("UPDATE accounts SET error_count=?, updated_at=? WHERE id=?",
278
+ (new_count, time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime()), account_id))
279
+ conn.commit()
280
+
281
  # ------------------------------------------------------------------------------
282
  # Dependencies
283
  # ------------------------------------------------------------------------------
 
328
  model = req.model
329
  do_stream = bool(req.stream)
330
 
331
+ def _send_upstream(stream: bool) -> Tuple[Optional[str], Optional[Generator[str, None, None]], Any]:
332
  access = account.get("accessToken")
333
  if not access:
334
  refreshed = refresh_access_token_in_db(account["id"])
 
348
  raise
349
 
350
  if not do_stream:
351
+ try:
352
+ text, _, tracker = _send_upstream(stream=False)
353
+ _update_stats(account["id"], bool(text))
354
+ return JSONResponse(content=_openai_non_streaming_response(text or "", model))
355
+ except Exception as e:
356
+ _update_stats(account["id"], False)
357
+ raise
358
  else:
359
  created = int(time.time())
360
  stream_id = f"chatcmpl-{uuid.uuid4()}"
361
  model_used = model or "unknown"
362
 
363
  def event_gen() -> Generator[str, None, None]:
364
+ tracker = None
365
+ try:
366
+ yield _sse_format({
367
+ "id": stream_id,
368
+ "object": "chat.completion.chunk",
369
+ "created": created,
370
+ "model": model_used,
371
+ "choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": None}],
372
+ })
373
+ _, it, tracker = _send_upstream(stream=True)
374
+ assert it is not None
375
+ for piece in it:
376
+ if not piece:
377
+ continue
378
+ yield _sse_format({
379
+ "id": stream_id,
380
+ "object": "chat.completion.chunk",
381
+ "created": created,
382
+ "model": model_used,
383
+ "choices": [{"index": 0, "delta": {"content": piece}, "finish_reason": None}],
384
+ })
385
  yield _sse_format({
386
  "id": stream_id,
387
  "object": "chat.completion.chunk",
388
  "created": created,
389
  "model": model_used,
390
+ "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
391
  })
392
+ yield "data: [DONE]\n\n"
393
+ if tracker:
394
+ _update_stats(account["id"], tracker.has_content)
395
+ except Exception:
396
+ if tracker:
397
+ _update_stats(account["id"], tracker.has_content)
398
+ raise
 
399
 
400
  return StreamingResponse(event_gen(), media_type="text/event-stream")
401
 
replicate.py CHANGED
@@ -5,6 +5,16 @@ from typing import Dict, Optional, Tuple, Iterator, List, Generator, Any
5
  import struct
6
  import requests
7
 
 
 
 
 
 
 
 
 
 
 
8
  BASE_DIR = Path(__file__).resolve().parent
9
  TEMPLATE_PATH = BASE_DIR / "templates" / "streaming_request.json"
10
 
@@ -175,7 +185,7 @@ def inject_model(body_json: Dict[str, Any], model: Optional[str]) -> None:
175
  except Exception:
176
  pass
177
 
178
- def send_chat_request(access_token: str, messages: List[Dict[str, Any]], model: Optional[str] = None, stream: bool = False, timeout: Tuple[int,int] = (15,300)) -> Tuple[Optional[str], Optional[Generator[str, None, None]]]:
179
  url, headers_from_log, body_json = load_template()
180
  headers_from_log["amz-sdk-invocation-id"] = str(uuid.uuid4())
181
  try:
@@ -196,6 +206,7 @@ def send_chat_request(access_token: str, messages: List[Dict[str, Any]], model:
196
  err = f"HTTP {resp.status_code}"
197
  raise requests.HTTPError(f"Upstream error {resp.status_code}: {err}", response=resp)
198
  parser = AwsEventStreamParser()
 
199
  def _iter_text() -> Generator[str, None, None]:
200
  for chunk in resp.iter_content(chunk_size=None):
201
  if not chunk:
@@ -215,9 +226,9 @@ def send_chat_request(access_token: str, messages: List[Dict[str, Any]], model:
215
  except Exception:
216
  pass
217
  if stream:
218
- return None, _iter_text()
219
  else:
220
  buf = []
221
- for t in _iter_text():
222
  buf.append(t)
223
- return "".join(buf), None
 
5
  import struct
6
  import requests
7
 
8
+ class StreamTracker:
9
+ def __init__(self):
10
+ self.has_content = False
11
+
12
+ def track(self, gen: Generator[str, None, None]) -> Generator[str, None, None]:
13
+ for item in gen:
14
+ if item:
15
+ self.has_content = True
16
+ yield item
17
+
18
  BASE_DIR = Path(__file__).resolve().parent
19
  TEMPLATE_PATH = BASE_DIR / "templates" / "streaming_request.json"
20
 
 
185
  except Exception:
186
  pass
187
 
188
+ def send_chat_request(access_token: str, messages: List[Dict[str, Any]], model: Optional[str] = None, stream: bool = False, timeout: Tuple[int,int] = (15,300)) -> Tuple[Optional[str], Optional[Generator[str, None, None]], bool]:
189
  url, headers_from_log, body_json = load_template()
190
  headers_from_log["amz-sdk-invocation-id"] = str(uuid.uuid4())
191
  try:
 
206
  err = f"HTTP {resp.status_code}"
207
  raise requests.HTTPError(f"Upstream error {resp.status_code}: {err}", response=resp)
208
  parser = AwsEventStreamParser()
209
+ tracker = StreamTracker()
210
  def _iter_text() -> Generator[str, None, None]:
211
  for chunk in resp.iter_content(chunk_size=None):
212
  if not chunk:
 
226
  except Exception:
227
  pass
228
  if stream:
229
+ return None, tracker.track(_iter_text()), tracker
230
  else:
231
  buf = []
232
+ for t in tracker.track(_iter_text()):
233
  buf.append(t)
234
+ return "".join(buf), None, tracker