File size: 4,872 Bytes
7169bc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
"""
Input validation module for Financial Market Data MCP Server.
Handles validation and sanitization of user inputs.
"""

import json
import re
import logging
from typing import Tuple, Dict, Callable
from .config import (
    ALLOWED_TICKER_PATTERN,
    MAX_TICKERS_PER_REQUEST,
    ALLOWED_PERIODS,
    ALLOWED_INTERVALS,
    ALLOWED_METRICS
)
from .data_fetcher import get_cached_stock_data

logger = logging.getLogger(__name__)


def _ticker_has_market_data(ticker: str, fetcher: Callable) -> bool:
    """Return True if the fetcher can return market data for ticker."""
    try:
        _, hist = fetcher(ticker)
        return not hist.empty
    except Exception as exc:
        logger.warning(f"Ticker existence check failed for {ticker}: {exc}")
        return False


def validate_ticker(
    ticker: str,
    *,
    check_exists: bool = False,
    fetcher: Callable = None
) -> Tuple[bool, str, str]:
    """
    Validate and sanitize ticker symbol.
    
    Args:
        ticker: Raw ticker input
        
    Returns:
        Tuple of (is_valid, sanitized_ticker, error_message)
    
        check_exists: When True, ensure the ticker returns real market data
        fetcher: Optional callable returning (info, history). Defaults to get_cached_stock_data
        
    Returns:
        Tuple of (is_valid, sanitized_ticker, error_message)
    """
    if not ticker:
        return False, "", "Ticker symbol is required"
    
    # Remove whitespace and convert to uppercase
    sanitized = ticker.strip().upper()
    
    # Check length
    if len(sanitized) > 5:
        return False, "", "Ticker symbol too long (max 5 characters)"
    
    # Check format (only uppercase letters)
    if not re.match(ALLOWED_TICKER_PATTERN, sanitized):
        logger.warning(f"Invalid ticker format attempted: {ticker}")
        return False, "", "Invalid ticker format. Use 1-5 uppercase letters only."
    
    # Check for common injection patterns
    dangerous_patterns = [';', '--', '/*', '*/', 'DROP', 'DELETE', 'INSERT']
    if any(pattern in sanitized for pattern in dangerous_patterns):
        logger.error(f"Potential SQL injection attempt: {ticker}")
        return False, "", "Invalid characters in ticker symbol"

    if check_exists:
        fetch_fn = fetcher or get_cached_stock_data
        if not _ticker_has_market_data(sanitized, fetch_fn):
            return False, "", f"Ticker '{sanitized}' not found or has no market data"
    
    return True, sanitized, ""


def validate_period(period: str) -> Tuple[bool, str, str]:
    """
    Validate period parameter.
    
    Args:
        period: Period string
        
    Returns:
        Tuple of (is_valid, sanitized_period, error_message)
    """
    if period not in ALLOWED_PERIODS:
        return False, "", f"Invalid period. Allowed: {', '.join(ALLOWED_PERIODS)}"
    
    return True, period, ""


def validate_interval(interval: str) -> Tuple[bool, str, str]:
    """
    Validate interval parameter.
    
    Args:
        interval: Interval string
        
    Returns:
        Tuple of (is_valid, sanitized_interval, error_message)
    """
    if interval not in ALLOWED_INTERVALS:
        return False, "", f"Invalid interval. Allowed: {', '.join(ALLOWED_INTERVALS)}"
    
    return True, interval, ""


def validate_metric(metric: str) -> Tuple[bool, str, str]:
    """
    Validate comparison metric parameter.
    
    Args:
        metric: Metric string
        
    Returns:
        Tuple of (is_valid, sanitized_metric, error_message)
    """
    if metric not in ALLOWED_METRICS:
        return False, "", f"Invalid metric. Allowed: {', '.join(ALLOWED_METRICS)}"
    
    return True, metric, ""


def validate_json_input(json_str: str) -> Tuple[bool, Dict, str]:
    """
    Validate and sanitize JSON input.
    
    Args:
        json_str: JSON string
        
    Returns:
        Tuple of (is_valid, parsed_json, error_message)
    """
    try:
        data = json.loads(json_str)
        
        # Check if it's a dictionary
        if not isinstance(data, dict):
            return False, {}, "JSON must be an object/dictionary"
        
        # Limit number of items
        if len(data) > MAX_TICKERS_PER_REQUEST:
            return False, {}, f"Too many tickers (max {MAX_TICKERS_PER_REQUEST})"
        
        # Validate each ticker in the portfolio
        for ticker in data.keys():
            is_valid, _, error = validate_ticker(ticker)
            if not is_valid:
                return False, {}, f"Invalid ticker '{ticker}': {error}"
        
        return True, data, ""
        
    except json.JSONDecodeError as e:
        logger.warning(f"Invalid JSON input: {e}")
        return False, {}, "Invalid JSON format"
    except Exception as e:
        logger.error(f"Unexpected error validating JSON: {e}")
        return False, {}, "Invalid JSON format"