Spaces:
Running
Running
| import asyncio | |
| import atexit | |
| import functools | |
| import inspect | |
| import threading | |
| from concurrent.futures import Future | |
| from typing import Awaitable, Callable, Concatenate, Generic, Literal, Optional, ParamSpec, TypeVar | |
| P = ParamSpec("P") | |
| R = TypeVar("R") | |
| T = TypeVar("T") | |
| class _BackgroundLoopRunner: | |
| def __init__(self) -> None: | |
| self._thread: Optional[threading.Thread] = None | |
| self._loop: Optional[asyncio.AbstractEventLoop] = None | |
| self._ready = threading.Event() | |
| self._lock = threading.Lock() | |
| def _thread_main(self) -> None: | |
| loop = asyncio.new_event_loop() | |
| self._loop = loop | |
| asyncio.set_event_loop(loop) | |
| self._ready.set() | |
| try: | |
| loop.run_forever() | |
| finally: | |
| try: | |
| pending = asyncio.all_tasks(loop) | |
| for task in pending: | |
| task.cancel() | |
| if pending: | |
| loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True)) | |
| loop.run_until_complete(loop.shutdown_asyncgens()) | |
| finally: | |
| loop.close() | |
| def _ensure_started(self) -> None: | |
| if self._thread is not None and self._loop is not None: | |
| return | |
| with self._lock: | |
| if self._thread is not None and self._loop is not None: | |
| return | |
| self._ready.clear() | |
| t = threading.Thread(target=self._thread_main, name="async_to_sync_loop", daemon=True) | |
| self._thread = t | |
| t.start() | |
| self._ready.wait() | |
| def run(self, coro: Awaitable[R], *, timeout: Optional[float] = None) -> R: | |
| self._ensure_started() | |
| loop = self._loop | |
| if loop is None: | |
| raise RuntimeError("Background loop was not initialized") | |
| fut: Future[R] = asyncio.run_coroutine_threadsafe(coro, loop) | |
| return fut.result(timeout=timeout) | |
| def shutdown(self) -> None: | |
| loop = self._loop | |
| thread = self._thread | |
| if loop is None or thread is None: | |
| return | |
| loop.call_soon_threadsafe(loop.stop) | |
| thread.join(timeout=1.0) | |
| _default_runner: Optional[_BackgroundLoopRunner] = None | |
| def _get_default_runner() -> _BackgroundLoopRunner: | |
| global _default_runner | |
| if _default_runner is None: | |
| _default_runner = _BackgroundLoopRunner() | |
| atexit.register(_default_runner.shutdown) | |
| return _default_runner | |
| InLoopPolicy = Literal["thread", "raise"] | |
| def sync_wrap( | |
| fn: Callable[P, Awaitable[R]], | |
| *, | |
| in_loop: InLoopPolicy = "thread", | |
| runner: Optional[_BackgroundLoopRunner] = None, | |
| timeout: Optional[float] = None, | |
| use_runner: bool = False, | |
| ) -> Callable[P, R]: | |
| def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: | |
| coro = fn(*args, **kwargs) | |
| if use_runner: | |
| r = runner or _get_default_runner() | |
| return r.run(coro, timeout=timeout) | |
| try: | |
| asyncio.get_running_loop() | |
| except RuntimeError: | |
| return asyncio.run(coro) | |
| if in_loop == "raise": | |
| raise RuntimeError( | |
| "Cannot call this sync wrapper from a running event loop. Use await on the async method instead." | |
| ) | |
| r = runner or _get_default_runner() | |
| return r.run(coro, timeout=timeout) | |
| wrapper.__isabstractmethod__ = False | |
| return wrapper | |
| class SyncMethod(Generic[T, P, R]): | |
| __isabstractmethod__ = False | |
| def __init__( | |
| self, | |
| async_method: Callable[Concatenate[T, P], Awaitable[R]], | |
| *, | |
| in_loop: InLoopPolicy = "thread", | |
| runner: Optional[_BackgroundLoopRunner] = None, | |
| timeout: Optional[float] = None, | |
| use_runner: bool = False, | |
| cache: bool = True, | |
| ): | |
| self._async_name = async_method.__name__ | |
| self._in_loop = in_loop | |
| self._runner = runner | |
| self._timeout = timeout | |
| self._use_runner = use_runner | |
| self._cache = cache | |
| self._attr_name: Optional[str] = None | |
| def __set_name__(self, owner, name: str) -> None: | |
| self._attr_name = name | |
| def __get__(self, obj: Optional[T], objtype=None) -> Callable[P, R]: | |
| if obj is None: | |
| return self | |
| if self._cache and self._attr_name is not None: | |
| cached = obj.__dict__.get(self._attr_name) | |
| if cached is not None: | |
| return cached | |
| async_fn = getattr(obj, self._async_name) | |
| sync_fn = sync_wrap( | |
| async_fn, | |
| in_loop=self._in_loop, | |
| runner=self._runner, | |
| timeout=self._timeout, | |
| use_runner=self._use_runner, | |
| ) | |
| sync_fn.__signature__ = inspect.signature(async_fn) | |
| sync_fn.__isabstractmethod__ = False | |
| if self._cache and self._attr_name is not None: | |
| obj.__dict__[self._attr_name] = sync_fn | |
| return sync_fn | |
| def sync_method_wrapper( | |
| async_method: Callable[Concatenate[T, P], Awaitable[R]], | |
| *, | |
| in_loop: InLoopPolicy = "thread", | |
| runner: Optional[_BackgroundLoopRunner] = None, | |
| timeout: Optional[float] = None, | |
| use_runner: bool = False, | |
| cache: bool = True, | |
| ) -> SyncMethod[T, P, R]: | |
| return SyncMethod( | |
| async_method, | |
| in_loop=in_loop, | |
| runner=runner, | |
| timeout=timeout, | |
| use_runner=use_runner, | |
| cache=cache, | |
| ) | |