meyosaj406 commited on
Commit
508f22f
·
verified ·
1 Parent(s): 87c20b0

Update app/providers/notion_provider.py

Browse files
Files changed (1) hide show
  1. app/providers/notion_provider.py +169 -333
app/providers/notion_provider.py CHANGED
@@ -4,13 +4,10 @@ import time
4
  import logging
5
  import uuid
6
  import re
7
- import random
8
  from typing import Dict, Any, AsyncGenerator, List, Optional, Tuple
9
  from datetime import datetime
10
 
11
- import requests
12
- import cloudscraper
13
-
14
  from fastapi import HTTPException
15
  from fastapi.responses import StreamingResponse, JSONResponse
16
  from fastapi.concurrency import run_in_threadpool
@@ -22,172 +19,30 @@ from app.utils.sse_utils import create_sse_data, create_chat_completion_chunk, D
22
  # 设置日志记录器
23
  logger = logging.getLogger(__name__)
24
 
25
- # --- 会话/重试常量 ---
26
- SCRAPER_LIFETIME_SEC = 45 * 60 # Scraper 最长存活 45 分钟(超时后自动重建)
27
- HTTP_MAX_RETRIES = 3 # 单个 HTTP 请求最大重试次数
28
- RETRYABLE_STATUS = {401, 403, 429, 502, 503, 504}
29
- BASE_BACKOFF_SEC = 1.0 # 指数退避初始秒
30
- BACKOFF_JITTER_SEC = 0.25 # 退避抖动
31
- API_REQUEST_TIMEOUT = getattr(settings, "API_REQUEST_TIMEOUT", 60)
32
-
33
-
34
  class NotionAIProvider(BaseProvider):
35
  def __init__(self):
36
- # 仅校验三件套;**不读取 NOTION_THREAD_ID**
37
- if not all([settings.NOTION_COOKIE, settings.NOTION_SPACE_ID, settings.NOTION_USER_ID]):
38
- raise ValueError("配置错误: NOTION_COOKIE, NOTION_SPACE_ID 和 NOTION_USER_ID 必须在 .env 中全部设置。")
39
-
40
  self.scraper = cloudscraper.create_scraper()
41
- self._scraper_born_at = time.time()
42
-
43
  self.api_endpoints = {
44
  "runInference": "https://www.notion.so/api/v3/runInferenceTranscript",
45
- "saveTransactionsFanout": "https://www.notion.so/api/v3/saveTransactionsFanout",
46
- "saveTransactions": "https://www.notion.so/api/v3/saveTransactions",
47
  }
 
 
 
48
 
49
- # 预热可失败,不阻塞启动;请求阶段还有兜底重试
50
- try:
51
- self._warmup_session()
52
- except Exception as e:
53
- logger.warning("会话预热失败但不影响启动:%s", e)
54
 
55
- # ----------------------------------------------------------------------
56
- # Scraper 生命周期与预热
57
- # ----------------------------------------------------------------------
58
- def _refresh_scraper(self, reason: str = "") -> None:
59
- """重建 cloudscraper(应对 CF 验证失败/会话过期)。"""
60
  try:
