#!/usr/bin/env python3 """ Track per-neuron activations in Qwen2 MLP layers using Hugging Face Transformers with explicit device management. """ import argparse import os from types import MethodType import torch from torch import Tensor from tqdm import tqdm from transformers import AutoModelForCausalLM # ---------------------- Activation Tracker ---------------------- class ActivationTracker: def __init__(self, num_layers: int, intermediate_size: int): # store on CPU to avoid memory issues self.over_zero = torch.zeros( num_layers, intermediate_size, dtype=torch.int32, device="cpu" ) def make_qwen_hook(self, index: int): over_zero = self.over_zero def qwen_forward(self, x: Tensor): gate_activation = self.act_fn(self.gate_proj(x)) with torch.no_grad(): over_zero[index, :] += (gate_activation > 0).sum(dim=(0, 1)).to("cpu") return self.down_proj(gate_activation * self.up_proj(x)) return qwen_forward # ---------------------- Arguments ---------------------- parser = argparse.ArgumentParser() parser.add_argument("--model", type=str, required=True, help="HF model ID or local folder path") parser.add_argument("--lang", type=str, required=True, help="Language code for dataset") parser.add_argument("--data-path", type=str, required=True, help="Path to tokenized dataset (torch tensor)") parser.add_argument("--output-dir", type=str, default="activations", help="Directory to save over_zero") parser.add_argument("--batch-size", type=int, default=1, help="Batch size per device") parser.add_argument("--chunk-size", type=int, default=4096, help="Max sequence length to process at once") args = parser.parse_args() # ---------------------- Setup Device ---------------------- device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") os.makedirs(args.output_dir, exist_ok=True) # ---------------------- Load Data ---------------------- print("Loading data...") ids = torch.load(args.data_path, map_location="cpu") # Load to CPU first # ---------------------- Load Model ---------------------- print(f"Loading model: {args.model}") model = AutoModelForCausalLM.from_pretrained( args.model, device_map="auto", # Let it automatically distribute across available GPUs torch_dtype=torch.bfloat16 # reduce memory ) model.eval() num_layers = model.config.num_hidden_layers intermediate_size = model.config.intermediate_size max_len = model.config.max_position_embeddings # Setup tracker tracker = ActivationTracker(num_layers=num_layers, intermediate_size=intermediate_size) # Monkey-patch MLP layers for i, layer in enumerate(model.model.layers): layer.mlp.forward = MethodType(tracker.make_qwen_hook(i), layer.mlp) # Prepare input - use chunk_size instead of max_len for memory efficiency chunk_size = min(args.chunk_size, max_len) n = (ids.size(0) // chunk_size) * chunk_size input_ids = ids[:n].reshape(-1, chunk_size) print(f"Processing {input_ids.size(0)} sequences of length {chunk_size}") # ---------------------- Run Inference ---------------------- with torch.no_grad(): for i in tqdm(range(0, input_ids.size(0), args.batch_size), desc="Processing", unit="batch"): batch = input_ids[i:i + args.batch_size] # Move batch to the same device as the model's first parameter # This works with device_map="auto" batch = batch.to(next(model.parameters()).device) # Clear cache before each batch to prevent memory buildup if torch.cuda.is_available(): torch.cuda.empty_cache() model(input_ids=batch) # ---------------------- Save activations ---------------------- model_name = os.path.basename(args.model.rstrip("/")) out_path = os.path.join(args.output_dir, f"activation_{model_name}_{args.lang}.pt") torch.save(tracker.over_zero, out_path) print(f"Saved activation counts to {out_path}") print("Activation single job done")