Yaning1001's picture
Add files using upload-large-folder tool
69168b6 verified
import torch
import sys
import argparse
import os
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
import gc
from numpy.random import default_rng
sys.path.append("..")
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
Trainer,
TrainingArguments,
DataCollatorForLanguageModeling
)
from datasets import load_dataset
os.environ["TOKENIZERS_PARALLELISM"] = "false"
class LayerMasker:
def __init__(self, model):
self.model = model
self.original_forwards = {} # Store original forward methods
def _mask_ffn(self, layer, mask_strength=1.0):
"""Correctly replace FFN forward propagation"""
layer_id = id(layer.mlp)
if layer_id not in self.original_forwards:
self.original_forwards[layer_id] = layer.mlp.forward # Store original function
def masked_forward(x):
return self.original_forwards[layer_id](x) * (1 - mask_strength)
layer.mlp.forward = masked_forward # Override with masked function
def _mask_attn(self, layer, mask_strength=1.0):
"""Correctly replace Attention forward propagation"""
layer_id = id(layer.attn)
if layer_id not in self.original_forwards:
self.original_forwards[layer_id] = layer.attn.forward # Store original function
def masked_forward(hidden_states, **kwargs):
output = self.original_forwards[layer_id](hidden_states, **kwargs)
masked_output = (output[0] * (1 - mask_strength),) + output[1:]
return masked_output
layer.attn.forward = masked_forward # Override with masked function
def mask_layer(self, layer_idx, module_type="ffn", mask_strength=1.0):
layer = self.model.transformer.h[layer_idx]
if module_type == "ffn":
self._mask_ffn(layer, mask_strength)
elif module_type == "attn":
self._mask_attn(layer, mask_strength)
else:
raise ValueError(f"Invalid module type: {module_type}")
def reset(self):
"""Restore all modified forward methods"""
for layer in self.model.transformer.h:
layer_id = id(layer.mlp)
if layer_id in self.original_forwards:
layer.mlp.forward = self.original_forwards[layer_id]
layer_id = id(layer.attn)
if layer_id in self.original_forwards:
layer.attn.forward = self.original_forwards[layer_id]
self.original_forwards.clear()
@torch.no_grad() # Disable gradient computation
def get_perplexities(model, eval_dataset, batch_size):
"""Compute perplexity for the dataset"""
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
training_args = TrainingArguments(
output_dir="./tmp_trainer",
per_device_eval_batch_size=batch_size,
fp16=torch.cuda.is_available(),
report_to="none"
)
trainer = Trainer(
model=model,
args=training_args,
eval_dataset=eval_dataset,
data_collator=data_collator
)
eval_results = trainer.evaluate()
loss = eval_results['eval_loss']
return torch.exp(torch.tensor(loss)).item()
def run_mask_experiment(model, tokenized_data, args):
"""Run layer masking experiment"""
results = []
masker = LayerMasker(model)
# Compute baseline perplexity
baseline_ppl = get_perplexities(model, tokenized_data, args.batch_size)
print(f"Baseline Perplexity: {baseline_ppl:.2f}")
# Layer masking experiments
print("\nRunning layer masking experiments:")
for layer_idx in tqdm(range(model.config.n_layer), desc="Layers"):
for module_type in ["ffn", "attn"]:
# Apply masking
masker.mask_layer(layer_idx, module_type)
# Compute perplexity
ppl = get_perplexities(model, tokenized_data, args.batch_size)
delta = ppl - baseline_ppl
gc.collect()
torch.cuda.empty_cache()
results.append({
"layer": layer_idx,
"module": module_type,
"perplexity": ppl,
"delta": delta
})
# Reset the model
masker.reset()
return results, baseline_ppl
def visualize_results(results_df, output_dir):
"""Visualize layer impact results"""
plt.figure(figsize=(12, 6))
# Plot FFN curve
ffn_data = results_df[results_df["module"] == "ffn"]
plt.plot(ffn_data["layer"], ffn_data["delta"],
marker='o', linestyle='-', linewidth=2, markersize=8,
color='#1f77b4', label='FFN')
# Plot Attention curve
attn_data = results_df[results_df["module"] == "attn"]
plt.plot(attn_data["layer"], attn_data["delta"],
marker='s', linestyle='--', linewidth=2, markersize=8,
color='#ff7f0e', label='Attention')
plt.xlabel("Layer Index", fontsize=12)
plt.ylabel("Δ Perplexity (log scale)", fontsize=12)
plt.title("GPT-2 Layer Masking Impact Analysis", fontsize=14)
plt.legend(fontsize=10)
plt.grid(True, alpha=0.3)
plt.xticks(range(0, results_df["layer"].max()+1))
# Set y-axis to log scale
plt.yscale('log')
plot_path = os.path.join(output_dir, f"{args.perturbation}_layer_impact.png")
plt.savefig(plot_path, bbox_inches='tight', dpi=300)
plt.close()
if __name__ == "__main__":
# Argument configuration
parser = argparse.ArgumentParser(description="GPT-2 Layer Masking Experiment")
parser.add_argument('--perturbation', type=str, default="hop_sg2pl",
help="Perturbation type (default: hop_sg2pl)")
parser.add_argument('--train_set', type=str, default="10M",
help="Training set size (default: 10M)")
parser.add_argument('--checkpoint_path', type=str, default="checkpoint-2736",
help="Model checkpoint path (default: checkpoint-2736)")
parser.add_argument('--batch_size', type=int, default=3,
help="Evaluation batch size (default: 3)")
parser.add_argument('--seed', type=int, default=0,
help="Random seed (default: 0)")
args = parser.parse_args()
# Initialize output directory
output_dir = f"mask_results/GPT2-MaskResults"
os.makedirs(output_dir, exist_ok=True)
# Load dataset
dataset_name = f"babylm_{args.perturbation}_{args.train_set}_seed{args.seed}"
dataset = load_dataset('../train/babylm_dataset_test.py', name=dataset_name, trust_remote_code=True)
test_dataset = dataset['test']
# Data sampling
rng = default_rng(args.seed)
indices = rng.choice(len(test_dataset), size=500, replace=False)
sampled_test = test_dataset.select(indices)
# Load model and tokenizer
checkpoint_path = f"../train/checkpoints/GPT-2/babylm_{args.perturbation}_10M_seed0/runs/{args.checkpoint_path}"
tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
model = AutoModelForCausalLM.from_pretrained(checkpoint_path, torch_dtype=torch.float16).eval()
if torch.cuda.is_available():
model = model.cuda()
# Data preprocessing
def tokenize_fn(examples):
return tokenizer(
examples["text"],
padding="max_length",
truncation=True,
max_length=512,
return_tensors="pt"
)
tokenized_test = sampled_test.map(
tokenize_fn,
batched=True,
remove_columns=["text"],
desc="Tokenizing"
)
# Run experiment
results, baseline_ppl = run_mask_experiment(model, tokenized_test, args)
results_df = pd.DataFrame(results)
# Save results
output_file = os.path.join(output_dir, f"mask_results_{args.perturbation}.csv")
results_df.to_csv(output_file, index=False)
# Generate visualization
visualize_results(results_df, output_dir)
# Print key results
print(f"\nExperiment completed. Results saved to {output_dir}")
print(f"Baseline Perplexity: {baseline_ppl:.2f}")
print("Layer impact results:")
print(results_df[["layer", "module", "delta"]].to_string(index=False))