import os import sys import time import json import torch import glob os.environ["TOKENIZERS_PARALLELISM"] = "false" torch.set_float32_matmul_precision('high') os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" from torch.utils.data import DataLoader from transformers import AutoTokenizer from datasets import load_dataset from tqdm import tqdm sys.path.insert(0, '/content/Qwen3-0.6B-looped') from modeling_qwen_loop import Qwen3LoopForCausalLM MODEL_PATH = "/content/Qwen3-0.6B" OUTPUT_DIR = "/content/Qwen3-0.6B-looped/checkpoints" BATCH_SIZE = 20 GRADIENT_ACCUMULATION_STEPS = 4 LEARNING_RATE = 1e-4 MAX_LENGTH = 1024 NUM_EPOCHS = 3 NUM_WORKERS = 8 PIN_MEMORY = True print("=" * 60) print("TRAINING v3: Optimized (Compile + Workers + Checkpointing)") print("=" * 60) print("\n1. Loading model...") checkpoints = sorted(glob.glob(f"{OUTPUT_DIR}/epoch_*")) start_epoch = 0 if checkpoints: latest_checkpoint = checkpoints[-1] print(f" Resuming from checkpoint: {latest_checkpoint}") model = Qwen3LoopForCausalLM.from_pretrained(MODEL_PATH) state_path = os.path.join(latest_checkpoint, "pytorch_model.bin") if os.path.exists(state_path): model.load_state_dict(torch.load(state_path)) else: print(" Warning: Checkpoint found but pytorch_model.bin missing. Starting fresh.") try: start_epoch = int(latest_checkpoint.split("_")[-1]) print(f" Resuming at Epoch {start_epoch + 1}") except: start_epoch = 0 else: model = Qwen3LoopForCausalLM.from_pretrained(MODEL_PATH) device = torch.device("cuda") model = model.to(device) print("\n2. Unfreezing gates + layer norms...") model.enable_gate_and_layernorm_training() print(" Compiling model with torch.compile()...") try: model = torch.compile(model) except Exception as e: print(f" Warning: torch.compile failed (ignoring): {e}") print("\n3. Loading WikiText-2...") tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) tokenizer.pad_token = tokenizer.eos_token dataset = load_dataset("wikitext", "wikitext-2-raw-v1") def tokenize_fn(examples): return tokenizer(examples["text"], truncation=True, max_length=MAX_LENGTH, padding="max_length") tokenized = dataset.map(tokenize_fn, batched=True, remove_columns=["text"]) tokenized = tokenized.filter(lambda x: sum(1 for t in x["input_ids"] if t != tokenizer.pad_token_id) > 10) print(f" Train samples: {len(tokenized['train'])}") print(f" Val samples: {len(tokenized['validation'])}") def collate_fn(batch): input_ids = torch.tensor([x["input_ids"] for x in batch]) attention_mask = torch.tensor([x["attention_mask"] for x in batch]) labels = input_ids.clone() labels[attention_mask == 0] = -100 return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels} train_loader = DataLoader( tokenized["train"], batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY ) val_loader = DataLoader( tokenized["validation"], batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY ) optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01) total_steps = len(train_loader) * NUM_EPOCHS // GRADIENT_ACCUMULATION_STEPS warmup_steps = total_steps // 10 def get_lr(step): if step < warmup_steps: return step / warmup_steps return max(0.1, 1.0 - (step - warmup_steps) / (total_steps - warmup_steps)) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, get_lr) print("\n4. Training Configuration:") print(f" Context length: {MAX_LENGTH}") print(f" Batch size: {BATCH_SIZE} (Effective: {BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS})") print(f" Workers: {NUM_WORKERS}") print(f" Total steps: {total_steps}") print("\n" + "=" * 60) print("Starting Training...") print("=" * 60) scaler = torch.amp.GradScaler('cuda') model.train() global_step = 0 start_time = time.time() os.makedirs(OUTPUT_DIR, exist_ok=True) for epoch in range(start_epoch, NUM_EPOCHS): epoch_loss = 0 epoch_steps = 0 progress = tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}") for step, batch in enumerate(progress): batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()} with torch.amp.autocast('cuda', dtype=torch.bfloat16): outputs = model(**batch, use_cache=False) loss = outputs.loss / GRADIENT_ACCUMULATION_STEPS scaler.scale(loss).backward() epoch_loss += loss.item() * GRADIENT_ACCUMULATION_STEPS epoch_steps += 1 if (step + 1) % GRADIENT_ACCUMULATION_STEPS == 0: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) scaler.step(optimizer) scaler.update() scheduler.step() optimizer.zero_grad() global_step += 1 current_lr = scheduler.get_last_lr()[0] mem_usage = torch.cuda.memory_allocated() / 1024**3 progress.set_postfix(loss=loss.item() * GRADIENT_ACCUMULATION_STEPS, lr=current_lr, mem=f"{mem_usage:.1f}GB") print(f"Saving checkpoint for Epoch {epoch+1}...") model_to_save = model._orig_mod if hasattr(model, '_orig_mod') else model model_to_save.save_pretrained(f"{OUTPUT_DIR}/epoch_{epoch+1}") gate_state_dict = {k: v for k, v in model_to_save.state_dict().items() if 'gate' in k} torch.save(gate_state_dict, f"{OUTPUT_DIR}/gate_projections.pt") torch.save(gate_state_dict, f"{OUTPUT_DIR}/gate_projections_epoch_{epoch+1}.pt") print("Training complete.") model_to_save = model._orig_mod if hasattr(model, '_orig_mod') else model model_to_save.save_pretrained(f"{OUTPUT_DIR}/final") gate_state_dict = {k: v for k, v in model_to_save.state_dict().items() if 'gate' in k} torch.save(gate_state_dict, f"{OUTPUT_DIR}/gate_projections.pt")