lorenzsp commited on
Commit
1f74b39
·
1 Parent(s): 82556e0

fix requirements

Browse files
plot_detection_probability.py DELETED
@@ -1,457 +0,0 @@
1
- import h5py
2
- import numpy as np
3
- import os
4
- import argparse
5
- import matplotlib.pyplot as plt
6
- import glob
7
- from scipy.stats import gumbel_r
8
- from mpl_toolkits.axes_grid1.inset_locator import inset_axes, mark_inset
9
- import pickle
10
- import matplotlib.patches as patches
11
- from sklearn.metrics import roc_curve, auc
12
- from astropy.cosmology import Planck18, z_at_value
13
- import astropy.units as u
14
- import matplotlib.lines as mlines
15
- import matplotlib.patches as mpatches
16
- from pastamarkers import pasta, salsa
17
- MTSUN_SI, YRSID_SI = 4.9254909476412675e-06, 31558149.763545595
18
- from plot_styles import apply_physrev_style
19
-
20
- # Apply the style
21
- apply_physrev_style()
22
-
23
- def chirpmass_from_f_fdot_few(f, fdot):
24
- """
25
- Calculate chirp mass using few constants.
26
- Returns chirp mass in solar masses.
27
- """
28
- M_chirp = ((10**f)**(-11) * (10**fdot)**3 * np.pi**(-8) * (5/96)**3)**(1/5) / MTSUN_SI
29
- return M_chirp
30
-
31
- def get_detection_threshold(normalized, alpha, gumbel=True, list_hyp=False):
32
- """Compute detection threshold for given significance level alpha."""
33
- if gumbel:
34
- if list_hyp:
35
- detection_threshold = [gumbel_r(*gumbel_r.fit(el)).isf(alpha) for el in normalized.T]
36
- else:
37
- detection_threshold = gumbel_r(*gumbel_r.fit(np.max(normalized, axis=1))).isf(alpha)
38
- else:
39
-
40
- if list_hyp:
41
- detection_threshold = np.quantile(normalized, 1-alpha/len(tpl_vector), axis=0)
42
- else:
43
- detection_threshold = np.quantile(np.max(normalized, axis=1), 1-alpha)
44
- return detection_threshold
45
-
46
- def compute_detection_probability(results, values, detection_threshold):
47
- """Compute detection probability for each unique value at given significance level."""
48
- unique_values = np.unique(values)
49
- unique_values = unique_values[unique_values > 1.0] # only above 1
50
- detection_probs = []
51
- detection_std_probs = []
52
-
53
- print(f"Quantile for detection: {detection_threshold}")
54
- for val in unique_values:
55
- mask = np.isclose(values, val, rtol=1e-3)
56
- detections = []
57
- for r in np.array(results)[mask]:
58
- # Compute det_stat for each result
59
- det_stat = (r['losses'] - mean_noise)/std_noise
60
- # to use any
61
- # detected = np.any(det_stat > detection_threshold)
62
- det_stat = np.max(det_stat) # use the max statistic across templates
63
- detected = det_stat > detection_threshold
64
- detections.append(detected)
65
- prob = np.mean(detections) if len(detections) > 0 else 0.0
66
- # Bernoulli standard deviation
67
- std_prob = np.sqrt(prob * (1 - prob) / len(detections)) if len(detections) > 0 else 0.0
68
- detection_probs.append(prob)
69
- detection_std_probs.append(std_prob)
70
- print(f"Detection probability for {val}: {prob} ± {std_prob}")
71
- return unique_values, np.asarray(detection_probs), np.asarray(detection_std_probs)
72
-
73
- def compute_accuracy(results, values, detection_threshold):
74
- """Compute median relative frequency error for each unique value."""
75
- unique_values = np.unique(values)
76
- unique_values = unique_values[unique_values > 1.0] # only above 1
77
- acc_medians = []
78
- acc_err_low = []
79
- acc_err_high = []
80
-
81
- for val in unique_values:
82
- mask = np.isclose(values, val, rtol=1e-3)
83
- accs = []
84
- for r in np.array(results)[mask]:
85
- # check if detected
86
- # detected = np.max((r['losses'] - mean_noise)/std_noise) > detection_threshold
87
- detected = np.any((r['losses'] - mean_noise)/std_noise > detection_threshold)
88
- if detected:
89
- # Find best tpl index based on max loss
90
- best_idx = np.argmax((r['losses'] - mean_noise)/std_noise)
91
- acc = r['rel_diff_medians'][best_idx]
92
- accs.append(acc)
93
- if len(accs) > 0:
94
- med = np.median(accs)
95
- low = np.percentile(accs, 16)
96
- high = np.percentile(accs, 84)
97
- else:
98
- med = low = high = np.nan
99
- acc_medians.append(med)
100
- acc_err_low.append(med - low)
101
- acc_err_high.append(high - med)
102
- return unique_values, np.array(acc_medians), np.array(acc_err_low), np.array(acc_err_high)
103
-
104
- if __name__ == "__main__":
105
- from theoretical_pdet import detection_probability
106
- from theoretical_pdet import detection_threshold as detection_threshold_func
107
-
108
- from plot_best_results import load_best_results
109
-
110
- parser = argparse.ArgumentParser(description="Plot detection probability versus SNR and scatter plot for Tpl vs ef.")
111
- args = parser.parse_args()
112
-
113
- # Load noise distribution
114
- save_path = 'paper_results_tdi.h5'
115
- if not os.path.exists(save_path):
116
- print(f"Error: {save_path} not found.")
117
- exit(1)
118
- print(f"Loading aggregated results from {save_path}.")
119
- with h5py.File(save_path, 'r') as f:
120
- all_best_losses_noise = f['all_best_losses_noise'][()]
121
- tpl_vector = f['tpl_vector'][()]
122
- best_fs = f['best_fs'][()]
123
- noise_f = np.mean(best_fs[:,:,:5],axis=-1)
124
- noise_fdot = np.mean(np.gradient(best_fs,5e4,axis=-1)[:,:,:5],axis=-1)
125
-
126
-
127
- mean_noise = all_best_losses_noise.mean(axis=0)
128
- std_noise = all_best_losses_noise.std(axis=0)
129
- normalized = (all_best_losses_noise - mean_noise) / std_noise
130
-
131
- plt.figure()
132
- # Get colormap and normalize for number of templates
133
- cmap = plt.get_cmap('viridis')
134
- norm = plt.Normalize(0, normalized.shape[1]-1)
135
-
136
- # Plot histograms and fitted Gumbel distributions for each template
137
- for ii in range(normalized.shape[1]):
138
- color = cmap(norm(ii))
139
- # Plot histogram
140
- # plt.hist(normalized[:,ii], bins=50, density=True, alpha=0.6, label=f'Template {ii+1}')
141
-
142
- # Fit Gumbel distribution and plot
143
- params = gumbel_r.fit(normalized[:,ii])
144
- x_range = np.linspace(normalized[:,ii].min(), normalized[:,ii].max(), 100)
145
- plt.plot(x_range, gumbel_r.pdf(x_range, *params), '-', linewidth=2, color=color, alpha=0.7)
146
- plt.semilogy()
147
- plt.xlabel('Normalized Statistic')
148
- plt.ylabel('Density')
149
- plt.tight_layout()
150
- plt.savefig('gumbel_fits.pdf')
151
-
152
- # Try to load cached results if available, otherwise process and save
153
-
154
- cache_file = "paper_detection_cache.pkl"
155
- if os.path.exists(cache_file):
156
- print(f"Loading cached results from {cache_file}")
157
- with open(cache_file, "rb") as f:
158
- results_detection, snr_values = pickle.load(f)
159
- else:
160
- # Process signal-injected realizations
161
- results_detection = []
162
- snr_values = []
163
- dirs = glob.glob("../apaper_results/*")
164
-
165
- if not dirs:
166
- print("Error: No signal-injected realization directories found.")
167
- exit(1)
168
-
169
- for output_dir in dirs:
170
- h5_file = os.path.join(output_dir, 'best_results.h5')
171
- if not os.path.exists(h5_file):
172
- print(f"Warning: {h5_file} not found. Skipping.")
173
- continue
174
-
175
- current_tpl, current_losses, param_dict, best_fs = load_best_results(h5_file)
176
- with h5py.File(h5_file, 'r') as f:
177
- snr = f['SNR'][()]
178
- rel_diff_medians = f['rel_diff_stats/medians'][()] # Load medians from the group
179
-
180
- if not np.array_equal(tpl_vector, current_tpl):
181
- print(f"Warning: Tpl_vector mismatch in {output_dir}. Skipping.")
182
- continue
183
-
184
- f0 = best_fs[:,:5].mean(axis=-1)
185
- fdot0 = np.gradient(best_fs, 5e4, axis=-1)[:,:5].mean(axis=-1)
186
- param_dict["f0"] = f0
187
- param_dict["fdot0"] = fdot0
188
- param_dict["losses"] = -current_losses
189
- param_dict["snr"] = float(snr)
190
- param_dict["rel_diff_medians"] = rel_diff_medians # Add rel_diff_medians to param_dict
191
- results_detection.append(param_dict)
192
- snr_values.append(float(snr))
193
- print(f"Processed {h5_file}: SNR={snr}")
194
-
195
- if not results_detection:
196
- print("Error: No valid signal-injected realizations processed.")
197
- exit(1)
198
-
199
- # Save to cache for future runs
200
- with open(cache_file, "wb") as f:
201
- pickle.dump((results_detection, snr_values), f)
202
- print(f"Saved processed results to {cache_file}")
203
-
204
- # Convert to arrays
205
- snr_values = np.array(snr_values)
206
-
207
- # Plot detection probability vs SNR for different significance levels
208
- # Create a figure with two subplots sharing the x-axis
209
- fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True, figsize=(6, 6), gridspec_kw={'height_ratios': [1.5, 1]})
210
-
211
- # Detection probability vs SNR for different significance levels
212
- alpha_values = [0.5, 0.01, 0.0001] # Different significance levels
213
- colors = ['C0', 'C1', 'C2'] # Colors for different alpha values
214
- linestyles = ['-', '--', ':'] # Different line styles
215
- markers = ['o', 's', 'D'] # Different markers
216
- labels = [r'$p_{\rm FA}=0.5$', r'$p_{\rm FA}=10^{-2}$', r'$p_{\rm FA}=10^{-4}$']
217
- ms = 6 # Marker size
218
-
219
- mu0 = 632 * 2.
220
- sigma0 = np.sqrt(2 * 632 * 1.45**2)
221
-
222
- for alpha, color, ls, marker, lb in zip(alpha_values, colors, linestyles, markers, labels):
223
- detection_threshold = get_detection_threshold(normalized, alpha, gumbel=True)
224
- print(f"Detection threshold for alpha={alpha}: {detection_threshold}")
225
- unique_snrs, detection_probs_snr, detection_std_probs_snr = compute_detection_probability(results_detection, snr_values, detection_threshold)
226
- ax1.errorbar(unique_snrs, detection_probs_snr, yerr=detection_std_probs_snr, fmt=marker, color=color, linestyle=ls, alpha=0.7, capsize=3, ms=ms, label=lb, lw=2.)
227
-
228
- # Accuracy plot: Median relative frequency error vs SNR
229
- unique_snrs, acc_med, acc_err_l, acc_err_h = compute_accuracy(results_detection, snr_values, detection_threshold)
230
- ax2.errorbar(unique_snrs, acc_med, yerr=[acc_err_l, acc_err_h], fmt=marker, color=color, linestyle=ls, alpha=0.7, capsize=3, ms=ms, label=lb, lw=2.)
231
- print("Detection threshold:", detection_threshold, "Alpha:", alpha, "acc_med", acc_med)
232
-
233
- P_D_values = []
234
- A_test_values = np.linspace(20, 40, 100)
235
- snr_mismatch = (1.-0.0)**0.5 # assuming average mismatch of 0.5
236
- print("SNR mismatch factor:", snr_mismatch)
237
- N = 632
238
- pf_per_template = 1e-2 / 1e25
239
- PD_curve = np.array([detection_probability(A * snr_mismatch, 632, pf_per_template) for A in A_test_values])
240
- ax1.plot(A_test_values, PD_curve, '--', color='C5', lw=1.5, alpha=0.7, zorder=10)
241
-
242
- ax1.set_ylabel('Detection Probability')
243
- ax1.grid(True, axis='y')
244
- ax1.legend(title='False Alarm Probability', loc='lower right')
245
- ax1.set_ylim(-0.05, 1.05)
246
-
247
- idx_50 = np.argmin(np.abs(PD_curve - 0.02))
248
- x_annotate = A_test_values[idx_50]
249
- y_annotate = PD_curve[idx_50]
250
-
251
- ax1.annotate('Theoretical \nTemplate Bank\n $p_{\\mathrm{FA}}=10^{-2}$',
252
- xy=(x_annotate, y_annotate),
253
- xytext=(x_annotate + 2, y_annotate + 0.01),
254
- # arrowprops=dict(arrowstyle='->', color=colors[1], lw=1.5),
255
- fontsize=10,
256
- color='C5')
257
-
258
- ax2.set_xlabel('SNR')
259
- ax2.set_ylabel('Relative Frequency Error')
260
- ax2.set_yscale('log')
261
- # ax2.legend()
262
- # ax2.set_ylim(3e-4, 1e-2)
263
- ax2.grid(True)
264
-
265
- plt.tight_layout()
266
- plt.savefig('detection_and_accuracy_vs_snr.pdf')
267
- plt.close('all')
268
-
269
- ####################################################
270
- # Scatter plot Tpl vs ef, colored by SNR, markers by detection
271
- expected_snrs = [25, 30, 35]
272
-
273
- cmap = plt.get_cmap('inferno')
274
- cmap = plt.get_cmap(salsa.pesto)
275
-
276
- alpha = 0.001 # Use default alpha for scatter plot
277
- detection_threshold = get_detection_threshold(normalized, alpha)
278
- quantile_detection = np.quantile(normalized, 1-alpha/len(tpl_vector), axis=0)
279
- norm_ds = np.array([np.max((r['losses'] - mean_noise)/std_noise) for r in results_detection])
280
- detected = np.array([np.any(np.max((r['losses'] - mean_noise)/std_noise) > detection_threshold) for r in results_detection])
281
-
282
- m1_values = np.array([r['m1'] for r in results_detection])
283
- m2_values = np.array([r['m2'] for r in results_detection])
284
- tpl_values = np.array([r['Tpl'] for r in results_detection])
285
- ef_values = np.array([r['e0'] for r in results_detection])
286
- dist_values = np.array([r['dist'] for r in results_detection])
287
- f0_values = np.array([r['f0'][np.argmax(np.max((r['losses'] - mean_noise)/std_noise))] for r in results_detection])
288
- fdot0_values = np.array([r['fdot0'][np.argmax(np.max((r['losses'] - mean_noise)/std_noise))] for r in results_detection])
289
-
290
- M_chirp_values = chirpmass_from_f_fdot_few(f0_values, fdot0_values)
291
- M_chirp_noise = chirpmass_from_f_fdot_few(noise_f, noise_fdot)
292
-
293
- norm = plt.Normalize(1.0, 100.0)
294
-
295
-
296
- fig, ax = plt.subplots()
297
- cbar = plt.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=cmap), ax=ax, label=r'$m_2$ [$M_\odot$]')
298
- for snr, mark in zip(expected_snrs, [pasta.penne, pasta.rigatoni, pasta.farfalle]):
299
- mask = np.isclose(snr_values, snr, rtol=1e-3)
300
- det_mask = mask & detected
301
- if np.any(det_mask):
302
- z_values = np.array([z_at_value(Planck18.luminosity_distance, d*u.Gpc) for d in dist_values[det_mask]])
303
- color_list = [cmap(norm(el)) for el in m2_values[det_mask]/(1+z_values)]
304
- plt.scatter(m1_values[det_mask]/(1+z_values), z_values, marker=mark, c=color_list, alpha=0.7)
305
- plt.semilogx()
306
- # plt.scatter(m1_values[det_mask]/(1+z_values), dist_values[det_mask], marker=mark, c=color_list, alpha=0.7)
307
-
308
-
309
- detected_marker = mlines.Line2D([], [], color='k', marker='o', linestyle='None', markersize=7, label='Detected')
310
- not_detected_marker = mlines.Line2D([], [], color='k', marker='x', linestyle='None', markersize=7, label='Not detected')
311
-
312
- # Create color patches for SNR legend
313
- snr_legend = [mlines.Line2D([], [], color='k', marker=mark, linestyle='None', markersize=7, label=f'SNR={snr}') for snr, mark in zip(expected_snrs, [pasta.penne, pasta.rigatoni, pasta.farfalle])]
314
- # Combine legends
315
- plt.legend(handles=snr_legend, loc='best')
316
-
317
- plt.xlabel(r'$m_1$ [$M_\odot$]')
318
- plt.ylabel(r'Redshift $z$')
319
- plt.tight_layout()
320
- plt.savefig('scatter.pdf')
321
-
322
- ###############################################
323
- # Build labels and scores for ROC
324
-
325
- labels = []
326
- scores = []
327
- argmax_scores = []
328
- print("Building labels and scores for ROC curve...")#, snr_values)
329
- dict_labels = {snr: [] for snr in snr_values}
330
-
331
- # Noise-only trials
332
- for noise_row in all_best_losses_noise:
333
- det_stat = (noise_row - mean_noise) / std_noise
334
- score = np.max(det_stat) # use the max statistic across time
335
- argmax_scores.append(np.argmax(det_stat))
336
- labels.append(0)
337
- scores.append(score)
338
-
339
- # show argmax histogram
340
- plt.figure()
341
- plt.hist(argmax_scores, bins=len(tpl_vector), density=True, alpha=0.7)
342
- plt.xlabel('Template Index of Max Score (Noise)')
343
- plt.ylabel('Density')
344
- plt.grid(True)
345
- plt.tight_layout()
346
- plt.savefig('argmax_histogram_noise.pdf')
347
-
348
-
349
- # Signal+noise trials
350
- for r in results_detection:
351
- det_stat = (r['losses'] - mean_noise) / std_noise
352
- score = np.max(det_stat)
353
- # define label as the SNR value
354
- labels.append(r['snr'])
355
- scores.append(score)
356
-
357
- labels = np.array(labels)
358
- scores = np.array(scores)
359
-
360
- # histogram of scores
361
- plt.figure()
362
- # Define log-spaced bins for better visualization
363
- min_score = np.min(np.array(scores)[np.array(labels)==0])
364
- max_score = np.max(np.array(scores)[np.array(labels)==0])
365
- log_bins = np.logspace(np.log10(min_score), np.log10(max_score), 10)
366
- plt.hist(np.array(scores)[np.array(labels)==0], bins=log_bins, label='Noise', alpha=0.7, density=True)
367
- linestyles = ['-', '--', '-.', ':', (0, (3, 1, 1, 1))]
368
- for i, snr in enumerate([20, 30, 40]):
369
- min_score = np.min(np.array(scores)[np.array(labels)==snr])
370
- max_score = np.max(np.array(scores)[np.array(labels)==snr])
371
- log_bins = np.logspace(np.log10(min_score), np.log10(max_score), 10)
372
- plt.hist(np.array(scores)[np.array(labels)==snr], bins=log_bins, label=f'SNR$=${snr}', alpha=1.0, density=True, histtype='step', linewidth=2.5, linestyle=linestyles[i])
373
- plt.semilogx()
374
- plt.semilogy()
375
- plt.axvline(get_detection_threshold(normalized, 0.5), color='grey', linestyle='-', label=r'$p_{\rm FA}=0.5$')
376
- plt.axvline(get_detection_threshold(normalized, 0.01), color='grey', linestyle='--', label=r'$p_{\rm FA}=10^{-2}$')
377
- plt.axvline(get_detection_threshold(normalized, 0.0001), color='grey', linestyle=':', label=r'$p_{\rm FA}=10^{-4}$')
378
-
379
- # Add text annotations next to the vertical lines
380
- from max_of_distribution import compute_max_stats
381
- MU_K = 2.0
382
- SIGMA_K = 1.45
383
- N = int(YRSID_SI / 5e4)
384
- Nopt = 500 * 512
385
- mu0 = N * MU_K
386
- sigma0 = np.sqrt(2 * N * SIGMA_K**2)
387
- results_per_seg = compute_max_stats(mu0, sigma0, Nopt, method='asymptotic')
388
- print("noise approx", results_per_seg['mean'], results_per_seg['variance']**0.5, "approx", mean_noise[-1], std_noise[-1])
389
- print("Relative difference", (results_per_seg['mean'] - 2 * mean_noise[-1]) / (2 * mean_noise[-1]))
390
-
391
- for A in [1e-5, 20, 30, 40]:
392
- snr = A
393
- mu1 = N * MU_K + A**2
394
- sigma1 = np.sqrt(2 * N * SIGMA_K**2 + 4 * A**2)
395
- results_per_seg = compute_max_stats(mu1, sigma1, Nopt, method='asymptotic')
396
- max_last_seg = np.asarray([r['losses'][-1] for r in results_detection if r['snr'] == snr])
397
- print(f"SNR={snr}")
398
- print("Signal+noise approx", results_per_seg['mean'], results_per_seg['variance']**0.5, 2 * np.mean(max_last_seg), np.std(max_last_seg))
399
- print("Relative difference", (results_per_seg['mean'] - 2 * np.mean(max_last_seg)) / (2 * np.mean(max_last_seg)), (results_per_seg['variance']**0.5 - np.std(max_last_seg)) / np.std(max_last_seg))
400
-
401
- # plot fit
402
- log_bins = np.logspace(-1, 3, 50)
403
- print(gumbel_r.fit(np.max(normalized, axis=1)))
404
- gumb = gumbel_r(*gumbel_r.fit(np.max(normalized, axis=1))).pdf(log_bins)
405
- plt.plot(log_bins, gumb, '-', linewidth=2, color='C0', alpha=0.7)
406
-
407
- plt.xlabel(r'Normalized Statistic $\mathcal{S}$ ')
408
- plt.ylabel('Density')
409
- plt.legend(ncol=1)
410
- # plt.title('Histogram of Detection Statistic Scores')
411
- plt.ylim(1e-4, 10)
412
- plt.xlim(0.6, 2000)
413
- # plt.grid(True)
414
- plt.tight_layout()
415
- plt.savefig('score_histogram.pdf')
416
-
417
- # Plot ROC
418
- # Full ROC with inset
419
- fig, ax = plt.subplots()
420
-
421
- # Inset axes for zoomed region
422
- axins = inset_axes(ax, width="50%", height="50%", loc='lower left',
423
- bbox_to_anchor=(0.45,0.08,0.5,0.5), bbox_transform=ax.transAxes)
424
-
425
- # Main ROC
426
- for snr in [25, 30, 35]:
427
- # select scores
428
- new_scores = scores[(labels == snr) | (labels == 0.0)]
429
- new_labels = labels[(labels == snr) | (labels == 0.0)]
430
- new_labels = np.array([1 if l > 0 else 0 for l in new_labels]) # binary labels: 1 for signal, 0 for noise
431
- # Compute ROC curve and AUC
432
- fpr, tpr, thresholds = roc_curve(new_labels, new_scores)
433
- roc_auc = auc(fpr, tpr)
434
- ax.plot(fpr, tpr, lw=2, label=f'SNR={snr}')# (AUC = {roc_auc:.3f})
435
-
436
- axins.plot(fpr, tpr, lw=2)
437
- axins.set_xlim([5e-4, 5e-2]) # zoomed FPR range
438
- axins.set_ylim([0.5, 1.02])
439
- axins.set_xscale('log')
440
- # Increase number of y-ticks
441
- axins.yaxis.set_major_locator(plt.MaxNLocator(nbins=4))
442
- # axins.set_title("Low-FPR zoom", fontsize=9)
443
- axins.grid(True, which="both")
444
-
445
- mark_inset(ax, axins, loc1=1, loc2=3, fc="none", ec="0.5")
446
-
447
- ax.plot([0,1],[0,1], lw=1, linestyle='--', label='Random')
448
- ax.set_xlabel('False Positive Rate')
449
- ax.set_ylabel('True Positive Rate')
450
- # ax.set_title('ROC curve for GW search pipeline')
451
- ax.legend(loc='lower right')
452
- plt.tight_layout()
453
- ax.grid(True)
454
-
455
- plt.savefig('roc_curve.pdf')
456
-
457
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -2,11 +2,8 @@ gradio
2
  numpy
