davidi123 commited on
Commit
14666fc
·
verified ·
1 Parent(s): 206314d

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +120 -0
README.md ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ tags:
4
+ - membership-inference-attack
5
+ - privacy
6
+ - security
7
+ - language-models
8
+ - pytorch
9
+ pipeline_tag: other
10
+ library_name: ltmia
11
+ ---
12
+
13
+ # Learned Transfer Membership Inference Attack
14
+
15
+ A classifier that detects whether a given text was part of a language model's fine-tuning data. It compares the output distributions of a fine-tuned model against its pretrained base, extracting per-token features that a small transformer classifier uses to predict membership. Trained on 10 transformer models × 3 text domains, it generalizes zero-shot to unseen model/dataset combinations, including non-transformer architectures (Mamba, RWKV, RecurrentGemma).
16
+
17
+ ## Usage
18
+
19
+ ### Install
20
+
21
+ ```bash
22
+ git clone https://github.com/JetBrains-Research/ltmia.git
23
+ cd ltmia
24
+ pip install -e .
25
+ ```
26
+
27
+ ### Inference
28
+
29
+ ```python
30
+ import torch
31
+ from huggingface_hub import hf_hub_download
32
+ from transformers import AutoTokenizer, AutoModelForCausalLM
33
+ from ltmia import extract_per_token_features_both, create_mia_model
34
+
35
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
+
37
+ # 1. Load your base and fine-tuned models
38
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
39
+ tokenizer.pad_token = tokenizer.eos_token
40
+
41
+ model_ref = AutoModelForCausalLM.from_pretrained("gpt2").to(device)
42
+ model_tgt = AutoModelForCausalLM.from_pretrained("./my-finetuned-gpt2").to(device).eval()
43
+
44
+ # 2. Extract features
45
+ texts = ["Text you want to check...", "Another text..."]
46
+
47
+ feats, masks, _ = extract_per_token_features_both(
48
+ model_tgt, model_ref, tokenizer, texts,
49
+ device=device, batch_size=8, sequence_length=128, k=20,
50
+ )
51
+
52
+ # 3. Load the MIA classifier
53
+ ckpt_path = hf_hub_download(
54
+ repo_id="JetBrains-Research/learned-transfer-attack",
55
+ filename="mia_combined_400k.pt",
56
+ )
57
+ ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
58
+ mia = create_mia_model(
59
+ architecture=ckpt["architecture"],
60
+ d_in=ckpt["d_in"],
61
+ seq_len=ckpt.get("seq_len", 128),
62
+ **ckpt["mia_hparams"],
63
+ )
64
+ mia.load_state_dict(ckpt["state_dict"])
65
+ mia.to(device).eval()
66
+
67
+ # 4. Predict membership
68
+ with torch.no_grad():
69
+ logits = mia(
70
+ torch.from_numpy(feats).to(device),
71
+ torch.from_numpy(masks).to(device),
72
+ )
73
+ probs = torch.sigmoid(logits)
74
+
75
+ for text, p in zip(texts, probs):
76
+ prob = p.item()
77
+ label = "MEMBER" if prob > 0.5 else "NON-MEMBER"
78
+ print(f"[{prob:.4f}] {label} ← {text[:80]}")
79
+ ```
80
+
81
+ You need black-box query access (full vocabulary logits) to both the fine-tuned model and its pretrained base. `sequence_length=128` and `k=20` must match this checkpoint. See the [GitHub repository](https://github.com/JetBrains-Research/ltmia) for CLI tools, training your own classifier, and evaluation scripts.
82
+
83
+
84
+ ## Model Details
85
+
86
+ **Architecture:** Transformer encoder — 154→112 projection, 3 layers, 4 heads, FFN 224, attention pooling, ~340K parameters.
87
+
88
+ **Input:** Per-token features (shape `N × 128 × 154`) comparing logits, ranks, and losses between target and reference models.
89
+
90
+ **Output:** Membership probability per text (sigmoid of scalar logit).
91
+
92
+ **Training data:** Features from 10 transformers (DistilGPT-2, GPT-2-XL, Pythia-1.4B, Cerebras-GPT-2.7B, GPT-J-6B, Gemma-2B, Qwen2-1.5B, MPT-7B, Falcon-RW-1B, Falcon-7B) fine-tuned on 3 datasets (News Category, Wikipedia, CNN/DailyMail). 18K samples per combination, 540K total.
93
+
94
+ **Training:** AdamW, lr 5e-4, batch 16384, 100 epochs. Checkpoint selected by best validation AUC.
95
+
96
+
97
+ ## Evaluation (Out-of-Distribution)
98
+
99
+ Performance on models and datasets **never seen** during classifier training:
100
+
101
+ | Architecture | Model | Dataset | AUC |
102
+ |---|---|---|---|
103
+ | Transformer | GPT-2 | AG News | 0.945 |
104
+ | Transformer | Pythia-2.8B | AG News | 0.911 |
105
+ | Transformer | Mistral-7B | XSum | 0.989 |
106
+ | Transformer | LLaMA-2-7B | AG News | 0.948 |
107
+ | **Transformer mean** | (7 models × 4 datasets) | | **0.908** |
108
+ | State-space | Mamba-2.8B | AG News | 0.969 |
109
+ | State-space | Mamba-2.8B | WikiText | 0.995 |
110
+ | Linear attention | RWKV-3B | AG News | 0.976 |
111
+ | Linear attention | RWKV-3B | XSum | 0.998 |
112
+ | Gated recurrence | RecurrentGemma-2B | AG News | 0.924 |
113
+ | Gated recurrence | RecurrentGemma-2B | XSum | 0.988 |
114
+ | **Non-transformer mean** | (3 models × 4 datasets) | | **0.957** |
115
+
116
+ Transfer to code (Swallow-Code): 0.865 mean AUC despite training only on natural language.
117
+
118
+ ## License
119
+
120
+ MIT