#!/usr/bin/env python import os import torch import numpy as np import matplotlib.pyplot as plt torch.set_printoptions(profile="full") FILTER_RATE = 0.95 TOP_RATE = 0.01 ACTIVATION_BAR_RATIO = 0.95 THRESHOLD = 0.8 def plot_language_neurons(langs, base_path, checkpoint_numbers): """ langs: list of languages, e.g., ["en", "eu"] base_path: folder containing activation files checkpoint_numbers: list of ints, e.g., [300, 3231] """ model_name = os.path.basename(base_path) for cp in checkpoint_numbers: checkpoint = f"qwen-checkpoint-{cp}" n, over_zero = [], [] # Load activation data for lang in langs: path = os.path.join(base_path, f"activation.{lang}.train.{checkpoint}") data = torch.load(path) n.append(data["n"]) over_zero.append(data["over_zero"]) # Convert to tensors n = torch.Tensor(n) over_zero = torch.stack(over_zero, dim=-1) num_layers, intermediate_size, lang_num = over_zero.size() # 1. Activation probability activation_probs = over_zero / n # 2. Normalized activation probability normed_activation_probs = activation_probs / activation_probs.sum(dim=-1, keepdim=True) # 3. LAPE (entropy) log_prop = torch.where(normed_activation_probs > 0, normed_activation_probs.log(), torch.zeros_like(normed_activation_probs)) entropy = -(normed_activation_probs * log_prop).sum(dim=-1) # 4. Filter neurons using 95th percentile flat_probs = activation_probs.flatten() thresh = flat_probs.kthvalue(int(flat_probs.numel() * FILTER_RATE)).values valid_mask = (activation_probs > thresh).any(dim=-1) entropy[~valid_mask] = float("inf") # 5. Select top-k neurons with lowest entropy flat_entropy = entropy.flatten() topk = int(flat_entropy.numel() * TOP_RATE) _, idx = flat_entropy.topk(topk, largest=False) layer_idx = idx // intermediate_size neuron_idx = idx % intermediate_size # 6. Group by languages selection_props = activation_probs[layer_idx, neuron_idx] bar = flat_probs.kthvalue(int(flat_probs.numel() * ACTIVATION_BAR_RATIO)).values lang_mask = (selection_props > bar).T final_mask = {} for i, lang in enumerate(langs): neuron_ids = torch.where(lang_mask[i])[0] layer_lists = [[] for _ in range(num_layers)] for nid in neuron_ids.tolist(): l = layer_idx[nid].item() h = neuron_idx[nid].item() layer_lists[l].append(h) final_mask[lang] = [torch.tensor(lst, dtype=torch.long) for lst in layer_lists] # ========================= # Plot bar chart với số trên mỗi bar # ========================= plt.figure(figsize=(12, 6)) x = np.arange(num_layers) width = 0.35 for i, lang_key in enumerate(langs): counts = [len(layer) for layer in final_mask[lang_key]] bars = plt.bar(x + i * width, counts, width=width, label=lang_key) # Thêm số trên bar for bar_item in bars: height = bar_item.get_height() plt.text(bar_item.get_x() + bar_item.get_width()/2.0, height, f'{int(height)}', ha='center', va='bottom', fontsize=9) plt.xlabel("Layer Index") plt.ylabel("Number of Neurons") plt.title(f"Number of Language-Specific Neurons per Layer\nModel: {model_name}, Checkpoint: {checkpoint}") plt.xticks(x + width / 2, x) plt.legend() plt.grid(alpha=0.3, axis='y') plt.tight_layout() # Lưu ảnh riêng cho mỗi checkpoint plt.savefig(f"{model_name}_{checkpoint}_neurons_bar.png", dpi=300) plt.close() # ========================= # Example usage # ========================= if __name__ == "__main__": langs = ["en", "zh"] base_path = "activations/qwen2.5-0.5b_english_wiki_750M_chinese_wikipedia_corpus" checkpoint_numbers = [300, 600, 900, 1200, 1500, 1800, 1962] plot_language_neurons(langs, base_path, checkpoint_numbers)