InsafQ commited on
Commit
b2faf0e
·
verified ·
1 Parent(s): 0e1204d

Add tabgan/privacy_metrics.py

Browse files
Files changed (1) hide show
  1. tabgan/privacy_metrics.py +225 -0
tabgan/privacy_metrics.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Privacy metrics for assessing re-identification risk in synthetic data.
4
+
5
+ Provides Distance to Closest Record (DCR), Nearest Neighbor Distance Ratio
6
+ (NNDR), and a membership inference risk score.
7
+ """
8
+
9
+ import logging
10
+ from typing import Dict, List, Optional
11
+
12
+ import numpy as np
13
+ import pandas as pd
14
+ from sklearn.neighbors import NearestNeighbors
15
+ from sklearn.preprocessing import OrdinalEncoder, StandardScaler
16
+
17
+ __all__ = ["PrivacyMetrics"]
18
+
19
+
20
+ def _encode_for_distance(
21
+ original: pd.DataFrame,
22
+ synthetic: pd.DataFrame,
23
+ cat_cols: Optional[List[str]] = None,
24
+ ) -> tuple:
25
+ """Encode and scale DataFrames for distance computation."""
26
+ original = original.copy()
27
+ synthetic = synthetic.copy()
28
+
29
+ if cat_cols:
30
+ encoder = OrdinalEncoder(handle_unknown="use_encoded_value", unknown_value=-1)
31
+ original[cat_cols] = encoder.fit_transform(original[cat_cols].astype(str))
32
+ synthetic[cat_cols] = encoder.transform(synthetic[cat_cols].astype(str))
33
+
34
+ # Fill NaN with column medians
35
+ for col in original.columns:
36
+ if original[col].isna().any():
37
+ med = original[col].median()
38
+ original[col] = original[col].fillna(med)
39
+ synthetic[col] = synthetic[col].fillna(med)
40
+
41
+ scaler = StandardScaler()
42
+ orig_scaled = scaler.fit_transform(original.select_dtypes(include=[np.number]))
43
+ synth_scaled = scaler.transform(synthetic.select_dtypes(include=[np.number]))
44
+
45
+ return orig_scaled, synth_scaled
46
+
47
+
48
+ class PrivacyMetrics:
49
+ """Evaluate privacy risk of synthetic data relative to original data.
50
+
51
+ Args:
52
+ original_df: The real / training DataFrame.
53
+ synthetic_df: The generated / synthetic DataFrame.
54
+ cat_cols: Names of categorical columns (encoded before distance computation).
55
+
56
+ Example::
57
+
58
+ from tabgan.privacy_metrics import PrivacyMetrics
59
+ pm = PrivacyMetrics(original_df, synthetic_df, cat_cols=["gender"])
60
+ print(pm.summary())
61
+ """
62
+
63
+ def __init__(
64
+ self,
65
+ original_df: pd.DataFrame,
66
+ synthetic_df: pd.DataFrame,
67
+ cat_cols: Optional[List[str]] = None,
68
+ ):
69
+ shared_cols = [c for c in original_df.columns if c in synthetic_df.columns]
70
+ self.original_df = original_df[shared_cols].copy()
71
+ self.synthetic_df = synthetic_df[shared_cols].copy()
72
+ self.cat_cols = [c for c in (cat_cols or []) if c in shared_cols]
73
+ self._orig_scaled, self._synth_scaled = _encode_for_distance(
74
+ self.original_df, self.synthetic_df, self.cat_cols
75
+ )
76
+
77
+ # ------------------------------------------------------------------
78
+ # DCR — Distance to Closest Record
79
+ # ------------------------------------------------------------------
80
+ def dcr(self, sample_size: Optional[int] = None) -> Dict:
81
+ """Compute the distance from each synthetic row to the nearest original row.
82
+
83
+ Higher distances indicate better privacy (synthetic rows are not
84
+ trivially close to any real record).
85
+
86
+ Returns:
87
+ dict with ``mean``, ``median``, ``5th_percentile``, and ``distances``.
88
+ """
89
+ synth = self._synth_scaled
90
+ if sample_size and sample_size < len(synth):
91
+ idx = np.random.choice(len(synth), sample_size, replace=False)
92
+ synth = synth[idx]
93
+
94
+ nn = NearestNeighbors(n_neighbors=1, algorithm="auto")
95
+ nn.fit(self._orig_scaled)
96
+ distances, _ = nn.kneighbors(synth)
97
+ distances = distances.ravel()
98
+
99
+ return {
100
+ "mean": float(np.mean(distances)),
101
+ "median": float(np.median(distances)),
102
+ "5th_percentile": float(np.percentile(distances, 5)),
103
+ "distances": distances,
104
+ }
105
+
106
+ # ------------------------------------------------------------------
107
+ # NNDR — Nearest Neighbor Distance Ratio
108
+ # ------------------------------------------------------------------
109
+ def nndr(self, sample_size: Optional[int] = None) -> Dict:
110
+ """Nearest-neighbor distance ratio for each synthetic row.
111
+
112
+ Ratio = dist(nearest_original) / dist(2nd_nearest_original).
113
+ A ratio close to 1 means the synthetic row is equidistant to
114
+ multiple originals (lower risk); a ratio near 0 means it is
115
+ suspiciously close to exactly one real record.
116
+
117
+ Returns:
118
+ dict with ``mean``, ``median``, and ``ratios``.
119
+ """
120
+ synth = self._synth_scaled
121
+ if sample_size and sample_size < len(synth):
122
+ idx = np.random.choice(len(synth), sample_size, replace=False)
123
+ synth = synth[idx]
124
+
125
+ k = min(2, len(self._orig_scaled))
126
+ nn = NearestNeighbors(n_neighbors=k, algorithm="auto")
127
+ nn.fit(self._orig_scaled)
128
+ distances, _ = nn.kneighbors(synth)
129
+
130
+ if k < 2:
131
+ ratios = np.ones(len(synth))
132
+ else:
133
+ d1 = distances[:, 0]
134
+ d2 = np.where(distances[:, 1] == 0, 1e-10, distances[:, 1])
135
+ ratios = d1 / d2
136
+
137
+ return {
138
+ "mean": float(np.mean(ratios)),
139
+ "median": float(np.median(ratios)),
140
+ "ratios": ratios,
141
+ }
142
+
143
+ # ------------------------------------------------------------------
144
+ # Membership Inference Risk
145
+ # ------------------------------------------------------------------
146
+ def membership_inference_risk(self, holdout_frac: float = 0.3) -> Dict:
147
+ """Estimate membership inference risk.
148
+
149
+ Splits the original data into a *member* set (simulating the training
150
+ data the generator saw) and a *holdout* set. If the generator
151
+ memorised the members, synthetic rows will be closer to members than
152
+ to holdout rows. The risk is quantified as the AUC of a simple
153
+ classifier trained on this distance signal.
154
+
155
+ Returns:
156
+ dict with ``auc`` (0.5 = good privacy, 1.0 = full memorisation)
157
+ and ``accuracy``.
158
+ """
159
+ from sklearn.metrics import roc_auc_score
160
+ from sklearn.model_selection import cross_val_predict
161
+ from sklearn.linear_model import LogisticRegression
162
+
163
+ n = len(self._orig_scaled)
164
+ n_holdout = max(int(n * holdout_frac), 1)
165
+ perm = np.random.permutation(n)
166
+ member_idx = perm[n_holdout:]
167
+ holdout_idx = perm[:n_holdout]
168
+
169
+ members = self._orig_scaled[member_idx]
170
+ holdout = self._orig_scaled[holdout_idx]
171
+
172
+ # For each original row, compute distance to nearest synthetic
173
+ nn = NearestNeighbors(n_neighbors=1, algorithm="auto")
174
+ nn.fit(self._synth_scaled)
175
+
176
+ d_members, _ = nn.kneighbors(members)
177
+ d_holdout, _ = nn.kneighbors(holdout)
178
+
179
+ X = np.concatenate([d_members.ravel(), d_holdout.ravel()]).reshape(-1, 1)
180
+ y = np.concatenate([np.ones(len(d_members)), np.zeros(len(d_holdout))])
181
+
182
+ if len(np.unique(y)) < 2:
183
+ return {"auc": 0.5, "accuracy": 0.5}
184
+
185
+ clf = LogisticRegression(solver="lbfgs", max_iter=200)
186
+ try:
187
+ proba = cross_val_predict(clf, X, y, cv=min(3, len(y)), method="predict_proba")[:, 1]
188
+ auc = float(roc_auc_score(y, proba))
189
+ except Exception:
190
+ auc = 0.5
191
+
192
+ accuracy = float(np.mean((proba > 0.5) == y)) if 'proba' in dir() else 0.5
193
+
194
+ return {"auc": auc, "accuracy": accuracy}
195
+
196
+ # ------------------------------------------------------------------
197
+ # Summary
198
+ # ------------------------------------------------------------------
199
+ def summary(self) -> Dict:
200
+ """Aggregate all privacy metrics into a single report.
201
+
202
+ The ``overall_privacy_score`` ranges from 0 (high risk) to 1 (private).
203
+ """
204
+ dcr_res = self.dcr()
205
+ nndr_res = self.nndr()
206
+ mi_res = self.membership_inference_risk()
207
+
208
+ # Score components (each normalised to 0-1, higher = more private)
209
+ # DCR: 5th percentile > 0 is good; cap contribution at 1
210
+ dcr_score = min(dcr_res["5th_percentile"], 1.0)
211
+
212
+ # NNDR: mean closer to 1 is better
213
+ nndr_score = min(nndr_res["mean"], 1.0)
214
+
215
+ # MI: AUC closer to 0.5 is better → score = 1 - 2*|AUC - 0.5|
216
+ mi_score = max(1.0 - 2.0 * abs(mi_res["auc"] - 0.5), 0.0)
217
+
218
+ overall = 0.4 * dcr_score + 0.3 * nndr_score + 0.3 * mi_score
219
+
220
+ return {
221
+ "dcr": {k: v for k, v in dcr_res.items() if k != "distances"},
222
+ "nndr": {k: v for k, v in nndr_res.items() if k != "ratios"},
223
+ "membership_inference": mi_res,
224
+ "overall_privacy_score": round(overall, 4),
225
+ }