ConstitutionAgent / services /graph_planner.py
Meshyboi's picture
Upload 53 files
0cd3dc5 verified
import json
from groq import Groq
from utils.config import settings
class GraphPlannerService:
def __init__(self):
if not settings.GROQ_API_KEY:
raise ValueError("GROQ_API_KEY is not set.")
self.client = Groq(api_key=settings.GROQ_API_KEY)
self.model = getattr(settings, 'GROQ_FAST_MODEL', settings.GROQ_MODEL)
def _get_json_response(self, prompt: str) -> dict:
"""Robust method to get JSON response, handling API errors and Markdown"""
import re
import time
for attempt in range(3):
try:
# Primary Strategy: Strict JSON Mode
try:
chat_completion = self.client.chat.completions.create(
messages=[
{"role": "system", "content": "You are a helpful assistant that outputs JSON."},
{"role": "user", "content": prompt}
],
model=self.model,
temperature=0,
response_format={"type": "json_object"}
)
content = chat_completion.choices[0].message.content
return json.loads(content.strip())
except Exception as api_err:
# Fallback Strategy: Text Mode if JSON validation fails
if "400" in str(api_err) or "json_validate_failed" in str(api_err):
print(f"JSON Mode failed. Falling back to Text Mode (Attempt {attempt+1})...")
chat_completion = self.client.chat.completions.create(
messages=[
{"role": "system", "content": "You are a helpful assistant that outputs JSON."},
{"role": "user", "content": prompt + "\n\nOUTPUT RAW JSON ONLY. NO MARKDOWN."}
],
model=self.model,
temperature=0
)
content = chat_completion.choices[0].message.content
# Strip Markdown Code Blocks
match = re.search(r"```(?:json)?(.*?)```", content, re.DOTALL)
if match:
content = match.group(1)
return json.loads(content.strip())
raise api_err
except Exception as e:
print(f"Error in LLM request (Attempt {attempt+1}): {e}")
if attempt < 2:
time.sleep(2)
else:
return {}
return {}
def _clean_cypher(self, query: str) -> str:
"""Sanitizes LLM-generated Cypher to prevent common syntax errors."""
import re
if not query: return ""
# 1. Strip Markdown code blocks
query = re.sub(r"```(?:cypher|json)?(.*?)```", r"\1", query, flags=re.DOTALL)
# 2. Fix Doubled Braces (LLM over-copying prompt escaping)
query = query.replace("{{", "{").replace("}}", "}")
# 3. Standardize Quotes (Smart quotes from some models)
query = query.replace("“", "\"").replace("”", "\"").replace("‘", "'").replace("’", "'")
# 4. Remove trailing semicolons
query = query.strip().rstrip(";")
return query.strip()
def generate_plan(self, classification_json: dict) -> dict:
prompt = f"""
You are a Constitution Graph Router. Map the user query to specific Graph Database (Neo4j) queries.
User Query: "{classification_json.get('raw_query', '')}"
Entities: {json.dumps(classification_json.get('entities'))}
YOUR TOOLBOX (Graph Schema):
- Nodes:
- `Amendment` (property: `number` (int))
- `Article` (property: `number` (string)) (e.g. "19", "21A")
- `Clause` (property: `id` (string)) (e.g. "19(1)(a)")
- `Schedule` (property: `id` (string))
- Relationships (Directed):
- `(Amendment)-[:AMENDS|SUBSTITUTES|INSERTS|OMITS|REPEALS|OVERRIDES|SUSPENDS|REFERS_TO|DEFINES]->(Article|Clause)`
- `(Article)-[:REFERS_TO|SUSPENDS|OVERRIDES|DEFINES]->(Article)`
- `(Clause)-[:PART_OF]->(Article)`
YOUR GOAL:
Generate 1-4 distinct Cypher queries based on the query:
1. **Direct Changes**: Find changes to the Article OR its Clauses.
- Preferred Pattern: `MATCH (am:Amendment)-[r]->(target) WHERE (target:Article AND target.number = '...') OR (target:Clause AND target.id STARTS WITH '...') RETURN am.number as amendment, am.year as year, type(r) as relationship, r.details as details, coalesce(target.number, target.id) as target_id`
2. **Contextual Changes (Multi-hop)**: Find if the Amendment modified *another* Article (like 358) that impacts the target Article (like 19).
- Pattern: `MATCH (am:Amendment)-[r1]->(other:Article)-[r2:SUSPENDS|OVERRIDES|REFERS_TO]->(target:Article {{number: '...'}}) RETURN am.number as amendment, am.year as year, other.number as via_article, r1.details as modification, type(r2) as effect`
3. **Related Provisions (Smart Expansion)**: Find articles that constitutionally affect the target, then check if the amendment modified those related articles.
- Pattern: `MATCH (related:Article)-[r_impact:SUSPENDS|OVERRIDES|REFERS_TO]->(target:Article {{number: '...'}}) WITH related, r_impact MATCH (am:Amendment)-[r_change]->(related) RETURN am.number as amendment, am.year as year, related.number as related_article, type(r_change) as change_type, r_change.details as change_details, type(r_impact) as impact_on_target`
4. **General Amendment Discovery**: If asking about an Amendment (e.g. "What is 44th Amendment?") without an article.
- Pattern: `MATCH (am:Amendment {{number: 44}})-[r]->(target) RETURN am.number as amendment, am.year as year, type(r) as relationship, r.details as details, COALESCE(target.number, target.id) as target_id, labels(target)[0] as target_type`
CRITICAL INSTRUCTIONS:
- **TYPE SAFETY**: `Amendment.number` is an INTEGER. Use `WHERE am.number = 44` or `{{number: 44}}`. NEVER use quotes around the amendment number.
- **ARTICLE STRINGS**: `Article.number` is a STRING. Use `WHERE target.number = '19'`.
- **Discovery Preference**: If the query is just about an Amendment, Query 4 is the MOST important. Do not hallucinate articles if none are provided.
- **Explicit Returns**: ALWAYS return specific properties (e.g. `r.details`, `target.id`), NOT the whole object (e.g. `r`). Objects serialize poorly.
- **Filtering**: If the user asks for "Amendment 44", YOU MUST FILTER `WHERE am.number = 44`.
- **Targeting**: Always check for changes to `Clauses` of an Article, not just the Article itself.
- **Query 3 Logic**: First find articles that affect the target (e.g., Article 358 SUSPENDS Article 19), then check if the amendment changed those articles or their dependencies (e.g., Amendment 44 changed Article 352 which Article 358 relies on).
- **Single String Rule**: Each Cypher query must be a SINGLE string in the `cypher_queries` list. Do NOT split one query into multiple list items. Use `\\n` for line breaks inside the string.
RETURN FORMAT (JSON ONLY):
{{
"tree_start_node": "Article 19",
"cypher_queries": [
"MATCH ... RETURN ...",
"MATCH ... RETURN ...",
"MATCH ... RETURN ..."
]
}}
"""
result = self._get_json_response(prompt)
if not result:
return {
"cypher_queries": [],
"tree_start_node": "Unknown",
"traversal": [],
"qdrant_scope": {"article_numbers": [], "parts": [], "include_amendments": False},
"error": "Error in graph planning"
}
# Sanitize queries
raw_queries = result.get("cypher_queries", [])
clean_queries = []
if isinstance(raw_queries, list):
for q in raw_queries:
if isinstance(q, str):
clean_queries.append(self._clean_cypher(q))
elif isinstance(q, dict) and "query" in q:
clean_queries.append(self._clean_cypher(q["query"]))
result["cypher_queries"] = [q for q in clean_queries if q]
return result