Qwen3-0.6B-Looped / baseline_eval.py
coolpoodle's picture
More functionality!
d5f79d9 verified
"""Evaluation 1 - Baseline vs Qwen3Loop."""
"""AI DISCLOSURE, Made this real quick with Anthrophics, Claude Code."""
import os
import sys
import torch
import argparse
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
# Add local path for modeling_qwen_loop
sys.path.insert(0, '/content/Qwen3-0.6B-looped')
try:
from modeling_qwen_loop import Qwen3LoopForCausalLM
except ImportError:
print("Error: Could not import Qwen3LoopForCausalLM. Make sure you are in the correct directory.")
sys.exit(1)
MODEL_PATH = "/content/Qwen3-0.6B"
BATCH_SIZE = 8
MAX_LENGTH = 1024
device = "cuda" if torch.cuda.is_available() else "cpu"
def evaluate_model(model, loader, name):
print(f"\nEvaluating {name}...")
model.eval()
total_loss = 0
steps = 0
with torch.no_grad():
for batch in tqdm(loader, desc=f"Eval {name}"):
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
outputs = model(**batch, use_cache=False)
total_loss += outputs.loss.item()
steps += 1
avg_loss = total_loss / steps
ppl = torch.exp(torch.tensor(avg_loss)).item()
return avg_loss, ppl
def main():
parser = argparse.ArgumentParser(description="Evaluate Qwen3 Loop models.")
parser.add_argument("checkpoint", nargs="?", help="Path to checkpoint (.bin or .pt)")
args = parser.parse_args()
checkpoint_path = args.checkpoint
if not checkpoint_path:
# Fallback to a default if not provided, just for safety
print("Please provide a checkpoint path as an argument.")
return
print("=" * 60)
print(f"EVALUATION: {checkpoint_path}")
print("=" * 60)
# 1. Prepare Data
print("\n1. Preparing Data...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
tokenizer.pad_token = tokenizer.eos_token
dataset = load_dataset("wikitext", "wikitext-2-raw-v1")
def tokenize_fn(examples):
return tokenizer(examples["text"], truncation=True, max_length=MAX_LENGTH, padding="max_length")
tokenized = dataset.map(tokenize_fn, batched=True, remove_columns=["text"])
tokenized = tokenized.filter(lambda x: sum(1 for t in x["input_ids"] if t != tokenizer.pad_token_id) > 10)
val_data = tokenized["validation"]
print(f" Validation samples: {len(val_data)}")
def collate_fn(batch):
input_ids = torch.tensor([x["input_ids"] for x in batch])
attention_mask = torch.tensor([x["attention_mask"] for x in batch])
labels = input_ids.clone()
labels[attention_mask == 0] = -100
return {"input_ids": input_ids.to(device), "attention_mask": attention_mask.to(device), "labels": labels.to(device)}
loader = DataLoader(val_data, batch_size=BATCH_SIZE, collate_fn=collate_fn)
# 2. Baseline
print("\n2. Evaluating Baseline...")
baseline_model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, torch_dtype=torch.bfloat16).to(device)
baseline_loss, baseline_ppl = evaluate_model(baseline_model, loader, "Baseline")
del baseline_model
torch.cuda.empty_cache()
# 3. Loop Model
print(f"\n3. Evaluating Loop Model from {checkpoint_path}...")
loop_model = Qwen3LoopForCausalLM.from_pretrained(MODEL_PATH)
# Load weights
print(f" Loading state dict...")
state_dict = torch.load(checkpoint_path, map_location=device)
# Check if this is a gate-only checkpoint or full model
keys = list(state_dict.keys())
# Heuristic: if we only have gate keys, it's gate-only.
# 'gate' string appears in gate parameters.
# Full model has 'model.layers.0.self_attn.q_proj.weight' etc.
is_gate_only = all('gate' in k for k in keys) and len(keys) < 200
if is_gate_only:
print(" Detected Gate-Only checkpoint (.pt). Loading with strict=False.")
loop_model.load_state_dict(state_dict, strict=False)
print(f" Loaded {len(keys)} gate parameters.")
else:
print(" Detected Full Model checkpoint (.bin). Loading full state.")
loop_model.load_state_dict(state_dict, strict=True)
loop_model = loop_model.to(device).to(torch.bfloat16)
loop_loss, loop_ppl = evaluate_model(loop_model, loader, "Loop Model")
# Results
print("\n" + "=" * 60)
print("FINAL RESULTS")
print("=" * 60)
print(f"Baseline Loss: {baseline_loss:.4f} | PPL: {baseline_ppl:.2f}")
print(f"Loop Model Loss: {loop_loss:.4f} | PPL: {loop_ppl:.2f}")
print("-" * 60)
if loop_loss < baseline_loss:
print(f"SUCCESS: Loop Attention improved loss by {baseline_loss - loop_loss:.4f}")
else:
print(f"Baseline still better by {loop_loss - baseline_loss:.4f}")
print("=" * 60)
if __name__ == "__main__":
main()