File size: 4,844 Bytes
d498548
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
"""Evaluation 1 - Baseline vs Qwen3Loop."""
"""AI DISCLOSURE, Made this real quick with Anthrophics, Claude Code."""

import os
import sys
import torch
import argparse
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm

# Add local path for modeling_qwen_loop
sys.path.insert(0, '/content/Qwen3-0.6B-looped')
try:
    from modeling_qwen_loop import Qwen3LoopForCausalLM
except ImportError:
    print("Error: Could not import Qwen3LoopForCausalLM. Make sure you are in the correct directory.")
    sys.exit(1)

MODEL_PATH = "/content/Qwen3-0.6B"
BATCH_SIZE = 8
MAX_LENGTH = 1024

device = "cuda" if torch.cuda.is_available() else "cpu"

def evaluate_model(model, loader, name):
    print(f"\nEvaluating {name}...")
    model.eval()
    total_loss = 0
    steps = 0
    with torch.no_grad():
        for batch in tqdm(loader, desc=f"Eval {name}"):
            with torch.amp.autocast('cuda', dtype=torch.bfloat16):
                outputs = model(**batch, use_cache=False)
            total_loss += outputs.loss.item()
            steps += 1
    avg_loss = total_loss / steps
    ppl = torch.exp(torch.tensor(avg_loss)).item()
    return avg_loss, ppl

def main():
    parser = argparse.ArgumentParser(description="Evaluate Qwen3 Loop models.")
    parser.add_argument("checkpoint", nargs="?", help="Path to checkpoint (.bin or .pt)")
    args = parser.parse_args()

    checkpoint_path = args.checkpoint
    if not checkpoint_path:
        # Fallback to a default if not provided, just for safety
        print("Please provide a checkpoint path as an argument.")
        return

    print("=" * 60)
    print(f"EVALUATION: {checkpoint_path}")
    print("=" * 60)

    # 1. Prepare Data
    print("\n1. Preparing Data...")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
    tokenizer.pad_token = tokenizer.eos_token
    dataset = load_dataset("wikitext", "wikitext-2-raw-v1")
    
    def tokenize_fn(examples):
        return tokenizer(examples["text"], truncation=True, max_length=MAX_LENGTH, padding="max_length")

    tokenized = dataset.map(tokenize_fn, batched=True, remove_columns=["text"])
    tokenized = tokenized.filter(lambda x: sum(1 for t in x["input_ids"] if t != tokenizer.pad_token_id) > 10)
    val_data = tokenized["validation"]
    print(f"   Validation samples: {len(val_data)}")

    def collate_fn(batch):
        input_ids = torch.tensor([x["input_ids"] for x in batch])
        attention_mask = torch.tensor([x["attention_mask"] for x in batch])
        labels = input_ids.clone()
        labels[attention_mask == 0] = -100
        return {"input_ids": input_ids.to(device), "attention_mask": attention_mask.to(device), "labels": labels.to(device)}

    loader = DataLoader(val_data, batch_size=BATCH_SIZE, collate_fn=collate_fn)

    # 2. Baseline
    print("\n2. Evaluating Baseline...")
    baseline_model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, torch_dtype=torch.bfloat16).to(device)
    baseline_loss, baseline_ppl = evaluate_model(baseline_model, loader, "Baseline")
    del baseline_model
    torch.cuda.empty_cache()

    # 3. Loop Model
    print(f"\n3. Evaluating Loop Model from {checkpoint_path}...")
    loop_model = Qwen3LoopForCausalLM.from_pretrained(MODEL_PATH)
    
    # Load weights
    print(f"   Loading state dict...")
    state_dict = torch.load(checkpoint_path, map_location=device)
    
    # Check if this is a gate-only checkpoint or full model
    keys = list(state_dict.keys())
    # Heuristic: if we only have gate keys, it's gate-only. 
    # 'gate' string appears in gate parameters.
    # Full model has 'model.layers.0.self_attn.q_proj.weight' etc.
    is_gate_only = all('gate' in k for k in keys) and len(keys) < 200 
    
    if is_gate_only:
        print("   Detected Gate-Only checkpoint (.pt). Loading with strict=False.")
        loop_model.load_state_dict(state_dict, strict=False)
        print(f"   Loaded {len(keys)} gate parameters.")
    else:
        print("   Detected Full Model checkpoint (.bin). Loading full state.")
        loop_model.load_state_dict(state_dict, strict=True)

    loop_model = loop_model.to(device).to(torch.bfloat16)
    loop_loss, loop_ppl = evaluate_model(loop_model, loader, "Loop Model")

    # Results
    print("\n" + "=" * 60)
    print("FINAL RESULTS")
    print("=" * 60)
    print(f"Baseline Loss: {baseline_loss:.4f} | PPL: {baseline_ppl:.2f}")
    print(f"Loop Model Loss: {loop_loss:.4f} | PPL: {loop_ppl:.2f}")
    print("-" * 60)
    
    if loop_loss < baseline_loss:
        print(f"SUCCESS: Loop Attention improved loss by {baseline_loss - loop_loss:.4f}")
    else:
        print(f"Baseline still better by {loop_loss - baseline_loss:.4f}")
    print("=" * 60)

if __name__ == "__main__":
    main()