Spaces:
Running
Running
| from typing import TypedDict, Annotated, List, Dict, Any, Union | |
| import operator | |
| from langgraph.graph import StateGraph, END | |
| import time | |
| import json | |
| # Import our existing services (Reusing logic!) | |
| from services.query_decomposition import QueryDecomposerService | |
| from services.graph_planner import GraphPlannerService | |
| from services.hybrid_retrieval import HybridRetrievalService | |
| from services.context_organization import ContextOrganizationService | |
| from services.legal_reasoner import LegalReasonerService | |
| from services.quality_assurance import QualityAssuranceService | |
| from services.neo4j import get_neo4j_driver | |
| from core.search import SearchService | |
| # 1. Start with State Definition | |
| class AgentState(TypedDict): | |
| # Inputs | |
| query: str | |
| # Internal State | |
| intent: str | |
| classification: Dict[str, Any] | |
| graph_plan: Dict[str, Any] | |
| qdrant_scope: Dict[str, Any] # New field for Tree Routing scope | |
| graph_results: List[Dict[str, Any]] | |
| retrieved_chunks: List[Dict[str, Any]] | |
| context_data: Dict[str, Any] | |
| draft_answer: Dict[str, Any] | |
| critique: Dict[str, Any] | |
| # Metadata | |
| retry_count: int | |
| trace: Annotated[List[str], operator.add] | |
| raw_chunks: List[Dict[str, Any]] | |
| # 2. Define Services Wrapper (Singleton access would be better, but init here is fine for now) | |
| # Ideally these should be initialized once outside and passed in, but we'll init inside the nodes or globally. | |
| # For thread safety in a real app, these should be global singletons. | |
| class ServiceContainer: | |
| def __init__(self): | |
| self.decomposer = QueryDecomposerService() | |
| self.planner = GraphPlannerService() | |
| self.retriever = HybridRetrievalService() | |
| self.organizer = ContextOrganizationService() | |
| self.reasoner = LegalReasonerService() | |
| self.qa = QualityAssuranceService() | |
| self.neo4j = get_neo4j_driver() | |
| self.vector_search = SearchService() | |
| # Global container instance | |
| services = ServiceContainer() | |
| # 3. Define Nodes | |
| def node_classify(state: AgentState) -> AgentState: | |
| print("--- Node: Classify ---") | |
| classification = services.decomposer.decompose(state["query"]) | |
| print(f"DEBUG: Classification: {json.dumps(classification, indent=2)}") | |
| return { | |
| "classification": classification, | |
| "qdrant_scope": { | |
| "article_numbers": classification.get("entities", {}).get("articles", []), | |
| "amendment_ids": classification.get("entities", {}).get("amendments", []), | |
| "include_amendments": True | |
| }, | |
| "trace": ["Classify"] | |
| } | |
| def node_plan_graph(state: AgentState) -> AgentState: | |
| print("--- Node: Graph Tree Router ---") | |
| # If retrying, ask planner for Broader Scope | |
| retry_count = state.get("retry_count", 0) | |
| classification = state["classification"] | |
| if retry_count > 0: | |
| print(f"๐ Retry Mode ({retry_count}): Broadening Graph Search Scope") | |
| # Add a flag to classification to trigger broader search | |
| classification["broad_search"] = True | |
| plan = services.planner.generate_plan(classification) | |
| # Execute Neo4j Traversal Helper | |
| def execute_queries(query_list): | |
| res_list = [] | |
| with services.neo4j.session() as session: | |
| for idx, q in enumerate(query_list, 1): | |
| if "DELETE" in q or "DETACH" in q: continue | |
| try: | |
| res = session.run(q) | |
| data = [r.data() for r in res] | |
| res_list.extend(data) | |
| except Exception as e: | |
| print(f"Neo4j Error on Query {idx}: {e}") | |
| return res_list | |
| queries = plan.get("cypher_queries", []) | |
| print(f"DEBUG: Generated Cypher Queries ({len(queries)}): {json.dumps(queries, indent=2)}") | |
| print(f"DEBUG: About to execute {len(queries)} explicit queries") | |
| results = execute_queries(queries) | |
| # EAGER FALLBACK: If explicit queries fail, run fallbacks immediately in the same step | |
| if not results: | |
| print("DEBUG: 0 results from initial queries. Running eager fallbacks.") | |
| fallback_queries = [] | |
| articles = classification.get("entities", {}).get("articles", []) | |
| if isinstance(articles, str): | |
| articles = [articles] | |
| for art in articles: | |
| fallback_queries.append(f""" | |
| MATCH (am:Amendment)-[r]->(a:Article {{number: '{art}'}}) | |
| RETURN am.number as amendment, r.details as modification, type(r) as relationship, '{art}' as target_id | |
| LIMIT 50 | |
| """) | |
| amendments = classification.get("entities", {}).get("amendments", []) | |
| if not articles and amendments: | |
| for am in amendments: | |
| am_clean = str(am) | |
| if not am_clean.isdigit(): | |
| import re; am_clean = re.sub(r'\D', '', am_clean) | |
| if am_clean: | |
| fallback_queries.append(f""" | |
| MATCH (am:Amendment {{number: {am_clean}}})-[r]->(target) | |
| RETURN am.number as amendment, type(r) as relationship, r.details as modification, COALESCE(target.number, target.id) as target_id | |
| LIMIT 50 | |
| """) | |
| if fallback_queries: | |
| print(f"DEBUG: Executing {len(fallback_queries)} fallback queries") | |
| results.extend(execute_queries(fallback_queries)) | |
| print(f"DEBUG: Total Graph Results: {len(results)}") | |
| if results: | |
| print(f"DEBUG: Sample Result: {json.dumps(results, indent=2)}") | |
| return { | |
| "graph_plan": plan, | |
| "graph_results": results, | |
| "trace": ["GraphPlan"] | |
| } | |
| def node_fetch_vector(state: AgentState) -> AgentState: | |
| print(f"--- Node: Fetch Vector (Parallel) ---") | |
| query = state["query"] | |
| scope = state.get("qdrant_scope", {}) | |
| limit = state.get("graph_plan", {}).get("expected_chunks", 5) | |
| if state.get("retry_count", 0) > 0: | |
| limit += 5 | |
| print(f"Executing Scoped Search with Scope: {scope}") | |
| raw_chunks = services.vector_search.search(query, scope=scope, limit=limit) | |
| print(f"DEBUG: Vector Search fetched {len(raw_chunks)} chunks") | |
| return { | |
| "raw_chunks": raw_chunks, | |
| "trace": ["FetchVector"] | |
| } | |
| def node_coordinate(state: AgentState) -> AgentState: | |
| print("--- Node: Coordinate (Reranking) ---") | |
| retrieval_output = services.retriever.coordinate_retrieval( | |
| state["query"], | |
| state.get("graph_results", []), | |
| state.get("raw_chunks", []) | |
| ) | |
| final_chunks = retrieval_output.get("final_selected_chunks", []) | |
| print(f"DEBUG: Final Selected Chunks: {len(final_chunks)}") | |
| return { | |
| "retrieved_chunks": final_chunks, | |
| "trace": ["Coordinate"] | |
| } | |
| def node_organize(state: AgentState) -> AgentState: | |
| print("--- Node: Organize ---") | |
| # Pass full chunks with metadata (year, amendment_number) for temporal organization | |
| chunks = state.get("retrieved_chunks", []) | |
| context = services.organizer.organize_context( | |
| state["query"], | |
| chunks, | |
| graph_data=state.get("graph_results", []) | |
| ) | |
| print(f"DEBUG: Organized Context Keys: {list(context.keys())}") | |
| if "context_block" in context: | |
| print(f"DEBUG: Context Block Preview (first 500 chars):\n{context['context_block'][:500]}") | |
| return { | |
| "context_data": context, | |
| "trace": ["Organize"] | |
| } | |
| def node_reason(state: AgentState) -> AgentState: | |
| print("--- Node: Reason ---") | |
| answer = services.reasoner.generate_answer(state["query"], state["context_data"]) | |
| # Debug answer brief | |
| print(f"DEBUG: Draft Answer Preview: {json.dumps(answer)[:200]}...") | |
| return { | |
| "draft_answer": answer, | |
| "trace": ["Reason"] | |
| } | |
| def node_validate(state: AgentState) -> AgentState: | |
| print("--- Node: Validate ---") | |
| # Log graph data completeness for debugging | |
| graph_results = state.get("graph_results", []) | |
| graph_plan = state.get("graph_plan", {}) | |
| cypher_queries = graph_plan.get("cypher_queries", []) | |
| context_data = state.get("context_data", {}) | |
| retry_count = state.get("retry_count", 0) | |
| # Log graph results summary | |
| found_types = [] | |
| if any("target_id" in r for r in graph_results): found_types.append("Direct") | |
| if any("via_article" in r for r in graph_results): found_types.append("Multi-hop") | |
| if any("related_article" in r for r in graph_results): found_types.append("Related") | |
| print(f"DEBUG: Graph context contains {len(graph_results)} nodes across {', '.join(found_types) if found_types else 'no'} relationship patterns.") | |
| # We delegate full judgment of context completeness to the QA LLM instead of a hard-coded check. | |
| # Run LLM-based validation with BOTH graph and vector data | |
| critique = services.qa.validate_answer( | |
| state["query"], | |
| state["draft_answer"], | |
| graph_results, | |
| context_data | |
| ) | |
| print(f"DEBUG: Critique: {json.dumps(critique, indent=2)}") | |
| return { | |
| "critique": critique, | |
| "trace": ["Validate"] | |
| } | |
| # 4. Define Conditional Logic | |
| def should_continue(state: AgentState) -> str: | |
| critique = state.get("critique", {}) | |
| retry = state.get("retry_count", 0) | |
| # Max 2 retries total | |
| if retry >= 2: | |
| return "end" | |
| quality_grade = critique.get("quality_grade") | |
| if quality_grade == "REFINE": | |
| # Check QA's suggested retry type | |
| retry_type = critique.get("retry_type", "retrieve") # default to retrieve | |
| if retry_type == "reason": | |
| # Data is good, but answer missed it - just re-run reasoner | |
| print("๐ Retry: Re-running reasoner with same data") | |
| return "retry_reason" | |
| else: | |
| # Data incomplete - need new retrieval | |
| print("๐ Retry: Re-running retrieval") | |
| return "retry_retrieve" | |
| if quality_grade == "REVIEW" and retry < 1: | |
| # For REVIEW grade, try re-reasoning once | |
| print("๐ Retry: Re-running reasoner for quality improvement") | |
| return "retry_reason" | |
| return "end" | |
| # 5. Build Graph | |
| workflow = StateGraph(AgentState) | |
| workflow.add_node("classify", node_classify) | |
| workflow.add_node("graph_plan", node_plan_graph) | |
| workflow.add_node("fetch_vector", node_fetch_vector) | |
| workflow.add_node("coordinate", node_coordinate) | |
| workflow.add_node("organize", node_organize) | |
| workflow.add_node("reason", node_reason) | |
| workflow.add_node("validate", node_validate) | |
| # Edges | |
| workflow.set_entry_point("classify") | |
| # Parallel Fan-Out | |
| workflow.add_edge("classify", "graph_plan") | |
| workflow.add_edge("classify", "fetch_vector") | |
| # Fan-In to Coordinate | |
| workflow.add_edge("graph_plan", "coordinate") | |
| workflow.add_edge("fetch_vector", "coordinate") | |
| workflow.add_edge("coordinate", "organize") | |
| workflow.add_edge("organize", "reason") | |
| workflow.add_edge("reason", "validate") | |
| workflow.add_conditional_edges( | |
| "validate", | |
| should_continue, | |
| { | |
| "retry_retrieve": "update_retry", | |
| "retry_reason": "update_retry", # Route both through update_retry to prevent infinite loops | |
| "end": END | |
| } | |
| ) | |
| def node_retry_dispatcher(state: AgentState) -> str: | |
| critique = state.get("critique", {}) | |
| return "reason" if critique.get("retry_type") == "reason" else "classify" | |
| workflow.add_node("update_retry", lambda x: {"retry_count": x["retry_count"] + 1}) | |
| workflow.add_conditional_edges("update_retry", node_retry_dispatcher, {"reason": "reason", "classify": "classify"}) | |
| # Compile | |
| app = workflow.compile() | |