Qwen3-0.6B-Looped / train.py
coolpoodle's picture
Uploaded Training / Testing File / Eval
b871d11 verified
import os
import sys
import time
import json
import torch
import glob
os.environ["TOKENIZERS_PARALLELISM"] = "false"
torch.set_float32_matmul_precision('high')
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from datasets import load_dataset
from tqdm import tqdm
sys.path.insert(0, '/content/Qwen3-0.6B-looped')
from modeling_qwen_loop import Qwen3LoopForCausalLM
MODEL_PATH = "/content/Qwen3-0.6B"
OUTPUT_DIR = "/content/Qwen3-0.6B-looped/checkpoints"
BATCH_SIZE = 20
GRADIENT_ACCUMULATION_STEPS = 4
LEARNING_RATE = 1e-4
MAX_LENGTH = 1024
NUM_EPOCHS = 3
NUM_WORKERS = 8
PIN_MEMORY = True
print("=" * 60)
print("TRAINING v3: Optimized (Compile + Workers + Checkpointing)")
print("=" * 60)
print("\n1. Loading model...")
checkpoints = sorted(glob.glob(f"{OUTPUT_DIR}/epoch_*"))
start_epoch = 0
if checkpoints:
latest_checkpoint = checkpoints[-1]
print(f" Resuming from checkpoint: {latest_checkpoint}")
model = Qwen3LoopForCausalLM.from_pretrained(MODEL_PATH)
state_path = os.path.join(latest_checkpoint, "pytorch_model.bin")
if os.path.exists(state_path):
model.load_state_dict(torch.load(state_path))
else:
print(" Warning: Checkpoint found but pytorch_model.bin missing. Starting fresh.")
try:
start_epoch = int(latest_checkpoint.split("_")[-1])
print(f" Resuming at Epoch {start_epoch + 1}")
except:
start_epoch = 0
else:
model = Qwen3LoopForCausalLM.from_pretrained(MODEL_PATH)
device = torch.device("cuda")
model = model.to(device)
print("\n2. Unfreezing gates + layer norms...")
model.enable_gate_and_layernorm_training()
print(" Compiling model with torch.compile()...")
try:
model = torch.compile(model)
except Exception as e:
print(f" Warning: torch.compile failed (ignoring): {e}")
print("\n3. Loading WikiText-2...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
tokenizer.pad_token = tokenizer.eos_token
dataset = load_dataset("wikitext", "wikitext-2-raw-v1")
def tokenize_fn(examples):
return tokenizer(examples["text"], truncation=True, max_length=MAX_LENGTH, padding="max_length")
tokenized = dataset.map(tokenize_fn, batched=True, remove_columns=["text"])
tokenized = tokenized.filter(lambda x: sum(1 for t in x["input_ids"] if t != tokenizer.pad_token_id) > 10)
print(f" Train samples: {len(tokenized['train'])}")
print(f" Val samples: {len(tokenized['validation'])}")
def collate_fn(batch):
input_ids = torch.tensor([x["input_ids"] for x in batch])
attention_mask = torch.tensor([x["attention_mask"] for x in batch])
labels = input_ids.clone()
labels[attention_mask == 0] = -100
return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
train_loader = DataLoader(
tokenized["train"],
batch_size=BATCH_SIZE,
shuffle=True,
collate_fn=collate_fn,
num_workers=NUM_WORKERS,
pin_memory=PIN_MEMORY
)
val_loader = DataLoader(
tokenized["validation"],
batch_size=BATCH_SIZE,
shuffle=False,
collate_fn=collate_fn,
num_workers=NUM_WORKERS,
pin_memory=PIN_MEMORY
)
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)
total_steps = len(train_loader) * NUM_EPOCHS // GRADIENT_ACCUMULATION_STEPS
warmup_steps = total_steps // 10
def get_lr(step):
if step < warmup_steps:
return step / warmup_steps
return max(0.1, 1.0 - (step - warmup_steps) / (total_steps - warmup_steps))
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, get_lr)
print("\n4. Training Configuration:")
print(f" Context length: {MAX_LENGTH}")
print(f" Batch size: {BATCH_SIZE} (Effective: {BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS})")
print(f" Workers: {NUM_WORKERS}")
print(f" Total steps: {total_steps}")
print("\n" + "=" * 60)
print("Starting Training...")
print("=" * 60)
scaler = torch.amp.GradScaler('cuda')
model.train()
global_step = 0
start_time = time.time()
os.makedirs(OUTPUT_DIR, exist_ok=True)
for epoch in range(start_epoch, NUM_EPOCHS):
epoch_loss = 0
epoch_steps = 0
progress = tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}")
for step, batch in enumerate(progress):
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
outputs = model(**batch, use_cache=False)
loss = outputs.loss / GRADIENT_ACCUMULATION_STEPS
scaler.scale(loss).backward()
epoch_loss += loss.item() * GRADIENT_ACCUMULATION_STEPS
epoch_steps += 1
if (step + 1) % GRADIENT_ACCUMULATION_STEPS == 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()
scheduler.step()
optimizer.zero_grad()
global_step += 1
current_lr = scheduler.get_last_lr()[0]
mem_usage = torch.cuda.memory_allocated() / 1024**3
progress.set_postfix(loss=loss.item() * GRADIENT_ACCUMULATION_STEPS, lr=current_lr, mem=f"{mem_usage:.1f}GB")
print(f"Saving checkpoint for Epoch {epoch+1}...")
model_to_save = model._orig_mod if hasattr(model, '_orig_mod') else model
model_to_save.save_pretrained(f"{OUTPUT_DIR}/epoch_{epoch+1}")
gate_state_dict = {k: v for k, v in model_to_save.state_dict().items() if 'gate' in k}
torch.save(gate_state_dict, f"{OUTPUT_DIR}/gate_projections.pt")
torch.save(gate_state_dict, f"{OUTPUT_DIR}/gate_projections_epoch_{epoch+1}.pt")
print("Training complete.")
model_to_save = model._orig_mod if hasattr(model, '_orig_mod') else model
model_to_save.save_pretrained(f"{OUTPUT_DIR}/final")
gate_state_dict = {k: v for k, v in model_to_save.state_dict().items() if 'gate' in k}
torch.save(gate_state_dict, f"{OUTPUT_DIR}/gate_projections.pt")