JiRackTernary_7b / val_sidecar_cms_7b.py
kgrabko's picture
Upload 8 files
2b24c37 verified
# ==============================================================================
# COPYRIGHT (C) 2025 KONSTANTIN VLADIMIROVICH GRABKO. ALL RIGHTS RESERVED.
# PATENT PENDING | CMS MANHATTAN JIRACK TECHNOLOGY
#
# This software is licensed under the Commercial License Agreement V.1.2.
# Any use, modification, or distribution of this code requires compliance with
# the terms found in the LICENSE.md file in the root directory.
#
# NO PATENTING RIGHTS: Users are strictly prohibited from filing patent claims
# based on the BRE or SWA architectures disclosed herein.
# Contact: grabko@cmsmanhattan.com | +1 (516) 777-0945
# ==============================================================================
# Sidecar Validator: JiRackTernary 7B - High Speed Monitoring
# Optimized for Tesla M10 (Maxwell Architecture)
import os
import torch
import time
import json
import glob
import matplotlib.pyplot as plt
from transformers import AutoTokenizer
from datasets import load_dataset
# Импортируйте вашу архитектуру 7B (убедитесь, что файл доступен)
# Если вы используете тот же класс, что и для 140B, просто измените конфиг
from JiRackTernaryPyTorch_70b import JiRackTernary70B as JiRackTernary7B
from JiRackTernaryPyTorch_70b import JiRackTernaryConfig
# --- КОНФИГУРАЦИЯ CMS ---
BASE_CHECKPOINT_DIR = "checkpoints_jirack_7b"
LOG_FILE = "val_metrics_7b_cms.json"
PLOT_FILE = "training_progress_7b.png"
CULTURAL_FILE = "cultural_finetune.jsonl"
VAL_SAMPLES_SLIM = 100 # Для 7B можно взять больше образцов для точности
MAX_LENGTH = 2048 # 7B легко держит длинный контекст на M10
def load_cultural_questions(file_path):
questions = []
if os.path.exists(file_path):
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
questions.append(json.loads(line))
return questions
def update_plot(history):
if len(history) < 2: return
steps = [int(h['step']) for h in history]
slim_losses = [h['slim_loss'] for h in history]
cultural_losses = [h['cultural_loss'] for h in history]
plt.figure(figsize=(10, 6))
plt.plot(steps, slim_losses, 'b-o', label='SlimPajama Loss')
plt.plot(steps, cultural_losses, 'r-s', label='Cultural Loss')
plt.title('JiRackTernary 7B: Training Dynamics')
plt.xlabel('Steps')
plt.ylabel('Loss')
plt.yscale('log')
plt.legend()
plt.grid(alpha=0.3)
plt.savefig(PLOT_FILE)
plt.close()
def main():
print(">>> Starting Validator for 7B Model...")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token
cultural_questions = load_cultural_questions(CULTURAL_FILE)
val_stream = load_dataset("cerebras/SlimPajama-627B", split="train", streaming=True)
processed_checkpoints = set()
device = "cuda" if torch.cuda.is_available() else "cpu"
while True:
ckpt_files = glob.glob(os.path.join(BASE_CHECKPOINT_DIR, "*.pt"))
if not ckpt_files:
time.sleep(30)
continue
last_ckpt = max(ckpt_files, key=os.path.getmtime)
if last_ckpt not in processed_checkpoints:
print(f"\n[VALIDATING 7B: {last_ckpt}]")
try:
# Конфиг для 7B (обычно ~32 слоя, hidden_size=4096)
config = JiRackTernaryConfig(
num_hidden_layers=32,
hidden_size=4096,
num_attention_heads=32
)
model = JiRackTernary7B(config)
checkpoint = torch.load(last_ckpt, map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()
# 1. SlimPajama Validation
s_loss = 0
for i, ex in enumerate(val_stream):
if i >= VAL_SAMPLES_SLIM: break
tok = tokenizer(ex["text"], truncation=True, max_length=MAX_LENGTH, return_tensors="pt").to(device)
with torch.no_grad():
s_loss += model(tok.input_ids, labels=tok.input_ids).loss.mean().item()
avg_slim = s_loss / VAL_SAMPLES_SLIM
# 2. Cultural Validation
c_loss = 0
for ex in cultural_questions:
tok = tokenizer(ex["text"], truncation=True, max_length=MAX_LENGTH, return_tensors="pt").to(device)
with torch.no_grad():
c_loss += model(tok.input_ids, labels=tok.input_ids).loss.mean().item()
avg_cultural = c_loss / len(cultural_questions) if cultural_questions else 0
# Log & Plot
step_val = checkpoint.get('step', last_ckpt.split('_')[-1].replace('.pt', ''))
history = []
if os.path.exists(LOG_FILE):
with open(LOG_FILE, 'r') as f: history = json.load(f)
history.append({
"step": step_val,
"slim_loss": avg_slim,
"cultural_loss": avg_cultural,
"timestamp": time.time()
})
with open(LOG_FILE, 'w') as f: json.dump(history, f, indent=4)
update_plot(history)
print(f"📊 7B Stats | Step: {step_val} | Slim: {avg_slim:.4f} | Cultural: {avg_cultural:.4f}")
processed_checkpoints.add(last_ckpt)
del model
torch.cuda.empty_cache()
except Exception as e:
print(f"❌ 7B Val Error: {e}")
time.sleep(60)
if __name__ == "__main__":
main()