ConstitutionAgent / data_tools /audit_and_fix_graph.py
Meshyboi's picture
Upload 53 files
0cd3dc5 verified
"""
audit_and_fix_graph.py
======================
Audits every amendment (0-106) in extracted_data/ against what is actually
stored in Neo4j, then fills in ONLY the missing relationships β€” no wipes, no
duplicates, fully idempotent.
Usage:
# Audit only β€” print a report but change nothing
PYTHONPATH=. .venv/bin/python3 scripts/audit_and_fix_graph.py --mode audit
# Audit + fix ONLY specific amendments
PYTHONPATH=. .venv/bin/python3 scripts/audit_and_fix_graph.py --mode fix --amendments 7,13,52,73,352
# Audit + fix ALL amendments that have zero relationships
PYTHONPATH=. .venv/bin/python3 scripts/audit_and_fix_graph.py --mode fix
# Fix a single amendment by number
PYTHONPATH=. .venv/bin/python3 scripts/audit_and_fix_graph.py --mode fix --amendments 38
"""
import os
import sys
import json
import glob
import re
import time
import argparse
from typing import List, Dict, Any, Set, Tuple, Optional
# ── Project imports ────────────────────────────────────────────────────────────
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from services.neo4j import get_neo4j_driver
from utils.config import settings
from groq import Groq
# ── Config ─────────────────────────────────────────────────────────────────────
EXTRACTED_DIR = os.path.join(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
"extracted_data"
)
# Minimum relationships we expect per amendment before flagging as "MISSING"
MIN_RELATIONSHIPS_PER_AMENDMENT = 1
# ── Groq LLM wrapper (reuses the same extraction logic as graph_ingestion.py) ──
class LLMExtractor:
def __init__(self):
self.client = Groq(api_key=settings.GROQ_API_KEY)
self.model = settings.GROQ_MODEL or "llama-3.3-70b-versatile"
def _call(self, prompt: str) -> str:
"""Call Groq with JSON mode, fall back to text mode."""
# Groq requires the word 'json' in messages when using json_object response_format
system_msg = "You are a constitutional law expert. Output only valid json."
try:
resp = self.client.chat.completions.create(
messages=[
{"role": "system", "content": system_msg},
{"role": "user", "content": prompt},
],
model=self.model,
temperature=0,
response_format={"type": "json_object"},
)
return resp.choices[0].message.content
except Exception as e:
if "400" in str(e) or "json_validate" in str(e):
resp = self.client.chat.completions.create(
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt + "\n\nRETURN ONLY VALID JSON. NO MARKDOWN."},
],
model=self.model,
temperature=0,
)
content = resp.choices[0].message.content
m = re.search(r"```(?:json)?(.*?)```", content, re.DOTALL)
return m.group(1) if m else content
raise
def extract_relationships(self, title: str, summary: str) -> List[dict]:
"""
Four focused extractions (modifications, deletions, insertions, cross_effects).
Returns deduplicated relationship list.
"""
all_rels: List[dict] = []
focus_map = {
"modifications": (
"MODIFIES or SUBSTITUTES_ARTICLE",
'Which articles/clauses were MODIFIED or SUBSTITUTED? Use relation "MODIFIES" or "SUBSTITUTES_ARTICLE".'
),
"deletions": (
"DELETES_CLAUSE or REPEALS",
'Which articles/clauses were DELETED, REMOVED, or REPEALED? Use "DELETES_CLAUSE" or "REPEALS". '
"CRITICAL: Include exact deleted terms (e.g. 'internal disturbance') in details."
),
"insertions": (
"INSERTS_CLAUSE or RESTORES_CLAUSE",
'Which articles/clauses were INSERTED or RESTORED? Use "INSERTS_CLAUSE" or "RESTORES_CLAUSE".'
),
"cross_effects": (
"SUSPENDS, OVERRIDES, RESTRICTS, REFERS_TO",
'Does this create cross-article relationships? e.g. "Article 358 SUSPENDS Article 19". '
"Source = the acting article. Target = the article being affected."
),
}
# Split summary if too large (>7000 chars)
chunks = self._split(summary)
for chunk_type, (rel_types, focus_text) in focus_map.items():
for chunk_idx, chunk in enumerate(chunks):
part_label = f" (part {chunk_idx+1}/{len(chunks)})" if len(chunks) > 1 else ""
prompt = f"""Extract ONLY {chunk_type.upper()}{part_label} from this constitutional amendment.
Return json with a 'relationships' array.
AMENDMENT: {title}
SUMMARY: {chunk}
TASK: {focus_text}
RULES:
- source: "CURRENT_AMENDMENT" for changes BY this amendment, or "Article X" for cross-article links
- target: "Article NUMBER" or "Article NUMBER(clause)" or "Schedule N"
- details: specific text including key deleted/inserted terms
Return ONLY valid json:
{{
"relationships": [
{{"source": "CURRENT_AMENDMENT", "relation": "{rel_types.split(',')[0].strip()}", "target": "Article X", "details": "..."}}
]
}}
If none found, return {{"relationships": []}}
"""
for attempt in range(3):
try:
raw = self._call(prompt)
data = json.loads(raw.strip())
rels = data.get("relationships", [])
all_rels.extend(rels)
print(f" [{chunk_type}{part_label}]: {len(rels)} relationships")
time.sleep(1.0) # respect rate limits
break
except Exception as e:
print(f" ⚠️ [{chunk_type}{part_label}] attempt {attempt+1} error: {str(e)[:80]}")
wait = (attempt + 1) * 5
if "429" in str(e):
wait = 30
time.sleep(wait)
else:
print(f" ❌ [{chunk_type}{part_label}] all attempts failed β€” skipping")
# Deduplicate by (target, relation)
seen: Set[str] = set()
unique: List[dict] = []
for r in all_rels:
key = f"{r.get('target','')}|{r.get('relation','')}"
if key not in seen:
seen.add(key)
unique.append(r)
return unique
@staticmethod
def _split(text: str, threshold: int = 7000, overlap: int = 500) -> List[str]:
if len(text) <= threshold:
return [text]
mid = len(text) // 2
for off in range(200):
if mid + off < len(text) and text[mid + off] == ".":
mid = mid + off + 1
break
if mid - off > 0 and text[mid - off] == ".":
mid = mid - off + 1
break
return [text[: mid + overlap], text[mid - overlap :]]
# ── Neo4j entity resolver (mirrors graph_ingestion.py) ────────────────────────
def resolve_entity(tx, entity_str: str, current_am_num: int) -> Tuple[Optional[str], Optional[dict]]:
if not entity_str:
return None, None
entity_str = entity_str.strip()
if entity_str == "CURRENT_AMENDMENT":
tx.run("MERGE (am:Amendment {number: $num})", {"num": current_am_num})
return "Amendment", {"number": current_am_num}
am_match = re.search(r"Amendment\s+(\d+)", entity_str, re.IGNORECASE)
if am_match:
num = int(am_match.group(1))
tx.run("MERGE (am:Amendment {number: $num})", {"num": num})
return "Amendment", {"number": num}
# Clause: Article 19(1)(f)
clause_match = re.search(r"(Article\s+)?(\d+[A-Z]*)\s*(\(.+\))", entity_str, re.IGNORECASE)
if clause_match:
art_num = clause_match.group(2)
clause_suffix = clause_match.group(3)
clause_id = f"{art_num}{clause_suffix}"
tx.run("MERGE (a:Article {number: $anum})", {"anum": art_num})
tx.run(
"""
MATCH (a:Article {number: $anum})
MERGE (c:Clause {id: $cid})
MERGE (c)-[:PART_OF]->(a)
""",
{"anum": art_num, "cid": clause_id},
)
return "Clause", {"id": clause_id}
# Schedule: "First Schedule", "Schedule 7", "7th Schedule"
sched_match = re.search(
r"(?:Schedule\s*(\w+)|(\w+)\s*Schedule)",
entity_str,
re.IGNORECASE,
)
if sched_match:
raw_id = sched_match.group(1) or sched_match.group(2)
# Normalise ordinals β†’ digits
ordinal_map = {
"first":"1","second":"2","third":"3","fourth":"4","fifth":"5",
"sixth":"6","seventh":"7","eighth":"8","ninth":"9","tenth":"10",
}
sid = ordinal_map.get(raw_id.lower(), raw_id)
tx.run("MERGE (s:Schedule {id: $sid})", {"sid": sid})
return "Schedule", {"id": sid}
# Plain Article
clean = entity_str.replace("Article", "").replace("article", "").strip()
if clean:
tx.run("MERGE (a:Article {number: $anum})", {"anum": clean})
return "Article", {"number": clean}
return None, None
def write_relationships(session, am_num: int, relationships: List[dict]) -> int:
"""Write relationships to Neo4j. Uses MERGE so it's safe to re-run."""
written = 0
for rel in relationships:
s_raw = rel.get("source", "CURRENT_AMENDMENT")
t_raw = rel.get("target")
r_type = rel.get("relation", "MODIFIES").upper().replace(" ", "_")
details = rel.get("details", "")
if not re.match(r"^[A-Z_]+$", r_type):
r_type = "MODIFIES"
if not t_raw:
continue
sType, sProps = resolve_entity(session, s_raw, am_num)
tType, tProps = resolve_entity(session, t_raw, am_num)
if not (sType and tType):
continue
s_ident = "number" if "number" in sProps else "id"
t_ident = "number" if "number" in tProps else "id"
query = f"""
MATCH (s:{sType} {{{s_ident}: $sVal}})
MATCH (t:{tType} {{{t_ident}: $tVal}})
MERGE (s)-[r:{r_type}]->(t)
SET r.details = $details, r.source = 'audit_fix'
"""
try:
session.run(query, {
"sVal": sProps[s_ident],
"tVal": tProps[t_ident],
"details": details,
})
written += 1
except Exception as e:
print(f" ⚠️ Write error: {e}")
return written
# ── Audit ──────────────────────────────────────────────────────────────────────
def audit(driver) -> Dict[int, Dict]:
"""
Returns a dict: { amendment_number: { "in_db": bool, "rel_count": int, "summary_path": str|None } }
"""
print("\n" + "="*70)
print("AUDIT: Scanning all amendments...")
print("="*70)
# 1. What's in Neo4j?
with driver.session() as session:
rows = session.run("""
MATCH (am:Amendment)
OPTIONAL MATCH (am)-[r]->()
RETURN am.number AS num, count(r) AS rel_count
ORDER BY am.number
""").data()
db_map: Dict[int, int] = {r["num"]: r["rel_count"] for r in rows if r["num"] is not None}
# 2. What's on disk?
summary_files = sorted(glob.glob(
os.path.join(EXTRACTED_DIR, "amendment_*", "summary.json")
))
disk_map: Dict[int, str] = {}
for f in summary_files:
try:
with open(f) as fh:
data = json.load(fh)
num = data.get("metadata", {}).get("amendment_number")
if num is not None:
disk_map[int(num)] = f
except Exception:
pass
# 3. Build report
all_nums = sorted(set(db_map.keys()) | set(disk_map.keys()))
report: Dict[int, Dict] = {}
missing_from_db: List[int] = []
zero_rels: List[int] = []
for num in all_nums:
in_db = num in db_map
rel_count = db_map.get(num, 0)
summary_path = disk_map.get(num)
status = "OK"
if not in_db:
status = "MISSING_FROM_DB"
missing_from_db.append(num)
elif rel_count < MIN_RELATIONSHIPS_PER_AMENDMENT:
status = "ZERO_RELATIONSHIPS"
zero_rels.append(num)
report[num] = {
"in_db": in_db,
"rel_count": rel_count,
"summary_path": summary_path,
"status": status,
}
# 4. Print report
print(f"\n{'Amendment':<12} {'In DB':<8} {'Rels':<8} {'Status'}")
print("-" * 50)
for num, info in sorted(report.items()):
flag = "⚠️ " if info["status"] != "OK" else " "
print(f"{flag} Am {num:<8} {'yes' if info['in_db'] else 'NO':<8} {info['rel_count']:<8} {info['status']}")
print(f"\nπŸ“Š Summary:")
print(f" Amendments on disk : {len(disk_map)}")
print(f" Amendments in DB : {len(db_map)}")
print(f" Missing from DB : {len(missing_from_db)} β†’ {missing_from_db}")
print(f" Zero relationships : {len(zero_rels)} β†’ {zero_rels}")
print("="*70 + "\n")
return report
# ── Fix ────────────────────────────────────────────────────────────────────────
def fix_amendments(driver, report: Dict[int, Dict], target_nums: Optional[List[int]] = None):
"""
Re-ingests (ONLY) the amendments that are missing or have zero relationships.
If target_nums is specified, only fixes those.
"""
extractor = LLMExtractor()
# Decide which amendments need fixing
if target_nums:
to_fix = target_nums
else:
to_fix = [
num for num, info in report.items()
if info["status"] in ("MISSING_FROM_DB", "ZERO_RELATIONSHIPS")
]
if not to_fix:
print("βœ… Nothing to fix β€” all amendments already have relationships.")
return
print(f"\n{'='*70}")
print(f"FIX: Re-ingesting {len(to_fix)} amendments: {to_fix}")
print(f"{'='*70}\n")
total_written = 0
for am_num in sorted(to_fix):
info = report.get(am_num, {})
summary_path = info.get("summary_path")
if not summary_path or not os.path.exists(summary_path):
# Try to find it
candidates = glob.glob(
os.path.join(EXTRACTED_DIR, f"amendment_{am_num:03d}", "summary.json")
)
if not candidates:
print(f" ❌ Am {am_num}: summary.json NOT FOUND on disk β€” skipping")
continue
summary_path = candidates[0]
print(f"\n βš™οΈ Processing Amendment {am_num} ({summary_path})")
try:
with open(summary_path) as fh:
data = json.load(fh)
except Exception as e:
print(f" ❌ Am {am_num}: cannot read summary.json β€” {e}")
continue
title = data.get("metadata", {}).get("amendment_title") or f"Amendment {am_num}"
content = data.get("content", "")
if not content.strip():
print(f" ⚠️ Am {am_num}: empty summary content β€” skipping")
continue
# Ensure the Amendment node exists
with driver.session() as session:
session.run(
"MERGE (am:Amendment {number: $num}) SET am.title = $title",
{"num": am_num, "title": title},
)
# Check existing relationship count (before fix)
with driver.session() as session:
existing = session.run(
"MATCH (am:Amendment {number: $num})-[r]->() RETURN count(r) AS c",
{"num": am_num},
).single()["c"]
print(f" Current relationships in DB: {existing}")
# LLM extraction
print(f" Extracting relationships via LLM...")
time.sleep(2.0) # Rate-limit buffer
relationships = extractor.extract_relationships(title, content)
print(f" β†’ Extracted {len(relationships)} total relationships")
if not relationships:
print(f" ⚠️ Am {am_num}: LLM returned 0 relationships")
continue
# Write to Neo4j
with driver.session() as session:
written = write_relationships(session, am_num, relationships)
total_written += written
print(f" βœ… Am {am_num}: wrote {written} relationships (new+updated)")
# Small pause between amendments to respect Groq rate limits
time.sleep(1.5)
print(f"\n{'='*70}")
print(f"FIX COMPLETE β€” Total relationships written: {total_written}")
print(f"{'='*70}\n")
# ── Entry point ────────────────────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser(description="Audit and fix Neo4j graph ingestion gaps.")
parser.add_argument(
"--mode",
choices=["audit", "fix"],
default="audit",
help="'audit' = report only. 'fix' = audit + write missing data.",
)
parser.add_argument(
"--amendments",
type=str,
default=None,
help="Comma-separated amendment numbers to fix (e.g. 7,13,38,52,73). "
"If omitted, fixes ALL amendments with zero relationships.",
)
args = parser.parse_args()
target_nums = None
if args.amendments:
target_nums = [int(x.strip()) for x in args.amendments.split(",")]
driver = get_neo4j_driver()
try:
report = audit(driver)
if args.mode == "fix":
fix_amendments(driver, report, target_nums=target_nums)
else:
print("Mode=audit. Run with --mode fix to repair missing data.\n")
finally:
driver.close()
if __name__ == "__main__":
main()