61
- self.scraper.close()
62
- except Exception:
63
- pass
64
- self.scraper = cloudscraper.create_scraper()
65
- self._scraper_born_at = time.time()
66
- logger.info("重建 cloudscraper 会话。原因:%s", reason or "未指定")
67
-
68
- def _get_scraper(self):
69
- if time.time() - self._scraper_born_at > SCRAPER_LIFETIME_SEC:
70
- self._refresh_scraper("生命周期已到")
71
- return self.scraper
72
-
73
- def _normalize_cookie(self, raw: str) -> str:
74
- c = (raw or "").strip()
75
- if not c:
76
- return ""
77
- # 若仅填了 token_v2 的值,这里补齐键名
78
- if "token_v2=" not in c:
79
- c = f"token_v2={c}"
80
- return c
81
-
82
- def _prepare_headers(self) -> Dict[str, str]:
83
- cookie_header = self._normalize_cookie(settings.NOTION_COOKIE)
84
- return {
85
- "Content-Type": "application/json",
86
- "Accept": "application/x-ndjson",
87
- "Cookie": cookie_header,
88
- "x-notion-space-id": settings.NOTION_SPACE_ID,
89
- "x-notion-active-user-header": settings.NOTION_USER_ID,
90
- "x-notion-client-version": getattr(settings, "NOTION_CLIENT_VERSION", "23.13.20251011.2037"),
91
- "notion-audit-log-platform": "web",
92
- "Origin": "https://www.notion.so",
93
- "Referer": "https://www.notion.so/",
94
- "User-Agent": (
95
- "Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
96
- "AppleWebKit(537.36) (KHTML, like Gecko) "
97
- "Chrome/125.0.0.0 Safari/537.36"
98
- ),
99
- }
100
-
101
- def _warmup_session(self) -> None:
102
- """轻量预热;遭遇 429 采用退避,不抛异常。"""
103
- s = self._get_scraper()
104
- headers = self._prepare_headers().copy()
105
- headers.pop("Accept", None) # 预热不要 ndjson
106
- url = "https://www.notion.so/"
107
- for attempt in range(2):
108
- try:
109
- r = s.get(url, headers=headers, timeout=20)
110
- if r.status_code == 429:
111
- retry_after = r.headers.get("Retry-After")
112
- if retry_after and retry_after.isdigit():
113
- sleep_sec = int(retry_after)
114
- else:
115
- sleep_sec = BASE_BACKOFF_SEC * (2 ** attempt) + random.random() * BACKOFF_JITTER_SEC
116
- logger.warning("预热命中 429,退避 %.2fs 后再试(第 %d 次)", sleep_sec, attempt + 1)
117
- time.sleep(sleep_sec)
118
- continue
119
- r.raise_for_status()
120
- logger.info("会话预热成功。")
121
- return
122
- except Exception as e:
123
- logger.warning("会话预热失败(第 %d 次):%s", attempt + 1, e)
124
- # 不抛异常
125
-
126
- # ----------------------------------------------------------------------
127
- # HTTP POST(自动重试/退避/重建)
128
- # ----------------------------------------------------------------------
129
- def _post_with_retry(
130
- self,
131
- url: str,
132
- *,
133
- headers: Dict[str, str],
134
- payload: Dict[str, Any],
135
- stream: bool = False,
136
- max_retries: int = HTTP_MAX_RETRIES,
137
- ) -> requests.Response:
138
- for attempt in range(max_retries):
139
- try:
140
- s = self._get_scraper()
141
- resp = s.post(url, headers=headers, json=payload, stream=stream, timeout=API_REQUEST_TIMEOUT)
142
-
143
- if resp.status_code in RETRYABLE_STATUS:
144
- # 429:遵循 Retry-After 或指数退避
145
- if resp.status_code == 429:
146
- retry_after = resp.headers.get("Retry-After")
147
- if retry_after and retry_after.isdigit():
148
- sleep_sec = int(retry_after)
149
- else:
150
- sleep_sec = BASE_BACKOFF_SEC * (2 ** attempt) + random.random() * BACKOFF_JITTER_SEC
151
- logger.warning("POST %s 命中 429,退避 %.2fs 后重试(第 %d/%d 次)",
152
- url, sleep_sec, attempt + 1, max_retries)
153
- time.sleep(sleep_sec)
154
- continue
155
-
156
- # 401/403:重建会话 + 预热
157
- if resp.status_code in (401, 403):
158
- logger.warning("POST %s 返回 %s,重建会话并预热(第 %d/%d 次)",
159
- url, resp.status_code, attempt + 1, max_retries)
160
- self._refresh_scraper(f"HTTP {resp.status_code}")
161
- self._warmup_session()
162
- continue
163
-
164
- # 5xx:指数退避
165
- if resp.status_code in (502, 503, 504):
166
- sleep_sec = BASE_BACKOFF_SEC * (2 ** attempt) + random.random() * BACKOFF_JITTER_SEC
167
- logger.warning("POST %s 返回 %s,退避 %.2fs 后重试(第 %d/%d 次)",
168
- url, resp.status_code, sleep_sec, attempt + 1, max_retries)
169
- time.sleep(sleep_sec)
170
- continue
171
-
172
- resp.raise_for_status()
173
- return resp
174
-
175
- except requests.RequestException as e:
176
- sleep_sec = BASE_BACKOFF_SEC * (2 ** attempt) + random.random() * BACKOFF_JITTER_SEC
177
- logger.warning("POST %s 网络异常:%s,退避 %.2fs 后重试(第 %d/%d 次)",
178
- url, e, sleep_sec, attempt + 1, max_retries)
179
- self._refresh_scraper("网络异常后重建")
180
- time.sleep(sleep_sec)
181
-
182
- raise HTTPException(status_code=502, detail=f"调用 {url} 多次重试仍失败。")
183
-
184
- def _open_stream_with_retry(self, url: str, headers: Dict[str, str], payload: Dict[str, Any]):
185
- resp = self._post_with_retry(url, headers=headers, payload=payload, stream=True)
186
- return resp.iter_lines()
187
-
188
- # ----------------------------------------------------------------------
189
- # 动态创建线程(不依赖固定会话 ID)
190
- # ----------------------------------------------------------------------
191
  async def _create_thread(self, thread_type: str) -> str:
192
  thread_id = str(uuid.uuid4())
193
  payload = {
@@ -209,161 +64,156 @@ class NotionAIProvider(BaseProvider):
209
  }]
210
  }]
211
  }
212
- headers = self._prepare_headers()
213
-
214
- # 先尝试 Fanout,不行再回退至 saveTransactions
215
- for ep_key in ("saveTransactionsFanout", "saveTransactions"):
216
- url = self.api_endpoints[ep_key]
217
- try:
218
- logger.info("创建线程:尝试 %s", ep_key)
219
- await run_in_threadpool(lambda: self._post_with_retry(url, headers=headers, payload=payload))
220
- logger.info("对话线程创建成功, Thread ID: %s", thread_id)
221
- return thread_id
222
- except HTTPException as he:
223
- status = getattr(he, "status_code", None)
224
- if status in (404, 405) and ep_key == "saveTransactionsFanout":
225
- logger.warning("Fanout 接口不可用,回退到 saveTransactions。")
226
- continue
227
- raise
228
- except Exception as e:
229
- if ep_key == "saveTransactionsFanout":
230
- logger.warning("Fanout 创建线程失败(%s),尝试旧接口。", e)
231
- continue
232
- logger.error("创建线程失败:%s", e, exc_info=True)
233
- raise HTTPException(status_code=502, detail="无法创建新的对话线程。")
234
-
235
- raise HTTPException(status_code=502, detail="创建线程失败:所有接口均不可用。")
236
 
