Spaces:
Paused
Paused
| import asyncio | |
| import asyncio.coroutines | |
| import contextvars | |
| import functools | |
| import inspect | |
| import os | |
| import sys | |
| import threading | |
| import warnings | |
| import weakref | |
| from concurrent.futures import Future, ThreadPoolExecutor | |
| from typing import ( | |
| TYPE_CHECKING, | |
| Any, | |
| Awaitable, | |
| Callable, | |
| Coroutine, | |
| Dict, | |
| Generic, | |
| List, | |
| Optional, | |
| TypeVar, | |
| Union, | |
| overload, | |
| ) | |
| from .current_thread_executor import CurrentThreadExecutor | |
| from .local import Local | |
| if sys.version_info >= (3, 10): | |
| from typing import ParamSpec | |
| else: | |
| from typing_extensions import ParamSpec | |
| if TYPE_CHECKING: | |
| # This is not available to import at runtime | |
| from _typeshed import OptExcInfo | |
| _F = TypeVar("_F", bound=Callable[..., Any]) | |
| _P = ParamSpec("_P") | |
| _R = TypeVar("_R") | |
| def _restore_context(context: contextvars.Context) -> None: | |
| # Check for changes in contextvars, and set them to the current | |
| # context for downstream consumers | |
| for cvar in context: | |
| cvalue = context.get(cvar) | |
| try: | |
| if cvar.get() != cvalue: | |
| cvar.set(cvalue) | |
| except LookupError: | |
| cvar.set(cvalue) | |
| # Python 3.12 deprecates asyncio.iscoroutinefunction() as an alias for | |
| # inspect.iscoroutinefunction(), whilst also removing the _is_coroutine marker. | |
| # The latter is replaced with the inspect.markcoroutinefunction decorator. | |
| # Until 3.12 is the minimum supported Python version, provide a shim. | |
| if hasattr(inspect, "markcoroutinefunction"): | |
| iscoroutinefunction = inspect.iscoroutinefunction | |
| markcoroutinefunction: Callable[[_F], _F] = inspect.markcoroutinefunction | |
| else: | |
| iscoroutinefunction = asyncio.iscoroutinefunction # type: ignore[assignment] | |
| def markcoroutinefunction(func: _F) -> _F: | |
| func._is_coroutine = asyncio.coroutines._is_coroutine # type: ignore | |
| return func | |
| class ThreadSensitiveContext: | |
| """Async context manager to manage context for thread sensitive mode | |
| This context manager controls which thread pool executor is used when in | |
| thread sensitive mode. By default, a single thread pool executor is shared | |
| within a process. | |
| The ThreadSensitiveContext() context manager may be used to specify a | |
| thread pool per context. | |
| This context manager is re-entrant, so only the outer-most call to | |
| ThreadSensitiveContext will set the context. | |
| Usage: | |
| >>> import time | |
| >>> async with ThreadSensitiveContext(): | |
| ... await sync_to_async(time.sleep, 1)() | |
| """ | |
| def __init__(self): | |
| self.token = None | |
| async def __aenter__(self): | |
| try: | |
| SyncToAsync.thread_sensitive_context.get() | |
| except LookupError: | |
| self.token = SyncToAsync.thread_sensitive_context.set(self) | |
| return self | |
| async def __aexit__(self, exc, value, tb): | |
| if not self.token: | |
| return | |
| executor = SyncToAsync.context_to_thread_executor.pop(self, None) | |
| if executor: | |
| executor.shutdown() | |
| SyncToAsync.thread_sensitive_context.reset(self.token) | |
| class AsyncToSync(Generic[_P, _R]): | |
| """ | |
| Utility class which turns an awaitable that only works on the thread with | |
| the event loop into a synchronous callable that works in a subthread. | |
| If the call stack contains an async loop, the code runs there. | |
| Otherwise, the code runs in a new loop in a new thread. | |
| Either way, this thread then pauses and waits to run any thread_sensitive | |
| code called from further down the call stack using SyncToAsync, before | |
| finally exiting once the async task returns. | |
| """ | |
| # Keeps a reference to the CurrentThreadExecutor in local context, so that | |
| # any sync_to_async inside the wrapped code can find it. | |
| executors: "Local" = Local() | |
| # When we can't find a CurrentThreadExecutor from the context, such as | |
| # inside create_task, we'll look it up here from the running event loop. | |
| loop_thread_executors: "Dict[asyncio.AbstractEventLoop, CurrentThreadExecutor]" = {} | |
| def __init__( | |
| self, | |
| awaitable: Union[ | |
| Callable[_P, Coroutine[Any, Any, _R]], | |
| Callable[_P, Awaitable[_R]], | |
| ], | |
| force_new_loop: bool = False, | |
| ): | |
| if not callable(awaitable) or ( | |
| not iscoroutinefunction(awaitable) | |
| and not iscoroutinefunction(getattr(awaitable, "__call__", awaitable)) | |
| ): | |
| # Python does not have very reliable detection of async functions | |
| # (lots of false negatives) so this is just a warning. | |
| warnings.warn( | |
| "async_to_sync was passed a non-async-marked callable", stacklevel=2 | |
| ) | |
| self.awaitable = awaitable | |
| try: | |
| self.__self__ = self.awaitable.__self__ # type: ignore[union-attr] | |
| except AttributeError: | |
| pass | |
| self.force_new_loop = force_new_loop | |
| self.main_event_loop = None | |
| try: | |
| self.main_event_loop = asyncio.get_running_loop() | |
| except RuntimeError: | |
| # There's no event loop in this thread. | |
| pass | |
| def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: | |
| __traceback_hide__ = True # noqa: F841 | |
| if not self.force_new_loop and not self.main_event_loop: | |
| # There's no event loop in this thread. Look for the threadlocal if | |
| # we're inside SyncToAsync | |
| main_event_loop_pid = getattr( | |
| SyncToAsync.threadlocal, "main_event_loop_pid", None | |
| ) | |
| # We make sure the parent loop is from the same process - if | |
| # they've forked, this is not going to be valid any more (#194) | |
| if main_event_loop_pid and main_event_loop_pid == os.getpid(): | |
| self.main_event_loop = getattr( | |
| SyncToAsync.threadlocal, "main_event_loop", None | |
| ) | |
| # You can't call AsyncToSync from a thread with a running event loop | |
| try: | |
| event_loop = asyncio.get_running_loop() | |
| except RuntimeError: | |
| pass | |
| else: | |
| if event_loop.is_running(): | |
| raise RuntimeError( | |
| "You cannot use AsyncToSync in the same thread as an async event loop - " | |
| "just await the async function directly." | |
| ) | |
| # Make a future for the return information | |
| call_result: "Future[_R]" = Future() | |
| # Make a CurrentThreadExecutor we'll use to idle in this thread - we | |
| # need one for every sync frame, even if there's one above us in the | |
| # same thread. | |
| old_executor = getattr(self.executors, "current", None) | |
| current_executor = CurrentThreadExecutor() | |
| self.executors.current = current_executor | |
| # Wrapping context in list so it can be reassigned from within | |
| # `main_wrap`. | |
| context = [contextvars.copy_context()] | |
| # Get task context so that parent task knows which task to propagate | |
| # an asyncio.CancelledError to. | |
| task_context = getattr(SyncToAsync.threadlocal, "task_context", None) | |
| loop = None | |
| # Use call_soon_threadsafe to schedule a synchronous callback on the | |
| # main event loop's thread if it's there, otherwise make a new loop | |
| # in this thread. | |
| try: | |
| awaitable = self.main_wrap( | |
| call_result, | |
| sys.exc_info(), | |
| task_context, | |
| context, | |
| *args, | |
| **kwargs, | |
| ) | |
| if not (self.main_event_loop and self.main_event_loop.is_running()): | |
| # Make our own event loop - in a new thread - and run inside that. | |
| loop = asyncio.new_event_loop() | |
| self.loop_thread_executors[loop] = current_executor | |
| loop_executor = ThreadPoolExecutor(max_workers=1) | |
| loop_future = loop_executor.submit( | |
| self._run_event_loop, loop, awaitable | |
| ) | |
| if current_executor: | |
| # Run the CurrentThreadExecutor until the future is done | |
| current_executor.run_until_future(loop_future) | |
| # Wait for future and/or allow for exception propagation | |
| loop_future.result() | |
| else: | |
| # Call it inside the existing loop | |
| self.main_event_loop.call_soon_threadsafe( | |
| self.main_event_loop.create_task, awaitable | |
| ) | |
| if current_executor: | |
| # Run the CurrentThreadExecutor until the future is done | |
| current_executor.run_until_future(call_result) | |
| finally: | |
| # Clean up any executor we were running | |
| if loop is not None: | |
| del self.loop_thread_executors[loop] | |
| _restore_context(context[0]) | |
| # Restore old current thread executor state | |
| self.executors.current = old_executor | |
| # Wait for results from the future. | |
| return call_result.result() | |
| def _run_event_loop(self, loop, coro): | |
| """ | |
| Runs the given event loop (designed to be called in a thread). | |
| """ | |
| asyncio.set_event_loop(loop) | |
| try: | |
| loop.run_until_complete(coro) | |
| finally: | |
| try: | |
| # mimic asyncio.run() behavior | |
| # cancel unexhausted async generators | |
| tasks = asyncio.all_tasks(loop) | |
| for task in tasks: | |
| task.cancel() | |
| async def gather(): | |
| await asyncio.gather(*tasks, return_exceptions=True) | |
| loop.run_until_complete(gather()) | |
| for task in tasks: | |
| if task.cancelled(): | |
| continue | |
| if task.exception() is not None: | |
| loop.call_exception_handler( | |
| { | |
| "message": "unhandled exception during loop shutdown", | |
| "exception": task.exception(), | |
| "task": task, | |
| } | |
| ) | |
| if hasattr(loop, "shutdown_asyncgens"): | |
| loop.run_until_complete(loop.shutdown_asyncgens()) | |
| finally: | |
| loop.close() | |
| asyncio.set_event_loop(self.main_event_loop) | |
| def __get__(self, parent: Any, objtype: Any) -> Callable[_P, _R]: | |
| """ | |
| Include self for methods | |
| """ | |
| func = functools.partial(self.__call__, parent) | |
| return functools.update_wrapper(func, self.awaitable) | |
| async def main_wrap( | |
| self, | |
| call_result: "Future[_R]", | |
| exc_info: "OptExcInfo", | |
| task_context: "Optional[List[asyncio.Task[Any]]]", | |
| context: List[contextvars.Context], | |
| *args: _P.args, | |
| **kwargs: _P.kwargs, | |
| ) -> None: | |
| """ | |
| Wraps the awaitable with something that puts the result into the | |
| result/exception future. | |
| """ | |
| __traceback_hide__ = True # noqa: F841 | |
| if context is not None: | |
| _restore_context(context[0]) | |
| current_task = asyncio.current_task() | |
| if current_task is not None and task_context is not None: | |
| task_context.append(current_task) | |
| try: | |
| # If we have an exception, run the function inside the except block | |
| # after raising it so exc_info is correctly populated. | |
| if exc_info[1]: | |
| try: | |
| raise exc_info[1] | |
| except BaseException: | |
| result = await self.awaitable(*args, **kwargs) | |
| else: | |
| result = await self.awaitable(*args, **kwargs) | |
| except BaseException as e: | |
| call_result.set_exception(e) | |
| else: | |
| call_result.set_result(result) | |
| finally: | |
| if current_task is not None and task_context is not None: | |
| task_context.remove(current_task) | |
| context[0] = contextvars.copy_context() | |
| class SyncToAsync(Generic[_P, _R]): | |
| """ | |
| Utility class which turns a synchronous callable into an awaitable that | |
| runs in a threadpool. It also sets a threadlocal inside the thread so | |
| calls to AsyncToSync can escape it. | |
| If thread_sensitive is passed, the code will run in the same thread as any | |
| outer code. This is needed for underlying Python code that is not | |
| threadsafe (for example, code which handles SQLite database connections). | |
| If the outermost program is async (i.e. SyncToAsync is outermost), then | |
| this will be a dedicated single sub-thread that all sync code runs in, | |
| one after the other. If the outermost program is sync (i.e. AsyncToSync is | |
| outermost), this will just be the main thread. This is achieved by idling | |
| with a CurrentThreadExecutor while AsyncToSync is blocking its sync parent, | |
| rather than just blocking. | |
| If executor is passed in, that will be used instead of the loop's default executor. | |
| In order to pass in an executor, thread_sensitive must be set to False, otherwise | |
| a TypeError will be raised. | |
| """ | |
| # Storage for main event loop references | |
| threadlocal = threading.local() | |
| # Single-thread executor for thread-sensitive code | |
| single_thread_executor = ThreadPoolExecutor(max_workers=1) | |
| # Maintain a contextvar for the current execution context. Optionally used | |
| # for thread sensitive mode. | |
| thread_sensitive_context: "contextvars.ContextVar[ThreadSensitiveContext]" = ( | |
| contextvars.ContextVar("thread_sensitive_context") | |
| ) | |
| # Contextvar that is used to detect if the single thread executor | |
| # would be awaited on while already being used in the same context | |
| deadlock_context: "contextvars.ContextVar[bool]" = contextvars.ContextVar( | |
| "deadlock_context" | |
| ) | |
| # Maintaining a weak reference to the context ensures that thread pools are | |
| # erased once the context goes out of scope. This terminates the thread pool. | |
| context_to_thread_executor: "weakref.WeakKeyDictionary[ThreadSensitiveContext, ThreadPoolExecutor]" = ( | |
| weakref.WeakKeyDictionary() | |
| ) | |
| def __init__( | |
| self, | |
| func: Callable[_P, _R], | |
| thread_sensitive: bool = True, | |
| executor: Optional["ThreadPoolExecutor"] = None, | |
| ) -> None: | |
| if ( | |
| not callable(func) | |
| or iscoroutinefunction(func) | |
| or iscoroutinefunction(getattr(func, "__call__", func)) | |
| ): | |
| raise TypeError("sync_to_async can only be applied to sync functions.") | |
| self.func = func | |
| functools.update_wrapper(self, func) | |
| self._thread_sensitive = thread_sensitive | |
| markcoroutinefunction(self) | |
| if thread_sensitive and executor is not None: | |
| raise TypeError("executor must not be set when thread_sensitive is True") | |
| self._executor = executor | |
| try: | |
| self.__self__ = func.__self__ # type: ignore | |
| except AttributeError: | |
| pass | |
| async def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: | |
| __traceback_hide__ = True # noqa: F841 | |
| loop = asyncio.get_running_loop() | |
| # Work out what thread to run the code in | |
| if self._thread_sensitive: | |
| current_thread_executor = getattr(AsyncToSync.executors, "current", None) | |
| if current_thread_executor: | |
| # If we have a parent sync thread above somewhere, use that | |
| executor = current_thread_executor | |
| elif self.thread_sensitive_context.get(None): | |
| # If we have a way of retrieving the current context, attempt | |
| # to use a per-context thread pool executor | |
| thread_sensitive_context = self.thread_sensitive_context.get() | |
| if thread_sensitive_context in self.context_to_thread_executor: | |
| # Re-use thread executor in current context | |
| executor = self.context_to_thread_executor[thread_sensitive_context] | |
| else: | |
| # Create new thread executor in current context | |
| executor = ThreadPoolExecutor(max_workers=1) | |
| self.context_to_thread_executor[thread_sensitive_context] = executor | |
| elif loop in AsyncToSync.loop_thread_executors: | |
| # Re-use thread executor for running loop | |
| executor = AsyncToSync.loop_thread_executors[loop] | |
| elif self.deadlock_context.get(False): | |
| raise RuntimeError( | |
| "Single thread executor already being used, would deadlock" | |
| ) | |
| else: | |
| # Otherwise, we run it in a fixed single thread | |
| executor = self.single_thread_executor | |
| self.deadlock_context.set(True) | |
| else: | |
| # Use the passed in executor, or the loop's default if it is None | |
| executor = self._executor | |
| context = contextvars.copy_context() | |
| child = functools.partial(self.func, *args, **kwargs) | |
| func = context.run | |
| task_context: List[asyncio.Task[Any]] = [] | |
| # Run the code in the right thread | |
| exec_coro = loop.run_in_executor( | |
| executor, | |
| functools.partial( | |
| self.thread_handler, | |
| loop, | |
| sys.exc_info(), | |
| task_context, | |
| func, | |
| child, | |
| ), | |
| ) | |
| ret: _R | |
| try: | |
| ret = await asyncio.shield(exec_coro) | |
| except asyncio.CancelledError: | |
| cancel_parent = True | |
| try: | |
| task = task_context[0] | |
| task.cancel() | |
| try: | |
| await task | |
| cancel_parent = False | |
| except asyncio.CancelledError: | |
| pass | |
| except IndexError: | |
| pass | |
| if exec_coro.done(): | |
| raise | |
| if cancel_parent: | |
| exec_coro.cancel() | |
| ret = await exec_coro | |
| finally: | |
| _restore_context(context) | |
| self.deadlock_context.set(False) | |
| return ret | |
| def __get__( | |
| self, parent: Any, objtype: Any | |
| ) -> Callable[_P, Coroutine[Any, Any, _R]]: | |
| """ | |
| Include self for methods | |
| """ | |
| func = functools.partial(self.__call__, parent) | |
| return functools.update_wrapper(func, self.func) | |
| def thread_handler(self, loop, exc_info, task_context, func, *args, **kwargs): | |
| """ | |
| Wraps the sync application with exception handling. | |
| """ | |
| __traceback_hide__ = True # noqa: F841 | |
| # Set the threadlocal for AsyncToSync | |
| self.threadlocal.main_event_loop = loop | |
| self.threadlocal.main_event_loop_pid = os.getpid() | |
| self.threadlocal.task_context = task_context | |
| # Run the function | |
| # If we have an exception, run the function inside the except block | |
| # after raising it so exc_info is correctly populated. | |
| if exc_info[1]: | |
| try: | |
| raise exc_info[1] | |
| except BaseException: | |
| return func(*args, **kwargs) | |
| else: | |
| return func(*args, **kwargs) | |
| def async_to_sync( | |
| *, | |
| force_new_loop: bool = False, | |
| ) -> Callable[ | |
| [Union[Callable[_P, Coroutine[Any, Any, _R]], Callable[_P, Awaitable[_R]]]], | |
| Callable[_P, _R], | |
| ]: | |
| ... | |
| def async_to_sync( | |
| awaitable: Union[ | |
| Callable[_P, Coroutine[Any, Any, _R]], | |
| Callable[_P, Awaitable[_R]], | |
| ], | |
| *, | |
| force_new_loop: bool = False, | |
| ) -> Callable[_P, _R]: | |
| ... | |
| def async_to_sync( | |
| awaitable: Optional[ | |
| Union[ | |
| Callable[_P, Coroutine[Any, Any, _R]], | |
| Callable[_P, Awaitable[_R]], | |
| ] | |
| ] = None, | |
| *, | |
| force_new_loop: bool = False, | |
| ) -> Union[ | |
| Callable[ | |
| [Union[Callable[_P, Coroutine[Any, Any, _R]], Callable[_P, Awaitable[_R]]]], | |
| Callable[_P, _R], | |
| ], | |
| Callable[_P, _R], | |
| ]: | |
| if awaitable is None: | |
| return lambda f: AsyncToSync( | |
| f, | |
| force_new_loop=force_new_loop, | |
| ) | |
| return AsyncToSync( | |
| awaitable, | |
| force_new_loop=force_new_loop, | |
| ) | |
| def sync_to_async( | |
| *, | |
| thread_sensitive: bool = True, | |
| executor: Optional["ThreadPoolExecutor"] = None, | |
| ) -> Callable[[Callable[_P, _R]], Callable[_P, Coroutine[Any, Any, _R]]]: | |
| ... | |
| def sync_to_async( | |
| func: Callable[_P, _R], | |
| *, | |
| thread_sensitive: bool = True, | |
| executor: Optional["ThreadPoolExecutor"] = None, | |
| ) -> Callable[_P, Coroutine[Any, Any, _R]]: | |
| ... | |
| def sync_to_async( | |
| func: Optional[Callable[_P, _R]] = None, | |
| *, | |
| thread_sensitive: bool = True, | |
| executor: Optional["ThreadPoolExecutor"] = None, | |
| ) -> Union[ | |
| Callable[[Callable[_P, _R]], Callable[_P, Coroutine[Any, Any, _R]]], | |
| Callable[_P, Coroutine[Any, Any, _R]], | |
| ]: | |
| if func is None: | |
| return lambda f: SyncToAsync( | |
| f, | |
| thread_sensitive=thread_sensitive, | |
| executor=executor, | |
| ) | |
| return SyncToAsync( | |
| func, | |
| thread_sensitive=thread_sensitive, | |
| executor=executor, | |
| ) | |