"""API endpoint tests.
Tests /interactions and /health endpoints directly.
/analyze requires the NER model loaded — tested via Docker or manual run.
"""
import pytest
from unittest.mock import AsyncMock, patch, MagicMock
from fastapi.testclient import TestClient
@pytest.fixture
def mock_drugbank():
"""Mock drugbank_client in every module that imports it."""
mock = MagicMock()
mock.get_interactions = AsyncMock()
mock.health_check = AsyncMock(return_value=True)
mock.connect = AsyncMock()
mock.close = AsyncMock()
mock.DrugBankUnavailableError = Exception
with patch("app.services.interaction_checker.drugbank_client", mock), \
patch("app.api.health.drugbank_client", mock), \
patch("app.main.drugbank_client", mock):
yield mock
@pytest.fixture
def mock_severity():
"""Mock severity_classifier in every module that imports it."""
mock = MagicMock()
mock.classify.return_value = ("moderate", False)
mock.load_model = MagicMock()
mock.is_loaded.return_value = True
with patch("app.services.interaction_checker.severity_classifier", mock), \
patch("app.main.severity_classifier", mock):
yield mock
@pytest.fixture
def mock_severity_parser():
"""Mock severity_parser in interaction checker."""
mock = MagicMock()
mock.parse_severity.return_value = "moderate"
with patch("app.services.interaction_checker.severity_parser", mock):
yield mock
@pytest.fixture
def client(mock_drugbank, mock_severity, mock_severity_parser):
from app.main import app
return TestClient(app)
class TestAnalyzeValidation:
def test_analyze_rejects_oversized_text(self, client):
"""Text over 5000 chars must be rejected with 422."""
resp = client.post(
"/analyze",
json={"text": "Metformin 500mg " * 500},
headers={"X-API-Key": "test-key"},
)
assert resp.status_code == 422
def test_analyze_strips_html_from_raw_text(self, client):
"""HTML tags must be stripped from raw_text to prevent XSS."""
with patch("app.services.drug_analyzer.analyze", new=AsyncMock(return_value=[])):
resp = client.post(
"/analyze",
json={"text": 'Metformin 500mg'},
headers={"X-API-Key": "test-key"},
)
assert resp.status_code == 200
data = resp.json()
assert "