QIN456987's picture
Upload 3 files
7c0d09c verified
import os
import uuid
import json
import time
import asyncio
import random
from curl_cffi.requests import AsyncSession
from fastapi import FastAPI, Request, HTTPException, Depends, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi.responses import StreamingResponse
from dotenv import load_dotenv
import secrets
from pydantic import BaseModel, Field
from typing import List, Optional, Dict, Any, Literal, Union
# Load environment variables from .env file
load_dotenv()
# --- 并发请求配置 ---
CONCURRENT_REQUESTS = 1 # 可自定义并发请求数量
# --- 重试配置 ---
MAX_RETRIES = 3
RETRY_DELAY = 1 # 秒
# --- Models (Integrated from models.py) ---
# Input Models (OpenAI-like)
class ChatMessage(BaseModel):
role: Literal["system", "user", "assistant"]
content: str
class ChatCompletionRequest(BaseModel):
messages: List[ChatMessage]
model: str = "notion-proxy" # Model name can be passed, but we map to Notion's model
stream: bool = False
# Add other potential OpenAI params if needed, though they might not map directly
# max_tokens: Optional[int] = None
# temperature: Optional[float] = None
# space_id and thread_id are now handled globally via environment variables
notion_model: str = "anthropic-opus-4" # Default Notion model, can be overridden
# Notion Models
class NotionTranscriptConfigValue(BaseModel):
type: str = "markdown-chat"
model: str # e.g., "anthropic-opus-4"
class NotionTranscriptItem(BaseModel):
type: Literal["config", "user", "markdown-chat"]
value: Union[List[List[str]], str, NotionTranscriptConfigValue]
class NotionDebugOverrides(BaseModel):
cachedInferences: Dict = Field(default_factory=dict)
annotationInferences: Dict = Field(default_factory=dict)
emitInferences: bool = False
class NotionRequestBody(BaseModel):
traceId: str = Field(default_factory=lambda: str(uuid.uuid4()))
spaceId: str
transcript: List[NotionTranscriptItem]
# threadId is removed, createThread will be set to true
createThread: bool = True
debugOverrides: NotionDebugOverrides = Field(default_factory=NotionDebugOverrides)
generateTitle: bool = False
saveAllThreadOperations: bool = True
# Output Models (OpenAI SSE)
class ChoiceDelta(BaseModel):
content: Optional[str] = None
class Choice(BaseModel):
index: int = 0
delta: ChoiceDelta
finish_reason: Optional[Literal["stop", "length"]] = None
class ChatCompletionChunk(BaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{uuid.uuid4()}")
object: str = "chat.completion.chunk"
created: int = Field(default_factory=lambda: int(time.time()))
model: str = "notion-proxy" # Or could reflect the underlying Notion model
choices: List[Choice]
# Models for /v1/models Endpoint
class Model(BaseModel):
id: str
object: str = "model"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: str = "notion" # Or specify based on actual model origin if needed
class ModelList(BaseModel):
object: str = "list"
data: List[Model]
# --- Configuration ---
NOTION_API_URL = "https://www.notion.so/api/v3/runInferenceTranscript"
# IMPORTANT: Load the Notion cookie securely from environment variables
NOTION_COOKIE = os.getenv("NOTION_COOKIE")
NOTION_SPACE_ID = os.getenv("NOTION_SPACE_ID")
if not NOTION_COOKIE:
print("Error: NOTION_COOKIE environment variable not set.")
# Consider raising HTTPException or exiting in a real app
if not NOTION_SPACE_ID:
print("Warning: NOTION_SPACE_ID environment variable not set. Using a default UUID.")
# Using a default might not be ideal, depends on Notion's behavior
# Consider raising an error instead: raise ValueError("NOTION_SPACE_ID not set")
NOTION_SPACE_ID = str(uuid.uuid4()) # Default or raise error
# --- Authentication ---
EXPECTED_TOKEN = os.getenv("PROXY_AUTH_TOKEN", "default_token") # Default token
security = HTTPBearer()
def authenticate(credentials: HTTPAuthorizationCredentials = Depends(security)):
"""Compares provided token with the expected token."""
correct_token = secrets.compare_digest(credentials.credentials, EXPECTED_TOKEN)
if not correct_token:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials",
# WWW-Authenticate header removed for Bearer
)
return True # Indicate successful authentication
# --- FastAPI App ---
app = FastAPI()
# --- Helper Functions ---
def build_notion_request(request_data: ChatCompletionRequest) -> NotionRequestBody:
"""Transforms OpenAI-style messages to Notion transcript format."""
transcript = [
NotionTranscriptItem(
type="config",
value=NotionTranscriptConfigValue(model=request_data.notion_model)
)
]
for message in request_data.messages:
# Map 'assistant' role to 'markdown-chat', all others to 'user'
if message.role == "assistant":
# Notion uses "markdown-chat" for assistant replies in the transcript history
transcript.append(NotionTranscriptItem(type="markdown-chat", value=message.content))
else:
# Map user, system, and any other potential roles to 'user'
transcript.append(NotionTranscriptItem(type="user", value=[[message.content]]))
# Use globally configured spaceId, set createThread=True
return NotionRequestBody(
spaceId=NOTION_SPACE_ID, # From environment variable
transcript=transcript,
createThread=True, # Always create a new thread
# Generate a new traceId for each request
traceId=str(uuid.uuid4()),
# Explicitly set debugOverrides, generateTitle, and saveAllThreadOperations
debugOverrides=NotionDebugOverrides(
cachedInferences={},
annotationInferences={},
emitInferences=False
),
generateTitle=False,
saveAllThreadOperations=False
)
async def check_first_response_line(session: AsyncSession, notion_request_body: NotionRequestBody, headers: dict, request_id: int):
"""检查响应的第一行,判断是否为500错误"""
try:
# 当并发请求数大于1时,添加随机延迟以避免同时到达
if CONCURRENT_REQUESTS > 1:
delay = random.uniform(0, 1.0)
print(f"并发请求 {request_id} 延迟 {delay:.2f}秒")
await asyncio.sleep(delay)
# 为每个并发请求创建独立的请求体,生成新的traceId
request_body_copy = notion_request_body.model_copy()
request_body_copy.traceId = str(uuid.uuid4())
response = await session.post(
NOTION_API_URL,
json=request_body_copy.model_dump(),
headers=headers,
stream=True
)
if response.status_code != 200:
return None, response, f"HTTP {response.status_code}"
# 读取第一行来检查是否是错误
buffer = ""
async for chunk in response.aiter_content():
if isinstance(chunk, bytes):
chunk = chunk.decode('utf-8')
buffer += chunk
# 尝试解析第一个完整的JSON行
lines = buffer.split('\n')
for line in lines:
line = line.strip()
if line:
try:
data = json.loads(line)
if (data.get("type") == "error" and
data.get("message") and
"error code 500" in data.get("message", "")):
print(f"并发请求 {request_id} 检测到500错误: {data}")
return None, response, "500 error"
else:
# 正常响应,返回response和已读取的buffer
print(f"并发请求 {request_id} 响应正常")
return (response, buffer), None, None
except json.JSONDecodeError:
continue
return None, response, "No valid response"
except Exception as e:
print(f"并发请求 {request_id} 发生异常: {e}")
return None, None, str(e)
async def stream_notion_response_single(session: AsyncSession, response, initial_buffer: str, chunk_id: str, created_time: int):
"""处理单个响应的流式输出"""
buffer = initial_buffer
# 首先处理已经读取的buffer中的内容
lines = buffer.split('\n')
buffer = lines[-1]
for line in lines[:-1]:
line = line.strip()
if not line:
continue
try:
data = json.loads(line)
if data.get("type") == "markdown-chat" and isinstance(data.get("value"), str):
content_chunk = data["value"]
if content_chunk:
chunk_obj = ChatCompletionChunk(
id=chunk_id,
created=created_time,
choices=[Choice(delta=ChoiceDelta(content=content_chunk))]
)
yield f"data: {chunk_obj.model_dump_json()}\n\n"
elif "recordMap" in data:
print("Detected recordMap, stopping stream.")
# 继续处理剩余的buffer
if buffer.strip():
try:
last_data = json.loads(buffer.strip())
if last_data.get("type") == "markdown-chat" and isinstance(last_data.get("value"), str):
if last_data["value"]:
last_chunk = ChatCompletionChunk(
id=chunk_id,
created=created_time,
choices=[Choice(delta=ChoiceDelta(content=last_data["value"]))]
)
yield f"data: {last_chunk.model_dump_json()}\n\n"
except:
pass
return
except json.JSONDecodeError as e:
print(f"Warning: Could not decode JSON line: {line[:100]}... Error: {str(e)}")
except Exception as e:
print(f"Error processing line: {str(e)}")
# 继续读取剩余的响应
async for chunk in response.aiter_content():
if isinstance(chunk, bytes):
chunk = chunk.decode('utf-8')
buffer += chunk
lines = buffer.split('\n')
buffer = lines[-1]
for line in lines[:-1]:
line = line.strip()
if not line:
continue
try:
data = json.loads(line)
if data.get("type") == "markdown-chat" and isinstance(data.get("value"), str):
content_chunk = data["value"]
if content_chunk:
chunk_obj = ChatCompletionChunk(
id=chunk_id,
created=created_time,
choices=[Choice(delta=ChoiceDelta(content=content_chunk))]
)
yield f"data: {chunk_obj.model_dump_json()}\n\n"
elif "recordMap" in data:
print("Detected recordMap, stopping stream.")
if buffer.strip():
try:
last_data = json.loads(buffer.strip())
if last_data.get("type") == "markdown-chat" and isinstance(last_data.get("value"), str):
if last_data["value"]:
last_chunk = ChatCompletionChunk(
id=chunk_id,
created=created_time,
choices=[Choice(delta=ChoiceDelta(content=last_data["value"]))]
)
yield f"data: {last_chunk.model_dump_json()}\n\n"
except:
pass
return
except json.JSONDecodeError as e:
print(f"Warning: Could not decode JSON line: {line[:100]}... Error: {str(e)}")
except Exception as e:
print(f"Error processing line: {str(e)}")
async def stream_notion_response(notion_request_body: NotionRequestBody):
"""Streams the request to Notion and yields OpenAI-compatible SSE chunks."""
# curl_cffi will automatically handle most headers like a real browser
# We only need to set specific headers that are necessary
headers = {
'accept': 'application/x-ndjson',
'accept-encoding': 'gzip, deflate, br, zstd',
'accept-language': 'en-US,zh;q=0.9',
'content-type': 'application/json',
'dnt': '1',
'notion-audit-log-platform': 'web',
'notion-client-version': '23.13.0.3661',
'origin': 'https://www.notion.so',
'referer': 'https://www.notion.so/',
'priority': 'u=1, i',
'sec-ch-ua-mobile': '?0',
'sec-ch-ua-platform': '"Windows"',
'sec-fetch-dest': 'empty',
'sec-fetch-mode': 'cors',
'sec-fetch-site': 'same-origin',
'cookie': NOTION_COOKIE,
'x-notion-space-id': NOTION_SPACE_ID
}
# Conditionally add the active user header
notion_active_user = os.getenv("NOTION_ACTIVE_USER_HEADER")
if notion_active_user: # Checks for None and empty string implicitly
headers['x-notion-active-user-header'] = notion_active_user
chunk_id = f"chatcmpl-{uuid.uuid4()}"
created_time = int(time.time())
# 使用全局重试配置
max_retries = MAX_RETRIES
retry_delay = RETRY_DELAY
# 首先尝试并发请求
print(f"同时发起 {CONCURRENT_REQUESTS} 个并发请求...")
async with AsyncSession(impersonate="chrome136") as session:
# 同时创建并发任务(每个都是独立的异步任务)
tasks = []
for i in range(CONCURRENT_REQUESTS):
task = asyncio.create_task(
check_first_response_line(session, notion_request_body, headers, i + 1)
)
tasks.append(task)
# 等待所有任务完成或找到第一个成功的响应
successful_response = None
failed_count = 0
completed_tasks = set()
while len(completed_tasks) < CONCURRENT_REQUESTS and not successful_response:
# 等待任意一个任务完成
done, pending = await asyncio.wait(
[t for t in tasks if t not in completed_tasks],
return_when=asyncio.FIRST_COMPLETED
)
for task in done:
completed_tasks.add(task)
result, response, error = await task
if result:
# 找到成功的响应,立即使用
successful_response = result
print(f"找到成功的并发响应,立即使用")
# 取消其他还在运行的任务
for t in tasks:
if t not in completed_tasks:
t.cancel()
break
else:
# 记录失败
failed_count += 1
if error:
print(f"并发请求失败: {error}")
# 如果有成功的响应,使用它进行流式传输
if successful_response:
response, initial_buffer = successful_response
print("使用成功的并发响应进行流式传输")
# 流式输出响应
async for data in stream_notion_response_single(session, response, initial_buffer, chunk_id, created_time):
yield data
# Send the final chunk indicating stop
final_chunk = ChatCompletionChunk(
id=chunk_id,
created=created_time,
choices=[Choice(delta=ChoiceDelta(), finish_reason="stop")]
)
yield f"data: {final_chunk.model_dump_json()}\n\n"
yield "data: [DONE]\n\n"
return
# 只有当所有并发请求都失败时,才进入重试流程
print(f"所有 {CONCURRENT_REQUESTS} 个并发请求都失败,开始单请求重试流程...")
# 进入原有的重试逻辑(不使用并发)
for attempt in range(max_retries):
try:
# Using curl_cffi with chrome136 impersonation for better anti-bot bypass
async with AsyncSession(impersonate="chrome136") as session:
# Stream the response
response = await session.post(
NOTION_API_URL,
json=notion_request_body.model_dump(),
headers=headers,
stream=True
)
if response.status_code != 200:
error_content = await response.atext()
print(f"Error from Notion API: {response.status_code}")
print(f"Response: {error_content}")
raise HTTPException(status_code=response.status_code, detail=f"Notion API Error: {error_content}")
# Process streaming response
# curl_cffi streaming works differently - we need to read the content in chunks
buffer = ""
first_line_checked = False
is_error_response = False
async for chunk in response.aiter_content():
# Decode chunk if it's bytes
if isinstance(chunk, bytes):
chunk = chunk.decode('utf-8')
buffer += chunk
# Split by newlines and process complete lines
lines = buffer.split('\n')
# Keep the last incomplete line in the buffer
buffer = lines[-1]
for line in lines[:-1]:
line = line.strip()
if not line:
continue
try:
data = json.loads(line)
# 检查第一行是否是500错误响应
if not first_line_checked:
first_line_checked = True
if (data.get("type") == "error" and
data.get("message") and
"error code 500" in data.get("message", "")):
print(f"检测到Notion API 500错误 (重试 {attempt + 1}/{max_retries}): {data}")
is_error_response = True
break
# 如果不是错误响应,实时流式转发
# Check if it's the type of message containing text chunks
if data.get("type") == "markdown-chat" and isinstance(data.get("value"), str):
content_chunk = data["value"]
if content_chunk: # Only send if there's content
chunk_obj = ChatCompletionChunk(
id=chunk_id,
created=created_time,
choices=[Choice(delta=ChoiceDelta(content=content_chunk))]
)
yield f"data: {chunk_obj.model_dump_json()}\n\n"
# Add logic here to detect the end of the stream if Notion has a specific marker
# For now, we assume markdown-chat stops when the main content is done.
# If we see a recordMap, it's definitely past the text stream.
elif "recordMap" in data:
print("Detected recordMap, stopping stream.")
# Process any remaining buffer
if buffer.strip():
try:
last_data = json.loads(buffer.strip())
if last_data.get("type") == "markdown-chat" and isinstance(last_data.get("value"), str):
if last_data["value"]:
last_chunk = ChatCompletionChunk(
id=chunk_id,
created=created_time,
choices=[Choice(delta=ChoiceDelta(content=last_data["value"]))]
)
yield f"data: {last_chunk.model_dump_json()}\n\n"
except:
pass
# Exit the loop
break
except json.JSONDecodeError as e:
print(f"Warning: Could not decode JSON line: {line[:100]}... Error: {str(e)}")
except Exception as e:
print(f"Error processing line: {str(e)}")
# Continue processing other lines
if is_error_response:
break
# 如果检测到错误,进行重试
if is_error_response:
if attempt < max_retries - 1:
print(f"等待 {retry_delay} 秒后重试...")
await asyncio.sleep(retry_delay)
continue # 重试
else:
# 所有重试都失败了,通过流式响应返回错误信息
print("所有重试都失败,返回500错误给客户端")
error_chunk = ChatCompletionChunk(
id=chunk_id,
created=created_time,
choices=[Choice(delta=ChoiceDelta(content="Error: Notion API returned error code 500 after all retries"), finish_reason="stop")]
)
yield f"data: {error_chunk.model_dump_json()}\n\n"
yield "data: [DONE]\n\n"
return
# 如果没有错误,发送最终的停止信号
# Send the final chunk indicating stop
final_chunk = ChatCompletionChunk(
id=chunk_id,
created=created_time,
choices=[Choice(delta=ChoiceDelta(), finish_reason="stop")]
)
yield f"data: {final_chunk.model_dump_json()}\n\n"
yield "data: [DONE]\n\n"
# 成功完成,退出重试循环
break
except HTTPException:
# 在流式响应中不能抛出HTTPException,通过流式响应返回错误
if attempt < max_retries - 1:
print(f"HTTP异常,等待 {retry_delay} 秒后重试...")
await asyncio.sleep(retry_delay)
continue
else:
print("HTTP异常且无更多重试,返回错误信息")
error_chunk = ChatCompletionChunk(
id=chunk_id,
created=created_time,
choices=[Choice(delta=ChoiceDelta(content="Error: HTTP exception occurred after all retries"), finish_reason="stop")]
)
yield f"data: {error_chunk.model_dump_json()}\n\n"
yield "data: [DONE]\n\n"
return
except Exception as e:
print(f"Unexpected error during streaming (attempt {attempt + 1}/{max_retries}): {e}")
if attempt < max_retries - 1:
print(f"等待 {retry_delay} 秒后重试...")
await asyncio.sleep(retry_delay)
continue
else:
print("意外错误且无更多重试,返回错误信息")
error_chunk = ChatCompletionChunk(
id=chunk_id,
created=created_time,
choices=[Choice(delta=ChoiceDelta(content=f"Error: Internal server error during streaming: {e}"), finish_reason="stop")]
)
yield f"data: {error_chunk.model_dump_json()}\n\n"
yield "data: [DONE]\n\n"
return
# --- API Endpoints ---
@app.get("/v1/models", response_model=ModelList)
async def list_models(authenticated: bool = Depends(authenticate)):
"""
Endpoint to list available Notion models, mimicking OpenAI's /v1/models.
"""
available_models = [
"openai-gpt-4.1",
"anthropic-opus-4",
"anthropic-sonnet-4"
]
model_list = [
Model(id=model_id, owned_by="notion") # created uses default_factory
for model_id in available_models
]
return ModelList(data=model_list)
@app.post("/v1/chat/completions")
async def chat_completions(request_data: ChatCompletionRequest, request: Request, authenticated: bool = Depends(authenticate)):
"""
Endpoint to mimic OpenAI's chat completions, proxying to Notion.
"""
if not NOTION_COOKIE:
raise HTTPException(status_code=500, detail="Server configuration error: Notion cookie not set.")
notion_request_body = build_notion_request(request_data)
if request_data.stream:
return StreamingResponse(
stream_notion_response(notion_request_body),
media_type="text/event-stream"
)
else:
# --- Non-Streaming Logic (Optional - Collects stream internally) ---
# Note: The primary goal is streaming, but a non-streaming version
# might be useful for testing or simpler clients.
# This requires collecting all chunks from the async generator.
full_response_content = ""
final_finish_reason = None
chunk_id = f"chatcmpl-{uuid.uuid4()}" # Generate ID for the non-streamed response
created_time = int(time.time())
try:
async for line in stream_notion_response(notion_request_body):
if line.startswith("data: ") and "[DONE]" not in line:
try:
data_json = line[len("data: "):].strip()
if data_json:
chunk_data = json.loads(data_json)
if chunk_data.get("choices"):
delta = chunk_data["choices"][0].get("delta", {})
content = delta.get("content")
if content:
full_response_content += content
finish_reason = chunk_data["choices"][0].get("finish_reason")
if finish_reason:
final_finish_reason = finish_reason
except json.JSONDecodeError:
print(f"Warning: Could not decode JSON line in non-streaming mode: {line}")
# Construct the final OpenAI-compatible non-streaming response
return {
"id": chunk_id,
"object": "chat.completion",
"created": created_time,
"model": request_data.model, # Return the model requested by the client
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": full_response_content,
},
"finish_reason": final_finish_reason or "stop", # Default to stop if not explicitly set
}
],
"usage": { # Note: Token usage is not available from Notion
"prompt_tokens": None,
"completion_tokens": None,
"total_tokens": None,
},
}
except HTTPException as e:
# Re-raise HTTP exceptions from the streaming function
raise e
except Exception as e:
print(f"Error during non-streaming processing: {e}")
raise HTTPException(status_code=500, detail="Internal server error processing Notion response")
if __name__ == "__main__":
import uvicorn
print("Starting server. Access at http://localhost:7860")
print("Ensure NOTION_COOKIE is set in your .env file or environment.")
uvicorn.run(app, host="0.0.0.0", port=7860)