#!/usr/bin/env python """ Compute FNR thresholds at standard alpha levels for the lookup table. This script computes False Negative Rate (FNR) controlling thresholds using conformal risk control. FNR thresholds ensure that the fraction of true positives missed is controlled at level alpha. The thresholds are computed by: 1. Sampling calibration data multiple times (n_trials) 2. Computing the FNR threshold for each trial 3. Averaging across trials to get a stable estimate Note on reproducibility: - Due to random sampling of calibration data, results may vary slightly between runs - The standard deviation across trials indicates the expected variability - For exact reproduction, use the same random seed Usage: python scripts/compute_fnr_table.py --calibration data/pfam_new_proteins.npy python scripts/compute_fnr_table.py --calibration data/pfam_new_proteins.npy --partial """ import argparse import sys from pathlib import Path import numpy as np import pandas as pd # Add parent directory to path sys.path.insert(0, str(Path(__file__).parent.parent)) from protein_conformal.util import get_thresh_new, get_sims_labels def compute_fnr_threshold(cal_data, alpha: float, n_trials: int = 100, n_calib: int = 1000, seed: int = None, partial: bool = False) -> dict: """ Compute FNR threshold at a given alpha level. Parameters: cal_data: Calibration data array alpha: Target FNR level (e.g., 0.1 means at most 10% false negatives) n_trials: Number of trials for averaging n_calib: Number of calibration samples per trial seed: Random seed for reproducibility partial: If True, use partial matches (at least one Pfam domain matches) Returns dict with: - mean_threshold: Average threshold across trials - std_threshold: Standard deviation across trials """ if seed is not None: np.random.seed(seed) thresholds = [] for trial in range(n_trials): # Shuffle and sample calibration data np.random.shuffle(cal_data) trial_data = cal_data[:n_calib] # Get similarity scores and labels X_cal, y_cal = get_sims_labels(trial_data, partial=partial) # Compute FNR threshold l_hat = get_thresh_new(X_cal, y_cal, alpha) thresholds.append(l_hat) return { 'mean_threshold': np.mean(thresholds), 'std_threshold': np.std(thresholds), 'min_threshold': np.min(thresholds), 'max_threshold': np.max(thresholds), } def main(): parser = argparse.ArgumentParser( description='Compute FNR thresholds at standard alpha levels' ) parser.add_argument( '--calibration', '-c', type=Path, required=True, help='Path to calibration data (.npy file)' ) parser.add_argument( '--output', '-o', type=Path, default=None, help='Output CSV file (default: results/fnr_thresholds.csv or results/fnr_thresholds_partial.csv)' ) parser.add_argument( '--n-trials', type=int, default=100, help='Number of calibration trials (default: 100)' ) parser.add_argument( '--n-calib', type=int, default=1000, help='Number of calibration samples per trial (default: 1000)' ) parser.add_argument( '--seed', type=int, default=42, help='Random seed for reproducibility (default: 42)' ) parser.add_argument( '--partial', action='store_true', help='Use partial matches (at least one Pfam domain matches)' ) parser.add_argument( '--alpha-levels', type=str, default=None, help='Comma-separated alpha levels (default: 0.001,0.005,0.01,0.02,0.05,0.1,0.15,0.2)' ) args = parser.parse_args() # Set default output path based on partial flag if args.output is None: suffix = '_partial' if args.partial else '' args.output = Path(f'results/fnr_thresholds{suffix}.csv') # Parse alpha levels (custom or default) if args.alpha_levels: alpha_levels = [float(x.strip()) for x in args.alpha_levels.split(',')] else: # Standard alpha levels that users commonly need alpha_levels = [0.001, 0.005, 0.01, 0.02, 0.05, 0.1, 0.15, 0.2] match_type = "partial" if args.partial else "exact" print(f"Computing FNR thresholds ({match_type} matches)") print(f"Loading calibration data from {args.calibration}...") cal_data = np.load(args.calibration, allow_pickle=True) print(f" Loaded {len(cal_data)} calibration samples") print(f"\nComputing thresholds at {len(alpha_levels)} alpha levels...") print(f" Trials per alpha: {args.n_trials}") print(f" Calibration samples per trial: {args.n_calib}") print(f" Random seed: {args.seed}") print(f" Match type: {match_type}") print() results = [] for alpha in alpha_levels: print(f" α = {alpha:.3f}...", end=" ", flush=True) # Use different seed offset for each alpha to ensure independence trial_seed = args.seed + int(alpha * 10000) stats = compute_fnr_threshold( cal_data.copy(), # Copy to avoid mutation alpha=alpha, n_trials=args.n_trials, n_calib=args.n_calib, seed=trial_seed, partial=args.partial ) results.append({ 'alpha': alpha, 'threshold_mean': stats['mean_threshold'], 'threshold_std': stats['std_threshold'], 'threshold_min': stats['min_threshold'], 'threshold_max': stats['max_threshold'], 'match_type': match_type, }) print(f"λ = {stats['mean_threshold']:.10f} ± {stats['std_threshold']:.2e}") # Create DataFrame and save df = pd.DataFrame(results) # Add human-readable notes print(f"\n{'='*70}") print(f"FNR Threshold Lookup Table ({match_type} matches)") print(f"{'='*70}") print(f"{'Alpha':<8} {'Threshold (λ)':<20} {'Std Dev':<12}") print("-" * 70) for _, row in df.iterrows(): print(f"{row['alpha']:<8.3f} {row['threshold_mean']:<20.12f} {row['threshold_std']:<12.2e}") print(f"{'='*70}") # Save to CSV args.output.parent.mkdir(parents=True, exist_ok=True) df.to_csv(args.output, index=False) print(f"\nSaved to {args.output}") # Also save a simple version for easy lookup simple_output = args.output.parent / f'fnr_thresholds{"_partial" if args.partial else ""}_simple.csv' df[['alpha', 'threshold_mean']].rename( columns={'threshold_mean': 'lambda_threshold'} ).to_csv(simple_output, index=False) print(f"Simple lookup table saved to {simple_output}") return df if __name__ == '__main__': main()