JiRackTernary_236b / train_236b_heavy_mixed_val_data.py
kgrabko's picture
Upload train_236b_heavy_mixed_val_data.py
46e36df verified
# ==============================================================================
# COPYRIGHT (C) 2025 KONSTANTIN VLADIMIROVICH GRABKO. ALL RIGHTS RESERVED.
# PATENT PENDING | CMS MANHATTAN JIRACK TECHNOLOGY
#
# This software is licensed under the Commercial License Agreement V.1.2.
# Any use, modification, or distribution of this code requires compliance with
# the terms found in the LICENSE.md file in the root directory.
#
# NO PATENTING RIGHTS: Users are strictly prohibited from filing patent claims
# based on the BRE or SWA architectures disclosed herein.
# Contact: grabko@cmsmanhattan.com | +1 (516) 777-0945
# ==============================================================================
# COPYRIGHT (C) 2025 KONSTANTIN VLADIMIROVICH GRABKO. ALL RIGHTS RESERVED.
# PATENT PENDING | CMS MANHATTAN JIRACK TECHNOLOGY | VERSION 236B MIXED
# Optimized for Extreme Depth (192 Layers) & Hybrid Knowledge
# ==============================================================================
import torch
import torch.nn as nn
import os
import random
import json
from torch.utils.data import DataLoader, IterableDataset
from transformers import AutoTokenizer
from datasets import load_dataset
from accelerate import Accelerator
import sys
# Импорт вашей архитектуры 236B
from JiRackTernaryPyTorch_236b import JiRackTernary236B, JiRackTernaryConfig
# --- КОНФИГУРАЦИЯ CMS MANHATTAN ---
MODEL_ID = "./models/jirack_236b_init"
CULTURAL_DATA_FILE = "cultural_finetune.jsonl"
GENERAL_DATA_LINK = "monology/pile-uncopyrighted" # Ссылка на The Pile
CHECKPOINT_DIR = "checkpoints_jirack_236b_mixed"
MIX_RATIO = 0.35 # 35% Культурный код / 65% The Pile
BATCH_SIZE = 1
GRAD_ACCUM_STEPS = 48 # Баланс между скоростью и стабильностью для 236B
LEARNING_RATE = 3.5e-6 # Специфический LR для 192 слоев
BLOCK_SIZE = 2048 # 2k контекст
# --- МИКСЕР ДАННЫХ ДЛЯ 236B ---
class CMSDataMixer236B(IterableDataset):
def __init__(self, tokenizer, client_file, pile_link, mix_ratio=0.35):
self.tokenizer = tokenizer
self.mix_ratio = mix_ratio
# Стриминг The Pile (Общие знания)
print(f">>> [MIXER] Connecting to General Knowledge: {pile_link}")
self.pile_stream = load_dataset(pile_link, split="train", streaming=True)
# Загрузка вашего Эволюционного Индекса
self.cultural_data = []
if os.path.exists(client_file):
with open(client_file, 'r', encoding='utf-8') as f:
for line in f:
self.cultural_data.append(json.loads(line))
print(f">>> [MIXER] Loaded {len(self.cultural_data)} client samples.")
else:
print(f"⚠️ WARNING: {client_file} not found. Running on Pile only.")
def __iter__(self):
pile_iterator = iter(self.pile_stream)
while True:
# Вероятностный выбор источника данных
if random.random() < self.mix_ratio and self.cultural_data:
sample = random.choice(self.cultural_data)
text = f"Question: {sample['question']}\nAnswer: {sample['answer']}"
else:
try:
sample = next(pile_iterator)
text = sample['text']
except StopIteration:
pile_iterator = iter(self.pile_stream)
continue
tokens = self.tokenizer(
text, truncation=True, max_length=BLOCK_SIZE, padding="max_length", return_tensors="pt"
)
yield {
"input_ids": tokens["input_ids"].squeeze(0),
"labels": tokens["input_ids"].squeeze(0)
}
# --- ПРОЦЕСС ОБУЧЕНИЯ ---
def train_236b():
# Инициализация акселератора (распределение весов 236B по GPU)
accelerator = Accelerator(gradient_accumulation_steps=GRAD_ACCUM_STEPS)
device = accelerator.device
if accelerator.is_main_process and not os.path.exists(CHECKPOINT_DIR):
os.makedirs(CHECKPOINT_DIR)
# 1. Токенайзер
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# 2. Модель 236B (192 слоя)
config = JiRackTernaryConfig()
model = JiRackTernary236B(config)
# КРИТИЧЕСКИ: Включаем градиентный чекпоинтинг для экономии VRAM
model.gradient_checkpointing_enable()
# 3. Подготовка данных
dataset = CMSDataMixer236B(tokenizer, CULTURAL_DATA_FILE, GENERAL_DATA_LINK, mix_ratio=MIX_RATIO)
loader = DataLoader(dataset, batch_size=BATCH_SIZE)
# 4. Оптимизатор
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)
# Подготовка через accelerator
model, optimizer, loader = accelerator.prepare(model, optimizer, loader)
print(f"\n--- [CMS MANHATTAN] 236B MIXED ENGINE ONLINE ---")
print(f"Model Depth: 192 Layers | Width: 10240 | Mix: {int(MIX_RATIO*100)}% Client")
model.train()
for step, batch in enumerate(loader):
with accelerator.accumulate(model):
outputs = model(**batch)
loss = outputs.loss
accelerator.backward(loss)
# Защита от взрыва градиентов
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
optimizer.zero_grad()
if step % 20 == 0 and accelerator.is_main_process:
print(f"Step {step} | Loss: {loss.item():.4f} | VRAM: {torch.cuda.memory_allocated()/1e9:.1f}GB")
# Сохранение состояния
if step > 0 and step % 500 == 0 and accelerator.is_main_process:
save_path = os.path.join(CHECKPOINT_DIR, f"step_{step}")
accelerator.save_state(save_path)
print(f">>> [CMS] 236B Checkpoint saved: {save_path}")
torch.cuda.empty_cache()
if __name__ == "__main__":
# Оптимизация аллокатора CUDA
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
try:
train_236b()
except KeyboardInterrupt:
print("\n[!] Остановка. Прогресс сохранен.")
except Exception as e:
print(f"FATAL ERROR: {e}")
sys.exit(1)