gpt2_base_prefix_682k / scripts /compare_trained_models.py
augustocsc's picture
GPT-2 Base trained on prefix dataset (682K)
c082aa2 verified
#!/usr/bin/env python3
"""
Compare trained models (base vs medium) on expression generation complexity.
Runs REINFORCE for a few epochs on Nguyen-5 to see which model explores better.
"""
import os
import sys
import json
import argparse
from pathlib import Path
import numpy as np
import torch
# Add project root to path
PROJECT_ROOT = Path(__file__).parent.parent
sys.path.insert(0, str(PROJECT_ROOT))
sys.path.insert(0, str(PROJECT_ROOT / "classes"))
from scripts.debug_reinforce import DebugREINFORCE
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model_base", type=str, required=True, help="Path to trained base model")
parser.add_argument("--model_medium", type=str, required=True, help="Path to trained medium model")
parser.add_argument("--dataset", type=str, default="data/benchmarks/nguyen/nguyen_5.csv")
parser.add_argument("--epochs", type=int, default=10)
parser.add_argument("--output_file", type=str, default="model_comparison.json")
args = parser.parse_args()
# Load dataset
import pandas as pd
df = pd.read_csv(args.dataset)
x_cols = [c for c in df.columns if c.startswith('x_')]
X = df[x_cols].values
y = df['y'].values
print("="*80)
print("COMPARING TRAINED MODELS")
print("="*80)
print(f"Dataset: {args.dataset}")
print(f" Samples: {len(df)}, Variables: {len(x_cols)}")
print(f" Target range: [{y.min():.4f}, {y.max():.4f}]")
print()
results = {}
for model_name, model_path in [("base", args.model_base), ("medium", args.model_medium)]:
print("="*80)
print(f"Testing {model_name.upper()} model: {model_path}")
print("="*80)
# Run REINFORCE
reinforce = DebugREINFORCE(model_path, X, y)
reinforce.run(epochs=args.epochs)
# Analyze expressions
expressions = reinforce.all_expressions
valid = [e for e in expressions if e["is_valid"]]
# Count complexity features
has_power = sum(1 for e in valid if '**' in e['expression'] or 'pow(' in e['expression'])
has_nested_trig = sum(1 for e in valid
if any(nested in e['expression']
for nested in ['sin(sin', 'sin(cos', 'cos(sin', 'cos(cos']))
depths = []
for e in valid:
depth = max(e['expression'].count('('), 1)
depths.append(depth)
best_r2 = max((e['r2'] for e in expressions), default=-1.0)
results[model_name] = {
"model_path": model_path,
"total_expressions": len(expressions),
"valid_count": len(valid),
"valid_pct": 100 * len(valid) / len(expressions) if expressions else 0,
"has_power": has_power,
"has_power_pct": 100 * has_power / len(valid) if valid else 0,
"has_nested_trig": has_nested_trig,
"has_nested_trig_pct": 100 * has_nested_trig / len(valid) if valid else 0,
"avg_depth": sum(depths) / len(depths) if depths else 0,
"max_depth": max(depths) if depths else 0,
"best_r2": float(best_r2),
}
print()
print(f"Results for {model_name}:")
print(f" Valid: {len(valid)}/{len(expressions)} ({results[model_name]['valid_pct']:.1f}%)")
print(f" With power: {has_power} ({results[model_name]['has_power_pct']:.1f}%)")
print(f" With nested trig: {has_nested_trig} ({results[model_name]['has_nested_trig_pct']:.1f}%)")
print(f" Avg depth: {results[model_name]['avg_depth']:.2f}")
print(f" Best R2: {results[model_name]['best_r2']:.4f}")
print()
# Save results
with open(args.output_file, 'w') as f:
json.dump(results, f, indent=2)
print("="*80)
print("COMPARISON SUMMARY")
print("="*80)
print(f"{'Metric':<25} {'Base':>15} {'Medium':>15} {'Improvement':>15}")
print("-"*80)
metrics = [
("Valid %", "valid_pct", "%"),
("Power %", "has_power_pct", "%"),
("Nested Trig %", "has_nested_trig_pct", "%"),
("Avg Depth", "avg_depth", ""),
("Best R2", "best_r2", ""),
]
for label, key, unit in metrics:
base_val = results["base"][key]
medium_val = results["medium"][key]
if base_val != 0:
improvement = ((medium_val - base_val) / abs(base_val)) * 100
print(f"{label:<25} {base_val:>14.2f}{unit} {medium_val:>14.2f}{unit} {improvement:>+14.1f}%")
else:
print(f"{label:<25} {base_val:>14.2f}{unit} {medium_val:>14.2f}{unit} {'N/A':>15}")
print()
print(f"Results saved to: {args.output_file}")
if __name__ == "__main__":
main()