JiRackTernary_140b / val_sidecar_cms_140b.py
kgrabko's picture
Upload 8 files
1461a1f verified
# ==============================================================================
# COPYRIGHT (C) 2025 KONSTANTIN VLADIMIROVICH GRABKO. ALL RIGHTS RESERVED.
# PATENT PENDING | CMS MANHATTAN JIRACK TECHNOLOGY
#
# This software is licensed under the Commercial License Agreement V.1.2.
# Any use, modification, or distribution of this code requires compliance with
# the terms found in the LICENSE.md file in the root directory.
#
# NO PATENTING RIGHTS: Users are strictly prohibited from filing patent claims
# based on the BRE or SWA architectures disclosed herein.
# Contact: grabko@cmsmanhattan.com | +1 (516) 777-0945
# ==============================================================================
# Optimized for Heavy Ternary Models (70B/140B) on ROCm
import os
import torch
import time
import json
import glob
from transformers import LlamaTokenizerFast
from datasets import load_dataset
# Конфигурация
MODEL_PATH = "./models"
LOG_FILE = "val_metrics_cms.json"
CULTURAL_FILE = "cultural_finetune.jsonl"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def load_cultural_data(path):
data = []
if os.path.exists(path):
with open(path, 'r', encoding='utf-8') as f:
for line in f:
data.append(json.loads(line))
return data
def run_validation():
print(">>> CMS Manhattan Heavy Sidecar started.")
# Используем Llama-3 токенайзер как стандарт для 70B+
tokenizer = LlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer", legacy=False)
cultural_data = load_cultural_data(CULTURAL_FILE)
print(f">>> Initializing data streams...")
pile_dataset = load_dataset("monology/pile-uncopyrighted", split="train", streaming=True)
processed_checkpoints = set()
while True:
checkpoints = glob.glob(os.path.join(MODEL_PATH, "ternary_*_checkpoint_step_*"))
if not checkpoints:
time.sleep(60)
continue
latest_ckpt = max(checkpoints, key=os.path.getmtime)
if latest_ckpt not in processed_checkpoints:
step = latest_ckpt.split('_')[-1]
print(f"\n[NEW CHECKPOINT DETECTED: {step}]")
try:
# ВАЖНО: Для 70B+ используем загрузку весов с маппингом на CPU перед GPU
# чтобы избежать пикового потребления VRAM
checkpoint_data = torch.load(latest_ckpt, map_location='cpu')
# Здесь должна быть инициализация вашей архитектуры 140B
# model = JiRackTernary140B(config).to(DEVICE)
# model.load_state_dict(checkpoint_data['model_state_dict'])
# Имитация расчета (подставьте реальный вызов модели)
# В 140B версии мы используем только 10 сэмплов для скорости
p_loss = 10.71 # Пример (заменить на model.forward)
c_loss = 10.46 # Пример (заменить на model.forward)
# Сохранение результатов
results = []
if os.path.exists(LOG_FILE):
with open(LOG_FILE, 'r') as f:
results = json.load(f)
results.append({
"step": step,
"pile_loss": p_loss,
"cultural_loss": c_loss,
"timestamp": time.time()
})
with open(LOG_FILE, 'w') as f:
json.dump(results, f, indent=4)
print(f"📊 SUMMARY | Step: {step}")
print(f" - Pile Loss: {p_loss:.4f}")
print(f" - Cultural Loss: {c_loss:.4f}")
processed_checkpoints.add(latest_ckpt)
# Очистка кэша после каждого тяжелого прогона
torch.cuda.empty_cache()
except Exception as e:
print(f"❌ Error during heavy validation: {e}")
time.sleep(120)
if __name__ == "__main__":
run_validation()