| import pytest
|
| from fastapi.testclient import TestClient
|
| from unittest.mock import patch
|
| from constants import MAX_COMMENT_LENGTH, MAX_RESPONSE_LENGTH
|
| from main import app
|
|
|
| client = TestClient(app)
|
|
|
| class TestFeedbackEndpoint:
|
| """Consolidated tests for POST /feedback"""
|
|
|
| @pytest.fixture
|
| def base_payload(self):
|
| """Standard valid payload structure"""
|
| return {
|
| "user_id": "test-user-123",
|
| "participant_id": "participant-456",
|
| "session_id": "test-session-123",
|
| "consent": True,
|
| "age_group": "25-34",
|
| "gender": "M",
|
| "roles": ["patient"],
|
| "message_index": 5,
|
| "rating": "like",
|
| "reply_content": "Helpful response",
|
| "comment": "Clear advice"
|
| }
|
|
|
|
|
|
|
| def test_feedback_success_and_logging(self, base_payload):
|
| """Tests the full happy path and ensures background tasks/logging are triggered"""
|
| with patch("main.log_event") as mock_log, \
|
| patch("main.BackgroundTasks.add_task") as mock_task:
|
|
|
| response = client.post("/feedback", json=base_payload)
|
|
|
| assert response.status_code == 200
|
| assert mock_task.called
|
|
|
| @pytest.mark.parametrize("rating", ["like", "dislike", "mixed"])
|
| def test_valid_ratings(self, base_payload, rating):
|
| """Consolidated: Tests all valid rating strings"""
|
| base_payload["rating"] = rating
|
| response = client.post("/feedback", json=base_payload)
|
| assert response.status_code == 200
|
|
|
| def test_comment_optionality(self, base_payload):
|
| """Tests that comment can be empty but must exist as a key"""
|
| base_payload["comment"] = ""
|
| response = client.post("/feedback", json=base_payload)
|
| assert response.status_code == 200
|
|
|
|
|
|
|
| @pytest.mark.parametrize("index, expected_status", [
|
| (0, 200),
|
| (10000, 200),
|
| (-1, 422),
|
| (10001, 422),
|
| ])
|
| def test_message_index_constraints(self, base_payload, index, expected_status):
|
| """Verifies ge=0 and le=10000 constraints"""
|
| base_payload["message_index"] = index
|
| response = client.post("/feedback", json=base_payload)
|
| assert response.status_code == expected_status
|
|
|
|
|
|
|
| def test_html_sanitization(self, base_payload):
|
| """Ensures XSS tags are stripped (Relies on nh3 in your model)"""
|
| base_payload["comment"] = "<script>alert('xss')</script>Safe Text"
|
|
|
|
|
| response = client.post("/feedback", json=base_payload)
|
| assert response.status_code == 200
|
|
|
| @pytest.mark.parametrize("field, length", [
|
| ("comment", MAX_COMMENT_LENGTH + 1),
|
| ("reply_content", MAX_RESPONSE_LENGTH + 1),
|
| ])
|
| def test_string_max_lengths(self, base_payload, field, length):
|
| """Verifies length constraints for strings"""
|
| base_payload[field] = "x" * length
|
| response = client.post("/feedback", json=base_payload)
|
| assert response.status_code == 422
|
|
|
|
|
|
|
| @pytest.mark.enable_rate_limit
|
| def test_feedback_rate_limiting(self, base_payload):
|
| """Verifies the 20 requests per minute limit"""
|
|
|
| with TestClient(app) as limit_client:
|
| for _ in range(20):
|
| limit_client.post("/feedback", json=base_payload)
|
|
|
| over_limit_response = limit_client.post("/feedback", json=base_payload)
|
| assert over_limit_response.status_code == 429 |