""" PrivacyVerificationModule: Run-time privacy guarantees during training & inference. Three-level defense: structural invariants, forbidden token detection, entropy bounds. """ import torch import torch.nn as nn from typing import Dict, List, Optional, Tuple from dataclasses import dataclass from collections import defaultdict @dataclass class PrivacyReport: struct_ok: bool tokens_ok: bool privacy_score: float budget_remaining: float details: Dict[str, float] class StructuralInvariantChecker: def __init__(self, schema: Dict[str, List[str]]): self.schema = schema def check(self, graph): violations = [] for edge in graph.get("edges", []): rel = edge.get("relation", "") if not rel.startswith(":") and rel != ":next": violations.append(f"Invalid relation: {rel}") nodes = set(range(len(graph.get("nodes", [])))) connected = set() for edge in graph.get("edges", []): connected.add(edge["source"]) connected.add(edge["target"]) isolated = nodes - connected if isolated: violations.append(f"Isolated nodes: {isolated}") return len(violations) == 0, violations class ForbiddenTokenDetector: def __init__(self, forbidden_patterns: List[str]): self.forbidden_patterns = set(forbidden_patterns) def check(self, generated_ids, vocab): violations = [] for token_id_tensor in generated_ids.flatten(): token_id = token_id_tensor.item() token_str = vocab.get(token_id, "") if not token_str.startswith("<") and self._matches_forbidden(token_str): violations.append({"token_id": token_id, "token": token_str}) return len(violations) == 0, violations def _matches_forbidden(self, token): token_lower = token.lower() for pattern in self.forbidden_patterns: if pattern in token_lower: return True return False class PrivacyEntropyEstimator: def estimate(self, hidden_states, target_token=0): norms = torch.norm(hidden_states, dim=-1) probs = torch.softmax(norms, dim=-1) entropy = -torch.sum(probs * torch.log(probs + 1e-10), dim=-1) max_entropy = torch.log(torch.tensor(norms.shape[-1], dtype=torch.float32)) privacy_score = (entropy / max_entropy).mean().item() return min(privacy_score, 1.0) class PrivacyBudgetAccountant: def __init__(self, epsilon=3.0, delta=1e-5, accountant_type="rdp"): self.epsilon = epsilon self.delta = delta self.spent = 0.0 self.query_count = 0 def add_query(self, noise_multiplier, sampling_rate): from math import log alpha = 2.0 rdp = alpha * (noise_multiplier ** 2) / 2.0 converted_eps = rdp + log(1 / self.delta) / (alpha - 1) self.spent += converted_eps * sampling_rate self.query_count += 1 def remaining(self): return max(0.0, self.epsilon - self.spent) def is_exhausted(self): return self.spent >= self.epsilon class PrivacyVerificationModule(nn.Module): def __init__(self, config): super().__init__() self.config = config self.struct_checker = StructuralInvariantChecker(schema={}) self.token_detector = ForbiddenTokenDetector([ "ssn", "social", "password", "passport", "bank account", "credit card", "cvv", "routing", ]) self.entropy_estimator = PrivacyEntropyEstimator() self.budget_accountant = PrivacyBudgetAccountant( epsilon=config.dp_epsilon, delta=config.dp_delta ) self.threshold = config.privacy_threshold def verify(self, hidden_states, generated_ids=None, graph=None, vocab=None): details = {} struct_ok = True if graph is not None: struct_ok, violations = self.struct_checker.check(graph) details["struct_violations"] = len(violations) tokens_ok = True if generated_ids is not None and vocab is not None: tokens_ok, violations = self.token_detector.check(generated_ids, vocab) details["forbidden_tokens"] = len(violations) privacy_score = self.entropy_estimator.estimate(hidden_states, target_token=0) details["privacy_score"] = privacy_score is_secure = ( struct_ok and tokens_ok and privacy_score >= self.threshold and not self.budget_accountant.is_exhausted() ) return PrivacyReport( struct_ok=struct_ok, tokens_ok=tokens_ok, privacy_score=privacy_score, budget_remaining=self.budget_accountant.remaining(), details=details ) def enforce(self, report): if report.privacy_score < self.threshold: return {"block": True, "reason": "low_privacy_score", "report": report} if not report.tokens_ok: return {"block": True, "reason": "forbidden_tokens", "report": report} if self.budget_accountant.is_exhausted(): return {"block": True, "reason": "privacy_budget_exhausted", "report": report} return {"block": False, "report": report}