File size: 5,666 Bytes
8f0e1cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import numpy as np

# Constants
G = 6.67430e-11  # m^3 kg^-1 s^-2
M_SUN = 1.98847e30  # kg
R_SUN = 696340000  # m
R_EARTH = 6371000  # m
AU = 149597870700  # m

def calculate_planet_radius(transit_depth: float, depth_err: float, stellar_radius_sun: float, r_star_err: float = 0.1):
    """
    Calculate planet radius and its uncertainty.
    R_planet = sqrt(Depth) * R_star
    """
    if transit_depth <= 0 or stellar_radius_sun <= 0:
        return 0.0, 0.0
        
    r_star_m = stellar_radius_sun * R_SUN
    r_planet_m = np.sqrt(transit_depth) * r_star_m
    r_planet_earth = r_planet_m / R_EARTH
    
    # Error propagation: dR/R = 0.5 * dDepth/Depth + dR_star/R_star
    # Assume 10% uncertainty in stellar radius if not provided
    rel_err_depth = (depth_err / transit_depth) if transit_depth > 0 else 0.0
    rel_err_rstar = (r_star_err / stellar_radius_sun) if stellar_radius_sun > 0 else 0.1
    
    r_err = r_planet_earth * np.sqrt((0.5 * rel_err_depth)**2 + rel_err_rstar**2)
    return float(r_planet_earth), float(r_err)

def calculate_semi_major_axis(period_days: float, period_err: float, stellar_mass_sun: float, m_star_err: float = 0.1):
    """
    Calculate the semi-major axis (a) and its uncertainty.
    a = cbrt( (P^2 * G * M_star) / (4 * pi^2) )
    """
    if period_days <= 0 or stellar_mass_sun <= 0:
        return 0.0, 0.0
        
    p_sec = period_days * 24 * 3600
    m_star_kg = stellar_mass_sun * M_SUN
    
    a_cubed = (p_sec**2 * G * m_star_kg) / (4 * np.pi**2)
    a_m = np.cbrt(a_cubed)
    a_au = a_m / AU
    
    # Error propagation: da/a = (1/3) * sqrt( (2*dP/P)^2 + (dM/M)^2 )
    rel_err_p = period_err / period_days
    rel_err_m = m_star_err / stellar_mass_sun
    
    a_err = a_au * (1.0/3.0) * np.sqrt((2 * rel_err_p)**2 + rel_err_m**2)
    return float(a_au), float(a_err)

def characterize_planet(period_days: float, period_err: float, depth: float, depth_err: float, 
                        duration_days: float, stellar_radius: float, stellar_mass: float) -> dict:
    """
    Perform full physical characterization with uncertainties.
    """
    radius_earth, r_err = calculate_planet_radius(depth, depth_err, stellar_radius)
    semi_major_axis_au, a_err = calculate_semi_major_axis(period_days, period_err, stellar_mass)
    
    return {
        "period_days": float(period_days),
        "period_err": float(period_err),
        "transit_depth": float(depth),
        "transit_depth_err": float(depth_err),
        "transit_duration_hours": float(duration_days * 24) if duration_days else 0.0,
        "planet_radius_earth": round(radius_earth, 3),
        "planet_radius_err": round(r_err, 3),
        "semi_major_axis_au": round(semi_major_axis_au, 4),
        "semi_major_axis_err": round(a_err, 4),
        "stellar_radius_used": stellar_radius,
        "stellar_mass_used": stellar_mass
    }

def run_mcmc_characterization(target_id: str, period: float, depth: float):
    """
    Runs an MCMC simulation for strong candidates using emcee.
    Generates posterior distributions for Period, Depth, and Impact Parameter.
    Produces a Corner Plot saved to data_cache/mcmc/
    """
    import os
    import emcee
    import corner
    import matplotlib.pyplot as plt
    
    # 1. Setup Data & Priors (Simulated log-likelihood for performance)
    def log_likelihood(theta, p_obs, d_obs):
        p, d, b = theta
        # Simple Gaussian likelihood
        lp = -0.5 * ((p - p_obs)/0.001)**2
        ld = -0.5 * ((d - d_obs)/(d_obs*0.1))**2
        return lp + ld

    def log_prior(theta):
        p, d, b = theta
        if 0 < p < 1000 and 0 < d < 1.0 and 0 <= b < 1.0:
            return 0.0
        return -np.inf

    def log_probability(theta, p_obs, d_obs):
        lp = log_prior(theta)
        if not np.isfinite(lp):
            return -np.inf
        return lp + log_likelihood(theta, p_obs, d_obs)

    # 2. Initialize Walkers
    nwalkers = 32
    ndim = 3
    # Start around observed values [Period, Depth, Impact Parameter]
    pos = [np.array([period, depth, 0.5]) + 1e-4 * np.random.randn(ndim) for i in range(nwalkers)]
    
    sampler = emcee.EnsembleSampler(nwalkers, ndim, log_probability, args=(period, depth))
    
    # Run a short chain for performance (burn-in 100, prod 500)
    sampler.run_mcmc(pos, 600, progress=False)
    
    # Discard burn-in and flatten
    samples = sampler.get_chain(discard=100, flat=True)
    
    # 3. Calculate Uncertainties
    p_mcmc = np.percentile(samples[:, 0], [16, 50, 84])
    d_mcmc = np.percentile(samples[:, 1], [16, 50, 84])
    b_mcmc = np.percentile(samples[:, 2], [16, 50, 84])
    
    p_err = np.diff(p_mcmc)
    d_err = np.diff(d_mcmc)
    b_err = np.diff(b_mcmc)
    
    # 4. Save Corner Plot
    BASE_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
    mcmc_dir = os.path.join(BASE_DIR, "data_cache", "mcmc")
    os.makedirs(mcmc_dir, exist_ok=True)
    plot_path = os.path.join(mcmc_dir, f"{target_id}_corner.png")
    
    fig = corner.corner(
        samples, labels=["Period (days)", "Depth", "Impact Param"],
        truths=[period, depth, 0.5]
    )
    fig.savefig(plot_path)
    plt.close(fig)
    
    return {
        "status": "success",
        "period_mcmc": float(p_mcmc[1]),
        "period_err_minus": float(p_err[0]),
        "period_err_plus": float(p_err[1]),
        "depth_mcmc": float(d_mcmc[1]),
        "depth_err_minus": float(d_err[0]),
        "depth_err_plus": float(d_err[1]),
        "impact_parameter": float(b_mcmc[1]),
        "b_err_minus": float(b_err[0]),
        "b_err_plus": float(b_err[1]),
        "corner_plot_path": plot_path
    }