237
- # ----------------------------------------------------------------------
238
- # Chat Completions(流式 + 自动重试)
239
- # ----------------------------------------------------------------------
240
  async def chat_completion(self, request_data: Dict[str, Any]):
241
  stream = request_data.get("stream", True)
242
- if not stream:
243
- raise HTTPException(status_code=400, detail="此端点当前仅支持流式响应 (stream=true)。")
244
 
245
  async def stream_generator() -> AsyncGenerator[bytes, None]:
246
  request_id = f"chatcmpl-{uuid.uuid4()}"
247
- model_name = request_data.get("model", settings.DEFAULT_MODEL)
248
-
249
- # 先下发一次 role,避免重试带来重复角色片段
250
- role_chunk = create_chat_completion_chunk(request_id, model_name, role="assistant")
251
- yield create_sse_data(role_chunk)
252
-
253
- last_error: Optional[Exception] = None
254
-
255
- # 整体最多重试 2 轮(每轮都将重建线程/刷新会话)
256
- for outer_try in range(2):
257
- try:
258
- mapped_model = settings.MODEL_MAP.get(model_name, "anthropic-sonnet-alt")
259
- thread_type = "markdown-chat" if mapped_model.startswith("vertex-") else "workflow"
260
-
261
- # 每轮都新建线程,避免上下文串扰
262
- thread_id = await self._create_thread(thread_type)
263
- payload = self._prepare_payload(request_data, thread_id, mapped_model, thread_type)
264
- headers = self._prepare_headers()
265
-
266
- logger.info("请求 Notion AI(第 %d 轮):%s", outer_try + 1, self.api_endpoints["runInference"])
267
- logger.info("请求体:%s", json.dumps(payload, ensure_ascii=False, indent=2))
268
-
269
- lines_iter = await run_in_threadpool(
270
- lambda: self._open_stream_with_retry(self.api_endpoints["runInference"], headers, payload)
271
- )
272
-
273
- incremental_fragments: List[str] = []
274
- final_message: Optional[str] = None
275
-
276
- while True:
277
- line = await run_in_threadpool(lambda: next(lines_iter, None))
278
- if line is None:
279
- break
280
-
281
- parsed_results = self._parse_ndjson_line_to_texts(line)
282
- for text_type, content in parsed_results:
283
- if text_type == 'final':
284
- final_message = content
285
- elif text_type == 'incremental':
286
- incremental_fragments.append(content)
287
-
288
- full_response = final_message if final_message else "".join(incremental_fragments)
289
-
290
- if full_response:
291
- cleaned_response = self._clean_content(full_response)
292
- logger.info("清洗后的最终响应: %s", cleaned_response)
293
- chunk = create_chat_completion_chunk(request_id, model_name, content=cleaned_response)
294
- yield create_sse_data(chunk)
295
-
296
- final_chunk = create_chat_completion_chunk(request_id, model_name, finish_reason="stop")
297
- yield create_sse_data(final_chunk)
298
- yield DONE_CHUNK
299
- return
300
-
301
- # 内容为空:重建会话并再预热,进入下一轮
302
- logger.warning("警告(第 %d 轮):流数据为空,尝试重建会话/再预热后重试。", outer_try + 1)
303
- last_error = Exception("空响应")
304
- self._refresh_scraper("空响应后重建")
305
- await run_in_threadpool(self._warmup_session)
306
- continue
307
-
308
- except Exception as e:
309
- last_error = e
310
- logger.error("处理 Notion AI 流时发生错误(第 %d 轮):%s", outer_try + 1, e, exc_info=True)
311
- if outer_try == 0:
312
- self._refresh_scraper("处理异常后重建")
313
- await run_in_threadpool(self._warmup_session)
314
- continue
315
- break
316
-
317
- # 重试用尽
318
- error_message = f"处理 Notion AI 流失败(自动重试已用尽):{str(last_error)}"
319
- logger.error(error_message)
320
- error_chunk = {"error": {"message": error_message, "type": "internal_server_error"}}
321
- yield create_sse_data(error_chunk)
322
- yield DONE_CHUNK
323
-
324
- return StreamingResponse(stream_generator(), media_type="text/event-stream")
325
-
326
- # ----------------------------------------------------------------------
327
- # Header & Payload
328
- # ----------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
  def _normalize_block_id(self, block_id: str) -> str:
330
- if not block_id:
331
- return block_id
332
  b = block_id.replace("-", "").strip()
333
  if len(b) == 32 and re.fullmatch(r"[0-9a-fA-F]{32}", b):
334
  return f"{b[0:8]}-{b[8:12]}-{b[12:16]}-{b[16:20]}-{b[20:]}"
335
  return block_id
336
 
337
  def _prepare_payload(self, request_data: Dict[str, Any], thread_id: str, mapped_model: str, thread_type: str) -> Dict[str, Any]:
