Spaces:
Running
Running
File size: 4,567 Bytes
a1eb2dc 1f74b39 a1eb2dc 1f74b39 a1eb2dc 3e8d296 a1eb2dc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
import h5py
import numpy as np
import os
import matplotlib.pyplot as plt
import pickle
from plot_styles import apply_physrev_style
from astropy.cosmology import Planck18
from astropy.cosmology import z_at_value
from astropy import units as u
from scipy.stats import gumbel_r
# Apply the style
def get_detection_threshold(normalized, alpha, gumbel=True, list_hyp=False):
"""Compute detection threshold for given significance level alpha."""
if gumbel:
if list_hyp:
detection_threshold = [gumbel_r(*gumbel_r.fit(el)).isf(alpha) for el in normalized.T]
else:
detection_threshold = gumbel_r(*gumbel_r.fit(np.max(normalized, axis=1))).isf(alpha)
else:
if list_hyp:
detection_threshold = np.quantile(normalized, 1-alpha/len(tpl_vector), axis=0)
else:
detection_threshold = np.quantile(np.max(normalized, axis=1), 1-alpha)
return detection_threshold
# New function for interactive plotting
def plot_mass_vs_distance_or_redshift(
snrs=[30], alpha=1e-4, y_axis="Redshift", x_axis="Primary Mass", colorbar_var="ef"):
"""
Interactive plot for mass vs distance/redshift.
snrs: list of SNRs to include
alpha: false alarm rate
y_axis: 'Redshift' or 'Luminosity Distance'
x_axis: 'Primary Mass' or 'Secondary Mass'
colorbar_var: 'e0', 'ef', 'm1', or 'm2'
Returns: matplotlib figure
"""
noise_file = "paper_results_tdi.h5"
if not os.path.exists(noise_file):
raise FileNotFoundError(f"Noise file {noise_file} not found.")
with h5py.File(noise_file, 'r') as f:
all_best_losses_noise = f['all_best_losses_noise'][()]
tpl_vector = f['tpl_vector'][()]
mean_noise = all_best_losses_noise.mean(axis=0)
std_noise = all_best_losses_noise.std(axis=0)
normalized = (all_best_losses_noise - mean_noise) / std_noise
results_detection = []
snr_values = []
for snr in snrs:
cache_file = f"paper_scatter_cache_{snr}.pkl"
if not os.path.exists(cache_file):
continue
with open(cache_file, "rb") as f:
results_, snr_ = pickle.load(f)
results_detection.extend(results_)
snr_values.extend(snr_)
snr_values = np.array(snr_values)
detection_threshold = get_detection_threshold(normalized, alpha)
detected = np.array([np.max((r['losses'] - mean_noise)/std_noise) > detection_threshold for r in results_detection])
norm_ds = np.asarray([np.max((r['losses'] - mean_noise)/std_noise) for r in results_detection])
m1_values = np.array([r['m1'] for r in results_detection])
m2_values = np.array([r['m2'] for r in results_detection])
distances = np.array([r['dist'] for r in results_detection])
e0_values = np.array([r['e0'] for r in results_detection])
ef_values = np.array([r['ef'] for r in results_detection])
mask = np.isin(snr_values, snrs)
det_mask = mask & detected
not_det_mask = mask & ~detected
filtered_distances = distances[det_mask]
z_values = np.array([z_at_value(Planck18.luminosity_distance, d*u.Gpc) for d in filtered_distances])
filtered_z = z_values
filtered_m1 = m1_values[det_mask]/(1 + z_values)
filtered_m2 = m2_values[det_mask]/(1 + z_values)
filtered_ef = ef_values[det_mask]
filtered_e0 = e0_values[det_mask]
# Map app.py dropdown input to variable
colorbar_map = {
r"Final eccentricity": (filtered_ef, 'Final Eccentricity $e_f$', 'plasma'),
r"Initial eccentricity": (filtered_e0, 'Initial Eccentricity $e_0$', 'cividis'),
r"Primary mass": (filtered_m1, 'Primary Mass [M$_\odot$]', 'viridis'),
r"Secondary mass": (filtered_m2, 'Secondary Mass [M$_\odot$]', 'viridis'),
}
color_data, color_label, cmap = colorbar_map.get(colorbar_var, (filtered_ef, 'Final Eccentricity ($e_f$)', 'plasma'))
fig, ax = plt.subplots(figsize=(7, 5))
if x_axis == "Primary Mass":
x = filtered_m1
xlabel = r'Source frame primary mass [M$_\odot$]'
else:
x = filtered_m2
xlabel = r'Source frame secondary mass [M$_\odot$]'
if y_axis == "Redshift":
y = filtered_z
ylabel = 'Redshift'
else:
y = filtered_distances
ylabel = 'Luminosity Distance [Gpc]'
scatter = ax.scatter(x, y, c=color_data, cmap=cmap, alpha=0.7, marker='o')
cbar = plt.colorbar(scatter, ax=ax)
cbar.set_label(color_label)
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
ax.grid(True, alpha=0.3)
plt.tight_layout()
return fig |