#!/usr/bin/env python import os import torch import numpy as np import matplotlib.pyplot as plt torch.set_printoptions(profile="full") FILTER_RATE = 0.95 TOP_RATE = 0.5 ACTIVATION_BAR_RATIO = 0.95 langs = ["en", "eu"] base_path = "new_activations" n, over_zero = [], [] model_name = None checkpoint = None for lang in langs: # File path path = os.path.join(base_path, f"activation.{lang}.train.l2-7b-eu.pt") data = torch.load(path) n.append(data["n"]) over_zero.append(data["over_zero"]) # Extract model_name and checkpoint only once if model_name is None: model_name = os.path.basename(os.path.dirname(path)) # folder name filename = os.path.basename(path) parts = filename.split('.') checkpoint = parts[-1] # 'qwen-checkpoint-1200' # Convert to tensors n = torch.Tensor(n) # (lang_num) over_zero = torch.stack(over_zero, dim=-1) # (layer_num, neuron_num, lang_num) num_layers, intermediate_size, lang_num = over_zero.size() # 1. Activation probability activation_probs = over_zero / n # broadcast # 2. Normalized activation probability normed_activation_probs = activation_probs / activation_probs.sum(dim=-1, keepdim=True) # 3. LAPE (entropy) log_prop = torch.where(normed_activation_probs > 0, normed_activation_probs.log(), torch.zeros_like(normed_activation_probs)) entropy = -(normed_activation_probs * log_prop).sum(dim=-1) # 4. Filter neurons using 95th percentile flat_probs = activation_probs.flatten() thresh = flat_probs.kthvalue(int(flat_probs.numel() * FILTER_RATE)).values valid_mask = (activation_probs > thresh).any(dim=-1) # [layers, neurons] entropy[~valid_mask] = float("inf") # 5. Select top-k neurons with lowest entropy flat_entropy = entropy.flatten() topk = int(flat_entropy.numel() * TOP_RATE) _, idx = flat_entropy.topk(topk, largest=False) layer_idx = idx // intermediate_size neuron_idx = idx % intermediate_size # 6. Group by languages selection_props = activation_probs[layer_idx, neuron_idx] # [topk, lang_num] bar = flat_probs.kthvalue(int(flat_probs.numel() * ACTIVATION_BAR_RATIO)).values lang_mask = (selection_props > bar).T # [lang_num, topk] final_mask = {} for i, lang in enumerate(langs): neuron_ids = torch.where(lang_mask[i])[0] layer_lists = [[] for _ in range(num_layers)] for nid in neuron_ids.tolist(): l = layer_idx[nid].item() h = neuron_idx[nid].item() layer_lists[l].append(h) final_mask[lang] = [torch.tensor(lst, dtype=torch.long) for lst in layer_lists] # ========================= # Plot number of neurons per layer (bar chart) # ========================= plt.figure(figsize=(12, 6)) x = np.arange(num_layers) width = 0.35 bars_list = [] for i, lang_key in enumerate(langs): counts = [len(layer) for layer in final_mask[lang_key]] bars = plt.bar(x + i * width, counts, width=width, label=lang_key) bars_list.append(bars) # Thêm số lên mỗi bar for bar in bars: height = bar.get_height() plt.text(bar.get_x() + bar.get_width()/2.0, height, f'{int(height)}', ha='center', va='bottom', fontsize=9) plt.xlabel("Layer Index") plt.ylabel("Number of Neurons") plt.title(f"Number of Language-Specific Neurons per Layer\nModel: {model_name}, Checkpoint: {checkpoint}") plt.xticks(x + width / 2, x) plt.legend() plt.grid(alpha=0.3, axis='y') plt.tight_layout() plt.savefig(f"{model_name}_{checkpoint}_neurons_bar.png", dpi=300) plt.close()