338
- req_block_id = request_data.get("notion_block_id") or getattr(settings, "NOTION_BLOCK_ID", None)
339
  normalized_block_id = self._normalize_block_id(req_block_id) if req_block_id else None
340
 
341
  context_value: Dict[str, Any] = {
342
  "timezone": "Asia/Shanghai",
343
  "spaceId": settings.NOTION_SPACE_ID,
344
  "userId": settings.NOTION_USER_ID,
345
- "userEmail": getattr(settings, "NOTION_USER_EMAIL", None),
346
  "currentDatetime": datetime.now().astimezone().isoformat(),
347
  }
348
  if normalized_block_id:
349
  context_value["blockId"] = normalized_block_id
350
 
351
- # Vertex / Gemini 的特殊上下文与配置(不再硬编码 spaceViewId,可从环境变量传入)
 
352
  if mapped_model.startswith("vertex-"):
353
- logger.info(f"检测到 Gemini 模型 ({mapped_model}),应用特定的 config/context。")
354
- gemini_context = {
355
- "userName": f"{getattr(settings, 'NOTION_USER_NAME', '')}",
356
- "spaceName": f"{getattr(settings, 'NOTION_USER_NAME', '')}的 Notion",
 
357
  "surface": "ai_module"
358
- }
359
- space_view_id = getattr(settings, "NOTION_SPACE_VIEW_ID", None)
360
- if space_view_id:
361
- gemini_context["spaceViewId"] = space_view_id
362
- logger.info(f"使用配置中的 spaceViewId: {space_view_id}")
363
- else:
364
- logger.warning("未配置 NOTION_SPACE_VIEW_ID,Gemini 模型可能无法正常工作。")
365
- context_value.update(gemini_context)
366
-
367
  config_value = {
368
  "type": thread_type,
369
  "model": mapped_model,
@@ -379,7 +229,7 @@ class NotionAIProvider(BaseProvider):
379
  }
380
  else:
381
  context_value.update({
382
- "userName": getattr(settings, "NOTION_USER_NAME", None),
383
  "surface": "workflows"
384
  })
385
  config_value = {
@@ -392,7 +242,7 @@ class NotionAIProvider(BaseProvider):
392
  {"id": str(uuid.uuid4()), "type": "config", "value": config_value},
393
  {"id": str(uuid.uuid4()), "type": "context", "value": context_value}
394
  ]
395
-
396
  for msg in request_data.get("messages", []):
397
  if msg.get("role") == "user":
398
  transcript.append({
@@ -403,11 +253,7 @@ class NotionAIProvider(BaseProvider):
403
  "createdAt": datetime.now().astimezone().isoformat()
404
  })
405
  elif msg.get("role") == "assistant":
406
- transcript.append({
407
- "id": str(uuid.uuid4()),
408
- "type": "agent-inference",
409
- "value": [{"type": "text", "content": msg.get("content")}]
410
- })
411
 
412
  payload = {
413
  "traceId": str(uuid.uuid4()),
@@ -430,20 +276,17 @@ class NotionAIProvider(BaseProvider):
430
  "annotationInferences": {},
431
  "emitInferences": False
432
  }
433
-
434
  return payload
435
 
436
- # ----------------------------------------------------------------------
437
- # 内容清洗(保持你的逻辑)
438
- # ----------------------------------------------------------------------
439
  def _clean_content(self, content: str) -> str:
440
  if not content:
441
  return ""
442
-
443
  content = re.sub(r'<lang primary="[^"]*"\s*/>\n*', '', content)
444
  content = re.sub(r'<thinking>[\s\S]*?</thinking>\s*', '', content, flags=re.IGNORECASE)
445
  content = re.sub(r'<thought>[\s\S]*?</thought>\s*', '', content, flags=re.IGNORECASE)
446
-
447
  content = re.sub(r'^.*?Chinese whatmodel I am.*?Theyspecifically.*?requested.*?me.*?to.*?reply.*?in.*?Chinese\.\s*', '', content, flags=re.IGNORECASE | re.DOTALL)
448
  content = re.sub(r'^.*?This.*?is.*?a.*?straightforward.*?question.*?about.*?my.*?identity.*?asan.*?AI.*?assistant\.\s*', '', content, flags=re.IGNORECASE | re.DOTALL)
449
  content = re.sub(r'^.*?Idon\'t.*?need.*?to.*?use.*?any.*?tools.*?for.*?this.*?-\s*it\'s.*?asimple.*?informational.*?response.*?aboutwhat.*?I.*?am\.\s*', '', content, flags=re.IGNORECASE | re.DOTALL)
@@ -452,22 +295,18 @@ class NotionAIProvider(BaseProvider):
452
  content = re.sub(r'^.*?This.*?is.*?a.*?question.*?about.*?my.*?identity.*?not requiring.*?any.*?tool.*?use.*?I.*?should.*?respond.*?directly.*?to.*?the.*?user.*?in.*?Chinese.*?as.*?requested\.\s*', '', content, flags=re.IGNORECASE | re.DOTALL)
453
  content = re.sub(r'^.*?I.*?should.*?identify.*?myself.*?as.*?Notion.*?AI.*?as.*?mentioned.*?in.*?the.*?system.*?prompt.*?\s*', '', content, flags=re.IGNORECASE | re.DOTALL)
