Spaces:
Sleeping
Sleeping
| import asyncio | |
| from typing import Any, Dict, List, Optional | |
| from fastapi import APIRouter, HTTPException, Query | |
| from langchain_core.runnables import RunnableConfig | |
| from loguru import logger | |
| from pydantic import BaseModel, Field | |
| from open_notebook.database.repository import ensure_record_id, repo_query | |
| from open_notebook.domain.notebook import ChatSession, Note, Notebook, Source | |
| from open_notebook.exceptions import ( | |
| NotFoundError, | |
| ) | |
| from open_notebook.graphs.chat import graph as chat_graph | |
| router = APIRouter() | |
| # Request/Response models | |
| class CreateSessionRequest(BaseModel): | |
| notebook_id: str = Field(..., description="Notebook ID to create session for") | |
| title: Optional[str] = Field(None, description="Optional session title") | |
| model_override: Optional[str] = Field( | |
| None, description="Optional model override for this session" | |
| ) | |
| class UpdateSessionRequest(BaseModel): | |
| title: Optional[str] = Field(None, description="New session title") | |
| model_override: Optional[str] = Field( | |
| None, description="Model override for this session" | |
| ) | |
| class ChatMessage(BaseModel): | |
| id: str = Field(..., description="Message ID") | |
| type: str = Field(..., description="Message type (human|ai)") | |
| content: str = Field(..., description="Message content") | |
| timestamp: Optional[str] = Field(None, description="Message timestamp") | |
| class ChatSessionResponse(BaseModel): | |
| id: str = Field(..., description="Session ID") | |
| title: str = Field(..., description="Session title") | |
| notebook_id: Optional[str] = Field(None, description="Notebook ID") | |
| created: str = Field(..., description="Creation timestamp") | |
| updated: str = Field(..., description="Last update timestamp") | |
| message_count: Optional[int] = Field( | |
| None, description="Number of messages in session" | |
| ) | |
| model_override: Optional[str] = Field( | |
| None, description="Model override for this session" | |
| ) | |
| class ChatSessionWithMessagesResponse(ChatSessionResponse): | |
| messages: List[ChatMessage] = Field( | |
| default_factory=list, description="Session messages" | |
| ) | |
| class ExecuteChatRequest(BaseModel): | |
| session_id: str = Field(..., description="Chat session ID") | |
| message: str = Field(..., description="User message content") | |
| context: Dict[str, Any] = Field( | |
| ..., description="Chat context with sources and notes" | |
| ) | |
| model_override: Optional[str] = Field( | |
| None, description="Optional model override for this message" | |
| ) | |
| class ExecuteChatResponse(BaseModel): | |
| session_id: str = Field(..., description="Session ID") | |
| messages: List[ChatMessage] = Field(..., description="Updated message list") | |
| class BuildContextRequest(BaseModel): | |
| notebook_id: str = Field(..., description="Notebook ID") | |
| context_config: Dict[str, Any] = Field(..., description="Context configuration") | |
| class BuildContextResponse(BaseModel): | |
| context: Dict[str, Any] = Field(..., description="Built context data") | |
| token_count: int = Field(..., description="Estimated token count") | |
| char_count: int = Field(..., description="Character count") | |
| class SuccessResponse(BaseModel): | |
| success: bool = Field(True, description="Operation success status") | |
| message: str = Field(..., description="Success message") | |
| async def get_sessions(notebook_id: str = Query(..., description="Notebook ID")): | |
| """Get all chat sessions for a notebook.""" | |
| try: | |
| # Get notebook to verify it exists | |
| notebook = await Notebook.get(notebook_id) | |
| if not notebook: | |
| raise HTTPException(status_code=404, detail="Notebook not found") | |
| # Get sessions for this notebook | |
| sessions = await notebook.get_chat_sessions() | |
| return [ | |
| ChatSessionResponse( | |
| id=session.id or "", | |
| title=session.title or "Untitled Session", | |
| notebook_id=notebook_id, | |
| created=str(session.created), | |
| updated=str(session.updated), | |
| message_count=0, # TODO: Add message count if needed | |
| model_override=getattr(session, "model_override", None), | |
| ) | |
| for session in sessions | |
| ] | |
| except NotFoundError: | |
| raise HTTPException(status_code=404, detail="Notebook not found") | |
| except Exception as e: | |
| logger.error(f"Error fetching chat sessions: {str(e)}") | |
| raise HTTPException( | |
| status_code=500, detail=f"Error fetching chat sessions: {str(e)}" | |
| ) | |
| async def create_session(request: CreateSessionRequest): | |
| """Create a new chat session.""" | |
| try: | |
| # Verify notebook exists | |
| notebook = await Notebook.get(request.notebook_id) | |
| if not notebook: | |
| raise HTTPException(status_code=404, detail="Notebook not found") | |
| # Create new session | |
| session = ChatSession( | |
| title=request.title or f"Chat Session {asyncio.get_event_loop().time():.0f}", | |
| model_override=request.model_override, | |
| ) | |
| await session.save() | |
| # Relate session to notebook | |
| await session.relate_to_notebook(request.notebook_id) | |
| return ChatSessionResponse( | |
| id=session.id or "", | |
| title=session.title or "", | |
| notebook_id=request.notebook_id, | |
| created=str(session.created), | |
| updated=str(session.updated), | |
| message_count=0, | |
| model_override=session.model_override, | |
| ) | |
| except NotFoundError: | |
| raise HTTPException(status_code=404, detail="Notebook not found") | |
| except Exception as e: | |
| logger.error(f"Error creating chat session: {str(e)}") | |
| raise HTTPException( | |
| status_code=500, detail=f"Error creating chat session: {str(e)}" | |
| ) | |
| async def get_session(session_id: str): | |
| """Get a specific session with its messages.""" | |
| try: | |
| # Get session | |
| # Ensure session_id has proper table prefix | |
| full_session_id = ( | |
| session_id | |
| if session_id.startswith("chat_session:") | |
| else f"chat_session:{session_id}" | |
| ) | |
| session = await ChatSession.get(full_session_id) | |
| if not session: | |
| raise HTTPException(status_code=404, detail="Session not found") | |
| # Get session state from LangGraph to retrieve messages | |
| thread_state = chat_graph.get_state( | |
| config=RunnableConfig(configurable={"thread_id": session_id}) | |
| ) | |
| # Extract messages from state | |
| messages: list[ChatMessage] = [] | |
| if thread_state and thread_state.values and "messages" in thread_state.values: | |
| for msg in thread_state.values["messages"]: | |
| messages.append( | |
| ChatMessage( | |
| id=getattr(msg, "id", f"msg_{len(messages)}"), | |
| type=msg.type if hasattr(msg, "type") else "unknown", | |
| content=msg.content if hasattr(msg, "content") else str(msg), | |
| timestamp=None, # LangChain messages don't have timestamps by default | |
| ) | |
| ) | |
| # Find notebook_id (we need to query the relationship) | |
| # Ensure session_id has proper table prefix | |
| full_session_id = ( | |
| session_id | |
| if session_id.startswith("chat_session:") | |
| else f"chat_session:{session_id}" | |
| ) | |
| notebook_query = await repo_query( | |
| "SELECT out FROM refers_to WHERE in = $session_id", | |
| {"session_id": ensure_record_id(full_session_id)}, | |
| ) | |
| notebook_id = notebook_query[0]["out"] if notebook_query else None | |
| if not notebook_id: | |
| # This might be an old session created before API migration | |
| logger.warning( | |
| f"No notebook relationship found for session {session_id} - may be an orphaned session" | |
| ) | |
| return ChatSessionWithMessagesResponse( | |
| id=session.id or "", | |
| title=session.title or "Untitled Session", | |
| notebook_id=notebook_id, | |
| created=str(session.created), | |
| updated=str(session.updated), | |
| message_count=len(messages), | |
| messages=messages, | |
| model_override=getattr(session, "model_override", None), | |
| ) | |
| except NotFoundError: | |
| raise HTTPException(status_code=404, detail="Session not found") | |
| except Exception as e: | |
| logger.error(f"Error fetching session: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Error fetching session: {str(e)}") | |
| async def update_session(session_id: str, request: UpdateSessionRequest): | |
| """Update session title.""" | |
| try: | |
| # Ensure session_id has proper table prefix | |
| full_session_id = ( | |
| session_id | |
| if session_id.startswith("chat_session:") | |
| else f"chat_session:{session_id}" | |
| ) | |
| session = await ChatSession.get(full_session_id) | |
| if not session: | |
| raise HTTPException(status_code=404, detail="Session not found") | |
| update_data = request.model_dump(exclude_unset=True) | |
| if "title" in update_data: | |
| session.title = update_data["title"] | |
| if "model_override" in update_data: | |
| session.model_override = update_data["model_override"] | |
| await session.save() | |
| # Find notebook_id | |
| # Ensure session_id has proper table prefix | |
| full_session_id = ( | |
| session_id | |
| if session_id.startswith("chat_session:") | |
| else f"chat_session:{session_id}" | |
| ) | |
| notebook_query = await repo_query( | |
| "SELECT out FROM refers_to WHERE in = $session_id", | |
| {"session_id": ensure_record_id(full_session_id)}, | |
| ) | |
| notebook_id = notebook_query[0]["out"] if notebook_query else None | |
| return ChatSessionResponse( | |
| id=session.id or "", | |
| title=session.title or "", | |
| notebook_id=notebook_id, | |
| created=str(session.created), | |
| updated=str(session.updated), | |
| message_count=0, | |
| model_override=session.model_override, | |
| ) | |
| except NotFoundError: | |
| raise HTTPException(status_code=404, detail="Session not found") | |
| except Exception as e: | |
| logger.error(f"Error updating session: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Error updating session: {str(e)}") | |
| async def delete_session(session_id: str): | |
| """Delete a chat session.""" | |
| try: | |
| # Ensure session_id has proper table prefix | |
| full_session_id = ( | |
| session_id | |
| if session_id.startswith("chat_session:") | |
| else f"chat_session:{session_id}" | |
| ) | |
| session = await ChatSession.get(full_session_id) | |
| if not session: | |
| raise HTTPException(status_code=404, detail="Session not found") | |
| await session.delete() | |
| return SuccessResponse(success=True, message="Session deleted successfully") | |
| except NotFoundError: | |
| raise HTTPException(status_code=404, detail="Session not found") | |
| except Exception as e: | |
| logger.error(f"Error deleting session: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Error deleting session: {str(e)}") | |
| async def execute_chat(request: ExecuteChatRequest): | |
| """Execute a chat request and get AI response.""" | |
| try: | |
| # Verify session exists | |
| # Ensure session_id has proper table prefix | |
| full_session_id = ( | |
| request.session_id | |
| if request.session_id.startswith("chat_session:") | |
| else f"chat_session:{request.session_id}" | |
| ) | |
| session = await ChatSession.get(full_session_id) | |
| if not session: | |
| raise HTTPException(status_code=404, detail="Session not found") | |
| # Determine model override (per-request override takes precedence over session-level) | |
| model_override = ( | |
| request.model_override | |
| if request.model_override is not None | |
| else getattr(session, "model_override", None) | |
| ) | |
| # Get current state | |
| current_state = chat_graph.get_state( | |
| config=RunnableConfig( | |
| configurable={"thread_id": request.session_id} | |
| ) | |
| ) | |
| # Prepare state for execution | |
| state_values = current_state.values if current_state else {} | |
| state_values["messages"] = state_values.get("messages", []) | |
| state_values["context"] = request.context | |
| state_values["model_override"] = model_override | |
| # Add user message to state | |
| from langchain_core.messages import HumanMessage | |
| user_message = HumanMessage(content=request.message) | |
| state_values["messages"].append(user_message) | |
| # Execute chat graph | |
| result = chat_graph.invoke( | |
| input=state_values, # type: ignore[arg-type] | |
| config=RunnableConfig( | |
| configurable={ | |
| "thread_id": request.session_id, | |
| "model_id": model_override, | |
| } | |
| ), | |
| ) | |
| # Update session timestamp | |
| await session.save() | |
| # Convert messages to response format | |
| messages: list[ChatMessage] = [] | |
| for msg in result.get("messages", []): | |
| messages.append( | |
| ChatMessage( | |
| id=getattr(msg, "id", f"msg_{len(messages)}"), | |
| type=msg.type if hasattr(msg, "type") else "unknown", | |
| content=msg.content if hasattr(msg, "content") else str(msg), | |
| timestamp=None, | |
| ) | |
| ) | |
| return ExecuteChatResponse(session_id=request.session_id, messages=messages) | |
| except NotFoundError: | |
| raise HTTPException(status_code=404, detail="Session not found") | |
| except Exception as e: | |
| logger.error(f"Error executing chat: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Error executing chat: {str(e)}") | |
| async def build_context(request: BuildContextRequest): | |
| """Build context for a notebook based on context configuration.""" | |
| try: | |
| # Verify notebook exists | |
| notebook = await Notebook.get(request.notebook_id) | |
| if not notebook: | |
| raise HTTPException(status_code=404, detail="Notebook not found") | |
| context_data: dict[str, list[dict[str, str]]] = {"sources": [], "notes": []} | |
| total_content = "" | |
| # Process context configuration if provided | |
| if request.context_config: | |
| # Process sources | |
| for source_id, status in request.context_config.get("sources", {}).items(): | |
| if "not in" in status: | |
| continue | |
| try: | |
| # Add table prefix if not present | |
| full_source_id = ( | |
| source_id | |
| if source_id.startswith("source:") | |
| else f"source:{source_id}" | |
| ) | |
| try: | |
| source = await Source.get(full_source_id) | |
| except Exception: | |
| continue | |
| if "insights" in status: | |
| source_context = await source.get_context(context_size="short") | |
| context_data["sources"].append(source_context) | |
| total_content += str(source_context) | |
| elif "full content" in status: | |
| source_context = await source.get_context(context_size="long") | |
| context_data["sources"].append(source_context) | |
| total_content += str(source_context) | |
| except Exception as e: | |
| logger.warning(f"Error processing source {source_id}: {str(e)}") | |
| continue | |
| # Process notes | |
| for note_id, status in request.context_config.get("notes", {}).items(): | |
| if "not in" in status: | |
| continue | |
| try: | |
| # Add table prefix if not present | |
| full_note_id = ( | |
| note_id if note_id.startswith("note:") else f"note:{note_id}" | |
| ) | |
| note = await Note.get(full_note_id) | |
| if not note: | |
| continue | |
| if "full content" in status: | |
| note_context = note.get_context(context_size="long") | |
| context_data["notes"].append(note_context) | |
| total_content += str(note_context) | |
| except Exception as e: | |
| logger.warning(f"Error processing note {note_id}: {str(e)}") | |
| continue | |
| else: | |
| # Default behavior - include all sources and notes with short context | |
| sources = await notebook.get_sources() | |
| for source in sources: | |
| try: | |
| source_context = await source.get_context(context_size="short") | |
| context_data["sources"].append(source_context) | |
| total_content += str(source_context) | |
| except Exception as e: | |
| logger.warning(f"Error processing source {source.id}: {str(e)}") | |
| continue | |
| notes = await notebook.get_notes() | |
| for note in notes: | |
| try: | |
| note_context = note.get_context(context_size="short") | |
| context_data["notes"].append(note_context) | |
| total_content += str(note_context) | |
| except Exception as e: | |
| logger.warning(f"Error processing note {note.id}: {str(e)}") | |
| continue | |
| # Calculate character and token counts | |
| char_count = len(total_content) | |
| # Use token count utility if available | |
| try: | |
| from open_notebook.utils import token_count | |
| estimated_tokens = token_count(total_content) if total_content else 0 | |
| except ImportError: | |
| # Fallback to simple estimation | |
| estimated_tokens = char_count // 4 | |
| return BuildContextResponse( | |
| context=context_data, token_count=estimated_tokens, char_count=char_count | |
| ) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error building context: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Error building context: {str(e)}") | |