| import torch |
| import sys |
| import argparse |
| import os |
| import pandas as pd |
| import matplotlib.pyplot as plt |
| from tqdm import tqdm |
| import gc |
| from numpy.random import default_rng |
| sys.path.append("..") |
| from transformers import ( |
| AutoTokenizer, |
| AutoModelForCausalLM, |
| Trainer, |
| TrainingArguments, |
| DataCollatorForLanguageModeling |
| ) |
| from datasets import load_dataset |
|
|
| os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
| class LayerMasker: |
| def __init__(self, model): |
| self.model = model |
| self.original_forwards = {} |
| |
| def _mask_ffn(self, layer, mask_strength=1.0): |
| """Correctly replace FFN forward propagation""" |
| layer_id = id(layer.mlp) |
| if layer_id not in self.original_forwards: |
| self.original_forwards[layer_id] = layer.mlp.forward |
| |
| def masked_forward(x): |
| return self.original_forwards[layer_id](x) * (1 - mask_strength) |
| |
| layer.mlp.forward = masked_forward |
| |
| def _mask_attn(self, layer, mask_strength=1.0): |
| """Correctly replace Attention forward propagation""" |
| layer_id = id(layer.attn) |
| if layer_id not in self.original_forwards: |
| self.original_forwards[layer_id] = layer.attn.forward |
| |
| def masked_forward(hidden_states, **kwargs): |
| output = self.original_forwards[layer_id](hidden_states, **kwargs) |
| masked_output = (output[0] * (1 - mask_strength),) + output[1:] |
| return masked_output |
| |
| layer.attn.forward = masked_forward |
| |
| def mask_layer(self, layer_idx, module_type="ffn", mask_strength=1.0): |
| layer = self.model.transformer.h[layer_idx] |
| if module_type == "ffn": |
| self._mask_ffn(layer, mask_strength) |
| elif module_type == "attn": |
| self._mask_attn(layer, mask_strength) |
| else: |
| raise ValueError(f"Invalid module type: {module_type}") |
| |
| def reset(self): |
| """Restore all modified forward methods""" |
| for layer in self.model.transformer.h: |
| layer_id = id(layer.mlp) |
| if layer_id in self.original_forwards: |
| layer.mlp.forward = self.original_forwards[layer_id] |
| |
| layer_id = id(layer.attn) |
| if layer_id in self.original_forwards: |
| layer.attn.forward = self.original_forwards[layer_id] |
| |
| self.original_forwards.clear() |
|
|
| @torch.no_grad() |
| def get_perplexities(model, eval_dataset, batch_size): |
| """Compute perplexity for the dataset""" |
| data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) |
|
|
| training_args = TrainingArguments( |
| output_dir="./tmp_trainer", |
| per_device_eval_batch_size=batch_size, |
| fp16=torch.cuda.is_available(), |
| report_to="none" |
| ) |
|
|
| trainer = Trainer( |
| model=model, |
| args=training_args, |
| eval_dataset=eval_dataset, |
| data_collator=data_collator |
| ) |
| |
| eval_results = trainer.evaluate() |
| loss = eval_results['eval_loss'] |
| return torch.exp(torch.tensor(loss)).item() |
|
|
| def run_mask_experiment(model, tokenized_data, args): |
| """Run layer masking experiment""" |
| results = [] |
| masker = LayerMasker(model) |
| |
| |
| baseline_ppl = get_perplexities(model, tokenized_data, args.batch_size) |
| print(f"Baseline Perplexity: {baseline_ppl:.2f}") |
| |
| |
| print("\nRunning layer masking experiments:") |
| for layer_idx in tqdm(range(model.config.n_layer), desc="Layers"): |
| for module_type in ["ffn", "attn"]: |
| |
| masker.mask_layer(layer_idx, module_type) |
| |
| |
| ppl = get_perplexities(model, tokenized_data, args.batch_size) |
| delta = ppl - baseline_ppl |
|
|
| gc.collect() |
| torch.cuda.empty_cache() |
| |
| results.append({ |
| "layer": layer_idx, |
| "module": module_type, |
| "perplexity": ppl, |
| "delta": delta |
| }) |
| |
| |
| masker.reset() |
| |
| return results, baseline_ppl |
|
|
| def visualize_results(results_df, output_dir): |
| """Visualize layer impact results""" |
| plt.figure(figsize=(12, 6)) |
| |
| |
| ffn_data = results_df[results_df["module"] == "ffn"] |
| plt.plot(ffn_data["layer"], ffn_data["delta"], |
| marker='o', linestyle='-', linewidth=2, markersize=8, |
| color='#1f77b4', label='FFN') |
| |
| |
| attn_data = results_df[results_df["module"] == "attn"] |
| plt.plot(attn_data["layer"], attn_data["delta"], |
| marker='s', linestyle='--', linewidth=2, markersize=8, |
| color='#ff7f0e', label='Attention') |
| |
| plt.xlabel("Layer Index", fontsize=12) |
| plt.ylabel("Δ Perplexity (log scale)", fontsize=12) |
| plt.title("GPT-2 Layer Masking Impact Analysis", fontsize=14) |
| plt.legend(fontsize=10) |
| plt.grid(True, alpha=0.3) |
| plt.xticks(range(0, results_df["layer"].max()+1)) |
| |
| |
| plt.yscale('log') |
|
|
| plot_path = os.path.join(output_dir, f"{args.perturbation}_layer_impact.png") |
| plt.savefig(plot_path, bbox_inches='tight', dpi=300) |
| plt.close() |
|
|
|
|
| if __name__ == "__main__": |
| |
| parser = argparse.ArgumentParser(description="GPT-2 Layer Masking Experiment") |
| parser.add_argument('--perturbation', type=str, default="hop_sg2pl", |
| help="Perturbation type (default: hop_sg2pl)") |
| parser.add_argument('--train_set', type=str, default="10M", |
| help="Training set size (default: 10M)") |
| parser.add_argument('--checkpoint_path', type=str, default="checkpoint-2736", |
| help="Model checkpoint path (default: checkpoint-2736)") |
| parser.add_argument('--batch_size', type=int, default=3, |
| help="Evaluation batch size (default: 3)") |
| parser.add_argument('--seed', type=int, default=0, |
| help="Random seed (default: 0)") |
| args = parser.parse_args() |
|
|
| |
| output_dir = f"mask_results/GPT2-MaskResults" |
| os.makedirs(output_dir, exist_ok=True) |
|
|
| |
| dataset_name = f"babylm_{args.perturbation}_{args.train_set}_seed{args.seed}" |
| dataset = load_dataset('../train/babylm_dataset_test.py', name=dataset_name, trust_remote_code=True) |
| test_dataset = dataset['test'] |
| |
| |
| rng = default_rng(args.seed) |
| indices = rng.choice(len(test_dataset), size=500, replace=False) |
| sampled_test = test_dataset.select(indices) |
|
|
| |
| checkpoint_path = f"../train/checkpoints/GPT-2/babylm_{args.perturbation}_10M_seed0/runs/{args.checkpoint_path}" |
| tokenizer = AutoTokenizer.from_pretrained(checkpoint_path) |
| model = AutoModelForCausalLM.from_pretrained(checkpoint_path, torch_dtype=torch.float16).eval() |
| if torch.cuda.is_available(): |
| model = model.cuda() |
|
|
| |
| def tokenize_fn(examples): |
| return tokenizer( |
| examples["text"], |
| padding="max_length", |
| truncation=True, |
| max_length=512, |
| return_tensors="pt" |
| ) |
| |
| tokenized_test = sampled_test.map( |
| tokenize_fn, |
| batched=True, |
| remove_columns=["text"], |
| desc="Tokenizing" |
| ) |
|
|
| |
| results, baseline_ppl = run_mask_experiment(model, tokenized_test, args) |
| results_df = pd.DataFrame(results) |
| |
| |
| output_file = os.path.join(output_dir, f"mask_results_{args.perturbation}.csv") |
| results_df.to_csv(output_file, index=False) |
| |
| |
| visualize_results(results_df, output_dir) |
| |
| |
| print(f"\nExperiment completed. Results saved to {output_dir}") |
| print(f"Baseline Perplexity: {baseline_ppl:.2f}") |
| print("Layer impact results:") |
| print(results_df[["layer", "module", "delta"]].to_string(index=False)) |
|
|
|
|