ConstitutionAgent / core /workflow.py
Meshyboi's picture
Upload 53 files
0cd3dc5 verified
import time
from services.query_decomposition import QueryDecomposerService
from services.graph_planner import GraphPlannerService
from services.hybrid_retrieval import HybridRetrievalService
from services.context_organization import ContextOrganizationService
from services.legal_reasoner import LegalReasonerService
from services.quality_assurance import QualityAssuranceService
from services.neo4j import get_neo4j_driver
from core.search import SearchService
import json
class ConstitutionWorkflow:
def __init__(self):
self.decomposer = QueryDecomposerService()
self.planner = GraphPlannerService()
self.retriever_coordinator = HybridRetrievalService()
self.context_organizer = ContextOrganizationService()
self.reasoner = LegalReasonerService()
self.qa_agent = QualityAssuranceService()
self.neo4j = get_neo4j_driver()
self.vector_search = SearchService()
def execute_neo4j(self, queries):
results = []
try:
with self.neo4j.session() as session:
for q in queries:
if "DELETE" in q or "DETACH" in q: continue
res = session.run(q)
results.extend([r.data() for r in res])
except Exception as e:
print(f"Neo4j Execution Error: {e}")
return results
def run_flow(self, query: str) -> dict:
start_time = time.time()
# State Initialization
state = {
"query": query,
"trace": [],
"graph_results": [],
"retrieval_plan": {},
"retrieved_chunks": [],
"context_data": {},
"answer": {},
"critique": {},
"retry_count": 0,
"max_retries": 1
}
# 1. CLASSIFY (Run once)
state["trace"].append("Agent 1: Classification")
classification = self.decomposer.decompose(query)
print(f"DEBUG: Intent={classification.get('intent')}")
# 2. PLAN (Run once)
state["trace"].append("Agent 2: Graph Planning")
plan = self.planner.generate_plan(classification)
state["graph_results"] = self.execute_neo4j(plan.get("cypher_queries", []))
print(f"DEBUG: Found {len(state['graph_results'])} graph entities")
while state["retry_count"] <= state["max_retries"]:
# 3. RETRIEVE
state["trace"].append(f"Agent 3: Hybrid Retrieval (Attempt {state['retry_count']+1})")
# If we are retrying and have specific feedback, mix it in
search_query = query
if state["retry_count"] > 0 and state["critique"].get("suggested_retrieval"):
# In a full complex system, we'd execute the Cypher suggestion.
# For now, we append the suggestion to the search query to bias the vector search
# or simply re-run vector search with the original query if logic was just "missed chunks".
# Better: Use missing entities to boost.
missing = state["critique"].get("missing_entities", {})
if missing.get("amendments"):
search_query += f" amendment {missing['amendments']}"
# Fetch candidates (broaden search if retrying)
limit = 20 if state["retry_count"] == 0 else 40
raw_chunks = self.vector_search.search(search_query, limit=limit)
# Coordinate & Rerank
retrieval_output = self.retriever_coordinator.coordinate_retrieval(
query,
state["graph_results"],
raw_chunks
)
state["retrieved_chunks"] = retrieval_output.get("final_selected_chunks", [])
# 4. ORGANIZE
state["trace"].append("Agent 4: Context Organization")
# Extract the raw chunk dicts from the rerank result wrapper
chunk_data = [c["full_chunk"] for c in state["retrieved_chunks"] if "full_chunk" in c]
if not chunk_data: chunk_data = raw_chunks[:10] # Fallback
state["context_data"] = self.context_organizer.organize_context(
query,
chunk_data,
state["graph_results"]
)
# 5. REASON
state["trace"].append("Agent 5: Legal Reasoning")
state["answer"] = self.reasoner.generate_answer(query, state["context_data"])
# 6. CRITIQUE
state["trace"].append("Agent 6: Quality Assurance")
state["critique"] = self.qa_agent.validate_answer(query, state["answer"], state["graph_results"])
if state["critique"].get("quality_grade") == "PASS":
break
if state["critique"].get("quality_grade") == "REFINE":
print(f"DEBUG: QA requested refinement. confidence={state['critique'].get('final_confidence')}")
state["retry_count"] += 1
else:
# Review or unknown, usually accept but note it
break
return {
"query": query,
"answer": state["answer"].get("answer"),
"constitutional_status": state["answer"].get("constitutional_status"),
"confidence": state["critique"].get("final_confidence"),
"sources": state["answer"].get("sources"),
"quality_grade": state["critique"].get("quality_grade"),
"execution_trace": state["trace"],
"time_taken": time.time() - start_time
}
def close(self):
self.neo4j.close()