meyosaj406 commited on
Commit
70b6127
·
verified ·
1 Parent(s): de9fd96

Upload 3 files

Browse files
app/providers/__init__.py ADDED
File without changes
app/providers/base_provider.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import Dict, Any, Union
3
+ from fastapi.responses import StreamingResponse, JSONResponse
4
+
5
+ class BaseProvider(ABC):
6
+ @abstractmethod
7
+ async def chat_completion(
8
+ self,
9
+ request_data: Dict[str, Any]
10
+ ) -> Union[StreamingResponse, JSONResponse]:
11
+ pass
12
+
13
+ @abstractmethod
14
+ async def get_models(self) -> JSONResponse:
15
+ pass
app/providers/notion_provider.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/providers/notion_provider.py
2
+ import json
3
+ import time
4
+ import logging
5
+ import uuid
6
+ import re
7
+ import cloudscraper
8
+ from typing import Dict, Any, AsyncGenerator, List, Optional, Tuple
9
+ from datetime import datetime
10
+
11
+ from fastapi import HTTPException
12
+ from fastapi.responses import StreamingResponse, JSONResponse
13
+ from fastapi.concurrency import run_in_threadpool
14
+
15
+ from app.core.config import settings
16
+ from app.providers.base_provider import BaseProvider
17
+ from app.utils.sse_utils import create_sse_data, create_chat_completion_chunk, DONE_CHUNK
18
+
19
+ # 设置日志记录器
20
+ logger = logging.getLogger(__name__)
21
+
22
+ class NotionAIProvider(BaseProvider):
23
+ def __init__(self):
24
+ self.scraper = cloudscraper.create_scraper()
25
+ self.api_endpoints = {
26
+ "runInference": "https://www.notion.so/api/v3/runInferenceTranscript",
27
+ "saveTransactions": "https://www.notion.so/api/v3/saveTransactionsFanout"
28
+ }
29
+
30
+ if not all([settings.NOTION_COOKIE, settings.NOTION_SPACE_ID, settings.NOTION_USER_ID]):
31
+ raise ValueError("配置错误: NOTION_COOKIE, NOTION_SPACE_ID 和 NOTION_USER_ID 必须在 .env 文件中全部设置。")
32
+
33
+ self._warmup_session()
34
+
35
+ def _warmup_session(self):
36
+ try:
37
+ logger.info("正在进行会话预热 (Session Warm-up)...")
38
+ headers = self._prepare_headers()
39
+ headers.pop("Accept", None)
40
+ response = self.scraper.get("https://www.notion.so/", headers=headers, timeout=30)
41
+ response.raise_for_status()
42
+ logger.info("会话预热成功。")
43
+ except Exception as e:
44
+ logger.error(f"会话预热失败: {e}", exc_info=True)
45
+
46
+ async def _create_thread(self, thread_type: str) -> str:
47
+ thread_id = str(uuid.uuid4())
48
+ payload = {
49
+ "requestId": str(uuid.uuid4()),
50
+ "transactions": [{
51
+ "id": str(uuid.uuid4()),
52
+ "spaceId": settings.NOTION_SPACE_ID,
53
+ "operations": [{
54
+ "pointer": {"table": "thread", "id": thread_id, "spaceId": settings.NOTION_SPACE_ID},
55
+ "path": [],
56
+ "command": "set",
57
+ "args": {
58
+ "id": thread_id, "version": 1, "parent_id": settings.NOTION_SPACE_ID,
59
+ "parent_table": "space", "space_id": settings.NOTION_SPACE_ID,
60
+ "created_time": int(time.time() * 1000),
61
+ "created_by_id": settings.NOTION_USER_ID, "created_by_table": "notion_user",
62
+ "messages": [], "data": {}, "alive": True, "type": thread_type
63
+ }
64
+ }]
65
+ }]
66
+ }
67
+ try:
68
+ logger.info(f"正在创建新的对话线程 (type: {thread_type})...")
69
+ response = await run_in_threadpool(
70
+ lambda: self.scraper.post(
71
+ self.api_endpoints["saveTransactions"],
72
+ headers=self._prepare_headers(),
73
+ json=payload,
74
+ timeout=20
75
+ )
76
+ )
77
+ response.raise_for_status()
78
+ logger.info(f"对话线程创建成功, Thread ID: {thread_id}")
79
+ return thread_id
80
+ except Exception as e:
81
+ logger.error(f"创建对话线程失败: {e}", exc_info=True)
82
+ raise Exception("无法创建新的对话线程。")
83
+
84
+ async def chat_completion(self, request_data: Dict[str, Any]):
85
+ stream = request_data.get("stream", True)
86
+
87
+ async def stream_generator() -> AsyncGenerator[bytes, None]:
88
+ request_id = f"chatcmpl-{uuid.uuid4()}"
89
+ incremental_fragments: List[str] = []
90
+ final_message: Optional[str] = None
91
+
92
+ try:
93
+ model_name = request_data.get("model", settings.DEFAULT_MODEL)
94
+ mapped_model = settings.MODEL_MAP.get(model_name, "anthropic-sonnet-alt")
95
+
96
+ thread_type = "markdown-chat" if mapped_model.startswith("vertex-") else "workflow"
97
+
98
+ thread_id = await self._create_thread(thread_type)
99
+ payload = self._prepare_payload(request_data, thread_id, mapped_model, thread_type)
100
+ headers = self._prepare_headers()
101
+
102
+ role_chunk = create_chat_completion_chunk(request_id, model_name, role="assistant")
103
+ yield create_sse_data(role_chunk)
104
+
105
+ def sync_stream_iterator():
106
+ try:
107
+ logger.info(f"请求 Notion AI URL: {self.api_endpoints['runInference']}")
108
+ logger.info(f"请求体: {json.dumps(payload, indent=2, ensure_ascii=False)}")
109
+
110
+ response = self.scraper.post(
111
+ self.api_endpoints['runInference'], headers=headers, json=payload, stream=True,
112
+ timeout=settings.API_REQUEST_TIMEOUT
113
+ )
114
+ response.raise_for_status()
115
+ for line in response.iter_lines():
116
+ if line:
117
+ yield line
118
+ except Exception as e:
119
+ yield e
120
+
121
+ sync_gen = sync_stream_iterator()
122
+
123
+ while True:
124
+ line = await run_in_threadpool(lambda: next(sync_gen, None))
125
+ if line is None:
126
+ break
127
+ if isinstance(line, Exception):
128
+ raise line
129
+
130
+ parsed_results = self._parse_ndjson_line_to_texts(line)
131
+ for text_type, content in parsed_results:
132
+ if text_type == 'final':
133
+ final_message = content
134
+ elif text_type == 'incremental':
135
+ incremental_fragments.append(content)
136
+
137
+ full_response = ""
138
+ if final_message:
139
+ full_response = final_message
140
+ logger.info(f"成功从 record-map 或 Gemini patch/event 中提取到最终消息。")
141
+ else:
142
+ full_response = "".join(incremental_fragments)
143
+ logger.info(f"使用拼接所有增量片段的方式获得最终消息。")
144
+
145
+ if full_response:
146
+ cleaned_response = self._clean_content(full_response)
147
+ logger.info(f"清洗后的最终响应: {cleaned_response}")
148
+ chunk = create_chat_completion_chunk(request_id, model_name, content=cleaned_response)
149
+ yield create_sse_data(chunk)
150
+ else:
151
+ logger.warning("警告: Notion 返回的数据流中未提取到任何有效文本。请检查您的 .env 配置是否全部正确且凭证有效。")
152
+
153
+ final_chunk = create_chat_completion_chunk(request_id, model_name, finish_reason="stop")
154
+ yield create_sse_data(final_chunk)
155
+ yield DONE_CHUNK
156
+
157
+ except Exception as e:
158
+ error_message = f"处理 Notion AI 流时发生意外错误: {str(e)}"
159
+ logger.error(error_message, exc_info=True)
160
+ error_chunk = {"error": {"message": error_message, "type": "internal_server_error"}}
161
+ yield create_sse_data(error_chunk)
162
+ yield DONE_CHUNK
163
+
164
+ if stream:
165
+ return StreamingResponse(stream_generator(), media_type="text/event-stream")
166
+ else:
167
+ raise HTTPException(status_code=400, detail="此端点当前仅支持流式响应 (stream=true)。")
168
+
169
+ def _prepare_headers(self) -> Dict[str, str]:
170
+ cookie_source = (settings.NOTION_COOKIE or "").strip()
171
+ cookie_header = cookie_source if "=" in cookie_source else f"token_v2={cookie_source}"
172
+
173
+ return {
174
+ "Content-Type": "application/json",
175
+ "Accept": "application/x-ndjson",
176
+ "Cookie": cookie_header,
177
+ "x-notion-space-id": settings.NOTION_SPACE_ID,
178
+ "x-notion-active-user-header": settings.NOTION_USER_ID,
179
+ "x-notion-client-version": settings.NOTION_CLIENT_VERSION,
180
+ "notion-audit-log-platform": "web",
181
+ "Origin": "https://www.notion.so",
182
+ "Referer": "https://www.notion.so/",
183
+ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/125.0.0.0 Safari/537.36",
184
+ }
185
+
186
+ def _normalize_block_id(self, block_id: str) -> str:
187
+ if not block_id: return block_id
188
+ b = block_id.replace("-", "").strip()
189
+ if len(b) == 32 and re.fullmatch(r"[0-9a-fA-F]{32}", b):
190
+ return f"{b[0:8]}-{b[8:12]}-{b[12:16]}-{b[16:20]}-{b[20:]}"
191
+ return block_id
192
+
193
+ def _prepare_payload(self, request_data: Dict[str, Any], thread_id: str, mapped_model: str, thread_type: str) -> Dict[str, Any]:
194
+ req_block_id = request_data.get("notion_block_id") or settings.NOTION_BLOCK_ID
195
+ normalized_block_id = self._normalize_block_id(req_block_id) if req_block_id else None
196
+
197
+ context_value: Dict[str, Any] = {
198
+ "timezone": "Asia/Shanghai",
199
+ "spaceId": settings.NOTION_SPACE_ID,
200
+ "userId": settings.NOTION_USER_ID,
201
+ "userEmail": settings.NOTION_USER_EMAIL,
202
+ "currentDatetime": datetime.now().astimezone().isoformat(),
203
+ }
204
+ if normalized_block_id:
205
+ context_value["blockId"] = normalized_block_id
206
+
207
+ config_value: Dict[str, Any]
208
+
209
+ if mapped_model.startswith("vertex-"):
210
+ logger.info(f"检测到 Gemini 模型 ({mapped_model}),应用特定的 config 和 context。")
211
+ context_value.update({
212
+ "userName": f" {settings.NOTION_USER_NAME}",
213
+ "spaceName": f"{settings.NOTION_USER_NAME}的 Notion",
214
+ "spaceViewId": "2008eefa-d0dc-80d5-9e67-000623befd8f",
215
+ "surface": "ai_module"
216
+ })
217
+ config_value = {
218
+ "type": thread_type,
219
+ "model": mapped_model,
220
+ "useWebSearch": True,
221
+ "enableAgentAutomations": False, "enableAgentIntegrations": False,
222
+ "enableBackgroundAgents": False, "enableCodegenIntegration": False,
223
+ "enableCustomAgents": False, "enableExperimentalIntegrations": False,
224
+ "enableLinkedDatabases": False, "enableAgentViewVersionHistoryTool": False,
225
+ "searchScopes": [{"type": "everything"}], "enableDatabaseAgents": False,
226
+ "enableAgentComments": False, "enableAgentForms": False,
227
+ "enableAgentMakesFormulas": False, "enableUserSessionContext": False,
228
+ "modelFromUser": True, "isCustomAgent": False
229
+ }
230
+ else:
231
+ context_value.update({
232
+ "userName": settings.NOTION_USER_NAME,
233
+ "surface": "workflows"
234
+ })
235
+ config_value = {
236
+ "type": thread_type,
237
+ "model": mapped_model,
238
+ "useWebSearch": True,
239
+ }
240
+
241
+ transcript = [
242
+ {"id": str(uuid.uuid4()), "type": "config", "value": config_value},
243
+ {"id": str(uuid.uuid4()), "type": "context", "value": context_value}
244
+ ]
245
+
246
+ for msg in request_data.get("messages", []):
247
+ if msg.get("role") == "user":
248
+ transcript.append({
249
+ "id": str(uuid.uuid4()),
250
+ "type": "user",
251
+ "value": [[msg.get("content")]],
252
+ "userId": settings.NOTION_USER_ID,
253
+ "createdAt": datetime.now().astimezone().isoformat()
254
+ })
255
+ elif msg.get("role") == "assistant":
256
+ transcript.append({"id": str(uuid.uuid4()), "type": "agent-inference", "value": [{"type": "text", "content": msg.get("content")}]})
257
+
258
+ payload = {
259
+ "traceId": str(uuid.uuid4()),
260
+ "spaceId": settings.NOTION_SPACE_ID,
261
+ "transcript": transcript,
262
+ "threadId": thread_id,
263
+ "createThread": False,
264
+ "isPartialTranscript": True,
265
+ "asPatchResponse": True,
266
+ "generateTitle": True,
267
+ "saveAllThreadOperations": True,
268
+ "threadType": thread_type
269
+ }
270
+
271
+ if mapped_model.startswith("vertex-"):
272
+ logger.info("为 Gemini 请求添加 debugOverrides。")
273
+ payload["debugOverrides"] = {
274
+ "emitAgentSearchExtractedResults": True,
275
+ "cachedInferences": {},
276
+ "annotationInferences": {},
277
+ "emitInferences": False
278
+ }
279
+
280
+ return payload
281
+
282
+ def _clean_content(self, content: str) -> str:
283
+ if not content:
284
+ return ""
285
+
286
+ content = re.sub(r'<lang primary="[^"]*"\s*/>\n*', '', content)
287
+ content = re.sub(r'<thinking>[\s\S]*?</thinking>\s*', '', content, flags=re.IGNORECASE)
288
+ content = re.sub(r'<thought>[\s\S]*?</thought>\s*', '', content, flags=re.IGNORECASE)
289
+
290
+ content = re.sub(r'^.*?Chinese whatmodel I am.*?Theyspecifically.*?requested.*?me.*?to.*?reply.*?in.*?Chinese\.\s*', '', content, flags=re.IGNORECASE | re.DOTALL)
291
+ content = re.sub(r'^.*?This.*?is.*?a.*?straightforward.*?question.*?about.*?my.*?identity.*?asan.*?AI.*?assistant\.\s*', '', content, flags=re.IGNORECASE | re.DOTALL)
292
+ content = re.sub(r'^.*?Idon\'t.*?need.*?to.*?use.*?any.*?tools.*?for.*?this.*?-\s*it\'s.*?asimple.*?informational.*?response.*?aboutwhat.*?I.*?am\.\s*', '', content, flags=re.IGNORECASE | re.DOTALL)
293
+ content = re.sub(r'^.*?Sincethe.*?user.*?asked.*?in.*?Chinese.*?and.*?specifically.*?requested.*?a.*?Chinese.*?response.*?I.*?should.*?respond.*?in.*?Chinese\.\s*', '', content, flags=re.IGNORECASE | re.DOTALL)
294
+ content = re.sub(r'^.*?What model are you.*?in Chinese and specifically requesting.*?me.*?to.*?reply.*?in.*?Chinese\.\s*', '', content, flags=re.IGNORECASE | re.DOTALL)
295
+ content = re.sub(r'^.*?This.*?is.*?a.*?question.*?about.*?my.*?identity.*?not requiring.*?any.*?tool.*?use.*?I.*?should.*?respond.*?directly.*?to.*?the.*?user.*?in.*?Chinese.*?as.*?requested\.\s*', '', content, flags=re.IGNORECASE | re.DOTALL)
296
+ content = re.sub(r'^.*?I.*?should.*?identify.*?myself.*?as.*?Notion.*?AI.*?as.*?mentioned.*?in.*?the.*?system.*?prompt.*?\s*', '', content, flags=re.IGNORECASE | re.DOTALL)
297
+ content = re.sub(r'^.*?I.*?should.*?not.*?make.*?specific.*?claims.*?about.*?the.*?underlying.*?model.*?architecture.*?since.*?that.*?information.*?is.*?not.*?provided.*?in.*?my.*?context\.\s*', '', content, flags=re.IGNORECASE | re.DOTALL)
298
+
299
+ return content.strip()
300
+
301
+ def _parse_ndjson_line_to_texts(self, line: bytes) -> List[Tuple[str, str]]:
302
+ results: List[Tuple[str, str]] = []
303
+ try:
304
+ s = line.decode("utf-8", errors="ignore").strip()
305
+ if not s: return results
306
+
307
+ data = json.loads(s)
308
+ logger.debug(f"原始响应数据: {json.dumps(data, ensure_ascii=False)}")
309
+
310
+ # 格式1: Gemini 返回的 markdown-chat 事件
311
+ if data.get("type") == "markdown-chat":
312
+ content = data.get("value", "")
313
+ if content:
314
+ logger.info("从 'markdown-chat' 直接事件中提取到内容。")
315
+ results.append(('final', content))
316
+
317
+ # 格式2: Claude 和 GPT 返回的补丁流,以及 Gemini 的 patch 格式
318
+ elif data.get("type") == "patch" and "v" in data:
319
+ for operation in data.get("v", []):
320
+ if not isinstance(operation, dict): continue
321
+
322
+ op_type = operation.get("o")
323
+ path = operation.get("p", "")
324
+ value = operation.get("v")
325
+
326
+ # 【修改】Gemini 的完整内容 patch 格式
327
+ if op_type == "a" and path.endswith("/s/-") and isinstance(value, dict) and value.get("type") == "markdown-chat":
328
+ content = value.get("value", "")
329
+ if content:
330
+ logger.info("从 'patch' (Gemini-style) 中提取到完整内容。")
331
+ results.append(('final', content))
332
+
333
+ # 【修改】Gemini 的增量内容 patch 格式
334
+ elif op_type == "x" and "/s/" in path and path.endswith("/value") and isinstance(value, str):
335
+ content = value
336
+ if content:
337
+ logger.info(f"从 'patch' (Gemini增量) 中提取到内容: {content}")
338
+ results.append(('incremental', content))
339
+
340
+ # 【修改】Claude 和 GPT 的增量内容 patch 格式
341
+ elif op_type == "x" and "/value/" in path and isinstance(value, str):
342
+ content = value
343
+ if content:
344
+ logger.info(f"从 'patch' (Claude/GPT增量) 中提取到内容: {content}")
345
+ results.append(('incremental', content))
346
+
347
+ # 【修改】Claude 和 GPT 的完整内容 patch 格式
348
+ elif op_type == "a" and path.endswith("/value/-") and isinstance(value, dict) and value.get("type") == "text":
349
+ content = value.get("content", "")
350
+ if content:
351
+ logger.info("从 'patch' (Claude/GPT-style) 中提取到完整内容。")
352
+ results.append(('final', content))
353
+
354
+ # 格式3: 处理record-map类型的数据
355
+ elif data.get("type") == "record-map" and "recordMap" in data:
356
+ record_map = data["recordMap"]
357
+ if "thread_message" in record_map:
358
+ for msg_id, msg_data in record_map["thread_message"].items():
359
+ value_data = msg_data.get("value", {}).get("value", {})
360
+ step = value_data.get("step", {})
361
+ if not step: continue
362
+
363
+ content = ""
364
+ step_type = step.get("type")
365
+
366
+ if step_type == "markdown-chat":
367
+ content = step.get("value", "")
368
+ elif step_type == "agent-inference":
369
+ agent_values = step.get("value", [])
370
+ if isinstance(agent_values, list):
371
+ for item in agent_values:
372
+ if isinstance(item, dict) and item.get("type") == "text":
373
+ content = item.get("content", "")
374
+ break
375
+
376
+ if content and isinstance(content, str):
377
+ logger.info(f"从 record-map (type: {step_type}) 提取到最终内容。")
378
+ results.append(('final', content))
379
+ break
380
+
381
+ except (json.JSONDecodeError, AttributeError) as e:
382
+ logger.warning(f"解析NDJSON行失败: {e} - Line: {line.decode('utf-8', errors='ignore')}")
383
+
384
+ return results
385
+
386
+ async def get_models(self) -> JSONResponse:
387
+ model_data = {
388
+ "object": "list",
389
+ "data": [
390
+ {"id": name, "object": "model", "created": int(time.time()), "owned_by": "lzA6"}
391
+ for name in settings.KNOWN_MODELS
392
+ ]
393
+ }
394
+ return JSONResponse(content=model_data)