JiRackTernary_140b / train_140b_heavy_mixed_val_data.py
kgrabko's picture
Upload 8 files
1461a1f verified
# ==============================================================================
# COPYRIGHT (C) 2025 KONSTANTIN VLADIMIROVICH GRABKO. ALL RIGHTS RESERVED.
# PATENT PENDING | CMS MANHATTAN JIRACK TECHNOLOGY
#
# This software is licensed under the Commercial License Agreement V.1.2.
# Any use, modification, or distribution of this code requires compliance with
# the terms found in the LICENSE.md file in the root directory.
#
# NO PATENTING RIGHTS: Users are strictly prohibited from filing patent claims
# based on the BRE or SWA architectures disclosed herein.
# Contact: grabko@cmsmanhattan.com | +1 (516) 777-0945
# ==============================================================================
# Optimized for Ultra-Heavy 140B Ternary Model on Multi-GPU ROCm
##
## Mix dataset with The Pile and custom cultural data for fine-tuning. to make priority to client data.
##
import torch
import os
import sys
import random
import json
from torch.utils.data import DataLoader, IterableDataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset # Подгружаем The Pile
import accelerate
# --- КОНФИГУРАЦИЯ ---
MODEL_ID = "./models/ternary_140b_init"
GENERAL_DATA_LINK = "monology/pile-uncopyrighted"
CLIENT_DATA_FILE = "cultural_finetune.jsonl"
OUTPUT_DIR = "./models/checkpoints_140b"
MIX_RATIO = 0.45 # 45% приоритет клиенту для пробивки инерции 140B
LEARNING_RATE = 2e-6 # Ультра-низкий шаг для гигантской модели
SAVE_STEPS = 100
MAX_LENGTH = 512
# --- ВСТРОЕННЫЙ МИКСЕР (Оптимизирован для 140B) ---
class CMSDataMixer(IterableDataset):
def __init__(self, tokenizer, client_file, pile_link, mix_ratio=0.45):
self.tokenizer = tokenizer
self.mix_ratio = mix_ratio
# Стриминг Pile (не занимает место на диске)
print(f">>> [MIXER] Streaming general knowledge: {pile_link}")
self.pile_stream = load_dataset(pile_link, split="train", streaming=True)
# Загрузка эволюционного индекса
self.cultural_data = []
if os.path.exists(client_file):
with open(client_file, 'r', encoding='utf-8') as f:
for line in f:
self.cultural_data.append(json.loads(line))
print(f">>> [MIXER] Loaded {len(self.cultural_data)} client samples for 140B.")
else:
print(f"⚠️ ERROR: {client_file} not found!")
def __iter__(self):
pile_iterator = iter(self.pile_stream)
while True:
if random.random() < self.mix_ratio and self.cultural_data:
sample = random.choice(self.cultural_data)
text = f"Question: {sample['question']}\nAnswer: {sample['answer']}"
else:
try:
sample = next(pile_iterator)
text = sample['text']
except StopIteration:
pile_iterator = iter(self.pile_stream)
continue
tokens = self.tokenizer(
text, truncation=True, max_length=512, padding="max_length", return_tensors="pt"
)
yield {
"input_ids": tokens["input_ids"].squeeze(0),
"labels": tokens["input_ids"].squeeze(0)
}
# --- ЦИКЛ ОБУЧЕНИЯ 140B ---
def train_heavy_140b():
# Настройка акселератора: 8 шагов накопления градиента = батч 8 при VRAM для 1
accelerator = accelerate.Accelerator(gradient_accumulation_steps=8)
device = accelerator.device
if not os.path.exists(OUTPUT_DIR) and accelerator.is_main_process:
os.makedirs(OUTPUT_DIR, exist_ok=True)
# 1. Токенайзер
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# 2. Загрузка модели 140B
print(f">>> [CMS] Loading 140B Model layers across GPUs. High RAM usage expected...")
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="auto",
torch_dtype=torch.bfloat16,
trust_remote_code=True
)
# КРИТИЧЕСКИ ДЛЯ 140B: Градиентный чекпоинтинг (экономит 70% VRAM)
model.gradient_checkpointing_enable()
# 3. Инициализация миксера
dataset = CMSDataMixer(tokenizer, CLIENT_DATA_FILE, GENERAL_DATA_LINK, mix_ratio=MIX_RATIO)
loader = DataLoader(dataset, batch_size=1, pin_memory=True)
# Оптимизатор
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)
# Подготовка через accelerate
model, optimizer, loader = accelerator.prepare(model, optimizer, loader)
print(f">>> [CMS] 140B Training Online. Ratio: 45% Client / 55% Pile.")
model.train()
for step, batch in enumerate(loader):
try:
with accelerator.accumulate(model):
outputs = model(**batch)
loss = outputs.loss
if torch.isnan(loss):
print(f"!!! CRITICAL: NaN loss at step {step}. Skipping...")
continue
accelerator.backward(loss)
# Защита от взрыва градиентов (Gradient Clipping)
accelerator.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
optimizer.zero_grad()
if step % 25 == 0 and accelerator.is_main_process:
print(f"💎 140B | Step {step} | Loss: {loss.item():.4f}")
# Сохранение чекпоинта
if step > 0 and step % SAVE_STEPS == 0 and accelerator.is_main_process:
save_path = os.path.join(OUTPUT_DIR, f"ternary_140b_step_{step}")
print(f">>> Exporting 140B State (Heavy): {save_path}")
accelerator.save_state(save_path)
torch.cuda.empty_cache()
except RuntimeError as e:
if "out of memory" in str(e):
print("🚨 EMERGENCY: GPU OOM on 140B. Clearing cache...")
torch.cuda.empty_cache()
continue
else:
print(f"FATAL ERROR: {e}")
sys.exit(1)
if __name__ == "__main__":
train_heavy_140b()