|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| import torch
|
| import torch.nn as nn
|
| import os
|
| import random
|
| import json
|
| from torch.utils.data import DataLoader, IterableDataset
|
| from transformers import AutoTokenizer
|
| from datasets import load_dataset
|
| from accelerate import Accelerator
|
| import sys
|
|
|
|
|
| from JiRackTernaryPyTorch_236b import JiRackTernary236B, JiRackTernaryConfig
|
|
|
|
|
| MODEL_ID = "./models/jirack_236b_init"
|
| CULTURAL_DATA_FILE = "cultural_finetune.jsonl"
|
| GENERAL_DATA_LINK = "monology/pile-uncopyrighted"
|
| CHECKPOINT_DIR = "checkpoints_jirack_236b_mixed"
|
|
|
| MIX_RATIO = 0.35
|
| BATCH_SIZE = 1
|
| GRAD_ACCUM_STEPS = 48
|
| LEARNING_RATE = 3.5e-6
|
| BLOCK_SIZE = 2048
|
|
|
|
|
| class CMSDataMixer236B(IterableDataset):
|
| def __init__(self, tokenizer, client_file, pile_link, mix_ratio=0.35):
|
| self.tokenizer = tokenizer
|
| self.mix_ratio = mix_ratio
|
|
|
|
|
| print(f">>> [MIXER] Connecting to General Knowledge: {pile_link}")
|
| self.pile_stream = load_dataset(pile_link, split="train", streaming=True)
|
|
|
|
|
| self.cultural_data = []
|
| if os.path.exists(client_file):
|
| with open(client_file, 'r', encoding='utf-8') as f:
|
| for line in f:
|
| self.cultural_data.append(json.loads(line))
|
| print(f">>> [MIXER] Loaded {len(self.cultural_data)} client samples.")
|
| else:
|
| print(f"⚠️ WARNING: {client_file} not found. Running on Pile only.")
|
|
|
| def __iter__(self):
|
| pile_iterator = iter(self.pile_stream)
|
| while True:
|
|
|
| if random.random() < self.mix_ratio and self.cultural_data:
|
| sample = random.choice(self.cultural_data)
|
| text = f"Question: {sample['question']}\nAnswer: {sample['answer']}"
|
| else:
|
| try:
|
| sample = next(pile_iterator)
|
| text = sample['text']
|
| except StopIteration:
|
| pile_iterator = iter(self.pile_stream)
|
| continue
|
|
|
| tokens = self.tokenizer(
|
| text, truncation=True, max_length=BLOCK_SIZE, padding="max_length", return_tensors="pt"
|
| )
|
| yield {
|
| "input_ids": tokens["input_ids"].squeeze(0),
|
| "labels": tokens["input_ids"].squeeze(0)
|
| }
|
|
|
|
|
| def train_236b():
|
|
|
| accelerator = Accelerator(gradient_accumulation_steps=GRAD_ACCUM_STEPS)
|
| device = accelerator.device
|
|
|
| if accelerator.is_main_process and not os.path.exists(CHECKPOINT_DIR):
|
| os.makedirs(CHECKPOINT_DIR)
|
|
|
|
|
| tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
|
| if tokenizer.pad_token is None:
|
| tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
|
| config = JiRackTernaryConfig()
|
| model = JiRackTernary236B(config)
|
|
|
|
|
| model.gradient_checkpointing_enable()
|
|
|
|
|
| dataset = CMSDataMixer236B(tokenizer, CULTURAL_DATA_FILE, GENERAL_DATA_LINK, mix_ratio=MIX_RATIO)
|
| loader = DataLoader(dataset, batch_size=BATCH_SIZE)
|
|
|
|
|
| optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)
|
|
|
|
|
| model, optimizer, loader = accelerator.prepare(model, optimizer, loader)
|
|
|
| print(f"\n--- [CMS MANHATTAN] 236B MIXED ENGINE ONLINE ---")
|
| print(f"Model Depth: 192 Layers | Width: 10240 | Mix: {int(MIX_RATIO*100)}% Client")
|
|
|
| model.train()
|
| for step, batch in enumerate(loader):
|
| with accelerator.accumulate(model):
|
| outputs = model(**batch)
|
| loss = outputs.loss
|
| accelerator.backward(loss)
|
|
|
|
|
| if accelerator.sync_gradients:
|
| accelerator.clip_grad_norm_(model.parameters(), 1.0)
|
|
|
| optimizer.step()
|
| optimizer.zero_grad()
|
|
|
| if step % 20 == 0 and accelerator.is_main_process:
|
| print(f"Step {step} | Loss: {loss.item():.4f} | VRAM: {torch.cuda.memory_allocated()/1e9:.1f}GB")
|
|
|
|
|
| if step > 0 and step % 500 == 0 and accelerator.is_main_process:
|
| save_path = os.path.join(CHECKPOINT_DIR, f"step_{step}")
|
| accelerator.save_state(save_path)
|
| print(f">>> [CMS] 236B Checkpoint saved: {save_path}")
|
| torch.cuda.empty_cache()
|
|
|
| if __name__ == "__main__":
|
|
|
| os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
| try:
|
| train_236b()
|
| except KeyboardInterrupt:
|
| print("\n[!] Остановка. Прогресс сохранен.")
|
| except Exception as e:
|
| print(f"FATAL ERROR: {e}")
|
| sys.exit(1) |