cpr / tests /test_cli.py
ronboger's picture
fix: update CLI tests for --input flag and --no-filter
68a03a2
"""
Tests for CPR CLI (protein_conformal/cli.py).
Tests cover:
- Help text for all commands
- Basic functionality with mock data
- Error handling
"""
import subprocess
import sys
import tempfile
import numpy as np
import pandas as pd
import pytest
from pathlib import Path
def run_cli(*args):
"""Helper to run CLI commands via subprocess."""
result = subprocess.run(
[sys.executable, '-m', 'protein_conformal.cli'] + list(args),
capture_output=True,
text=True
)
return result
def test_main_help():
"""Test that 'cpr --help' shows all subcommands."""
result = run_cli('--help')
assert result.returncode == 0
assert 'embed' in result.stdout
assert 'search' in result.stdout
assert 'verify' in result.stdout
assert 'prob' in result.stdout
assert 'calibrate' in result.stdout
assert 'Conformal Protein Retrieval' in result.stdout
def test_main_no_command():
"""Test that running cpr with no command shows help."""
result = run_cli()
assert result.returncode == 1
# Should show help when no command provided
assert 'embed' in result.stdout or 'embed' in result.stderr
def test_embed_help():
"""Test that 'cpr embed --help' works and shows expected options."""
result = run_cli('embed', '--help')
assert result.returncode == 0
assert '--input' in result.stdout
assert '--output' in result.stdout
assert '--model' in result.stdout
assert 'protein-vec' in result.stdout
assert 'clean' in result.stdout
assert '--cpu' in result.stdout
def test_search_help():
"""Test that 'cpr search --help' works."""
result = run_cli('search', '--help')
assert result.returncode == 0
assert '--input' in result.stdout
assert '--database' in result.stdout
assert '--output' in result.stdout
assert '--k' in result.stdout
assert '--threshold' in result.stdout
assert '--database-meta' in result.stdout
def test_verify_help():
"""Test that 'cpr verify --help' works."""
result = run_cli('verify', '--help')
assert result.returncode == 0
assert '--check' in result.stdout
assert 'syn30' in result.stdout
assert 'fdr' in result.stdout
assert 'dali' in result.stdout
assert 'clean' in result.stdout
def test_prob_help():
"""Test that 'cpr prob --help' works."""
result = run_cli('prob', '--help')
assert result.returncode == 0
assert '--input' in result.stdout
assert '--calibration' in result.stdout
assert '--output' in result.stdout
assert '--score-column' in result.stdout
assert '--n-calib' in result.stdout
assert '--seed' in result.stdout
def test_calibrate_help():
"""Test that 'cpr calibrate --help' works."""
result = run_cli('calibrate', '--help')
assert result.returncode == 0
assert '--calibration' in result.stdout
assert '--output' in result.stdout
assert '--alpha' in result.stdout
assert '--n-trials' in result.stdout
assert '--n-calib' in result.stdout
assert '--method' in result.stdout
assert 'ltt' in result.stdout
assert 'quantile' in result.stdout
def test_embed_missing_args():
"""Test that embed command fails without required args."""
result = run_cli('embed')
assert result.returncode != 0
assert '--input' in result.stderr or 'required' in result.stderr
def test_search_missing_args():
"""Test that search command fails without required args."""
result = run_cli('search')
assert result.returncode != 0
assert '--input' in result.stderr or 'required' in result.stderr
def test_verify_missing_args():
"""Test that verify command fails without required args."""
result = run_cli('verify')
assert result.returncode != 0
assert '--check' in result.stderr or 'required' in result.stderr
def test_verify_invalid_check():
"""Test that verify command fails with invalid check name."""
result = run_cli('verify', '--check', 'invalid_check_name')
assert result.returncode != 0
def test_search_with_mock_data(tmp_path):
"""Test search command with small mock embeddings."""
# Create mock query and database embeddings
np.random.seed(42)
query_embeddings = np.random.randn(5, 128).astype(np.float32)
db_embeddings = np.random.randn(20, 128).astype(np.float32)
# Normalize to unit vectors (for cosine similarity)
query_embeddings = query_embeddings / np.linalg.norm(query_embeddings, axis=1, keepdims=True)
db_embeddings = db_embeddings / np.linalg.norm(db_embeddings, axis=1, keepdims=True)
# Save to temp files
query_file = tmp_path / "query.npy"
db_file = tmp_path / "db.npy"
output_file = tmp_path / "results.csv"
np.save(query_file, query_embeddings)
np.save(db_file, db_embeddings)
# Run search (use --no-filter since random embeddings won't pass FDR threshold)
result = run_cli(
'search',
'--input', str(query_file),
'--database', str(db_file),
'--output', str(output_file),
'--k', '3',
'--no-filter'
)
assert result.returncode == 0
assert output_file.exists()
# Verify output
df = pd.read_csv(output_file)
assert len(df) == 5 * 3 # 5 queries * 3 neighbors
assert 'query_idx' in df.columns
assert 'match_idx' in df.columns
assert 'similarity' in df.columns
# Check that similarities are reasonable (cosine similarity range)
assert df['similarity'].min() >= -1.0
assert df['similarity'].max() <= 1.0
def test_search_with_threshold(tmp_path):
"""Test search command with similarity threshold."""
np.random.seed(42)
query_embeddings = np.random.randn(3, 128).astype(np.float32)
db_embeddings = np.random.randn(10, 128).astype(np.float32)
query_embeddings = query_embeddings / np.linalg.norm(query_embeddings, axis=1, keepdims=True)
db_embeddings = db_embeddings / np.linalg.norm(db_embeddings, axis=1, keepdims=True)
query_file = tmp_path / "query.npy"
db_file = tmp_path / "db.npy"
output_file = tmp_path / "results.csv"
np.save(query_file, query_embeddings)
np.save(db_file, db_embeddings)
# Run search with high threshold
result = run_cli(
'search',
'--input', str(query_file),
'--database', str(db_file),
'--output', str(output_file),
'--k', '10',
'--threshold', '0.9'
)
assert result.returncode == 0
assert output_file.exists()
# With high threshold on random embeddings, file may be empty or have few results
# Random unit vectors have expected cosine similarity ~0, so 0.9 threshold filters most
try:
df = pd.read_csv(output_file)
# With high threshold, we should have fewer results
assert len(df) <= 3 * 10 # At most 3 queries * 10 neighbors
# All results should be above threshold
if len(df) > 0:
assert df['similarity'].min() >= 0.9
except pd.errors.EmptyDataError:
# Empty file is valid - no results passed threshold
pass
def test_search_with_metadata(tmp_path):
"""Test search command with database metadata."""
np.random.seed(42)
query_embeddings = np.random.randn(2, 128).astype(np.float32)
db_embeddings = np.random.randn(5, 128).astype(np.float32)
query_embeddings = query_embeddings / np.linalg.norm(query_embeddings, axis=1, keepdims=True)
db_embeddings = db_embeddings / np.linalg.norm(db_embeddings, axis=1, keepdims=True)
query_file = tmp_path / "query.npy"
db_file = tmp_path / "db.npy"
meta_file = tmp_path / "meta.csv"
output_file = tmp_path / "results.csv"
np.save(query_file, query_embeddings)
np.save(db_file, db_embeddings)
# Create metadata
meta_df = pd.DataFrame({
'protein_id': [f'PROT_{i:03d}' for i in range(5)],
'description': [f'Protein {i}' for i in range(5)],
'organism': ['E. coli', 'Human', 'Yeast', 'Mouse', 'Rat'],
})
meta_df.to_csv(meta_file, index=False)
# Run search with metadata (use --no-filter since random embeddings won't pass FDR threshold)
result = run_cli(
'search',
'--input', str(query_file),
'--database', str(db_file),
'--database-meta', str(meta_file),
'--output', str(output_file),
'--k', '3',
'--no-filter'
)
assert result.returncode == 0
assert output_file.exists()
df = pd.read_csv(output_file)
assert len(df) == 2 * 3 # 2 queries * 3 neighbors
# Check that metadata columns were added
assert 'match_protein_id' in df.columns
assert 'match_description' in df.columns
assert 'match_organism' in df.columns
def test_prob_with_mock_data(tmp_path):
"""Test prob command with mock calibration data and scores."""
np.random.seed(42)
# Create mock calibration data (format: array of dicts with S_i, exact, partial)
n_calib = 50
cal_data = []
for i in range(n_calib):
sims = np.random.uniform(0.998, 0.9999, size=10).astype(np.float32)
exact_labels = (np.random.random(10) < 0.2).astype(bool)
partial_labels = exact_labels | (np.random.random(10) < 0.1)
cal_data.append({
"S_i": sims,
"exact": exact_labels,
"partial": partial_labels,
})
cal_file = tmp_path / "calibration.npy"
np.save(cal_file, np.array(cal_data, dtype=object))
# Create input scores
scores = np.array([0.9985, 0.9990, 0.9995, 0.9998])
score_file = tmp_path / "scores.npy"
np.save(score_file, scores)
output_file = tmp_path / "probs.csv"
# Run prob command
result = run_cli(
'prob',
'--input', str(score_file),
'--calibration', str(cal_file),
'--output', str(output_file),
'--n-calib', '50',
'--seed', '42'
)
assert result.returncode == 0
assert output_file.exists()
df = pd.read_csv(output_file)
assert len(df) == 4
assert 'score' in df.columns
assert 'probability' in df.columns
assert 'uncertainty' in df.columns
# Probabilities should be in [0, 1]
assert df['probability'].min() >= 0.0
assert df['probability'].max() <= 1.0
# Uncertainties should be in [0, 1]
assert df['uncertainty'].min() >= 0.0
assert df['uncertainty'].max() <= 1.0
def test_prob_with_csv_input(tmp_path):
"""Test prob command with CSV input (e.g., from search results)."""
np.random.seed(42)
# Create mock calibration data (format: array of dicts with S_i, exact, partial)
n_calib = 30
cal_data = []
for i in range(n_calib):
sims = np.random.uniform(0.998, 0.9999, size=5).astype(np.float32)
exact_labels = (np.random.random(5) < 0.2).astype(bool)
partial_labels = exact_labels | (np.random.random(5) < 0.1)
cal_data.append({
"S_i": sims,
"exact": exact_labels,
"partial": partial_labels,
})
cal_file = tmp_path / "calibration.npy"
np.save(cal_file, np.array(cal_data, dtype=object))
# Create CSV input with similarity scores
input_df = pd.DataFrame({
'query_idx': [0, 0, 1, 1],
'match_idx': [5, 10, 3, 8],
'similarity': [0.9985, 0.9990, 0.9995, 0.9998],
'match_protein_id': ['PROT_A', 'PROT_B', 'PROT_C', 'PROT_D'],
})
input_file = tmp_path / "input.csv"
input_df.to_csv(input_file, index=False)
output_file = tmp_path / "output.csv"
# Run prob command
result = run_cli(
'prob',
'--input', str(input_file),
'--calibration', str(cal_file),
'--output', str(output_file),
'--score-column', 'similarity',
'--n-calib', '30'
)
assert result.returncode == 0
assert output_file.exists()
df = pd.read_csv(output_file)
assert len(df) == 4
# Original columns should be preserved
assert 'query_idx' in df.columns
assert 'match_idx' in df.columns
assert 'similarity' in df.columns
assert 'match_protein_id' in df.columns
# New columns should be added
assert 'probability' in df.columns
assert 'uncertainty' in df.columns
def test_calibrate_with_mock_data(tmp_path):
"""Test calibrate command with mock calibration data."""
np.random.seed(42)
# Create mock calibration data (format: array of dicts with S_i, exact, partial)
n_samples = 100
cal_data = []
for i in range(n_samples):
sims = np.random.uniform(0.997, 0.9999, size=10).astype(np.float32)
# Create labels: higher similarity -> higher chance of being positive
exact_labels = (sims > 0.9995).astype(bool)
partial_labels = (sims > 0.999).astype(bool)
cal_data.append({
"S_i": sims,
"exact": exact_labels,
"partial": partial_labels,
})
cal_file = tmp_path / "calibration.npy"
np.save(cal_file, np.array(cal_data, dtype=object))
output_file = tmp_path / "thresholds.csv"
# Run calibrate command (small number of trials for speed)
result = run_cli(
'calibrate',
'--calibration', str(cal_file),
'--output', str(output_file),
'--alpha', '0.1',
'--n-trials', '5',
'--n-calib', '50',
'--method', 'quantile',
'--seed', '42'
)
assert result.returncode == 0
assert output_file.exists()
df = pd.read_csv(output_file)
assert len(df) == 5 # 5 trials
assert 'trial' in df.columns
assert 'alpha' in df.columns
assert 'fdr_threshold' in df.columns
assert 'fnr_threshold' in df.columns
# All alpha values should be 0.1
assert (df['alpha'] == 0.1).all()
# Thresholds should be in reasonable range
assert df['fdr_threshold'].min() > 0.0
assert df['fdr_threshold'].max() <= 1.0
assert df['fnr_threshold'].min() > 0.0
assert df['fnr_threshold'].max() <= 1.0
def test_embed_missing_input_file():
"""Test that embed fails gracefully with missing input file."""
with tempfile.NamedTemporaryFile(suffix='.npy', delete=False) as tmp:
output_file = tmp.name
try:
result = run_cli(
'embed',
'--input', '/nonexistent/file.fasta',
'--output', output_file
)
assert result.returncode != 0
finally:
Path(output_file).unlink(missing_ok=True)
def test_search_missing_query_file(tmp_path):
"""Test that search fails gracefully with missing query file."""
# Create a valid database file
db_embeddings = np.random.randn(10, 128).astype(np.float32)
db_file = tmp_path / "db.npy"
np.save(db_file, db_embeddings)
output_file = tmp_path / "results.csv"
result = run_cli(
'search',
'--input', '/nonexistent/query.npy',
'--database', str(db_file),
'--output', str(output_file)
)
assert result.returncode != 0
def test_search_missing_database_file(tmp_path):
"""Test that search fails gracefully with missing database file."""
# Create a valid query file
query_embeddings = np.random.randn(5, 128).astype(np.float32)
query_file = tmp_path / "query.npy"
np.save(query_file, query_embeddings)
output_file = tmp_path / "results.csv"
result = run_cli(
'search',
'--input', str(query_file),
'--database', '/nonexistent/db.npy',
'--output', str(output_file)
)
assert result.returncode != 0
def test_prob_missing_calibration_file(tmp_path):
"""Test that prob fails gracefully with missing calibration file."""
scores = np.array([0.998, 0.999])
score_file = tmp_path / "scores.npy"
np.save(score_file, scores)
output_file = tmp_path / "probs.csv"
result = run_cli(
'prob',
'--input', str(score_file),
'--calibration', '/nonexistent/calibration.npy',
'--output', str(output_file)
)
assert result.returncode != 0
def test_calibrate_missing_calibration_file(tmp_path):
"""Test that calibrate fails gracefully with missing calibration file."""
output_file = tmp_path / "thresholds.csv"
result = run_cli(
'calibrate',
'--calibration', '/nonexistent/calibration.npy',
'--output', str(output_file),
'--n-trials', '1'
)
assert result.returncode != 0
def test_search_with_k_larger_than_database(tmp_path):
"""Test search when k is larger than database size."""
np.random.seed(42)
query_embeddings = np.random.randn(2, 128).astype(np.float32)
db_embeddings = np.random.randn(3, 128).astype(np.float32) # Only 3 items
query_embeddings = query_embeddings / np.linalg.norm(query_embeddings, axis=1, keepdims=True)
db_embeddings = db_embeddings / np.linalg.norm(db_embeddings, axis=1, keepdims=True)
query_file = tmp_path / "query.npy"
db_file = tmp_path / "db.npy"
output_file = tmp_path / "results.csv"
np.save(query_file, query_embeddings)
np.save(db_file, db_embeddings)
# Request k=10 but only have 3 items in database (use --no-filter)
result = run_cli(
'search',
'--input', str(query_file),
'--database', str(db_file),
'--output', str(output_file),
'--k', '10',
'--no-filter'
)
# Should succeed (FAISS will return at most db size)
assert result.returncode == 0
assert output_file.exists()
df = pd.read_csv(output_file)
# Should have at most 2 * 3 = 6 results (2 queries, 3 db items each)
assert len(df) <= 6
def test_cli_module_import():
"""Test that CLI module can be imported and has expected functions."""
from protein_conformal import cli
assert hasattr(cli, 'main')
assert hasattr(cli, 'cmd_embed')
assert hasattr(cli, 'cmd_search')
assert hasattr(cli, 'cmd_verify')
assert hasattr(cli, 'cmd_prob')
assert hasattr(cli, 'cmd_calibrate')
assert callable(cli.main)