Spaces:
Running
Running
File size: 5,701 Bytes
0cd3dc5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 | import time
from services.query_decomposition import QueryDecomposerService
from services.graph_planner import GraphPlannerService
from services.hybrid_retrieval import HybridRetrievalService
from services.context_organization import ContextOrganizationService
from services.legal_reasoner import LegalReasonerService
from services.quality_assurance import QualityAssuranceService
from services.neo4j import get_neo4j_driver
from core.search import SearchService
import json
class ConstitutionWorkflow:
def __init__(self):
self.decomposer = QueryDecomposerService()
self.planner = GraphPlannerService()
self.retriever_coordinator = HybridRetrievalService()
self.context_organizer = ContextOrganizationService()
self.reasoner = LegalReasonerService()
self.qa_agent = QualityAssuranceService()
self.neo4j = get_neo4j_driver()
self.vector_search = SearchService()
def execute_neo4j(self, queries):
results = []
try:
with self.neo4j.session() as session:
for q in queries:
if "DELETE" in q or "DETACH" in q: continue
res = session.run(q)
results.extend([r.data() for r in res])
except Exception as e:
print(f"Neo4j Execution Error: {e}")
return results
def run_flow(self, query: str) -> dict:
start_time = time.time()
# State Initialization
state = {
"query": query,
"trace": [],
"graph_results": [],
"retrieval_plan": {},
"retrieved_chunks": [],
"context_data": {},
"answer": {},
"critique": {},
"retry_count": 0,
"max_retries": 1
}
# 1. CLASSIFY (Run once)
state["trace"].append("Agent 1: Classification")
classification = self.decomposer.decompose(query)
print(f"DEBUG: Intent={classification.get('intent')}")
# 2. PLAN (Run once)
state["trace"].append("Agent 2: Graph Planning")
plan = self.planner.generate_plan(classification)
state["graph_results"] = self.execute_neo4j(plan.get("cypher_queries", []))
print(f"DEBUG: Found {len(state['graph_results'])} graph entities")
while state["retry_count"] <= state["max_retries"]:
# 3. RETRIEVE
state["trace"].append(f"Agent 3: Hybrid Retrieval (Attempt {state['retry_count']+1})")
# If we are retrying and have specific feedback, mix it in
search_query = query
if state["retry_count"] > 0 and state["critique"].get("suggested_retrieval"):
# In a full complex system, we'd execute the Cypher suggestion.
# For now, we append the suggestion to the search query to bias the vector search
# or simply re-run vector search with the original query if logic was just "missed chunks".
# Better: Use missing entities to boost.
missing = state["critique"].get("missing_entities", {})
if missing.get("amendments"):
search_query += f" amendment {missing['amendments']}"
# Fetch candidates (broaden search if retrying)
limit = 20 if state["retry_count"] == 0 else 40
raw_chunks = self.vector_search.search(search_query, limit=limit)
# Coordinate & Rerank
retrieval_output = self.retriever_coordinator.coordinate_retrieval(
query,
state["graph_results"],
raw_chunks
)
state["retrieved_chunks"] = retrieval_output.get("final_selected_chunks", [])
# 4. ORGANIZE
state["trace"].append("Agent 4: Context Organization")
# Extract the raw chunk dicts from the rerank result wrapper
chunk_data = [c["full_chunk"] for c in state["retrieved_chunks"] if "full_chunk" in c]
if not chunk_data: chunk_data = raw_chunks[:10] # Fallback
state["context_data"] = self.context_organizer.organize_context(
query,
chunk_data,
state["graph_results"]
)
# 5. REASON
state["trace"].append("Agent 5: Legal Reasoning")
state["answer"] = self.reasoner.generate_answer(query, state["context_data"])
# 6. CRITIQUE
state["trace"].append("Agent 6: Quality Assurance")
state["critique"] = self.qa_agent.validate_answer(query, state["answer"], state["graph_results"])
if state["critique"].get("quality_grade") == "PASS":
break
if state["critique"].get("quality_grade") == "REFINE":
print(f"DEBUG: QA requested refinement. confidence={state['critique'].get('final_confidence')}")
state["retry_count"] += 1
else:
# Review or unknown, usually accept but note it
break
return {
"query": query,
"answer": state["answer"].get("answer"),
"constitutional_status": state["answer"].get("constitutional_status"),
"confidence": state["critique"].get("final_confidence"),
"sources": state["answer"].get("sources"),
"quality_grade": state["critique"].get("quality_grade"),
"execution_trace": state["trace"],
"time_taken": time.time() - start_time
}
def close(self):
self.neo4j.close()
|