ConstitutionAgent / core /graph_workflow.py
Meshyboi's picture
Upload 53 files
0cd3dc5 verified
from typing import TypedDict, Annotated, List, Dict, Any, Union
import operator
from langgraph.graph import StateGraph, END
import time
import json
# Import our existing services (Reusing logic!)
from services.query_decomposition import QueryDecomposerService
from services.graph_planner import GraphPlannerService
from services.hybrid_retrieval import HybridRetrievalService
from services.context_organization import ContextOrganizationService
from services.legal_reasoner import LegalReasonerService
from services.quality_assurance import QualityAssuranceService
from services.neo4j import get_neo4j_driver
from core.search import SearchService
# 1. Start with State Definition
class AgentState(TypedDict):
# Inputs
query: str
# Internal State
intent: str
classification: Dict[str, Any]
graph_plan: Dict[str, Any]
qdrant_scope: Dict[str, Any] # New field for Tree Routing scope
graph_results: List[Dict[str, Any]]
retrieved_chunks: List[Dict[str, Any]]
context_data: Dict[str, Any]
draft_answer: Dict[str, Any]
critique: Dict[str, Any]
# Metadata
retry_count: int
trace: Annotated[List[str], operator.add]
raw_chunks: List[Dict[str, Any]]
# 2. Define Services Wrapper (Singleton access would be better, but init here is fine for now)
# Ideally these should be initialized once outside and passed in, but we'll init inside the nodes or globally.
# For thread safety in a real app, these should be global singletons.
class ServiceContainer:
def __init__(self):
self.decomposer = QueryDecomposerService()
self.planner = GraphPlannerService()
self.retriever = HybridRetrievalService()
self.organizer = ContextOrganizationService()
self.reasoner = LegalReasonerService()
self.qa = QualityAssuranceService()
self.neo4j = get_neo4j_driver()
self.vector_search = SearchService()
# Global container instance
services = ServiceContainer()
# 3. Define Nodes
def node_classify(state: AgentState) -> AgentState:
print("--- Node: Classify ---")
classification = services.decomposer.decompose(state["query"])
print(f"DEBUG: Classification: {json.dumps(classification, indent=2)}")
return {
"classification": classification,
"qdrant_scope": {
"article_numbers": classification.get("entities", {}).get("articles", []),
"amendment_ids": classification.get("entities", {}).get("amendments", []),
"include_amendments": True
},
"trace": ["Classify"]
}
def node_plan_graph(state: AgentState) -> AgentState:
print("--- Node: Graph Tree Router ---")
# If retrying, ask planner for Broader Scope
retry_count = state.get("retry_count", 0)
classification = state["classification"]
if retry_count > 0:
print(f"๐Ÿ”„ Retry Mode ({retry_count}): Broadening Graph Search Scope")
# Add a flag to classification to trigger broader search
classification["broad_search"] = True
plan = services.planner.generate_plan(classification)
# Execute Neo4j Traversal Helper
def execute_queries(query_list):
res_list = []
with services.neo4j.session() as session:
for idx, q in enumerate(query_list, 1):
if "DELETE" in q or "DETACH" in q: continue
try:
res = session.run(q)
data = [r.data() for r in res]
res_list.extend(data)
except Exception as e:
print(f"Neo4j Error on Query {idx}: {e}")
return res_list
queries = plan.get("cypher_queries", [])
print(f"DEBUG: Generated Cypher Queries ({len(queries)}): {json.dumps(queries, indent=2)}")
print(f"DEBUG: About to execute {len(queries)} explicit queries")
results = execute_queries(queries)
# EAGER FALLBACK: If explicit queries fail, run fallbacks immediately in the same step
if not results:
print("DEBUG: 0 results from initial queries. Running eager fallbacks.")
fallback_queries = []
articles = classification.get("entities", {}).get("articles", [])
if isinstance(articles, str):
articles = [articles]
for art in articles:
fallback_queries.append(f"""
MATCH (am:Amendment)-[r]->(a:Article {{number: '{art}'}})
RETURN am.number as amendment, r.details as modification, type(r) as relationship, '{art}' as target_id
LIMIT 50
""")
amendments = classification.get("entities", {}).get("amendments", [])
if not articles and amendments:
for am in amendments:
am_clean = str(am)
if not am_clean.isdigit():
import re; am_clean = re.sub(r'\D', '', am_clean)
if am_clean:
fallback_queries.append(f"""
MATCH (am:Amendment {{number: {am_clean}}})-[r]->(target)
RETURN am.number as amendment, type(r) as relationship, r.details as modification, COALESCE(target.number, target.id) as target_id
LIMIT 50
""")
if fallback_queries:
print(f"DEBUG: Executing {len(fallback_queries)} fallback queries")
results.extend(execute_queries(fallback_queries))
print(f"DEBUG: Total Graph Results: {len(results)}")
if results:
print(f"DEBUG: Sample Result: {json.dumps(results, indent=2)}")
return {
"graph_plan": plan,
"graph_results": results,
"trace": ["GraphPlan"]
}
def node_fetch_vector(state: AgentState) -> AgentState:
print(f"--- Node: Fetch Vector (Parallel) ---")
query = state["query"]
scope = state.get("qdrant_scope", {})
limit = state.get("graph_plan", {}).get("expected_chunks", 5)
if state.get("retry_count", 0) > 0:
limit += 5
print(f"Executing Scoped Search with Scope: {scope}")
raw_chunks = services.vector_search.search(query, scope=scope, limit=limit)
print(f"DEBUG: Vector Search fetched {len(raw_chunks)} chunks")
return {
"raw_chunks": raw_chunks,
"trace": ["FetchVector"]
}
def node_coordinate(state: AgentState) -> AgentState:
print("--- Node: Coordinate (Reranking) ---")
retrieval_output = services.retriever.coordinate_retrieval(
state["query"],
state.get("graph_results", []),
state.get("raw_chunks", [])
)
final_chunks = retrieval_output.get("final_selected_chunks", [])
print(f"DEBUG: Final Selected Chunks: {len(final_chunks)}")
return {
"retrieved_chunks": final_chunks,
"trace": ["Coordinate"]
}
def node_organize(state: AgentState) -> AgentState:
print("--- Node: Organize ---")
# Pass full chunks with metadata (year, amendment_number) for temporal organization
chunks = state.get("retrieved_chunks", [])
context = services.organizer.organize_context(
state["query"],
chunks,
graph_data=state.get("graph_results", [])
)
print(f"DEBUG: Organized Context Keys: {list(context.keys())}")
if "context_block" in context:
print(f"DEBUG: Context Block Preview (first 500 chars):\n{context['context_block'][:500]}")
return {
"context_data": context,
"trace": ["Organize"]
}
def node_reason(state: AgentState) -> AgentState:
print("--- Node: Reason ---")
answer = services.reasoner.generate_answer(state["query"], state["context_data"])
# Debug answer brief
print(f"DEBUG: Draft Answer Preview: {json.dumps(answer)[:200]}...")
return {
"draft_answer": answer,
"trace": ["Reason"]
}
def node_validate(state: AgentState) -> AgentState:
print("--- Node: Validate ---")
# Log graph data completeness for debugging
graph_results = state.get("graph_results", [])
graph_plan = state.get("graph_plan", {})
cypher_queries = graph_plan.get("cypher_queries", [])
context_data = state.get("context_data", {})
retry_count = state.get("retry_count", 0)
# Log graph results summary
found_types = []
if any("target_id" in r for r in graph_results): found_types.append("Direct")
if any("via_article" in r for r in graph_results): found_types.append("Multi-hop")
if any("related_article" in r for r in graph_results): found_types.append("Related")
print(f"DEBUG: Graph context contains {len(graph_results)} nodes across {', '.join(found_types) if found_types else 'no'} relationship patterns.")
# We delegate full judgment of context completeness to the QA LLM instead of a hard-coded check.
# Run LLM-based validation with BOTH graph and vector data
critique = services.qa.validate_answer(
state["query"],
state["draft_answer"],
graph_results,
context_data
)
print(f"DEBUG: Critique: {json.dumps(critique, indent=2)}")
return {
"critique": critique,
"trace": ["Validate"]
}
# 4. Define Conditional Logic
def should_continue(state: AgentState) -> str:
critique = state.get("critique", {})
retry = state.get("retry_count", 0)
# Max 2 retries total
if retry >= 2:
return "end"
quality_grade = critique.get("quality_grade")
if quality_grade == "REFINE":
# Check QA's suggested retry type
retry_type = critique.get("retry_type", "retrieve") # default to retrieve
if retry_type == "reason":
# Data is good, but answer missed it - just re-run reasoner
print("๐Ÿ”„ Retry: Re-running reasoner with same data")
return "retry_reason"
else:
# Data incomplete - need new retrieval
print("๐Ÿ”„ Retry: Re-running retrieval")
return "retry_retrieve"
if quality_grade == "REVIEW" and retry < 1:
# For REVIEW grade, try re-reasoning once
print("๐Ÿ”„ Retry: Re-running reasoner for quality improvement")
return "retry_reason"
return "end"
# 5. Build Graph
workflow = StateGraph(AgentState)
workflow.add_node("classify", node_classify)
workflow.add_node("graph_plan", node_plan_graph)
workflow.add_node("fetch_vector", node_fetch_vector)
workflow.add_node("coordinate", node_coordinate)
workflow.add_node("organize", node_organize)
workflow.add_node("reason", node_reason)
workflow.add_node("validate", node_validate)
# Edges
workflow.set_entry_point("classify")
# Parallel Fan-Out
workflow.add_edge("classify", "graph_plan")
workflow.add_edge("classify", "fetch_vector")
# Fan-In to Coordinate
workflow.add_edge("graph_plan", "coordinate")
workflow.add_edge("fetch_vector", "coordinate")
workflow.add_edge("coordinate", "organize")
workflow.add_edge("organize", "reason")
workflow.add_edge("reason", "validate")
workflow.add_conditional_edges(
"validate",
should_continue,
{
"retry_retrieve": "update_retry",
"retry_reason": "update_retry", # Route both through update_retry to prevent infinite loops
"end": END
}
)
def node_retry_dispatcher(state: AgentState) -> str:
critique = state.get("critique", {})
return "reason" if critique.get("retry_type") == "reason" else "classify"
workflow.add_node("update_retry", lambda x: {"retry_count": x["retry_count"] + 1})
workflow.add_conditional_edges("update_retry", node_retry_dispatcher, {"reason": "reason", "classify": "classify"})
# Compile
app = workflow.compile()