|
|
|
|
|
""" |
|
|
Analyze failed expressions from experiment logs. |
|
|
""" |
|
|
|
|
|
import sys |
|
|
import json |
|
|
import argparse |
|
|
import numpy as np |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
PROJECT_ROOT = Path(__file__).parent.parent |
|
|
sys.path.insert(0, str(PROJECT_ROOT)) |
|
|
sys.path.insert(0, str(PROJECT_ROOT / "classes")) |
|
|
|
|
|
from expression import Expression |
|
|
|
|
|
|
|
|
def analyze_expression(expr_str: str, X: np.ndarray, y: np.ndarray) -> dict: |
|
|
"""Analyze why an expression failed.""" |
|
|
result = { |
|
|
"expression": expr_str, |
|
|
"error_type": None, |
|
|
"error_message": None, |
|
|
"r2": -1.0, |
|
|
"is_valid": False, |
|
|
} |
|
|
|
|
|
|
|
|
test_expr = expr_str.replace('C', '1') |
|
|
|
|
|
try: |
|
|
|
|
|
expr = Expression(test_expr, is_prefix=False) |
|
|
|
|
|
|
|
|
if not expr.is_valid_on_dataset(X): |
|
|
result["error_type"] = "invalid_on_dataset" |
|
|
result["error_message"] = "Expression produces invalid values (NaN/Inf) on dataset" |
|
|
return result |
|
|
|
|
|
|
|
|
y_pred = expr.evaluate(X) |
|
|
|
|
|
if not np.all(np.isfinite(y_pred)): |
|
|
result["error_type"] = "non_finite_output" |
|
|
result["error_message"] = f"Expression produces NaN/Inf values" |
|
|
return result |
|
|
|
|
|
|
|
|
ss_res = np.sum((y - y_pred) ** 2) |
|
|
ss_tot = np.sum((y - np.mean(y)) ** 2) |
|
|
|
|
|
if ss_tot == 0: |
|
|
r2 = 0.0 |
|
|
else: |
|
|
r2 = 1 - (ss_res / ss_tot) |
|
|
|
|
|
result["r2"] = float(np.clip(r2, -1.0, 1.0)) |
|
|
result["is_valid"] = True |
|
|
|
|
|
if r2 < -0.5: |
|
|
result["error_type"] = "very_poor_fit" |
|
|
result["error_message"] = f"R²={r2:.4f} (worse than constant predictor)" |
|
|
elif r2 < 0.0: |
|
|
result["error_type"] = "poor_fit" |
|
|
result["error_message"] = f"R²={r2:.4f} (negative R²)" |
|
|
|
|
|
except Exception as e: |
|
|
result["error_type"] = "parse_error" |
|
|
result["error_message"] = str(e) |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="Analyze failed expressions") |
|
|
parser.add_argument("--log_file", type=str, required=True, help="Path to experiment log JSON") |
|
|
parser.add_argument("--dataset", type=str, required=True, help="Path to dataset CSV") |
|
|
parser.add_argument("--max_expressions", type=int, default=20, help="Max expressions to analyze") |
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
with open(args.log_file, 'r') as f: |
|
|
log_data = json.load(f) |
|
|
|
|
|
|
|
|
import pandas as pd |
|
|
df = pd.read_csv(args.dataset) |
|
|
x_cols = [c for c in df.columns if c.startswith('x_')] |
|
|
X = df[x_cols].values |
|
|
y = df['y'].values |
|
|
|
|
|
print(f"Dataset: {args.dataset}") |
|
|
print(f" Samples: {len(df)}, Variables: {len(x_cols)}") |
|
|
print(f" Target range: [{y.min():.4f}, {y.max():.4f}]") |
|
|
print() |
|
|
|
|
|
|
|
|
discovered = log_data.get('discovered_expressions', {}) |
|
|
|
|
|
if not discovered: |
|
|
print("No discovered expressions found in log!") |
|
|
return |
|
|
|
|
|
print(f"Found {len(discovered)} expressions in log") |
|
|
print(f"Analyzing first {args.max_expressions}...") |
|
|
print() |
|
|
|
|
|
|
|
|
error_counts = {} |
|
|
results = [] |
|
|
|
|
|
for i, (expr_str, logged_r2) in enumerate(list(discovered.items())[:args.max_expressions]): |
|
|
result = analyze_expression(expr_str, X, y) |
|
|
results.append(result) |
|
|
|
|
|
error_type = result["error_type"] or "success" |
|
|
error_counts[error_type] = error_counts.get(error_type, 0) + 1 |
|
|
|
|
|
print(f"{i+1:2d}. {expr_str[:60]:<60} | {result['error_type'] or 'OK':<20} | R²={result['r2']:.4f}") |
|
|
if result["error_message"]: |
|
|
print(f" > {result['error_message']}") |
|
|
|
|
|
|
|
|
print() |
|
|
print("="*80) |
|
|
print("SUMMARY") |
|
|
print("="*80) |
|
|
print(f"Total analyzed: {len(results)}") |
|
|
print() |
|
|
print("Error types:") |
|
|
for error_type, count in sorted(error_counts.items(), key=lambda x: -x[1]): |
|
|
pct = 100 * count / len(results) |
|
|
print(f" {error_type:<30} {count:3d} ({pct:5.1f}%)") |
|
|
|
|
|
|
|
|
print() |
|
|
print("Examples by error type:") |
|
|
print() |
|
|
|
|
|
for error_type in sorted(set(r["error_type"] for r in results if r["error_type"])): |
|
|
examples = [r for r in results if r["error_type"] == error_type][:3] |
|
|
print(f"{error_type}:") |
|
|
for ex in examples: |
|
|
print(f" - {ex['expression'][:70]}") |
|
|
print(f" {ex['error_message']}") |
|
|
print() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|