finance-data-mcp / tests /test_validators.py
dlrklc's picture
Initial commit: Gradio MCP app for real-time financial data
7169bc5
"""
Test cases for input validation module.
Tests edge cases and security scenarios.
"""
import unittest
import json
from src.validators import (
validate_ticker,
validate_period,
validate_interval,
validate_metric,
validate_json_input
)
class TestValidators(unittest.TestCase):
"""Test cases for input validation functions."""
# ========================================================================
# TICKER VALIDATION TESTS
# ========================================================================
def test_validate_ticker_valid(self):
"""Test valid ticker symbols."""
valid_tickers = ["AAPL", "TSLA", "MSFT", "GOOGL", "NVDA", "A", "AB", "ABC"]
for ticker in valid_tickers:
is_valid, sanitized, error = validate_ticker(ticker)
self.assertTrue(is_valid, f"Ticker {ticker} should be valid")
self.assertEqual(sanitized, ticker.upper())
self.assertEqual(error, "")
def test_validate_ticker_empty(self):
"""Test empty ticker input."""
is_valid, sanitized, error = validate_ticker("")
self.assertFalse(is_valid)
self.assertEqual(error, "Ticker symbol is required")
def test_validate_ticker_none(self):
"""Test None ticker input."""
is_valid, sanitized, error = validate_ticker(None)
self.assertFalse(is_valid)
def test_validate_ticker_whitespace(self):
"""Test ticker with whitespace."""
is_valid, sanitized, error = validate_ticker(" AAPL ")
self.assertTrue(is_valid)
self.assertEqual(sanitized, "AAPL")
def test_validate_ticker_lowercase(self):
"""Test ticker with lowercase letters."""
is_valid, sanitized, error = validate_ticker("aapl")
self.assertTrue(is_valid)
self.assertEqual(sanitized, "AAPL")
def test_validate_ticker_too_long(self):
"""Test ticker that's too long."""
is_valid, sanitized, error = validate_ticker("ABCDEF")
self.assertFalse(is_valid)
self.assertIn("too long", error.lower())
def test_validate_ticker_numbers(self):
"""Test ticker with numbers (invalid)."""
is_valid, sanitized, error = validate_ticker("AAPL1")
self.assertFalse(is_valid)
def test_validate_ticker_special_chars(self):
"""Test ticker with special characters."""
invalid_tickers = ["AAP-L", "AAP.L", "AAP@L", "AAP L"]
for ticker in invalid_tickers:
is_valid, sanitized, error = validate_ticker(ticker)
self.assertFalse(is_valid, f"Ticker {ticker} should be invalid")
def test_validate_ticker_sql_injection(self):
"""Test SQL injection attempts."""
dangerous = ["'; DROP TABLE--", "AAPL;--", "AAPL/*", "*/DROP"]
for ticker in dangerous:
is_valid, sanitized, error = validate_ticker(ticker)
self.assertFalse(is_valid, f"Should reject SQL injection: {ticker}")
def test_validate_ticker_very_long_string(self):
"""Test very long string input."""
long_string = "A" * 1000
is_valid, sanitized, error = validate_ticker(long_string)
self.assertFalse(is_valid)
# ========================================================================
# PERIOD VALIDATION TESTS
# ========================================================================
def test_validate_period_valid(self):
"""Test valid periods."""
valid_periods = ["1d", "5d", "1mo", "3mo", "6mo", "1y", "2y", "5y", "10y", "ytd", "max"]
for period in valid_periods:
is_valid, sanitized, error = validate_period(period)
self.assertTrue(is_valid, f"Period {period} should be valid")
def test_validate_period_invalid(self):
"""Test invalid periods."""
invalid_periods = ["1w", "2mo", "invalid", "1day", ""]
for period in invalid_periods:
is_valid, sanitized, error = validate_period(period)
self.assertFalse(is_valid, f"Period {period} should be invalid")
# ========================================================================
# INTERVAL VALIDATION TESTS
# ========================================================================
def test_validate_interval_valid(self):
"""Test valid intervals."""
valid_intervals = ["1m", "5m", "15m", "1h", "1d", "1wk", "1mo"]
for interval in valid_intervals:
is_valid, sanitized, error = validate_interval(interval)
self.assertTrue(is_valid, f"Interval {interval} should be valid")
def test_validate_interval_invalid(self):
"""Test invalid intervals."""
invalid_intervals = ["1min", "hour", "invalid", ""]
for interval in invalid_intervals:
is_valid, sanitized, error = validate_interval(interval)
self.assertFalse(is_valid, f"Interval {interval} should be invalid")
# ========================================================================
# METRIC VALIDATION TESTS
# ========================================================================
def test_validate_metric_valid(self):
"""Test valid metrics."""
valid_metrics = ["performance", "valuation", "volatility"]
for metric in valid_metrics:
is_valid, sanitized, error = validate_metric(metric)
self.assertTrue(is_valid, f"Metric {metric} should be valid")
def test_validate_metric_invalid(self):
"""Test invalid metrics."""
invalid_metrics = ["price", "volume", "invalid", ""]
for metric in invalid_metrics:
is_valid, sanitized, error = validate_metric(metric)
self.assertFalse(is_valid, f"Metric {metric} should be invalid")
# ========================================================================
# JSON VALIDATION TESTS
# ========================================================================
def test_validate_json_valid(self):
"""Test valid JSON portfolio."""
valid_json = '{"AAPL": {"shares": 10, "cost_basis": 150}}'
is_valid, data, error = validate_json_input(valid_json)
self.assertTrue(is_valid)
self.assertIsInstance(data, dict)
self.assertEqual(error, "")
def test_validate_json_invalid_format(self):
"""Test invalid JSON format."""
invalid_json = '{"AAPL": {"shares": 10}' # Missing closing brace
is_valid, data, error = validate_json_input(invalid_json)
self.assertFalse(is_valid)
self.assertIn("JSON format", error)
def test_validate_json_not_dict(self):
"""Test JSON that's not a dictionary."""
invalid_json = '[1, 2, 3]' # Array instead of object
is_valid, data, error = validate_json_input(invalid_json)
self.assertFalse(is_valid)
self.assertIn("object/dictionary", error)
def test_validate_json_too_many_tickers(self):
"""Test portfolio with too many tickers."""
# Create JSON with 6 tickers (max is 5)
tickers = {f"TICK{i}": {"shares": 1, "cost_basis": 100} for i in range(6)}
invalid_json = json.dumps(tickers)
is_valid, data, error = validate_json_input(invalid_json)
self.assertFalse(is_valid)
self.assertIn("Too many tickers", error)
def test_validate_json_invalid_ticker(self):
"""Test portfolio with invalid ticker."""
invalid_json = '{"INVALID123": {"shares": 10, "cost_basis": 150}}'
is_valid, data, error = validate_json_input(invalid_json)
self.assertFalse(is_valid)
self.assertIn("Invalid ticker", error)
def test_validate_json_empty(self):
"""Test empty JSON."""
is_valid, data, error = validate_json_input("{}")
self.assertTrue(is_valid) # Empty dict is valid
def test_validate_json_malformed(self):
"""Test malformed JSON strings."""
malformed = ['{', '}', '{invalid}', 'null', 'true', '123']
for json_str in malformed:
is_valid, data, error = validate_json_input(json_str)
self.assertFalse(is_valid, f"Should reject malformed JSON: {json_str}")
def test_validate_json_sql_injection_in_ticker(self):
"""Test JSON with SQL injection in ticker."""
dangerous_json = '{"DROP TABLE": {"shares": 10, "cost_basis": 150}}'
is_valid, data, error = validate_json_input(dangerous_json)
self.assertFalse(is_valid)
if __name__ == '__main__':
unittest.main()