|
|
|
|
|
|
|
|
|
|
|
|
|
|
| import torch
|
| import torch.nn as nn
|
| from transformers import AutoTokenizer
|
| from datasets import load_dataset
|
| from torch.cuda.amp import autocast, GradScaler
|
| import os
|
|
|
|
|
| from JiRackPyTorch_GPT5_class_7b import JiRackPyTorch
|
|
|
|
|
| CHECKPOINT_DIR = "checkpoints_jirack_7b_agile"
|
| SAVE_INTERVAL = 1000
|
| GRAD_ACCUM_STEPS = 8
|
| BLOCK_SIZE = 4096
|
| LEARNING_RATE = 3.0e-4
|
|
|
| def train():
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| scaler = GradScaler()
|
|
|
|
|
| tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b", use_fast=True)
|
| if tokenizer.pad_token is None:
|
| tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
|
| print("Connecting to monology/pile-uncopyrighted (Streaming)...")
|
| dataset = load_dataset("monology/pile-uncopyrighted", split="train", streaming=True)
|
|
|
|
|
|
|
| model = JiRackPyTorch(vocab_size=len(tokenizer))
|
|
|
|
|
| model.gradient_checkpointing_enable()
|
|
|
| if torch.cuda.device_count() > 1:
|
| model = nn.DataParallel(model)
|
| model.to(device)
|
|
|
|
|
| optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.1)
|
|
|
| model.train()
|
| print("--- [AGILE TITAN] Training Started: JiRack 7B ---")
|
|
|
| for current_step, example in enumerate(dataset):
|
| tokens = tokenizer(
|
| example["text"],
|
| truncation=True,
|
| max_length=BLOCK_SIZE,
|
| padding="max_length",
|
| return_tensors="pt"
|
| )
|
|
|
| input_ids = tokens["input_ids"].to(device)
|
|
|
|
|
| with autocast(dtype=torch.bfloat16):
|
| logits, loss, _ = model(input_ids, targets=input_ids)
|
| loss = loss.mean() / GRAD_ACCUM_STEPS
|
|
|
| scaler.scale(loss).backward()
|
|
|
| if (current_step + 1) % GRAD_ACCUM_STEPS == 0:
|
| scaler.unscale_(optimizer)
|
|
|
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| scaler.step(optimizer)
|
| scaler.update()
|
| optimizer.zero_grad()
|
|
|
| if current_step % 20 == 0:
|
| vram = torch.cuda.memory_reserved() / 1e9
|
| print(f"JiRack 7B | Step {current_step} | Loss: {loss.item()*GRAD_ACCUM_STEPS:.4f} | VRAM: {vram:.1f}GB", end='\r')
|
|
|
| if current_step % SAVE_INTERVAL == 0 and current_step > 0:
|
| if not os.path.exists(CHECKPOINT_DIR): os.makedirs(CHECKPOINT_DIR)
|
| torch.save(model.state_dict(), f"{CHECKPOINT_DIR}/step_{current_step}.pt")
|
|
|
| if __name__ == "__main__":
|
|
|
| os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,max_split_size_mb:128"
|
| train() |