""" Test cases for input validation module. Tests edge cases and security scenarios. """ import unittest import json from src.validators import ( validate_ticker, validate_period, validate_interval, validate_metric, validate_json_input ) class TestValidators(unittest.TestCase): """Test cases for input validation functions.""" # ======================================================================== # TICKER VALIDATION TESTS # ======================================================================== def test_validate_ticker_valid(self): """Test valid ticker symbols.""" valid_tickers = ["AAPL", "TSLA", "MSFT", "GOOGL", "NVDA", "A", "AB", "ABC"] for ticker in valid_tickers: is_valid, sanitized, error = validate_ticker(ticker) self.assertTrue(is_valid, f"Ticker {ticker} should be valid") self.assertEqual(sanitized, ticker.upper()) self.assertEqual(error, "") def test_validate_ticker_empty(self): """Test empty ticker input.""" is_valid, sanitized, error = validate_ticker("") self.assertFalse(is_valid) self.assertEqual(error, "Ticker symbol is required") def test_validate_ticker_none(self): """Test None ticker input.""" is_valid, sanitized, error = validate_ticker(None) self.assertFalse(is_valid) def test_validate_ticker_whitespace(self): """Test ticker with whitespace.""" is_valid, sanitized, error = validate_ticker(" AAPL ") self.assertTrue(is_valid) self.assertEqual(sanitized, "AAPL") def test_validate_ticker_lowercase(self): """Test ticker with lowercase letters.""" is_valid, sanitized, error = validate_ticker("aapl") self.assertTrue(is_valid) self.assertEqual(sanitized, "AAPL") def test_validate_ticker_too_long(self): """Test ticker that's too long.""" is_valid, sanitized, error = validate_ticker("ABCDEF") self.assertFalse(is_valid) self.assertIn("too long", error.lower()) def test_validate_ticker_numbers(self): """Test ticker with numbers (invalid).""" is_valid, sanitized, error = validate_ticker("AAPL1") self.assertFalse(is_valid) def test_validate_ticker_special_chars(self): """Test ticker with special characters.""" invalid_tickers = ["AAP-L", "AAP.L", "AAP@L", "AAP L"] for ticker in invalid_tickers: is_valid, sanitized, error = validate_ticker(ticker) self.assertFalse(is_valid, f"Ticker {ticker} should be invalid") def test_validate_ticker_sql_injection(self): """Test SQL injection attempts.""" dangerous = ["'; DROP TABLE--", "AAPL;--", "AAPL/*", "*/DROP"] for ticker in dangerous: is_valid, sanitized, error = validate_ticker(ticker) self.assertFalse(is_valid, f"Should reject SQL injection: {ticker}") def test_validate_ticker_very_long_string(self): """Test very long string input.""" long_string = "A" * 1000 is_valid, sanitized, error = validate_ticker(long_string) self.assertFalse(is_valid) # ======================================================================== # PERIOD VALIDATION TESTS # ======================================================================== def test_validate_period_valid(self): """Test valid periods.""" valid_periods = ["1d", "5d", "1mo", "3mo", "6mo", "1y", "2y", "5y", "10y", "ytd", "max"] for period in valid_periods: is_valid, sanitized, error = validate_period(period) self.assertTrue(is_valid, f"Period {period} should be valid") def test_validate_period_invalid(self): """Test invalid periods.""" invalid_periods = ["1w", "2mo", "invalid", "1day", ""] for period in invalid_periods: is_valid, sanitized, error = validate_period(period) self.assertFalse(is_valid, f"Period {period} should be invalid") # ======================================================================== # INTERVAL VALIDATION TESTS # ======================================================================== def test_validate_interval_valid(self): """Test valid intervals.""" valid_intervals = ["1m", "5m", "15m", "1h", "1d", "1wk", "1mo"] for interval in valid_intervals: is_valid, sanitized, error = validate_interval(interval) self.assertTrue(is_valid, f"Interval {interval} should be valid") def test_validate_interval_invalid(self): """Test invalid intervals.""" invalid_intervals = ["1min", "hour", "invalid", ""] for interval in invalid_intervals: is_valid, sanitized, error = validate_interval(interval) self.assertFalse(is_valid, f"Interval {interval} should be invalid") # ======================================================================== # METRIC VALIDATION TESTS # ======================================================================== def test_validate_metric_valid(self): """Test valid metrics.""" valid_metrics = ["performance", "valuation", "volatility"] for metric in valid_metrics: is_valid, sanitized, error = validate_metric(metric) self.assertTrue(is_valid, f"Metric {metric} should be valid") def test_validate_metric_invalid(self): """Test invalid metrics.""" invalid_metrics = ["price", "volume", "invalid", ""] for metric in invalid_metrics: is_valid, sanitized, error = validate_metric(metric) self.assertFalse(is_valid, f"Metric {metric} should be invalid") # ======================================================================== # JSON VALIDATION TESTS # ======================================================================== def test_validate_json_valid(self): """Test valid JSON portfolio.""" valid_json = '{"AAPL": {"shares": 10, "cost_basis": 150}}' is_valid, data, error = validate_json_input(valid_json) self.assertTrue(is_valid) self.assertIsInstance(data, dict) self.assertEqual(error, "") def test_validate_json_invalid_format(self): """Test invalid JSON format.""" invalid_json = '{"AAPL": {"shares": 10}' # Missing closing brace is_valid, data, error = validate_json_input(invalid_json) self.assertFalse(is_valid) self.assertIn("JSON format", error) def test_validate_json_not_dict(self): """Test JSON that's not a dictionary.""" invalid_json = '[1, 2, 3]' # Array instead of object is_valid, data, error = validate_json_input(invalid_json) self.assertFalse(is_valid) self.assertIn("object/dictionary", error) def test_validate_json_too_many_tickers(self): """Test portfolio with too many tickers.""" # Create JSON with 6 tickers (max is 5) tickers = {f"TICK{i}": {"shares": 1, "cost_basis": 100} for i in range(6)} invalid_json = json.dumps(tickers) is_valid, data, error = validate_json_input(invalid_json) self.assertFalse(is_valid) self.assertIn("Too many tickers", error) def test_validate_json_invalid_ticker(self): """Test portfolio with invalid ticker.""" invalid_json = '{"INVALID123": {"shares": 10, "cost_basis": 150}}' is_valid, data, error = validate_json_input(invalid_json) self.assertFalse(is_valid) self.assertIn("Invalid ticker", error) def test_validate_json_empty(self): """Test empty JSON.""" is_valid, data, error = validate_json_input("{}") self.assertTrue(is_valid) # Empty dict is valid def test_validate_json_malformed(self): """Test malformed JSON strings.""" malformed = ['{', '}', '{invalid}', 'null', 'true', '123'] for json_str in malformed: is_valid, data, error = validate_json_input(json_str) self.assertFalse(is_valid, f"Should reject malformed JSON: {json_str}") def test_validate_json_sql_injection_in_ticker(self): """Test JSON with SQL injection in ticker.""" dangerous_json = '{"DROP TABLE": {"shares": 10, "cost_basis": 150}}' is_valid, data, error = validate_json_input(dangerous_json) self.assertFalse(is_valid) if __name__ == '__main__': unittest.main()