|
|
--- |
|
|
license: mit |
|
|
tags: |
|
|
- membership-inference-attack |
|
|
- privacy |
|
|
- security |
|
|
- language-models |
|
|
- pytorch |
|
|
pipeline_tag: other |
|
|
library_name: ltmia |
|
|
--- |
|
|
|
|
|
# Learned Transfer Membership Inference Attack |
|
|
|
|
|
A classifier that detects whether a given text was part of a language model's fine-tuning data. It compares the output distributions of a fine-tuned model against its pretrained base, extracting per-token features that a small transformer classifier uses to predict membership. Trained on 10 transformer models × 3 text domains, it generalizes zero-shot to unseen model/dataset combinations, including non-transformer architectures (Mamba, RWKV, RecurrentGemma). |
|
|
|
|
|
## Usage |
|
|
|
|
|
### Install |
|
|
|
|
|
```bash |
|
|
git clone https://github.com/JetBrains-Research/ltmia.git |
|
|
cd ltmia |
|
|
pip install -e . |
|
|
``` |
|
|
|
|
|
### Inference |
|
|
|
|
|
```python |
|
|
import torch |
|
|
from huggingface_hub import hf_hub_download |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
from ltmia import extract_per_token_features_both, create_mia_model |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
# 1. Load your base and fine-tuned models |
|
|
tokenizer = AutoTokenizer.from_pretrained("gpt2") |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
model_ref = AutoModelForCausalLM.from_pretrained("gpt2").to(device) |
|
|
model_tgt = AutoModelForCausalLM.from_pretrained("./my-finetuned-gpt2").to(device).eval() |
|
|
|
|
|
# 2. Extract features |
|
|
texts = ["Text you want to check...", "Another text..."] |
|
|
|
|
|
feats, masks, _ = extract_per_token_features_both( |
|
|
model_tgt, model_ref, tokenizer, texts, |
|
|
device=device, batch_size=8, sequence_length=128, k=20, |
|
|
) |
|
|
|
|
|
# 3. Load the MIA classifier |
|
|
ckpt_path = hf_hub_download( |
|
|
repo_id="JetBrains-Research/learned-transfer-attack", |
|
|
filename="mia_combined_400k.pt", |
|
|
) |
|
|
ckpt = torch.load(ckpt_path, map_location=device, weights_only=False) |
|
|
mia = create_mia_model( |
|
|
architecture=ckpt["architecture"], |
|
|
d_in=ckpt["d_in"], |
|
|
seq_len=ckpt.get("seq_len", 128), |
|
|
**ckpt["mia_hparams"], |
|
|
) |
|
|
mia.load_state_dict(ckpt["state_dict"]) |
|
|
mia.to(device).eval() |
|
|
|
|
|
# 4. Predict membership |
|
|
with torch.no_grad(): |
|
|
logits = mia( |
|
|
torch.from_numpy(feats).to(device), |
|
|
torch.from_numpy(masks).to(device), |
|
|
) |
|
|
probs = torch.sigmoid(logits) |
|
|
|
|
|
for text, p in zip(texts, probs): |
|
|
prob = p.item() |
|
|
label = "MEMBER" if prob > 0.5 else "NON-MEMBER" |
|
|
print(f"[{prob:.4f}] {label} ← {text[:80]}") |
|
|
``` |
|
|
|
|
|
You need black-box query access (full vocabulary logits) to both the fine-tuned model and its pretrained base. `sequence_length=128` and `k=20` must match this checkpoint. See the [GitHub repository](https://github.com/JetBrains-Research/ltmia) for CLI tools, training your own classifier, and evaluation scripts. |
|
|
|
|
|
|
|
|
## Model Details |
|
|
|
|
|
**Architecture:** Transformer encoder — 154→112 projection, 3 layers, 4 heads, FFN 224, attention pooling, ~340K parameters. |
|
|
|
|
|
**Input:** Per-token features (shape `N × 128 × 154`) comparing logits, ranks, and losses between target and reference models. |
|
|
|
|
|
**Output:** Membership probability per text (sigmoid of scalar logit). |
|
|
|
|
|
**Training data:** Features from 10 transformers (DistilGPT-2, GPT-2-XL, Pythia-1.4B, Cerebras-GPT-2.7B, GPT-J-6B, Gemma-2B, Qwen2-1.5B, MPT-7B, Falcon-RW-1B, Falcon-7B) fine-tuned on 3 datasets (News Category, Wikipedia, CNN/DailyMail). 18K samples per combination, 540K total. |
|
|
|
|
|
**Training:** AdamW, lr 5e-4, batch 16384, 100 epochs. Checkpoint selected by best validation AUC. |
|
|
|
|
|
|
|
|
## Evaluation (Out-of-Distribution) |
|
|
|
|
|
Performance on models and datasets **never seen** during classifier training: |
|
|
|
|
|
| Architecture | Model | Dataset | AUC | |
|
|
|---|---|---|---| |
|
|
| Transformer | GPT-2 | AG News | 0.945 | |
|
|
| Transformer | Pythia-2.8B | AG News | 0.911 | |
|
|
| Transformer | Mistral-7B | XSum | 0.989 | |
|
|
| Transformer | LLaMA-2-7B | AG News | 0.948 | |
|
|
| **Transformer mean** | (7 models × 4 datasets) | | **0.908** | |
|
|
| State-space | Mamba-2.8B | AG News | 0.969 | |
|
|
| State-space | Mamba-2.8B | WikiText | 0.995 | |
|
|
| Linear attention | RWKV-3B | AG News | 0.976 | |
|
|
| Linear attention | RWKV-3B | XSum | 0.998 | |
|
|
| Gated recurrence | RecurrentGemma-2B | AG News | 0.924 | |
|
|
| Gated recurrence | RecurrentGemma-2B | XSum | 0.988 | |
|
|
| **Non-transformer mean** | (3 models × 4 datasets) | | **0.957** | |
|
|
|
|
|
Transfer to code (Swallow-Code): 0.865 mean AUC despite training only on natural language. |
|
|
|
|
|
## License |
|
|
|
|
|
MIT |
|
|
|