Spaces:
Running
Running
| import asyncio | |
| import sqlite3 | |
| from typing import Annotated, Dict, List, Optional | |
| from ai_prompter import Prompter | |
| from langchain_core.messages import AIMessage, SystemMessage | |
| from langchain_core.runnables import RunnableConfig | |
| from open_notebook.utils import clean_thinking_content | |
| from langgraph.checkpoint.sqlite import SqliteSaver | |
| from langgraph.graph import END, START, StateGraph | |
| from langgraph.graph.message import add_messages | |
| from typing_extensions import TypedDict | |
| from open_notebook.config import LANGGRAPH_CHECKPOINT_FILE | |
| from open_notebook.domain.notebook import Source, SourceInsight | |
| from open_notebook.graphs.utils import provision_langchain_model | |
| from open_notebook.utils.context_builder import ContextBuilder | |
| class SourceChatState(TypedDict): | |
| messages: Annotated[list, add_messages] | |
| source_id: str | |
| source: Optional[Source] | |
| insights: Optional[List[SourceInsight]] | |
| context: Optional[str] | |
| model_override: Optional[str] | |
| context_indicators: Optional[Dict[str, List[str]]] | |
| def call_model_with_source_context( | |
| state: SourceChatState, config: RunnableConfig | |
| ) -> dict: | |
| """ | |
| Main function that builds source context and calls the model. | |
| This function: | |
| 1. Uses ContextBuilder to build source-specific context | |
| 2. Applies the source_chat Jinja2 prompt template | |
| 3. Handles model provisioning with override support | |
| 4. Tracks context indicators for referenced insights/content | |
| """ | |
| source_id = state.get("source_id") | |
| if not source_id: | |
| raise ValueError("source_id is required in state") | |
| # Build source context using ContextBuilder (run async code in new loop) | |
| def build_context(): | |
| """Build context in a new event loop""" | |
| new_loop = asyncio.new_event_loop() | |
| try: | |
| asyncio.set_event_loop(new_loop) | |
| context_builder = ContextBuilder( | |
| source_id=source_id, | |
| include_insights=True, | |
| include_notes=False, # Focus on source-specific content | |
| max_tokens=50000, # Reasonable limit for source context | |
| ) | |
| return new_loop.run_until_complete(context_builder.build()) | |
| finally: | |
| new_loop.close() | |
| asyncio.set_event_loop(None) | |
| # Get the built context | |
| try: | |
| # Try to get the current event loop | |
| asyncio.get_running_loop() | |
| # If we're in an event loop, run in a thread with a new loop | |
| import concurrent.futures | |
| with concurrent.futures.ThreadPoolExecutor() as executor: | |
| future = executor.submit(build_context) | |
| context_data = future.result() | |
| except RuntimeError: | |
| # No event loop running, safe to create a new one | |
| context_data = build_context() | |
| # Extract source and insights from context | |
| source = None | |
| insights = [] | |
| context_indicators: dict[str, list[str | None]] = { | |
| "sources": [], | |
| "insights": [], | |
| "notes": [], | |
| } | |
| if context_data.get("sources"): | |
| source_info = context_data["sources"][0] # First source | |
| source = Source(**source_info) if isinstance(source_info, dict) else source_info | |
| context_indicators["sources"].append(source.id) | |
| if context_data.get("insights"): | |
| for insight_data in context_data["insights"]: | |
| insight = ( | |
| SourceInsight(**insight_data) | |
| if isinstance(insight_data, dict) | |
| else insight_data | |
| ) | |
| insights.append(insight) | |
| context_indicators["insights"].append(insight.id) | |
| # Format context for the prompt | |
| formatted_context = _format_source_context(context_data) | |
| # Build prompt data for the template | |
| prompt_data = { | |
| "source": source.model_dump() if source else None, | |
| "insights": [insight.model_dump() for insight in insights] if insights else [], | |
| "context": formatted_context, | |
| "context_indicators": context_indicators, | |
| } | |
| # Apply the source_chat prompt template | |
| system_prompt = Prompter(prompt_template="source_chat").render(data=prompt_data) | |
| payload = [SystemMessage(content=system_prompt)] + state.get("messages", []) | |
| # Handle async model provisioning from sync context | |
| def run_in_new_loop(): | |
| """Run the async function in a new event loop""" | |
| new_loop = asyncio.new_event_loop() | |
| try: | |
| asyncio.set_event_loop(new_loop) | |
| return new_loop.run_until_complete( | |
| provision_langchain_model( | |
| str(payload), | |
| config.get("configurable", {}).get("model_id") | |
| or state.get("model_override"), | |
| "chat", | |
| max_tokens=8192, | |
| ) | |
| ) | |
| finally: | |
| new_loop.close() | |
| asyncio.set_event_loop(None) | |
| try: | |
| # Try to get the current event loop | |
| asyncio.get_running_loop() | |
| # If we're in an event loop, run in a thread with a new loop | |
| import concurrent.futures | |
| with concurrent.futures.ThreadPoolExecutor() as executor: | |
| future = executor.submit(run_in_new_loop) | |
| model = future.result() | |
| except RuntimeError: | |
| # No event loop running, safe to use asyncio.run() | |
| model = asyncio.run( | |
| provision_langchain_model( | |
| str(payload), | |
| config.get("configurable", {}).get("model_id") | |
| or state.get("model_override"), | |
| "chat", | |
| max_tokens=8192, | |
| ) | |
| ) | |
| ai_message = model.invoke(payload) | |
| # Clean thinking content from AI response (e.g., <think>...</think> tags) | |
| content = ai_message.content if isinstance(ai_message.content, str) else str(ai_message.content) | |
| cleaned_content = clean_thinking_content(content) | |
| cleaned_message = ai_message.model_copy(update={"content": cleaned_content}) | |
| # Update state with context information | |
| return { | |
| "messages": cleaned_message, | |
| "source": source, | |
| "insights": insights, | |
| "context": formatted_context, | |
| "context_indicators": context_indicators, | |
| } | |
| def _format_source_context(context_data: Dict) -> str: | |
| """ | |
| Format the context data into a readable string for the prompt. | |
| Args: | |
| context_data: Context data from ContextBuilder | |
| Returns: | |
| Formatted context string | |
| """ | |
| context_parts = [] | |
| # Add source information | |
| if context_data.get("sources"): | |
| context_parts.append("## SOURCE CONTENT") | |
| for source in context_data["sources"]: | |
| if isinstance(source, dict): | |
| context_parts.append(f"**Source ID:** {source.get('id', 'Unknown')}") | |
| context_parts.append(f"**Title:** {source.get('title', 'No title')}") | |
| if source.get("full_text"): | |
| # Truncate full text if too long | |
| full_text = source["full_text"] | |
| if len(full_text) > 5000: | |
| full_text = full_text[:5000] + "...\n[Content truncated]" | |
| context_parts.append(f"**Content:**\n{full_text}") | |
| context_parts.append("") # Empty line for separation | |
| # Add insights | |
| if context_data.get("insights"): | |
| context_parts.append("## SOURCE INSIGHTS") | |
| for insight in context_data["insights"]: | |
| if isinstance(insight, dict): | |
| context_parts.append(f"**Insight ID:** {insight.get('id', 'Unknown')}") | |
| context_parts.append( | |
| f"**Type:** {insight.get('insight_type', 'Unknown')}" | |
| ) | |
| context_parts.append( | |
| f"**Content:** {insight.get('content', 'No content')}" | |
| ) | |
| context_parts.append("") # Empty line for separation | |
| # Add metadata | |
| if context_data.get("metadata"): | |
| metadata = context_data["metadata"] | |
| context_parts.append("## CONTEXT METADATA") | |
| context_parts.append(f"- Source count: {metadata.get('source_count', 0)}") | |
| context_parts.append(f"- Insight count: {metadata.get('insight_count', 0)}") | |
| context_parts.append(f"- Total tokens: {context_data.get('total_tokens', 0)}") | |
| context_parts.append("") | |
| return "\n".join(context_parts) | |
| # Create SQLite checkpointer | |
| conn = sqlite3.connect( | |
| LANGGRAPH_CHECKPOINT_FILE, | |
| check_same_thread=False, | |
| ) | |
| memory = SqliteSaver(conn) | |
| # Create the StateGraph | |
| source_chat_state = StateGraph(SourceChatState) | |
| source_chat_state.add_node("source_chat_agent", call_model_with_source_context) | |
| source_chat_state.add_edge(START, "source_chat_agent") | |
| source_chat_state.add_edge("source_chat_agent", END) | |
| source_chat_graph = source_chat_state.compile(checkpointer=memory) | |