pikamomo's picture
Initial deployment
a60c0af
"""Orchestrator coordinating the deep research workflow using LangGraph."""
from __future__ import annotations
import logging
import re
import operator
from pathlib import Path
from queue import Empty, Queue
from threading import Lock, Thread
from typing import Any, Annotated, Iterator, TypedDict, Optional, Callable
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
from langchain_core.tools import tool
from langgraph.graph import StateGraph, END
from config import Configuration
from prompts import (
report_writer_instructions,
task_summarizer_instructions,
todo_planner_system_prompt,
todo_planner_instructions,
get_current_date,
)
from models import SummaryState, SummaryStateOutput, TodoItem
from services.search import dispatch_search, prepare_research_context
from utils import strip_thinking_tokens
logger = logging.getLogger(__name__)
# ============================================================================
# State Schema
# ============================================================================
class ResearchState(TypedDict, total=False):
"""State schema for the research workflow graph."""
research_topic: str
todo_items: list[TodoItem]
current_task_index: int
web_research_results: Annotated[list[str], operator.add]
sources_gathered: Annotated[list[str], operator.add]
research_loop_count: int
structured_report: Optional[str]
report_note_id: Optional[str]
report_note_path: Optional[str]
# Internal tracking
messages: list[Any]
config: Configuration
# ============================================================================
# Note Tool Implementation
# ============================================================================
class NoteTool:
"""Simple file-based note tool for persisting task notes."""
def __init__(self, workspace: str = "./notes"):
self.workspace = Path(workspace)
self.workspace.mkdir(parents=True, exist_ok=True)
self._id_counter = 0
self._lock = Lock()
def _generate_id(self) -> str:
with self._lock:
self._id_counter += 1
import time
return f"note_{int(time.time())}_{self._id_counter}"
def run(self, params: dict[str, Any]) -> str:
"""Execute note action: create, read, update, list."""
action = params.get("action", "read")
if action == "create":
return self._create_note(params)
elif action == "read":
return self._read_note(params)
elif action == "update":
return self._update_note(params)
elif action == "list":
return self._list_notes(params)
else:
return f"❌ Unknown action: {action}"
def _create_note(self, params: dict[str, Any]) -> str:
note_id = self._generate_id()
title = params.get("title", "Untitled")
note_type = params.get("note_type", "general")
tags = params.get("tags", [])
content = params.get("content", "")
task_id = params.get("task_id")
note_path = self.workspace / f"{note_id}.md"
frontmatter = f"""---
id: {note_id}
title: {title}
type: {note_type}
tags: {tags}
task_id: {task_id}
---
"""
note_path.write_text(frontmatter + content, encoding="utf-8")
return f"✅ Note created\nID: {note_id}\nPath: {note_path}"
def _read_note(self, params: dict[str, Any]) -> str:
note_id = params.get("note_id")
if not note_id:
return "❌ Missing note_id parameter"
note_path = self.workspace / f"{note_id}.md"
if not note_path.exists():
return f"❌ Note does not exist: {note_id}"
content = note_path.read_text(encoding="utf-8")
return f"✅ Note content:\n{content}"
def _update_note(self, params: dict[str, Any]) -> str:
note_id = params.get("note_id")
if not note_id:
return "❌ Missing note_id parameter"
note_path = self.workspace / f"{note_id}.md"
if not note_path.exists():
return f"❌ Note does not exist: {note_id}"
# Read existing content
existing = note_path.read_text(encoding="utf-8")
# Update frontmatter if provided
title = params.get("title")
content = params.get("content", "")
# Simple append strategy
if content:
updated = existing + "\n\n---\nUpdate:\n" + content
note_path.write_text(updated, encoding="utf-8")
return f"✅ Note updated\nID: {note_id}"
def _list_notes(self, params: dict[str, Any]) -> str:
notes = list(self.workspace.glob("*.md"))
if not notes:
return "📝 No notes yet"
result = "📝 Note list:\n"
for note in notes:
result += f"- {note.stem}\n"
return result
# ============================================================================
# Tool Call Tracker
# ============================================================================
class ToolCallTracker:
"""Collects tool call events for SSE streaming."""
def __init__(self, notes_workspace: Optional[str] = None):
self._notes_workspace = notes_workspace
self._events: list[dict[str, Any]] = []
self._cursor = 0
self._lock = Lock()
self._event_sink: Optional[Callable[[dict[str, Any]], None]] = None
def record(self, event: dict[str, Any]) -> None:
with self._lock:
event["id"] = len(self._events) + 1
self._events.append(event)
sink = self._event_sink
if sink:
sink({"type": "tool_call", **event})
def drain(self, step: Optional[int] = None) -> list[dict[str, Any]]:
with self._lock:
if self._cursor >= len(self._events):
return []
new_events = self._events[self._cursor:]
self._cursor = len(self._events)
payloads = []
for event in new_events:
payload = {"type": "tool_call", **event}
if step is not None:
payload["step"] = step
payloads.append(payload)
return payloads
def set_event_sink(self, sink: Optional[Callable[[dict[str, Any]], None]]) -> None:
self._event_sink = sink
def as_dicts(self) -> list[dict[str, Any]]:
with self._lock:
return list(self._events)
def reset(self) -> None:
with self._lock:
self._events.clear()
self._cursor = 0
# ============================================================================
# Deep Research Agent using LangGraph
# ============================================================================
class DeepResearchAgent:
"""Coordinator orchestrating TODO-based research workflow using LangGraph."""
def __init__(self, config: Configuration | None = None) -> None:
"""Initialize the coordinator with configuration and LangGraph components."""
self.config = config or Configuration.from_env()
self.llm = self._init_llm()
# Note tool setup
self.note_tool = (
NoteTool(workspace=self.config.notes_workspace)
if self.config.enable_notes
else None
)
# Tool call tracking
self._tool_tracker = ToolCallTracker(
self.config.notes_workspace if self.config.enable_notes else None
)
self._tool_event_sink_enabled = False
self._state_lock = Lock()
# Build the graph
self.graph = self._build_graph()
self._last_search_notices: list[str] = []
def _init_llm(self) -> ChatOpenAI:
"""Initialize ChatOpenAI with configuration preferences."""
llm_kwargs: dict[str, Any] = {
"temperature": 0.0,
"streaming": True,
}
model_id = self.config.llm_model_id or self.config.local_llm
if model_id:
llm_kwargs["model"] = model_id
provider = (self.config.llm_provider or "").strip()
if provider == "ollama":
llm_kwargs["base_url"] = self.config.sanitized_ollama_url()
llm_kwargs["api_key"] = self.config.llm_api_key or "ollama"
elif provider == "lmstudio":
llm_kwargs["base_url"] = self.config.lmstudio_base_url
if self.config.llm_api_key:
llm_kwargs["api_key"] = self.config.llm_api_key
else:
llm_kwargs["api_key"] = "lm-studio"
else:
if self.config.llm_base_url:
llm_kwargs["base_url"] = self.config.llm_base_url
if self.config.llm_api_key:
llm_kwargs["api_key"] = self.config.llm_api_key
return ChatOpenAI(**llm_kwargs)
def _build_graph(self) -> StateGraph:
"""Build the LangGraph workflow."""
workflow = StateGraph(ResearchState)
# Add nodes
workflow.add_node("plan_research", self._plan_research_node)
workflow.add_node("execute_tasks", self._execute_tasks_node)
workflow.add_node("generate_report", self._generate_report_node)
# Define edges
workflow.set_entry_point("plan_research")
workflow.add_edge("plan_research", "execute_tasks")
workflow.add_edge("execute_tasks", "generate_report")
workflow.add_edge("generate_report", END)
return workflow.compile()
# -------------------------------------------------------------------------
# Graph Nodes
# -------------------------------------------------------------------------
def _plan_research_node(self, state: ResearchState) -> dict[str, Any]:
"""Planning node: break research topic into actionable tasks."""
topic = state.get("research_topic", "")
system_prompt = todo_planner_system_prompt.strip()
user_prompt = todo_planner_instructions.format(
current_date=get_current_date(),
research_topic=topic,
)
messages = [
SystemMessage(content=system_prompt),
HumanMessage(content=user_prompt),
]
response = self.llm.invoke(messages)
response_text = response.content
if self.config.strip_thinking_tokens:
response_text = strip_thinking_tokens(response_text)
logger.info("Planner raw output (truncated): %s", response_text[:500])
# Parse tasks from response
todo_items = self._parse_todo_items(response_text, topic)
# Create notes for each task if enabled
if self.note_tool:
for task in todo_items:
result = self.note_tool.run({
"action": "create",
"task_id": task.id,
"title": f"Task {task.id}: {task.title}",
"note_type": "task_state",
"tags": ["deep_research", f"task_{task.id}"],
"content": f"Task objective: {task.intent}\nSearch query: {task.query}",
})
# Extract note_id from result
note_id = self._extract_note_id(result)
if note_id:
task.note_id = note_id
task.note_path = str(Path(self.config.notes_workspace) / f"{note_id}.md")
self._tool_tracker.record({
"agent": "Research Planning Expert",
"tool": "note",
"parameters": {"action": "create", "task_id": task.id},
"result": result,
"task_id": task.id,
"note_id": note_id,
})
titles = [task.title for task in todo_items]
logger.info("Planner produced %d tasks: %s", len(todo_items), titles)
return {
"todo_items": todo_items,
"current_task_index": 0,
"research_loop_count": 0,
}
def _execute_tasks_node(self, state: ResearchState) -> dict[str, Any]:
"""Execute research tasks: search and summarize each task."""
todo_items = state.get("todo_items", [])
topic = state.get("research_topic", "")
loop_count = state.get("research_loop_count", 0)
web_results: list[str] = []
sources: list[str] = []
for task in todo_items:
task.status = "in_progress"
# Execute search
search_result, notices, answer_text, backend = dispatch_search(
task.query,
self.config,
loop_count,
)
self._last_search_notices = notices
task.notices = notices
if not search_result or not search_result.get("results"):
task.status = "skipped"
continue
# Prepare context
sources_summary, context = prepare_research_context(
search_result, answer_text, self.config
)
task.sources_summary = sources_summary
web_results.append(context)
sources.append(sources_summary)
# Summarize task
summary = self._summarize_task(topic, task, context)
task.summary = summary
task.status = "completed"
# Update note if enabled
if self.note_tool and task.note_id:
result = self.note_tool.run({
"action": "update",
"note_id": task.note_id,
"task_id": task.id,
"content": f"## Task Summary\n{summary}\n\n## Sources\n{sources_summary}",
})
self._tool_tracker.record({
"agent": "Task Summary Expert",
"tool": "note",
"parameters": {"action": "update", "note_id": task.note_id},
"result": result,
"task_id": task.id,
"note_id": task.note_id,
})
loop_count += 1
return {
"todo_items": todo_items,
"web_research_results": web_results,
"sources_gathered": sources,
"research_loop_count": loop_count,
}
def _generate_report_node(self, state: ResearchState) -> dict[str, Any]:
"""Generate the final structured report."""
topic = state.get("research_topic", "")
todo_items = state.get("todo_items", [])
# Build task overview
tasks_block = []
for task in todo_items:
summary_block = task.summary or "No information available"
sources_block = task.sources_summary or "No sources available"
tasks_block.append(
f"### Task {task.id}: {task.title}\n"
f"- Objective: {task.intent}\n"
f"- Search query: {task.query}\n"
f"- Status: {task.status}\n"
f"- Summary:\n{summary_block}\n"
f"- Sources:\n{sources_block}\n"
)
prompt = (
f"Research topic: {topic}\n"
f"Task overview:\n{''.join(tasks_block)}\n"
"Based on the above task summaries, please write a structured research report."
)
messages = [
SystemMessage(content=report_writer_instructions.strip()),
HumanMessage(content=prompt),
]
response = self.llm.invoke(messages)
report_text = response.content
if self.config.strip_thinking_tokens:
report_text = strip_thinking_tokens(report_text)
report_text = report_text.strip() or "Report generation failed, please check input."
# Create conclusion note if enabled
report_note_id = None
report_note_path = None
if self.note_tool and report_text:
result = self.note_tool.run({
"action": "create",
"title": f"Research Report: {topic}",
"note_type": "conclusion",
"tags": ["deep_research", "report"],
"content": report_text,
})
report_note_id = self._extract_note_id(result)
if report_note_id:
report_note_path = str(Path(self.config.notes_workspace) / f"{report_note_id}.md")
self._tool_tracker.record({
"agent": "Report Writing Expert",
"tool": "note",
"parameters": {"action": "create", "note_type": "conclusion"},
"result": result,
"note_id": report_note_id,
})
return {
"structured_report": report_text,
"report_note_id": report_note_id,
"report_note_path": report_note_path,
}
# -------------------------------------------------------------------------
# Helper Methods
# -------------------------------------------------------------------------
def _summarize_task(self, topic: str, task: TodoItem, context: str) -> str:
"""Generate summary for a single task."""
prompt = (
f"Task topic: {topic}\n"
f"Task name: {task.title}\n"
f"Task objective: {task.intent}\n"
f"Search query: {task.query}\n"
f"Task context:\n{context}\n"
"Please generate a detailed task summary."
)
messages = [
SystemMessage(content=task_summarizer_instructions.strip()),
HumanMessage(content=prompt),
]
response = self.llm.invoke(messages)
summary_text = response.content
if self.config.strip_thinking_tokens:
summary_text = strip_thinking_tokens(summary_text)
return summary_text.strip() or "No information available"
def _parse_todo_items(self, response: str, topic: str) -> list[TodoItem]:
"""Parse planner output into TodoItem list."""
import json
text = response.strip()
tasks_payload: list[dict[str, Any]] = []
# Try to extract JSON
start = text.find("{")
end = text.rfind("}")
if start != -1 and end != -1 and end > start:
try:
json_obj = json.loads(text[start:end + 1])
if isinstance(json_obj, dict) and "tasks" in json_obj:
tasks_payload = json_obj["tasks"]
except json.JSONDecodeError:
pass
if not tasks_payload:
start = text.find("[")
end = text.rfind("]")
if start != -1 and end != -1 and end > start:
try:
tasks_payload = json.loads(text[start:end + 1])
except json.JSONDecodeError:
pass
# Create TodoItems
todo_items: list[TodoItem] = []
for idx, item in enumerate(tasks_payload, start=1):
if not isinstance(item, dict):
continue
title = str(item.get("title") or f"Task{idx}").strip()
intent = str(item.get("intent") or "Focus on key issues of the topic").strip()
query = str(item.get("query") or topic).strip() or topic
todo_items.append(TodoItem(
id=idx,
title=title,
intent=intent,
query=query,
))
# Fallback if no tasks parsed
if not todo_items:
todo_items.append(TodoItem(
id=1,
title="Basic Background Overview",
intent="Collect core background and latest developments on the topic",
query=f"{topic} latest developments" if topic else "Basic background overview",
))
return todo_items
@staticmethod
def _extract_note_id(response: str) -> Optional[str]:
"""Extract note ID from tool response."""
if not response:
return None
match = re.search(r"ID:\s*([^\n]+)", response)
return match.group(1).strip() if match else None
def _set_tool_event_sink(self, sink: Callable[[dict[str, Any]], None] | None) -> None:
"""Enable or disable immediate tool event callbacks."""
self._tool_event_sink_enabled = sink is not None
self._tool_tracker.set_event_sink(sink)
# -------------------------------------------------------------------------
# Public API
# -------------------------------------------------------------------------
def run(self, topic: str) -> SummaryStateOutput:
"""Execute the research workflow and return the final report."""
initial_state: ResearchState = {
"research_topic": topic,
"todo_items": [],
"current_task_index": 0,
"web_research_results": [],
"sources_gathered": [],
"research_loop_count": 0,
"structured_report": None,
"report_note_id": None,
"report_note_path": None,
"messages": [],
"config": self.config,
}
# Run the graph
final_state = self.graph.invoke(initial_state)
report = final_state.get("structured_report", "")
todo_items = final_state.get("todo_items", [])
return SummaryStateOutput(
running_summary=report,
report_markdown=report,
todo_items=todo_items,
)
def run_stream(self, topic: str) -> Iterator[dict[str, Any]]:
"""Execute the workflow yielding incremental progress events."""
logger.debug("Starting streaming research: topic=%s", topic)
yield {"type": "status", "message": "Initializing research workflow"}
# Plan phase
yield {"type": "status", "message": "Planning research tasks..."}
system_prompt = todo_planner_system_prompt.strip()
user_prompt = todo_planner_instructions.format(
current_date=get_current_date(),
research_topic=topic,
)
messages = [
SystemMessage(content=system_prompt),
HumanMessage(content=user_prompt),
]
response = self.llm.invoke(messages)
response_text = response.content
if self.config.strip_thinking_tokens:
response_text = strip_thinking_tokens(response_text)
todo_items = self._parse_todo_items(response_text, topic)
# Create notes for tasks
if self.note_tool:
for task in todo_items:
result = self.note_tool.run({
"action": "create",
"task_id": task.id,
"title": f"Task {task.id}: {task.title}",
"note_type": "task_state",
"tags": ["deep_research", f"task_{task.id}"],
"content": f"Task objective: {task.intent}\nSearch query: {task.query}",
})
note_id = self._extract_note_id(result)
if note_id:
task.note_id = note_id
task.note_path = str(Path(self.config.notes_workspace) / f"{note_id}.md")
# Setup channel mapping for streaming
channel_map: dict[int, dict[str, Any]] = {}
for index, task in enumerate(todo_items, start=1):
token = f"task_{task.id}"
task.stream_token = token
channel_map[task.id] = {"step": index, "token": token}
yield {
"type": "todo_list",
"tasks": [self._serialize_task(t) for t in todo_items],
"step": 0,
}
# Execute tasks with streaming
event_queue: Queue[dict[str, Any]] = Queue()
def enqueue(event: dict[str, Any], task: Optional[TodoItem] = None, step_override: Optional[int] = None) -> None:
payload = dict(event)
target_task_id = payload.get("task_id")
if task is not None:
target_task_id = task.id
payload["task_id"] = task.id
channel = channel_map.get(target_task_id) if target_task_id else None
if channel:
payload.setdefault("step", channel["step"])
payload["stream_token"] = channel["token"]
if step_override is not None:
payload["step"] = step_override
event_queue.put(payload)
def tool_event_sink(event: dict[str, Any]) -> None:
enqueue(event)
self._set_tool_event_sink(tool_event_sink)
threads: list[Thread] = []
state = SummaryState(research_topic=topic)
state.todo_items = todo_items
def worker(task: TodoItem, step: int) -> None:
try:
enqueue({
"type": "task_status",
"task_id": task.id,
"status": "in_progress",
"title": task.title,
"intent": task.intent,
"note_id": task.note_id,
"note_path": task.note_path,
}, task=task)
# Execute search
search_result, notices, answer_text, backend = dispatch_search(
task.query, self.config, state.research_loop_count
)
task.notices = notices
for notice in notices:
if notice:
enqueue({
"type": "status",
"message": notice,
"task_id": task.id,
}, task=task)
if not search_result or not search_result.get("results"):
task.status = "skipped"
enqueue({
"type": "task_status",
"task_id": task.id,
"status": "skipped",
"title": task.title,
"intent": task.intent,
"note_id": task.note_id,
"note_path": task.note_path,
}, task=task)
return
# Prepare context
sources_summary, context = prepare_research_context(
search_result, answer_text, self.config
)
task.sources_summary = sources_summary
with self._state_lock:
state.web_research_results.append(context)
state.sources_gathered.append(sources_summary)
state.research_loop_count += 1
enqueue({
"type": "sources",
"task_id": task.id,
"latest_sources": sources_summary,
"raw_context": context,
"backend": backend,
"note_id": task.note_id,
"note_path": task.note_path,
}, task=task)
# Stream summarization
prompt = (
f"Task topic: {topic}\n"
f"Task name: {task.title}\n"
f"Task objective: {task.intent}\n"
f"Search query: {task.query}\n"
f"Task context:\n{context}\n"
"Please generate a detailed task summary."
)
summary_messages = [
SystemMessage(content=task_summarizer_instructions.strip()),
HumanMessage(content=prompt),
]
summary_chunks: list[str] = []
for chunk in self.llm.stream(summary_messages):
chunk_text = chunk.content
if chunk_text:
summary_chunks.append(chunk_text)
# Strip thinking tokens from visible output
visible_chunk = chunk_text
if self.config.strip_thinking_tokens and "<think>" not in chunk_text:
enqueue({
"type": "task_summary_chunk",
"task_id": task.id,
"content": visible_chunk,
"note_id": task.note_id,
}, task=task)
full_summary = "".join(summary_chunks)
if self.config.strip_thinking_tokens:
full_summary = strip_thinking_tokens(full_summary)
task.summary = full_summary.strip() or "No information available"
task.status = "completed"
# Update note
if self.note_tool and task.note_id:
self.note_tool.run({
"action": "update",
"note_id": task.note_id,
"task_id": task.id,
"content": f"## Task Summary\n{task.summary}\n\n## Sources\n{sources_summary}",
})
enqueue({
"type": "task_status",
"task_id": task.id,
"status": "completed",
"summary": task.summary,
"sources_summary": task.sources_summary,
"note_id": task.note_id,
"note_path": task.note_path,
}, task=task)
except Exception as exc:
logger.exception("Task execution failed", exc_info=exc)
enqueue({
"type": "task_status",
"task_id": task.id,
"status": "failed",
"detail": str(exc),
"title": task.title,
"intent": task.intent,
"note_id": task.note_id,
"note_path": task.note_path,
}, task=task)
finally:
enqueue({"type": "__task_done__", "task_id": task.id})
# Start worker threads
for task in todo_items:
step = channel_map.get(task.id, {}).get("step", 0)
thread = Thread(target=worker, args=(task, step), daemon=True)
threads.append(thread)
thread.start()
# Yield events from queue
active_workers = len(todo_items)
finished_workers = 0
try:
while finished_workers < active_workers:
event = event_queue.get()
if event.get("type") == "__task_done__":
finished_workers += 1
continue
yield event
# Drain remaining events
while True:
try:
event = event_queue.get_nowait()
except Empty:
break
if event.get("type") != "__task_done__":
yield event
finally:
self._set_tool_event_sink(None)
for thread in threads:
thread.join()
# Generate final report
yield {"type": "status", "message": "Generating research report..."}
tasks_block = []
for task in todo_items:
summary_block = task.summary or "No information available"
sources_block = task.sources_summary or "No sources available"
tasks_block.append(
f"### Task {task.id}: {task.title}\n"
f"- Objective: {task.intent}\n"
f"- Search query: {task.query}\n"
f"- Status: {task.status}\n"
f"- Summary:\n{summary_block}\n"
f"- Sources:\n{sources_block}\n"
)
report_prompt = (
f"Research topic: {topic}\n"
f"Task overview:\n{''.join(tasks_block)}\n"
"Based on the above task summaries, please write a structured research report."
)
report_messages = [
SystemMessage(content=report_writer_instructions.strip()),
HumanMessage(content=report_prompt),
]
report = self.llm.invoke(report_messages).content
if self.config.strip_thinking_tokens:
report = strip_thinking_tokens(report)
report = report.strip() or "Report generation failed"
# Create conclusion note
report_note_id = None
report_note_path = None
if self.note_tool:
result = self.note_tool.run({
"action": "create",
"title": f"Research Report: {topic}",
"note_type": "conclusion",
"tags": ["deep_research", "report"],
"content": report,
})
report_note_id = self._extract_note_id(result)
if report_note_id:
report_note_path = str(Path(self.config.notes_workspace) / f"{report_note_id}.md")
yield {
"type": "final_report",
"report": report,
"note_id": report_note_id,
"note_path": report_note_path,
}
yield {"type": "done"}
def _serialize_task(self, task: TodoItem) -> dict[str, Any]:
"""Convert task dataclass to serializable dict for frontend."""
return {
"id": task.id,
"title": task.title,
"intent": task.intent,
"query": task.query,
"status": task.status,
"summary": task.summary,
"sources_summary": task.sources_summary,
"note_id": task.note_id,
"note_path": task.note_path,
"stream_token": task.stream_token,
}
def run_deep_research(topic: str, config: Configuration | None = None) -> SummaryStateOutput:
"""Convenience function mirroring the class-based API."""
agent = DeepResearchAgent(config=config)
return agent.run(topic)