Spaces:
Running
Running
| import asyncio | |
| import json | |
| from typing import AsyncGenerator, List, Optional | |
| from fastapi import APIRouter, HTTPException, Path | |
| from fastapi.responses import StreamingResponse | |
| from langchain_core.messages import HumanMessage | |
| from langchain_core.runnables import RunnableConfig | |
| from loguru import logger | |
| from pydantic import BaseModel, Field | |
| from open_notebook.database.repository import ensure_record_id, repo_query | |
| from open_notebook.domain.notebook import ChatSession, Source | |
| from open_notebook.exceptions import ( | |
| NotFoundError, | |
| ) | |
| from open_notebook.graphs.source_chat import source_chat_graph as source_chat_graph | |
| router = APIRouter() | |
| # Request/Response models | |
| class CreateSourceChatSessionRequest(BaseModel): | |
| source_id: str = Field(..., description="Source ID to create chat session for") | |
| title: Optional[str] = Field(None, description="Optional session title") | |
| model_override: Optional[str] = Field(None, description="Optional model override for this session") | |
| class UpdateSourceChatSessionRequest(BaseModel): | |
| title: Optional[str] = Field(None, description="New session title") | |
| model_override: Optional[str] = Field(None, description="Model override for this session") | |
| class ChatMessage(BaseModel): | |
| id: str = Field(..., description="Message ID") | |
| type: str = Field(..., description="Message type (human|ai)") | |
| content: str = Field(..., description="Message content") | |
| timestamp: Optional[str] = Field(None, description="Message timestamp") | |
| class ContextIndicator(BaseModel): | |
| sources: List[str] = Field(default_factory=list, description="Source IDs used in context") | |
| insights: List[str] = Field(default_factory=list, description="Insight IDs used in context") | |
| notes: List[str] = Field(default_factory=list, description="Note IDs used in context") | |
| class SourceChatSessionResponse(BaseModel): | |
| id: str = Field(..., description="Session ID") | |
| title: str = Field(..., description="Session title") | |
| source_id: str = Field(..., description="Source ID") | |
| model_override: Optional[str] = Field(None, description="Model override for this session") | |
| created: str = Field(..., description="Creation timestamp") | |
| updated: str = Field(..., description="Last update timestamp") | |
| message_count: Optional[int] = Field(None, description="Number of messages in session") | |
| class SourceChatSessionWithMessagesResponse(SourceChatSessionResponse): | |
| messages: List[ChatMessage] = Field(default_factory=list, description="Session messages") | |
| context_indicators: Optional[ContextIndicator] = Field(None, description="Context indicators from last response") | |
| class SendMessageRequest(BaseModel): | |
| message: str = Field(..., description="User message content") | |
| model_override: Optional[str] = Field(None, description="Optional model override for this message") | |
| class SuccessResponse(BaseModel): | |
| success: bool = Field(True, description="Operation success status") | |
| message: str = Field(..., description="Success message") | |
| async def create_source_chat_session( | |
| request: CreateSourceChatSessionRequest, | |
| source_id: str = Path(..., description="Source ID") | |
| ): | |
| """Create a new chat session for a source.""" | |
| try: | |
| # Verify source exists | |
| full_source_id = source_id if source_id.startswith("source:") else f"source:{source_id}" | |
| source = await Source.get(full_source_id) | |
| if not source: | |
| raise HTTPException(status_code=404, detail="Source not found") | |
| # Create new session with model_override support | |
| session = ChatSession( | |
| title=request.title or f"Source Chat {asyncio.get_event_loop().time():.0f}", | |
| model_override=request.model_override | |
| ) | |
| await session.save() | |
| # Relate session to source using "refers_to" relation | |
| await session.relate("refers_to", full_source_id) | |
| return SourceChatSessionResponse( | |
| id=session.id or "", | |
| title=session.title or "Untitled Session", | |
| source_id=source_id, | |
| model_override=session.model_override, | |
| created=str(session.created), | |
| updated=str(session.updated), | |
| message_count=0 | |
| ) | |
| except NotFoundError: | |
| raise HTTPException(status_code=404, detail="Source not found") | |
| except Exception as e: | |
| logger.error(f"Error creating source chat session: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Error creating source chat session: {str(e)}") | |
| async def get_source_chat_sessions( | |
| source_id: str = Path(..., description="Source ID") | |
| ): | |
| """Get all chat sessions for a source.""" | |
| try: | |
| # Verify source exists | |
| full_source_id = source_id if source_id.startswith("source:") else f"source:{source_id}" | |
| source = await Source.get(full_source_id) | |
| if not source: | |
| raise HTTPException(status_code=404, detail="Source not found") | |
| # Get sessions that refer to this source - first get relations, then sessions | |
| relations = await repo_query( | |
| "SELECT in FROM refers_to WHERE out = $source_id", | |
| {"source_id": ensure_record_id(full_source_id)} | |
| ) | |
| sessions = [] | |
| for relation in relations: | |
| session_id = relation.get("in") | |
| if session_id: | |
| session_result = await repo_query(f"SELECT * FROM {session_id}") | |
| if session_result and len(session_result) > 0: | |
| session_data = session_result[0] | |
| sessions.append(SourceChatSessionResponse( | |
| id=session_data.get("id") or "", | |
| title=session_data.get("title") or "Untitled Session", | |
| source_id=source_id, | |
| model_override=session_data.get("model_override"), | |
| created=str(session_data.get("created")), | |
| updated=str(session_data.get("updated")), | |
| message_count=0 # TODO: Add message count if needed | |
| )) | |
| # Sort sessions by created date (newest first) | |
| sessions.sort(key=lambda x: x.created, reverse=True) | |
| return sessions | |
| except NotFoundError: | |
| raise HTTPException(status_code=404, detail="Source not found") | |
| except Exception as e: | |
| logger.error(f"Error fetching source chat sessions: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Error fetching source chat sessions: {str(e)}") | |
| async def get_source_chat_session( | |
| source_id: str = Path(..., description="Source ID"), | |
| session_id: str = Path(..., description="Session ID") | |
| ): | |
| """Get a specific source chat session with its messages.""" | |
| try: | |
| # Verify source exists | |
| full_source_id = source_id if source_id.startswith("source:") else f"source:{source_id}" | |
| source = await Source.get(full_source_id) | |
| if not source: | |
| raise HTTPException(status_code=404, detail="Source not found") | |
| # Get session | |
| full_session_id = session_id if session_id.startswith("chat_session:") else f"chat_session:{session_id}" | |
| session = await ChatSession.get(full_session_id) | |
| if not session: | |
| raise HTTPException(status_code=404, detail="Session not found") | |
| # Verify session is related to this source | |
| relation_query = await repo_query( | |
| "SELECT * FROM refers_to WHERE in = $session_id AND out = $source_id", | |
| {"session_id": ensure_record_id(full_session_id), "source_id": ensure_record_id(full_source_id)} | |
| ) | |
| if not relation_query: | |
| raise HTTPException(status_code=404, detail="Session not found for this source") | |
| # Get session state from LangGraph to retrieve messages | |
| thread_state = source_chat_graph.get_state( | |
| config=RunnableConfig(configurable={"thread_id": session_id}) | |
| ) | |
| # Extract messages from state | |
| messages: list[ChatMessage] = [] | |
| context_indicators = None | |
| if thread_state and thread_state.values: | |
| # Extract messages | |
| if "messages" in thread_state.values: | |
| for msg in thread_state.values["messages"]: | |
| messages.append(ChatMessage( | |
| id=getattr(msg, 'id', f"msg_{len(messages)}"), | |
| type=msg.type if hasattr(msg, 'type') else 'unknown', | |
| content=msg.content if hasattr(msg, 'content') else str(msg), | |
| timestamp=None # LangChain messages don't have timestamps by default | |
| )) | |
| # Extract context indicators from the last state | |
| if "context_indicators" in thread_state.values: | |
| context_data = thread_state.values["context_indicators"] | |
| context_indicators = ContextIndicator( | |
| sources=context_data.get("sources", []), | |
| insights=context_data.get("insights", []), | |
| notes=context_data.get("notes", []) | |
| ) | |
| return SourceChatSessionWithMessagesResponse( | |
| id=session.id or "", | |
| title=session.title or "Untitled Session", | |
| source_id=source_id, | |
| model_override=getattr(session, 'model_override', None), | |
| created=str(session.created), | |
| updated=str(session.updated), | |
| message_count=len(messages), | |
| messages=messages, | |
| context_indicators=context_indicators | |
| ) | |
| except NotFoundError: | |
| raise HTTPException(status_code=404, detail="Source or session not found") | |
| except Exception as e: | |
| logger.error(f"Error fetching source chat session: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Error fetching source chat session: {str(e)}") | |
| async def update_source_chat_session( | |
| request: UpdateSourceChatSessionRequest, | |
| source_id: str = Path(..., description="Source ID"), | |
| session_id: str = Path(..., description="Session ID") | |
| ): | |
| """Update source chat session title and/or model override.""" | |
| try: | |
| # Verify source exists | |
| full_source_id = source_id if source_id.startswith("source:") else f"source:{source_id}" | |
| source = await Source.get(full_source_id) | |
| if not source: | |
| raise HTTPException(status_code=404, detail="Source not found") | |
| # Get session | |
| full_session_id = session_id if session_id.startswith("chat_session:") else f"chat_session:{session_id}" | |
| session = await ChatSession.get(full_session_id) | |
| if not session: | |
| raise HTTPException(status_code=404, detail="Session not found") | |
| # Verify session is related to this source | |
| relation_query = await repo_query( | |
| "SELECT * FROM refers_to WHERE in = $session_id AND out = $source_id", | |
| {"session_id": ensure_record_id(full_session_id), "source_id": ensure_record_id(full_source_id)} | |
| ) | |
| if not relation_query: | |
| raise HTTPException(status_code=404, detail="Session not found for this source") | |
| # Update session fields | |
| if request.title is not None: | |
| session.title = request.title | |
| if request.model_override is not None: | |
| session.model_override = request.model_override | |
| await session.save() | |
| return SourceChatSessionResponse( | |
| id=session.id or "", | |
| title=session.title or "Untitled Session", | |
| source_id=source_id, | |
| model_override=getattr(session, 'model_override', None), | |
| created=str(session.created), | |
| updated=str(session.updated), | |
| message_count=0 | |
| ) | |
| except NotFoundError: | |
| raise HTTPException(status_code=404, detail="Source or session not found") | |
| except Exception as e: | |
| logger.error(f"Error updating source chat session: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Error updating source chat session: {str(e)}") | |
| async def delete_source_chat_session( | |
| source_id: str = Path(..., description="Source ID"), | |
| session_id: str = Path(..., description="Session ID") | |
| ): | |
| """Delete a source chat session.""" | |
| try: | |
| # Verify source exists | |
| full_source_id = source_id if source_id.startswith("source:") else f"source:{source_id}" | |
| source = await Source.get(full_source_id) | |
| if not source: | |
| raise HTTPException(status_code=404, detail="Source not found") | |
| # Get session | |
| full_session_id = session_id if session_id.startswith("chat_session:") else f"chat_session:{session_id}" | |
| session = await ChatSession.get(full_session_id) | |
| if not session: | |
| raise HTTPException(status_code=404, detail="Session not found") | |
| # Verify session is related to this source | |
| relation_query = await repo_query( | |
| "SELECT * FROM refers_to WHERE in = $session_id AND out = $source_id", | |
| {"session_id": ensure_record_id(full_session_id), "source_id": ensure_record_id(full_source_id)} | |
| ) | |
| if not relation_query: | |
| raise HTTPException(status_code=404, detail="Session not found for this source") | |
| await session.delete() | |
| return SuccessResponse( | |
| success=True, | |
| message="Source chat session deleted successfully" | |
| ) | |
| except NotFoundError: | |
| raise HTTPException(status_code=404, detail="Source or session not found") | |
| except Exception as e: | |
| logger.error(f"Error deleting source chat session: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Error deleting source chat session: {str(e)}") | |
| async def stream_source_chat_response( | |
| session_id: str, | |
| source_id: str, | |
| message: str, | |
| model_override: Optional[str] = None | |
| ) -> AsyncGenerator[str, None]: | |
| """Stream the source chat response as Server-Sent Events.""" | |
| try: | |
| # Get current state | |
| current_state = source_chat_graph.get_state( | |
| config=RunnableConfig(configurable={"thread_id": session_id}) | |
| ) | |
| # Prepare state for execution | |
| state_values = current_state.values if current_state else {} | |
| state_values["messages"] = state_values.get("messages", []) | |
| state_values["source_id"] = source_id | |
| state_values["model_override"] = model_override | |
| # Add user message to state | |
| user_message = HumanMessage(content=message) | |
| state_values["messages"].append(user_message) | |
| # Send user message event | |
| user_event = { | |
| "type": "user_message", | |
| "content": message, | |
| "timestamp": None | |
| } | |
| yield f"data: {json.dumps(user_event)}\n\n" | |
| # Execute source chat graph synchronously (like notebook chat does) | |
| result = source_chat_graph.invoke( | |
| input=state_values, # type: ignore[arg-type] | |
| config=RunnableConfig( | |
| configurable={ | |
| "thread_id": session_id, | |
| "model_id": model_override | |
| } | |
| ) | |
| ) | |
| # Stream the complete AI response | |
| if "messages" in result: | |
| for msg in result["messages"]: | |
| if hasattr(msg, 'type') and msg.type == 'ai': | |
| ai_event = { | |
| "type": "ai_message", | |
| "content": msg.content if hasattr(msg, 'content') else str(msg), | |
| "timestamp": None | |
| } | |
| yield f"data: {json.dumps(ai_event)}\n\n" | |
| # Stream context indicators | |
| if "context_indicators" in result: | |
| context_event = { | |
| "type": "context_indicators", | |
| "data": result["context_indicators"] | |
| } | |
| yield f"data: {json.dumps(context_event)}\n\n" | |
| # Send completion signal | |
| completion_event = {"type": "complete"} | |
| yield f"data: {json.dumps(completion_event)}\n\n" | |
| except Exception as e: | |
| logger.error(f"Error in source chat streaming: {str(e)}") | |
| error_event = {"type": "error", "message": str(e)} | |
| yield f"data: {json.dumps(error_event)}\n\n" | |
| async def send_message_to_source_chat( | |
| request: SendMessageRequest, | |
| source_id: str = Path(..., description="Source ID"), | |
| session_id: str = Path(..., description="Session ID") | |
| ): | |
| """Send a message to source chat session with SSE streaming response.""" | |
| try: | |
| # Verify source exists | |
| full_source_id = source_id if source_id.startswith("source:") else f"source:{source_id}" | |
| source = await Source.get(full_source_id) | |
| if not source: | |
| raise HTTPException(status_code=404, detail="Source not found") | |
| # Verify session exists and is related to source | |
| full_session_id = session_id if session_id.startswith("chat_session:") else f"chat_session:{session_id}" | |
| session = await ChatSession.get(full_session_id) | |
| if not session: | |
| raise HTTPException(status_code=404, detail="Session not found") | |
| # Verify session is related to this source | |
| relation_query = await repo_query( | |
| "SELECT * FROM refers_to WHERE in = $session_id AND out = $source_id", | |
| {"session_id": ensure_record_id(full_session_id), "source_id": ensure_record_id(full_source_id)} | |
| ) | |
| if not relation_query: | |
| raise HTTPException(status_code=404, detail="Session not found for this source") | |
| if not request.message: | |
| raise HTTPException(status_code=400, detail="Message content is required") | |
| # Determine model override (request override takes precedence over session override) | |
| model_override = request.model_override or getattr(session, 'model_override', None) | |
| # Update session timestamp | |
| await session.save() | |
| # Return streaming response | |
| return StreamingResponse( | |
| stream_source_chat_response( | |
| session_id=session_id, | |
| source_id=full_source_id, | |
| message=request.message, | |
| model_override=model_override | |
| ), | |
| media_type="text/plain", | |
| headers={ | |
| "Cache-Control": "no-cache", | |
| "Connection": "keep-alive", | |
| "Content-Type": "text/plain; charset=utf-8" | |
| } | |
| ) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error sending message to source chat: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Error sending message: {str(e)}") |