454
  content = re.sub(r'^.*?I.*?should.*?not.*?make.*?specific.*?claims.*?about.*?the.*?underlying.*?model.*?architecture.*?since.*?that.*?information.*?is.*?not.*?provided.*?in.*?my.*?context\.\s*', '', content, flags=re.IGNORECASE | re.DOTALL)
455
-
456
  return content.strip()
457
 
458
- # ----------------------------------------------------------------------
459
- # NDJSON 解析(保持并扩展)
460
- # ----------------------------------------------------------------------
461
  def _parse_ndjson_line_to_texts(self, line: bytes) -> List[Tuple[str, str]]:
462
  results: List[Tuple[str, str]] = []
463
  try:
464
  s = line.decode("utf-8", errors="ignore").strip()
465
- if not s:
466
- return results
467
-
468
  data = json.loads(s)
469
  logger.debug(f"原始响应数据: {json.dumps(data, ensure_ascii=False)}")
470
-
471
  # 格式1: Gemini 返回的 markdown-chat 事件
472
  if data.get("type") == "markdown-chat":
473
  content = data.get("value", "")
@@ -475,52 +314,51 @@ class NotionAIProvider(BaseProvider):
475
  logger.info("从 'markdown-chat' 直接事件中提取到内容。")
476
  results.append(('final', content))
477
 
478
- # 格式2: Claude/GPT/Gemini 的 patch
479
  elif data.get("type") == "patch" and "v" in data:
480
  for operation in data.get("v", []):
481
- if not isinstance(operation, dict):
482
- continue
483
  op_type = operation.get("o")
484
  path = operation.get("p", "")
485
  value = operation.get("v")
486
-
487
- # Gemini 完整
488
  if op_type == "a" and path.endswith("/s/-") and isinstance(value, dict) and value.get("type") == "markdown-chat":
489
  content = value.get("value", "")
490
  if content:
491
  logger.info("从 'patch' (Gemini-style) 中提取到完整内容。")
492
  results.append(('final', content))
493
-
494
- # Gemini 增量
495
  elif op_type == "x" and "/s/" in path and path.endswith("/value") and isinstance(value, str):
496
  content = value
497
  if content:
498
  logger.info(f"从 'patch' (Gemini增量) 中提取到内容: {content}")
499
  results.append(('incremental', content))
500
-
501
- # Claude / GPT 增量
502
  elif op_type == "x" and "/value/" in path and isinstance(value, str):
503
  content = value
504
  if content:
505
  logger.info(f"从 'patch' (Claude/GPT增量) 中提取到内容: {content}")
506
  results.append(('incremental', content))
507
-
508
- # Claude / GPT 完整
509
  elif op_type == "a" and path.endswith("/value/-") and isinstance(value, dict) and value.get("type") == "text":
510
  content = value.get("content", "")
511
  if content:
512
  logger.info("从 'patch' (Claude/GPT-style) 中提取到完整内容。")
513
  results.append(('final', content))
514
 
515
- # 格式3: record-map
516
  elif data.get("type") == "record-map" and "recordMap" in data:
517
  record_map = data["recordMap"]
518
  if "thread_message" in record_map:
519
- for _, msg_data in record_map["thread_message"].items():
520
  value_data = msg_data.get("value", {}).get("value", {})
521
  step = value_data.get("step", {})
522
- if not step:
523
- continue
524
 
525
  content = ""
526
  step_type = step.get("type")
@@ -534,20 +372,17 @@ class NotionAIProvider(BaseProvider):
534
  if isinstance(item, dict) and item.get("type") == "text":
535
  content = item.get("content", "")
536
  break
537
-
538
  if content and isinstance(content, str):
539
  logger.info(f"从 record-map (type: {step_type}) 提取到最终内容。")
540
  results.append(('final', content))
541
- break
542
-
543
  except (json.JSONDecodeError, AttributeError) as e:
544
  logger.warning(f"解析NDJSON行失败: {e} - Line: {line.decode('utf-8', errors='ignore')}")
545
-
546
  return results
547
 
548
- # ----------------------------------------------------------------------
549
- # 模型列表
550
- # ----------------------------------------------------------------------
551
  async def get_models(self) -> JSONResponse:
