Spaces:
Sleeping
Sleeping
Upload mia.py
Browse files
mia.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MIA Calibrator for the Dependable FungAI Project.
|
| 3 |
+
|
| 4 |
+
This script performs a post-training analysis to determine the optimal
|
| 5 |
+
threshold for Membership Inference Attacks (MIA). It collects the two most
|
| 6 |
+
reliable metrics: Confidence Score and Loss Value.
|
| 7 |
+
|
| 8 |
+
It uses a balanced logistic regression model on confidence scores to
|
| 9 |
+
determine a single, robust membership threshold.
|
| 10 |
+
|
| 11 |
+
The script outputs:
|
| 12 |
+
1. A visualization of the confidence score distributions.
|
| 13 |
+
2. An updated TBOM.json file with a detailed 'membership_inference_analysis' key,
|
| 14 |
+
containing the calculated threshold and summary statistics for the collected metrics.
|
| 15 |
+
|
| 16 |
+
This script should be run AFTER TBOM.py has successfully completed.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import os
|
| 20 |
+
import json
|
| 21 |
+
import torch
|
| 22 |
+
import torch.nn as nn
|
| 23 |
+
import numpy as np
|
| 24 |
+
import pandas as pd
|
| 25 |
+
import matplotlib.pyplot as plt
|
| 26 |
+
from sklearn.linear_model import LogisticRegression
|
| 27 |
+
from sklearn.metrics import roc_auc_score
|
| 28 |
+
from torchvision import datasets
|
| 29 |
+
from torch.utils.data import DataLoader, Subset
|
| 30 |
+
import warnings
|
| 31 |
+
|
| 32 |
+
# Suppress warnings for clean output
|
| 33 |
+
warnings.filterwarnings('ignore')
|
| 34 |
+
|
| 35 |
+
# --- Import from existing project files ---
|
| 36 |
+
from TBOM import HybridMLP
|
| 37 |
+
from IBOM import DetailedIBOMGenerator, get_paths_with_smart_fallback
|
| 38 |
+
|
| 39 |
+
def get_mia_metrics(model_wrapper, dataset, batch_size=32):
|
| 40 |
+
"""
|
| 41 |
+
Runs inference on a dataset and returns a list of dictionaries,
|
| 42 |
+
each containing MIA metrics (confidence, loss) for a sample.
|
| 43 |
+
"""
|
| 44 |
+
print(f"Calculating MIA metrics (Loss, Confidence) for {len(dataset)} samples...")
|
| 45 |
+
all_metrics = []
|
| 46 |
+
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
| 47 |
+
criterion = nn.BCEWithLogitsLoss(reduction='none') # Use 'none' to get per-sample loss
|
| 48 |
+
|
| 49 |
+
model = model_wrapper.model
|
| 50 |
+
model.eval()
|
| 51 |
+
|
| 52 |
+
with torch.no_grad():
|
| 53 |
+
for i, (images, labels) in enumerate(data_loader):
|
| 54 |
+
images = images.to(model_wrapper.device)
|
| 55 |
+
labels_float = labels.unsqueeze(1).float().to(model_wrapper.device)
|
| 56 |
+
|
| 57 |
+
# --- Feature Extraction (mirrors IBOM) ---
|
| 58 |
+
img_features = model_wrapper.clip_model.encode_image(images).float()
|
| 59 |
+
img_features_norm = img_features / img_features.norm(dim=1, keepdim=True)
|
| 60 |
+
concept_scores = img_features_norm @ model_wrapper.text_embeddings.T
|
| 61 |
+
hybrid_features = torch.cat([img_features, concept_scores], dim=1)
|
| 62 |
+
|
| 63 |
+
# --- Metric Calculation ---
|
| 64 |
+
logits = model(hybrid_features)
|
| 65 |
+
probabilities = torch.sigmoid(logits)
|
| 66 |
+
|
| 67 |
+
# 1. Confidence Score
|
| 68 |
+
confidences = torch.max(probabilities, 1 - probabilities).squeeze()
|
| 69 |
+
|
| 70 |
+
# 2. Loss Value (per-sample)
|
| 71 |
+
losses = criterion(logits, labels_float).squeeze()
|
| 72 |
+
|
| 73 |
+
# --- Store metrics for each sample in the batch ---
|
| 74 |
+
for j in range(images.size(0)):
|
| 75 |
+
all_metrics.append({
|
| 76 |
+
'confidence': confidences[j].item(),
|
| 77 |
+
'loss': losses[j].item()
|
| 78 |
+
})
|
| 79 |
+
|
| 80 |
+
if (i + 1) % 50 == 0:
|
| 81 |
+
print(f" Processed { (i + 1) * batch_size } / {len(dataset)} samples...")
|
| 82 |
+
|
| 83 |
+
return all_metrics
|
| 84 |
+
|
| 85 |
+
def find_best_threshold_lr_balanced(member_scores, non_member_scores):
|
| 86 |
+
"""
|
| 87 |
+
Trains a balanced logistic regression model on confidence scores to find an MIA threshold.
|
| 88 |
+
"""
|
| 89 |
+
print("\nFinding optimal threshold using a BALANCED logistic regression attack model...")
|
| 90 |
+
X = np.concatenate([member_scores, non_member_scores]).reshape(-1, 1)
|
| 91 |
+
y = np.concatenate([np.ones_like(member_scores), np.zeros_like(non_member_scores)])
|
| 92 |
+
|
| 93 |
+
attack_model = LogisticRegression(solver='liblinear', class_weight='balanced')
|
| 94 |
+
attack_model.fit(X, y)
|
| 95 |
+
|
| 96 |
+
intercept = attack_model.intercept_[0]
|
| 97 |
+
coef = attack_model.coef_[0][0]
|
| 98 |
+
threshold = -intercept / coef if coef != 0 else 0.5
|
| 99 |
+
|
| 100 |
+
attack_probs = attack_model.predict_proba(X)[:, 1]
|
| 101 |
+
attack_auc = roc_auc_score(y, attack_probs)
|
| 102 |
+
|
| 103 |
+
print(f"BALANCED Attack Model Trained. Optimal Threshold: {threshold:.4f}, Attack AUC: {attack_auc:.4f}")
|
| 104 |
+
return float(threshold), float(attack_auc)
|
| 105 |
+
|
| 106 |
+
def visualize_distributions(member_scores, non_member_scores, threshold, output_path):
|
| 107 |
+
"""
|
| 108 |
+
Creates and saves a histogram plot of the confidence distributions.
|
| 109 |
+
"""
|
| 110 |
+
print(f"\nGenerating and saving visualization to {output_path}...")
|
| 111 |
+
plt.style.use('seaborn-v0_8-whitegrid')
|
| 112 |
+
plt.figure(figsize=(12, 7))
|
| 113 |
+
|
| 114 |
+
plt.hist(non_member_scores, bins=50, density=True, alpha=0.7, label='Non-Members (Test Set)', color='darkorange')
|
| 115 |
+
plt.hist(member_scores, bins=50, density=True, alpha=0.7, label='Members (Train Set)', color='royalblue')
|
| 116 |
+
|
| 117 |
+
plt.axvline(threshold, color='crimson', linestyle='--', linewidth=2.5, label=f'Decision Threshold ({threshold:.3f})')
|
| 118 |
+
|
| 119 |
+
plt.title('Confidence Score Distributions: Members vs. Non-Members', fontsize=16, fontweight='bold')
|
| 120 |
+
plt.xlabel('Model Prediction Confidence', fontsize=12)
|
| 121 |
+
plt.ylabel('Density', fontsize=12)
|
| 122 |
+
plt.legend(fontsize=11)
|
| 123 |
+
plt.xlim(0.5, 1.0)
|
| 124 |
+
plt.tight_layout()
|
| 125 |
+
|
| 126 |
+
plt.savefig(output_path, dpi=300)
|
| 127 |
+
plt.close()
|
| 128 |
+
|
| 129 |
+
def main():
|
| 130 |
+
"""
|
| 131 |
+
Main function to orchestrate the MIA calibration process.
|
| 132 |
+
"""
|
| 133 |
+
print("--- Starting Membership Inference Attack (MIA) Calibration ---")
|
| 134 |
+
|
| 135 |
+
try:
|
| 136 |
+
tbom_path, model_path, csv_path = get_paths_with_smart_fallback()
|
| 137 |
+
output_dir = os.path.dirname(tbom_path)
|
| 138 |
+
|
| 139 |
+
print(f"Loading data splits from {tbom_path}")
|
| 140 |
+
with open(tbom_path, 'r') as f:
|
| 141 |
+
tbom_data = json.load(f)
|
| 142 |
+
|
| 143 |
+
train_val_indices = tbom_data['data_summary']['data_splits']['train_validation_set']['indices']
|
| 144 |
+
test_indices = tbom_data['data_summary']['data_splits']['test_set']['indices']
|
| 145 |
+
image_root = tbom_data['data_summary']['image_dataset_path']
|
| 146 |
+
|
| 147 |
+
print("\nInitializing model and data pipeline...")
|
| 148 |
+
ibom_generator = DetailedIBOMGenerator(model_path, tbom_path, csv_path)
|
| 149 |
+
|
| 150 |
+
full_dataset = datasets.ImageFolder(root=image_root, transform=ibom_generator.preprocess)
|
| 151 |
+
member_dataset = Subset(full_dataset, train_val_indices)
|
| 152 |
+
non_member_dataset = Subset(full_dataset, test_indices)
|
| 153 |
+
|
| 154 |
+
member_metrics = get_mia_metrics(ibom_generator, member_dataset)
|
| 155 |
+
non_member_metrics = get_mia_metrics(ibom_generator, non_member_dataset)
|
| 156 |
+
|
| 157 |
+
member_confidences = np.array([m['confidence'] for m in member_metrics])
|
| 158 |
+
non_member_confidences = np.array([m['confidence'] for m in non_member_metrics])
|
| 159 |
+
|
| 160 |
+
final_threshold, attack_auc = find_best_threshold_lr_balanced(member_confidences, non_member_confidences)
|
| 161 |
+
|
| 162 |
+
print(f"\n--- Final Threshold Selected ---")
|
| 163 |
+
print(f" Method: Balanced Logistic Regression")
|
| 164 |
+
print(f" Threshold (on Confidence Score): {final_threshold:.4f}")
|
| 165 |
+
|
| 166 |
+
viz_path = os.path.join(output_dir, 'mia_confidence_distribution.png')
|
| 167 |
+
visualize_distributions(member_confidences, non_member_confidences, final_threshold, viz_path)
|
| 168 |
+
|
| 169 |
+
print(f"\nUpdating {tbom_path} with detailed MIA results...")
|
| 170 |
+
|
| 171 |
+
member_df = pd.DataFrame(member_metrics)
|
| 172 |
+
non_member_df = pd.DataFrame(non_member_metrics)
|
| 173 |
+
|
| 174 |
+
tbom_data['membership_inference_analysis'] = {
|
| 175 |
+
'description': "Analysis to distinguish members from non-members using Loss and Confidence metrics.",
|
| 176 |
+
'threshold_finding_summary': {
|
| 177 |
+
'metric_used_for_threshold': 'confidence_score',
|
| 178 |
+
'method': "A balanced logistic regression was used to find the optimal threshold.",
|
| 179 |
+
'decision_threshold': final_threshold,
|
| 180 |
+
'attack_auc_score_on_confidence': attack_auc,
|
| 181 |
+
},
|
| 182 |
+
'metric_statistics': {
|
| 183 |
+
'members': member_df.describe().to_dict(),
|
| 184 |
+
'non_members': non_member_df.describe().to_dict()
|
| 185 |
+
},
|
| 186 |
+
'interpretation': "Clear separation in metric distributions (lower loss and higher confidence for members) indicates data memorization.",
|
| 187 |
+
'visualization_artifact': viz_path
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
with open(tbom_path, 'w') as f:
|
| 191 |
+
json.dump(tbom_data, f, indent=4)
|
| 192 |
+
|
| 193 |
+
print("\n--- MIA Calibration Complete! ---")
|
| 194 |
+
print(f"✅ TBOM file successfully updated with focused membership analysis.")
|
| 195 |
+
print(f"✅ Distribution plot saved to {viz_path}")
|
| 196 |
+
|
| 197 |
+
except Exception as e:
|
| 198 |
+
print(f"\n❌ An error occurred during MIA calibration: {e}")
|
| 199 |
+
import traceback
|
| 200 |
+
traceback.print_exc()
|
| 201 |
+
|
| 202 |
+
if __name__ == "__main__":
|
| 203 |
+
main()
|