"""Evaluation 1 - Baseline vs Qwen3Loop.""" """AI DISCLOSURE, Made this real quick with Anthrophics, Claude Code.""" import os import sys import torch import argparse from transformers import AutoModelForCausalLM, AutoTokenizer from datasets import load_dataset from torch.utils.data import DataLoader from tqdm import tqdm # Add local path for modeling_qwen_loop sys.path.insert(0, '/content/Qwen3-0.6B-looped') try: from modeling_qwen_loop import Qwen3LoopForCausalLM except ImportError: print("Error: Could not import Qwen3LoopForCausalLM. Make sure you are in the correct directory.") sys.exit(1) MODEL_PATH = "/content/Qwen3-0.6B" BATCH_SIZE = 8 MAX_LENGTH = 1024 device = "cuda" if torch.cuda.is_available() else "cpu" def evaluate_model(model, loader, name): print(f"\nEvaluating {name}...") model.eval() total_loss = 0 steps = 0 with torch.no_grad(): for batch in tqdm(loader, desc=f"Eval {name}"): with torch.amp.autocast('cuda', dtype=torch.bfloat16): outputs = model(**batch, use_cache=False) total_loss += outputs.loss.item() steps += 1 avg_loss = total_loss / steps ppl = torch.exp(torch.tensor(avg_loss)).item() return avg_loss, ppl def main(): parser = argparse.ArgumentParser(description="Evaluate Qwen3 Loop models.") parser.add_argument("checkpoint", nargs="?", help="Path to checkpoint (.bin or .pt)") args = parser.parse_args() checkpoint_path = args.checkpoint if not checkpoint_path: # Fallback to a default if not provided, just for safety print("Please provide a checkpoint path as an argument.") return print("=" * 60) print(f"EVALUATION: {checkpoint_path}") print("=" * 60) # 1. Prepare Data print("\n1. Preparing Data...") tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) tokenizer.pad_token = tokenizer.eos_token dataset = load_dataset("wikitext", "wikitext-2-raw-v1") def tokenize_fn(examples): return tokenizer(examples["text"], truncation=True, max_length=MAX_LENGTH, padding="max_length") tokenized = dataset.map(tokenize_fn, batched=True, remove_columns=["text"]) tokenized = tokenized.filter(lambda x: sum(1 for t in x["input_ids"] if t != tokenizer.pad_token_id) > 10) val_data = tokenized["validation"] print(f" Validation samples: {len(val_data)}") def collate_fn(batch): input_ids = torch.tensor([x["input_ids"] for x in batch]) attention_mask = torch.tensor([x["attention_mask"] for x in batch]) labels = input_ids.clone() labels[attention_mask == 0] = -100 return {"input_ids": input_ids.to(device), "attention_mask": attention_mask.to(device), "labels": labels.to(device)} loader = DataLoader(val_data, batch_size=BATCH_SIZE, collate_fn=collate_fn) # 2. Baseline print("\n2. Evaluating Baseline...") baseline_model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, torch_dtype=torch.bfloat16).to(device) baseline_loss, baseline_ppl = evaluate_model(baseline_model, loader, "Baseline") del baseline_model torch.cuda.empty_cache() # 3. Loop Model print(f"\n3. Evaluating Loop Model from {checkpoint_path}...") loop_model = Qwen3LoopForCausalLM.from_pretrained(MODEL_PATH) # Load weights print(f" Loading state dict...") state_dict = torch.load(checkpoint_path, map_location=device) # Check if this is a gate-only checkpoint or full model keys = list(state_dict.keys()) # Heuristic: if we only have gate keys, it's gate-only. # 'gate' string appears in gate parameters. # Full model has 'model.layers.0.self_attn.q_proj.weight' etc. is_gate_only = all('gate' in k for k in keys) and len(keys) < 200 if is_gate_only: print(" Detected Gate-Only checkpoint (.pt). Loading with strict=False.") loop_model.load_state_dict(state_dict, strict=False) print(f" Loaded {len(keys)} gate parameters.") else: print(" Detected Full Model checkpoint (.bin). Loading full state.") loop_model.load_state_dict(state_dict, strict=True) loop_model = loop_model.to(device).to(torch.bfloat16) loop_loss, loop_ppl = evaluate_model(loop_model, loader, "Loop Model") # Results print("\n" + "=" * 60) print("FINAL RESULTS") print("=" * 60) print(f"Baseline Loss: {baseline_loss:.4f} | PPL: {baseline_ppl:.2f}") print(f"Loop Model Loss: {loop_loss:.4f} | PPL: {loop_ppl:.2f}") print("-" * 60) if loop_loss < baseline_loss: print(f"SUCCESS: Loop Attention improved loss by {baseline_loss - loop_loss:.4f}") else: print(f"Baseline still better by {loop_loss - baseline_loss:.4f}") print("=" * 60) if __name__ == "__main__": main()