552
  model_data = {
553
  "object": "list",
@@ -559,3 +394,4 @@ class NotionAIProvider(BaseProvider):
559
  return JSONResponse(content=model_data)
560
 
561
 
 
 
4
  import logging
5
  import uuid
6
  import re
7
+ import cloudscraper
8
  from typing import Dict, Any, AsyncGenerator, List, Optional, Tuple
9
  from datetime import datetime
10
 
 
 
 
11
  from fastapi import HTTPException
12
  from fastapi.responses import StreamingResponse, JSONResponse
13
  from fastapi.concurrency import run_in_threadpool
 
19
  # 设置日志记录器
20
  logger = logging.getLogger(__name__)
21
 
 
 
 
 
 
 
 
 
 
22
  class NotionAIProvider(BaseProvider):
23
  def __init__(self):
 
 
 
 
24
  self.scraper = cloudscraper.create_scraper()
 
 
25
  self.api_endpoints = {
26
  "runInference": "https://www.notion.so/api/v3/runInferenceTranscript",
27
+ "saveTransactions": "https://www.notion.so/api/v3/saveTransactionsFanout"
 
28
  }
29
+
30
+ if not all([settings.NOTION_COOKIE, settings.NOTION_SPACE_ID, settings.NOTION_USER_ID]):
31
+ raise ValueError("配置错误: NOTION_COOKIE, NOTION_SPACE_ID 和 NOTION_USER_ID 必须在 .env 文件中全部设置。")
32
 
33
+ self._warmup_session()
 
 
 
 
34
 
35
+ def _warmup_session(self):
 
 
 
 
36
  try:
37
+ logger.info("正在进行会话预热 (Session Warm-up)...")
38
+ headers = self._prepare_headers()
39
+ headers.pop("Accept", None)
40
+ response = self.scraper.get("https://www.notion.so/", headers=headers, timeout=30)
41
+ response.raise_for_status()
42
+ logger.info("会话预热成功。")
43
+ except Exception as e:
44
+ logger.error(f"会话预热失败: {e}", exc_info=True)
45
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  async def _create_thread(self, thread_type: str) -> str:
47
  thread_id = str(uuid.uuid4())
48
  payload = {
 
64
  }]
65
  }]
66
  }
67
+ try:
68
+ logger.info(f"正在创建新的对话线程 (type: {thread_type})...")
69
+ response = await run_in_threadpool(
70
+ lambda: self.scraper.post(
71
+ self.api_endpoints["saveTransactions"],
72
+ headers=self._prepare_headers(),
73
+ json=payload,
74
+ timeout=20
75
+ )
76
+ )
77
+ response.raise_for_status()
78
+ logger.info(f"对话线程创建成功, Thread ID: {thread_id}")
79
+ return thread_id
80
+ except Exception as e:
81
+ logger.error(f"创建对话线程失败: {e}", exc_info=True)
82
+ raise Exception("无法创建新的对话线程。")
 
 
 
 
 
 
 
 
83
 
 
 
 
84
  async def chat_completion(self, request_data: Dict[str, Any]):
85
  stream = request_data.get("stream", True)
 
 
86
 
87
  async def stream_generator() -> AsyncGenerator[bytes, None]:
88
  request_id = f"chatcmpl-{uuid.uuid4()}"
89
+ incremental_fragments: List[str] = []
90
+ final_message: Optional[str] = None
91
+
92
+ try:
93
+ model_name = request_data.get("model", settings.DEFAULT_MODEL)
94
+ mapped_model = settings.MODEL_MAP.get(model_name, "anthropic-sonnet-alt")
95
+
96
+ thread_type = "markdown-chat" if mapped_model.startswith("vertex-") else "workflow"
97
+
98
+ thread_id = await self._create_thread(thread_type)
99
+ payload = self._prepare_payload(request_data, thread_id, mapped_model, thread_type)
100
+ headers = self._prepare_headers()
101
+
102
+ role_chunk = create_chat_completion_chunk(request_id, model_name, role="assistant")
103
+ yield create_sse_data(role_chunk)
104
+
105
+ def sync_stream_iterator():
106
+ try:
107
+ logger.info(f"请求 Notion AI URL: {self.api_endpoints['runInference']}")
108
+ logger.info(f"请求体: {json.dumps(payload, indent=2, ensure_ascii=False)}")
109
+
110
+ response = self.scraper.post(
111
+ self.api_endpoints['runInference'], headers=headers, json=payload, stream=True,
112
+ timeout=settings.API_REQUEST_TIMEOUT
113
+ )
114
+ response.raise_for_status()
115
+ for line in response.iter_lines():
116
+ if line:
117
+ yield line
118
+ except Exception as e:
119
+ yield e
120
+
121
+ sync_gen = sync_stream_iterator()
122
+
123
+ while True:
124
+ line = await run_in_threadpool(lambda: next(sync_gen, None))
125
+ if line is None:
126
+ break
127
+ if isinstance(line, Exception):
128
+ raise line
129
+
130
+ parsed_results = self._parse_ndjson_line_to_texts(line)
131
+ for text_type, content in parsed_results:
132
+ if text_type == 'final':
133
+ final_message = content
134
+ elif text_type == 'incremental':
135
+ incremental_fragments.append(content)
136
+
137
+ full_response = ""
138
+ if final_message:
139
+ full_response = final_message
140
+ logger.info(f"成功从 record-map 或 Gemini patch/event 中提取到最终消息。")
141
+ else:
142
+ full_response = "".join(incremental_fragments)
143
+ logger.info(f"使用拼接所有增量片段的方式获得最终消息。")
144
+
145
+ if full_response:
146
+ cleaned_response = self._clean_content(full_response)
147
+ logger.info(f"清洗后的最终响应: {cleaned_response}")
148
+ chunk = create_chat_completion_chunk(request_id, model_name, content=cleaned_response)
149
+ yield create_sse_data(chunk)
150
+ else:
151
+ logger.warning("警告: Notion 返回的数据流中未提取到任何有效文本。请检查您的 .env 配置是否全部正确且凭证有效。")
152
+
153
+ final_chunk = create_chat_completion_chunk(request_id, model_name, finish_reason="stop")
154
+ yield create_sse_data(final_chunk)
155
+ yield DONE_CHUNK
156
+
157
+ except Exception as e:
158
+ error_message = f"���理 Notion AI 流时发生意外错误: {str(e)}"
159
+ logger.error(error_message, exc_info=True)
160
+ error_chunk = {"error": {"message": error_message, "type": "internal_server_error"}}
161
+ yield create_sse_data(error_chunk)
162
+ yield DONE_CHUNK
163
+
164
+ if stream:
165
+ return StreamingResponse(stream_generator(), media_type="text/event-stream")
166
+ else:
167
+ raise HTTPException(status_code=400, detail="此端点当前仅支持流式响应 (stream=true)。")
168
+
169
+ def _prepare_headers(self) -> Dict[str, str]:
170
+ cookie_source = (settings.NOTION_COOKIE or "").strip()
171
+ cookie_header = cookie_source if "=" in cookie_source else f"token_v2={cookie_source}"
172
+
173
+ return {
174
+ "Content-Type": "application/json",
175
+ "Accept": "application/x-ndjson",
176
+ "Cookie": cookie_header,
177
+ "x-notion-space-id": settings.NOTION_SPACE_ID,
178
+ "x-notion-active-user-header": settings.NOTION_USER_ID,
179
+ "x-notion-client-version": settings.NOTION_CLIENT_VERSION,
180
+ "notion-audit-log-platform": "web",
181
+ "Origin": "https://www.notion.so",
182
+ "Referer": "https://www.notion.so/",
183
+ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/125.0.0.0 Safari/537.36",
184
+ }
185
+
186
  def _normalize_block_id(self, block_id: str) -> str:
