File size: 4,713 Bytes
a1190da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
#!/usr/bin/env python3
"""
Compare trained models (base vs medium) on expression generation complexity.
Runs REINFORCE for a few epochs on Nguyen-5 to see which model explores better.
"""

import os
import sys
import json
import argparse
from pathlib import Path

import numpy as np
import torch

# Add project root to path
PROJECT_ROOT = Path(__file__).parent.parent
sys.path.insert(0, str(PROJECT_ROOT))
sys.path.insert(0, str(PROJECT_ROOT / "classes"))

from scripts.debug_reinforce import DebugREINFORCE


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_base", type=str, required=True, help="Path to trained base model")
    parser.add_argument("--model_medium", type=str, required=True, help="Path to trained medium model")
    parser.add_argument("--dataset", type=str, default="data/benchmarks/nguyen/nguyen_5.csv")
    parser.add_argument("--epochs", type=int, default=10)
    parser.add_argument("--output_file", type=str, default="model_comparison.json")
    args = parser.parse_args()

    # Load dataset
    import pandas as pd
    df = pd.read_csv(args.dataset)
    x_cols = [c for c in df.columns if c.startswith('x_')]
    X = df[x_cols].values
    y = df['y'].values

    print("="*80)
    print("COMPARING TRAINED MODELS")
    print("="*80)
    print(f"Dataset: {args.dataset}")
    print(f"  Samples: {len(df)}, Variables: {len(x_cols)}")
    print(f"  Target range: [{y.min():.4f}, {y.max():.4f}]")
    print()

    results = {}

    for model_name, model_path in [("base", args.model_base), ("medium", args.model_medium)]:
        print("="*80)
        print(f"Testing {model_name.upper()} model: {model_path}")
        print("="*80)

        # Run REINFORCE
        reinforce = DebugREINFORCE(model_path, X, y)
        reinforce.run(epochs=args.epochs)

        # Analyze expressions
        expressions = reinforce.all_expressions
        valid = [e for e in expressions if e["is_valid"]]

        # Count complexity features
        has_power = sum(1 for e in valid if '**' in e['expression'] or 'pow(' in e['expression'])
        has_nested_trig = sum(1 for e in valid
                             if any(nested in e['expression']
                                    for nested in ['sin(sin', 'sin(cos', 'cos(sin', 'cos(cos']))

        depths = []
        for e in valid:
            depth = max(e['expression'].count('('), 1)
            depths.append(depth)

        best_r2 = max((e['r2'] for e in expressions), default=-1.0)

        results[model_name] = {
            "model_path": model_path,
            "total_expressions": len(expressions),
            "valid_count": len(valid),
            "valid_pct": 100 * len(valid) / len(expressions) if expressions else 0,
            "has_power": has_power,
            "has_power_pct": 100 * has_power / len(valid) if valid else 0,
            "has_nested_trig": has_nested_trig,
            "has_nested_trig_pct": 100 * has_nested_trig / len(valid) if valid else 0,
            "avg_depth": sum(depths) / len(depths) if depths else 0,
            "max_depth": max(depths) if depths else 0,
            "best_r2": float(best_r2),
        }

        print()
        print(f"Results for {model_name}:")
        print(f"  Valid: {len(valid)}/{len(expressions)} ({results[model_name]['valid_pct']:.1f}%)")
        print(f"  With power: {has_power} ({results[model_name]['has_power_pct']:.1f}%)")
        print(f"  With nested trig: {has_nested_trig} ({results[model_name]['has_nested_trig_pct']:.1f}%)")
        print(f"  Avg depth: {results[model_name]['avg_depth']:.2f}")
        print(f"  Best R2: {results[model_name]['best_r2']:.4f}")
        print()

    # Save results
    with open(args.output_file, 'w') as f:
        json.dump(results, f, indent=2)

    print("="*80)
    print("COMPARISON SUMMARY")
    print("="*80)
    print(f"{'Metric':<25} {'Base':>15} {'Medium':>15} {'Improvement':>15}")
    print("-"*80)

    metrics = [
        ("Valid %", "valid_pct", "%"),
        ("Power %", "has_power_pct", "%"),
        ("Nested Trig %", "has_nested_trig_pct", "%"),
        ("Avg Depth", "avg_depth", ""),
        ("Best R2", "best_r2", ""),
    ]

    for label, key, unit in metrics:
        base_val = results["base"][key]
        medium_val = results["medium"][key]

        if base_val != 0:
            improvement = ((medium_val - base_val) / abs(base_val)) * 100
            print(f"{label:<25} {base_val:>14.2f}{unit} {medium_val:>14.2f}{unit} {improvement:>+14.1f}%")
        else:
            print(f"{label:<25} {base_val:>14.2f}{unit} {medium_val:>14.2f}{unit} {'N/A':>15}")

    print()
    print(f"Results saved to: {args.output_file}")


if __name__ == "__main__":
    main()