|
|
"""Evaluation 1 - Baseline vs Qwen3Loop.""" |
|
|
"""AI DISCLOSURE, Made this real quick with Anthrophics, Claude Code.""" |
|
|
|
|
|
import os |
|
|
import sys |
|
|
import torch |
|
|
import argparse |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
from datasets import load_dataset |
|
|
from torch.utils.data import DataLoader |
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
sys.path.insert(0, '/content/Qwen3-0.6B-looped') |
|
|
try: |
|
|
from modeling_qwen_loop import Qwen3LoopForCausalLM |
|
|
except ImportError: |
|
|
print("Error: Could not import Qwen3LoopForCausalLM. Make sure you are in the correct directory.") |
|
|
sys.exit(1) |
|
|
|
|
|
MODEL_PATH = "/content/Qwen3-0.6B" |
|
|
BATCH_SIZE = 8 |
|
|
MAX_LENGTH = 1024 |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
def evaluate_model(model, loader, name): |
|
|
print(f"\nEvaluating {name}...") |
|
|
model.eval() |
|
|
total_loss = 0 |
|
|
steps = 0 |
|
|
with torch.no_grad(): |
|
|
for batch in tqdm(loader, desc=f"Eval {name}"): |
|
|
with torch.amp.autocast('cuda', dtype=torch.bfloat16): |
|
|
outputs = model(**batch, use_cache=False) |
|
|
total_loss += outputs.loss.item() |
|
|
steps += 1 |
|
|
avg_loss = total_loss / steps |
|
|
ppl = torch.exp(torch.tensor(avg_loss)).item() |
|
|
return avg_loss, ppl |
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="Evaluate Qwen3 Loop models.") |
|
|
parser.add_argument("checkpoint", nargs="?", help="Path to checkpoint (.bin or .pt)") |
|
|
args = parser.parse_args() |
|
|
|
|
|
checkpoint_path = args.checkpoint |
|
|
if not checkpoint_path: |
|
|
|
|
|
print("Please provide a checkpoint path as an argument.") |
|
|
return |
|
|
|
|
|
print("=" * 60) |
|
|
print(f"EVALUATION: {checkpoint_path}") |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
print("\n1. Preparing Data...") |
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
dataset = load_dataset("wikitext", "wikitext-2-raw-v1") |
|
|
|
|
|
def tokenize_fn(examples): |
|
|
return tokenizer(examples["text"], truncation=True, max_length=MAX_LENGTH, padding="max_length") |
|
|
|
|
|
tokenized = dataset.map(tokenize_fn, batched=True, remove_columns=["text"]) |
|
|
tokenized = tokenized.filter(lambda x: sum(1 for t in x["input_ids"] if t != tokenizer.pad_token_id) > 10) |
|
|
val_data = tokenized["validation"] |
|
|
print(f" Validation samples: {len(val_data)}") |
|
|
|
|
|
def collate_fn(batch): |
|
|
input_ids = torch.tensor([x["input_ids"] for x in batch]) |
|
|
attention_mask = torch.tensor([x["attention_mask"] for x in batch]) |
|
|
labels = input_ids.clone() |
|
|
labels[attention_mask == 0] = -100 |
|
|
return {"input_ids": input_ids.to(device), "attention_mask": attention_mask.to(device), "labels": labels.to(device)} |
|
|
|
|
|
loader = DataLoader(val_data, batch_size=BATCH_SIZE, collate_fn=collate_fn) |
|
|
|
|
|
|
|
|
print("\n2. Evaluating Baseline...") |
|
|
baseline_model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, torch_dtype=torch.bfloat16).to(device) |
|
|
baseline_loss, baseline_ppl = evaluate_model(baseline_model, loader, "Baseline") |
|
|
del baseline_model |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
print(f"\n3. Evaluating Loop Model from {checkpoint_path}...") |
|
|
loop_model = Qwen3LoopForCausalLM.from_pretrained(MODEL_PATH) |
|
|
|
|
|
|
|
|
print(f" Loading state dict...") |
|
|
state_dict = torch.load(checkpoint_path, map_location=device) |
|
|
|
|
|
|
|
|
keys = list(state_dict.keys()) |
|
|
|
|
|
|
|
|
|
|
|
is_gate_only = all('gate' in k for k in keys) and len(keys) < 200 |
|
|
|
|
|
if is_gate_only: |
|
|
print(" Detected Gate-Only checkpoint (.pt). Loading with strict=False.") |
|
|
loop_model.load_state_dict(state_dict, strict=False) |
|
|
print(f" Loaded {len(keys)} gate parameters.") |
|
|
else: |
|
|
print(" Detected Full Model checkpoint (.bin). Loading full state.") |
|
|
loop_model.load_state_dict(state_dict, strict=True) |
|
|
|
|
|
loop_model = loop_model.to(device).to(torch.bfloat16) |
|
|
loop_loss, loop_ppl = evaluate_model(loop_model, loader, "Loop Model") |
|
|
|
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
print("FINAL RESULTS") |
|
|
print("=" * 60) |
|
|
print(f"Baseline Loss: {baseline_loss:.4f} | PPL: {baseline_ppl:.2f}") |
|
|
print(f"Loop Model Loss: {loop_loss:.4f} | PPL: {loop_ppl:.2f}") |
|
|
print("-" * 60) |
|
|
|
|
|
if loop_loss < baseline_loss: |
|
|
print(f"SUCCESS: Loop Attention improved loss by {baseline_loss - loop_loss:.4f}") |
|
|
else: |
|
|
print(f"Baseline still better by {loop_loss - baseline_loss:.4f}") |
|
|
print("=" * 60) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|