Spaces:
Running
Running
File size: 5,521 Bytes
e0c585c a820b5b e0c585c a820b5b 7baf8ba a820b5b e0c585c a820b5b f160233 e0c585c f160233 a820b5b 7baf8ba a820b5b e0c585c a820b5b f160233 a820b5b f160233 a820b5b e0c585c a820b5b e0c585c a820b5b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
"""LangGraph-based orchestrator implementation.
NOTE: This orchestrator is deprecated in favor of the shared memory layer
integrated into Simple and Advanced modes (SPEC-08). It remains as a reference
implementation for LangGraph patterns.
"""
import os
import uuid
from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any, Literal
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
from src.agents.graph.state import ResearchState
from src.agents.graph.workflow import create_research_graph
from src.orchestrators.base import OrchestratorProtocol
from src.utils.config import settings
from src.utils.models import AgentEvent
from src.utils.service_loader import get_embedding_service
class LangGraphOrchestrator(OrchestratorProtocol):
"""State-driven research orchestrator using LangGraph.
DEPRECATED: Memory features are now integrated into Simple and Advanced modes.
This class is kept for reference and potential future use.
"""
def __init__(
self,
max_iterations: int = 10,
checkpoint_path: str | None = None,
):
self._max_iterations = max_iterations
self._checkpoint_path = checkpoint_path
# Initialize the LLM (Llama 3.1 via HF Inference)
# We use the serverless API by default
repo_id = "meta-llama/Llama-3.1-70B-Instruct"
# Ensure we have an API key
api_key = settings.hf_token
if not api_key:
raise ValueError(
"HF_TOKEN (Hugging Face API Token) is required for LangGraph orchestrator."
)
self.llm_endpoint = HuggingFaceEndpoint( # type: ignore
repo_id=repo_id,
task="text-generation",
max_new_tokens=1024,
temperature=0.1,
huggingfacehub_api_token=api_key,
)
self.chat_model = ChatHuggingFace(llm=self.llm_endpoint)
async def run(self, query: str) -> AsyncGenerator[AgentEvent, None]:
"""Execute research workflow with structured state."""
# Initialize embedding service using tiered selection (service_loader)
# Returns LlamaIndexRAGService if OpenAI key available, else local EmbeddingService
embedding_service = get_embedding_service()
# Setup checkpointer (SQLite for dev)
if self._checkpoint_path:
# Ensure directory exists (handle paths without directory component)
dir_name = os.path.dirname(self._checkpoint_path)
if dir_name:
os.makedirs(dir_name, exist_ok=True)
saver = AsyncSqliteSaver.from_conn_string(self._checkpoint_path)
else:
saver = None
# Use a helper context manager to handle the optional saver
from contextlib import asynccontextmanager
@asynccontextmanager
async def get_graph_context(saver_instance: Any) -> AsyncIterator[Any]:
if saver_instance:
async with saver_instance as s:
yield create_research_graph(
llm=self.chat_model,
checkpointer=s,
embedding_service=embedding_service,
)
else:
yield create_research_graph(
llm=self.chat_model,
checkpointer=None,
embedding_service=embedding_service,
)
async with get_graph_context(saver) as graph:
# Initialize state
initial_state: ResearchState = {
"query": query,
"hypotheses": [],
"conflicts": [],
"evidence_ids": [],
"messages": [],
"next_step": "search", # Start with search
"iteration_count": 0,
"max_iterations": self._max_iterations,
}
yield AgentEvent(type="started", message=f"Starting LangGraph research: {query}")
# Config for persistence (unique thread_id per run to avoid state conflicts)
thread_id = str(uuid.uuid4())
config = {"configurable": {"thread_id": thread_id}} if saver else {}
# Stream events
# We use astream to get updates from the graph
async for event in graph.astream(initial_state, config=config):
# Event is a dict of node_name -> state_update
for node_name, update in event.items():
if update.get("messages"):
last_msg = update["messages"][-1]
event_type: Literal["progress", "thinking", "searching"] = "progress"
if node_name == "supervisor":
event_type = "thinking"
elif node_name == "search":
event_type = "searching"
yield AgentEvent(
type=event_type, message=str(last_msg.content), data={"node": node_name}
)
elif node_name == "supervisor":
yield AgentEvent(
type="thinking",
message=f"Supervisor decided: {update.get('next_step')}",
data={"node": node_name},
)
yield AgentEvent(type="complete", message="Research complete.")
|