File size: 12,290 Bytes
f4b267d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c95d941
 
 
 
 
 
 
 
f4b267d
 
 
 
 
c95d941
 
f4b267d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
"""
Tests for protein_conformal/util.py core functions.

These tests verify:
1. FASTA parsing
2. FAISS database operations
3. FDR/FNR threshold calculations (conformal risk control)
4. Risk metrics (FDR, FNR, TPR)
5. Venn-Abers probability predictions
6. Hierarchical loss functions (for SCOPe)
"""
import numpy as np
import pytest
from protein_conformal.util import (
    read_fasta,
    load_database,
    query,
    get_thresh_new,
    get_thresh_new_FDR,
    get_thresh_FDR,
    risk,
    risk_1d,
    calculate_false_negatives,
    calculate_true_positives,
    simplifed_venn_abers_prediction,
    get_isotone_regression,
    scope_hierarchical_loss,
    validate_lhat_new,
)


class TestFastaParsing:
    """Tests for FASTA file parsing."""

    def test_read_fasta_basic(self, sample_fasta_file):
        """Test basic FASTA parsing returns sequences and metadata."""
        sequences, metadata = read_fasta(sample_fasta_file)

        assert len(sequences) == 3
        assert len(metadata) == 3

        # Check first sequence
        assert sequences[0].startswith('MVLSPADKTN')
        assert '>protein1' in metadata[0]

    def test_read_fasta_sequence_content(self, sample_fasta_file):
        """Test that sequences contain only valid amino acids."""
        sequences, _ = read_fasta(sample_fasta_file)

        valid_aa = set('ACDEFGHIKLMNPQRSTVWY')
        for seq in sequences:
            assert all(aa in valid_aa for aa in seq), f"Invalid AA in sequence: {seq}"

    def test_read_fasta_short_sequence(self, sample_fasta_file):
        """Test that short sequence is parsed correctly."""
        sequences, metadata = read_fasta(sample_fasta_file)

        # Third sequence is exactly the 20 standard amino acids
        assert sequences[2] == 'ACDEFGHIKLMNPQRSTVWY'
        assert len(sequences[2]) == 20


class TestFAISSOperations:
    """Tests for FAISS database loading and querying."""

    def test_load_database(self, sample_embeddings):
        """Test that database loads and has correct dimensions."""
        _, lookup_embeddings = sample_embeddings

        index = load_database(lookup_embeddings.copy())

        assert index.ntotal == 100  # Number of vectors in index
        assert index.d == 128  # Dimensionality

    def test_query_returns_correct_shape(self, sample_embeddings):
        """Test that query returns distances and indices with correct shapes."""
        query_embeddings, lookup_embeddings = sample_embeddings

        index = load_database(lookup_embeddings.copy())
        D, I = query(index, query_embeddings.copy(), k=10)

        assert D.shape == (10, 10)  # 10 queries, k=10 neighbors
        assert I.shape == (10, 10)

    def test_query_distances_are_similarities(self, sample_embeddings):
        """Test that distances are cosine similarities (normalized dot product)."""
        query_embeddings, lookup_embeddings = sample_embeddings

        index = load_database(lookup_embeddings.copy())
        D, I = query(index, query_embeddings.copy(), k=10)

        # Cosine similarities should be in [-1, 1] range
        assert D.min() >= -1.0
        assert D.max() <= 1.0

    def test_query_indices_valid(self, sample_embeddings):
        """Test that returned indices are valid."""
        query_embeddings, lookup_embeddings = sample_embeddings

        index = load_database(lookup_embeddings.copy())
        D, I = query(index, query_embeddings.copy(), k=10)

        # All indices should be in valid range
        assert I.min() >= 0
        assert I.max() < 100  # lookup has 100 embeddings


