Spaces:
Sleeping
Sleeping
| """ | |
| audit_and_fix_graph.py | |
| ====================== | |
| Audits every amendment (0-106) in extracted_data/ against what is actually | |
| stored in Neo4j, then fills in ONLY the missing relationships β no wipes, no | |
| duplicates, fully idempotent. | |
| Usage: | |
| # Audit only β print a report but change nothing | |
| PYTHONPATH=. .venv/bin/python3 scripts/audit_and_fix_graph.py --mode audit | |
| # Audit + fix ONLY specific amendments | |
| PYTHONPATH=. .venv/bin/python3 scripts/audit_and_fix_graph.py --mode fix --amendments 7,13,52,73,352 | |
| # Audit + fix ALL amendments that have zero relationships | |
| PYTHONPATH=. .venv/bin/python3 scripts/audit_and_fix_graph.py --mode fix | |
| # Fix a single amendment by number | |
| PYTHONPATH=. .venv/bin/python3 scripts/audit_and_fix_graph.py --mode fix --amendments 38 | |
| """ | |
| import os | |
| import sys | |
| import json | |
| import glob | |
| import re | |
| import time | |
| import argparse | |
| from typing import List, Dict, Any, Set, Tuple, Optional | |
| # ββ Project imports ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from services.neo4j import get_neo4j_driver | |
| from utils.config import settings | |
| from groq import Groq | |
| # ββ Config βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| EXTRACTED_DIR = os.path.join( | |
| os.path.dirname(os.path.dirname(os.path.abspath(__file__))), | |
| "extracted_data" | |
| ) | |
| # Minimum relationships we expect per amendment before flagging as "MISSING" | |
| MIN_RELATIONSHIPS_PER_AMENDMENT = 1 | |
| # ββ Groq LLM wrapper (reuses the same extraction logic as graph_ingestion.py) ββ | |
| class LLMExtractor: | |
| def __init__(self): | |
| self.client = Groq(api_key=settings.GROQ_API_KEY) | |
| self.model = settings.GROQ_MODEL or "llama-3.3-70b-versatile" | |
| def _call(self, prompt: str) -> str: | |
| """Call Groq with JSON mode, fall back to text mode.""" | |
| # Groq requires the word 'json' in messages when using json_object response_format | |
| system_msg = "You are a constitutional law expert. Output only valid json." | |
| try: | |
| resp = self.client.chat.completions.create( | |
| messages=[ | |
| {"role": "system", "content": system_msg}, | |
| {"role": "user", "content": prompt}, | |
| ], | |
| model=self.model, | |
| temperature=0, | |
| response_format={"type": "json_object"}, | |
| ) | |
| return resp.choices[0].message.content | |
| except Exception as e: | |
| if "400" in str(e) or "json_validate" in str(e): | |
| resp = self.client.chat.completions.create( | |
| messages=[ | |
| {"role": "system", "content": "You are a helpful assistant."}, | |
| {"role": "user", "content": prompt + "\n\nRETURN ONLY VALID JSON. NO MARKDOWN."}, | |
| ], | |
| model=self.model, | |
| temperature=0, | |
| ) | |
| content = resp.choices[0].message.content | |
| m = re.search(r"```(?:json)?(.*?)```", content, re.DOTALL) | |
| return m.group(1) if m else content | |
| raise | |
| def extract_relationships(self, title: str, summary: str) -> List[dict]: | |
| """ | |
| Four focused extractions (modifications, deletions, insertions, cross_effects). | |
| Returns deduplicated relationship list. | |
| """ | |
| all_rels: List[dict] = [] | |
| focus_map = { | |
| "modifications": ( | |
| "MODIFIES or SUBSTITUTES_ARTICLE", | |
| 'Which articles/clauses were MODIFIED or SUBSTITUTED? Use relation "MODIFIES" or "SUBSTITUTES_ARTICLE".' | |
| ), | |
| "deletions": ( | |
| "DELETES_CLAUSE or REPEALS", | |
| 'Which articles/clauses were DELETED, REMOVED, or REPEALED? Use "DELETES_CLAUSE" or "REPEALS". ' | |
| "CRITICAL: Include exact deleted terms (e.g. 'internal disturbance') in details." | |
| ), | |
| "insertions": ( | |
| "INSERTS_CLAUSE or RESTORES_CLAUSE", | |
| 'Which articles/clauses were INSERTED or RESTORED? Use "INSERTS_CLAUSE" or "RESTORES_CLAUSE".' | |
| ), | |
| "cross_effects": ( | |
| "SUSPENDS, OVERRIDES, RESTRICTS, REFERS_TO", | |
| 'Does this create cross-article relationships? e.g. "Article 358 SUSPENDS Article 19". ' | |
| "Source = the acting article. Target = the article being affected." | |
| ), | |
| } | |
| # Split summary if too large (>7000 chars) | |
| chunks = self._split(summary) | |
| for chunk_type, (rel_types, focus_text) in focus_map.items(): | |
| for chunk_idx, chunk in enumerate(chunks): | |
| part_label = f" (part {chunk_idx+1}/{len(chunks)})" if len(chunks) > 1 else "" | |
| prompt = f"""Extract ONLY {chunk_type.upper()}{part_label} from this constitutional amendment. | |
| Return json with a 'relationships' array. | |
| AMENDMENT: {title} | |
| SUMMARY: {chunk} | |
| TASK: {focus_text} | |
| RULES: | |
| - source: "CURRENT_AMENDMENT" for changes BY this amendment, or "Article X" for cross-article links | |
| - target: "Article NUMBER" or "Article NUMBER(clause)" or "Schedule N" | |
| - details: specific text including key deleted/inserted terms | |
| Return ONLY valid json: | |
| {{ | |
| "relationships": [ | |
| {{"source": "CURRENT_AMENDMENT", "relation": "{rel_types.split(',')[0].strip()}", "target": "Article X", "details": "..."}} | |
| ] | |
| }} | |
| If none found, return {{"relationships": []}} | |
| """ | |
| for attempt in range(3): | |
| try: | |
| raw = self._call(prompt) | |
| data = json.loads(raw.strip()) | |
| rels = data.get("relationships", []) | |
| all_rels.extend(rels) | |
| print(f" [{chunk_type}{part_label}]: {len(rels)} relationships") | |
| time.sleep(1.0) # respect rate limits | |
| break | |
| except Exception as e: | |
| print(f" β οΈ [{chunk_type}{part_label}] attempt {attempt+1} error: {str(e)[:80]}") | |
| wait = (attempt + 1) * 5 | |
| if "429" in str(e): | |
| wait = 30 | |
| time.sleep(wait) | |
| else: | |
| print(f" β [{chunk_type}{part_label}] all attempts failed β skipping") | |
| # Deduplicate by (target, relation) | |
| seen: Set[str] = set() | |
| unique: List[dict] = [] | |
| for r in all_rels: | |
| key = f"{r.get('target','')}|{r.get('relation','')}" | |
| if key not in seen: | |
| seen.add(key) | |
| unique.append(r) | |
| return unique | |
| def _split(text: str, threshold: int = 7000, overlap: int = 500) -> List[str]: | |
| if len(text) <= threshold: | |
| return [text] | |
| mid = len(text) // 2 | |
| for off in range(200): | |
| if mid + off < len(text) and text[mid + off] == ".": | |
| mid = mid + off + 1 | |
| break | |
| if mid - off > 0 and text[mid - off] == ".": | |
| mid = mid - off + 1 | |
| break | |
| return [text[: mid + overlap], text[mid - overlap :]] | |
| # ββ Neo4j entity resolver (mirrors graph_ingestion.py) ββββββββββββββββββββββββ | |
| def resolve_entity(tx, entity_str: str, current_am_num: int) -> Tuple[Optional[str], Optional[dict]]: | |
| if not entity_str: | |
| return None, None | |
| entity_str = entity_str.strip() | |
| if entity_str == "CURRENT_AMENDMENT": | |
| tx.run("MERGE (am:Amendment {number: $num})", {"num": current_am_num}) | |
| return "Amendment", {"number": current_am_num} | |
| am_match = re.search(r"Amendment\s+(\d+)", entity_str, re.IGNORECASE) | |
| if am_match: | |
| num = int(am_match.group(1)) | |
| tx.run("MERGE (am:Amendment {number: $num})", {"num": num}) | |
| return "Amendment", {"number": num} | |
| # Clause: Article 19(1)(f) | |
| clause_match = re.search(r"(Article\s+)?(\d+[A-Z]*)\s*(\(.+\))", entity_str, re.IGNORECASE) | |
| if clause_match: | |
| art_num = clause_match.group(2) | |
| clause_suffix = clause_match.group(3) | |
| clause_id = f"{art_num}{clause_suffix}" | |
| tx.run("MERGE (a:Article {number: $anum})", {"anum": art_num}) | |
| tx.run( | |
| """ | |
| MATCH (a:Article {number: $anum}) | |
| MERGE (c:Clause {id: $cid}) | |
| MERGE (c)-[:PART_OF]->(a) | |
| """, | |
| {"anum": art_num, "cid": clause_id}, | |
| ) | |
| return "Clause", {"id": clause_id} | |
| # Schedule: "First Schedule", "Schedule 7", "7th Schedule" | |
| sched_match = re.search( | |
| r"(?:Schedule\s*(\w+)|(\w+)\s*Schedule)", | |
| entity_str, | |
| re.IGNORECASE, | |
| ) | |
| if sched_match: | |
| raw_id = sched_match.group(1) or sched_match.group(2) | |
| # Normalise ordinals β digits | |
| ordinal_map = { | |
| "first":"1","second":"2","third":"3","fourth":"4","fifth":"5", | |
| "sixth":"6","seventh":"7","eighth":"8","ninth":"9","tenth":"10", | |
| } | |
| sid = ordinal_map.get(raw_id.lower(), raw_id) | |
| tx.run("MERGE (s:Schedule {id: $sid})", {"sid": sid}) | |
| return "Schedule", {"id": sid} | |
| # Plain Article | |
| clean = entity_str.replace("Article", "").replace("article", "").strip() | |
| if clean: | |
| tx.run("MERGE (a:Article {number: $anum})", {"anum": clean}) | |
| return "Article", {"number": clean} | |
| return None, None | |
| def write_relationships(session, am_num: int, relationships: List[dict]) -> int: | |
| """Write relationships to Neo4j. Uses MERGE so it's safe to re-run.""" | |
| written = 0 | |
| for rel in relationships: | |
| s_raw = rel.get("source", "CURRENT_AMENDMENT") | |
| t_raw = rel.get("target") | |
| r_type = rel.get("relation", "MODIFIES").upper().replace(" ", "_") | |
| details = rel.get("details", "") | |
| if not re.match(r"^[A-Z_]+$", r_type): | |
| r_type = "MODIFIES" | |
| if not t_raw: | |
| continue | |
| sType, sProps = resolve_entity(session, s_raw, am_num) | |
| tType, tProps = resolve_entity(session, t_raw, am_num) | |
| if not (sType and tType): | |
| continue | |
| s_ident = "number" if "number" in sProps else "id" | |
| t_ident = "number" if "number" in tProps else "id" | |
| query = f""" | |
| MATCH (s:{sType} {{{s_ident}: $sVal}}) | |
| MATCH (t:{tType} {{{t_ident}: $tVal}}) | |
| MERGE (s)-[r:{r_type}]->(t) | |
| SET r.details = $details, r.source = 'audit_fix' | |
| """ | |
| try: | |
| session.run(query, { | |
| "sVal": sProps[s_ident], | |
| "tVal": tProps[t_ident], | |
| "details": details, | |
| }) | |
| written += 1 | |
| except Exception as e: | |
| print(f" β οΈ Write error: {e}") | |
| return written | |
| # ββ Audit ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def audit(driver) -> Dict[int, Dict]: | |
| """ | |
| Returns a dict: { amendment_number: { "in_db": bool, "rel_count": int, "summary_path": str|None } } | |
| """ | |
| print("\n" + "="*70) | |
| print("AUDIT: Scanning all amendments...") | |
| print("="*70) | |
| # 1. What's in Neo4j? | |
| with driver.session() as session: | |
| rows = session.run(""" | |
| MATCH (am:Amendment) | |
| OPTIONAL MATCH (am)-[r]->() | |
| RETURN am.number AS num, count(r) AS rel_count | |
| ORDER BY am.number | |
| """).data() | |
| db_map: Dict[int, int] = {r["num"]: r["rel_count"] for r in rows if r["num"] is not None} | |
| # 2. What's on disk? | |
| summary_files = sorted(glob.glob( | |
| os.path.join(EXTRACTED_DIR, "amendment_*", "summary.json") | |
| )) | |
| disk_map: Dict[int, str] = {} | |
| for f in summary_files: | |
| try: | |
| with open(f) as fh: | |
| data = json.load(fh) | |
| num = data.get("metadata", {}).get("amendment_number") | |
| if num is not None: | |
| disk_map[int(num)] = f | |
| except Exception: | |
| pass | |
| # 3. Build report | |
| all_nums = sorted(set(db_map.keys()) | set(disk_map.keys())) | |
| report: Dict[int, Dict] = {} | |
| missing_from_db: List[int] = [] | |
| zero_rels: List[int] = [] | |
| for num in all_nums: | |
| in_db = num in db_map | |
| rel_count = db_map.get(num, 0) | |
| summary_path = disk_map.get(num) | |
| status = "OK" | |
| if not in_db: | |
| status = "MISSING_FROM_DB" | |
| missing_from_db.append(num) | |
| elif rel_count < MIN_RELATIONSHIPS_PER_AMENDMENT: | |
| status = "ZERO_RELATIONSHIPS" | |
| zero_rels.append(num) | |
| report[num] = { | |
| "in_db": in_db, | |
| "rel_count": rel_count, | |
| "summary_path": summary_path, | |
| "status": status, | |
| } | |
| # 4. Print report | |
| print(f"\n{'Amendment':<12} {'In DB':<8} {'Rels':<8} {'Status'}") | |
| print("-" * 50) | |
| for num, info in sorted(report.items()): | |
| flag = "β οΈ " if info["status"] != "OK" else " " | |
| print(f"{flag} Am {num:<8} {'yes' if info['in_db'] else 'NO':<8} {info['rel_count']:<8} {info['status']}") | |
| print(f"\nπ Summary:") | |
| print(f" Amendments on disk : {len(disk_map)}") | |
| print(f" Amendments in DB : {len(db_map)}") | |
| print(f" Missing from DB : {len(missing_from_db)} β {missing_from_db}") | |
| print(f" Zero relationships : {len(zero_rels)} β {zero_rels}") | |
| print("="*70 + "\n") | |
| return report | |
| # ββ Fix ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def fix_amendments(driver, report: Dict[int, Dict], target_nums: Optional[List[int]] = None): | |
| """ | |
| Re-ingests (ONLY) the amendments that are missing or have zero relationships. | |
| If target_nums is specified, only fixes those. | |
| """ | |
| extractor = LLMExtractor() | |
| # Decide which amendments need fixing | |
| if target_nums: | |
| to_fix = target_nums | |
| else: | |
| to_fix = [ | |
| num for num, info in report.items() | |
| if info["status"] in ("MISSING_FROM_DB", "ZERO_RELATIONSHIPS") | |
| ] | |
| if not to_fix: | |
| print("β Nothing to fix β all amendments already have relationships.") | |
| return | |
| print(f"\n{'='*70}") | |
| print(f"FIX: Re-ingesting {len(to_fix)} amendments: {to_fix}") | |
| print(f"{'='*70}\n") | |
| total_written = 0 | |
| for am_num in sorted(to_fix): | |
| info = report.get(am_num, {}) | |
| summary_path = info.get("summary_path") | |
| if not summary_path or not os.path.exists(summary_path): | |
| # Try to find it | |
| candidates = glob.glob( | |
| os.path.join(EXTRACTED_DIR, f"amendment_{am_num:03d}", "summary.json") | |
| ) | |
| if not candidates: | |
| print(f" β Am {am_num}: summary.json NOT FOUND on disk β skipping") | |
| continue | |
| summary_path = candidates[0] | |
| print(f"\n βοΈ Processing Amendment {am_num} ({summary_path})") | |
| try: | |
| with open(summary_path) as fh: | |
| data = json.load(fh) | |
| except Exception as e: | |
| print(f" β Am {am_num}: cannot read summary.json β {e}") | |
| continue | |
| title = data.get("metadata", {}).get("amendment_title") or f"Amendment {am_num}" | |
| content = data.get("content", "") | |
| if not content.strip(): | |
| print(f" β οΈ Am {am_num}: empty summary content β skipping") | |
| continue | |
| # Ensure the Amendment node exists | |
| with driver.session() as session: | |
| session.run( | |
| "MERGE (am:Amendment {number: $num}) SET am.title = $title", | |
| {"num": am_num, "title": title}, | |
| ) | |
| # Check existing relationship count (before fix) | |
| with driver.session() as session: | |
| existing = session.run( | |
| "MATCH (am:Amendment {number: $num})-[r]->() RETURN count(r) AS c", | |
| {"num": am_num}, | |
| ).single()["c"] | |
| print(f" Current relationships in DB: {existing}") | |
| # LLM extraction | |
| print(f" Extracting relationships via LLM...") | |
| time.sleep(2.0) # Rate-limit buffer | |
| relationships = extractor.extract_relationships(title, content) | |
| print(f" β Extracted {len(relationships)} total relationships") | |
| if not relationships: | |
| print(f" β οΈ Am {am_num}: LLM returned 0 relationships") | |
| continue | |
| # Write to Neo4j | |
| with driver.session() as session: | |
| written = write_relationships(session, am_num, relationships) | |
| total_written += written | |
| print(f" β Am {am_num}: wrote {written} relationships (new+updated)") | |
| # Small pause between amendments to respect Groq rate limits | |
| time.sleep(1.5) | |
| print(f"\n{'='*70}") | |
| print(f"FIX COMPLETE β Total relationships written: {total_written}") | |
| print(f"{'='*70}\n") | |
| # ββ Entry point ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Audit and fix Neo4j graph ingestion gaps.") | |
| parser.add_argument( | |
| "--mode", | |
| choices=["audit", "fix"], | |
| default="audit", | |
| help="'audit' = report only. 'fix' = audit + write missing data.", | |
| ) | |
| parser.add_argument( | |
| "--amendments", | |
| type=str, | |
| default=None, | |
| help="Comma-separated amendment numbers to fix (e.g. 7,13,38,52,73). " | |
| "If omitted, fixes ALL amendments with zero relationships.", | |
| ) | |
| args = parser.parse_args() | |
| target_nums = None | |
| if args.amendments: | |
| target_nums = [int(x.strip()) for x in args.amendments.split(",")] | |
| driver = get_neo4j_driver() | |
| try: | |
| report = audit(driver) | |
| if args.mode == "fix": | |
| fix_amendments(driver, report, target_nums=target_nums) | |
| else: | |
| print("Mode=audit. Run with --mode fix to repair missing data.\n") | |
| finally: | |
| driver.close() | |
| if __name__ == "__main__": | |
| main() | |