reachy_mini_minder / tests /test_output_escaping.py
Boopster's picture
initial commit
af9cde9
"""Tests for output escaping — verifies XSS and SQL-injection mitigations.
Covers:
- Finding #1: appointment_export.format_as_html() escapes LLM summary + DB values
- Finding #2: email_service.send_missed_dose_alert() escapes patient/med names
- Finding #4: database.update_profile() rejects unrecognized column names
"""
import html as _html
import sqlite3
import tempfile
from pathlib import Path
from unittest.mock import patch, MagicMock
import pytest
# ---- XSS payloads used across tests ----
XSS_SCRIPT = '<script>alert("xss")</script>'
XSS_IMG = "<img src=x onerror=\"fetch('https://evil.com')\">"
XSS_SVG = '<svg onload="alert(1)">'
PAYLOADS = [XSS_SCRIPT, XSS_IMG, XSS_SVG]
# ===========================================================================
# Finding #1: appointment_export.format_as_html
# ===========================================================================
class TestAppointmentExportEscaping:
"""format_as_html must escape all interpolated values."""
def _call(self, **kwargs):
from reachy_mini_conversation_app.appointment_export import format_as_html
defaults = dict(
summary="Safe summary",
headaches=[],
medications=[],
patient_name="Alice",
days=30,
)
defaults.update(kwargs)
return format_as_html(**defaults)
@pytest.mark.parametrize("payload", PAYLOADS)
def test_summary_is_escaped(self, payload):
html = self._call(summary=payload)
assert payload not in html, f"Raw payload found in HTML: {payload}"
assert _html.escape(payload) in html
@pytest.mark.parametrize("payload", PAYLOADS)
def test_patient_name_is_escaped(self, payload):
html = self._call(patient_name=payload)
assert payload not in html
@pytest.mark.parametrize("payload", PAYLOADS)
def test_headache_notes_escaped(self, payload):
headaches = [{"date": "2025-01-01", "severity": 5, "notes": payload}]
html = self._call(headaches=headaches)
assert payload not in html
@pytest.mark.parametrize("payload", PAYLOADS)
def test_medication_name_escaped(self, payload):
medications = [{"date": "2025-01-01", "medication_name": payload}]
html = self._call(medications=medications)
assert payload not in html
# ===========================================================================
# Finding #2: email_service.send_missed_dose_alert
# ===========================================================================
class TestMissedDoseEmailEscaping:
"""HTML email body must escape patient and medication names."""
def _get_html_body(self, patient_name="Alice", missed_meds=None):
"""Call send_missed_dose_alert and capture the HTML body."""
if missed_meds is None:
missed_meds = [{"medication_name": "Topiramate", "scheduled_time": "08:00"}]
captured_html = {}
def fake_send(to, subject, body_text, body_html=None, from_name=""):
captured_html["html"] = body_html
return True
with patch(
"reachy_mini_conversation_app.email_service.send_email",
side_effect=fake_send,
):
from reachy_mini_conversation_app.email_service import (
send_missed_dose_alert,
)
send_missed_dose_alert("test@example.com", patient_name, missed_meds)
return captured_html.get("html", "")
@pytest.mark.parametrize("payload", PAYLOADS)
def test_patient_name_escaped_in_html(self, payload):
html = self._get_html_body(patient_name=payload)
assert payload not in html, f"Raw XSS payload in email HTML: {payload}"
@pytest.mark.parametrize("payload", PAYLOADS)
def test_medication_name_escaped_in_html(self, payload):
meds = [{"medication_name": payload, "scheduled_time": "08:00"}]
html = self._get_html_body(missed_meds=meds)
assert payload not in html, f"Raw XSS payload in email HTML: {payload}"
# ===========================================================================
# Finding #4: database.update_profile column-name allowlist
# ===========================================================================
class TestUpdateProfileAllowlist:
"""update_profile must reject keys not in the allowlist."""
@pytest.fixture()
def db(self, tmp_path):
from reachy_mini_conversation_app.database import MiniMinderDB
db_path = str(tmp_path / "test.db")
database = MiniMinderDB(db_path)
database.get_or_create_profile()
return database
def test_valid_key_accepted(self, db):
"""Known column should be written."""
db.update_profile({"display_name": "Bob"})
profile = db.get_or_create_profile()
assert profile["display_name"] == "Bob"
def test_invalid_key_rejected(self, db):
"""Unknown key like SQL injection payload should be silently dropped."""
db.update_profile({"display_name; DROP TABLE user_profile --": "evil"})
# If we got here without error, the bad key was filtered out
profile = db.get_or_create_profile()
assert profile is not None # Table still exists
def test_mixed_keys_only_valid_written(self, db):
"""Only allowlisted keys should be persisted."""
db.update_profile(
{
"display_name": "Charlie",
"evil_column": "should_be_dropped",
}
)
profile = db.get_or_create_profile()
assert profile["display_name"] == "Charlie"
def test_all_invalid_keys_is_noop(self, db):
"""If every key is rejected, nothing should be written (no SQL error)."""
db.update_profile({"bad_col_1": "a", "bad_col_2": "b"})
profile = db.get_or_create_profile()
assert profile is not None