Spaces:
Running
Running
| import json | |
| from groq import Groq | |
| from utils.config import settings | |
| class GraphPlannerService: | |
| def __init__(self): | |
| if not settings.GROQ_API_KEY: | |
| raise ValueError("GROQ_API_KEY is not set.") | |
| self.client = Groq(api_key=settings.GROQ_API_KEY) | |
| self.model = getattr(settings, 'GROQ_FAST_MODEL', settings.GROQ_MODEL) | |
| def _get_json_response(self, prompt: str) -> dict: | |
| """Robust method to get JSON response, handling API errors and Markdown""" | |
| import re | |
| import time | |
| for attempt in range(3): | |
| try: | |
| # Primary Strategy: Strict JSON Mode | |
| try: | |
| chat_completion = self.client.chat.completions.create( | |
| messages=[ | |
| {"role": "system", "content": "You are a helpful assistant that outputs JSON."}, | |
| {"role": "user", "content": prompt} | |
| ], | |
| model=self.model, | |
| temperature=0, | |
| response_format={"type": "json_object"} | |
| ) | |
| content = chat_completion.choices[0].message.content | |
| return json.loads(content.strip()) | |
| except Exception as api_err: | |
| # Fallback Strategy: Text Mode if JSON validation fails | |
| if "400" in str(api_err) or "json_validate_failed" in str(api_err): | |
| print(f"JSON Mode failed. Falling back to Text Mode (Attempt {attempt+1})...") | |
| chat_completion = self.client.chat.completions.create( | |
| messages=[ | |
| {"role": "system", "content": "You are a helpful assistant that outputs JSON."}, | |
| {"role": "user", "content": prompt + "\n\nOUTPUT RAW JSON ONLY. NO MARKDOWN."} | |
| ], | |
| model=self.model, | |
| temperature=0 | |
| ) | |
| content = chat_completion.choices[0].message.content | |
| # Strip Markdown Code Blocks | |
| match = re.search(r"```(?:json)?(.*?)```", content, re.DOTALL) | |
| if match: | |
| content = match.group(1) | |
| return json.loads(content.strip()) | |
| raise api_err | |
| except Exception as e: | |
| print(f"Error in LLM request (Attempt {attempt+1}): {e}") | |
| if attempt < 2: | |
| time.sleep(2) | |
| else: | |
| return {} | |
| return {} | |
| def _clean_cypher(self, query: str) -> str: | |
| """Sanitizes LLM-generated Cypher to prevent common syntax errors.""" | |
| import re | |
| if not query: return "" | |
| # 1. Strip Markdown code blocks | |
| query = re.sub(r"```(?:cypher|json)?(.*?)```", r"\1", query, flags=re.DOTALL) | |
| # 2. Fix Doubled Braces (LLM over-copying prompt escaping) | |
| query = query.replace("{{", "{").replace("}}", "}") | |
| # 3. Standardize Quotes (Smart quotes from some models) | |
| query = query.replace("“", "\"").replace("”", "\"").replace("‘", "'").replace("’", "'") | |
| # 4. Remove trailing semicolons | |
| query = query.strip().rstrip(";") | |
| return query.strip() | |
| def generate_plan(self, classification_json: dict) -> dict: | |
| prompt = f""" | |
| You are a Constitution Graph Router. Map the user query to specific Graph Database (Neo4j) queries. | |
| User Query: "{classification_json.get('raw_query', '')}" | |
| Entities: {json.dumps(classification_json.get('entities'))} | |
| YOUR TOOLBOX (Graph Schema): | |
| - Nodes: | |
| - `Amendment` (property: `number` (int)) | |
| - `Article` (property: `number` (string)) (e.g. "19", "21A") | |
| - `Clause` (property: `id` (string)) (e.g. "19(1)(a)") | |
| - `Schedule` (property: `id` (string)) | |
| - Relationships (Directed): | |
| - `(Amendment)-[:AMENDS|SUBSTITUTES|INSERTS|OMITS|REPEALS|OVERRIDES|SUSPENDS|REFERS_TO|DEFINES]->(Article|Clause)` | |
| - `(Article)-[:REFERS_TO|SUSPENDS|OVERRIDES|DEFINES]->(Article)` | |
| - `(Clause)-[:PART_OF]->(Article)` | |
| YOUR GOAL: | |
| Generate 1-4 distinct Cypher queries based on the query: | |
| 1. **Direct Changes**: Find changes to the Article OR its Clauses. | |
| - Preferred Pattern: `MATCH (am:Amendment)-[r]->(target) WHERE (target:Article AND target.number = '...') OR (target:Clause AND target.id STARTS WITH '...') RETURN am.number as amendment, am.year as year, type(r) as relationship, r.details as details, coalesce(target.number, target.id) as target_id` | |
| 2. **Contextual Changes (Multi-hop)**: Find if the Amendment modified *another* Article (like 358) that impacts the target Article (like 19). | |
| - Pattern: `MATCH (am:Amendment)-[r1]->(other:Article)-[r2:SUSPENDS|OVERRIDES|REFERS_TO]->(target:Article {{number: '...'}}) RETURN am.number as amendment, am.year as year, other.number as via_article, r1.details as modification, type(r2) as effect` | |
| 3. **Related Provisions (Smart Expansion)**: Find articles that constitutionally affect the target, then check if the amendment modified those related articles. | |
| - Pattern: `MATCH (related:Article)-[r_impact:SUSPENDS|OVERRIDES|REFERS_TO]->(target:Article {{number: '...'}}) WITH related, r_impact MATCH (am:Amendment)-[r_change]->(related) RETURN am.number as amendment, am.year as year, related.number as related_article, type(r_change) as change_type, r_change.details as change_details, type(r_impact) as impact_on_target` | |
| 4. **General Amendment Discovery**: If asking about an Amendment (e.g. "What is 44th Amendment?") without an article. | |
| - Pattern: `MATCH (am:Amendment {{number: 44}})-[r]->(target) RETURN am.number as amendment, am.year as year, type(r) as relationship, r.details as details, COALESCE(target.number, target.id) as target_id, labels(target)[0] as target_type` | |
| CRITICAL INSTRUCTIONS: | |
| - **TYPE SAFETY**: `Amendment.number` is an INTEGER. Use `WHERE am.number = 44` or `{{number: 44}}`. NEVER use quotes around the amendment number. | |
| - **ARTICLE STRINGS**: `Article.number` is a STRING. Use `WHERE target.number = '19'`. | |
| - **Discovery Preference**: If the query is just about an Amendment, Query 4 is the MOST important. Do not hallucinate articles if none are provided. | |
| - **Explicit Returns**: ALWAYS return specific properties (e.g. `r.details`, `target.id`), NOT the whole object (e.g. `r`). Objects serialize poorly. | |
| - **Filtering**: If the user asks for "Amendment 44", YOU MUST FILTER `WHERE am.number = 44`. | |
| - **Targeting**: Always check for changes to `Clauses` of an Article, not just the Article itself. | |
| - **Query 3 Logic**: First find articles that affect the target (e.g., Article 358 SUSPENDS Article 19), then check if the amendment changed those articles or their dependencies (e.g., Amendment 44 changed Article 352 which Article 358 relies on). | |
| - **Single String Rule**: Each Cypher query must be a SINGLE string in the `cypher_queries` list. Do NOT split one query into multiple list items. Use `\\n` for line breaks inside the string. | |
| RETURN FORMAT (JSON ONLY): | |
| {{ | |
| "tree_start_node": "Article 19", | |
| "cypher_queries": [ | |
| "MATCH ... RETURN ...", | |
| "MATCH ... RETURN ...", | |
| "MATCH ... RETURN ..." | |
| ] | |
| }} | |
| """ | |
| result = self._get_json_response(prompt) | |
| if not result: | |
| return { | |
| "cypher_queries": [], | |
| "tree_start_node": "Unknown", | |
| "traversal": [], | |
| "qdrant_scope": {"article_numbers": [], "parts": [], "include_amendments": False}, | |
| "error": "Error in graph planning" | |
| } | |
| # Sanitize queries | |
| raw_queries = result.get("cypher_queries", []) | |
| clean_queries = [] | |
| if isinstance(raw_queries, list): | |
| for q in raw_queries: | |
| if isinstance(q, str): | |
| clean_queries.append(self._clean_cypher(q)) | |
| elif isinstance(q, dict) and "query" in q: | |
| clean_queries.append(self._clean_cypher(q["query"])) | |
| result["cypher_queries"] = [q for q in clean_queries if q] | |
| return result | |