import asyncio import json from typing import AsyncGenerator, List, Optional from fastapi import APIRouter, HTTPException, Path from fastapi.responses import StreamingResponse from langchain_core.messages import HumanMessage from langchain_core.runnables import RunnableConfig from loguru import logger from pydantic import BaseModel, Field from open_notebook.database.repository import ensure_record_id, repo_query from open_notebook.domain.notebook import ChatSession, Source from open_notebook.exceptions import ( NotFoundError, ) from open_notebook.graphs.source_chat import source_chat_graph as source_chat_graph router = APIRouter() # Request/Response models class CreateSourceChatSessionRequest(BaseModel): source_id: str = Field(..., description="Source ID to create chat session for") title: Optional[str] = Field(None, description="Optional session title") model_override: Optional[str] = Field(None, description="Optional model override for this session") class UpdateSourceChatSessionRequest(BaseModel): title: Optional[str] = Field(None, description="New session title") model_override: Optional[str] = Field(None, description="Model override for this session") class ChatMessage(BaseModel): id: str = Field(..., description="Message ID") type: str = Field(..., description="Message type (human|ai)") content: str = Field(..., description="Message content") timestamp: Optional[str] = Field(None, description="Message timestamp") class ContextIndicator(BaseModel): sources: List[str] = Field(default_factory=list, description="Source IDs used in context") insights: List[str] = Field(default_factory=list, description="Insight IDs used in context") notes: List[str] = Field(default_factory=list, description="Note IDs used in context") class SourceChatSessionResponse(BaseModel): id: str = Field(..., description="Session ID") title: str = Field(..., description="Session title") source_id: str = Field(..., description="Source ID") model_override: Optional[str] = Field(None, description="Model override for this session") created: str = Field(..., description="Creation timestamp") updated: str = Field(..., description="Last update timestamp") message_count: Optional[int] = Field(None, description="Number of messages in session") class SourceChatSessionWithMessagesResponse(SourceChatSessionResponse): messages: List[ChatMessage] = Field(default_factory=list, description="Session messages") context_indicators: Optional[ContextIndicator] = Field(None, description="Context indicators from last response") class SendMessageRequest(BaseModel): message: str = Field(..., description="User message content") model_override: Optional[str] = Field(None, description="Optional model override for this message") class SuccessResponse(BaseModel): success: bool = Field(True, description="Operation success status") message: str = Field(..., description="Success message") @router.post("/sources/{source_id}/chat/sessions", response_model=SourceChatSessionResponse) async def create_source_chat_session( request: CreateSourceChatSessionRequest, source_id: str = Path(..., description="Source ID") ): """Create a new chat session for a source.""" try: # Verify source exists full_source_id = source_id if source_id.startswith("source:") else f"source:{source_id}" source = await Source.get(full_source_id) if not source: raise HTTPException(status_code=404, detail="Source not found") # Create new session with model_override support session = ChatSession( title=request.title or f"Source Chat {asyncio.get_event_loop().time():.0f}", model_override=request.model_override ) await session.save() # Relate session to source using "refers_to" relation await session.relate("refers_to", full_source_id) return SourceChatSessionResponse( id=session.id or "", title=session.title or "Untitled Session", source_id=source_id, model_override=session.model_override, created=str(session.created), updated=str(session.updated), message_count=0 ) except NotFoundError: raise HTTPException(status_code=404, detail="Source not found") except Exception as e: logger.error(f"Error creating source chat session: {str(e)}") raise HTTPException(status_code=500, detail=f"Error creating source chat session: {str(e)}") @router.get("/sources/{source_id}/chat/sessions", response_model=List[SourceChatSessionResponse]) async def get_source_chat_sessions( source_id: str = Path(..., description="Source ID") ): """Get all chat sessions for a source.""" try: # Verify source exists full_source_id = source_id if source_id.startswith("source:") else f"source:{source_id}" source = await Source.get(full_source_id) if not source: raise HTTPException(status_code=404, detail="Source not found") # Get sessions that refer to this source - first get relations, then sessions relations = await repo_query( "SELECT in FROM refers_to WHERE out = $source_id", {"source_id": ensure_record_id(full_source_id)} ) sessions = [] for relation in relations: session_id = relation.get("in") if session_id: session_result = await repo_query(f"SELECT * FROM {session_id}") if session_result and len(session_result) > 0: session_data = session_result[0] sessions.append(SourceChatSessionResponse( id=session_data.get("id") or "", title=session_data.get("title") or "Untitled Session", source_id=source_id, model_override=session_data.get("model_override"), created=str(session_data.get("created")), updated=str(session_data.get("updated")), message_count=0 # TODO: Add message count if needed )) # Sort sessions by created date (newest first) sessions.sort(key=lambda x: x.created, reverse=True) return sessions except NotFoundError: raise HTTPException(status_code=404, detail="Source not found") except Exception as e: logger.error(f"Error fetching source chat sessions: {str(e)}") raise HTTPException(status_code=500, detail=f"Error fetching source chat sessions: {str(e)}") @router.get("/sources/{source_id}/chat/sessions/{session_id}", response_model=SourceChatSessionWithMessagesResponse) async def get_source_chat_session( source_id: str = Path(..., description="Source ID"), session_id: str = Path(..., description="Session ID") ): """Get a specific source chat session with its messages.""" try: # Verify source exists full_source_id = source_id if source_id.startswith("source:") else f"source:{source_id}" source = await Source.get(full_source_id) if not source: raise HTTPException(status_code=404, detail="Source not found") # Get session full_session_id = session_id if session_id.startswith("chat_session:") else f"chat_session:{session_id}" session = await ChatSession.get(full_session_id) if not session: raise HTTPException(status_code=404, detail="Session not found") # Verify session is related to this source relation_query = await repo_query( "SELECT * FROM refers_to WHERE in = $session_id AND out = $source_id", {"session_id": ensure_record_id(full_session_id), "source_id": ensure_record_id(full_source_id)} ) if not relation_query: raise HTTPException(status_code=404, detail="Session not found for this source") # Get session state from LangGraph to retrieve messages thread_state = source_chat_graph.get_state( config=RunnableConfig(configurable={"thread_id": session_id}) ) # Extract messages from state messages: list[ChatMessage] = [] context_indicators = None if thread_state and thread_state.values: # Extract messages if "messages" in thread_state.values: for msg in thread_state.values["messages"]: messages.append(ChatMessage( id=getattr(msg, 'id', f"msg_{len(messages)}"), type=msg.type if hasattr(msg, 'type') else 'unknown', content=msg.content if hasattr(msg, 'content') else str(msg), timestamp=None # LangChain messages don't have timestamps by default )) # Extract context indicators from the last state if "context_indicators" in thread_state.values: context_data = thread_state.values["context_indicators"] context_indicators = ContextIndicator( sources=context_data.get("sources", []), insights=context_data.get("insights", []), notes=context_data.get("notes", []) ) return SourceChatSessionWithMessagesResponse( id=session.id or "", title=session.title or "Untitled Session", source_id=source_id, model_override=getattr(session, 'model_override', None), created=str(session.created), updated=str(session.updated), message_count=len(messages), messages=messages, context_indicators=context_indicators ) except NotFoundError: raise HTTPException(status_code=404, detail="Source or session not found") except Exception as e: logger.error(f"Error fetching source chat session: {str(e)}") raise HTTPException(status_code=500, detail=f"Error fetching source chat session: {str(e)}") @router.put("/sources/{source_id}/chat/sessions/{session_id}", response_model=SourceChatSessionResponse) async def update_source_chat_session( request: UpdateSourceChatSessionRequest, source_id: str = Path(..., description="Source ID"), session_id: str = Path(..., description="Session ID") ): """Update source chat session title and/or model override.""" try: # Verify source exists full_source_id = source_id if source_id.startswith("source:") else f"source:{source_id}" source = await Source.get(full_source_id) if not source: raise HTTPException(status_code=404, detail="Source not found") # Get session full_session_id = session_id if session_id.startswith("chat_session:") else f"chat_session:{session_id}" session = await ChatSession.get(full_session_id) if not session: raise HTTPException(status_code=404, detail="Session not found") # Verify session is related to this source relation_query = await repo_query( "SELECT * FROM refers_to WHERE in = $session_id AND out = $source_id", {"session_id": ensure_record_id(full_session_id), "source_id": ensure_record_id(full_source_id)} ) if not relation_query: raise HTTPException(status_code=404, detail="Session not found for this source") # Update session fields if request.title is not None: session.title = request.title if request.model_override is not None: session.model_override = request.model_override await session.save() return SourceChatSessionResponse( id=session.id or "", title=session.title or "Untitled Session", source_id=source_id, model_override=getattr(session, 'model_override', None), created=str(session.created), updated=str(session.updated), message_count=0 ) except NotFoundError: raise HTTPException(status_code=404, detail="Source or session not found") except Exception as e: logger.error(f"Error updating source chat session: {str(e)}") raise HTTPException(status_code=500, detail=f"Error updating source chat session: {str(e)}") @router.delete("/sources/{source_id}/chat/sessions/{session_id}", response_model=SuccessResponse) async def delete_source_chat_session( source_id: str = Path(..., description="Source ID"), session_id: str = Path(..., description="Session ID") ): """Delete a source chat session.""" try: # Verify source exists full_source_id = source_id if source_id.startswith("source:") else f"source:{source_id}" source = await Source.get(full_source_id) if not source: raise HTTPException(status_code=404, detail="Source not found") # Get session full_session_id = session_id if session_id.startswith("chat_session:") else f"chat_session:{session_id}" session = await ChatSession.get(full_session_id) if not session: raise HTTPException(status_code=404, detail="Session not found") # Verify session is related to this source relation_query = await repo_query( "SELECT * FROM refers_to WHERE in = $session_id AND out = $source_id", {"session_id": ensure_record_id(full_session_id), "source_id": ensure_record_id(full_source_id)} ) if not relation_query: raise HTTPException(status_code=404, detail="Session not found for this source") await session.delete() return SuccessResponse( success=True, message="Source chat session deleted successfully" ) except NotFoundError: raise HTTPException(status_code=404, detail="Source or session not found") except Exception as e: logger.error(f"Error deleting source chat session: {str(e)}") raise HTTPException(status_code=500, detail=f"Error deleting source chat session: {str(e)}") async def stream_source_chat_response( session_id: str, source_id: str, message: str, model_override: Optional[str] = None ) -> AsyncGenerator[str, None]: """Stream the source chat response as Server-Sent Events.""" try: # Get current state current_state = source_chat_graph.get_state( config=RunnableConfig(configurable={"thread_id": session_id}) ) # Prepare state for execution state_values = current_state.values if current_state else {} state_values["messages"] = state_values.get("messages", []) state_values["source_id"] = source_id state_values["model_override"] = model_override # Add user message to state user_message = HumanMessage(content=message) state_values["messages"].append(user_message) # Send user message event user_event = { "type": "user_message", "content": message, "timestamp": None } yield f"data: {json.dumps(user_event)}\n\n" # Execute source chat graph synchronously (like notebook chat does) result = source_chat_graph.invoke( input=state_values, # type: ignore[arg-type] config=RunnableConfig( configurable={ "thread_id": session_id, "model_id": model_override } ) ) # Stream the complete AI response if "messages" in result: for msg in result["messages"]: if hasattr(msg, 'type') and msg.type == 'ai': ai_event = { "type": "ai_message", "content": msg.content if hasattr(msg, 'content') else str(msg), "timestamp": None } yield f"data: {json.dumps(ai_event)}\n\n" # Stream context indicators if "context_indicators" in result: context_event = { "type": "context_indicators", "data": result["context_indicators"] } yield f"data: {json.dumps(context_event)}\n\n" # Send completion signal completion_event = {"type": "complete"} yield f"data: {json.dumps(completion_event)}\n\n" except Exception as e: logger.error(f"Error in source chat streaming: {str(e)}") error_event = {"type": "error", "message": str(e)} yield f"data: {json.dumps(error_event)}\n\n" @router.post("/sources/{source_id}/chat/sessions/{session_id}/messages") async def send_message_to_source_chat( request: SendMessageRequest, source_id: str = Path(..., description="Source ID"), session_id: str = Path(..., description="Session ID") ): """Send a message to source chat session with SSE streaming response.""" try: # Verify source exists full_source_id = source_id if source_id.startswith("source:") else f"source:{source_id}" source = await Source.get(full_source_id) if not source: raise HTTPException(status_code=404, detail="Source not found") # Verify session exists and is related to source full_session_id = session_id if session_id.startswith("chat_session:") else f"chat_session:{session_id}" session = await ChatSession.get(full_session_id) if not session: raise HTTPException(status_code=404, detail="Session not found") # Verify session is related to this source relation_query = await repo_query( "SELECT * FROM refers_to WHERE in = $session_id AND out = $source_id", {"session_id": ensure_record_id(full_session_id), "source_id": ensure_record_id(full_source_id)} ) if not relation_query: raise HTTPException(status_code=404, detail="Session not found for this source") if not request.message: raise HTTPException(status_code=400, detail="Message content is required") # Determine model override (request override takes precedence over session override) model_override = request.model_override or getattr(session, 'model_override', None) # Update session timestamp await session.save() # Return streaming response return StreamingResponse( stream_source_chat_response( session_id=session_id, source_id=full_source_id, message=request.message, model_override=model_override ), media_type="text/plain", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "Content-Type": "text/plain; charset=utf-8" } ) except HTTPException: raise except Exception as e: logger.error(f"Error sending message to source chat: {str(e)}") raise HTTPException(status_code=500, detail=f"Error sending message: {str(e)}")