File size: 3,079 Bytes
a721dfa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from __future__ import annotations

from collections import defaultdict
from types import SimpleNamespace
import unittest

from fastapi import HTTPException

from backend.security_utils import (
    admin_only,
    normalize_watchlist_symbols,
    rate_limit_guard,
    validate_cache_target,
    validate_symbol_and_interval,
)


def _make_request(path: str, token: str = "", host: str = "127.0.0.1") -> SimpleNamespace:
    return SimpleNamespace(
        url=SimpleNamespace(path=path),
        headers={"X-Admin-Token": token} if token else {},
        client=SimpleNamespace(host=host),
    )


class SecurityUtilsTests(unittest.TestCase):
    def test_rate_limit_guard_skips_whitelisted_paths(self) -> None:
        ip_limits: defaultdict[str, list[float]] = defaultdict(list)
        blocked = rate_limit_guard(_make_request("/api/health"), ip_limits)
        self.assertFalse(blocked)
        self.assertEqual(ip_limits, {})

    def test_rate_limit_guard_blocks_after_limit(self) -> None:
        ip_limits: defaultdict[str, list[float]] = defaultdict(list)
        timestamps = iter([1.0, 2.0, 3.0])
        request = _make_request("/api/symbols")

        self.assertFalse(rate_limit_guard(request, ip_limits, limit=2, time_provider=lambda: next(timestamps)))
        self.assertFalse(rate_limit_guard(request, ip_limits, limit=2, time_provider=lambda: next(timestamps)))
        self.assertTrue(rate_limit_guard(request, ip_limits, limit=2, time_provider=lambda: next(timestamps)))

    def test_admin_only_rejects_invalid_token(self) -> None:
        with self.assertRaises(HTTPException) as context:
            admin_only(_make_request("/api/cache", token="wrong"), "expected")
        self.assertEqual(context.exception.status_code, 401)

    def test_validate_symbol_and_interval_returns_canonical_values(self) -> None:
        result = validate_symbol_and_interval(
            "btc",
            "1H",
            get_canonical_symbol=lambda symbol: "BTCUSD",
            symbols={"BTCUSD": object()},
            supported_intervals={"1h", "4h"},
        )
        self.assertEqual(result, ("BTCUSD", "1h"))

    def test_normalize_watchlist_symbols_deduplicates_and_tracks_invalid(self) -> None:
        valid_symbols, invalid_symbols, duplicate_count = normalize_watchlist_symbols(
            ["btc", "BTC", "eth", "ghost"],
            get_canonical_symbol=lambda symbol: {"btc": "BTCUSD", "BTC": "BTCUSD", "eth": "ETHUSD"}.get(symbol, symbol.upper()),
            symbol_registry={"BTCUSD": object(), "ETHUSD": object()},
        )
        self.assertEqual(valid_symbols, ["BTCUSD", "ETHUSD"])
        self.assertEqual(invalid_symbols, ["GHOST"])
        self.assertEqual(duplicate_count, 1)

    def test_validate_cache_target_rejects_unknown_value(self) -> None:
        with self.assertRaises(HTTPException) as context:
            validate_cache_target("mystery")
        self.assertEqual(context.exception.status_code, 400)
        self.assertIn("Unsupported cache target", context.exception.detail)


if __name__ == "__main__":
    unittest.main()