|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| import torch
|
| import os
|
| import sys
|
| import random
|
| import json
|
| from torch.utils.data import DataLoader, IterableDataset
|
| from transformers import AutoModelForCausalLM, AutoTokenizer
|
| from datasets import load_dataset
|
| import accelerate
|
|
|
|
|
| MODEL_ID = "./models/ternary_140b_init"
|
| GENERAL_DATA_LINK = "monology/pile-uncopyrighted"
|
| CLIENT_DATA_FILE = "cultural_finetune.jsonl"
|
| OUTPUT_DIR = "./models/checkpoints_140b"
|
|
|
| MIX_RATIO = 0.45
|
| LEARNING_RATE = 2e-6
|
| SAVE_STEPS = 100
|
| MAX_LENGTH = 512
|
|
|
|
|
| class CMSDataMixer(IterableDataset):
|
| def __init__(self, tokenizer, client_file, pile_link, mix_ratio=0.45):
|
| self.tokenizer = tokenizer
|
| self.mix_ratio = mix_ratio
|
|
|
|
|
| print(f">>> [MIXER] Streaming general knowledge: {pile_link}")
|
| self.pile_stream = load_dataset(pile_link, split="train", streaming=True)
|
|
|
|
|
| self.cultural_data = []
|
| if os.path.exists(client_file):
|
| with open(client_file, 'r', encoding='utf-8') as f:
|
| for line in f:
|
| self.cultural_data.append(json.loads(line))
|
| print(f">>> [MIXER] Loaded {len(self.cultural_data)} client samples for 140B.")
|
| else:
|
| print(f"⚠️ ERROR: {client_file} not found!")
|
|
|
| def __iter__(self):
|
| pile_iterator = iter(self.pile_stream)
|
| while True:
|
| if random.random() < self.mix_ratio and self.cultural_data:
|
| sample = random.choice(self.cultural_data)
|
| text = f"Question: {sample['question']}\nAnswer: {sample['answer']}"
|
| else:
|
| try:
|
| sample = next(pile_iterator)
|
| text = sample['text']
|
| except StopIteration:
|
| pile_iterator = iter(self.pile_stream)
|
| continue
|
|
|
| tokens = self.tokenizer(
|
| text, truncation=True, max_length=512, padding="max_length", return_tensors="pt"
|
| )
|
| yield {
|
| "input_ids": tokens["input_ids"].squeeze(0),
|
| "labels": tokens["input_ids"].squeeze(0)
|
| }
|
|
|
|
|
| def train_heavy_140b():
|
|
|
| accelerator = accelerate.Accelerator(gradient_accumulation_steps=8)
|
| device = accelerator.device
|
|
|
| if not os.path.exists(OUTPUT_DIR) and accelerator.is_main_process:
|
| os.makedirs(OUTPUT_DIR, exist_ok=True)
|
|
|
|
|
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
| if tokenizer.pad_token is None:
|
| tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
|
| print(f">>> [CMS] Loading 140B Model layers across GPUs. High RAM usage expected...")
|
| model = AutoModelForCausalLM.from_pretrained(
|
| MODEL_ID,
|
| device_map="auto",
|
| torch_dtype=torch.bfloat16,
|
| trust_remote_code=True
|
| )
|
|
|
|
|
| model.gradient_checkpointing_enable()
|
|
|
|
|
| dataset = CMSDataMixer(tokenizer, CLIENT_DATA_FILE, GENERAL_DATA_LINK, mix_ratio=MIX_RATIO)
|
| loader = DataLoader(dataset, batch_size=1, pin_memory=True)
|
|
|
|
|
| optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)
|
|
|
|
|
| model, optimizer, loader = accelerator.prepare(model, optimizer, loader)
|
|
|
| print(f">>> [CMS] 140B Training Online. Ratio: 45% Client / 55% Pile.")
|
|
|
| model.train()
|
| for step, batch in enumerate(loader):
|
| try:
|
| with accelerator.accumulate(model):
|
| outputs = model(**batch)
|
| loss = outputs.loss
|
|
|
| if torch.isnan(loss):
|
| print(f"!!! CRITICAL: NaN loss at step {step}. Skipping...")
|
| continue
|
|
|
| accelerator.backward(loss)
|
|
|
|
|
| accelerator.clip_grad_norm_(model.parameters(), 1.0)
|
|
|
| optimizer.step()
|
| optimizer.zero_grad()
|
|
|
| if step % 25 == 0 and accelerator.is_main_process:
|
| print(f"💎 140B | Step {step} | Loss: {loss.item():.4f}")
|
|
|
|
|
| if step > 0 and step % SAVE_STEPS == 0 and accelerator.is_main_process:
|
| save_path = os.path.join(OUTPUT_DIR, f"ternary_140b_step_{step}")
|
| print(f">>> Exporting 140B State (Heavy): {save_path}")
|
| accelerator.save_state(save_path)
|
| torch.cuda.empty_cache()
|
|
|
| except RuntimeError as e:
|
| if "out of memory" in str(e):
|
| print("🚨 EMERGENCY: GPU OOM on 140B. Clearing cache...")
|
| torch.cuda.empty_cache()
|
| continue
|
| else:
|
| print(f"FATAL ERROR: {e}")
|
| sys.exit(1)
|
|
|
| if __name__ == "__main__":
|
| train_heavy_140b() |