Spaces:
Paused
Paused
| import os | |
| import uuid | |
| import json | |
| import time | |
| import asyncio | |
| import random | |
| import threading | |
| from curl_cffi.requests import AsyncSession | |
| from fastapi import FastAPI, Request, HTTPException, Depends, status | |
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
| from fastapi.responses import StreamingResponse | |
| from dotenv import load_dotenv | |
| import secrets | |
| from pydantic import BaseModel, Field | |
| from typing import List, Optional, Dict, Any, Literal, Union | |
| from contextlib import asynccontextmanager | |
| # Load environment variables from .env file | |
| load_dotenv() | |
| # --- 并发请求配置 --- | |
| CONCURRENT_REQUESTS = 1 # 可自定义并发请求数量 | |
| # --- 重试配置 --- | |
| MAX_RETRIES = 3 | |
| RETRY_DELAY = 1 # 秒 | |
| # --- Models (Integrated from models.py) --- | |
| # Input Models (OpenAI-like) | |
| class ChatMessage(BaseModel): | |
| role: Literal["system", "user", "assistant"] | |
| content: str | |
| class ChatCompletionRequest(BaseModel): | |
| messages: List[ChatMessage] | |
| model: str = "notion-proxy" | |
| stream: bool = False | |
| notion_model: str = "anthropic-opus-4" | |
| # Notion Models | |
| class NotionTranscriptConfigValue(BaseModel): | |
| type: str = "markdown-chat" | |
| model: str # e.g., "anthropic-opus-4" | |
| class NotionTranscriptItem(BaseModel): | |
| type: Literal["config", "user", "markdown-chat"] | |
| value: Union[List[List[str]], str, NotionTranscriptConfigValue] | |
| class NotionDebugOverrides(BaseModel): | |
| cachedInferences: Dict = Field(default_factory=dict) | |
| annotationInferences: Dict = Field(default_factory=dict) | |
| emitInferences: bool = False | |
| class NotionRequestBody(BaseModel): | |
| traceId: str = Field(default_factory=lambda: str(uuid.uuid4())) | |
| spaceId: str | |
| transcript: List[NotionTranscriptItem] | |
| # threadId is removed, createThread will be set to true | |
| createThread: bool = True | |
| debugOverrides: NotionDebugOverrides = Field(default_factory=NotionDebugOverrides) | |
| generateTitle: bool = False | |
| saveAllThreadOperations: bool = True | |
| # Output Models (OpenAI SSE) | |
| class ChoiceDelta(BaseModel): | |
| content: Optional[str] = None | |
| class Choice(BaseModel): | |
| index: int = 0 | |
| delta: ChoiceDelta | |
| finish_reason: Optional[Literal["stop", "length"]] = None | |
| class ChatCompletionChunk(BaseModel): | |
| id: str = Field(default_factory=lambda: f"chatcmpl-{uuid.uuid4()}") | |
| object: str = "chat.completion.chunk" | |
| created: int = Field(default_factory=lambda: int(time.time())) | |
| model: str = "notion-proxy" # Or could reflect the underlying Notion model | |
| choices: List[Choice] | |
| # Models for /v1/models Endpoint | |
| class Model(BaseModel): | |
| id: str | |
| object: str = "model" | |
| created: int = Field(default_factory=lambda: int(time.time())) | |
| owned_by: str = "notion" # Or specify based on actual model origin if needed | |
| class ModelList(BaseModel): | |
| object: str = "list" | |
| data: List[Model] | |
| # --- Configuration --- | |
| NOTION_API_URL = "https://www.notion.so/api/v3/runInferenceTranscript" | |
| # IMPORTANT: Load the Notion cookie securely from environment variables | |
| NOTION_COOKIE = os.getenv("NOTION_COOKIE") | |
| NOTION_SPACE_ID = os.getenv("NOTION_SPACE_ID") | |
| if not NOTION_COOKIE: | |
| print("Error: NOTION_COOKIE environment variable not set.") | |
| # Consider raising HTTPException or exiting in a real app | |
| if not NOTION_SPACE_ID: | |
| print("Warning: NOTION_SPACE_ID environment variable not set. Using a default UUID.") | |
| # Using a default might not be ideal, depends on Notion's behavior | |
| # Consider raising an error instead: raise ValueError("NOTION_SPACE_ID not set") | |
| NOTION_SPACE_ID = str(uuid.uuid4()) # Default or raise error | |
| # --- Cookie Management --- | |
| browser_cookies = "" | |
| cookie_lock = threading.Lock() | |
| last_cookie_update = 0 | |
| COOKIE_UPDATE_INTERVAL = 30 * 60 # 30 minutes in seconds | |
| async def get_browser_cookies(): | |
| """获取Notion网站的浏览器cookie""" | |
| global browser_cookies, last_cookie_update | |
| try: | |
| print("正在获取Notion浏览器cookie...") | |
| async with AsyncSession(impersonate="chrome136") as session: | |
| response = await session.get("https://www.notion.so") | |
| if response.status_code == 200: | |
| # 获取所有cookie | |
| cookies = response.cookies | |
| notion_so_cookies = [] | |
| # 处理CookieConflict问题,只获取.notion.so域名的cookie | |
| try: | |
| # 尝试通过域名过滤来避免冲突 | |
| if hasattr(cookies, 'get_dict'): | |
| # 使用get_dict方法并指定域名 | |
| notion_so_dict = cookies.get_dict(domain='.notion.so') | |
| for name, value in notion_so_dict.items(): | |
| notion_so_cookies.append(f"{name}={value}") | |
| elif hasattr(cookies, 'jar'): | |
| # 如果cookies有jar属性,遍历并过滤域名 | |
| for cookie in cookies.jar: | |
| if hasattr(cookie, 'domain') and cookie.domain: | |
| if '.notion.so' in cookie.domain and '.notion.com' not in cookie.domain: | |
| notion_so_cookies.append(f"{cookie.name}={cookie.value}") | |
| else: | |
| # 尝试手动构建cookie字符串,避免冲突 | |
| # 直接从响应头中提取Set-Cookie信息 | |
| set_cookie_headers = response.headers.get_list('Set-Cookie') if hasattr(response.headers, 'get_list') else [] | |
| if not set_cookie_headers and 'Set-Cookie' in response.headers: | |
| set_cookie_headers = [response.headers['Set-Cookie']] | |
| for cookie_header in set_cookie_headers: | |
| if 'domain=.notion.so' in cookie_header or ('notion.so' in cookie_header and 'notion.com' not in cookie_header): | |
| # 提取cookie名称和值 | |
| cookie_parts = cookie_header.split(';')[0].strip() | |
| if '=' in cookie_parts: | |
| notion_so_cookies.append(cookie_parts) | |
| # 如果还是没有获取到,尝试使用requests-like的方式 | |
| if not notion_so_cookies and hasattr(response, 'cookies'): | |
| try: | |
| # 遍历所有cookie,手动过滤 | |
| for cookie in response.cookies: | |
| if hasattr(cookie, 'domain') and cookie.domain and '.notion.so' in cookie.domain: | |
| notion_so_cookies.append(f"{cookie.name}={cookie.value}") | |
| except Exception as inner_e: | |
| print(f"内部cookie处理错误: {inner_e}") | |
| except Exception as cookie_error: | |
| print(f"处理cookie时出现错误: {cookie_error}") | |
| # 如果所有方法都失败,尝试从session获取 | |
| if hasattr(session, 'cookies'): | |
| try: | |
| for name, value in session.cookies.items(): | |
| notion_so_cookies.append(f"{name}={value}") | |
| except: | |
| pass | |
| # 添加环境变量中的cookie,加上token_v2前缀 | |
| if NOTION_COOKIE: | |
| notion_so_cookies.append(f"token_v2={NOTION_COOKIE}") | |
| # 如果没有获取到任何cookie,至少使用环境变量的 | |
| if not notion_so_cookies and NOTION_COOKIE: | |
| notion_so_cookies = [f"token_v2={NOTION_COOKIE}"] | |
| with cookie_lock: | |
| browser_cookies = "; ".join(notion_so_cookies) | |
| last_cookie_update = time.time() | |
| # 提取cookie名称用于日志显示 | |
| cookie_names = [] | |
| for cookie_str in notion_so_cookies: | |
| if '=' in cookie_str: | |
| name = cookie_str.split('=')[0] | |
| cookie_names.append(name) | |
| print(f"成功获取到 {len(notion_so_cookies)} 个cookie") | |
| print(f"Cookie名称列表: {', '.join(cookie_names)}") | |
| return True | |
| else: | |
| print(f"获取cookie失败,HTTP状态码: {response.status_code}") | |
| return False | |
| except Exception as e: | |
| print(f"获取browser cookie时出错: {e}") | |
| print(f"错误详情: {type(e).__name__}: {str(e)}") | |
| # 如果完全失败,至少使用环境变量的cookie | |
| if NOTION_COOKIE: | |
| with cookie_lock: | |
| browser_cookies = f"token_v2={NOTION_COOKIE}" | |
| last_cookie_update = time.time() | |
| print("使用环境变量cookie作为备用") | |
| return True | |
| return False | |
| def should_update_cookies(): | |
| """检查是否需要更新cookie""" | |
| return time.time() - last_cookie_update > COOKIE_UPDATE_INTERVAL | |
| async def ensure_cookies_available(): | |
| """确保cookie可用,如果需要则更新""" | |
| global browser_cookies | |
| if not browser_cookies or should_update_cookies(): | |
| success = await get_browser_cookies() | |
| if not success and not browser_cookies: | |
| # 如果获取失败且没有备用cookie,使用环境变量的cookie | |
| if NOTION_COOKIE: | |
| with cookie_lock: | |
| browser_cookies = f"token_v2={NOTION_COOKIE}" | |
| print("使用环境变量cookie作为备用") | |
| else: | |
| raise HTTPException(status_code=500, detail="无法获取Notion cookie") | |
| def start_cookie_updater(): | |
| """启动cookie定时更新器""" | |
| def cookie_updater(): | |
| loop = asyncio.new_event_loop() | |
| asyncio.set_event_loop(loop) | |
| while True: | |
| try: | |
| if should_update_cookies(): | |
| print("开始定时更新cookie...") | |
| loop.run_until_complete(get_browser_cookies()) | |
| time.sleep(60) # 每分钟检查一次 | |
| except Exception as e: | |
| print(f"定时更新cookie时出错: {e}") | |
| time.sleep(60) | |
| thread = threading.Thread(target=cookie_updater, daemon=True) | |
| thread.start() | |
| print("cookie定时更新器已启动") | |
| # --- Authentication --- | |
| EXPECTED_TOKEN = os.getenv("PROXY_AUTH_TOKEN", "default_token") # Default token | |
| security = HTTPBearer() | |
| def authenticate(credentials: HTTPAuthorizationCredentials = Depends(security)): | |
| """Compares provided token with the expected token.""" | |
| correct_token = secrets.compare_digest(credentials.credentials, EXPECTED_TOKEN) | |
| if not correct_token: | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Invalid authentication credentials", | |
| # WWW-Authenticate header removed for Bearer | |
| ) | |
| return True # Indicate successful authentication | |
| # --- Lifespan Event Handler --- | |
| async def lifespan(app: FastAPI): | |
| """应用生命周期管理""" | |
| # 启动时的初始化 | |
| print("正在初始化Notion浏览器cookie...") | |
| await get_browser_cookies() | |
| # 启动cookie定时更新器 | |
| start_cookie_updater() | |
| yield | |
| # 关闭时的清理(如果需要) | |
| # --- FastAPI App --- | |
| app = FastAPI(lifespan=lifespan) | |
| # --- Helper Functions --- | |
| def build_notion_request(request_data: ChatCompletionRequest) -> NotionRequestBody: | |
| """Transforms OpenAI-style messages to Notion transcript format.""" | |
| transcript = [ | |
| NotionTranscriptItem( | |
| type="config", | |
| value=NotionTranscriptConfigValue(model=request_data.notion_model) | |
| ) | |
| ] | |
| for message in request_data.messages: | |
| # Map 'assistant' role to 'markdown-chat', all others to 'user' | |
| if message.role == "assistant": | |
| # Notion uses "markdown-chat" for assistant replies in the transcript history | |
| transcript.append(NotionTranscriptItem(type="markdown-chat", value=message.content)) | |
| else: | |
| # Map user, system, and any other potential roles to 'user' | |
| transcript.append(NotionTranscriptItem(type="user", value=[[message.content]])) | |
| # Use globally configured spaceId, set createThread=True | |
| return NotionRequestBody( | |
| spaceId=NOTION_SPACE_ID, # From environment variable | |
| transcript=transcript, | |
| createThread=True, # Always create a new thread | |
| # Generate a new traceId for each request | |
| traceId=str(uuid.uuid4()), | |
| # Explicitly set debugOverrides, generateTitle, and saveAllThreadOperations | |
| debugOverrides=NotionDebugOverrides( | |
| cachedInferences={}, | |
| annotationInferences={}, | |
| emitInferences=False | |
| ), | |
| generateTitle=False, | |
| saveAllThreadOperations=False | |
| ) | |
| async def check_first_response_line(session: AsyncSession, notion_request_body: NotionRequestBody, headers: dict, request_id: int): | |
| """检查响应的第一行,判断是否为500错误""" | |
| try: | |
| # 当并发请求数大于1时,添加随机延迟以避免同时到达 | |
| if CONCURRENT_REQUESTS > 1: | |
| delay = random.uniform(0, 1.0) | |
| print(f"并发请求 {request_id} 延迟 {delay:.2f}秒") | |
| await asyncio.sleep(delay) | |
| # 为每个并发请求创建独立的请求体,生成新的traceId | |
| request_body_copy = notion_request_body.model_copy() | |
| request_body_copy.traceId = str(uuid.uuid4()) | |
| response = await session.post( | |
| NOTION_API_URL, | |
| json=request_body_copy.model_dump(), | |
| headers=headers, | |
| stream=True | |
| ) | |
| if response.status_code != 200: | |
| return None, response, f"HTTP {response.status_code}" | |
| # 读取第一行来检查是否是错误 | |
| buffer = "" | |
| async for chunk in response.aiter_content(): | |
| if isinstance(chunk, bytes): | |
| chunk = chunk.decode('utf-8') | |
| buffer += chunk | |
| # 尝试解析第一个完整的JSON行 | |
| lines = buffer.split('\n') | |
| for line in lines: | |
| line = line.strip() | |
| if line: | |
| try: | |
| data = json.loads(line) | |
| if (data.get("type") == "error" and | |
| data.get("message") and | |
| "error code 500" in data.get("message", "")): | |
| print(f"并发请求 {request_id} 检测到500错误: {data}") | |
| return None, response, "500 error" | |
| else: | |
| # 正常响应,返回response和已读取的buffer | |
| print(f"并发请求 {request_id} 响应正常") | |
| return (response, buffer), None, None | |
| except json.JSONDecodeError: | |
| continue | |
| return None, response, "No valid response" | |
| except Exception as e: | |
| print(f"并发请求 {request_id} 发生异常: {e}") | |
| return None, None, str(e) | |
| async def stream_notion_response_single(session: AsyncSession, response, initial_buffer: str, chunk_id: str, created_time: int): | |
| """处理单个响应的流式输出""" | |
| buffer = initial_buffer | |
| # 首先处理已经读取的buffer中的内容 | |
| lines = buffer.split('\n') | |
| buffer = lines[-1] | |
| for line in lines[:-1]: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| try: | |
| data = json.loads(line) | |
| if data.get("type") == "markdown-chat" and isinstance(data.get("value"), str): | |
| content_chunk = data["value"] | |
| if content_chunk: | |
| chunk_obj = ChatCompletionChunk( | |
| id=chunk_id, | |
| created=created_time, | |
| choices=[Choice(delta=ChoiceDelta(content=content_chunk))] | |
| ) | |
| yield f"data: {chunk_obj.model_dump_json()}\n\n" | |
| elif "recordMap" in data: | |
| print("Detected recordMap, stopping stream.") | |
| # 继续处理剩余的buffer | |
| if buffer.strip(): | |
| try: | |
| last_data = json.loads(buffer.strip()) | |
| if last_data.get("type") == "markdown-chat" and isinstance(last_data.get("value"), str): | |
| if last_data["value"]: | |
| last_chunk = ChatCompletionChunk( | |
| id=chunk_id, | |
| created=created_time, | |
| choices=[Choice(delta=ChoiceDelta(content=last_data["value"]))] | |
| ) | |
| yield f"data: {last_chunk.model_dump_json()}\n\n" | |
| except: | |
| pass | |
| return | |
| except json.JSONDecodeError as e: | |
| print(f"Warning: Could not decode JSON line: {line[:100]}... Error: {str(e)}") | |
| except Exception as e: | |
| print(f"Error processing line: {str(e)}") | |
| # 继续读取剩余的响应 | |
| async for chunk in response.aiter_content(): | |
| if isinstance(chunk, bytes): | |
| chunk = chunk.decode('utf-8') | |
| buffer += chunk | |
| lines = buffer.split('\n') | |
| buffer = lines[-1] | |
| for line in lines[:-1]: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| try: | |
| data = json.loads(line) | |
| if data.get("type") == "markdown-chat" and isinstance(data.get("value"), str): | |
| content_chunk = data["value"] | |
| if content_chunk: | |
| chunk_obj = ChatCompletionChunk( | |
| id=chunk_id, | |
| created=created_time, | |
| choices=[Choice(delta=ChoiceDelta(content=content_chunk))] | |
| ) | |
| yield f"data: {chunk_obj.model_dump_json()}\n\n" | |
| elif "recordMap" in data: | |
| print("Detected recordMap, stopping stream.") | |
| if buffer.strip(): | |
| try: | |
| last_data = json.loads(buffer.strip()) | |
| if last_data.get("type") == "markdown-chat" and isinstance(last_data.get("value"), str): | |
| if last_data["value"]: | |
| last_chunk = ChatCompletionChunk( | |
| id=chunk_id, | |
| created=created_time, | |
| choices=[Choice(delta=ChoiceDelta(content=last_data["value"]))] | |
| ) | |
| yield f"data: {last_chunk.model_dump_json()}\n\n" | |
| except: | |
| pass | |
| return | |
| except json.JSONDecodeError as e: | |
| print(f"Warning: Could not decode JSON line: {line[:100]}... Error: {str(e)}") | |
| except Exception as e: | |
| print(f"Error processing line: {str(e)}") | |
| async def stream_notion_response(notion_request_body: NotionRequestBody): | |
| """Streams the request to Notion and yields OpenAI-compatible SSE chunks.""" | |
| # 确保cookie可用 | |
| await ensure_cookies_available() | |
| with cookie_lock: | |
| current_cookies = browser_cookies | |
| headers = { | |
| 'accept': 'application/x-ndjson', | |
| 'accept-encoding': 'gzip, deflate, br, zstd', | |
| 'accept-language': 'en-US,zh;q=0.9', | |
| 'content-type': 'application/json', | |
| 'dnt': '1', | |
| 'notion-audit-log-platform': 'web', | |
| 'notion-client-version': '23.13.0.3661', | |
| 'origin': 'https://www.notion.so', | |
| 'referer': 'https://www.notion.so/', | |
| 'priority': 'u=1, i', | |
| 'sec-ch-ua-mobile': '?0', | |
| 'sec-ch-ua-platform': '"Windows"', | |
| 'sec-fetch-dest': 'empty', | |
| 'sec-fetch-mode': 'cors', | |
| 'sec-fetch-site': 'same-origin', | |
| 'cookie': current_cookies, | |
| 'x-notion-space-id': NOTION_SPACE_ID | |
| } | |
| # Conditionally add the active user header | |
| notion_active_user = os.getenv("NOTION_ACTIVE_USER_HEADER") | |
| if notion_active_user: # Checks for None and empty string implicitly | |
| headers['x-notion-active-user-header'] = notion_active_user | |
| chunk_id = f"chatcmpl-{uuid.uuid4()}" | |
| created_time = int(time.time()) | |
| # 使用全局重试配置 | |
| max_retries = MAX_RETRIES | |
| retry_delay = RETRY_DELAY | |
| # 首先尝试并发请求 | |
| print(f"同时发起 {CONCURRENT_REQUESTS} 个并发请求...") | |
| async with AsyncSession(impersonate="chrome136") as session: | |
| # 同时创建并发任务(每个都是独立的异步任务) | |
| tasks = [] | |
| for i in range(CONCURRENT_REQUESTS): | |
| task = asyncio.create_task( | |
| check_first_response_line(session, notion_request_body, headers, i + 1) | |
| ) | |
| tasks.append(task) | |
| # 等待所有任务完成或找到第一个成功的响应 | |
| successful_response = None | |
| failed_count = 0 | |
| completed_tasks = set() | |
| while len(completed_tasks) < CONCURRENT_REQUESTS and not successful_response: | |
| # 等待任意一个任务完成 | |
| done, pending = await asyncio.wait( | |
| [t for t in tasks if t not in completed_tasks], | |
| return_when=asyncio.FIRST_COMPLETED | |
| ) | |
| for task in done: | |
| completed_tasks.add(task) | |
| result, response, error = await task | |
| if result: | |
| # 找到成功的响应,立即使用 | |
| successful_response = result | |
| print(f"找到成功的并发响应,立即使用") | |
| # 取消其他还在运行的任务 | |
| for t in tasks: | |
| if t not in completed_tasks: | |
| t.cancel() | |
| break | |
| else: | |
| # 记录失败 | |
| failed_count += 1 | |
| if error: | |
| print(f"并发请求失败: {error}") | |
| # 如果有成功的响应,使用它进行流式传输 | |
| if successful_response: | |
| response, initial_buffer = successful_response | |
| print("使用成功的并发响应进行流式传输") | |
| # 流式输出响应 | |
| async for data in stream_notion_response_single(session, response, initial_buffer, chunk_id, created_time): | |
| yield data | |
| # Send the final chunk indicating stop | |
| final_chunk = ChatCompletionChunk( | |
| id=chunk_id, | |
| created=created_time, | |
| choices=[Choice(delta=ChoiceDelta(), finish_reason="stop")] | |
| ) | |
| yield f"data: {final_chunk.model_dump_json()}\n\n" | |
| yield "data: [DONE]\n\n" | |
| return | |
| # 只有当所有并发请求都失败时,才进入重试流程 | |
| print(f"所有 {CONCURRENT_REQUESTS} 个并发请求都失败,开始单请求重试流程...") | |
| # 进入原有的重试逻辑(不使用并发) | |
| for attempt in range(max_retries): | |
| try: | |
| # Using curl_cffi with chrome136 impersonation for better anti-bot bypass | |
| async with AsyncSession(impersonate="chrome136") as session: | |
| # Stream the response | |
| response = await session.post( | |
| NOTION_API_URL, | |
| json=notion_request_body.model_dump(), | |
| headers=headers, | |
| stream=True | |
| ) | |
| if response.status_code != 200: | |
| error_content = await response.atext() | |
| print(f"Error from Notion API: {response.status_code}") | |
| print(f"Response: {error_content}") | |
| raise HTTPException(status_code=response.status_code, detail=f"Notion API Error: {error_content}") | |
| # Process streaming response | |
| # curl_cffi streaming works differently - we need to read the content in chunks | |
| buffer = "" | |
| first_line_checked = False | |
| is_error_response = False | |
| async for chunk in response.aiter_content(): | |
| # Decode chunk if it's bytes | |
| if isinstance(chunk, bytes): | |
| chunk = chunk.decode('utf-8') | |
| buffer += chunk | |
| # Split by newlines and process complete lines | |
| lines = buffer.split('\n') | |
| # Keep the last incomplete line in the buffer | |
| buffer = lines[-1] | |
| for line in lines[:-1]: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| try: | |
| data = json.loads(line) | |
| # 检查第一行是否是500错误响应 | |
| if not first_line_checked: | |
| first_line_checked = True | |
| if (data.get("type") == "error" and | |
| data.get("message") and | |
| "error code 500" in data.get("message", "")): | |
| print(f"检测到Notion API 500错误 (重试 {attempt + 1}/{max_retries}): {data}") | |
| is_error_response = True | |
| break | |
| # 如果不是错误响应,实时流式转发 | |
| # Check if it's the type of message containing text chunks | |
| if data.get("type") == "markdown-chat" and isinstance(data.get("value"), str): | |
| content_chunk = data["value"] | |
| if content_chunk: # Only send if there's content | |
| chunk_obj = ChatCompletionChunk( | |
| id=chunk_id, | |
| created=created_time, | |
| choices=[Choice(delta=ChoiceDelta(content=content_chunk))] | |
| ) | |
| yield f"data: {chunk_obj.model_dump_json()}\n\n" | |
| # Add logic here to detect the end of the stream if Notion has a specific marker | |
| # For now, we assume markdown-chat stops when the main content is done. | |
| # If we see a recordMap, it's definitely past the text stream. | |
| elif "recordMap" in data: | |
| print("Detected recordMap, stopping stream.") | |
| # Process any remaining buffer | |
| if buffer.strip(): | |
| try: | |
| last_data = json.loads(buffer.strip()) | |
| if last_data.get("type") == "markdown-chat" and isinstance(last_data.get("value"), str): | |
| if last_data["value"]: | |
| last_chunk = ChatCompletionChunk( | |
| id=chunk_id, | |
| created=created_time, | |
| choices=[Choice(delta=ChoiceDelta(content=last_data["value"]))] | |
| ) | |
| yield f"data: {last_chunk.model_dump_json()}\n\n" | |
| except: | |
| pass | |
| # Exit the loop | |
| break | |
| except json.JSONDecodeError as e: | |
| print(f"Warning: Could not decode JSON line: {line[:100]}... Error: {str(e)}") | |
| except Exception as e: | |
| print(f"Error processing line: {str(e)}") | |
| # Continue processing other lines | |
| if is_error_response: | |
| break | |
| # 如果检测到错误,进行重试 | |
| if is_error_response: | |
| if attempt < max_retries - 1: | |
| print(f"等待 {retry_delay} 秒后重试...") | |
| await asyncio.sleep(retry_delay) | |
| continue # 重试 | |
| else: | |
| # 所有重试都失败了,通过流式响应返回错误信息 | |
| print("所有重试都失败,返回500错误给客户端") | |
| error_chunk = ChatCompletionChunk( | |
| id=chunk_id, | |
| created=created_time, | |
| choices=[Choice(delta=ChoiceDelta(content="Error: Notion API returned error code 500 after all retries"), finish_reason="stop")] | |
| ) | |
| yield f"data: {error_chunk.model_dump_json()}\n\n" | |
| yield "data: [DONE]\n\n" | |
| return | |
| # 如果没有错误,发送最终的停止信号 | |
| # Send the final chunk indicating stop | |
| final_chunk = ChatCompletionChunk( | |
| id=chunk_id, | |
| created=created_time, | |
| choices=[Choice(delta=ChoiceDelta(), finish_reason="stop")] | |
| ) | |
| yield f"data: {final_chunk.model_dump_json()}\n\n" | |
| yield "data: [DONE]\n\n" | |
| # 成功完成,退出重试循环 | |
| break | |
| except HTTPException: | |
| # 在流式响应中不能抛出HTTPException,通过流式响应返回错误 | |
| if attempt < max_retries - 1: | |
| print(f"HTTP异常,等待 {retry_delay} 秒后重试...") | |
| await asyncio.sleep(retry_delay) | |
| continue | |
| else: | |
| print("HTTP异常且无更多重试,返回错误信息") | |
| error_chunk = ChatCompletionChunk( | |
| id=chunk_id, | |
| created=created_time, | |
| choices=[Choice(delta=ChoiceDelta(content="Error: HTTP exception occurred after all retries"), finish_reason="stop")] | |
| ) | |
| yield f"data: {error_chunk.model_dump_json()}\n\n" | |
| yield "data: [DONE]\n\n" | |
| return | |
| except Exception as e: | |
| print(f"Unexpected error during streaming (attempt {attempt + 1}/{max_retries}): {e}") | |
| if attempt < max_retries - 1: | |
| print(f"等待 {retry_delay} 秒后重试...") | |
| await asyncio.sleep(retry_delay) | |
| continue | |
| else: | |
| print("意外错误且无更多重试,返回错误信息") | |
| error_chunk = ChatCompletionChunk( | |
| id=chunk_id, | |
| created=created_time, | |
| choices=[Choice(delta=ChoiceDelta(content=f"Error: Internal server error during streaming: {e}"), finish_reason="stop")] | |
| ) | |
| yield f"data: {error_chunk.model_dump_json()}\n\n" | |
| yield "data: [DONE]\n\n" | |
| return | |
| # --- API Endpoints --- | |
| async def list_models(authenticated: bool = Depends(authenticate)): | |
| """ | |
| Endpoint to list available Notion models, mimicking OpenAI's /v1/models. | |
| """ | |
| available_models = [ | |
| "openai-gpt-4.1", | |
| "anthropic-opus-4", | |
| "anthropic-sonnet-4" | |
| ] | |
| model_list = [ | |
| Model(id=model_id, owned_by="notion") # created uses default_factory | |
| for model_id in available_models | |
| ] | |
| return ModelList(data=model_list) | |
| async def chat_completions(request_data: ChatCompletionRequest, request: Request, authenticated: bool = Depends(authenticate)): | |
| """ | |
| Endpoint to mimic OpenAI's chat completions, proxying to Notion. | |
| """ | |
| if not NOTION_COOKIE: | |
| raise HTTPException(status_code=500, detail="Server configuration error: Notion cookie not set.") | |
| notion_request_body = build_notion_request(request_data) | |
| if request_data.stream: | |
| return StreamingResponse( | |
| stream_notion_response(notion_request_body), | |
| media_type="text/event-stream" | |
| ) | |
| else: | |
| # --- Non-Streaming Logic (Optional - Collects stream internally) --- | |
| # Note: The primary goal is streaming, but a non-streaming version | |
| # might be useful for testing or simpler clients. | |
| # This requires collecting all chunks from the async generator. | |
| full_response_content = "" | |
| final_finish_reason = None | |
| chunk_id = f"chatcmpl-{uuid.uuid4()}" # Generate ID for the non-streamed response | |
| created_time = int(time.time()) | |
| try: | |
| async for line in stream_notion_response(notion_request_body): | |
| if line.startswith("data: ") and "[DONE]" not in line: | |
| try: | |
| data_json = line[len("data: "):].strip() | |
| if data_json: | |
| chunk_data = json.loads(data_json) | |
| if chunk_data.get("choices"): | |
| delta = chunk_data["choices"][0].get("delta", {}) | |
| content = delta.get("content") | |
| if content: | |
| full_response_content += content | |
| finish_reason = chunk_data["choices"][0].get("finish_reason") | |
| if finish_reason: | |
| final_finish_reason = finish_reason | |
| except json.JSONDecodeError: | |
| print(f"Warning: Could not decode JSON line in non-streaming mode: {line}") | |
| # Construct the final OpenAI-compatible non-streaming response | |
| return { | |
| "id": chunk_id, | |
| "object": "chat.completion", | |
| "created": created_time, | |
| "model": request_data.model, # Return the model requested by the client | |
| "choices": [ | |
| { | |
| "index": 0, | |
| "message": { | |
| "role": "assistant", | |
| "content": full_response_content, | |
| }, | |
| "finish_reason": final_finish_reason or "stop", # Default to stop if not explicitly set | |
| } | |
| ], | |
| "usage": { # Note: Token usage is not available from Notion | |
| "prompt_tokens": None, | |
| "completion_tokens": None, | |
| "total_tokens": None, | |
| }, | |
| } | |
| except HTTPException as e: | |
| # Re-raise HTTP exceptions from the streaming function | |
| raise e | |
| except Exception as e: | |
| print(f"Error during non-streaming processing: {e}") | |
| raise HTTPException(status_code=500, detail="Internal server error processing Notion response") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| print("Starting server. Access at http://localhost:7860") | |
| print("Ensure NOTION_COOKIE is set in your .env file or environment.") | |
| print("Cookie管理系统已启用,将自动获取和更新Notion浏览器cookie") | |
| # 运行服务器 | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |