--- license: mit tags: - membership-inference-attack - privacy - security - language-models - pytorch pipeline_tag: other library_name: ltmia --- # Learned Transfer Membership Inference Attack A classifier that detects whether a given text was part of a language model's fine-tuning data. It compares the output distributions of a fine-tuned model against its pretrained base, extracting per-token features that a small transformer classifier uses to predict membership. Trained on 10 transformer models × 3 text domains, it generalizes zero-shot to unseen model/dataset combinations, including non-transformer architectures (Mamba, RWKV, RecurrentGemma). ## Usage ### Install ```bash git clone https://github.com/JetBrains-Research/ltmia.git cd ltmia pip install -e . ``` ### Inference ```python import torch from huggingface_hub import hf_hub_download from transformers import AutoTokenizer, AutoModelForCausalLM from ltmia import extract_per_token_features_both, create_mia_model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 1. Load your base and fine-tuned models tokenizer = AutoTokenizer.from_pretrained("gpt2") tokenizer.pad_token = tokenizer.eos_token model_ref = AutoModelForCausalLM.from_pretrained("gpt2").to(device) model_tgt = AutoModelForCausalLM.from_pretrained("./my-finetuned-gpt2").to(device).eval() # 2. Extract features texts = ["Text you want to check...", "Another text..."] feats, masks, _ = extract_per_token_features_both( model_tgt, model_ref, tokenizer, texts, device=device, batch_size=8, sequence_length=128, k=20, ) # 3. Load the MIA classifier ckpt_path = hf_hub_download( repo_id="JetBrains-Research/learned-transfer-attack", filename="mia_combined_400k.pt", ) ckpt = torch.load(ckpt_path, map_location=device, weights_only=False) mia = create_mia_model( architecture=ckpt["architecture"], d_in=ckpt["d_in"], seq_len=ckpt.get("seq_len", 128), **ckpt["mia_hparams"], ) mia.load_state_dict(ckpt["state_dict"]) mia.to(device).eval() # 4. Predict membership with torch.no_grad(): logits = mia( torch.from_numpy(feats).to(device), torch.from_numpy(masks).to(device), ) probs = torch.sigmoid(logits) for text, p in zip(texts, probs): prob = p.item() label = "MEMBER" if prob > 0.5 else "NON-MEMBER" print(f"[{prob:.4f}] {label} ← {text[:80]}") ``` You need black-box query access (full vocabulary logits) to both the fine-tuned model and its pretrained base. `sequence_length=128` and `k=20` must match this checkpoint. See the [GitHub repository](https://github.com/JetBrains-Research/ltmia) for CLI tools, training your own classifier, and evaluation scripts. ## Model Details **Architecture:** Transformer encoder — 154→112 projection, 3 layers, 4 heads, FFN 224, attention pooling, ~340K parameters. **Input:** Per-token features (shape `N × 128 × 154`) comparing logits, ranks, and losses between target and reference models. **Output:** Membership probability per text (sigmoid of scalar logit). **Training data:** Features from 10 transformers (DistilGPT-2, GPT-2-XL, Pythia-1.4B, Cerebras-GPT-2.7B, GPT-J-6B, Gemma-2B, Qwen2-1.5B, MPT-7B, Falcon-RW-1B, Falcon-7B) fine-tuned on 3 datasets (News Category, Wikipedia, CNN/DailyMail). 18K samples per combination, 540K total. **Training:** AdamW, lr 5e-4, batch 16384, 100 epochs. Checkpoint selected by best validation AUC. ## Evaluation (Out-of-Distribution) Performance on models and datasets **never seen** during classifier training: | Architecture | Model | Dataset | AUC | |---|---|---|---| | Transformer | GPT-2 | AG News | 0.945 | | Transformer | Pythia-2.8B | AG News | 0.911 | | Transformer | Mistral-7B | XSum | 0.989 | | Transformer | LLaMA-2-7B | AG News | 0.948 | | **Transformer mean** | (7 models × 4 datasets) | | **0.908** | | State-space | Mamba-2.8B | AG News | 0.969 | | State-space | Mamba-2.8B | WikiText | 0.995 | | Linear attention | RWKV-3B | AG News | 0.976 | | Linear attention | RWKV-3B | XSum | 0.998 | | Gated recurrence | RecurrentGemma-2B | AG News | 0.924 | | Gated recurrence | RecurrentGemma-2B | XSum | 0.988 | | **Non-transformer mean** | (3 models × 4 datasets) | | **0.957** | Transfer to code (Swallow-Code): 0.865 mean AUC despite training only on natural language. ## License MIT