Spaces:
Running
Running
File size: 5,443 Bytes
0f8b3a0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 | import asyncio
import atexit
import functools
import inspect
import threading
from concurrent.futures import Future
from typing import Awaitable, Callable, Concatenate, Generic, Literal, Optional, ParamSpec, TypeVar
P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")
class _BackgroundLoopRunner:
def __init__(self) -> None:
self._thread: Optional[threading.Thread] = None
self._loop: Optional[asyncio.AbstractEventLoop] = None
self._ready = threading.Event()
self._lock = threading.Lock()
def _thread_main(self) -> None:
loop = asyncio.new_event_loop()
self._loop = loop
asyncio.set_event_loop(loop)
self._ready.set()
try:
loop.run_forever()
finally:
try:
pending = asyncio.all_tasks(loop)
for task in pending:
task.cancel()
if pending:
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
loop.run_until_complete(loop.shutdown_asyncgens())
finally:
loop.close()
def _ensure_started(self) -> None:
if self._thread is not None and self._loop is not None:
return
with self._lock:
if self._thread is not None and self._loop is not None:
return
self._ready.clear()
t = threading.Thread(target=self._thread_main, name="async_to_sync_loop", daemon=True)
self._thread = t
t.start()
self._ready.wait()
def run(self, coro: Awaitable[R], *, timeout: Optional[float] = None) -> R:
self._ensure_started()
loop = self._loop
if loop is None:
raise RuntimeError("Background loop was not initialized")
fut: Future[R] = asyncio.run_coroutine_threadsafe(coro, loop)
return fut.result(timeout=timeout)
def shutdown(self) -> None:
loop = self._loop
thread = self._thread
if loop is None or thread is None:
return
loop.call_soon_threadsafe(loop.stop)
thread.join(timeout=1.0)
_default_runner: Optional[_BackgroundLoopRunner] = None
def _get_default_runner() -> _BackgroundLoopRunner:
global _default_runner
if _default_runner is None:
_default_runner = _BackgroundLoopRunner()
atexit.register(_default_runner.shutdown)
return _default_runner
InLoopPolicy = Literal["thread", "raise"]
def sync_wrap(
fn: Callable[P, Awaitable[R]],
*,
in_loop: InLoopPolicy = "thread",
runner: Optional[_BackgroundLoopRunner] = None,
timeout: Optional[float] = None,
use_runner: bool = False,
) -> Callable[P, R]:
@functools.wraps(fn)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
coro = fn(*args, **kwargs)
if use_runner:
r = runner or _get_default_runner()
return r.run(coro, timeout=timeout)
try:
asyncio.get_running_loop()
except RuntimeError:
return asyncio.run(coro)
if in_loop == "raise":
raise RuntimeError(
"Cannot call this sync wrapper from a running event loop. Use await on the async method instead."
)
r = runner or _get_default_runner()
return r.run(coro, timeout=timeout)
wrapper.__isabstractmethod__ = False
return wrapper
class SyncMethod(Generic[T, P, R]):
__isabstractmethod__ = False
def __init__(
self,
async_method: Callable[Concatenate[T, P], Awaitable[R]],
*,
in_loop: InLoopPolicy = "thread",
runner: Optional[_BackgroundLoopRunner] = None,
timeout: Optional[float] = None,
use_runner: bool = False,
cache: bool = True,
):
self._async_name = async_method.__name__
self._in_loop = in_loop
self._runner = runner
self._timeout = timeout
self._use_runner = use_runner
self._cache = cache
self._attr_name: Optional[str] = None
def __set_name__(self, owner, name: str) -> None:
self._attr_name = name
def __get__(self, obj: Optional[T], objtype=None) -> Callable[P, R]:
if obj is None:
return self
if self._cache and self._attr_name is not None:
cached = obj.__dict__.get(self._attr_name)
if cached is not None:
return cached
async_fn = getattr(obj, self._async_name)
sync_fn = sync_wrap(
async_fn,
in_loop=self._in_loop,
runner=self._runner,
timeout=self._timeout,
use_runner=self._use_runner,
)
sync_fn.__signature__ = inspect.signature(async_fn)
sync_fn.__isabstractmethod__ = False
if self._cache and self._attr_name is not None:
obj.__dict__[self._attr_name] = sync_fn
return sync_fn
def sync_method_wrapper(
async_method: Callable[Concatenate[T, P], Awaitable[R]],
*,
in_loop: InLoopPolicy = "thread",
runner: Optional[_BackgroundLoopRunner] = None,
timeout: Optional[float] = None,
use_runner: bool = False,
cache: bool = True,
) -> SyncMethod[T, P, R]:
return SyncMethod(
async_method,
in_loop=in_loop,
runner=runner,
timeout=timeout,
use_runner=use_runner,
cache=cache,
)
|