|
|
|
|
|
import os |
|
|
import torch |
|
|
import numpy as np |
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
torch.set_printoptions(profile="full") |
|
|
|
|
|
FILTER_RATE = 0.95 |
|
|
TOP_RATE = 0.01 |
|
|
ACTIVATION_BAR_RATIO = 0.95 |
|
|
THRESHOLD = 0.8 |
|
|
|
|
|
def plot_language_neurons(langs, base_path, checkpoint_numbers): |
|
|
""" |
|
|
langs: list of languages, e.g., ["en", "eu"] |
|
|
base_path: folder containing activation files |
|
|
checkpoint_numbers: list of ints, e.g., [300, 3231] |
|
|
""" |
|
|
model_name = os.path.basename(base_path) |
|
|
|
|
|
for cp in checkpoint_numbers: |
|
|
checkpoint = f"qwen-checkpoint-{cp}" |
|
|
n, over_zero = [], [] |
|
|
|
|
|
|
|
|
for lang in langs: |
|
|
path = os.path.join(base_path, f"activation.{lang}.train.{checkpoint}") |
|
|
data = torch.load(path) |
|
|
n.append(data["n"]) |
|
|
over_zero.append(data["over_zero"]) |
|
|
|
|
|
|
|
|
n = torch.Tensor(n) |
|
|
over_zero = torch.stack(over_zero, dim=-1) |
|
|
num_layers, intermediate_size, lang_num = over_zero.size() |
|
|
|
|
|
|
|
|
activation_probs = over_zero / n |
|
|
|
|
|
|
|
|
normed_activation_probs = activation_probs / activation_probs.sum(dim=-1, keepdim=True) |
|
|
|
|
|
|
|
|
log_prop = torch.where(normed_activation_probs > 0, |
|
|
normed_activation_probs.log(), |
|
|
torch.zeros_like(normed_activation_probs)) |
|
|
entropy = -(normed_activation_probs * log_prop).sum(dim=-1) |
|
|
|
|
|
|
|
|
flat_probs = activation_probs.flatten() |
|
|
thresh = flat_probs.kthvalue(int(flat_probs.numel() * FILTER_RATE)).values |
|
|
valid_mask = (activation_probs > thresh).any(dim=-1) |
|
|
entropy[~valid_mask] = float("inf") |
|
|
|
|
|
|
|
|
flat_entropy = entropy.flatten() |
|
|
topk = int(flat_entropy.numel() * TOP_RATE) |
|
|
_, idx = flat_entropy.topk(topk, largest=False) |
|
|
|
|
|
layer_idx = idx // intermediate_size |
|
|
neuron_idx = idx % intermediate_size |
|
|
|
|
|
|
|
|
selection_props = activation_probs[layer_idx, neuron_idx] |
|
|
bar = flat_probs.kthvalue(int(flat_probs.numel() * ACTIVATION_BAR_RATIO)).values |
|
|
lang_mask = (selection_props > bar).T |
|
|
|
|
|
final_mask = {} |
|
|
for i, lang in enumerate(langs): |
|
|
neuron_ids = torch.where(lang_mask[i])[0] |
|
|
layer_lists = [[] for _ in range(num_layers)] |
|
|
for nid in neuron_ids.tolist(): |
|
|
l = layer_idx[nid].item() |
|
|
h = neuron_idx[nid].item() |
|
|
layer_lists[l].append(h) |
|
|
final_mask[lang] = [torch.tensor(lst, dtype=torch.long) for lst in layer_lists] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
plt.figure(figsize=(12, 6)) |
|
|
x = np.arange(num_layers) |
|
|
width = 0.35 |
|
|
|
|
|
for i, lang_key in enumerate(langs): |
|
|
counts = [len(layer) for layer in final_mask[lang_key]] |
|
|
bars = plt.bar(x + i * width, counts, width=width, label=lang_key) |
|
|
|
|
|
for bar_item in bars: |
|
|
height = bar_item.get_height() |
|
|
plt.text(bar_item.get_x() + bar_item.get_width()/2.0, height, f'{int(height)}', |
|
|
ha='center', va='bottom', fontsize=9) |
|
|
|
|
|
plt.xlabel("Layer Index") |
|
|
plt.ylabel("Number of Neurons") |
|
|
plt.title(f"Number of Language-Specific Neurons per Layer\nModel: {model_name}, Checkpoint: {checkpoint}") |
|
|
plt.xticks(x + width / 2, x) |
|
|
plt.legend() |
|
|
plt.grid(alpha=0.3, axis='y') |
|
|
plt.tight_layout() |
|
|
|
|
|
|
|
|
plt.savefig(f"{model_name}_{checkpoint}_neurons_bar.png", dpi=300) |
|
|
plt.close() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
langs = ["en", "zh"] |
|
|
base_path = "activations/qwen2.5-0.5b_english_wiki_750M_chinese_wikipedia_corpus" |
|
|
checkpoint_numbers = [300, 600, 900, 1200, 1500, 1800, 1962] |
|
|
plot_language_neurons(langs, base_path, checkpoint_numbers) |
|
|
|