Spaces:
Running
Running
File size: 5,819 Bytes
95ecd72 d963dcd 95ecd72 d963dcd 95ecd72 d963dcd 95ecd72 d963dcd 95ecd72 d963dcd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
"""Tests for the prediction API."""
import pytest
from fastapi.testclient import TestClient
import sys
from pathlib import Path
# Add project root to path
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from webapp.app.main import app
from webapp.app.database import init_db
@pytest.fixture(scope="module")
def client():
"""Create test client."""
init_db()
return TestClient(app)
class TestHealthEndpoint:
"""Tests for /api/health endpoint."""
def test_health_check(self, client):
"""Test health check returns OK."""
response = client.get("/api/health")
assert response.status_code == 200
data = response.json()
assert "status" in data
assert "version" in data
class TestExampleEndpoint:
"""Tests for /api/example endpoint."""
def test_get_examples(self, client):
"""Test getting example sequences."""
response = client.get("/api/example")
assert response.status_code == 200
data = response.json()
assert "sequences" in data
assert len(data["sequences"]) >= 1
# Check structure of first example
example = data["sequences"][0]
assert "name" in example
assert "sequence" in example
assert "description" in example
assert len(example["sequence"]) == 70
class TestPredictEndpoint:
"""Tests for /api/predict endpoint."""
def test_predict_valid_sequence(self, client):
"""Test prediction with valid sequence."""
# Use a real example sequence
sequence = "GGTAGTACGCCAATTCGCCGGTGCCGCGAGCCAGAGGCTACCAAAACTTGACAAGCCTACATATACTACT"
response = client.post(
"/api/predict",
json={"sequence": sequence},
)
assert response.status_code == 200
data = response.json()
assert "job_id" in data
assert data["status"] == "finished"
assert "result_url" in data
def test_predict_invalid_length(self, client):
"""Test prediction rejects wrong length."""
response = client.post(
"/api/predict",
json={"sequence": "ACGT"}, # Too short
)
assert response.status_code == 422 # Validation error
def test_predict_invalid_characters(self, client):
"""Test prediction rejects invalid characters."""
# 70 characters but with invalid ones
sequence = "X" * 70
response = client.post(
"/api/predict",
json={"sequence": sequence},
)
assert response.status_code == 422 # Validation error
class TestResultEndpoint:
"""Tests for /api/result endpoint."""
def test_get_result_not_found(self, client):
"""Test getting result for non-existent job."""
response = client.get("/api/result/nonexistent-id")
assert response.status_code == 404
def test_get_result_after_prediction(self, client):
"""Test getting result after prediction."""
# First submit a prediction
sequence = "GGTAGTACGCCAATTCGCCGGTGCCGCGAGCCAGAGGCTACCAAAACTTGACAAGCCTACATATACTACT"
predict_response = client.post(
"/api/predict",
json={"sequence": sequence},
)
job_id = predict_response.json()["job_id"]
# Then get the result
response = client.get(f"/api/result/{job_id}")
assert response.status_code == 200
data = response.json()
assert "psi" in data
assert 0 <= data["psi"] <= 1
assert "interpretation" in data
class TestBatchEndpoint:
"""Tests for /api/batch endpoint."""
def test_batch_predict(self, client):
"""Test batch prediction."""
sequences = [
{"name": "seq1", "sequence": "GGTAGTACGCCAATTCGCCGGTGCCGCGAGCCAGAGGCTACCAAAACTTGACAAGCCTACATATACTACT"},
{"name": "seq2", "sequence": "CTACCACCTCCCAAGCTTACACACTGTTTGATGAAAGGTCGCCACAACGTTCCCTCACCCCTAGTCTCGC"},
]
response = client.post(
"/api/batch",
json={"sequences": sequences},
)
assert response.status_code == 200
data = response.json()
assert data["status"] == "finished"
def test_batch_empty_list(self, client):
"""Test batch rejects empty list."""
response = client.post(
"/api/batch",
json={"sequences": []},
)
assert response.status_code == 422
class TestExportEndpoint:
"""Tests for /api/export endpoint."""
def test_export_csv_basic(self, client):
"""Test exporting results as CSV."""
# First create a prediction
sequence = "GGTAGTACGCCAATTCGCCGGTGCCGCGAGCCAGAGGCTACCAAAACTTGACAAGCCTACATATACTACT"
predict_response = client.post(
"/api/predict",
json={"sequence": sequence},
)
job_id = predict_response.json()["job_id"]
# Export as CSV
response = client.get(f"/api/export/{job_id}/csv")
assert response.status_code == 200
def test_export_csv_content(self, client):
"""Test exporting results as CSV with proper content."""
# First create a prediction
sequence = "GGTAGTACGCCAATTCGCCGGTGCCGCGAGCCAGAGGCTACCAAAACTTGACAAGCCTACATATACTACT"
predict_response = client.post(
"/api/predict",
json={"sequence": sequence},
)
job_id = predict_response.json()["job_id"]
# Export as CSV
response = client.get(f"/api/export/{job_id}/csv")
assert response.status_code == 200
assert response.headers["content-type"] == "text/csv; charset=utf-8"
csv_content = response.text
assert "sequence" in csv_content # CSV header should include sequence
assert "psi" in csv_content # CSV header should include psi
|