File size: 5,293 Bytes
2b8a1c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
"""
PrivacyVerificationModule: Run-time privacy guarantees during training & inference.
Three-level defense: structural invariants, forbidden token detection, entropy bounds.
"""

import torch
import torch.nn as nn
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass
from collections import defaultdict


@dataclass
class PrivacyReport:
    struct_ok: bool
    tokens_ok: bool
    privacy_score: float
    budget_remaining: float
    details: Dict[str, float]


class StructuralInvariantChecker:
    def __init__(self, schema: Dict[str, List[str]]):
        self.schema = schema

    def check(self, graph):
        violations = []
        for edge in graph.get("edges", []):
            rel = edge.get("relation", "")
            if not rel.startswith(":") and rel != ":next":
                violations.append(f"Invalid relation: {rel}")
        nodes = set(range(len(graph.get("nodes", []))))
        connected = set()
        for edge in graph.get("edges", []):
            connected.add(edge["source"])
            connected.add(edge["target"])
        isolated = nodes - connected
        if isolated:
            violations.append(f"Isolated nodes: {isolated}")
        return len(violations) == 0, violations


class ForbiddenTokenDetector:
    def __init__(self, forbidden_patterns: List[str]):
        self.forbidden_patterns = set(forbidden_patterns)

    def check(self, generated_ids, vocab):
        violations = []
        for token_id_tensor in generated_ids.flatten():
            token_id = token_id_tensor.item()
            token_str = vocab.get(token_id, "<unk>")
            if not token_str.startswith("<") and self._matches_forbidden(token_str):
                violations.append({"token_id": token_id, "token": token_str})
        return len(violations) == 0, violations

    def _matches_forbidden(self, token):
        token_lower = token.lower()
        for pattern in self.forbidden_patterns:
            if pattern in token_lower:
                return True
        return False


class PrivacyEntropyEstimator:
    def estimate(self, hidden_states, target_token=0):
        norms = torch.norm(hidden_states, dim=-1)
        probs = torch.softmax(norms, dim=-1)
        entropy = -torch.sum(probs * torch.log(probs + 1e-10), dim=-1)
        max_entropy = torch.log(torch.tensor(norms.shape[-1], dtype=torch.float32))
        privacy_score = (entropy / max_entropy).mean().item()
        return min(privacy_score, 1.0)


class PrivacyBudgetAccountant:
    def __init__(self, epsilon=3.0, delta=1e-5, accountant_type="rdp"):
        self.epsilon = epsilon
        self.delta = delta
        self.spent = 0.0
        self.query_count = 0

    def add_query(self, noise_multiplier, sampling_rate):
        from math import log
        alpha = 2.0
        rdp = alpha * (noise_multiplier ** 2) / 2.0
        converted_eps = rdp + log(1 / self.delta) / (alpha - 1)
        self.spent += converted_eps * sampling_rate
        self.query_count += 1

    def remaining(self):
        return max(0.0, self.epsilon - self.spent)

    def is_exhausted(self):
        return self.spent >= self.epsilon


class PrivacyVerificationModule(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.struct_checker = StructuralInvariantChecker(schema={})
        self.token_detector = ForbiddenTokenDetector([
            "ssn", "social", "password", "passport", "bank account",
            "credit card", "cvv", "routing",
        ])
        self.entropy_estimator = PrivacyEntropyEstimator()
        self.budget_accountant = PrivacyBudgetAccountant(
            epsilon=config.dp_epsilon, delta=config.dp_delta
        )
        self.threshold = config.privacy_threshold

    def verify(self, hidden_states, generated_ids=None, graph=None, vocab=None):
        details = {}
        struct_ok = True
        if graph is not None:
            struct_ok, violations = self.struct_checker.check(graph)
            details["struct_violations"] = len(violations)
        tokens_ok = True
        if generated_ids is not None and vocab is not None:
            tokens_ok, violations = self.token_detector.check(generated_ids, vocab)
            details["forbidden_tokens"] = len(violations)
        privacy_score = self.entropy_estimator.estimate(hidden_states, target_token=0)
        details["privacy_score"] = privacy_score
        is_secure = (
            struct_ok and tokens_ok and
            privacy_score >= self.threshold and
            not self.budget_accountant.is_exhausted()
        )
        return PrivacyReport(
            struct_ok=struct_ok, tokens_ok=tokens_ok,
            privacy_score=privacy_score,
            budget_remaining=self.budget_accountant.remaining(),
            details=details
        )

    def enforce(self, report):
        if report.privacy_score < self.threshold:
            return {"block": True, "reason": "low_privacy_score", "report": report}
        if not report.tokens_ok:
            return {"block": True, "reason": "forbidden_tokens", "report": report}
        if self.budget_accountant.is_exhausted():
            return {"block": True, "reason": "privacy_budget_exhausted", "report": report}
        return {"block": False, "report": report}