ConstitutionAgent / core /graph_ingestion.py
Meshyboi's picture
Upload 53 files
0cd3dc5 verified
import os
import json
import glob
import re
import time
from typing import List, Dict, Any, Set
from services.neo4j import get_neo4j_driver
from utils.config import settings
from groq import Groq
# ============================================================================
# RESUME CONFIGURATION
# ============================================================================
# Set to amendment number to resume from (e.g., 68 to skip 0-67)
# Set to 0 or None to start from beginning
RESUME_FROM = 0 # Change this to resume from a specific amendment
REINGEST = True # Set to False to skip database clear
# Auto-detect: If True, checks Neo4j for already-processed amendments
AUTO_DETECT_RESUME = False
# ============================================================================
class GraphIngestionService:
def __init__(self):
print("Initializing Graph Ingestion Service...")
self.driver = get_neo4j_driver()
# Initialize Groq for semantic extraction
if settings.GROQ_API_KEY:
self.groq_client = Groq(api_key=settings.GROQ_API_KEY)
self.model = settings.GROQ_MODEL or "llama-3.3-70b-versatile"
print(f"Groq Client initialized with model: {self.model}")
else:
import os
env_key = os.getenv("GROQ_API_KEY")
if env_key:
self.groq_client = Groq(api_key=env_key)
self.model = os.getenv("GROQ_MODEL", "llama-3.3-70b-versatile")
print(f"Groq Client initialized from os.environ with model: {self.model}")
else:
print("Warning: GROQ_API_KEY not set in settings or environment. Semantic extraction disabled.")
self.groq_client = None
if REINGEST and (RESUME_FROM == 0 or RESUME_FROM is None):
self.clear_database()
def close(self):
self.driver.close()
def setup_schema(self):
print("Setting up Neo4j Schema (Constraints)...")
with self.driver.session() as session:
# Nodes
session.run("CREATE CONSTRAINT IF NOT EXISTS FOR (a:Article) REQUIRE a.number IS UNIQUE")
session.run("CREATE CONSTRAINT IF NOT EXISTS FOR (am:Amendment) REQUIRE am.number IS UNIQUE")
session.run("CREATE INDEX IF NOT EXISTS FOR (a:Article) ON (a.part)")
session.run("CREATE INDEX IF NOT EXISTS FOR (c:Clause) ON (c.id)")
session.run("CREATE INDEX IF NOT EXISTS FOR (s:Schedule) ON (s.id)")
def clear_database(self):
print("Clearing Neo4j Database...")
with self.driver.session() as session:
session.run("MATCH (n) DETACH DELETE n")
def get_processed_amendments(self) -> Set[int]:
"""Get set of amendment numbers already in Neo4j"""
with self.driver.session() as session:
result = session.run("MATCH (am:Amendment) RETURN am.number as num")
return {record["num"] for record in result if record["num"] is not None}
def get_resume_point(self) -> int:
"""Determine which amendment to start from"""
if RESUME_FROM is not None and RESUME_FROM > 0:
print(f"📍 Manual resume from Amendment {RESUME_FROM}")
return RESUME_FROM
if AUTO_DETECT_RESUME:
processed = self.get_processed_amendments()
if processed:
max_processed = max(processed)
resume_from = max_processed + 1
print(f"📍 Auto-detected {len(processed)} amendments already processed")
print(f"📍 Resuming from Amendment {resume_from}")
return resume_from
print("📍 Starting from Amendment 0")
return 0
def extract_semantic_changes(self, title: str, summary: str) -> List[dict]:
"""Runs a single LLM extraction request with retries."""
if not self.groq_client: return []
prompt = f"""
You are a Highly Specialized Constitutional Law Scholar and Graph Architect.
Analyze this Amendment Summary to extract a precise Knowledge Graph of entities and relationships.
AMENDMENT CONTEXT:
- Number: {title}
- Identifier for this amendment: "CURRENT_AMENDMENT" (MANDATORY: Use this as the source for all changes made by this amendment).
Summary: {summary}
### RELATIONSHIP SCHEMA (STRICT):
Use ONLY these relationship types:
1. **AMENDS**: Standard modification to an Article or Clause.
2. **REPEALS**: Entire Article removed from the flag.
3. **OMITS**: Specific Clause removed from an Article.
4. **SUBSTITUTES**: Entire Article or Clause replaced by new text.
5. **INSERTS**: New Article or Clause added.
6. **REFERS_TO**: One Article cites or mentions another.
7. **OVERRIDES**: A "Notwithstanding" clause where one Article takes precedence over another (e.g., Art 31C OVERRIDES Art 14).
8. **SUSPENDS**: An Article pauses the operation of another (e.g., Art 358 SUSPENDS Art 19 during Emergency).
9. **DEFINES**: A clause or article that provides the definition for a term used in another article.
### EXTRACTION RULES:
- **Strict IDs**: NEVER invent or guess amendment numbers. All direct actions of the current amendment MUST have `source: "CURRENT_AMENDMENT"`.
- **Entity Clarity**:
- Articles: "Article 19", "Article 21A"
- Clauses: "Article 19(1)(f)", "Article 368(2)"
- **Contextual Links**: If Article X says it overrides Article Y, create `(Article X)-[:OVERRIDES]->(Article Y)`.
Format:
Return JSON Object:
{{
"relationships": [
{{ "source": "CURRENT_AMENDMENT", "relation": "AMENDS", "target": "Article 19", "details": "Modified to add reasonable restrictions" }},
{{ "source": "Article 358", "relation": "SUSPENDS", "target": "Article 19", "details": "Automatic suspension during proclaimed emergency" }}
]
}}
"""
# If no relevant semantic changes are found, return {"relationships": []}
return []
def extract_with_validation(self, title: str, summary: str, max_retries: int = 2) -> List[dict]:
"""Runs the four-pass extraction process for a single amendment."""
all_rels = []
passes = ["modifications", "deletions", "insertions", "cross_effects"]
for p in passes:
rels = self.extract_chunk(title, summary, p)
all_rels.extend(rels)
return all_rels
def split_text_smart(self, text: str, threshold: int = 7000, overlap: int = 500) -> List[str]:
"""
Smart text splitting:
- If text < threshold: return as single chunk
- If text > threshold: split into balanced halves with overlap
"""
if len(text) <= threshold:
return [text]
midpoint = len(text) // 2
for offset in range(0, 200):
if midpoint + offset < len(text) and text[midpoint + offset] == '.':
midpoint = midpoint + offset + 1
break
elif midpoint - offset > 0 and text[midpoint - offset] == '.':
midpoint = midpoint - offset + 1
break
chunk1 = text[:midpoint + overlap]
chunk2 = text[midpoint - overlap:]
return [chunk1, chunk2]
def extract_with_validation(self, title: str, summary: str, max_retries: int = 2) -> List[dict]:
"""
Extract semantic changes using 2 focused, mutually-exclusive passes.
"""
all_relationships = []
# Combined into 2 passes to prevent categorization jitter
passes = ["structural", "mechanics"]
for p in passes:
time.sleep(1.0) # Rate limit protection
rels = self.extract_chunk(title, summary, p)
all_relationships.extend(rels)
print(f" - Pass {p}: {len(rels)} relationships")
# Deduplicate
seen = set()
unique_rels = []
for rel in all_relationships:
# Normalize for deduplication
src = str(rel.get('source', '')).strip().upper()
tgt = str(rel.get('target', '')).strip().upper()
rtype = str(rel.get('relation', '')).strip().upper()
key = f"{src}_{tgt}_{rtype}"
if key not in seen:
seen.add(key)
unique_rels.append(rel)
return unique_rels
def extract_chunk(self, title: str, summary: str, chunk_type: str) -> List[dict]:
"""
Extract a specific type of relationships from the amendment.
"""
if not self.groq_client: return []
text_chunks = self.split_text_smart(summary, threshold=7000, overlap=500)
all_relationships = []
for chunk_idx, summary_chunk in enumerate(text_chunks):
if chunk_type == "structural":
focus = """
Extract Structural Alterations only:
- AMENDS: General modification.
- SUBSTITUTES: Total replacement of text.
- INSERTS: New Article/Clause/Schedule added.
- REPEALS/OMITS: Removal of text/Article.
"""
elif chunk_type == "mechanics":
focus = """
Extract Legal Mechanics and Cross-References only:
- OVERRIDES: "Notwithstanding" clauses or precedence.
- SUSPENDS: Pausing an article (e.g., Emergency).
- REFERS_TO: Citing or mentioning another article.
- DEFINES: Providing a definition for a term used elsewhere.
"""
else: focus = "Semantic extraction."
prompt = f"""
Constitutional Scholar Task: Extract {chunk_type.upper()} from the summary.
AMENDMENT: {title}
SUMMARY PART: {summary_chunk}
FOCUS: {focus}
STRICT RULES:
1. Actions by the amendment MUST use source: "CURRENT_AMENDMENT".
2. If Article A affects Article B, source: "Article A", target: "Article B".
3. Article format: "19", "21A", "19(1)(f)".
4. Return JSON only.
JSON Format:
{{ "relationships": [ {{ "source": "...", "relation": "...", "target": "...", "details": "..." }} ] }}
"""
# Retry loop for malformed generation or transient errors
max_chunk_retries = 3
for attempt in range(max_chunk_retries):
try:
# Provide higher max_tokens to prevent truncation
chat_completion = self.groq_client.chat.completions.create(
messages=[{"role": "user", "content": prompt}],
model=self.model,
temperature=0,
max_tokens=4096,
response_format={"type": "json_object"}
)
content = chat_completion.choices[0].message.content.strip()
data = json.loads(content)
chunk_rels = data.get("relationships", [])
all_relationships.extend(chunk_rels)
break # Success, break retry loop
except Exception as e:
err_str = str(e).lower()
# 429 should be handled by the caller to trigger key rotation exit
if "429" in err_str or "rate_limit" in err_str or "insufficient_balance" in err_str:
raise e
print(f" ⚠️ Chunk extraction attempt {attempt+1}/{max_chunk_retries} failed: {e}")
if attempt < max_chunk_retries - 1:
time.sleep(2 * (attempt + 1)) # Exponential-ish backoff
else:
print(f" ❌ Failed to extract chunk after {max_chunk_retries} attempts.")
return all_relationships
def validate_extraction(self, title: str, summary: str, relationships: List[dict]) -> dict:
"""
Validate that the extracted relationships are complete.
Returns validation result with feedback for retry if needed.
"""
if not self.groq_client or not relationships:
return {"valid": len(relationships) > 0, "feedback": ""}
# Skip validation if content is too large (>15 relationships or >5000 char summary)
if len(relationships) > 15 or len(summary) > 5000:
print(f" ⚠️ Skipping validation (content too large: {len(relationships)} rels, {len(summary)} chars)")
return {"valid": True, "confidence": 0.9, "feedback": ""}
# Truncate summary to prevent token overflow
summary_truncated = summary[:3000] + "..." if len(summary) > 3000 else summary
# Show only first 10 relationships to validation
rels_sample = relationships[:10]
rels_summary = json.dumps(rels_sample, indent=2)
prompt = f"""You are a validation expert checking if constitutional amendment extraction is COMPLETE.
ORIGINAL AMENDMENT: {title}
ORIGINAL SUMMARY (first 3000 chars): {summary_truncated}
EXTRACTED RELATIONSHIPS (first 10):
{rels_summary}
Your task: Verify if the MOST IMPORTANT changes mentioned in the summary are captured in the extracted relationships.
Check for:
1. **Key terms coverage**: If summary mentions "internal disturbance", "armed rebellion", "property rights", etc., they MUST appear in relationship details
2. **Deleted content**: If summary says something was DELETED or REMOVED, there must be a DELETES_CLAUSE or REPEALS relationship
3. **Restored content**: If summary mentions RESTORED, there must be a RESTORES_CLAUSE relationship
4. **Cross-article effects**: If summary mentions "Article X suspends Article Y", there must be a SUSPENDS relationship
Return ONLY valid JSON:
{{
"valid": true|false,
"confidence": 0.0-1.0,
"missing_items": ["List specific items from summary that are missing from relationships"],
"feedback": "Brief feedback (max 100 words)"
}}
Be strict: If ANY significant term or relationship from the summary is missing, mark as invalid.
"""
try:
resp = self.groq_client.chat.completions.create(
messages=[{"role": "user", "content": prompt}],
model=self.model,
temperature=0,
response_format={"type": "json_object"}
)
result = json.loads(resp.choices[0].message.content)
return result
except Exception as e:
# If validation fails, assume valid to not block ingestion
print(f" ⚠️ Validation error (assuming valid): {e}")
return {"valid": True, "confidence": 0.85, "feedback": ""}
def resolve_entity(self, tx, entity_str: str, current_am_num: int, current_year: int):
"""Resolves an entity string to a Neo4j Node and returns the node reference variable."""
if not entity_str: return None, None
entity_str = entity_str.strip()
# 1. Current Amendment
if entity_str == "CURRENT_AMENDMENT":
tx.run("MERGE (am:Amendment {number: $num}) SET am.year = $year", {"num": current_am_num, "year": current_year})
return "Amendment", {"number": current_am_num}
# 2. Specific Amendment
am_match = re.search(r'Amendment\s+(\d+)', entity_str, re.IGNORECASE)
if am_match:
num = int(am_match.group(1))
tx.run("MERGE (am:Amendment {number: $num})", {"num": num})
return "Amendment", {"number": num}
# 3. Clause (Article X(Y))
# Regex for Article with brackets: Article 19(1)(f)
clause_match = re.search(r'(Article\s+)?(\d+[A-Z]*)\s*(\(.+\))', entity_str, re.IGNORECASE)
if clause_match:
art_num = clause_match.group(2)
if not art_num: return None, None
clause_suffix = clause_match.group(3)
clause_id = f"{art_num}{clause_suffix}" # e.g. 19(1)(f)
# Ensure Article Parent
tx.run("MERGE (a:Article {number: $anum}) ON CREATE SET a.year = $year", {"anum": art_num, "year": current_year})
# Ensure Clause
tx.run("""
MATCH (a:Article {number: $anum})
MERGE (c:Clause {id: $cid})
ON CREATE SET c.year = $year
MERGE (c)-[:PART_OF]->(a)
""", {"anum": art_num, "cid": clause_id, "year": current_year})
return "Clause", {"id": clause_id}
# 4. Schedule
sched_match = re.search(r'(\d+)\w*\s+Schedule', entity_str, re.IGNORECASE)
if sched_match:
sid = sched_match.group(1)
tx.run("MERGE (s:Schedule {id: $sid}) ON CREATE SET s.year = $year", {"sid": sid, "year": current_year})
return "Schedule", {"id": sid}
# 5. Article (Fallback)
# Clean "Article " prefix
clean_art = entity_str.replace("Article", "").replace("article", "").strip()
if clean_art:
tx.run("MERGE (a:Article {number: $anum}) ON CREATE SET a.year = $year", {"anum": clean_art, "year": current_year})
return "Article", {"number": clean_art}
return None, None
def ingest_all(self):
self.setup_schema()
# Determine resume point (DON'T clear database if resuming)
resume_from = self.get_resume_point()
if resume_from == 0:
self.clear_database()
else:
print(f"⚠️ Skipping database clear (resuming from Amendment {resume_from})")
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
extracted_dir = os.path.join(root_dir, "extracted_data")
all_files = sorted(glob.glob(os.path.join(extracted_dir, "**", "*.json"), recursive=True))
print(f"Found {len(all_files)} files for Graph Ingestion.")
rel_count = 0
amendments_processed = 0
amendments_skipped = 0
with self.driver.session() as session:
for i, file_path in enumerate(all_files):
if i % 10 == 0: print(f"Processing file {i}/{len(all_files)}...")
try:
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
metadata = data.get("metadata", {})
content = data.get("content", "")
doc_type = metadata.get("type", "")
am_num = metadata.get("amendment_number")
if am_num is None: continue
current_am_num = int(am_num)
# Skip if before resume point
if current_am_num < resume_from:
if doc_type == "amendment_summary":
amendments_skipped += 1
continue
# Create Base Amendment Node
session.run("""
MERGE (am:Amendment {number: $am_num})
SET am.year = $year, am.title = $title
""", {
"am_num": current_am_num,
"year": metadata.get("year"),
"title": metadata.get("amendment_title")
})
# Semantic Extraction
if doc_type == "amendment_summary" and self.groq_client:
print(f"Extracting Semantics for Amendment {current_am_num}...")
amendments_processed += 1
# Add delay to avoid Rate Limits (PROACTIVE SLEEP)
time.sleep(2.0)
start_time = time.time()
# Use validated extraction with retry loop
relationships = self.extract_with_validation(
metadata.get("amendment_title"),
content,
max_retries=2
)
elapsed = time.time() - start_time
print(f" ⏱ Completed in {elapsed:.2f}s ({len(relationships)} relationships)")
for rel in relationships:
s_raw = rel.get("source", "CURRENT_AMENDMENT")
t_raw = rel.get("target")
r_type = rel.get("relation", "MODIFIES").upper().replace(" ", "_")
details = rel.get("details", "")
# Basic validation for Cypher safety
if not re.match(r"^[A-Z_]+$", r_type):
r_type = "AMENDS"
if not t_raw: continue
# Resolve Nodes
current_year = int(metadata.get("year", 1950))
sType, sProps = self.resolve_entity(session, s_raw, current_am_num, current_year)
tType, tProps = self.resolve_entity(session, t_raw, current_am_num, current_year)
if sType and tType:
# Build Query dynamically based on resolved types
s_ident = "number" if "number" in sProps else "id"
t_ident = "number" if "number" in tProps else "id"
query = f"""
MATCH (s:{sType} {{{s_ident}: $sVal}})
MATCH (t:{tType} {{{t_ident}: $tVal}})
MERGE (s)-[r:{r_type}]->(t)
SET r.details = $details, r.source = 'llm_extraction'
"""
session.run(query, {
"sVal": sProps[s_ident],
"tVal": tProps[t_ident],
"details": details
})
rel_count += 1
except Exception as e:
print(f"Error processing {file_path}: {e}")
print("\n" + "="*70)
print("GRAPH INGESTION COMPLETE")
print("="*70)
if amendments_skipped > 0:
print(f"Amendments skipped (already processed): {amendments_skipped}")
print(f"Amendments processed this run: {amendments_processed}")
print(f"Total relationships created: {rel_count}")
print("="*70 + "\n")
if __name__ == "__main__":
try:
service = GraphIngestionService()
service.ingest_all()
service.close()
except Exception as e:
print(f"Failed to run graph ingestion. Is Neo4j running? {e}")