import os import uuid import json import time import asyncio import random from curl_cffi.requests import AsyncSession from fastapi import FastAPI, Request, HTTPException, Depends, status from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from fastapi.responses import StreamingResponse from dotenv import load_dotenv import secrets from pydantic import BaseModel, Field from typing import List, Optional, Dict, Any, Literal, Union # Load environment variables from .env file load_dotenv() # --- 并发请求配置 --- CONCURRENT_REQUESTS = 1 # 可自定义并发请求数量 # --- 重试配置 --- MAX_RETRIES = 3 RETRY_DELAY = 1 # 秒 # --- Models (Integrated from models.py) --- # Input Models (OpenAI-like) class ChatMessage(BaseModel): role: Literal["system", "user", "assistant"] content: str class ChatCompletionRequest(BaseModel): messages: List[ChatMessage] model: str = "notion-proxy" # Model name can be passed, but we map to Notion's model stream: bool = False # Add other potential OpenAI params if needed, though they might not map directly # max_tokens: Optional[int] = None # temperature: Optional[float] = None # space_id and thread_id are now handled globally via environment variables notion_model: str = "anthropic-opus-4" # Default Notion model, can be overridden # Notion Models class NotionTranscriptConfigValue(BaseModel): type: str = "markdown-chat" model: str # e.g., "anthropic-opus-4" class NotionTranscriptItem(BaseModel): type: Literal["config", "user", "markdown-chat"] value: Union[List[List[str]], str, NotionTranscriptConfigValue] class NotionDebugOverrides(BaseModel): cachedInferences: Dict = Field(default_factory=dict) annotationInferences: Dict = Field(default_factory=dict) emitInferences: bool = False class NotionRequestBody(BaseModel): traceId: str = Field(default_factory=lambda: str(uuid.uuid4())) spaceId: str transcript: List[NotionTranscriptItem] # threadId is removed, createThread will be set to true createThread: bool = True debugOverrides: NotionDebugOverrides = Field(default_factory=NotionDebugOverrides) generateTitle: bool = False saveAllThreadOperations: bool = True # Output Models (OpenAI SSE) class ChoiceDelta(BaseModel): content: Optional[str] = None class Choice(BaseModel): index: int = 0 delta: ChoiceDelta finish_reason: Optional[Literal["stop", "length"]] = None class ChatCompletionChunk(BaseModel): id: str = Field(default_factory=lambda: f"chatcmpl-{uuid.uuid4()}") object: str = "chat.completion.chunk" created: int = Field(default_factory=lambda: int(time.time())) model: str = "notion-proxy" # Or could reflect the underlying Notion model choices: List[Choice] # Models for /v1/models Endpoint class Model(BaseModel): id: str object: str = "model" created: int = Field(default_factory=lambda: int(time.time())) owned_by: str = "notion" # Or specify based on actual model origin if needed class ModelList(BaseModel): object: str = "list" data: List[Model] # --- Configuration --- NOTION_API_URL = "https://www.notion.so/api/v3/runInferenceTranscript" # IMPORTANT: Load the Notion cookie securely from environment variables NOTION_COOKIE = os.getenv("NOTION_COOKIE") NOTION_SPACE_ID = os.getenv("NOTION_SPACE_ID") if not NOTION_COOKIE: print("Error: NOTION_COOKIE environment variable not set.") # Consider raising HTTPException or exiting in a real app if not NOTION_SPACE_ID: print("Warning: NOTION_SPACE_ID environment variable not set. Using a default UUID.") # Using a default might not be ideal, depends on Notion's behavior # Consider raising an error instead: raise ValueError("NOTION_SPACE_ID not set") NOTION_SPACE_ID = str(uuid.uuid4()) # Default or raise error # --- Authentication --- EXPECTED_TOKEN = os.getenv("PROXY_AUTH_TOKEN", "default_token") # Default token security = HTTPBearer() def authenticate(credentials: HTTPAuthorizationCredentials = Depends(security)): """Compares provided token with the expected token.""" correct_token = secrets.compare_digest(credentials.credentials, EXPECTED_TOKEN) if not correct_token: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid authentication credentials", # WWW-Authenticate header removed for Bearer ) return True # Indicate successful authentication # --- FastAPI App --- app = FastAPI() # --- Helper Functions --- def build_notion_request(request_data: ChatCompletionRequest) -> NotionRequestBody: """Transforms OpenAI-style messages to Notion transcript format.""" transcript = [ NotionTranscriptItem( type="config", value=NotionTranscriptConfigValue(model=request_data.notion_model) ) ] for message in request_data.messages: # Map 'assistant' role to 'markdown-chat', all others to 'user' if message.role == "assistant": # Notion uses "markdown-chat" for assistant replies in the transcript history transcript.append(NotionTranscriptItem(type="markdown-chat", value=message.content)) else: # Map user, system, and any other potential roles to 'user' transcript.append(NotionTranscriptItem(type="user", value=[[message.content]])) # Use globally configured spaceId, set createThread=True return NotionRequestBody( spaceId=NOTION_SPACE_ID, # From environment variable transcript=transcript, createThread=True, # Always create a new thread # Generate a new traceId for each request traceId=str(uuid.uuid4()), # Explicitly set debugOverrides, generateTitle, and saveAllThreadOperations debugOverrides=NotionDebugOverrides( cachedInferences={}, annotationInferences={}, emitInferences=False ), generateTitle=False, saveAllThreadOperations=False ) async def check_first_response_line(session: AsyncSession, notion_request_body: NotionRequestBody, headers: dict, request_id: int): """检查响应的第一行,判断是否为500错误""" try: # 当并发请求数大于1时,添加随机延迟以避免同时到达 if CONCURRENT_REQUESTS > 1: delay = random.uniform(0, 1.0) print(f"并发请求 {request_id} 延迟 {delay:.2f}秒") await asyncio.sleep(delay) # 为每个并发请求创建独立的请求体,生成新的traceId request_body_copy = notion_request_body.model_copy() request_body_copy.traceId = str(uuid.uuid4()) response = await session.post( NOTION_API_URL, json=request_body_copy.model_dump(), headers=headers, stream=True ) if response.status_code != 200: return None, response, f"HTTP {response.status_code}" # 读取第一行来检查是否是错误 buffer = "" async for chunk in response.aiter_content(): if isinstance(chunk, bytes): chunk = chunk.decode('utf-8') buffer += chunk # 尝试解析第一个完整的JSON行 lines = buffer.split('\n') for line in lines: line = line.strip() if line: try: data = json.loads(line) if (data.get("type") == "error" and data.get("message") and "error code 500" in data.get("message", "")): print(f"并发请求 {request_id} 检测到500错误: {data}") return None, response, "500 error" else: # 正常响应,返回response和已读取的buffer print(f"并发请求 {request_id} 响应正常") return (response, buffer), None, None except json.JSONDecodeError: continue return None, response, "No valid response" except Exception as e: print(f"并发请求 {request_id} 发生异常: {e}") return None, None, str(e) async def stream_notion_response_single(session: AsyncSession, response, initial_buffer: str, chunk_id: str, created_time: int): """处理单个响应的流式输出""" buffer = initial_buffer # 首先处理已经读取的buffer中的内容 lines = buffer.split('\n') buffer = lines[-1] for line in lines[:-1]: line = line.strip() if not line: continue try: data = json.loads(line) if data.get("type") == "markdown-chat" and isinstance(data.get("value"), str): content_chunk = data["value"] if content_chunk: chunk_obj = ChatCompletionChunk( id=chunk_id, created=created_time, choices=[Choice(delta=ChoiceDelta(content=content_chunk))] ) yield f"data: {chunk_obj.model_dump_json()}\n\n" elif "recordMap" in data: print("Detected recordMap, stopping stream.") # 继续处理剩余的buffer if buffer.strip(): try: last_data = json.loads(buffer.strip()) if last_data.get("type") == "markdown-chat" and isinstance(last_data.get("value"), str): if last_data["value"]: last_chunk = ChatCompletionChunk( id=chunk_id, created=created_time, choices=[Choice(delta=ChoiceDelta(content=last_data["value"]))] ) yield f"data: {last_chunk.model_dump_json()}\n\n" except: pass return except json.JSONDecodeError as e: print(f"Warning: Could not decode JSON line: {line[:100]}... Error: {str(e)}") except Exception as e: print(f"Error processing line: {str(e)}") # 继续读取剩余的响应 async for chunk in response.aiter_content(): if isinstance(chunk, bytes): chunk = chunk.decode('utf-8') buffer += chunk lines = buffer.split('\n') buffer = lines[-1] for line in lines[:-1]: line = line.strip() if not line: continue try: data = json.loads(line) if data.get("type") == "markdown-chat" and isinstance(data.get("value"), str): content_chunk = data["value"] if content_chunk: chunk_obj = ChatCompletionChunk( id=chunk_id, created=created_time, choices=[Choice(delta=ChoiceDelta(content=content_chunk))] ) yield f"data: {chunk_obj.model_dump_json()}\n\n" elif "recordMap" in data: print("Detected recordMap, stopping stream.") if buffer.strip(): try: last_data = json.loads(buffer.strip()) if last_data.get("type") == "markdown-chat" and isinstance(last_data.get("value"), str): if last_data["value"]: last_chunk = ChatCompletionChunk( id=chunk_id, created=created_time, choices=[Choice(delta=ChoiceDelta(content=last_data["value"]))] ) yield f"data: {last_chunk.model_dump_json()}\n\n" except: pass return except json.JSONDecodeError as e: print(f"Warning: Could not decode JSON line: {line[:100]}... Error: {str(e)}") except Exception as e: print(f"Error processing line: {str(e)}") async def stream_notion_response(notion_request_body: NotionRequestBody): """Streams the request to Notion and yields OpenAI-compatible SSE chunks.""" # curl_cffi will automatically handle most headers like a real browser # We only need to set specific headers that are necessary headers = { 'accept': 'application/x-ndjson', 'accept-encoding': 'gzip, deflate, br, zstd', 'accept-language': 'en-US,zh;q=0.9', 'content-type': 'application/json', 'dnt': '1', 'notion-audit-log-platform': 'web', 'notion-client-version': '23.13.0.3661', 'origin': 'https://www.notion.so', 'referer': 'https://www.notion.so/', 'priority': 'u=1, i', 'sec-ch-ua-mobile': '?0', 'sec-ch-ua-platform': '"Windows"', 'sec-fetch-dest': 'empty', 'sec-fetch-mode': 'cors', 'sec-fetch-site': 'same-origin', 'cookie': NOTION_COOKIE, 'x-notion-space-id': NOTION_SPACE_ID } # Conditionally add the active user header notion_active_user = os.getenv("NOTION_ACTIVE_USER_HEADER") if notion_active_user: # Checks for None and empty string implicitly headers['x-notion-active-user-header'] = notion_active_user chunk_id = f"chatcmpl-{uuid.uuid4()}" created_time = int(time.time()) # 使用全局重试配置 max_retries = MAX_RETRIES retry_delay = RETRY_DELAY # 首先尝试并发请求 print(f"同时发起 {CONCURRENT_REQUESTS} 个并发请求...") async with AsyncSession(impersonate="chrome136") as session: # 同时创建并发任务(每个都是独立的异步任务) tasks = [] for i in range(CONCURRENT_REQUESTS): task = asyncio.create_task( check_first_response_line(session, notion_request_body, headers, i + 1) ) tasks.append(task) # 等待所有任务完成或找到第一个成功的响应 successful_response = None failed_count = 0 completed_tasks = set() while len(completed_tasks) < CONCURRENT_REQUESTS and not successful_response: # 等待任意一个任务完成 done, pending = await asyncio.wait( [t for t in tasks if t not in completed_tasks], return_when=asyncio.FIRST_COMPLETED ) for task in done: completed_tasks.add(task) result, response, error = await task if result: # 找到成功的响应,立即使用 successful_response = result print(f"找到成功的并发响应,立即使用") # 取消其他还在运行的任务 for t in tasks: if t not in completed_tasks: t.cancel() break else: # 记录失败 failed_count += 1 if error: print(f"并发请求失败: {error}") # 如果有成功的响应,使用它进行流式传输 if successful_response: response, initial_buffer = successful_response print("使用成功的并发响应进行流式传输") # 流式输出响应 async for data in stream_notion_response_single(session, response, initial_buffer, chunk_id, created_time): yield data # Send the final chunk indicating stop final_chunk = ChatCompletionChunk( id=chunk_id, created=created_time, choices=[Choice(delta=ChoiceDelta(), finish_reason="stop")] ) yield f"data: {final_chunk.model_dump_json()}\n\n" yield "data: [DONE]\n\n" return # 只有当所有并发请求都失败时,才进入重试流程 print(f"所有 {CONCURRENT_REQUESTS} 个并发请求都失败,开始单请求重试流程...") # 进入原有的重试逻辑(不使用并发) for attempt in range(max_retries): try: # Using curl_cffi with chrome136 impersonation for better anti-bot bypass async with AsyncSession(impersonate="chrome136") as session: # Stream the response response = await session.post( NOTION_API_URL, json=notion_request_body.model_dump(), headers=headers, stream=True ) if response.status_code != 200: error_content = await response.atext() print(f"Error from Notion API: {response.status_code}") print(f"Response: {error_content}") raise HTTPException(status_code=response.status_code, detail=f"Notion API Error: {error_content}") # Process streaming response # curl_cffi streaming works differently - we need to read the content in chunks buffer = "" first_line_checked = False is_error_response = False async for chunk in response.aiter_content(): # Decode chunk if it's bytes if isinstance(chunk, bytes): chunk = chunk.decode('utf-8') buffer += chunk # Split by newlines and process complete lines lines = buffer.split('\n') # Keep the last incomplete line in the buffer buffer = lines[-1] for line in lines[:-1]: line = line.strip() if not line: continue try: data = json.loads(line) # 检查第一行是否是500错误响应 if not first_line_checked: first_line_checked = True if (data.get("type") == "error" and data.get("message") and "error code 500" in data.get("message", "")): print(f"检测到Notion API 500错误 (重试 {attempt + 1}/{max_retries}): {data}") is_error_response = True break # 如果不是错误响应,实时流式转发 # Check if it's the type of message containing text chunks if data.get("type") == "markdown-chat" and isinstance(data.get("value"), str): content_chunk = data["value"] if content_chunk: # Only send if there's content chunk_obj = ChatCompletionChunk( id=chunk_id, created=created_time, choices=[Choice(delta=ChoiceDelta(content=content_chunk))] ) yield f"data: {chunk_obj.model_dump_json()}\n\n" # Add logic here to detect the end of the stream if Notion has a specific marker # For now, we assume markdown-chat stops when the main content is done. # If we see a recordMap, it's definitely past the text stream. elif "recordMap" in data: print("Detected recordMap, stopping stream.") # Process any remaining buffer if buffer.strip(): try: last_data = json.loads(buffer.strip()) if last_data.get("type") == "markdown-chat" and isinstance(last_data.get("value"), str): if last_data["value"]: last_chunk = ChatCompletionChunk( id=chunk_id, created=created_time, choices=[Choice(delta=ChoiceDelta(content=last_data["value"]))] ) yield f"data: {last_chunk.model_dump_json()}\n\n" except: pass # Exit the loop break except json.JSONDecodeError as e: print(f"Warning: Could not decode JSON line: {line[:100]}... Error: {str(e)}") except Exception as e: print(f"Error processing line: {str(e)}") # Continue processing other lines if is_error_response: break # 如果检测到错误,进行重试 if is_error_response: if attempt < max_retries - 1: print(f"等待 {retry_delay} 秒后重试...") await asyncio.sleep(retry_delay) continue # 重试 else: # 所有重试都失败了,通过流式响应返回错误信息 print("所有重试都失败,返回500错误给客户端") error_chunk = ChatCompletionChunk( id=chunk_id, created=created_time, choices=[Choice(delta=ChoiceDelta(content="Error: Notion API returned error code 500 after all retries"), finish_reason="stop")] ) yield f"data: {error_chunk.model_dump_json()}\n\n" yield "data: [DONE]\n\n" return # 如果没有错误,发送最终的停止信号 # Send the final chunk indicating stop final_chunk = ChatCompletionChunk( id=chunk_id, created=created_time, choices=[Choice(delta=ChoiceDelta(), finish_reason="stop")] ) yield f"data: {final_chunk.model_dump_json()}\n\n" yield "data: [DONE]\n\n" # 成功完成,退出重试循环 break except HTTPException: # 在流式响应中不能抛出HTTPException,通过流式响应返回错误 if attempt < max_retries - 1: print(f"HTTP异常,等待 {retry_delay} 秒后重试...") await asyncio.sleep(retry_delay) continue else: print("HTTP异常且无更多重试,返回错误信息") error_chunk = ChatCompletionChunk( id=chunk_id, created=created_time, choices=[Choice(delta=ChoiceDelta(content="Error: HTTP exception occurred after all retries"), finish_reason="stop")] ) yield f"data: {error_chunk.model_dump_json()}\n\n" yield "data: [DONE]\n\n" return except Exception as e: print(f"Unexpected error during streaming (attempt {attempt + 1}/{max_retries}): {e}") if attempt < max_retries - 1: print(f"等待 {retry_delay} 秒后重试...") await asyncio.sleep(retry_delay) continue else: print("意外错误且无更多重试,返回错误信息") error_chunk = ChatCompletionChunk( id=chunk_id, created=created_time, choices=[Choice(delta=ChoiceDelta(content=f"Error: Internal server error during streaming: {e}"), finish_reason="stop")] ) yield f"data: {error_chunk.model_dump_json()}\n\n" yield "data: [DONE]\n\n" return # --- API Endpoints --- @app.get("/v1/models", response_model=ModelList) async def list_models(authenticated: bool = Depends(authenticate)): """ Endpoint to list available Notion models, mimicking OpenAI's /v1/models. """ available_models = [ "openai-gpt-4.1", "anthropic-opus-4", "anthropic-sonnet-4" ] model_list = [ Model(id=model_id, owned_by="notion") # created uses default_factory for model_id in available_models ] return ModelList(data=model_list) @app.post("/v1/chat/completions") async def chat_completions(request_data: ChatCompletionRequest, request: Request, authenticated: bool = Depends(authenticate)): """ Endpoint to mimic OpenAI's chat completions, proxying to Notion. """ if not NOTION_COOKIE: raise HTTPException(status_code=500, detail="Server configuration error: Notion cookie not set.") notion_request_body = build_notion_request(request_data) if request_data.stream: return StreamingResponse( stream_notion_response(notion_request_body), media_type="text/event-stream" ) else: # --- Non-Streaming Logic (Optional - Collects stream internally) --- # Note: The primary goal is streaming, but a non-streaming version # might be useful for testing or simpler clients. # This requires collecting all chunks from the async generator. full_response_content = "" final_finish_reason = None chunk_id = f"chatcmpl-{uuid.uuid4()}" # Generate ID for the non-streamed response created_time = int(time.time()) try: async for line in stream_notion_response(notion_request_body): if line.startswith("data: ") and "[DONE]" not in line: try: data_json = line[len("data: "):].strip() if data_json: chunk_data = json.loads(data_json) if chunk_data.get("choices"): delta = chunk_data["choices"][0].get("delta", {}) content = delta.get("content") if content: full_response_content += content finish_reason = chunk_data["choices"][0].get("finish_reason") if finish_reason: final_finish_reason = finish_reason except json.JSONDecodeError: print(f"Warning: Could not decode JSON line in non-streaming mode: {line}") # Construct the final OpenAI-compatible non-streaming response return { "id": chunk_id, "object": "chat.completion", "created": created_time, "model": request_data.model, # Return the model requested by the client "choices": [ { "index": 0, "message": { "role": "assistant", "content": full_response_content, }, "finish_reason": final_finish_reason or "stop", # Default to stop if not explicitly set } ], "usage": { # Note: Token usage is not available from Notion "prompt_tokens": None, "completion_tokens": None, "total_tokens": None, }, } except HTTPException as e: # Re-raise HTTP exceptions from the streaming function raise e except Exception as e: print(f"Error during non-streaming processing: {e}") raise HTTPException(status_code=500, detail="Internal server error processing Notion response") if __name__ == "__main__": import uvicorn print("Starting server. Access at http://localhost:7860") print("Ensure NOTION_COOKIE is set in your .env file or environment.") uvicorn.run(app, host="0.0.0.0", port=7860)