ximilala commited on
Commit
5bed516
·
verified ·
1 Parent(s): 9231f37

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +298 -0
main.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import uuid
3
+ import json
4
+ import time
5
+ import httpx
6
+ from fastapi import FastAPI, Request, HTTPException, Depends, status
7
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
8
+ from fastapi.responses import StreamingResponse
9
+ from dotenv import load_dotenv
10
+ import secrets # Added for secure comparison
11
+ from models import (
12
+ ChatMessage, ChatCompletionRequest, NotionTranscriptConfigValue,
13
+ NotionTranscriptItem, NotionDebugOverrides, NotionRequestBody,
14
+ ChoiceDelta, Choice, ChatCompletionChunk, Model, ModelList
15
+ )
16
+
17
+ # Load environment variables from .env file
18
+ load_dotenv()
19
+
20
+ # --- Configuration ---
21
+ NOTION_API_URL = "https://www.notion.so/api/v3/runInferenceTranscript"
22
+ # IMPORTANT: Load the Notion cookie securely from environment variables
23
+ NOTION_COOKIE = os.getenv("NOTION_COOKIE")
24
+
25
+ NOTION_SPACE_ID = os.getenv("NOTION_SPACE_ID")
26
+ if not NOTION_COOKIE:
27
+ print("Error: NOTION_COOKIE environment variable not set.")
28
+ # Consider raising HTTPException or exiting in a real app
29
+ if not NOTION_SPACE_ID:
30
+ print("Warning: NOTION_SPACE_ID environment variable not set. Using a default UUID.")
31
+ # Using a default might not be ideal, depends on Notion's behavior
32
+ # Consider raising an error instead: raise ValueError("NOTION_SPACE_ID not set")
33
+ NOTION_SPACE_ID = str(uuid.uuid4()) # Default or raise error
34
+
35
+ # --- Authentication ---
36
+ EXPECTED_TOKEN = os.getenv("PROXY_AUTH_TOKEN", "default_token") # Default token
37
+ security = HTTPBearer()
38
+
39
+ def authenticate(credentials: HTTPAuthorizationCredentials = Depends(security)):
40
+ """Compares provided token with the expected token."""
41
+ correct_token = secrets.compare_digest(credentials.credentials, EXPECTED_TOKEN)
42
+ if not correct_token:
43
+ raise HTTPException(
44
+ status_code=status.HTTP_401_UNAUTHORIZED,
45
+ detail="Invalid authentication credentials",
46
+ # WWW-Authenticate header removed for Bearer
47
+ )
48
+ return True # Indicate successful authentication
49
+
50
+ # --- FastAPI App ---
51
+ app = FastAPI()
52
+
53
+ # --- Helper Functions ---
54
+
55
+ def build_notion_request(request_data: ChatCompletionRequest) -> NotionRequestBody:
56
+ """Transforms OpenAI-style messages to Notion transcript format."""
57
+ transcript = [
58
+ NotionTranscriptItem(
59
+ type="config",
60
+ value=NotionTranscriptConfigValue(model=request_data.notion_model)
61
+ )
62
+ ]
63
+ for message in request_data.messages:
64
+ # Map 'assistant' role to 'markdown-chat', all others to 'user'
65
+ if message.role == "assistant":
66
+ # Notion uses "markdown-chat" for assistant replies in the transcript history
67
+ transcript.append(NotionTranscriptItem(type="markdown-chat", value=message.content))
68
+ else: # Handles 'user', 'system', etc.
69
+ content = message.content
70
+ if isinstance(content, str):
71
+ # Handle string content: Append one item, using default [[""]] for empty strings
72
+ notion_value = [[content]] if content else [[""]]
73
+ transcript.append(NotionTranscriptItem(type="user", value=notion_value))
74
+ elif isinstance(content, list):
75
+ # Handle list content: Append a SEPARATE item for each valid text part
76
+ found_text_part = False
77
+ for part in content:
78
+ # Check if part is a dict with type="text" and non-empty text
79
+ if isinstance(part, dict) and part.get("type") == "text":
80
+ text_content = part.get("text")
81
+ if isinstance(text_content, str) and text_content:
82
+ # Create and append a SEPARATE item for this text part
83
+ transcript.append(NotionTranscriptItem(type="user", value=[[text_content]]))
84
+ found_text_part = True
85
+ # If the list was empty or had no valid text parts, append a default empty item to maintain behavior
86
+ if not found_text_part:
87
+ print(f'Error: no valid input found: {message}')
88
+ transcript.append(NotionTranscriptItem(type="user", value=[[""]]))
89
+ else:
90
+ # Handle unexpected content types (e.g., None, int) by appending a default empty item
91
+ transcript.append(NotionTranscriptItem(type="user", value=[[""]]))
92
+ print(f'Error: no valid input found: {message}')
93
+
94
+ # Use globally configured spaceId, set createThread=True
95
+ return NotionRequestBody(
96
+ spaceId=NOTION_SPACE_ID, # From environment variable
97
+ transcript=transcript,
98
+ createThread=True, # Always create a new thread
99
+ # Generate a new traceId for each request
100
+ traceId=str(uuid.uuid4()),
101
+ # Explicitly set debugOverrides, generateTitle, and saveAllThreadOperations
102
+ debugOverrides=NotionDebugOverrides(
103
+ cachedInferences={},
104
+ annotationInferences={},
105
+ emitInferences=False
106
+ ),
107
+ generateTitle=False,
108
+ saveAllThreadOperations=False
109
+ )
110
+
111
+ async def stream_notion_response(notion_request_body: NotionRequestBody):
112
+ """Streams the request to Notion and yields OpenAI-compatible SSE chunks."""
113
+ headers = {
114
+ 'accept': 'application/x-ndjson',
115
+ 'accept-language': 'en-US,en;q=0.9,zh-CN;q=0.8,zh;q=0.7,zh-TW;q=0.6,ja;q=0.5',
116
+ 'content-type': 'application/json',
117
+ 'notion-audit-log-platform': 'web',
118
+ 'notion-client-version': '23.13.0.3604', # Consider making this configurable
119
+ 'origin': 'https://www.notion.so',
120
+ 'priority': 'u=1, i',
121
+ # Referer might be optional or need adjustment. Removing threadId part.
122
+ 'referer': 'https://www.notion.so/chat',
123
+ 'sec-ch-ua': '"Chromium";v="136", "Google Chrome";v="136", "Not.A/Brand";v="99"',
124
+ 'sec-ch-ua-mobile': '?0',
125
+ 'sec-ch-ua-platform': '"Windows"',
126
+ 'sec-fetch-dest': 'empty',
127
+ 'sec-fetch-mode': 'cors',
128
+ 'sec-fetch-site': 'same-origin',
129
+ 'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/136.0.0.0 Safari/537.36',
130
+ 'cookie': NOTION_COOKIE, # Loaded from .env
131
+ 'x-notion-space-id': NOTION_SPACE_ID # Added space ID header
132
+ }
133
+
134
+ # Conditionally add the active user header
135
+ notion_active_user = os.getenv("NOTION_ACTIVE_USER_HEADER")
136
+ if notion_active_user: # Checks for None and empty string implicitly
137
+ headers['x-notion-active-user-header'] = notion_active_user
138
+
139
+ chunk_id = f"chatcmpl-{uuid.uuid4()}"
140
+ created_time = int(time.time())
141
+
142
+ try:
143
+ async with httpx.AsyncClient(timeout=None) as client: # No timeout for streaming
144
+ async with client.stream("POST", NOTION_API_URL, json=notion_request_body.dict(), headers=headers) as response:
145
+ if response.status_code != 200:
146
+ error_content = await response.aread()
147
+ print(f"Error from Notion API: {response.status_code}")
148
+ print(f"Response: {error_content.decode()}")
149
+ # Yield an error message in SSE format? Or just raise exception?
150
+ # For now, raise internal server error in the endpoint
151
+ raise HTTPException(status_code=response.status_code, detail=f"Notion API Error: {error_content.decode()}")
152
+
153
+ async for line in response.aiter_lines():
154
+ if not line.strip():
155
+ continue
156
+ try:
157
+ data = json.loads(line)
158
+ # Check if it's the type of message containing text chunks
159
+ if data.get("type") == "markdown-chat" and isinstance(data.get("value"), str):
160
+ content_chunk = data["value"]
161
+ if content_chunk: # Only send if there's content
162
+ chunk = ChatCompletionChunk(
163
+ id=chunk_id,
164
+ created=created_time,
165
+ choices=[Choice(delta=ChoiceDelta(content=content_chunk))]
166
+ )
167
+ yield f"data: {chunk.json()}\n\n"
168
+ # Add logic here to detect the end of the stream if Notion has a specific marker
169
+ # For now, we assume markdown-chat stops when the main content is done.
170
+ # If we see a recordMap, it's definitely past the text stream.
171
+ elif "recordMap" in data:
172
+ print("Detected recordMap, stopping stream.")
173
+ break # Stop processing after recordMap
174
+
175
+ except json.JSONDecodeError:
176
+ print(f"Warning: Could not decode JSON line: {line}")
177
+ except Exception as e:
178
+ print(f"Error processing line: {line} - {e}")
179
+ # Decide if we should continue or stop
180
+
181
+ # Send the final chunk indicating stop
182
+ final_chunk = ChatCompletionChunk(
183
+ id=chunk_id,
184
+ created=created_time,
185
+ choices=[Choice(delta=ChoiceDelta(), finish_reason="stop")]
186
+ )
187
+ yield f"data: {final_chunk.json()}\n\n"
188
+ yield "data: [DONE]\n\n"
189
+
190
+ except httpx.RequestError as e:
191
+ print(f"HTTPX Request Error: {e}")
192
+ # Yield an error message or handle in the endpoint
193
+ # For now, let the endpoint handle it
194
+ raise HTTPException(status_code=500, detail=f"Error connecting to Notion API: {e}")
195
+ except Exception as e:
196
+ print(f"Unexpected error during streaming: {e}")
197
+ # Yield an error message or handle in the endpoint
198
+ raise HTTPException(status_code=500, detail=f"Internal server error during streaming: {e}")
199
+
200
+
201
+ # --- API Endpoint ---
202
+
203
+ @app.get("/v1/models", response_model=ModelList)
204
+ async def list_models(authenticated: bool = Depends(authenticate)):
205
+ """
206
+ Endpoint to list available Notion models, mimicking OpenAI's /v1/models.
207
+ """
208
+ available_models = [
209
+ "openai-gpt-4.1",
210
+ "anthropic-opus-4",
211
+ "anthropic-sonnet-4"
212
+ ]
213
+ model_list = [
214
+ Model(id=model_id, owned_by="notion") # created uses default_factory
215
+ for model_id in available_models
216
+ ]
217
+ return ModelList(data=model_list)
218
+ @app.post("/v1/chat/completions")
219
+ async def chat_completions(request_data: ChatCompletionRequest, request: Request, authenticated: bool = Depends(authenticate)):
220
+ """
221
+ Endpoint to mimic OpenAI's chat completions, proxying to Notion.
222
+ """
223
+ if not NOTION_COOKIE:
224
+ raise HTTPException(status_code=500, detail="Server configuration error: Notion cookie not set.")
225
+
226
+ notion_request_body = build_notion_request(request_data)
227
+
228
+ if request_data.stream:
229
+ return StreamingResponse(
230
+ stream_notion_response(notion_request_body),
231
+ media_type="text/event-stream"
232
+ )
233
+ else:
234
+ # --- Non-Streaming Logic (Optional - Collects stream internally) ---
235
+ # Note: The primary goal is streaming, but a non-streaming version
236
+ # might be useful for testing or simpler clients.
237
+ # This requires collecting all chunks from the async generator.
238
+ full_response_content = ""
239
+ final_finish_reason = None
240
+ chunk_id = f"chatcmpl-{uuid.uuid4()}" # Generate ID for the non-streamed response
241
+ created_time = int(time.time())
242
+
243
+ try:
244
+ async for line in stream_notion_response(notion_request_body):
245
+ if line.startswith("data: ") and "[DONE]" not in line:
246
+ try:
247
+ data_json = line[len("data: "):].strip()
248
+ if data_json:
249
+ chunk_data = json.loads(data_json)
250
+ if chunk_data.get("choices"):
251
+ delta = chunk_data["choices"][0].get("delta", {})
252
+ content = delta.get("content")
253
+ if content:
254
+ full_response_content += content
255
+ finish_reason = chunk_data["choices"][0].get("finish_reason")
256
+ if finish_reason:
257
+ final_finish_reason = finish_reason
258
+ except json.JSONDecodeError:
259
+ print(f"Warning: Could not decode JSON line in non-streaming mode: {line}")
260
+
261
+ # Construct the final OpenAI-compatible non-streaming response
262
+ return {
263
+ "id": chunk_id,
264
+ "object": "chat.completion",
265
+ "created": created_time,
266
+ "model": request_data.model, # Return the model requested by the client
267
+ "choices": [
268
+ {
269
+ "index": 0,
270
+ "message": {
271
+ "role": "assistant",
272
+ "content": full_response_content,
273
+ },
274
+ "finish_reason": final_finish_reason or "stop", # Default to stop if not explicitly set
275
+ }
276
+ ],
277
+ "usage": { # Note: Token usage is not available from Notion
278
+ "prompt_tokens": None,
279
+ "completion_tokens": None,
280
+ "total_tokens": None,
281
+ },
282
+ }
283
+ except HTTPException as e:
284
+ # Re-raise HTTP exceptions from the streaming function
285
+ raise e
286
+ except Exception as e:
287
+ print(f"Error during non-streaming processing: {e}")
288
+ raise HTTPException(status_code=500, detail="Internal server error processing Notion response")
289
+
290
+
291
+ # --- Uvicorn Runner ---
292
+ # Allows running with `python main.py` for simple testing,
293
+ # but `uvicorn main:app --reload` is recommended for development.
294
+ if __name__ == "__main__":
295
+ import uvicorn
296
+ print("Starting server. Access at http://127.0.0.1:7860")
297
+ print("Ensure NOTION_COOKIE is set in your .env file or environment.")
298
+ uvicorn.run(app, host="127.0.0.1", port=7860)