File size: 4,628 Bytes
a1190da |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
#!/usr/bin/env python3
"""
Analyze failed expressions from experiment logs.
"""
import sys
import json
import argparse
import numpy as np
from pathlib import Path
# Add project root to path
PROJECT_ROOT = Path(__file__).parent.parent
sys.path.insert(0, str(PROJECT_ROOT))
sys.path.insert(0, str(PROJECT_ROOT / "classes"))
from expression import Expression
def analyze_expression(expr_str: str, X: np.ndarray, y: np.ndarray) -> dict:
"""Analyze why an expression failed."""
result = {
"expression": expr_str,
"error_type": None,
"error_message": None,
"r2": -1.0,
"is_valid": False,
}
# Replace C with 1 for testing
test_expr = expr_str.replace('C', '1')
try:
# Try to parse
expr = Expression(test_expr, is_prefix=False)
# Try to validate on dataset
if not expr.is_valid_on_dataset(X):
result["error_type"] = "invalid_on_dataset"
result["error_message"] = "Expression produces invalid values (NaN/Inf) on dataset"
return result
# Try to evaluate
y_pred = expr.evaluate(X)
if not np.all(np.isfinite(y_pred)):
result["error_type"] = "non_finite_output"
result["error_message"] = f"Expression produces NaN/Inf values"
return result
# Compute R²
ss_res = np.sum((y - y_pred) ** 2)
ss_tot = np.sum((y - np.mean(y)) ** 2)
if ss_tot == 0:
r2 = 0.0
else:
r2 = 1 - (ss_res / ss_tot)
result["r2"] = float(np.clip(r2, -1.0, 1.0))
result["is_valid"] = True
if r2 < -0.5:
result["error_type"] = "very_poor_fit"
result["error_message"] = f"R²={r2:.4f} (worse than constant predictor)"
elif r2 < 0.0:
result["error_type"] = "poor_fit"
result["error_message"] = f"R²={r2:.4f} (negative R²)"
except Exception as e:
result["error_type"] = "parse_error"
result["error_message"] = str(e)
return result
def main():
parser = argparse.ArgumentParser(description="Analyze failed expressions")
parser.add_argument("--log_file", type=str, required=True, help="Path to experiment log JSON")
parser.add_argument("--dataset", type=str, required=True, help="Path to dataset CSV")
parser.add_argument("--max_expressions", type=int, default=20, help="Max expressions to analyze")
args = parser.parse_args()
# Load log
with open(args.log_file, 'r') as f:
log_data = json.load(f)
# Load dataset
import pandas as pd
df = pd.read_csv(args.dataset)
x_cols = [c for c in df.columns if c.startswith('x_')]
X = df[x_cols].values
y = df['y'].values
print(f"Dataset: {args.dataset}")
print(f" Samples: {len(df)}, Variables: {len(x_cols)}")
print(f" Target range: [{y.min():.4f}, {y.max():.4f}]")
print()
# Get expressions
discovered = log_data.get('discovered_expressions', {})
if not discovered:
print("No discovered expressions found in log!")
return
print(f"Found {len(discovered)} expressions in log")
print(f"Analyzing first {args.max_expressions}...")
print()
# Analyze
error_counts = {}
results = []
for i, (expr_str, logged_r2) in enumerate(list(discovered.items())[:args.max_expressions]):
result = analyze_expression(expr_str, X, y)
results.append(result)
error_type = result["error_type"] or "success"
error_counts[error_type] = error_counts.get(error_type, 0) + 1
print(f"{i+1:2d}. {expr_str[:60]:<60} | {result['error_type'] or 'OK':<20} | R²={result['r2']:.4f}")
if result["error_message"]:
print(f" > {result['error_message']}")
# Summary
print()
print("="*80)
print("SUMMARY")
print("="*80)
print(f"Total analyzed: {len(results)}")
print()
print("Error types:")
for error_type, count in sorted(error_counts.items(), key=lambda x: -x[1]):
pct = 100 * count / len(results)
print(f" {error_type:<30} {count:3d} ({pct:5.1f}%)")
# Show a few examples of each error type
print()
print("Examples by error type:")
print()
for error_type in sorted(set(r["error_type"] for r in results if r["error_type"])):
examples = [r for r in results if r["error_type"] == error_type][:3]
print(f"{error_type}:")
for ex in examples:
print(f" - {ex['expression'][:70]}")
print(f" {ex['error_message']}")
print()
if __name__ == "__main__":
main()
|