File size: 4,300 Bytes
fed1832
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#!/usr/bin/env python
import os
import torch
import numpy as np
import matplotlib.pyplot as plt

torch.set_printoptions(profile="full")  

FILTER_RATE = 0.95
TOP_RATE = 0.01
ACTIVATION_BAR_RATIO = 0.95
THRESHOLD = 0.8

def plot_language_neurons(langs, base_path, checkpoint_numbers):
    """
    langs: list of languages, e.g., ["en", "eu"]
    base_path: folder containing activation files
    checkpoint_numbers: list of ints, e.g., [300, 3231]
    """
    model_name = os.path.basename(base_path)

    for cp in checkpoint_numbers:
        checkpoint = f"qwen-checkpoint-{cp}"
        n, over_zero = [], []

        # Load activation data
        for lang in langs:
            path = os.path.join(base_path, f"activation.{lang}.train.{checkpoint}")
            data = torch.load(path)
            n.append(data["n"])
            over_zero.append(data["over_zero"])

        # Convert to tensors
        n = torch.Tensor(n)
        over_zero = torch.stack(over_zero, dim=-1)
        num_layers, intermediate_size, lang_num = over_zero.size()

        # 1. Activation probability
        activation_probs = over_zero / n

        # 2. Normalized activation probability
        normed_activation_probs = activation_probs / activation_probs.sum(dim=-1, keepdim=True)

        # 3. LAPE (entropy)
        log_prop = torch.where(normed_activation_probs > 0,
                               normed_activation_probs.log(),
                               torch.zeros_like(normed_activation_probs))
        entropy = -(normed_activation_probs * log_prop).sum(dim=-1)

        # 4. Filter neurons using 95th percentile
        flat_probs = activation_probs.flatten()
        thresh = flat_probs.kthvalue(int(flat_probs.numel() * FILTER_RATE)).values
        valid_mask = (activation_probs > thresh).any(dim=-1)
        entropy[~valid_mask] = float("inf")

        # 5. Select top-k neurons with lowest entropy
        flat_entropy = entropy.flatten()
        topk = int(flat_entropy.numel() * TOP_RATE)
        _, idx = flat_entropy.topk(topk, largest=False)

        layer_idx = idx // intermediate_size
        neuron_idx = idx % intermediate_size

        # 6. Group by languages
        selection_props = activation_probs[layer_idx, neuron_idx]
        bar = flat_probs.kthvalue(int(flat_probs.numel() * ACTIVATION_BAR_RATIO)).values
        lang_mask = (selection_props > bar).T

        final_mask = {}
        for i, lang in enumerate(langs):
            neuron_ids = torch.where(lang_mask[i])[0]
            layer_lists = [[] for _ in range(num_layers)]
            for nid in neuron_ids.tolist():
                l = layer_idx[nid].item()
                h = neuron_idx[nid].item()
                layer_lists[l].append(h)
            final_mask[lang] = [torch.tensor(lst, dtype=torch.long) for lst in layer_lists]

        # =========================
        # Plot bar chart với số trên mỗi bar
        # =========================
        plt.figure(figsize=(12, 6))
        x = np.arange(num_layers)
        width = 0.35

        for i, lang_key in enumerate(langs):
            counts = [len(layer) for layer in final_mask[lang_key]]
            bars = plt.bar(x + i * width, counts, width=width, label=lang_key)
            # Thêm số trên bar
            for bar_item in bars:
                height = bar_item.get_height()
                plt.text(bar_item.get_x() + bar_item.get_width()/2.0, height, f'{int(height)}',
                         ha='center', va='bottom', fontsize=9)

        plt.xlabel("Layer Index")
        plt.ylabel("Number of Neurons")
        plt.title(f"Number of Language-Specific Neurons per Layer\nModel: {model_name}, Checkpoint: {checkpoint}")
        plt.xticks(x + width / 2, x)
        plt.legend()
        plt.grid(alpha=0.3, axis='y')
        plt.tight_layout()

        # Lưu ảnh riêng cho mỗi checkpoint
        plt.savefig(f"{model_name}_{checkpoint}_neurons_bar.png", dpi=300)
        plt.close()



# =========================
# Example usage
# =========================
if __name__ == "__main__":
    langs = ["en", "zh"]
    base_path = "activations/qwen2.5-0.5b_english_wiki_750M_chinese_wikipedia_corpus"
    checkpoint_numbers = [300, 600, 900, 1200, 1500, 1800, 1962]
    plot_language_neurons(langs, base_path, checkpoint_numbers)