File size: 5,701 Bytes
0cd3dc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import time
from services.query_decomposition import QueryDecomposerService
from services.graph_planner import GraphPlannerService
from services.hybrid_retrieval import HybridRetrievalService
from services.context_organization import ContextOrganizationService
from services.legal_reasoner import LegalReasonerService
from services.quality_assurance import QualityAssuranceService
from services.neo4j import get_neo4j_driver
from core.search import SearchService
import json

class ConstitutionWorkflow:
    def __init__(self):
        self.decomposer = QueryDecomposerService()
        self.planner = GraphPlannerService()
        self.retriever_coordinator = HybridRetrievalService()
        self.context_organizer = ContextOrganizationService()
        self.reasoner = LegalReasonerService()
        self.qa_agent = QualityAssuranceService()
        
        self.neo4j = get_neo4j_driver()
        self.vector_search = SearchService() 

    def execute_neo4j(self, queries):
        results = []
        try:
            with self.neo4j.session() as session:
                for q in queries:
                    if "DELETE" in q or "DETACH" in q: continue 
                    res = session.run(q)
                    results.extend([r.data() for r in res])
        except Exception as e:
            print(f"Neo4j Execution Error: {e}")
        return results

    def run_flow(self, query: str) -> dict:
        start_time = time.time()
        
        # State Initialization
        state = {
            "query": query,
            "trace": [],
            "graph_results": [],
            "retrieval_plan": {},
            "retrieved_chunks": [],
            "context_data": {},
            "answer": {},
            "critique": {},
            "retry_count": 0,
            "max_retries": 1
        }

        # 1. CLASSIFY (Run once)
        state["trace"].append("Agent 1: Classification")
        classification = self.decomposer.decompose(query)
        print(f"DEBUG: Intent={classification.get('intent')}")
        
        # 2. PLAN (Run once)
        state["trace"].append("Agent 2: Graph Planning")
        plan = self.planner.generate_plan(classification)
        state["graph_results"] = self.execute_neo4j(plan.get("cypher_queries", []))
        print(f"DEBUG: Found {len(state['graph_results'])} graph entities")

        while state["retry_count"] <= state["max_retries"]:
            # 3. RETRIEVE
            state["trace"].append(f"Agent 3: Hybrid Retrieval (Attempt {state['retry_count']+1})")
            
            # If we are retrying and have specific feedback, mix it in
            search_query = query
            if state["retry_count"] > 0 and state["critique"].get("suggested_retrieval"):
                 # In a full complex system, we'd execute the Cypher suggestion.
                 # For now, we append the suggestion to the search query to bias the vector search
                 # or simply re-run vector search with the original query if logic was just "missed chunks".
                 # Better: Use missing entities to boost.
                 missing = state["critique"].get("missing_entities", {})
                 if missing.get("amendments"):
                     search_query += f" amendment {missing['amendments']}"
            
            # Fetch candidates (broaden search if retrying)
            limit = 20 if state["retry_count"] == 0 else 40
            raw_chunks = self.vector_search.search(search_query, limit=limit)
            
            # Coordinate & Rerank
            retrieval_output = self.retriever_coordinator.coordinate_retrieval(
                query, 
                state["graph_results"], 
                raw_chunks
            )
            state["retrieved_chunks"] = retrieval_output.get("final_selected_chunks", [])
            
            # 4. ORGANIZE
            state["trace"].append("Agent 4: Context Organization")
            # Extract the raw chunk dicts from the rerank result wrapper
            chunk_data = [c["full_chunk"] for c in state["retrieved_chunks"] if "full_chunk" in c]
            if not chunk_data: chunk_data = raw_chunks[:10] # Fallback
            
            state["context_data"] = self.context_organizer.organize_context(
                query, 
                chunk_data, 
                state["graph_results"]
            )

            # 5. REASON
            state["trace"].append("Agent 5: Legal Reasoning")
            state["answer"] = self.reasoner.generate_answer(query, state["context_data"])

            # 6. CRITIQUE
            state["trace"].append("Agent 6: Quality Assurance")
            state["critique"] = self.qa_agent.validate_answer(query, state["answer"], state["graph_results"])
            
            if state["critique"].get("quality_grade") == "PASS":
                break
            
            if state["critique"].get("quality_grade") == "REFINE":
                print(f"DEBUG: QA requested refinement. confidence={state['critique'].get('final_confidence')}")
                state["retry_count"] += 1
            else:
                # Review or unknown, usually accept but note it
                break

        return {
            "query": query,
            "answer": state["answer"].get("answer"),
            "constitutional_status": state["answer"].get("constitutional_status"),
            "confidence": state["critique"].get("final_confidence"),
            "sources": state["answer"].get("sources"),
            "quality_grade": state["critique"].get("quality_grade"),
            "execution_trace": state["trace"],
            "time_taken": time.time() - start_time
        }

    def close(self):
        self.neo4j.close()