StruCTA / structa /abstraction.py
YOUSSEF88's picture
Upload structa/abstraction.py
d4c2430 verified
"""
AbstractionLayer: Non-differentiable entity abstraction pipeline.
Replaces sensitive entities with typed abstract tokens and produces structural graphs.
This runs OUTSIDE the model — raw text never enters the transformer.
"""
import re
import hashlib
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Any
from collections import defaultdict
@dataclass
class AbstractDocument:
"""Output of the abstraction layer."""
abstract_text: str
amr_graph: Dict[str, Any]
vault_id: str
schema_version: str = "1.0"
def _hash_vault(mapping: Dict[str, Any]) -> str:
"""Generate deterministic vault ID from entity mapping."""
serialized = str(sorted(mapping.items()))
return hashlib.sha256(serialized.encode()).hexdigest()[:24]
def _regex_based_ner(text: str) -> List[Dict[str, Any]]:
"""Fallback regex-based entity detection (no external NER model needed)."""
entities = []
email_pattern = r'[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}'
for match in re.finditer(email_pattern, text):
entities.append({"entity_group": "EMAIL", "word": match.group(),
"start": match.start(), "end": match.end()})
phone_pattern = r'(?:\+1[-.\s]?)?(?:\(?[0-9]{3}\)?[-.\s]?[0-9]{3}[-.\s]?[0-9]{4})'
for match in re.finditer(phone_pattern, text):
entities.append({"entity_group": "PHONE", "word": match.group(),
"start": match.start(), "end": match.end()})
ssn_pattern = r'\b\d{3}[-.\s]?\d{2}[-.\s]?\d{4}\b'
for match in re.finditer(ssn_pattern, text):
entities.append({"entity_group": "SSN", "word": match.group(),
"start": match.start(), "end": match.end()})
money_pattern = r'(?:\$|USD|GBP|EUR|JPY|CNY)\s*(?:\d{1,3}(?:,\d{3})*|\d+)(?:\.\d{2})?'
for match in re.finditer(money_pattern, text):
entities.append({"entity_group": "MONEY", "word": match.group(),
"start": match.start(), "end": match.end()})
date_pattern = (r'\b(?:January|February|March|April|May|June|July|August|'
r'September|October|November|December|Jan|Feb|Mar|Apr|Jun|'
r'Jul|Aug|Sep|Oct|Nov|Dec)\.?\s+\d{1,2},?\s+(?:\d{4}|\d{2})\b')
for match in re.finditer(date_pattern, text):
entities.append({"entity_group": "DATE", "word": match.group(),
"start": match.start(), "end": match.end()})
url_pattern = r'https?://[^\s]+|www\.[^\s]+'
for match in re.finditer(url_pattern, text):
entities.append({"entity_group": "ID", "word": match.group(),
"start": match.start(), "end": match.end()})
return entities
def _parse_amr_fallback(abstract_text: str) -> Dict[str, Any]:
"""Simplified AMR parse — dependency-like graph from token relationships."""
tokens = abstract_text.split()
nodes = []
edges = []
for i, tok in enumerate(tokens):
nodes.append({"id": i, "concept": tok,
"is_entity": tok.startswith("<") and tok.endswith(">"),
"type": _extract_entity_type(tok) if tok.startswith("<") else "WORD"})
if i > 0:
edges.append({"source": i - 1, "target": i, "relation": ":next"})
return {"nodes": nodes, "edges": edges, "root": 0}
def _extract_entity_type(token: str) -> str:
if token.startswith("<PERSON"): return "PERSON"
elif token.startswith("<ORG"): return "ORG"
elif token.startswith("<LOC") or token.startswith("<GPE"): return "LOC"
elif token.startswith("<$AMOUNT") or token.startswith("<MONEY"): return "MONEY"
elif token.startswith("<DATE"): return "DATE"
elif token.startswith("<PHONE"): return "PHONE"
elif token.startswith("<EMAIL"): return "EMAIL"
elif token.startswith("<SSN"): return "SSN"
elif token.startswith("<ID"): return "ID"
return "MISC"
def _replace_spans(text: str, entities: List[Dict[str, Any]]) -> str:
entities = sorted(entities, key=lambda e: e["start"])
result = ""
last_end = 0
type_counter = defaultdict(int)
for ent in entities:
result += text[last_end:ent["start"]]
ent_type = ent["entity_group"]
type_counter[ent_type] += 1
idx = type_counter[ent_type]
token = f"<{ent_type}_{idx}>"
result += token
last_end = ent["end"]
result += text[last_end:]
result = re.sub(r'\s+', ' ', result).strip()
return result
class AbstractionLayer:
"""Privacy abstraction pipeline. Converts raw text → abstract document with entity vault."""
ENTITY_SCHEMA = {
"PERSON": "<PERSON_{id}>",
"ORG": "<ORG_{id}>",
"LOC": "<LOC_{id}>",
"GPE": "<GPE_{id}>",
"MONEY": "<MONEY_{id}>",
"DATE": "<DATE_{id}>",
"PHONE": "<PHONE_{id}>",
"EMAIL": "<EMAIL_{id}>",
"SSN": "<SSN_{id}>",
"ID": "<ID_{id}>",
"PRODUCT": "<PRODUCT_{id}>",
"EVENT": "<EVENT_{id}>",
"MISC": "<MISC_{id}>",
}
def __init__(self, use_ner_model: bool = False, ner_model_name: Optional[str] = None):
self.use_ner_model = use_ner_model
if use_ner_model:
try:
from transformers import pipeline
model_name = ner_model_name or "dslim/bert-base-NER"
self.ner_pipeline = pipeline("ner", model=model_name,
aggregation_strategy="simple")
except ImportError:
self.use_ner_model = False
def abstract(self, text: str) -> AbstractDocument:
if self.use_ner_model:
raw = self.ner_pipeline(text)
entities = [{"entity_group": e["entity_group"], "word": e["word"],
"start": e["start"], "end": e["end"]} for e in raw]
else:
entities = _regex_based_ner(text)
abstract_text = _replace_spans(text, entities)
type_counter = defaultdict(int)
vault = {}
for ent in entities:
ent_type = ent["entity_group"]
type_counter[ent_type] += 1
token = f"<{ent_type}_{type_counter[ent_type]}>"
vault[token] = ent
vault_id = _hash_vault(vault)
self._store_vault(vault_id, vault)
amr_graph = _parse_amr_fallback(abstract_text)
return AbstractDocument(
abstract_text=abstract_text,
amr_graph=amr_graph,
vault_id=vault_id,
schema_version="1.0"
)
def _store_vault(self, vault_id: str, vault: Dict[str, Any]):
if not hasattr(self, "_vault_store"):
self._vault_store = {}
self._vault_store[vault_id] = vault
def retrieve_vault(self, vault_id: str) -> Dict[str, Any]:
if not hasattr(self, "_vault_store"):
return {}
return self._vault_store.get(vault_id, {})
def is_secure(self, text: str) -> bool:
if self.use_ner_model:
entities = self.ner_pipeline(text)
return len(entities) == 0
else:
entities = _regex_based_ner(text)
return len(entities) == 0