187
+ if not block_id: return block_id
 
188
  b = block_id.replace("-", "").strip()
189
  if len(b) == 32 and re.fullmatch(r"[0-9a-fA-F]{32}", b):
190
  return f"{b[0:8]}-{b[8:12]}-{b[12:16]}-{b[16:20]}-{b[20:]}"
191
  return block_id
192
 
193
  def _prepare_payload(self, request_data: Dict[str, Any], thread_id: str, mapped_model: str, thread_type: str) -> Dict[str, Any]:
194
+ req_block_id = request_data.get("notion_block_id") or settings.NOTION_BLOCK_ID
195
  normalized_block_id = self._normalize_block_id(req_block_id) if req_block_id else None
196
 
197
  context_value: Dict[str, Any] = {
198
  "timezone": "Asia/Shanghai",
199
  "spaceId": settings.NOTION_SPACE_ID,
200
  "userId": settings.NOTION_USER_ID,
201
+ "userEmail": settings.NOTION_USER_EMAIL,
202
  "currentDatetime": datetime.now().astimezone().isoformat(),
203
  }
204
  if normalized_block_id:
205
  context_value["blockId"] = normalized_block_id
206
 
207
+ config_value: Dict[str, Any]
208
+
209
  if mapped_model.startswith("vertex-"):
210
+ logger.info(f"检测到 Gemini 模型 ({mapped_model}),应用特定的 configcontext。")
211
+ context_value.update({
212
+ "userName": f" {settings.NOTION_USER_NAME}",
213
+ "spaceName": f"{settings.NOTION_USER_NAME}的 Notion",
214
+ "spaceViewId": "29d2ea19-5923-80f2-9f44-00a9fed7bffe",
215
  "surface": "ai_module"
216
+ })
 
 
 
 
 
 
 
 
217
  config_value = {
218
  "type": thread_type,
219
  "model": mapped_model,
 
229
  }
230
  else:
231
  context_value.update({
232
+ "userName": settings.NOTION_USER_NAME,
233
  "surface": "workflows"
234
  })
235
  config_value = {
 
242
  {"id": str(uuid.uuid4()), "type": "config", "value": config_value},
243
  {"id": str(uuid.uuid4()), "type": "context", "value": context_value}
244
  ]
245
+
246
  for msg in request_data.get("messages", []):
247
  if msg.get("role") == "user":
248
  transcript.append({
 
253
  "createdAt": datetime.now().astimezone().isoformat()
254
  })
255
  elif msg.get("role") == "assistant":
256
+ transcript.append({"id": str(uuid.uuid4()), "type": "agent-inference", "value": [{"type": "text", "content": msg.get("content")}]})
 
 
 
 
257
 
258
  payload = {
259
  "traceId": str(uuid.uuid4()),
 
276
  "annotationInferences": {},
277
  "emitInferences": False
278
  }
279
+
280
  return payload
281
 
 
 
 
282
  def _clean_content(self, content: str) -> str:
283
  if not content:
284
  return ""
285
+
286
  content = re.sub(r'<lang primary="[^"]*"\s*/>\n*', '', content)
287
  content = re.sub(r'<thinking>[\s\S]*?</thinking>\s*', '', content, flags=re.IGNORECASE)
288
  content = re.sub(r'<thought>[\s\S]*?</thought>\s*', '', content, flags=re.IGNORECASE)
289
+
290
  content = re.sub(r'^.*?Chinese whatmodel I am.*?Theyspecifically.*?requested.*?me.*?to.*?reply.*?in.*?Chinese\.\s*', '', content, flags=re.IGNORECASE | re.DOTALL)
