Spaces:
Sleeping
Sleeping
| """Tests for the prediction API.""" | |
| import pytest | |
| from fastapi.testclient import TestClient | |
| import sys | |
| from pathlib import Path | |
| # Add project root to path | |
| sys.path.insert(0, str(Path(__file__).parent.parent.parent)) | |
| from webapp.app.main import app | |
| from webapp.app.database import init_db | |
| def client(): | |
| """Create test client.""" | |
| init_db() | |
| return TestClient(app) | |
| class TestHealthEndpoint: | |
| """Tests for /api/health endpoint.""" | |
| def test_health_check(self, client): | |
| """Test health check returns OK.""" | |
| response = client.get("/api/health") | |
| assert response.status_code == 200 | |
| data = response.json() | |
| assert "status" in data | |
| assert "version" in data | |
| class TestExampleEndpoint: | |
| """Tests for /api/example endpoint.""" | |
| def test_get_examples(self, client): | |
| """Test getting example sequences.""" | |
| response = client.get("/api/example") | |
| assert response.status_code == 200 | |
| data = response.json() | |
| assert "sequences" in data | |
| assert len(data["sequences"]) >= 1 | |
| # Check structure of first example | |
| example = data["sequences"][0] | |
| assert "name" in example | |
| assert "sequence" in example | |
| assert "description" in example | |
| assert len(example["sequence"]) == 70 | |
| class TestPredictEndpoint: | |
| """Tests for /api/predict endpoint.""" | |
| def test_predict_valid_sequence(self, client): | |
| """Test prediction with valid sequence.""" | |
| # Use a real example sequence | |
| sequence = "GGTAGTACGCCAATTCGCCGGTGCCGCGAGCCAGAGGCTACCAAAACTTGACAAGCCTACATATACTACT" | |
| response = client.post( | |
| "/api/predict", | |
| json={"sequence": sequence}, | |
| ) | |
| assert response.status_code == 200 | |
| data = response.json() | |
| assert "job_id" in data | |
| assert data["status"] == "finished" | |
| assert "result_url" in data | |
| def test_predict_invalid_length(self, client): | |
| """Test prediction rejects wrong length.""" | |
| response = client.post( | |
| "/api/predict", | |
| json={"sequence": "ACGT"}, # Too short | |
| ) | |
| assert response.status_code == 422 # Validation error | |
| def test_predict_invalid_characters(self, client): | |
| """Test prediction rejects invalid characters.""" | |
| # 70 characters but with invalid ones | |
| sequence = "X" * 70 | |
| response = client.post( | |
| "/api/predict", | |
| json={"sequence": sequence}, | |
| ) | |
| assert response.status_code == 422 # Validation error | |
| class TestResultEndpoint: | |
| """Tests for /api/result endpoint.""" | |
| def test_get_result_not_found(self, client): | |
| """Test getting result for non-existent job.""" | |
| response = client.get("/api/result/nonexistent-id") | |
| assert response.status_code == 404 | |
| def test_get_result_after_prediction(self, client): | |
| """Test getting result after prediction.""" | |
| # First submit a prediction | |
| sequence = "GGTAGTACGCCAATTCGCCGGTGCCGCGAGCCAGAGGCTACCAAAACTTGACAAGCCTACATATACTACT" | |
| predict_response = client.post( | |
| "/api/predict", | |
| json={"sequence": sequence}, | |
| ) | |
| job_id = predict_response.json()["job_id"] | |
| # Then get the result | |
| response = client.get(f"/api/result/{job_id}") | |
| assert response.status_code == 200 | |
| data = response.json() | |
| assert "psi" in data | |
| assert 0 <= data["psi"] <= 1 | |
| assert "interpretation" in data | |
| class TestBatchEndpoint: | |
| """Tests for /api/batch endpoint.""" | |
| def test_batch_predict(self, client): | |
| """Test batch prediction.""" | |
| sequences = [ | |
| {"name": "seq1", "sequence": "GGTAGTACGCCAATTCGCCGGTGCCGCGAGCCAGAGGCTACCAAAACTTGACAAGCCTACATATACTACT"}, | |
| {"name": "seq2", "sequence": "CTACCACCTCCCAAGCTTACACACTGTTTGATGAAAGGTCGCCACAACGTTCCCTCACCCCTAGTCTCGC"}, | |
| ] | |
| response = client.post( | |
| "/api/batch", | |
| json={"sequences": sequences}, | |
| ) | |
| assert response.status_code == 200 | |
| data = response.json() | |
| assert data["status"] == "finished" | |
| def test_batch_empty_list(self, client): | |
| """Test batch rejects empty list.""" | |
| response = client.post( | |
| "/api/batch", | |
| json={"sequences": []}, | |
| ) | |
| assert response.status_code == 422 | |
| class TestExportEndpoint: | |
| """Tests for /api/export endpoint.""" | |
| def test_export_csv_basic(self, client): | |
| """Test exporting results as CSV.""" | |
| # First create a prediction | |
| sequence = "GGTAGTACGCCAATTCGCCGGTGCCGCGAGCCAGAGGCTACCAAAACTTGACAAGCCTACATATACTACT" | |
| predict_response = client.post( | |
| "/api/predict", | |
| json={"sequence": sequence}, | |
| ) | |
| job_id = predict_response.json()["job_id"] | |
| # Export as CSV | |
| response = client.get(f"/api/export/{job_id}/csv") | |
| assert response.status_code == 200 | |
| def test_export_csv_content(self, client): | |
| """Test exporting results as CSV with proper content.""" | |
| # First create a prediction | |
| sequence = "GGTAGTACGCCAATTCGCCGGTGCCGCGAGCCAGAGGCTACCAAAACTTGACAAGCCTACATATACTACT" | |
| predict_response = client.post( | |
| "/api/predict", | |
| json={"sequence": sequence}, | |
| ) | |
| job_id = predict_response.json()["job_id"] | |
| # Export as CSV | |
| response = client.get(f"/api/export/{job_id}/csv") | |
| assert response.status_code == 200 | |
| assert response.headers["content-type"] == "text/csv; charset=utf-8" | |
| csv_content = response.text | |
| assert "sequence" in csv_content # CSV header should include sequence | |
| assert "psi" in csv_content # CSV header should include psi | |