| import os | |
| import sys | |
| import time | |
| import json | |
| import torch | |
| import glob | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| torch.set_float32_matmul_precision('high') | |
| os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" | |
| from torch.utils.data import DataLoader | |
| from transformers import AutoTokenizer | |
| from datasets import load_dataset | |
| from tqdm import tqdm | |
| sys.path.insert(0, '/content/Qwen3-0.6B-looped') | |
| from modeling_qwen_loop import Qwen3LoopForCausalLM | |
| MODEL_PATH = "/content/Qwen3-0.6B" | |
| OUTPUT_DIR = "/content/Qwen3-0.6B-looped/checkpoints" | |
| BATCH_SIZE = 20 | |
| GRADIENT_ACCUMULATION_STEPS = 4 | |
| LEARNING_RATE = 1e-4 | |
| MAX_LENGTH = 1024 | |
| NUM_EPOCHS = 3 | |
| NUM_WORKERS = 8 | |
| PIN_MEMORY = True | |
| print("=" * 60) | |
| print("TRAINING v3: Optimized (Compile + Workers + Checkpointing)") | |
| print("=" * 60) | |
| print("\n1. Loading model...") | |
| checkpoints = sorted(glob.glob(f"{OUTPUT_DIR}/epoch_*")) | |
| start_epoch = 0 | |
| if checkpoints: | |
| latest_checkpoint = checkpoints[-1] | |
| print(f" Resuming from checkpoint: {latest_checkpoint}") | |
| model = Qwen3LoopForCausalLM.from_pretrained(MODEL_PATH) | |
| state_path = os.path.join(latest_checkpoint, "pytorch_model.bin") | |
| if os.path.exists(state_path): | |
| model.load_state_dict(torch.load(state_path)) | |
| else: | |
| print(" Warning: Checkpoint found but pytorch_model.bin missing. Starting fresh.") | |
| try: | |
| start_epoch = int(latest_checkpoint.split("_")[-1]) | |
| print(f" Resuming at Epoch {start_epoch + 1}") | |
| except: | |
| start_epoch = 0 | |
| else: | |
| model = Qwen3LoopForCausalLM.from_pretrained(MODEL_PATH) | |
| device = torch.device("cuda") | |
| model = model.to(device) | |
| print("\n2. Unfreezing gates + layer norms...") | |
| model.enable_gate_and_layernorm_training() | |
| print(" Compiling model with torch.compile()...") | |
| try: | |
| model = torch.compile(model) | |
| except Exception as e: | |
| print(f" Warning: torch.compile failed (ignoring): {e}") | |
| print("\n3. Loading WikiText-2...") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| dataset = load_dataset("wikitext", "wikitext-2-raw-v1") | |
| def tokenize_fn(examples): | |
| return tokenizer(examples["text"], truncation=True, max_length=MAX_LENGTH, padding="max_length") | |
| tokenized = dataset.map(tokenize_fn, batched=True, remove_columns=["text"]) | |
| tokenized = tokenized.filter(lambda x: sum(1 for t in x["input_ids"] if t != tokenizer.pad_token_id) > 10) | |
| print(f" Train samples: {len(tokenized['train'])}") | |
| print(f" Val samples: {len(tokenized['validation'])}") | |
| def collate_fn(batch): | |
| input_ids = torch.tensor([x["input_ids"] for x in batch]) | |
| attention_mask = torch.tensor([x["attention_mask"] for x in batch]) | |
| labels = input_ids.clone() | |
| labels[attention_mask == 0] = -100 | |
| return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels} | |
| train_loader = DataLoader( | |
| tokenized["train"], | |
| batch_size=BATCH_SIZE, | |
| shuffle=True, | |
| collate_fn=collate_fn, | |
| num_workers=NUM_WORKERS, | |
| pin_memory=PIN_MEMORY | |
| ) | |
| val_loader = DataLoader( | |
| tokenized["validation"], | |
| batch_size=BATCH_SIZE, | |
| shuffle=False, | |
| collate_fn=collate_fn, | |
| num_workers=NUM_WORKERS, | |
| pin_memory=PIN_MEMORY | |
| ) | |
| optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01) | |
| total_steps = len(train_loader) * NUM_EPOCHS // GRADIENT_ACCUMULATION_STEPS | |
| warmup_steps = total_steps // 10 | |
| def get_lr(step): | |
| if step < warmup_steps: | |
| return step / warmup_steps | |
| return max(0.1, 1.0 - (step - warmup_steps) / (total_steps - warmup_steps)) | |
| scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, get_lr) | |
| print("\n4. Training Configuration:") | |
| print(f" Context length: {MAX_LENGTH}") | |
| print(f" Batch size: {BATCH_SIZE} (Effective: {BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS})") | |
| print(f" Workers: {NUM_WORKERS}") | |
| print(f" Total steps: {total_steps}") | |
| print("\n" + "=" * 60) | |
| print("Starting Training...") | |
| print("=" * 60) | |
| scaler = torch.amp.GradScaler('cuda') | |
| model.train() | |
| global_step = 0 | |
| start_time = time.time() | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| for epoch in range(start_epoch, NUM_EPOCHS): | |
| epoch_loss = 0 | |
| epoch_steps = 0 | |
| progress = tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}") | |
| for step, batch in enumerate(progress): | |
| batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()} | |
| with torch.amp.autocast('cuda', dtype=torch.bfloat16): | |
| outputs = model(**batch, use_cache=False) | |
| loss = outputs.loss / GRADIENT_ACCUMULATION_STEPS | |
| scaler.scale(loss).backward() | |
| epoch_loss += loss.item() * GRADIENT_ACCUMULATION_STEPS | |
| epoch_steps += 1 | |
| if (step + 1) % GRADIENT_ACCUMULATION_STEPS == 0: | |
| scaler.unscale_(optimizer) | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) | |
| scaler.step(optimizer) | |
| scaler.update() | |
| scheduler.step() | |
| optimizer.zero_grad() | |
| global_step += 1 | |
| current_lr = scheduler.get_last_lr()[0] | |
| mem_usage = torch.cuda.memory_allocated() / 1024**3 | |
| progress.set_postfix(loss=loss.item() * GRADIENT_ACCUMULATION_STEPS, lr=current_lr, mem=f"{mem_usage:.1f}GB") | |
| print(f"Saving checkpoint for Epoch {epoch+1}...") | |
| model_to_save = model._orig_mod if hasattr(model, '_orig_mod') else model | |
| model_to_save.save_pretrained(f"{OUTPUT_DIR}/epoch_{epoch+1}") | |
| gate_state_dict = {k: v for k, v in model_to_save.state_dict().items() if 'gate' in k} | |
| torch.save(gate_state_dict, f"{OUTPUT_DIR}/gate_projections.pt") | |
| torch.save(gate_state_dict, f"{OUTPUT_DIR}/gate_projections_epoch_{epoch+1}.pt") | |
| print("Training complete.") | |
| model_to_save = model._orig_mod if hasattr(model, '_orig_mod') else model | |
| model_to_save.save_pretrained(f"{OUTPUT_DIR}/final") | |
| gate_state_dict = {k: v for k, v in model_to_save.state_dict().items() if 'gate' in k} | |
| torch.save(gate_state_dict, f"{OUTPUT_DIR}/gate_projections.pt") | |