Spaces:
Running
Running
| """ | |
| Tests for protein_conformal/util.py core functions. | |
| These tests verify: | |
| 1. FASTA parsing | |
| 2. FAISS database operations | |
| 3. FDR/FNR threshold calculations (conformal risk control) | |
| 4. Risk metrics (FDR, FNR, TPR) | |
| 5. Venn-Abers probability predictions | |
| 6. Hierarchical loss functions (for SCOPe) | |
| """ | |
| import numpy as np | |
| import pytest | |
| from protein_conformal.util import ( | |
| read_fasta, | |
| load_database, | |
| query, | |
| get_thresh_new, | |
| get_thresh_new_FDR, | |
| get_thresh_FDR, | |
| risk, | |
| risk_1d, | |
| calculate_false_negatives, | |
| calculate_true_positives, | |
| simplifed_venn_abers_prediction, | |
| get_isotone_regression, | |
| scope_hierarchical_loss, | |
| validate_lhat_new, | |
| ) | |
| class TestFastaParsing: | |
| """Tests for FASTA file parsing.""" | |
| def test_read_fasta_basic(self, sample_fasta_file): | |
| """Test basic FASTA parsing returns sequences and metadata.""" | |
| sequences, metadata = read_fasta(sample_fasta_file) | |
| assert len(sequences) == 3 | |
| assert len(metadata) == 3 | |
| # Check first sequence | |
| assert sequences[0].startswith('MVLSPADKTN') | |
| assert '>protein1' in metadata[0] | |
| def test_read_fasta_sequence_content(self, sample_fasta_file): | |
| """Test that sequences contain only valid amino acids.""" | |
| sequences, _ = read_fasta(sample_fasta_file) | |
| valid_aa = set('ACDEFGHIKLMNPQRSTVWY') | |
| for seq in sequences: | |
| assert all(aa in valid_aa for aa in seq), f"Invalid AA in sequence: {seq}" | |
| def test_read_fasta_short_sequence(self, sample_fasta_file): | |
| """Test that short sequence is parsed correctly.""" | |
| sequences, metadata = read_fasta(sample_fasta_file) | |
| # Third sequence is exactly the 20 standard amino acids | |
| assert sequences[2] == 'ACDEFGHIKLMNPQRSTVWY' | |
| assert len(sequences[2]) == 20 | |
| class TestFAISSOperations: | |
| """Tests for FAISS database loading and querying.""" | |
| def test_load_database(self, sample_embeddings): | |
| """Test that database loads and has correct dimensions.""" | |
| _, lookup_embeddings = sample_embeddings | |
| index = load_database(lookup_embeddings.copy()) | |
| assert index.ntotal == 100 # Number of vectors in index | |
| assert index.d == 128 # Dimensionality | |
| def test_query_returns_correct_shape(self, sample_embeddings): | |
| """Test that query returns distances and indices with correct shapes.""" | |
| query_embeddings, lookup_embeddings = sample_embeddings | |
| index = load_database(lookup_embeddings.copy()) | |
| D, I = query(index, query_embeddings.copy(), k=10) | |
| assert D.shape == (10, 10) # 10 queries, k=10 neighbors | |
| assert I.shape == (10, 10) | |
| def test_query_distances_are_similarities(self, sample_embeddings): | |
| """Test that distances are cosine similarities (normalized dot product).""" | |
| query_embeddings, lookup_embeddings = sample_embeddings | |
| index = load_database(lookup_embeddings.copy()) | |
| D, I = query(index, query_embeddings.copy(), k=10) | |
| # Cosine similarities should be in [-1, 1] range | |
| assert D.min() >= -1.0 | |
| assert D.max() <= 1.0 | |
| def test_query_indices_valid(self, sample_embeddings): | |
| """Test that returned indices are valid.""" | |
| query_embeddings, lookup_embeddings = sample_embeddings | |
| index = load_database(lookup_embeddings.copy()) | |
| D, I = query(index, query_embeddings.copy(), k=10) | |
| # All indices should be in valid range | |
| assert I.min() >= 0 | |
| assert I.max() < 100 # lookup has 100 embeddings | |
| class TestRiskMetrics: | |
| """Tests for FDR, FNR, and related risk calculations.""" | |
| def test_risk_all_correct(self): | |
| """Test risk is 0 when all predictions above threshold are correct.""" | |
| sims = np.array([[0.9, 0.8, 0.7, 0.6]]) | |
| labels = np.array([[True, True, True, False]]) # First 3 are true matches | |
| # Threshold 0.65: returns indices 0,1,2 (all true) → FDR = 0 | |
| fdr = risk(sims, labels, 0.65) | |
| assert fdr == 0.0 | |
| def test_risk_all_incorrect(self): | |
| """Test risk is 1 when all predictions above threshold are incorrect.""" | |
| sims = np.array([[0.9, 0.8, 0.7, 0.6]]) | |
| labels = np.array([[False, False, False, True]]) # Only index 3 is true | |
| # Threshold 0.65: returns indices 0,1,2 (all false) → FDR = 1 | |
| fdr = risk(sims, labels, 0.65) | |
| assert fdr == 1.0 | |
| def test_risk_partial(self): | |
| """Test risk calculation with mixed predictions.""" | |
| sims = np.array([[0.9, 0.8, 0.7, 0.6]]) | |
| labels = np.array([[True, False, True, False]]) | |
| # Threshold 0.65: returns 3 items, 1 false → FDR = 1/3 | |
| fdr = risk(sims, labels, 0.65) | |
| assert abs(fdr - 1/3) < 1e-6 | |
| def test_calculate_false_negatives_zero(self): | |
| """Test FNR is 0 when all positives are detected.""" | |
| sims = np.array([[0.9, 0.8, 0.7, 0.6]]) | |
| labels = np.array([[True, True, False, False]]) | |
| # Threshold 0.75: detects both true positives → FNR = 0 | |
| fnr = calculate_false_negatives(sims, labels, 0.75) | |
| assert fnr == 0.0 | |
| def test_calculate_false_negatives_partial(self): | |
| """Test FNR when some positives are missed.""" | |
| sims = np.array([[0.9, 0.8, 0.7, 0.6]]) | |
| labels = np.array([[True, True, True, False]]) | |
| # Threshold 0.85: only detects index 0, misses 1,2 → FNR = 2/3 | |
| fnr = calculate_false_negatives(sims, labels, 0.85) | |
| assert abs(fnr - 2/3) < 1e-6 | |
| class TestConformalThresholds: | |
| """Tests for conformal risk control threshold calculations.""" | |
| def test_get_thresh_new_basic(self, scope_like_data): | |
| """Test basic threshold calculation for FNR control.""" | |
| sims, labels = scope_like_data | |
| alpha = 0.1 | |
| lhat = get_thresh_new(sims, labels, alpha) | |
| # Threshold should be in valid similarity range | |
| assert sims.min() <= lhat <= sims.max() | |
| def test_get_thresh_new_FDR_basic(self, scope_like_data): | |
| """Test basic threshold calculation for FDR control.""" | |
| sims, labels = scope_like_data | |
| alpha = 0.1 | |
| lhat = get_thresh_new_FDR(sims, labels, alpha) | |
| # Threshold should be in valid similarity range | |
| assert sims.min() <= lhat <= sims.max() | |
| def test_threshold_decreases_with_lower_alpha(self, scope_like_data): | |
| """Test that more stringent alpha leads to lower threshold for FNR control. | |
| For FNR (false negative rate) control via get_thresh_new: | |
| - Lower alpha = more stringent = want fewer false negatives | |
| - Algorithm picks a lower quantile of positive similarities | |
| - Lower quantile = lower threshold = accept more matches | |
| """ | |
| sims, labels = scope_like_data | |
| lhat_10 = get_thresh_new(sims, labels, alpha=0.1) | |
| lhat_05 = get_thresh_new(sims, labels, alpha=0.05) | |
| # Lower alpha (more stringent FNR) should give lower threshold | |
| assert lhat_05 <= lhat_10 | |
| def test_get_thresh_FDR_returns_risk(self, scope_like_data): | |
| """Test that get_thresh_FDR returns both threshold and risk.""" | |
| sims, labels = scope_like_data | |
| alpha = 0.1 | |
| lhat, risk_fdr = get_thresh_FDR(labels, sims, alpha, delta=0.5, N=100) | |
| # Should return valid threshold and risk | |
| assert isinstance(lhat, (int, float)) | |
| assert isinstance(risk_fdr, (int, float)) | |
| assert 0 <= risk_fdr <= 1 | |
| class TestVennAbers: | |
| """Tests for Venn-Abers probability predictions.""" | |
| def test_simplified_venn_abers_returns_two_probs(self): | |
| """Test that simplified Venn-Abers returns p0 and p1.""" | |
| np.random.seed(42) | |
| X_cal = np.random.uniform(0.5, 1.0, 100) | |
| Y_cal = (X_cal > 0.7).astype(bool) | |
| X_test = 0.8 | |
| p0, p1 = simplifed_venn_abers_prediction(X_cal, Y_cal, X_test) | |
| assert 0 <= p0 <= 1 | |
| assert 0 <= p1 <= 1 | |
| def test_venn_abers_high_similarity_high_prob(self): | |
| """Test that high similarity gives high probability.""" | |
| # Calibration: high sim → positive label | |
| X_cal = np.array([0.5, 0.6, 0.7, 0.8, 0.9, 0.95]) | |
| Y_cal = np.array([False, False, False, True, True, True]) | |
| # Test point with high similarity should get high probability | |
| p0, p1 = simplifed_venn_abers_prediction(X_cal.copy(), Y_cal.copy(), 0.92) | |
| # Average of p0, p1 should be high for high similarity | |
| avg_prob = (p0 + p1) / 2 | |
| assert avg_prob > 0.5 | |
| def test_isotonic_regression_monotonic(self): | |
| """Test that isotonic regression produces monotonic predictions.""" | |
| X = np.array([0.5, 0.6, 0.7, 0.8, 0.9]) | |
| y = np.array([0.1, 0.2, 0.4, 0.8, 0.9]) | |
| ir = get_isotone_regression(X, y) | |
| # Predictions should be monotonically increasing | |
| test_x = np.linspace(0.5, 0.9, 10) | |
| preds = ir.predict(test_x) | |
| assert all(preds[i] <= preds[i+1] for i in range(len(preds)-1)) | |
| class TestHierarchicalLoss: | |
| """Tests for SCOPe hierarchical loss function.""" | |
| def test_exact_match(self): | |
| """Test exact match returns loss=0, exact=True.""" | |
| loss, exact = scope_hierarchical_loss('a.1.1.1', 'a.1.1.1') | |
| assert loss == 0 | |
| assert exact is True | |
| def test_family_mismatch(self): | |
| """Test family mismatch (last level) returns loss=1.""" | |
| loss, exact = scope_hierarchical_loss('a.1.1.1', 'a.1.1.2') | |
| assert loss == 1 | |
| assert exact is False | |
| def test_superfamily_mismatch(self): | |
| """Test superfamily mismatch returns loss=2.""" | |
| loss, exact = scope_hierarchical_loss('a.1.1.1', 'a.1.2.1') | |
| assert loss == 2 | |
| assert exact is False | |
| def test_fold_mismatch(self): | |
| """Test fold mismatch returns loss=3.""" | |
| loss, exact = scope_hierarchical_loss('a.1.1.1', 'a.2.1.1') | |
| assert loss == 3 | |
| assert exact is False | |
| def test_class_mismatch(self): | |
| """Test class mismatch returns loss=4.""" | |
| loss, exact = scope_hierarchical_loss('a.1.1.1', 'b.1.1.1') | |
| assert loss == 4 | |
| assert exact is False | |
| class TestValidation: | |
| """Tests for validation functions.""" | |
| def test_validate_lhat_new_returns_metrics(self, scope_like_data): | |
| """Test that validate_lhat_new returns expected metrics.""" | |
| sims, labels_exact = scope_like_data | |
| labels_partial = labels_exact.copy() # Use same for simplicity | |
| lhat = 0.9995 # Some threshold | |
| error, frac_inexact, error_partial, frac_partial, fpr = validate_lhat_new( | |
| sims, labels_partial, labels_exact, lhat | |
| ) | |
| # All metrics should be in [0, 1] | |
| assert 0 <= error <= 1 | |
| assert 0 <= frac_inexact <= 1 | |
| assert 0 <= error_partial <= 1 | |
| assert 0 <= frac_partial <= 1 | |
| assert 0 <= fpr <= 1 | |
| class TestIntegration: | |
| """Integration tests combining multiple components.""" | |
| def test_full_fdr_pipeline(self, calibration_test_split): | |
| """Test complete FDR control pipeline: calibrate → threshold → validate.""" | |
| data = calibration_test_split | |
| alpha = 0.1 | |
| # Step 1: Get threshold from calibration data | |
| lhat = get_thresh_new_FDR( | |
| data['cal_sims'], | |
| data['cal_labels'], | |
| alpha | |
| ) | |
| # Step 2: Calculate risk on test data | |
| test_fdr = risk(data['test_sims'], data['test_labels'], lhat) | |
| # FDR should be controlled (may be higher due to randomness in small samples) | |
| # In practice with enough data, test_fdr should be <= alpha | |
| assert test_fdr >= 0 # At minimum, should be valid | |
| def test_full_fnr_pipeline(self, calibration_test_split): | |
| """Test complete FNR control pipeline.""" | |
| data = calibration_test_split | |
| alpha = 0.1 | |
| # Get threshold for FNR control | |
| lhat = get_thresh_new( | |
| data['cal_sims'], | |
| data['cal_labels'], | |
| alpha | |
| ) | |
| # Calculate FNR on test data | |
| test_fnr = calculate_false_negatives( | |
| data['test_sims'], | |
| data['test_labels'], | |
| lhat | |
| ) | |
| # FNR should be controlled | |
| assert test_fnr >= 0 # At minimum, should be valid | |