File size: 8,642 Bytes
7169bc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
"""
Test cases for input validation module.
Tests edge cases and security scenarios.
"""

import unittest
import json
from src.validators import (
    validate_ticker,
    validate_period,
    validate_interval,
    validate_metric,
    validate_json_input
)


class TestValidators(unittest.TestCase):
    """Test cases for input validation functions."""
    
    # ========================================================================
    # TICKER VALIDATION TESTS
    # ========================================================================
    
    def test_validate_ticker_valid(self):
        """Test valid ticker symbols."""
        valid_tickers = ["AAPL", "TSLA", "MSFT", "GOOGL", "NVDA", "A", "AB", "ABC"]
        for ticker in valid_tickers:
            is_valid, sanitized, error = validate_ticker(ticker)
            self.assertTrue(is_valid, f"Ticker {ticker} should be valid")
            self.assertEqual(sanitized, ticker.upper())
            self.assertEqual(error, "")
    
    def test_validate_ticker_empty(self):
        """Test empty ticker input."""
        is_valid, sanitized, error = validate_ticker("")
        self.assertFalse(is_valid)
        self.assertEqual(error, "Ticker symbol is required")
    
    def test_validate_ticker_none(self):
        """Test None ticker input."""
        is_valid, sanitized, error = validate_ticker(None)
        self.assertFalse(is_valid)
    
    def test_validate_ticker_whitespace(self):
        """Test ticker with whitespace."""
        is_valid, sanitized, error = validate_ticker("  AAPL  ")
        self.assertTrue(is_valid)
        self.assertEqual(sanitized, "AAPL")
    
    def test_validate_ticker_lowercase(self):
        """Test ticker with lowercase letters."""
        is_valid, sanitized, error = validate_ticker("aapl")
        self.assertTrue(is_valid)
        self.assertEqual(sanitized, "AAPL")
    
    def test_validate_ticker_too_long(self):
        """Test ticker that's too long."""
        is_valid, sanitized, error = validate_ticker("ABCDEF")
        self.assertFalse(is_valid)
        self.assertIn("too long", error.lower())
    
    def test_validate_ticker_numbers(self):
        """Test ticker with numbers (invalid)."""
        is_valid, sanitized, error = validate_ticker("AAPL1")
        self.assertFalse(is_valid)
    
    def test_validate_ticker_special_chars(self):
        """Test ticker with special characters."""
        invalid_tickers = ["AAP-L", "AAP.L", "AAP@L", "AAP L"]
        for ticker in invalid_tickers:
            is_valid, sanitized, error = validate_ticker(ticker)
            self.assertFalse(is_valid, f"Ticker {ticker} should be invalid")
    
    def test_validate_ticker_sql_injection(self):
        """Test SQL injection attempts."""
        dangerous = ["'; DROP TABLE--", "AAPL;--", "AAPL/*", "*/DROP"]
        for ticker in dangerous:
            is_valid, sanitized, error = validate_ticker(ticker)
            self.assertFalse(is_valid, f"Should reject SQL injection: {ticker}")
    
    def test_validate_ticker_very_long_string(self):
        """Test very long string input."""
        long_string = "A" * 1000
        is_valid, sanitized, error = validate_ticker(long_string)
        self.assertFalse(is_valid)
    
    # ========================================================================
    # PERIOD VALIDATION TESTS
    # ========================================================================
    
    def test_validate_period_valid(self):
        """Test valid periods."""
        valid_periods = ["1d", "5d", "1mo", "3mo", "6mo", "1y", "2y", "5y", "10y", "ytd", "max"]
        for period in valid_periods:
            is_valid, sanitized, error = validate_period(period)
            self.assertTrue(is_valid, f"Period {period} should be valid")
    
    def test_validate_period_invalid(self):
        """Test invalid periods."""
        invalid_periods = ["1w", "2mo", "invalid", "1day", ""]
        for period in invalid_periods:
            is_valid, sanitized, error = validate_period(period)
            self.assertFalse(is_valid, f"Period {period} should be invalid")
    
    # ========================================================================
    # INTERVAL VALIDATION TESTS
    # ========================================================================
    
    def test_validate_interval_valid(self):
        """Test valid intervals."""
        valid_intervals = ["1m", "5m", "15m", "1h", "1d", "1wk", "1mo"]
        for interval in valid_intervals:
            is_valid, sanitized, error = validate_interval(interval)
            self.assertTrue(is_valid, f"Interval {interval} should be valid")
    
    def test_validate_interval_invalid(self):
        """Test invalid intervals."""
        invalid_intervals = ["1min", "hour", "invalid", ""]
        for interval in invalid_intervals:
            is_valid, sanitized, error = validate_interval(interval)
            self.assertFalse(is_valid, f"Interval {interval} should be invalid")
    
    # ========================================================================
    # METRIC VALIDATION TESTS
    # ========================================================================
    
    def test_validate_metric_valid(self):
        """Test valid metrics."""
        valid_metrics = ["performance", "valuation", "volatility"]
        for metric in valid_metrics:
            is_valid, sanitized, error = validate_metric(metric)
            self.assertTrue(is_valid, f"Metric {metric} should be valid")
    
    def test_validate_metric_invalid(self):
        """Test invalid metrics."""
        invalid_metrics = ["price", "volume", "invalid", ""]
        for metric in invalid_metrics:
            is_valid, sanitized, error = validate_metric(metric)
            self.assertFalse(is_valid, f"Metric {metric} should be invalid")
    
    # ========================================================================
    # JSON VALIDATION TESTS
    # ========================================================================
    
    def test_validate_json_valid(self):
        """Test valid JSON portfolio."""
        valid_json = '{"AAPL": {"shares": 10, "cost_basis": 150}}'
        is_valid, data, error = validate_json_input(valid_json)
        self.assertTrue(is_valid)
        self.assertIsInstance(data, dict)
        self.assertEqual(error, "")
    
    def test_validate_json_invalid_format(self):
        """Test invalid JSON format."""
        invalid_json = '{"AAPL": {"shares": 10}'  # Missing closing brace
        is_valid, data, error = validate_json_input(invalid_json)
        self.assertFalse(is_valid)
        self.assertIn("JSON format", error)
    
    def test_validate_json_not_dict(self):
        """Test JSON that's not a dictionary."""
        invalid_json = '[1, 2, 3]'  # Array instead of object
        is_valid, data, error = validate_json_input(invalid_json)
        self.assertFalse(is_valid)
        self.assertIn("object/dictionary", error)
    
    def test_validate_json_too_many_tickers(self):
        """Test portfolio with too many tickers."""
        # Create JSON with 6 tickers (max is 5)
        tickers = {f"TICK{i}": {"shares": 1, "cost_basis": 100} for i in range(6)}
        invalid_json = json.dumps(tickers)
        is_valid, data, error = validate_json_input(invalid_json)
        self.assertFalse(is_valid)
        self.assertIn("Too many tickers", error)
    
    def test_validate_json_invalid_ticker(self):
        """Test portfolio with invalid ticker."""
        invalid_json = '{"INVALID123": {"shares": 10, "cost_basis": 150}}'
        is_valid, data, error = validate_json_input(invalid_json)
        self.assertFalse(is_valid)
        self.assertIn("Invalid ticker", error)
    
    def test_validate_json_empty(self):
        """Test empty JSON."""
        is_valid, data, error = validate_json_input("{}")
        self.assertTrue(is_valid)  # Empty dict is valid
    
    def test_validate_json_malformed(self):
        """Test malformed JSON strings."""
        malformed = ['{', '}', '{invalid}', 'null', 'true', '123']
        for json_str in malformed:
            is_valid, data, error = validate_json_input(json_str)
            self.assertFalse(is_valid, f"Should reject malformed JSON: {json_str}")
    
    def test_validate_json_sql_injection_in_ticker(self):
        """Test JSON with SQL injection in ticker."""
        dangerous_json = '{"DROP TABLE": {"shares": 10, "cost_basis": 150}}'
        is_valid, data, error = validate_json_input(dangerous_json)
        self.assertFalse(is_valid)


if __name__ == '__main__':
    unittest.main()