velai-workshop / velai /sync_utils.py
kratadata's picture
Upload folder via script
0f8b3a0 verified
import asyncio
import atexit
import functools
import inspect
import threading
from concurrent.futures import Future
from typing import Awaitable, Callable, Concatenate, Generic, Literal, Optional, ParamSpec, TypeVar
P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")
class _BackgroundLoopRunner:
def __init__(self) -> None:
self._thread: Optional[threading.Thread] = None
self._loop: Optional[asyncio.AbstractEventLoop] = None
self._ready = threading.Event()
self._lock = threading.Lock()
def _thread_main(self) -> None:
loop = asyncio.new_event_loop()
self._loop = loop
asyncio.set_event_loop(loop)
self._ready.set()
try:
loop.run_forever()
finally:
try:
pending = asyncio.all_tasks(loop)
for task in pending:
task.cancel()
if pending:
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
loop.run_until_complete(loop.shutdown_asyncgens())
finally:
loop.close()
def _ensure_started(self) -> None:
if self._thread is not None and self._loop is not None:
return
with self._lock:
if self._thread is not None and self._loop is not None:
return
self._ready.clear()
t = threading.Thread(target=self._thread_main, name="async_to_sync_loop", daemon=True)
self._thread = t
t.start()
self._ready.wait()
def run(self, coro: Awaitable[R], *, timeout: Optional[float] = None) -> R:
self._ensure_started()
loop = self._loop
if loop is None:
raise RuntimeError("Background loop was not initialized")
fut: Future[R] = asyncio.run_coroutine_threadsafe(coro, loop)
return fut.result(timeout=timeout)
def shutdown(self) -> None:
loop = self._loop
thread = self._thread
if loop is None or thread is None:
return
loop.call_soon_threadsafe(loop.stop)
thread.join(timeout=1.0)
_default_runner: Optional[_BackgroundLoopRunner] = None
def _get_default_runner() -> _BackgroundLoopRunner:
global _default_runner
if _default_runner is None:
_default_runner = _BackgroundLoopRunner()
atexit.register(_default_runner.shutdown)
return _default_runner
InLoopPolicy = Literal["thread", "raise"]
def sync_wrap(
fn: Callable[P, Awaitable[R]],
*,
in_loop: InLoopPolicy = "thread",
runner: Optional[_BackgroundLoopRunner] = None,
timeout: Optional[float] = None,
use_runner: bool = False,
) -> Callable[P, R]:
@functools.wraps(fn)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
coro = fn(*args, **kwargs)
if use_runner:
r = runner or _get_default_runner()
return r.run(coro, timeout=timeout)
try:
asyncio.get_running_loop()
except RuntimeError:
return asyncio.run(coro)
if in_loop == "raise":
raise RuntimeError(
"Cannot call this sync wrapper from a running event loop. Use await on the async method instead."
)
r = runner or _get_default_runner()
return r.run(coro, timeout=timeout)
wrapper.__isabstractmethod__ = False
return wrapper
class SyncMethod(Generic[T, P, R]):
__isabstractmethod__ = False
def __init__(
self,
async_method: Callable[Concatenate[T, P], Awaitable[R]],
*,
in_loop: InLoopPolicy = "thread",
runner: Optional[_BackgroundLoopRunner] = None,
timeout: Optional[float] = None,
use_runner: bool = False,
cache: bool = True,
):
self._async_name = async_method.__name__
self._in_loop = in_loop
self._runner = runner
self._timeout = timeout
self._use_runner = use_runner
self._cache = cache
self._attr_name: Optional[str] = None
def __set_name__(self, owner, name: str) -> None:
self._attr_name = name
def __get__(self, obj: Optional[T], objtype=None) -> Callable[P, R]:
if obj is None:
return self
if self._cache and self._attr_name is not None:
cached = obj.__dict__.get(self._attr_name)
if cached is not None:
return cached
async_fn = getattr(obj, self._async_name)
sync_fn = sync_wrap(
async_fn,
in_loop=self._in_loop,
runner=self._runner,
timeout=self._timeout,
use_runner=self._use_runner,
)
sync_fn.__signature__ = inspect.signature(async_fn)
sync_fn.__isabstractmethod__ = False
if self._cache and self._attr_name is not None:
obj.__dict__[self._attr_name] = sync_fn
return sync_fn
def sync_method_wrapper(
async_method: Callable[Concatenate[T, P], Awaitable[R]],
*,
in_loop: InLoopPolicy = "thread",
runner: Optional[_BackgroundLoopRunner] = None,
timeout: Optional[float] = None,
use_runner: bool = False,
cache: bool = True,
) -> SyncMethod[T, P, R]:
return SyncMethod(
async_method,
in_loop=in_loop,
runner=runner,
timeout=timeout,
use_runner=use_runner,
cache=cache,
)