File size: 8,355 Bytes
0cd3dc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import json
from groq import Groq
from utils.config import settings

class GraphPlannerService:
    def __init__(self):
        if not settings.GROQ_API_KEY:
            raise ValueError("GROQ_API_KEY is not set.")
        self.client = Groq(api_key=settings.GROQ_API_KEY)
        self.model = getattr(settings, 'GROQ_FAST_MODEL', settings.GROQ_MODEL)

    def _get_json_response(self, prompt: str) -> dict:
        """Robust method to get JSON response, handling API errors and Markdown"""
        import re
        import time
        
        for attempt in range(3):
            try:
                # Primary Strategy: Strict JSON Mode
                try:
                    chat_completion = self.client.chat.completions.create(
                        messages=[
                            {"role": "system", "content": "You are a helpful assistant that outputs JSON."},
                            {"role": "user", "content": prompt}
                        ],
                        model=self.model,
                        temperature=0,
                        response_format={"type": "json_object"}
                    )
                    content = chat_completion.choices[0].message.content
                    return json.loads(content.strip())
                except Exception as api_err:
                    # Fallback Strategy: Text Mode if JSON validation fails
                    if "400" in str(api_err) or "json_validate_failed" in str(api_err):
                        print(f"JSON Mode failed. Falling back to Text Mode (Attempt {attempt+1})...")
                        chat_completion = self.client.chat.completions.create(
                            messages=[
                                {"role": "system", "content": "You are a helpful assistant that outputs JSON."},
                                {"role": "user", "content": prompt + "\n\nOUTPUT RAW JSON ONLY. NO MARKDOWN."}
                            ],
                            model=self.model,
                            temperature=0
                        )
                        content = chat_completion.choices[0].message.content
                        
                        # Strip Markdown Code Blocks
                        match = re.search(r"```(?:json)?(.*?)```", content, re.DOTALL)
                        if match:
                            content = match.group(1)
                        
                        return json.loads(content.strip())
                    raise api_err
                    
            except Exception as e:
                print(f"Error in LLM request (Attempt {attempt+1}): {e}")
                if attempt < 2:
                    time.sleep(2)
                else:
                    return {}
        return {}

    def _clean_cypher(self, query: str) -> str:
        """Sanitizes LLM-generated Cypher to prevent common syntax errors."""
        import re
        if not query: return ""
        
        # 1. Strip Markdown code blocks
        query = re.sub(r"```(?:cypher|json)?(.*?)```", r"\1", query, flags=re.DOTALL)
        
        # 2. Fix Doubled Braces (LLM over-copying prompt escaping)
        query = query.replace("{{", "{").replace("}}", "}")
        
        # 3. Standardize Quotes (Smart quotes from some models)
        query = query.replace("“", "\"").replace("”", "\"").replace("‘", "'").replace("’", "'")
        
        # 4. Remove trailing semicolons
        query = query.strip().rstrip(";")
        
        return query.strip()

    def generate_plan(self, classification_json: dict) -> dict:
        prompt = f"""
You are a Constitution Graph Router. Map the user query to specific Graph Database (Neo4j) queries.

User Query: "{classification_json.get('raw_query', '')}"
Entities: {json.dumps(classification_json.get('entities'))}

YOUR TOOLBOX (Graph Schema):
- Nodes: 
  - `Amendment` (property: `number` (int))
  - `Article` (property: `number` (string)) (e.g. "19", "21A")
  - `Clause` (property: `id` (string)) (e.g. "19(1)(a)")
  - `Schedule` (property: `id` (string))
- Relationships (Directed):
  - `(Amendment)-[:AMENDS|SUBSTITUTES|INSERTS|OMITS|REPEALS|OVERRIDES|SUSPENDS|REFERS_TO|DEFINES]->(Article|Clause)`
  - `(Article)-[:REFERS_TO|SUSPENDS|OVERRIDES|DEFINES]->(Article)`
  - `(Clause)-[:PART_OF]->(Article)`

YOUR GOAL:
Generate 1-4 distinct Cypher queries based on the query:
1. **Direct Changes**: Find changes to the Article OR its Clauses.
   - Preferred Pattern: `MATCH (am:Amendment)-[r]->(target) WHERE (target:Article AND target.number = '...') OR (target:Clause AND target.id STARTS WITH '...') RETURN am.number as amendment, am.year as year, type(r) as relationship, r.details as details, coalesce(target.number, target.id) as target_id`
2. **Contextual Changes (Multi-hop)**: Find if the Amendment modified *another* Article (like 358) that impacts the target Article (like 19).
   - Pattern: `MATCH (am:Amendment)-[r1]->(other:Article)-[r2:SUSPENDS|OVERRIDES|REFERS_TO]->(target:Article {{number: '...'}}) RETURN am.number as amendment, am.year as year, other.number as via_article, r1.details as modification, type(r2) as effect`
3. **Related Provisions (Smart Expansion)**: Find articles that constitutionally affect the target, then check if the amendment modified those related articles.
   - Pattern: `MATCH (related:Article)-[r_impact:SUSPENDS|OVERRIDES|REFERS_TO]->(target:Article {{number: '...'}}) WITH related, r_impact MATCH (am:Amendment)-[r_change]->(related) RETURN am.number as amendment, am.year as year, related.number as related_article, type(r_change) as change_type, r_change.details as change_details, type(r_impact) as impact_on_target`
4. **General Amendment Discovery**: If asking about an Amendment (e.g. "What is 44th Amendment?") without an article.
   - Pattern: `MATCH (am:Amendment {{number: 44}})-[r]->(target) RETURN am.number as amendment, am.year as year, type(r) as relationship, r.details as details, COALESCE(target.number, target.id) as target_id, labels(target)[0] as target_type`

CRITICAL INSTRUCTIONS:
- **TYPE SAFETY**: `Amendment.number` is an INTEGER. Use `WHERE am.number = 44` or `{{number: 44}}`. NEVER use quotes around the amendment number.
- **ARTICLE STRINGS**: `Article.number` is a STRING. Use `WHERE target.number = '19'`.
- **Discovery Preference**: If the query is just about an Amendment, Query 4 is the MOST important. Do not hallucinate articles if none are provided.
- **Explicit Returns**: ALWAYS return specific properties (e.g. `r.details`, `target.id`), NOT the whole object (e.g. `r`). Objects serialize poorly.
- **Filtering**: If the user asks for "Amendment 44", YOU MUST FILTER `WHERE am.number = 44`.
- **Targeting**: Always check for changes to `Clauses` of an Article, not just the Article itself.
- **Query 3 Logic**: First find articles that affect the target (e.g., Article 358 SUSPENDS Article 19), then check if the amendment changed those articles or their dependencies (e.g., Amendment 44 changed Article 352 which Article 358 relies on).
- **Single String Rule**: Each Cypher query must be a SINGLE string in the `cypher_queries` list. Do NOT split one query into multiple list items. Use `\\n` for line breaks inside the string.

RETURN FORMAT (JSON ONLY):
{{
  "tree_start_node": "Article 19",
  "cypher_queries": [
    "MATCH ... RETURN ...", 
    "MATCH ... RETURN ...",
    "MATCH ... RETURN ..."
  ]
}}
"""
        result = self._get_json_response(prompt)
        
        if not result:
            return {
                "cypher_queries": [],
                "tree_start_node": "Unknown",
                "traversal": [],
                "qdrant_scope": {"article_numbers": [], "parts": [], "include_amendments": False},
                "error": "Error in graph planning"
            }
            
        # Sanitize queries
        raw_queries = result.get("cypher_queries", [])
        clean_queries = []
        if isinstance(raw_queries, list):
            for q in raw_queries:
                if isinstance(q, str):
                    clean_queries.append(self._clean_cypher(q))
                elif isinstance(q, dict) and "query" in q:
                    clean_queries.append(self._clean_cypher(q["query"]))
                    
        result["cypher_queries"] = [q for q in clean_queries if q]
        return result