|
|
""" |
|
|
Tests for FastAPI skill classification endpoints. |
|
|
|
|
|
Tests cover request validation, response structure, error handling, |
|
|
and batch processing capabilities. |
|
|
""" |
|
|
|
|
|
from http import HTTPStatus |
|
|
|
|
|
import pytest |
|
|
from fastapi.testclient import TestClient |
|
|
|
|
|
from hopcroft_skill_classification_tool_competition.main import app |
|
|
|
|
|
_client = None |
|
|
|
|
|
def get_client(): |
|
|
"""Get or create TestClient with lifespan executed.""" |
|
|
global _client |
|
|
if _client is None: |
|
|
_client = TestClient(app) |
|
|
_client.__enter__() |
|
|
return _client |
|
|
|
|
|
client = get_client() |
|
|
|
|
|
class TestRootEndpoint: |
|
|
"""Tests for the root endpoint.""" |
|
|
|
|
|
def test_read_root(self): |
|
|
"""Test root endpoint returns basic API information.""" |
|
|
response = client.get("/") |
|
|
|
|
|
assert response.status_code == HTTPStatus.OK |
|
|
assert response.request.method == "GET" |
|
|
|
|
|
data = response.json() |
|
|
assert "message" in data |
|
|
assert "version" in data |
|
|
assert data["message"] == "Skill Classification API" |
|
|
assert data["version"] == "1.0.0" |
|
|
|
|
|
|
|
|
class TestHealthEndpoint: |
|
|
"""Tests for the health check endpoint.""" |
|
|
|
|
|
def test_health_check(self): |
|
|
"""Test health endpoint returns service status.""" |
|
|
response = client.get("/health") |
|
|
|
|
|
assert response.status_code == HTTPStatus.OK |
|
|
assert response.request.method == "GET" |
|
|
|
|
|
data = response.json() |
|
|
assert "status" in data |
|
|
assert "model_loaded" in data |
|
|
assert "version" in data |
|
|
assert data["status"] == "healthy" |
|
|
assert isinstance(data["model_loaded"], bool) |
|
|
|
|
|
|
|
|
class TestPredictionEndpoint: |
|
|
"""Tests for the single prediction endpoint.""" |
|
|
|
|
|
def test_predict_with_minimal_data(self): |
|
|
"""Test prediction with only required fields.""" |
|
|
issue_data = { |
|
|
"issue_text": "Fix authentication bug in login module" |
|
|
} |
|
|
|
|
|
response = client.post("/predict", json=issue_data) |
|
|
|
|
|
assert response.status_code == HTTPStatus.CREATED |
|
|
assert response.request.method == "POST" |
|
|
|
|
|
data = response.json() |
|
|
assert "predictions" in data |
|
|
assert "num_predictions" in data |
|
|
assert "model_version" in data |
|
|
assert "processing_time_ms" in data |
|
|
|
|
|
|
|
|
assert data["num_predictions"] == len(data["predictions"]) |
|
|
|
|
|
|
|
|
for pred in data["predictions"]: |
|
|
assert "skill_name" in pred |
|
|
assert "confidence" in pred |
|
|
assert 0.0 <= pred["confidence"] <= 1.0 |
|
|
|
|
|
def test_predict_with_full_data(self): |
|
|
"""Test prediction with all optional fields.""" |
|
|
issue_data = { |
|
|
"issue_text": "Add support for OAuth authentication", |
|
|
"issue_description": "Implement OAuth 2.0 flow for third-party authentication providers", |
|
|
"repo_name": "myorg/myproject", |
|
|
"pr_number": 456, |
|
|
"author_name": "developer123", |
|
|
"created_at": "2024-01-15T10:30:00Z" |
|
|
} |
|
|
|
|
|
response = client.post("/predict", json=issue_data) |
|
|
|
|
|
assert response.status_code == HTTPStatus.CREATED |
|
|
|
|
|
data = response.json() |
|
|
assert len(data["predictions"]) > 0 |
|
|
assert data["model_version"] == "1.0.0" |
|
|
assert data["processing_time_ms"] > 0 |
|
|
|
|
|
def test_predict_missing_required_field(self): |
|
|
"""Test prediction fails when required field is missing.""" |
|
|
issue_data = { |
|
|
"issue_description": "This is missing the issue_text field" |
|
|
} |
|
|
|
|
|
response = client.post("/predict", json=issue_data) |
|
|
|
|
|
|
|
|
assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY |
|
|
|
|
|
def test_predict_invalid_pr_number(self): |
|
|
"""Test prediction fails with invalid PR number.""" |
|
|
issue_data = { |
|
|
"issue_text": "Fix bug", |
|
|
"pr_number": -5 |
|
|
} |
|
|
|
|
|
response = client.post("/predict", json=issue_data) |
|
|
|
|
|
|
|
|
assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY |
|
|
|
|
|
def test_predict_empty_issue_text(self): |
|
|
"""Test prediction with empty issue text.""" |
|
|
issue_data = { |
|
|
"issue_text": "" |
|
|
} |
|
|
|
|
|
response = client.post("/predict", json=issue_data) |
|
|
|
|
|
|
|
|
assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY |
|
|
|
|
|
def test_predict_whitespace_only_text(self): |
|
|
"""Test prediction with whitespace-only text.""" |
|
|
issue_data = { |
|
|
"issue_text": " " |
|
|
} |
|
|
|
|
|
response = client.post("/predict", json=issue_data) |
|
|
|
|
|
|
|
|
assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY |
|
|
|
|
|
|
|
|
class TestBatchPredictionEndpoint: |
|
|
"""Tests for the batch prediction endpoint.""" |
|
|
|
|
|
def test_batch_predict_multiple_issues(self): |
|
|
"""Test batch prediction with multiple issues.""" |
|
|
batch_data = { |
|
|
"issues": [ |
|
|
{ |
|
|
"issue_text": "Transfer learning with transformers for text classification." |
|
|
}, |
|
|
{ |
|
|
"issue_text": "Generative adversarial networks in both PyTorch and TensorFlow." |
|
|
}, |
|
|
{ |
|
|
"issue_text": "Fix database connection pooling issue" |
|
|
} |
|
|
] |
|
|
} |
|
|
|
|
|
response = client.post("/predict/batch", json=batch_data) |
|
|
|
|
|
assert response.status_code == HTTPStatus.OK |
|
|
assert response.request.method == "POST" |
|
|
|
|
|
data = response.json() |
|
|
assert "results" in data |
|
|
assert "total_issues" in data |
|
|
assert "total_processing_time_ms" in data |
|
|
|
|
|
|
|
|
assert len(data["results"]) == len(batch_data["issues"]) |
|
|
assert data["total_issues"] == 3 |
|
|
|
|
|
|
|
|
for result in data["results"]: |
|
|
assert "predictions" in result |
|
|
assert "num_predictions" in result |
|
|
assert len(result["predictions"]) > 0 |
|
|
|
|
|
def test_batch_predict_single_issue(self): |
|
|
"""Test batch prediction with single issue.""" |
|
|
batch_data = { |
|
|
"issues": [ |
|
|
{ |
|
|
"issue_text": "Add unit tests for authentication module" |
|
|
} |
|
|
] |
|
|
} |
|
|
|
|
|
response = client.post("/predict/batch", json=batch_data) |
|
|
|
|
|
assert response.status_code == HTTPStatus.OK |
|
|
|
|
|
data = response.json() |
|
|
assert data["total_issues"] == 1 |
|
|
assert len(data["results"]) == 1 |
|
|
|
|
|
def test_batch_predict_empty_list(self): |
|
|
"""Test batch prediction with empty issues list.""" |
|
|
batch_data = { |
|
|
"issues": [] |
|
|
} |
|
|
|
|
|
response = client.post("/predict/batch", json=batch_data) |
|
|
|
|
|
|
|
|
assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY |
|
|
|
|
|
def test_batch_predict_too_many_issues(self): |
|
|
"""Test batch prediction exceeds maximum limit.""" |
|
|
batch_data = { |
|
|
"issues": [ |
|
|
{"issue_text": f"Issue {i}"} |
|
|
for i in range(101) |
|
|
] |
|
|
} |
|
|
|
|
|
response = client.post("/predict/batch", json=batch_data) |
|
|
|
|
|
|
|
|
assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY |
|
|
|
|
|
def test_batch_predict_with_mixed_data(self): |
|
|
"""Test batch prediction with mix of minimal and full data.""" |
|
|
batch_data = { |
|
|
"issues": [ |
|
|
{ |
|
|
"issue_text": "Simple issue" |
|
|
}, |
|
|
{ |
|
|
"issue_text": "Detailed issue", |
|
|
"issue_description": "With description and metadata", |
|
|
"repo_name": "user/repo", |
|
|
"pr_number": 123 |
|
|
} |
|
|
] |
|
|
} |
|
|
|
|
|
response = client.post("/predict/batch", json=batch_data) |
|
|
|
|
|
assert response.status_code == HTTPStatus.OK |
|
|
data = response.json() |
|
|
assert len(data["results"]) == 2 |
|
|
|
|
|
|
|
|
class TestErrorHandling: |
|
|
"""Tests for error handling and responses.""" |
|
|
|
|
|
def test_missing_required_field(self): |
|
|
"""Test validation error for missing required field.""" |
|
|
response = client.post("/predict", json={}) |
|
|
assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY |
|
|
|
|
|
def test_endpoint_not_found(self): |
|
|
"""Test non-existent endpoint returns 404.""" |
|
|
response = client.get("/nonexistent") |
|
|
assert response.status_code == HTTPStatus.NOT_FOUND |
|
|
|
|
|
|
|
|
class TestGetPredictionEndpoint: |
|
|
"""Tests for retrieving individual predictions by run_id.""" |
|
|
|
|
|
def test_get_prediction_success(self): |
|
|
"""Test retrieving an existing prediction.""" |
|
|
issue_data = {"issue_text": "Test issue for retrieval"} |
|
|
create_response = client.post("/predict", json=issue_data) |
|
|
|
|
|
assert create_response.status_code == HTTPStatus.CREATED |
|
|
run_id = create_response.json()["run_id"] |
|
|
|
|
|
response = client.get(f"/predictions/{run_id}") |
|
|
|
|
|
assert response.status_code == HTTPStatus.OK |
|
|
data = response.json() |
|
|
|
|
|
assert data["run_id"] == run_id |
|
|
assert "predictions" in data |
|
|
assert "timestamp" in data |
|
|
|
|
|
def test_get_prediction_not_found(self): |
|
|
"""Test retrieving a non-existent prediction returns 404.""" |
|
|
fake_run_id = "nonexistent_run_id_12345" |
|
|
response = client.get(f"/predictions/{fake_run_id}") |
|
|
|
|
|
assert response.status_code == HTTPStatus.NOT_FOUND |
|
|
|
|
|
|
|
|
class TestListPredictionsEndpoint: |
|
|
"""Tests for listing recent predictions.""" |
|
|
|
|
|
def test_list_predictions(self): |
|
|
"""Test listing predictions works.""" |
|
|
response = client.get("/predictions") |
|
|
|
|
|
assert response.status_code == HTTPStatus.OK |
|
|
data = response.json() |
|
|
|
|
|
assert isinstance(data, list) |
|
|
|
|
|
def test_list_predictions_with_pagination(self): |
|
|
"""Test listing predictions with pagination parameters.""" |
|
|
response = client.get("/predictions?skip=0&limit=5") |
|
|
|
|
|
assert response.status_code == HTTPStatus.OK |
|
|
data = response.json() |
|
|
|
|
|
assert isinstance(data, list) |
|
|
assert len(data) <= 5 |
|
|
|
|
|
|
|
|
class TestMLflowIntegration: |
|
|
"""Tests for MLflow tracking integration.""" |
|
|
|
|
|
def test_prediction_creates_run_id(self): |
|
|
"""Test that predictions create MLflow run_id.""" |
|
|
issue_data = {"issue_text": "MLflow tracking test"} |
|
|
response = client.post("/predict", json=issue_data) |
|
|
|
|
|
assert response.status_code == HTTPStatus.CREATED |
|
|
data = response.json() |
|
|
|
|
|
assert "run_id" in data |
|
|
assert data["run_id"] |
|
|
|
|
|
def test_retrieve_prediction_by_run_id(self): |
|
|
"""Test retrieving prediction using run_id.""" |
|
|
response = client.post("/predict", json={"issue_text": "Test retrieval"}) |
|
|
run_id = response.json()["run_id"] |
|
|
|
|
|
retrieve_response = client.get(f"/predictions/{run_id}") |
|
|
|
|
|
assert retrieve_response.status_code == HTTPStatus.OK |
|
|
assert retrieve_response.json()["run_id"] == run_id |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
pytest.main([__file__, "-v"]) |
|
|
|