Spaces:
Running
Running
File size: 11,738 Bytes
0cd3dc5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 | from typing import TypedDict, Annotated, List, Dict, Any, Union
import operator
from langgraph.graph import StateGraph, END
import time
import json
# Import our existing services (Reusing logic!)
from services.query_decomposition import QueryDecomposerService
from services.graph_planner import GraphPlannerService
from services.hybrid_retrieval import HybridRetrievalService
from services.context_organization import ContextOrganizationService
from services.legal_reasoner import LegalReasonerService
from services.quality_assurance import QualityAssuranceService
from services.neo4j import get_neo4j_driver
from core.search import SearchService
# 1. Start with State Definition
class AgentState(TypedDict):
# Inputs
query: str
# Internal State
intent: str
classification: Dict[str, Any]
graph_plan: Dict[str, Any]
qdrant_scope: Dict[str, Any] # New field for Tree Routing scope
graph_results: List[Dict[str, Any]]
retrieved_chunks: List[Dict[str, Any]]
context_data: Dict[str, Any]
draft_answer: Dict[str, Any]
critique: Dict[str, Any]
# Metadata
retry_count: int
trace: Annotated[List[str], operator.add]
raw_chunks: List[Dict[str, Any]]
# 2. Define Services Wrapper (Singleton access would be better, but init here is fine for now)
# Ideally these should be initialized once outside and passed in, but we'll init inside the nodes or globally.
# For thread safety in a real app, these should be global singletons.
class ServiceContainer:
def __init__(self):
self.decomposer = QueryDecomposerService()
self.planner = GraphPlannerService()
self.retriever = HybridRetrievalService()
self.organizer = ContextOrganizationService()
self.reasoner = LegalReasonerService()
self.qa = QualityAssuranceService()
self.neo4j = get_neo4j_driver()
self.vector_search = SearchService()
# Global container instance
services = ServiceContainer()
# 3. Define Nodes
def node_classify(state: AgentState) -> AgentState:
print("--- Node: Classify ---")
classification = services.decomposer.decompose(state["query"])
print(f"DEBUG: Classification: {json.dumps(classification, indent=2)}")
return {
"classification": classification,
"qdrant_scope": {
"article_numbers": classification.get("entities", {}).get("articles", []),
"amendment_ids": classification.get("entities", {}).get("amendments", []),
"include_amendments": True
},
"trace": ["Classify"]
}
def node_plan_graph(state: AgentState) -> AgentState:
print("--- Node: Graph Tree Router ---")
# If retrying, ask planner for Broader Scope
retry_count = state.get("retry_count", 0)
classification = state["classification"]
if retry_count > 0:
print(f"π Retry Mode ({retry_count}): Broadening Graph Search Scope")
# Add a flag to classification to trigger broader search
classification["broad_search"] = True
plan = services.planner.generate_plan(classification)
# Execute Neo4j Traversal Helper
def execute_queries(query_list):
res_list = []
with services.neo4j.session() as session:
for idx, q in enumerate(query_list, 1):
if "DELETE" in q or "DETACH" in q: continue
try:
res = session.run(q)
data = [r.data() for r in res]
res_list.extend(data)
except Exception as e:
print(f"Neo4j Error on Query {idx}: {e}")
return res_list
queries = plan.get("cypher_queries", [])
print(f"DEBUG: Generated Cypher Queries ({len(queries)}): {json.dumps(queries, indent=2)}")
print(f"DEBUG: About to execute {len(queries)} explicit queries")
results = execute_queries(queries)
# EAGER FALLBACK: If explicit queries fail, run fallbacks immediately in the same step
if not results:
print("DEBUG: 0 results from initial queries. Running eager fallbacks.")
fallback_queries = []
articles = classification.get("entities", {}).get("articles", [])
if isinstance(articles, str):
articles = [articles]
for art in articles:
fallback_queries.append(f"""
MATCH (am:Amendment)-[r]->(a:Article {{number: '{art}'}})
RETURN am.number as amendment, r.details as modification, type(r) as relationship, '{art}' as target_id
LIMIT 50
""")
amendments = classification.get("entities", {}).get("amendments", [])
if not articles and amendments:
for am in amendments:
am_clean = str(am)
if not am_clean.isdigit():
import re; am_clean = re.sub(r'\D', '', am_clean)
if am_clean:
fallback_queries.append(f"""
MATCH (am:Amendment {{number: {am_clean}}})-[r]->(target)
RETURN am.number as amendment, type(r) as relationship, r.details as modification, COALESCE(target.number, target.id) as target_id
LIMIT 50
""")
if fallback_queries:
print(f"DEBUG: Executing {len(fallback_queries)} fallback queries")
results.extend(execute_queries(fallback_queries))
print(f"DEBUG: Total Graph Results: {len(results)}")
if results:
print(f"DEBUG: Sample Result: {json.dumps(results, indent=2)}")
return {
"graph_plan": plan,
"graph_results": results,
"trace": ["GraphPlan"]
}
def node_fetch_vector(state: AgentState) -> AgentState:
print(f"--- Node: Fetch Vector (Parallel) ---")
query = state["query"]
scope = state.get("qdrant_scope", {})
limit = state.get("graph_plan", {}).get("expected_chunks", 5)
if state.get("retry_count", 0) > 0:
limit += 5
print(f"Executing Scoped Search with Scope: {scope}")
raw_chunks = services.vector_search.search(query, scope=scope, limit=limit)
print(f"DEBUG: Vector Search fetched {len(raw_chunks)} chunks")
return {
"raw_chunks": raw_chunks,
"trace": ["FetchVector"]
}
def node_coordinate(state: AgentState) -> AgentState:
print("--- Node: Coordinate (Reranking) ---")
retrieval_output = services.retriever.coordinate_retrieval(
state["query"],
state.get("graph_results", []),
state.get("raw_chunks", [])
)
final_chunks = retrieval_output.get("final_selected_chunks", [])
print(f"DEBUG: Final Selected Chunks: {len(final_chunks)}")
return {
"retrieved_chunks": final_chunks,
"trace": ["Coordinate"]
}
def node_organize(state: AgentState) -> AgentState:
print("--- Node: Organize ---")
# Pass full chunks with metadata (year, amendment_number) for temporal organization
chunks = state.get("retrieved_chunks", [])
context = services.organizer.organize_context(
state["query"],
chunks,
graph_data=state.get("graph_results", [])
)
print(f"DEBUG: Organized Context Keys: {list(context.keys())}")
if "context_block" in context:
print(f"DEBUG: Context Block Preview (first 500 chars):\n{context['context_block'][:500]}")
return {
"context_data": context,
"trace": ["Organize"]
}
def node_reason(state: AgentState) -> AgentState:
print("--- Node: Reason ---")
answer = services.reasoner.generate_answer(state["query"], state["context_data"])
# Debug answer brief
print(f"DEBUG: Draft Answer Preview: {json.dumps(answer)[:200]}...")
return {
"draft_answer": answer,
"trace": ["Reason"]
}
def node_validate(state: AgentState) -> AgentState:
print("--- Node: Validate ---")
# Log graph data completeness for debugging
graph_results = state.get("graph_results", [])
graph_plan = state.get("graph_plan", {})
cypher_queries = graph_plan.get("cypher_queries", [])
context_data = state.get("context_data", {})
retry_count = state.get("retry_count", 0)
# Log graph results summary
found_types = []
if any("target_id" in r for r in graph_results): found_types.append("Direct")
if any("via_article" in r for r in graph_results): found_types.append("Multi-hop")
if any("related_article" in r for r in graph_results): found_types.append("Related")
print(f"DEBUG: Graph context contains {len(graph_results)} nodes across {', '.join(found_types) if found_types else 'no'} relationship patterns.")
# We delegate full judgment of context completeness to the QA LLM instead of a hard-coded check.
# Run LLM-based validation with BOTH graph and vector data
critique = services.qa.validate_answer(
state["query"],
state["draft_answer"],
graph_results,
context_data
)
print(f"DEBUG: Critique: {json.dumps(critique, indent=2)}")
return {
"critique": critique,
"trace": ["Validate"]
}
# 4. Define Conditional Logic
def should_continue(state: AgentState) -> str:
critique = state.get("critique", {})
retry = state.get("retry_count", 0)
# Max 2 retries total
if retry >= 2:
return "end"
quality_grade = critique.get("quality_grade")
if quality_grade == "REFINE":
# Check QA's suggested retry type
retry_type = critique.get("retry_type", "retrieve") # default to retrieve
if retry_type == "reason":
# Data is good, but answer missed it - just re-run reasoner
print("π Retry: Re-running reasoner with same data")
return "retry_reason"
else:
# Data incomplete - need new retrieval
print("π Retry: Re-running retrieval")
return "retry_retrieve"
if quality_grade == "REVIEW" and retry < 1:
# For REVIEW grade, try re-reasoning once
print("π Retry: Re-running reasoner for quality improvement")
return "retry_reason"
return "end"
# 5. Build Graph
workflow = StateGraph(AgentState)
workflow.add_node("classify", node_classify)
workflow.add_node("graph_plan", node_plan_graph)
workflow.add_node("fetch_vector", node_fetch_vector)
workflow.add_node("coordinate", node_coordinate)
workflow.add_node("organize", node_organize)
workflow.add_node("reason", node_reason)
workflow.add_node("validate", node_validate)
# Edges
workflow.set_entry_point("classify")
# Parallel Fan-Out
workflow.add_edge("classify", "graph_plan")
workflow.add_edge("classify", "fetch_vector")
# Fan-In to Coordinate
workflow.add_edge("graph_plan", "coordinate")
workflow.add_edge("fetch_vector", "coordinate")
workflow.add_edge("coordinate", "organize")
workflow.add_edge("organize", "reason")
workflow.add_edge("reason", "validate")
workflow.add_conditional_edges(
"validate",
should_continue,
{
"retry_retrieve": "update_retry",
"retry_reason": "update_retry", # Route both through update_retry to prevent infinite loops
"end": END
}
)
def node_retry_dispatcher(state: AgentState) -> str:
critique = state.get("critique", {})
return "reason" if critique.get("retry_type") == "reason" else "classify"
workflow.add_node("update_retry", lambda x: {"retry_count": x["retry_count"] + 1})
workflow.add_conditional_edges("update_retry", node_retry_dispatcher, {"reason": "reason", "classify": "classify"})
# Compile
app = workflow.compile()
|