class TestRiskMetrics:
    """Tests for FDR, FNR, and related risk calculations."""

    def test_risk_all_correct(self):
        """Test risk is 0 when all predictions above threshold are correct."""
        sims = np.array([[0.9, 0.8, 0.7, 0.6]])
        labels = np.array([[True, True, True, False]])  # First 3 are true matches

        # Threshold 0.65: returns indices 0,1,2 (all true) → FDR = 0
        fdr = risk(sims, labels, 0.65)
        assert fdr == 0.0

    def test_risk_all_incorrect(self):
        """Test risk is 1 when all predictions above threshold are incorrect."""
        sims = np.array([[0.9, 0.8, 0.7, 0.6]])
        labels = np.array([[False, False, False, True]])  # Only index 3 is true

        # Threshold 0.65: returns indices 0,1,2 (all false) → FDR = 1
        fdr = risk(sims, labels, 0.65)
        assert fdr == 1.0

    def test_risk_partial(self):
        """Test risk calculation with mixed predictions."""
        sims = np.array([[0.9, 0.8, 0.7, 0.6]])
        labels = np.array([[True, False, True, False]])

        # Threshold 0.65: returns 3 items, 1 false → FDR = 1/3
        fdr = risk(sims, labels, 0.65)
        assert abs(fdr - 1/3) < 1e-6

    def test_calculate_false_negatives_zero(self):
        """Test FNR is 0 when all positives are detected."""
        sims = np.array([[0.9, 0.8, 0.7, 0.6]])
        labels = np.array([[True, True, False, False]])

        # Threshold 0.75: detects both true positives → FNR = 0
        fnr = calculate_false_negatives(sims, labels, 0.75)
        assert fnr == 0.0

    def test_calculate_false_negatives_partial(self):
        """Test FNR when some positives are missed."""
        sims = np.array([[0.9, 0.8, 0.7, 0.6]])
        labels = np.array([[True, True, True, False]])

        # Threshold 0.85: only detects index 0, misses 1,2 → FNR = 2/3
        fnr = calculate_false_negatives(sims, labels, 0.85)
        assert abs(fnr - 2/3) < 1e-6


class TestConformalThresholds:
    """Tests for conformal risk control threshold calculations."""

    def test_get_thresh_new_basic(self, scope_like_data):
        """Test basic threshold calculation for FNR control."""
        sims, labels = scope_like_data
        alpha = 0.1

        lhat = get_thresh_new(sims, labels, alpha)

        # Threshold should be in valid similarity range
        assert sims.min() <= lhat <= sims.max()

    def test_get_thresh_new_FDR_basic(self, scope_like_data):
        """Test basic threshold calculation for FDR control."""
        sims, labels = scope_like_data
        alpha = 0.1

        lhat = get_thresh_new_FDR(sims, labels, alpha)

        # Threshold should be in valid similarity range
        assert sims.min() <= lhat <= sims.max()

    def test_threshold_decreases_with_lower_alpha(self, scope_like_data):
        """Test that more stringent alpha leads to lower threshold for FNR control.

        For FNR (false negative rate) control via get_thresh_new:
        - Lower alpha = more stringent = want fewer false negatives
        - Algorithm picks a lower quantile of positive similarities
        - Lower quantile = lower threshold = accept more matches
        """
        sims, labels = scope_like_data

        lhat_10 = get_thresh_new(sims, labels, alpha=0.1)
        lhat_05 = get_thresh_new(sims, labels, alpha=0.05)

        # Lower alpha (more stringent FNR) should give lower threshold
        assert lhat_05 <= lhat_10

    def test_get_thresh_FDR_returns_risk(self, scope_like_data):
        """Test that get_thresh_FDR returns both threshold and risk."""
        sims, labels = scope_like_data
        alpha = 0.1

        lhat, risk_fdr = get_thresh_FDR(labels, sims, alpha, delta=0.5, N=100)

        # Should return valid threshold and risk
        assert isinstance(lhat, (int, float))
        assert isinstance(risk_fdr, (int, float))
        assert 0 <= risk_fdr <= 1


