JiRackTernary_405b / grid_search_405b.py
kgrabko's picture
Update grid_search_405b.py
c55dec7 verified
# ==============================================================================
# COPYRIGHT (C) 2025-2026 KONSTANTIN VLADIMIROVICH GRABKO. ALL RIGHTS RESERVED.
# PATENT PENDING | CMS MANHATTAN JIRACK TECHNOLOGY
# ==============================================================================
import torch
import torch.nn.functional as F
import os
from transformers import AutoTokenizer
from JiRackTernaryPyTorch_405b import JiRackTernaryModel, JiRackTernaryConfig
# --- CONFIGURATION ---
LOCAL_MODEL_PATH = "/home/kgrabko/Jirack405b/JiRackTernary_405B_Packed_BitNet"
TOKENIZER_ID = "meta-llama/Meta-Llama-3.1-405B-Instruct"
OFFLOAD_DIR = "offload_cache"
def evaluate_quality(model, input_ids, alpha, k_val, temp):
"""
Оценка качества генерации 405B с использованием контрастивной логики.
"""
total_loss = 0
context = input_ids.clone()
test_len = 12 # Длина тестовой генерации для оценки
for _ in range(test_len):
with torch.no_grad():
# Подача контекста на входное устройство модели (model.device)
outputs = model(context)
# Применяем температуру
logits = outputs.logits[:, -1, :] / max(temp, 1e-6)
# Contrastive Logic
probs = F.softmax(logits, dim=-1)
top_k_probs, top_k_ids = torch.topk(probs, k_val)
best_score = -float('inf')
best_token = top_k_ids[:, 0:1]
for i in range(k_val):
candidate_id = top_k_ids[:, i:i+1]
candidate_prob = top_k_probs[:, i]
# Аппроксимация штрафа за повторение (Contrastive Penalty)
# Для 405B это помогает избежать "залипания" на одной фразе
rep_count = (context == candidate_id).sum().item()
penalty = rep_count / context.size(1)
# Итоговый скор: баланс между вероятностью и новизной
score = (1 - alpha) * candidate_prob.item() - alpha * penalty
if score > best_score:
best_score = score
best_token = candidate_id
# Считаем перплексию для выбранного токена
target_prob = probs[0, best_token.item()]
total_loss += -torch.log(target_prob + 1e-10)
# Обновляем контекст (учитываем, что 405B работает на model.device)
context = torch.cat([context, best_token.to(model.device)], dim=-1)
return torch.exp(total_loss / test_len).item()
def run_405b_grid_search():
if not os.path.exists(OFFLOAD_DIR):
os.makedirs(OFFLOAD_DIR)
print(f"🚀 Загрузка токенизатора {TOKENIZER_ID}...")
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_ID)
print(f"🏗️ Инициализация JiRack 405B (Offload Mode)...")
# Используем автоматическое распределение весов
model = JiRackTernaryModel.from_pretrained(
LOCAL_MODEL_PATH,
device_map="auto",
torch_dtype=torch.float16,
trust_remote_code=True,
offload_folder=OFFLOAD_DIR,
low_cpu_mem_usage=True
)
model.eval()
# Сетка параметров (Alpha - штраф за повтор, K - окно кандидатов)
alphas = [0.3, 0.5, 0.7]
ks = [3, 5]
temps = [0.7, 0.9]
test_phrase = "The advancement of artificial intelligence in 2026 is"
input_ids = tokenizer.encode(test_phrase, return_tensors="pt").to(model.device)
print(f"\n🧪 Глобальный поиск гиперпараметров для 405B")
print(f"{'Temp':<6} | {'Alpha':<6} | {'K':<4} | {'PPL Score':<10}")
print("-" * 35)
results = []
for t in temps:
for a in alphas:
for k in ks:
# Оценка может быть долгой из-за Disk Offload
ppl = evaluate_quality(model, input_ids, a, k, t)
print(f"{t:<6.1f} | {a:<6.1f} | {k:<4} | {ppl:<10.2f}")
results.append((t, a, k, ppl))
best = min(results, key=lambda x: x[3])
print("-" * 35)
print(f"🎯 ОПТИМАЛЬНО ДЛЯ 405B: Temp={best[0]}, Alpha={best[1]}, K={best[2]}")
print(f"📈 Min PPL: {best[3]:.2f}")
if __name__ == "__main__":
run_405b_grid_search()