velai-workshop / velai /async_utils.py
kratadata's picture
Upload folder via script
0f8b3a0 verified
import asyncio
import functools
import inspect
from concurrent.futures import Executor
from dataclasses import dataclass, field
from typing import Any, Awaitable, Callable, Concatenate, Generic, Optional, ParamSpec, TypeVar
from nicegui import ui
P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")
def async_wrap(fn: Callable[P, R], *, executor: Optional[Executor] = None) -> Callable[P, Awaitable[R]]:
@functools.wraps(fn)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
if executor is None:
return await asyncio.to_thread(fn, *args, **kwargs)
loop = asyncio.get_running_loop()
call = functools.partial(fn, *args, **kwargs)
return await loop.run_in_executor(executor, call)
wrapper.__isabstractmethod__ = False
return wrapper
class AsyncMethod(Generic[T, P, R]):
__isabstractmethod__ = False
def __init__(
self, sync_method: Callable[Concatenate[T, P], R], *, executor: Optional[Executor] = None, cache: bool = True
):
self._sync_name = sync_method.__name__
self._executor = executor
self._cache = cache
self._attr_name: Optional[str] = None
def __set_name__(self, owner, name: str) -> None:
self._attr_name = name
def __get__(self, obj: Optional[T], objtype=None) -> Callable[P, Awaitable[R]]:
if obj is None:
return self
if self._cache and self._attr_name is not None:
cached = obj.__dict__.get(self._attr_name)
if cached is not None:
return cached
sync = getattr(obj, self._sync_name)
async_fn = async_wrap(sync, executor=self._executor)
async_fn.__signature__ = inspect.signature(sync)
async_fn.__isabstractmethod__ = False
if self._cache and self._attr_name is not None:
obj.__dict__[self._attr_name] = async_fn
return async_fn
def async_method_wrapper(
sync_method: Callable[Concatenate[T, P], R],
*,
executor: Optional[Executor] = None,
cache: bool = True,
) -> AsyncMethod[T, P, R]:
return AsyncMethod(sync_method, executor=executor, cache=cache)
@dataclass
class AsyncDirtyTimer:
"""Throttled async callback that runs at most every `interval` seconds.
Usage:
saver = AsyncDirtyTimer(callback=some_async_function, interval=5.0)
saver.mark_dirty() # call whenever state changes
"""
callback: Callable[[], Awaitable[None]]
interval: float = 5.0
auto_start: bool = True
_dirty: bool = field(default=False, init=False)
_timer: Any = field(default=None, init=False)
def __post_init__(self) -> None:
# One NiceGUI timer per client; it will call `_on_tick` periodically
self._timer = ui.timer(self.interval, self._on_tick, active=self.auto_start)
def mark_dirty(self) -> None:
"""Indicate that there is new work to save."""
self._dirty = True
def set_active(self, active: bool) -> None:
"""Enable or disable the periodic flushing."""
if self._timer is not None:
self._timer.active = active
async def flush_now(self) -> None:
"""Force an immediate save (used for explicit save button etc)."""
if not self._dirty:
return
self._dirty = False
await self.callback()
async def _on_tick(self) -> None:
"""Called by the NiceGUI timer every `interval` seconds."""
if not self._dirty:
return
self._dirty = False
await self.callback()