File size: 3,606 Bytes
97f6726
 
 
 
 
82e7fcd
97f6726
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82e7fcd
 
 
 
 
 
 
 
 
 
 
fa65b59
82e7fcd
 
 
97f6726
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
"""Unit tests for ticker_validator module."""

import sys
import os
import unittest
from unittest.mock import patch

try:
    import pytest
    _slow = pytest.mark.slow
except ImportError:
    # pytest not installed – define a no-op decorator so the file loads cleanly
    def _slow(cls):
        return cls

sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))

from ticker_validator import validate_ticker, validate_ticker_format, TickerValidationResult


# ---------------------------------------------------------------------------
# Format-only tests (no network, no DB)
# ---------------------------------------------------------------------------

class TestFormatValidation(unittest.TestCase):

    def test_invalid_format_numbers(self):
        result = validate_ticker_format("123ABC")
        self.assertFalse(result.valid)
        self.assertIn("not a valid ticker format", result.error)

    def test_invalid_format_too_long(self):
        result = validate_ticker_format("ABCDEF")
        self.assertFalse(result.valid)
        self.assertIn("not a valid ticker format", result.error)

    def test_invalid_format_empty(self):
        for value in ("", "   "):
            result = validate_ticker_format(value)
            self.assertFalse(result.valid)
            self.assertIn("Please enter", result.error)

    def test_reserved_word(self):
        for word in ("TEST", "NULL"):
            result = validate_ticker_format(word)
            self.assertFalse(result.valid)
            self.assertIn("reserved word", result.error)

    def test_ticker_normalization(self):
        """Lowercase and padded tickers should normalise cleanly."""
        for raw in ("aapl", " AAPL ", "Aapl"):
            result = validate_ticker_format(raw)
            self.assertTrue(result.valid)
            self.assertEqual(result.ticker, "AAPL")


# ---------------------------------------------------------------------------
# Full-stack tests (hit yfinance – mark slow for CI skipping)
# ---------------------------------------------------------------------------

@_slow
class TestFullValidation(unittest.TestCase):

    def test_valid_ticker_aapl(self):
        result = validate_ticker("AAPL")
        self.assertTrue(result.valid)
        self.assertIn("Apple", result.company_name)
        self.assertIn(result.source, ("api", "local_db", "cache"))

    def test_nonexistent_ticker(self):
        result = validate_ticker("XYZQW")
        self.assertFalse(result.valid)
        self.assertNotEqual(result.error, "")

    def test_result_has_suggestions(self):
        """A close-but-wrong ticker should surface suggestions."""
        result = validate_ticker("AAPLL")
        # Either invalid with suggestions or (edge case) valid – just check structure
        self.assertIsInstance(result.suggestions, list)
        if not result.valid:
            self.assertTrue(len(result.suggestions) > 0 or result.error != "")


class TestLocalDbFastPath(unittest.TestCase):

    def test_known_ticker_skips_yfinance_lookup(self):
        with patch("ticker_validator.is_known_ticker", return_value=True), \
             patch("ticker_validator.get_company_name", return_value="Apple Inc."), \
             patch("ticker_validator.yf.Ticker") as mock_ticker:
            result = validate_ticker("AAPL")

        self.assertTrue(result.valid)
        self.assertEqual(result.company_name, "Apple Inc.")
        self.assertEqual(result.source, "local_db")
        self.assertFalse(result.warning)
        mock_ticker.assert_not_called()


if __name__ == "__main__":
    unittest.main()