Miosy's picture
Upload folder using huggingface_hub
2143587 verified
import asyncio
import dataclasses
import inspect
import sys
from functools import _CacheInfo, _make_key, partial, partialmethod
from typing import (
Any,
Callable,
Coroutine,
Generic,
Hashable,
List,
Optional,
OrderedDict,
Type,
TypedDict,
TypeVar,
Union,
cast,
final,
overload,
)
if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self
if sys.version_info < (3, 14):
from asyncio.coroutines import _is_coroutine # type: ignore[attr-defined]
__version__ = "2.1.0"
__all__ = ("alru_cache",)
_T = TypeVar("_T")
_R = TypeVar("_R")
_Coro = Coroutine[Any, Any, _R]
_CB = Callable[..., _Coro[_R]]
_CBP = Union[_CB[_R], "partial[_Coro[_R]]", "partialmethod[_Coro[_R]]"]
@final
class _CacheParameters(TypedDict):
typed: bool
maxsize: Optional[int]
tasks: int
closed: bool
@final
@dataclasses.dataclass
class _CacheItem(Generic[_R]):
task: "asyncio.Task[_R]"
later_call: Optional[asyncio.Handle]
waiters: int
def cancel(self) -> None:
if self.later_call is not None:
self.later_call.cancel()
self.later_call = None
@final
class _LRUCacheWrapper(Generic[_R]):
def __init__(
self,
fn: _CB[_R],
maxsize: Optional[int],
typed: bool,
ttl: Optional[float],
) -> None:
try:
self.__module__ = fn.__module__
except AttributeError:
pass
try:
self.__name__ = fn.__name__
except AttributeError:
pass
try:
self.__qualname__ = fn.__qualname__
except AttributeError:
pass
try:
self.__doc__ = fn.__doc__
except AttributeError:
pass
try:
self.__annotations__ = fn.__annotations__
except AttributeError:
pass
try:
self.__dict__.update(fn.__dict__)
except AttributeError:
pass
# set __wrapped__ last so we don't inadvertently copy it
# from the wrapped function when updating __dict__
if sys.version_info < (3, 14):
self._is_coroutine = _is_coroutine
self.__wrapped__ = fn
self.__maxsize = maxsize
self.__typed = typed
self.__ttl = ttl
self.__cache: OrderedDict[Hashable, _CacheItem[_R]] = OrderedDict()
self.__closed = False
self.__hits = 0
self.__misses = 0
@property
def __tasks(self) -> List["asyncio.Task[_R]"]:
# NOTE: I don't think we need to form a set first here but not too sure we want it for guarantees
return list(
{
cache_item.task
for cache_item in self.__cache.values()
if not cache_item.task.done()
}
)
def cache_invalidate(self, /, *args: Hashable, **kwargs: Any) -> bool:
key = _make_key(args, kwargs, self.__typed)
cache_item = self.__cache.pop(key, None)
if cache_item is None:
return False
else:
cache_item.cancel()
return True
def cache_clear(self) -> None:
self.__hits = 0
self.__misses = 0
for c in self.__cache.values():
if c.later_call:
c.later_call.cancel()
self.__cache.clear()
async def cache_close(self, *, wait: bool = False) -> None:
self.__closed = True
tasks = self.__tasks
if not tasks:
return
if not wait:
for task in tasks:
if not task.done():
task.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
def cache_info(self) -> _CacheInfo:
return _CacheInfo(
self.__hits,
self.__misses,
self.__maxsize,
len(self.__cache),
)
def cache_parameters(self) -> _CacheParameters:
return _CacheParameters(
maxsize=self.__maxsize,
typed=self.__typed,
tasks=len(self.__tasks),
closed=self.__closed,
)
def _cache_hit(self, key: Hashable) -> None:
self.__hits += 1
self.__cache.move_to_end(key)
def _cache_miss(self, key: Hashable) -> None:
self.__misses += 1
def _task_done_callback(self, key: Hashable, task: "asyncio.Task[_R]") -> None:
# We must use the private attribute instead of `exception()`
# so asyncio does not set `task.__log_traceback = False` on
# the false assumption that the caller read the task Exception
if task.cancelled() or task._exception is not None:
self.__cache.pop(key, None)
return
cache_item = self.__cache.get(key)
if self.__ttl is not None and cache_item is not None:
loop = asyncio.get_running_loop()
cache_item.later_call = loop.call_later(
self.__ttl, self.__cache.pop, key, None
)
async def _shield_and_handle_cancelled_error(
self, cache_item: _CacheItem[_T], key: Hashable
) -> _T:
task = cache_item.task
try:
# All waiters await the same shielded task.
return await asyncio.shield(task)
except asyncio.CancelledError:
# If this is the last waiter and the underlying task is not done,
# cancel the underlying task and remove the cache entry.
if cache_item.waiters == 1 and not task.done():
cache_item.cancel() # Cancel TTL expiration
task.cancel() # Cancel the running coroutine
self.__cache.pop(key, None) # Remove from cache
raise
finally:
# Each logical waiter decrements waiters on exit (normal or cancelled).
cache_item.waiters -= 1
async def __call__(self, /, *fn_args: Any, **fn_kwargs: Any) -> _R:
if self.__closed:
raise RuntimeError(f"alru_cache is closed for {self}")
loop = asyncio.get_running_loop()
key = _make_key(fn_args, fn_kwargs, self.__typed)
cache_item = self.__cache.get(key)
if cache_item is not None:
self._cache_hit(key)
if not cache_item.task.done():
# Each logical waiter increments waiters on entry.
cache_item.waiters += 1
return await self._shield_and_handle_cancelled_error(cache_item, key)
# If the task is already done, just return the result.
return cache_item.task.result()
coro = self.__wrapped__(*fn_args, **fn_kwargs)
task: asyncio.Task[_R] = loop.create_task(coro)
task.add_done_callback(partial(self._task_done_callback, key))
cache_item = _CacheItem(task, None, 1)
self.__cache[key] = cache_item
if self.__maxsize is not None and len(self.__cache) > self.__maxsize:
dropped_key, dropped_cache_item = self.__cache.popitem(last=False)
dropped_cache_item.cancel()
self._cache_miss(key)
return await self._shield_and_handle_cancelled_error(cache_item, key)
def __get__(
self, instance: _T, owner: Optional[Type[_T]]
) -> Union[Self, "_LRUCacheWrapperInstanceMethod[_R, _T]"]:
if owner is None:
return self
else:
return _LRUCacheWrapperInstanceMethod(self, instance)
@final
class _LRUCacheWrapperInstanceMethod(Generic[_R, _T]):
def __init__(
self,
wrapper: _LRUCacheWrapper[_R],
instance: _T,
) -> None:
try:
self.__module__ = wrapper.__module__
except AttributeError:
pass
try:
self.__name__ = wrapper.__name__
except AttributeError:
pass
try:
self.__qualname__ = wrapper.__qualname__
except AttributeError:
pass
try:
self.__doc__ = wrapper.__doc__
except AttributeError:
pass
try:
self.__annotations__ = wrapper.__annotations__
except AttributeError:
pass
try:
self.__dict__.update(wrapper.__dict__)
except AttributeError:
pass
# set __wrapped__ last so we don't inadvertently copy it
# from the wrapped function when updating __dict__
if sys.version_info < (3, 14):
self._is_coroutine = _is_coroutine
self.__wrapped__ = wrapper.__wrapped__
self.__instance = instance
self.__wrapper = wrapper
def cache_invalidate(self, /, *args: Hashable, **kwargs: Any) -> bool:
return self.__wrapper.cache_invalidate(self.__instance, *args, **kwargs)
def cache_clear(self) -> None:
self.__wrapper.cache_clear()
async def cache_close(
self, *, cancel: bool = False, return_exceptions: bool = True
) -> None:
await self.__wrapper.cache_close()
def cache_info(self) -> _CacheInfo:
return self.__wrapper.cache_info()
def cache_parameters(self) -> _CacheParameters:
return self.__wrapper.cache_parameters()
async def __call__(self, /, *fn_args: Any, **fn_kwargs: Any) -> _R:
return await self.__wrapper(self.__instance, *fn_args, **fn_kwargs)
def _make_wrapper(
maxsize: Optional[int],
typed: bool,
ttl: Optional[float] = None,
) -> Callable[[_CBP[_R]], _LRUCacheWrapper[_R]]:
def wrapper(fn: _CBP[_R]) -> _LRUCacheWrapper[_R]:
origin = fn
while isinstance(origin, (partial, partialmethod)):
origin = origin.func
if not inspect.iscoroutinefunction(origin):
raise RuntimeError(f"Coroutine function is required, got {fn!r}")
# functools.partialmethod support
if hasattr(fn, "_make_unbound_method"):
fn = fn._make_unbound_method()
wrapper = _LRUCacheWrapper(cast(_CB[_R], fn), maxsize, typed, ttl)
if sys.version_info >= (3, 12):
wrapper = inspect.markcoroutinefunction(wrapper)
return wrapper
return wrapper
@overload
def alru_cache(
maxsize: Optional[int] = 128,
typed: bool = False,
*,
ttl: Optional[float] = None,
) -> Callable[[_CBP[_R]], _LRUCacheWrapper[_R]]:
...
@overload
def alru_cache(
maxsize: _CBP[_R],
/,
) -> _LRUCacheWrapper[_R]:
...
def alru_cache(
maxsize: Union[Optional[int], _CBP[_R]] = 128,
typed: bool = False,
*,
ttl: Optional[float] = None,
) -> Union[Callable[[_CBP[_R]], _LRUCacheWrapper[_R]], _LRUCacheWrapper[_R]]:
if maxsize is None or isinstance(maxsize, int):
return _make_wrapper(maxsize, typed, ttl)
else:
fn = cast(_CB[_R], maxsize)
if callable(fn) or hasattr(fn, "_make_unbound_method"):
return _make_wrapper(128, False, None)(fn)
raise NotImplementedError(f"{fn!r} decorating is not supported")