Spaces:
Sleeping
Sleeping
| """ | |
| Tests for FastAPI skill classification endpoints. | |
| Tests cover request validation, response structure, error handling, | |
| and batch processing capabilities. | |
| """ | |
| from http import HTTPStatus | |
| import pytest | |
| from fastapi.testclient import TestClient | |
| from hopcroft_skill_classification_tool_competition.main import app | |
| _client = None | |
| def get_client(): | |
| """Get or create TestClient with lifespan executed.""" | |
| global _client | |
| if _client is None: | |
| _client = TestClient(app) | |
| _client.__enter__() # Force lifespan startup | |
| return _client | |
| client = get_client() | |
| class TestRootEndpoint: | |
| """Tests for the root endpoint.""" | |
| def test_read_root(self): | |
| """Test root endpoint returns basic API information.""" | |
| response = client.get("/") | |
| assert response.status_code == HTTPStatus.OK | |
| assert response.request.method == "GET" | |
| data = response.json() | |
| assert "message" in data | |
| assert "version" in data | |
| assert data["message"] == "Skill Classification API" | |
| assert data["version"] == "1.0.0" | |
| class TestHealthEndpoint: | |
| """Tests for the health check endpoint.""" | |
| def test_health_check(self): | |
| """Test health endpoint returns service status.""" | |
| response = client.get("/health") | |
| assert response.status_code == HTTPStatus.OK | |
| assert response.request.method == "GET" | |
| data = response.json() | |
| assert "status" in data | |
| assert "model_loaded" in data | |
| assert "version" in data | |
| assert data["status"] == "healthy" | |
| assert isinstance(data["model_loaded"], bool) | |
| class TestPredictionEndpoint: | |
| """Tests for the single prediction endpoint.""" | |
| def test_predict_with_minimal_data(self): | |
| """Test prediction with only required fields.""" | |
| issue_data = { | |
| "issue_text": "Fix authentication bug in login module" | |
| } | |
| response = client.post("/predict", json=issue_data) | |
| assert response.status_code == HTTPStatus.CREATED | |
| assert response.request.method == "POST" | |
| data = response.json() | |
| assert "predictions" in data | |
| assert "num_predictions" in data | |
| assert "model_version" in data | |
| assert "processing_time_ms" in data | |
| # Verify predictions structure | |
| assert data["num_predictions"] == len(data["predictions"]) | |
| # Check each prediction has required fields | |
| for pred in data["predictions"]: | |
| assert "skill_name" in pred | |
| assert "confidence" in pred | |
| assert 0.0 <= pred["confidence"] <= 1.0 | |
| def test_predict_with_full_data(self): | |
| """Test prediction with all optional fields.""" | |
| issue_data = { | |
| "issue_text": "Add support for OAuth authentication", | |
| "issue_description": "Implement OAuth 2.0 flow for third-party authentication providers", | |
| "repo_name": "myorg/myproject", | |
| "pr_number": 456, | |
| "author_name": "developer123", | |
| "created_at": "2024-01-15T10:30:00Z" | |
| } | |
| response = client.post("/predict", json=issue_data) | |
| assert response.status_code == HTTPStatus.CREATED | |
| data = response.json() | |
| assert len(data["predictions"]) > 0 | |
| assert data["model_version"] == "1.0.0" | |
| assert data["processing_time_ms"] > 0 | |
| def test_predict_missing_required_field(self): | |
| """Test prediction fails when required field is missing.""" | |
| issue_data = { | |
| "issue_description": "This is missing the issue_text field" | |
| } | |
| response = client.post("/predict", json=issue_data) | |
| # Should return validation error (422) | |
| assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY | |
| def test_predict_invalid_pr_number(self): | |
| """Test prediction fails with invalid PR number.""" | |
| issue_data = { | |
| "issue_text": "Fix bug", | |
| "pr_number": -5 | |
| } | |
| response = client.post("/predict", json=issue_data) | |
| # Should return validation error | |
| assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY | |
| def test_predict_empty_issue_text(self): | |
| """Test prediction with empty issue text.""" | |
| issue_data = { | |
| "issue_text": "" | |
| } | |
| response = client.post("/predict", json=issue_data) | |
| # Should return validation error (min_length=1) | |
| assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY | |
| def test_predict_whitespace_only_text(self): | |
| """Test prediction with whitespace-only text.""" | |
| issue_data = { | |
| "issue_text": " " # Only whitespace | |
| } | |
| response = client.post("/predict", json=issue_data) | |
| # Should be cleaned by validator | |
| assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY | |
| class TestBatchPredictionEndpoint: | |
| """Tests for the batch prediction endpoint.""" | |
| def test_batch_predict_multiple_issues(self): | |
| """Test batch prediction with multiple issues.""" | |
| batch_data = { | |
| "issues": [ | |
| { | |
| "issue_text": "Transfer learning with transformers for text classification." | |
| }, | |
| { | |
| "issue_text": "Generative adversarial networks in both PyTorch and TensorFlow." | |
| }, | |
| { | |
| "issue_text": "Fix database connection pooling issue" | |
| } | |
| ] | |
| } | |
| response = client.post("/predict/batch", json=batch_data) | |
| assert response.status_code == HTTPStatus.OK | |
| assert response.request.method == "POST" | |
| data = response.json() | |
| assert "results" in data | |
| assert "total_issues" in data | |
| assert "total_processing_time_ms" in data | |
| # Verify correct number of results | |
| assert len(data["results"]) == len(batch_data["issues"]) | |
| assert data["total_issues"] == 3 | |
| # Verify each result has predictions | |
| for result in data["results"]: | |
| assert "predictions" in result | |
| assert "num_predictions" in result | |
| assert len(result["predictions"]) > 0 | |
| def test_batch_predict_single_issue(self): | |
| """Test batch prediction with single issue.""" | |
| batch_data = { | |
| "issues": [ | |
| { | |
| "issue_text": "Add unit tests for authentication module" | |
| } | |
| ] | |
| } | |
| response = client.post("/predict/batch", json=batch_data) | |
| assert response.status_code == HTTPStatus.OK | |
| data = response.json() | |
| assert data["total_issues"] == 1 | |
| assert len(data["results"]) == 1 | |
| def test_batch_predict_empty_list(self): | |
| """Test batch prediction with empty issues list.""" | |
| batch_data = { | |
| "issues": [] | |
| } | |
| response = client.post("/predict/batch", json=batch_data) | |
| # Should return validation error (min_length=1) | |
| assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY | |
| def test_batch_predict_too_many_issues(self): | |
| """Test batch prediction exceeds maximum limit.""" | |
| batch_data = { | |
| "issues": [ | |
| {"issue_text": f"Issue {i}"} | |
| for i in range(101) | |
| ] | |
| } | |
| response = client.post("/predict/batch", json=batch_data) | |
| # Should return validation error | |
| assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY | |
| def test_batch_predict_with_mixed_data(self): | |
| """Test batch prediction with mix of minimal and full data.""" | |
| batch_data = { | |
| "issues": [ | |
| { | |
| "issue_text": "Simple issue" | |
| }, | |
| { | |
| "issue_text": "Detailed issue", | |
| "issue_description": "With description and metadata", | |
| "repo_name": "user/repo", | |
| "pr_number": 123 | |
| } | |
| ] | |
| } | |
| response = client.post("/predict/batch", json=batch_data) | |
| assert response.status_code == HTTPStatus.OK | |
| data = response.json() | |
| assert len(data["results"]) == 2 | |
| class TestErrorHandling: | |
| """Tests for error handling and responses.""" | |
| def test_missing_required_field(self): | |
| """Test validation error for missing required field.""" | |
| response = client.post("/predict", json={}) | |
| assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY | |
| def test_endpoint_not_found(self): | |
| """Test non-existent endpoint returns 404.""" | |
| response = client.get("/nonexistent") | |
| assert response.status_code == HTTPStatus.NOT_FOUND | |
| class TestGetPredictionEndpoint: | |
| """Tests for retrieving individual predictions by run_id.""" | |
| def test_get_prediction_success(self): | |
| """Test retrieving an existing prediction.""" | |
| issue_data = {"issue_text": "Test issue for retrieval"} | |
| create_response = client.post("/predict", json=issue_data) | |
| assert create_response.status_code == HTTPStatus.CREATED | |
| run_id = create_response.json()["run_id"] | |
| response = client.get(f"/predictions/{run_id}") | |
| assert response.status_code == HTTPStatus.OK | |
| data = response.json() | |
| assert data["run_id"] == run_id | |
| assert "predictions" in data | |
| assert "timestamp" in data | |
| def test_get_prediction_not_found(self): | |
| """Test retrieving a non-existent prediction returns 404.""" | |
| fake_run_id = "nonexistent_run_id_12345" | |
| response = client.get(f"/predictions/{fake_run_id}") | |
| assert response.status_code == HTTPStatus.NOT_FOUND | |
| class TestListPredictionsEndpoint: | |
| """Tests for listing recent predictions.""" | |
| def test_list_predictions(self): | |
| """Test listing predictions works.""" | |
| response = client.get("/predictions") | |
| assert response.status_code == HTTPStatus.OK | |
| data = response.json() | |
| assert isinstance(data, list) | |
| def test_list_predictions_with_pagination(self): | |
| """Test listing predictions with pagination parameters.""" | |
| response = client.get("/predictions?skip=0&limit=5") | |
| assert response.status_code == HTTPStatus.OK | |
| data = response.json() | |
| assert isinstance(data, list) | |
| assert len(data) <= 5 | |
| class TestMLflowIntegration: | |
| """Tests for MLflow tracking integration.""" | |
| def test_prediction_creates_run_id(self): | |
| """Test that predictions create MLflow run_id.""" | |
| issue_data = {"issue_text": "MLflow tracking test"} | |
| response = client.post("/predict", json=issue_data) | |
| assert response.status_code == HTTPStatus.CREATED | |
| data = response.json() | |
| assert "run_id" in data | |
| assert data["run_id"] | |
| def test_retrieve_prediction_by_run_id(self): | |
| """Test retrieving prediction using run_id.""" | |
| response = client.post("/predict", json={"issue_text": "Test retrieval"}) | |
| run_id = response.json()["run_id"] | |
| retrieve_response = client.get(f"/predictions/{run_id}") | |
| assert retrieve_response.status_code == HTTPStatus.OK | |
| assert retrieve_response.json()["run_id"] == run_id | |
| if __name__ == "__main__": | |
| pytest.main([__file__, "-v"]) | |