3
  matplotlib
4
  h5py
5
- pandas
6
  astropy
7
  scipy
8
- scikit-learn
9
  pillow
10
  corner
11
- pastamarkers
12
- few
 
2
  numpy
3
  matplotlib
4
  h5py
 
5
  astropy
6
  scipy
 
7
  pillow
8
  corner
9
+ pastamarkers
 
scatter_plot_snr.py CHANGED
@@ -1,23 +1,30 @@
1
  import h5py
2
  import numpy as np
3
  import os
4
- import argparse
5
  import matplotlib.pyplot as plt
6
- import glob
7
- from plot_detection_probability import get_detection_threshold
8
  import pickle
9
- import corner
10
- from matplotlib.patches import Rectangle
11
  from plot_styles import apply_physrev_style
12
  from astropy.cosmology import Planck18
13
  from astropy.cosmology import z_at_value
14
  from astropy import units as u
15
- from matplotlib.ticker import MaxNLocator
16
- from matplotlib.ticker import LogLocator
17
- from matplotlib.lines import Line2D
18
- import pandas as pd
19
  # Apply the style
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  # New function for interactive plotting
22
  def plot_mass_vs_distance_or_redshift(
23
  snrs=[30], alpha=1e-4, y_axis="Redshift", x_axis="Primary Mass", colorbar_var="ef"):
 
