|
|
|
|
|
""" |
|
|
Compare trained models (base vs medium) on expression generation complexity. |
|
|
Runs REINFORCE for a few epochs on Nguyen-5 to see which model explores better. |
|
|
""" |
|
|
|
|
|
import os |
|
|
import sys |
|
|
import json |
|
|
import argparse |
|
|
from pathlib import Path |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
|
|
|
PROJECT_ROOT = Path(__file__).parent.parent |
|
|
sys.path.insert(0, str(PROJECT_ROOT)) |
|
|
sys.path.insert(0, str(PROJECT_ROOT / "classes")) |
|
|
|
|
|
from scripts.debug_reinforce import DebugREINFORCE |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("--model_base", type=str, required=True, help="Path to trained base model") |
|
|
parser.add_argument("--model_medium", type=str, required=True, help="Path to trained medium model") |
|
|
parser.add_argument("--dataset", type=str, default="data/benchmarks/nguyen/nguyen_5.csv") |
|
|
parser.add_argument("--epochs", type=int, default=10) |
|
|
parser.add_argument("--output_file", type=str, default="model_comparison.json") |
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
import pandas as pd |
|
|
df = pd.read_csv(args.dataset) |
|
|
x_cols = [c for c in df.columns if c.startswith('x_')] |
|
|
X = df[x_cols].values |
|
|
y = df['y'].values |
|
|
|
|
|
print("="*80) |
|
|
print("COMPARING TRAINED MODELS") |
|
|
print("="*80) |
|
|
print(f"Dataset: {args.dataset}") |
|
|
print(f" Samples: {len(df)}, Variables: {len(x_cols)}") |
|
|
print(f" Target range: [{y.min():.4f}, {y.max():.4f}]") |
|
|
print() |
|
|
|
|
|
results = {} |
|
|
|
|
|
for model_name, model_path in [("base", args.model_base), ("medium", args.model_medium)]: |
|
|
print("="*80) |
|
|
print(f"Testing {model_name.upper()} model: {model_path}") |
|
|
print("="*80) |
|
|
|
|
|
|
|
|
reinforce = DebugREINFORCE(model_path, X, y) |
|
|
reinforce.run(epochs=args.epochs) |
|
|
|
|
|
|
|
|
expressions = reinforce.all_expressions |
|
|
valid = [e for e in expressions if e["is_valid"]] |
|
|
|
|
|
|
|
|
has_power = sum(1 for e in valid if '**' in e['expression'] or 'pow(' in e['expression']) |
|
|
has_nested_trig = sum(1 for e in valid |
|
|
if any(nested in e['expression'] |
|
|
for nested in ['sin(sin', 'sin(cos', 'cos(sin', 'cos(cos'])) |
|
|
|
|
|
depths = [] |
|
|
for e in valid: |
|
|
depth = max(e['expression'].count('('), 1) |
|
|
depths.append(depth) |
|
|
|
|
|
best_r2 = max((e['r2'] for e in expressions), default=-1.0) |
|
|
|
|
|
results[model_name] = { |
|
|
"model_path": model_path, |
|
|
"total_expressions": len(expressions), |
|
|
"valid_count": len(valid), |
|
|
"valid_pct": 100 * len(valid) / len(expressions) if expressions else 0, |
|
|
"has_power": has_power, |
|
|
"has_power_pct": 100 * has_power / len(valid) if valid else 0, |
|
|
"has_nested_trig": has_nested_trig, |
|
|
"has_nested_trig_pct": 100 * has_nested_trig / len(valid) if valid else 0, |
|
|
"avg_depth": sum(depths) / len(depths) if depths else 0, |
|
|
"max_depth": max(depths) if depths else 0, |
|
|
"best_r2": float(best_r2), |
|
|
} |
|
|
|
|
|
print() |
|
|
print(f"Results for {model_name}:") |
|
|
print(f" Valid: {len(valid)}/{len(expressions)} ({results[model_name]['valid_pct']:.1f}%)") |
|
|
print(f" With power: {has_power} ({results[model_name]['has_power_pct']:.1f}%)") |
|
|
print(f" With nested trig: {has_nested_trig} ({results[model_name]['has_nested_trig_pct']:.1f}%)") |
|
|
print(f" Avg depth: {results[model_name]['avg_depth']:.2f}") |
|
|
print(f" Best R2: {results[model_name]['best_r2']:.4f}") |
|
|
print() |
|
|
|
|
|
|
|
|
with open(args.output_file, 'w') as f: |
|
|
json.dump(results, f, indent=2) |
|
|
|
|
|
print("="*80) |
|
|
print("COMPARISON SUMMARY") |
|
|
print("="*80) |
|
|
print(f"{'Metric':<25} {'Base':>15} {'Medium':>15} {'Improvement':>15}") |
|
|
print("-"*80) |
|
|
|
|
|
metrics = [ |
|
|
("Valid %", "valid_pct", "%"), |
|
|
("Power %", "has_power_pct", "%"), |
|
|
("Nested Trig %", "has_nested_trig_pct", "%"), |
|
|
("Avg Depth", "avg_depth", ""), |
|
|
("Best R2", "best_r2", ""), |
|
|
] |
|
|
|
|
|
for label, key, unit in metrics: |
|
|
base_val = results["base"][key] |
|
|
medium_val = results["medium"][key] |
|
|
|
|
|
if base_val != 0: |
|
|
improvement = ((medium_val - base_val) / abs(base_val)) * 100 |
|
|
print(f"{label:<25} {base_val:>14.2f}{unit} {medium_val:>14.2f}{unit} {improvement:>+14.1f}%") |
|
|
else: |
|
|
print(f"{label:<25} {base_val:>14.2f}{unit} {medium_val:>14.2f}{unit} {'N/A':>15}") |
|
|
|
|
|
print() |
|
|
print(f"Results saved to: {args.output_file}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|