|
|
""" |
|
|
Test cases for input validation module. |
|
|
Tests edge cases and security scenarios. |
|
|
""" |
|
|
|
|
|
import unittest |
|
|
import json |
|
|
from src.validators import ( |
|
|
validate_ticker, |
|
|
validate_period, |
|
|
validate_interval, |
|
|
validate_metric, |
|
|
validate_json_input |
|
|
) |
|
|
|
|
|
|
|
|
class TestValidators(unittest.TestCase): |
|
|
"""Test cases for input validation functions.""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_validate_ticker_valid(self): |
|
|
"""Test valid ticker symbols.""" |
|
|
valid_tickers = ["AAPL", "TSLA", "MSFT", "GOOGL", "NVDA", "A", "AB", "ABC"] |
|
|
for ticker in valid_tickers: |
|
|
is_valid, sanitized, error = validate_ticker(ticker) |
|
|
self.assertTrue(is_valid, f"Ticker {ticker} should be valid") |
|
|
self.assertEqual(sanitized, ticker.upper()) |
|
|
self.assertEqual(error, "") |
|
|
|
|
|
def test_validate_ticker_empty(self): |
|
|
"""Test empty ticker input.""" |
|
|
is_valid, sanitized, error = validate_ticker("") |
|
|
self.assertFalse(is_valid) |
|
|
self.assertEqual(error, "Ticker symbol is required") |
|
|
|
|
|
def test_validate_ticker_none(self): |
|
|
"""Test None ticker input.""" |
|
|
is_valid, sanitized, error = validate_ticker(None) |
|
|
self.assertFalse(is_valid) |
|
|
|
|
|
def test_validate_ticker_whitespace(self): |
|
|
"""Test ticker with whitespace.""" |
|
|
is_valid, sanitized, error = validate_ticker(" AAPL ") |
|
|
self.assertTrue(is_valid) |
|
|
self.assertEqual(sanitized, "AAPL") |
|
|
|
|
|
def test_validate_ticker_lowercase(self): |
|
|
"""Test ticker with lowercase letters.""" |
|
|
is_valid, sanitized, error = validate_ticker("aapl") |
|
|
self.assertTrue(is_valid) |
|
|
self.assertEqual(sanitized, "AAPL") |
|
|
|
|
|
def test_validate_ticker_too_long(self): |
|
|
"""Test ticker that's too long.""" |
|
|
is_valid, sanitized, error = validate_ticker("ABCDEF") |
|
|
self.assertFalse(is_valid) |
|
|
self.assertIn("too long", error.lower()) |
|
|
|
|
|
def test_validate_ticker_numbers(self): |
|
|
"""Test ticker with numbers (invalid).""" |
|
|
is_valid, sanitized, error = validate_ticker("AAPL1") |
|
|
self.assertFalse(is_valid) |
|
|
|
|
|
def test_validate_ticker_special_chars(self): |
|
|
"""Test ticker with special characters.""" |
|
|
invalid_tickers = ["AAP-L", "AAP.L", "AAP@L", "AAP L"] |
|
|
for ticker in invalid_tickers: |
|
|
is_valid, sanitized, error = validate_ticker(ticker) |
|
|
self.assertFalse(is_valid, f"Ticker {ticker} should be invalid") |
|
|
|
|
|
def test_validate_ticker_sql_injection(self): |
|
|
"""Test SQL injection attempts.""" |
|
|
dangerous = ["'; DROP TABLE--", "AAPL;--", "AAPL/*", "*/DROP"] |
|
|
for ticker in dangerous: |
|
|
is_valid, sanitized, error = validate_ticker(ticker) |
|
|
self.assertFalse(is_valid, f"Should reject SQL injection: {ticker}") |
|
|
|
|
|
def test_validate_ticker_very_long_string(self): |
|
|
"""Test very long string input.""" |
|
|
long_string = "A" * 1000 |
|
|
is_valid, sanitized, error = validate_ticker(long_string) |
|
|
self.assertFalse(is_valid) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_validate_period_valid(self): |
|
|
"""Test valid periods.""" |
|
|
valid_periods = ["1d", "5d", "1mo", "3mo", "6mo", "1y", "2y", "5y", "10y", "ytd", "max"] |
|
|
for period in valid_periods: |
|
|
is_valid, sanitized, error = validate_period(period) |
|
|
self.assertTrue(is_valid, f"Period {period} should be valid") |
|
|
|
|
|
def test_validate_period_invalid(self): |
|
|
"""Test invalid periods.""" |
|
|
invalid_periods = ["1w", "2mo", "invalid", "1day", ""] |
|
|
for period in invalid_periods: |
|
|
is_valid, sanitized, error = validate_period(period) |
|
|
self.assertFalse(is_valid, f"Period {period} should be invalid") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_validate_interval_valid(self): |
|
|
"""Test valid intervals.""" |
|
|
valid_intervals = ["1m", "5m", "15m", "1h", "1d", "1wk", "1mo"] |
|
|
for interval in valid_intervals: |
|
|
is_valid, sanitized, error = validate_interval(interval) |
|
|
self.assertTrue(is_valid, f"Interval {interval} should be valid") |
|
|
|
|
|
def test_validate_interval_invalid(self): |
|
|
"""Test invalid intervals.""" |
|
|
invalid_intervals = ["1min", "hour", "invalid", ""] |
|
|
for interval in invalid_intervals: |
|
|
is_valid, sanitized, error = validate_interval(interval) |
|
|
self.assertFalse(is_valid, f"Interval {interval} should be invalid") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_validate_metric_valid(self): |
|
|
"""Test valid metrics.""" |
|
|
valid_metrics = ["performance", "valuation", "volatility"] |
|
|
for metric in valid_metrics: |
|
|
is_valid, sanitized, error = validate_metric(metric) |
|
|
self.assertTrue(is_valid, f"Metric {metric} should be valid") |
|
|
|
|
|
def test_validate_metric_invalid(self): |
|
|
"""Test invalid metrics.""" |
|
|
invalid_metrics = ["price", "volume", "invalid", ""] |
|
|
for metric in invalid_metrics: |
|
|
is_valid, sanitized, error = validate_metric(metric) |
|
|
self.assertFalse(is_valid, f"Metric {metric} should be invalid") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_validate_json_valid(self): |
|
|
"""Test valid JSON portfolio.""" |
|
|
valid_json = '{"AAPL": {"shares": 10, "cost_basis": 150}}' |
|
|
is_valid, data, error = validate_json_input(valid_json) |
|
|
self.assertTrue(is_valid) |
|
|
self.assertIsInstance(data, dict) |
|
|
self.assertEqual(error, "") |
|
|
|
|
|
def test_validate_json_invalid_format(self): |
|
|
"""Test invalid JSON format.""" |
|
|
invalid_json = '{"AAPL": {"shares": 10}' |
|
|
is_valid, data, error = validate_json_input(invalid_json) |
|
|
self.assertFalse(is_valid) |
|
|
self.assertIn("JSON format", error) |
|
|
|
|
|
def test_validate_json_not_dict(self): |
|
|
"""Test JSON that's not a dictionary.""" |
|
|
invalid_json = '[1, 2, 3]' |
|
|
is_valid, data, error = validate_json_input(invalid_json) |
|
|
self.assertFalse(is_valid) |
|
|
self.assertIn("object/dictionary", error) |
|
|
|
|
|
def test_validate_json_too_many_tickers(self): |
|
|
"""Test portfolio with too many tickers.""" |
|
|
|
|
|
tickers = {f"TICK{i}": {"shares": 1, "cost_basis": 100} for i in range(6)} |
|
|
invalid_json = json.dumps(tickers) |
|
|
is_valid, data, error = validate_json_input(invalid_json) |
|
|
self.assertFalse(is_valid) |
|
|
self.assertIn("Too many tickers", error) |
|
|
|
|
|
def test_validate_json_invalid_ticker(self): |
|
|
"""Test portfolio with invalid ticker.""" |
|
|
invalid_json = '{"INVALID123": {"shares": 10, "cost_basis": 150}}' |
|
|
is_valid, data, error = validate_json_input(invalid_json) |
|
|
self.assertFalse(is_valid) |
|
|
self.assertIn("Invalid ticker", error) |
|
|
|
|
|
def test_validate_json_empty(self): |
|
|
"""Test empty JSON.""" |
|
|
is_valid, data, error = validate_json_input("{}") |
|
|
self.assertTrue(is_valid) |
|
|
|
|
|
def test_validate_json_malformed(self): |
|
|
"""Test malformed JSON strings.""" |
|
|
malformed = ['{', '}', '{invalid}', 'null', 'true', '123'] |
|
|
for json_str in malformed: |
|
|
is_valid, data, error = validate_json_input(json_str) |
|
|
self.assertFalse(is_valid, f"Should reject malformed JSON: {json_str}") |
|
|
|
|
|
def test_validate_json_sql_injection_in_ticker(self): |
|
|
"""Test JSON with SQL injection in ticker.""" |
|
|
dangerous_json = '{"DROP TABLE": {"shares": 10, "cost_basis": 150}}' |
|
|
is_valid, data, error = validate_json_input(dangerous_json) |
|
|
self.assertFalse(is_valid) |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
unittest.main() |
|
|
|
|
|
|