Spaces:
Running
Running
File size: 3,079 Bytes
a721dfa | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 | from __future__ import annotations
from collections import defaultdict
from types import SimpleNamespace
import unittest
from fastapi import HTTPException
from backend.security_utils import (
admin_only,
normalize_watchlist_symbols,
rate_limit_guard,
validate_cache_target,
validate_symbol_and_interval,
)
def _make_request(path: str, token: str = "", host: str = "127.0.0.1") -> SimpleNamespace:
return SimpleNamespace(
url=SimpleNamespace(path=path),
headers={"X-Admin-Token": token} if token else {},
client=SimpleNamespace(host=host),
)
class SecurityUtilsTests(unittest.TestCase):
def test_rate_limit_guard_skips_whitelisted_paths(self) -> None:
ip_limits: defaultdict[str, list[float]] = defaultdict(list)
blocked = rate_limit_guard(_make_request("/api/health"), ip_limits)
self.assertFalse(blocked)
self.assertEqual(ip_limits, {})
def test_rate_limit_guard_blocks_after_limit(self) -> None:
ip_limits: defaultdict[str, list[float]] = defaultdict(list)
timestamps = iter([1.0, 2.0, 3.0])
request = _make_request("/api/symbols")
self.assertFalse(rate_limit_guard(request, ip_limits, limit=2, time_provider=lambda: next(timestamps)))
self.assertFalse(rate_limit_guard(request, ip_limits, limit=2, time_provider=lambda: next(timestamps)))
self.assertTrue(rate_limit_guard(request, ip_limits, limit=2, time_provider=lambda: next(timestamps)))
def test_admin_only_rejects_invalid_token(self) -> None:
with self.assertRaises(HTTPException) as context:
admin_only(_make_request("/api/cache", token="wrong"), "expected")
self.assertEqual(context.exception.status_code, 401)
def test_validate_symbol_and_interval_returns_canonical_values(self) -> None:
result = validate_symbol_and_interval(
"btc",
"1H",
get_canonical_symbol=lambda symbol: "BTCUSD",
symbols={"BTCUSD": object()},
supported_intervals={"1h", "4h"},
)
self.assertEqual(result, ("BTCUSD", "1h"))
def test_normalize_watchlist_symbols_deduplicates_and_tracks_invalid(self) -> None:
valid_symbols, invalid_symbols, duplicate_count = normalize_watchlist_symbols(
["btc", "BTC", "eth", "ghost"],
get_canonical_symbol=lambda symbol: {"btc": "BTCUSD", "BTC": "BTCUSD", "eth": "ETHUSD"}.get(symbol, symbol.upper()),
symbol_registry={"BTCUSD": object(), "ETHUSD": object()},
)
self.assertEqual(valid_symbols, ["BTCUSD", "ETHUSD"])
self.assertEqual(invalid_symbols, ["GHOST"])
self.assertEqual(duplicate_count, 1)
def test_validate_cache_target_rejects_unknown_value(self) -> None:
with self.assertRaises(HTTPException) as context:
validate_cache_target("mystery")
self.assertEqual(context.exception.status_code, 400)
self.assertIn("Unsupported cache target", context.exception.detail)
if __name__ == "__main__":
unittest.main()
|