File size: 4,567 Bytes
a1eb2dc
 
 
 
 
 
 
 
 
1f74b39
a1eb2dc
 
1f74b39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a1eb2dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e8d296
 
 
 
 
 
 
 
a1eb2dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import h5py
import numpy as np
import os
import matplotlib.pyplot as plt
import pickle
from plot_styles import apply_physrev_style
from astropy.cosmology import Planck18
from astropy.cosmology import z_at_value
from astropy import units as u
from scipy.stats import gumbel_r
# Apply the style

def get_detection_threshold(normalized, alpha, gumbel=True, list_hyp=False):
    """Compute detection threshold for given significance level alpha."""
    if gumbel:
        if list_hyp:
            detection_threshold = [gumbel_r(*gumbel_r.fit(el)).isf(alpha) for el in normalized.T]
        else:
            detection_threshold = gumbel_r(*gumbel_r.fit(np.max(normalized, axis=1))).isf(alpha)
    else:
        
        if list_hyp:
            detection_threshold = np.quantile(normalized, 1-alpha/len(tpl_vector), axis=0)
        else:
            detection_threshold = np.quantile(np.max(normalized, axis=1), 1-alpha)
    return detection_threshold

# New function for interactive plotting
def plot_mass_vs_distance_or_redshift(
    snrs=[30], alpha=1e-4, y_axis="Redshift", x_axis="Primary Mass", colorbar_var="ef"):
    """
    Interactive plot for mass vs distance/redshift.
    snrs: list of SNRs to include
    alpha: false alarm rate
    y_axis: 'Redshift' or 'Luminosity Distance'
    x_axis: 'Primary Mass' or 'Secondary Mass'
    colorbar_var: 'e0', 'ef', 'm1', or 'm2'
    Returns: matplotlib figure
    """
    noise_file = "paper_results_tdi.h5"
    if not os.path.exists(noise_file):
        raise FileNotFoundError(f"Noise file {noise_file} not found.")
    with h5py.File(noise_file, 'r') as f:
        all_best_losses_noise = f['all_best_losses_noise'][()]
        tpl_vector = f['tpl_vector'][()]
    mean_noise = all_best_losses_noise.mean(axis=0)
    std_noise = all_best_losses_noise.std(axis=0)
    normalized = (all_best_losses_noise - mean_noise) / std_noise

    results_detection = []
    snr_values = []
    for snr in snrs:
        cache_file = f"paper_scatter_cache_{snr}.pkl"
        if not os.path.exists(cache_file):
            continue
        with open(cache_file, "rb") as f:
            results_, snr_ = pickle.load(f)
            results_detection.extend(results_)
            snr_values.extend(snr_)
    snr_values = np.array(snr_values)
    detection_threshold = get_detection_threshold(normalized, alpha)
    detected = np.array([np.max((r['losses'] - mean_noise)/std_noise) > detection_threshold for r in results_detection])
    norm_ds = np.asarray([np.max((r['losses'] - mean_noise)/std_noise) for r in results_detection])

    m1_values = np.array([r['m1'] for r in results_detection])
    m2_values = np.array([r['m2'] for r in results_detection])
    distances = np.array([r['dist'] for r in results_detection])
    e0_values = np.array([r['e0'] for r in results_detection])
    ef_values = np.array([r['ef'] for r in results_detection])

    mask = np.isin(snr_values, snrs)
    det_mask = mask & detected
    not_det_mask = mask & ~detected

    filtered_distances = distances[det_mask]
    z_values = np.array([z_at_value(Planck18.luminosity_distance, d*u.Gpc) for d in filtered_distances])
    filtered_z = z_values
    filtered_m1 = m1_values[det_mask]/(1 + z_values)
    filtered_m2 = m2_values[det_mask]/(1 + z_values)
    filtered_ef = ef_values[det_mask]
    filtered_e0 = e0_values[det_mask]

    # Map app.py dropdown input to variable
    colorbar_map = {
        r"Final eccentricity": (filtered_ef, 'Final Eccentricity $e_f$', 'plasma'),
        r"Initial eccentricity": (filtered_e0, 'Initial Eccentricity $e_0$', 'cividis'),
        r"Primary mass": (filtered_m1, 'Primary Mass [M$_\odot$]', 'viridis'),
        r"Secondary mass": (filtered_m2, 'Secondary Mass [M$_\odot$]', 'viridis'),
    }
    color_data, color_label, cmap = colorbar_map.get(colorbar_var, (filtered_ef, 'Final Eccentricity ($e_f$)', 'plasma'))

    fig, ax = plt.subplots(figsize=(7, 5))
    if x_axis == "Primary Mass":
        x = filtered_m1
        xlabel = r'Source frame primary mass [M$_\odot$]'
    else:
        x = filtered_m2
        xlabel = r'Source frame secondary mass [M$_\odot$]'

    if y_axis == "Redshift":
        y = filtered_z
        ylabel = 'Redshift'
    else:
        y = filtered_distances
        ylabel = 'Luminosity Distance [Gpc]'

    scatter = ax.scatter(x, y, c=color_data, cmap=cmap, alpha=0.7, marker='o')
    cbar = plt.colorbar(scatter, ax=ax)
    cbar.set_label(color_label)
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    return fig