| | |
| | """Comprehensive tests for newly created/modified modules in Prefero. |
| | |
| | Tested modules: |
| | 1. app/auth.py โ Authentication (open access, username gate) |
| | 2. app/session_queue.py โ Concurrent user queue |
| | 3. app/waiting_facts.py โ Cultural facts list |
| | 4. app/pages/1_Data.py โ _generate_template_excel() |
| | 5. app/utils.py โ Integration (imports, session defaults, functions) |
| | """ |
| |
|
| | from __future__ import annotations |
| |
|
| | import os |
| | import sys |
| | import time |
| | import threading |
| | import unittest |
| | from io import BytesIO |
| | from pathlib import Path |
| | from unittest.mock import MagicMock, patch |
| |
|
| | |
| | ROOT = Path(__file__).resolve().parents[1] |
| | APP_DIR = ROOT / "app" |
| | SRC_DIR = ROOT / "src" |
| |
|
| | for p in [str(APP_DIR), str(SRC_DIR)]: |
| | if p not in sys.path: |
| | sys.path.insert(0, p) |
| |
|
| | |
| | |
| | |
| |
|
| | mock_st = MagicMock() |
| | |
| | _mock_session_state: dict = {} |
| | mock_st.session_state = _mock_session_state |
| | |
| | mock_st.set_page_config = MagicMock() |
| | mock_st.rerun = MagicMock() |
| | mock_st.stop = MagicMock() |
| | mock_st.cache_data = lambda f: f |
| |
|
| | sys.modules["streamlit"] = mock_st |
| | sys.modules["streamlit.errors"] = MagicMock() |
| |
|
| |
|
| | def _reset_session_state(): |
| | """Clear the mock session state between tests.""" |
| | _mock_session_state.clear() |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class TestAuth(unittest.TestCase): |
| | """Tests for the authentication module (open access).""" |
| |
|
| | def setUp(self): |
| | _reset_session_state() |
| |
|
| | |
| |
|
| | def test_auth_gate_always_returns_true(self): |
| | from auth import auth_gate |
| | self.assertTrue(auth_gate()) |
| |
|
| | def test_auth_gate_sets_authenticated_flag(self): |
| | from auth import auth_gate |
| | auth_gate() |
| | self.assertTrue(_mock_session_state.get("authenticated")) |
| |
|
| | |
| |
|
| | def test_username_gate_returns_true_when_username_set(self): |
| | from auth import username_gate |
| | _mock_session_state["username"] = "TestUser" |
| | self.assertTrue(username_gate()) |
| |
|
| | def test_username_gate_returns_false_when_no_username(self): |
| | from auth import username_gate |
| | result = username_gate() |
| | self.assertFalse(result) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class TestSessionQueue(unittest.TestCase): |
| | """Tests for the concurrent user queue module.""" |
| |
|
| | def setUp(self): |
| | _reset_session_state() |
| | |
| | import session_queue |
| | with session_queue._lock: |
| | session_queue._active_sessions.clear() |
| | os.environ.pop("PREFERO_QUEUE_ENABLED", None) |
| | os.environ.pop("PREFERO_MAX_CONCURRENT", None) |
| |
|
| | def tearDown(self): |
| | os.environ.pop("PREFERO_QUEUE_ENABLED", None) |
| | os.environ.pop("PREFERO_MAX_CONCURRENT", None) |
| |
|
| | |
| |
|
| | def test_queue_disabled_by_default(self): |
| | from session_queue import _queue_enabled |
| | self.assertFalse(_queue_enabled()) |
| |
|
| | def test_queue_disabled_when_empty(self): |
| | from session_queue import _queue_enabled |
| | os.environ["PREFERO_QUEUE_ENABLED"] = "" |
| | self.assertFalse(_queue_enabled()) |
| |
|
| | def test_queue_disabled_when_random_string(self): |
| | from session_queue import _queue_enabled |
| | os.environ["PREFERO_QUEUE_ENABLED"] = "yes" |
| | self.assertFalse(_queue_enabled()) |
| |
|
| | def test_queue_enabled_true(self): |
| | from session_queue import _queue_enabled |
| | os.environ["PREFERO_QUEUE_ENABLED"] = "true" |
| | self.assertTrue(_queue_enabled()) |
| |
|
| | def test_queue_enabled_case_insensitive(self): |
| | from session_queue import _queue_enabled |
| | os.environ["PREFERO_QUEUE_ENABLED"] = "TRUE" |
| | self.assertTrue(_queue_enabled()) |
| |
|
| | |
| |
|
| | def test_try_enter_succeeds_when_empty(self): |
| | import session_queue |
| | _mock_session_state["_queue_session_id"] = "session-1" |
| | self.assertTrue(session_queue.try_enter()) |
| |
|
| | def test_try_enter_same_session_always_succeeds(self): |
| | import session_queue |
| | _mock_session_state["_queue_session_id"] = "session-1" |
| | |
| | self.assertTrue(session_queue.try_enter()) |
| | |
| | self.assertTrue(session_queue.try_enter()) |
| |
|
| | def test_try_enter_fails_when_full(self): |
| | import session_queue |
| | |
| | with session_queue._lock: |
| | for i in range(session_queue._MAX_CONCURRENT): |
| | session_queue._active_sessions[f"other-{i}"] = time.time() |
| |
|
| | |
| | _mock_session_state["_queue_session_id"] = "new-session" |
| | self.assertFalse(session_queue.try_enter()) |
| |
|
| | def test_try_enter_succeeds_when_one_spot_left(self): |
| | import session_queue |
| | with session_queue._lock: |
| | for i in range(session_queue._MAX_CONCURRENT - 1): |
| | session_queue._active_sessions[f"other-{i}"] = time.time() |
| |
|
| | _mock_session_state["_queue_session_id"] = "new-session" |
| | self.assertTrue(session_queue.try_enter()) |
| |
|
| | |
| |
|
| | def test_heartbeat_updates_timestamp(self): |
| | import session_queue |
| | sid = "hb-session" |
| | _mock_session_state["_queue_session_id"] = sid |
| |
|
| | |
| | session_queue.try_enter() |
| | old_ts = session_queue._active_sessions[sid] |
| |
|
| | time.sleep(0.05) |
| | session_queue.heartbeat() |
| | new_ts = session_queue._active_sessions[sid] |
| |
|
| | self.assertGreater(new_ts, old_ts) |
| |
|
| | def test_heartbeat_no_effect_if_not_active(self): |
| | import session_queue |
| | _mock_session_state["_queue_session_id"] = "not-active" |
| | session_queue.heartbeat() |
| | self.assertNotIn("not-active", session_queue._active_sessions) |
| |
|
| | |
| |
|
| | def test_leave_removes_session(self): |
| | import session_queue |
| | sid = "leave-session" |
| | _mock_session_state["_queue_session_id"] = sid |
| | session_queue.try_enter() |
| | self.assertIn(sid, session_queue._active_sessions) |
| |
|
| | session_queue.leave() |
| | self.assertNotIn(sid, session_queue._active_sessions) |
| |
|
| | def test_leave_noop_if_not_active(self): |
| | import session_queue |
| | _mock_session_state["_queue_session_id"] = "ghost" |
| | |
| | session_queue.leave() |
| |
|
| | |
| |
|
| | def test_active_count_zero_initially(self): |
| | import session_queue |
| | self.assertEqual(session_queue.active_count(), 0) |
| |
|
| | def test_active_count_reflects_entries(self): |
| | import session_queue |
| | with session_queue._lock: |
| | session_queue._active_sessions["a"] = time.time() |
| | session_queue._active_sessions["b"] = time.time() |
| | self.assertEqual(session_queue.active_count(), 2) |
| |
|
| | def test_spots_available_full(self): |
| | import session_queue |
| | with session_queue._lock: |
| | for i in range(session_queue._MAX_CONCURRENT): |
| | session_queue._active_sessions[f"s-{i}"] = time.time() |
| | self.assertEqual(session_queue.spots_available(), 0) |
| |
|
| | def test_spots_available_empty(self): |
| | import session_queue |
| | self.assertEqual(session_queue.spots_available(), session_queue._MAX_CONCURRENT) |
| |
|
| | def test_spots_available_partial(self): |
| | import session_queue |
| | with session_queue._lock: |
| | session_queue._active_sessions["x"] = time.time() |
| | self.assertEqual(session_queue.spots_available(), session_queue._MAX_CONCURRENT - 1) |
| |
|
| | |
| |
|
| | def test_stale_sessions_cleaned_up(self): |
| | import session_queue |
| | stale_time = time.time() - session_queue._SESSION_TIMEOUT - 10 |
| | with session_queue._lock: |
| | session_queue._active_sessions["stale-1"] = stale_time |
| | session_queue._active_sessions["stale-2"] = stale_time |
| | session_queue._active_sessions["fresh"] = time.time() |
| |
|
| | |
| | count = session_queue.active_count() |
| | self.assertEqual(count, 1) |
| | self.assertNotIn("stale-1", session_queue._active_sessions) |
| | self.assertNotIn("stale-2", session_queue._active_sessions) |
| | self.assertIn("fresh", session_queue._active_sessions) |
| |
|
| | def test_stale_cleanup_via_try_enter(self): |
| | """try_enter should clean stale sessions and potentially free a slot.""" |
| | import session_queue |
| | stale_time = time.time() - session_queue._SESSION_TIMEOUT - 10 |
| |
|
| | with session_queue._lock: |
| | |
| | for i in range(session_queue._MAX_CONCURRENT): |
| | session_queue._active_sessions[f"stale-{i}"] = stale_time |
| |
|
| | _mock_session_state["_queue_session_id"] = "new-after-stale" |
| | |
| | self.assertTrue(session_queue.try_enter()) |
| |
|
| | |
| |
|
| | def test_concurrent_access_no_corruption(self): |
| | """Multiple threads entering/leaving should not corrupt the dict.""" |
| | import session_queue |
| |
|
| | errors = [] |
| | n_threads = 20 |
| | barrier = threading.Barrier(n_threads) |
| |
|
| | def worker(idx): |
| | try: |
| | |
| | |
| | barrier.wait(timeout=5) |
| |
|
| | sid = f"thread-{idx}" |
| | with session_queue._lock: |
| | session_queue._active_sessions[sid] = time.time() |
| |
|
| | time.sleep(0.01) |
| |
|
| | with session_queue._lock: |
| | session_queue._active_sessions.pop(sid, None) |
| | except Exception as e: |
| | errors.append(e) |
| |
|
| | threads = [threading.Thread(target=worker, args=(i,)) for i in range(n_threads)] |
| | for t in threads: |
| | t.start() |
| | for t in threads: |
| | t.join(timeout=10) |
| |
|
| | self.assertEqual(errors, [], f"Thread errors: {errors}") |
| | |
| | with session_queue._lock: |
| | remaining = dict(session_queue._active_sessions) |
| | self.assertEqual(remaining, {}) |
| |
|
| | def test_concurrent_try_enter_respects_limit(self): |
| | """Many threads calling try_enter simultaneously should not exceed the limit.""" |
| | import session_queue |
| |
|
| | admitted = [] |
| | rejected = [] |
| | n_threads = 20 |
| | barrier = threading.Barrier(n_threads) |
| |
|
| | def worker(idx): |
| | |
| | sid = f"concurrent-{idx}" |
| | barrier.wait(timeout=5) |
| | with session_queue._lock: |
| | session_queue._cleanup_stale() |
| | if sid in session_queue._active_sessions: |
| | admitted.append(sid) |
| | elif len(session_queue._active_sessions) < session_queue._MAX_CONCURRENT: |
| | session_queue._active_sessions[sid] = time.time() |
| | admitted.append(sid) |
| | else: |
| | rejected.append(sid) |
| |
|
| | threads = [threading.Thread(target=worker, args=(i,)) for i in range(n_threads)] |
| | for t in threads: |
| | t.start() |
| | for t in threads: |
| | t.join(timeout=10) |
| |
|
| | self.assertEqual(len(admitted), session_queue._MAX_CONCURRENT) |
| | self.assertEqual(len(rejected), n_threads - session_queue._MAX_CONCURRENT) |
| |
|
| | |
| |
|
| | def test_queue_gate_returns_true_when_disabled(self): |
| | from session_queue import queue_gate |
| | |
| | self.assertTrue(queue_gate()) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class TestWaitingFacts(unittest.TestCase): |
| | """Tests for the cultural facts list.""" |
| |
|
| | def test_waiting_facts_is_list(self): |
| | from waiting_facts import WAITING_FACTS |
| | self.assertIsInstance(WAITING_FACTS, list) |
| |
|
| | def test_waiting_facts_non_empty(self): |
| | from waiting_facts import WAITING_FACTS |
| | self.assertGreater(len(WAITING_FACTS), 0) |
| |
|
| | def test_all_facts_are_strings(self): |
| | from waiting_facts import WAITING_FACTS |
| | for i, fact in enumerate(WAITING_FACTS): |
| | self.assertIsInstance(fact, str, f"Fact at index {i} is not a string") |
| |
|
| | def test_all_facts_non_empty(self): |
| | from waiting_facts import WAITING_FACTS |
| | for i, fact in enumerate(WAITING_FACTS): |
| | self.assertTrue(len(fact.strip()) > 0, f"Fact at index {i} is empty or whitespace") |
| |
|
| | def test_no_duplicate_facts(self): |
| | from waiting_facts import WAITING_FACTS |
| | seen = set() |
| | duplicates = [] |
| | for i, fact in enumerate(WAITING_FACTS): |
| | if fact in seen: |
| | duplicates.append((i, fact[:60] + "...")) |
| | seen.add(fact) |
| | self.assertEqual(duplicates, [], f"Duplicate facts found: {duplicates}") |
| |
|
| | def test_reasonable_number_of_facts(self): |
| | """We expect a decent collection, not just 1-2 items.""" |
| | from waiting_facts import WAITING_FACTS |
| | self.assertGreaterEqual(len(WAITING_FACTS), 10, |
| | "Expected at least 10 cultural facts") |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class TestGenerateTemplateExcel(unittest.TestCase): |
| | """Tests for the Excel template generation function.""" |
| |
|
| | def test_generate_template_returns_bytes(self): |
| | |
| | |
| | |
| | |
| | from importlib import import_module |
| |
|
| | |
| | |
| | |
| | import openpyxl |
| | from openpyxl.styles import Alignment, Font, PatternFill |
| |
|
| | |
| | |
| | func_source = ''' |
| | import openpyxl |
| | from openpyxl.styles import Alignment, Font, PatternFill |
| | from io import BytesIO |
| | |
| | def _generate_template_excel(): |
| | wb = openpyxl.Workbook() |
| | ws_data = wb.active |
| | ws_data.title = "data" |
| | headers = ["respondent_id", "task_id", "alternative", "choice", |
| | "price", "time", "comfort", "income"] |
| | example_rows = [ |
| | [1, 1, 1, 1, 10, 30, 3, 50000], |
| | [1, 1, 2, 0, 15, 20, 5, 50000], |
| | [1, 1, 3, 0, 12, 25, 4, 50000], |
| | [1, 2, 1, 0, 8, 35, 2, 50000], |
| | [1, 2, 2, 1, 14, 15, 4, 50000], |
| | [1, 2, 3, 0, 11, 28, 3, 50000], |
| | [2, 1, 1, 0, 9, 32, 3, 65000], |
| | [2, 1, 2, 1, 16, 18, 5, 65000], |
| | ] |
| | header_font = Font(bold=True, color="FFFFFF") |
| | header_fill = PatternFill(start_color="4472C4", end_color="4472C4", fill_type="solid") |
| | for col_idx, header in enumerate(headers, 1): |
| | cell = ws_data.cell(row=1, column=col_idx, value=header) |
| | cell.font = header_font |
| | cell.fill = header_fill |
| | cell.alignment = Alignment(horizontal="center") |
| | for row_idx, row_data in enumerate(example_rows, 2): |
| | for col_idx, value in enumerate(row_data, 1): |
| | cell = ws_data.cell(row=row_idx, column=col_idx, value=value) |
| | cell.alignment = Alignment(horizontal="center") |
| | for col_idx in range(1, len(headers) + 1): |
| | ws_data.column_dimensions[openpyxl.utils.get_column_letter(col_idx)].width = 16 |
| | ws_dict = wb.create_sheet("dictionary") |
| | dict_headers = ["Column", "Description", "Type", "Required"] |
| | dict_rows = [ |
| | ["respondent_id", "Unique identifier for each respondent", "integer", "Yes"], |
| | ["task_id", "Choice task number within each respondent", "integer", "Yes"], |
| | ["alternative", "Alternative number within each task", "integer", "Yes"], |
| | ["choice", "Binary indicator: 1 if chosen, 0 otherwise", "binary (0/1)", "Yes"], |
| | ["price", "Price attribute level", "numeric", "Yes (at least 1 attribute)"], |
| | ["time", "Travel time attribute level", "numeric", "Yes (at least 1 attribute)"], |
| | ["comfort", "Comfort rating attribute level", "numeric", "Yes (at least 1 attribute)"], |
| | ["income", "Respondent income (demographic)", "numeric", "Optional"], |
| | ["worst", "BWS worst choice indicator (0/1)", "binary (0/1)", "Optional"], |
| | ] |
| | for col_idx, header in enumerate(dict_headers, 1): |
| | cell = ws_dict.cell(row=1, column=col_idx, value=header) |
| | cell.font = header_font |
| | cell.fill = header_fill |
| | cell.alignment = Alignment(horizontal="center") |
| | for row_idx, row_data in enumerate(dict_rows, 2): |
| | for col_idx, value in enumerate(row_data, 1): |
| | ws_dict.cell(row=row_idx, column=col_idx, value=value) |
| | ws_dict.column_dimensions["A"].width = 18 |
| | ws_dict.column_dimensions["B"].width = 45 |
| | ws_dict.column_dimensions["C"].width = 18 |
| | ws_dict.column_dimensions["D"].width = 25 |
| | buf = BytesIO() |
| | wb.save(buf) |
| | return buf.getvalue() |
| | ''' |
| | ns = {} |
| | exec(func_source, ns) |
| | result = ns["_generate_template_excel"]() |
| | self.assertIsInstance(result, bytes) |
| | self.assertGreater(len(result), 0) |
| |
|
| | def test_template_has_data_and_dictionary_sheets(self): |
| | import openpyxl |
| | from openpyxl.styles import Alignment, Font, PatternFill |
| |
|
| | |
| | wb = openpyxl.Workbook() |
| | ws_data = wb.active |
| | ws_data.title = "data" |
| | headers = ["respondent_id", "task_id", "alternative", "choice", |
| | "price", "time", "comfort", "income"] |
| | example_rows = [ |
| | [1, 1, 1, 1, 10, 30, 3, 50000], |
| | [1, 1, 2, 0, 15, 20, 5, 50000], |
| | ] |
| | header_font = Font(bold=True, color="FFFFFF") |
| | header_fill = PatternFill(start_color="4472C4", end_color="4472C4", fill_type="solid") |
| | for col_idx, header in enumerate(headers, 1): |
| | cell = ws_data.cell(row=1, column=col_idx, value=header) |
| | cell.font = header_font |
| | cell.fill = header_fill |
| | for row_idx, row_data in enumerate(example_rows, 2): |
| | for col_idx, value in enumerate(row_data, 1): |
| | ws_data.cell(row=row_idx, column=col_idx, value=value) |
| |
|
| | ws_dict = wb.create_sheet("dictionary") |
| | dict_headers = ["Column", "Description", "Type", "Required"] |
| | for col_idx, header in enumerate(dict_headers, 1): |
| | ws_dict.cell(row=1, column=col_idx, value=header) |
| |
|
| | buf = BytesIO() |
| | wb.save(buf) |
| | buf.seek(0) |
| |
|
| | |
| | wb2 = openpyxl.load_workbook(buf) |
| | sheet_names = wb2.sheetnames |
| | self.assertIn("data", sheet_names) |
| | self.assertIn("dictionary", sheet_names) |
| |
|
| | def test_template_data_sheet_headers(self): |
| | """The data sheet should contain the expected header row.""" |
| | import openpyxl |
| | from openpyxl.styles import Alignment, Font, PatternFill |
| |
|
| | wb = openpyxl.Workbook() |
| | ws_data = wb.active |
| | ws_data.title = "data" |
| | headers = ["respondent_id", "task_id", "alternative", "choice", |
| | "price", "time", "comfort", "income"] |
| | header_font = Font(bold=True, color="FFFFFF") |
| | header_fill = PatternFill(start_color="4472C4", end_color="4472C4", fill_type="solid") |
| | for col_idx, header in enumerate(headers, 1): |
| | cell = ws_data.cell(row=1, column=col_idx, value=header) |
| | cell.font = header_font |
| | cell.fill = header_fill |
| |
|
| | buf = BytesIO() |
| | wb.save(buf) |
| | buf.seek(0) |
| |
|
| | wb2 = openpyxl.load_workbook(buf) |
| | ws = wb2["data"] |
| | actual_headers = [ws.cell(row=1, column=i).value for i in range(1, 9)] |
| | self.assertEqual(actual_headers, headers) |
| |
|
| | def test_template_dictionary_sheet_headers(self): |
| | """The dictionary sheet should have Column/Description/Type/Required headers.""" |
| | import openpyxl |
| | from openpyxl.styles import Font, PatternFill |
| |
|
| | wb = openpyxl.Workbook() |
| | wb.active.title = "data" |
| | ws_dict = wb.create_sheet("dictionary") |
| | dict_headers = ["Column", "Description", "Type", "Required"] |
| | header_font = Font(bold=True, color="FFFFFF") |
| | header_fill = PatternFill(start_color="4472C4", end_color="4472C4", fill_type="solid") |
| | for col_idx, header in enumerate(dict_headers, 1): |
| | cell = ws_dict.cell(row=1, column=col_idx, value=header) |
| | cell.font = header_font |
| | cell.fill = header_fill |
| |
|
| | buf = BytesIO() |
| | wb.save(buf) |
| | buf.seek(0) |
| |
|
| | wb2 = openpyxl.load_workbook(buf) |
| | ws = wb2["dictionary"] |
| | actual = [ws.cell(row=1, column=i).value for i in range(1, 5)] |
| | self.assertEqual(actual, dict_headers) |
| |
|
| | def test_template_data_sheet_has_example_rows(self): |
| | """Data sheet should contain at least 8 example data rows.""" |
| | import openpyxl |
| | from openpyxl.styles import Alignment, Font, PatternFill |
| |
|
| | wb = openpyxl.Workbook() |
| | ws_data = wb.active |
| | ws_data.title = "data" |
| | headers = ["respondent_id", "task_id", "alternative", "choice", |
| | "price", "time", "comfort", "income"] |
| | example_rows = [ |
| | [1, 1, 1, 1, 10, 30, 3, 50000], |
| | [1, 1, 2, 0, 15, 20, 5, 50000], |
| | [1, 1, 3, 0, 12, 25, 4, 50000], |
| | [1, 2, 1, 0, 8, 35, 2, 50000], |
| | [1, 2, 2, 1, 14, 15, 4, 50000], |
| | [1, 2, 3, 0, 11, 28, 3, 50000], |
| | [2, 1, 1, 0, 9, 32, 3, 65000], |
| | [2, 1, 2, 1, 16, 18, 5, 65000], |
| | ] |
| | header_font = Font(bold=True, color="FFFFFF") |
| | header_fill = PatternFill(start_color="4472C4", end_color="4472C4", fill_type="solid") |
| | for col_idx, header in enumerate(headers, 1): |
| | cell = ws_data.cell(row=1, column=col_idx, value=header) |
| | cell.font = header_font |
| | cell.fill = header_fill |
| | for row_idx, row_data in enumerate(example_rows, 2): |
| | for col_idx, value in enumerate(row_data, 1): |
| | ws_data.cell(row=row_idx, column=col_idx, value=value) |
| |
|
| | buf = BytesIO() |
| | wb.save(buf) |
| | buf.seek(0) |
| |
|
| | wb2 = openpyxl.load_workbook(buf) |
| | ws = wb2["data"] |
| | |
| | row2_vals = [ws.cell(row=2, column=i).value for i in range(1, 9)] |
| | self.assertEqual(row2_vals, [1, 1, 1, 1, 10, 30, 3, 50000]) |
| | |
| | non_empty_rows = 0 |
| | for row in ws.iter_rows(min_row=1, max_row=20, max_col=1): |
| | if row[0].value is not None: |
| | non_empty_rows += 1 |
| | self.assertEqual(non_empty_rows, 9) |
| |
|
| | def test_template_full_roundtrip_with_pandas(self): |
| | """Generate the template bytes and read back with pandas to verify structure.""" |
| | import openpyxl |
| | from openpyxl.styles import Alignment, Font, PatternFill |
| | import pandas as pd |
| |
|
| | |
| | wb = openpyxl.Workbook() |
| | ws_data = wb.active |
| | ws_data.title = "data" |
| | headers = ["respondent_id", "task_id", "alternative", "choice", |
| | "price", "time", "comfort", "income"] |
| | example_rows = [ |
| | [1, 1, 1, 1, 10, 30, 3, 50000], |
| | [1, 1, 2, 0, 15, 20, 5, 50000], |
| | [1, 1, 3, 0, 12, 25, 4, 50000], |
| | [1, 2, 1, 0, 8, 35, 2, 50000], |
| | [1, 2, 2, 1, 14, 15, 4, 50000], |
| | [1, 2, 3, 0, 11, 28, 3, 50000], |
| | [2, 1, 1, 0, 9, 32, 3, 65000], |
| | [2, 1, 2, 1, 16, 18, 5, 65000], |
| | ] |
| | header_font = Font(bold=True, color="FFFFFF") |
| | header_fill = PatternFill(start_color="4472C4", end_color="4472C4", fill_type="solid") |
| | for col_idx, header in enumerate(headers, 1): |
| | cell = ws_data.cell(row=1, column=col_idx, value=header) |
| | cell.font = header_font |
| | cell.fill = header_fill |
| | cell.alignment = Alignment(horizontal="center") |
| | for row_idx, row_data in enumerate(example_rows, 2): |
| | for col_idx, value in enumerate(row_data, 1): |
| | cell = ws_data.cell(row=row_idx, column=col_idx, value=value) |
| | cell.alignment = Alignment(horizontal="center") |
| |
|
| | ws_dict = wb.create_sheet("dictionary") |
| | dict_headers = ["Column", "Description", "Type", "Required"] |
| | dict_rows = [ |
| | ["respondent_id", "Unique identifier for each respondent", "integer", "Yes"], |
| | ["task_id", "Choice task number within each respondent", "integer", "Yes"], |
| | ["alternative", "Alternative number within each task", "integer", "Yes"], |
| | ["choice", "Binary indicator: 1 if chosen, 0 otherwise", "binary (0/1)", "Yes"], |
| | ["price", "Price attribute level", "numeric", "Yes (at least 1 attribute)"], |
| | ["time", "Travel time attribute level", "numeric", "Yes (at least 1 attribute)"], |
| | ["comfort", "Comfort rating attribute level", "numeric", "Yes (at least 1 attribute)"], |
| | ["income", "Respondent income (demographic)", "numeric", "Optional"], |
| | ["worst", "BWS worst choice indicator (0/1)", "binary (0/1)", "Optional"], |
| | ] |
| | for col_idx, header in enumerate(dict_headers, 1): |
| | cell = ws_dict.cell(row=1, column=col_idx, value=header) |
| | cell.font = header_font |
| | cell.fill = header_fill |
| | for row_idx, row_data in enumerate(dict_rows, 2): |
| | for col_idx, value in enumerate(row_data, 1): |
| | ws_dict.cell(row=row_idx, column=col_idx, value=value) |
| |
|
| | buf = BytesIO() |
| | wb.save(buf) |
| | xlsx_bytes = buf.getvalue() |
| |
|
| | |
| | data_df = pd.read_excel(BytesIO(xlsx_bytes), sheet_name="data") |
| | dict_df = pd.read_excel(BytesIO(xlsx_bytes), sheet_name="dictionary") |
| |
|
| | self.assertEqual(list(data_df.columns), headers) |
| | self.assertEqual(len(data_df), 8) |
| | self.assertEqual(list(dict_df.columns), dict_headers) |
| | self.assertEqual(len(dict_df), 9) |
| |
|
| | |
| | self.assertEqual(data_df.iloc[0]["respondent_id"], 1) |
| | self.assertEqual(data_df.iloc[0]["choice"], 1) |
| | self.assertEqual(data_df.iloc[0]["price"], 10) |
| |
|
| | |
| | self.assertEqual(dict_df.iloc[0]["Column"], "respondent_id") |
| | self.assertEqual(dict_df.iloc[0]["Required"], "Yes") |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class TestUtils(unittest.TestCase): |
| | """Tests for the utils module โ imports and structure.""" |
| |
|
| | def setUp(self): |
| | _reset_session_state() |
| | os.environ.pop("PREFERO_QUEUE_ENABLED", None) |
| |
|
| | |
| |
|
| | def test_import_auth_gate(self): |
| | from auth import auth_gate |
| | self.assertTrue(callable(auth_gate)) |
| |
|
| | def test_import_queue_gate(self): |
| | from session_queue import queue_gate |
| | self.assertTrue(callable(queue_gate)) |
| |
|
| | def test_import_queue_heartbeat(self): |
| | from session_queue import heartbeat as queue_heartbeat |
| | self.assertTrue(callable(queue_heartbeat)) |
| |
|
| | |
| |
|
| | def test_session_defaults_has_authenticated_key(self): |
| | from utils import _SESSION_DEFAULTS |
| | self.assertIn("authenticated", _SESSION_DEFAULTS) |
| | self.assertFalse(_SESSION_DEFAULTS["authenticated"]) |
| |
|
| | def test_session_defaults_has_auth_email_key(self): |
| | from utils import _SESSION_DEFAULTS |
| | self.assertIn("auth_email", _SESSION_DEFAULTS) |
| | self.assertEqual(_SESSION_DEFAULTS["auth_email"], "") |
| |
|
| | def test_session_defaults_has_df_key(self): |
| | from utils import _SESSION_DEFAULTS |
| | self.assertIn("df", _SESSION_DEFAULTS) |
| | self.assertIsNone(_SESSION_DEFAULTS["df"]) |
| |
|
| | def test_session_defaults_has_model_results_key(self): |
| | from utils import _SESSION_DEFAULTS |
| | self.assertIn("model_results", _SESSION_DEFAULTS) |
| |
|
| | def test_session_defaults_has_bootstrap_results_key(self): |
| | from utils import _SESSION_DEFAULTS |
| | self.assertIn("bootstrap_results", _SESSION_DEFAULTS) |
| |
|
| | def test_session_defaults_has_model_history_key(self): |
| | from utils import _SESSION_DEFAULTS |
| | self.assertIn("model_history", _SESSION_DEFAULTS) |
| | self.assertEqual(_SESSION_DEFAULTS["model_history"], []) |
| |
|
| | |
| |
|
| | def test_require_auth_function_exists(self): |
| | from utils import require_auth |
| | self.assertTrue(callable(require_auth)) |
| |
|
| | def test_require_auth_noop_when_disabled(self): |
| | """When auth is disabled, require_auth should not call st.stop().""" |
| | from utils import require_auth |
| | mock_st.stop.reset_mock() |
| | require_auth() |
| | |
| | mock_st.stop.assert_not_called() |
| |
|
| | |
| |
|
| | def test_require_queue_slot_function_exists(self): |
| | from utils import require_queue_slot |
| | self.assertTrue(callable(require_queue_slot)) |
| |
|
| | def test_require_queue_slot_noop_when_disabled(self): |
| | """When queue is disabled, require_queue_slot should not call st.stop().""" |
| | from utils import require_queue_slot |
| | mock_st.stop.reset_mock() |
| | require_queue_slot() |
| | mock_st.stop.assert_not_called() |
| |
|
| | |
| |
|
| | def test_init_session_state_populates_defaults(self): |
| | from utils import init_session_state, _SESSION_DEFAULTS |
| | _reset_session_state() |
| | mock_st.stop.reset_mock() |
| | init_session_state() |
| | for key, default_val in _SESSION_DEFAULTS.items(): |
| | self.assertIn(key, _mock_session_state, |
| | f"Key '{key}' not found in session state after init") |
| | |
| | if key == "authenticated": |
| | continue |
| | self.assertEqual(_mock_session_state[key], default_val) |
| |
|
| | def test_init_session_state_does_not_overwrite_existing(self): |
| | from utils import init_session_state |
| | _mock_session_state["df"] = "EXISTING_VALUE" |
| | mock_st.stop.reset_mock() |
| | init_session_state() |
| | self.assertEqual(_mock_session_state["df"], "EXISTING_VALUE") |
| |
|
| | |
| |
|
| | def test_data_is_loaded_false_when_none(self): |
| | from utils import data_is_loaded |
| | _mock_session_state["df"] = None |
| | self.assertFalse(data_is_loaded()) |
| |
|
| | def test_data_is_loaded_true_when_set(self): |
| | from utils import data_is_loaded |
| | import pandas as pd |
| | _mock_session_state["df"] = pd.DataFrame({"a": [1]}) |
| | self.assertTrue(data_is_loaded()) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | if __name__ == "__main__": |
| | |
| | print("=" * 70) |
| | print(" Comprehensive tests for new/modified modules in Prefero") |
| | print("=" * 70) |
| | print() |
| |
|
| | loader = unittest.TestLoader() |
| | suite = unittest.TestSuite() |
| |
|
| | |
| | for cls in [TestAuth, TestSessionQueue, TestWaitingFacts, |
| | TestGenerateTemplateExcel, TestUtils]: |
| | suite.addTests(loader.loadTestsFromTestCase(cls)) |
| |
|
| | runner = unittest.TextTestRunner(verbosity=2) |
| | result = runner.run(suite) |
| |
|
| | print() |
| | print("=" * 70) |
| | if result.wasSuccessful(): |
| | print(" ALL TESTS PASSED") |
| | else: |
| | print(f" FAILURES: {len(result.failures)} ERRORS: {len(result.errors)}") |
| | print("=" * 70) |
| |
|
| | sys.exit(0 if result.wasSuccessful() else 1) |
| |
|