SuperAI_Forecast / backend /security_utils.py
Thang6822
Update Kronos Platform v6.1.0: Complete backend refactor and frontend UI optimization
a721dfa
from __future__ import annotations
import hmac
import time
from typing import Any, Callable, Dict, List, Mapping, MutableMapping, Sequence, Set, Tuple
from fastapi import HTTPException, Request
CACHE_TARGETS: frozenset[str] = frozenset(
{"all", "historical", "forecast", "indicators", "ticker", "ai_verdict"}
)
RATE_LIMIT_WHITELIST: frozenset[str] = frozenset(
{"/api/health", "/api/metrics", "/api/ping"}
)
def rate_limit_guard(
request: Request,
ip_limits: MutableMapping[str, List[float]],
*,
limit: int = 60,
window_seconds: int = 60,
time_provider: Callable[[], float] = time.time,
) -> bool:
if request.url.path in RATE_LIMIT_WHITELIST:
return False
ip = request.client.host if request.client else "unknown"
now = time_provider()
ip_limits[ip] = [ts for ts in ip_limits[ip] if now - ts < window_seconds]
if len(ip_limits[ip]) >= limit:
return True
ip_limits[ip].append(now)
return False
def admin_only(request: Request, admin_token: str) -> bool:
token = request.headers.get("X-Admin-Token", "")
if not token or not hmac.compare_digest(token, admin_token):
raise HTTPException(status_code=401, detail="Admin access required")
return True
def validate_symbol_and_interval(
symbol: str,
interval: str,
*,
get_canonical_symbol: Callable[[str], str],
symbols: Mapping[str, Any],
supported_intervals: Sequence[str] | Set[str],
) -> Tuple[str, str]:
canonical_symbol = get_canonical_symbol(symbol)
normalized_interval = interval.lower()
if canonical_symbol not in symbols:
raise HTTPException(404, f"Unknown symbol: {canonical_symbol}")
if normalized_interval not in supported_intervals:
raise HTTPException(400, f"Unsupported interval: {interval}")
return canonical_symbol, normalized_interval
def normalize_watchlist_symbols(
symbols: List[str],
*,
get_canonical_symbol: Callable[[str], str],
symbol_registry: Mapping[str, Any],
) -> Tuple[List[str], List[str], int]:
valid_symbols: List[str] = []
invalid_symbols: List[str] = []
seen_valid: Set[str] = set()
seen_invalid: Set[str] = set()
for raw_symbol in symbols:
canonical_symbol = get_canonical_symbol(raw_symbol.strip())
if canonical_symbol in symbol_registry:
if canonical_symbol not in seen_valid:
valid_symbols.append(canonical_symbol)
seen_valid.add(canonical_symbol)
continue
if canonical_symbol and canonical_symbol not in seen_invalid:
invalid_symbols.append(canonical_symbol)
seen_invalid.add(canonical_symbol)
duplicate_count = max(0, len(symbols) - len(valid_symbols) - len(invalid_symbols))
return valid_symbols, invalid_symbols, duplicate_count
def validate_cache_target(target: str) -> str:
normalized_target = target.lower()
if normalized_target not in CACHE_TARGETS:
allowed_targets = ", ".join(sorted(CACHE_TARGETS))
raise HTTPException(400, f"Unsupported cache target: {target}. Allowed: {allowed_targets}")
return normalized_target