OwenPowell commited on
Commit
7427c08
·
verified ·
1 Parent(s): fa1a195

Upload 86 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. app/__init__.py +1 -0
  3. app/api/v1/admin.py +1299 -0
  4. app/api/v1/chat.py +251 -0
  5. app/api/v1/files.py +72 -0
  6. app/api/v1/image.py +1065 -0
  7. app/api/v1/models.py +51 -0
  8. app/api/v1/uploads.py +64 -0
  9. app/api/v1/video.py +3 -0
  10. app/core/auth.py +159 -0
  11. app/core/config.py +329 -0
  12. app/core/exceptions.py +221 -0
  13. app/core/legacy_migration.py +285 -0
  14. app/core/logger.py +117 -0
  15. app/core/response_middleware.py +71 -0
  16. app/core/storage.py +720 -0
  17. app/services/api_keys.py +432 -0
  18. app/services/base.py +2 -0
  19. app/services/grok/assets.py +875 -0
  20. app/services/grok/chat.py +571 -0
  21. app/services/grok/imagine_experimental.py +416 -0
  22. app/services/grok/imagine_generation.py +137 -0
  23. app/services/grok/media.py +512 -0
  24. app/services/grok/model.py +226 -0
  25. app/services/grok/processor.py +596 -0
  26. app/services/grok/retry.py +178 -0
  27. app/services/grok/statsig.py +46 -0
  28. app/services/grok/usage.py +162 -0
  29. app/services/quota.py +70 -0
  30. app/services/register/__init__.py +5 -0
  31. app/services/register/account_settings_refresh.py +267 -0
  32. app/services/register/manager.py +332 -0
  33. app/services/register/runner.py +415 -0
  34. app/services/register/services/__init__.py +15 -0
  35. app/services/register/services/birth_date_service.py +97 -0
  36. app/services/register/services/email_service.py +90 -0
  37. app/services/register/services/nsfw_service.py +118 -0
  38. app/services/register/services/turnstile_service.py +161 -0
  39. app/services/register/services/user_agreement_service.py +115 -0
  40. app/services/register/solver.py +296 -0
  41. app/services/request_logger.py +143 -0
  42. app/services/request_stats.py +205 -0
  43. app/services/token/__init__.py +36 -0
  44. app/services/token/manager.py +654 -0
  45. app/services/token/models.py +221 -0
  46. app/services/token/pool.py +112 -0
  47. app/services/token/scheduler.py +104 -0
  48. app/services/token/service.py +156 -0
  49. app/static/.assetsignore +2 -0
  50. app/static/_worker.js +4 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ app/template/favicon.png filter=lfs diff=lfs merge=lfs -text
app/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """App Package"""
app/api/v1/admin.py ADDED
@@ -0,0 +1,1299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, Depends, HTTPException, Request, Query, Body, WebSocket
2
+ from fastapi.responses import HTMLResponse, RedirectResponse
3
+ from pydantic import BaseModel
4
+ from typing import Any, Optional
5
+
6
+ from app.core.auth import verify_api_key
7
+ from app.core.config import config, get_config
8
+ from app.core.storage import get_storage, LocalStorage, RedisStorage, SQLStorage
9
+ import os
10
+ from pathlib import Path
11
+ import aiofiles
12
+ import asyncio
13
+ import json
14
+ import time
15
+ import uuid
16
+ import orjson
17
+ from starlette.websockets import WebSocketDisconnect, WebSocketState
18
+ from app.core.logger import logger
19
+ from app.services.register import get_auto_register_manager
20
+ from app.services.register.account_settings_refresh import (
21
+ refresh_account_settings_for_tokens,
22
+ normalize_sso_token as normalize_refresh_token,
23
+ )
24
+ from app.services.api_keys import api_key_manager
25
+ from app.services.grok.model import ModelService
26
+ from app.services.grok.imagine_generation import (
27
+ collect_experimental_generation_images,
28
+ is_valid_image_value as is_valid_imagine_image_value,
29
+ resolve_aspect_ratio as resolve_imagine_aspect_ratio,
30
+ )
31
+ from app.services.token import get_token_manager
32
+ from app.core.auth import _load_legacy_api_keys
33
+
34
+
35
+ router = APIRouter()
36
+
37
+ TEMPLATE_DIR = Path(__file__).parent.parent.parent / "static"
38
+
39
+
40
+ class AdminLoginBody(BaseModel):
41
+ username: str | None = None
42
+ password: str | None = None
43
+
44
+ async def render_template(filename: str):
45
+ """渲染指定模板"""
46
+ template_path = TEMPLATE_DIR / filename
47
+ if not template_path.exists():
48
+ return HTMLResponse(f"Template {filename} not found.", status_code=404)
49
+
50
+ async with aiofiles.open(template_path, "r", encoding="utf-8") as f:
51
+ content = await f.read()
52
+ return HTMLResponse(content)
53
+
54
+ @router.get("/", include_in_schema=False)
55
+ async def root_redirect():
56
+ """Default entry -> /login (consistent with Workers/Pages)."""
57
+ return RedirectResponse(url="/login", status_code=302)
58
+
59
+
60
+ @router.get("/login", response_class=HTMLResponse, include_in_schema=False)
61
+ async def login_page():
62
+ """Login page (default)."""
63
+ return await render_template("login/login.html")
64
+
65
+
66
+ @router.get("/admin", response_class=HTMLResponse, include_in_schema=False)
67
+ async def admin_login_page():
68
+ """Legacy login entry (redirect to /login)."""
69
+ return RedirectResponse(url="/login", status_code=302)
70
+
71
+ @router.get("/admin/config", response_class=HTMLResponse, include_in_schema=False)
72
+ async def admin_config_page():
73
+ """配置管理页"""
74
+ return await render_template("config/config.html")
75
+
76
+ @router.get("/admin/token", response_class=HTMLResponse, include_in_schema=False)
77
+ async def admin_token_page():
78
+ """Token 管理页"""
79
+ return await render_template("token/token.html")
80
+
81
+ @router.get("/admin/datacenter", response_class=HTMLResponse, include_in_schema=False)
82
+ async def admin_datacenter_page():
83
+ """数据中心页"""
84
+ return await render_template("datacenter/datacenter.html")
85
+
86
+ @router.get("/admin/keys", response_class=HTMLResponse, include_in_schema=False)
87
+ async def admin_keys_page():
88
+ """API Key 管理页"""
89
+ return await render_template("keys/keys.html")
90
+
91
+ @router.get("/chat", response_class=HTMLResponse, include_in_schema=False)
92
+ async def chat_page():
93
+ """在线聊天页(公开入口)"""
94
+ return await render_template("chat/chat.html")
95
+
96
+ @router.get("/admin/chat", response_class=HTMLResponse, include_in_schema=False)
97
+ async def admin_chat_page():
98
+ """在线聊天页(后台入口)"""
99
+ return await render_template("chat/chat_admin.html")
100
+
101
+
102
+ async def _verify_ws_api_key(websocket: WebSocket) -> bool:
103
+ api_key = str(get_config("app.api_key", "") or "").strip()
104
+ legacy_keys = await _load_legacy_api_keys()
105
+ if not api_key and not legacy_keys:
106
+ return True
107
+ token = str(websocket.query_params.get("api_key") or "").strip()
108
+ if not token:
109
+ return False
110
+ if (api_key and token == api_key) or token in legacy_keys:
111
+ return True
112
+ try:
113
+ await api_key_manager.init()
114
+ if api_key_manager.validate_key(token):
115
+ return True
116
+ except Exception as e:
117
+ logger.warning(f"Imagine ws api_key validation fallback failed: {e}")
118
+ return False
119
+
120
+
121
+ async def _collect_imagine_batch(token: str, prompt: str, aspect_ratio: str) -> list[str]:
122
+ return await collect_experimental_generation_images(
123
+ token=token,
124
+ prompt=prompt,
125
+ n=6,
126
+ response_format="b64_json",
127
+ aspect_ratio=aspect_ratio,
128
+ concurrency=1,
129
+ )
130
+
131
+
132
+ @router.websocket("/api/v1/admin/imagine/ws")
133
+ async def admin_imagine_ws(websocket: WebSocket):
134
+ if not await _verify_ws_api_key(websocket):
135
+ await websocket.close(code=1008)
136
+ return
137
+
138
+ await websocket.accept()
139
+ stop_event = asyncio.Event()
140
+ run_task: Optional[asyncio.Task] = None
141
+
142
+ async def _send(payload: dict) -> bool:
143
+ try:
144
+ await websocket.send_text(orjson.dumps(payload).decode())
145
+ return True
146
+ except Exception:
147
+ return False
148
+
149
+ async def _stop_run():
150
+ nonlocal run_task
151
+ stop_event.set()
152
+ if run_task and not run_task.done():
153
+ run_task.cancel()
154
+ try:
155
+ await run_task
156
+ except asyncio.CancelledError:
157
+ pass
158
+ except Exception:
159
+ pass
160
+ run_task = None
161
+ stop_event.clear()
162
+
163
+ async def _run(prompt: str, aspect_ratio: str):
164
+ model_id = "grok-imagine-1.0"
165
+ model_info = ModelService.get(model_id)
166
+ if not model_info or not model_info.is_image:
167
+ await _send(
168
+ {
169
+ "type": "error",
170
+ "message": "Image model is not available.",
171
+ "code": "model_not_supported",
172
+ }
173
+ )
174
+ return
175
+
176
+ token_mgr = await get_token_manager()
177
+ sequence = 0
178
+ run_id = uuid.uuid4().hex
179
+ await _send(
180
+ {
181
+ "type": "status",
182
+ "status": "running",
183
+ "prompt": prompt,
184
+ "aspect_ratio": aspect_ratio,
185
+ "run_id": run_id,
186
+ }
187
+ )
188
+
189
+ while not stop_event.is_set():
190
+ try:
191
+ await token_mgr.reload_if_stale()
192
+ token = token_mgr.get_token_for_model(model_info.model_id)
193
+ if not token:
194
+ await _send(
195
+ {
196
+ "type": "error",
197
+ "message": "No available tokens. Please try again later.",
198
+ "code": "rate_limit_exceeded",
199
+ }
200
+ )
201
+ await asyncio.sleep(2)
202
+ continue
203
+
204
+ start_at = time.time()
205
+ images = await _collect_imagine_batch(token, prompt, aspect_ratio)
206
+ elapsed_ms = int((time.time() - start_at) * 1000)
207
+
208
+ sent_any = False
209
+ for image_b64 in images:
210
+ if not is_valid_imagine_image_value(image_b64):
211
+ continue
212
+ sent_any = True
213
+ sequence += 1
214
+ ok = await _send(
215
+ {
216
+ "type": "image",
217
+ "b64_json": image_b64,
218
+ "sequence": sequence,
219
+ "created_at": int(time.time() * 1000),
220
+ "elapsed_ms": elapsed_ms,
221
+ "aspect_ratio": aspect_ratio,
222
+ "run_id": run_id,
223
+ }
224
+ )
225
+ if not ok:
226
+ stop_event.set()
227
+ break
228
+
229
+ if sent_any:
230
+ try:
231
+ await token_mgr.sync_usage(
232
+ token,
233
+ model_info.model_id,
234
+ consume_on_fail=True,
235
+ is_usage=True,
236
+ )
237
+ except Exception as e:
238
+ logger.warning(f"Imagine ws token sync failed: {e}")
239
+ else:
240
+ await _send(
241
+ {
242
+ "type": "error",
243
+ "message": "Image generation returned empty data.",
244
+ "code": "empty_image",
245
+ }
246
+ )
247
+ except asyncio.CancelledError:
248
+ break
249
+ except Exception as e:
250
+ logger.warning(f"Imagine stream error: {e}")
251
+ await _send(
252
+ {
253
+ "type": "error",
254
+ "message": str(e),
255
+ "code": "internal_error",
256
+ }
257
+ )
258
+ await asyncio.sleep(1.5)
259
+
260
+ await _send({"type": "status", "status": "stopped", "run_id": run_id})
261
+
262
+ try:
263
+ while True:
264
+ try:
265
+ raw = await websocket.receive_text()
266
+ except (RuntimeError, WebSocketDisconnect):
267
+ break
268
+
269
+ try:
270
+ payload = orjson.loads(raw)
271
+ except Exception:
272
+ await _send(
273
+ {
274
+ "type": "error",
275
+ "message": "Invalid message format.",
276
+ "code": "invalid_payload",
277
+ }
278
+ )
279
+ continue
280
+
281
+ msg_type = payload.get("type")
282
+ if msg_type == "start":
283
+ prompt = str(payload.get("prompt") or "").strip()
284
+ if not prompt:
285
+ await _send(
286
+ {
287
+ "type": "error",
288
+ "message": "Prompt cannot be empty.",
289
+ "code": "empty_prompt",
290
+ }
291
+ )
292
+ continue
293
+ ratio = resolve_imagine_aspect_ratio(str(payload.get("aspect_ratio") or "2:3").strip())
294
+ await _stop_run()
295
+ run_task = asyncio.create_task(_run(prompt, ratio))
296
+ elif msg_type == "stop":
297
+ await _stop_run()
298
+ elif msg_type == "ping":
299
+ await _send({"type": "pong"})
300
+ else:
301
+ await _send(
302
+ {
303
+ "type": "error",
304
+ "message": "Unknown command.",
305
+ "code": "unknown_command",
306
+ }
307
+ )
308
+ except WebSocketDisconnect:
309
+ logger.debug("WebSocket disconnected by client")
310
+ except asyncio.CancelledError:
311
+ logger.debug("WebSocket handler cancelled")
312
+ except Exception as e:
313
+ logger.warning(f"WebSocket error: {e}")
314
+ finally:
315
+ await _stop_run()
316
+ try:
317
+ if websocket.client_state == WebSocketState.CONNECTED:
318
+ await websocket.close(code=1000, reason="Server closing connection")
319
+ except Exception as e:
320
+ logger.debug(f"WebSocket close ignored: {e}")
321
+
322
+
323
+ @router.post("/api/v1/admin/login")
324
+ async def admin_login_api(request: Request, body: AdminLoginBody | None = Body(default=None)):
325
+ """管理后台登录验证(用户名+密码)
326
+
327
+ - 默认账号/密码:admin/admin(可在配置管理的「应用设置」里修改)
328
+ - 兼容旧版本:允许 Authorization: Bearer <password> 仅密码登录(用户名默认为 admin)
329
+ """
330
+
331
+ admin_username = str(get_config("app.admin_username", "admin") or "admin").strip() or "admin"
332
+ admin_password = str(get_config("app.app_key", "admin") or "admin").strip()
333
+
334
+ username = (body.username.strip() if body and isinstance(body.username, str) else "").strip()
335
+ password = (body.password.strip() if body and isinstance(body.password, str) else "").strip()
336
+
337
+ # Legacy: password-only via Bearer token.
338
+ if not password:
339
+ auth = request.headers.get("Authorization") or ""
340
+ if auth.lower().startswith("bearer "):
341
+ password = auth[7:].strip()
342
+ if not username:
343
+ username = "admin"
344
+
345
+ if not username or not password:
346
+ raise HTTPException(status_code=400, detail="Missing username or password")
347
+
348
+ if username != admin_username or password != admin_password:
349
+ raise HTTPException(status_code=401, detail="Invalid username or password")
350
+
351
+ return {"status": "success", "api_key": get_config("app.api_key", "")}
352
+
353
+ @router.get("/api/v1/admin/config", dependencies=[Depends(verify_api_key)])
354
+ async def get_config_api():
355
+ """获取当前配置"""
356
+ # 暴露原始配置字典
357
+ return config._config
358
+
359
+ @router.post("/api/v1/admin/config", dependencies=[Depends(verify_api_key)])
360
+ async def update_config_api(data: dict):
361
+ """更新配置"""
362
+ try:
363
+ await config.update(data)
364
+ return {"status": "success", "message": "配置已更新"}
365
+ except Exception as e:
366
+ raise HTTPException(status_code=500, detail=str(e))
367
+
368
+
369
+ def _display_key(key: str) -> str:
370
+ k = str(key or "")
371
+ if len(k) <= 12:
372
+ return k
373
+ return f"{k[:6]}...{k[-4:]}"
374
+
375
+
376
+ def _normalize_limit(v: Any) -> int:
377
+ if v is None or v == "":
378
+ return -1
379
+ try:
380
+ return max(-1, int(v))
381
+ except Exception:
382
+ return -1
383
+
384
+
385
+ def _pool_to_token_type(pool_name: str) -> str:
386
+ return "ssoSuper" if str(pool_name or "").strip() == "ssoSuper" else "sso"
387
+
388
+
389
+ def _parse_quota_value(v: Any) -> tuple[int, bool]:
390
+ if v is None or v == "":
391
+ return -1, False
392
+ try:
393
+ n = int(v)
394
+ except Exception:
395
+ return -1, False
396
+ if n < 0:
397
+ return -1, False
398
+ return n, True
399
+
400
+
401
+ def _safe_int(v: Any, default: int = 0) -> int:
402
+ try:
403
+ return int(v)
404
+ except Exception:
405
+ return default
406
+
407
+
408
+ def _normalize_token_status(raw_status: Any) -> str:
409
+ s = str(raw_status or "active").strip().lower()
410
+ if s == "expired":
411
+ return "invalid"
412
+ if s in ("active", "cooling", "invalid", "disabled"):
413
+ return s
414
+ return "active"
415
+
416
+
417
+ def _normalize_admin_token_item(pool_name: str, item: Any) -> dict | None:
418
+ token_type = _pool_to_token_type(pool_name)
419
+
420
+ if isinstance(item, str):
421
+ token = item.strip()
422
+ if not token:
423
+ return None
424
+ if token.startswith("sso="):
425
+ token = token[4:]
426
+ return {
427
+ "token": token,
428
+ "status": "active",
429
+ "quota": 0,
430
+ "quota_known": False,
431
+ "heavy_quota": -1,
432
+ "heavy_quota_known": False,
433
+ "token_type": token_type,
434
+ "note": "",
435
+ "fail_count": 0,
436
+ "use_count": 0,
437
+ }
438
+
439
+ if not isinstance(item, dict):
440
+ return None
441
+
442
+ token = str(item.get("token") or "").strip()
443
+ if not token:
444
+ return None
445
+ if token.startswith("sso="):
446
+ token = token[4:]
447
+
448
+ quota, quota_known = _parse_quota_value(item.get("quota"))
449
+ heavy_quota, heavy_quota_known = _parse_quota_value(item.get("heavy_quota"))
450
+
451
+ return {
452
+ "token": token,
453
+ "status": _normalize_token_status(item.get("status")),
454
+ "quota": quota if quota_known else 0,
455
+ "quota_known": quota_known,
456
+ "heavy_quota": heavy_quota,
457
+ "heavy_quota_known": heavy_quota_known,
458
+ "token_type": token_type,
459
+ "note": str(item.get("note") or ""),
460
+ "fail_count": _safe_int(item.get("fail_count") or 0, 0),
461
+ "use_count": _safe_int(item.get("use_count") or 0, 0),
462
+ }
463
+
464
+
465
+ def _collect_tokens_from_pool_payload(payload: Any) -> list[str]:
466
+ if not isinstance(payload, dict):
467
+ return []
468
+
469
+ collected: list[str] = []
470
+ seen: set[str] = set()
471
+ for raw_items in payload.values():
472
+ if not isinstance(raw_items, list):
473
+ continue
474
+ for item in raw_items:
475
+ token_raw = item if isinstance(item, str) else (item.get("token") if isinstance(item, dict) else "")
476
+ token = normalize_refresh_token(str(token_raw or "").strip())
477
+ if not token or token in seen:
478
+ continue
479
+ seen.add(token)
480
+ collected.append(token)
481
+ return collected
482
+
483
+
484
+ def _resolve_nsfw_refresh_concurrency(override: Any = None) -> int:
485
+ source = override if override is not None else get_config("token.nsfw_refresh_concurrency", 10)
486
+ try:
487
+ value = int(source)
488
+ except Exception:
489
+ value = 10
490
+ return max(1, value)
491
+
492
+
493
+ def _resolve_nsfw_refresh_retries(override: Any = None) -> int:
494
+ source = override if override is not None else get_config("token.nsfw_refresh_retries", 3)
495
+ try:
496
+ value = int(source)
497
+ except Exception:
498
+ value = 3
499
+ return max(0, value)
500
+
501
+
502
+ def _trigger_account_settings_refresh_background(
503
+ tokens: list[str],
504
+ concurrency: int,
505
+ retries: int,
506
+ ) -> None:
507
+ if not tokens:
508
+ return
509
+
510
+ async def _run() -> None:
511
+ try:
512
+ result = await refresh_account_settings_for_tokens(
513
+ tokens=tokens,
514
+ concurrency=concurrency,
515
+ retries=retries,
516
+ )
517
+ summary = result.get("summary") or {}
518
+ logger.info(
519
+ "Background account-settings refresh finished: total={} success={} failed={} invalidated={}",
520
+ summary.get("total", 0),
521
+ summary.get("success", 0),
522
+ summary.get("failed", 0),
523
+ summary.get("invalidated", 0),
524
+ )
525
+ except Exception as exc:
526
+ logger.warning("Background account-settings refresh failed: {}", exc)
527
+
528
+ asyncio.create_task(_run())
529
+
530
+
531
+ @router.get("/api/v1/admin/keys", dependencies=[Depends(verify_api_key)])
532
+ async def list_api_keys():
533
+ """List API keys + daily usage/remaining (for admin UI)."""
534
+ await api_key_manager.init()
535
+ day, usage_map = await api_key_manager.usage_today()
536
+
537
+ out = []
538
+ for row in api_key_manager.get_all_keys():
539
+ key = str(row.get("key") or "")
540
+ used = usage_map.get(key) or {}
541
+ chat_used = int(used.get("chat_used", 0) or 0)
542
+ heavy_used = int(used.get("heavy_used", 0) or 0)
543
+ image_used = int(used.get("image_used", 0) or 0)
544
+ video_used = int(used.get("video_used", 0) or 0)
545
+
546
+ chat_limit = _normalize_limit(row.get("chat_limit", -1))
547
+ heavy_limit = _normalize_limit(row.get("heavy_limit", -1))
548
+ image_limit = _normalize_limit(row.get("image_limit", -1))
549
+ video_limit = _normalize_limit(row.get("video_limit", -1))
550
+
551
+ remaining = {
552
+ "chat": None if chat_limit < 0 else max(0, chat_limit - chat_used),
553
+ "heavy": None if heavy_limit < 0 else max(0, heavy_limit - heavy_used),
554
+ "image": None if image_limit < 0 else max(0, image_limit - image_used),
555
+ "video": None if video_limit < 0 else max(0, video_limit - video_used),
556
+ }
557
+
558
+ out.append({
559
+ **row,
560
+ "is_active": bool(row.get("is_active", True)),
561
+ "display_key": _display_key(key),
562
+ "usage_today": {
563
+ "chat_used": chat_used,
564
+ "heavy_used": heavy_used,
565
+ "image_used": image_used,
566
+ "video_used": video_used,
567
+ },
568
+ "remaining_today": remaining,
569
+ "day": day,
570
+ })
571
+
572
+ # New UI expects { success: true, data: [...] }
573
+ return {"success": True, "data": out}
574
+
575
+
576
+ @router.post("/api/v1/admin/keys", dependencies=[Depends(verify_api_key)])
577
+ async def create_api_key(data: dict):
578
+ """Create a new API key (optional name/key/limits)."""
579
+ await api_key_manager.init()
580
+ data = data or {}
581
+
582
+ name = str(data.get("name") or "").strip() or api_key_manager.generate_name()
583
+ key_val = str(data.get("key") or "").strip() or None
584
+ is_active = bool(data.get("is_active", True))
585
+
586
+ limits = data.get("limits") if isinstance(data.get("limits"), dict) else {}
587
+ try:
588
+ row = await api_key_manager.add_key(
589
+ name=name,
590
+ key=key_val,
591
+ is_active=is_active,
592
+ limits={
593
+ "chat_per_day": limits.get("chat_per_day"),
594
+ "heavy_per_day": limits.get("heavy_per_day"),
595
+ "image_per_day": limits.get("image_per_day"),
596
+ "video_per_day": limits.get("video_per_day"),
597
+ },
598
+ )
599
+ except ValueError as e:
600
+ raise HTTPException(status_code=400, detail=str(e))
601
+
602
+ return {"success": True, "data": {**row, "display_key": _display_key(row.get("key", ""))}}
603
+
604
+
605
+ @router.post("/api/v1/admin/keys/update", dependencies=[Depends(verify_api_key)])
606
+ async def update_api_key(data: dict):
607
+ """Update name/status/limits for an API key."""
608
+ await api_key_manager.init()
609
+ data = data or {}
610
+ key = str(data.get("key") or "").strip()
611
+ if not key:
612
+ raise HTTPException(status_code=400, detail="Missing key")
613
+
614
+ existing = api_key_manager.get_key_row(key)
615
+ if not existing:
616
+ raise HTTPException(status_code=404, detail="Key not found")
617
+
618
+ if "name" in data and data.get("name") is not None:
619
+ name = str(data.get("name") or "").strip()
620
+ if name:
621
+ await api_key_manager.update_key_name(key, name)
622
+
623
+ if "is_active" in data:
624
+ await api_key_manager.update_key_status(key, bool(data.get("is_active")))
625
+
626
+ limits = data.get("limits") if isinstance(data.get("limits"), dict) else None
627
+ if limits is not None:
628
+ await api_key_manager.update_key_limits(
629
+ key,
630
+ {
631
+ "chat_per_day": limits.get("chat_per_day"),
632
+ "heavy_per_day": limits.get("heavy_per_day"),
633
+ "image_per_day": limits.get("image_per_day"),
634
+ "video_per_day": limits.get("video_per_day"),
635
+ },
636
+ )
637
+
638
+ return {"success": True}
639
+
640
+
641
+ @router.post("/api/v1/admin/keys/delete", dependencies=[Depends(verify_api_key)])
642
+ async def delete_api_key(data: dict):
643
+ """Delete an API key."""
644
+ await api_key_manager.init()
645
+ data = data or {}
646
+ key = str(data.get("key") or "").strip()
647
+ if not key:
648
+ raise HTTPException(status_code=400, detail="Missing key")
649
+
650
+ ok = await api_key_manager.delete_key(key)
651
+ if not ok:
652
+ raise HTTPException(status_code=404, detail="Key not found")
653
+ return {"success": True}
654
+
655
+ @router.get("/api/v1/admin/storage", dependencies=[Depends(verify_api_key)])
656
+ async def get_storage_info():
657
+ """获取当前存储模式"""
658
+ storage_type = os.getenv("SERVER_STORAGE_TYPE", "local").lower()
659
+ logger.info(f"Storage type: {storage_type}")
660
+ if not storage_type:
661
+ storage_type = str(get_config("storage.type", "")).lower()
662
+ if not storage_type:
663
+ storage = get_storage()
664
+ if isinstance(storage, LocalStorage):
665
+ storage_type = "local"
666
+ elif isinstance(storage, RedisStorage):
667
+ storage_type = "redis"
668
+ elif isinstance(storage, SQLStorage):
669
+ if storage.dialect in ("mysql", "mariadb"):
670
+ storage_type = "mysql"
671
+ elif storage.dialect in ("postgres", "postgresql", "pgsql"):
672
+ storage_type = "pgsql"
673
+ else:
674
+ storage_type = storage.dialect
675
+ return {"type": storage_type or "local"}
676
+
677
+ @router.get("/api/v1/admin/tokens", dependencies=[Depends(verify_api_key)])
678
+ async def get_tokens_api():
679
+ """获取所有 Token"""
680
+ storage = get_storage()
681
+ tokens = await storage.load_tokens()
682
+ data = tokens if isinstance(tokens, dict) else {}
683
+ out: dict[str, list[dict]] = {}
684
+ for pool_name, raw_items in data.items():
685
+ arr = raw_items if isinstance(raw_items, list) else []
686
+ normalized: list[dict] = []
687
+ for item in arr:
688
+ obj = _normalize_admin_token_item(pool_name, item)
689
+ if obj:
690
+ normalized.append(obj)
691
+ out[str(pool_name)] = normalized
692
+ return out
693
+
694
+ @router.post("/api/v1/admin/tokens", dependencies=[Depends(verify_api_key)])
695
+ async def update_tokens_api(data: dict):
696
+ """Update token payload and trigger background account-settings refresh for new tokens."""
697
+ storage = get_storage()
698
+ try:
699
+ from app.services.token.manager import get_token_manager
700
+
701
+ posted_data = data if isinstance(data, dict) else {}
702
+ existing_tokens: list[str] = []
703
+ added_tokens: list[str] = []
704
+
705
+ async with storage.acquire_lock("tokens_save", timeout=10):
706
+ old_data = await storage.load_tokens()
707
+ existing_tokens = _collect_tokens_from_pool_payload(
708
+ old_data if isinstance(old_data, dict) else {}
709
+ )
710
+
711
+ await storage.save_tokens(posted_data)
712
+ mgr = await get_token_manager()
713
+ await mgr.reload()
714
+
715
+ new_tokens = _collect_tokens_from_pool_payload(posted_data)
716
+ existing_set = set(existing_tokens)
717
+ added_tokens = [token for token in new_tokens if token not in existing_set]
718
+
719
+ concurrency = _resolve_nsfw_refresh_concurrency()
720
+ retries = _resolve_nsfw_refresh_retries()
721
+ _trigger_account_settings_refresh_background(
722
+ tokens=added_tokens,
723
+ concurrency=concurrency,
724
+ retries=retries,
725
+ )
726
+
727
+ return {
728
+ "status": "success",
729
+ "message": "Token updated",
730
+ "nsfw_refresh": {
731
+ "mode": "background",
732
+ "triggered": len(added_tokens),
733
+ "concurrency": concurrency,
734
+ "retries": retries,
735
+ },
736
+ }
737
+ except Exception as e:
738
+ raise HTTPException(status_code=500, detail=str(e))
739
+
740
+ @router.post("/api/v1/admin/tokens/refresh", dependencies=[Depends(verify_api_key)])
741
+ async def refresh_tokens_api(data: dict):
742
+ """刷新 Token 状态"""
743
+ from app.services.token.manager import get_token_manager
744
+
745
+ try:
746
+ mgr = await get_token_manager()
747
+ tokens = []
748
+ if "token" in data:
749
+ tokens.append(data["token"])
750
+ if "tokens" in data and isinstance(data["tokens"], list):
751
+ tokens.extend(data["tokens"])
752
+
753
+ if not tokens:
754
+ raise HTTPException(status_code=400, detail="No tokens provided")
755
+
756
+ unique_tokens = list(set(tokens))
757
+
758
+ sem = asyncio.Semaphore(10)
759
+
760
+ async def _refresh_one(t):
761
+ async with sem:
762
+ return t, await mgr.sync_usage(t, "grok-3", consume_on_fail=False, is_usage=False)
763
+
764
+ results_list = await asyncio.gather(*[_refresh_one(t) for t in unique_tokens])
765
+ results = dict(results_list)
766
+
767
+ return {"status": "success", "results": results}
768
+ except Exception as e:
769
+ raise HTTPException(status_code=500, detail=str(e))
770
+
771
+
772
+ @router.post("/api/v1/admin/tokens/nsfw/refresh", dependencies=[Depends(verify_api_key)])
773
+ async def refresh_tokens_nsfw_api(data: dict):
774
+ """Refresh account settings (TOS + birth date + NSFW) for selected/all tokens."""
775
+ payload = data if isinstance(data, dict) else {}
776
+ mgr = await get_token_manager()
777
+
778
+ tokens: list[str] = []
779
+ seen: set[str] = set()
780
+
781
+ if bool(payload.get("all")):
782
+ for pool in mgr.pools.values():
783
+ for info in pool.list():
784
+ token = normalize_refresh_token(str(info.token or "").strip())
785
+ if not token or token in seen:
786
+ continue
787
+ seen.add(token)
788
+ tokens.append(token)
789
+ else:
790
+ candidates: list[str] = []
791
+ single = payload.get("token")
792
+ if isinstance(single, str):
793
+ candidates.append(single)
794
+ batch = payload.get("tokens")
795
+ if isinstance(batch, list):
796
+ candidates.extend([item for item in batch if isinstance(item, str)])
797
+
798
+ for raw in candidates:
799
+ token = normalize_refresh_token(str(raw or "").strip())
800
+ if not token or token in seen:
801
+ continue
802
+ seen.add(token)
803
+ tokens.append(token)
804
+
805
+ if not tokens:
806
+ raise HTTPException(status_code=400, detail="No tokens provided")
807
+
808
+ concurrency = _resolve_nsfw_refresh_concurrency(payload.get("concurrency"))
809
+ retries = _resolve_nsfw_refresh_retries(payload.get("retries"))
810
+ result = await refresh_account_settings_for_tokens(
811
+ tokens=tokens,
812
+ concurrency=concurrency,
813
+ retries=retries,
814
+ )
815
+ return {
816
+ "status": "success",
817
+ "summary": result.get("summary") or {},
818
+ "failed": result.get("failed") or [],
819
+ }
820
+
821
+
822
+ @router.post("/api/v1/admin/tokens/auto-register", dependencies=[Depends(verify_api_key)])
823
+ async def auto_register_tokens_api(data: dict):
824
+ """Start auto registration."""
825
+ try:
826
+ data = data or {}
827
+ count = data.get("count")
828
+ concurrency = data.get("concurrency")
829
+ pool = (data.get("pool") or "ssoBasic").strip() or "ssoBasic"
830
+
831
+ try:
832
+ count_val = int(count)
833
+ except Exception:
834
+ count_val = int(get_config("register.default_count", 100) or 100)
835
+
836
+ if count_val <= 0:
837
+ count_val = int(get_config("register.default_count", 100) or 100)
838
+
839
+ try:
840
+ concurrency_val = int(concurrency)
841
+ except Exception:
842
+ concurrency_val = None
843
+ if concurrency_val is not None and concurrency_val <= 0:
844
+ concurrency_val = None
845
+
846
+ manager = get_auto_register_manager()
847
+ job = await manager.start_job(count=count_val, pool=pool, concurrency=concurrency_val)
848
+ return {"status": "started", "job": job.to_dict()}
849
+ except RuntimeError as e:
850
+ raise HTTPException(status_code=409, detail=str(e))
851
+ except Exception as e:
852
+ raise HTTPException(status_code=500, detail=str(e))
853
+
854
+
855
+ @router.get("/api/v1/admin/tokens/auto-register/status", dependencies=[Depends(verify_api_key)])
856
+ async def auto_register_status_api(job_id: str | None = None):
857
+ """Get auto registration status."""
858
+ manager = get_auto_register_manager()
859
+ status = manager.get_status(job_id)
860
+ if status.get("status") == "not_found":
861
+ raise HTTPException(status_code=404, detail="Job not found")
862
+ return status
863
+
864
+
865
+ @router.post("/api/v1/admin/tokens/auto-register/stop", dependencies=[Depends(verify_api_key)])
866
+ async def auto_register_stop_api(job_id: str | None = None):
867
+ """Stop auto registration (best-effort)."""
868
+ manager = get_auto_register_manager()
869
+ status = manager.get_status(job_id)
870
+ if status.get("status") == "not_found":
871
+ raise HTTPException(status_code=404, detail="Job not found")
872
+ await manager.stop_job()
873
+ return {"status": "stopping"}
874
+
875
+ @router.get("/admin/cache", response_class=HTMLResponse, include_in_schema=False)
876
+ async def admin_cache_page():
877
+ """缓存管理页"""
878
+ return await render_template("cache/cache.html")
879
+
880
+ @router.get("/api/v1/admin/cache", dependencies=[Depends(verify_api_key)])
881
+ async def get_cache_stats_api(request: Request):
882
+ """获取缓存统计"""
883
+ from app.services.grok.assets import DownloadService, ListService
884
+ from app.services.token.manager import get_token_manager
885
+
886
+ try:
887
+ dl_service = DownloadService()
888
+ image_stats = dl_service.get_stats("image")
889
+ video_stats = dl_service.get_stats("video")
890
+
891
+ mgr = await get_token_manager()
892
+ pools = mgr.pools
893
+ accounts = []
894
+ for pool_name, pool in pools.items():
895
+ for info in pool.list():
896
+ raw_token = info.token[4:] if info.token.startswith("sso=") else info.token
897
+ masked = f"{raw_token[:8]}...{raw_token[-16:]}" if len(raw_token) > 24 else raw_token
898
+ accounts.append({
899
+ "token": raw_token,
900
+ "token_masked": masked,
901
+ "pool": pool_name,
902
+ "status": info.status,
903
+ "last_asset_clear_at": info.last_asset_clear_at
904
+ })
905
+
906
+ scope = request.query_params.get("scope")
907
+ selected_token = request.query_params.get("token")
908
+ tokens_param = request.query_params.get("tokens")
909
+ selected_tokens = []
910
+ if tokens_param:
911
+ selected_tokens = [t.strip() for t in tokens_param.split(",") if t.strip()]
912
+
913
+ online_stats = {"count": 0, "status": "unknown", "token": None, "last_asset_clear_at": None}
914
+ online_details = []
915
+ account_map = {a["token"]: a for a in accounts}
916
+ batch_size = get_config("performance.admin_assets_batch_size", 10)
917
+ try:
918
+ batch_size = int(batch_size)
919
+ except Exception:
920
+ batch_size = 10
921
+ batch_size = max(1, batch_size)
922
+
923
+ async def _fetch_assets(token: str):
924
+ list_service = ListService()
925
+ try:
926
+ return await list_service.count(token)
927
+ finally:
928
+ await list_service.close()
929
+
930
+ async def _fetch_detail(token: str):
931
+ account = account_map.get(token)
932
+ try:
933
+ count = await _fetch_assets(token)
934
+ return ({
935
+ "token": token,
936
+ "token_masked": account["token_masked"] if account else token,
937
+ "count": count,
938
+ "status": "ok",
939
+ "last_asset_clear_at": account["last_asset_clear_at"] if account else None
940
+ }, count)
941
+ except Exception as e:
942
+ return ({
943
+ "token": token,
944
+ "token_masked": account["token_masked"] if account else token,
945
+ "count": 0,
946
+ "status": f"error: {str(e)}",
947
+ "last_asset_clear_at": account["last_asset_clear_at"] if account else None
948
+ }, 0)
949
+
950
+ if selected_tokens:
951
+ total = 0
952
+ for i in range(0, len(selected_tokens), batch_size):
953
+ chunk = selected_tokens[i:i + batch_size]
954
+ results = await asyncio.gather(*[_fetch_detail(token) for token in chunk])
955
+ for detail, count in results:
956
+ online_details.append(detail)
957
+ total += count
958
+ online_stats = {"count": total, "status": "ok" if selected_tokens else "no_token", "token": None, "last_asset_clear_at": None}
959
+ scope = "selected"
960
+ elif scope == "all":
961
+ total = 0
962
+ tokens = [account["token"] for account in accounts]
963
+ for i in range(0, len(tokens), batch_size):
964
+ chunk = tokens[i:i + batch_size]
965
+ results = await asyncio.gather(*[_fetch_detail(token) for token in chunk])
966
+ for detail, count in results:
967
+ online_details.append(detail)
968
+ total += count
969
+ online_stats = {"count": total, "status": "ok" if accounts else "no_token", "token": None, "last_asset_clear_at": None}
970
+ else:
971
+ token = selected_token
972
+ if token:
973
+ try:
974
+ count = await _fetch_assets(token)
975
+ match = next((a for a in accounts if a["token"] == token), None)
976
+ online_stats = {
977
+ "count": count,
978
+ "status": "ok",
979
+ "token": token,
980
+ "token_masked": match["token_masked"] if match else token,
981
+ "last_asset_clear_at": match["last_asset_clear_at"] if match else None
982
+ }
983
+ except Exception as e:
984
+ match = next((a for a in accounts if a["token"] == token), None)
985
+ online_stats = {
986
+ "count": 0,
987
+ "status": f"error: {str(e)}",
988
+ "token": token,
989
+ "token_masked": match["token_masked"] if match else token,
990
+ "last_asset_clear_at": match["last_asset_clear_at"] if match else None
991
+ }
992
+ else:
993
+ online_stats = {"count": 0, "status": "not_loaded", "token": None, "last_asset_clear_at": None}
994
+
995
+ return {
996
+ "local_image": image_stats,
997
+ "local_video": video_stats,
998
+ "online": online_stats,
999
+ "online_accounts": accounts,
1000
+ "online_scope": scope or "none",
1001
+ "online_details": online_details
1002
+ }
1003
+ except Exception as e:
1004
+ raise HTTPException(status_code=500, detail=str(e))
1005
+
1006
+ @router.post("/api/v1/admin/cache/clear", dependencies=[Depends(verify_api_key)])
1007
+ async def clear_local_cache_api(data: dict):
1008
+ """清理本地缓存"""
1009
+ from app.services.grok.assets import DownloadService
1010
+ cache_type = data.get("type", "image")
1011
+
1012
+ try:
1013
+ dl_service = DownloadService()
1014
+ result = dl_service.clear(cache_type)
1015
+ return {"status": "success", "result": result}
1016
+ except Exception as e:
1017
+ raise HTTPException(status_code=500, detail=str(e))
1018
+
1019
+ @router.get("/api/v1/admin/cache/list", dependencies=[Depends(verify_api_key)])
1020
+ async def list_local_cache_api(
1021
+ cache_type: str = "image",
1022
+ type_: str = Query(default=None, alias="type"),
1023
+ page: int = 1,
1024
+ page_size: int = 1000
1025
+ ):
1026
+ """列出本地缓存文件"""
1027
+ from app.services.grok.assets import DownloadService
1028
+ try:
1029
+ if type_:
1030
+ cache_type = type_
1031
+ dl_service = DownloadService()
1032
+ result = dl_service.list_files(cache_type, page, page_size)
1033
+ return {"status": "success", **result}
1034
+ except Exception as e:
1035
+ raise HTTPException(status_code=500, detail=str(e))
1036
+
1037
+ @router.post("/api/v1/admin/cache/item/delete", dependencies=[Depends(verify_api_key)])
1038
+ async def delete_local_cache_item_api(data: dict):
1039
+ """删除单个本地缓存文件"""
1040
+ from app.services.grok.assets import DownloadService
1041
+ cache_type = data.get("type", "image")
1042
+ name = data.get("name")
1043
+ if not name:
1044
+ raise HTTPException(status_code=400, detail="Missing file name")
1045
+ try:
1046
+ dl_service = DownloadService()
1047
+ result = dl_service.delete_file(cache_type, name)
1048
+ return {"status": "success", "result": result}
1049
+ except Exception as e:
1050
+ raise HTTPException(status_code=500, detail=str(e))
1051
+
1052
+ @router.post("/api/v1/admin/cache/online/clear", dependencies=[Depends(verify_api_key)])
1053
+ async def clear_online_cache_api(data: dict):
1054
+ """清理在线缓存"""
1055
+ from app.services.grok.assets import DeleteService
1056
+ from app.services.token.manager import get_token_manager
1057
+
1058
+ delete_service = None
1059
+ try:
1060
+ mgr = await get_token_manager()
1061
+ tokens = data.get("tokens")
1062
+ delete_service = DeleteService()
1063
+
1064
+ if isinstance(tokens, list):
1065
+ token_list = [t.strip() for t in tokens if isinstance(t, str) and t.strip()]
1066
+ if not token_list:
1067
+ raise HTTPException(status_code=400, detail="No tokens provided")
1068
+
1069
+ results = {}
1070
+ batch_size = get_config("performance.admin_assets_batch_size", 10)
1071
+ try:
1072
+ batch_size = int(batch_size)
1073
+ except Exception:
1074
+ batch_size = 10
1075
+ batch_size = max(1, batch_size)
1076
+
1077
+ async def _clear_one(t: str):
1078
+ try:
1079
+ result = await delete_service.delete_all(t)
1080
+ await mgr.mark_asset_clear(t)
1081
+ return t, {"status": "success", "result": result}
1082
+ except Exception as e:
1083
+ return t, {"status": "error", "error": str(e)}
1084
+
1085
+ for i in range(0, len(token_list), batch_size):
1086
+ chunk = token_list[i:i + batch_size]
1087
+ res_list = await asyncio.gather(*[_clear_one(t) for t in chunk])
1088
+ for t, res in res_list:
1089
+ results[t] = res
1090
+
1091
+ return {"status": "success", "results": results}
1092
+
1093
+ token = data.get("token") or mgr.get_token()
1094
+ if not token:
1095
+ raise HTTPException(status_code=400, detail="No available token to perform cleanup")
1096
+
1097
+ result = await delete_service.delete_all(token)
1098
+ await mgr.mark_asset_clear(token)
1099
+ return {"status": "success", "result": result}
1100
+ except Exception as e:
1101
+ raise HTTPException(status_code=500, detail=str(e))
1102
+ finally:
1103
+ if delete_service:
1104
+ await delete_service.close()
1105
+
1106
+
1107
+ @router.get("/api/v1/admin/metrics", dependencies=[Depends(verify_api_key)])
1108
+ async def get_metrics_api():
1109
+ """数据中心:聚合常用指标(token/cache/request_stats)。"""
1110
+ try:
1111
+ from app.services.request_stats import request_stats
1112
+ from app.services.token.manager import get_token_manager
1113
+ from app.services.token.models import TokenStatus
1114
+ from app.services.grok.assets import DownloadService
1115
+
1116
+ mgr = await get_token_manager()
1117
+ await mgr.reload_if_stale()
1118
+
1119
+ total = 0
1120
+ active = 0
1121
+ cooling = 0
1122
+ expired = 0
1123
+ disabled = 0
1124
+ chat_quota = 0
1125
+ total_calls = 0
1126
+
1127
+ for pool in mgr.pools.values():
1128
+ for info in pool.list():
1129
+ total += 1
1130
+ total_calls += int(getattr(info, "use_count", 0) or 0)
1131
+ if info.status == TokenStatus.ACTIVE:
1132
+ active += 1
1133
+ chat_quota += int(getattr(info, "quota", 0) or 0)
1134
+ elif info.status == TokenStatus.COOLING:
1135
+ cooling += 1
1136
+ elif info.status == TokenStatus.EXPIRED:
1137
+ expired += 1
1138
+ elif info.status == TokenStatus.DISABLED:
1139
+ disabled += 1
1140
+
1141
+ dl = DownloadService()
1142
+ local_image = dl.get_stats("image")
1143
+ local_video = dl.get_stats("video")
1144
+
1145
+ await request_stats.init()
1146
+ stats = request_stats.get_stats(hours=24, days=7)
1147
+
1148
+ return {
1149
+ "tokens": {
1150
+ "total": total,
1151
+ "active": active,
1152
+ "cooling": cooling,
1153
+ "expired": expired,
1154
+ "disabled": disabled,
1155
+ "chat_quota": chat_quota,
1156
+ "image_quota": int(chat_quota // 2),
1157
+ "total_calls": total_calls,
1158
+ },
1159
+ "cache": {
1160
+ "local_image": local_image,
1161
+ "local_video": local_video,
1162
+ },
1163
+ "request_stats": stats,
1164
+ }
1165
+ except Exception as e:
1166
+ raise HTTPException(status_code=500, detail=str(e))
1167
+
1168
+
1169
+ @router.get("/api/v1/admin/cache/local", dependencies=[Depends(verify_api_key)])
1170
+ async def get_cache_local_stats_api():
1171
+ """仅获取本地缓存统计(用于前端实时刷新)。"""
1172
+ from app.services.grok.assets import DownloadService
1173
+
1174
+ try:
1175
+ dl_service = DownloadService()
1176
+ image_stats = dl_service.get_stats("image")
1177
+ video_stats = dl_service.get_stats("video")
1178
+ return {"local_image": image_stats, "local_video": video_stats}
1179
+ except Exception as e:
1180
+ raise HTTPException(status_code=500, detail=str(e))
1181
+
1182
+
1183
+ def _safe_log_file_path(name: str) -> Path:
1184
+ """Resolve a log file name under ./logs safely."""
1185
+ from app.core.logger import LOG_DIR
1186
+
1187
+ name = (name or "").strip()
1188
+ if not name:
1189
+ raise ValueError("Missing log file")
1190
+ # Disallow path traversal.
1191
+ if "/" in name or "\\" in name or ".." in name:
1192
+ raise ValueError("Invalid log file name")
1193
+
1194
+ p = (LOG_DIR / name).resolve()
1195
+ if LOG_DIR.resolve() not in p.parents:
1196
+ raise ValueError("Invalid log file path")
1197
+ if not p.exists() or not p.is_file():
1198
+ raise FileNotFoundError(name)
1199
+ return p
1200
+
1201
+
1202
+ def _format_log_line(raw: str) -> str:
1203
+ raw = (raw or "").rstrip("\r\n")
1204
+ if not raw:
1205
+ return ""
1206
+
1207
+ # Try JSON log line (our file sink uses json lines).
1208
+ try:
1209
+ obj = json.loads(raw)
1210
+ if not isinstance(obj, dict):
1211
+ return raw
1212
+ ts = str(obj.get("time", "") or "")
1213
+ ts = ts.replace("T", " ")
1214
+ if len(ts) >= 19:
1215
+ ts = ts[:19]
1216
+ level = str(obj.get("level", "") or "").upper()
1217
+ caller = str(obj.get("caller", "") or "")
1218
+ msg = str(obj.get("msg", "") or "")
1219
+ if not (ts and level and msg):
1220
+ return raw
1221
+ return f"{ts} | {level:<8} | {caller} - {msg}".rstrip()
1222
+ except Exception:
1223
+ return raw
1224
+
1225
+
1226
+ def _tail_lines(path: Path, max_lines: int = 2000, max_bytes: int = 1024 * 1024) -> list[str]:
1227
+ """Best-effort tail for a text file."""
1228
+ try:
1229
+ max_lines = int(max_lines)
1230
+ except Exception:
1231
+ max_lines = 2000
1232
+ max_lines = max(1, min(5000, max_lines))
1233
+ max_bytes = max(16 * 1024, min(5 * 1024 * 1024, int(max_bytes)))
1234
+
1235
+ with open(path, "rb") as f:
1236
+ f.seek(0, os.SEEK_END)
1237
+ end = f.tell()
1238
+ start = max(0, end - max_bytes)
1239
+ f.seek(start, os.SEEK_SET)
1240
+ data = f.read()
1241
+
1242
+ text = data.decode("utf-8", errors="replace")
1243
+ lines = text.splitlines()
1244
+ # If we read from the middle of a line, drop the first partial line.
1245
+ if start > 0 and lines:
1246
+ lines = lines[1:]
1247
+ lines = lines[-max_lines:]
1248
+ return [_format_log_line(ln) for ln in lines if ln is not None]
1249
+
1250
+
1251
+ @router.get("/api/v1/admin/logs/files", dependencies=[Depends(verify_api_key)])
1252
+ async def list_log_files_api():
1253
+ """列出可查看的日志文件(logs/*.log)。"""
1254
+ from app.core.logger import LOG_DIR
1255
+
1256
+ try:
1257
+ items = []
1258
+ for p in LOG_DIR.glob("*.log"):
1259
+ try:
1260
+ stat = p.stat()
1261
+ items.append(
1262
+ {
1263
+ "name": p.name,
1264
+ "size_bytes": stat.st_size,
1265
+ "mtime_ms": int(stat.st_mtime * 1000),
1266
+ }
1267
+ )
1268
+ except Exception:
1269
+ continue
1270
+ items.sort(key=lambda x: x["mtime_ms"], reverse=True)
1271
+ return {"files": items}
1272
+ except Exception as e:
1273
+ raise HTTPException(status_code=500, detail=str(e))
1274
+
1275
+
1276
+ @router.get("/api/v1/admin/logs/tail", dependencies=[Depends(verify_api_key)])
1277
+ async def tail_log_api(file: str | None = None, lines: int = 500):
1278
+ """读取后台日志(尾部)。"""
1279
+ from app.core.logger import LOG_DIR
1280
+
1281
+ try:
1282
+ # Default to latest log.
1283
+ if not file:
1284
+ candidates = sorted(LOG_DIR.glob("*.log"), key=lambda p: p.stat().st_mtime if p.exists() else 0, reverse=True)
1285
+ if not candidates:
1286
+ return {"file": None, "lines": []}
1287
+ path = candidates[0]
1288
+ file = path.name
1289
+ else:
1290
+ path = _safe_log_file_path(file)
1291
+
1292
+ data = await asyncio.to_thread(_tail_lines, path, lines)
1293
+ return {"file": str(file), "lines": data}
1294
+ except FileNotFoundError:
1295
+ raise HTTPException(status_code=404, detail="Log file not found")
1296
+ except ValueError as ve:
1297
+ raise HTTPException(status_code=400, detail=str(ve))
1298
+ except Exception as e:
1299
+ raise HTTPException(status_code=500, detail=str(e))
app/api/v1/chat.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Chat Completions API 路由
3
+ """
4
+
5
+ from typing import Any, Dict, List, Optional, Union
6
+
7
+ from fastapi import APIRouter, Depends
8
+ from fastapi.responses import StreamingResponse, JSONResponse
9
+ from pydantic import BaseModel, Field, field_validator
10
+
11
+ from app.core.auth import verify_api_key
12
+ from app.services.grok.chat import ChatService
13
+ from app.services.grok.model import ModelService
14
+ from app.core.exceptions import ValidationException
15
+ from app.services.quota import enforce_daily_quota
16
+
17
+
18
+ router = APIRouter(tags=["Chat"])
19
+
20
+
21
+ VALID_ROLES = ["developer", "system", "user", "assistant"]
22
+ USER_CONTENT_TYPES = ["text", "image_url", "input_audio", "file"]
23
+
24
+
25
+ class MessageItem(BaseModel):
26
+ """消息项"""
27
+ role: str
28
+ content: Union[str, List[Dict[str, Any]]]
29
+
30
+ @field_validator("role")
31
+ @classmethod
32
+ def validate_role(cls, v):
33
+ if v not in VALID_ROLES:
34
+ raise ValueError(f"role must be one of {VALID_ROLES}")
35
+ return v
36
+
37
+
38
+ class VideoConfig(BaseModel):
39
+ """视频生成配置"""
40
+ aspect_ratio: Optional[str] = Field("3:2", description="视频比例: 3:2, 16:9, 1:1 等")
41
+ video_length: Optional[int] = Field(6, description="视频时长(秒): 5-15")
42
+ resolution: Optional[str] = Field("SD", description="视频分辨率: SD, HD")
43
+ preset: Optional[str] = Field("custom", description="风格预设: fun, normal, spicy")
44
+
45
+ @field_validator("aspect_ratio")
46
+ @classmethod
47
+ def validate_aspect_ratio(cls, v):
48
+ allowed = ["2:3", "3:2", "1:1", "9:16", "16:9"]
49
+ if v and v not in allowed:
50
+ raise ValidationException(
51
+ message=f"aspect_ratio must be one of {allowed}",
52
+ param="video_config.aspect_ratio",
53
+ code="invalid_aspect_ratio"
54
+ )
55
+ return v
56
+
57
+ @field_validator("video_length")
58
+ @classmethod
59
+ def validate_video_length(cls, v):
60
+ if v is not None:
61
+ if v < 5 or v > 15:
62
+ raise ValidationException(
63
+ message="video_length must be between 5 and 15 seconds",
64
+ param="video_config.video_length",
65
+ code="invalid_video_length"
66
+ )
67
+ return v
68
+
69
+ @field_validator("resolution")
70
+ @classmethod
71
+ def validate_resolution(cls, v):
72
+ allowed = ["SD", "HD"]
73
+ if v and v not in allowed:
74
+ raise ValidationException(
75
+ message=f"resolution must be one of {allowed}",
76
+ param="video_config.resolution",
77
+ code="invalid_resolution"
78
+ )
79
+ return v
80
+
81
+ @field_validator("preset")
82
+ @classmethod
83
+ def validate_preset(cls, v):
84
+ # 允许为空,默认 custom
85
+ if not v:
86
+ return "custom"
87
+ allowed = ["fun", "normal", "spicy", "custom"]
88
+ if v not in allowed:
89
+ raise ValidationException(
90
+ message=f"preset must be one of {allowed}",
91
+ param="video_config.preset",
92
+ code="invalid_preset"
93
+ )
94
+ return v
95
+
96
+
97
+ class ChatCompletionRequest(BaseModel):
98
+ """Chat Completions 请求"""
99
+ model: str = Field(..., description="模型名称")
100
+ messages: List[MessageItem] = Field(..., description="消息数组")
101
+ stream: Optional[bool] = Field(None, description="是否流式输出")
102
+ thinking: Optional[str] = Field(None, description="思考模式: enabled/disabled/None")
103
+
104
+ # 视频生成配置
105
+ video_config: Optional[VideoConfig] = Field(None, description="视频生成参数")
106
+
107
+ model_config = {
108
+ "extra": "ignore"
109
+ }
110
+
111
+
112
+ def validate_request(request: ChatCompletionRequest):
113
+ """验证请求参数"""
114
+ # 验证模型
115
+ if not ModelService.valid(request.model):
116
+ raise ValidationException(
117
+ message=f"The model `{request.model}` does not exist or you do not have access to it.",
118
+ param="model",
119
+ code="model_not_found"
120
+ )
121
+
122
+ # 验证消息
123
+ for idx, msg in enumerate(request.messages):
124
+ content = msg.content
125
+
126
+ # 字符串内容
127
+ if isinstance(content, str):
128
+ if not content.strip():
129
+ raise ValidationException(
130
+ message="Message content cannot be empty",
131
+ param=f"messages.{idx}.content",
132
+ code="empty_content"
133
+ )
134
+
135
+ # 列表内容
136
+ elif isinstance(content, list):
137
+ if not content:
138
+ raise ValidationException(
139
+ message="Message content cannot be an empty array",
140
+ param=f"messages.{idx}.content",
141
+ code="empty_content"
142
+ )
143
+
144
+ for block_idx, block in enumerate(content):
145
+ # 检查空对象
146
+ if not block:
147
+ raise ValidationException(
148
+ message="Content block cannot be empty",
149
+ param=f"messages.{idx}.content.{block_idx}",
150
+ code="empty_block"
151
+ )
152
+
153
+ # 检查 type 字段
154
+ if "type" not in block:
155
+ raise ValidationException(
156
+ message="Content block must have a 'type' field",
157
+ param=f"messages.{idx}.content.{block_idx}",
158
+ code="missing_type"
159
+ )
160
+
161
+ block_type = block.get("type")
162
+
163
+ # 检查 type 空值
164
+ if not block_type or not isinstance(block_type, str) or not block_type.strip():
165
+ raise ValidationException(
166
+ message="Content block 'type' cannot be empty",
167
+ param=f"messages.{idx}.content.{block_idx}.type",
168
+ code="empty_type"
169
+ )
170
+
171
+ # 验证 type 有效性
172
+ if msg.role == "user":
173
+ if block_type not in USER_CONTENT_TYPES:
174
+ raise ValidationException(
175
+ message=f"Invalid content block type: '{block_type}'",
176
+ param=f"messages.{idx}.content.{block_idx}.type",
177
+ code="invalid_type"
178
+ )
179
+ elif block_type != "text":
180
+ raise ValidationException(
181
+ message=f"The `{msg.role}` role only supports 'text' type, got '{block_type}'",
182
+ param=f"messages.{idx}.content.{block_idx}.type",
183
+ code="invalid_type"
184
+ )
185
+
186
+ # 验证字段是否存在 & 非空
187
+ if block_type == "text":
188
+ text = block.get("text", "")
189
+ if not isinstance(text, str) or not text.strip():
190
+ raise ValidationException(
191
+ message="Text content cannot be empty",
192
+ param=f"messages.{idx}.content.{block_idx}.text",
193
+ code="empty_text"
194
+ )
195
+ elif block_type == "image_url":
196
+ image_url = block.get("image_url")
197
+ if not image_url or not (isinstance(image_url, dict) and image_url.get("url")):
198
+ raise ValidationException(
199
+ message="image_url must have a 'url' field",
200
+ param=f"messages.{idx}.content.{block_idx}.image_url",
201
+ code="missing_url"
202
+ )
203
+
204
+
205
+ @router.post("/chat/completions")
206
+ async def chat_completions(request: ChatCompletionRequest, api_key: Optional[str] = Depends(verify_api_key)):
207
+ """Chat Completions API - 兼容 OpenAI"""
208
+
209
+ # 参数验证
210
+ validate_request(request)
211
+
212
+ # Daily quota (best-effort)
213
+ await enforce_daily_quota(api_key, request.model)
214
+
215
+ # 检测视频模型
216
+ model_info = ModelService.get(request.model)
217
+ if model_info and model_info.is_video:
218
+ from app.services.grok.media import VideoService
219
+
220
+ # 提取视频配置 (默认值在 Pydantic 模型中处理)
221
+ v_conf = request.video_config or VideoConfig()
222
+
223
+ result = await VideoService.completions(
224
+ model=request.model,
225
+ messages=[msg.model_dump() for msg in request.messages],
226
+ stream=request.stream,
227
+ thinking=request.thinking,
228
+ aspect_ratio=v_conf.aspect_ratio,
229
+ video_length=v_conf.video_length,
230
+ resolution=v_conf.resolution,
231
+ preset=v_conf.preset
232
+ )
233
+ else:
234
+ result = await ChatService.completions(
235
+ model=request.model,
236
+ messages=[msg.model_dump() for msg in request.messages],
237
+ stream=request.stream,
238
+ thinking=request.thinking
239
+ )
240
+
241
+ if isinstance(result, dict):
242
+ return JSONResponse(content=result)
243
+ else:
244
+ return StreamingResponse(
245
+ result,
246
+ media_type="text/event-stream",
247
+ headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}
248
+ )
249
+
250
+
251
+ __all__ = ["router"]
app/api/v1/files.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 文件服务 API 路由
3
+ """
4
+
5
+ import aiofiles.os
6
+ from pathlib import Path
7
+ from fastapi import APIRouter, HTTPException
8
+ from fastapi.responses import FileResponse
9
+
10
+ from app.core.logger import logger
11
+
12
+ router = APIRouter(tags=["Files"])
13
+
14
+ # 缓存根目录
15
+ BASE_DIR = Path(__file__).parent.parent.parent.parent / "data" / "tmp"
16
+ IMAGE_DIR = BASE_DIR / "image"
17
+ VIDEO_DIR = BASE_DIR / "video"
18
+
19
+
20
+ @router.get("/image/{filename:path}")
21
+ async def get_image(filename: str):
22
+ """
23
+ 获取图片文件
24
+ """
25
+ if "/" in filename:
26
+ filename = filename.replace("/", "-")
27
+
28
+ file_path = IMAGE_DIR / filename
29
+
30
+ if await aiofiles.os.path.exists(file_path):
31
+ if await aiofiles.os.path.isfile(file_path):
32
+ content_type = "image/jpeg"
33
+ if file_path.suffix.lower() == ".png":
34
+ content_type = "image/png"
35
+ elif file_path.suffix.lower() == ".webp":
36
+ content_type = "image/webp"
37
+
38
+ # 增加缓存头,支持高并发场景下的浏览器/CDN缓存
39
+ return FileResponse(
40
+ file_path,
41
+ media_type=content_type,
42
+ headers={
43
+ "Cache-Control": "public, max-age=31536000, immutable"
44
+ }
45
+ )
46
+
47
+ logger.warning(f"Image not found: {filename}")
48
+ raise HTTPException(status_code=404, detail="Image not found")
49
+
50
+
51
+ @router.get("/video/{filename:path}")
52
+ async def get_video(filename: str):
53
+ """
54
+ 获取视频文件
55
+ """
56
+ if "/" in filename:
57
+ filename = filename.replace("/", "-")
58
+
59
+ file_path = VIDEO_DIR / filename
60
+
61
+ if await aiofiles.os.path.exists(file_path):
62
+ if await aiofiles.os.path.isfile(file_path):
63
+ return FileResponse(
64
+ file_path,
65
+ media_type="video/mp4",
66
+ headers={
67
+ "Cache-Control": "public, max-age=31536000, immutable"
68
+ }
69
+ )
70
+
71
+ logger.warning(f"Video not found: {filename}")
72
+ raise HTTPException(status_code=404, detail="Video not found")
app/api/v1/image.py ADDED
@@ -0,0 +1,1065 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Image Generation API 路由
3
+ """
4
+
5
+ import asyncio
6
+ import base64
7
+ import random
8
+ from pathlib import Path
9
+ from typing import Any, Awaitable, Callable, Dict, List, Optional, Union
10
+
11
+ import orjson
12
+ from fastapi import APIRouter, Depends, File, Form, UploadFile
13
+ from fastapi.responses import JSONResponse, StreamingResponse
14
+ from pydantic import BaseModel, Field, ValidationError
15
+
16
+ from app.core.auth import verify_api_key
17
+ from app.core.config import get_config
18
+ from app.core.exceptions import AppException, ErrorType, UpstreamException, ValidationException
19
+ from app.core.logger import logger
20
+ from app.services.grok.assets import UploadService
21
+ from app.services.grok.chat import GrokChatService
22
+ from app.services.grok.imagine_experimental import (
23
+ IMAGE_METHOD_IMAGINE_WS_EXPERIMENTAL,
24
+ IMAGE_METHOD_LEGACY,
25
+ ImagineExperimentalService,
26
+ resolve_image_generation_method,
27
+ )
28
+ from app.services.grok.imagine_generation import (
29
+ call_experimental_generation_once,
30
+ collect_experimental_generation_images,
31
+ dedupe_images as dedupe_imagine_images,
32
+ is_valid_image_value as is_valid_imagine_image_value,
33
+ resolve_aspect_ratio as resolve_imagine_aspect_ratio,
34
+ )
35
+ from app.services.grok.model import ModelService
36
+ from app.services.grok.processor import ImageCollectProcessor, ImageStreamProcessor
37
+ from app.services.quota import enforce_daily_quota
38
+ from app.services.request_stats import request_stats
39
+ from app.services.token import get_token_manager
40
+
41
+
42
+ router = APIRouter(tags=["Images"])
43
+ ALLOWED_RESPONSE_FORMATS = {"b64_json", "base64", "url"}
44
+
45
+
46
+ class ImageGenerationRequest(BaseModel):
47
+ """Image generation request - OpenAI compatible."""
48
+
49
+ prompt: str = Field(..., description="Image prompt")
50
+ model: Optional[str] = Field("grok-imagine-1.0", description="Model name")
51
+ n: Optional[int] = Field(1, ge=1, le=10, description="Image count (1-10)")
52
+ size: Optional[str] = Field("1024x1024", description="Image size / ratio")
53
+ quality: Optional[str] = Field("standard", description="Reserved")
54
+ response_format: Optional[str] = Field(None, description="Response format")
55
+ style: Optional[str] = Field(None, description="Reserved")
56
+ stream: Optional[bool] = Field(False, description="Enable streaming")
57
+ concurrency: Optional[int] = Field(1, ge=1, le=3, description="Experimental concurrency")
58
+
59
+
60
+ class ImageEditRequest(BaseModel):
61
+ """Image edit request - OpenAI compatible."""
62
+
63
+ prompt: str = Field(..., description="Edit prompt")
64
+ model: Optional[str] = Field("grok-imagine-1.0-edit", description="Model name")
65
+ image: Optional[Union[str, List[str]]] = Field(None, description="Input image(s)")
66
+ n: Optional[int] = Field(1, ge=1, le=10, description="Image count (1-10)")
67
+ size: Optional[str] = Field("1024x1024", description="Reserved")
68
+ quality: Optional[str] = Field("standard", description="Reserved")
69
+ response_format: Optional[str] = Field(None, description="Response format")
70
+ style: Optional[str] = Field(None, description="Reserved")
71
+ stream: Optional[bool] = Field(False, description="Enable streaming")
72
+
73
+
74
+ def validate_generation_request(request: ImageGenerationRequest):
75
+ """Validate image generation request parameters."""
76
+ model_id = request.model or "grok-imagine-1.0"
77
+ if model_id != "grok-imagine-1.0":
78
+ raise ValidationException(
79
+ message="The model `grok-imagine-1.0` is required for image generation.",
80
+ param="model",
81
+ code="model_not_supported",
82
+ )
83
+
84
+ model_info = ModelService.get(model_id)
85
+ if not model_info or not model_info.is_image:
86
+ raise ValidationException(
87
+ message=f"The model `{model_id}` is not supported for image generation.",
88
+ param="model",
89
+ code="model_not_supported",
90
+ )
91
+
92
+ if not request.prompt or not request.prompt.strip():
93
+ raise ValidationException(
94
+ message="Prompt cannot be empty",
95
+ param="prompt",
96
+ code="empty_prompt",
97
+ )
98
+
99
+ if request.n is None:
100
+ request.n = 1
101
+ if request.n < 1 or request.n > 10:
102
+ raise ValidationException(
103
+ message="n must be between 1 and 10",
104
+ param="n",
105
+ code="invalid_n",
106
+ )
107
+
108
+ if request.stream and request.n not in [1, 2]:
109
+ raise ValidationException(
110
+ message="Streaming is only supported when n=1 or n=2",
111
+ param="stream",
112
+ code="invalid_stream_n",
113
+ )
114
+
115
+ if request.concurrency is None:
116
+ request.concurrency = 1
117
+ if request.concurrency < 1 or request.concurrency > 3:
118
+ raise ValidationException(
119
+ message="concurrency must be between 1 and 3",
120
+ param="concurrency",
121
+ code="invalid_concurrency",
122
+ )
123
+
124
+ if request.response_format:
125
+ candidate = request.response_format.lower()
126
+ if candidate not in ALLOWED_RESPONSE_FORMATS:
127
+ raise ValidationException(
128
+ message=f"response_format must be one of {sorted(ALLOWED_RESPONSE_FORMATS)}",
129
+ param="response_format",
130
+ code="invalid_response_format",
131
+ )
132
+
133
+
134
+ def validate_edit_request(request: ImageEditRequest, images: List[UploadFile]):
135
+ """Validate image edit request parameters."""
136
+ model_id = request.model or "grok-imagine-1.0-edit"
137
+ if model_id != "grok-imagine-1.0-edit":
138
+ raise ValidationException(
139
+ message="The model `grok-imagine-1.0-edit` is required for image edits.",
140
+ param="model",
141
+ code="model_not_supported",
142
+ )
143
+
144
+ model_info = ModelService.get(model_id)
145
+ if not model_info or not model_info.is_image:
146
+ raise ValidationException(
147
+ message=f"The model `{model_id}` is not supported for image edits.",
148
+ param="model",
149
+ code="model_not_supported",
150
+ )
151
+
152
+ if not request.prompt or not request.prompt.strip():
153
+ raise ValidationException(
154
+ message="Prompt cannot be empty",
155
+ param="prompt",
156
+ code="empty_prompt",
157
+ )
158
+
159
+ if request.n is None:
160
+ request.n = 1
161
+ if request.n < 1 or request.n > 10:
162
+ raise ValidationException(
163
+ message="n must be between 1 and 10",
164
+ param="n",
165
+ code="invalid_n",
166
+ )
167
+
168
+ if request.stream and request.n not in [1, 2]:
169
+ raise ValidationException(
170
+ message="Streaming is only supported when n=1 or n=2",
171
+ param="stream",
172
+ code="invalid_stream_n",
173
+ )
174
+
175
+ if request.response_format:
176
+ candidate = request.response_format.lower()
177
+ if candidate not in ALLOWED_RESPONSE_FORMATS:
178
+ raise ValidationException(
179
+ message=f"response_format must be one of {sorted(ALLOWED_RESPONSE_FORMATS)}",
180
+ param="response_format",
181
+ code="invalid_response_format",
182
+ )
183
+
184
+ if not images:
185
+ raise ValidationException(
186
+ message="Image is required",
187
+ param="image",
188
+ code="missing_image",
189
+ )
190
+ if len(images) > 16:
191
+ raise ValidationException(
192
+ message="Too many images. Maximum is 16.",
193
+ param="image",
194
+ code="invalid_image_count",
195
+ )
196
+
197
+
198
+ def resolve_response_format(response_format: Optional[str]) -> str:
199
+ candidate = response_format
200
+ if not candidate:
201
+ candidate = get_config("app.image_format", "url")
202
+ if isinstance(candidate, str):
203
+ candidate = candidate.lower()
204
+ if candidate in ALLOWED_RESPONSE_FORMATS:
205
+ return candidate
206
+ raise ValidationException(
207
+ message=f"response_format must be one of {sorted(ALLOWED_RESPONSE_FORMATS)}",
208
+ param="response_format",
209
+ code="invalid_response_format",
210
+ )
211
+
212
+
213
+ def resolve_image_response_format(
214
+ response_format: Optional[str],
215
+ image_method: str,
216
+ ) -> str:
217
+ """
218
+ Keep legacy behavior, but for experimental imagine path:
219
+ if caller does not explicitly provide response_format and global default is `url`,
220
+ prefer `b64_json` to avoid loopback URL rendering issues in local deployments.
221
+ """
222
+ raw = response_format if not isinstance(response_format, str) else response_format.strip()
223
+ if not raw and image_method == IMAGE_METHOD_IMAGINE_WS_EXPERIMENTAL:
224
+ default_format = str(get_config("app.image_format", "url") or "url").strip().lower()
225
+ if default_format == "url":
226
+ return "b64_json"
227
+ return resolve_response_format(response_format)
228
+
229
+
230
+ def response_field_name(response_format: str) -> str:
231
+ if response_format == "url":
232
+ return "url"
233
+ if response_format == "base64":
234
+ return "base64"
235
+ return "b64_json"
236
+
237
+
238
+ def _image_generation_method() -> str:
239
+ return resolve_image_generation_method(
240
+ get_config("grok.image_generation_method", IMAGE_METHOD_LEGACY)
241
+ )
242
+
243
+
244
+ def resolve_aspect_ratio(size: Optional[str]) -> str:
245
+ return resolve_imagine_aspect_ratio(size)
246
+
247
+
248
+ def _is_valid_image_value(value: Any) -> bool:
249
+ return is_valid_imagine_image_value(value)
250
+
251
+
252
+ def _dedupe_images(images: List[str]) -> List[str]:
253
+ return dedupe_imagine_images(images)
254
+
255
+
256
+ async def _gather_limited(
257
+ task_factories: List[Callable[[], Awaitable[List[str]]]],
258
+ max_concurrency: int,
259
+ ) -> List[Any]:
260
+ sem = asyncio.Semaphore(max(1, int(max_concurrency or 1)))
261
+
262
+ async def _run(factory: Callable[[], Awaitable[List[str]]]) -> Any:
263
+ async with sem:
264
+ return await factory()
265
+
266
+ return await asyncio.gather(*[_run(factory) for factory in task_factories], return_exceptions=True)
267
+
268
+
269
+ async def call_grok_legacy(
270
+ token: str,
271
+ prompt: str,
272
+ model_info,
273
+ file_attachments: Optional[List[str]] = None,
274
+ response_format: str = "b64_json",
275
+ ) -> List[str]:
276
+ """
277
+ 调用 Grok 获取图片,返回图片列表
278
+ """
279
+ chat_service = GrokChatService()
280
+
281
+ try:
282
+ response = await chat_service.chat(
283
+ token=token,
284
+ message=prompt,
285
+ model=model_info.grok_model,
286
+ mode=model_info.model_mode,
287
+ think=False,
288
+ stream=True,
289
+ file_attachments=file_attachments,
290
+ )
291
+
292
+ processor = ImageCollectProcessor(
293
+ model_info.model_id,
294
+ token,
295
+ response_format=response_format,
296
+ )
297
+ return await processor.process(response)
298
+ except Exception as e:
299
+ logger.error(f"Grok image call failed: {e}")
300
+ return []
301
+
302
+
303
+ async def call_grok_experimental_ws(
304
+ token: str,
305
+ prompt: str,
306
+ response_format: str = "b64_json",
307
+ n: int = 4,
308
+ aspect_ratio: str = "2:3",
309
+ ) -> List[str]:
310
+ return await call_experimental_generation_once(
311
+ token=token,
312
+ prompt=prompt,
313
+ response_format=response_format,
314
+ n=n,
315
+ aspect_ratio=aspect_ratio,
316
+ )
317
+
318
+
319
+ async def call_grok_experimental_edit(
320
+ token: str,
321
+ prompt: str,
322
+ model_id: str,
323
+ file_uris: List[str],
324
+ response_format: str = "b64_json",
325
+ ) -> List[str]:
326
+ service = ImagineExperimentalService()
327
+ response = await service.chat_edit(token=token, prompt=prompt, file_uris=file_uris)
328
+ processor = ImageCollectProcessor(
329
+ model_id,
330
+ token,
331
+ response_format=response_format,
332
+ )
333
+ return await processor.process(response)
334
+
335
+
336
+ async def _collect_experimental_generation_images(
337
+ token: str,
338
+ prompt: str,
339
+ n: int,
340
+ response_format: str,
341
+ aspect_ratio: str,
342
+ concurrency: int,
343
+ ) -> List[str]:
344
+ return await collect_experimental_generation_images(
345
+ token=token,
346
+ prompt=prompt,
347
+ n=n,
348
+ response_format=response_format,
349
+ aspect_ratio=aspect_ratio,
350
+ concurrency=concurrency,
351
+ )
352
+
353
+
354
+ async def _experimental_stream_generation(
355
+ token: str,
356
+ prompt: str,
357
+ n: int,
358
+ response_format: str,
359
+ response_field: str,
360
+ aspect_ratio: str,
361
+ state: dict[str, Any],
362
+ ):
363
+ service = ImagineExperimentalService()
364
+ queue: asyncio.Queue[Optional[str]] = asyncio.Queue()
365
+ index_map: Dict[int, int] = {}
366
+ map_lock = asyncio.Lock()
367
+ next_output_index = 0
368
+
369
+ async def _resolve_output_index(raw_index: int) -> int:
370
+ nonlocal next_output_index
371
+ async with map_lock:
372
+ if raw_index not in index_map:
373
+ index_map[raw_index] = min(next_output_index, max(0, n - 1))
374
+ next_output_index += 1
375
+ return index_map[raw_index]
376
+
377
+ async def _progress_cb(raw_index: int, progress: float):
378
+ idx = await _resolve_output_index(raw_index)
379
+ await queue.put(
380
+ _sse_event(
381
+ "image_generation.partial_image",
382
+ {
383
+ "type": "image_generation.partial_image",
384
+ response_field: "",
385
+ "index": idx,
386
+ "progress": max(0, min(100, int(progress))),
387
+ },
388
+ )
389
+ )
390
+
391
+ async def _completed_cb(raw_index: int, raw_url: str):
392
+ idx = await _resolve_output_index(raw_index)
393
+ converted = await service.convert_url(
394
+ token=token,
395
+ url=raw_url,
396
+ response_format=response_format,
397
+ )
398
+ if not _is_valid_image_value(converted):
399
+ return
400
+
401
+ state["success"] = True
402
+ await queue.put(
403
+ _sse_event(
404
+ "image_generation.completed",
405
+ {
406
+ "type": "image_generation.completed",
407
+ response_field: converted,
408
+ "index": idx,
409
+ "usage": {
410
+ "total_tokens": 50,
411
+ "input_tokens": 25,
412
+ "output_tokens": 25,
413
+ "input_tokens_details": {"text_tokens": 5, "image_tokens": 20},
414
+ },
415
+ },
416
+ )
417
+ )
418
+
419
+ producer_error: Optional[Exception] = None
420
+
421
+ async def _producer():
422
+ nonlocal producer_error
423
+ try:
424
+ await service.generate_ws(
425
+ token=token,
426
+ prompt=prompt,
427
+ n=n,
428
+ aspect_ratio=aspect_ratio,
429
+ progress_cb=_progress_cb,
430
+ completed_cb=_completed_cb,
431
+ )
432
+ except Exception as exc:
433
+ producer_error = exc
434
+ finally:
435
+ await queue.put(None)
436
+
437
+ producer_task = asyncio.create_task(_producer())
438
+ try:
439
+ while True:
440
+ chunk = await queue.get()
441
+ if chunk is None:
442
+ break
443
+ yield chunk
444
+ finally:
445
+ await producer_task
446
+
447
+ if not state.get("success", False):
448
+ if isinstance(producer_error, Exception):
449
+ raise producer_error
450
+ raise UpstreamException("Experimental imagine websocket returned no images")
451
+
452
+
453
+ def _sse_event(event: str, data: dict) -> str:
454
+ return f"event: {event}\ndata: {orjson.dumps(data).decode()}\n\n"
455
+
456
+
457
+ async def _synthetic_image_stream(
458
+ selected_images: List[str],
459
+ response_field: str,
460
+ ):
461
+ emitted = False
462
+ for idx, image in enumerate(selected_images):
463
+ if not isinstance(image, str) or not image or image == "error":
464
+ continue
465
+ emitted = True
466
+ yield _sse_event(
467
+ "image_generation.partial_image",
468
+ {
469
+ "type": "image_generation.partial_image",
470
+ response_field: "",
471
+ "index": idx,
472
+ "progress": 100,
473
+ },
474
+ )
475
+ yield _sse_event(
476
+ "image_generation.completed",
477
+ {
478
+ "type": "image_generation.completed",
479
+ response_field: image,
480
+ "index": idx,
481
+ "usage": {
482
+ "total_tokens": 50,
483
+ "input_tokens": 25,
484
+ "output_tokens": 25,
485
+ "input_tokens_details": {"text_tokens": 5, "image_tokens": 20},
486
+ },
487
+ },
488
+ )
489
+ if not emitted:
490
+ yield _sse_event(
491
+ "image_generation.completed",
492
+ {
493
+ "type": "image_generation.completed",
494
+ response_field: "error",
495
+ "index": 0,
496
+ "usage": {
497
+ "total_tokens": 0,
498
+ "input_tokens": 0,
499
+ "output_tokens": 0,
500
+ "input_tokens_details": {"text_tokens": 0, "image_tokens": 0},
501
+ },
502
+ },
503
+ )
504
+
505
+
506
+ async def _record_request(model_id: str, success: bool):
507
+ try:
508
+ await request_stats.record_request(model_id, success=success)
509
+ except Exception:
510
+ pass
511
+
512
+
513
+ async def _get_token_for_model(model_id: str):
514
+ """获取指定模型可用 token,失败时抛出统一异常"""
515
+ try:
516
+ token_mgr = await get_token_manager()
517
+ await token_mgr.reload_if_stale()
518
+ token = token_mgr.get_token_for_model(model_id)
519
+ except Exception as e:
520
+ logger.error(f"Failed to get token: {e}")
521
+ await _record_request(model_id or "image", False)
522
+ raise AppException(
523
+ message="Internal service error obtaining token",
524
+ error_type=ErrorType.SERVER.value,
525
+ code="internal_error",
526
+ )
527
+
528
+ if not token:
529
+ await _record_request(model_id or "image", False)
530
+ raise AppException(
531
+ message="No available tokens. Please try again later.",
532
+ error_type=ErrorType.RATE_LIMIT.value,
533
+ code="rate_limit_exceeded",
534
+ status_code=429,
535
+ )
536
+ return token_mgr, token
537
+
538
+
539
+ def _pick_images(all_images: List[str], n: int) -> List[str]:
540
+ if len(all_images) >= n:
541
+ return random.sample(all_images, n)
542
+ selected = all_images.copy()
543
+ while len(selected) < n:
544
+ selected.append("error")
545
+ return selected
546
+
547
+
548
+ def _build_image_response(selected_images: List[str], response_field: str) -> JSONResponse:
549
+ import time
550
+
551
+ return JSONResponse(
552
+ content={
553
+ "created": int(time.time()),
554
+ "data": [{response_field: img} for img in selected_images],
555
+ "usage": {
556
+ "total_tokens": 0 * len([img for img in selected_images if img != "error"]),
557
+ "input_tokens": 0,
558
+ "output_tokens": 0 * len([img for img in selected_images if img != "error"]),
559
+ "input_tokens_details": {"text_tokens": 0, "image_tokens": 0},
560
+ },
561
+ }
562
+ )
563
+
564
+
565
+ @router.get("/images/method")
566
+ async def get_image_generation_method():
567
+ return {"image_generation_method": _image_generation_method()}
568
+
569
+
570
+ @router.post("/images/generations")
571
+ async def create_image(
572
+ request: ImageGenerationRequest,
573
+ api_key: Optional[str] = Depends(verify_api_key),
574
+ ):
575
+ """Image Generation API."""
576
+ if request.stream is None:
577
+ request.stream = False
578
+
579
+ validate_generation_request(request)
580
+ model_id = request.model or "grok-imagine-1.0"
581
+ n = int(request.n or 1)
582
+ concurrency = max(1, min(3, int(request.concurrency or 1)))
583
+ image_method = _image_generation_method()
584
+ response_format = resolve_image_response_format(request.response_format, image_method)
585
+ request.response_format = response_format
586
+ response_field = response_field_name(response_format)
587
+ aspect_ratio = resolve_aspect_ratio(request.size)
588
+
589
+ await enforce_daily_quota(api_key, model_id, image_count=n)
590
+ token_mgr, token = await _get_token_for_model(model_id)
591
+ model_info = ModelService.get(model_id)
592
+
593
+ if request.stream:
594
+ if image_method == IMAGE_METHOD_IMAGINE_WS_EXPERIMENTAL:
595
+ stream_state: Dict[str, Any] = {"success": False}
596
+
597
+ async def _wrapped_experimental_stream():
598
+ try:
599
+ try:
600
+ async for chunk in _experimental_stream_generation(
601
+ token=token,
602
+ prompt=request.prompt,
603
+ n=n,
604
+ response_format=response_format,
605
+ response_field=response_field,
606
+ aspect_ratio=aspect_ratio,
607
+ state=stream_state,
608
+ ):
609
+ yield chunk
610
+ except Exception as stream_err:
611
+ logger.warning(
612
+ f"Experimental image generation realtime stream failed: {stream_err}. "
613
+ "Fallback to synthetic stream."
614
+ )
615
+ try:
616
+ all_images = await _collect_experimental_generation_images(
617
+ token=token,
618
+ prompt=request.prompt,
619
+ n=n,
620
+ response_format=response_format,
621
+ aspect_ratio=aspect_ratio,
622
+ concurrency=concurrency,
623
+ )
624
+ selected_images = _pick_images(_dedupe_images(all_images), n)
625
+ stream_state["success"] = any(
626
+ _is_valid_image_value(item) for item in selected_images
627
+ )
628
+ async for chunk in _synthetic_image_stream(selected_images, response_field):
629
+ yield chunk
630
+ except Exception as synthetic_err:
631
+ logger.warning(
632
+ f"Experimental synthetic stream failed: {synthetic_err}. "
633
+ "Fallback to legacy stream."
634
+ )
635
+ chat_service = GrokChatService()
636
+ response = await chat_service.chat(
637
+ token=token,
638
+ message=f"Image Generation: {request.prompt}",
639
+ model=model_info.grok_model,
640
+ mode=model_info.model_mode,
641
+ think=False,
642
+ stream=True,
643
+ )
644
+ processor = ImageStreamProcessor(
645
+ model_info.model_id,
646
+ token,
647
+ n=n,
648
+ response_format=response_format,
649
+ )
650
+ async for chunk in processor.process(response):
651
+ yield chunk
652
+ stream_state["success"] = True
653
+ finally:
654
+ try:
655
+ if stream_state.get("success"):
656
+ await token_mgr.sync_usage(
657
+ token,
658
+ model_info.model_id,
659
+ consume_on_fail=True,
660
+ is_usage=True,
661
+ )
662
+ await _record_request(model_info.model_id, True)
663
+ else:
664
+ await _record_request(model_info.model_id, False)
665
+ except Exception:
666
+ pass
667
+
668
+ return StreamingResponse(
669
+ _wrapped_experimental_stream(),
670
+ media_type="text/event-stream",
671
+ headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
672
+ )
673
+
674
+ chat_service = GrokChatService()
675
+ try:
676
+ response = await chat_service.chat(
677
+ token=token,
678
+ message=f"Image Generation: {request.prompt}",
679
+ model=model_info.grok_model,
680
+ mode=model_info.model_mode,
681
+ think=False,
682
+ stream=True,
683
+ )
684
+ except Exception:
685
+ await _record_request(model_info.model_id, False)
686
+ raise
687
+
688
+ processor = ImageStreamProcessor(
689
+ model_info.model_id,
690
+ token,
691
+ n=n,
692
+ response_format=response_format,
693
+ )
694
+
695
+ async def _wrapped_stream():
696
+ completed = False
697
+ try:
698
+ async for chunk in processor.process(response):
699
+ yield chunk
700
+ completed = True
701
+ finally:
702
+ try:
703
+ if completed:
704
+ await token_mgr.sync_usage(
705
+ token,
706
+ model_info.model_id,
707
+ consume_on_fail=True,
708
+ is_usage=True,
709
+ )
710
+ await _record_request(model_info.model_id, True)
711
+ else:
712
+ await _record_request(model_info.model_id, False)
713
+ except Exception:
714
+ pass
715
+
716
+ return StreamingResponse(
717
+ _wrapped_stream(),
718
+ media_type="text/event-stream",
719
+ headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
720
+ )
721
+
722
+ all_images: List[str] = []
723
+ if image_method == IMAGE_METHOD_IMAGINE_WS_EXPERIMENTAL:
724
+ try:
725
+ all_images = await _collect_experimental_generation_images(
726
+ token=token,
727
+ prompt=request.prompt,
728
+ n=n,
729
+ response_format=response_format,
730
+ aspect_ratio=aspect_ratio,
731
+ concurrency=concurrency,
732
+ )
733
+ except Exception as e:
734
+ logger.warning(f"Experimental image generation failed, fallback to legacy: {e}")
735
+
736
+ if not all_images:
737
+ calls_needed = (n + 1) // 2
738
+ task_factories: List[Callable[[], Awaitable[List[str]]]] = [
739
+ lambda: call_grok_legacy(
740
+ token,
741
+ f"Image Generation: {request.prompt}",
742
+ model_info,
743
+ response_format=response_format,
744
+ )
745
+ for _ in range(calls_needed)
746
+ ]
747
+ results = await _gather_limited(
748
+ task_factories,
749
+ max_concurrency=min(calls_needed, concurrency),
750
+ )
751
+
752
+ all_images = []
753
+ for result in results:
754
+ if isinstance(result, Exception):
755
+ logger.error(f"Concurrent call failed: {result}")
756
+ elif isinstance(result, list):
757
+ all_images.extend(result)
758
+
759
+ selected_images = _pick_images(_dedupe_images(all_images), n)
760
+ success = any(_is_valid_image_value(img) for img in selected_images)
761
+ try:
762
+ if success:
763
+ await token_mgr.sync_usage(
764
+ token,
765
+ model_info.model_id,
766
+ consume_on_fail=True,
767
+ is_usage=True,
768
+ )
769
+ await _record_request(model_info.model_id, bool(success))
770
+ except Exception:
771
+ pass
772
+
773
+ return _build_image_response(selected_images, response_field)
774
+
775
+
776
+ @router.post("/images/edits")
777
+ async def edit_image(
778
+ prompt: str = Form(...),
779
+ image: Optional[List[UploadFile]] = File(None),
780
+ image_alias: Optional[List[UploadFile]] = File(None, alias="image[]"),
781
+ model: Optional[str] = Form("grok-imagine-1.0-edit"),
782
+ n: int = Form(1),
783
+ size: str = Form("1024x1024"),
784
+ quality: str = Form("standard"),
785
+ response_format: Optional[str] = Form(None),
786
+ style: Optional[str] = Form(None),
787
+ stream: Optional[bool] = Form(False),
788
+ api_key: Optional[str] = Depends(verify_api_key),
789
+ ):
790
+ """
791
+ Image Edits API
792
+
793
+ 同官方 API 格式,仅支持 multipart/form-data 文件上传
794
+ """
795
+ try:
796
+ edit_request = ImageEditRequest(
797
+ prompt=prompt,
798
+ model=model,
799
+ n=n,
800
+ size=size,
801
+ quality=quality,
802
+ response_format=response_format,
803
+ style=style,
804
+ stream=stream,
805
+ )
806
+ except ValidationError as exc:
807
+ errors = exc.errors()
808
+ if errors:
809
+ first = errors[0]
810
+ loc = first.get("loc", [])
811
+ msg = first.get("msg", "Invalid request")
812
+ code = first.get("type", "invalid_value")
813
+ param_parts = [str(x) for x in loc if not (isinstance(x, int) or str(x).isdigit())]
814
+ param = ".".join(param_parts) if param_parts else None
815
+ raise ValidationException(message=msg, param=param, code=code)
816
+ raise ValidationException(message="Invalid request", code="invalid_value")
817
+
818
+ if edit_request.stream is None:
819
+ edit_request.stream = False
820
+ if edit_request.n is None:
821
+ edit_request.n = 1
822
+
823
+ image_method = _image_generation_method()
824
+ response_format = resolve_image_response_format(edit_request.response_format, image_method)
825
+ edit_request.response_format = response_format
826
+ response_field = response_field_name(response_format)
827
+ images = (image or []) + (image_alias or [])
828
+ validate_edit_request(edit_request, images)
829
+
830
+ model_id = edit_request.model or "grok-imagine-1.0-edit"
831
+ n = int(edit_request.n or 1)
832
+
833
+ await enforce_daily_quota(api_key, model_id, image_count=n)
834
+
835
+ max_image_bytes = 50 * 1024 * 1024
836
+ allowed_types = {"image/png", "image/jpeg", "image/webp", "image/jpg"}
837
+ image_payloads: List[str] = []
838
+
839
+ for item in images:
840
+ content = await item.read()
841
+ await item.close()
842
+ if not content:
843
+ raise ValidationException(
844
+ message="File content is empty",
845
+ param="image",
846
+ code="empty_file",
847
+ )
848
+ if len(content) > max_image_bytes:
849
+ raise ValidationException(
850
+ message="Image file too large. Maximum is 50MB.",
851
+ param="image",
852
+ code="file_too_large",
853
+ )
854
+
855
+ mime = (item.content_type or "").lower()
856
+ if mime == "image/jpg":
857
+ mime = "image/jpeg"
858
+ ext = Path(item.filename or "").suffix.lower()
859
+ if mime not in allowed_types:
860
+ if ext in (".jpg", ".jpeg"):
861
+ mime = "image/jpeg"
862
+ elif ext == ".png":
863
+ mime = "image/png"
864
+ elif ext == ".webp":
865
+ mime = "image/webp"
866
+ else:
867
+ raise ValidationException(
868
+ message="Unsupported image type. Supported: png, jpg, webp.",
869
+ param="image",
870
+ code="invalid_image_type",
871
+ )
872
+
873
+ image_payloads.append(f"data:{mime};base64,{base64.b64encode(content).decode()}")
874
+
875
+ token_mgr, token = await _get_token_for_model(model_id)
876
+ model_info = ModelService.get(model_id)
877
+
878
+ file_ids: List[str] = []
879
+ file_uris: List[str] = []
880
+ upload_service = UploadService()
881
+ try:
882
+ for payload in image_payloads:
883
+ file_id, file_uri = await upload_service.upload(payload, token)
884
+ if file_id:
885
+ file_ids.append(file_id)
886
+ if file_uri:
887
+ file_uris.append(file_uri)
888
+ finally:
889
+ await upload_service.close()
890
+
891
+ if edit_request.stream:
892
+ if image_method == IMAGE_METHOD_IMAGINE_WS_EXPERIMENTAL:
893
+ try:
894
+ service = ImagineExperimentalService()
895
+ response = await service.chat_edit(
896
+ token=token,
897
+ prompt=edit_request.prompt,
898
+ file_uris=file_uris,
899
+ )
900
+ processor = ImageStreamProcessor(
901
+ model_info.model_id,
902
+ token,
903
+ n=n,
904
+ response_format=response_format,
905
+ )
906
+
907
+ async def _wrapped_experimental_stream():
908
+ completed = False
909
+ try:
910
+ async for chunk in processor.process(response):
911
+ yield chunk
912
+ completed = True
913
+ finally:
914
+ try:
915
+ if completed:
916
+ await token_mgr.sync_usage(
917
+ token,
918
+ model_info.model_id,
919
+ consume_on_fail=True,
920
+ is_usage=True,
921
+ )
922
+ await _record_request(model_info.model_id, True)
923
+ else:
924
+ await _record_request(model_info.model_id, False)
925
+ except Exception:
926
+ pass
927
+
928
+ return StreamingResponse(
929
+ _wrapped_experimental_stream(),
930
+ media_type="text/event-stream",
931
+ headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
932
+ )
933
+ except Exception as e:
934
+ logger.warning(f"Experimental image edit stream failed, fallback to legacy: {e}")
935
+
936
+ chat_service = GrokChatService()
937
+ try:
938
+ response = await chat_service.chat(
939
+ token=token,
940
+ message=f"Image Edit: {edit_request.prompt}",
941
+ model=model_info.grok_model,
942
+ mode=model_info.model_mode,
943
+ think=False,
944
+ stream=True,
945
+ file_attachments=file_ids,
946
+ )
947
+ except Exception:
948
+ await _record_request(model_info.model_id, False)
949
+ raise
950
+
951
+ processor = ImageStreamProcessor(
952
+ model_info.model_id,
953
+ token,
954
+ n=n,
955
+ response_format=response_format,
956
+ )
957
+
958
+ async def _wrapped_stream():
959
+ completed = False
960
+ try:
961
+ async for chunk in processor.process(response):
962
+ yield chunk
963
+ completed = True
964
+ finally:
965
+ try:
966
+ if completed:
967
+ await token_mgr.sync_usage(
968
+ token,
969
+ model_info.model_id,
970
+ consume_on_fail=True,
971
+ is_usage=True,
972
+ )
973
+ await _record_request(model_info.model_id, True)
974
+ else:
975
+ await _record_request(model_info.model_id, False)
976
+ except Exception:
977
+ pass
978
+
979
+ return StreamingResponse(
980
+ _wrapped_stream(),
981
+ media_type="text/event-stream",
982
+ headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
983
+ )
984
+
985
+ all_images: List[str] = []
986
+ if image_method == IMAGE_METHOD_IMAGINE_WS_EXPERIMENTAL:
987
+ try:
988
+ calls_needed = (n + 1) // 2
989
+ if calls_needed == 1:
990
+ all_images = await call_grok_experimental_edit(
991
+ token=token,
992
+ prompt=edit_request.prompt,
993
+ model_id=model_info.model_id,
994
+ file_uris=file_uris,
995
+ response_format=response_format,
996
+ )
997
+ else:
998
+ tasks = [
999
+ call_grok_experimental_edit(
1000
+ token=token,
1001
+ prompt=edit_request.prompt,
1002
+ model_id=model_info.model_id,
1003
+ file_uris=file_uris,
1004
+ response_format=response_format,
1005
+ )
1006
+ for _ in range(calls_needed)
1007
+ ]
1008
+ results = await asyncio.gather(*tasks, return_exceptions=True)
1009
+ for result in results:
1010
+ if isinstance(result, Exception):
1011
+ logger.warning(f"Experimental image edit call failed: {result}")
1012
+ elif isinstance(result, list):
1013
+ all_images.extend(result)
1014
+ if not all_images:
1015
+ raise UpstreamException("Experimental image edit returned no images")
1016
+ except Exception as e:
1017
+ logger.warning(f"Experimental image edit failed, fallback to legacy: {e}")
1018
+
1019
+ if not all_images:
1020
+ calls_needed = (n + 1) // 2
1021
+ if calls_needed == 1:
1022
+ all_images = await call_grok_legacy(
1023
+ token,
1024
+ f"Image Edit: {edit_request.prompt}",
1025
+ model_info,
1026
+ file_attachments=file_ids,
1027
+ response_format=response_format,
1028
+ )
1029
+ else:
1030
+ tasks = [
1031
+ call_grok_legacy(
1032
+ token,
1033
+ f"Image Edit: {edit_request.prompt}",
1034
+ model_info,
1035
+ file_attachments=file_ids,
1036
+ response_format=response_format,
1037
+ )
1038
+ for _ in range(calls_needed)
1039
+ ]
1040
+ results = await asyncio.gather(*tasks, return_exceptions=True)
1041
+ all_images = []
1042
+ for result in results:
1043
+ if isinstance(result, Exception):
1044
+ logger.error(f"Concurrent call failed: {result}")
1045
+ elif isinstance(result, list):
1046
+ all_images.extend(result)
1047
+
1048
+ selected_images = _pick_images(all_images, n)
1049
+ success = any(isinstance(img, str) and img and img != "error" for img in selected_images)
1050
+ try:
1051
+ if success:
1052
+ await token_mgr.sync_usage(
1053
+ token,
1054
+ model_info.model_id,
1055
+ consume_on_fail=True,
1056
+ is_usage=True,
1057
+ )
1058
+ await _record_request(model_info.model_id, bool(success))
1059
+ except Exception:
1060
+ pass
1061
+
1062
+ return _build_image_response(selected_images, response_field)
1063
+
1064
+
1065
+ __all__ = ["router"]
app/api/v1/models.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Models API 路由
3
+ """
4
+
5
+ import time
6
+
7
+ from fastapi import APIRouter, HTTPException
8
+
9
+ from app.services.grok.model import ModelService
10
+
11
+
12
+ router = APIRouter(tags=["Models"])
13
+
14
+
15
+ @router.get("/models")
16
+ async def list_models():
17
+ """OpenAI 兼容 models 列表接口"""
18
+ ts = int(time.time())
19
+ data = [
20
+ {
21
+ "id": m.model_id,
22
+ "object": "model",
23
+ "created": ts,
24
+ "owned_by": "grok2api",
25
+ "display_name": m.display_name,
26
+ "description": m.description,
27
+ }
28
+ for m in ModelService.list()
29
+ ]
30
+ return {"object": "list", "data": data}
31
+
32
+
33
+ @router.get("/models/{model_id}")
34
+ async def get_model(model_id: str):
35
+ """OpenAI compatible: single model detail."""
36
+ m = ModelService.get(model_id)
37
+ if not m:
38
+ raise HTTPException(status_code=404, detail=f"Model '{model_id}' not found")
39
+
40
+ ts = int(time.time())
41
+ return {
42
+ "id": m.model_id,
43
+ "object": "model",
44
+ "created": ts,
45
+ "owned_by": "grok2api",
46
+ "display_name": m.display_name,
47
+ "description": m.description,
48
+ }
49
+
50
+
51
+ __all__ = ["router"]
app/api/v1/uploads.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Uploads API (used by the web chat UI)
3
+ """
4
+
5
+ import uuid
6
+ from pathlib import Path
7
+
8
+ import aiofiles
9
+ from fastapi import APIRouter, UploadFile, File, HTTPException
10
+
11
+ from app.services.grok.assets import DownloadService
12
+
13
+
14
+ router = APIRouter(tags=["Uploads"])
15
+
16
+ BASE_DIR = Path(__file__).parent.parent.parent.parent / "data" / "tmp"
17
+ IMAGE_DIR = BASE_DIR / "image"
18
+
19
+
20
+ def _ext_from_mime(mime: str) -> str:
21
+ m = (mime or "").lower()
22
+ if m == "image/png":
23
+ return "png"
24
+ if m == "image/webp":
25
+ return "webp"
26
+ if m == "image/gif":
27
+ return "gif"
28
+ if m in ("image/jpeg", "image/jpg"):
29
+ return "jpg"
30
+ return "jpg"
31
+
32
+
33
+ @router.post("/uploads/image")
34
+ async def upload_image(file: UploadFile = File(...)):
35
+ content_type = (file.content_type or "").lower()
36
+ if not content_type.startswith("image/"):
37
+ raise HTTPException(status_code=400, detail=f"Unsupported file type: {file.content_type}")
38
+
39
+ IMAGE_DIR.mkdir(parents=True, exist_ok=True)
40
+ name = f"upload-{uuid.uuid4().hex}.{_ext_from_mime(content_type)}"
41
+ path = IMAGE_DIR / name
42
+
43
+ size = 0
44
+ async with aiofiles.open(path, "wb") as f:
45
+ while True:
46
+ chunk = await file.read(1024 * 1024)
47
+ if not chunk:
48
+ break
49
+ size += len(chunk)
50
+ await f.write(chunk)
51
+
52
+ # Best-effort: reuse existing cache cleanup policy (size-based).
53
+ try:
54
+ dl = DownloadService()
55
+ await dl.check_limit()
56
+ await dl.close()
57
+ except Exception:
58
+ pass
59
+
60
+ return {"url": f"/v1/files/image/{name}", "name": name, "size_bytes": size}
61
+
62
+
63
+ __all__ = ["router"]
64
+
app/api/v1/video.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """
2
+ TODO:Video Generation API 路由
3
+ """
app/core/auth.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ API 认证模块
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ import asyncio
8
+ import json
9
+ from pathlib import Path
10
+ from typing import Optional, Set
11
+
12
+ from fastapi import HTTPException, Security, status
13
+ from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
14
+
15
+ from app.core.config import get_config
16
+
17
+ # 定义 Bearer Scheme
18
+ security = HTTPBearer(
19
+ auto_error=False,
20
+ scheme_name="API Key",
21
+ description="Enter your API Key in the format: Bearer <key>",
22
+ )
23
+
24
+ LEGACY_API_KEYS_FILE = Path(__file__).parent.parent.parent / "data" / "api_keys.json"
25
+ _legacy_api_keys_cache: Set[str] | None = None
26
+ _legacy_api_keys_mtime: float | None = None
27
+ _legacy_api_keys_lock = asyncio.Lock()
28
+
29
+
30
+ async def _load_legacy_api_keys() -> Set[str]:
31
+ """
32
+ Backward-compatible API keys loader.
33
+
34
+ Older versions stored multiple API keys in `data/api_keys.json` with a shape like:
35
+ [{"key": "...", "is_active": true, ...}, ...]
36
+ """
37
+ global _legacy_api_keys_cache, _legacy_api_keys_mtime
38
+
39
+ if not LEGACY_API_KEYS_FILE.exists():
40
+ _legacy_api_keys_cache = set()
41
+ _legacy_api_keys_mtime = None
42
+ return set()
43
+
44
+ try:
45
+ stat = LEGACY_API_KEYS_FILE.stat()
46
+ mtime = stat.st_mtime
47
+ except Exception:
48
+ mtime = None
49
+
50
+ if _legacy_api_keys_cache is not None and mtime is not None and _legacy_api_keys_mtime == mtime:
51
+ return _legacy_api_keys_cache
52
+
53
+ async with _legacy_api_keys_lock:
54
+ # Re-check in lock
55
+ if not LEGACY_API_KEYS_FILE.exists():
56
+ _legacy_api_keys_cache = set()
57
+ _legacy_api_keys_mtime = None
58
+ return set()
59
+
60
+ try:
61
+ stat = LEGACY_API_KEYS_FILE.stat()
62
+ mtime = stat.st_mtime
63
+ except Exception:
64
+ mtime = None
65
+
66
+ if _legacy_api_keys_cache is not None and mtime is not None and _legacy_api_keys_mtime == mtime:
67
+ return _legacy_api_keys_cache
68
+
69
+ try:
70
+ raw = await asyncio.to_thread(LEGACY_API_KEYS_FILE.read_text, "utf-8")
71
+ data = json.loads(raw) if raw.strip() else []
72
+ except Exception:
73
+ data = []
74
+
75
+ keys: Set[str] = set()
76
+ if isinstance(data, list):
77
+ for item in data:
78
+ if not isinstance(item, dict):
79
+ continue
80
+ key = item.get("key")
81
+ is_active = item.get("is_active", True)
82
+ if isinstance(key, str) and key.strip() and is_active is not False:
83
+ keys.add(key.strip())
84
+
85
+ _legacy_api_keys_cache = keys
86
+ _legacy_api_keys_mtime = mtime
87
+ return keys
88
+
89
+
90
+ async def verify_api_key(
91
+ auth: Optional[HTTPAuthorizationCredentials] = Security(security),
92
+ ) -> Optional[str]:
93
+ """
94
+ 验证 Bearer Token
95
+
96
+ - 若 `app.api_key` 未配置且不存在 legacy keys,则跳过验证。
97
+ - 若配置了 `app.api_key` 或存在 legacy keys,则必须提供 Authorization: Bearer <key>。
98
+ """
99
+ api_key = str(get_config("app.api_key", "") or "").strip()
100
+ legacy_keys = await _load_legacy_api_keys()
101
+
102
+ # 如果未配置 API Key 且没有 legacy keys,直接放行
103
+ if not api_key and not legacy_keys:
104
+ return None
105
+
106
+ if not auth:
107
+ raise HTTPException(
108
+ status_code=status.HTTP_401_UNAUTHORIZED,
109
+ detail="Missing authentication token",
110
+ headers={"WWW-Authenticate": "Bearer"},
111
+ )
112
+
113
+ token = auth.credentials
114
+ if (api_key and token == api_key) or token in legacy_keys:
115
+ return token
116
+
117
+ raise HTTPException(
118
+ status_code=status.HTTP_401_UNAUTHORIZED,
119
+ detail="Invalid authentication token",
120
+ headers={"WWW-Authenticate": "Bearer"},
121
+ )
122
+
123
+
124
+ async def verify_app_key(
125
+ auth: Optional[HTTPAuthorizationCredentials] = Security(security),
126
+ ) -> Optional[str]:
127
+ """
128
+ 验证后台登录密钥(app_key)。
129
+
130
+ 如果未配置 app_key,则跳过验证。
131
+ """
132
+ app_key = str(get_config("app.app_key", "") or "").strip()
133
+
134
+ if not app_key:
135
+ raise HTTPException(
136
+ status_code=status.HTTP_401_UNAUTHORIZED,
137
+ detail="App key is not configured",
138
+ headers={"WWW-Authenticate": "Bearer"},
139
+ )
140
+
141
+ if not auth:
142
+ raise HTTPException(
143
+ status_code=status.HTTP_401_UNAUTHORIZED,
144
+ detail="Missing authentication token",
145
+ headers={"WWW-Authenticate": "Bearer"},
146
+ )
147
+
148
+ if auth.credentials != app_key:
149
+ raise HTTPException(
150
+ status_code=status.HTTP_401_UNAUTHORIZED,
151
+ detail="Invalid authentication token",
152
+ headers={"WWW-Authenticate": "Bearer"},
153
+ )
154
+
155
+ return auth.credentials
156
+
157
+
158
+ __all__ = ["verify_api_key", "verify_app_key"]
159
+
app/core/config.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 配置管理
3
+
4
+ - config.toml: 运行时配置
5
+ - config.defaults.toml: 默认配置基线
6
+ """
7
+
8
+ from copy import deepcopy
9
+ from pathlib import Path
10
+ from typing import Any, Dict
11
+ import tomllib
12
+
13
+ from app.core.logger import logger
14
+
15
+ DEFAULT_CONFIG_FILE = Path(__file__).parent.parent.parent / "config.defaults.toml"
16
+ LEGACY_CONFIG_FILE = Path(__file__).parent.parent.parent / "data" / "setting.toml"
17
+
18
+
19
+ def _as_str(v: Any) -> str:
20
+ if isinstance(v, str):
21
+ return v
22
+ return ""
23
+
24
+
25
+ def _as_int(v: Any) -> int | None:
26
+ try:
27
+ if v is None:
28
+ return None
29
+ return int(v)
30
+ except Exception:
31
+ return None
32
+
33
+
34
+ def _as_bool(v: Any) -> bool | None:
35
+ if isinstance(v, bool):
36
+ return v
37
+ return None
38
+
39
+
40
+ def _split_csv_tags(v: Any) -> list[str] | None:
41
+ if not isinstance(v, str):
42
+ return None
43
+ parts = [x.strip() for x in v.split(",")]
44
+ tags = [x for x in parts if x]
45
+ return tags or None
46
+
47
+
48
+ def _legacy_setting_to_config(legacy: Dict[str, Any]) -> Dict[str, Any]:
49
+ """
50
+ Migrate legacy `data/setting.toml` format (grok/global) to the new config schema.
51
+
52
+ Best-effort mapping only for stable fields. It does not delete or rename the legacy file.
53
+ """
54
+
55
+ grok = legacy.get("grok") if isinstance(legacy.get("grok"), dict) else {}
56
+ global_ = legacy.get("global") if isinstance(legacy.get("global"), dict) else {}
57
+
58
+ out: Dict[str, Any] = {}
59
+
60
+ # === app ===
61
+ app_url = _as_str(global_.get("base_url")).strip()
62
+ admin_username = _as_str(global_.get("admin_username")).strip()
63
+ app_key = _as_str(global_.get("admin_password")).strip()
64
+ api_key = _as_str(grok.get("api_key")).strip()
65
+ image_format = _as_str(global_.get("image_mode")).strip()
66
+
67
+ if app_url or admin_username or app_key or api_key or image_format:
68
+ out["app"] = {}
69
+ if app_url:
70
+ out["app"]["app_url"] = app_url
71
+ if admin_username:
72
+ out["app"]["admin_username"] = admin_username
73
+ if app_key:
74
+ out["app"]["app_key"] = app_key
75
+ if api_key:
76
+ out["app"]["api_key"] = api_key
77
+ if image_format:
78
+ out["app"]["image_format"] = image_format
79
+
80
+ # === grok ===
81
+ base_proxy_url = _as_str(grok.get("proxy_url")).strip()
82
+ asset_proxy_url = _as_str(grok.get("cache_proxy_url")).strip()
83
+ cf_clearance = _as_str(grok.get("cf_clearance")).strip()
84
+
85
+ temporary = _as_bool(grok.get("temporary"))
86
+ thinking = _as_bool(grok.get("show_thinking"))
87
+ dynamic_statsig = _as_bool(grok.get("dynamic_statsig"))
88
+ filter_tags = _split_csv_tags(grok.get("filtered_tags"))
89
+
90
+ retry_status_codes = grok.get("retry_status_codes")
91
+
92
+ timeout = None
93
+ total_timeout = _as_int(grok.get("stream_total_timeout"))
94
+ if total_timeout and total_timeout > 0:
95
+ timeout = total_timeout
96
+ else:
97
+ chunk_timeout = _as_int(grok.get("stream_chunk_timeout"))
98
+ if chunk_timeout and chunk_timeout > 0:
99
+ timeout = chunk_timeout
100
+
101
+ if (
102
+ base_proxy_url
103
+ or asset_proxy_url
104
+ or cf_clearance
105
+ or temporary is not None
106
+ or thinking is not None
107
+ or dynamic_statsig is not None
108
+ or filter_tags is not None
109
+ or timeout is not None
110
+ or isinstance(retry_status_codes, list)
111
+ ):
112
+ out["grok"] = {}
113
+ if base_proxy_url:
114
+ out["grok"]["base_proxy_url"] = base_proxy_url
115
+ if asset_proxy_url:
116
+ out["grok"]["asset_proxy_url"] = asset_proxy_url
117
+ if cf_clearance:
118
+ out["grok"]["cf_clearance"] = cf_clearance
119
+ if temporary is not None:
120
+ out["grok"]["temporary"] = temporary
121
+ if thinking is not None:
122
+ out["grok"]["thinking"] = thinking
123
+ if dynamic_statsig is not None:
124
+ out["grok"]["dynamic_statsig"] = dynamic_statsig
125
+ if filter_tags is not None:
126
+ out["grok"]["filter_tags"] = filter_tags
127
+ if timeout is not None:
128
+ out["grok"]["timeout"] = timeout
129
+ if isinstance(retry_status_codes, list) and retry_status_codes:
130
+ out["grok"]["retry_status_codes"] = retry_status_codes
131
+
132
+ # === cache ===
133
+ # Legacy had separate limits; new uses a single total limit_mb.
134
+ image_mb = _as_int(global_.get("image_cache_max_size_mb")) or 0
135
+ video_mb = _as_int(global_.get("video_cache_max_size_mb")) or 0
136
+ if image_mb > 0 or video_mb > 0:
137
+ out["cache"] = {"limit_mb": max(1, image_mb + video_mb)}
138
+
139
+ return out
140
+
141
+
142
+ def _apply_legacy_config(
143
+ config_data: Dict[str, Any],
144
+ legacy_cfg: Dict[str, Any],
145
+ defaults: Dict[str, Any],
146
+ ) -> bool:
147
+ """
148
+ Merge legacy settings into current config:
149
+ - fill missing keys
150
+ - override keys that are still default values
151
+ """
152
+
153
+ changed = False
154
+ for section, items in legacy_cfg.items():
155
+ if not isinstance(items, dict):
156
+ continue
157
+
158
+ current_section = config_data.get(section)
159
+ if not isinstance(current_section, dict):
160
+ current_section = {}
161
+ config_data[section] = current_section
162
+ changed = True
163
+
164
+ default_section = defaults.get(section) if isinstance(defaults.get(section), dict) else {}
165
+
166
+ for key, val in items.items():
167
+ if val is None:
168
+ continue
169
+ if key not in current_section:
170
+ current_section[key] = val
171
+ changed = True
172
+ continue
173
+
174
+ default_val = default_section.get(key) if isinstance(default_section, dict) else None
175
+ current_val = current_section.get(key)
176
+
177
+ # NOTE: The admin panel password default used to be `grok2api` in older versions.
178
+ # Treat it as "still default" so legacy `data/setting.toml` can override it during migration.
179
+ is_effective_default = current_val == default_val
180
+ if section == "app" and key == "app_key" and current_val == "grok2api":
181
+ is_effective_default = True
182
+
183
+ if is_effective_default and val != default_val:
184
+ current_section[key] = val
185
+ changed = True
186
+
187
+ return changed
188
+
189
+
190
+ def _deep_merge(base: Dict[str, Any], override: Dict[str, Any]) -> Dict[str, Any]:
191
+ """深度合并字典:override 覆盖 base。"""
192
+ if not isinstance(base, dict):
193
+ return deepcopy(override) if isinstance(override, dict) else deepcopy(base)
194
+
195
+ result = deepcopy(base)
196
+ if not isinstance(override, dict):
197
+ return result
198
+
199
+ for key, val in override.items():
200
+ if isinstance(val, dict) and isinstance(result.get(key), dict):
201
+ result[key] = _deep_merge(result[key], val)
202
+ else:
203
+ result[key] = val
204
+ return result
205
+
206
+
207
+ def _load_defaults() -> Dict[str, Any]:
208
+ """加载默认配置文件"""
209
+ if not DEFAULT_CONFIG_FILE.exists():
210
+ return {}
211
+ try:
212
+ with DEFAULT_CONFIG_FILE.open("rb") as f:
213
+ return tomllib.load(f)
214
+ except Exception as e:
215
+ logger.warning(f"Failed to load defaults from {DEFAULT_CONFIG_FILE}: {e}")
216
+ return {}
217
+
218
+
219
+ class Config:
220
+ """配置管理器"""
221
+
222
+ _instance = None
223
+ _config = {}
224
+
225
+ def __init__(self):
226
+ self._config = {}
227
+ self._defaults = {}
228
+ self._defaults_loaded = False
229
+
230
+ def _ensure_defaults(self):
231
+ if self._defaults_loaded:
232
+ return
233
+ self._defaults = _load_defaults()
234
+ self._defaults_loaded = True
235
+
236
+ async def load(self):
237
+ """显式加载配置"""
238
+ try:
239
+ from app.core.storage import get_storage, LocalStorage
240
+
241
+ self._ensure_defaults()
242
+
243
+ storage = get_storage()
244
+ config_data = await storage.load_config()
245
+ from_remote = True
246
+
247
+ # 从本地 data/config.toml 初始化后端
248
+ if config_data is None:
249
+ local_storage = LocalStorage()
250
+ from_remote = False
251
+ try:
252
+ config_data = await local_storage.load_config()
253
+ except Exception as e:
254
+ logger.info(f"Failed to auto-init config from local: {e}")
255
+ config_data = {}
256
+
257
+ config_data = config_data or {}
258
+ before_legacy = deepcopy(config_data)
259
+
260
+ # Legacy migration: data/setting.toml -> config schema
261
+ if LEGACY_CONFIG_FILE.exists():
262
+ try:
263
+ with LEGACY_CONFIG_FILE.open("rb") as f:
264
+ legacy_raw = tomllib.load(f) or {}
265
+ legacy_cfg = _legacy_setting_to_config(legacy_raw)
266
+ if legacy_cfg and _apply_legacy_config(config_data, legacy_cfg, self._defaults):
267
+ logger.info(
268
+ "Detected legacy data/setting.toml, migrated into config (missing/default keys)."
269
+ )
270
+ except Exception as e:
271
+ logger.warning(f"Failed to migrate legacy config from {LEGACY_CONFIG_FILE}: {e}")
272
+
273
+ merged = _deep_merge(self._defaults, config_data)
274
+
275
+ # 自动回填缺失配置到存储
276
+ should_persist = (not from_remote) or (merged != before_legacy)
277
+ if should_persist:
278
+ async with storage.acquire_lock("config_save", timeout=10):
279
+ await storage.save_config(merged)
280
+ if not from_remote:
281
+ logger.info(
282
+ f"Initialized remote storage ({storage.__class__.__name__}) with config baseline."
283
+ )
284
+
285
+ self._config = merged
286
+ except Exception as e:
287
+ logger.error(f"Error loading config: {e}")
288
+ self._config = {}
289
+
290
+ def get(self, key: str, default: Any = None) -> Any:
291
+ """
292
+ 获取配置值
293
+
294
+ Args:
295
+ key: 配置键,格式 "section.key"
296
+ default: 默认值
297
+ """
298
+ if "." in key:
299
+ try:
300
+ section, attr = key.split(".", 1)
301
+ return self._config.get(section, {}).get(attr, default)
302
+ except (ValueError, AttributeError):
303
+ return default
304
+
305
+ return self._config.get(key, default)
306
+
307
+ async def update(self, new_config: dict):
308
+ """更新配置"""
309
+ from app.core.storage import get_storage
310
+
311
+ storage = get_storage()
312
+ async with storage.acquire_lock("config_save", timeout=10):
313
+ self._ensure_defaults()
314
+ base = _deep_merge(self._defaults, self._config or {})
315
+ merged = _deep_merge(base, new_config or {})
316
+ await storage.save_config(merged)
317
+ self._config = merged
318
+
319
+
320
+ # 全局配置实例
321
+ config = Config()
322
+
323
+
324
+ def get_config(key: str, default: Any = None) -> Any:
325
+ """获取配置"""
326
+ return config.get(key, default)
327
+
328
+
329
+ __all__ = ["Config", "config", "get_config"]
app/core/exceptions.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 全局异常处理 - OpenAI 兼容错误格式
3
+ """
4
+
5
+ from typing import Any, Optional
6
+ from enum import Enum
7
+ from fastapi import Request, HTTPException
8
+ from fastapi.responses import JSONResponse
9
+ from fastapi.exceptions import RequestValidationError
10
+
11
+ from app.core.logger import logger
12
+
13
+
14
+ # ============= 错误类型 =============
15
+
16
+ class ErrorType(str, Enum):
17
+ """OpenAI 错误类型"""
18
+ INVALID_REQUEST = "invalid_request_error"
19
+ AUTHENTICATION = "authentication_error"
20
+ PERMISSION = "permission_error"
21
+ NOT_FOUND = "not_found_error"
22
+ RATE_LIMIT = "rate_limit_error"
23
+ SERVER = "server_error"
24
+ SERVICE_UNAVAILABLE = "service_unavailable_error"
25
+
26
+
27
+ # ============= 辅助函数 =============
28
+
29
+ def error_response(
30
+ message: str,
31
+ error_type: str = ErrorType.INVALID_REQUEST.value,
32
+ param: str = None,
33
+ code: str = None
34
+ ) -> dict:
35
+ """构建 OpenAI 错误响应"""
36
+ return {
37
+ "error": {
38
+ "message": message,
39
+ "type": error_type,
40
+ "param": param,
41
+ "code": code
42
+ }
43
+ }
44
+
45
+
46
+ # ============= 异常类 =============
47
+
48
+ class AppException(Exception):
49
+ """应用基础异常"""
50
+
51
+ def __init__(
52
+ self,
53
+ message: str,
54
+ error_type: str = ErrorType.SERVER.value,
55
+ code: str = None,
56
+ param: str = None,
57
+ status_code: int = 500
58
+ ):
59
+ self.message = message
60
+ self.error_type = error_type
61
+ self.code = code
62
+ self.param = param
63
+ self.status_code = status_code
64
+ super().__init__(message)
65
+
66
+
67
+ class ValidationException(AppException):
68
+ """验证错误"""
69
+
70
+ def __init__(self, message: str, param: str = None, code: str = None):
71
+ super().__init__(
72
+ message=message,
73
+ error_type=ErrorType.INVALID_REQUEST.value,
74
+ code=code or "invalid_value",
75
+ param=param,
76
+ status_code=400
77
+ )
78
+
79
+
80
+ class AuthenticationException(AppException):
81
+ """认证错误"""
82
+
83
+ def __init__(self, message: str = "Invalid API key"):
84
+ super().__init__(
85
+ message=message,
86
+ error_type=ErrorType.AUTHENTICATION.value,
87
+ code="invalid_api_key",
88
+ status_code=401
89
+ )
90
+
91
+
92
+ class UpstreamException(AppException):
93
+ """上游服务错误"""
94
+
95
+ def __init__(self, message: str, details: Any = None):
96
+ super().__init__(
97
+ message=message,
98
+ error_type=ErrorType.SERVER.value,
99
+ code="upstream_error",
100
+ status_code=502
101
+ )
102
+ self.details = details
103
+
104
+
105
+ # ============= 异常处理器 =============
106
+
107
+ async def app_exception_handler(request: Request, exc: AppException) -> JSONResponse:
108
+ """处理应用异常"""
109
+ logger.warning(f"AppException: {exc.error_type} - {exc.message}")
110
+
111
+ return JSONResponse(
112
+ status_code=exc.status_code,
113
+ content=error_response(
114
+ message=exc.message,
115
+ error_type=exc.error_type,
116
+ param=exc.param,
117
+ code=exc.code
118
+ )
119
+ )
120
+
121
+
122
+ async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse:
123
+ """处理 HTTP 异常"""
124
+ type_map = {
125
+ 400: ErrorType.INVALID_REQUEST.value,
126
+ 401: ErrorType.AUTHENTICATION.value,
127
+ 403: ErrorType.PERMISSION.value,
128
+ 404: ErrorType.NOT_FOUND.value,
129
+ 429: ErrorType.RATE_LIMIT.value,
130
+ }
131
+ error_type = type_map.get(exc.status_code, ErrorType.SERVER.value)
132
+
133
+ # 默认 code 映射
134
+ code_map = {
135
+ 401: "invalid_api_key",
136
+ 403: "insufficient_quota",
137
+ 404: "model_not_found",
138
+ 429: "rate_limit_exceeded",
139
+ }
140
+ code = code_map.get(exc.status_code, None)
141
+
142
+ logger.warning(f"HTTPException: {exc.status_code} - {exc.detail}")
143
+
144
+ return JSONResponse(
145
+ status_code=exc.status_code,
146
+ content=error_response(
147
+ message=str(exc.detail),
148
+ error_type=error_type,
149
+ code=code
150
+ )
151
+ )
152
+
153
+
154
+ async def validation_exception_handler(request: Request, exc: RequestValidationError) -> JSONResponse:
155
+ """处理验证错误"""
156
+ errors = exc.errors()
157
+
158
+ if errors:
159
+ first = errors[0]
160
+ loc = first.get("loc", [])
161
+ msg = first.get("msg", "Invalid request")
162
+ code = first.get("type", "invalid_value")
163
+
164
+ # JSON 解析错误
165
+ if code == "json_invalid" or "JSON" in msg:
166
+ message = "Invalid JSON in request body. Please check for trailing commas or syntax errors."
167
+ param = "body"
168
+ else:
169
+ param_parts = [str(x) for x in loc if not (isinstance(x, int) or str(x).isdigit())]
170
+ param = ".".join(param_parts) if param_parts else None
171
+ message = msg
172
+ else:
173
+ param, message, code = None, "Invalid request", "invalid_value"
174
+
175
+ logger.warning(f"ValidationError: {param} - {message}")
176
+
177
+ return JSONResponse(
178
+ status_code=400,
179
+ content=error_response(
180
+ message=message,
181
+ error_type=ErrorType.INVALID_REQUEST.value,
182
+ param=param,
183
+ code=code
184
+ )
185
+ )
186
+
187
+
188
+ async def generic_exception_handler(request: Request, exc: Exception) -> JSONResponse:
189
+ """处理未捕获异常"""
190
+ logger.exception(f"Unhandled: {type(exc).__name__}: {str(exc)}")
191
+
192
+ return JSONResponse(
193
+ status_code=500,
194
+ content=error_response(
195
+ message="Internal server error",
196
+ error_type=ErrorType.SERVER.value,
197
+ code="internal_error"
198
+ )
199
+ )
200
+
201
+
202
+ # ============= 注册 =============
203
+
204
+ def register_exception_handlers(app):
205
+ """注册异常处理器"""
206
+ app.add_exception_handler(AppException, app_exception_handler)
207
+ app.add_exception_handler(HTTPException, http_exception_handler)
208
+ app.add_exception_handler(RequestValidationError, validation_exception_handler)
209
+ app.add_exception_handler(Exception, generic_exception_handler)
210
+ app.add_exception_handler(Exception, generic_exception_handler)
211
+
212
+
213
+ __all__ = [
214
+ "ErrorType",
215
+ "AppException",
216
+ "ValidationException",
217
+ "AuthenticationException",
218
+ "UpstreamException",
219
+ "error_response",
220
+ "register_exception_handlers",
221
+ ]
app/core/legacy_migration.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Legacy data migrations for local deployments (python/docker).
3
+
4
+ Goal: when upgrading the project, old on-disk data should still be readable and not lost.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import asyncio
10
+ import os
11
+ import shutil
12
+ import time
13
+ from pathlib import Path
14
+ from typing import Any, Dict
15
+
16
+ from app.core.logger import logger
17
+
18
+
19
+ def migrate_legacy_cache_dirs(data_dir: Path | None = None) -> Dict[str, Any]:
20
+ """
21
+ Migrate old cache directory layout:
22
+
23
+ - legacy: data/temp/{image,video}
24
+ - current: data/tmp/{image,video}
25
+
26
+ This keeps existing cached files (not yet cleaned) available after upgrades.
27
+ """
28
+
29
+ data_root = data_dir or (Path(__file__).parent.parent.parent / "data")
30
+ legacy_root = data_root / "temp"
31
+ current_root = data_root / "tmp"
32
+
33
+ if not legacy_root.exists() or not legacy_root.is_dir():
34
+ return {"migrated": False, "reason": "no_legacy_dir"}
35
+
36
+ lock_dir = data_root / ".locks"
37
+ lock_dir.mkdir(parents=True, exist_ok=True)
38
+
39
+ done_marker = lock_dir / "legacy_cache_dirs_v1.done"
40
+ if done_marker.exists():
41
+ return {"migrated": False, "reason": "already_done"}
42
+
43
+ lock_file = lock_dir / "legacy_cache_dirs_v1.lock"
44
+
45
+ # Best-effort cross-process lock (works on Windows/Linux).
46
+ fd: int | None = None
47
+ try:
48
+ try:
49
+ fd = os.open(str(lock_file), os.O_CREAT | os.O_EXCL | os.O_WRONLY)
50
+ except FileExistsError:
51
+ # Another worker/process is migrating. Wait briefly for completion.
52
+ deadline = time.monotonic() + 30.0
53
+ while time.monotonic() < deadline:
54
+ if done_marker.exists():
55
+ return {"migrated": False, "reason": "waited_for_other_process"}
56
+ time.sleep(0.2)
57
+ return {"migrated": False, "reason": "lock_timeout"}
58
+
59
+ current_root.mkdir(parents=True, exist_ok=True)
60
+
61
+ moved = 0
62
+ skipped = 0
63
+ errors = 0
64
+
65
+ for sub in ("image", "video"):
66
+ src_dir = legacy_root / sub
67
+ if not src_dir.exists() or not src_dir.is_dir():
68
+ continue
69
+
70
+ dst_dir = current_root / sub
71
+ dst_dir.mkdir(parents=True, exist_ok=True)
72
+
73
+ for item in src_dir.iterdir():
74
+ if not item.is_file():
75
+ continue
76
+ target = dst_dir / item.name
77
+ if target.exists():
78
+ skipped += 1
79
+ continue
80
+ try:
81
+ shutil.move(str(item), str(target))
82
+ moved += 1
83
+ except Exception:
84
+ errors += 1
85
+
86
+ # Cleanup empty legacy dirs (best-effort).
87
+ for sub in ("image", "video"):
88
+ p = legacy_root / sub
89
+ try:
90
+ if p.exists() and p.is_dir() and not any(p.iterdir()):
91
+ p.rmdir()
92
+ except Exception:
93
+ pass
94
+ try:
95
+ if legacy_root.exists() and legacy_root.is_dir() and not any(legacy_root.iterdir()):
96
+ legacy_root.rmdir()
97
+ except Exception:
98
+ pass
99
+
100
+ if errors == 0:
101
+ done_marker.write_text(str(int(time.time())), encoding="utf-8")
102
+ if moved or skipped or errors:
103
+ logger.info(
104
+ f"Legacy cache migration complete: moved={moved}, skipped={skipped}, errors={errors}"
105
+ )
106
+ return {"migrated": True, "moved": moved, "skipped": skipped, "errors": errors}
107
+ finally:
108
+ try:
109
+ if fd is not None:
110
+ os.close(fd)
111
+ except Exception:
112
+ pass
113
+ try:
114
+ if lock_file.exists():
115
+ lock_file.unlink()
116
+ except Exception:
117
+ pass
118
+
119
+
120
+ __all__ = ["migrate_legacy_cache_dirs", "migrate_legacy_account_settings"]
121
+
122
+
123
+ async def migrate_legacy_account_settings(
124
+ concurrency: int = 10,
125
+ data_dir: Path | None = None,
126
+ ) -> Dict[str, Any]:
127
+ """
128
+ After legacy data migration, run a one-time TOS + BirthDate + NSFW pass for existing accounts.
129
+
130
+ This is best-effort and guarded by a cross-process lock + done marker.
131
+ """
132
+
133
+ data_root = data_dir or (Path(__file__).parent.parent.parent / "data")
134
+ lock_dir = data_root / ".locks"
135
+ lock_dir.mkdir(parents=True, exist_ok=True)
136
+
137
+ done_marker = lock_dir / "legacy_accounts_tos_birth_nsfw_v2.done"
138
+ if done_marker.exists():
139
+ return {"migrated": False, "reason": "already_done"}
140
+
141
+ lock_file = lock_dir / "legacy_accounts_tos_birth_nsfw_v2.lock"
142
+ fd: int | None = None
143
+
144
+ try:
145
+ try:
146
+ fd = os.open(str(lock_file), os.O_CREAT | os.O_EXCL | os.O_WRONLY)
147
+ except FileExistsError:
148
+ deadline = time.monotonic() + 30.0
149
+ while time.monotonic() < deadline:
150
+ if done_marker.exists():
151
+ return {"migrated": False, "reason": "waited_for_other_process"}
152
+ await asyncio.sleep(0.2)
153
+ return {"migrated": False, "reason": "lock_timeout"}
154
+
155
+ from app.core.config import get_config
156
+ from app.core.storage import get_storage
157
+ from app.services.register.services import (
158
+ UserAgreementService,
159
+ BirthDateService,
160
+ NsfwSettingsService,
161
+ )
162
+
163
+ storage = get_storage()
164
+ try:
165
+ token_data = await storage.load_tokens()
166
+ except Exception as exc:
167
+ logger.warning("Legacy account migration: failed to load tokens: {}", exc)
168
+ return {"migrated": False, "reason": "load_tokens_failed"}
169
+
170
+ token_data = token_data or {}
171
+ tokens: list[str] = []
172
+ for items in token_data.values():
173
+ if not isinstance(items, list):
174
+ continue
175
+ for item in items:
176
+ if isinstance(item, str):
177
+ tokens.append(item)
178
+ elif isinstance(item, dict):
179
+ token_val = item.get("token")
180
+ if isinstance(token_val, str):
181
+ tokens.append(token_val)
182
+
183
+ # De-duplicate while preserving order.
184
+ tokens = list(dict.fromkeys([t.strip() for t in tokens if isinstance(t, str) and t.strip()]))
185
+ if not tokens:
186
+ done_marker.write_text(str(int(time.time())), encoding="utf-8")
187
+ return {"migrated": True, "total": 0, "ok": 0, "failed": 0}
188
+
189
+ try:
190
+ concurrency = max(1, int(concurrency))
191
+ except Exception:
192
+ concurrency = 10
193
+
194
+ cf_clearance = str(get_config("grok.cf_clearance", "") or "").strip()
195
+
196
+ def _extract_cookie_value(cookie_str: str, name: str) -> str | None:
197
+ needle = f"{name}="
198
+ if needle not in cookie_str:
199
+ return None
200
+ for part in cookie_str.split(";"):
201
+ part = part.strip()
202
+ if part.startswith(needle):
203
+ return part[len(needle):].strip()
204
+ return None
205
+
206
+ def _normalize_tokens(raw_token: str) -> tuple[str, str]:
207
+ raw_token = raw_token.strip()
208
+ if ";" in raw_token:
209
+ sso_val = _extract_cookie_value(raw_token, "sso") or ""
210
+ sso_rw_val = _extract_cookie_value(raw_token, "sso-rw") or sso_val
211
+ else:
212
+ sso_val = raw_token[4:] if raw_token.startswith("sso=") else raw_token
213
+ sso_rw_val = sso_val
214
+ return sso_val, sso_rw_val
215
+
216
+ def _apply_settings(raw_token: str) -> bool:
217
+ sso_val, sso_rw_val = _normalize_tokens(raw_token)
218
+ if not sso_val:
219
+ return False
220
+
221
+ user_service = UserAgreementService(cf_clearance=cf_clearance)
222
+ birth_service = BirthDateService(cf_clearance=cf_clearance)
223
+ nsfw_service = NsfwSettingsService(cf_clearance=cf_clearance)
224
+
225
+ tos_result = user_service.accept_tos_version(
226
+ sso=sso_val,
227
+ sso_rw=sso_rw_val or sso_val,
228
+ impersonate="chrome120",
229
+ )
230
+ if not tos_result.get("ok"):
231
+ return False
232
+
233
+ birth_result = birth_service.set_birth_date(
234
+ sso=sso_val,
235
+ sso_rw=sso_rw_val or sso_val,
236
+ impersonate="chrome120",
237
+ )
238
+ if not birth_result.get("ok"):
239
+ return False
240
+
241
+ nsfw_result = nsfw_service.enable_nsfw(
242
+ sso=sso_val,
243
+ sso_rw=sso_rw_val or sso_val,
244
+ impersonate="chrome120",
245
+ )
246
+ return bool(nsfw_result.get("ok"))
247
+
248
+ sem = asyncio.Semaphore(concurrency)
249
+
250
+ async def _run_one(token: str) -> bool:
251
+ async with sem:
252
+ return await asyncio.to_thread(_apply_settings, token)
253
+
254
+ tasks = [_run_one(token) for token in tokens]
255
+ results = await asyncio.gather(*tasks, return_exceptions=True)
256
+
257
+ ok = 0
258
+ failed = 0
259
+ for res in results:
260
+ if isinstance(res, Exception):
261
+ failed += 1
262
+ elif res:
263
+ ok += 1
264
+ else:
265
+ failed += 1
266
+
267
+ done_marker.write_text(str(int(time.time())), encoding="utf-8")
268
+ logger.info(
269
+ "Legacy account migration complete: total=%d, ok=%d, failed=%d",
270
+ len(tokens),
271
+ ok,
272
+ failed,
273
+ )
274
+ return {"migrated": True, "total": len(tokens), "ok": ok, "failed": failed}
275
+ finally:
276
+ try:
277
+ if fd is not None:
278
+ os.close(fd)
279
+ except Exception:
280
+ pass
281
+ try:
282
+ if lock_file.exists():
283
+ lock_file.unlink()
284
+ except Exception:
285
+ pass
app/core/logger.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 结构化 JSON 日志 - 极简格式
3
+ """
4
+
5
+ import sys
6
+ import json
7
+ import traceback
8
+ from pathlib import Path
9
+ from loguru import logger
10
+
11
+ # 日志目录
12
+ LOG_DIR = Path(__file__).parent.parent.parent / "logs"
13
+ LOG_DIR.mkdir(parents=True, exist_ok=True)
14
+
15
+
16
+ def _format_json(record) -> str:
17
+ """格式化日志"""
18
+ # ISO8601 时间
19
+ time_str = record["time"].strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]
20
+ tz = record["time"].strftime("%z")
21
+ if tz:
22
+ time_str += tz[:3] + ":" + tz[3:]
23
+
24
+ log_entry = {
25
+ "time": time_str,
26
+ "level": record["level"].name.lower(),
27
+ "msg": record["message"],
28
+ "caller": f"{record['file'].name}:{record['line']}",
29
+ }
30
+
31
+ # trace 上下文
32
+ extra = record["extra"]
33
+ if extra.get("traceID"):
34
+ log_entry["traceID"] = extra["traceID"]
35
+ if extra.get("spanID"):
36
+ log_entry["spanID"] = extra["spanID"]
37
+
38
+ # 其他 extra 字段
39
+ for key, value in extra.items():
40
+ if key not in ("traceID", "spanID") and not key.startswith("_"):
41
+ log_entry[key] = value
42
+
43
+ # 错误及以上级别添加堆栈跟踪
44
+ if record["level"].no >= 40 and record["exception"]:
45
+ log_entry["stacktrace"] = "".join(traceback.format_exception(
46
+ record["exception"].type,
47
+ record["exception"].value,
48
+ record["exception"].traceback
49
+ ))
50
+
51
+ return json.dumps(log_entry, ensure_ascii=False)
52
+
53
+
54
+ def _make_json_sink(output):
55
+ """创建 JSON sink"""
56
+ def sink(message):
57
+ json_str = _format_json(message.record)
58
+ print(json_str, file=output, flush=True)
59
+ return sink
60
+
61
+
62
+ def _file_json_sink(message):
63
+ """写入日志文件"""
64
+ record = message.record
65
+ json_str = _format_json(record)
66
+ log_file = LOG_DIR / f"app_{record['time'].strftime('%Y-%m-%d')}.log"
67
+ with open(log_file, "a", encoding="utf-8") as f:
68
+ f.write(json_str + "\n")
69
+
70
+
71
+ def setup_logging(
72
+ level: str = "DEBUG",
73
+ json_console: bool = True,
74
+ file_logging: bool = True,
75
+ ):
76
+ """设置日志配置"""
77
+ logger.remove()
78
+
79
+ # 控制台输出
80
+ if json_console:
81
+ logger.add(
82
+ _make_json_sink(sys.stdout),
83
+ level=level,
84
+ format="{message}",
85
+ colorize=False,
86
+ )
87
+ else:
88
+ logger.add(
89
+ sys.stdout,
90
+ level=level,
91
+ format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{file.name}:{line}</cyan> - <level>{message}</level>",
92
+ colorize=True,
93
+ )
94
+
95
+ # 文件输出
96
+ if file_logging:
97
+ logger.add(
98
+ _file_json_sink,
99
+ level=level,
100
+ format="{message}",
101
+ enqueue=True,
102
+ )
103
+
104
+ return logger
105
+
106
+
107
+ def get_logger(trace_id: str = "", span_id: str = ""):
108
+ """获取绑定了 trace 上下文的 logger"""
109
+ bound = {}
110
+ if trace_id:
111
+ bound["traceID"] = trace_id
112
+ if span_id:
113
+ bound["spanID"] = span_id
114
+ return logger.bind(**bound) if bound else logger
115
+
116
+
117
+ __all__ = ["logger", "setup_logging", "get_logger", "LOG_DIR"]
app/core/response_middleware.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 响应中间件
3
+ Response Middleware
4
+
5
+ 用于记录请求日志、生成 TraceID 和计算请求耗时
6
+ """
7
+
8
+ import time
9
+ import uuid
10
+ from starlette.middleware.base import BaseHTTPMiddleware
11
+ from starlette.requests import Request
12
+ from starlette.types import ASGIApp
13
+
14
+ from app.core.logger import logger
15
+
16
+ class ResponseLoggerMiddleware(BaseHTTPMiddleware):
17
+ """
18
+ 请求日志/响应追踪中间件
19
+ Request Logging and Response Tracking Middleware
20
+ """
21
+
22
+ async def dispatch(self, request: Request, call_next):
23
+ # 生成请求 ID
24
+ trace_id = str(uuid.uuid4())
25
+ request.state.trace_id = trace_id
26
+
27
+ start_time = time.time()
28
+
29
+ # 记录请求信息
30
+ logger.info(
31
+ f"Request: {request.method} {request.url.path}",
32
+ extra={
33
+ "traceID": trace_id,
34
+ "method": request.method,
35
+ "path": request.url.path
36
+ }
37
+ )
38
+
39
+ try:
40
+ response = await call_next(request)
41
+
42
+ # 计算耗时
43
+ duration = (time.time() - start_time) * 1000
44
+
45
+ # 记录响应信息
46
+ logger.info(
47
+ f"Response: {request.method} {request.url.path} - {response.status_code} ({duration:.2f}ms)",
48
+ extra={
49
+ "traceID": trace_id,
50
+ "method": request.method,
51
+ "path": request.url.path,
52
+ "status": response.status_code,
53
+ "duration_ms": round(duration, 2)
54
+ }
55
+ )
56
+
57
+ return response
58
+
59
+ except Exception as e:
60
+ duration = (time.time() - start_time) * 1000
61
+ logger.error(
62
+ f"Response Error: {request.method} {request.url.path} - {str(e)} ({duration:.2f}ms)",
63
+ extra={
64
+ "traceID": trace_id,
65
+ "method": request.method,
66
+ "path": request.url.path,
67
+ "duration_ms": round(duration, 2),
68
+ "error": str(e)
69
+ }
70
+ )
71
+ raise e
app/core/storage.py ADDED
@@ -0,0 +1,720 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 统一存储服务 (Professional Storage Service)
3
+ 支持 Local (TOML), Redis, MySQL, PostgreSQL
4
+
5
+ 特性:
6
+ - 全异步 I/O (Async I/O)
7
+ - 连接池管理 (Connection Pooling)
8
+ - 分布式/本地锁 (Distributed/Local Locking)
9
+ - 内存优化 (序列化性能优化)
10
+ """
11
+
12
+ import abc
13
+ import os
14
+ import asyncio
15
+ import os
16
+ import hashlib
17
+ import time
18
+ import tomllib
19
+ from typing import Any, Dict, Optional
20
+ from pathlib import Path
21
+ from enum import Enum
22
+ try:
23
+ import fcntl
24
+ except ImportError: # pragma: no cover - non-posix platforms
25
+ fcntl = None
26
+ from contextlib import asynccontextmanager
27
+
28
+ import orjson
29
+ import aiofiles
30
+ from app.core.logger import logger
31
+
32
+ # 配置文件路径
33
+ CONFIG_FILE = Path(__file__).parent.parent.parent / "data" / "config.toml"
34
+ TOKEN_FILE = Path(__file__).parent.parent.parent / "data" / "token.json"
35
+ LOCK_DIR = Path(__file__).parent.parent.parent / "data" / ".locks"
36
+
37
+ # JSON 序列化优化助手函数
38
+ def json_dumps(obj: Any) -> str:
39
+ return orjson.dumps(obj).decode("utf-8")
40
+
41
+ def json_loads(obj: str | bytes) -> Any:
42
+ return orjson.loads(obj)
43
+
44
+ class StorageError(Exception):
45
+ """存储服务基础异常"""
46
+ pass
47
+
48
+ class BaseStorage(abc.ABC):
49
+ """存储基类"""
50
+
51
+ @abc.abstractmethod
52
+ async def load_config(self) -> Dict[str, Any]:
53
+ """加载配置"""
54
+ pass
55
+
56
+ @abc.abstractmethod
57
+ async def save_config(self, data: Dict[str, Any]):
58
+ """保存配置"""
59
+ pass
60
+
61
+ @abc.abstractmethod
62
+ async def load_tokens(self) -> Dict[str, Any]:
63
+ """加载所有 Token"""
64
+ pass
65
+
66
+ @abc.abstractmethod
67
+ async def save_tokens(self, data: Dict[str, Any]):
68
+ """保存所有 Token"""
69
+ pass
70
+
71
+ @abc.abstractmethod
72
+ async def close(self):
73
+ """关闭资源"""
74
+ pass
75
+
76
+ @asynccontextmanager
77
+ async def acquire_lock(self, name: str, timeout: int = 10):
78
+ """
79
+ 获取锁 (互斥访问)
80
+ 用于读写操作的临界区保护
81
+
82
+ Args:
83
+ name: 锁名称
84
+ timeout: 超时时间 (秒)
85
+ """
86
+ # 默认空实现,用于 fallback
87
+ yield
88
+
89
+ async def verify_connection(self) -> bool:
90
+ """健康检查"""
91
+ return True
92
+
93
+
94
+ class LocalStorage(BaseStorage):
95
+ """
96
+ 本地文件存储
97
+ - 使用 aiofiles 进行异步 I/O
98
+ - 使用 asyncio.Lock 进行进程内并发控制
99
+ - 如果需要多进程安全,需要系统级文件锁 (fcntl)
100
+ """
101
+
102
+ def __init__(self):
103
+ self._lock = asyncio.Lock()
104
+
105
+ @asynccontextmanager
106
+ async def acquire_lock(self, name: str, timeout: int = 10):
107
+ if fcntl is None:
108
+ try:
109
+ async with asyncio.timeout(timeout):
110
+ async with self._lock:
111
+ yield
112
+ except asyncio.TimeoutError:
113
+ logger.warning(f"LocalStorage: 获取锁 '{name}' 超时 ({timeout}s)")
114
+ raise StorageError(f"无法获取锁 '{name}'")
115
+ return
116
+
117
+ lock_path = LOCK_DIR / f"{name}.lock"
118
+ lock_path.parent.mkdir(parents=True, exist_ok=True)
119
+ fd = None
120
+ locked = False
121
+ start = time.monotonic()
122
+
123
+ async with self._lock:
124
+ try:
125
+ fd = open(lock_path, "a+")
126
+ while True:
127
+ try:
128
+ fcntl.flock(fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
129
+ locked = True
130
+ break
131
+ except BlockingIOError:
132
+ if time.monotonic() - start >= timeout:
133
+ raise StorageError(f"无法获取锁 '{name}'")
134
+ await asyncio.sleep(0.05)
135
+ yield
136
+ except StorageError:
137
+ logger.warning(f"LocalStorage: 获取锁 '{name}' 超时 ({timeout}s)")
138
+ raise
139
+ finally:
140
+ if fd:
141
+ if locked:
142
+ try:
143
+ fcntl.flock(fd, fcntl.LOCK_UN)
144
+ except Exception:
145
+ pass
146
+ try:
147
+ fd.close()
148
+ except Exception:
149
+ pass
150
+
151
+ async def load_config(self) -> Dict[str, Any]:
152
+ if not CONFIG_FILE.exists():
153
+ return {}
154
+ try:
155
+ async with aiofiles.open(CONFIG_FILE, "rb") as f:
156
+ content = await f.read()
157
+ return tomllib.loads(content.decode("utf-8"))
158
+ except Exception as e:
159
+ logger.error(f"LocalStorage: 加载配置失败: {e}")
160
+ return {}
161
+
162
+ async def save_config(self, data: Dict[str, Any]):
163
+ try:
164
+ lines = []
165
+ for section, items in data.items():
166
+ if not isinstance(items, dict): continue
167
+ lines.append(f"[{section}]")
168
+ for key, val in items.items():
169
+ if isinstance(val, bool):
170
+ val_str = "true" if val else "false"
171
+ elif isinstance(val, str):
172
+ escaped = val.replace('"', '\\"')
173
+ val_str = f'"{escaped}"'
174
+ elif isinstance(val, (int, float)):
175
+ val_str = str(val)
176
+ elif isinstance(val, (list, dict)):
177
+ val_str = json_dumps(val)
178
+ else:
179
+ val_str = f'"{str(val)}"'
180
+ lines.append(f"{key} = {val_str}")
181
+ lines.append("")
182
+
183
+ content = "\n".join(lines)
184
+
185
+ CONFIG_FILE.parent.mkdir(parents=True, exist_ok=True)
186
+ async with aiofiles.open(CONFIG_FILE, "w", encoding="utf-8") as f:
187
+ await f.write(content)
188
+ except Exception as e:
189
+ logger.error(f"LocalStorage: 保存配置失败: {e}")
190
+ raise StorageError(f"保存配置失败: {e}")
191
+
192
+ async def load_tokens(self) -> Dict[str, Any]:
193
+ if not TOKEN_FILE.exists():
194
+ return {}
195
+ try:
196
+ async with aiofiles.open(TOKEN_FILE, "rb") as f:
197
+ content = await f.read()
198
+ return json_loads(content)
199
+ except Exception as e:
200
+ logger.error(f"LocalStorage: 加载 Token 失败: {e}")
201
+ return {}
202
+
203
+ async def save_tokens(self, data: Dict[str, Any]):
204
+ try:
205
+ TOKEN_FILE.parent.mkdir(parents=True, exist_ok=True)
206
+ temp_path = TOKEN_FILE.with_suffix('.tmp')
207
+
208
+ # 原子写操作: 写入临时文件 -> 重命名
209
+ async with aiofiles.open(temp_path, "wb") as f:
210
+ await f.write(orjson.dumps(data, option=orjson.OPT_INDENT_2))
211
+
212
+ # 使用 os.replace 保证原子性
213
+ os.replace(temp_path, TOKEN_FILE)
214
+
215
+ except Exception as e:
216
+ logger.error(f"LocalStorage: 保存 Token 失败: {e}")
217
+ raise StorageError(f"保存 Token 失败: {e}")
218
+
219
+ async def close(self):
220
+ pass
221
+
222
+
223
+ class RedisStorage(BaseStorage):
224
+ """
225
+ Redis 存储
226
+ - 使用 redis-py 异步客户端 (自带连接池)
227
+ - 支持分布式锁 (redis.lock)
228
+ - 扁平化数据结构优化性能
229
+ """
230
+
231
+ def __init__(self, url: str):
232
+ try:
233
+ from redis import asyncio as aioredis
234
+ from redis.asyncio.lock import Lock
235
+ except ImportError:
236
+ raise ImportError("需要安装 redis 包: pip install redis")
237
+
238
+ # 显式配置连接池
239
+ # 使用 decode_responses=True 简化字符串处理,但在处理复杂对象时使用 orjson
240
+ self.redis = aioredis.from_url(
241
+ url,
242
+ decode_responses=True,
243
+ health_check_interval=30
244
+ )
245
+ self.config_key = "grok2api:config" # Hash: section.key -> value_json
246
+ self.key_pools = "grok2api:pools" # Set: pool_names
247
+ self.prefix_pool_set = "grok2api:pool:" # Set: pool -> token_ids
248
+ self.prefix_token_hash = "grok2api:token:"# Hash: token_id -> token_data
249
+ self.lock_prefix = "grok2api:lock:"
250
+
251
+ @asynccontextmanager
252
+ async def acquire_lock(self, name: str, timeout: int = 10):
253
+ # 使用 Redis 分布式锁
254
+ lock_key = f"{self.lock_prefix}{name}"
255
+ lock = self.redis.lock(lock_key, timeout=timeout, blocking_timeout=5)
256
+ acquired = False
257
+ try:
258
+ acquired = await lock.acquire()
259
+ if not acquired:
260
+ raise StorageError(f"RedisStorage: 无法获取锁 '{name}'")
261
+ yield
262
+ finally:
263
+ if acquired:
264
+ try:
265
+ await lock.release()
266
+ except Exception:
267
+ # 锁可能已过期或被意外释放,忽略异常
268
+ pass
269
+
270
+ async def verify_connection(self) -> bool:
271
+ try:
272
+ return await self.redis.ping()
273
+ except Exception:
274
+ return False
275
+
276
+ async def load_config(self) -> Dict[str, Any]:
277
+ """从 Redis Hash 加载配置"""
278
+ try:
279
+ raw_data = await self.redis.hgetall(self.config_key)
280
+ if not raw_data:
281
+ return None
282
+
283
+ config = {}
284
+ for composite_key, val_str in raw_data.items():
285
+ if "." not in composite_key: continue
286
+ section, key = composite_key.split(".", 1)
287
+
288
+ if section not in config: config[section] = {}
289
+
290
+ try:
291
+ val = json_loads(val_str)
292
+ except:
293
+ val = val_str
294
+ config[section][key] = val
295
+ return config
296
+ except Exception as e:
297
+ logger.error(f"RedisStorage: 加载配置失败: {e}")
298
+ return None
299
+
300
+ async def save_config(self, data: Dict[str, Any]):
301
+ """保存配置到 Redis Hash"""
302
+ if not data: return
303
+ try:
304
+ mapping = {}
305
+ for section, items in data.items():
306
+ if not isinstance(items, dict): continue
307
+ for key, val in items.items():
308
+ composite_key = f"{section}.{key}"
309
+ mapping[composite_key] = json_dumps(val)
310
+
311
+ if mapping:
312
+ await self.redis.hset(self.config_key, mapping=mapping)
313
+ except Exception as e:
314
+ logger.error(f"RedisStorage: 保存配置失败: {e}")
315
+ raise
316
+
317
+ async def load_tokens(self) -> Dict[str, Any]:
318
+ """加载所有 Token"""
319
+ try:
320
+ pool_names = await self.redis.smembers(self.key_pools)
321
+ if not pool_names: return None
322
+
323
+ pools = {}
324
+ async with self.redis.pipeline() as pipe:
325
+ for pool_name in pool_names:
326
+ # 获取该池下所有 Token ID
327
+ pipe.smembers(f"{self.prefix_pool_set}{pool_name}")
328
+ pool_tokens_res = await pipe.execute()
329
+
330
+ # 收集所有 Token ID 以便批量查询
331
+ all_token_ids = []
332
+ pool_map = {} # pool_name -> list[token_id]
333
+
334
+ for i, pool_name in enumerate(pool_names):
335
+ tids = list(pool_tokens_res[i])
336
+ pool_map[pool_name] = tids
337
+ all_token_ids.extend(tids)
338
+
339
+ if not all_token_ids:
340
+ return {name: [] for name in pool_names}
341
+
342
+ # 批量获取 Token 详情 (Hash)
343
+ async with self.redis.pipeline() as pipe:
344
+ for tid in all_token_ids:
345
+ pipe.hgetall(f"{self.prefix_token_hash}{tid}")
346
+ token_data_list = await pipe.execute()
347
+
348
+ # 重组数据结构
349
+ token_lookup = {}
350
+ for i, tid in enumerate(all_token_ids):
351
+ t_data = token_data_list[i]
352
+ if not t_data: continue
353
+
354
+ # 恢复 tags (JSON -> List)
355
+ if "tags" in t_data:
356
+ try: t_data["tags"] = json_loads(t_data["tags"])
357
+ except: t_data["tags"] = []
358
+
359
+ # 类型转换 (Redis 返回全 string)
360
+ for int_field in ["quota", "created_at", "use_count", "fail_count", "last_used_at", "last_fail_at", "last_sync_at"]:
361
+ if t_data.get(int_field) and t_data[int_field] != "None":
362
+ try: t_data[int_field] = int(t_data[int_field])
363
+ except: pass
364
+
365
+ token_lookup[tid] = t_data
366
+
367
+ # 按 Pool 分组返回
368
+ for pool_name in pool_names:
369
+ pools[pool_name] = []
370
+ for tid in pool_map[pool_name]:
371
+ if tid in token_lookup:
372
+ pools[pool_name].append(token_lookup[tid])
373
+
374
+ return pools
375
+
376
+ except Exception as e:
377
+ logger.error(f"RedisStorage: 加载 Token 失败: {e}")
378
+ return None
379
+
380
+ async def save_tokens(self, data: Dict[str, Any]):
381
+ """保存所有 Token"""
382
+ if data is None:
383
+ return
384
+ try:
385
+ new_pools = set(data.keys()) if isinstance(data, dict) else set()
386
+ pool_tokens_map = {}
387
+ new_token_ids = set()
388
+
389
+ for pool_name, tokens in (data or {}).items():
390
+ tids_in_pool = []
391
+ for t in tokens:
392
+ token_str = t.get("token")
393
+ if not token_str:
394
+ continue
395
+ tids_in_pool.append(token_str)
396
+ new_token_ids.add(token_str)
397
+ pool_tokens_map[pool_name] = tids_in_pool
398
+
399
+ existing_pools = await self.redis.smembers(self.key_pools)
400
+ existing_pools = set(existing_pools) if existing_pools else set()
401
+
402
+ existing_token_ids = set()
403
+ if existing_pools:
404
+ async with self.redis.pipeline() as pipe:
405
+ for pool_name in existing_pools:
406
+ pipe.smembers(f"{self.prefix_pool_set}{pool_name}")
407
+ pool_tokens_res = await pipe.execute()
408
+ for tokens in pool_tokens_res:
409
+ existing_token_ids.update(list(tokens or []))
410
+
411
+ tokens_to_delete = existing_token_ids - new_token_ids
412
+ all_pools = existing_pools.union(new_pools)
413
+
414
+ async with self.redis.pipeline() as pipe:
415
+ # Reset pool index
416
+ pipe.delete(self.key_pools)
417
+ if new_pools:
418
+ pipe.sadd(self.key_pools, *new_pools)
419
+
420
+ # Reset pool sets
421
+ for pool_name in all_pools:
422
+ pipe.delete(f"{self.prefix_pool_set}{pool_name}")
423
+ for pool_name, tids_in_pool in pool_tokens_map.items():
424
+ if tids_in_pool:
425
+ pipe.sadd(f"{self.prefix_pool_set}{pool_name}", *tids_in_pool)
426
+
427
+ # Remove deleted token hashes
428
+ for token_str in tokens_to_delete:
429
+ pipe.delete(f"{self.prefix_token_hash}{token_str}")
430
+
431
+ # Upsert token hashes
432
+ for pool_name, tokens in (data or {}).items():
433
+ for t in tokens:
434
+ token_str = t.get("token")
435
+ if not token_str:
436
+ continue
437
+ t_flat = t.copy()
438
+ if "tags" in t_flat:
439
+ t_flat["tags"] = json_dumps(t_flat["tags"])
440
+ status = t_flat.get("status")
441
+ if isinstance(status, str) and status.startswith("TokenStatus."):
442
+ t_flat["status"] = status.split(".", 1)[1].lower()
443
+ elif isinstance(status, Enum):
444
+ t_flat["status"] = status.value
445
+ t_flat = {k: str(v) for k, v in t_flat.items() if v is not None}
446
+ pipe.hset(f"{self.prefix_token_hash}{token_str}", mapping=t_flat)
447
+
448
+ await pipe.execute()
449
+
450
+ except Exception as e:
451
+ logger.error(f"RedisStorage: 保存 Token 失败: {e}")
452
+ raise
453
+
454
+ async def close(self):
455
+ try:
456
+ await self.redis.close()
457
+ except (RuntimeError, asyncio.CancelledError, Exception):
458
+ # 忽略关闭时的 Event loop is closed 错误
459
+ pass
460
+
461
+
462
+ class SQLStorage(BaseStorage):
463
+ """
464
+ SQL 数据库存储 (MySQL/PgSQL)
465
+ - 使用 SQLAlchemy 异步引擎
466
+ - 自动 Schema 初始化
467
+ - 内置连接池 (QueuePool)
468
+ """
469
+
470
+ def __init__(self, url: str):
471
+ try:
472
+ from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
473
+ from sqlalchemy import text, MetaData
474
+ except ImportError:
475
+ raise ImportError("需要安装 sqlalchemy 和 async 驱动: pip install sqlalchemy[asyncio]")
476
+
477
+ self.dialect = url.split(":", 1)[0].split("+", 1)[0].lower()
478
+
479
+ # 配置 robust 的连接池
480
+ self.engine = create_async_engine(
481
+ url,
482
+ echo=False,
483
+ pool_size=20,
484
+ max_overflow=10,
485
+ pool_recycle=3600,
486
+ pool_pre_ping=True
487
+ )
488
+ self.async_session = async_sessionmaker(self.engine, expire_on_commit=False)
489
+ self._initialized = False
490
+
491
+ async def _ensure_schema(self):
492
+ """确保数据库表存在"""
493
+ if self._initialized: return
494
+ try:
495
+ async with self.engine.begin() as conn:
496
+ from sqlalchemy import text
497
+
498
+ # Tokens 表 (通用 SQL)
499
+ await conn.execute(text("""
500
+ CREATE TABLE IF NOT EXISTS tokens (
501
+ token VARCHAR(512) PRIMARY KEY,
502
+ pool_name VARCHAR(64) NOT NULL,
503
+ data TEXT,
504
+ updated_at BIGINT
505
+ )
506
+ """))
507
+
508
+ # 配置表
509
+ await conn.execute(text("""
510
+ CREATE TABLE IF NOT EXISTS app_config (
511
+ section VARCHAR(64) NOT NULL,
512
+ key_name VARCHAR(64) NOT NULL,
513
+ value TEXT,
514
+ PRIMARY KEY (section, key_name)
515
+ )
516
+ """))
517
+
518
+ # 索引
519
+ try:
520
+ await conn.execute(text("CREATE INDEX idx_tokens_pool ON tokens (pool_name)"))
521
+ except Exception:
522
+ pass
523
+
524
+ # 尝试兼容旧表结构
525
+ try:
526
+ if self.dialect in ("mysql", "mariadb"):
527
+ await conn.execute(text("ALTER TABLE tokens MODIFY token VARCHAR(512)"))
528
+ await conn.execute(text("ALTER TABLE tokens MODIFY data TEXT"))
529
+ elif self.dialect in ("postgres", "postgresql", "pgsql"):
530
+ await conn.execute(text("ALTER TABLE tokens ALTER COLUMN token TYPE VARCHAR(512)"))
531
+ await conn.execute(text("ALTER TABLE tokens ALTER COLUMN data TYPE TEXT"))
532
+ except Exception:
533
+ pass
534
+
535
+ self._initialized = True
536
+ except Exception as e:
537
+ logger.error(f"SQLStorage: Schema 初始化失败: {e}")
538
+ raise
539
+
540
+ @asynccontextmanager
541
+ async def acquire_lock(self, name: str, timeout: int = 10):
542
+ # SQL 分布式锁: MySQL GET_LOCK / PG advisory_lock
543
+ from sqlalchemy import text
544
+ lock_name = f"g2a:{hashlib.sha1(name.encode('utf-8')).hexdigest()[:24]}"
545
+ if self.dialect in ("mysql", "mariadb"):
546
+ async with self.async_session() as session:
547
+ res = await session.execute(
548
+ text("SELECT GET_LOCK(:name, :timeout)"),
549
+ {"name": lock_name, "timeout": timeout}
550
+ )
551
+ got = res.scalar()
552
+ if got != 1:
553
+ raise StorageError(f"SQLStorage: 无法获取锁 '{name}'")
554
+ try:
555
+ yield
556
+ finally:
557
+ try:
558
+ await session.execute(text("SELECT RELEASE_LOCK(:name)"), {"name": lock_name})
559
+ await session.commit()
560
+ except Exception:
561
+ pass
562
+ elif self.dialect in ("postgres", "postgresql", "pgsql"):
563
+ lock_key = int.from_bytes(hashlib.sha256(name.encode("utf-8")).digest()[:8], "big", signed=False)
564
+ async with self.async_session() as session:
565
+ start = time.monotonic()
566
+ while True:
567
+ res = await session.execute(
568
+ text("SELECT pg_try_advisory_lock(:key)"),
569
+ {"key": lock_key}
570
+ )
571
+ if res.scalar():
572
+ break
573
+ if time.monotonic() - start >= timeout:
574
+ raise StorageError(f"SQLStorage: 无法获取锁 '{name}'")
575
+ await asyncio.sleep(0.1)
576
+ try:
577
+ yield
578
+ finally:
579
+ try:
580
+ await session.execute(text("SELECT pg_advisory_unlock(:key)"), {"key": lock_key})
581
+ await session.commit()
582
+ except Exception:
583
+ pass
584
+ else:
585
+ yield
586
+
587
+ async def load_config(self) -> Dict[str, Any]:
588
+ await self._ensure_schema()
589
+ from sqlalchemy import text
590
+ try:
591
+ async with self.async_session() as session:
592
+ res = await session.execute(text("SELECT section, key_name, value FROM app_config"))
593
+ rows = res.fetchall()
594
+ if not rows: return None
595
+
596
+ config = {}
597
+ for section, key, val_str in rows:
598
+ if section not in config: config[section] = {}
599
+ try:
600
+ val = json_loads(val_str)
601
+ except:
602
+ val = val_str
603
+ config[section][key] = val
604
+ return config
605
+ except Exception as e:
606
+ logger.error(f"SQLStorage: 加载配置失败: {e}")
607
+ return None
608
+
609
+ async def save_config(self, data: Dict[str, Any]):
610
+ await self._ensure_schema()
611
+ from sqlalchemy import text
612
+ try:
613
+ async with self.async_session() as session:
614
+ for section, items in data.items():
615
+ if not isinstance(items, dict): continue
616
+ for key, val in items.items():
617
+ val_str = json_dumps(val)
618
+
619
+ # Upsert 逻辑 (简单实现: Delete + Insert)
620
+ await session.execute(
621
+ text("DELETE FROM app_config WHERE section=:s AND key_name=:k"),
622
+ {"s": section, "k": key}
623
+ )
624
+ await session.execute(
625
+ text("INSERT INTO app_config (section, key_name, value) VALUES (:s, :k, :v)"),
626
+ {"s": section, "k": key, "v": val_str}
627
+ )
628
+ await session.commit()
629
+ except Exception as e:
630
+ logger.error(f"SQLStorage: 保存配置失败: {e}")
631
+ raise
632
+
633
+ async def load_tokens(self) -> Dict[str, Any]:
634
+ await self._ensure_schema()
635
+ from sqlalchemy import text
636
+ try:
637
+ async with self.async_session() as session:
638
+ res = await session.execute(text("SELECT pool_name, data FROM tokens"))
639
+ rows = res.fetchall()
640
+ if not rows: return None
641
+
642
+ pools = {}
643
+ for pool_name, data_json in rows:
644
+ if pool_name not in pools: pools[pool_name] = []
645
+
646
+ try:
647
+ if isinstance(data_json, str):
648
+ t_data = json_loads(data_json)
649
+ else:
650
+ t_data = data_json
651
+ pools[pool_name].append(t_data)
652
+ except:
653
+ pass
654
+ return pools
655
+ except Exception as e:
656
+ logger.error(f"SQLStorage: 加载 Token 失败: {e}")
657
+ return None
658
+
659
+ async def save_tokens(self, data: Dict[str, Any]):
660
+ await self._ensure_schema()
661
+ from sqlalchemy import text
662
+ try:
663
+ async with self.async_session() as session:
664
+ await session.execute(text("DELETE FROM tokens"))
665
+
666
+ params = []
667
+ for pool_name, tokens in data.items():
668
+ for t in tokens:
669
+ params.append({
670
+ "token": t.get("token"),
671
+ "pool_name": pool_name,
672
+ "data": json_dumps(t),
673
+ "updated_at": 0
674
+ })
675
+
676
+ if params:
677
+ # 批量插入
678
+ await session.execute(
679
+ text("INSERT INTO tokens (token, pool_name, data, updated_at) VALUES (:token, :pool_name, :data, :updated_at)"),
680
+ params
681
+ )
682
+ await session.commit()
683
+ except Exception as e:
684
+ logger.error(f"SQLStorage: 保存 Token 失败: {e}")
685
+ raise
686
+
687
+ async def close(self):
688
+ await self.engine.dispose()
689
+
690
+
691
+ class StorageFactory:
692
+ """存储后端工厂"""
693
+ _instance: Optional[BaseStorage] = None
694
+
695
+ @classmethod
696
+ def get_storage(cls) -> BaseStorage:
697
+ """获取全局存储实例 (单例)"""
698
+ if cls._instance:
699
+ return cls._instance
700
+
701
+ storage_type = os.getenv("SERVER_STORAGE_TYPE", "local").lower()
702
+ storage_url = os.getenv("SERVER_STORAGE_URL", "")
703
+
704
+ logger.info(f"StorageFactory: 初始化存储后端: {storage_type}")
705
+
706
+ if storage_type == "redis":
707
+ if not storage_url: raise ValueError("Redis 存储需要设置 SERVER_STORAGE_URL")
708
+ cls._instance = RedisStorage(storage_url)
709
+
710
+ elif storage_type in ("mysql", "pgsql"):
711
+ if not storage_url: raise ValueError("SQL 存储需要设置 SERVER_STORAGE_URL")
712
+ cls._instance = SQLStorage(storage_url)
713
+
714
+ else:
715
+ cls._instance = LocalStorage()
716
+
717
+ return cls._instance
718
+
719
+ def get_storage() -> BaseStorage:
720
+ return StorageFactory.get_storage()
app/services/api_keys.py ADDED
@@ -0,0 +1,432 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """API Key 管理器 - 多用户密钥管理"""
2
+
3
+ import orjson
4
+ import time
5
+ import os
6
+ import secrets
7
+ import asyncio
8
+ from datetime import datetime, timezone, timedelta
9
+ from typing import List, Dict, Optional, Any, Tuple
10
+ from pathlib import Path
11
+
12
+ from app.core.logger import logger
13
+ from app.core.config import get_config
14
+
15
+
16
+ class ApiKeyManager:
17
+ """API Key 管理服务"""
18
+
19
+ _instance = None
20
+
21
+ def __new__(cls):
22
+ if cls._instance is None:
23
+ cls._instance = super().__new__(cls)
24
+ return cls._instance
25
+
26
+ def __init__(self):
27
+ if hasattr(self, '_initialized'):
28
+ return
29
+
30
+ self.file_path = Path(__file__).parents[2] / "data" / "api_keys.json"
31
+ self.usage_path = Path(__file__).parents[2] / "data" / "api_key_usage.json"
32
+ self._keys: List[Dict] = []
33
+ self._lock = asyncio.Lock()
34
+ self._loaded = False
35
+
36
+ self._usage: Dict[str, Dict[str, Dict[str, int]]] = {}
37
+ self._usage_lock = asyncio.Lock()
38
+ self._usage_loaded = False
39
+
40
+ self._initialized = True
41
+ logger.debug(f"[ApiKey] 初始化完成: {self.file_path}")
42
+
43
+ async def init(self):
44
+ """初始化加载数据"""
45
+ if not self._loaded:
46
+ await self._load_data()
47
+ if not self._usage_loaded:
48
+ await self._load_usage_data()
49
+
50
+ async def _load_data(self):
51
+ """加载 API Keys"""
52
+ if self._loaded:
53
+ return
54
+
55
+ if not self.file_path.exists():
56
+ self._keys = []
57
+ self._loaded = True
58
+ return
59
+
60
+ try:
61
+ async with self._lock:
62
+ content = await asyncio.to_thread(self.file_path.read_bytes)
63
+ if content:
64
+ data = orjson.loads(content)
65
+ if isinstance(data, list):
66
+ out: List[Dict[str, Any]] = []
67
+ for item in data:
68
+ if not isinstance(item, dict):
69
+ continue
70
+ row = self._normalize_key_row(item)
71
+ if row.get("key"):
72
+ out.append(row)
73
+ self._keys = out
74
+ else:
75
+ self._keys = []
76
+ else:
77
+ self._keys = []
78
+ self._loaded = True
79
+ logger.debug(f"[ApiKey] 加载了 {len(self._keys)} 个 API Key")
80
+ except Exception as e:
81
+ logger.error(f"[ApiKey] 加载失败: {e}")
82
+ self._keys = []
83
+ self._loaded = True # 即使加载失败也认为已尝试加载,防止后续保存清空数据(或者抛出异常)
84
+
85
+ async def _save_data(self):
86
+ """保存 API Keys"""
87
+ if not self._loaded:
88
+ logger.warning("[ApiKey] 尝试在数据未加载时保存,已取消操作以防覆盖数据")
89
+ return
90
+
91
+ try:
92
+ # 确保目录存在
93
+ self.file_path.parent.mkdir(parents=True, exist_ok=True)
94
+
95
+ async with self._lock:
96
+ content = orjson.dumps(self._keys, option=orjson.OPT_INDENT_2)
97
+ await asyncio.to_thread(self.file_path.write_bytes, content)
98
+ except Exception as e:
99
+ logger.error(f"[ApiKey] 保存失败: {e}")
100
+
101
+ def _normalize_limit(self, v: Any) -> int:
102
+ """Normalize a daily limit value. -1 means unlimited."""
103
+ if v is None or v == "":
104
+ return -1
105
+ try:
106
+ n = int(v)
107
+ except Exception:
108
+ return -1
109
+ return max(-1, n)
110
+
111
+ def _normalize_key_row(self, row: Dict[str, Any]) -> Dict[str, Any]:
112
+ out = dict(row or {})
113
+ out["key"] = str(out.get("key") or "").strip()
114
+ out["name"] = str(out.get("name") or "").strip()
115
+ try:
116
+ out["created_at"] = int(out.get("created_at") or int(time.time()))
117
+ except Exception:
118
+ out["created_at"] = int(time.time())
119
+ out["is_active"] = bool(out.get("is_active", True))
120
+
121
+ # Daily limits (-1 = unlimited)
122
+ out["chat_limit"] = self._normalize_limit(out.get("chat_limit", -1))
123
+ out["heavy_limit"] = self._normalize_limit(out.get("heavy_limit", -1))
124
+ out["image_limit"] = self._normalize_limit(out.get("image_limit", -1))
125
+ out["video_limit"] = self._normalize_limit(out.get("video_limit", -1))
126
+ return out
127
+
128
+ def _tz_offset_minutes(self) -> int:
129
+ raw = (os.getenv("CACHE_RESET_TZ_OFFSET_MINUTES", "") or "").strip()
130
+ try:
131
+ n = int(raw)
132
+ except Exception:
133
+ n = 480
134
+ return max(-720, min(840, n))
135
+
136
+ def _day_str(self, at_ms: Optional[int] = None, tz_offset_minutes: Optional[int] = None) -> str:
137
+ now_ms = int(at_ms if at_ms is not None else int(time.time() * 1000))
138
+ offset = self._tz_offset_minutes() if tz_offset_minutes is None else int(tz_offset_minutes)
139
+ dt = datetime.fromtimestamp(now_ms / 1000, tz=timezone.utc) + timedelta(minutes=offset)
140
+ return dt.strftime("%Y-%m-%d")
141
+
142
+ async def _load_usage_data(self):
143
+ """Load per-day per-key usage counters."""
144
+ if self._usage_loaded:
145
+ return
146
+
147
+ if not self.usage_path.exists():
148
+ self._usage = {}
149
+ self._usage_loaded = True
150
+ return
151
+
152
+ try:
153
+ async with self._usage_lock:
154
+ if self.usage_path.exists():
155
+ content = await asyncio.to_thread(self.usage_path.read_bytes)
156
+ if content:
157
+ data = orjson.loads(content)
158
+ if isinstance(data, dict):
159
+ # { day: { key: { chat_used, ... } } }
160
+ self._usage = data # type: ignore[assignment]
161
+ else:
162
+ self._usage = {}
163
+ else:
164
+ self._usage = {}
165
+ self._usage_loaded = True
166
+ except Exception as e:
167
+ logger.error(f"[ApiKey] Usage 加载失败: {e}")
168
+ self._usage = {}
169
+ self._usage_loaded = True
170
+
171
+ async def _save_usage_data(self):
172
+ if not self._usage_loaded:
173
+ return
174
+ try:
175
+ self.usage_path.parent.mkdir(parents=True, exist_ok=True)
176
+ async with self._usage_lock:
177
+ content = orjson.dumps(self._usage, option=orjson.OPT_INDENT_2)
178
+ await asyncio.to_thread(self.usage_path.write_bytes, content)
179
+ except Exception as e:
180
+ logger.error(f"[ApiKey] Usage 保存失败: {e}")
181
+
182
+ def generate_key(self) -> str:
183
+ """生成一个新的 sk- 开头的 key"""
184
+ return f"sk-{secrets.token_urlsafe(24)}"
185
+
186
+ def generate_name(self) -> str:
187
+ """生成一个随机 key 名称"""
188
+ return f"key-{secrets.token_urlsafe(6)}"
189
+
190
+ async def add_key(
191
+ self,
192
+ name: str | None = None,
193
+ key: str | None = None,
194
+ limits: Optional[Dict[str, Any]] = None,
195
+ is_active: bool = True,
196
+ ) -> Dict[str, Any]:
197
+ """添加 API Key(支持自定义 key 与每日额度)"""
198
+ await self.init()
199
+
200
+ name_val = str(name or "").strip() or self.generate_name()
201
+ key_val = str(key or "").strip() or self.generate_key()
202
+
203
+ limits = limits or {}
204
+ new_key: Dict[str, Any] = {
205
+ "key": key_val,
206
+ "name": name_val,
207
+ "created_at": int(time.time()),
208
+ "is_active": bool(is_active),
209
+ "chat_limit": self._normalize_limit(limits.get("chat_limit", limits.get("chat_per_day", -1))),
210
+ "heavy_limit": self._normalize_limit(limits.get("heavy_limit", limits.get("heavy_per_day", -1))),
211
+ "image_limit": self._normalize_limit(limits.get("image_limit", limits.get("image_per_day", -1))),
212
+ "video_limit": self._normalize_limit(limits.get("video_limit", limits.get("video_per_day", -1))),
213
+ }
214
+
215
+ # Ensure uniqueness
216
+ if any(k.get("key") == key_val for k in self._keys):
217
+ raise ValueError("Key already exists")
218
+
219
+ self._keys.append(new_key)
220
+ await self._save_data()
221
+ logger.info(f"[ApiKey] 添加新Key: {name_val}")
222
+ return new_key
223
+
224
+ async def batch_add_keys(self, name_prefix: str, count: int) -> List[Dict]:
225
+ """批量添加 API Key"""
226
+ new_keys = []
227
+ for i in range(1, count + 1):
228
+ name = f"{name_prefix}-{i}" if count > 1 else name_prefix
229
+ new_keys.append({
230
+ "key": self.generate_key(),
231
+ "name": name,
232
+ "created_at": int(time.time()),
233
+ "is_active": True,
234
+ "chat_limit": -1,
235
+ "heavy_limit": -1,
236
+ "image_limit": -1,
237
+ "video_limit": -1,
238
+ })
239
+
240
+ self._keys.extend(new_keys)
241
+ await self._save_data()
242
+ logger.info(f"[ApiKey] 批量添加 {count} 个 Key, 前缀: {name_prefix}")
243
+ return new_keys
244
+
245
+ async def delete_key(self, key: str) -> bool:
246
+ """删除 API Key"""
247
+ initial_len = len(self._keys)
248
+ self._keys = [k for k in self._keys if k["key"] != key]
249
+
250
+ if len(self._keys) != initial_len:
251
+ await self._save_data()
252
+ logger.info(f"[ApiKey] 删除Key: {key[:10]}...")
253
+ return True
254
+ return False
255
+
256
+ async def batch_delete_keys(self, keys: List[str]) -> int:
257
+ """批量删除 API Key"""
258
+ initial_len = len(self._keys)
259
+ self._keys = [k for k in self._keys if k["key"] not in keys]
260
+
261
+ deleted_count = initial_len - len(self._keys)
262
+ if deleted_count > 0:
263
+ await self._save_data()
264
+ logger.info(f"[ApiKey] 批量删除 {deleted_count} 个 Key")
265
+ return deleted_count
266
+
267
+ async def update_key_status(self, key: str, is_active: bool) -> bool:
268
+ """更新 Key 状态"""
269
+ for k in self._keys:
270
+ if k["key"] == key:
271
+ k["is_active"] = is_active
272
+ await self._save_data()
273
+ return True
274
+ return False
275
+
276
+ async def batch_update_keys_status(self, keys: List[str], is_active: bool) -> int:
277
+ """批量更新 Key 状态"""
278
+ updated_count = 0
279
+ for k in self._keys:
280
+ if k["key"] in keys:
281
+ if k["is_active"] != is_active:
282
+ k["is_active"] = is_active
283
+ updated_count += 1
284
+
285
+ if updated_count > 0:
286
+ await self._save_data()
287
+ logger.info(f"[ApiKey] 批量更新 {updated_count} 个 Key 状态为: {is_active}")
288
+ return updated_count
289
+
290
+ async def update_key_name(self, key: str, name: str) -> bool:
291
+ """更新 Key 备注"""
292
+ for k in self._keys:
293
+ if k["key"] == key:
294
+ k["name"] = name
295
+ await self._save_data()
296
+ return True
297
+ return False
298
+
299
+ async def update_key_limits(self, key: str, limits: Dict[str, Any]) -> bool:
300
+ """更新 Key 每日额度(-1 表示不限)"""
301
+ limits = limits or {}
302
+ for k in self._keys:
303
+ if k.get("key") != key:
304
+ continue
305
+ if "chat_limit" in limits or "chat_per_day" in limits:
306
+ k["chat_limit"] = self._normalize_limit(limits.get("chat_limit", limits.get("chat_per_day")))
307
+ if "heavy_limit" in limits or "heavy_per_day" in limits:
308
+ k["heavy_limit"] = self._normalize_limit(limits.get("heavy_limit", limits.get("heavy_per_day")))
309
+ if "image_limit" in limits or "image_per_day" in limits:
310
+ k["image_limit"] = self._normalize_limit(limits.get("image_limit", limits.get("image_per_day")))
311
+ if "video_limit" in limits or "video_per_day" in limits:
312
+ k["video_limit"] = self._normalize_limit(limits.get("video_limit", limits.get("video_per_day")))
313
+ await self._save_data()
314
+ return True
315
+ return False
316
+
317
+ def get_key_row(self, key: str) -> Optional[Dict[str, Any]]:
318
+ """获取 Key 原始记录(不要求 active)"""
319
+ for k in self._keys:
320
+ if k.get("key") == key:
321
+ return self._normalize_key_row(k)
322
+ return None
323
+
324
+ async def usage_for_day(self, day: str) -> Dict[str, Dict[str, int]]:
325
+ """返回指定 day 的 usage map: { key: {chat_used,...} }"""
326
+ await self.init()
327
+ if not self._usage_loaded:
328
+ await self._load_usage_data()
329
+ day_map = self._usage.get(day)
330
+ return day_map if isinstance(day_map, dict) else {}
331
+
332
+ async def usage_today(self) -> Tuple[str, Dict[str, Dict[str, int]]]:
333
+ day = self._day_str()
334
+ return day, await self.usage_for_day(day)
335
+
336
+ async def consume_daily_usage(
337
+ self,
338
+ key: str,
339
+ incs: Dict[str, int],
340
+ tz_offset_minutes: Optional[int] = None,
341
+ ) -> bool:
342
+ """
343
+ Consume per-day quota for the given API key.
344
+
345
+ incs keys: chat_used/heavy_used/image_used/video_used
346
+ """
347
+ await self.init()
348
+ row = self.get_key_row(key)
349
+ if not row or not row.get("is_active"):
350
+ # Unknown/disabled keys are already rejected by auth; keep best-effort safe here.
351
+ return True
352
+
353
+ if not self._usage_loaded:
354
+ await self._load_usage_data()
355
+
356
+ day = self._day_str(tz_offset_minutes=tz_offset_minutes)
357
+ at_ms = int(time.time() * 1000)
358
+
359
+ # Normalize incs
360
+ normalized: Dict[str, int] = {}
361
+ for k, v in (incs or {}).items():
362
+ try:
363
+ inc = int(v)
364
+ except Exception:
365
+ continue
366
+ if inc <= 0:
367
+ continue
368
+ normalized[k] = inc
369
+ if not normalized:
370
+ return True
371
+
372
+ limits = {
373
+ "chat_used": int(row.get("chat_limit", -1)),
374
+ "heavy_used": int(row.get("heavy_limit", -1)),
375
+ "image_used": int(row.get("image_limit", -1)),
376
+ "video_used": int(row.get("video_limit", -1)),
377
+ }
378
+
379
+ async with self._usage_lock:
380
+ day_map = self._usage.get(day)
381
+ if not isinstance(day_map, dict):
382
+ day_map = {}
383
+ self._usage[day] = day_map # type: ignore[assignment]
384
+
385
+ usage = day_map.get(key)
386
+ if not isinstance(usage, dict):
387
+ usage = {"chat_used": 0, "heavy_used": 0, "image_used": 0, "video_used": 0, "updated_at": at_ms}
388
+ day_map[key] = usage # type: ignore[assignment]
389
+
390
+ # Check all limits first (atomic for multi-bucket)
391
+ for bucket, inc in normalized.items():
392
+ lim = int(limits.get(bucket, -1))
393
+ used = int(usage.get(bucket, 0) or 0)
394
+ if lim >= 0 and used + inc > lim:
395
+ return False
396
+
397
+ # Apply
398
+ for bucket, inc in normalized.items():
399
+ usage[bucket] = int(usage.get(bucket, 0) or 0) + inc
400
+ usage["updated_at"] = at_ms
401
+
402
+ await self._save_usage_data()
403
+ return True
404
+
405
+ def validate_key(self, key: str) -> Optional[Dict]:
406
+ """验证 Key,返回 Key 信息"""
407
+ # 1. 检查全局配置的 Key (作为默认 admin key)
408
+ global_key = str(get_config("app.api_key", "") or "").strip()
409
+ if global_key and key == global_key:
410
+ return {
411
+ "key": global_key,
412
+ "name": "默认管理员",
413
+ "is_active": True,
414
+ "is_admin": True
415
+ }
416
+
417
+ # 2. 检查多 Key 列表
418
+ for k in self._keys:
419
+ if k["key"] == key:
420
+ if k["is_active"]:
421
+ return {**k, "is_admin": False} # 普通 Key 也可以视为非管理员? 暂不区分权限,只做身份识别
422
+ return None
423
+
424
+ return None
425
+
426
+ def get_all_keys(self) -> List[Dict]:
427
+ """获取所有 Keys"""
428
+ return [self._normalize_key_row(k) for k in self._keys]
429
+
430
+
431
+ # 全局实例
432
+ api_key_manager = ApiKeyManager()
app/services/base.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Base service interface will be defined here
2
+ # Placeholder for service abstraction with concurrency control
app/services/grok/assets.py ADDED
@@ -0,0 +1,875 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Grok 文件资产服务
3
+ """
4
+
5
+ import asyncio
6
+ import base64
7
+ import os
8
+ import time
9
+ import hashlib
10
+ import re
11
+ import uuid
12
+ from pathlib import Path
13
+ from contextlib import asynccontextmanager
14
+ try:
15
+ import fcntl
16
+ except ImportError: # pragma: no cover - non-posix platforms
17
+ fcntl = None
18
+ from typing import Tuple, List, Dict, Optional, Any
19
+ from urllib.parse import urlparse
20
+
21
+ import aiofiles
22
+ from curl_cffi.requests import AsyncSession
23
+
24
+ from app.core.logger import logger
25
+ from app.core.config import get_config
26
+ from app.core.exceptions import (
27
+ AppException,
28
+ UpstreamException,
29
+ ValidationException
30
+ )
31
+ from app.services.grok.statsig import StatsigService
32
+
33
+
34
+ # ==================== 常量 ====================
35
+
36
+ UPLOAD_API = "https://grok.com/rest/app-chat/upload-file"
37
+ LIST_API = "https://grok.com/rest/assets"
38
+ DELETE_API = "https://grok.com/rest/assets-metadata"
39
+ DOWNLOAD_API = "https://assets.grok.com"
40
+ LOCK_DIR = Path(__file__).parent.parent.parent.parent / "data" / ".locks"
41
+
42
+ TIMEOUT = 120
43
+ BROWSER = "chrome136"
44
+ DEFAULT_MIME = "application/octet-stream"
45
+
46
+ # 并发控制
47
+ DEFAULT_MAX_CONCURRENT = 25
48
+ DEFAULT_DELETE_BATCH_SIZE = 10
49
+ _ASSETS_SEMAPHORE = asyncio.Semaphore(DEFAULT_MAX_CONCURRENT)
50
+ _ASSETS_SEM_VALUE = DEFAULT_MAX_CONCURRENT
51
+
52
+ def _get_assets_semaphore() -> asyncio.Semaphore:
53
+ global _ASSETS_SEMAPHORE, _ASSETS_SEM_VALUE
54
+ value = get_config("performance.assets_max_concurrent", DEFAULT_MAX_CONCURRENT)
55
+ try:
56
+ value = int(value)
57
+ except Exception:
58
+ value = DEFAULT_MAX_CONCURRENT
59
+ value = max(1, value)
60
+ if value != _ASSETS_SEM_VALUE:
61
+ _ASSETS_SEM_VALUE = value
62
+ _ASSETS_SEMAPHORE = asyncio.Semaphore(value)
63
+ return _ASSETS_SEMAPHORE
64
+
65
+ def _get_delete_batch_size() -> int:
66
+ value = get_config("performance.assets_delete_batch_size", DEFAULT_DELETE_BATCH_SIZE)
67
+ try:
68
+ value = int(value)
69
+ except Exception:
70
+ value = DEFAULT_DELETE_BATCH_SIZE
71
+ return max(1, value)
72
+
73
+ @asynccontextmanager
74
+ async def _file_lock(name: str, timeout: int = 10):
75
+ if fcntl is None:
76
+ yield
77
+ return
78
+ LOCK_DIR.mkdir(parents=True, exist_ok=True)
79
+ lock_path = LOCK_DIR / f"{name}.lock"
80
+ fd = None
81
+ locked = False
82
+ start = time.monotonic()
83
+ try:
84
+ fd = open(lock_path, "a+")
85
+ while True:
86
+ try:
87
+ fcntl.flock(fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
88
+ locked = True
89
+ break
90
+ except BlockingIOError:
91
+ if time.monotonic() - start >= timeout:
92
+ break
93
+ await asyncio.sleep(0.05)
94
+ yield
95
+ finally:
96
+ if fd:
97
+ if locked:
98
+ try:
99
+ fcntl.flock(fd, fcntl.LOCK_UN)
100
+ except Exception:
101
+ pass
102
+ try:
103
+ fd.close()
104
+ except Exception:
105
+ pass
106
+
107
+ MIME_TYPES = {
108
+ # 图片
109
+ '.jpg': 'image/jpeg', '.jpeg': 'image/jpeg', '.png': 'image/png',
110
+ '.gif': 'image/gif', '.webp': 'image/webp', '.bmp': 'image/bmp',
111
+
112
+ # 文档
113
+ '.pdf': 'application/pdf', '.txt': 'text/plain', '.md': 'text/markdown',
114
+ '.doc': 'application/msword',
115
+ '.docx': 'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
116
+ '.rtf': 'application/rtf',
117
+
118
+ # 表格
119
+ '.csv': 'text/csv',
120
+ '.xls': 'application/vnd.ms-excel',
121
+ '.xlsx': 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
122
+
123
+ # 代码
124
+ '.py': 'text/x-python-script', '.js': 'application/javascript', '.ts': 'application/typescript',
125
+ '.java': 'text/x-java', '.cpp': 'text/x-c++', '.c': 'text/x-c',
126
+ '.go': 'text/x-go', '.rs': 'text/x-rust', '.rb': 'text/x-ruby',
127
+ '.php': 'text/x-php', '.sh': 'application/x-sh', '.html': 'text/html',
128
+ '.css': 'text/css', '.sql': 'application/sql',
129
+
130
+ # 数据
131
+ '.json': 'application/json', '.xml': 'application/xml', '.yaml': 'application/x-yaml',
132
+ '.yml': 'application/x-yaml', '.toml': 'application/toml', '.ini': 'text/plain',
133
+ '.log': 'text/plain', '.tmp': 'application/octet-stream',
134
+
135
+ # 其他
136
+ '.graphql': 'application/graphql', '.proto': 'application/x-protobuf',
137
+ '.latex': 'application/x-latex', '.wiki': 'text/plain', '.rst': 'text/x-rst',
138
+ }
139
+
140
+ IMAGE_EXTS = {'.jpg', '.jpeg', '.png', '.gif', '.webp', '.bmp'}
141
+ VIDEO_EXTS = {'.mp4', '.mov', '.m4v', '.webm', '.avi', '.mkv'}
142
+
143
+
144
+ # ==================== 基础服务 ====================
145
+
146
+ class BaseService:
147
+ """基础服务类"""
148
+
149
+ def __init__(self, proxy: str = None):
150
+ self.proxy = proxy or get_config("grok.asset_proxy_url") or get_config("grok.base_proxy_url", "")
151
+ self.timeout = get_config("grok.timeout", TIMEOUT)
152
+ self._session: Optional[AsyncSession] = None
153
+
154
+ def _headers(self, token: str, referer: str = "https://grok.com/") -> dict:
155
+ """构建请求头"""
156
+ headers = {
157
+ "Accept": "*/*",
158
+ "Accept-Encoding": "gzip, deflate, br, zstd",
159
+ "Accept-Language": "zh-CN,zh;q=0.9",
160
+ "Baggage": "sentry-environment=production,sentry-release=d6add6fb0460641fd482d767a335ef72b9b6abb8,sentry-public_key=b311e0f2690c81f25e2c4cf6d4f7ce1c",
161
+ "Cache-Control": "no-cache",
162
+ "Content-Type": "application/json",
163
+ "Origin": "https://grok.com",
164
+ "Pragma": "no-cache",
165
+ "Priority": "u=1, i",
166
+ "Referer": referer,
167
+ "Sec-Ch-Ua": '"Google Chrome";v="136", "Chromium";v="136", "Not(A:Brand";v="24"',
168
+ "Sec-Ch-Ua-Arch": "arm",
169
+ "Sec-Ch-Ua-Bitness": "64",
170
+ "Sec-Ch-Ua-Mobile": "?0",
171
+ "Sec-Ch-Ua-Model": "",
172
+ "Sec-Ch-Ua-Platform": '"macOS"',
173
+ "Sec-Fetch-Dest": "empty",
174
+ "Sec-Fetch-Mode": "cors",
175
+ "Sec-Fetch-Site": "same-origin",
176
+ "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/136.0.0.0 Safari/537.36",
177
+ }
178
+
179
+ # Statsig ID
180
+ headers["x-statsig-id"] = StatsigService.gen_id()
181
+ headers["x-xai-request-id"] = str(uuid.uuid4())
182
+
183
+ # Cookie
184
+ token = token[4:] if token.startswith("sso=") else token
185
+ cf = get_config("grok.cf_clearance", "")
186
+ headers["Cookie"] = f"sso={token};cf_clearance={cf}" if cf else f"sso={token}"
187
+
188
+ return headers
189
+
190
+ def _proxies(self) -> Optional[dict]:
191
+ """构建代理配置"""
192
+ return {"http": self.proxy, "https": self.proxy} if self.proxy else None
193
+
194
+ def _dl_headers(self, token: str, file_path: str) -> dict:
195
+ """构建下载请求头"""
196
+ headers = {
197
+ "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,*/*;q=0.8",
198
+ "Sec-Fetch-Dest": "document",
199
+ "Sec-Fetch-Mode": "navigate",
200
+ "Sec-Fetch-Site": "same-site",
201
+ "Sec-Fetch-User": "?1",
202
+ "Upgrade-Insecure-Requests": "1",
203
+ "Referer": "https://grok.com/",
204
+ }
205
+
206
+ # Cookie
207
+ token = token[4:] if token.startswith("sso=") else token
208
+ cf = get_config("grok.cf_clearance", "")
209
+ headers["Cookie"] = f"sso={token};cf_clearance={cf}" if cf else f"sso={token}"
210
+
211
+ return headers
212
+
213
+ async def _get_session(self) -> AsyncSession:
214
+ """获取复用 Session"""
215
+ if self._session is None:
216
+ self._session = AsyncSession()
217
+ return self._session
218
+
219
+ async def close(self):
220
+ """关闭 Session"""
221
+ if self._session:
222
+ await self._session.close()
223
+ self._session = None
224
+
225
+ @staticmethod
226
+ def is_url(input_str: str) -> bool:
227
+ """检查是否为 URL"""
228
+ try:
229
+ result = urlparse(input_str)
230
+ return all([result.scheme, result.netloc]) and result.scheme in ['http', 'https']
231
+ except:
232
+ return False
233
+
234
+ @staticmethod
235
+ async def fetch(url: str) -> Tuple[str, str, str]:
236
+ """
237
+ 获取远程资源并转 Base64
238
+
239
+ Raises:
240
+ UpstreamException: 当获取失败时
241
+ """
242
+ try:
243
+ async with AsyncSession() as session:
244
+ response = await session.get(url, timeout=10)
245
+ if response.status_code >= 400:
246
+ raise UpstreamException(
247
+ message=f"Failed to fetch resource: {response.status_code}",
248
+ details={"url": url, "status": response.status_code}
249
+ )
250
+
251
+ filename = url.split('/')[-1].split('?')[0] or 'download'
252
+ content_type = response.headers.get('content-type', DEFAULT_MIME).split(';')[0]
253
+ b64 = base64.b64encode(response.content).decode()
254
+
255
+ logger.debug(f"Fetched: {url} -> {filename}")
256
+ return filename, b64, content_type
257
+ except Exception as e:
258
+ logger.error(f"Fetch failed: {url} - {e}")
259
+ if isinstance(e, AppException):
260
+ raise e
261
+ raise UpstreamException(f"Resource fetch failed: {str(e)}", details={"url": url})
262
+
263
+ @staticmethod
264
+ def parse_b64(data_uri: str) -> Tuple[str, str, str]:
265
+ """解析 Base64 数据"""
266
+ if data_uri.startswith("data:"):
267
+ match = re.match(r"data:([^;]+);base64,(.+)", data_uri)
268
+ if match:
269
+ mime = match.group(1)
270
+ b64 = match.group(2)
271
+ ext = mime.split('/')[-1] if '/' in mime else 'bin'
272
+ return f"file.{ext}", b64, mime
273
+ return "file.bin", data_uri, DEFAULT_MIME
274
+
275
+ @staticmethod
276
+ def to_b64(file_path: Path, mime_type: str) -> str:
277
+ """将本地文件转为 base64 data URI"""
278
+ try:
279
+ b64_data = base64.b64encode(file_path.read_bytes()).decode()
280
+ return f"data:{mime_type};base64,{b64_data}"
281
+ except Exception as e:
282
+ logger.error(f"File to base64 failed: {file_path} - {e}")
283
+ raise AppException(f"Failed to read file: {file_path}", code="file_read_error")
284
+
285
+
286
+ # ==================== 上传服务 ====================
287
+
288
+ class UploadService(BaseService):
289
+ """文件上传服务"""
290
+
291
+ async def upload(self, file_input: str, token: str) -> Tuple[str, str]:
292
+ """
293
+ 上传文件到 Grok
294
+
295
+ Returns:
296
+ (file_id, file_uri)
297
+
298
+ Raises:
299
+ ValidationException: 输入无效
300
+ UpstreamException: 上传失败
301
+ """
302
+ async with _get_assets_semaphore():
303
+ try:
304
+ # 处理输入
305
+ if self.is_url(file_input):
306
+ filename, b64, mime = await self.fetch(file_input)
307
+ else:
308
+ filename, b64, mime = self.parse_b64(file_input)
309
+
310
+ if not b64:
311
+ raise ValidationException("Invalid file input: empty content")
312
+
313
+ # 构建请求
314
+ headers = self._headers(token)
315
+ payload = {
316
+ "fileName": filename,
317
+ "fileMimeType": mime,
318
+ "content": b64,
319
+ }
320
+
321
+ # 执行上传
322
+ session = await self._get_session()
323
+ response = await session.post(
324
+ UPLOAD_API,
325
+ headers=headers,
326
+ json=payload,
327
+ impersonate=BROWSER,
328
+ timeout=self.timeout,
329
+ proxies=self._proxies(),
330
+ )
331
+
332
+ if response.status_code == 200:
333
+ result = response.json()
334
+ file_id = result.get("fileMetadataId", "")
335
+ file_uri = result.get("fileUri", "")
336
+ logger.info(f"Upload success: {filename} -> {file_id}", extra={"file_id": file_id})
337
+ return file_id, file_uri
338
+
339
+ logger.error(
340
+ f"Upload failed: {filename} - {response.status_code}",
341
+ extra={"response": response.text[:200]}
342
+ )
343
+ raise UpstreamException(
344
+ message=f"Upload failed with status {response.status_code}",
345
+ details={"status": response.status_code, "response": response.text[:200]}
346
+ )
347
+
348
+ except Exception as e:
349
+ logger.error(f"Upload error: {e}")
350
+ if isinstance(e, AppException):
351
+ raise e
352
+ raise UpstreamException(f"Upload process error: {str(e)}")
353
+
354
+
355
+ # ==================== 列表服务 ====================
356
+
357
+ class ListService(BaseService):
358
+ """文件列表查询服务"""
359
+
360
+ async def iter_assets(self, token: str):
361
+ """
362
+ 分页迭代资产列表
363
+ """
364
+ headers = self._headers(token, referer="https://grok.com/files")
365
+ base_params = {
366
+ "pageSize": 50,
367
+ "orderBy": "ORDER_BY_LAST_USE_TIME",
368
+ "source": "SOURCE_ANY",
369
+ "isLatest": "true",
370
+ }
371
+ page_token = None
372
+ seen_tokens = set()
373
+
374
+ async with AsyncSession() as session:
375
+ while True:
376
+ params = dict(base_params)
377
+ if page_token:
378
+ if page_token in seen_tokens:
379
+ logger.warning("List pagination stopped due to repeated page token")
380
+ break
381
+ seen_tokens.add(page_token)
382
+ params["pageToken"] = page_token
383
+
384
+ response = await session.get(
385
+ LIST_API,
386
+ headers=headers,
387
+ params=params,
388
+ impersonate=BROWSER,
389
+ timeout=self.timeout,
390
+ proxies=self._proxies(),
391
+ )
392
+
393
+ if response.status_code != 200:
394
+ logger.error(f"List failed: {response.status_code}")
395
+ raise UpstreamException(
396
+ message=f"List assets failed: {response.status_code}",
397
+ details={"status": response.status_code}
398
+ )
399
+
400
+ result = response.json()
401
+ page_assets = result.get("assets", [])
402
+ yield page_assets
403
+
404
+ page_token = result.get("nextPageToken")
405
+ if not page_token:
406
+ break
407
+
408
+ async def list(self, token: str) -> List[Dict]:
409
+ """
410
+ 查询文件列表
411
+
412
+ Raises:
413
+ UpstreamException: 查询失败
414
+ """
415
+ try:
416
+ assets: List[Dict] = []
417
+ async for page_assets in self.iter_assets(token):
418
+ assets.extend(page_assets)
419
+
420
+ logger.info(f"List success: {len(assets)} files")
421
+ return assets
422
+
423
+ except Exception as e:
424
+ logger.error(f"List error: {e}")
425
+ if isinstance(e, AppException):
426
+ raise e
427
+ raise UpstreamException(f"List assets error: {str(e)}")
428
+
429
+ async def count(self, token: str) -> int:
430
+ """
431
+ 统计资产数量(不保留明细)
432
+ """
433
+ try:
434
+ total = 0
435
+ async for page_assets in self.iter_assets(token):
436
+ total += len(page_assets)
437
+ return total
438
+ except Exception as e:
439
+ logger.error(f"List count error: {e}")
440
+ if isinstance(e, AppException):
441
+ raise e
442
+ raise UpstreamException(f"List assets error: {str(e)}")
443
+
444
+
445
+ # ==================== 删除服务 ====================
446
+
447
+ class DeleteService(BaseService):
448
+ """文件删除服务"""
449
+
450
+ async def delete(self, token: str, asset_id: str) -> bool:
451
+ """
452
+ 删除单个文件
453
+
454
+ Raises:
455
+ UpstreamException: 删除失败
456
+ """
457
+ async with _get_assets_semaphore():
458
+ try:
459
+ headers = self._headers(token, referer="https://grok.com/files")
460
+ url = f"{DELETE_API}/{asset_id}"
461
+
462
+ session = await self._get_session()
463
+ response = await session.delete(
464
+ url,
465
+ headers=headers,
466
+ impersonate=BROWSER,
467
+ timeout=self.timeout,
468
+ proxies=self._proxies(),
469
+ )
470
+
471
+ if response.status_code == 200:
472
+ logger.debug(f"Delete success: {asset_id}")
473
+ return True
474
+
475
+ logger.error(f"Delete failed: {asset_id} - {response.status_code}")
476
+ #: Note: Returning False or raising Exception?
477
+ #: Assuming caller handles Exception for stricter control, or False for loose control.
478
+ #: Given "optimization" and "standardization", raising exceptions is better for API feedback.
479
+ raise UpstreamException(
480
+ message=f"Delete failed: {asset_id}",
481
+ details={"status": response.status_code}
482
+ )
483
+
484
+ except Exception as e:
485
+ logger.error(f"Delete error: {asset_id} - {e}")
486
+ if isinstance(e, AppException):
487
+ raise e
488
+ raise UpstreamException(f"Delete error: {str(e)}")
489
+
490
+ async def delete_all(self, token: str) -> Dict[str, int]:
491
+ """
492
+ 删除所有文件
493
+ """
494
+ total = 0
495
+ success = 0
496
+ failed = 0
497
+ list_service = ListService(self.proxy)
498
+ try:
499
+ async for assets in list_service.iter_assets(token):
500
+ if not assets:
501
+ continue
502
+ total += len(assets)
503
+
504
+ # 批量并发删除
505
+ async def _delete_one(asset: Dict, index: int) -> bool:
506
+ await asyncio.sleep(0.01 * index)
507
+ asset_id = asset.get("assetId", "")
508
+ if asset_id:
509
+ try:
510
+ return await self.delete(token, asset_id)
511
+ except:
512
+ return False
513
+ return False
514
+
515
+ batch_size = _get_delete_batch_size()
516
+ for i in range(0, len(assets), batch_size):
517
+ batch = assets[i:i + batch_size]
518
+ results = await asyncio.gather(*[
519
+ _delete_one(asset, idx) for idx, asset in enumerate(batch)
520
+ ])
521
+ success += sum(results)
522
+ failed += len(batch) - sum(results)
523
+
524
+ if total == 0:
525
+ logger.info("No assets to delete")
526
+ return {"total": 0, "success": 0, "failed": 0, "skipped": True}
527
+ except Exception as e:
528
+ logger.error(f"Delete all failed during list: {e}")
529
+ return {"total": total, "success": success, "failed": failed}
530
+ finally:
531
+ await list_service.close()
532
+
533
+ logger.info(f"Delete all: total={total}, success={success}, failed={failed}")
534
+ return {"total": total, "success": success, "failed": failed}
535
+
536
+
537
+ # ==================== 下载服务 ====================
538
+
539
+ class DownloadService(BaseService):
540
+ """文件下载服务"""
541
+
542
+ def __init__(self, proxy: str = None):
543
+ super().__init__(proxy)
544
+ # 创建缓存目录
545
+ self.base_dir = Path(__file__).parent.parent.parent.parent / "data" / "tmp"
546
+ self.legacy_base_dir = Path(__file__).parent.parent.parent.parent / "data" / "temp"
547
+ self.legacy_image_dir = self.legacy_base_dir / "image"
548
+ self.legacy_video_dir = self.legacy_base_dir / "video"
549
+ self.image_dir = self.base_dir / "image"
550
+ self.video_dir = self.base_dir / "video"
551
+ self.image_dir.mkdir(parents=True, exist_ok=True)
552
+ self.video_dir.mkdir(parents=True, exist_ok=True)
553
+ self._cleanup_running = False
554
+
555
+ def _cache_path(self, file_path: str, media_type: str) -> Path:
556
+ """获取缓存路径"""
557
+ cache_dir = self.image_dir if media_type == "image" else self.video_dir
558
+ filename = file_path.lstrip('/').replace('/', '-')
559
+ return cache_dir / filename
560
+
561
+ def _legacy_cache_path(self, file_path: str, media_type: str) -> Path:
562
+ """Legacy cache path (data/temp)."""
563
+ cache_dir = self.legacy_image_dir if media_type == "image" else self.legacy_video_dir
564
+ filename = file_path.lstrip("/").replace("/", "-")
565
+ return cache_dir / filename
566
+
567
+ async def download(self, file_path: str, token: str, media_type: str = "image") -> Tuple[Optional[Path], str]:
568
+ """
569
+ 下载文件到本地
570
+
571
+ Raises:
572
+ UpstreamException: 下载失败
573
+ """
574
+ async with _get_assets_semaphore():
575
+ try:
576
+ # Be forgiving: callers may pass absolute URLs.
577
+ if isinstance(file_path, str) and file_path.startswith("http"):
578
+ try:
579
+ file_path = urlparse(file_path).path
580
+ except Exception:
581
+ pass
582
+
583
+ cache_path = self._cache_path(file_path, media_type)
584
+
585
+ # 如果已缓存
586
+ if cache_path.exists():
587
+ logger.debug(f"Cache hit: {cache_path}")
588
+ mime_type = MIME_TYPES.get(cache_path.suffix.lower(), DEFAULT_MIME)
589
+ return cache_path, mime_type
590
+
591
+ legacy_path = self._legacy_cache_path(file_path, media_type)
592
+ if legacy_path.exists():
593
+ logger.debug(f"Legacy cache hit: {legacy_path}")
594
+ mime_type = MIME_TYPES.get(legacy_path.suffix.lower(), DEFAULT_MIME)
595
+ return legacy_path, mime_type
596
+
597
+ lock_name = f"download_{media_type}_{hashlib.sha1(str(cache_path).encode('utf-8')).hexdigest()[:16]}"
598
+ async with _file_lock(lock_name, timeout=10):
599
+ # Double-check after lock
600
+ if cache_path.exists():
601
+ logger.debug(f"Cache hit after lock: {cache_path}")
602
+ mime_type = MIME_TYPES.get(cache_path.suffix.lower(), DEFAULT_MIME)
603
+ return cache_path, mime_type
604
+
605
+ # 下载文件
606
+ if not file_path.startswith("/"):
607
+ file_path = f"/{file_path}"
608
+
609
+ url = f"{DOWNLOAD_API}{file_path}"
610
+ headers = self._dl_headers(token, file_path)
611
+
612
+ session = await self._get_session()
613
+ response = await session.get(
614
+ url,
615
+ headers=headers,
616
+ proxies=self._proxies(),
617
+ timeout=self.timeout,
618
+ allow_redirects=True,
619
+ impersonate=BROWSER,
620
+ stream=True,
621
+ )
622
+
623
+ if response.status_code != 200:
624
+ raise UpstreamException(
625
+ message=f"Download failed: {response.status_code}",
626
+ details={"path": file_path, "status": response.status_code}
627
+ )
628
+
629
+ # 保存文件(分块写入,避免大文件占用内存)
630
+ tmp_path = cache_path.with_suffix(cache_path.suffix + ".tmp")
631
+ try:
632
+ async with aiofiles.open(tmp_path, "wb") as f:
633
+ if hasattr(response, "aiter_content"):
634
+ async for chunk in response.aiter_content():
635
+ if chunk:
636
+ await f.write(chunk)
637
+ elif hasattr(response, "aiter_bytes"):
638
+ async for chunk in response.aiter_bytes():
639
+ if chunk:
640
+ await f.write(chunk)
641
+ elif hasattr(response, "aiter_raw"):
642
+ async for chunk in response.aiter_raw():
643
+ if chunk:
644
+ await f.write(chunk)
645
+ else:
646
+ await f.write(response.content)
647
+ os.replace(tmp_path, cache_path)
648
+ finally:
649
+ if tmp_path.exists() and not cache_path.exists():
650
+ try:
651
+ tmp_path.unlink()
652
+ except Exception:
653
+ pass
654
+ mime_type = response.headers.get('content-type', DEFAULT_MIME).split(';')[0]
655
+
656
+ logger.info(f"Download success: {file_path}")
657
+
658
+ # 检查缓存限制
659
+ asyncio.create_task(self.check_limit())
660
+
661
+ return cache_path, mime_type
662
+
663
+ except Exception as e:
664
+ logger.error(f"Download failed: {file_path} - {e}")
665
+ if isinstance(e, AppException):
666
+ raise e
667
+ raise UpstreamException(f"Download error: {str(e)}")
668
+
669
+ async def to_base64(
670
+ self,
671
+ file_path: str,
672
+ token: str,
673
+ media_type: str = "image"
674
+ ) -> str:
675
+ """
676
+ 下载文件并转为 base64
677
+ """
678
+ try:
679
+ cache_path, mime_type = await self.download(file_path, token, media_type)
680
+
681
+ if not cache_path or not cache_path.exists():
682
+ raise AppException("File download returned invalid path")
683
+
684
+ # 使用基础服务的工具方法转换
685
+ data_uri = self.to_b64(cache_path, mime_type)
686
+
687
+ # 默认保留文件到本地缓存,便于后台“缓存管理”统计与复用;
688
+ # 如需转为临时模式,可通过 cache.keep_base64_cache=false 关闭保留。
689
+ keep_cache = get_config("cache.keep_base64_cache", True)
690
+ if data_uri and not keep_cache:
691
+ try:
692
+ cache_path.unlink()
693
+ except Exception as e:
694
+ logger.warning(f"Delete temp file failed: {e}")
695
+
696
+ return data_uri
697
+
698
+ except Exception as e:
699
+ logger.error(f"To base64 failed: {file_path} - {e}")
700
+ if isinstance(e, AppException):
701
+ raise e
702
+ raise AppException(f"Base64 conversion failed: {str(e)}")
703
+
704
+ def get_stats(self, media_type: str = "image") -> Dict[str, Any]:
705
+ """获取缓存统计"""
706
+ cache_dir = self.image_dir if media_type == "image" else self.video_dir
707
+ if not cache_dir.exists():
708
+ return {"count": 0, "size_mb": 0.0}
709
+
710
+ # 统计目录下所有文件(有些资产路径可能不带标准后缀名)
711
+ files = [f for f in cache_dir.glob("*") if f.is_file()]
712
+ total_size = sum(f.stat().st_size for f in files)
713
+
714
+ return {
715
+ "count": len(files),
716
+ "size_mb": round(total_size / 1024 / 1024, 2)
717
+ }
718
+
719
+ def list_files(self, media_type: str = "image", page: int = 1, page_size: int = 1000) -> Dict[str, Any]:
720
+ """列出本地缓存文件"""
721
+ cache_dir = self.image_dir if media_type == "image" else self.video_dir
722
+ if not cache_dir.exists():
723
+ return {"total": 0, "page": page, "page_size": page_size, "items": []}
724
+
725
+ files = [f for f in cache_dir.glob("*") if f.is_file()]
726
+ items = []
727
+ for f in files:
728
+ try:
729
+ stat = f.stat()
730
+ items.append({
731
+ "name": f.name,
732
+ "size_bytes": stat.st_size,
733
+ "mtime_ms": int(stat.st_mtime * 1000),
734
+ })
735
+ except Exception:
736
+ continue
737
+
738
+ items.sort(key=lambda x: x["mtime_ms"], reverse=True)
739
+ total = len(items)
740
+ start = max(0, (page - 1) * page_size)
741
+ end = start + page_size
742
+ paged = items[start:end]
743
+
744
+ if media_type == "image":
745
+ for item in paged:
746
+ item["view_url"] = f"/v1/files/image/{item['name']}"
747
+ else:
748
+ preview_map = {}
749
+ if self.image_dir.exists():
750
+ for img in self.image_dir.glob("*"):
751
+ if img.is_file() and img.suffix.lower() in IMAGE_EXTS:
752
+ preview_map.setdefault(img.stem, img.name)
753
+ for item in paged:
754
+ item["view_url"] = f"/v1/files/video/{item['name']}"
755
+ preview_name = preview_map.get(Path(item["name"]).stem)
756
+ if preview_name:
757
+ item["preview_url"] = f"/v1/files/image/{preview_name}"
758
+
759
+ return {"total": total, "page": page, "page_size": page_size, "items": paged}
760
+
761
+ def delete_file(self, media_type: str, name: str) -> Dict[str, Any]:
762
+ """删除单个缓存文件"""
763
+ cache_dir = self.image_dir if media_type == "image" else self.video_dir
764
+ safe_name = name.replace("/", "-")
765
+ file_path = cache_dir / safe_name
766
+ if not file_path.exists():
767
+ return {"deleted": False}
768
+ try:
769
+ file_path.unlink()
770
+ return {"deleted": True}
771
+ except Exception:
772
+ return {"deleted": False}
773
+
774
+ def clear(self, media_type: str = "image") -> Dict[str, Any]:
775
+ """清空��存"""
776
+ cache_dir = self.image_dir if media_type == "image" else self.video_dir
777
+ if not cache_dir.exists():
778
+ return {"count": 0, "size_mb": 0.0}
779
+
780
+ files = list(cache_dir.glob("*"))
781
+ total_size = sum(f.stat().st_size for f in files)
782
+ count = 0
783
+
784
+ for f in files:
785
+ try:
786
+ f.unlink()
787
+ count += 1
788
+ except Exception as e:
789
+ logger.error(f"Failed to delete {f}: {e}")
790
+
791
+ return {
792
+ "count": count,
793
+ "size_mb": round(total_size / 1024 / 1024, 2)
794
+ }
795
+
796
+ async def check_limit(self):
797
+ """检查并清理缓存限制"""
798
+ if self._cleanup_running:
799
+ return
800
+ self._cleanup_running = True
801
+ try:
802
+ async with _file_lock("cache_cleanup", timeout=5):
803
+ if not get_config("cache.enable_auto_clean", True):
804
+ return
805
+
806
+ limit_mb = get_config("cache.limit_mb", 1024)
807
+
808
+ # 统计总大小
809
+ total_size = 0
810
+ all_files = []
811
+
812
+ for d in [self.image_dir, self.video_dir]:
813
+ if d.exists():
814
+ for f in d.glob("*"):
815
+ try:
816
+ stat = f.stat()
817
+ total_size += stat.st_size
818
+ all_files.append((f, stat.st_mtime, stat.st_size))
819
+ except:
820
+ pass
821
+
822
+ current_mb = total_size / 1024 / 1024
823
+ if current_mb <= limit_mb:
824
+ return
825
+
826
+ # 需要清理
827
+ logger.info(f"Cache limit exceeded ({current_mb:.2f}MB > {limit_mb}MB), cleaning up...")
828
+
829
+ # 按时间排序
830
+ all_files.sort(key=lambda x: x[1])
831
+
832
+ deleted_count = 0
833
+ deleted_size = 0
834
+ target_mb = limit_mb * 0.8 # 清理到 80%
835
+
836
+ for f, _, size in all_files:
837
+ try:
838
+ f.unlink()
839
+ deleted_count += 1
840
+ deleted_size += size
841
+ total_size -= size
842
+
843
+ if (total_size / 1024 / 1024) <= target_mb:
844
+ break
845
+ except Exception as e:
846
+ logger.error(f"Cleanup failed for {f}: {e}")
847
+
848
+ logger.info(f"Cache cleanup: deleted {deleted_count} files ({deleted_size/1024/1024:.2f}MB)")
849
+ finally:
850
+ self._cleanup_running = False
851
+
852
+ def get_public_url(self, file_path: str) -> str:
853
+ """
854
+ 获取文件的公共访问 URL
855
+
856
+ 如果配置了 app_url,则返回自托管 URL,否则返回 Grok 原始 URL
857
+ """
858
+ app_url = get_config("app.app_url", "")
859
+ if not app_url:
860
+ return f"{DOWNLOAD_API}{file_path if file_path.startswith('/') else '/' + file_path}"
861
+
862
+ if not file_path.startswith("/"):
863
+ file_path = f"/{file_path}"
864
+
865
+ # 自动添加 /v1/files 前缀
866
+ return f"{app_url.rstrip('/')}/v1/files{file_path}"
867
+
868
+
869
+ __all__ = [
870
+ "BaseService",
871
+ "UploadService",
872
+ "ListService",
873
+ "DeleteService",
874
+ "DownloadService",
875
+ ]
app/services/grok/chat.py ADDED
@@ -0,0 +1,571 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Grok Chat 服务
3
+ """
4
+
5
+ import asyncio
6
+ import uuid
7
+ import orjson
8
+ from typing import Dict, List, Any
9
+ from dataclasses import dataclass
10
+
11
+ from curl_cffi.requests import AsyncSession
12
+
13
+ from app.core.logger import logger
14
+ from app.core.config import get_config
15
+ from app.core.exceptions import (
16
+ AppException,
17
+ UpstreamException,
18
+ ValidationException,
19
+ ErrorType
20
+ )
21
+ from app.services.grok.statsig import StatsigService
22
+ from app.services.grok.model import ModelService
23
+ from app.services.grok.assets import UploadService
24
+ from app.services.grok.processor import StreamProcessor, CollectProcessor
25
+ from app.services.grok.retry import retry_on_status
26
+ from app.services.token import get_token_manager
27
+ from app.services.request_stats import request_stats
28
+
29
+
30
+ CHAT_API = "https://grok.com/rest/app-chat/conversations/new"
31
+ TIMEOUT = 120
32
+ BROWSER = "chrome136"
33
+
34
+
35
+ @dataclass
36
+ class ChatRequest:
37
+ """聊天请求数据"""
38
+ model: str
39
+ messages: List[Dict[str, Any]]
40
+ stream: bool = None
41
+ think: bool = None
42
+
43
+
44
+ class MessageExtractor:
45
+ """消息内容提取器"""
46
+
47
+ # 需要上传的类型
48
+ UPLOAD_TYPES = {"image_url", "input_audio", "file"}
49
+ # 视频模式不支持的类型
50
+ VIDEO_UNSUPPORTED = {"input_audio", "file"}
51
+
52
+ @staticmethod
53
+ def extract(messages: List[Dict[str, Any]], is_video: bool = False) -> tuple[str, List[str]]:
54
+ """
55
+ 从 OpenAI 消息格式提取内容
56
+
57
+ Args:
58
+ messages: OpenAI 格式消息列表
59
+ is_video: 是否为视频模型
60
+
61
+ Returns:
62
+ (text, attachments): 拼接后的文本和需要上传的附件列表
63
+
64
+ Raises:
65
+ ValueError: 视频模型遇到不支持的内容类型
66
+ """
67
+ texts = []
68
+ attachments = [] # 需要上传的附件 (URL 或 base64)
69
+
70
+ # 先抽取每条消息的文本,保留角色信息用于合并
71
+ extracted: List[Dict[str, str]] = []
72
+
73
+ for msg in messages:
74
+ role = msg.get("role", "")
75
+ content = msg.get("content", "")
76
+ parts = []
77
+
78
+ # 简单字符串内容
79
+ if isinstance(content, str):
80
+ if content.strip():
81
+ parts.append(content)
82
+
83
+ # 列表格式内容
84
+ elif isinstance(content, list):
85
+ for item in content:
86
+ item_type = item.get("type", "")
87
+
88
+ # 文本类型
89
+ if item_type == "text":
90
+ text = item.get("text", "")
91
+ if text.strip():
92
+ parts.append(text)
93
+
94
+ # 图片类型
95
+ elif item_type == "image_url":
96
+ image_data = item.get("image_url", {})
97
+ url = image_data.get("url", "") if isinstance(image_data, dict) else str(image_data)
98
+ if url:
99
+ attachments.append(("image", url))
100
+
101
+ # 音频类型
102
+ elif item_type == "input_audio":
103
+ if is_video:
104
+ raise ValueError("视频模型不支持 input_audio 类型")
105
+ audio_data = item.get("input_audio", {})
106
+ data = audio_data.get("data", "") if isinstance(audio_data, dict) else str(audio_data)
107
+ if data:
108
+ attachments.append(("audio", data))
109
+
110
+ # 文件类型
111
+ elif item_type == "file":
112
+ if is_video:
113
+ raise ValueError("视频模型不支持 file 类型")
114
+ file_data = item.get("file", {})
115
+ # file 可能是 URL 或 base64
116
+ url = file_data.get("url", "") or file_data.get("data", "")
117
+ if isinstance(file_data, str):
118
+ url = file_data
119
+ if url:
120
+ attachments.append(("file", url))
121
+
122
+ if parts:
123
+ extracted.append({"role": role, "text": "\n".join(parts)})
124
+
125
+ # 合并文本
126
+ last_user_index = None
127
+ for i in range(len(extracted) - 1, -1, -1):
128
+ if extracted[i]["role"] == "user":
129
+ last_user_index = i
130
+ break
131
+
132
+ for i, item in enumerate(extracted):
133
+ role = item["role"] or "user"
134
+ text = item["text"]
135
+ if i == last_user_index:
136
+ texts.append(text)
137
+ else:
138
+ texts.append(f"{role}: {text}")
139
+
140
+ # 换行拼接文本
141
+ message = "\n\n".join(texts)
142
+ return message, attachments
143
+
144
+ @staticmethod
145
+ def extract_text_only(messages: List[Dict[str, Any]]) -> str:
146
+ """仅提取文本内容"""
147
+ text, _ = MessageExtractor.extract(messages, is_video=True)
148
+ return text
149
+
150
+
151
+ class ChatRequestBuilder:
152
+ """请求构造器"""
153
+
154
+ @staticmethod
155
+ def build_headers(token: str) -> Dict[str, str]:
156
+ """构造请求头"""
157
+ headers = {
158
+ "Accept": "*/*",
159
+ "Accept-Encoding": "gzip, deflate, br, zstd",
160
+ "Accept-Language": "zh-CN,zh;q=0.9",
161
+ "Baggage": "sentry-environment=production,sentry-release=d6add6fb0460641fd482d767a335ef72b9b6abb8,sentry-public_key=b311e0f2690c81f25e2c4cf6d4f7ce1c",
162
+ "Cache-Control": "no-cache",
163
+ "Content-Type": "application/json",
164
+ "Origin": "https://grok.com",
165
+ "Pragma": "no-cache",
166
+ "Priority": "u=1, i",
167
+ "Referer": "https://grok.com/",
168
+ "Sec-Ch-Ua": '"Google Chrome";v="136", "Chromium";v="136", "Not(A:Brand";v="24"',
169
+ "Sec-Ch-Ua-Arch": "arm",
170
+ "Sec-Ch-Ua-Bitness": "64",
171
+ "Sec-Ch-Ua-Mobile": "?0",
172
+ "Sec-Ch-Ua-Model": "",
173
+ "Sec-Ch-Ua-Platform": '"macOS"',
174
+ "Sec-Fetch-Dest": "empty",
175
+ "Sec-Fetch-Mode": "cors",
176
+ "Sec-Fetch-Site": "same-origin",
177
+ "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/136.0.0.0 Safari/537.36",
178
+ }
179
+
180
+ # Statsig ID
181
+ headers["x-statsig-id"] = StatsigService.gen_id()
182
+ headers["x-xai-request-id"] = str(uuid.uuid4())
183
+
184
+ # Cookie
185
+ token = token[4:] if token.startswith("sso=") else token
186
+ cf = get_config("grok.cf_clearance", "")
187
+ headers["Cookie"] = f"sso={token};cf_clearance={cf}" if cf else f"sso={token}"
188
+
189
+ return headers
190
+
191
+ @staticmethod
192
+ def build_payload(
193
+ message: str,
194
+ model: str,
195
+ mode: str,
196
+ think: bool = None,
197
+ file_attachments: List[str] = None,
198
+ image_attachments: List[str] = None
199
+ ) -> Dict[str, Any]:
200
+ """
201
+ 构造请求体
202
+
203
+ Args:
204
+ message: 消息文本
205
+ model: 模型名称
206
+ mode: 模型模式
207
+ think: 是否开启思考
208
+ file_attachments: 文件附件 ID 列表
209
+ image_attachments: 图片附件 URL 列表
210
+ """
211
+ temporary = get_config("grok.temporary", True)
212
+ if think is None:
213
+ think = get_config("grok.thinking", False)
214
+
215
+ # Upstream payload expects image attachments merged into fileAttachments.
216
+ merged_attachments: List[str] = []
217
+ if file_attachments:
218
+ merged_attachments.extend(file_attachments)
219
+ if image_attachments:
220
+ merged_attachments.extend(image_attachments)
221
+
222
+ return {
223
+ "temporary": temporary,
224
+ "modelName": model,
225
+ "modelMode": mode,
226
+ "message": message,
227
+ "fileAttachments": merged_attachments,
228
+ "imageAttachments": [],
229
+ "disableSearch": False,
230
+ "enableImageGeneration": True,
231
+ "returnImageBytes": False,
232
+ "returnRawGrokInXaiRequest": False,
233
+ "enableImageStreaming": True,
234
+ "imageGenerationCount": 2,
235
+ "forceConcise": False,
236
+ "toolOverrides": {},
237
+ "enableSideBySide": True,
238
+ "sendFinalMetadata": True,
239
+ "isReasoning": False,
240
+ "disableTextFollowUps": False,
241
+ "responseMetadata": {
242
+ "modelConfigOverride": {"modelMap": {}},
243
+ "requestModelDetails": {"modelId": model}
244
+ },
245
+ "disableMemory": False,
246
+ "forceSideBySide": False,
247
+ "isAsyncChat": False,
248
+ "disableSelfHarmShortCircuit": False,
249
+ "deviceEnvInfo": {
250
+ "darkModeEnabled": False,
251
+ "devicePixelRatio": 2,
252
+ "screenWidth": 2056,
253
+ "screenHeight": 1329,
254
+ "viewportWidth": 2056,
255
+ "viewportHeight": 1083
256
+ }
257
+ }
258
+
259
+
260
+ # ==================== Grok 服务 ====================
261
+
262
+ class GrokChatService:
263
+ """Grok API 调用服务"""
264
+
265
+ def __init__(self, proxy: str = None):
266
+ self.proxy = proxy or get_config("grok.base_proxy_url", "")
267
+
268
+ async def chat(
269
+ self,
270
+ token: str,
271
+ message: str,
272
+ model: str = "grok-3",
273
+ mode: str = "MODEL_MODE_FAST",
274
+ think: bool = None,
275
+ stream: bool = None,
276
+ file_attachments: List[str] = None,
277
+ image_attachments: List[str] = None
278
+ ):
279
+ """
280
+ 发送聊天请求
281
+
282
+ Args:
283
+ token: 认证 Token
284
+ message: 消息文本
285
+ model: Grok 模型名称
286
+ mode: 模型模式
287
+ think: 是否开启思考
288
+ stream: 是否流式
289
+ file_attachments: 文件附件 ID 列表
290
+ image_attachments: 图片附件 URL 列表
291
+
292
+ Raises:
293
+ UpstreamException: 当 Grok API 返回错误且重试耗尽时
294
+ """
295
+ if stream is None:
296
+ stream = get_config("grok.stream", True)
297
+
298
+ headers = ChatRequestBuilder.build_headers(token)
299
+ payload = ChatRequestBuilder.build_payload(
300
+ message, model, mode, think,
301
+ file_attachments, image_attachments
302
+ )
303
+ proxies = {"http": self.proxy, "https": self.proxy} if self.proxy else None
304
+ timeout = get_config("grok.timeout", TIMEOUT)
305
+
306
+ # 状态码提取器
307
+ def extract_status(e: Exception) -> int | None:
308
+ if isinstance(e, UpstreamException) and e.details:
309
+ return e.details.get("status")
310
+ return None
311
+
312
+ # 建立连接函数
313
+ async def establish_connection():
314
+ """建立连接并返回 response 对象"""
315
+ session = AsyncSession(impersonate=BROWSER)
316
+ try:
317
+ response = await session.post(
318
+ CHAT_API,
319
+ headers=headers,
320
+ data=orjson.dumps(payload),
321
+ timeout=timeout,
322
+ stream=True,
323
+ proxies=proxies
324
+ )
325
+
326
+ if response.status_code != 200:
327
+ try:
328
+ content = await response.text()
329
+ content = content[:1000] # 限制长度避免日志过大
330
+ except:
331
+ content = "Unable to read response content"
332
+
333
+ logger.error(
334
+ f"Chat failed: {response.status_code}, {content}",
335
+ extra={"status": response.status_code, "token": token[:10] + "..."}
336
+ )
337
+ # 关闭 session 并抛出异常
338
+ try:
339
+ await session.close()
340
+ except:
341
+ pass
342
+ raise UpstreamException(
343
+ message=f"Grok API request failed: {response.status_code}",
344
+ details={"status": response.status_code}
345
+ )
346
+
347
+ # 返回 session 和 response
348
+ return session, response
349
+
350
+ except UpstreamException:
351
+ # 已经处理过的异常,直接抛出
352
+ raise
353
+ except Exception as e:
354
+ # 其他异常,关闭 session 并包装
355
+ logger.error(f"Chat request error: {e}")
356
+ try:
357
+ await session.close()
358
+ except:
359
+ pass
360
+ raise UpstreamException(
361
+ message=f"Chat connection failed: {str(e)}",
362
+ details={"error": str(e)}
363
+ )
364
+
365
+ # 建立连接
366
+ session = None
367
+ response = None
368
+ try:
369
+ session, response = await retry_on_status(
370
+ establish_connection,
371
+ extract_status=extract_status
372
+ )
373
+ except Exception as e:
374
+ # 记录失败
375
+ status_code = extract_status(e)
376
+ if status_code:
377
+ token_mgr = await get_token_manager()
378
+ await token_mgr.record_fail(token, status_code, str(e))
379
+ raise
380
+
381
+ # 流式传输
382
+ async def stream_response():
383
+ try:
384
+ async for line in response.aiter_lines():
385
+ yield line
386
+ finally:
387
+ if session:
388
+ await session.close()
389
+
390
+ return stream_response()
391
+
392
+ async def chat_openai(self, token: str, request: ChatRequest):
393
+ """OpenAI 兼容接口"""
394
+ model_info = ModelService.get(request.model)
395
+ if not model_info:
396
+ raise ValidationException(f"Unknown model: {request.model}")
397
+
398
+ grok_model = model_info.grok_model
399
+ mode = model_info.model_mode
400
+ is_video = model_info.is_video
401
+
402
+ # 提取消息和附件
403
+ try:
404
+ message, attachments = MessageExtractor.extract(request.messages, is_video=is_video)
405
+ except ValueError as e:
406
+ raise ValidationException(str(e))
407
+
408
+ # 处理附件上传
409
+ file_ids = []
410
+ image_ids = []
411
+
412
+ if attachments:
413
+ upload_service = UploadService()
414
+ try:
415
+ for attach_type, attach_data in attachments:
416
+ # 获取 ID
417
+ file_id, _ = await upload_service.upload(attach_data, token)
418
+
419
+ if attach_type == "image":
420
+ # 图片 imageAttachments
421
+ image_ids.append(file_id)
422
+ logger.debug(f"Image uploaded: {file_id}")
423
+ else:
424
+ # 文件 fileAttachments
425
+ file_ids.append(file_id)
426
+ logger.debug(f"File uploaded: {file_id}")
427
+ finally:
428
+ await upload_service.close()
429
+
430
+ stream = request.stream if request.stream is not None else get_config("grok.stream", True)
431
+ think = request.think if request.think is not None else get_config("grok.thinking", False)
432
+
433
+ response = await self.chat(
434
+ token, message, grok_model, mode, think, stream,
435
+ file_attachments=file_ids,
436
+ image_attachments=image_ids
437
+ )
438
+
439
+ return response, stream, request.model
440
+
441
+
442
+ # ==================== Chat 业务服务 ====================
443
+
444
+ class ChatService:
445
+ """Chat 业务服务"""
446
+
447
+ @staticmethod
448
+ async def completions(
449
+ model: str,
450
+ messages: List[Dict[str, Any]],
451
+ stream: bool = None,
452
+ thinking: str = None
453
+ ):
454
+ """
455
+ Chat Completions 入口
456
+
457
+ Args:
458
+ model: 模型名称
459
+ messages: 消息列表
460
+ stream: 是否流式
461
+ thinking: 思考模式
462
+
463
+ Returns:
464
+ AsyncGenerator 或 dict
465
+ """
466
+ # 获取 token
467
+ try:
468
+ token_mgr = await get_token_manager()
469
+ await token_mgr.reload_if_stale()
470
+ token = token_mgr.get_token_for_model(model)
471
+ except Exception as e:
472
+ logger.error(f"Failed to get token: {e}")
473
+ try:
474
+ await request_stats.record_request(model, success=False)
475
+ except Exception:
476
+ pass
477
+ raise AppException(
478
+ message="Internal service error obtaining token",
479
+ error_type=ErrorType.SERVER.value,
480
+ code="internal_error"
481
+ )
482
+
483
+ if not token:
484
+ try:
485
+ await request_stats.record_request(model, success=False)
486
+ except Exception:
487
+ pass
488
+ raise AppException(
489
+ message="No available tokens. Please try again later.",
490
+ error_type=ErrorType.RATE_LIMIT.value,
491
+ code="rate_limit_exceeded",
492
+ status_code=429
493
+ )
494
+
495
+ # 解析参数
496
+ think = None
497
+ if thinking == "enabled":
498
+ think = True
499
+ elif thinking == "disabled":
500
+ think = False
501
+
502
+ is_stream = stream if stream is not None else get_config("grok.stream", True)
503
+
504
+ # 构造请求
505
+ chat_request = ChatRequest(
506
+ model=model,
507
+ messages=messages,
508
+ stream=is_stream,
509
+ think=think
510
+ )
511
+
512
+ # 请求 Grok
513
+ service = GrokChatService()
514
+ try:
515
+ response, _, model_name = await service.chat_openai(token, chat_request)
516
+ except AppException:
517
+ try:
518
+ await request_stats.record_request(model, success=False)
519
+ except Exception:
520
+ pass
521
+ raise
522
+ except Exception as e:
523
+ logger.error(f"Chat service error: {e}")
524
+ try:
525
+ await request_stats.record_request(model, success=False)
526
+ except Exception:
527
+ pass
528
+ raise UpstreamException(
529
+ message=f"Service processing failed: {str(e)}",
530
+ details={"error": str(e)}
531
+ )
532
+
533
+ # 处理响应
534
+ if is_stream:
535
+ processor = StreamProcessor(model_name, token, think).process(response)
536
+
537
+ async def _wrapped_stream():
538
+ completed = False
539
+ try:
540
+ async for chunk in processor:
541
+ yield chunk
542
+ completed = True
543
+ finally:
544
+ # Only count as "success" when the stream ends naturally.
545
+ try:
546
+ if completed:
547
+ await token_mgr.sync_usage(token, model_name, consume_on_fail=True, is_usage=True)
548
+ await request_stats.record_request(model_name, success=True)
549
+ else:
550
+ await request_stats.record_request(model_name, success=False)
551
+ except Exception:
552
+ pass
553
+
554
+ return _wrapped_stream()
555
+
556
+ result = await CollectProcessor(model_name, token).process(response)
557
+ try:
558
+ await token_mgr.sync_usage(token, model_name, consume_on_fail=True, is_usage=True)
559
+ await request_stats.record_request(model_name, success=True)
560
+ except Exception:
561
+ pass
562
+ return result
563
+
564
+
565
+ __all__ = [
566
+ "GrokChatService",
567
+ "ChatRequest",
568
+ "ChatRequestBuilder",
569
+ "MessageExtractor",
570
+ "ChatService",
571
+ ]
app/services/grok/imagine_experimental.py ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Experimental imagine/image-edit upstream calls.
3
+
4
+ This module provides:
5
+ - WebSocket imagine generation (ws/imagine/listen)
6
+ - Experimental image-edit payloads via conversations/new
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import asyncio
12
+ import time
13
+ import uuid
14
+ from typing import Any, Awaitable, Callable, Dict, Iterable, List, Optional
15
+ from urllib.parse import urlparse
16
+
17
+ import orjson
18
+ from curl_cffi.requests import AsyncSession
19
+
20
+ from app.core.config import get_config
21
+ from app.core.exceptions import UpstreamException
22
+ from app.core.logger import logger
23
+ from app.services.grok.assets import DownloadService
24
+ from app.services.grok.chat import BROWSER, CHAT_API, ChatRequestBuilder
25
+
26
+
27
+ IMAGE_METHOD_LEGACY = "legacy"
28
+ IMAGE_METHOD_IMAGINE_WS_EXPERIMENTAL = "imagine_ws_experimental"
29
+ IMAGE_METHODS = {IMAGE_METHOD_LEGACY, IMAGE_METHOD_IMAGINE_WS_EXPERIMENTAL}
30
+ IMAGE_METHOD_ALIASES = {
31
+ "imagine_ws": IMAGE_METHOD_IMAGINE_WS_EXPERIMENTAL,
32
+ "experimental": IMAGE_METHOD_IMAGINE_WS_EXPERIMENTAL,
33
+ "new": IMAGE_METHOD_IMAGINE_WS_EXPERIMENTAL,
34
+ "new_method": IMAGE_METHOD_IMAGINE_WS_EXPERIMENTAL,
35
+ }
36
+
37
+ IMAGINE_WS_API = "wss://grok.com/ws/imagine/listen"
38
+ ASSET_API = "https://assets.grok.com"
39
+ TIMEOUT = 120
40
+
41
+ ProgressCallback = Callable[[int, float], Optional[Awaitable[None] | None]]
42
+ CompletedCallback = Callable[[int, str], Optional[Awaitable[None] | None]]
43
+
44
+
45
+ def resolve_image_generation_method(raw: Any) -> str:
46
+ candidate = str(raw or "").strip().lower()
47
+ if candidate in IMAGE_METHODS:
48
+ return candidate
49
+ mapped = IMAGE_METHOD_ALIASES.get(candidate)
50
+ if mapped:
51
+ return mapped
52
+ return IMAGE_METHOD_LEGACY
53
+
54
+
55
+ def _normalize_asset_path(raw_url: str) -> str:
56
+ raw = str(raw_url or "").strip()
57
+ if not raw:
58
+ return "/"
59
+ if raw.startswith("http://") or raw.startswith("https://"):
60
+ try:
61
+ path = urlparse(raw).path or "/"
62
+ except Exception:
63
+ path = "/"
64
+ else:
65
+ path = raw
66
+ if not path.startswith("/"):
67
+ path = f"/{path}"
68
+ return path
69
+
70
+
71
+ class ImagineExperimentalService:
72
+ def __init__(self, proxy: str | None = None):
73
+ self.proxy = proxy or get_config("grok.base_proxy_url", "")
74
+ self.timeout = int(get_config("grok.timeout", TIMEOUT) or TIMEOUT)
75
+
76
+ def _proxies(self) -> Optional[dict]:
77
+ return {"http": self.proxy, "https": self.proxy} if self.proxy else None
78
+
79
+ def _headers(self, token: str, referer: str = "https://grok.com/imagine") -> Dict[str, str]:
80
+ headers = ChatRequestBuilder.build_headers(token)
81
+ headers["Referer"] = referer
82
+ headers["Origin"] = "https://grok.com"
83
+ return headers
84
+
85
+ @staticmethod
86
+ def _build_ws_payload(
87
+ prompt: str,
88
+ request_id: str,
89
+ aspect_ratio: str = "2:3",
90
+ ) -> Dict[str, Any]:
91
+ return {
92
+ "type": "conversation.item.create",
93
+ "timestamp": int(time.time() * 1000),
94
+ "item": {
95
+ "type": "message",
96
+ "content": [
97
+ {
98
+ "requestId": request_id,
99
+ "text": prompt,
100
+ "type": "input_scroll",
101
+ "properties": {
102
+ "section_count": 0,
103
+ "is_kids_mode": False,
104
+ "enable_nsfw": True,
105
+ "skip_upsampler": False,
106
+ "is_initial": False,
107
+ "aspect_ratio": aspect_ratio,
108
+ },
109
+ }
110
+ ],
111
+ },
112
+ }
113
+
114
+ @staticmethod
115
+ def _extract_url(msg: Dict[str, Any]) -> str:
116
+ for key in ("url", "imageUrl", "image_url"):
117
+ value = msg.get(key)
118
+ if isinstance(value, str) and value.strip():
119
+ return value.strip()
120
+ return ""
121
+
122
+ @staticmethod
123
+ def _extract_progress(msg: Dict[str, Any]) -> Optional[float]:
124
+ for key in ("progress", "percentage_complete", "percentageComplete"):
125
+ value = msg.get(key)
126
+ if value is None:
127
+ continue
128
+ try:
129
+ pct = float(value)
130
+ if pct < 0:
131
+ pct = 0
132
+ if pct > 100:
133
+ pct = 100
134
+ return pct
135
+ except Exception:
136
+ continue
137
+ return None
138
+
139
+ @staticmethod
140
+ def _is_completed(msg: Dict[str, Any], progress: Optional[float]) -> bool:
141
+ status = str(msg.get("current_status") or msg.get("currentStatus") or "").strip().lower()
142
+ if status in {"completed", "done", "success"}:
143
+ return True
144
+ if progress is not None and progress >= 100:
145
+ return True
146
+ return False
147
+
148
+ async def generate_ws(
149
+ self,
150
+ token: str,
151
+ prompt: str,
152
+ n: int = 2,
153
+ aspect_ratio: str = "2:3",
154
+ progress_cb: Optional[ProgressCallback] = None,
155
+ completed_cb: Optional[CompletedCallback] = None,
156
+ timeout: Optional[int] = None,
157
+ ) -> List[str]:
158
+ request_id = str(uuid.uuid4())
159
+ target_count = max(1, int(n or 1))
160
+ effective_timeout = max(10, int(timeout or self.timeout))
161
+ payload = self._build_ws_payload(
162
+ prompt=prompt,
163
+ request_id=request_id,
164
+ aspect_ratio=aspect_ratio,
165
+ )
166
+
167
+ session = AsyncSession(impersonate=BROWSER)
168
+ ws = None
169
+ started_at = time.monotonic()
170
+ image_indices: Dict[str, int] = {}
171
+ final_urls: Dict[str, str] = {}
172
+
173
+ try:
174
+ ws = await session.ws_connect(
175
+ IMAGINE_WS_API,
176
+ headers=self._headers(token),
177
+ timeout=effective_timeout,
178
+ proxies=self._proxies(),
179
+ impersonate=BROWSER,
180
+ )
181
+ await ws.send_json(payload)
182
+
183
+ while time.monotonic() - started_at < effective_timeout:
184
+ remain = max(1.0, effective_timeout - (time.monotonic() - started_at))
185
+ try:
186
+ msg = await ws.recv_json(timeout=min(5.0, remain))
187
+ except asyncio.TimeoutError:
188
+ continue
189
+ except Exception as e:
190
+ raise UpstreamException(f"Imagine websocket receive failed: {e}") from e
191
+
192
+ if not isinstance(msg, dict):
193
+ continue
194
+
195
+ msg_request_id = str(msg.get("request_id") or msg.get("requestId") or "")
196
+ if msg_request_id and msg_request_id != request_id:
197
+ continue
198
+
199
+ msg_type = str(msg.get("type") or "").lower()
200
+ status = str(msg.get("current_status") or msg.get("currentStatus") or "").lower()
201
+ if msg_type == "error" or status == "error":
202
+ err_code = str(msg.get("err_code") or msg.get("errCode") or "unknown")
203
+ err_msg = str(
204
+ msg.get("err_message") or msg.get("err_msg") or msg.get("error") or "unknown error"
205
+ )
206
+ raise UpstreamException(
207
+ message=f"Imagine websocket error ({err_code}): {err_msg}",
208
+ details={"code": err_code, "message": err_msg},
209
+ )
210
+
211
+ image_id = str(msg.get("id") or msg.get("imageId") or msg.get("image_id") or "")
212
+ if not image_id:
213
+ image_id = f"image-{len(image_indices)}"
214
+ if image_id not in image_indices:
215
+ image_indices[image_id] = len(image_indices)
216
+
217
+ progress = self._extract_progress(msg)
218
+ if progress is not None and progress_cb is not None:
219
+ try:
220
+ maybe_coro = progress_cb(image_indices[image_id], progress)
221
+ if asyncio.iscoroutine(maybe_coro):
222
+ await maybe_coro
223
+ except Exception as e:
224
+ logger.debug(f"Imagine progress callback failed: {e}")
225
+
226
+ image_url = self._extract_url(msg)
227
+ if image_url and self._is_completed(msg, progress):
228
+ is_new = image_id not in final_urls
229
+ final_urls.setdefault(image_id, image_url)
230
+ if is_new and completed_cb is not None:
231
+ try:
232
+ maybe_coro = completed_cb(image_indices[image_id], image_url)
233
+ if asyncio.iscoroutine(maybe_coro):
234
+ await maybe_coro
235
+ except Exception as e:
236
+ logger.debug(f"Imagine completion callback failed: {e}")
237
+ if len(final_urls) >= target_count:
238
+ break
239
+
240
+ if not final_urls:
241
+ raise UpstreamException("Imagine websocket returned no completed images")
242
+
243
+ return list(final_urls.values())
244
+ finally:
245
+ if ws is not None:
246
+ try:
247
+ await ws.close()
248
+ except Exception:
249
+ pass
250
+ try:
251
+ await session.close()
252
+ except Exception:
253
+ pass
254
+
255
+ async def convert_urls(self, token: str, urls: Iterable[str], response_format: str = "b64_json") -> List[str]:
256
+ mode = str(response_format or "b64_json").strip().lower()
257
+ out: List[str] = []
258
+ dl = DownloadService(self.proxy)
259
+ try:
260
+ for raw in urls:
261
+ raw = str(raw or "").strip()
262
+ if not raw:
263
+ continue
264
+ if mode == "url":
265
+ path = _normalize_asset_path(raw)
266
+ if path in {"", "/"}:
267
+ continue
268
+ await dl.download(path, token, "image")
269
+ app_url = str(get_config("app.app_url", "") or "").strip()
270
+ local_path = f"/v1/files/image{path}"
271
+ if app_url:
272
+ out.append(f"{app_url.rstrip('/')}{local_path}")
273
+ else:
274
+ out.append(local_path)
275
+ continue
276
+
277
+ data_uri = await dl.to_base64(raw, token, "image")
278
+ if not data_uri:
279
+ continue
280
+ if "," in data_uri:
281
+ out.append(data_uri.split(",", 1)[1])
282
+ else:
283
+ out.append(data_uri)
284
+ return out
285
+ finally:
286
+ await dl.close()
287
+
288
+ async def convert_url(self, token: str, url: str, response_format: str = "b64_json") -> str:
289
+ items = await self.convert_urls(token=token, urls=[url], response_format=response_format)
290
+ return items[0] if items else ""
291
+
292
+ @staticmethod
293
+ def _to_asset_urls(file_uris: List[str]) -> List[str]:
294
+ out = []
295
+ for uri in file_uris:
296
+ value = str(uri or "").strip()
297
+ if not value:
298
+ continue
299
+ if value.startswith("http://") or value.startswith("https://"):
300
+ out.append(value)
301
+ else:
302
+ out.append(f"{ASSET_API}/{value.lstrip('/')}")
303
+ return out
304
+
305
+ @staticmethod
306
+ def _build_edit_payload(prompt: str, image_urls: List[str], model_name: str) -> Dict[str, Any]:
307
+ model_map = {
308
+ "imageEditModel": "imagine",
309
+ "imageEditModelConfig": {
310
+ "imageReferences": image_urls,
311
+ },
312
+ }
313
+ payload: Dict[str, Any] = {
314
+ "temporary": True,
315
+ "modelName": model_name,
316
+ "message": prompt,
317
+ "fileAttachments": [],
318
+ "imageAttachments": [],
319
+ "disableSearch": False,
320
+ "enableImageGeneration": True,
321
+ "returnImageBytes": False,
322
+ "returnRawGrokInXaiRequest": False,
323
+ "enableImageStreaming": True,
324
+ "imageGenerationCount": 2,
325
+ "forceConcise": False,
326
+ "toolOverrides": {"imageGen": True},
327
+ "enableSideBySide": True,
328
+ "sendFinalMetadata": True,
329
+ "isReasoning": False,
330
+ "disableTextFollowUps": False,
331
+ "disableMemory": False,
332
+ "forceSideBySide": False,
333
+ "isAsyncChat": False,
334
+ "responseMetadata": {
335
+ "modelConfigOverride": {
336
+ "modelMap": model_map,
337
+ },
338
+ "requestModelDetails": {
339
+ "modelId": model_name,
340
+ },
341
+ },
342
+ }
343
+ if model_name == "grok-3":
344
+ payload["modelMode"] = "MODEL_MODE_FAST"
345
+ return payload
346
+
347
+ async def chat_edit(
348
+ self,
349
+ token: str,
350
+ prompt: str,
351
+ file_uris: List[str],
352
+ ):
353
+ image_urls = self._to_asset_urls(file_uris)
354
+ if not image_urls:
355
+ raise UpstreamException("Experimental image edit requires at least one uploaded image")
356
+
357
+ headers = self._headers(token, referer="https://grok.com/imagine")
358
+ proxies = self._proxies()
359
+ timeout = self.timeout
360
+
361
+ payloads = [
362
+ self._build_edit_payload(prompt, image_urls, "imagine-image-edit"),
363
+ self._build_edit_payload(prompt, image_urls, "grok-3"),
364
+ ]
365
+
366
+ last_error: Optional[Exception] = None
367
+ for payload in payloads:
368
+ session = AsyncSession(impersonate=BROWSER)
369
+ response = None
370
+ try:
371
+ response = await session.post(
372
+ CHAT_API,
373
+ headers=headers,
374
+ data=orjson.dumps(payload),
375
+ timeout=timeout,
376
+ stream=True,
377
+ proxies=proxies,
378
+ )
379
+ if response.status_code != 200:
380
+ try:
381
+ body = await response.text()
382
+ except Exception:
383
+ body = ""
384
+ raise UpstreamException(
385
+ message=f"Experimental image edit request failed: {response.status_code}",
386
+ details={"status": response.status_code, "body": body[:500]},
387
+ )
388
+
389
+ async def _stream_response():
390
+ try:
391
+ async for line in response.aiter_lines():
392
+ yield line
393
+ finally:
394
+ await session.close()
395
+
396
+ return _stream_response()
397
+ except Exception as e:
398
+ last_error = e
399
+ try:
400
+ await session.close()
401
+ except Exception:
402
+ pass
403
+ continue
404
+
405
+ if isinstance(last_error, Exception):
406
+ raise last_error
407
+ raise UpstreamException("Experimental image edit request failed")
408
+
409
+
410
+ __all__ = [
411
+ "ImagineExperimentalService",
412
+ "IMAGE_METHOD_LEGACY",
413
+ "IMAGE_METHOD_IMAGINE_WS_EXPERIMENTAL",
414
+ "IMAGE_METHODS",
415
+ "resolve_image_generation_method",
416
+ ]
app/services/grok/imagine_generation.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Shared helpers for experimental imagine generation flows.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ import asyncio
8
+ from typing import Any, Awaitable, Callable, List, Optional
9
+
10
+ from app.core.exceptions import UpstreamException
11
+ from app.core.logger import logger
12
+ from app.services.grok.imagine_experimental import ImagineExperimentalService
13
+
14
+
15
+ def resolve_aspect_ratio(size: Optional[str]) -> str:
16
+ value = str(size or "").strip().lower()
17
+ if value in {"16:9", "9:16", "1:1", "2:3", "3:2"}:
18
+ return value
19
+
20
+ mapping = {
21
+ "1024x1024": "1:1",
22
+ "512x512": "1:1",
23
+ "1024x576": "16:9",
24
+ "1280x720": "16:9",
25
+ "1536x864": "16:9",
26
+ "576x1024": "9:16",
27
+ "720x1280": "9:16",
28
+ "864x1536": "9:16",
29
+ "1024x1536": "2:3",
30
+ "1024x1792": "2:3",
31
+ "512x768": "2:3",
32
+ "768x1024": "2:3",
33
+ "1536x1024": "3:2",
34
+ "1792x1024": "3:2",
35
+ "768x512": "3:2",
36
+ "1024x768": "3:2",
37
+ }
38
+ return mapping.get(value, "2:3")
39
+
40
+
41
+ def is_valid_image_value(value: Any) -> bool:
42
+ return isinstance(value, str) and bool(value) and value != "error"
43
+
44
+
45
+ def dedupe_images(images: List[str]) -> List[str]:
46
+ out: List[str] = []
47
+ seen: set[str] = set()
48
+ for image in images:
49
+ if not isinstance(image, str):
50
+ continue
51
+ if image in seen:
52
+ continue
53
+ seen.add(image)
54
+ out.append(image)
55
+ return out
56
+
57
+
58
+ async def gather_limited(
59
+ task_factories: List[Callable[[], Awaitable[List[str]]]],
60
+ max_concurrency: int,
61
+ ) -> List[Any]:
62
+ sem = asyncio.Semaphore(max(1, int(max_concurrency or 1)))
63
+
64
+ async def _run(factory: Callable[[], Awaitable[List[str]]]) -> Any:
65
+ async with sem:
66
+ return await factory()
67
+
68
+ return await asyncio.gather(*[_run(factory) for factory in task_factories], return_exceptions=True)
69
+
70
+
71
+ async def call_experimental_generation_once(
72
+ token: str,
73
+ prompt: str,
74
+ response_format: str = "b64_json",
75
+ n: int = 4,
76
+ aspect_ratio: str = "2:3",
77
+ ) -> List[str]:
78
+ service = ImagineExperimentalService()
79
+ raw_urls = await service.generate_ws(
80
+ token=token,
81
+ prompt=prompt,
82
+ n=n,
83
+ aspect_ratio=aspect_ratio,
84
+ )
85
+ return await service.convert_urls(token=token, urls=raw_urls, response_format=response_format)
86
+
87
+
88
+ async def collect_experimental_generation_images(
89
+ token: str,
90
+ prompt: str,
91
+ n: int,
92
+ response_format: str,
93
+ aspect_ratio: str,
94
+ concurrency: int,
95
+ ) -> List[str]:
96
+ calls_needed = max(1, (n + 3) // 4)
97
+ task_factories: List[Callable[[], Awaitable[List[str]]]] = []
98
+ remain = n
99
+ for _ in range(calls_needed):
100
+ target_n = max(1, min(4, remain))
101
+ remain -= target_n
102
+ task_factories.append(
103
+ lambda target_n=target_n: call_experimental_generation_once(
104
+ token,
105
+ prompt,
106
+ response_format=response_format,
107
+ n=target_n,
108
+ aspect_ratio=aspect_ratio,
109
+ )
110
+ )
111
+
112
+ results = await gather_limited(
113
+ task_factories,
114
+ max_concurrency=min(calls_needed, max(1, int(concurrency or 1))),
115
+ )
116
+ all_images: List[str] = []
117
+ for result in results:
118
+ if isinstance(result, Exception):
119
+ logger.warning(f"Experimental imagine websocket call failed: {result}")
120
+ continue
121
+ if isinstance(result, list):
122
+ all_images.extend(result)
123
+
124
+ all_images = dedupe_images(all_images)
125
+ if not any(is_valid_image_value(item) for item in all_images):
126
+ raise UpstreamException("Experimental imagine websocket returned no images")
127
+ return all_images
128
+
129
+
130
+ __all__ = [
131
+ "resolve_aspect_ratio",
132
+ "is_valid_image_value",
133
+ "dedupe_images",
134
+ "gather_limited",
135
+ "call_experimental_generation_once",
136
+ "collect_experimental_generation_images",
137
+ ]
app/services/grok/media.py ADDED
@@ -0,0 +1,512 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Grok 视频生成服务
3
+ """
4
+
5
+ import asyncio
6
+ import uuid
7
+ from typing import AsyncGenerator, Optional
8
+
9
+ import orjson
10
+ from curl_cffi.requests import AsyncSession
11
+
12
+ from app.core.logger import logger
13
+ from app.core.config import get_config
14
+ from app.core.exceptions import UpstreamException, AppException, ValidationException, ErrorType
15
+ from app.services.grok.statsig import StatsigService
16
+ from app.services.grok.model import ModelService
17
+ from app.services.token import get_token_manager
18
+ from app.services.grok.processor import VideoStreamProcessor, VideoCollectProcessor
19
+ from app.services.request_stats import request_stats
20
+
21
+ # API 端点
22
+ CREATE_POST_API = "https://grok.com/rest/media/post/create"
23
+ CHAT_API = "https://grok.com/rest/app-chat/conversations/new"
24
+
25
+ # 常量
26
+ BROWSER = "chrome136"
27
+ TIMEOUT = 300
28
+ DEFAULT_MAX_CONCURRENT = 50
29
+ _MEDIA_SEMAPHORE = asyncio.Semaphore(DEFAULT_MAX_CONCURRENT)
30
+ _MEDIA_SEM_VALUE = DEFAULT_MAX_CONCURRENT
31
+
32
+ def _get_media_semaphore() -> asyncio.Semaphore:
33
+ global _MEDIA_SEMAPHORE, _MEDIA_SEM_VALUE
34
+ value = get_config("performance.media_max_concurrent", DEFAULT_MAX_CONCURRENT)
35
+ try:
36
+ value = int(value)
37
+ except Exception:
38
+ value = DEFAULT_MAX_CONCURRENT
39
+ value = max(1, value)
40
+ if value != _MEDIA_SEM_VALUE:
41
+ _MEDIA_SEM_VALUE = value
42
+ _MEDIA_SEMAPHORE = asyncio.Semaphore(value)
43
+ return _MEDIA_SEMAPHORE
44
+
45
+
46
+ class VideoService:
47
+ """视频生成服务"""
48
+
49
+ def __init__(self, proxy: str = None):
50
+ self.proxy = proxy or get_config("grok.base_proxy_url", "")
51
+ self.timeout = get_config("grok.timeout", TIMEOUT)
52
+
53
+ def _build_headers(self, token: str, referer: str = "https://grok.com/imagine") -> dict:
54
+ """构建请求头"""
55
+ headers = {
56
+ "Accept": "*/*",
57
+ "Accept-Encoding": "gzip, deflate, br, zstd",
58
+ "Accept-Language": "zh-CN,zh;q=0.9",
59
+ "Baggage": "sentry-environment=production,sentry-release=d6add6fb0460641fd482d767a335ef72b9b6abb8,sentry-public_key=b311e0f2690c81f25e2c4cf6d4f7ce1c",
60
+ "Cache-Control": "no-cache",
61
+ "Content-Type": "application/json",
62
+ "Origin": "https://grok.com",
63
+ "Pragma": "no-cache",
64
+ "Priority": "u=1, i",
65
+ "Referer": referer,
66
+ "Sec-Ch-Ua": '"Google Chrome";v="136", "Chromium";v="136", "Not(A:Brand";v="24"',
67
+ "Sec-Ch-Ua-Arch": "arm",
68
+ "Sec-Ch-Ua-Bitness": "64",
69
+ "Sec-Ch-Ua-Mobile": "?0",
70
+ "Sec-Ch-Ua-Model": "",
71
+ "Sec-Ch-Ua-Platform": '"macOS"',
72
+ "Sec-Fetch-Dest": "empty",
73
+ "Sec-Fetch-Mode": "cors",
74
+ "Sec-Fetch-Site": "same-origin",
75
+ "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/136.0.0.0 Safari/537.36",
76
+ }
77
+
78
+ # Statsig ID
79
+ headers["x-statsig-id"] = StatsigService.gen_id()
80
+ headers["x-xai-request-id"] = str(uuid.uuid4())
81
+
82
+ # Cookie
83
+ token = token[4:] if token.startswith("sso=") else token
84
+ cf = get_config("grok.cf_clearance", "")
85
+ headers["Cookie"] = f"sso={token};cf_clearance={cf}" if cf else f"sso={token}"
86
+
87
+ return headers
88
+
89
+ def _build_proxies(self) -> Optional[dict]:
90
+ """构建代理"""
91
+ return {"http": self.proxy, "https": self.proxy} if self.proxy else None
92
+
93
+ async def create_post(self, token: str, prompt: str, media_type: str = "MEDIA_POST_TYPE_VIDEO", media_url: str = None) -> str:
94
+ """
95
+ 创建媒体帖子
96
+
97
+ Args:
98
+ token: 认证 Token
99
+ prompt: 提示词(视频生成用)
100
+ media_type: 媒体类型 (MEDIA_POST_TYPE_VIDEO 或 MEDIA_POST_TYPE_IMAGE)
101
+ media_url: 媒体 URL(图片模式用)
102
+
103
+ Returns:
104
+ post ID
105
+ """
106
+ try:
107
+ headers = self._build_headers(token)
108
+
109
+ # 根据类型构建不同的载荷
110
+ if media_type == "MEDIA_POST_TYPE_IMAGE" and media_url:
111
+ payload = {
112
+ "mediaType": media_type,
113
+ "mediaUrl": media_url
114
+ }
115
+ else:
116
+ payload = {
117
+ "mediaType": media_type,
118
+ "prompt": prompt
119
+ }
120
+
121
+ async with AsyncSession() as session:
122
+ response = await session.post(
123
+ CREATE_POST_API,
124
+ headers=headers,
125
+ json=payload,
126
+ impersonate=BROWSER,
127
+ timeout=30,
128
+ proxies=self._build_proxies()
129
+ )
130
+
131
+ if response.status_code != 200:
132
+ logger.error(f"Create post failed: {response.status_code}")
133
+ raise UpstreamException(f"Failed to create post: {response.status_code}")
134
+
135
+ data = response.json()
136
+ post_id = data.get("post", {}).get("id", "")
137
+
138
+ if not post_id:
139
+ raise UpstreamException("No post ID in response")
140
+
141
+ logger.info(f"Media post created: {post_id} (type={media_type})")
142
+ return post_id
143
+
144
+ except Exception as e:
145
+ logger.error(f"Create post error: {e}")
146
+ if isinstance(e, AppException):
147
+ raise e
148
+ raise UpstreamException(f"Create post error: {str(e)}")
149
+
150
+ async def create_image_post(self, token: str, image_url: str) -> str:
151
+ """
152
+ 创建图片帖子
153
+
154
+ Args:
155
+ token: 认证 Token
156
+ image_url: 完整的图片 URL (https://assets.grok.com/...)
157
+
158
+ Returns:
159
+ post ID
160
+ """
161
+ return await self.create_post(
162
+ token,
163
+ prompt="",
164
+ media_type="MEDIA_POST_TYPE_IMAGE",
165
+ media_url=image_url
166
+ )
167
+
168
+ def _build_payload(
169
+ self,
170
+ prompt: str,
171
+ post_id: str,
172
+ aspect_ratio: str = "3:2",
173
+ video_length: int = 6,
174
+ resolution: str = "SD",
175
+ preset: str = "normal"
176
+ ) -> dict:
177
+ """构建视频生成载荷"""
178
+ mode_flag = "--mode=custom"
179
+ if preset == "fun":
180
+ mode_flag = "--mode=extremely-crazy"
181
+ elif preset == "normal":
182
+ mode_flag = "--mode=normal"
183
+ elif preset == "spicy":
184
+ mode_flag = "--mode=extremely-spicy-or-crazy"
185
+
186
+ full_prompt = f"{prompt} {mode_flag}"
187
+
188
+ return {
189
+ "temporary": True,
190
+ "modelName": "grok-3",
191
+ "message": full_prompt,
192
+ "toolOverrides": {"videoGen": True},
193
+ "enableSideBySide": True,
194
+ "responseMetadata": {
195
+ "experiments": [],
196
+ "modelConfigOverride": {
197
+ "modelMap": {
198
+ "videoGenModelConfig": {
199
+ "parentPostId": post_id,
200
+ "aspectRatio": aspect_ratio,
201
+ "videoLength": video_length,
202
+ "videoResolution": resolution
203
+ }
204
+ }
205
+ }
206
+ }
207
+ }
208
+
209
+ async def generate(
210
+ self,
211
+ token: str,
212
+ prompt: str,
213
+ aspect_ratio: str = "3:2",
214
+ video_length: int = 6,
215
+ resolution: str = "SD",
216
+ stream: bool = True,
217
+ preset: str = "normal"
218
+ ) -> AsyncGenerator[bytes, None]:
219
+ """
220
+ 生成视频
221
+
222
+ Args:
223
+ token: 认证 Token
224
+ prompt: 视频描述
225
+ aspect_ratio: 宽高比
226
+ video_length: 视频时长
227
+ resolution: 分辨率
228
+ stream: 是否流式
229
+ preset: 预设
230
+
231
+ Returns:
232
+ AsyncGenerator,流式传输
233
+
234
+ Raises:
235
+ UpstreamException: 连接失败时
236
+ """
237
+ async with _get_media_semaphore():
238
+ session = None
239
+ try:
240
+ # Step 1: 创建帖子
241
+ post_id = await self.create_post(token, prompt)
242
+
243
+ # Step 2: 建立连接
244
+ headers = self._build_headers(token)
245
+ payload = self._build_payload(prompt, post_id, aspect_ratio, video_length, resolution, preset)
246
+
247
+ session = AsyncSession(impersonate=BROWSER)
248
+ response = await session.post(
249
+ CHAT_API,
250
+ headers=headers,
251
+ data=orjson.dumps(payload),
252
+ timeout=self.timeout,
253
+ stream=True,
254
+ proxies=self._build_proxies()
255
+ )
256
+
257
+ if response.status_code != 200:
258
+ logger.error(f"Video generation failed: {response.status_code}")
259
+ try:
260
+ await session.close()
261
+ except:
262
+ pass
263
+ raise UpstreamException(
264
+ message=f"Video generation failed: {response.status_code}",
265
+ details={"status": response.status_code}
266
+ )
267
+
268
+ # Step 3: 流式传输
269
+ async def stream_response():
270
+ try:
271
+ async for line in response.aiter_lines():
272
+ yield line
273
+ finally:
274
+ if session:
275
+ await session.close()
276
+
277
+ return stream_response()
278
+
279
+ except Exception as e:
280
+ if session:
281
+ try:
282
+ await session.close()
283
+ except:
284
+ pass
285
+ logger.error(f"Video generation error: {e}")
286
+ if isinstance(e, AppException):
287
+ raise e
288
+ raise UpstreamException(f"Video generation error: {str(e)}")
289
+
290
+ async def generate_from_image(
291
+ self,
292
+ token: str,
293
+ prompt: str,
294
+ image_url: str,
295
+ aspect_ratio: str = "3:2",
296
+ video_length: int = 6,
297
+ resolution: str = "SD",
298
+ stream: bool = True,
299
+ preset: str = "normal"
300
+ ) -> AsyncGenerator[bytes, None]:
301
+ """
302
+ 从图片生成视频
303
+
304
+ Args:
305
+ token: 认证 Token
306
+ prompt: 视频描述
307
+ image_url: 图片 URL
308
+ aspect_ratio: 宽高比
309
+ video_length: 视频时长
310
+ resolution: 分辨率
311
+ stream: 是否流式
312
+ preset: 预设
313
+
314
+ Returns:
315
+ AsyncGenerator,流式传输
316
+ """
317
+ async with _get_media_semaphore():
318
+ session = None
319
+ try:
320
+ # Step 1: 创建帖子
321
+ post_id = await self.create_image_post(token, image_url)
322
+
323
+ # Step 2: 建立连接
324
+ headers = self._build_headers(token)
325
+ payload = self._build_payload(prompt, post_id, aspect_ratio, video_length, resolution, preset)
326
+
327
+ session = AsyncSession(impersonate=BROWSER)
328
+ response = await session.post(
329
+ CHAT_API,
330
+ headers=headers,
331
+ data=orjson.dumps(payload),
332
+ timeout=self.timeout,
333
+ stream=True,
334
+ proxies=self._build_proxies()
335
+ )
336
+
337
+ if response.status_code != 200:
338
+ logger.error(f"Video from image failed: {response.status_code}")
339
+ try:
340
+ await session.close()
341
+ except:
342
+ pass
343
+ raise UpstreamException(
344
+ message=f"Video from image failed: {response.status_code}",
345
+ details={"status": response.status_code}
346
+ )
347
+
348
+ # Step 3: 流式传输
349
+ async def stream_response():
350
+ try:
351
+ async for line in response.aiter_lines():
352
+ yield line
353
+ finally:
354
+ if session:
355
+ await session.close()
356
+
357
+ return stream_response()
358
+
359
+ except Exception as e:
360
+ if session:
361
+ try:
362
+ await session.close()
363
+ except:
364
+ pass
365
+ logger.error(f"Video from image error: {e}")
366
+ if isinstance(e, AppException):
367
+ raise e
368
+ raise UpstreamException(f"Video from image error: {str(e)}")
369
+
370
+ @staticmethod
371
+ async def completions(
372
+ model: str,
373
+ messages: list,
374
+ stream: bool = None,
375
+ thinking: str = None,
376
+ aspect_ratio: str = "3:2",
377
+ video_length: int = 6,
378
+ resolution: str = "SD",
379
+ preset: str = "normal"
380
+ ):
381
+ """
382
+ 视频生成入口
383
+
384
+ Args:
385
+ model: 模型名称
386
+ messages: 消息列表
387
+ stream: 是否流式
388
+ thinking: 思考模式
389
+ aspect_ratio: 宽高比
390
+ video_length: 视频时长
391
+ resolution: 分辨率
392
+ preset: 预设模式
393
+
394
+ Returns:
395
+ AsyncGenerator (流式) 或 dict (非流式)
396
+ """
397
+ # 获取 token
398
+ try:
399
+ token_mgr = await get_token_manager()
400
+ await token_mgr.reload_if_stale()
401
+ token = token_mgr.get_token_for_model(model)
402
+ except Exception as e:
403
+ logger.error(f"Failed to get token: {e}")
404
+ try:
405
+ await request_stats.record_request(model, success=False)
406
+ except Exception:
407
+ pass
408
+ raise AppException(
409
+ message="Internal service error obtaining token",
410
+ error_type=ErrorType.SERVER.value,
411
+ code="internal_error"
412
+ )
413
+
414
+ if not token:
415
+ try:
416
+ await request_stats.record_request(model, success=False)
417
+ except Exception:
418
+ pass
419
+ raise AppException(
420
+ message="No available tokens. Please try again later.",
421
+ error_type=ErrorType.RATE_LIMIT.value,
422
+ code="rate_limit_exceeded",
423
+ status_code=429
424
+ )
425
+
426
+ # 解析参数
427
+ think = None
428
+ if thinking == "enabled":
429
+ think = True
430
+ elif thinking == "disabled":
431
+ think = False
432
+
433
+ is_stream = stream if stream is not None else get_config("grok.stream", True)
434
+
435
+ # 提取内容
436
+ from app.services.grok.chat import MessageExtractor
437
+ from app.services.grok.assets import UploadService
438
+
439
+ try:
440
+ prompt, attachments = MessageExtractor.extract(messages, is_video=True)
441
+ except ValueError as e:
442
+ raise ValidationException(str(e))
443
+
444
+ # 处理图片附件
445
+ image_url = None
446
+ if attachments:
447
+ upload_service = UploadService()
448
+ try:
449
+ for attach_type, attach_data in attachments:
450
+ if attach_type == "image":
451
+ # 上传图片
452
+ _, file_uri = await upload_service.upload(attach_data, token)
453
+ image_url = f"https://assets.grok.com/{file_uri}"
454
+ logger.info(f"Image uploaded for video: {image_url}")
455
+ break # 视频模型只使用第一张图片
456
+ finally:
457
+ await upload_service.close()
458
+
459
+ # 生成视频
460
+ service = VideoService()
461
+
462
+ try:
463
+ # 图片转视频
464
+ if image_url:
465
+ response = await service.generate_from_image(
466
+ token, prompt, image_url,
467
+ aspect_ratio, video_length, resolution, stream, preset
468
+ )
469
+ else:
470
+ response = await service.generate(
471
+ token, prompt,
472
+ aspect_ratio, video_length, resolution, stream, preset
473
+ )
474
+ except Exception:
475
+ try:
476
+ await request_stats.record_request(model, success=False)
477
+ except Exception:
478
+ pass
479
+ raise
480
+
481
+ # 处理响应
482
+ if is_stream:
483
+ processor = VideoStreamProcessor(model, token, think).process(response)
484
+
485
+ async def _wrapped_stream():
486
+ completed = False
487
+ try:
488
+ async for chunk in processor:
489
+ yield chunk
490
+ completed = True
491
+ finally:
492
+ try:
493
+ if completed:
494
+ await token_mgr.sync_usage(token, model, consume_on_fail=True, is_usage=True)
495
+ await request_stats.record_request(model, success=True)
496
+ else:
497
+ await request_stats.record_request(model, success=False)
498
+ except Exception:
499
+ pass
500
+
501
+ return _wrapped_stream()
502
+
503
+ result = await VideoCollectProcessor(model, token).process(response)
504
+ try:
505
+ await token_mgr.sync_usage(token, model, consume_on_fail=True, is_usage=True)
506
+ await request_stats.record_request(model, success=True)
507
+ except Exception:
508
+ pass
509
+ return result
510
+
511
+
512
+ __all__ = ["VideoService"]
app/services/grok/model.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Grok 模型管理服务
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ from enum import Enum
8
+ from typing import Optional, Tuple
9
+ from pydantic import BaseModel, Field
10
+
11
+ from app.core.exceptions import ValidationException
12
+
13
+
14
+ class Tier(str, Enum):
15
+ """模型档位"""
16
+ BASIC = "basic"
17
+ SUPER = "super"
18
+
19
+
20
+ class Cost(str, Enum):
21
+ """计费类型"""
22
+ LOW = "low"
23
+ HIGH = "high"
24
+
25
+
26
+ class ModelInfo(BaseModel):
27
+ """模型信息"""
28
+ model_id: str
29
+ grok_model: str
30
+ rate_limit_model: str
31
+ model_mode: str
32
+ tier: Tier = Field(default=Tier.BASIC)
33
+ cost: Cost = Field(default=Cost.LOW)
34
+ display_name: str
35
+ description: str = ""
36
+ is_video: bool = False
37
+ is_image: bool = False
38
+
39
+
40
+ class ModelService:
41
+ """模型管理服务"""
42
+
43
+ MODELS = [
44
+ ModelInfo(
45
+ model_id="grok-3",
46
+ grok_model="grok-3",
47
+ rate_limit_model="grok-3",
48
+ model_mode="MODEL_MODE_GROK_3",
49
+ cost=Cost.LOW,
50
+ display_name="Grok 3"
51
+ ),
52
+ ModelInfo(
53
+ model_id="grok-3-mini",
54
+ grok_model="grok-3",
55
+ rate_limit_model="grok-3",
56
+ model_mode="MODEL_MODE_GROK_3_MINI_THINKING",
57
+ cost=Cost.LOW,
58
+ display_name="Grok 3 Mini"
59
+ ),
60
+ ModelInfo(
61
+ model_id="grok-3-thinking",
62
+ grok_model="grok-3",
63
+ rate_limit_model="grok-3",
64
+ model_mode="MODEL_MODE_GROK_3_THINKING",
65
+ cost=Cost.LOW,
66
+ display_name="Grok 3 Thinking"
67
+ ),
68
+ ModelInfo(
69
+ model_id="grok-4",
70
+ grok_model="grok-4",
71
+ rate_limit_model="grok-4",
72
+ model_mode="MODEL_MODE_GROK_4",
73
+ cost=Cost.LOW,
74
+ display_name="Grok 4"
75
+ ),
76
+ ModelInfo(
77
+ model_id="grok-4-mini",
78
+ grok_model="grok-4-mini",
79
+ rate_limit_model="grok-4-mini",
80
+ model_mode="MODEL_MODE_GROK_4_MINI_THINKING",
81
+ cost=Cost.LOW,
82
+ display_name="Grok 4 Mini"
83
+ ),
84
+ ModelInfo(
85
+ model_id="grok-4-thinking",
86
+ grok_model="grok-4",
87
+ rate_limit_model="grok-4",
88
+ model_mode="MODEL_MODE_GROK_4_THINKING",
89
+ cost=Cost.LOW,
90
+ display_name="Grok 4 Thinking"
91
+ ),
92
+ ModelInfo(
93
+ model_id="grok-4-heavy",
94
+ grok_model="grok-4",
95
+ rate_limit_model="grok-4-heavy",
96
+ model_mode="MODEL_MODE_HEAVY",
97
+ cost=Cost.HIGH,
98
+ tier=Tier.SUPER,
99
+ display_name="Grok 4 Heavy"
100
+ ),
101
+ ModelInfo(
102
+ model_id="grok-4.1-mini",
103
+ grok_model="grok-4-1-thinking-1129",
104
+ rate_limit_model="grok-4-1-thinking-1129",
105
+ model_mode="MODEL_MODE_GROK_4_1_MINI_THINKING",
106
+ cost=Cost.LOW,
107
+ display_name="Grok 4.1 Mini"
108
+ ),
109
+ ModelInfo(
110
+ model_id="grok-4.1-fast",
111
+ grok_model="grok-4-1-thinking-1129",
112
+ rate_limit_model="grok-4-1-thinking-1129",
113
+ model_mode="MODEL_MODE_FAST",
114
+ cost=Cost.LOW,
115
+ display_name="Grok 4.1 Fast"
116
+ ),
117
+ ModelInfo(
118
+ model_id="grok-4.1-expert",
119
+ grok_model="grok-4-1-thinking-1129",
120
+ rate_limit_model="grok-4-1-thinking-1129",
121
+ model_mode="MODEL_MODE_EXPERT",
122
+ cost=Cost.HIGH,
123
+ display_name="Grok 4.1 Expert"
124
+ ),
125
+ ModelInfo(
126
+ model_id="grok-4.1-thinking",
127
+ grok_model="grok-4-1-thinking-1129",
128
+ rate_limit_model="grok-4-1-thinking-1129",
129
+ model_mode="MODEL_MODE_GROK_4_1_THINKING",
130
+ cost=Cost.HIGH,
131
+ display_name="Grok 4.1 Thinking"
132
+ ),
133
+ ModelInfo(
134
+ model_id="grok-4.20-beta",
135
+ grok_model="grok-420",
136
+ rate_limit_model="grok-420",
137
+ model_mode="MODEL_MODE_GROK_420",
138
+ cost=Cost.LOW,
139
+ display_name="Grok 4.20 Beta"
140
+ ),
141
+ ModelInfo(
142
+ model_id="grok-imagine-1.0",
143
+ grok_model="grok-3",
144
+ rate_limit_model="grok-3",
145
+ model_mode="MODEL_MODE_FAST",
146
+ cost=Cost.HIGH,
147
+ display_name="Grok Image",
148
+ description="Image generation model",
149
+ is_image=True
150
+ ),
151
+ ModelInfo(
152
+ model_id="grok-imagine-1.0-edit",
153
+ grok_model="imagine-image-edit",
154
+ rate_limit_model="grok-3",
155
+ model_mode="MODEL_MODE_FAST",
156
+ cost=Cost.HIGH,
157
+ display_name="Grok Image Edit",
158
+ description="Image edit model",
159
+ is_image=True
160
+ ),
161
+ ModelInfo(
162
+ model_id="grok-imagine-1.0-video",
163
+ grok_model="grok-3",
164
+ rate_limit_model="grok-3",
165
+ model_mode="MODEL_MODE_FAST",
166
+ cost=Cost.HIGH,
167
+ display_name="Grok Video",
168
+ description="Video generation model",
169
+ is_video=True
170
+ ),
171
+ ]
172
+
173
+ _map = {m.model_id: m for m in MODELS}
174
+
175
+ @classmethod
176
+ def get(cls, model_id: str) -> Optional[ModelInfo]:
177
+ """获取模型信息"""
178
+ return cls._map.get(model_id)
179
+
180
+ @classmethod
181
+ def list(cls) -> list[ModelInfo]:
182
+ """获取所有模型"""
183
+ return list(cls._map.values())
184
+
185
+ @classmethod
186
+ def valid(cls, model_id: str) -> bool:
187
+ """模型是否有效"""
188
+ return model_id in cls._map
189
+
190
+ @classmethod
191
+ def to_grok(cls, model_id: str) -> Tuple[str, str]:
192
+ """转换为 Grok 参数"""
193
+ model = cls.get(model_id)
194
+ if not model:
195
+ raise ValidationException(f"Invalid model ID: {model_id}")
196
+ return model.grok_model, model.model_mode
197
+
198
+ @classmethod
199
+ def rate_limit_model_for(cls, model_id: str) -> str:
200
+ """用于 /rest/rate-limits 的 modelName 映射。"""
201
+ model = cls.get(model_id)
202
+ return model.rate_limit_model if model else model_id
203
+
204
+ @classmethod
205
+ def is_heavy_bucket_model(cls, model_id: str) -> bool:
206
+ """是否使用 heavy 配额桶(目前仅 grok-4-heavy)。"""
207
+ return model_id == "grok-4-heavy"
208
+
209
+ @classmethod
210
+ def pool_for_model(cls, model_id: str) -> str:
211
+ """根据模型选择 Token 池"""
212
+ model = cls.get(model_id)
213
+ if model and model.tier == Tier.SUPER:
214
+ return "ssoSuper"
215
+ return "ssoBasic"
216
+
217
+ @classmethod
218
+ def pool_candidates_for_model(cls, model_id: str) -> list[str]:
219
+ """按优先级返回可用 Token 池列表。"""
220
+ model = cls.get(model_id)
221
+ if model and model.tier == Tier.SUPER:
222
+ return ["ssoSuper"]
223
+ return ["ssoBasic", "ssoSuper"]
224
+
225
+
226
+ __all__ = ["ModelService"]
app/services/grok/processor.py ADDED
@@ -0,0 +1,596 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ OpenAI 响应格式处理器
3
+ """
4
+ import time
5
+ import uuid
6
+ import random
7
+ import html
8
+ import orjson
9
+ from typing import Any, AsyncGenerator, Optional, AsyncIterable, List
10
+
11
+ from app.core.config import get_config
12
+ from app.core.logger import logger
13
+ from app.services.grok.assets import DownloadService
14
+
15
+
16
+ ASSET_URL = "https://assets.grok.com/"
17
+
18
+
19
+ def _build_video_poster_preview(video_url: str, thumbnail_url: str = "") -> str:
20
+ """将 <video> 替换为可点击的 Poster 预览图(用于前端展示)"""
21
+ safe_video = html.escape(video_url or "", quote=True)
22
+ safe_thumb = html.escape(thumbnail_url or "", quote=True)
23
+
24
+ if not safe_video:
25
+ return ""
26
+
27
+ if not safe_thumb:
28
+ return f'<a href="{safe_video}" target="_blank" rel="noopener noreferrer">{safe_video}</a>'
29
+
30
+ return f'''<a href="{safe_video}" target="_blank" rel="noopener noreferrer" style="display:inline-block;position:relative;max-width:100%;text-decoration:none;">
31
+ <img src="{safe_thumb}" alt="video" style="max-width:100%;height:auto;border-radius:12px;display:block;" />
32
+ <span style="position:absolute;inset:0;display:flex;align-items:center;justify-content:center;">
33
+ <span style="width:64px;height:64px;border-radius:9999px;background:rgba(0,0,0,.55);display:flex;align-items:center;justify-content:center;">
34
+ <span style="width:0;height:0;border-top:12px solid transparent;border-bottom:12px solid transparent;border-left:18px solid #fff;margin-left:4px;"></span>
35
+ </span>
36
+ </span>
37
+ </a>'''
38
+
39
+
40
+ class BaseProcessor:
41
+ """基础处理器"""
42
+
43
+ def __init__(self, model: str, token: str = ""):
44
+ self.model = model
45
+ self.token = token
46
+ self.created = int(time.time())
47
+ self.app_url = get_config("app.app_url", "")
48
+ self._dl_service: Optional[DownloadService] = None
49
+
50
+ def _get_dl(self) -> DownloadService:
51
+ """获取下载服务实例(复用)"""
52
+ if self._dl_service is None:
53
+ self._dl_service = DownloadService()
54
+ return self._dl_service
55
+
56
+ async def close(self):
57
+ """释放下载服务资源"""
58
+ if self._dl_service:
59
+ await self._dl_service.close()
60
+ self._dl_service = None
61
+
62
+ async def process_url(self, path: str, media_type: str = "image") -> str:
63
+ """处理资产 URL"""
64
+ # 处理可能的绝对路径
65
+ if path.startswith("http"):
66
+ from urllib.parse import urlparse
67
+ path = urlparse(path).path
68
+
69
+ if not path.startswith("/"):
70
+ path = f"/{path}"
71
+
72
+ # Invalid root path is not a displayable image URL.
73
+ if path in {"", "/"}:
74
+ return ""
75
+
76
+ # Always materialize to local cache endpoint so callers don't rely on
77
+ # direct assets.grok.com access (often blocked without upstream cookies).
78
+ dl_service = self._get_dl()
79
+ await dl_service.download(path, self.token, media_type)
80
+ local_path = f"/v1/files/{media_type}{path}"
81
+ if self.app_url:
82
+ return f"{self.app_url.rstrip('/')}{local_path}"
83
+ return local_path
84
+
85
+ def _sse(self, content: str = "", role: str = None, finish: str = None) -> str:
86
+ """构建 SSE 响应 (StreamProcessor 通用)"""
87
+ if not hasattr(self, 'response_id'):
88
+ self.response_id = None
89
+ if not hasattr(self, 'fingerprint'):
90
+ self.fingerprint = ""
91
+
92
+ delta = {}
93
+ if role:
94
+ delta["role"] = role
95
+ delta["content"] = ""
96
+ elif content:
97
+ delta["content"] = content
98
+
99
+ chunk = {
100
+ "id": self.response_id or f"chatcmpl-{uuid.uuid4().hex[:24]}",
101
+ "object": "chat.completion.chunk",
102
+ "created": self.created,
103
+ "model": self.model,
104
+ "system_fingerprint": self.fingerprint if hasattr(self, 'fingerprint') else "",
105
+ "choices": [{"index": 0, "delta": delta, "logprobs": None, "finish_reason": finish}]
106
+ }
107
+ return f"data: {orjson.dumps(chunk).decode()}\n\n"
108
+
109
+
110
+ class StreamProcessor(BaseProcessor):
111
+ """流式响应处理器"""
112
+
113
+ def __init__(self, model: str, token: str = "", think: bool = None):
114
+ super().__init__(model, token)
115
+ self.response_id: Optional[str] = None
116
+ self.fingerprint: str = ""
117
+ self.think_opened: bool = False
118
+ self.role_sent: bool = False
119
+ self.filter_tags = get_config("grok.filter_tags", [])
120
+ self.image_format = get_config("app.image_format", "url")
121
+
122
+ if think is None:
123
+ self.show_think = get_config("grok.thinking", False)
124
+ else:
125
+ self.show_think = think
126
+
127
+ async def process(self, response: AsyncIterable[bytes]) -> AsyncGenerator[str, None]:
128
+ """处理流式响应"""
129
+ try:
130
+ async for line in response:
131
+ if not line:
132
+ continue
133
+ try:
134
+ data = orjson.loads(line)
135
+ except orjson.JSONDecodeError:
136
+ continue
137
+
138
+ resp = data.get("result", {}).get("response", {})
139
+
140
+ # 元数据
141
+ if (llm := resp.get("llmInfo")) and not self.fingerprint:
142
+ self.fingerprint = llm.get("modelHash", "")
143
+ if rid := resp.get("responseId"):
144
+ self.response_id = rid
145
+
146
+ # 首次发送 role
147
+ if not self.role_sent:
148
+ yield self._sse(role="assistant")
149
+ self.role_sent = True
150
+
151
+ # 图像生成进度
152
+ if img := resp.get("streamingImageGenerationResponse"):
153
+ if self.show_think:
154
+ if not self.think_opened:
155
+ yield self._sse("<think>\n")
156
+ self.think_opened = True
157
+ idx = img.get('imageIndex', 0) + 1
158
+ progress = img.get('progress', 0)
159
+ yield self._sse(f"正在生成第{idx}张图片中,当前进度{progress}%\n")
160
+ continue
161
+
162
+ # modelResponse
163
+ if mr := resp.get("modelResponse"):
164
+ if self.think_opened and self.show_think:
165
+ if msg := mr.get("message"):
166
+ yield self._sse(msg + "\n")
167
+ yield self._sse("</think>\n")
168
+ self.think_opened = False
169
+
170
+ # 处理生成的图片
171
+ for url in mr.get("generatedImageUrls", []):
172
+ parts = url.split("/")
173
+ img_id = parts[-2] if len(parts) >= 2 else "image"
174
+
175
+ if self.image_format == "base64":
176
+ dl_service = self._get_dl()
177
+ base64_data = await dl_service.to_base64(url, self.token, "image")
178
+ if base64_data:
179
+ yield self._sse(f"![{img_id}]({base64_data})\n")
180
+ else:
181
+ final_url = await self.process_url(url, "image")
182
+ yield self._sse(f"![{img_id}]({final_url})\n")
183
+ else:
184
+ final_url = await self.process_url(url, "image")
185
+ yield self._sse(f"![{img_id}]({final_url})\n")
186
+
187
+ if (meta := mr.get("metadata", {})).get("llm_info", {}).get("modelHash"):
188
+ self.fingerprint = meta["llm_info"]["modelHash"]
189
+ continue
190
+
191
+ # 普通 token
192
+ if (token := resp.get("token")) is not None:
193
+ if token and not (self.filter_tags and any(t in token for t in self.filter_tags)):
194
+ yield self._sse(token)
195
+
196
+ if self.think_opened:
197
+ yield self._sse("</think>\n")
198
+ yield self._sse(finish="stop")
199
+ yield "data: [DONE]\n\n"
200
+ except Exception as e:
201
+ logger.error(f"Stream processing error: {e}", extra={"model": self.model})
202
+ raise
203
+ finally:
204
+ await self.close()
205
+
206
+
207
+ class CollectProcessor(BaseProcessor):
208
+ """非流式响应处理器"""
209
+
210
+ def __init__(self, model: str, token: str = ""):
211
+ super().__init__(model, token)
212
+ self.image_format = get_config("app.image_format", "url")
213
+
214
+ async def process(self, response: AsyncIterable[bytes]) -> dict[str, Any]:
215
+ """处理并收集完整响应"""
216
+ response_id = ""
217
+ fingerprint = ""
218
+ content = ""
219
+
220
+ try:
221
+ async for line in response:
222
+ if not line:
223
+ continue
224
+ try:
225
+ data = orjson.loads(line)
226
+ except orjson.JSONDecodeError:
227
+ continue
228
+
229
+ resp = data.get("result", {}).get("response", {})
230
+
231
+ if (llm := resp.get("llmInfo")) and not fingerprint:
232
+ fingerprint = llm.get("modelHash", "")
233
+
234
+ if mr := resp.get("modelResponse"):
235
+ response_id = mr.get("responseId", "")
236
+ content = mr.get("message", "")
237
+
238
+ if urls := mr.get("generatedImageUrls"):
239
+ content += "\n"
240
+ for url in urls:
241
+ parts = url.split("/")
242
+ img_id = parts[-2] if len(parts) >= 2 else "image"
243
+
244
+ if self.image_format == "base64":
245
+ dl_service = self._get_dl()
246
+ base64_data = await dl_service.to_base64(url, self.token, "image")
247
+ if base64_data:
248
+ content += f"![{img_id}]({base64_data})\n"
249
+ else:
250
+ final_url = await self.process_url(url, "image")
251
+ content += f"![{img_id}]({final_url})\n"
252
+ else:
253
+ final_url = await self.process_url(url, "image")
254
+ content += f"![{img_id}]({final_url})\n"
255
+
256
+ if (meta := mr.get("metadata", {})).get("llm_info", {}).get("modelHash"):
257
+ fingerprint = meta["llm_info"]["modelHash"]
258
+
259
+ except Exception as e:
260
+ logger.error(f"Collect processing error: {e}", extra={"model": self.model})
261
+ finally:
262
+ await self.close()
263
+
264
+ return {
265
+ "id": response_id,
266
+ "object": "chat.completion",
267
+ "created": self.created,
268
+ "model": self.model,
269
+ "system_fingerprint": fingerprint,
270
+ "choices": [{
271
+ "index": 0,
272
+ "message": {"role": "assistant", "content": content, "refusal": None, "annotations": []},
273
+ "finish_reason": "stop"
274
+ }],
275
+ "usage": {
276
+ "prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0,
277
+ "prompt_tokens_details": {"cached_tokens": 0, "text_tokens": 0, "audio_tokens": 0, "image_tokens": 0},
278
+ "completion_tokens_details": {"text_tokens": 0, "audio_tokens": 0, "reasoning_tokens": 0}
279
+ }
280
+ }
281
+
282
+
283
+ class VideoStreamProcessor(BaseProcessor):
284
+ """视频流式响应处理器"""
285
+
286
+ def __init__(self, model: str, token: str = "", think: bool = None):
287
+ super().__init__(model, token)
288
+ self.response_id: Optional[str] = None
289
+ self.think_opened: bool = False
290
+ self.role_sent: bool = False
291
+ self.video_format = get_config("app.video_format", "url")
292
+
293
+ if think is None:
294
+ self.show_think = get_config("grok.thinking", False)
295
+ else:
296
+ self.show_think = think
297
+
298
+ def _build_video_html(self, video_url: str, thumbnail_url: str = "") -> str:
299
+ """构建视频 HTML 标签"""
300
+ if get_config("grok.video_poster_preview", False):
301
+ return _build_video_poster_preview(video_url, thumbnail_url)
302
+ poster_attr = f' poster="{thumbnail_url}"' if thumbnail_url else ""
303
+ return f'''<video id="video" controls="" preload="none"{poster_attr}>
304
+ <source id="mp4" src="{video_url}" type="video/mp4">
305
+ </video>'''
306
+
307
+ async def process(self, response: AsyncIterable[bytes]) -> AsyncGenerator[str, None]:
308
+ """处理视频流式响应"""
309
+ try:
310
+ async for line in response:
311
+ if not line:
312
+ continue
313
+ try:
314
+ data = orjson.loads(line)
315
+ except orjson.JSONDecodeError:
316
+ continue
317
+
318
+ resp = data.get("result", {}).get("response", {})
319
+
320
+ if rid := resp.get("responseId"):
321
+ self.response_id = rid
322
+
323
+ # 首次发送 role
324
+ if not self.role_sent:
325
+ yield self._sse(role="assistant")
326
+ self.role_sent = True
327
+
328
+ # 视频生成进度
329
+ if video_resp := resp.get("streamingVideoGenerationResponse"):
330
+ progress = video_resp.get("progress", 0)
331
+
332
+ if self.show_think:
333
+ if not self.think_opened:
334
+ yield self._sse("<think>\n")
335
+ self.think_opened = True
336
+ yield self._sse(f"正在生成视频中,当前进度{progress}%\n")
337
+
338
+ if progress == 100:
339
+ video_url = video_resp.get("videoUrl", "")
340
+ thumbnail_url = video_resp.get("thumbnailImageUrl", "")
341
+
342
+ if self.think_opened and self.show_think:
343
+ yield self._sse("</think>\n")
344
+ self.think_opened = False
345
+
346
+ if video_url:
347
+ final_video_url = await self.process_url(video_url, "video")
348
+ final_thumbnail_url = ""
349
+ if thumbnail_url:
350
+ final_thumbnail_url = await self.process_url(thumbnail_url, "image")
351
+
352
+ video_html = self._build_video_html(final_video_url, final_thumbnail_url)
353
+ yield self._sse(video_html)
354
+
355
+ logger.info(f"Video generated: {video_url}")
356
+ continue
357
+
358
+ if self.think_opened:
359
+ yield self._sse("</think>\n")
360
+ yield self._sse(finish="stop")
361
+ yield "data: [DONE]\n\n"
362
+ except Exception as e:
363
+ logger.error(f"Video stream processing error: {e}", extra={"model": self.model})
364
+ finally:
365
+ await self.close()
366
+
367
+
368
+ class VideoCollectProcessor(BaseProcessor):
369
+ """视频非流式响应处理器"""
370
+
371
+ def __init__(self, model: str, token: str = ""):
372
+ super().__init__(model, token)
373
+ self.video_format = get_config("app.video_format", "url")
374
+
375
+ def _build_video_html(self, video_url: str, thumbnail_url: str = "") -> str:
376
+ if get_config("grok.video_poster_preview", False):
377
+ return _build_video_poster_preview(video_url, thumbnail_url)
378
+ poster_attr = f' poster="{thumbnail_url}"' if thumbnail_url else ""
379
+ return f'''<video id="video" controls="" preload="none"{poster_attr}>
380
+ <source id="mp4" src="{video_url}" type="video/mp4">
381
+ </video>'''
382
+
383
+ async def process(self, response: AsyncIterable[bytes]) -> dict[str, Any]:
384
+ """处理并收集视频响应"""
385
+ response_id = ""
386
+ content = ""
387
+
388
+ try:
389
+ async for line in response:
390
+ if not line:
391
+ continue
392
+ try:
393
+ data = orjson.loads(line)
394
+ except orjson.JSONDecodeError:
395
+ continue
396
+
397
+ resp = data.get("result", {}).get("response", {})
398
+
399
+ if video_resp := resp.get("streamingVideoGenerationResponse"):
400
+ if video_resp.get("progress") == 100:
401
+ response_id = resp.get("responseId", "")
402
+ video_url = video_resp.get("videoUrl", "")
403
+ thumbnail_url = video_resp.get("thumbnailImageUrl", "")
404
+
405
+ if video_url:
406
+ final_video_url = await self.process_url(video_url, "video")
407
+ final_thumbnail_url = ""
408
+ if thumbnail_url:
409
+ final_thumbnail_url = await self.process_url(thumbnail_url, "image")
410
+
411
+ content = self._build_video_html(final_video_url, final_thumbnail_url)
412
+ logger.info(f"Video generated: {video_url}")
413
+
414
+ except Exception as e:
415
+ logger.error(f"Video collect processing error: {e}", extra={"model": self.model})
416
+ finally:
417
+ await self.close()
418
+
419
+ return {
420
+ "id": response_id,
421
+ "object": "chat.completion",
422
+ "created": self.created,
423
+ "model": self.model,
424
+ "choices": [{
425
+ "index": 0,
426
+ "message": {"role": "assistant", "content": content, "refusal": None},
427
+ "finish_reason": "stop"
428
+ }],
429
+ "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
430
+ }
431
+
432
+
433
+ class ImageStreamProcessor(BaseProcessor):
434
+ """图片生成流式响应处理器"""
435
+
436
+ def __init__(
437
+ self,
438
+ model: str,
439
+ token: str = "",
440
+ n: int = 1,
441
+ response_format: str = "b64_json",
442
+ ):
443
+ super().__init__(model, token)
444
+ self.partial_index = 0
445
+ self.n = n
446
+ self.target_index = random.randint(0, 1) if n == 1 else None
447
+ self.response_format = (response_format or "b64_json").lower()
448
+ if self.response_format == "url":
449
+ self.response_field = "url"
450
+ elif self.response_format == "base64":
451
+ self.response_field = "base64"
452
+ else:
453
+ self.response_field = "b64_json"
454
+
455
+ def _sse(self, event: str, data: dict) -> str:
456
+ """构建 SSE 响应 (覆盖基类)"""
457
+ return f"event: {event}\ndata: {orjson.dumps(data).decode()}\n\n"
458
+
459
+ async def process(self, response: AsyncIterable[bytes]) -> AsyncGenerator[str, None]:
460
+ """处理流式响应"""
461
+ final_images = []
462
+
463
+ try:
464
+ async for line in response:
465
+ if not line:
466
+ continue
467
+ try:
468
+ data = orjson.loads(line)
469
+ except orjson.JSONDecodeError:
470
+ continue
471
+
472
+ resp = data.get("result", {}).get("response", {})
473
+
474
+ # 图片生成进度
475
+ if img := resp.get("streamingImageGenerationResponse"):
476
+ image_index = img.get("imageIndex", 0)
477
+ progress = img.get("progress", 0)
478
+
479
+ if self.n == 1 and image_index != self.target_index:
480
+ continue
481
+
482
+ out_index = 0 if self.n == 1 else image_index
483
+
484
+ yield self._sse("image_generation.partial_image", {
485
+ "type": "image_generation.partial_image",
486
+ self.response_field: "",
487
+ "index": out_index,
488
+ "progress": progress
489
+ })
490
+ continue
491
+
492
+ # modelResponse
493
+ if mr := resp.get("modelResponse"):
494
+ if urls := mr.get("generatedImageUrls"):
495
+ for url in urls:
496
+ if self.response_format == "url":
497
+ processed = await self.process_url(url, "image")
498
+ if processed:
499
+ final_images.append(processed)
500
+ continue
501
+ dl_service = self._get_dl()
502
+ base64_data = await dl_service.to_base64(url, self.token, "image")
503
+ if base64_data:
504
+ if "," in base64_data:
505
+ b64 = base64_data.split(",", 1)[1]
506
+ else:
507
+ b64 = base64_data
508
+ final_images.append(b64)
509
+ continue
510
+
511
+ for index, b64 in enumerate(final_images):
512
+ if self.n == 1:
513
+ if index != self.target_index:
514
+ continue
515
+ out_index = 0
516
+ else:
517
+ out_index = index
518
+
519
+ yield self._sse("image_generation.completed", {
520
+ "type": "image_generation.completed",
521
+ self.response_field: b64,
522
+ "index": out_index,
523
+ "usage": {
524
+ "total_tokens": 50,
525
+ "input_tokens": 25,
526
+ "output_tokens": 25,
527
+ "input_tokens_details": {"text_tokens": 5, "image_tokens": 20}
528
+ }
529
+ })
530
+ except Exception as e:
531
+ logger.error(f"Image stream processing error: {e}")
532
+ raise
533
+ finally:
534
+ await self.close()
535
+
536
+
537
+ class ImageCollectProcessor(BaseProcessor):
538
+ """图片生成非流式响应处理器"""
539
+
540
+ def __init__(
541
+ self,
542
+ model: str,
543
+ token: str = "",
544
+ response_format: str = "b64_json",
545
+ ):
546
+ super().__init__(model, token)
547
+ self.response_format = (response_format or "b64_json").lower()
548
+
549
+ async def process(self, response: AsyncIterable[bytes]) -> List[str]:
550
+ """处理并收集图片"""
551
+ images = []
552
+
553
+ try:
554
+ async for line in response:
555
+ if not line:
556
+ continue
557
+ try:
558
+ data = orjson.loads(line)
559
+ except orjson.JSONDecodeError:
560
+ continue
561
+
562
+ resp = data.get("result", {}).get("response", {})
563
+
564
+ if mr := resp.get("modelResponse"):
565
+ if urls := mr.get("generatedImageUrls"):
566
+ for url in urls:
567
+ if self.response_format == "url":
568
+ processed = await self.process_url(url, "image")
569
+ if processed:
570
+ images.append(processed)
571
+ continue
572
+ dl_service = self._get_dl()
573
+ base64_data = await dl_service.to_base64(url, self.token, "image")
574
+ if base64_data:
575
+ if "," in base64_data:
576
+ b64 = base64_data.split(",", 1)[1]
577
+ else:
578
+ b64 = base64_data
579
+ images.append(b64)
580
+
581
+ except Exception as e:
582
+ logger.error(f"Image collect processing error: {e}")
583
+ finally:
584
+ await self.close()
585
+
586
+ return images
587
+
588
+
589
+ __all__ = [
590
+ "StreamProcessor",
591
+ "CollectProcessor",
592
+ "VideoStreamProcessor",
593
+ "VideoCollectProcessor",
594
+ "ImageStreamProcessor",
595
+ "ImageCollectProcessor",
596
+ ]
app/services/grok/retry.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Grok API 重试工具
3
+
4
+ 提供可配置的重试机制,支持:
5
+ - 可配置的重试次数
6
+ - 可配置的重试状态码
7
+ - 仅记录最后一次失败
8
+ """
9
+
10
+ import asyncio
11
+ from typing import Callable, Any, Optional, List
12
+ from functools import wraps
13
+
14
+ from app.core.logger import logger
15
+ from app.core.config import get_config
16
+ from app.core.exceptions import UpstreamException
17
+
18
+
19
+ class RetryConfig:
20
+ """重试配置"""
21
+
22
+ @staticmethod
23
+ def get_max_retry() -> int:
24
+ """获取最大重试次数"""
25
+ return get_config("grok.max_retry", 1)
26
+
27
+ @staticmethod
28
+ def get_retry_codes() -> List[int]:
29
+ """获取可重试的状态码"""
30
+ return get_config("grok.retry_status_codes", [401, 429, 403])
31
+
32
+
33
+ class RetryContext:
34
+ """重试上下文"""
35
+
36
+ def __init__(self):
37
+ self.attempt = 0
38
+ self.max_retry = RetryConfig.get_max_retry()
39
+ self.retry_codes = RetryConfig.get_retry_codes()
40
+ self.last_error = None
41
+ self.last_status = None
42
+
43
+ def should_retry(self, status_code: int) -> bool:
44
+ """判断是否重试"""
45
+ return (
46
+ self.attempt < self.max_retry and
47
+ status_code in self.retry_codes
48
+ )
49
+
50
+ def record_error(self, status_code: int, error: Exception):
51
+ """记录错误信息"""
52
+ self.last_status = status_code
53
+ self.last_error = error
54
+ self.attempt += 1
55
+
56
+
57
+ async def retry_on_status(
58
+ func: Callable,
59
+ *args,
60
+ extract_status: Callable[[Exception], Optional[int]] = None,
61
+ on_retry: Callable[[int, int, Exception], None] = None,
62
+ **kwargs
63
+ ) -> Any:
64
+ """
65
+ 通用重试函数
66
+
67
+ Args:
68
+ func: 重试的异步函数
69
+ *args: 函数参数
70
+ extract_status: 异常提取状态码的函数
71
+ on_retry: 重试时的回调函数
72
+ **kwargs: 函数关键字参数
73
+
74
+ Returns:
75
+ 函数执行结果
76
+
77
+ Raises:
78
+ 最后一次失败的异常
79
+ """
80
+ ctx = RetryContext()
81
+
82
+ # 状态码提取器
83
+ if extract_status is None:
84
+ def extract_status(e: Exception) -> Optional[int]:
85
+ if isinstance(e, UpstreamException):
86
+ return e.details.get("status") if e.details else None
87
+ return None
88
+
89
+ while ctx.attempt <= ctx.max_retry:
90
+ try:
91
+ result = await func(*args, **kwargs)
92
+
93
+ # 记录日志
94
+ if ctx.attempt > 0:
95
+ logger.info(
96
+ f"Retry succeeded after {ctx.attempt} attempts"
97
+ )
98
+
99
+ return result
100
+
101
+ except Exception as e:
102
+ # 提取状态码
103
+ status_code = extract_status(e)
104
+
105
+ if status_code is None:
106
+ # 错误无法识别
107
+ logger.error(f"Non-retryable error: {e}")
108
+ raise
109
+
110
+ # 记录错误
111
+ ctx.record_error(status_code, e)
112
+
113
+ # 判断是否重试
114
+ if ctx.should_retry(status_code):
115
+ delay = 0.5 * (ctx.attempt + 1) # 渐进延迟
116
+ logger.warning(
117
+ f"Retry {ctx.attempt}/{ctx.max_retry} for status {status_code}, "
118
+ f"waiting {delay}s"
119
+ )
120
+
121
+ # 回调
122
+ if on_retry:
123
+ on_retry(ctx.attempt, status_code, e)
124
+
125
+ await asyncio.sleep(delay)
126
+ continue
127
+ else:
128
+ # 不可重试或重试次数耗尽
129
+ if status_code in ctx.retry_codes:
130
+ # 打印当前尝试次数(包括最后一次)
131
+ logger.warning(
132
+ f"Retry {ctx.attempt}/{ctx.max_retry} for status {status_code}, failed"
133
+ )
134
+ logger.error(
135
+ f"Retry exhausted after {ctx.max_retry} attempts, "
136
+ f"last status: {status_code}"
137
+ )
138
+ else:
139
+ logger.error(
140
+ f"Non-retryable status code: {status_code}"
141
+ )
142
+
143
+ # 抛出最后一次的错误
144
+ raise
145
+
146
+
147
+ def with_retry(
148
+ extract_status: Callable[[Exception], Optional[int]] = None,
149
+ on_retry: Callable[[int, int, Exception], None] = None
150
+ ):
151
+ """
152
+ 重试装饰器
153
+
154
+ Usage:
155
+ @with_retry()
156
+ async def my_api_call():
157
+ ...
158
+ """
159
+ def decorator(func: Callable):
160
+ @wraps(func)
161
+ async def wrapper(*args, **kwargs):
162
+ return await retry_on_status(
163
+ func,
164
+ *args,
165
+ extract_status=extract_status,
166
+ on_retry=on_retry,
167
+ **kwargs
168
+ )
169
+ return wrapper
170
+ return decorator
171
+
172
+
173
+ __all__ = [
174
+ "RetryConfig",
175
+ "RetryContext",
176
+ "retry_on_status",
177
+ "with_retry",
178
+ ]
app/services/grok/statsig.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Statsig ID 生成服务
3
+ """
4
+
5
+ import base64
6
+ import random
7
+ import string
8
+
9
+ from app.core.config import get_config
10
+
11
+
12
+ class StatsigService:
13
+ """Statsig ID 生成服务"""
14
+
15
+ @staticmethod
16
+ def _rand(length: int, alphanumeric: bool = False) -> str:
17
+ """生成随机字符串"""
18
+ chars = string.ascii_lowercase + string.digits if alphanumeric else string.ascii_lowercase
19
+ return "".join(random.choices(chars, k=length))
20
+
21
+ @staticmethod
22
+ def gen_id() -> str:
23
+ """
24
+ 生成 Statsig ID
25
+
26
+ Returns:
27
+ Base64 编码的 ID
28
+ """
29
+ # 读取配置
30
+ dynamic = get_config("grok.dynamic_statsig", True)
31
+
32
+ if not dynamic:
33
+ return "ZTpUeXBlRXJyb3I6IENhbm5vdCByZWFkIHByb3BlcnRpZXMgb2YgdW5kZWZpbmVkIChyZWFkaW5nICdjaGlsZE5vZGVzJyk="
34
+
35
+ # 随机格式
36
+ if random.choice([True, False]):
37
+ rand = StatsigService._rand(5, alphanumeric=True)
38
+ message = f"e:TypeError: Cannot read properties of null (reading 'children['{rand}']')"
39
+ else:
40
+ rand = StatsigService._rand(10)
41
+ message = f"e:TypeError: Cannot read properties of undefined (reading '{rand}')"
42
+
43
+ return base64.b64encode(message.encode()).decode()
44
+
45
+
46
+ __all__ = ["StatsigService"]
app/services/grok/usage.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Grok 用量服务
3
+ """
4
+
5
+ import asyncio
6
+ import uuid
7
+ from typing import Dict
8
+
9
+ import orjson
10
+ from curl_cffi.requests import AsyncSession
11
+
12
+ from app.core.logger import logger
13
+ from app.core.config import get_config
14
+ from app.core.exceptions import UpstreamException, AppException
15
+ from app.services.grok.statsig import StatsigService
16
+ from app.services.grok.retry import retry_on_status
17
+
18
+
19
+ LIMITS_API = "https://grok.com/rest/rate-limits"
20
+ BROWSER = "chrome136"
21
+ TIMEOUT = 10
22
+ DEFAULT_MAX_CONCURRENT = 25
23
+ _USAGE_SEMAPHORE = asyncio.Semaphore(DEFAULT_MAX_CONCURRENT)
24
+ _USAGE_SEM_VALUE = DEFAULT_MAX_CONCURRENT
25
+
26
+ def _get_usage_semaphore() -> asyncio.Semaphore:
27
+ global _USAGE_SEMAPHORE, _USAGE_SEM_VALUE
28
+ value = get_config("performance.usage_max_concurrent", DEFAULT_MAX_CONCURRENT)
29
+ try:
30
+ value = int(value)
31
+ except Exception:
32
+ value = DEFAULT_MAX_CONCURRENT
33
+ value = max(1, value)
34
+ if value != _USAGE_SEM_VALUE:
35
+ _USAGE_SEM_VALUE = value
36
+ _USAGE_SEMAPHORE = asyncio.Semaphore(value)
37
+ return _USAGE_SEMAPHORE
38
+
39
+
40
+ class UsageService:
41
+ """用量查询服务"""
42
+
43
+ def __init__(self, proxy: str = None):
44
+ self.proxy = proxy or get_config("grok.base_proxy_url", "")
45
+ self.timeout = get_config("grok.timeout", TIMEOUT)
46
+
47
+ def _build_headers(self, token: str) -> dict:
48
+ """构建请求头"""
49
+ headers = {
50
+ "Accept": "*/*",
51
+ "Accept-Encoding": "gzip, deflate, br, zstd",
52
+ "Accept-Language": "zh-CN,zh;q=0.9",
53
+ "Baggage": "sentry-environment=production,sentry-release=d6add6fb0460641fd482d767a335ef72b9b6abb8,sentry-public_key=b311e0f2690c81f25e2c4cf6d4f7ce1c",
54
+ "Cache-Control": "no-cache",
55
+ "Content-Type": "application/json",
56
+ "Origin": "https://grok.com",
57
+ "Pragma": "no-cache",
58
+ "Priority": "u=1, i",
59
+ "Referer": "https://grok.com/",
60
+ "Sec-Ch-Ua": '"Google Chrome";v="136", "Chromium";v="136", "Not(A:Brand";v="24"',
61
+ "Sec-Ch-Ua-Arch": "arm",
62
+ "Sec-Ch-Ua-Bitness": "64",
63
+ "Sec-Ch-Ua-Mobile": "?0",
64
+ "Sec-Ch-Ua-Model": "",
65
+ "Sec-Ch-Ua-Platform": '"macOS"',
66
+ "Sec-Fetch-Dest": "empty",
67
+ "Sec-Fetch-Mode": "cors",
68
+ "Sec-Fetch-Site": "same-origin",
69
+ "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/136.0.0.0 Safari/537.36",
70
+ }
71
+
72
+ # Statsig ID
73
+ headers["x-statsig-id"] = StatsigService.gen_id()
74
+ headers["x-xai-request-id"] = str(uuid.uuid4())
75
+
76
+ # Cookie
77
+ token = token[4:] if token.startswith("sso=") else token
78
+ cf = get_config("grok.cf_clearance", "")
79
+ headers["Cookie"] = f"sso={token};cf_clearance={cf}" if cf else f"sso={token}"
80
+
81
+ return headers
82
+
83
+ def _build_proxies(self) -> dict:
84
+ """构建代理配置"""
85
+ return {"http": self.proxy, "https": self.proxy} if self.proxy else None
86
+
87
+ async def get(self, token: str, model_name: str = "grok-4-1-thinking-1129") -> Dict:
88
+ """
89
+ 获取速率限制信息
90
+
91
+ Args:
92
+ token: 认证 Token
93
+ model_name: 模型名称
94
+
95
+ Returns:
96
+ 响应数据
97
+
98
+ Raises:
99
+ UpstreamException: 当获取失败且重试耗尽时
100
+ """
101
+ async with _get_usage_semaphore():
102
+ # 定义状态码提取器
103
+ def extract_status(e: Exception) -> int | None:
104
+ if isinstance(e, UpstreamException) and e.details:
105
+ return e.details.get("status")
106
+ return None
107
+
108
+ # 定义实际的请求函数
109
+ async def do_request():
110
+ try:
111
+ headers = self._build_headers(token)
112
+ payload = {
113
+ "requestKind": "DEFAULT",
114
+ "modelName": model_name
115
+ }
116
+
117
+ async with AsyncSession() as session:
118
+ response = await session.post(
119
+ LIMITS_API,
120
+ headers=headers,
121
+ json=payload,
122
+ impersonate=BROWSER,
123
+ timeout=self.timeout,
124
+ proxies=self._build_proxies()
125
+ )
126
+
127
+ if response.status_code == 200:
128
+ data = response.json()
129
+ remaining = data.get('remainingTokens', 0)
130
+ logger.info(f"Usage: quota {remaining} remaining")
131
+ return data
132
+
133
+ logger.error(f"Usage failed: {response.status_code}")
134
+
135
+ raise UpstreamException(
136
+ message=f"Failed to get usage stats: {response.status_code}",
137
+ details={"status": response.status_code}
138
+ )
139
+
140
+ except Exception as e:
141
+ if isinstance(e, UpstreamException):
142
+ raise
143
+ logger.error(f"Usage error: {e}")
144
+ raise UpstreamException(
145
+ message=f"Usage service error: {str(e)}",
146
+ details={"error": str(e)}
147
+ )
148
+
149
+ # 带重试的执行
150
+ try:
151
+ result = await retry_on_status(
152
+ do_request,
153
+ extract_status=extract_status
154
+ )
155
+ return result
156
+
157
+ except Exception as e:
158
+ # 最后一次失败已经被记录
159
+ raise
160
+
161
+
162
+ __all__ = ["UsageService"]
app/services/quota.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ API Key daily quota enforcement (local/docker runtime)
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ from typing import Optional, Dict
8
+
9
+ from app.core.config import get_config
10
+ from app.core.exceptions import AppException, ErrorType
11
+ from app.services.api_keys import api_key_manager
12
+ from app.services.grok.model import ModelService
13
+
14
+
15
+ async def enforce_daily_quota(
16
+ api_key: Optional[str],
17
+ model: str,
18
+ *,
19
+ image_count: Optional[int] = None,
20
+ ) -> None:
21
+ """
22
+ Enforce per-day quotas for a non-admin API key.
23
+
24
+ - chat/heavy/video: count by request (1)
25
+ - image: count by generated images
26
+ - chat endpoint + image model: charge 2 images per request
27
+ - image endpoint: charge `image_count` (n)
28
+ - heavy: consumes both heavy + chat buckets
29
+ """
30
+
31
+ token = str(api_key or "").strip()
32
+ if not token:
33
+ return
34
+
35
+ global_key = str(get_config("app.api_key", "") or "").strip()
36
+ if global_key and token == global_key:
37
+ return
38
+
39
+ model_info = ModelService.get(model)
40
+ incs: Dict[str, int] = {}
41
+ bucket_name = "chat"
42
+
43
+ if model == "grok-4-heavy":
44
+ incs = {"heavy_used": 1, "chat_used": 1}
45
+ bucket_name = "heavy/chat"
46
+ elif model_info and model_info.is_video:
47
+ incs = {"video_used": 1}
48
+ bucket_name = "video"
49
+ elif model_info and model_info.is_image:
50
+ # grok image model via chat endpoint: upstream usually returns up to 2 images
51
+ incs = {"image_used": max(1, int(image_count or 2))}
52
+ bucket_name = "image"
53
+ else:
54
+ incs = {"chat_used": 1}
55
+ bucket_name = "chat"
56
+
57
+ ok = await api_key_manager.consume_daily_usage(token, incs)
58
+ if ok:
59
+ return
60
+
61
+ raise AppException(
62
+ message=f"Daily quota exceeded: {bucket_name}",
63
+ error_type=ErrorType.RATE_LIMIT.value,
64
+ code="daily_quota_exceeded",
65
+ status_code=429,
66
+ )
67
+
68
+
69
+ __all__ = ["enforce_daily_quota"]
70
+
app/services/register/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """Auto registration services."""
2
+
3
+ from app.services.register.manager import get_auto_register_manager, AutoRegisterManager
4
+
5
+ __all__ = ["AutoRegisterManager", "get_auto_register_manager"]
app/services/register/account_settings_refresh.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ from typing import Iterable, Any
5
+
6
+ from app.core.config import get_config
7
+ from app.core.logger import logger
8
+ from app.services.register.services import (
9
+ UserAgreementService,
10
+ BirthDateService,
11
+ NsfwSettingsService,
12
+ )
13
+ from app.services.token.manager import TokenManager, get_token_manager
14
+
15
+
16
+ DEFAULT_NSFW_REFRESH_CONCURRENCY = 10
17
+ DEFAULT_NSFW_REFRESH_RETRIES = 3
18
+ DEFAULT_IMPERSONATE = "chrome120"
19
+
20
+
21
+ def _extract_cookie_value(cookie_str: str, name: str) -> str | None:
22
+ needle = f"{name}="
23
+ if needle not in cookie_str:
24
+ return None
25
+ for part in cookie_str.split(";"):
26
+ part = part.strip()
27
+ if part.startswith(needle):
28
+ value = part[len(needle):].strip()
29
+ return value or None
30
+ return None
31
+
32
+
33
+ def parse_sso_pair(raw_token: str) -> tuple[str, str]:
34
+ raw = str(raw_token or "").strip()
35
+ if not raw:
36
+ return "", ""
37
+
38
+ if ";" in raw:
39
+ sso = _extract_cookie_value(raw, "sso") or ""
40
+ sso_rw = _extract_cookie_value(raw, "sso-rw") or sso
41
+ return sso.strip(), sso_rw.strip()
42
+
43
+ sso = raw[4:].strip() if raw.startswith("sso=") else raw
44
+ sso_rw = sso
45
+ return sso, sso_rw
46
+
47
+
48
+ def normalize_sso_token(raw_token: str) -> str:
49
+ sso, _ = parse_sso_pair(raw_token)
50
+ return sso
51
+
52
+
53
+ def _coerce_concurrency(value: Any, default: int = DEFAULT_NSFW_REFRESH_CONCURRENCY) -> int:
54
+ try:
55
+ n = int(value)
56
+ except Exception:
57
+ n = default
58
+ return max(1, n)
59
+
60
+
61
+ def _coerce_retries(value: Any, default: int = DEFAULT_NSFW_REFRESH_RETRIES) -> int:
62
+ try:
63
+ n = int(value)
64
+ except Exception:
65
+ n = default
66
+ return max(0, n)
67
+
68
+
69
+ def _format_step_error(result: dict, fallback: str = "unknown error") -> str:
70
+ if not isinstance(result, dict):
71
+ return fallback
72
+
73
+ text = str(result.get("error") or "").strip()
74
+ if text:
75
+ return text
76
+
77
+ status_code = result.get("status_code")
78
+ if status_code is not None:
79
+ return f"HTTP {status_code}"
80
+
81
+ grpc_status = result.get("grpc_status")
82
+ if grpc_status is not None:
83
+ return f"gRPC {grpc_status}"
84
+
85
+ response_text = str(result.get("response_text") or "").strip()
86
+ if response_text:
87
+ return response_text
88
+
89
+ return fallback
90
+
91
+
92
+ class AccountSettingsRefreshService:
93
+ def __init__(self, token_manager: TokenManager, cf_clearance: str = "") -> None:
94
+ self.token_manager = token_manager
95
+ self.cf_clearance = (cf_clearance or "").strip()
96
+
97
+ def _apply_once(self, raw_token: str) -> tuple[bool, str, str]:
98
+ sso, sso_rw = parse_sso_pair(raw_token)
99
+ if not sso:
100
+ return False, "parse", "missing sso"
101
+ if not sso_rw:
102
+ sso_rw = sso
103
+
104
+ user_service = UserAgreementService(cf_clearance=self.cf_clearance)
105
+ birth_service = BirthDateService(cf_clearance=self.cf_clearance)
106
+ nsfw_service = NsfwSettingsService(cf_clearance=self.cf_clearance)
107
+
108
+ tos_result = user_service.accept_tos_version(
109
+ sso=sso,
110
+ sso_rw=sso_rw,
111
+ impersonate=DEFAULT_IMPERSONATE,
112
+ )
113
+ if not tos_result.get("ok"):
114
+ return False, "tos", _format_step_error(tos_result, "accept_tos failed")
115
+
116
+ birth_result = birth_service.set_birth_date(
117
+ sso=sso,
118
+ sso_rw=sso_rw,
119
+ impersonate=DEFAULT_IMPERSONATE,
120
+ )
121
+ if not birth_result.get("ok"):
122
+ return False, "birth", _format_step_error(birth_result, "set_birth_date failed")
123
+
124
+ nsfw_result = nsfw_service.enable_nsfw(
125
+ sso=sso,
126
+ sso_rw=sso_rw,
127
+ impersonate=DEFAULT_IMPERSONATE,
128
+ )
129
+ if not nsfw_result.get("ok"):
130
+ return False, "nsfw", _format_step_error(nsfw_result, "enable_nsfw failed")
131
+
132
+ return True, "", ""
133
+
134
+ async def refresh_tokens(
135
+ self,
136
+ tokens: Iterable[str],
137
+ concurrency: int = DEFAULT_NSFW_REFRESH_CONCURRENCY,
138
+ retries: int = DEFAULT_NSFW_REFRESH_RETRIES,
139
+ ) -> dict[str, Any]:
140
+ resolved_concurrency = _coerce_concurrency(concurrency)
141
+ resolved_retries = _coerce_retries(retries)
142
+
143
+ unique_tokens: list[str] = []
144
+ seen: set[str] = set()
145
+ for token in tokens:
146
+ normalized = normalize_sso_token(str(token or "").strip())
147
+ if not normalized or normalized in seen:
148
+ continue
149
+ seen.add(normalized)
150
+ unique_tokens.append(normalized)
151
+
152
+ if not unique_tokens:
153
+ return {
154
+ "summary": {"total": 0, "success": 0, "failed": 0, "invalidated": 0},
155
+ "failed": [],
156
+ }
157
+
158
+ semaphore = asyncio.Semaphore(resolved_concurrency)
159
+
160
+ async def _run_one(token: str) -> dict[str, Any]:
161
+ max_attempts = resolved_retries + 1
162
+ last_step = "unknown"
163
+ last_error = "unknown error"
164
+
165
+ async with semaphore:
166
+ for attempt in range(1, max_attempts + 1):
167
+ try:
168
+ ok, step, error = await asyncio.to_thread(self._apply_once, token)
169
+ except Exception as exc:
170
+ ok, step, error = False, "exception", str(exc)
171
+
172
+ if ok:
173
+ updated = await self.token_manager.mark_token_account_settings_success(
174
+ token,
175
+ save=False,
176
+ )
177
+ if not updated:
178
+ logger.warning(
179
+ "Account settings refresh succeeded but token not found: {}...",
180
+ token[:10],
181
+ )
182
+ return {
183
+ "token": token,
184
+ "ok": True,
185
+ "attempts": attempt,
186
+ }
187
+
188
+ last_step = step or "unknown"
189
+ last_error = error or "unknown error"
190
+
191
+ reason = (
192
+ f"account_settings_refresh_failed step={last_step} "
193
+ f"attempts={max_attempts} error={last_error}"
194
+ )
195
+ invalidated = await self.token_manager.set_token_invalid(
196
+ token,
197
+ reason=reason,
198
+ save=False,
199
+ )
200
+ return {
201
+ "token": token,
202
+ "ok": False,
203
+ "attempts": max_attempts,
204
+ "step": last_step,
205
+ "error": last_error,
206
+ "invalidated": bool(invalidated),
207
+ }
208
+
209
+ results = await asyncio.gather(*[_run_one(token) for token in unique_tokens])
210
+
211
+ try:
212
+ await self.token_manager.commit()
213
+ except Exception as exc:
214
+ logger.warning("Account settings refresh commit failed: {}", exc)
215
+
216
+ success = sum(1 for item in results if item.get("ok"))
217
+ failed_items = [item for item in results if not item.get("ok")]
218
+ invalidated = sum(1 for item in failed_items if item.get("invalidated"))
219
+
220
+ summary = {
221
+ "total": len(unique_tokens),
222
+ "success": success,
223
+ "failed": len(failed_items),
224
+ "invalidated": invalidated,
225
+ }
226
+
227
+ return {"summary": summary, "failed": failed_items}
228
+
229
+
230
+ async def refresh_account_settings_for_tokens(
231
+ tokens: Iterable[str],
232
+ concurrency: int | None = None,
233
+ retries: int | None = None,
234
+ ) -> dict[str, Any]:
235
+ resolved_concurrency = _coerce_concurrency(
236
+ concurrency if concurrency is not None else get_config(
237
+ "token.nsfw_refresh_concurrency",
238
+ DEFAULT_NSFW_REFRESH_CONCURRENCY,
239
+ ),
240
+ default=DEFAULT_NSFW_REFRESH_CONCURRENCY,
241
+ )
242
+ resolved_retries = _coerce_retries(
243
+ retries if retries is not None else get_config(
244
+ "token.nsfw_refresh_retries",
245
+ DEFAULT_NSFW_REFRESH_RETRIES,
246
+ ),
247
+ default=DEFAULT_NSFW_REFRESH_RETRIES,
248
+ )
249
+
250
+ token_manager = await get_token_manager()
251
+ cf_clearance = str(get_config("grok.cf_clearance", "") or "").strip()
252
+ service = AccountSettingsRefreshService(token_manager, cf_clearance=cf_clearance)
253
+ return await service.refresh_tokens(
254
+ tokens=tokens,
255
+ concurrency=resolved_concurrency,
256
+ retries=resolved_retries,
257
+ )
258
+
259
+
260
+ __all__ = [
261
+ "AccountSettingsRefreshService",
262
+ "parse_sso_pair",
263
+ "normalize_sso_token",
264
+ "refresh_account_settings_for_tokens",
265
+ "DEFAULT_NSFW_REFRESH_CONCURRENCY",
266
+ "DEFAULT_NSFW_REFRESH_RETRIES",
267
+ ]
app/services/register/manager.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Auto registration manager."""
2
+ from __future__ import annotations
3
+
4
+ import asyncio
5
+ import queue
6
+ import threading
7
+ import time
8
+ import uuid
9
+ from dataclasses import dataclass, field
10
+ from typing import Dict, List, Optional
11
+
12
+ from app.core.config import get_config
13
+ from app.core.logger import logger
14
+ from app.services.token.manager import get_token_manager
15
+ from app.services.register.runner import RegisterRunner
16
+ from app.services.register.solver import SolverConfig, TurnstileSolverProcess
17
+
18
+
19
+ @dataclass
20
+ class RegisterJob:
21
+ job_id: str
22
+ total: int
23
+ pool: str
24
+ register_threads: int = 10
25
+ status: str = "starting"
26
+ started_at: float = field(default_factory=time.time)
27
+ finished_at: Optional[float] = None
28
+ completed: int = 0
29
+ added: int = 0
30
+ errors: int = 0
31
+ error: Optional[str] = None
32
+ last_error: Optional[str] = None
33
+ tokens: List[str] = field(default_factory=list)
34
+ _lock: threading.Lock = field(default_factory=threading.Lock, repr=False)
35
+ stop_event: threading.Event = field(default_factory=threading.Event, repr=False)
36
+
37
+ def record_success(self, token: str) -> None:
38
+ with self._lock:
39
+ self.completed += 1
40
+ self.tokens.append(token)
41
+
42
+ def record_added(self) -> None:
43
+ with self._lock:
44
+ self.added += 1
45
+
46
+ def record_error(self, message: str) -> None:
47
+ message = (message or "").strip()
48
+ if len(message) > 500:
49
+ message = message[:500] + "..."
50
+ with self._lock:
51
+ self.errors += 1
52
+ if message:
53
+ self.last_error = message
54
+
55
+ def to_dict(self) -> Dict[str, object]:
56
+ with self._lock:
57
+ return {
58
+ "job_id": self.job_id,
59
+ "status": self.status,
60
+ "pool": self.pool,
61
+ "total": self.total,
62
+ "concurrency": self.register_threads,
63
+ "completed": self.completed,
64
+ "added": self.added,
65
+ "errors": self.errors,
66
+ "error": self.error,
67
+ "last_error": self.last_error,
68
+ "started_at": int(self.started_at * 1000),
69
+ "finished_at": int(self.finished_at * 1000) if self.finished_at else None,
70
+ }
71
+
72
+
73
+ class AutoRegisterManager:
74
+ """Single job manager for auto registration."""
75
+
76
+ _instance: Optional["AutoRegisterManager"] = None
77
+
78
+ def __init__(self) -> None:
79
+ self._lock = asyncio.Lock()
80
+ self._job: Optional[RegisterJob] = None
81
+ self._task: Optional[asyncio.Task] = None
82
+ self._solver: Optional[TurnstileSolverProcess] = None
83
+
84
+ async def start_job(
85
+ self,
86
+ count: int,
87
+ pool: str,
88
+ concurrency: Optional[int] = None,
89
+ ) -> RegisterJob:
90
+ async with self._lock:
91
+ if self._job and self._job.status in {"starting", "running", "stopping"}:
92
+ raise RuntimeError("Auto registration already running")
93
+
94
+ default_threads = get_config("register.register_threads", 10)
95
+ try:
96
+ default_threads = max(1, int(default_threads))
97
+ except Exception:
98
+ default_threads = 10
99
+
100
+ threads = concurrency if isinstance(concurrency, int) and concurrency > 0 else default_threads
101
+
102
+ job = RegisterJob(
103
+ job_id=uuid.uuid4().hex[:8],
104
+ total=count,
105
+ pool=pool,
106
+ register_threads=threads,
107
+ )
108
+ self._job = job
109
+ self._task = asyncio.create_task(self._run_job(job))
110
+ return job
111
+
112
+ def get_status(self, job_id: Optional[str] = None) -> Dict[str, object]:
113
+ if not self._job:
114
+ return {"status": "idle"}
115
+ if job_id and self._job.job_id != job_id:
116
+ return {"status": "not_found"}
117
+ return self._job.to_dict()
118
+
119
+ async def stop_job(self) -> None:
120
+ """Best-effort stop for the current job (used on shutdown)."""
121
+ async with self._lock:
122
+ job = self._job
123
+ task = self._task
124
+ solver = self._solver
125
+
126
+ if not job or job.status not in {"starting", "running"}:
127
+ return
128
+ job.status = "stopping"
129
+ job.stop_event.set()
130
+
131
+ # Stop solver first to avoid noisy retries.
132
+ if solver:
133
+ try:
134
+ await asyncio.to_thread(solver.stop)
135
+ except Exception:
136
+ pass
137
+
138
+ # Give the runner a short grace period to exit.
139
+ if task:
140
+ try:
141
+ await asyncio.wait_for(task, timeout=5.0)
142
+ except Exception:
143
+ # Don't block shutdown; the process is exiting anyway.
144
+ pass
145
+
146
+ async def _run_job(self, job: RegisterJob) -> None:
147
+ job.status = "starting"
148
+
149
+ solver_url = get_config("register.solver_url", "http://127.0.0.1:5072")
150
+ solver_threads = get_config("register.solver_threads", 5)
151
+ try:
152
+ solver_threads = max(1, int(solver_threads))
153
+ except Exception:
154
+ solver_threads = 5
155
+
156
+ auto_start_solver = get_config("register.auto_start_solver", True)
157
+ if not isinstance(auto_start_solver, bool):
158
+ auto_start_solver = str(auto_start_solver).lower() in {"1", "true", "yes", "on"}
159
+
160
+ # Auto-start only for local solver endpoints.
161
+ try:
162
+ from urllib.parse import urlparse
163
+
164
+ host = urlparse(str(solver_url)).hostname or ""
165
+ if host and host not in {"127.0.0.1", "localhost", "::1", "0.0.0.0"}:
166
+ auto_start_solver = False
167
+ except Exception:
168
+ pass
169
+
170
+ solver_debug = get_config("register.solver_debug", False)
171
+ if not isinstance(solver_debug, bool):
172
+ solver_debug = str(solver_debug).lower() in {"1", "true", "yes", "on"}
173
+
174
+ browser_type = str(get_config("register.solver_browser_type", "chromium") or "chromium").strip().lower()
175
+ if browser_type not in {"chromium", "chrome", "msedge", "camoufox"}:
176
+ browser_type = "chromium"
177
+
178
+ solver_cfg = SolverConfig(
179
+ url=str(solver_url or "http://127.0.0.1:5072"),
180
+ threads=solver_threads,
181
+ browser_type=browser_type,
182
+ debug=solver_debug,
183
+ auto_start=auto_start_solver,
184
+ )
185
+ solver = TurnstileSolverProcess(solver_cfg)
186
+ self._solver = solver
187
+
188
+ use_yescaptcha = bool(str(get_config("register.yescaptcha_key", "") or "").strip())
189
+ if use_yescaptcha:
190
+ # When YesCaptcha is configured we don't need a local solver process.
191
+ auto_start_solver = False
192
+ solver.config.auto_start = False
193
+
194
+ # Safety limits to avoid endless loops when upstream is broken.
195
+ max_errors = get_config("register.max_errors", 0)
196
+ try:
197
+ max_errors = int(max_errors)
198
+ except Exception:
199
+ max_errors = 0
200
+ if max_errors <= 0:
201
+ # Default: allow retries, but stop instead of looping "forever".
202
+ max_errors = max(30, int(job.total) * 5)
203
+
204
+ max_runtime_minutes = get_config("register.max_runtime_minutes", 0)
205
+ try:
206
+ max_runtime_minutes = float(max_runtime_minutes)
207
+ except Exception:
208
+ max_runtime_minutes = 0
209
+ max_runtime_sec = max_runtime_minutes * 60 if max_runtime_minutes and max_runtime_minutes > 0 else 0
210
+
211
+ token_queue: queue.Queue[object] = queue.Queue()
212
+ sentinel = object()
213
+
214
+ async def _consume_tokens() -> None:
215
+ mgr = await get_token_manager()
216
+ while True:
217
+ item = await asyncio.to_thread(token_queue.get)
218
+ if item is sentinel:
219
+ break
220
+ token = str(item or "").strip()
221
+ if not token:
222
+ continue
223
+ try:
224
+ if await mgr.add(token, pool_name=job.pool):
225
+ job.record_added()
226
+ except Exception as exc:
227
+ job.record_error(f"save token failed: {exc}")
228
+
229
+ def _on_error(msg: str) -> None:
230
+ job.record_error(msg)
231
+ # Called from worker threads; keep it simple and thread-safe.
232
+ with job._lock:
233
+ if job.status in {"starting", "running"} and job.errors >= max_errors:
234
+ job.status = "error"
235
+ job.error = f"Too many failures ({job.errors}/{max_errors}). Check register config/solver."
236
+ job.stop_event.set()
237
+
238
+ async def _watchdog() -> None:
239
+ if not max_runtime_sec:
240
+ return
241
+ while True:
242
+ await asyncio.sleep(1.0)
243
+ if job.stop_event.is_set():
244
+ return
245
+ if job.status not in {"starting", "running"}:
246
+ return
247
+ if (time.time() - job.started_at) >= max_runtime_sec:
248
+ with job._lock:
249
+ if job.status in {"starting", "running"}:
250
+ job.status = "error"
251
+ job.error = f"Timeout after {max_runtime_minutes:g} minutes."
252
+ job.stop_event.set()
253
+ return
254
+
255
+ try:
256
+ if auto_start_solver:
257
+ try:
258
+ await asyncio.to_thread(solver.start)
259
+ except Exception as exc:
260
+ if not use_yescaptcha:
261
+ raise
262
+ logger.warning("Solver start failed, continuing with YesCaptcha: {}", exc)
263
+
264
+ job.status = "running"
265
+ watchdog_task = asyncio.create_task(_watchdog())
266
+ consumer_task = asyncio.create_task(_consume_tokens())
267
+ runner = RegisterRunner(
268
+ target_count=job.total,
269
+ thread_count=job.register_threads,
270
+ stop_event=job.stop_event,
271
+ on_success=lambda _email, _password, token, _done, _total: (
272
+ job.record_success(token),
273
+ token_queue.put(token),
274
+ ),
275
+ on_error=_on_error,
276
+ )
277
+
278
+ await asyncio.to_thread(runner.run)
279
+
280
+ # Drain token consumer.
281
+ token_queue.put(sentinel)
282
+ await consumer_task
283
+ if job.status == "stopping":
284
+ job.status = "stopped"
285
+ elif job.status != "error":
286
+ # If we returned without reaching the target, treat it as a failure.
287
+ # This makes issues like "TOS/BirthDate/NSFW not enabled" visible to the UI as a failed job.
288
+ if job.completed < job.total:
289
+ job.status = "error"
290
+ suffix = f" Last error: {job.last_error}" if job.last_error else ""
291
+ job.error = f"Registration ended early ({job.completed}/{job.total}).{suffix}".strip()
292
+ else:
293
+ job.status = "completed"
294
+ except Exception as exc:
295
+ job.status = "error"
296
+ job.error = str(exc)
297
+ logger.exception("Auto registration failed")
298
+ finally:
299
+ job.finished_at = time.time()
300
+ # Ensure consumer exits even on exceptions.
301
+ try:
302
+ token_queue.put(sentinel)
303
+ except Exception:
304
+ pass
305
+ try:
306
+ if "consumer_task" in locals():
307
+ await asyncio.wait_for(consumer_task, timeout=10)
308
+ except Exception:
309
+ try:
310
+ consumer_task.cancel()
311
+ except Exception:
312
+ pass
313
+ try:
314
+ if "watchdog_task" in locals():
315
+ watchdog_task.cancel()
316
+ except Exception:
317
+ pass
318
+ self._solver = None
319
+ if auto_start_solver:
320
+ try:
321
+ await asyncio.to_thread(solver.stop)
322
+ except Exception:
323
+ pass
324
+
325
+
326
+ def get_auto_register_manager() -> AutoRegisterManager:
327
+ if AutoRegisterManager._instance is None:
328
+ AutoRegisterManager._instance = AutoRegisterManager()
329
+ return AutoRegisterManager._instance
330
+
331
+
332
+ __all__ = ["AutoRegisterManager", "get_auto_register_manager"]
app/services/register/runner.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Grok account registration runner."""
2
+ from __future__ import annotations
3
+
4
+ import concurrent.futures
5
+ import random
6
+ import re
7
+ import string
8
+ import struct
9
+ import threading
10
+ import time
11
+ from typing import Callable, Dict, List, Optional, Tuple
12
+ from urllib.parse import urljoin
13
+
14
+ from bs4 import BeautifulSoup
15
+ from curl_cffi import requests as curl_requests
16
+
17
+ from app.core.logger import logger
18
+ from app.services.register.services import (
19
+ EmailService,
20
+ TurnstileService,
21
+ UserAgreementService,
22
+ BirthDateService,
23
+ NsfwSettingsService,
24
+ )
25
+
26
+
27
+ SITE_URL = "https://accounts.x.ai"
28
+ DEFAULT_IMPERSONATE = "chrome120"
29
+
30
+ CHROME_PROFILES = [
31
+ {"impersonate": "chrome110", "version": "110.0.0.0", "brand": "chrome"},
32
+ {"impersonate": "chrome119", "version": "119.0.0.0", "brand": "chrome"},
33
+ {"impersonate": "chrome120", "version": "120.0.0.0", "brand": "chrome"},
34
+ {"impersonate": "edge99", "version": "99.0.1150.36", "brand": "edge"},
35
+ {"impersonate": "edge101", "version": "101.0.1210.47", "brand": "edge"},
36
+ ]
37
+
38
+
39
+ def _random_chrome_profile() -> Tuple[str, str]:
40
+ profile = random.choice(CHROME_PROFILES)
41
+ if profile.get("brand") == "edge":
42
+ chrome_major = profile["version"].split(".")[0]
43
+ chrome_version = f"{chrome_major}.0.0.0"
44
+ ua = (
45
+ "Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
46
+ "AppleWebKit/537.36 (KHTML, like Gecko) "
47
+ f"Chrome/{chrome_version} Safari/537.36 Edg/{profile['version']}"
48
+ )
49
+ else:
50
+ ua = (
51
+ "Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
52
+ "AppleWebKit/537.36 (KHTML, like Gecko) "
53
+ f"Chrome/{profile['version']} Safari/537.36"
54
+ )
55
+ return profile["impersonate"], ua
56
+
57
+
58
+ def _generate_random_name() -> str:
59
+ length = random.randint(4, 6)
60
+ return random.choice(string.ascii_uppercase) + "".join(
61
+ random.choice(string.ascii_lowercase) for _ in range(length - 1)
62
+ )
63
+
64
+
65
+ def _generate_random_string(length: int = 15) -> str:
66
+ return "".join(random.choice(string.ascii_lowercase + string.digits) for _ in range(length))
67
+
68
+
69
+ def _encode_grpc_message(field_id: int, string_value: str) -> bytes:
70
+ key = (field_id << 3) | 2
71
+ value_bytes = string_value.encode("utf-8")
72
+ payload = struct.pack("B", key) + struct.pack("B", len(value_bytes)) + value_bytes
73
+ return b"\x00" + struct.pack(">I", len(payload)) + payload
74
+
75
+
76
+ def _encode_grpc_message_verify(email: str, code: str) -> bytes:
77
+ p1 = struct.pack("B", (1 << 3) | 2) + struct.pack("B", len(email)) + email.encode("utf-8")
78
+ p2 = struct.pack("B", (2 << 3) | 2) + struct.pack("B", len(code)) + code.encode("utf-8")
79
+ payload = p1 + p2
80
+ return b"\x00" + struct.pack(">I", len(payload)) + payload
81
+
82
+
83
+ class RegisterRunner:
84
+ """Threaded registration runner."""
85
+
86
+ def __init__(
87
+ self,
88
+ target_count: int = 100,
89
+ thread_count: int = 8,
90
+ on_success: Optional[Callable[[str, str, str, int, int], None]] = None,
91
+ on_error: Optional[Callable[[str], None]] = None,
92
+ stop_event: Optional[threading.Event] = None,
93
+ ) -> None:
94
+ self.target_count = max(1, int(target_count))
95
+ self.thread_count = max(1, int(thread_count))
96
+ self.on_success = on_success
97
+ self.on_error = on_error
98
+ self.stop_event = stop_event or threading.Event()
99
+
100
+ self._post_lock = threading.Lock()
101
+ self._result_lock = threading.Lock()
102
+
103
+ self._success_count = 0
104
+ self._start_time = 0.0
105
+ self._tokens: List[str] = []
106
+ self._accounts: List[Dict[str, str]] = []
107
+
108
+ self._config: Dict[str, Optional[str]] = {
109
+ "site_key": "0x4AAAAAAAhr9JGVDZbrZOo0",
110
+ "action_id": None,
111
+ "state_tree": "%5B%22%22%2C%7B%22children%22%3A%5B%22(app)%22%2C%7B%22children%22%3A%5B%22(auth)%22%2C%7B%22children%22%3A%5B%22sign-up%22%2C%7B%22children%22%3A%5B%22__PAGE__%22%2C%7B%7D%2C%22%2Fsign-up%22%2C%22refresh%22%5D%7D%5D%7D%2Cnull%2Cnull%5D%7D%2Cnull%2Cnull%5D%7D%2Cnull%2Cnull%2Ctrue%5D",
112
+ }
113
+
114
+ @property
115
+ def success_count(self) -> int:
116
+ return self._success_count
117
+
118
+ @property
119
+ def tokens(self) -> List[str]:
120
+ return list(self._tokens)
121
+
122
+ @property
123
+ def accounts(self) -> List[Dict[str, str]]:
124
+ return list(self._accounts)
125
+
126
+ def _record_success(self, email: str, password: str, token: str) -> None:
127
+ with self._result_lock:
128
+ if self._success_count >= self.target_count:
129
+ if not self.stop_event.is_set():
130
+ self.stop_event.set()
131
+ return
132
+
133
+ self._success_count += 1
134
+ self._tokens.append(token)
135
+ self._accounts.append({"email": email, "password": password, "token": token})
136
+
137
+ avg = (time.time() - self._start_time) / max(1, self._success_count)
138
+ logger.info(
139
+ "Register success: {} | sso={}... | avg={:.1f}s ({}/{})",
140
+ email,
141
+ token[:12],
142
+ avg,
143
+ self._success_count,
144
+ self.target_count,
145
+ )
146
+
147
+ if self.on_success:
148
+ try:
149
+ self.on_success(email, password, token, self._success_count, self.target_count)
150
+ except Exception:
151
+ pass
152
+
153
+ if self._success_count >= self.target_count and not self.stop_event.is_set():
154
+ self.stop_event.set()
155
+
156
+ def _record_error(self, message: str) -> None:
157
+ if self.on_error:
158
+ try:
159
+ self.on_error(message)
160
+ except Exception:
161
+ pass
162
+
163
+ def _init_config(self) -> None:
164
+ logger.info("Register: initializing action config...")
165
+ start_url = f"{SITE_URL}/sign-up"
166
+
167
+ with curl_requests.Session(impersonate=DEFAULT_IMPERSONATE) as session:
168
+ html = session.get(start_url, timeout=15).text
169
+
170
+ key_match = re.search(r'sitekey":"(0x4[a-zA-Z0-9_-]+)"', html)
171
+ if key_match:
172
+ self._config["site_key"] = key_match.group(1)
173
+
174
+ tree_match = re.search(r'next-router-state-tree":"([^"]+)"', html)
175
+ if tree_match:
176
+ self._config["state_tree"] = tree_match.group(1)
177
+
178
+ soup = BeautifulSoup(html, "html.parser")
179
+ js_urls = [
180
+ urljoin(start_url, script["src"])
181
+ for script in soup.find_all("script", src=True)
182
+ if "_next/static" in script["src"]
183
+ ]
184
+ for js_url in js_urls:
185
+ js_content = session.get(js_url, timeout=15).text
186
+ match = re.search(r"7f[a-fA-F0-9]{40}", js_content)
187
+ if match:
188
+ self._config["action_id"] = match.group(0)
189
+ logger.info("Register: Action ID found: {}", self._config["action_id"])
190
+ break
191
+
192
+ if not self._config.get("action_id"):
193
+ raise RuntimeError("Register init failed: missing action_id")
194
+
195
+ def _send_email_code(self, session: curl_requests.Session, email: str) -> bool:
196
+ url = f"{SITE_URL}/auth_mgmt.AuthManagement/CreateEmailValidationCode"
197
+ data = _encode_grpc_message(1, email)
198
+ headers = {
199
+ "content-type": "application/grpc-web+proto",
200
+ "x-grpc-web": "1",
201
+ "x-user-agent": "connect-es/2.1.1",
202
+ "origin": SITE_URL,
203
+ "referer": f"{SITE_URL}/sign-up?redirect=grok-com",
204
+ }
205
+ try:
206
+ res = session.post(url, data=data, headers=headers, timeout=15)
207
+ return res.status_code == 200
208
+ except Exception as exc:
209
+ self._record_error(f"send code error: {email} - {exc}")
210
+ return False
211
+
212
+ def _verify_email_code(self, session: curl_requests.Session, email: str, code: str) -> bool:
213
+ url = f"{SITE_URL}/auth_mgmt.AuthManagement/VerifyEmailValidationCode"
214
+ data = _encode_grpc_message_verify(email, code)
215
+ headers = {
216
+ "content-type": "application/grpc-web+proto",
217
+ "x-grpc-web": "1",
218
+ "x-user-agent": "connect-es/2.1.1",
219
+ "origin": SITE_URL,
220
+ "referer": f"{SITE_URL}/sign-up?redirect=grok-com",
221
+ }
222
+ try:
223
+ res = session.post(url, data=data, headers=headers, timeout=15)
224
+ return res.status_code == 200
225
+ except Exception as exc:
226
+ self._record_error(f"verify code error: {email} - {exc}")
227
+ return False
228
+
229
+ def _register_single_thread(self) -> None:
230
+ time.sleep(random.uniform(0, 5))
231
+
232
+ try:
233
+ email_service = EmailService()
234
+ turnstile_service = TurnstileService()
235
+ user_agreement_service = UserAgreementService()
236
+ birth_date_service = BirthDateService()
237
+ nsfw_service = NsfwSettingsService()
238
+ except Exception as exc:
239
+ self._record_error(f"service init failed: {exc}")
240
+ return
241
+
242
+ final_action_id = self._config.get("action_id")
243
+ if not final_action_id:
244
+ self._record_error("missing action id")
245
+ return
246
+
247
+ while not self.stop_event.is_set():
248
+ try:
249
+ impersonate_fingerprint, account_user_agent = _random_chrome_profile()
250
+
251
+ with curl_requests.Session(impersonate=impersonate_fingerprint) as session:
252
+ try:
253
+ session.get(SITE_URL, timeout=10)
254
+ except Exception:
255
+ pass
256
+
257
+ password = _generate_random_string()
258
+
259
+ jwt, email = email_service.create_email()
260
+ if not email:
261
+ self._record_error("create_email failed")
262
+ time.sleep(5)
263
+ continue
264
+
265
+ if self.stop_event.is_set():
266
+ return
267
+
268
+ if not self._send_email_code(session, email):
269
+ self._record_error(f"send_email_code failed: {email}")
270
+ time.sleep(5)
271
+ continue
272
+
273
+ verify_code = None
274
+ for _ in range(30):
275
+ time.sleep(1)
276
+ if self.stop_event.is_set():
277
+ return
278
+ content = email_service.fetch_first_email(jwt)
279
+ if content:
280
+ match = re.search(r">([A-Z0-9]{3}-[A-Z0-9]{3})<", content)
281
+ if match:
282
+ verify_code = match.group(1).replace("-", "")
283
+ break
284
+
285
+ if not verify_code:
286
+ self._record_error(f"verify_code not received: {email}")
287
+ time.sleep(3)
288
+ continue
289
+
290
+ if not self._verify_email_code(session, email, verify_code):
291
+ self._record_error(f"verify_email_code failed: {email}")
292
+ time.sleep(3)
293
+ continue
294
+
295
+ for _ in range(3):
296
+ if self.stop_event.is_set():
297
+ return
298
+
299
+ try:
300
+ task_id = turnstile_service.create_task(f"{SITE_URL}/sign-up", self._config["site_key"] or "")
301
+ except Exception as exc:
302
+ self._record_error(f"turnstile create_task failed: {exc}")
303
+ time.sleep(2)
304
+ continue
305
+
306
+ token = turnstile_service.get_response(task_id, stop_event=self.stop_event)
307
+
308
+ if not token:
309
+ self._record_error(f"turnstile failed: {turnstile_service.last_error or 'no token'}")
310
+ time.sleep(2)
311
+ continue
312
+
313
+ headers = {
314
+ "user-agent": account_user_agent,
315
+ "accept": "text/x-component",
316
+ "content-type": "text/plain;charset=UTF-8",
317
+ "origin": SITE_URL,
318
+ "referer": f"{SITE_URL}/sign-up",
319
+ "cookie": f"__cf_bm={session.cookies.get('__cf_bm','')}",
320
+ "next-router-state-tree": self._config["state_tree"] or "",
321
+ "next-action": final_action_id,
322
+ }
323
+ payload = [
324
+ {
325
+ "emailValidationCode": verify_code,
326
+ "createUserAndSessionRequest": {
327
+ "email": email,
328
+ "givenName": _generate_random_name(),
329
+ "familyName": _generate_random_name(),
330
+ "clearTextPassword": password,
331
+ "tosAcceptedVersion": "$undefined",
332
+ },
333
+ "turnstileToken": token,
334
+ "promptOnDuplicateEmail": True,
335
+ }
336
+ ]
337
+
338
+ with self._post_lock:
339
+ res = session.post(
340
+ f"{SITE_URL}/sign-up",
341
+ json=payload,
342
+ headers=headers,
343
+ timeout=20,
344
+ )
345
+
346
+ if res.status_code != 200:
347
+ self._record_error(f"sign_up http {res.status_code}")
348
+ time.sleep(3)
349
+ continue
350
+
351
+ match = re.search(r'(https://[^" \s]+set-cookie\?q=[^:" \s]+)1:', res.text)
352
+ if not match:
353
+ self._record_error("sign_up missing set-cookie redirect")
354
+ break
355
+
356
+ verify_url = match.group(1)
357
+ session.get(verify_url, allow_redirects=True, timeout=15)
358
+
359
+ sso = session.cookies.get("sso")
360
+ sso_rw = session.cookies.get("sso-rw")
361
+ if not sso:
362
+ self._record_error("sign_up missing sso cookie")
363
+ break
364
+
365
+ tos_result = user_agreement_service.accept_tos_version(
366
+ sso=sso,
367
+ sso_rw=sso_rw or "",
368
+ impersonate=impersonate_fingerprint,
369
+ user_agent=account_user_agent,
370
+ )
371
+ if not tos_result.get("ok") or not tos_result.get("hex_reply"):
372
+ self._record_error(f"accept_tos failed: {tos_result.get('error') or 'unknown'}")
373
+ break
374
+
375
+ birth_result = birth_date_service.set_birth_date(
376
+ sso=sso,
377
+ sso_rw=sso_rw or "",
378
+ impersonate=impersonate_fingerprint,
379
+ user_agent=account_user_agent,
380
+ )
381
+ if not birth_result.get("ok"):
382
+ self._record_error(
383
+ f"set_birth_date failed: {birth_result.get('error') or 'unknown'}"
384
+ )
385
+ break
386
+
387
+ nsfw_result = nsfw_service.enable_nsfw(
388
+ sso=sso,
389
+ sso_rw=sso_rw or "",
390
+ impersonate=impersonate_fingerprint,
391
+ user_agent=account_user_agent,
392
+ )
393
+ if not nsfw_result.get("ok") or not nsfw_result.get("hex_reply"):
394
+ self._record_error(f"enable_nsfw failed: {nsfw_result.get('error') or 'unknown'}")
395
+ break
396
+
397
+ self._record_success(email, password, sso)
398
+ break
399
+
400
+ except Exception as exc:
401
+ self._record_error(f"thread error: {str(exc)[:80]}")
402
+ time.sleep(3)
403
+
404
+ def run(self) -> List[str]:
405
+ """Run the registration process and return collected tokens."""
406
+ self._init_config()
407
+ self._start_time = time.time()
408
+
409
+ logger.info("Register: starting {} threads, target {}", self.thread_count, self.target_count)
410
+
411
+ with concurrent.futures.ThreadPoolExecutor(max_workers=self.thread_count) as executor:
412
+ futures = [executor.submit(self._register_single_thread) for _ in range(self.thread_count)]
413
+ concurrent.futures.wait(futures)
414
+
415
+ return list(self._tokens)
app/services/register/services/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Registration helper services."""
2
+
3
+ from app.services.register.services.email_service import EmailService
4
+ from app.services.register.services.turnstile_service import TurnstileService
5
+ from app.services.register.services.user_agreement_service import UserAgreementService
6
+ from app.services.register.services.birth_date_service import BirthDateService
7
+ from app.services.register.services.nsfw_service import NsfwSettingsService
8
+
9
+ __all__ = [
10
+ "EmailService",
11
+ "TurnstileService",
12
+ "UserAgreementService",
13
+ "BirthDateService",
14
+ "NsfwSettingsService",
15
+ ]
app/services/register/services/birth_date_service.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import datetime
4
+ import random
5
+ from typing import Any, Dict, Optional
6
+
7
+ from curl_cffi import requests
8
+
9
+ DEFAULT_USER_AGENT = (
10
+ "Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
11
+ "AppleWebKit/537.36 (KHTML, like Gecko) "
12
+ "Chrome/120.0.0.0 Safari/537.36"
13
+ )
14
+
15
+
16
+ def generate_random_birthdate() -> str:
17
+ """Generate a random birth date between 20 and 40 years old."""
18
+ today = datetime.date.today()
19
+ age = random.randint(20, 40)
20
+ birth_year = today.year - age
21
+ birth_month = random.randint(1, 12)
22
+ birth_day = random.randint(1, 28)
23
+ return f"{birth_year}-{birth_month:02d}-{birth_day:02d}T16:00:00.000Z"
24
+
25
+
26
+ class BirthDateService:
27
+ """Set account birth date via Grok REST API."""
28
+
29
+ def __init__(self, cf_clearance: str = ""):
30
+ self.cf_clearance = (cf_clearance or "").strip()
31
+
32
+ def set_birth_date(
33
+ self,
34
+ sso: str,
35
+ sso_rw: str,
36
+ impersonate: str,
37
+ user_agent: Optional[str] = None,
38
+ cf_clearance: Optional[str] = None,
39
+ timeout: int = 15,
40
+ ) -> Dict[str, Any]:
41
+ if not sso:
42
+ return {
43
+ "ok": False,
44
+ "status_code": None,
45
+ "response_text": "",
46
+ "error": "missing sso",
47
+ }
48
+ if not sso_rw:
49
+ return {
50
+ "ok": False,
51
+ "status_code": None,
52
+ "response_text": "",
53
+ "error": "missing sso-rw",
54
+ }
55
+
56
+ url = "https://grok.com/rest/auth/set-birth-date"
57
+ cookies = {
58
+ "sso": sso,
59
+ "sso-rw": sso_rw,
60
+ }
61
+ clearance = (cf_clearance if cf_clearance is not None else self.cf_clearance).strip()
62
+ if clearance:
63
+ cookies["cf_clearance"] = clearance
64
+
65
+ headers = {
66
+ "content-type": "application/json",
67
+ "origin": "https://grok.com",
68
+ "referer": "https://grok.com/",
69
+ "user-agent": user_agent or DEFAULT_USER_AGENT,
70
+ }
71
+ payload = {"birthDate": generate_random_birthdate()}
72
+
73
+ try:
74
+ response = requests.post(
75
+ url,
76
+ headers=headers,
77
+ cookies=cookies,
78
+ json=payload,
79
+ impersonate=impersonate or "chrome120",
80
+ timeout=timeout,
81
+ )
82
+ status_code = response.status_code
83
+ response_text = response.text or ""
84
+ ok = status_code == 200
85
+ return {
86
+ "ok": ok,
87
+ "status_code": status_code,
88
+ "response_text": response_text,
89
+ "error": None if ok else f"HTTP {status_code}",
90
+ }
91
+ except Exception as e:
92
+ return {
93
+ "ok": False,
94
+ "status_code": None,
95
+ "response_text": "",
96
+ "error": str(e),
97
+ }
app/services/register/services/email_service.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Email service for temporary inbox creation."""
2
+ from __future__ import annotations
3
+
4
+ import os
5
+ import random
6
+ import string
7
+ from typing import Tuple, Optional
8
+
9
+ import requests
10
+
11
+ from app.core.config import get_config
12
+
13
+
14
+ class EmailService:
15
+ """Email service wrapper."""
16
+
17
+ def __init__(
18
+ self,
19
+ worker_domain: Optional[str] = None,
20
+ email_domain: Optional[str] = None,
21
+ admin_password: Optional[str] = None,
22
+ ) -> None:
23
+ self.worker_domain = (
24
+ (worker_domain or get_config("register.worker_domain", "") or os.getenv("WORKER_DOMAIN", "")).strip()
25
+ )
26
+ self.email_domain = (
27
+ (email_domain or get_config("register.email_domain", "") or os.getenv("EMAIL_DOMAIN", "")).strip()
28
+ )
29
+ self.admin_password = (
30
+ (admin_password or get_config("register.admin_password", "") or os.getenv("ADMIN_PASSWORD", "")).strip()
31
+ )
32
+
33
+ if not all([self.worker_domain, self.email_domain, self.admin_password]):
34
+ raise ValueError(
35
+ "Missing required email settings: register.worker_domain, register.email_domain, "
36
+ "register.admin_password"
37
+ )
38
+
39
+ def _generate_random_name(self) -> str:
40
+ letters1 = "".join(random.choices(string.ascii_lowercase, k=random.randint(4, 6)))
41
+ numbers = "".join(random.choices(string.digits, k=random.randint(1, 3)))
42
+ letters2 = "".join(random.choices(string.ascii_lowercase, k=random.randint(0, 5)))
43
+ return letters1 + numbers + letters2
44
+
45
+ def create_email(self) -> Tuple[Optional[str], Optional[str]]:
46
+ """Create a temporary mailbox. Returns (jwt, address)."""
47
+ url = f"https://{self.worker_domain}/admin/new_address"
48
+ try:
49
+ random_name = self._generate_random_name()
50
+ res = requests.post(
51
+ url,
52
+ json={
53
+ "enablePrefix": True,
54
+ "name": random_name,
55
+ "domain": self.email_domain,
56
+ },
57
+ headers={
58
+ "x-admin-auth": self.admin_password,
59
+ "Content-Type": "application/json",
60
+ },
61
+ timeout=10,
62
+ )
63
+ if res.status_code == 200:
64
+ data = res.json()
65
+ return data.get("jwt"), data.get("address")
66
+ print(f"[-] Email create failed: {res.status_code} - {res.text}")
67
+ except Exception as exc: # pragma: no cover - network/remote errors
68
+ print(f"[-] Email create error ({url}): {exc}")
69
+ return None, None
70
+
71
+ def fetch_first_email(self, jwt: str) -> Optional[str]:
72
+ """Fetch the first email content for the mailbox."""
73
+ try:
74
+ res = requests.get(
75
+ f"https://{self.worker_domain}/api/mails",
76
+ params={"limit": 10, "offset": 0},
77
+ headers={
78
+ "Authorization": f"Bearer {jwt}",
79
+ "Content-Type": "application/json",
80
+ },
81
+ timeout=10,
82
+ )
83
+ if res.status_code == 200:
84
+ data = res.json()
85
+ if data.get("results"):
86
+ return data["results"][0].get("raw")
87
+ return None
88
+ except Exception as exc: # pragma: no cover - network/remote errors
89
+ print(f"Email fetch failed: {exc}")
90
+ return None
app/services/register/services/nsfw_service.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Optional, Dict, Any
4
+
5
+ from curl_cffi import requests
6
+
7
+ DEFAULT_USER_AGENT = (
8
+ "Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
9
+ "AppleWebKit/537.36 (KHTML, like Gecko) "
10
+ "Chrome/120.0.0.0 Safari/537.36"
11
+ )
12
+
13
+
14
+ class NsfwSettingsService:
15
+ """开启 NSFW 相关设置(线程安全,无全局状态)。"""
16
+
17
+ def __init__(self, cf_clearance: str = ""):
18
+ self.cf_clearance = (cf_clearance or "").strip()
19
+
20
+ def enable_nsfw(
21
+ self,
22
+ sso: str,
23
+ sso_rw: str,
24
+ impersonate: str,
25
+ user_agent: Optional[str] = None,
26
+ cf_clearance: Optional[str] = None,
27
+ timeout: int = 15,
28
+ ) -> Dict[str, Any]:
29
+ """
30
+ 启用 always_show_nsfw_content。
31
+ 返回: {
32
+ ok: bool,
33
+ hex_reply: str,
34
+ status_code: int | None,
35
+ grpc_status: str | None,
36
+ error: str | None
37
+ }
38
+ """
39
+ if not sso:
40
+ return {
41
+ "ok": False,
42
+ "hex_reply": "",
43
+ "status_code": None,
44
+ "grpc_status": None,
45
+ "error": "缺少 sso",
46
+ }
47
+ if not sso_rw:
48
+ return {
49
+ "ok": False,
50
+ "hex_reply": "",
51
+ "status_code": None,
52
+ "grpc_status": None,
53
+ "error": "缺少 sso-rw",
54
+ }
55
+
56
+ url = "https://grok.com/auth_mgmt.AuthManagement/UpdateUserFeatureControls"
57
+
58
+ cookies = {
59
+ "sso": sso,
60
+ "sso-rw": sso_rw,
61
+ }
62
+ clearance = (cf_clearance if cf_clearance is not None else self.cf_clearance).strip()
63
+ if clearance:
64
+ cookies["cf_clearance"] = clearance
65
+
66
+ headers = {
67
+ "content-type": "application/grpc-web+proto",
68
+ "origin": "https://grok.com",
69
+ "referer": "https://grok.com/?_s=data",
70
+ "x-grpc-web": "1",
71
+ "user-agent": user_agent or DEFAULT_USER_AGENT,
72
+ }
73
+
74
+ data = (
75
+ b"\x00\x00\x00\x00"
76
+ b"\x20"
77
+ b"\x0a\x02\x10\x01"
78
+ b"\x12\x1a"
79
+ b"\x0a\x18"
80
+ b"always_show_nsfw_content"
81
+ )
82
+
83
+ try:
84
+ response = requests.post(
85
+ url,
86
+ headers=headers,
87
+ cookies=cookies,
88
+ data=data,
89
+ impersonate=impersonate or "chrome120",
90
+ timeout=timeout,
91
+ )
92
+ hex_reply = response.content.hex()
93
+ grpc_status = response.headers.get("grpc-status")
94
+
95
+ error = None
96
+ ok = response.status_code == 200 and (grpc_status in (None, "0"))
97
+ if response.status_code == 403:
98
+ error = "403 Forbidden"
99
+ elif response.status_code != 200:
100
+ error = f"HTTP {response.status_code}"
101
+ elif grpc_status not in (None, "0"):
102
+ error = f"gRPC {grpc_status}"
103
+
104
+ return {
105
+ "ok": ok,
106
+ "hex_reply": hex_reply,
107
+ "status_code": response.status_code,
108
+ "grpc_status": grpc_status,
109
+ "error": error,
110
+ }
111
+ except Exception as e:
112
+ return {
113
+ "ok": False,
114
+ "hex_reply": "",
115
+ "status_code": None,
116
+ "grpc_status": None,
117
+ "error": str(e),
118
+ }
app/services/register/services/turnstile_service.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Turnstile solving service."""
2
+ from __future__ import annotations
3
+
4
+ import os
5
+ import time
6
+ from typing import Optional
7
+
8
+ from app.core.logger import logger
9
+
10
+ import requests
11
+
12
+ from app.core.config import get_config
13
+
14
+
15
+ class TurnstileService:
16
+ """Turnstile solver wrapper (local solver or YesCaptcha)."""
17
+
18
+ def __init__(
19
+ self,
20
+ solver_url: Optional[str] = None,
21
+ yescaptcha_key: Optional[str] = None,
22
+ ) -> None:
23
+ self.yescaptcha_key = (
24
+ (yescaptcha_key or get_config("register.yescaptcha_key", "") or os.getenv("YESCAPTCHA_KEY", "")).strip()
25
+ )
26
+ self.solver_url = (
27
+ solver_url
28
+ or get_config("register.solver_url", "")
29
+ or os.getenv("TURNSTILE_SOLVER_URL", "")
30
+ or "http://127.0.0.1:5072"
31
+ ).strip()
32
+ self.yescaptcha_api = "https://api.yescaptcha.com"
33
+ self.last_error: Optional[str] = None
34
+
35
+ def create_task(self, siteurl: str, sitekey: str) -> str:
36
+ """Create a Turnstile task and return task ID."""
37
+ self.last_error = None
38
+ if self.yescaptcha_key:
39
+ url = f"{self.yescaptcha_api}/createTask"
40
+ payload = {
41
+ "clientKey": self.yescaptcha_key,
42
+ "task": {
43
+ "type": "TurnstileTaskProxyless",
44
+ "websiteURL": siteurl,
45
+ "websiteKey": sitekey,
46
+ },
47
+ }
48
+ response = requests.post(url, json=payload, timeout=20)
49
+ response.raise_for_status()
50
+ data = response.json()
51
+ if data.get("errorId") != 0:
52
+ desc = data.get("errorDescription") or "unknown"
53
+ self.last_error = f"YesCaptcha createTask failed: {desc}"
54
+ raise RuntimeError(self.last_error)
55
+ return data["taskId"]
56
+
57
+ response = requests.get(
58
+ f"{self.solver_url}/turnstile",
59
+ params={"url": siteurl, "sitekey": sitekey},
60
+ timeout=20,
61
+ )
62
+ response.raise_for_status()
63
+ data = response.json()
64
+ task_id = data.get("taskId")
65
+ if not task_id:
66
+ self.last_error = data.get("errorDescription") or data.get("errorCode") or "missing taskId"
67
+ raise RuntimeError(f"Solver create task failed: {self.last_error}")
68
+ return task_id
69
+
70
+ def get_response(
71
+ self,
72
+ task_id: str,
73
+ max_retries: int = 30,
74
+ initial_delay: int = 5,
75
+ retry_delay: int = 2,
76
+ stop_event: object | None = None,
77
+ ) -> Optional[str]:
78
+ """Fetch a Turnstile solution token."""
79
+ self.last_error = None
80
+ # Make shutdown/cancel responsive.
81
+ if initial_delay > 0:
82
+ for _ in range(int(initial_delay * 10)):
83
+ if stop_event is not None and getattr(stop_event, "is_set", lambda: False)():
84
+ return None
85
+ time.sleep(0.1)
86
+
87
+ for _ in range(max_retries):
88
+ if stop_event is not None and getattr(stop_event, "is_set", lambda: False)():
89
+ return None
90
+ try:
91
+ if self.yescaptcha_key:
92
+ url = f"{self.yescaptcha_api}/getTaskResult"
93
+ payload = {"clientKey": self.yescaptcha_key, "taskId": task_id}
94
+ response = requests.post(url, json=payload, timeout=20)
95
+ response.raise_for_status()
96
+ data = response.json()
97
+
98
+ if data.get("errorId") != 0:
99
+ self.last_error = str(data.get("errorDescription") or "unknown")
100
+ logger.warning("YesCaptcha getTaskResult failed: {}", self.last_error)
101
+ return None
102
+
103
+ status = data.get("status")
104
+ if status == "ready":
105
+ token = data.get("solution", {}).get("token")
106
+ if token:
107
+ return token
108
+ self.last_error = "YesCaptcha returned empty token"
109
+ logger.warning(self.last_error)
110
+ return None
111
+ if status == "processing":
112
+ if retry_delay > 0:
113
+ for _ in range(int(retry_delay * 10)):
114
+ if stop_event is not None and getattr(stop_event, "is_set", lambda: False)():
115
+ return None
116
+ time.sleep(0.1)
117
+ continue
118
+ self.last_error = f"YesCaptcha unexpected status: {status}"
119
+ logger.warning(self.last_error)
120
+ if retry_delay > 0:
121
+ for _ in range(int(retry_delay * 10)):
122
+ if stop_event is not None and getattr(stop_event, "is_set", lambda: False)():
123
+ return None
124
+ time.sleep(0.1)
125
+ continue
126
+
127
+ response = requests.get(
128
+ f"{self.solver_url}/result",
129
+ params={"id": task_id},
130
+ timeout=20,
131
+ )
132
+ response.raise_for_status()
133
+ data = response.json()
134
+
135
+ # Solver error -> stop early (avoid polling forever on unsolvable tasks).
136
+ error_id = data.get("errorId")
137
+ if error_id is not None and error_id != 0:
138
+ self.last_error = str(data.get("errorDescription") or data.get("errorCode") or "solver error")
139
+ return None
140
+
141
+ token = data.get("solution", {}).get("token")
142
+ if token:
143
+ if token != "CAPTCHA_FAIL":
144
+ return token
145
+ self.last_error = "CAPTCHA_FAIL"
146
+ return None
147
+ if retry_delay > 0:
148
+ for _ in range(int(retry_delay * 10)):
149
+ if stop_event is not None and getattr(stop_event, "is_set", lambda: False)():
150
+ return None
151
+ time.sleep(0.1)
152
+ except Exception as exc: # pragma: no cover - network/remote errors
153
+ self.last_error = str(exc)
154
+ logger.debug("Turnstile response error: {}", exc)
155
+ if retry_delay > 0:
156
+ for _ in range(int(retry_delay * 10)):
157
+ if stop_event is not None and getattr(stop_event, "is_set", lambda: False)():
158
+ return None
159
+ time.sleep(0.1)
160
+
161
+ return None
app/services/register/services/user_agreement_service.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Optional, Dict, Any
4
+
5
+ from curl_cffi import requests
6
+
7
+ DEFAULT_USER_AGENT = (
8
+ "Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
9
+ "AppleWebKit/537.36 (KHTML, like Gecko) "
10
+ "Chrome/120.0.0.0 Safari/537.36"
11
+ )
12
+
13
+
14
+ class UserAgreementService:
15
+ """处理账号协议同意流程(线程安全,无全局状态)。"""
16
+
17
+ def __init__(self, cf_clearance: str = ""):
18
+ self.cf_clearance = (cf_clearance or "").strip()
19
+
20
+ def accept_tos_version(
21
+ self,
22
+ sso: str,
23
+ sso_rw: str,
24
+ impersonate: str,
25
+ user_agent: Optional[str] = None,
26
+ cf_clearance: Optional[str] = None,
27
+ timeout: int = 15,
28
+ ) -> Dict[str, Any]:
29
+ """
30
+ 同意 TOS 版本。
31
+ 返回: {
32
+ ok: bool,
33
+ hex_reply: str,
34
+ status_code: int | None,
35
+ grpc_status: str | None,
36
+ error: str | None
37
+ }
38
+ """
39
+ if not sso:
40
+ return {
41
+ "ok": False,
42
+ "hex_reply": "",
43
+ "status_code": None,
44
+ "grpc_status": None,
45
+ "error": "缺少 sso",
46
+ }
47
+ if not sso_rw:
48
+ return {
49
+ "ok": False,
50
+ "hex_reply": "",
51
+ "status_code": None,
52
+ "grpc_status": None,
53
+ "error": "缺少 sso-rw",
54
+ }
55
+
56
+ url = "https://accounts.x.ai/auth_mgmt.AuthManagement/SetTosAcceptedVersion"
57
+
58
+ cookies = {
59
+ "sso": sso,
60
+ "sso-rw": sso_rw,
61
+ }
62
+ clearance = (cf_clearance if cf_clearance is not None else self.cf_clearance).strip()
63
+ if clearance:
64
+ cookies["cf_clearance"] = clearance
65
+
66
+ headers = {
67
+ "content-type": "application/grpc-web+proto",
68
+ "origin": "https://accounts.x.ai",
69
+ "referer": "https://accounts.x.ai/accept-tos",
70
+ "x-grpc-web": "1",
71
+ "user-agent": user_agent or DEFAULT_USER_AGENT,
72
+ }
73
+
74
+ data = (
75
+ b"\x00\x00\x00\x00" # 头部
76
+ b"\x02" # 长度
77
+ b"\x10\x01" # Field 2 = 1
78
+ )
79
+
80
+ try:
81
+ response = requests.post(
82
+ url,
83
+ headers=headers,
84
+ cookies=cookies,
85
+ data=data,
86
+ impersonate=impersonate or "chrome120",
87
+ timeout=timeout,
88
+ )
89
+ hex_reply = response.content.hex()
90
+ grpc_status = response.headers.get("grpc-status")
91
+
92
+ error = None
93
+ ok = response.status_code == 200 and (grpc_status in (None, "0"))
94
+ if response.status_code == 403:
95
+ error = "403 Forbidden"
96
+ elif response.status_code != 200:
97
+ error = f"HTTP {response.status_code}"
98
+ elif grpc_status not in (None, "0"):
99
+ error = f"gRPC {grpc_status}"
100
+
101
+ return {
102
+ "ok": ok,
103
+ "hex_reply": hex_reply,
104
+ "status_code": response.status_code,
105
+ "grpc_status": grpc_status,
106
+ "error": error,
107
+ }
108
+ except Exception as e:
109
+ return {
110
+ "ok": False,
111
+ "hex_reply": "",
112
+ "status_code": None,
113
+ "grpc_status": None,
114
+ "error": str(e),
115
+ }
app/services/register/solver.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Local Turnstile solver process manager."""
2
+ from __future__ import annotations
3
+
4
+ import socket
5
+ import subprocess
6
+ import sys
7
+ import time
8
+ from dataclasses import dataclass
9
+ from pathlib import Path
10
+ from typing import Optional
11
+ from urllib.parse import urlparse
12
+
13
+ from app.core.logger import logger
14
+
15
+
16
+ def _wait_for_port(host: str, port: int, timeout: float = 20.0) -> bool:
17
+ deadline = time.time() + timeout
18
+ while time.time() < deadline:
19
+ try:
20
+ with socket.create_connection((host, port), timeout=1):
21
+ return True
22
+ except Exception:
23
+ time.sleep(0.5)
24
+ return False
25
+
26
+
27
+ @dataclass
28
+ class SolverConfig:
29
+ url: str
30
+ threads: int = 5
31
+ browser_type: str = "chromium"
32
+ debug: bool = False
33
+ auto_start: bool = True
34
+
35
+
36
+ class TurnstileSolverProcess:
37
+ """Start/stop a local Turnstile solver."""
38
+
39
+ def __init__(self, config: SolverConfig) -> None:
40
+ self.config = config
41
+ self._process: Optional[subprocess.Popen] = None
42
+ self._started_by_us = False
43
+ self._repo_root = Path(__file__).resolve().parents[3]
44
+ self._python_exe: str = sys.executable
45
+ self._actual_browser_type: str = config.browser_type
46
+
47
+ def _script_path(self) -> Path:
48
+ return self._repo_root / "scripts" / "turnstile_solver" / "api_solver.py"
49
+
50
+ def _can_import(self, python_exe: str, modules: list[str]) -> bool:
51
+ """Check whether a python executable can import given modules."""
52
+ code = "; ".join([f"import {m}" for m in modules])
53
+ try:
54
+ subprocess.check_call(
55
+ [python_exe, "-c", code],
56
+ stdout=subprocess.DEVNULL,
57
+ stderr=subprocess.DEVNULL,
58
+ )
59
+ return True
60
+ except Exception:
61
+ return False
62
+
63
+ def _windows_where_python(self) -> list[str]:
64
+ """List python.exe candidates on Windows using `where python` (best-effort)."""
65
+ if not sys.platform.startswith("win"):
66
+ return []
67
+ try:
68
+ out = subprocess.check_output(
69
+ ["where", "python"],
70
+ stderr=subprocess.DEVNULL,
71
+ text=True,
72
+ encoding="utf-8",
73
+ errors="ignore",
74
+ )
75
+ except Exception:
76
+ return []
77
+
78
+ paths: list[str] = []
79
+ seen: set[str] = set()
80
+ for line in out.splitlines():
81
+ p = (line or "").strip().strip('"')
82
+ if not p:
83
+ continue
84
+ key = p.lower()
85
+ if key in seen:
86
+ continue
87
+ seen.add(key)
88
+ paths.append(p)
89
+ return paths
90
+
91
+ def _select_runtime(self) -> None:
92
+ """Pick python executable + browser type to run solver with.
93
+
94
+ Practical notes (Windows):
95
+ - The API server may run in a venv (e.g. Python 3.13).
96
+ - Many users install the solver dependencies (camoufox/patchright) into their
97
+ system python (e.g. Python 3.12) and start the solver via a `.bat`.
98
+
99
+ To match that workflow, we prefer an interpreter that has `patchright` when
100
+ available (it tends to have better anti-bot compatibility). For camoufox,
101
+ we also require `camoufox` import to succeed.
102
+ """
103
+ desired = (self.config.browser_type or "chromium").strip().lower()
104
+ if desired not in {"chromium", "chrome", "msedge", "camoufox"}:
105
+ desired = "chromium"
106
+
107
+ # Collect python candidates.
108
+ #
109
+ # NOTE: When the API server runs under `uv run`, `python` on PATH usually points to
110
+ # the venv python, not the system python. On Windows, use `where python` to discover
111
+ # other interpreters (e.g. Python312) where users installed camoufox/patchright.
112
+ candidates: list[str] = [sys.executable]
113
+ for p in self._windows_where_python():
114
+ if p.lower() != sys.executable.lower():
115
+ candidates.append(p)
116
+ # As a last resort, try PATH resolution.
117
+ candidates.append("python")
118
+
119
+ # De-duplicate while preserving order.
120
+ dedup: list[str] = []
121
+ seen: set[str] = set()
122
+ for p in candidates:
123
+ k = p.lower()
124
+ if k in seen:
125
+ continue
126
+ seen.add(k)
127
+ dedup.append(p)
128
+ candidates = dedup
129
+
130
+ def _pick_with(modules: list[str]) -> str | None:
131
+ for exe in candidates:
132
+ if self._can_import(exe, modules):
133
+ return exe
134
+ return None
135
+
136
+ self._actual_browser_type = desired
137
+
138
+ if desired == "camoufox":
139
+ # Prefer patchright if possible.
140
+ exe = _pick_with(["quart", "camoufox", "patchright"])
141
+ if exe:
142
+ self._python_exe = exe
143
+ return
144
+
145
+ exe = _pick_with(["quart", "camoufox", "playwright"])
146
+ if exe:
147
+ self._python_exe = exe
148
+ return
149
+
150
+ # No camoufox in any known interpreter; fallback to chromium.
151
+ logger.warning("Camoufox not available. Falling back solver browser to chromium.")
152
+ self._actual_browser_type = "chromium"
153
+
154
+ # For chromium/chrome/msedge, prefer patchright if available.
155
+ exe = _pick_with(["quart", "patchright"])
156
+ if exe:
157
+ self._python_exe = exe
158
+ return
159
+
160
+ exe = _pick_with(["quart", "playwright"])
161
+ if exe:
162
+ self._python_exe = exe
163
+ return
164
+
165
+ # Last resort: current interpreter (may fail fast with a clear error from the solver process).
166
+ self._python_exe = sys.executable
167
+
168
+ def _ensure_playwright_browsers(self, python_exe: str) -> None:
169
+ """Ensure Playwright browser binaries exist (best-effort).
170
+
171
+ We only auto-install for bundled Chromium. Branded channels (chrome/msedge)
172
+ rely on system-installed browsers.
173
+ """
174
+ if self._actual_browser_type != "chromium":
175
+ return
176
+
177
+ lock_dir = self._repo_root / "data" / ".locks"
178
+ lock_dir.mkdir(parents=True, exist_ok=True)
179
+ lock_path = lock_dir / "playwright_chromium_v1.lock"
180
+ if lock_path.exists():
181
+ return
182
+
183
+ try:
184
+ logger.info("Installing Playwright Chromium (first run)...")
185
+ args = [python_exe, "-m", "playwright", "install"]
186
+ # On Linux (Docker), install system deps as well.
187
+ if sys.platform.startswith("linux"):
188
+ args.append("--with-deps")
189
+ args.append("chromium")
190
+ subprocess.check_call(args, cwd=str(self._repo_root))
191
+ lock_path.write_text(str(time.time()), encoding="utf-8")
192
+ except Exception as exc:
193
+ # Don't create lock file; let next run retry.
194
+ raise RuntimeError(f"Playwright browser install failed: {exc}") from exc
195
+
196
+ def _parse_host_port(self) -> tuple[str, int]:
197
+ parsed = urlparse(self.config.url)
198
+ host = parsed.hostname or "127.0.0.1"
199
+ port = parsed.port or 5072
200
+ return host, int(port)
201
+
202
+ def start(self) -> None:
203
+ if not self.config.auto_start:
204
+ return
205
+
206
+ host, port = self._parse_host_port()
207
+
208
+ def _spawn() -> None:
209
+ script = self._script_path()
210
+ if not script.exists():
211
+ raise RuntimeError(f"Solver script not found: {script}")
212
+
213
+ # Ensure Playwright browsers are present before starting the solver process.
214
+ self._ensure_playwright_browsers(self._python_exe)
215
+
216
+ cmd = [
217
+ self._python_exe,
218
+ str(script),
219
+ "--browser_type",
220
+ self._actual_browser_type,
221
+ "--thread",
222
+ str(self.config.threads),
223
+ ]
224
+ if self.config.debug:
225
+ cmd.append("--debug")
226
+ cmd += ["--host", host, "--port", str(port)]
227
+
228
+ logger.info("Starting Turnstile solver: {}", " ".join(cmd))
229
+ self._process = subprocess.Popen(
230
+ cmd,
231
+ cwd=str(script.parent),
232
+ )
233
+ self._started_by_us = True
234
+
235
+ if not _wait_for_port(host, port, timeout=60.0):
236
+ exit_code = self._process.poll() if self._process else None
237
+ self.stop()
238
+ if exit_code is not None:
239
+ raise RuntimeError(
240
+ f"Turnstile solver exited early (code {exit_code}). "
241
+ "Please check solver dependencies."
242
+ )
243
+ raise RuntimeError("Turnstile solver did not become ready in time")
244
+
245
+ # Decide runtime + browser strategy before checking readiness.
246
+ self._select_runtime()
247
+ logger.info(
248
+ "Turnstile solver runtime selected: python={} browser_type={}",
249
+ self._python_exe,
250
+ self._actual_browser_type,
251
+ )
252
+
253
+ if _wait_for_port(host, port, timeout=1.0):
254
+ logger.info("Turnstile solver already running at {}:{}", host, port)
255
+ self._started_by_us = False
256
+ return
257
+
258
+ try:
259
+ _spawn()
260
+ return
261
+ except Exception as exc:
262
+ # camoufox is not always stable/available across environments (notably Docker).
263
+ # Fall back to chromium instead of failing the whole auto-register workflow.
264
+ if self._actual_browser_type != "camoufox":
265
+ raise
266
+ logger.warning("Camoufox solver failed to start; falling back to chromium: {}", exc)
267
+ try:
268
+ self.stop()
269
+ except Exception:
270
+ pass
271
+ self.config.browser_type = "chromium"
272
+ self._actual_browser_type = "chromium"
273
+ self._select_runtime()
274
+ logger.info(
275
+ "Turnstile solver runtime selected: python={} browser_type={}",
276
+ self._python_exe,
277
+ self._actual_browser_type,
278
+ )
279
+ _spawn()
280
+
281
+ def stop(self) -> None:
282
+ if not self._process or not self._started_by_us:
283
+ return
284
+
285
+ try:
286
+ logger.info("Stopping Turnstile solver...")
287
+ self._process.terminate()
288
+ self._process.wait(timeout=10)
289
+ except Exception:
290
+ try:
291
+ self._process.kill()
292
+ except Exception:
293
+ pass
294
+ finally:
295
+ self._process = None
296
+ self._started_by_us = False
app/services/request_logger.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """请求日志审计 - 记录近期请求"""
2
+
3
+ import time
4
+ import asyncio
5
+ import orjson
6
+ from typing import List, Dict, Deque
7
+ from collections import deque
8
+ from dataclasses import dataclass, asdict
9
+ from pathlib import Path
10
+
11
+ from app.core.logger import logger
12
+
13
+ @dataclass
14
+ class RequestLog:
15
+ id: str
16
+ time: str
17
+ timestamp: float
18
+ ip: str
19
+ model: str
20
+ duration: float
21
+ status: int
22
+ key_name: str
23
+ token_suffix: str
24
+ error: str = ""
25
+
26
+ class RequestLogger:
27
+ """请求日志记录器"""
28
+
29
+ _instance = None
30
+
31
+ def __new__(cls):
32
+ if cls._instance is None:
33
+ cls._instance = super().__new__(cls)
34
+ return cls._instance
35
+
36
+ def __init__(self, max_len: int = 1000):
37
+ if hasattr(self, '_initialized'):
38
+ return
39
+
40
+ self.file_path = Path(__file__).parents[2] / "data" / "logs.json"
41
+ self._logs: Deque[Dict] = deque(maxlen=max_len)
42
+ self._lock = asyncio.Lock()
43
+ self._loaded = False
44
+
45
+ self._initialized = True
46
+
47
+ async def init(self):
48
+ """初始化加载数据"""
49
+ if not self._loaded:
50
+ await self._load_data()
51
+
52
+ async def _load_data(self):
53
+ """从磁盘加载日志数据"""
54
+ if self._loaded:
55
+ return
56
+
57
+ if not self.file_path.exists():
58
+ self._loaded = True
59
+ return
60
+
61
+ try:
62
+ async with self._lock:
63
+ content = await asyncio.to_thread(self.file_path.read_bytes)
64
+ if content:
65
+ data = orjson.loads(content)
66
+ if isinstance(data, list):
67
+ self._logs.clear()
68
+ self._logs.extend(data)
69
+ self._loaded = True
70
+ logger.debug(f"[Logger] 加载日志成功: {len(self._logs)} 条")
71
+ except Exception as e:
72
+ logger.error(f"[Logger] 加载日志失败: {e}")
73
+ self._loaded = True
74
+
75
+ async def _save_data(self):
76
+ """保存日志数据到磁盘"""
77
+ if not self._loaded:
78
+ return
79
+
80
+ try:
81
+ # 确保目录存在
82
+ self.file_path.parent.mkdir(parents=True, exist_ok=True)
83
+
84
+ async with self._lock:
85
+ # 转换为列表保存
86
+ content = orjson.dumps(list(self._logs))
87
+ await asyncio.to_thread(self.file_path.write_bytes, content)
88
+ except Exception as e:
89
+ logger.error(f"[Logger] 保存日志失败: {e}")
90
+
91
+ async def add_log(self,
92
+ ip: str,
93
+ model: str,
94
+ duration: float,
95
+ status: int,
96
+ key_name: str,
97
+ token_suffix: str = "",
98
+ error: str = ""):
99
+ """添加日志"""
100
+ if not self._loaded:
101
+ await self.init()
102
+
103
+ try:
104
+ now = time.time()
105
+ # 格式化时间
106
+ time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(now))
107
+
108
+ log = {
109
+ "id": str(int(now * 1000)),
110
+ "time": time_str,
111
+ "timestamp": now,
112
+ "ip": ip,
113
+ "model": model,
114
+ "duration": round(duration, 2),
115
+ "status": status,
116
+ "key_name": key_name,
117
+ "token_suffix": token_suffix,
118
+ "error": error
119
+ }
120
+
121
+ async with self._lock:
122
+ self._logs.appendleft(log) # 最新的在前
123
+
124
+ # 异步保存
125
+ asyncio.create_task(self._save_data())
126
+
127
+ except Exception as e:
128
+ logger.error(f"[Logger] 记录日志失败: {e}")
129
+
130
+ async def get_logs(self, limit: int = 1000) -> List[Dict]:
131
+ """获取日志"""
132
+ async with self._lock:
133
+ return list(self._logs)[:limit]
134
+
135
+ async def clear_logs(self):
136
+ """清空日志"""
137
+ async with self._lock:
138
+ self._logs.clear()
139
+ await self._save_data()
140
+
141
+
142
+ # 全局实例
143
+ request_logger = RequestLogger()
app/services/request_stats.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """请求统计模块 - 按小时/天统计请求数据"""
2
+
3
+ import time
4
+ import asyncio
5
+ import orjson
6
+ from datetime import datetime
7
+ from typing import Dict, Any
8
+ from pathlib import Path
9
+ from collections import defaultdict
10
+
11
+ from app.core.logger import logger
12
+
13
+
14
+ class RequestStats:
15
+ """请求统计管理器(单例)"""
16
+
17
+ _instance = None
18
+
19
+ def __new__(cls):
20
+ if cls._instance is None:
21
+ cls._instance = super().__new__(cls)
22
+ return cls._instance
23
+
24
+ def __init__(self):
25
+ if hasattr(self, '_initialized'):
26
+ return
27
+
28
+ self.file_path = Path(__file__).parents[2] / "data" / "stats.json"
29
+
30
+ # 统计数据
31
+ self._hourly: Dict[str, Dict[str, int]] = defaultdict(lambda: {"total": 0, "success": 0, "failed": 0})
32
+ self._daily: Dict[str, Dict[str, int]] = defaultdict(lambda: {"total": 0, "success": 0, "failed": 0})
33
+ self._models: Dict[str, int] = defaultdict(int)
34
+
35
+ # 保留策略
36
+ self._hourly_keep = 48 # 保留48小时
37
+ self._daily_keep = 30 # 保留30天
38
+
39
+ self._lock = asyncio.Lock()
40
+ self._loaded = False
41
+ self._initialized = True
42
+
43
+ async def init(self):
44
+ """初始化加载数据"""
45
+ if not self._loaded:
46
+ await self._load_data()
47
+
48
+ async def _load_data(self):
49
+ """从磁盘加载统计数据"""
50
+ if self._loaded:
51
+ return
52
+
53
+ if not self.file_path.exists():
54
+ self._loaded = True
55
+ return
56
+
57
+ try:
58
+ async with self._lock:
59
+ content = await asyncio.to_thread(self.file_path.read_bytes)
60
+ if content:
61
+ data = orjson.loads(content)
62
+
63
+ # 恢复 defaultdict 结构
64
+ self._hourly = defaultdict(lambda: {"total": 0, "success": 0, "failed": 0})
65
+ self._hourly.update(data.get("hourly", {}))
66
+
67
+ self._daily = defaultdict(lambda: {"total": 0, "success": 0, "failed": 0})
68
+ self._daily.update(data.get("daily", {}))
69
+
70
+ self._models = defaultdict(int)
71
+ self._models.update(data.get("models", {}))
72
+
73
+ self._loaded = True
74
+ logger.debug(f"[Stats] 加载统计数据成功")
75
+ except Exception as e:
76
+ logger.error(f"[Stats] 加载数据失败: {e}")
77
+ self._loaded = True # 防止覆盖
78
+
79
+ async def _save_data(self):
80
+ """保存统计数据到磁盘"""
81
+ if not self._loaded:
82
+ return
83
+
84
+ try:
85
+ # 确保目录存在
86
+ self.file_path.parent.mkdir(parents=True, exist_ok=True)
87
+
88
+ async with self._lock:
89
+ data = {
90
+ "hourly": dict(self._hourly),
91
+ "daily": dict(self._daily),
92
+ "models": dict(self._models)
93
+ }
94
+ content = orjson.dumps(data)
95
+ await asyncio.to_thread(self.file_path.write_bytes, content)
96
+ except Exception as e:
97
+ logger.error(f"[Stats] 保存数据失败: {e}")
98
+
99
+ async def record_request(self, model: str, success: bool) -> None:
100
+ """记录一次请求"""
101
+ if not self._loaded:
102
+ await self.init()
103
+
104
+ now = datetime.now()
105
+ hour_key = now.strftime("%Y-%m-%dT%H")
106
+ day_key = now.strftime("%Y-%m-%d")
107
+
108
+ # 小时统计
109
+ self._hourly[hour_key]["total"] += 1
110
+ if success:
111
+ self._hourly[hour_key]["success"] += 1
112
+ else:
113
+ self._hourly[hour_key]["failed"] += 1
114
+
115
+ # 天统计
116
+ self._daily[day_key]["total"] += 1
117
+ if success:
118
+ self._daily[day_key]["success"] += 1
119
+ else:
120
+ self._daily[day_key]["failed"] += 1
121
+
122
+ # 模型统计
123
+ self._models[model] += 1
124
+
125
+ # 定期清理旧数据
126
+ self._cleanup()
127
+
128
+ # 异步保存
129
+ asyncio.create_task(self._save_data())
130
+
131
+ def _cleanup(self) -> None:
132
+ """清理过期数据"""
133
+ now = datetime.now()
134
+
135
+ # 清理小时数据
136
+ hour_keys = list(self._hourly.keys())
137
+ if len(hour_keys) > self._hourly_keep:
138
+ for key in sorted(hour_keys)[:-self._hourly_keep]:
139
+ del self._hourly[key]
140
+
141
+ # 清理天数据
142
+ day_keys = list(self._daily.keys())
143
+ if len(day_keys) > self._daily_keep:
144
+ for key in sorted(day_keys)[:-self._daily_keep]:
145
+ del self._daily[key]
146
+
147
+ def get_stats(self, hours: int = 24, days: int = 7) -> Dict[str, Any]:
148
+ """获取统计数据"""
149
+ now = datetime.now()
150
+
151
+ # 获取最近N小时数据
152
+ hourly_data = []
153
+ for i in range(hours - 1, -1, -1):
154
+ from datetime import timedelta
155
+ dt = now - timedelta(hours=i)
156
+ key = dt.strftime("%Y-%m-%dT%H")
157
+ data = self._hourly.get(key, {"total": 0, "success": 0, "failed": 0})
158
+ hourly_data.append({
159
+ "hour": dt.strftime("%H:00"),
160
+ "date": dt.strftime("%m-%d"),
161
+ **data
162
+ })
163
+
164
+ # 获取最近N天数据
165
+ daily_data = []
166
+ for i in range(days - 1, -1, -1):
167
+ from datetime import timedelta
168
+ dt = now - timedelta(days=i)
169
+ key = dt.strftime("%Y-%m-%d")
170
+ data = self._daily.get(key, {"total": 0, "success": 0, "failed": 0})
171
+ daily_data.append({
172
+ "date": dt.strftime("%m-%d"),
173
+ **data
174
+ })
175
+
176
+ # 模型统计(取 Top 10)
177
+ model_data = sorted(self._models.items(), key=lambda x: x[1], reverse=True)[:10]
178
+
179
+ # 总计
180
+ total_requests = sum(d["total"] for d in self._hourly.values())
181
+ total_success = sum(d["success"] for d in self._hourly.values())
182
+ total_failed = sum(d["failed"] for d in self._hourly.values())
183
+
184
+ return {
185
+ "hourly": hourly_data,
186
+ "daily": daily_data,
187
+ "models": [{"model": m, "count": c} for m, c in model_data],
188
+ "summary": {
189
+ "total": total_requests,
190
+ "success": total_success,
191
+ "failed": total_failed,
192
+ "success_rate": round(total_success / total_requests * 100, 1) if total_requests > 0 else 0
193
+ }
194
+ }
195
+
196
+ async def reset(self) -> None:
197
+ """重置所有统计"""
198
+ self._hourly.clear()
199
+ self._daily.clear()
200
+ self._models.clear()
201
+ await self._save_data()
202
+
203
+
204
+ # 全局实例
205
+ request_stats = RequestStats()
app/services/token/__init__.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Token 服务模块"""
2
+
3
+ from app.services.token.models import (
4
+ TokenInfo,
5
+ TokenStatus,
6
+ TokenPoolStats,
7
+ EffortType,
8
+ DEFAULT_QUOTA,
9
+ EFFORT_COST
10
+ )
11
+ from app.services.token.pool import TokenPool
12
+ from app.services.token.manager import TokenManager, get_token_manager
13
+ from app.services.token.service import TokenService
14
+ from app.services.token.scheduler import TokenRefreshScheduler, get_scheduler
15
+
16
+ __all__ = [
17
+ # Models
18
+ "TokenInfo",
19
+ "TokenStatus",
20
+ "TokenPoolStats",
21
+ "EffortType",
22
+ "DEFAULT_QUOTA",
23
+ "EFFORT_COST",
24
+
25
+ # Core
26
+ "TokenPool",
27
+ "TokenManager",
28
+
29
+ # API
30
+ "TokenService",
31
+ "get_token_manager",
32
+
33
+ # Scheduler
34
+ "TokenRefreshScheduler",
35
+ "get_scheduler",
36
+ ]
app/services/token/manager.py ADDED
@@ -0,0 +1,654 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Token 管理服务"""
2
+
3
+ import asyncio
4
+ import time
5
+ from datetime import datetime
6
+ from typing import Dict, List, Optional
7
+
8
+ from app.core.logger import logger
9
+ from app.services.token.models import TokenInfo, EffortType, TokenPoolStats, FAIL_THRESHOLD, TokenStatus
10
+ from app.core.storage import get_storage
11
+ from app.core.config import get_config
12
+ from app.services.token.pool import TokenPool
13
+
14
+ # 批量刷新配置
15
+ REFRESH_INTERVAL_HOURS = 8
16
+ REFRESH_BATCH_SIZE = 10
17
+ REFRESH_CONCURRENCY = 5
18
+
19
+
20
+ class TokenManager:
21
+ """管理 Token 的增删改查和配额同步"""
22
+
23
+ _instance: Optional["TokenManager"] = None
24
+ _lock = asyncio.Lock()
25
+
26
+ def __init__(self):
27
+ self.pools: Dict[str, TokenPool] = {}
28
+ self.initialized = False
29
+ self._save_lock = asyncio.Lock()
30
+ self._dirty = False
31
+ self._save_task: Optional[asyncio.Task] = None
32
+ self._save_delay = 0.5
33
+ self._last_reload_at = 0.0
34
+
35
+ @classmethod
36
+ async def get_instance(cls) -> "TokenManager":
37
+ """获取单例实例"""
38
+ if cls._instance is None:
39
+ async with cls._lock:
40
+ if cls._instance is None:
41
+ cls._instance = cls()
42
+ await cls._instance._load()
43
+ return cls._instance
44
+
45
+ async def _load(self):
46
+ """初始化加载"""
47
+ if not self.initialized:
48
+ try:
49
+ storage = get_storage()
50
+ data = await storage.load_tokens()
51
+
52
+ # 如果后端返回 None 或空数据,尝试从本地 data/token.json 初始化后端
53
+ if not data:
54
+ from app.core.storage import LocalStorage
55
+ local_storage = LocalStorage()
56
+ local_data = await local_storage.load_tokens()
57
+ if local_data:
58
+ data = local_data
59
+ await storage.save_tokens(local_data)
60
+ logger.info(f"Initialized remote token storage ({storage.__class__.__name__}) with local tokens.")
61
+ else:
62
+ data = {}
63
+
64
+ self.pools = {}
65
+ for pool_name, tokens in data.items():
66
+ pool = TokenPool(pool_name)
67
+ for token_data in tokens:
68
+ try:
69
+ # 统一存储裸 token
70
+ if isinstance(token_data, dict):
71
+ raw_token = token_data.get("token")
72
+ if isinstance(raw_token, str) and raw_token.startswith("sso="):
73
+ token_data["token"] = raw_token[4:]
74
+ token_info = TokenInfo(**token_data)
75
+ pool.add(token_info)
76
+ except Exception as e:
77
+ logger.warning(f"Failed to load token in pool '{pool_name}': {e}")
78
+ continue
79
+ pool._rebuild_index()
80
+ self.pools[pool_name] = pool
81
+
82
+ self.initialized = True
83
+ self._last_reload_at = time.monotonic()
84
+ total = sum(p.count() for p in self.pools.values())
85
+ logger.info(f"TokenManager initialized: {len(self.pools)} pools with {total} tokens")
86
+ except Exception as e:
87
+ logger.error(f"Failed to initialize TokenManager: {e}")
88
+ self.pools = {}
89
+ self.initialized = True
90
+
91
+ async def reload(self):
92
+ """重新加载 Token 池数据"""
93
+ async with self.__class__._lock:
94
+ self.initialized = False
95
+ await self._load()
96
+
97
+ async def reload_if_stale(self):
98
+ """在多 worker 场景下保持短周期一致性"""
99
+ interval = get_config("token.reload_interval_sec", 30)
100
+ try:
101
+ interval = float(interval)
102
+ except Exception:
103
+ interval = 30.0
104
+ if interval <= 0:
105
+ return
106
+ if time.monotonic() - self._last_reload_at < interval:
107
+ return
108
+ await self.reload()
109
+
110
+ async def _save(self):
111
+ """保存变更"""
112
+ async with self._save_lock:
113
+ try:
114
+ data = {}
115
+ for pool_name, pool in self.pools.items():
116
+ data[pool_name] = [
117
+ info.model_dump() for info in pool.list()
118
+ ]
119
+
120
+ storage = get_storage()
121
+ async with storage.acquire_lock("tokens_save", timeout=10):
122
+ await storage.save_tokens(data)
123
+ except Exception as e:
124
+ logger.error(f"Failed to save tokens: {e}")
125
+
126
+ def _schedule_save(self):
127
+ """合并高频保存请求,减少写入开销"""
128
+ delay_ms = get_config("token.save_delay_ms", 500)
129
+ try:
130
+ delay_ms = float(delay_ms)
131
+ except Exception:
132
+ delay_ms = 500
133
+ self._save_delay = max(0.0, delay_ms / 1000.0)
134
+ self._dirty = True
135
+ if self._save_delay == 0:
136
+ if self._save_task and not self._save_task.done():
137
+ return
138
+ self._save_task = asyncio.create_task(self._save())
139
+ return
140
+ if self._save_task and not self._save_task.done():
141
+ return
142
+ self._save_task = asyncio.create_task(self._flush_loop())
143
+
144
+ async def _flush_loop(self):
145
+ try:
146
+ while True:
147
+ await asyncio.sleep(self._save_delay)
148
+ if not self._dirty:
149
+ break
150
+ self._dirty = False
151
+ await self._save()
152
+ finally:
153
+ self._save_task = None
154
+ if self._dirty:
155
+ self._schedule_save()
156
+
157
+ @staticmethod
158
+ def _extract_cookie_value(cookie_str: str, name: str) -> str | None:
159
+ needle = f"{name}="
160
+ if needle not in cookie_str:
161
+ return None
162
+ for part in cookie_str.split(";"):
163
+ part = part.strip()
164
+ if part.startswith(needle):
165
+ value = part[len(needle):].strip()
166
+ return value or None
167
+ return None
168
+
169
+ @classmethod
170
+ def _normalize_input_token(cls, token_str: str) -> str:
171
+ raw = str(token_str or "").strip()
172
+ if not raw:
173
+ return ""
174
+ if ";" in raw:
175
+ return (cls._extract_cookie_value(raw, "sso") or "").strip()
176
+ if raw.startswith("sso="):
177
+ return raw[4:].strip()
178
+ return raw
179
+
180
+ def _find_token_info(self, token_str: str) -> tuple[Optional[TokenInfo], str]:
181
+ raw_token = self._normalize_input_token(token_str)
182
+ if not raw_token:
183
+ return None, ""
184
+ for pool in self.pools.values():
185
+ token = pool.get(raw_token)
186
+ if token:
187
+ return token, raw_token
188
+ return None, raw_token
189
+
190
+ def get_token(self, pool_name: str = "ssoBasic") -> Optional[str]:
191
+ """
192
+ 获取可用 Token
193
+
194
+ Args:
195
+ pool_name: Token 池名称
196
+
197
+ Returns:
198
+ Token 字符串或 None
199
+ """
200
+ pool = self.pools.get(pool_name)
201
+ if not pool:
202
+ logger.warning(f"Pool '{pool_name}' not found")
203
+ return None
204
+
205
+ token_info = pool.select()
206
+ if not token_info:
207
+ logger.warning(f"No available token in pool '{pool_name}'")
208
+ return None
209
+
210
+ token = token_info.token
211
+ if token.startswith("sso="):
212
+ return token[4:]
213
+ return token
214
+
215
+ def get_token_for_model(self, model_id: str) -> Optional[str]:
216
+ """按模型选择可用 Token(包含 basic->super 回退与 heavy 配额桶选择)。"""
217
+ from app.services.grok.model import ModelService
218
+
219
+ bucket = "heavy" if ModelService.is_heavy_bucket_model(model_id) else "normal"
220
+ for pool_name in ModelService.pool_candidates_for_model(model_id):
221
+ pool = self.pools.get(pool_name)
222
+ if not pool:
223
+ continue
224
+ token_info = pool.select(bucket=bucket)
225
+ if not token_info:
226
+ continue
227
+ token = token_info.token
228
+ return token[4:] if token.startswith("sso=") else token
229
+
230
+ logger.warning(f"No available token for model '{model_id}'")
231
+ return None
232
+
233
+ async def consume(self, token_str: str, effort: EffortType = EffortType.LOW, bucket: str = "normal") -> bool:
234
+ """
235
+ 消耗配额(本地预估)
236
+
237
+ Args:
238
+ token_str: Token 字符串
239
+ effort: 消耗力度
240
+
241
+ Returns:
242
+ 是否成功
243
+ """
244
+ raw_token = token_str.replace("sso=", "")
245
+
246
+ for pool in self.pools.values():
247
+ token = pool.get(raw_token)
248
+ if token:
249
+ consumed = token.consume_heavy(effort) if bucket == "heavy" else token.consume(effort)
250
+ logger.debug(
251
+ f"Token {raw_token[:10]}...: consumed {consumed} quota (bucket={bucket}), use_count={token.use_count}"
252
+ )
253
+ self._schedule_save()
254
+ return True
255
+
256
+ logger.warning(f"Token {raw_token[:10]}...: not found for consumption")
257
+ return False
258
+
259
+ async def sync_usage(
260
+ self,
261
+ token_str: str,
262
+ model_id: str,
263
+ fallback_effort: EffortType = EffortType.LOW,
264
+ consume_on_fail: bool = True,
265
+ is_usage: bool = True
266
+ ) -> bool:
267
+ """
268
+ 同步 Token 用量
269
+
270
+ 优先从 API 获取最新配额,失败则降级到本地预估
271
+
272
+ Args:
273
+ token_str: Token 字符串(可带 sso= 前缀)
274
+ model_name: 模型名称(用于 API 查询)
275
+ fallback_effort: 降级时的消耗力度
276
+ consume_on_fail: 失败时是否降��扣费
277
+ is_usage: 是否记录为一次使用(影响 use_count)
278
+
279
+ Returns:
280
+ 是否成功
281
+ """
282
+ raw_token = token_str.replace("sso=", "")
283
+
284
+ # 查找 Token 对象
285
+ target_token: Optional[TokenInfo] = None
286
+ for pool in self.pools.values():
287
+ target_token = pool.get(raw_token)
288
+ if target_token:
289
+ break
290
+
291
+ if not target_token:
292
+ logger.warning(f"Token {raw_token[:10]}...: not found for sync")
293
+ return False
294
+
295
+ from app.services.grok.model import ModelService
296
+
297
+ bucket = "heavy" if ModelService.is_heavy_bucket_model(model_id) else "normal"
298
+ rate_limit_model = ModelService.rate_limit_model_for(model_id)
299
+
300
+ # 尝试 API 同步
301
+ try:
302
+ from app.services.grok.usage import UsageService
303
+
304
+ usage_service = UsageService()
305
+ result = await usage_service.get(token_str, model_name=rate_limit_model)
306
+
307
+ if result and "remainingTokens" in result:
308
+ try:
309
+ new_quota = int(result["remainingTokens"])
310
+ except Exception:
311
+ new_quota = 0
312
+
313
+ if bucket == "heavy":
314
+ old_quota = target_token.heavy_quota
315
+ target_token.update_heavy_quota(new_quota)
316
+ else:
317
+ old_quota = target_token.quota
318
+ target_token.update_quota(new_quota)
319
+
320
+ target_token.record_success(is_usage=is_usage)
321
+
322
+ consumed = max(0, old_quota - new_quota) if old_quota >= 0 else 0
323
+ logger.info(
324
+ f"Token {raw_token[:10]}...: synced quota (bucket={bucket}, model={rate_limit_model}) "
325
+ f"{old_quota} -> {new_quota} (consumed: {consumed}, use_count: {target_token.use_count})"
326
+ )
327
+
328
+ self._schedule_save()
329
+ return True
330
+
331
+ except Exception as e:
332
+ logger.warning(f"Token {raw_token[:10]}...: API sync failed, fallback to local ({e})")
333
+
334
+ # 降级:本地预估扣费
335
+ if consume_on_fail:
336
+ logger.debug(f"Token {raw_token[:10]}...: using local consumption")
337
+ return await self.consume(token_str, fallback_effort, bucket=bucket)
338
+ else:
339
+ logger.debug(f"Token {raw_token[:10]}...: sync failed, skipping local consumption")
340
+ return False
341
+
342
+ async def record_fail(self, token_str: str, status_code: int = 401, reason: str = "") -> bool:
343
+ """
344
+ 记录 Token 失败
345
+
346
+ Args:
347
+ token_str: Token 字符串
348
+ status_code: HTTP 状态码
349
+ reason: 失败原因
350
+
351
+ Returns:
352
+ 是否成功
353
+ """
354
+ raw_token = token_str.replace("sso=", "")
355
+
356
+ for pool in self.pools.values():
357
+ token = pool.get(raw_token)
358
+ if token:
359
+ if status_code == 401:
360
+ token.record_fail(status_code, reason)
361
+ logger.warning(
362
+ f"Token {raw_token[:10]}...: recorded 401 failure "
363
+ f"({token.fail_count}/{FAIL_THRESHOLD}) - {reason}"
364
+ )
365
+ else:
366
+ logger.info(
367
+ f"Token {raw_token[:10]}...: non-401 error ({status_code}) - {reason} (not counted)"
368
+ )
369
+ self._schedule_save()
370
+ return True
371
+
372
+ logger.warning(f"Token {raw_token[:10]}...: not found for failure record")
373
+ return False
374
+
375
+ # ========== 管理功能 ==========
376
+
377
+ async def add(self, token: str, pool_name: str = "ssoBasic") -> bool:
378
+ """
379
+ 添加 Token
380
+
381
+ Args:
382
+ token: Token 字符串(不含 sso= 前缀)
383
+ pool_name: 池名称
384
+
385
+ Returns:
386
+ 是否成功
387
+ """
388
+ if pool_name not in self.pools:
389
+ self.pools[pool_name] = TokenPool(pool_name)
390
+ logger.info(f"Pool '{pool_name}': created")
391
+
392
+ pool = self.pools[pool_name]
393
+
394
+ token = token[4:] if token.startswith("sso=") else token
395
+ if pool.get(token):
396
+ logger.warning(f"Pool '{pool_name}': token already exists")
397
+ return False
398
+
399
+ pool.add(TokenInfo(token=token))
400
+ await self._save()
401
+ logger.info(f"Pool '{pool_name}': token added")
402
+ return True
403
+
404
+ async def mark_asset_clear(self, token: str) -> bool:
405
+ """Record online asset cleanup timestamp."""
406
+ info, _ = self._find_token_info(token)
407
+ if info:
408
+ info.last_asset_clear_at = int(datetime.now().timestamp() * 1000)
409
+ self._schedule_save()
410
+ return True
411
+ return False
412
+
413
+ async def set_token_invalid(self, token_str: str, reason: str = "", save: bool = True) -> bool:
414
+ """Mark a token as expired/invalid."""
415
+ token, raw_token = self._find_token_info(token_str)
416
+ if not token:
417
+ logger.warning(f"Token {raw_token[:10]}...: not found for invalidation")
418
+ return False
419
+
420
+ token.status = TokenStatus.EXPIRED
421
+ token.fail_count = max(token.fail_count, FAIL_THRESHOLD)
422
+ token.last_fail_at = int(datetime.now().timestamp() * 1000)
423
+ if reason:
424
+ token.last_fail_reason = str(reason)[:500]
425
+
426
+ if save:
427
+ await self._save()
428
+ return True
429
+
430
+ async def mark_token_account_settings_success(self, token_str: str, save: bool = True) -> bool:
431
+ """Reset failure state after account-settings flow succeeded."""
432
+ token, raw_token = self._find_token_info(token_str)
433
+ if not token:
434
+ logger.warning(f"Token {raw_token[:10]}...: not found for account-settings success")
435
+ return False
436
+
437
+ token.fail_count = 0
438
+ token.last_fail_at = None
439
+ token.last_fail_reason = None
440
+ token.last_sync_at = int(datetime.now().timestamp() * 1000)
441
+ token.status = TokenStatus.COOLING if token.quota == 0 else TokenStatus.ACTIVE
442
+
443
+ if save:
444
+ await self._save()
445
+ return True
446
+
447
+ async def commit(self):
448
+ """Persist current in-memory token state."""
449
+ await self._save()
450
+
451
+ async def remove(self, token: str) -> bool:
452
+ """
453
+ 删除 Token
454
+
455
+ Args:
456
+ token: Token 字符串
457
+
458
+ Returns:
459
+ 是否成功
460
+ """
461
+ for pool_name, pool in self.pools.items():
462
+ if pool.remove(token):
463
+ await self._save()
464
+ logger.info(f"Pool '{pool_name}': token removed")
465
+ return True
466
+
467
+ logger.warning(f"Token not found for removal")
468
+ return False
469
+
470
+ async def reset_all(self):
471
+ """重置所有 Token 配额"""
472
+ count = 0
473
+ for pool in self.pools.values():
474
+ for token in pool:
475
+ token.reset()
476
+ count += 1
477
+
478
+ await self._save()
479
+ logger.info(f"Reset all: {count} tokens updated")
480
+
481
+ async def reset_token(self, token_str: str) -> bool:
482
+ """
483
+ 重置单个 Token
484
+
485
+ Args:
486
+ token_str: Token 字符串
487
+
488
+ Returns:
489
+ 是否成功
490
+ """
491
+ raw_token = token_str.replace("sso=", "")
492
+
493
+ for pool in self.pools.values():
494
+ token = pool.get(raw_token)
495
+ if token:
496
+ token.reset()
497
+ await self._save()
498
+ logger.info(f"Token {raw_token[:10]}...: reset completed")
499
+ return True
500
+
501
+ logger.warning(f"Token {raw_token[:10]}...: not found for reset")
502
+ return False
503
+
504
+ def get_stats(self) -> Dict[str, dict]:
505
+ """获取统计信息"""
506
+ stats = {}
507
+ for name, pool in self.pools.items():
508
+ pool_stats = pool.get_stats()
509
+ stats[name] = pool_stats.model_dump()
510
+ return stats
511
+
512
+ def get_pool_tokens(self, pool_name: str = "ssoBasic") -> List[TokenInfo]:
513
+ """
514
+ 获取指定池的所有 Token
515
+
516
+ Args:
517
+ pool_name: 池名称
518
+
519
+ Returns:
520
+ Token 列表
521
+ """
522
+ pool = self.pools.get(pool_name)
523
+ if not pool:
524
+ return []
525
+ return pool.list()
526
+
527
+ async def refresh_cooling_tokens(self) -> Dict[str, int]:
528
+ """
529
+ 批量刷新 cooling 状态的 Token 配额
530
+
531
+ Returns:
532
+ {"checked": int, "refreshed": int, "recovered": int, "expired": int}
533
+ """
534
+ from app.services.grok.usage import UsageService
535
+
536
+ # 收集需要刷新的 token
537
+ to_refresh: List[TokenInfo] = []
538
+ for pool in self.pools.values():
539
+ for token in pool:
540
+ if token.need_refresh(REFRESH_INTERVAL_HOURS):
541
+ to_refresh.append(token)
542
+
543
+ if not to_refresh:
544
+ logger.debug("Refresh check: no tokens need refresh")
545
+ return {"checked": 0, "refreshed": 0, "recovered": 0, "expired": 0}
546
+
547
+ logger.info(f"Refresh check: found {len(to_refresh)} cooling tokens to refresh")
548
+
549
+ # 批量并发刷新
550
+ semaphore = asyncio.Semaphore(REFRESH_CONCURRENCY)
551
+ usage_service = UsageService()
552
+ refreshed = 0
553
+ recovered = 0
554
+ expired = 0
555
+
556
+ async def _refresh_one(token_info: TokenInfo) -> dict:
557
+ """刷新单个 token"""
558
+ async with semaphore:
559
+ token_str = token_info.token
560
+ if token_str.startswith("sso="):
561
+ token_str = token_str[4:]
562
+
563
+ # 重试逻辑:最多 2 次重试
564
+ for retry in range(3): # 0, 1, 2
565
+ try:
566
+ result = await usage_service.get(token_str)
567
+
568
+ if result and "remainingTokens" in result:
569
+ new_quota = result["remainingTokens"]
570
+ old_quota = token_info.quota
571
+ old_status = token_info.status
572
+
573
+ token_info.update_quota(new_quota)
574
+ token_info.mark_synced()
575
+
576
+ logger.info(
577
+ f"Token {token_info.token[:10]}...: refreshed "
578
+ f"{old_quota} -> {new_quota}, status: {old_status} -> {token_info.status}"
579
+ )
580
+
581
+ return {
582
+ "recovered": new_quota > 0 and old_quota == 0,
583
+ "expired": False
584
+ }
585
+
586
+ token_info.mark_synced()
587
+ return {"recovered": False, "expired": False}
588
+
589
+ except Exception as e:
590
+ error_str = str(e)
591
+
592
+ # 检查是否为 401 错误
593
+ if "401" in error_str or "Unauthorized" in error_str:
594
+ if retry < 2:
595
+ logger.warning(
596
+ f"Token {token_info.token[:10]}...: 401 error, "
597
+ f"retry {retry + 1}/2..."
598
+ )
599
+ await asyncio.sleep(0.5)
600
+ continue
601
+ else:
602
+ # 重试 2 次后仍然 401,标记为 expired
603
+ logger.error(
604
+ f"Token {token_info.token[:10]}...: 401 after 2 retries, "
605
+ f"marking as expired"
606
+ )
607
+ token_info.status = TokenStatus.EXPIRED
608
+ token_info.mark_synced()
609
+ return {"recovered": False, "expired": True}
610
+ else:
611
+ logger.warning(
612
+ f"Token {token_info.token[:10]}...: refresh failed ({e})"
613
+ )
614
+ token_info.mark_synced()
615
+ return {"recovered": False, "expired": False}
616
+
617
+ token_info.mark_synced()
618
+ return {"recovered": False, "expired": False}
619
+
620
+ # 批量处理
621
+ for i in range(0, len(to_refresh), REFRESH_BATCH_SIZE):
622
+ batch = to_refresh[i:i + REFRESH_BATCH_SIZE]
623
+ results = await asyncio.gather(*[_refresh_one(t) for t in batch])
624
+ refreshed += len(batch)
625
+ recovered += sum(r["recovered"] for r in results)
626
+ expired += sum(r["expired"] for r in results)
627
+
628
+ # 批次间延迟
629
+ if i + REFRESH_BATCH_SIZE < len(to_refresh):
630
+ await asyncio.sleep(1)
631
+
632
+ await self._save()
633
+
634
+ logger.info(
635
+ f"Refresh completed: "
636
+ f"checked={len(to_refresh)}, refreshed={refreshed}, "
637
+ f"recovered={recovered}, expired={expired}"
638
+ )
639
+
640
+ return {
641
+ "checked": len(to_refresh),
642
+ "refreshed": refreshed,
643
+ "recovered": recovered,
644
+ "expired": expired
645
+ }
646
+
647
+
648
+ # 便捷函数
649
+ async def get_token_manager() -> TokenManager:
650
+ """获取 TokenManager 单例"""
651
+ return await TokenManager.get_instance()
652
+
653
+
654
+ __all__ = ["TokenManager", "get_token_manager"]
app/services/token/models.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Token 数据模型
3
+
4
+ 额度规则:
5
+ - 新号默认 80 配额
6
+ - 重置后恢复 80
7
+ - lowEffort 扣 1,highEffort 扣 4
8
+ """
9
+
10
+ from enum import Enum
11
+ from typing import Optional, List
12
+ from pydantic import BaseModel, Field
13
+ from datetime import datetime
14
+
15
+
16
+ # 默认配额
17
+ DEFAULT_QUOTA = 80
18
+
19
+ # 失败阈值
20
+ FAIL_THRESHOLD = 5
21
+
22
+
23
+ class TokenStatus(str, Enum):
24
+ """Token 状态"""
25
+ ACTIVE = "active"
26
+ DISABLED = "disabled"
27
+ EXPIRED = "expired"
28
+ COOLING = "cooling"
29
+
30
+
31
+ class EffortType(str, Enum):
32
+ """请求消耗类型"""
33
+ LOW = "low" # 扣 1
34
+ HIGH = "high" # 扣 4
35
+
36
+
37
+ EFFORT_COST = {
38
+ EffortType.LOW: 1,
39
+ EffortType.HIGH: 4,
40
+ }
41
+
42
+
43
+ class TokenInfo(BaseModel):
44
+ """Token 信息"""
45
+
46
+ token: str
47
+ status: TokenStatus = TokenStatus.ACTIVE
48
+ quota: int = DEFAULT_QUOTA
49
+ heavy_quota: int = -1
50
+
51
+ # 统计
52
+ created_at: int = Field(default_factory=lambda: int(datetime.now().timestamp() * 1000))
53
+ last_used_at: Optional[int] = None
54
+ use_count: int = 0
55
+
56
+ # 失败追踪
57
+ fail_count: int = 0
58
+ last_fail_at: Optional[int] = None
59
+ last_fail_reason: Optional[str] = None
60
+
61
+ # 冷却管理
62
+ last_sync_at: Optional[int] = None # 上次同步时间
63
+
64
+ # 扩展
65
+ tags: List[str] = Field(default_factory=list)
66
+ note: str = ""
67
+ last_asset_clear_at: Optional[int] = None
68
+
69
+ def is_available(self) -> bool:
70
+ """检查是否可用(状态正常且配额 > 0)"""
71
+ return self.status == TokenStatus.ACTIVE and self.quota > 0
72
+
73
+ def consume(self, effort: EffortType = EffortType.LOW) -> int:
74
+ """
75
+ 消耗配额
76
+
77
+ Args:
78
+ effort: LOW 扣 1,HIGH 扣 4
79
+
80
+ Returns:
81
+ 实际扣除的配额
82
+ """
83
+ cost = EFFORT_COST[effort]
84
+ actual_cost = min(cost, self.quota)
85
+
86
+ self.last_used_at = int(datetime.now().timestamp() * 1000)
87
+ self.use_count += 1
88
+ self.quota = max(0, self.quota - cost)
89
+
90
+ # 成功消耗后清空失败计数
91
+ self.fail_count = 0
92
+ self.last_fail_reason = None
93
+
94
+ if self.quota == 0:
95
+ self.status = TokenStatus.COOLING
96
+ elif self.status in [TokenStatus.COOLING, TokenStatus.EXPIRED]:
97
+ self.status = TokenStatus.ACTIVE
98
+
99
+ return actual_cost
100
+
101
+ def update_quota(self, new_quota: int):
102
+ """
103
+ 更新配额(用于 API 同步)
104
+
105
+ Args:
106
+ new_quota: 新的配额值
107
+ """
108
+ self.quota = max(0, new_quota)
109
+
110
+ if self.quota == 0:
111
+ self.status = TokenStatus.COOLING
112
+ elif self.quota > 0 and self.status in [TokenStatus.COOLING, TokenStatus.EXPIRED]:
113
+ self.status = TokenStatus.ACTIVE
114
+
115
+ def update_heavy_quota(self, new_quota: int):
116
+ """
117
+ 更新 heavy 配额(用于 grok-4-heavy 的 rate-limits 同步)。
118
+
119
+ 注意:heavy 配额不参与 status 计算,避免误伤普通模型可用性。
120
+ """
121
+ try:
122
+ v = int(new_quota)
123
+ except Exception:
124
+ v = 0
125
+ self.heavy_quota = max(0, v)
126
+
127
+ def consume_heavy(self, effort: EffortType = EffortType.LOW) -> int:
128
+ """
129
+ 消耗 heavy 配额(本地预估)。
130
+
131
+ 当 heavy_quota 为 -1(未知)时,不扣减配额,仅记录一次使用。
132
+ """
133
+ cost = EFFORT_COST[effort]
134
+
135
+ self.last_used_at = int(datetime.now().timestamp() * 1000)
136
+ self.use_count += 1
137
+
138
+ # 成功消耗后清空失败计数
139
+ self.fail_count = 0
140
+ self.last_fail_reason = None
141
+
142
+ if self.heavy_quota < 0:
143
+ return 0
144
+
145
+ actual_cost = min(cost, self.heavy_quota)
146
+ self.heavy_quota = max(0, self.heavy_quota - actual_cost)
147
+ return actual_cost
148
+
149
+ def reset(self):
150
+ """重置配额到默认值"""
151
+ self.quota = DEFAULT_QUOTA
152
+ self.heavy_quota = -1
153
+ self.status = TokenStatus.ACTIVE
154
+ self.fail_count = 0
155
+ self.last_fail_reason = None
156
+
157
+ def record_fail(self, status_code: int = 401, reason: str = ""):
158
+ """记录失败,达到阈值后自动标记为 expired"""
159
+ # 仅 401 错误才计入失败
160
+ if status_code != 401:
161
+ return
162
+
163
+ self.fail_count += 1
164
+ self.last_fail_at = int(datetime.now().timestamp() * 1000)
165
+ self.last_fail_reason = reason
166
+
167
+ if self.fail_count >= FAIL_THRESHOLD:
168
+ self.status = TokenStatus.EXPIRED
169
+
170
+ def record_success(self, is_usage: bool = True):
171
+ """记录成功,清空失败计数并根据配额更新状态"""
172
+ self.fail_count = 0
173
+ self.last_fail_at = None
174
+ self.last_fail_reason = None
175
+
176
+ if is_usage:
177
+ self.use_count += 1
178
+ self.last_used_at = int(datetime.now().timestamp() * 1000)
179
+
180
+ if self.quota == 0:
181
+ self.status = TokenStatus.COOLING
182
+ else:
183
+ self.status = TokenStatus.ACTIVE
184
+
185
+ def need_refresh(self, interval_hours: int = 8) -> bool:
186
+ """检查是否需要刷新配额"""
187
+ if self.status != TokenStatus.COOLING:
188
+ return False
189
+
190
+ if self.last_sync_at is None:
191
+ return True
192
+
193
+ now = int(datetime.now().timestamp() * 1000)
194
+ interval_ms = interval_hours * 3600 * 1000
195
+ return (now - self.last_sync_at) >= interval_ms
196
+
197
+ def mark_synced(self):
198
+ """标记已同步"""
199
+ self.last_sync_at = int(datetime.now().timestamp() * 1000)
200
+
201
+
202
+ class TokenPoolStats(BaseModel):
203
+ """Token 池统计"""
204
+ total: int = 0
205
+ active: int = 0
206
+ disabled: int = 0
207
+ expired: int = 0
208
+ cooling: int = 0
209
+ total_quota: int = 0
210
+ avg_quota: float = 0.0
211
+
212
+
213
+ __all__ = [
214
+ "TokenStatus",
215
+ "TokenInfo",
216
+ "TokenPoolStats",
217
+ "EffortType",
218
+ "EFFORT_COST",
219
+ "DEFAULT_QUOTA",
220
+ "FAIL_THRESHOLD",
221
+ ]
app/services/token/pool.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Token 池管理"""
2
+
3
+ import random
4
+ from typing import Dict, List, Optional, Iterator
5
+
6
+ from app.services.token.models import TokenInfo, TokenStatus, TokenPoolStats
7
+
8
+
9
+ class TokenPool:
10
+ """Token 池(管理一组 Token)"""
11
+
12
+ def __init__(self, name: str):
13
+ self.name = name
14
+ self._tokens: Dict[str, TokenInfo] = {}
15
+
16
+ def add(self, token: TokenInfo):
17
+ """添加 Token"""
18
+ self._tokens[token.token] = token
19
+
20
+ def remove(self, token_str: str) -> bool:
21
+ """删除 Token"""
22
+ if token_str in self._tokens:
23
+ del self._tokens[token_str]
24
+ return True
25
+ return False
26
+
27
+ def get(self, token_str: str) -> Optional[TokenInfo]:
28
+ """获取 Token"""
29
+ return self._tokens.get(token_str)
30
+
31
+ def select(self, bucket: str = "normal") -> Optional[TokenInfo]:
32
+ """
33
+ 选择一个可用 Token
34
+ 策略:
35
+ 1. 选择 active 状态且有配额的 token
36
+ 2. 优先选择剩余额度最多的
37
+ 3. 如果额度相同,随机选择(避免并发冲突)
38
+ """
39
+ # 选择 token
40
+ if bucket == "heavy":
41
+ available = [
42
+ t
43
+ for t in self._tokens.values()
44
+ if t.status in (TokenStatus.ACTIVE, TokenStatus.COOLING) and t.heavy_quota != 0
45
+ ]
46
+
47
+ if not available:
48
+ return None
49
+
50
+ unknown = [t for t in available if t.heavy_quota < 0]
51
+ if unknown:
52
+ return random.choice(unknown)
53
+
54
+ max_quota = max(t.heavy_quota for t in available)
55
+ candidates = [t for t in available if t.heavy_quota == max_quota]
56
+ return random.choice(candidates)
57
+
58
+ available = [
59
+ t for t in self._tokens.values()
60
+ if t.status == TokenStatus.ACTIVE and t.quota > 0
61
+ ]
62
+
63
+ if not available:
64
+ return None
65
+
66
+ # 找到最大额度
67
+ max_quota = max(t.quota for t in available)
68
+
69
+ # 筛选最大额度
70
+ candidates = [t for t in available if t.quota == max_quota]
71
+
72
+ # 随机选择
73
+ return random.choice(candidates)
74
+
75
+ def count(self) -> int:
76
+ """Token 数量"""
77
+ return len(self._tokens)
78
+
79
+ def list(self) -> List[TokenInfo]:
80
+ """获取所有 Token"""
81
+ return list(self._tokens.values())
82
+
83
+ def get_stats(self) -> TokenPoolStats:
84
+ """获取池统计信息"""
85
+ stats = TokenPoolStats(total=len(self._tokens))
86
+
87
+ for token in self._tokens.values():
88
+ stats.total_quota += token.quota
89
+
90
+ if token.status == TokenStatus.ACTIVE:
91
+ stats.active += 1
92
+ elif token.status == TokenStatus.DISABLED:
93
+ stats.disabled += 1
94
+ elif token.status == TokenStatus.EXPIRED:
95
+ stats.expired += 1
96
+ elif token.status == TokenStatus.COOLING:
97
+ stats.cooling += 1
98
+
99
+ if stats.total > 0:
100
+ stats.avg_quota = stats.total_quota / stats.total
101
+
102
+ return stats
103
+
104
+ def _rebuild_index(self):
105
+ """重建索引(预留接口,用于加载时调用)"""
106
+ pass
107
+
108
+ def __iter__(self) -> Iterator[TokenInfo]:
109
+ return iter(self._tokens.values())
110
+
111
+
112
+ __all__ = ["TokenPool"]
app/services/token/scheduler.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Token 刷新调度器"""
2
+
3
+ import asyncio
4
+ from typing import Optional
5
+
6
+ from app.core.logger import logger
7
+ from app.core.storage import get_storage, StorageError, RedisStorage
8
+ from app.services.token.manager import get_token_manager
9
+
10
+
11
+ class TokenRefreshScheduler:
12
+ """Token 自动刷新调度器"""
13
+
14
+ def __init__(self, interval_hours: int = 8):
15
+ self.interval_hours = interval_hours
16
+ self.interval_seconds = interval_hours * 3600
17
+ self._task: Optional[asyncio.Task] = None
18
+ self._running = False
19
+
20
+ async def _refresh_loop(self):
21
+ """刷新循环"""
22
+ logger.info(f"Scheduler: started (interval: {self.interval_hours}h)")
23
+
24
+ while self._running:
25
+ try:
26
+ await asyncio.sleep(self.interval_seconds)
27
+ storage = get_storage()
28
+ lock_acquired = False
29
+ lock = None
30
+
31
+ if isinstance(storage, RedisStorage):
32
+ # Redis: non-blocking lock to avoid multi-worker duplication
33
+ lock_key = "grok2api:lock:token_refresh"
34
+ lock = storage.redis.lock(lock_key, timeout=self.interval_seconds + 60, blocking_timeout=0)
35
+ lock_acquired = await lock.acquire(blocking=False)
36
+ else:
37
+ try:
38
+ async with storage.acquire_lock("token_refresh", timeout=0):
39
+ lock_acquired = True
40
+ except StorageError:
41
+ lock_acquired = False
42
+
43
+ if not lock_acquired:
44
+ logger.info("Scheduler: skipped (lock not acquired)")
45
+ continue
46
+
47
+ try:
48
+ logger.info("Scheduler: starting token refresh...")
49
+ manager = await get_token_manager()
50
+ result = await manager.refresh_cooling_tokens()
51
+
52
+ logger.info(
53
+ f"Scheduler: refresh completed - "
54
+ f"checked={result['checked']}, "
55
+ f"refreshed={result['refreshed']}, "
56
+ f"recovered={result['recovered']}, "
57
+ f"expired={result['expired']}"
58
+ )
59
+ finally:
60
+ if lock is not None and lock_acquired:
61
+ try:
62
+ await lock.release()
63
+ except Exception:
64
+ pass
65
+
66
+ except asyncio.CancelledError:
67
+ break
68
+ except Exception as e:
69
+ logger.error(f"Scheduler: refresh error - {e}")
70
+
71
+ def start(self):
72
+ """启动调度器"""
73
+ if self._running:
74
+ logger.warning("Scheduler: already running")
75
+ return
76
+
77
+ self._running = True
78
+ self._task = asyncio.create_task(self._refresh_loop())
79
+ logger.info("Scheduler: enabled")
80
+
81
+ def stop(self):
82
+ """停止调度器"""
83
+ if not self._running:
84
+ return
85
+
86
+ self._running = False
87
+ if self._task:
88
+ self._task.cancel()
89
+ logger.info("Scheduler: stopped")
90
+
91
+
92
+ # 全局单例
93
+ _scheduler: Optional[TokenRefreshScheduler] = None
94
+
95
+
96
+ def get_scheduler(interval_hours: int = 8) -> TokenRefreshScheduler:
97
+ """获取调度器单例"""
98
+ global _scheduler
99
+ if _scheduler is None:
100
+ _scheduler = TokenRefreshScheduler(interval_hours)
101
+ return _scheduler
102
+
103
+
104
+ __all__ = ["TokenRefreshScheduler", "get_scheduler"]
app/services/token/service.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Token 服务外观(Facade)"""
2
+
3
+ from typing import List, Optional, Dict
4
+
5
+ from app.services.token.manager import get_token_manager
6
+ from app.services.token.models import TokenInfo, EffortType
7
+
8
+
9
+ class TokenService:
10
+ """
11
+ Token 服务外观
12
+
13
+ 提供简化的 API,隐藏内部实现细节
14
+ """
15
+
16
+ @staticmethod
17
+ async def get_token(pool_name: str = "ssoBasic") -> Optional[str]:
18
+ """
19
+ 获取可用 Token
20
+
21
+ Args:
22
+ pool_name: Token 池名称
23
+
24
+ Returns:
25
+ Token 字符串(不含 sso= 前缀)或 None
26
+ """
27
+ manager = await get_token_manager()
28
+ return manager.get_token(pool_name)
29
+
30
+ @staticmethod
31
+ async def consume(token: str, effort: EffortType = EffortType.LOW) -> bool:
32
+ """
33
+ 消耗 Token 配额(本地预估)
34
+
35
+ Args:
36
+ token: Token 字符串
37
+ effort: 消耗力度
38
+
39
+ Returns:
40
+ 是否成功
41
+ """
42
+ manager = await get_token_manager()
43
+ return await manager.consume(token, effort)
44
+
45
+ @staticmethod
46
+ async def sync_usage(
47
+ token: str,
48
+ model: str,
49
+ effort: EffortType = EffortType.LOW
50
+ ) -> bool:
51
+ """
52
+ 同步 Token 使用量(优先 API,降级本地)
53
+
54
+ Args:
55
+ token: Token 字符串
56
+ model: 模型名称
57
+ effort: 降级时的消耗力度
58
+
59
+ Returns:
60
+ 是否成功
61
+ """
62
+ manager = await get_token_manager()
63
+ return await manager.sync_usage(token, model, effort)
64
+
65
+ @staticmethod
66
+ async def record_fail(token: str, status_code: int = 401, reason: str = "") -> bool:
67
+ """
68
+ 记录 Token 失败
69
+
70
+ Args:
71
+ token: Token 字符串
72
+ status_code: HTTP 状态码
73
+ reason: 失败原因
74
+
75
+ Returns:
76
+ 是否成功
77
+ """
78
+ manager = await get_token_manager()
79
+ return await manager.record_fail(token, status_code, reason)
80
+
81
+ @staticmethod
82
+ async def add_token(token: str, pool_name: str = "ssoBasic") -> bool:
83
+ """
84
+ 添加 Token
85
+
86
+ Args:
87
+ token: Token 字符串
88
+ pool: Token 池名称
89
+
90
+ Returns:
91
+ 是否成功
92
+ """
93
+ manager = await get_token_manager()
94
+ return await manager.add(token, pool_name)
95
+
96
+ @staticmethod
97
+ async def remove_token(token: str) -> bool:
98
+ """
99
+ 删除 Token
100
+
101
+ Args:
102
+ token: Token 字符串
103
+
104
+ Returns:
105
+ 是否成功
106
+ """
107
+ manager = await get_token_manager()
108
+ return await manager.remove(token)
109
+
110
+ @staticmethod
111
+ async def reset_token(token: str) -> bool:
112
+ """
113
+ 重置单个 Token
114
+
115
+ Args:
116
+ token: Token 字符串
117
+
118
+ Returns:
119
+ 是否成功
120
+ """
121
+ manager = await get_token_manager()
122
+ return await manager.reset_token(token)
123
+
124
+ @staticmethod
125
+ async def reset_all():
126
+ """重置所有 Token"""
127
+ manager = await get_token_manager()
128
+ await manager.reset_all()
129
+
130
+ @staticmethod
131
+ async def get_stats() -> Dict[str, dict]:
132
+ """
133
+ 获取统计信息
134
+
135
+ Returns:
136
+ 各池的统计信息
137
+ """
138
+ manager = await get_token_manager()
139
+ return manager.get_stats()
140
+
141
+ @staticmethod
142
+ async def list_tokens(pool_name: str = "ssoBasic") -> List[TokenInfo]:
143
+ """
144
+ 获取指定池的所有 Token
145
+
146
+ Args:
147
+ pool_name: Token 池名称
148
+
149
+ Returns:
150
+ Token 列表
151
+ """
152
+ manager = await get_token_manager()
153
+ return manager.get_pool_tokens(pool_name)
154
+
155
+
156
+ __all__ = ["TokenService"]
app/static/.assetsignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ _worker.js
2
+
app/static/_worker.js ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ import worker from "../../src/index.ts";
2
+
3
+ export default worker;
4
+