cpr / tests /test_util.py
ronboger's picture
feat: add CLI, fix FDR bug, verify paper result (59/149 = 39.6%)
c95d941
"""
Tests for protein_conformal/util.py core functions.
These tests verify:
1. FASTA parsing
2. FAISS database operations
3. FDR/FNR threshold calculations (conformal risk control)
4. Risk metrics (FDR, FNR, TPR)
5. Venn-Abers probability predictions
6. Hierarchical loss functions (for SCOPe)
"""
import numpy as np
import pytest
from protein_conformal.util import (
read_fasta,
load_database,
query,
get_thresh_new,
get_thresh_new_FDR,
get_thresh_FDR,
risk,
risk_1d,
calculate_false_negatives,
calculate_true_positives,
simplifed_venn_abers_prediction,
get_isotone_regression,
scope_hierarchical_loss,
validate_lhat_new,
)
class TestFastaParsing:
"""Tests for FASTA file parsing."""
def test_read_fasta_basic(self, sample_fasta_file):
"""Test basic FASTA parsing returns sequences and metadata."""
sequences, metadata = read_fasta(sample_fasta_file)
assert len(sequences) == 3
assert len(metadata) == 3
# Check first sequence
assert sequences[0].startswith('MVLSPADKTN')
assert '>protein1' in metadata[0]
def test_read_fasta_sequence_content(self, sample_fasta_file):
"""Test that sequences contain only valid amino acids."""
sequences, _ = read_fasta(sample_fasta_file)
valid_aa = set('ACDEFGHIKLMNPQRSTVWY')
for seq in sequences:
assert all(aa in valid_aa for aa in seq), f"Invalid AA in sequence: {seq}"
def test_read_fasta_short_sequence(self, sample_fasta_file):
"""Test that short sequence is parsed correctly."""
sequences, metadata = read_fasta(sample_fasta_file)
# Third sequence is exactly the 20 standard amino acids
assert sequences[2] == 'ACDEFGHIKLMNPQRSTVWY'
assert len(sequences[2]) == 20
class TestFAISSOperations:
"""Tests for FAISS database loading and querying."""
def test_load_database(self, sample_embeddings):
"""Test that database loads and has correct dimensions."""
_, lookup_embeddings = sample_embeddings
index = load_database(lookup_embeddings.copy())
assert index.ntotal == 100 # Number of vectors in index
assert index.d == 128 # Dimensionality
def test_query_returns_correct_shape(self, sample_embeddings):
"""Test that query returns distances and indices with correct shapes."""
query_embeddings, lookup_embeddings = sample_embeddings
index = load_database(lookup_embeddings.copy())
D, I = query(index, query_embeddings.copy(), k=10)
assert D.shape == (10, 10) # 10 queries, k=10 neighbors
assert I.shape == (10, 10)
def test_query_distances_are_similarities(self, sample_embeddings):
"""Test that distances are cosine similarities (normalized dot product)."""
query_embeddings, lookup_embeddings = sample_embeddings
index = load_database(lookup_embeddings.copy())
D, I = query(index, query_embeddings.copy(), k=10)
# Cosine similarities should be in [-1, 1] range
assert D.min() >= -1.0
assert D.max() <= 1.0
def test_query_indices_valid(self, sample_embeddings):
"""Test that returned indices are valid."""
query_embeddings, lookup_embeddings = sample_embeddings
index = load_database(lookup_embeddings.copy())
D, I = query(index, query_embeddings.copy(), k=10)
# All indices should be in valid range
assert I.min() >= 0
assert I.max() < 100 # lookup has 100 embeddings
class TestRiskMetrics:
"""Tests for FDR, FNR, and related risk calculations."""
def test_risk_all_correct(self):
"""Test risk is 0 when all predictions above threshold are correct."""
sims = np.array([[0.9, 0.8, 0.7, 0.6]])
labels = np.array([[True, True, True, False]]) # First 3 are true matches
# Threshold 0.65: returns indices 0,1,2 (all true) → FDR = 0
fdr = risk(sims, labels, 0.65)
assert fdr == 0.0
def test_risk_all_incorrect(self):
"""Test risk is 1 when all predictions above threshold are incorrect."""
sims = np.array([[0.9, 0.8, 0.7, 0.6]])
labels = np.array([[False, False, False, True]]) # Only index 3 is true
# Threshold 0.65: returns indices 0,1,2 (all false) → FDR = 1
fdr = risk(sims, labels, 0.65)
assert fdr == 1.0
def test_risk_partial(self):
"""Test risk calculation with mixed predictions."""
sims = np.array([[0.9, 0.8, 0.7, 0.6]])
labels = np.array([[True, False, True, False]])
# Threshold 0.65: returns 3 items, 1 false → FDR = 1/3
fdr = risk(sims, labels, 0.65)
assert abs(fdr - 1/3) < 1e-6
def test_calculate_false_negatives_zero(self):
"""Test FNR is 0 when all positives are detected."""
sims = np.array([[0.9, 0.8, 0.7, 0.6]])
labels = np.array([[True, True, False, False]])
# Threshold 0.75: detects both true positives → FNR = 0
fnr = calculate_false_negatives(sims, labels, 0.75)
assert fnr == 0.0
def test_calculate_false_negatives_partial(self):
"""Test FNR when some positives are missed."""
sims = np.array([[0.9, 0.8, 0.7, 0.6]])
labels = np.array([[True, True, True, False]])
# Threshold 0.85: only detects index 0, misses 1,2 → FNR = 2/3
fnr = calculate_false_negatives(sims, labels, 0.85)
assert abs(fnr - 2/3) < 1e-6
class TestConformalThresholds:
"""Tests for conformal risk control threshold calculations."""
def test_get_thresh_new_basic(self, scope_like_data):
"""Test basic threshold calculation for FNR control."""
sims, labels = scope_like_data
alpha = 0.1
lhat = get_thresh_new(sims, labels, alpha)
# Threshold should be in valid similarity range
assert sims.min() <= lhat <= sims.max()
def test_get_thresh_new_FDR_basic(self, scope_like_data):
"""Test basic threshold calculation for FDR control."""
sims, labels = scope_like_data
alpha = 0.1
lhat = get_thresh_new_FDR(sims, labels, alpha)
# Threshold should be in valid similarity range
assert sims.min() <= lhat <= sims.max()
def test_threshold_decreases_with_lower_alpha(self, scope_like_data):
"""Test that more stringent alpha leads to lower threshold for FNR control.
For FNR (false negative rate) control via get_thresh_new:
- Lower alpha = more stringent = want fewer false negatives
- Algorithm picks a lower quantile of positive similarities
- Lower quantile = lower threshold = accept more matches
"""
sims, labels = scope_like_data
lhat_10 = get_thresh_new(sims, labels, alpha=0.1)
lhat_05 = get_thresh_new(sims, labels, alpha=0.05)
# Lower alpha (more stringent FNR) should give lower threshold
assert lhat_05 <= lhat_10
def test_get_thresh_FDR_returns_risk(self, scope_like_data):
"""Test that get_thresh_FDR returns both threshold and risk."""
sims, labels = scope_like_data
alpha = 0.1
lhat, risk_fdr = get_thresh_FDR(labels, sims, alpha, delta=0.5, N=100)
# Should return valid threshold and risk
assert isinstance(lhat, (int, float))
assert isinstance(risk_fdr, (int, float))
assert 0 <= risk_fdr <= 1
class TestVennAbers:
"""Tests for Venn-Abers probability predictions."""
def test_simplified_venn_abers_returns_two_probs(self):
"""Test that simplified Venn-Abers returns p0 and p1."""
np.random.seed(42)
X_cal = np.random.uniform(0.5, 1.0, 100)
Y_cal = (X_cal > 0.7).astype(bool)
X_test = 0.8
p0, p1 = simplifed_venn_abers_prediction(X_cal, Y_cal, X_test)
assert 0 <= p0 <= 1
assert 0 <= p1 <= 1
def test_venn_abers_high_similarity_high_prob(self):
"""Test that high similarity gives high probability."""
# Calibration: high sim → positive label
X_cal = np.array([0.5, 0.6, 0.7, 0.8, 0.9, 0.95])
Y_cal = np.array([False, False, False, True, True, True])
# Test point with high similarity should get high probability
p0, p1 = simplifed_venn_abers_prediction(X_cal.copy(), Y_cal.copy(), 0.92)
# Average of p0, p1 should be high for high similarity
avg_prob = (p0 + p1) / 2
assert avg_prob > 0.5
def test_isotonic_regression_monotonic(self):
"""Test that isotonic regression produces monotonic predictions."""
X = np.array([0.5, 0.6, 0.7, 0.8, 0.9])
y = np.array([0.1, 0.2, 0.4, 0.8, 0.9])
ir = get_isotone_regression(X, y)
# Predictions should be monotonically increasing
test_x = np.linspace(0.5, 0.9, 10)
preds = ir.predict(test_x)
assert all(preds[i] <= preds[i+1] for i in range(len(preds)-1))
class TestHierarchicalLoss:
"""Tests for SCOPe hierarchical loss function."""
def test_exact_match(self):
"""Test exact match returns loss=0, exact=True."""
loss, exact = scope_hierarchical_loss('a.1.1.1', 'a.1.1.1')
assert loss == 0
assert exact is True
def test_family_mismatch(self):
"""Test family mismatch (last level) returns loss=1."""
loss, exact = scope_hierarchical_loss('a.1.1.1', 'a.1.1.2')
assert loss == 1
assert exact is False
def test_superfamily_mismatch(self):
"""Test superfamily mismatch returns loss=2."""
loss, exact = scope_hierarchical_loss('a.1.1.1', 'a.1.2.1')
assert loss == 2
assert exact is False
def test_fold_mismatch(self):
"""Test fold mismatch returns loss=3."""
loss, exact = scope_hierarchical_loss('a.1.1.1', 'a.2.1.1')
assert loss == 3
assert exact is False
def test_class_mismatch(self):
"""Test class mismatch returns loss=4."""
loss, exact = scope_hierarchical_loss('a.1.1.1', 'b.1.1.1')
assert loss == 4
assert exact is False
class TestValidation:
"""Tests for validation functions."""
def test_validate_lhat_new_returns_metrics(self, scope_like_data):
"""Test that validate_lhat_new returns expected metrics."""
sims, labels_exact = scope_like_data
labels_partial = labels_exact.copy() # Use same for simplicity
lhat = 0.9995 # Some threshold
error, frac_inexact, error_partial, frac_partial, fpr = validate_lhat_new(
sims, labels_partial, labels_exact, lhat
)
# All metrics should be in [0, 1]
assert 0 <= error <= 1
assert 0 <= frac_inexact <= 1
assert 0 <= error_partial <= 1
assert 0 <= frac_partial <= 1
assert 0 <= fpr <= 1
class TestIntegration:
"""Integration tests combining multiple components."""
def test_full_fdr_pipeline(self, calibration_test_split):
"""Test complete FDR control pipeline: calibrate → threshold → validate."""
data = calibration_test_split
alpha = 0.1
# Step 1: Get threshold from calibration data
lhat = get_thresh_new_FDR(
data['cal_sims'],
data['cal_labels'],
alpha
)
# Step 2: Calculate risk on test data
test_fdr = risk(data['test_sims'], data['test_labels'], lhat)
# FDR should be controlled (may be higher due to randomness in small samples)
# In practice with enough data, test_fdr should be <= alpha
assert test_fdr >= 0 # At minimum, should be valid
def test_full_fnr_pipeline(self, calibration_test_split):
"""Test complete FNR control pipeline."""
data = calibration_test_split
alpha = 0.1
# Get threshold for FNR control
lhat = get_thresh_new(
data['cal_sims'],
data['cal_labels'],
alpha
)
# Calculate FNR on test data
test_fnr = calculate_false_negatives(
data['test_sims'],
data['test_labels'],
lhat
)
# FNR should be controlled
assert test_fnr >= 0 # At minimum, should be valid