Spaces:
Running
Running
| import time | |
| from services.query_decomposition import QueryDecomposerService | |
| from services.graph_planner import GraphPlannerService | |
| from services.hybrid_retrieval import HybridRetrievalService | |
| from services.context_organization import ContextOrganizationService | |
| from services.legal_reasoner import LegalReasonerService | |
| from services.quality_assurance import QualityAssuranceService | |
| from services.neo4j import get_neo4j_driver | |
| from core.search import SearchService | |
| import json | |
| class ConstitutionWorkflow: | |
| def __init__(self): | |
| self.decomposer = QueryDecomposerService() | |
| self.planner = GraphPlannerService() | |
| self.retriever_coordinator = HybridRetrievalService() | |
| self.context_organizer = ContextOrganizationService() | |
| self.reasoner = LegalReasonerService() | |
| self.qa_agent = QualityAssuranceService() | |
| self.neo4j = get_neo4j_driver() | |
| self.vector_search = SearchService() | |
| def execute_neo4j(self, queries): | |
| results = [] | |
| try: | |
| with self.neo4j.session() as session: | |
| for q in queries: | |
| if "DELETE" in q or "DETACH" in q: continue | |
| res = session.run(q) | |
| results.extend([r.data() for r in res]) | |
| except Exception as e: | |
| print(f"Neo4j Execution Error: {e}") | |
| return results | |
| def run_flow(self, query: str) -> dict: | |
| start_time = time.time() | |
| # State Initialization | |
| state = { | |
| "query": query, | |
| "trace": [], | |
| "graph_results": [], | |
| "retrieval_plan": {}, | |
| "retrieved_chunks": [], | |
| "context_data": {}, | |
| "answer": {}, | |
| "critique": {}, | |
| "retry_count": 0, | |
| "max_retries": 1 | |
| } | |
| # 1. CLASSIFY (Run once) | |
| state["trace"].append("Agent 1: Classification") | |
| classification = self.decomposer.decompose(query) | |
| print(f"DEBUG: Intent={classification.get('intent')}") | |
| # 2. PLAN (Run once) | |
| state["trace"].append("Agent 2: Graph Planning") | |
| plan = self.planner.generate_plan(classification) | |
| state["graph_results"] = self.execute_neo4j(plan.get("cypher_queries", [])) | |
| print(f"DEBUG: Found {len(state['graph_results'])} graph entities") | |
| while state["retry_count"] <= state["max_retries"]: | |
| # 3. RETRIEVE | |
| state["trace"].append(f"Agent 3: Hybrid Retrieval (Attempt {state['retry_count']+1})") | |
| # If we are retrying and have specific feedback, mix it in | |
| search_query = query | |
| if state["retry_count"] > 0 and state["critique"].get("suggested_retrieval"): | |
| # In a full complex system, we'd execute the Cypher suggestion. | |
| # For now, we append the suggestion to the search query to bias the vector search | |
| # or simply re-run vector search with the original query if logic was just "missed chunks". | |
| # Better: Use missing entities to boost. | |
| missing = state["critique"].get("missing_entities", {}) | |
| if missing.get("amendments"): | |
| search_query += f" amendment {missing['amendments']}" | |
| # Fetch candidates (broaden search if retrying) | |
| limit = 20 if state["retry_count"] == 0 else 40 | |
| raw_chunks = self.vector_search.search(search_query, limit=limit) | |
| # Coordinate & Rerank | |
| retrieval_output = self.retriever_coordinator.coordinate_retrieval( | |
| query, | |
| state["graph_results"], | |
| raw_chunks | |
| ) | |
| state["retrieved_chunks"] = retrieval_output.get("final_selected_chunks", []) | |
| # 4. ORGANIZE | |
| state["trace"].append("Agent 4: Context Organization") | |
| # Extract the raw chunk dicts from the rerank result wrapper | |
| chunk_data = [c["full_chunk"] for c in state["retrieved_chunks"] if "full_chunk" in c] | |
| if not chunk_data: chunk_data = raw_chunks[:10] # Fallback | |
| state["context_data"] = self.context_organizer.organize_context( | |
| query, | |
| chunk_data, | |
| state["graph_results"] | |
| ) | |
| # 5. REASON | |
| state["trace"].append("Agent 5: Legal Reasoning") | |
| state["answer"] = self.reasoner.generate_answer(query, state["context_data"]) | |
| # 6. CRITIQUE | |
| state["trace"].append("Agent 6: Quality Assurance") | |
| state["critique"] = self.qa_agent.validate_answer(query, state["answer"], state["graph_results"]) | |
| if state["critique"].get("quality_grade") == "PASS": | |
| break | |
| if state["critique"].get("quality_grade") == "REFINE": | |
| print(f"DEBUG: QA requested refinement. confidence={state['critique'].get('final_confidence')}") | |
| state["retry_count"] += 1 | |
| else: | |
| # Review or unknown, usually accept but note it | |
| break | |
| return { | |
| "query": query, | |
| "answer": state["answer"].get("answer"), | |
| "constitutional_status": state["answer"].get("constitutional_status"), | |
| "confidence": state["critique"].get("final_confidence"), | |
| "sources": state["answer"].get("sources"), | |
| "quality_grade": state["critique"].get("quality_grade"), | |
| "execution_trace": state["trace"], | |
| "time_taken": time.time() - start_time | |
| } | |
| def close(self): | |
| self.neo4j.close() | |