fungi00 commited on
Commit
ee7e3b5
·
verified ·
1 Parent(s): ee1cb12

Upload mia.py

Browse files
Files changed (1) hide show
  1. mia.py +203 -0
mia.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MIA Calibrator for the Dependable FungAI Project.
3
+
4
+ This script performs a post-training analysis to determine the optimal
5
+ threshold for Membership Inference Attacks (MIA). It collects the two most
6
+ reliable metrics: Confidence Score and Loss Value.
7
+
8
+ It uses a balanced logistic regression model on confidence scores to
9
+ determine a single, robust membership threshold.
10
+
11
+ The script outputs:
12
+ 1. A visualization of the confidence score distributions.
13
+ 2. An updated TBOM.json file with a detailed 'membership_inference_analysis' key,
14
+ containing the calculated threshold and summary statistics for the collected metrics.
15
+
16
+ This script should be run AFTER TBOM.py has successfully completed.
17
+ """
18
+
19
+ import os
20
+ import json
21
+ import torch
22
+ import torch.nn as nn
23
+ import numpy as np
24
+ import pandas as pd
25
+ import matplotlib.pyplot as plt
26
+ from sklearn.linear_model import LogisticRegression
27
+ from sklearn.metrics import roc_auc_score
28
+ from torchvision import datasets
29
+ from torch.utils.data import DataLoader, Subset
30
+ import warnings
31
+
32
+ # Suppress warnings for clean output
33
+ warnings.filterwarnings('ignore')
34
+
35
+ # --- Import from existing project files ---
36
+ from TBOM import HybridMLP
37
+ from IBOM import DetailedIBOMGenerator, get_paths_with_smart_fallback
38
+
39
+ def get_mia_metrics(model_wrapper, dataset, batch_size=32):
40
+ """
41
+ Runs inference on a dataset and returns a list of dictionaries,
42
+ each containing MIA metrics (confidence, loss) for a sample.
43
+ """
44
+ print(f"Calculating MIA metrics (Loss, Confidence) for {len(dataset)} samples...")
45
+ all_metrics = []
46
+ data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
47
+ criterion = nn.BCEWithLogitsLoss(reduction='none') # Use 'none' to get per-sample loss
48
+
49
+ model = model_wrapper.model
50
+ model.eval()
51
+
52
+ with torch.no_grad():
53
+ for i, (images, labels) in enumerate(data_loader):
54
+ images = images.to(model_wrapper.device)
55
+ labels_float = labels.unsqueeze(1).float().to(model_wrapper.device)
56
+
57
+ # --- Feature Extraction (mirrors IBOM) ---
58
+ img_features = model_wrapper.clip_model.encode_image(images).float()
59
+ img_features_norm = img_features / img_features.norm(dim=1, keepdim=True)
60
+ concept_scores = img_features_norm @ model_wrapper.text_embeddings.T
61
+ hybrid_features = torch.cat([img_features, concept_scores], dim=1)
62
+
63
+ # --- Metric Calculation ---
64
+ logits = model(hybrid_features)
65
+ probabilities = torch.sigmoid(logits)
66
+
67
+ # 1. Confidence Score
68
+ confidences = torch.max(probabilities, 1 - probabilities).squeeze()
69
+
70
+ # 2. Loss Value (per-sample)
71
+ losses = criterion(logits, labels_float).squeeze()
72
+
73
+ # --- Store metrics for each sample in the batch ---
74
+ for j in range(images.size(0)):
75
+ all_metrics.append({
76
+ 'confidence': confidences[j].item(),
77
+ 'loss': losses[j].item()
78
+ })
79
+
80
+ if (i + 1) % 50 == 0:
81
+ print(f" Processed { (i + 1) * batch_size } / {len(dataset)} samples...")
82
+
83
+ return all_metrics
84
+
85
+ def find_best_threshold_lr_balanced(member_scores, non_member_scores):
86
+ """
87
+ Trains a balanced logistic regression model on confidence scores to find an MIA threshold.
88
+ """
89
+ print("\nFinding optimal threshold using a BALANCED logistic regression attack model...")
90
+ X = np.concatenate([member_scores, non_member_scores]).reshape(-1, 1)
91
+ y = np.concatenate([np.ones_like(member_scores), np.zeros_like(non_member_scores)])
92
+
93
+ attack_model = LogisticRegression(solver='liblinear', class_weight='balanced')
94
+ attack_model.fit(X, y)
95
+
96
+ intercept = attack_model.intercept_[0]
97
+ coef = attack_model.coef_[0][0]
98
+ threshold = -intercept / coef if coef != 0 else 0.5
99
+
100
+ attack_probs = attack_model.predict_proba(X)[:, 1]
101
+ attack_auc = roc_auc_score(y, attack_probs)
102
+
103
+ print(f"BALANCED Attack Model Trained. Optimal Threshold: {threshold:.4f}, Attack AUC: {attack_auc:.4f}")
104
+ return float(threshold), float(attack_auc)
105
+
106
+ def visualize_distributions(member_scores, non_member_scores, threshold, output_path):
107
+ """
108
+ Creates and saves a histogram plot of the confidence distributions.
109
+ """
110
+ print(f"\nGenerating and saving visualization to {output_path}...")
111
+ plt.style.use('seaborn-v0_8-whitegrid')
112
+ plt.figure(figsize=(12, 7))
113
+
114
+ plt.hist(non_member_scores, bins=50, density=True, alpha=0.7, label='Non-Members (Test Set)', color='darkorange')
115
+ plt.hist(member_scores, bins=50, density=True, alpha=0.7, label='Members (Train Set)', color='royalblue')
116
+
117
+ plt.axvline(threshold, color='crimson', linestyle='--', linewidth=2.5, label=f'Decision Threshold ({threshold:.3f})')
118
+
119
+ plt.title('Confidence Score Distributions: Members vs. Non-Members', fontsize=16, fontweight='bold')
120
+ plt.xlabel('Model Prediction Confidence', fontsize=12)
121
+ plt.ylabel('Density', fontsize=12)
122
+ plt.legend(fontsize=11)
123
+ plt.xlim(0.5, 1.0)
124
+ plt.tight_layout()
125
+
126
+ plt.savefig(output_path, dpi=300)
127
+ plt.close()
128
+
129
+ def main():
130
+ """
131
+ Main function to orchestrate the MIA calibration process.
132
+ """
133
+ print("--- Starting Membership Inference Attack (MIA) Calibration ---")
134
+
135
+ try:
136
+ tbom_path, model_path, csv_path = get_paths_with_smart_fallback()
137
+ output_dir = os.path.dirname(tbom_path)
138
+
139
+ print(f"Loading data splits from {tbom_path}")
140
+ with open(tbom_path, 'r') as f:
141
+ tbom_data = json.load(f)
142
+
143
+ train_val_indices = tbom_data['data_summary']['data_splits']['train_validation_set']['indices']
144
+ test_indices = tbom_data['data_summary']['data_splits']['test_set']['indices']
145
+ image_root = tbom_data['data_summary']['image_dataset_path']
146
+
147
+ print("\nInitializing model and data pipeline...")
148
+ ibom_generator = DetailedIBOMGenerator(model_path, tbom_path, csv_path)
149
+
150
+ full_dataset = datasets.ImageFolder(root=image_root, transform=ibom_generator.preprocess)
151
+ member_dataset = Subset(full_dataset, train_val_indices)
152
+ non_member_dataset = Subset(full_dataset, test_indices)
153
+
154
+ member_metrics = get_mia_metrics(ibom_generator, member_dataset)
155
+ non_member_metrics = get_mia_metrics(ibom_generator, non_member_dataset)
156
+
157
+ member_confidences = np.array([m['confidence'] for m in member_metrics])
158
+ non_member_confidences = np.array([m['confidence'] for m in non_member_metrics])
159
+
160
+ final_threshold, attack_auc = find_best_threshold_lr_balanced(member_confidences, non_member_confidences)
161
+
162
+ print(f"\n--- Final Threshold Selected ---")
163
+ print(f" Method: Balanced Logistic Regression")
164
+ print(f" Threshold (on Confidence Score): {final_threshold:.4f}")
165
+
166
+ viz_path = os.path.join(output_dir, 'mia_confidence_distribution.png')
167
+ visualize_distributions(member_confidences, non_member_confidences, final_threshold, viz_path)
168
+
169
+ print(f"\nUpdating {tbom_path} with detailed MIA results...")
170
+
171
+ member_df = pd.DataFrame(member_metrics)
172
+ non_member_df = pd.DataFrame(non_member_metrics)
173
+
174
+ tbom_data['membership_inference_analysis'] = {
175
+ 'description': "Analysis to distinguish members from non-members using Loss and Confidence metrics.",
176
+ 'threshold_finding_summary': {
177
+ 'metric_used_for_threshold': 'confidence_score',
178
+ 'method': "A balanced logistic regression was used to find the optimal threshold.",
179
+ 'decision_threshold': final_threshold,
180
+ 'attack_auc_score_on_confidence': attack_auc,
181
+ },
182
+ 'metric_statistics': {
183
+ 'members': member_df.describe().to_dict(),
184
+ 'non_members': non_member_df.describe().to_dict()
185
+ },
186
+ 'interpretation': "Clear separation in metric distributions (lower loss and higher confidence for members) indicates data memorization.",
187
+ 'visualization_artifact': viz_path
188
+ }
189
+
190
+ with open(tbom_path, 'w') as f:
191
+ json.dump(tbom_data, f, indent=4)
192
+
193
+ print("\n--- MIA Calibration Complete! ---")
194
+ print(f"✅ TBOM file successfully updated with focused membership analysis.")
195
+ print(f"✅ Distribution plot saved to {viz_path}")
196
+
197
+ except Exception as e:
198
+ print(f"\n❌ An error occurred during MIA calibration: {e}")
199
+ import traceback
200
+ traceback.print_exc()
201
+
202
+ if __name__ == "__main__":
203
+ main()