class TestVennAbers:
    """Tests for Venn-Abers probability predictions."""

    def test_simplified_venn_abers_returns_two_probs(self):
        """Test that simplified Venn-Abers returns p0 and p1."""
        np.random.seed(42)
        X_cal = np.random.uniform(0.5, 1.0, 100)
        Y_cal = (X_cal > 0.7).astype(bool)
        X_test = 0.8

        p0, p1 = simplifed_venn_abers_prediction(X_cal, Y_cal, X_test)

        assert 0 <= p0 <= 1
        assert 0 <= p1 <= 1

    def test_venn_abers_high_similarity_high_prob(self):
        """Test that high similarity gives high probability."""
        # Calibration: high sim → positive label
        X_cal = np.array([0.5, 0.6, 0.7, 0.8, 0.9, 0.95])
        Y_cal = np.array([False, False, False, True, True, True])

        # Test point with high similarity should get high probability
        p0, p1 = simplifed_venn_abers_prediction(X_cal.copy(), Y_cal.copy(), 0.92)

        # Average of p0, p1 should be high for high similarity
        avg_prob = (p0 + p1) / 2
        assert avg_prob > 0.5

    def test_isotonic_regression_monotonic(self):
        """Test that isotonic regression produces monotonic predictions."""
        X = np.array([0.5, 0.6, 0.7, 0.8, 0.9])
        y = np.array([0.1, 0.2, 0.4, 0.8, 0.9])

        ir = get_isotone_regression(X, y)

        # Predictions should be monotonically increasing
        test_x = np.linspace(0.5, 0.9, 10)
        preds = ir.predict(test_x)

        assert all(preds[i] <= preds[i+1] for i in range(len(preds)-1))


class TestHierarchicalLoss:
    """Tests for SCOPe hierarchical loss function."""

    def test_exact_match(self):
        """Test exact match returns loss=0, exact=True."""
        loss, exact = scope_hierarchical_loss('a.1.1.1', 'a.1.1.1')
        assert loss == 0
        assert exact is True

    def test_family_mismatch(self):
        """Test family mismatch (last level) returns loss=1."""
        loss, exact = scope_hierarchical_loss('a.1.1.1', 'a.1.1.2')
        assert loss == 1
        assert exact is False

    def test_superfamily_mismatch(self):
        """Test superfamily mismatch returns loss=2."""
        loss, exact = scope_hierarchical_loss('a.1.1.1', 'a.1.2.1')
        assert loss == 2
        assert exact is False

    def test_fold_mismatch(self):
        """Test fold mismatch returns loss=3."""
        loss, exact = scope_hierarchical_loss('a.1.1.1', 'a.2.1.1')
        assert loss == 3
        assert exact is False

    def test_class_mismatch(self):
        """Test class mismatch returns loss=4."""
        loss, exact = scope_hierarchical_loss('a.1.1.1', 'b.1.1.1')
        assert loss == 4
        assert exact is False


class TestValidation:
    """Tests for validation functions."""

    def test_validate_lhat_new_returns_metrics(self, scope_like_data):
        """Test that validate_lhat_new returns expected metrics."""
        sims, labels_exact = scope_like_data
        labels_partial = labels_exact.copy()  # Use same for simplicity

        lhat = 0.9995  # Some threshold

        error, frac_inexact, error_partial, frac_partial, fpr = validate_lhat_new(
            sims, labels_partial, labels_exact, lhat
        )

        # All metrics should be in [0, 1]
        assert 0 <= error <= 1
        assert 0 <= frac_inexact <= 1
        assert 0 <= error_partial <= 1
        assert 0 <= frac_partial <= 1
        assert 0 <= fpr <= 1


class TestIntegration:
    """Integration tests combining multiple components."""

    def test_full_fdr_pipeline(self, calibration_test_split):
        """Test complete FDR control pipeline: calibrate → threshold → validate."""
        data = calibration_test_split
        alpha = 0.1

        # Step 1: Get threshold from calibration data
        lhat = get_thresh_new_FDR(
            data['cal_sims'],
            data['cal_labels'],
            alpha
        )

        # Step 2: Calculate risk on test data
        test_fdr = risk(data['test_sims'], data['test_labels'], lhat)

        # FDR should be controlled (may be higher due to randomness in small samples)
        # In practice with enough data, test_fdr should be <= alpha
        assert test_fdr >= 0  # At minimum, should be valid

    def test_full_fnr_pipeline(self, calibration_test_split):
        """Test complete FNR control pipeline."""
        data = calibration_test_split
        alpha = 0.1

        # Get threshold for FNR control
        lhat = get_thresh_new(
            data['cal_sims'],
            data['cal_labels'],
            alpha
        )

        # Calculate FNR on test data
        test_fnr = calculate_false_negatives(
            data['test_sims'],
            data['test_labels'],
            lhat
        )

        # FNR should be controlled
        assert test_fnr >= 0  # At minimum, should be valid