JiRackTernary_70b / layer_inspector_70b.py
kgrabko's picture
Update layer_inspector_70b.py
af32a55 verified
# ==============================================================================
# COPYRIGHT (C) 2025-2026 KONSTANTIN VLADIMIROVICH GRABKO. ALL RIGHTS RESERVED.
# PATENT PENDING | CMS MANHATTAN JIRACK TECHNOLOGY
# ==============================================================================
import torch
import gc
from transformers import AutoTokenizer
from JiRackTernaryPyTorch_70b import JiRackTernaryConfig
from load_packed_70b import load_jirack_70b_packed
def inspect_activations():
PATH = "JiRack_BitNet_70B_Packed"
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-70B")
config = JiRackTernaryConfig.from_pretrained(PATH)
print("🏗️ Загрузка упакованной 70B на GPU...")
model = load_jirack_70b_packed(PATH, config)
model.eval().to("cuda:0")
# Тот же промпт для консистентности
prompt = "The solar system consists of the Sun and the objects that orbit it."
inputs = tokenizer(prompt, return_tensors="pt").to("cuda:0")
# ИСПРАВЛЕНИЕ: Приведение маски к bool для SDPA
if "attention_mask" in inputs:
inputs["attention_mask"] = inputs["attention_mask"].bool()
stats = {}
def get_hook(name):
def hook(model, input, output):
# output - это CausalLMOutputWithPast или тензор в зависимости от места
# В слоях это обычно тензор (hidden_states)
data = output[0] if isinstance(output, tuple) else output
data = data.detach().to(torch.float32)
stats[name] = {
"mean": data.mean().item(),
"std": data.std().item(),
"l2": torch.norm(data, p=2).item(),
"zeros": (data.abs() < 1e-6).float().mean().item(),
"has_nan": torch.isnan(data).any().item()
}
return hook
# Регистрируем хуки на ключевых точках
layers_to_watch = [0, 10, 40, 79]
for idx in layers_to_watch:
model.layers[idx].register_forward_hook(get_hook(f"Layer_{idx}"))
print("\n🔍 Запуск диагностического прогона...")
with torch.no_grad():
try:
model(**inputs)
except Exception as e:
print(f"❌ Ошибка во время прогона: {e}")
return
print("\n📊 Результаты инспекции слоев:")
print(f"{'Слой':<10} | {'Среднее':<12} | {'Std':<12} | {'L2-Норма':<12} | {'% Нулей':<10} | {'NaN'}")
print("-" * 85)
for name in [f"Layer_{i}" for i in layers_to_watch]:
if name in stats:
s = stats[name]
nan_str = "❌ YES" if s['has_nan'] else "✅ NO"
print(f"{name:<10} | {s['mean']:>12.6f} | {s['std']:>12.6f} | {s['l2']:>12.1f} | {s['zeros']:>10.1%}| {nan_str}")
# Анализ здоровья сигнала
l0 = stats["Layer_0"]["l2"]
l79 = stats["Layer_79"]["l2"]
ratio = l79 / l0 if l0 != 0 else 0
print(f"\n📈 Коэффициент сохранения энергии (L2_L79/L2_L0): {ratio:.4f}")
if ratio < 0.2:
print("⚠️ Сигнал критически затухает. Модель 'засыпает' к выходу.")
elif ratio > 5.0:
print("⚠️ Сигнал перегружен. Возможны галлюцинации и 'салат из слов'.")
else:
print("✅ Динамика в пределах нормы. Тернарный стек стабилен.")
if __name__ == "__main__":
torch.cuda.empty_cache()
gc.collect()
inspect_activations()