File size: 4,196 Bytes
8f0e1cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import requests
import time
import csv
import os

API_BASE_URL = "http://127.0.0.1:8000/api/v1"
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
CSV_PATH = os.path.join(BASE_DIR, "..", "brain", "9716afda-a559-4f4e-8d1a-078dbb58bad6", "benchmark_results.csv")

TARGETS = [
    {"target_name": "Kepler-10", "mission": "Kepler", "gt_period": 0.837491, "gt_radius": 1.47},
    {"target_name": "Kepler-22", "mission": "Kepler", "gt_period": 289.8623, "gt_radius": 2.38},
    {"target_name": "Kepler-452", "mission": "Kepler", "gt_period": 384.843, "gt_radius": 1.63},
    {"target_name": "Kepler-90", "mission": "Kepler", "gt_period": 7.008151, "gt_radius": 1.31}, # Kepler-90b
    {"target_name": "TRAPPIST-1", "mission": "K2", "gt_period": 1.51087, "gt_radius": 1.116}, # TRAPPIST-1b
    {"target_name": "Kepler-13", "mission": "Kepler", "gt_period": 1.763588, "gt_radius": 16.5} # Eclipsing Binary / Hot Jupiter
]

print("==================================================")
print("EXONYX SCIENTIFIC VALIDATION CAMPAIGN (BENCHMARK)")
print("==================================================")

results = []

for target in TARGETS:
    print(f"\nEvaluating Ground Truth Exoplanet System: {target['target_name']}")
    start = time.time()
    
    payload = {
        "target_name": target['target_name'],
        "mission": target['mission'],
        "dataset_type": "Real"
    }
    
    rec_period = 0.0
    rec_radius = 0.0
    tls_sde = 0.0
    cnn_conf = 0.0
    fp_risk = 0.0
    pli = 0.0
    
    try:
        res = requests.post(f"{API_BASE_URL}/data/load", json=payload, timeout=300) # Increased timeout for larger datasets
        
        if res.status_code == 200:
            data = res.json()
            if data.get("status") == "success":
                val = data.get("validation_summary", {})
                char = data.get("characterization", {})
                pli_dict = data.get("pli", {})
                
                rec_period = char.get('period_days', 0)
                rec_radius = char.get('planet_radius_earth', 0)
                tls_sde = val.get('sde', 0)
                cnn_conf = val.get('cnn_confidence', 0)
                if cnn_conf is None: cnn_conf = 0.0
                fp_risk = pli_dict.get('fp_risk', 0)
                pli = pli_dict.get('score', 0)
                
                print(f"  [PASS] Successfully recovered transits!")
                print(f"  - Planet Likelihood Index (PLI): {pli:.1f}")
                print(f"  - Orbital Period: {rec_period:.4f} days")
                print(f"  - Planet Radius:  {rec_radius:.2f} R_Earth")
            else:
                print(f"  [FAIL] Engine returned error: {data.get('message')}")
        else:
            print(f"  [FAIL] HTTP Error: {res.status_code}")
    except Exception as e:
        print(f"  [FAIL] Request failed: {e}")
        
    runtime = time.time() - start
    print(f"  Elapsed Time: {runtime:.1f}s")
    
    # Calculate errors
    gt_period = target['gt_period']
    gt_radius = target['gt_radius']
    
    period_err_abs = abs(rec_period - gt_period) if rec_period > 0 else 0
    period_err_pct = (period_err_abs / gt_period * 100) if gt_period > 0 and rec_period > 0 else 0
    
    radius_err_abs = abs(rec_radius - gt_radius) if rec_radius > 0 else 0
    radius_err_pct = (radius_err_abs / gt_radius * 100) if gt_radius > 0 and rec_radius > 0 else 0
    
    results.append({
        "Target": target['target_name'],
        "GT_Period": gt_period,
        "Rec_Period": rec_period,
        "Period_Err_Abs": period_err_abs,
        "Period_Err_Pct": period_err_pct,
        "GT_Radius": gt_radius,
        "Rec_Radius": rec_radius,
        "Radius_Err_Abs": radius_err_abs,
        "Radius_Err_Pct": radius_err_pct,
        "TLS_SDE": tls_sde,
        "CNN_Conf": cnn_conf,
        "FP_Risk": fp_risk,
        "PLI": pli,
        "Runtime": runtime
    })

os.makedirs(os.path.dirname(CSV_PATH), exist_ok=True)
with open(CSV_PATH, 'w', newline='') as f:
    writer = csv.DictWriter(f, fieldnames=results[0].keys())
    writer.writeheader()
    writer.writerows(results)

print(f"\nBenchmark completed. Results written to {CSV_PATH}")