| | """ |
| | Tests for FastAPI skill classification endpoints. |
| | |
| | Tests cover request validation, response structure, error handling, |
| | and batch processing capabilities. |
| | """ |
| |
|
| | from http import HTTPStatus |
| |
|
| | import pytest |
| | from fastapi.testclient import TestClient |
| |
|
| | from hopcroft_skill_classification_tool_competition.main import app |
| |
|
| | _client = None |
| |
|
| | def get_client(): |
| | """Get or create TestClient with lifespan executed.""" |
| | global _client |
| | if _client is None: |
| | _client = TestClient(app) |
| | _client.__enter__() |
| | return _client |
| |
|
| | client = get_client() |
| |
|
| | class TestRootEndpoint: |
| | """Tests for the root endpoint.""" |
| | |
| | def test_read_root(self): |
| | """Test root endpoint returns basic API information.""" |
| | response = client.get("/") |
| | |
| | assert response.status_code == HTTPStatus.OK |
| | assert response.request.method == "GET" |
| | |
| | data = response.json() |
| | assert "message" in data |
| | assert "version" in data |
| | assert data["message"] == "Skill Classification API" |
| | assert data["version"] == "1.0.0" |
| |
|
| |
|
| | class TestHealthEndpoint: |
| | """Tests for the health check endpoint.""" |
| | |
| | def test_health_check(self): |
| | """Test health endpoint returns service status.""" |
| | response = client.get("/health") |
| | |
| | assert response.status_code == HTTPStatus.OK |
| | assert response.request.method == "GET" |
| | |
| | data = response.json() |
| | assert "status" in data |
| | assert "model_loaded" in data |
| | assert "version" in data |
| | assert data["status"] == "healthy" |
| | assert isinstance(data["model_loaded"], bool) |
| |
|
| |
|
| | class TestPredictionEndpoint: |
| | """Tests for the single prediction endpoint.""" |
| | |
| | def test_predict_with_minimal_data(self): |
| | """Test prediction with only required fields.""" |
| | issue_data = { |
| | "issue_text": "Fix authentication bug in login module" |
| | } |
| | |
| | response = client.post("/predict", json=issue_data) |
| | |
| | assert response.status_code == HTTPStatus.CREATED |
| | assert response.request.method == "POST" |
| | |
| | data = response.json() |
| | assert "predictions" in data |
| | assert "num_predictions" in data |
| | assert "model_version" in data |
| | assert "processing_time_ms" in data |
| | |
| | |
| | assert data["num_predictions"] == len(data["predictions"]) |
| | |
| | |
| | for pred in data["predictions"]: |
| | assert "skill_name" in pred |
| | assert "confidence" in pred |
| | assert 0.0 <= pred["confidence"] <= 1.0 |
| | |
| | def test_predict_with_full_data(self): |
| | """Test prediction with all optional fields.""" |
| | issue_data = { |
| | "issue_text": "Add support for OAuth authentication", |
| | "issue_description": "Implement OAuth 2.0 flow for third-party authentication providers", |
| | "repo_name": "myorg/myproject", |
| | "pr_number": 456, |
| | "author_name": "developer123", |
| | "created_at": "2024-01-15T10:30:00Z" |
| | } |
| | |
| | response = client.post("/predict", json=issue_data) |
| | |
| | assert response.status_code == HTTPStatus.CREATED |
| | |
| | data = response.json() |
| | assert len(data["predictions"]) > 0 |
| | assert data["model_version"] == "1.0.0" |
| | assert data["processing_time_ms"] > 0 |
| | |
| | def test_predict_missing_required_field(self): |
| | """Test prediction fails when required field is missing.""" |
| | issue_data = { |
| | "issue_description": "This is missing the issue_text field" |
| | } |
| | |
| | response = client.post("/predict", json=issue_data) |
| | |
| | |
| | assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY |
| | |
| | def test_predict_invalid_pr_number(self): |
| | """Test prediction fails with invalid PR number.""" |
| | issue_data = { |
| | "issue_text": "Fix bug", |
| | "pr_number": -5 |
| | } |
| | |
| | response = client.post("/predict", json=issue_data) |
| | |
| | |
| | assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY |
| | |
| | def test_predict_empty_issue_text(self): |
| | """Test prediction with empty issue text.""" |
| | issue_data = { |
| | "issue_text": "" |
| | } |
| | |
| | response = client.post("/predict", json=issue_data) |
| | |
| | |
| | assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY |
| | |
| | def test_predict_whitespace_only_text(self): |
| | """Test prediction with whitespace-only text.""" |
| | issue_data = { |
| | "issue_text": " " |
| | } |
| | |
| | response = client.post("/predict", json=issue_data) |
| | |
| | |
| | assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY |
| |
|
| |
|
| | class TestBatchPredictionEndpoint: |
| | """Tests for the batch prediction endpoint.""" |
| | |
| | def test_batch_predict_multiple_issues(self): |
| | """Test batch prediction with multiple issues.""" |
| | batch_data = { |
| | "issues": [ |
| | { |
| | "issue_text": "Transfer learning with transformers for text classification." |
| | }, |
| | { |
| | "issue_text": "Generative adversarial networks in both PyTorch and TensorFlow." |
| | }, |
| | { |
| | "issue_text": "Fix database connection pooling issue" |
| | } |
| | ] |
| | } |
| | |
| | response = client.post("/predict/batch", json=batch_data) |
| | |
| | assert response.status_code == HTTPStatus.OK |
| | assert response.request.method == "POST" |
| | |
| | data = response.json() |
| | assert "results" in data |
| | assert "total_issues" in data |
| | assert "total_processing_time_ms" in data |
| | |
| | |
| | assert len(data["results"]) == len(batch_data["issues"]) |
| | assert data["total_issues"] == 3 |
| | |
| | |
| | for result in data["results"]: |
| | assert "predictions" in result |
| | assert "num_predictions" in result |
| | assert len(result["predictions"]) > 0 |
| | |
| | def test_batch_predict_single_issue(self): |
| | """Test batch prediction with single issue.""" |
| | batch_data = { |
| | "issues": [ |
| | { |
| | "issue_text": "Add unit tests for authentication module" |
| | } |
| | ] |
| | } |
| | |
| | response = client.post("/predict/batch", json=batch_data) |
| | |
| | assert response.status_code == HTTPStatus.OK |
| | |
| | data = response.json() |
| | assert data["total_issues"] == 1 |
| | assert len(data["results"]) == 1 |
| | |
| | def test_batch_predict_empty_list(self): |
| | """Test batch prediction with empty issues list.""" |
| | batch_data = { |
| | "issues": [] |
| | } |
| | |
| | response = client.post("/predict/batch", json=batch_data) |
| | |
| | |
| | assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY |
| | |
| | def test_batch_predict_too_many_issues(self): |
| | """Test batch prediction exceeds maximum limit.""" |
| | batch_data = { |
| | "issues": [ |
| | {"issue_text": f"Issue {i}"} |
| | for i in range(101) |
| | ] |
| | } |
| | |
| | response = client.post("/predict/batch", json=batch_data) |
| | |
| | |
| | assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY |
| | |
| | def test_batch_predict_with_mixed_data(self): |
| | """Test batch prediction with mix of minimal and full data.""" |
| | batch_data = { |
| | "issues": [ |
| | { |
| | "issue_text": "Simple issue" |
| | }, |
| | { |
| | "issue_text": "Detailed issue", |
| | "issue_description": "With description and metadata", |
| | "repo_name": "user/repo", |
| | "pr_number": 123 |
| | } |
| | ] |
| | } |
| | |
| | response = client.post("/predict/batch", json=batch_data) |
| | |
| | assert response.status_code == HTTPStatus.OK |
| | data = response.json() |
| | assert len(data["results"]) == 2 |
| |
|
| |
|
| | class TestErrorHandling: |
| | """Tests for error handling and responses.""" |
| | |
| | def test_missing_required_field(self): |
| | """Test validation error for missing required field.""" |
| | response = client.post("/predict", json={}) |
| | assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY |
| | |
| | def test_endpoint_not_found(self): |
| | """Test non-existent endpoint returns 404.""" |
| | response = client.get("/nonexistent") |
| | assert response.status_code == HTTPStatus.NOT_FOUND |
| |
|
| |
|
| | class TestGetPredictionEndpoint: |
| | """Tests for retrieving individual predictions by run_id.""" |
| |
|
| | def test_get_prediction_success(self): |
| | """Test retrieving an existing prediction.""" |
| | issue_data = {"issue_text": "Test issue for retrieval"} |
| | create_response = client.post("/predict", json=issue_data) |
| | |
| | assert create_response.status_code == HTTPStatus.CREATED |
| | run_id = create_response.json()["run_id"] |
| | |
| | response = client.get(f"/predictions/{run_id}") |
| | |
| | assert response.status_code == HTTPStatus.OK |
| | data = response.json() |
| | |
| | assert data["run_id"] == run_id |
| | assert "predictions" in data |
| | assert "timestamp" in data |
| | |
| | def test_get_prediction_not_found(self): |
| | """Test retrieving a non-existent prediction returns 404.""" |
| | fake_run_id = "nonexistent_run_id_12345" |
| | response = client.get(f"/predictions/{fake_run_id}") |
| | |
| | assert response.status_code == HTTPStatus.NOT_FOUND |
| |
|
| |
|
| | class TestListPredictionsEndpoint: |
| | """Tests for listing recent predictions.""" |
| | |
| | def test_list_predictions(self): |
| | """Test listing predictions works.""" |
| | response = client.get("/predictions") |
| | |
| | assert response.status_code == HTTPStatus.OK |
| | data = response.json() |
| | |
| | assert isinstance(data, list) |
| | |
| | def test_list_predictions_with_pagination(self): |
| | """Test listing predictions with pagination parameters.""" |
| | response = client.get("/predictions?skip=0&limit=5") |
| | |
| | assert response.status_code == HTTPStatus.OK |
| | data = response.json() |
| | |
| | assert isinstance(data, list) |
| | assert len(data) <= 5 |
| |
|
| |
|
| | class TestMLflowIntegration: |
| | """Tests for MLflow tracking integration.""" |
| | |
| | def test_prediction_creates_run_id(self): |
| | """Test that predictions create MLflow run_id.""" |
| | issue_data = {"issue_text": "MLflow tracking test"} |
| | response = client.post("/predict", json=issue_data) |
| | |
| | assert response.status_code == HTTPStatus.CREATED |
| | data = response.json() |
| | |
| | assert "run_id" in data |
| | assert data["run_id"] |
| | |
| | def test_retrieve_prediction_by_run_id(self): |
| | """Test retrieving prediction using run_id.""" |
| | response = client.post("/predict", json={"issue_text": "Test retrieval"}) |
| | run_id = response.json()["run_id"] |
| | |
| | retrieve_response = client.get(f"/predictions/{run_id}") |
| | |
| | assert retrieve_response.status_code == HTTPStatus.OK |
| | assert retrieve_response.json()["run_id"] == run_id |
| |
|
| |
|
| | if __name__ == "__main__": |
| | pytest.main([__file__, "-v"]) |
| |
|