import time from services.query_decomposition import QueryDecomposerService from services.graph_planner import GraphPlannerService from services.hybrid_retrieval import HybridRetrievalService from services.context_organization import ContextOrganizationService from services.legal_reasoner import LegalReasonerService from services.quality_assurance import QualityAssuranceService from services.neo4j import get_neo4j_driver from core.search import SearchService import json class ConstitutionWorkflow: def __init__(self): self.decomposer = QueryDecomposerService() self.planner = GraphPlannerService() self.retriever_coordinator = HybridRetrievalService() self.context_organizer = ContextOrganizationService() self.reasoner = LegalReasonerService() self.qa_agent = QualityAssuranceService() self.neo4j = get_neo4j_driver() self.vector_search = SearchService() def execute_neo4j(self, queries): results = [] try: with self.neo4j.session() as session: for q in queries: if "DELETE" in q or "DETACH" in q: continue res = session.run(q) results.extend([r.data() for r in res]) except Exception as e: print(f"Neo4j Execution Error: {e}") return results def run_flow(self, query: str) -> dict: start_time = time.time() # State Initialization state = { "query": query, "trace": [], "graph_results": [], "retrieval_plan": {}, "retrieved_chunks": [], "context_data": {}, "answer": {}, "critique": {}, "retry_count": 0, "max_retries": 1 } # 1. CLASSIFY (Run once) state["trace"].append("Agent 1: Classification") classification = self.decomposer.decompose(query) print(f"DEBUG: Intent={classification.get('intent')}") # 2. PLAN (Run once) state["trace"].append("Agent 2: Graph Planning") plan = self.planner.generate_plan(classification) state["graph_results"] = self.execute_neo4j(plan.get("cypher_queries", [])) print(f"DEBUG: Found {len(state['graph_results'])} graph entities") while state["retry_count"] <= state["max_retries"]: # 3. RETRIEVE state["trace"].append(f"Agent 3: Hybrid Retrieval (Attempt {state['retry_count']+1})") # If we are retrying and have specific feedback, mix it in search_query = query if state["retry_count"] > 0 and state["critique"].get("suggested_retrieval"): # In a full complex system, we'd execute the Cypher suggestion. # For now, we append the suggestion to the search query to bias the vector search # or simply re-run vector search with the original query if logic was just "missed chunks". # Better: Use missing entities to boost. missing = state["critique"].get("missing_entities", {}) if missing.get("amendments"): search_query += f" amendment {missing['amendments']}" # Fetch candidates (broaden search if retrying) limit = 20 if state["retry_count"] == 0 else 40 raw_chunks = self.vector_search.search(search_query, limit=limit) # Coordinate & Rerank retrieval_output = self.retriever_coordinator.coordinate_retrieval( query, state["graph_results"], raw_chunks ) state["retrieved_chunks"] = retrieval_output.get("final_selected_chunks", []) # 4. ORGANIZE state["trace"].append("Agent 4: Context Organization") # Extract the raw chunk dicts from the rerank result wrapper chunk_data = [c["full_chunk"] for c in state["retrieved_chunks"] if "full_chunk" in c] if not chunk_data: chunk_data = raw_chunks[:10] # Fallback state["context_data"] = self.context_organizer.organize_context( query, chunk_data, state["graph_results"] ) # 5. REASON state["trace"].append("Agent 5: Legal Reasoning") state["answer"] = self.reasoner.generate_answer(query, state["context_data"]) # 6. CRITIQUE state["trace"].append("Agent 6: Quality Assurance") state["critique"] = self.qa_agent.validate_answer(query, state["answer"], state["graph_results"]) if state["critique"].get("quality_grade") == "PASS": break if state["critique"].get("quality_grade") == "REFINE": print(f"DEBUG: QA requested refinement. confidence={state['critique'].get('final_confidence')}") state["retry_count"] += 1 else: # Review or unknown, usually accept but note it break return { "query": query, "answer": state["answer"].get("answer"), "constitutional_status": state["answer"].get("constitutional_status"), "confidence": state["critique"].get("final_confidence"), "sources": state["answer"].get("sources"), "quality_grade": state["critique"].get("quality_grade"), "execution_trace": state["trace"], "time_taken": time.time() - start_time } def close(self): self.neo4j.close()