"""Tests for ticker DB refresh, startup checks, and health endpoint DB-age fields.""" import json import os import sys import time import tempfile import unittest from unittest.mock import patch, MagicMock sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import ticker_db def _make_sec_response(tickers: dict[str, str]) -> MagicMock: """Build a mock requests.Response for the SEC tickers endpoint. *tickers* maps ticker symbol → company title. """ payload = { str(i): {"ticker": sym, "title": name} for i, (sym, name) in enumerate(tickers.items()) } mock_resp = MagicMock() mock_resp.raise_for_status = MagicMock() mock_resp.json.return_value = payload return mock_resp class TestStartupCreatesDbIfMissing(unittest.TestCase): def test_startup_creates_db_if_missing(self): """If valid_tickers.json doesn't exist, run_startup_checks() must create it.""" with tempfile.TemporaryDirectory() as tmp_dir: data_file = os.path.join(tmp_dir, "valid_tickers.json") name_file = os.path.join(tmp_dir, "ticker_names.json") raw_file = os.path.join(tmp_dir, "sec_company_tickers.json") lock_file = os.path.join(tmp_dir, "valid_tickers.lock") sec_data = {"AAPL": "Apple Inc.", "MSFT": "Microsoft Corp", "NVDA": "NVIDIA"} with ( patch.object(ticker_db, "_DATA_FILE", data_file), patch.object(ticker_db, "_NAME_FILE", name_file), patch.object(ticker_db, "_RAW_SEC_FILE", raw_file), patch.object(ticker_db, "_LOCK_FILE", lock_file), patch.object(ticker_db, "_MIN_TICKER_COUNT", 1), patch.object(ticker_db, "_ticker_cache", None), patch.object(ticker_db, "_name_cache", None), patch("ticker_db.requests.get", return_value=_make_sec_response(sec_data)), ): self.assertFalse(os.path.exists(data_file), "Pre-condition: file must not exist") ticker_db.run_startup_checks() self.assertTrue(os.path.exists(data_file), "DB file should be created after startup check") self.assertTrue(os.path.exists(raw_file), "Raw SEC snapshot should be created after startup check") with open(data_file, encoding="utf-8") as f: loaded = set(json.load(f)) self.assertEqual(loaded, {"AAPL", "MSFT", "NVDA"}) class TestHealthReportsDbAge(unittest.TestCase): def setUp(self): import app as flask_app flask_app.app.config["TESTING"] = True self.client = flask_app.app.test_client() def test_health_reports_db_age(self): """GET /api/health must include ticker_db_age_hours and ticker_db_stale.""" resp = self.client.get("/api/health") self.assertEqual(resp.status_code, 200) data = resp.get_json() self.assertIn("ticker_db_age_hours", data) self.assertIn("ticker_db_stale", data) self.assertIn("ticker_db_loaded", data) self.assertIn("ticker_count", data) # age_hours must be a non-negative number (or None if file missing) age = data["ticker_db_age_hours"] if age is not None: self.assertIsInstance(age, (int, float)) self.assertGreaterEqual(age, 0) self.assertIsInstance(data["ticker_db_stale"], bool) def test_health_stale_flag_false_for_fresh_file(self): """ticker_db_stale must be False when the file was just written.""" # Touch the data file to make it appear fresh with patch( "ticker_db.get_db_file_age_hours", return_value=1.0 ), patch( "ticker_db.is_db_stale", return_value=False ): resp = self.client.get("/api/health") self.assertEqual(resp.status_code, 200) data = resp.get_json() self.assertFalse(data["ticker_db_stale"]) class TestRefreshUpdatesData(unittest.TestCase): def test_refresh_updates_data(self): """refresh_ticker_db() must detect added/removed tickers and update caches.""" with tempfile.TemporaryDirectory() as tmp_dir: data_file = os.path.join(tmp_dir, "valid_tickers.json") name_file = os.path.join(tmp_dir, "ticker_names.json") raw_file = os.path.join(tmp_dir, "sec_company_tickers.json") lock_file = os.path.join(tmp_dir, "valid_tickers.lock") # Seed an initial DB with 3 tickers initial = ["AAPL", "MSFT", "GOOG"] with open(data_file, "w", encoding="utf-8") as f: json.dump(initial, f) # New SEC data: removed GOOG, added NVDA and AMZN new_sec_data = { "AAPL": "Apple Inc.", "MSFT": "Microsoft Corp", "NVDA": "NVIDIA Corp", "AMZN": "Amazon.com Inc.", } with ( patch.object(ticker_db, "_DATA_FILE", data_file), patch.object(ticker_db, "_NAME_FILE", name_file), patch.object(ticker_db, "_RAW_SEC_FILE", raw_file), patch.object(ticker_db, "_LOCK_FILE", lock_file), patch.object(ticker_db, "_MIN_TICKER_COUNT", 1), patch.object(ticker_db, "_ticker_cache", set(initial)), patch.object(ticker_db, "_name_cache", None), patch("ticker_db.requests.get", return_value=_make_sec_response(new_sec_data)), ): result = ticker_db.refresh_ticker_db() self.assertEqual(result["status"], "ok") self.assertEqual(result["ticker_count"], 4) self.assertEqual(result["added"], 2, f"Expected 2 added, got {result['added']}") self.assertEqual(result["removed"], 1, f"Expected 1 removed, got {result['removed']}") # Verify the file on disk reflects the new set with open(data_file, encoding="utf-8") as f: on_disk = set(json.load(f)) self.assertEqual(on_disk, {"AAPL", "MSFT", "NVDA", "AMZN"}) self.assertTrue(os.path.exists(raw_file), "Refresh should persist the raw SEC snapshot") if __name__ == "__main__": unittest.main()