File size: 6,165 Bytes
f5bade2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
"""Utility caching primitives used across the demo application."""

from __future__ import annotations

import asyncio
import time
from dataclasses import dataclass
from threading import Lock
from typing import Awaitable, Callable, Dict, Generic, Hashable, Optional, TypeVar

T = TypeVar("T")


class CacheUnavailableError(RuntimeError):
    """Raised when cached resource is temporarily unavailable."""

    def __init__(self, message: str, retry_in: float):
        super().__init__(message)
        self.retry_in = max(retry_in, 0.0)


@dataclass
class _CacheRecord(Generic[T]):
    value: Optional[T]
    expires_at: float
    error_until: float
    error_message: Optional[str]


class AsyncTTLCache(Generic[T]):
    """Simple async-aware TTL cache with cooldown on failures."""

    def __init__(self, ttl: float, retry_after: float):
        self.ttl = ttl
        self.retry_after = retry_after
        self._store: Dict[Hashable, _CacheRecord[T]] = {}
        self._locks: Dict[Hashable, asyncio.Lock] = {}
        self._global_lock = asyncio.Lock()

    async def get(self, key: Hashable, loader: Callable[[], Awaitable[T]]) -> T:
        now = time.monotonic()
        record = self._store.get(key)
        if record:
            if record.value is not None and now < record.expires_at:
                return record.value
            if record.error_message and now < record.error_until:
                raise CacheUnavailableError(
                    record.error_message,
                    record.error_until - now,
                )

        lock = await self._get_lock(key)
        async with lock:
            now = time.monotonic()
            record = self._store.get(key)
            if record:
                if record.value is not None and now < record.expires_at:
                    return record.value
                if record.error_message and now < record.error_until:
                    raise CacheUnavailableError(
                        record.error_message,
                        record.error_until - now,
                    )

            try:
                value = await loader()
            except CacheUnavailableError as exc:
                cooldown = max(exc.retry_in, self.retry_after)
                message = str(exc) or "Resource unavailable"
                self._store[key] = _CacheRecord(
                    value=None,
                    expires_at=0.0,
                    error_until=now + cooldown,
                    error_message=message,
                )
                raise CacheUnavailableError(message, cooldown) from exc
            except Exception as exc:  # noqa: BLE001 - surface upstream
                message = str(exc) or "Source request failed"
                self._store[key] = _CacheRecord(
                    value=None,
                    expires_at=0.0,
                    error_until=now + self.retry_after,
                    error_message=message,
                )
                raise CacheUnavailableError(message, self.retry_after) from exc
            else:
                self._store[key] = _CacheRecord(
                    value=value,
                    expires_at=now + self.ttl,
                    error_until=0.0,
                    error_message=None,
                )
                return value

    async def _get_lock(self, key: Hashable) -> asyncio.Lock:
        lock = self._locks.get(key)
        if lock is not None:
            return lock
        async with self._global_lock:
            lock = self._locks.get(key)
            if lock is None:
                lock = asyncio.Lock()
                self._locks[key] = lock
            return lock


class TTLCache(Generic[T]):
    """Synchronous TTL cache with cooldown control."""

    def __init__(self, ttl: float, retry_after: float):
        self.ttl = ttl
        self.retry_after = retry_after
        self._store: Dict[Hashable, _CacheRecord[T]] = {}
        self._lock = Lock()

    def get(self, key: Hashable, loader: Callable[[], T]) -> T:
        now = time.monotonic()
        record = self._store.get(key)
        if record:
            if record.value is not None and now < record.expires_at:
                return record.value
            if record.error_message and now < record.error_until:
                raise CacheUnavailableError(
                    record.error_message,
                    record.error_until - now,
                )

        with self._lock:
            now = time.monotonic()
            record = self._store.get(key)
            if record:
                if record.value is not None and now < record.expires_at:
                    return record.value
                if record.error_message and now < record.error_until:
                    raise CacheUnavailableError(
                        record.error_message,
                        record.error_until - now,
                    )
            try:
                value = loader()
            except CacheUnavailableError as exc:
                cooldown = max(exc.retry_in, self.retry_after)
                message = str(exc) or "Resource unavailable"
                self._store[key] = _CacheRecord(
                    value=None,
                    expires_at=0.0,
                    error_until=now + cooldown,
                    error_message=message,
                )
                raise CacheUnavailableError(message, cooldown) from exc
            except Exception as exc:  # noqa: BLE001 - propagate for visibility
                message = str(exc) or "Source request failed"
                self._store[key] = _CacheRecord(
                    value=None,
                    expires_at=0.0,
                    error_until=now + self.retry_after,
                    error_message=message,
                )
                raise CacheUnavailableError(message, self.retry_after) from exc
            else:
                self._store[key] = _CacheRecord(
                    value=value,
                    expires_at=now + self.ttl,
                    error_until=0.0,
                    error_message=None,
                )
                return value