SuperAI_Forecast / backend /test_security_utils.py
Thang6822
Update Kronos Platform v6.1.0: Complete backend refactor and frontend UI optimization
a721dfa
from __future__ import annotations
from collections import defaultdict
from types import SimpleNamespace
import unittest
from fastapi import HTTPException
from backend.security_utils import (
admin_only,
normalize_watchlist_symbols,
rate_limit_guard,
validate_cache_target,
validate_symbol_and_interval,
)
def _make_request(path: str, token: str = "", host: str = "127.0.0.1") -> SimpleNamespace:
return SimpleNamespace(
url=SimpleNamespace(path=path),
headers={"X-Admin-Token": token} if token else {},
client=SimpleNamespace(host=host),
)
class SecurityUtilsTests(unittest.TestCase):
def test_rate_limit_guard_skips_whitelisted_paths(self) -> None:
ip_limits: defaultdict[str, list[float]] = defaultdict(list)
blocked = rate_limit_guard(_make_request("/api/health"), ip_limits)
self.assertFalse(blocked)
self.assertEqual(ip_limits, {})
def test_rate_limit_guard_blocks_after_limit(self) -> None:
ip_limits: defaultdict[str, list[float]] = defaultdict(list)
timestamps = iter([1.0, 2.0, 3.0])
request = _make_request("/api/symbols")
self.assertFalse(rate_limit_guard(request, ip_limits, limit=2, time_provider=lambda: next(timestamps)))
self.assertFalse(rate_limit_guard(request, ip_limits, limit=2, time_provider=lambda: next(timestamps)))
self.assertTrue(rate_limit_guard(request, ip_limits, limit=2, time_provider=lambda: next(timestamps)))
def test_admin_only_rejects_invalid_token(self) -> None:
with self.assertRaises(HTTPException) as context:
admin_only(_make_request("/api/cache", token="wrong"), "expected")
self.assertEqual(context.exception.status_code, 401)
def test_validate_symbol_and_interval_returns_canonical_values(self) -> None:
result = validate_symbol_and_interval(
"btc",
"1H",
get_canonical_symbol=lambda symbol: "BTCUSD",
symbols={"BTCUSD": object()},
supported_intervals={"1h", "4h"},
)
self.assertEqual(result, ("BTCUSD", "1h"))
def test_normalize_watchlist_symbols_deduplicates_and_tracks_invalid(self) -> None:
valid_symbols, invalid_symbols, duplicate_count = normalize_watchlist_symbols(
["btc", "BTC", "eth", "ghost"],
get_canonical_symbol=lambda symbol: {"btc": "BTCUSD", "BTC": "BTCUSD", "eth": "ETHUSD"}.get(symbol, symbol.upper()),
symbol_registry={"BTCUSD": object(), "ETHUSD": object()},
)
self.assertEqual(valid_symbols, ["BTCUSD", "ETHUSD"])
self.assertEqual(invalid_symbols, ["GHOST"])
self.assertEqual(duplicate_count, 1)
def test_validate_cache_target_rejects_unknown_value(self) -> None:
with self.assertRaises(HTTPException) as context:
validate_cache_target("mystery")
self.assertEqual(context.exception.status_code, 400)
self.assertIn("Unsupported cache target", context.exception.detail)
if __name__ == "__main__":
unittest.main()