from __future__ import annotations from collections import defaultdict from types import SimpleNamespace import unittest from fastapi import HTTPException from backend.security_utils import ( admin_only, normalize_watchlist_symbols, rate_limit_guard, validate_cache_target, validate_symbol_and_interval, ) def _make_request(path: str, token: str = "", host: str = "127.0.0.1") -> SimpleNamespace: return SimpleNamespace( url=SimpleNamespace(path=path), headers={"X-Admin-Token": token} if token else {}, client=SimpleNamespace(host=host), ) class SecurityUtilsTests(unittest.TestCase): def test_rate_limit_guard_skips_whitelisted_paths(self) -> None: ip_limits: defaultdict[str, list[float]] = defaultdict(list) blocked = rate_limit_guard(_make_request("/api/health"), ip_limits) self.assertFalse(blocked) self.assertEqual(ip_limits, {}) def test_rate_limit_guard_blocks_after_limit(self) -> None: ip_limits: defaultdict[str, list[float]] = defaultdict(list) timestamps = iter([1.0, 2.0, 3.0]) request = _make_request("/api/symbols") self.assertFalse(rate_limit_guard(request, ip_limits, limit=2, time_provider=lambda: next(timestamps))) self.assertFalse(rate_limit_guard(request, ip_limits, limit=2, time_provider=lambda: next(timestamps))) self.assertTrue(rate_limit_guard(request, ip_limits, limit=2, time_provider=lambda: next(timestamps))) def test_admin_only_rejects_invalid_token(self) -> None: with self.assertRaises(HTTPException) as context: admin_only(_make_request("/api/cache", token="wrong"), "expected") self.assertEqual(context.exception.status_code, 401) def test_validate_symbol_and_interval_returns_canonical_values(self) -> None: result = validate_symbol_and_interval( "btc", "1H", get_canonical_symbol=lambda symbol: "BTCUSD", symbols={"BTCUSD": object()}, supported_intervals={"1h", "4h"}, ) self.assertEqual(result, ("BTCUSD", "1h")) def test_normalize_watchlist_symbols_deduplicates_and_tracks_invalid(self) -> None: valid_symbols, invalid_symbols, duplicate_count = normalize_watchlist_symbols( ["btc", "BTC", "eth", "ghost"], get_canonical_symbol=lambda symbol: {"btc": "BTCUSD", "BTC": "BTCUSD", "eth": "ETHUSD"}.get(symbol, symbol.upper()), symbol_registry={"BTCUSD": object(), "ETHUSD": object()}, ) self.assertEqual(valid_symbols, ["BTCUSD", "ETHUSD"]) self.assertEqual(invalid_symbols, ["GHOST"]) self.assertEqual(duplicate_count, 1) def test_validate_cache_target_rejects_unknown_value(self) -> None: with self.assertRaises(HTTPException) as context: validate_cache_target("mystery") self.assertEqual(context.exception.status_code, 400) self.assertIn("Unsupported cache target", context.exception.detail) if __name__ == "__main__": unittest.main()