File size: 5,293 Bytes
2b8a1c6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 | """
PrivacyVerificationModule: Run-time privacy guarantees during training & inference.
Three-level defense: structural invariants, forbidden token detection, entropy bounds.
"""
import torch
import torch.nn as nn
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass
from collections import defaultdict
@dataclass
class PrivacyReport:
struct_ok: bool
tokens_ok: bool
privacy_score: float
budget_remaining: float
details: Dict[str, float]
class StructuralInvariantChecker:
def __init__(self, schema: Dict[str, List[str]]):
self.schema = schema
def check(self, graph):
violations = []
for edge in graph.get("edges", []):
rel = edge.get("relation", "")
if not rel.startswith(":") and rel != ":next":
violations.append(f"Invalid relation: {rel}")
nodes = set(range(len(graph.get("nodes", []))))
connected = set()
for edge in graph.get("edges", []):
connected.add(edge["source"])
connected.add(edge["target"])
isolated = nodes - connected
if isolated:
violations.append(f"Isolated nodes: {isolated}")
return len(violations) == 0, violations
class ForbiddenTokenDetector:
def __init__(self, forbidden_patterns: List[str]):
self.forbidden_patterns = set(forbidden_patterns)
def check(self, generated_ids, vocab):
violations = []
for token_id_tensor in generated_ids.flatten():
token_id = token_id_tensor.item()
token_str = vocab.get(token_id, "<unk>")
if not token_str.startswith("<") and self._matches_forbidden(token_str):
violations.append({"token_id": token_id, "token": token_str})
return len(violations) == 0, violations
def _matches_forbidden(self, token):
token_lower = token.lower()
for pattern in self.forbidden_patterns:
if pattern in token_lower:
return True
return False
class PrivacyEntropyEstimator:
def estimate(self, hidden_states, target_token=0):
norms = torch.norm(hidden_states, dim=-1)
probs = torch.softmax(norms, dim=-1)
entropy = -torch.sum(probs * torch.log(probs + 1e-10), dim=-1)
max_entropy = torch.log(torch.tensor(norms.shape[-1], dtype=torch.float32))
privacy_score = (entropy / max_entropy).mean().item()
return min(privacy_score, 1.0)
class PrivacyBudgetAccountant:
def __init__(self, epsilon=3.0, delta=1e-5, accountant_type="rdp"):
self.epsilon = epsilon
self.delta = delta
self.spent = 0.0
self.query_count = 0
def add_query(self, noise_multiplier, sampling_rate):
from math import log
alpha = 2.0
rdp = alpha * (noise_multiplier ** 2) / 2.0
converted_eps = rdp + log(1 / self.delta) / (alpha - 1)
self.spent += converted_eps * sampling_rate
self.query_count += 1
def remaining(self):
return max(0.0, self.epsilon - self.spent)
def is_exhausted(self):
return self.spent >= self.epsilon
class PrivacyVerificationModule(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.struct_checker = StructuralInvariantChecker(schema={})
self.token_detector = ForbiddenTokenDetector([
"ssn", "social", "password", "passport", "bank account",
"credit card", "cvv", "routing",
])
self.entropy_estimator = PrivacyEntropyEstimator()
self.budget_accountant = PrivacyBudgetAccountant(
epsilon=config.dp_epsilon, delta=config.dp_delta
)
self.threshold = config.privacy_threshold
def verify(self, hidden_states, generated_ids=None, graph=None, vocab=None):
details = {}
struct_ok = True
if graph is not None:
struct_ok, violations = self.struct_checker.check(graph)
details["struct_violations"] = len(violations)
tokens_ok = True
if generated_ids is not None and vocab is not None:
tokens_ok, violations = self.token_detector.check(generated_ids, vocab)
details["forbidden_tokens"] = len(violations)
privacy_score = self.entropy_estimator.estimate(hidden_states, target_token=0)
details["privacy_score"] = privacy_score
is_secure = (
struct_ok and tokens_ok and
privacy_score >= self.threshold and
not self.budget_accountant.is_exhausted()
)
return PrivacyReport(
struct_ok=struct_ok, tokens_ok=tokens_ok,
privacy_score=privacy_score,
budget_remaining=self.budget_accountant.remaining(),
details=details
)
def enforce(self, report):
if report.privacy_score < self.threshold:
return {"block": True, "reason": "low_privacy_score", "report": report}
if not report.tokens_ok:
return {"block": True, "reason": "forbidden_tokens", "report": report}
if self.budget_accountant.is_exhausted():
return {"block": True, "reason": "privacy_budget_exhausted", "report": report}
return {"block": False, "report": report}
|