File size: 6,313 Bytes
fb8e94c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82e7fcd
fb8e94c
 
 
 
 
 
 
82e7fcd
fb8e94c
82e7fcd
fb8e94c
 
 
 
 
 
 
82e7fcd
fb8e94c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82e7fcd
fb8e94c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82e7fcd
fb8e94c
82e7fcd
fb8e94c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82e7fcd
fb8e94c
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
"""Tests for ticker DB refresh, startup checks, and health endpoint DB-age fields."""

import json
import os
import sys
import time
import tempfile
import unittest
from unittest.mock import patch, MagicMock

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import ticker_db


def _make_sec_response(tickers: dict[str, str]) -> MagicMock:
    """Build a mock requests.Response for the SEC tickers endpoint.

    *tickers* maps ticker symbol → company title.
    """
    payload = {
        str(i): {"ticker": sym, "title": name}
        for i, (sym, name) in enumerate(tickers.items())
    }
    mock_resp = MagicMock()
    mock_resp.raise_for_status = MagicMock()
    mock_resp.json.return_value = payload
    return mock_resp


class TestStartupCreatesDbIfMissing(unittest.TestCase):

    def test_startup_creates_db_if_missing(self):
        """If valid_tickers.json doesn't exist, run_startup_checks() must create it."""
        with tempfile.TemporaryDirectory() as tmp_dir:
            data_file = os.path.join(tmp_dir, "valid_tickers.json")
            name_file = os.path.join(tmp_dir, "ticker_names.json")
            raw_file = os.path.join(tmp_dir, "sec_company_tickers.json")
            lock_file = os.path.join(tmp_dir, "valid_tickers.lock")

            sec_data = {"AAPL": "Apple Inc.", "MSFT": "Microsoft Corp", "NVDA": "NVIDIA"}

            with (
                patch.object(ticker_db, "_DATA_FILE", data_file),
                patch.object(ticker_db, "_NAME_FILE", name_file),
                patch.object(ticker_db, "_RAW_SEC_FILE", raw_file),
                patch.object(ticker_db, "_LOCK_FILE", lock_file),
                patch.object(ticker_db, "_MIN_TICKER_COUNT", 1),
                patch.object(ticker_db, "_ticker_cache", None),
                patch.object(ticker_db, "_name_cache", None),
                patch("ticker_db.requests.get", return_value=_make_sec_response(sec_data)),
            ):
                self.assertFalse(os.path.exists(data_file), "Pre-condition: file must not exist")
                ticker_db.run_startup_checks()
                self.assertTrue(os.path.exists(data_file), "DB file should be created after startup check")
                self.assertTrue(os.path.exists(raw_file), "Raw SEC snapshot should be created after startup check")

                with open(data_file, encoding="utf-8") as f:
                    loaded = set(json.load(f))
                self.assertEqual(loaded, {"AAPL", "MSFT", "NVDA"})


class TestHealthReportsDbAge(unittest.TestCase):

    def setUp(self):
        import app as flask_app
        flask_app.app.config["TESTING"] = True
        self.client = flask_app.app.test_client()

    def test_health_reports_db_age(self):
        """GET /api/health must include ticker_db_age_hours and ticker_db_stale."""
        resp = self.client.get("/api/health")
        self.assertEqual(resp.status_code, 200)
        data = resp.get_json()

        self.assertIn("ticker_db_age_hours", data)
        self.assertIn("ticker_db_stale", data)
        self.assertIn("ticker_db_loaded", data)
        self.assertIn("ticker_count", data)

        # age_hours must be a non-negative number (or None if file missing)
        age = data["ticker_db_age_hours"]
        if age is not None:
            self.assertIsInstance(age, (int, float))
            self.assertGreaterEqual(age, 0)

        self.assertIsInstance(data["ticker_db_stale"], bool)

    def test_health_stale_flag_false_for_fresh_file(self):
        """ticker_db_stale must be False when the file was just written."""
        # Touch the data file to make it appear fresh
        with patch(
            "ticker_db.get_db_file_age_hours", return_value=1.0
        ), patch(
            "ticker_db.is_db_stale", return_value=False
        ):
            resp = self.client.get("/api/health")
        self.assertEqual(resp.status_code, 200)
        data = resp.get_json()
        self.assertFalse(data["ticker_db_stale"])


class TestRefreshUpdatesData(unittest.TestCase):

    def test_refresh_updates_data(self):
        """refresh_ticker_db() must detect added/removed tickers and update caches."""
        with tempfile.TemporaryDirectory() as tmp_dir:
            data_file = os.path.join(tmp_dir, "valid_tickers.json")
            name_file = os.path.join(tmp_dir, "ticker_names.json")
            raw_file = os.path.join(tmp_dir, "sec_company_tickers.json")
            lock_file = os.path.join(tmp_dir, "valid_tickers.lock")

            # Seed an initial DB with 3 tickers
            initial = ["AAPL", "MSFT", "GOOG"]
            with open(data_file, "w", encoding="utf-8") as f:
                json.dump(initial, f)

            # New SEC data: removed GOOG, added NVDA and AMZN
            new_sec_data = {
                "AAPL": "Apple Inc.",
                "MSFT": "Microsoft Corp",
                "NVDA": "NVIDIA Corp",
                "AMZN": "Amazon.com Inc.",
            }

            with (
                patch.object(ticker_db, "_DATA_FILE", data_file),
                patch.object(ticker_db, "_NAME_FILE", name_file),
                patch.object(ticker_db, "_RAW_SEC_FILE", raw_file),
                patch.object(ticker_db, "_LOCK_FILE", lock_file),
                patch.object(ticker_db, "_MIN_TICKER_COUNT", 1),
                patch.object(ticker_db, "_ticker_cache", set(initial)),
                patch.object(ticker_db, "_name_cache", None),
                patch("ticker_db.requests.get", return_value=_make_sec_response(new_sec_data)),
            ):
                result = ticker_db.refresh_ticker_db()

            self.assertEqual(result["status"], "ok")
            self.assertEqual(result["ticker_count"], 4)
            self.assertEqual(result["added"], 2,    f"Expected 2 added, got {result['added']}")
            self.assertEqual(result["removed"], 1,  f"Expected 1 removed, got {result['removed']}")

            # Verify the file on disk reflects the new set
            with open(data_file, encoding="utf-8") as f:
                on_disk = set(json.load(f))
            self.assertEqual(on_disk, {"AAPL", "MSFT", "NVDA", "AMZN"})
            self.assertTrue(os.path.exists(raw_file), "Refresh should persist the raw SEC snapshot")


if __name__ == "__main__":
    unittest.main()