Spaces:
Sleeping
Sleeping
| import os | |
| import uuid | |
| import json | |
| import time | |
| import random | |
| import httpx | |
| from fastapi import FastAPI, Request, HTTPException, Depends, status | |
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
| from fastapi.responses import StreamingResponse | |
| from dotenv import load_dotenv | |
| import secrets # Added for secure comparison | |
| from datetime import datetime, timedelta, timezone # Explicit datetime imports | |
| from zoneinfo import ZoneInfo # For timezone handling | |
| from models import ( | |
| ChatMessage, ChatCompletionRequest, NotionTranscriptConfigValue, | |
| NotionTranscriptContextValue, NotionTranscriptItem, NotionDebugOverrides, | |
| NotionRequestBody, ChoiceDelta, Choice, ChatCompletionChunk, Model, ModelList | |
| ) | |
| # Load environment variables from .env file | |
| load_dotenv() | |
| # --- Configuration --- | |
| NOTION_API_URL = "https://www.notion.so/api/v3/runInferenceTranscript" | |
| # IMPORTANT: Load the Notion cookie securely from environment variables | |
| NOTION_COOKIE = os.getenv("NOTION_COOKIE") | |
| NOTION_SPACE_ID = os.getenv("NOTION_SPACE_ID") | |
| if not NOTION_COOKIE: | |
| print("Error: NOTION_COOKIE environment variable not set.") | |
| # Consider raising HTTPException or exiting in a real app | |
| if not NOTION_SPACE_ID: | |
| print("Warning: NOTION_SPACE_ID environment variable not set. Using a default UUID.") | |
| # Using a default might not be ideal, depends on Notion's behavior | |
| # Consider raising an error instead: raise ValueError("NOTION_SPACE_ID not set") | |
| NOTION_SPACE_ID = str(uuid.uuid4()) # Default or raise error | |
| # --- Authentication --- | |
| EXPECTED_TOKEN = os.getenv("PROXY_AUTH_TOKEN", "default_token") # Default token | |
| security = HTTPBearer() | |
| def authenticate(credentials: HTTPAuthorizationCredentials = Depends(security)): | |
| """Compares provided token with the expected token.""" | |
| correct_token = secrets.compare_digest(credentials.credentials, EXPECTED_TOKEN) | |
| if not correct_token: | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Invalid authentication credentials", | |
| # WWW-Authenticate header removed for Bearer | |
| ) | |
| return True # Indicate successful authentication | |
| # --- FastAPI App --- | |
| app = FastAPI() | |
| # --- Helper Functions --- | |
| def build_notion_request(request_data: ChatCompletionRequest) -> NotionRequestBody: | |
| """Transforms OpenAI-style messages to Notion transcript format, adding userId and createdAt.""" | |
| # --- Timestamp and User ID Logic --- | |
| user_id = os.getenv("NOTION_ACTIVE_USER_HEADER") | |
| user_messages = [msg for msg in request_data.messages if msg.role == "user"] | |
| num_user_messages = len(user_messages) | |
| user_message_timestamps = {} # Store timestamps keyed by message id | |
| if num_user_messages > 0: | |
| # Get current time specifically in Pacific Time (America/Los_Angeles) | |
| pacific_tz = ZoneInfo("America/Los_Angeles") | |
| now_pacific = datetime.now(timezone.utc).astimezone(pacific_tz) | |
| # Assign timestamp to the last user message | |
| last_user_msg_id = user_messages[-1].id | |
| user_message_timestamps[last_user_msg_id] = now_pacific | |
| # Calculate timestamps for previous user messages (10 mins earlier each) | |
| current_timestamp = now_pacific | |
| for i in range(num_user_messages - 2, -1, -1): # Iterate backwards from second-to-last | |
| current_timestamp -= timedelta(minutes=random.randint(3, 20)) # Use random interval (3-20 mins) | |
| user_message_timestamps[user_messages[i].id] = current_timestamp | |
| # --- Build Transcript --- | |
| # Get current time in Pacific timezone for context | |
| pacific_tz = ZoneInfo("America/Los_Angeles") | |
| now_pacific = datetime.now(timezone.utc).astimezone(pacific_tz) | |
| # Format timestamp exactly as YYYY-MM-DDTHH:MM:SS.fff-HH:MM | |
| dt_str = now_pacific.strftime("%Y-%m-%dT%H:%M:%S") | |
| ms = f"{now_pacific.microsecond // 1000:03d}" # Ensure 3 digits for milliseconds | |
| tz_str = now_pacific.strftime("%z") # Gets +HHMM or -HHMM | |
| formatted_tz = f"{tz_str[:-2]}:{tz_str[-2:]}" # Insert colon | |
| current_datetime_iso = f"{dt_str}.{ms}{formatted_tz}" | |
| # Generate random text for userName and spaceName | |
| random_words = ["Project", "Workspace", "Team", "Studio", "Lab", "Hub", "Zone", "Space"] | |
| user_name = f"User{random.randint(100, 999)}" | |
| space_name = f"{random.choice(random_words)} {random.randint(1, 99)}" | |
| transcript = [ | |
| NotionTranscriptItem( | |
| type="config", | |
| value=NotionTranscriptConfigValue(model=request_data.notion_model) | |
| ), | |
| NotionTranscriptItem( | |
| type="context", | |
| value=NotionTranscriptContextValue( | |
| userId=user_id or "", # Use the user_id from env or empty string | |
| spaceId=NOTION_SPACE_ID, | |
| surface="home_module", | |
| timezone="America/Los_Angeles", | |
| userName=user_name, | |
| spaceName=space_name, | |
| spaceViewId=str(uuid.uuid4()), # Random UUID for spaceViewId | |
| currentDatetime=current_datetime_iso | |
| ) | |
| ), | |
| NotionTranscriptItem( | |
| type="agent-integration" | |
| # No value field needed for agent-integration | |
| ) | |
| ] | |
| for message in request_data.messages: | |
| if message.role == "assistant": | |
| # Assistant messages get a traceId, but not userId or createdAt | |
| transcript.append(NotionTranscriptItem( | |
| type="markdown-chat", | |
| value=message.content, | |
| traceId=str(uuid.uuid4()) # Generate unique traceId for assistant message | |
| )) | |
| elif message.role == "user": | |
| created_at_dt = user_message_timestamps.get(message.id) | |
| created_at_iso = None | |
| if created_at_dt: | |
| # Format timestamp exactly as YYYY-MM-DDTHH:MM:SS.fff-HH:MM | |
| dt_str = created_at_dt.strftime("%Y-%m-%dT%H:%M:%S") | |
| ms = f"{created_at_dt.microsecond // 1000:03d}" # Ensure 3 digits for milliseconds | |
| tz_str = created_at_dt.strftime("%z") # Gets +HHMM or -HHMM | |
| formatted_tz = f"{tz_str[:-2]}:{tz_str[-2:]}" # Insert colon | |
| created_at_iso = f"{dt_str}.{ms}{formatted_tz}" | |
| content = message.content | |
| if isinstance(content, str): | |
| notion_value = [[content]] if content else [[""]] | |
| transcript.append(NotionTranscriptItem( | |
| type="user", | |
| value=notion_value, | |
| userId=user_id, | |
| createdAt=created_at_iso | |
| )) | |
| elif isinstance(content, list): | |
| found_text_part = False | |
| for part in content: | |
| if isinstance(part, dict) and part.get("type") == "text": | |
| text_content = part.get("text") | |
| if isinstance(text_content, str) and text_content: | |
| # Append separate item for each text part, with same userId/timestamp | |
| transcript.append(NotionTranscriptItem( | |
| type="user", | |
| value=[[text_content]], | |
| userId=user_id, | |
| createdAt=created_at_iso | |
| )) | |
| found_text_part = True | |
| if not found_text_part: | |
| # Append default empty item if no valid text parts found | |
| transcript.append(NotionTranscriptItem( | |
| type="user", | |
| value=[[""]], | |
| userId=user_id, | |
| createdAt=created_at_iso # Still add userId/timestamp | |
| )) | |
| print(f'Warning: No valid text parts found in user message list content: {message}') | |
| else: | |
| # Handle unexpected content types with default empty item | |
| transcript.append(NotionTranscriptItem( | |
| type="user", | |
| value=[[""]], | |
| userId=user_id, | |
| createdAt=created_at_iso # Still add userId/timestamp | |
| )) | |
| print(f'Warning: Unexpected content type in user message: {message}') | |
| # System messages are currently ignored in the Notion transcript based on original logic | |
| # else: # Handle 'system' role if needed in the future | |
| # Use globally configured spaceId, set createThread=True | |
| return NotionRequestBody( | |
| spaceId=NOTION_SPACE_ID, # From environment variable | |
| transcript=transcript, | |
| createThread=True, # Always create a new thread | |
| # Generate a new traceId for each request | |
| traceId=str(uuid.uuid4()), | |
| # Explicitly set debugOverrides, generateTitle, and saveAllThreadOperations | |
| debugOverrides=NotionDebugOverrides( | |
| cachedInferences={}, | |
| annotationInferences={}, | |
| emitInferences=False | |
| ), | |
| generateTitle=False, | |
| saveAllThreadOperations=False | |
| ) | |
| async def stream_notion_response(notion_request_body: NotionRequestBody): | |
| """Streams the request to Notion and yields OpenAI-compatible SSE chunks.""" | |
| headers = { | |
| 'accept': 'application/x-ndjson', | |
| 'accept-language': 'en-US,en;q=0.9', | |
| 'content-type': 'application/json', | |
| 'notion-audit-log-platform': 'web', | |
| 'notion-client-version': '23.13.0.3668', # Consider making this configurable | |
| 'origin': 'https://www.notion.so', | |
| 'priority': 'u=1, i', | |
| # Referer might be optional or need adjustment. Removing threadId part. | |
| 'referer': 'https://www.notion.so/chat', | |
| 'sec-ch-ua': '"Chromium";v="136", "Google Chrome";v="136", "Not.A/Brand";v="99"', | |
| 'sec-ch-ua-mobile': '?0', | |
| 'sec-ch-ua-platform': '"Windows"', | |
| 'sec-fetch-dest': 'empty', | |
| 'sec-fetch-mode': 'cors', | |
| 'sec-fetch-site': 'same-origin', | |
| 'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/136.0.0.0 Safari/537.36', | |
| 'cookie': NOTION_COOKIE, # Loaded from .env | |
| 'x-notion-space-id': NOTION_SPACE_ID # Added space ID header | |
| } | |
| # Conditionally add the active user header | |
| notion_active_user = os.getenv("NOTION_ACTIVE_USER_HEADER") | |
| if notion_active_user: # Checks for None and empty string implicitly | |
| headers['x-notion-active-user-header'] = notion_active_user | |
| chunk_id = f"chatcmpl-{uuid.uuid4()}" | |
| created_time = int(time.time()) | |
| try: | |
| async with httpx.AsyncClient(timeout=None) as client: # No timeout for streaming | |
| # Explicitly serialize using .json() to respect Pydantic Config (like json_encoders for UUID) | |
| request_body_json = notion_request_body.json() | |
| async with client.stream("POST", NOTION_API_URL, content=request_body_json, headers=headers) as response: | |
| if response.status_code != 200: | |
| error_content = await response.aread() | |
| print(f"Error from Notion API: {response.status_code}") | |
| print(f"Response: {error_content.decode()}") | |
| # Yield an error message in SSE format? Or just raise exception? | |
| # For now, raise internal server error in the endpoint | |
| raise HTTPException(status_code=response.status_code, detail=f"Notion API Error: {error_content.decode()}") | |
| async for line in response.aiter_lines(): | |
| if not line.strip(): | |
| continue | |
| try: | |
| data = json.loads(line) | |
| # Check if it's the type of message containing text chunks | |
| if data.get("type") == "markdown-chat" and isinstance(data.get("value"), str): | |
| content_chunk = data["value"] | |
| if content_chunk: # Only send if there's content | |
| chunk = ChatCompletionChunk( | |
| id=chunk_id, | |
| created=created_time, | |
| choices=[Choice(delta=ChoiceDelta(content=content_chunk))] | |
| ) | |
| yield f"data: {chunk.json()}\n\n" | |
| # Add logic here to detect the end of the stream if Notion has a specific marker | |
| # For now, we assume markdown-chat stops when the main content is done. | |
| # If we see a recordMap, it's definitely past the text stream. | |
| elif "recordMap" in data: | |
| print("Detected recordMap, stopping stream.") | |
| break # Stop processing after recordMap | |
| except json.JSONDecodeError: | |
| print(f"Warning: Could not decode JSON line: {line}") | |
| except Exception as e: | |
| print(f"Error processing line: {line} - {e}") | |
| # Decide if we should continue or stop | |
| # Send the final chunk indicating stop | |
| final_chunk = ChatCompletionChunk( | |
| id=chunk_id, | |
| created=created_time, | |
| choices=[Choice(delta=ChoiceDelta(), finish_reason="stop")] | |
| ) | |
| yield f"data: {final_chunk.json()}\n\n" | |
| yield "data: [DONE]\n\n" | |
| except httpx.RequestError as e: | |
| print(f"HTTPX Request Error: {e}") | |
| # Yield an error message or handle in the endpoint | |
| # For now, let the endpoint handle it | |
| raise HTTPException(status_code=500, detail=f"Error connecting to Notion API: {e}") | |
| except Exception as e: | |
| print(f"Unexpected error during streaming: {e}") | |
| # Yield an error message or handle in the endpoint | |
| raise HTTPException(status_code=500, detail=f"Internal server error during streaming: {e}") | |
| # --- API Endpoint --- | |
| async def list_models(authenticated: bool = Depends(authenticate)): | |
| """ | |
| Endpoint to list available Notion models, mimicking OpenAI's /v1/models. | |
| """ | |
| available_models = [ | |
| "openai-gpt-4.1", | |
| "anthropic-opus-4", | |
| "anthropic-sonnet-4" | |
| ] | |
| model_list = [ | |
| Model(id=model_id, owned_by="notion") # created uses default_factory | |
| for model_id in available_models | |
| ] | |
| return ModelList(data=model_list) | |
| async def chat_completions(request_data: ChatCompletionRequest, request: Request, authenticated: bool = Depends(authenticate)): | |
| """ | |
| Endpoint to mimic OpenAI's chat completions, proxying to Notion. | |
| """ | |
| if not NOTION_COOKIE: | |
| raise HTTPException(status_code=500, detail="Server configuration error: Notion cookie not set.") | |
| notion_request_body = build_notion_request(request_data) | |
| if request_data.stream: | |
| return StreamingResponse( | |
| stream_notion_response(notion_request_body), | |
| media_type="text/event-stream" | |
| ) | |
| else: | |
| # --- Non-Streaming Logic (Optional - Collects stream internally) --- | |
| # Note: The primary goal is streaming, but a non-streaming version | |
| # might be useful for testing or simpler clients. | |
| # This requires collecting all chunks from the async generator. | |
| full_response_content = "" | |
| final_finish_reason = None | |
| chunk_id = f"chatcmpl-{uuid.uuid4()}" # Generate ID for the non-streamed response | |
| created_time = int(time.time()) | |
| try: | |
| async for line in stream_notion_response(notion_request_body): | |
| if line.startswith("data: ") and "[DONE]" not in line: | |
| try: | |
| data_json = line[len("data: "):].strip() | |
| if data_json: | |
| chunk_data = json.loads(data_json) | |
| if chunk_data.get("choices"): | |
| delta = chunk_data["choices"][0].get("delta", {}) | |
| content = delta.get("content") | |
| if content: | |
| full_response_content += content | |
| finish_reason = chunk_data["choices"][0].get("finish_reason") | |
| if finish_reason: | |
| final_finish_reason = finish_reason | |
| except json.JSONDecodeError: | |
| print(f"Warning: Could not decode JSON line in non-streaming mode: {line}") | |
| # Construct the final OpenAI-compatible non-streaming response | |
| return { | |
| "id": chunk_id, | |
| "object": "chat.completion", | |
| "created": created_time, | |
| "model": request_data.model, # Return the model requested by the client | |
| "choices": [ | |
| { | |
| "index": 0, | |
| "message": { | |
| "role": "assistant", | |
| "content": full_response_content, | |
| }, | |
| "finish_reason": final_finish_reason or "stop", # Default to stop if not explicitly set | |
| } | |
| ], | |
| "usage": { # Note: Token usage is not available from Notion | |
| "prompt_tokens": None, | |
| "completion_tokens": None, | |
| "total_tokens": None, | |
| }, | |
| } | |
| except HTTPException as e: | |
| # Re-raise HTTP exceptions from the streaming function | |
| raise e | |
| except Exception as e: | |
| print(f"Error during non-streaming processing: {e}") | |
| raise HTTPException(status_code=500, detail="Internal server error processing Notion response") | |
| # --- Uvicorn Runner --- | |
| # Allows running with `python main.py` for simple testing, | |
| # but `uvicorn main:app --reload` is recommended for development. | |
| if __name__ == "__main__": | |
| import uvicorn | |
| print("Starting server. Access at http://127.0.0.1:7860") | |
| print("Ensure NOTION_COOKIE is set in your .env file or environment.") | |
| uvicorn.run(app, host="127.0.0.1", port=7860) |