StruCTA / structa /privacy.py
YOUSSEF88's picture
Upload structa/privacy.py
2b8a1c6 verified
"""
PrivacyVerificationModule: Run-time privacy guarantees during training & inference.
Three-level defense: structural invariants, forbidden token detection, entropy bounds.
"""
import torch
import torch.nn as nn
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass
from collections import defaultdict
@dataclass
class PrivacyReport:
struct_ok: bool
tokens_ok: bool
privacy_score: float
budget_remaining: float
details: Dict[str, float]
class StructuralInvariantChecker:
def __init__(self, schema: Dict[str, List[str]]):
self.schema = schema
def check(self, graph):
violations = []
for edge in graph.get("edges", []):
rel = edge.get("relation", "")
if not rel.startswith(":") and rel != ":next":
violations.append(f"Invalid relation: {rel}")
nodes = set(range(len(graph.get("nodes", []))))
connected = set()
for edge in graph.get("edges", []):
connected.add(edge["source"])
connected.add(edge["target"])
isolated = nodes - connected
if isolated:
violations.append(f"Isolated nodes: {isolated}")
return len(violations) == 0, violations
class ForbiddenTokenDetector:
def __init__(self, forbidden_patterns: List[str]):
self.forbidden_patterns = set(forbidden_patterns)
def check(self, generated_ids, vocab):
violations = []
for token_id_tensor in generated_ids.flatten():
token_id = token_id_tensor.item()
token_str = vocab.get(token_id, "<unk>")
if not token_str.startswith("<") and self._matches_forbidden(token_str):
violations.append({"token_id": token_id, "token": token_str})
return len(violations) == 0, violations
def _matches_forbidden(self, token):
token_lower = token.lower()
for pattern in self.forbidden_patterns:
if pattern in token_lower:
return True
return False
class PrivacyEntropyEstimator:
def estimate(self, hidden_states, target_token=0):
norms = torch.norm(hidden_states, dim=-1)
probs = torch.softmax(norms, dim=-1)
entropy = -torch.sum(probs * torch.log(probs + 1e-10), dim=-1)
max_entropy = torch.log(torch.tensor(norms.shape[-1], dtype=torch.float32))
privacy_score = (entropy / max_entropy).mean().item()
return min(privacy_score, 1.0)
class PrivacyBudgetAccountant:
def __init__(self, epsilon=3.0, delta=1e-5, accountant_type="rdp"):
self.epsilon = epsilon
self.delta = delta
self.spent = 0.0
self.query_count = 0
def add_query(self, noise_multiplier, sampling_rate):
from math import log
alpha = 2.0
rdp = alpha * (noise_multiplier ** 2) / 2.0
converted_eps = rdp + log(1 / self.delta) / (alpha - 1)
self.spent += converted_eps * sampling_rate
self.query_count += 1
def remaining(self):
return max(0.0, self.epsilon - self.spent)
def is_exhausted(self):
return self.spent >= self.epsilon
class PrivacyVerificationModule(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.struct_checker = StructuralInvariantChecker(schema={})
self.token_detector = ForbiddenTokenDetector([
"ssn", "social", "password", "passport", "bank account",
"credit card", "cvv", "routing",
])
self.entropy_estimator = PrivacyEntropyEstimator()
self.budget_accountant = PrivacyBudgetAccountant(
epsilon=config.dp_epsilon, delta=config.dp_delta
)
self.threshold = config.privacy_threshold
def verify(self, hidden_states, generated_ids=None, graph=None, vocab=None):
details = {}
struct_ok = True
if graph is not None:
struct_ok, violations = self.struct_checker.check(graph)
details["struct_violations"] = len(violations)
tokens_ok = True
if generated_ids is not None and vocab is not None:
tokens_ok, violations = self.token_detector.check(generated_ids, vocab)
details["forbidden_tokens"] = len(violations)
privacy_score = self.entropy_estimator.estimate(hidden_states, target_token=0)
details["privacy_score"] = privacy_score
is_secure = (
struct_ok and tokens_ok and
privacy_score >= self.threshold and
not self.budget_accountant.is_exhausted()
)
return PrivacyReport(
struct_ok=struct_ok, tokens_ok=tokens_ok,
privacy_score=privacy_score,
budget_remaining=self.budget_accountant.remaining(),
details=details
)
def enforce(self, report):
if report.privacy_score < self.threshold:
return {"block": True, "reason": "low_privacy_score", "report": report}
if not report.tokens_ok:
return {"block": True, "reason": "forbidden_tokens", "report": report}
if self.budget_accountant.is_exhausted():
return {"block": True, "reason": "privacy_budget_exhausted", "report": report}
return {"block": False, "report": report}