| """ |
| PrivacyVerificationModule: Run-time privacy guarantees during training & inference. |
| Three-level defense: structural invariants, forbidden token detection, entropy bounds. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| from typing import Dict, List, Optional, Tuple |
| from dataclasses import dataclass |
| from collections import defaultdict |
|
|
|
|
| @dataclass |
| class PrivacyReport: |
| struct_ok: bool |
| tokens_ok: bool |
| privacy_score: float |
| budget_remaining: float |
| details: Dict[str, float] |
|
|
|
|
| class StructuralInvariantChecker: |
| def __init__(self, schema: Dict[str, List[str]]): |
| self.schema = schema |
|
|
| def check(self, graph): |
| violations = [] |
| for edge in graph.get("edges", []): |
| rel = edge.get("relation", "") |
| if not rel.startswith(":") and rel != ":next": |
| violations.append(f"Invalid relation: {rel}") |
| nodes = set(range(len(graph.get("nodes", [])))) |
| connected = set() |
| for edge in graph.get("edges", []): |
| connected.add(edge["source"]) |
| connected.add(edge["target"]) |
| isolated = nodes - connected |
| if isolated: |
| violations.append(f"Isolated nodes: {isolated}") |
| return len(violations) == 0, violations |
|
|
|
|
| class ForbiddenTokenDetector: |
| def __init__(self, forbidden_patterns: List[str]): |
| self.forbidden_patterns = set(forbidden_patterns) |
|
|
| def check(self, generated_ids, vocab): |
| violations = [] |
| for token_id_tensor in generated_ids.flatten(): |
| token_id = token_id_tensor.item() |
| token_str = vocab.get(token_id, "<unk>") |
| if not token_str.startswith("<") and self._matches_forbidden(token_str): |
| violations.append({"token_id": token_id, "token": token_str}) |
| return len(violations) == 0, violations |
|
|
| def _matches_forbidden(self, token): |
| token_lower = token.lower() |
| for pattern in self.forbidden_patterns: |
| if pattern in token_lower: |
| return True |
| return False |
|
|
|
|
| class PrivacyEntropyEstimator: |
| def estimate(self, hidden_states, target_token=0): |
| norms = torch.norm(hidden_states, dim=-1) |
| probs = torch.softmax(norms, dim=-1) |
| entropy = -torch.sum(probs * torch.log(probs + 1e-10), dim=-1) |
| max_entropy = torch.log(torch.tensor(norms.shape[-1], dtype=torch.float32)) |
| privacy_score = (entropy / max_entropy).mean().item() |
| return min(privacy_score, 1.0) |
|
|
|
|
| class PrivacyBudgetAccountant: |
| def __init__(self, epsilon=3.0, delta=1e-5, accountant_type="rdp"): |
| self.epsilon = epsilon |
| self.delta = delta |
| self.spent = 0.0 |
| self.query_count = 0 |
|
|
| def add_query(self, noise_multiplier, sampling_rate): |
| from math import log |
| alpha = 2.0 |
| rdp = alpha * (noise_multiplier ** 2) / 2.0 |
| converted_eps = rdp + log(1 / self.delta) / (alpha - 1) |
| self.spent += converted_eps * sampling_rate |
| self.query_count += 1 |
|
|
| def remaining(self): |
| return max(0.0, self.epsilon - self.spent) |
|
|
| def is_exhausted(self): |
| return self.spent >= self.epsilon |
|
|
|
|
| class PrivacyVerificationModule(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| self.struct_checker = StructuralInvariantChecker(schema={}) |
| self.token_detector = ForbiddenTokenDetector([ |
| "ssn", "social", "password", "passport", "bank account", |
| "credit card", "cvv", "routing", |
| ]) |
| self.entropy_estimator = PrivacyEntropyEstimator() |
| self.budget_accountant = PrivacyBudgetAccountant( |
| epsilon=config.dp_epsilon, delta=config.dp_delta |
| ) |
| self.threshold = config.privacy_threshold |
|
|
| def verify(self, hidden_states, generated_ids=None, graph=None, vocab=None): |
| details = {} |
| struct_ok = True |
| if graph is not None: |
| struct_ok, violations = self.struct_checker.check(graph) |
| details["struct_violations"] = len(violations) |
| tokens_ok = True |
| if generated_ids is not None and vocab is not None: |
| tokens_ok, violations = self.token_detector.check(generated_ids, vocab) |
| details["forbidden_tokens"] = len(violations) |
| privacy_score = self.entropy_estimator.estimate(hidden_states, target_token=0) |
| details["privacy_score"] = privacy_score |
| is_secure = ( |
| struct_ok and tokens_ok and |
| privacy_score >= self.threshold and |
| not self.budget_accountant.is_exhausted() |
| ) |
| return PrivacyReport( |
| struct_ok=struct_ok, tokens_ok=tokens_ok, |
| privacy_score=privacy_score, |
| budget_remaining=self.budget_accountant.remaining(), |
| details=details |
| ) |
|
|
| def enforce(self, report): |
| if report.privacy_score < self.threshold: |
| return {"block": True, "reason": "low_privacy_score", "report": report} |
| if not report.tokens_ok: |
| return {"block": True, "reason": "forbidden_tokens", "report": report} |
| if self.budget_accountant.is_exhausted(): |
| return {"block": True, "reason": "privacy_budget_exhausted", "report": report} |
| return {"block": False, "report": report} |
|
|