File size: 7,141 Bytes
fed1832 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 | #!/usr/bin/env python3
"""
Per-neuron activation tracker for LLaMA-2 and Qwen MLP layers.
Runs on a fixed set of models and multiple input ID files per model.
"""
import torch
import os
from types import MethodType
from vllm import LLM, SamplingParams # Keep original import since hook logic depends on vLLM
# ---------------------- Config ----------------------
BASE_PATH = "/home/khanh/sla/sla_cpt"
ID_BASE_PATH = "./oscar_ids"
RUN_CONFIGS = [
# {
# 'name': 'l2-13b',
# 'model': f'{BASE_PATH}/uccix/checkpoint-4280',
# 'ids_list': [
# {"path": './ids/l2-13b/id.ga.train.l2-13b', "lang": "ga"},
# {"path": './ids/l2-13b/id.en.train.l2-13b', "lang": "en"}
# ],
# 'type': 'llama'
# },
# {
# 'name': 'l2-7b',
# 'model': f'{BASE_PATH}/llama2_7b_full_basque_corpus_grad_clip_1/checkpoint-10200',
# 'ids_list': [
# {"path": './ids/l2-7b/id.eu.train.l2-7b', "lang": "eu"},
# {"path": './ids/l2-7b/id.en.train.l2-7b', "lang": "en"}
# ],
# 'type': 'llama'
# },
{
'name': 'q2.5-zh',
'model': f'{BASE_PATH}/qwen2.5-0.5b_english_wiki_750M_chinese_wikipedia_corpus_2e_240925/checkpoint-2944',
'ids_list': [
{"path": f'{ID_BASE_PATH}/q2.5/id.zh.train.qwen2.5-0.5', "lang": "zh"},
{"path": f'{ID_BASE_PATH}/q2.5/id.en.train.qwen2.5-0.5', "lang": "en"}
],
'type': 'qwen'
},
# {
# 'name': 'q2.5-en+zh',
# 'model': f'{BASE_PATH}/qwen2.5-0.5b_english_wiki_150M_en_750M_chinese_wikipedia_corpus_2e_240925/checkpoint-3494',
# 'ids_list': [
# {"path": '{ID_BASE_PATH}/q2.5/id.zh.train.qwen2.5-0.5', "lang": "zh"},
# {"path": '{ID_BASE_PATH}/q2.5/id.en.train.qwen2.5-0.5', "lang": "en"}
# ],
# 'type': 'qwen'
# },
# {
# 'name': 'q2.5-ga',
# 'model': f'{BASE_PATH}/qwen2.5-0.5b_english_wiki_1.5B_irish_corpus_240925/checkpoint-2854',
# 'ids_list': [
# {"path": '{ID_BASE_PATH}/q2.5/id.en.train.qwen2.5-0.5', "lang": "en"},
# {"path": '{ID_BASE_PATH}/q2.5/id.ga.train.qwen2.5-0.5', "lang": "ga"}
# ],
# 'type': 'qwen'
# },
# # {
# # 'name': 'q2.5-en+ga',
# # 'model': f'{BASE_PATH}/qwen2.5-0.5_full_english_corpus_grad_clip_1/checkpoint-3231',
# # 'ids_list': [
# # {"path": './ids/qwen2.5-0.5/id.en.train.qwen2.5-0.5', "lang": "en"},
# # {"path": './ids/qwen2.5-0.5/id.ga.train.qwen2.5-0.5', "lang": "ga"}
# # ],
# # 'type': 'qwen'
# # },
# {
# 'name': 'q2.5-eu',
# 'model': f'{BASE_PATH}/qwen2.5-0.5b_english_wiki_1.5Bbasque_corpus_240925/checkpoint-2424',
# 'ids_list': [
# {"path": '{ID_BASE_PATH}/q2.5/id.eu.train.qwen2.5-0.5', "lang": "eu"},
# {"path": '{ID_BASE_PATH}/q2.5/id.en.train.qwen2.5-0.5', "lang": "en"}
# ],
# 'type': 'qwen'
# },
# {
# 'name': 'q2.5-en+eu',
# 'model': f'{BASE_PATH}/qwen2.5-0.5_full_basque_corpus_grad_clip_1/checkpoint-7800',
# 'ids_list': [
# {"path": './ids/qwen2.5-0.5/id.eu.train.qwen2.5-0.5', "lang": "eu"},
# {"path": './ids/qwen2.5-0.5/id.en.train.qwen2.5-0.5', "lang": "en"}
# ],
# }
]
SAVE_FOLDER = "new_activations"
os.makedirs(SAVE_FOLDER, exist_ok=True)
# ---------------------- Hook Functions ----------------------
def make_llama_hook(idx):
def llama_forward(self, x):
gate_up, _ = self.gate_up_proj(x) # l, 2i
i = gate_up.size(-1)
gate_up[:, : i // 2] = torch.nn.SiLU()(gate_up[:, : i // 2])
activation = gate_up[:, : i // 2].float() # l, i
over_zero[idx, :] += (activation > 0).sum(dim=0)
x = gate_up[:, : i // 2] * gate_up[:, i // 2 :]
x, _ = self.down_proj(x)
return x
return llama_forward
def make_qwen_hook(idx):
def qwen_forward(self, x):
gate_up, _ = self.gate_up_proj(x) # (s, 2h)
intermediate_size = gate_up.size(-1) // 2
gate = gate_up[..., :intermediate_size] # (s, h)
up = gate_up[..., intermediate_size:] # (s, h)
gate_activation = torch.nn.functional.silu(gate)
over_zero[idx, :] += (gate_activation > 0).sum(dim=0)
x, _ = self.down_proj(gate_activation * up)
return x
return qwen_forward
# ---------------------- Run All Configs ----------------------
for config in RUN_CONFIGS:
model_name = config['model']
save_name = config.get('name', model_name)
model_type = config.get('type', 'llama')
ids_list = config.get('ids_list', [])
print(f"\n=== Processing model: {model_name}, type: {model_type} ===")
# Load model
model = LLM(
model=model_name,
tensor_parallel_size=1,
enforce_eager=True,
trust_remote_code=True
)
max_length = model.llm_engine.model_config.max_model_len
num_layers = model.llm_engine.model_config.hf_config.num_hidden_layers
intermediate_size = model.llm_engine.model_config.hf_config.intermediate_size
print(f"Layers: {num_layers}, Intermediate size: {intermediate_size}, Max length: {max_length}")
# Setup activation tracker
over_zero = torch.zeros(num_layers, intermediate_size, dtype=torch.int32).to('cuda')
# Hook MLP layers
for i in range(num_layers):
mlp = model.llm_engine.model_executor.driver_worker.model_runner.model.model.layers[i].mlp
if model_type == 'llama':
mlp.forward = MethodType(make_llama_hook(i), mlp)
elif model_type == 'qwen':
mlp.forward = MethodType(make_qwen_hook(i), mlp)
else:
raise ValueError(f"Unknown model type: {model_type}")
# Iterate over all ID files
for id_dict in ids_list:
ids_path = id_dict['path']
lang = id_dict.get('lang', 'unknown') # Use lang in dict for output filename
print(f"\nLoading IDs from {ids_path} (lang: {lang})...")
ids = torch.load(ids_path)
print(f"ID shape: {ids.shape}")
l = ids.size(0)
l = min(l, 99999744) // max_length * max_length
input_ids = ids[:l].reshape(-1, max_length)
print(f"Processing {input_ids.size(0)} sequences of length {max_length}")
# Run inference
print("Running inference...")
_ = model.generate(
prompt_token_ids=input_ids.tolist(),
sampling_params=SamplingParams(max_tokens=1)
)
# Save results for this ID file
output_path = os.path.join(SAVE_FOLDER, f'activation.{lang}.train.{save_name}.pt')
torch.save({
'n': l,
'over_zero': over_zero.cpu(),
'num_layers': num_layers,
'intermediate_size': intermediate_size
}, output_path)
print(f"Saved activation counts to {output_path}")
print(f"Processed {l} tokens total")
print(f"\nActivation analysis complete for model: {save_name}!")
del model
torch.cuda.empty_cache()
import gc
gc.collect()
|