lsn-analysis / activation.py
tvkain's picture
Upload folder using huggingface_hub
fed1832 verified
#!/usr/bin/env python3
"""
Track per-neuron activations in Qwen2 MLP layers using Hugging Face Transformers
with explicit device management.
"""
import argparse
import os
from types import MethodType
import torch
from torch import Tensor
from tqdm import tqdm
from transformers import AutoModelForCausalLM
# ---------------------- Activation Tracker ----------------------
class ActivationTracker:
def __init__(self, num_layers: int, intermediate_size: int):
# store on CPU to avoid memory issues
self.over_zero = torch.zeros(
num_layers, intermediate_size, dtype=torch.int32, device="cpu"
)
def make_qwen_hook(self, index: int):
over_zero = self.over_zero
def qwen_forward(self, x: Tensor):
gate_activation = self.act_fn(self.gate_proj(x))
with torch.no_grad():
over_zero[index, :] += (gate_activation > 0).sum(dim=(0, 1)).to("cpu")
return self.down_proj(gate_activation * self.up_proj(x))
return qwen_forward
# ---------------------- Arguments ----------------------
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, required=True, help="HF model ID or local folder path")
parser.add_argument("--lang", type=str, required=True, help="Language code for dataset")
parser.add_argument("--data-path", type=str, required=True, help="Path to tokenized dataset (torch tensor)")
parser.add_argument("--output-dir", type=str, default="activations", help="Directory to save over_zero")
parser.add_argument("--batch-size", type=int, default=1, help="Batch size per device")
parser.add_argument("--chunk-size", type=int, default=4096, help="Max sequence length to process at once")
args = parser.parse_args()
# ---------------------- Setup Device ----------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
os.makedirs(args.output_dir, exist_ok=True)
# ---------------------- Load Data ----------------------
print("Loading data...")
ids = torch.load(args.data_path, map_location="cpu") # Load to CPU first
# ---------------------- Load Model ----------------------
print(f"Loading model: {args.model}")
model = AutoModelForCausalLM.from_pretrained(
args.model,
device_map="auto", # Let it automatically distribute across available GPUs
torch_dtype=torch.bfloat16 # reduce memory
)
model.eval()
num_layers = model.config.num_hidden_layers
intermediate_size = model.config.intermediate_size
max_len = model.config.max_position_embeddings
# Setup tracker
tracker = ActivationTracker(num_layers=num_layers, intermediate_size=intermediate_size)
# Monkey-patch MLP layers
for i, layer in enumerate(model.model.layers):
layer.mlp.forward = MethodType(tracker.make_qwen_hook(i), layer.mlp)
# Prepare input - use chunk_size instead of max_len for memory efficiency
chunk_size = min(args.chunk_size, max_len)
n = (ids.size(0) // chunk_size) * chunk_size
input_ids = ids[:n].reshape(-1, chunk_size)
print(f"Processing {input_ids.size(0)} sequences of length {chunk_size}")
# ---------------------- Run Inference ----------------------
with torch.no_grad():
for i in tqdm(range(0, input_ids.size(0), args.batch_size), desc="Processing", unit="batch"):
batch = input_ids[i:i + args.batch_size]
# Move batch to the same device as the model's first parameter
# This works with device_map="auto"
batch = batch.to(next(model.parameters()).device)
# Clear cache before each batch to prevent memory buildup
if torch.cuda.is_available():
torch.cuda.empty_cache()
model(input_ids=batch)
# ---------------------- Save activations ----------------------
model_name = os.path.basename(args.model.rstrip("/"))
out_path = os.path.join(args.output_dir, f"activation_{model_name}_{args.lang}.pt")
torch.save(tracker.over_zero, out_path)
print(f"Saved activation counts to {out_path}")
print("Activation single job done")