| | import asyncio
|
| | import contextlib
|
| | import inspect
|
| | import warnings
|
| | from typing import (
|
| | Any,
|
| | Awaitable,
|
| | Callable,
|
| | Dict,
|
| | Iterator,
|
| | Optional,
|
| | Protocol,
|
| | Type,
|
| | Union,
|
| | )
|
| |
|
| | import pytest
|
| |
|
| | from .test_utils import (
|
| | BaseTestServer,
|
| | RawTestServer,
|
| | TestClient,
|
| | TestServer,
|
| | loop_context,
|
| | setup_test_loop,
|
| | teardown_test_loop,
|
| | unused_port as _unused_port,
|
| | )
|
| | from .web import Application
|
| | from .web_protocol import _RequestHandler
|
| |
|
| | try:
|
| | import uvloop
|
| | except ImportError:
|
| | uvloop = None
|
| |
|
| |
|
| | class AiohttpClient(Protocol):
|
| | def __call__(
|
| | self,
|
| | __param: Union[Application, BaseTestServer],
|
| | *,
|
| | server_kwargs: Optional[Dict[str, Any]] = None,
|
| | **kwargs: Any
|
| | ) -> Awaitable[TestClient]: ...
|
| |
|
| |
|
| | class AiohttpServer(Protocol):
|
| | def __call__(
|
| | self, app: Application, *, port: Optional[int] = None, **kwargs: Any
|
| | ) -> Awaitable[TestServer]: ...
|
| |
|
| |
|
| | class AiohttpRawServer(Protocol):
|
| | def __call__(
|
| | self, handler: _RequestHandler, *, port: Optional[int] = None, **kwargs: Any
|
| | ) -> Awaitable[RawTestServer]: ...
|
| |
|
| |
|
| | def pytest_addoption(parser):
|
| | parser.addoption(
|
| | "--aiohttp-fast",
|
| | action="store_true",
|
| | default=False,
|
| | help="run tests faster by disabling extra checks",
|
| | )
|
| | parser.addoption(
|
| | "--aiohttp-loop",
|
| | action="store",
|
| | default="pyloop",
|
| | help="run tests with specific loop: pyloop, uvloop or all",
|
| | )
|
| | parser.addoption(
|
| | "--aiohttp-enable-loop-debug",
|
| | action="store_true",
|
| | default=False,
|
| | help="enable event loop debug mode",
|
| | )
|
| |
|
| |
|
| | def pytest_fixture_setup(fixturedef):
|
| | """Set up pytest fixture.
|
| |
|
| | Allow fixtures to be coroutines. Run coroutine fixtures in an event loop.
|
| | """
|
| | func = fixturedef.func
|
| |
|
| | if inspect.isasyncgenfunction(func):
|
| |
|
| | is_async_gen = True
|
| | elif asyncio.iscoroutinefunction(func):
|
| |
|
| | is_async_gen = False
|
| | else:
|
| |
|
| | return
|
| |
|
| | strip_request = False
|
| | if "request" not in fixturedef.argnames:
|
| | fixturedef.argnames += ("request",)
|
| | strip_request = True
|
| |
|
| | def wrapper(*args, **kwargs):
|
| | request = kwargs["request"]
|
| | if strip_request:
|
| | del kwargs["request"]
|
| |
|
| |
|
| |
|
| |
|
| | if "loop" not in request.fixturenames:
|
| | raise Exception(
|
| | "Asynchronous fixtures must depend on the 'loop' fixture or "
|
| | "be used in tests depending from it."
|
| | )
|
| |
|
| | _loop = request.getfixturevalue("loop")
|
| |
|
| | if is_async_gen:
|
| |
|
| |
|
| | gen = func(*args, **kwargs)
|
| |
|
| | def finalizer():
|
| | try:
|
| | return _loop.run_until_complete(gen.__anext__())
|
| | except StopAsyncIteration:
|
| | pass
|
| |
|
| | request.addfinalizer(finalizer)
|
| | return _loop.run_until_complete(gen.__anext__())
|
| | else:
|
| | return _loop.run_until_complete(func(*args, **kwargs))
|
| |
|
| | fixturedef.func = wrapper
|
| |
|
| |
|
| | @pytest.fixture
|
| | def fast(request):
|
| | """--fast config option"""
|
| | return request.config.getoption("--aiohttp-fast")
|
| |
|
| |
|
| | @pytest.fixture
|
| | def loop_debug(request):
|
| | """--enable-loop-debug config option"""
|
| | return request.config.getoption("--aiohttp-enable-loop-debug")
|
| |
|
| |
|
| | @contextlib.contextmanager
|
| | def _runtime_warning_context():
|
| | """Context manager which checks for RuntimeWarnings.
|
| |
|
| | This exists specifically to
|
| | avoid "coroutine 'X' was never awaited" warnings being missed.
|
| |
|
| | If RuntimeWarnings occur in the context a RuntimeError is raised.
|
| | """
|
| | with warnings.catch_warnings(record=True) as _warnings:
|
| | yield
|
| | rw = [
|
| | "{w.filename}:{w.lineno}:{w.message}".format(w=w)
|
| | for w in _warnings
|
| | if w.category == RuntimeWarning
|
| | ]
|
| | if rw:
|
| | raise RuntimeError(
|
| | "{} Runtime Warning{},\n{}".format(
|
| | len(rw), "" if len(rw) == 1 else "s", "\n".join(rw)
|
| | )
|
| | )
|
| |
|
| |
|
| | @contextlib.contextmanager
|
| | def _passthrough_loop_context(loop, fast=False):
|
| | """Passthrough loop context.
|
| |
|
| | Sets up and tears down a loop unless one is passed in via the loop
|
| | argument when it's passed straight through.
|
| | """
|
| | if loop:
|
| |
|
| | yield loop
|
| | else:
|
| |
|
| | loop = setup_test_loop()
|
| | yield loop
|
| | teardown_test_loop(loop, fast=fast)
|
| |
|
| |
|
| | def pytest_pycollect_makeitem(collector, name, obj):
|
| | """Fix pytest collecting for coroutines."""
|
| | if collector.funcnamefilter(name) and asyncio.iscoroutinefunction(obj):
|
| | return list(collector._genfunctions(name, obj))
|
| |
|
| |
|
| | def pytest_pyfunc_call(pyfuncitem):
|
| | """Run coroutines in an event loop instead of a normal function call."""
|
| | fast = pyfuncitem.config.getoption("--aiohttp-fast")
|
| | if asyncio.iscoroutinefunction(pyfuncitem.function):
|
| | existing_loop = pyfuncitem.funcargs.get(
|
| | "proactor_loop"
|
| | ) or pyfuncitem.funcargs.get("loop", None)
|
| | with _runtime_warning_context():
|
| | with _passthrough_loop_context(existing_loop, fast=fast) as _loop:
|
| | testargs = {
|
| | arg: pyfuncitem.funcargs[arg]
|
| | for arg in pyfuncitem._fixtureinfo.argnames
|
| | }
|
| | _loop.run_until_complete(pyfuncitem.obj(**testargs))
|
| |
|
| | return True
|
| |
|
| |
|
| | def pytest_generate_tests(metafunc):
|
| | if "loop_factory" not in metafunc.fixturenames:
|
| | return
|
| |
|
| | loops = metafunc.config.option.aiohttp_loop
|
| | avail_factories: Dict[str, Type[asyncio.AbstractEventLoopPolicy]]
|
| | avail_factories = {"pyloop": asyncio.DefaultEventLoopPolicy}
|
| |
|
| | if uvloop is not None:
|
| | avail_factories["uvloop"] = uvloop.EventLoopPolicy
|
| |
|
| | if loops == "all":
|
| | loops = "pyloop,uvloop?"
|
| |
|
| | factories = {}
|
| | for name in loops.split(","):
|
| | required = not name.endswith("?")
|
| | name = name.strip(" ?")
|
| | if name not in avail_factories:
|
| | if required:
|
| | raise ValueError(
|
| | "Unknown loop '%s', available loops: %s"
|
| | % (name, list(factories.keys()))
|
| | )
|
| | else:
|
| | continue
|
| | factories[name] = avail_factories[name]
|
| | metafunc.parametrize(
|
| | "loop_factory", list(factories.values()), ids=list(factories.keys())
|
| | )
|
| |
|
| |
|
| | @pytest.fixture
|
| | def loop(loop_factory, fast, loop_debug):
|
| | """Return an instance of the event loop."""
|
| | policy = loop_factory()
|
| | asyncio.set_event_loop_policy(policy)
|
| | with loop_context(fast=fast) as _loop:
|
| | if loop_debug:
|
| | _loop.set_debug(True)
|
| | asyncio.set_event_loop(_loop)
|
| | yield _loop
|
| |
|
| |
|
| | @pytest.fixture
|
| | def proactor_loop():
|
| | policy = asyncio.WindowsProactorEventLoopPolicy()
|
| | asyncio.set_event_loop_policy(policy)
|
| |
|
| | with loop_context(policy.new_event_loop) as _loop:
|
| | asyncio.set_event_loop(_loop)
|
| | yield _loop
|
| |
|
| |
|
| | @pytest.fixture
|
| | def unused_port(aiohttp_unused_port: Callable[[], int]) -> Callable[[], int]:
|
| | warnings.warn(
|
| | "Deprecated, use aiohttp_unused_port fixture instead",
|
| | DeprecationWarning,
|
| | stacklevel=2,
|
| | )
|
| | return aiohttp_unused_port
|
| |
|
| |
|
| | @pytest.fixture
|
| | def aiohttp_unused_port() -> Callable[[], int]:
|
| | """Return a port that is unused on the current host."""
|
| | return _unused_port
|
| |
|
| |
|
| | @pytest.fixture
|
| | def aiohttp_server(loop: asyncio.AbstractEventLoop) -> Iterator[AiohttpServer]:
|
| | """Factory to create a TestServer instance, given an app.
|
| |
|
| | aiohttp_server(app, **kwargs)
|
| | """
|
| | servers = []
|
| |
|
| | async def go(
|
| | app: Application, *, port: Optional[int] = None, **kwargs: Any
|
| | ) -> TestServer:
|
| | server = TestServer(app, port=port)
|
| | await server.start_server(loop=loop, **kwargs)
|
| | servers.append(server)
|
| | return server
|
| |
|
| | yield go
|
| |
|
| | async def finalize() -> None:
|
| | while servers:
|
| | await servers.pop().close()
|
| |
|
| | loop.run_until_complete(finalize())
|
| |
|
| |
|
| | @pytest.fixture
|
| | def test_server(aiohttp_server):
|
| | warnings.warn(
|
| | "Deprecated, use aiohttp_server fixture instead",
|
| | DeprecationWarning,
|
| | stacklevel=2,
|
| | )
|
| | return aiohttp_server
|
| |
|
| |
|
| | @pytest.fixture
|
| | def aiohttp_raw_server(loop: asyncio.AbstractEventLoop) -> Iterator[AiohttpRawServer]:
|
| | """Factory to create a RawTestServer instance, given a web handler.
|
| |
|
| | aiohttp_raw_server(handler, **kwargs)
|
| | """
|
| | servers = []
|
| |
|
| | async def go(
|
| | handler: _RequestHandler, *, port: Optional[int] = None, **kwargs: Any
|
| | ) -> RawTestServer:
|
| | server = RawTestServer(handler, port=port)
|
| | await server.start_server(loop=loop, **kwargs)
|
| | servers.append(server)
|
| | return server
|
| |
|
| | yield go
|
| |
|
| | async def finalize() -> None:
|
| | while servers:
|
| | await servers.pop().close()
|
| |
|
| | loop.run_until_complete(finalize())
|
| |
|
| |
|
| | @pytest.fixture
|
| | def raw_test_server(
|
| | aiohttp_raw_server,
|
| | ):
|
| | warnings.warn(
|
| | "Deprecated, use aiohttp_raw_server fixture instead",
|
| | DeprecationWarning,
|
| | stacklevel=2,
|
| | )
|
| | return aiohttp_raw_server
|
| |
|
| |
|
| | @pytest.fixture
|
| | def aiohttp_client(
|
| | loop: asyncio.AbstractEventLoop,
|
| | ) -> Iterator[AiohttpClient]:
|
| | """Factory to create a TestClient instance.
|
| |
|
| | aiohttp_client(app, **kwargs)
|
| | aiohttp_client(server, **kwargs)
|
| | aiohttp_client(raw_server, **kwargs)
|
| | """
|
| | clients = []
|
| |
|
| | async def go(
|
| | __param: Union[Application, BaseTestServer],
|
| | *args: Any,
|
| | server_kwargs: Optional[Dict[str, Any]] = None,
|
| | **kwargs: Any
|
| | ) -> TestClient:
|
| |
|
| | if isinstance(__param, Callable) and not isinstance(
|
| | __param, (Application, BaseTestServer)
|
| | ):
|
| | __param = __param(loop, *args, **kwargs)
|
| | kwargs = {}
|
| | else:
|
| | assert not args, "args should be empty"
|
| |
|
| | if isinstance(__param, Application):
|
| | server_kwargs = server_kwargs or {}
|
| | server = TestServer(__param, loop=loop, **server_kwargs)
|
| | client = TestClient(server, loop=loop, **kwargs)
|
| | elif isinstance(__param, BaseTestServer):
|
| | client = TestClient(__param, loop=loop, **kwargs)
|
| | else:
|
| | raise ValueError("Unknown argument type: %r" % type(__param))
|
| |
|
| | await client.start_server()
|
| | clients.append(client)
|
| | return client
|
| |
|
| | yield go
|
| |
|
| | async def finalize() -> None:
|
| | while clients:
|
| | await clients.pop().close()
|
| |
|
| | loop.run_until_complete(finalize())
|
| |
|
| |
|
| | @pytest.fixture
|
| | def test_client(aiohttp_client):
|
| | warnings.warn(
|
| | "Deprecated, use aiohttp_client fixture instead",
|
| | DeprecationWarning,
|
| | stacklevel=2,
|
| | )
|
| | return aiohttp_client
|
| |
|