1
  import h5py
2
  import numpy as np
3
  import os
 
4
  import matplotlib.pyplot as plt
 
 
5
  import pickle
 
 
6
  from plot_styles import apply_physrev_style
7
  from astropy.cosmology import Planck18
8
  from astropy.cosmology import z_at_value
9
  from astropy import units as u
10
+ from scipy.stats import gumbel_r
 
 
 
11
  # Apply the style
12
 
13
+ def get_detection_threshold(normalized, alpha, gumbel=True, list_hyp=False):
14
+ """Compute detection threshold for given significance level alpha."""
15
+ if gumbel:
16
+ if list_hyp:
17
+ detection_threshold = [gumbel_r(*gumbel_r.fit(el)).isf(alpha) for el in normalized.T]
18
+ else:
19
+ detection_threshold = gumbel_r(*gumbel_r.fit(np.max(normalized, axis=1))).isf(alpha)
20
+ else:
21
+
22
+ if list_hyp:
23
+ detection_threshold = np.quantile(normalized, 1-alpha/len(tpl_vector), axis=0)
24
+ else:
25
+ detection_threshold = np.quantile(np.max(normalized, axis=1), 1-alpha)
26
+ return detection_threshold
27
+
28
  # New function for interactive plotting
29
  def plot_mass_vs_distance_or_redshift(
30
  snrs=[30], alpha=1e-4, y_axis="Redshift", x_axis="Primary Mass", colorbar_var="ef"):