from typing import TypedDict, Annotated, List, Dict, Any, Union import operator from langgraph.graph import StateGraph, END import time import json # Import our existing services (Reusing logic!) from services.query_decomposition import QueryDecomposerService from services.graph_planner import GraphPlannerService from services.hybrid_retrieval import HybridRetrievalService from services.context_organization import ContextOrganizationService from services.legal_reasoner import LegalReasonerService from services.quality_assurance import QualityAssuranceService from services.neo4j import get_neo4j_driver from core.search import SearchService # 1. Start with State Definition class AgentState(TypedDict): # Inputs query: str # Internal State intent: str classification: Dict[str, Any] graph_plan: Dict[str, Any] qdrant_scope: Dict[str, Any] # New field for Tree Routing scope graph_results: List[Dict[str, Any]] retrieved_chunks: List[Dict[str, Any]] context_data: Dict[str, Any] draft_answer: Dict[str, Any] critique: Dict[str, Any] # Metadata retry_count: int trace: Annotated[List[str], operator.add] raw_chunks: List[Dict[str, Any]] # 2. Define Services Wrapper (Singleton access would be better, but init here is fine for now) # Ideally these should be initialized once outside and passed in, but we'll init inside the nodes or globally. # For thread safety in a real app, these should be global singletons. class ServiceContainer: def __init__(self): self.decomposer = QueryDecomposerService() self.planner = GraphPlannerService() self.retriever = HybridRetrievalService() self.organizer = ContextOrganizationService() self.reasoner = LegalReasonerService() self.qa = QualityAssuranceService() self.neo4j = get_neo4j_driver() self.vector_search = SearchService() # Global container instance services = ServiceContainer() # 3. Define Nodes def node_classify(state: AgentState) -> AgentState: print("--- Node: Classify ---") classification = services.decomposer.decompose(state["query"]) print(f"DEBUG: Classification: {json.dumps(classification, indent=2)}") return { "classification": classification, "qdrant_scope": { "article_numbers": classification.get("entities", {}).get("articles", []), "amendment_ids": classification.get("entities", {}).get("amendments", []), "include_amendments": True }, "trace": ["Classify"] } def node_plan_graph(state: AgentState) -> AgentState: print("--- Node: Graph Tree Router ---") # If retrying, ask planner for Broader Scope retry_count = state.get("retry_count", 0) classification = state["classification"] if retry_count > 0: print(f"🔄 Retry Mode ({retry_count}): Broadening Graph Search Scope") # Add a flag to classification to trigger broader search classification["broad_search"] = True plan = services.planner.generate_plan(classification) # Execute Neo4j Traversal Helper def execute_queries(query_list): res_list = [] with services.neo4j.session() as session: for idx, q in enumerate(query_list, 1): if "DELETE" in q or "DETACH" in q: continue try: res = session.run(q) data = [r.data() for r in res] res_list.extend(data) except Exception as e: print(f"Neo4j Error on Query {idx}: {e}") return res_list queries = plan.get("cypher_queries", []) print(f"DEBUG: Generated Cypher Queries ({len(queries)}): {json.dumps(queries, indent=2)}") print(f"DEBUG: About to execute {len(queries)} explicit queries") results = execute_queries(queries) # EAGER FALLBACK: If explicit queries fail, run fallbacks immediately in the same step if not results: print("DEBUG: 0 results from initial queries. Running eager fallbacks.") fallback_queries = [] articles = classification.get("entities", {}).get("articles", []) if isinstance(articles, str): articles = [articles] for art in articles: fallback_queries.append(f""" MATCH (am:Amendment)-[r]->(a:Article {{number: '{art}'}}) RETURN am.number as amendment, r.details as modification, type(r) as relationship, '{art}' as target_id LIMIT 50 """) amendments = classification.get("entities", {}).get("amendments", []) if not articles and amendments: for am in amendments: am_clean = str(am) if not am_clean.isdigit(): import re; am_clean = re.sub(r'\D', '', am_clean) if am_clean: fallback_queries.append(f""" MATCH (am:Amendment {{number: {am_clean}}})-[r]->(target) RETURN am.number as amendment, type(r) as relationship, r.details as modification, COALESCE(target.number, target.id) as target_id LIMIT 50 """) if fallback_queries: print(f"DEBUG: Executing {len(fallback_queries)} fallback queries") results.extend(execute_queries(fallback_queries)) print(f"DEBUG: Total Graph Results: {len(results)}") if results: print(f"DEBUG: Sample Result: {json.dumps(results, indent=2)}") return { "graph_plan": plan, "graph_results": results, "trace": ["GraphPlan"] } def node_fetch_vector(state: AgentState) -> AgentState: print(f"--- Node: Fetch Vector (Parallel) ---") query = state["query"] scope = state.get("qdrant_scope", {}) limit = state.get("graph_plan", {}).get("expected_chunks", 5) if state.get("retry_count", 0) > 0: limit += 5 print(f"Executing Scoped Search with Scope: {scope}") raw_chunks = services.vector_search.search(query, scope=scope, limit=limit) print(f"DEBUG: Vector Search fetched {len(raw_chunks)} chunks") return { "raw_chunks": raw_chunks, "trace": ["FetchVector"] } def node_coordinate(state: AgentState) -> AgentState: print("--- Node: Coordinate (Reranking) ---") retrieval_output = services.retriever.coordinate_retrieval( state["query"], state.get("graph_results", []), state.get("raw_chunks", []) ) final_chunks = retrieval_output.get("final_selected_chunks", []) print(f"DEBUG: Final Selected Chunks: {len(final_chunks)}") return { "retrieved_chunks": final_chunks, "trace": ["Coordinate"] } def node_organize(state: AgentState) -> AgentState: print("--- Node: Organize ---") # Pass full chunks with metadata (year, amendment_number) for temporal organization chunks = state.get("retrieved_chunks", []) context = services.organizer.organize_context( state["query"], chunks, graph_data=state.get("graph_results", []) ) print(f"DEBUG: Organized Context Keys: {list(context.keys())}") if "context_block" in context: print(f"DEBUG: Context Block Preview (first 500 chars):\n{context['context_block'][:500]}") return { "context_data": context, "trace": ["Organize"] } def node_reason(state: AgentState) -> AgentState: print("--- Node: Reason ---") answer = services.reasoner.generate_answer(state["query"], state["context_data"]) # Debug answer brief print(f"DEBUG: Draft Answer Preview: {json.dumps(answer)[:200]}...") return { "draft_answer": answer, "trace": ["Reason"] } def node_validate(state: AgentState) -> AgentState: print("--- Node: Validate ---") # Log graph data completeness for debugging graph_results = state.get("graph_results", []) graph_plan = state.get("graph_plan", {}) cypher_queries = graph_plan.get("cypher_queries", []) context_data = state.get("context_data", {}) retry_count = state.get("retry_count", 0) # Log graph results summary found_types = [] if any("target_id" in r for r in graph_results): found_types.append("Direct") if any("via_article" in r for r in graph_results): found_types.append("Multi-hop") if any("related_article" in r for r in graph_results): found_types.append("Related") print(f"DEBUG: Graph context contains {len(graph_results)} nodes across {', '.join(found_types) if found_types else 'no'} relationship patterns.") # We delegate full judgment of context completeness to the QA LLM instead of a hard-coded check. # Run LLM-based validation with BOTH graph and vector data critique = services.qa.validate_answer( state["query"], state["draft_answer"], graph_results, context_data ) print(f"DEBUG: Critique: {json.dumps(critique, indent=2)}") return { "critique": critique, "trace": ["Validate"] } # 4. Define Conditional Logic def should_continue(state: AgentState) -> str: critique = state.get("critique", {}) retry = state.get("retry_count", 0) # Max 2 retries total if retry >= 2: return "end" quality_grade = critique.get("quality_grade") if quality_grade == "REFINE": # Check QA's suggested retry type retry_type = critique.get("retry_type", "retrieve") # default to retrieve if retry_type == "reason": # Data is good, but answer missed it - just re-run reasoner print("🔄 Retry: Re-running reasoner with same data") return "retry_reason" else: # Data incomplete - need new retrieval print("🔄 Retry: Re-running retrieval") return "retry_retrieve" if quality_grade == "REVIEW" and retry < 1: # For REVIEW grade, try re-reasoning once print("🔄 Retry: Re-running reasoner for quality improvement") return "retry_reason" return "end" # 5. Build Graph workflow = StateGraph(AgentState) workflow.add_node("classify", node_classify) workflow.add_node("graph_plan", node_plan_graph) workflow.add_node("fetch_vector", node_fetch_vector) workflow.add_node("coordinate", node_coordinate) workflow.add_node("organize", node_organize) workflow.add_node("reason", node_reason) workflow.add_node("validate", node_validate) # Edges workflow.set_entry_point("classify") # Parallel Fan-Out workflow.add_edge("classify", "graph_plan") workflow.add_edge("classify", "fetch_vector") # Fan-In to Coordinate workflow.add_edge("graph_plan", "coordinate") workflow.add_edge("fetch_vector", "coordinate") workflow.add_edge("coordinate", "organize") workflow.add_edge("organize", "reason") workflow.add_edge("reason", "validate") workflow.add_conditional_edges( "validate", should_continue, { "retry_retrieve": "update_retry", "retry_reason": "update_retry", # Route both through update_retry to prevent infinite loops "end": END } ) def node_retry_dispatcher(state: AgentState) -> str: critique = state.get("critique", {}) return "reason" if critique.get("retry_type") == "reason" else "classify" workflow.add_node("update_retry", lambda x: {"retry_count": x["retry_count"] + 1}) workflow.add_conditional_edges("update_retry", node_retry_dispatcher, {"reason": "reason", "classify": "classify"}) # Compile app = workflow.compile()