291
  content = re.sub(r'^.*?This.*?is.*?a.*?straightforward.*?question.*?about.*?my.*?identity.*?asan.*?AI.*?assistant\.\s*', '', content, flags=re.IGNORECASE | re.DOTALL)
292
  content = re.sub(r'^.*?Idon\'t.*?need.*?to.*?use.*?any.*?tools.*?for.*?this.*?-\s*it\'s.*?asimple.*?informational.*?response.*?aboutwhat.*?I.*?am\.\s*', '', content, flags=re.IGNORECASE | re.DOTALL)
 
295
  content = re.sub(r'^.*?This.*?is.*?a.*?question.*?about.*?my.*?identity.*?not requiring.*?any.*?tool.*?use.*?I.*?should.*?respond.*?directly.*?to.*?the.*?user.*?in.*?Chinese.*?as.*?requested\.\s*', '', content, flags=re.IGNORECASE | re.DOTALL)
296
  content = re.sub(r'^.*?I.*?should.*?identify.*?myself.*?as.*?Notion.*?AI.*?as.*?mentioned.*?in.*?the.*?system.*?prompt.*?\s*', '', content, flags=re.IGNORECASE | re.DOTALL)
297
  content = re.sub(r'^.*?I.*?should.*?not.*?make.*?specific.*?claims.*?about.*?the.*?underlying.*?model.*?architecture.*?since.*?that.*?information.*?is.*?not.*?provided.*?in.*?my.*?context\.\s*', '', content, flags=re.IGNORECASE | re.DOTALL)
298
+
299
  return content.strip()
300
 
 
 
 
301
  def _parse_ndjson_line_to_texts(self, line: bytes) -> List[Tuple[str, str]]:
302
  results: List[Tuple[str, str]] = []
303
  try:
304
  s = line.decode("utf-8", errors="ignore").strip()
305
+ if not s: return results
306
+
 
307
  data = json.loads(s)
308
  logger.debug(f"原始响应数据: {json.dumps(data, ensure_ascii=False)}")
309
+
310
  # 格式1: Gemini 返回的 markdown-chat 事件
311
  if data.get("type") == "markdown-chat":
312
  content = data.get("value", "")
 
314
  logger.info("从 'markdown-chat' 直接事件中提取到内容。")
315
  results.append(('final', content))
316
 
317
+ # 格式2: ClaudeGPT 返回的补丁流,以及 Gemini 的 patch 格式
318
  elif data.get("type") == "patch" and "v" in data:
319
  for operation in data.get("v", []):
320
+ if not isinstance(operation, dict): continue
321
+
322
  op_type = operation.get("o")
323
  path = operation.get("p", "")
324
  value = operation.get("v")
325
+
326
+ # 【修改】Gemini 完整内容 patch 格式
327
  if op_type == "a" and path.endswith("/s/-") and isinstance(value, dict) and value.get("type") == "markdown-chat":
328
  content = value.get("value", "")
329
  if content:
330
  logger.info("从 'patch' (Gemini-style) 中提取到完整内容。")
331
  results.append(('final', content))
332
+
333
+ # 【修改】Gemini 增量内容 patch 格式
334
  elif op_type == "x" and "/s/" in path and path.endswith("/value") and isinstance(value, str):
335
  content = value
336
  if content:
337
  logger.info(f"从 'patch' (Gemini增量) 中提取到内容: {content}")
338
  results.append(('incremental', content))
339
+
340
+ # 【修改】Claude GPT 增量内容 patch 格式
341
  elif op_type == "x" and "/value/" in path and isinstance(value, str):
342
  content = value
343
  if content:
344
  logger.info(f"从 'patch' (Claude/GPT增量) 中提取到内容: {content}")
345
  results.append(('incremental', content))
346
+
347
+ # 【修改】Claude GPT 完整内容 patch 格式
348
  elif op_type == "a" and path.endswith("/value/-") and isinstance(value, dict) and value.get("type") == "text":
349
  content = value.get("content", "")
350
  if content:
351
  logger.info("从 'patch' (Claude/GPT-style) 中提取到完整内容。")
352
  results.append(('final', content))
353
 
354
+ # 格式3: 处理record-map类型的数据
355
  elif data.get("type") == "record-map" and "recordMap" in data:
356
  record_map = data["recordMap"]
357
  if "thread_message" in record_map:
358
+ for msg_id, msg_data in record_map["thread_message"].items():
359
  value_data = msg_data.get("value", {}).get("value", {})
360
  step = value_data.get("step", {})
361
+ if not step: continue
 
362
 
363
  content = ""
364
  step_type = step.get("type")
 
372
  if isinstance(item, dict) and item.get("type") == "text":
373
  content = item.get("content", "")
374
  break
375
+
376
  if content and isinstance(content, str):
377
  logger.info(f"从 record-map (type: {step_type}) 提取到最终内容。")
378
  results.append(('final', content))
379
+ break
380
+
381
  except (json.JSONDecodeError, AttributeError) as e:
382
  logger.warning(f"解析NDJSON行失败: {e} - Line: {line.decode('utf-8', errors='ignore')}")
383
+
384
  return results
385
 
 
 
 
386
  async def get_models(self) -> JSONResponse:
387
  model_data = {
388
  "object": "list",
 
394
  return JSONResponse(content=model_data)
395
 
396
 
397
+