|
|
""" |
|
|
Input validation module for Financial Market Data MCP Server. |
|
|
Handles validation and sanitization of user inputs. |
|
|
""" |
|
|
|
|
|
import json |
|
|
import re |
|
|
import logging |
|
|
from typing import Tuple, Dict, Callable |
|
|
from .config import ( |
|
|
ALLOWED_TICKER_PATTERN, |
|
|
MAX_TICKERS_PER_REQUEST, |
|
|
ALLOWED_PERIODS, |
|
|
ALLOWED_INTERVALS, |
|
|
ALLOWED_METRICS |
|
|
) |
|
|
from .data_fetcher import get_cached_stock_data |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
def _ticker_has_market_data(ticker: str, fetcher: Callable) -> bool: |
|
|
"""Return True if the fetcher can return market data for ticker.""" |
|
|
try: |
|
|
_, hist = fetcher(ticker) |
|
|
return not hist.empty |
|
|
except Exception as exc: |
|
|
logger.warning(f"Ticker existence check failed for {ticker}: {exc}") |
|
|
return False |
|
|
|
|
|
|
|
|
def validate_ticker( |
|
|
ticker: str, |
|
|
*, |
|
|
check_exists: bool = False, |
|
|
fetcher: Callable = None |
|
|
) -> Tuple[bool, str, str]: |
|
|
""" |
|
|
Validate and sanitize ticker symbol. |
|
|
|
|
|
Args: |
|
|
ticker: Raw ticker input |
|
|
|
|
|
Returns: |
|
|
Tuple of (is_valid, sanitized_ticker, error_message) |
|
|
|
|
|
check_exists: When True, ensure the ticker returns real market data |
|
|
fetcher: Optional callable returning (info, history). Defaults to get_cached_stock_data |
|
|
|
|
|
Returns: |
|
|
Tuple of (is_valid, sanitized_ticker, error_message) |
|
|
""" |
|
|
if not ticker: |
|
|
return False, "", "Ticker symbol is required" |
|
|
|
|
|
|
|
|
sanitized = ticker.strip().upper() |
|
|
|
|
|
|
|
|
if len(sanitized) > 5: |
|
|
return False, "", "Ticker symbol too long (max 5 characters)" |
|
|
|
|
|
|
|
|
if not re.match(ALLOWED_TICKER_PATTERN, sanitized): |
|
|
logger.warning(f"Invalid ticker format attempted: {ticker}") |
|
|
return False, "", "Invalid ticker format. Use 1-5 uppercase letters only." |
|
|
|
|
|
|
|
|
dangerous_patterns = [';', '--', '/*', '*/', 'DROP', 'DELETE', 'INSERT'] |
|
|
if any(pattern in sanitized for pattern in dangerous_patterns): |
|
|
logger.error(f"Potential SQL injection attempt: {ticker}") |
|
|
return False, "", "Invalid characters in ticker symbol" |
|
|
|
|
|
if check_exists: |
|
|
fetch_fn = fetcher or get_cached_stock_data |
|
|
if not _ticker_has_market_data(sanitized, fetch_fn): |
|
|
return False, "", f"Ticker '{sanitized}' not found or has no market data" |
|
|
|
|
|
return True, sanitized, "" |
|
|
|
|
|
|
|
|
def validate_period(period: str) -> Tuple[bool, str, str]: |
|
|
""" |
|
|
Validate period parameter. |
|
|
|
|
|
Args: |
|
|
period: Period string |
|
|
|
|
|
Returns: |
|
|
Tuple of (is_valid, sanitized_period, error_message) |
|
|
""" |
|
|
if period not in ALLOWED_PERIODS: |
|
|
return False, "", f"Invalid period. Allowed: {', '.join(ALLOWED_PERIODS)}" |
|
|
|
|
|
return True, period, "" |
|
|
|
|
|
|
|
|
def validate_interval(interval: str) -> Tuple[bool, str, str]: |
|
|
""" |
|
|
Validate interval parameter. |
|
|
|
|
|
Args: |
|
|
interval: Interval string |
|
|
|
|
|
Returns: |
|
|
Tuple of (is_valid, sanitized_interval, error_message) |
|
|
""" |
|
|
if interval not in ALLOWED_INTERVALS: |
|
|
return False, "", f"Invalid interval. Allowed: {', '.join(ALLOWED_INTERVALS)}" |
|
|
|
|
|
return True, interval, "" |
|
|
|
|
|
|
|
|
def validate_metric(metric: str) -> Tuple[bool, str, str]: |
|
|
""" |
|
|
Validate comparison metric parameter. |
|
|
|
|
|
Args: |
|
|
metric: Metric string |
|
|
|
|
|
Returns: |
|
|
Tuple of (is_valid, sanitized_metric, error_message) |
|
|
""" |
|
|
if metric not in ALLOWED_METRICS: |
|
|
return False, "", f"Invalid metric. Allowed: {', '.join(ALLOWED_METRICS)}" |
|
|
|
|
|
return True, metric, "" |
|
|
|
|
|
|
|
|
def validate_json_input(json_str: str) -> Tuple[bool, Dict, str]: |
|
|
""" |
|
|
Validate and sanitize JSON input. |
|
|
|
|
|
Args: |
|
|
json_str: JSON string |
|
|
|
|
|
Returns: |
|
|
Tuple of (is_valid, parsed_json, error_message) |
|
|
""" |
|
|
try: |
|
|
data = json.loads(json_str) |
|
|
|
|
|
|
|
|
if not isinstance(data, dict): |
|
|
return False, {}, "JSON must be an object/dictionary" |
|
|
|
|
|
|
|
|
if len(data) > MAX_TICKERS_PER_REQUEST: |
|
|
return False, {}, f"Too many tickers (max {MAX_TICKERS_PER_REQUEST})" |
|
|
|
|
|
|
|
|
for ticker in data.keys(): |
|
|
is_valid, _, error = validate_ticker(ticker) |
|
|
if not is_valid: |
|
|
return False, {}, f"Invalid ticker '{ticker}': {error}" |
|
|
|
|
|
return True, data, "" |
|
|
|
|
|
except json.JSONDecodeError as e: |
|
|
logger.warning(f"Invalid JSON input: {e}") |
|
|
return False, {}, "Invalid JSON format" |
|
|
except Exception as e: |
|
|
logger.error(f"Unexpected error validating JSON: {e}") |
|
|
return False, {}, "Invalid JSON format" |
|
|
|
|
|
|