diff --git a/.venv/lib/python3.11/site-packages/anyio/_core/__pycache__/_eventloop.cpython-311.pyc b/.venv/lib/python3.11/site-packages/anyio/_core/__pycache__/_eventloop.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c1afbd90795ad93490cb27fa7d840c641dcaed1a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/anyio/_core/__pycache__/_eventloop.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/anyio/_core/__pycache__/_fileio.cpython-311.pyc b/.venv/lib/python3.11/site-packages/anyio/_core/__pycache__/_fileio.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a021a29acdadd4464e37a7983d0da61af39da7a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/anyio/_core/__pycache__/_fileio.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/anyio/_core/__pycache__/_streams.cpython-311.pyc b/.venv/lib/python3.11/site-packages/anyio/_core/__pycache__/_streams.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5484fad52101caf7a5cd6412353055bc279c4e24 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/anyio/_core/__pycache__/_streams.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/anyio/_core/__pycache__/_testing.cpython-311.pyc b/.venv/lib/python3.11/site-packages/anyio/_core/__pycache__/_testing.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..173f69054e7cef338bb22502257fb2548833d3ad Binary files /dev/null and b/.venv/lib/python3.11/site-packages/anyio/_core/__pycache__/_testing.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/anyio/abc/__init__.py b/.venv/lib/python3.11/site-packages/anyio/abc/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3d3b61cc9a06019489fe94bdd00c4ff904805136 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/anyio/abc/__init__.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +from ._eventloop import AsyncBackend as AsyncBackend +from ._resources import AsyncResource as AsyncResource +from ._sockets import ConnectedUDPSocket as ConnectedUDPSocket +from ._sockets import ConnectedUNIXDatagramSocket as ConnectedUNIXDatagramSocket +from ._sockets import IPAddressType as IPAddressType +from ._sockets import IPSockAddrType as IPSockAddrType +from ._sockets import SocketAttribute as SocketAttribute +from ._sockets import SocketListener as SocketListener +from ._sockets import SocketStream as SocketStream +from ._sockets import UDPPacketType as UDPPacketType +from ._sockets import UDPSocket as UDPSocket +from ._sockets import UNIXDatagramPacketType as UNIXDatagramPacketType +from ._sockets import UNIXDatagramSocket as UNIXDatagramSocket +from ._sockets import UNIXSocketStream as UNIXSocketStream +from ._streams import AnyByteReceiveStream as AnyByteReceiveStream +from ._streams import AnyByteSendStream as AnyByteSendStream +from ._streams import AnyByteStream as AnyByteStream +from ._streams import AnyUnreliableByteReceiveStream as AnyUnreliableByteReceiveStream +from ._streams import AnyUnreliableByteSendStream as AnyUnreliableByteSendStream +from ._streams import AnyUnreliableByteStream as AnyUnreliableByteStream +from ._streams import ByteReceiveStream as ByteReceiveStream +from ._streams import ByteSendStream as ByteSendStream +from ._streams import ByteStream as ByteStream +from ._streams import Listener as Listener +from ._streams import ObjectReceiveStream as ObjectReceiveStream +from ._streams import ObjectSendStream as ObjectSendStream +from ._streams import ObjectStream as ObjectStream +from ._streams import UnreliableObjectReceiveStream as UnreliableObjectReceiveStream +from ._streams import UnreliableObjectSendStream as UnreliableObjectSendStream +from ._streams import UnreliableObjectStream as UnreliableObjectStream +from ._subprocesses import Process as Process +from ._tasks import TaskGroup as TaskGroup +from ._tasks import TaskStatus as TaskStatus +from ._testing import TestRunner as TestRunner + +# Re-exported here, for backwards compatibility +# isort: off +from .._core._synchronization import ( + CapacityLimiter as CapacityLimiter, + Condition as Condition, + Event as Event, + Lock as Lock, + Semaphore as Semaphore, +) +from .._core._tasks import CancelScope as CancelScope +from ..from_thread import BlockingPortal as BlockingPortal + +# Re-export imports so they look like they live directly in this package +for __value in list(locals().values()): + if getattr(__value, "__module__", "").startswith("anyio.abc."): + __value.__module__ = __name__ + +del __value diff --git a/.venv/lib/python3.11/site-packages/anyio/abc/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/anyio/abc/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..68f9a4a90a57328971dbef17e8952ffb632e63fd Binary files /dev/null and b/.venv/lib/python3.11/site-packages/anyio/abc/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/anyio/abc/__pycache__/_eventloop.cpython-311.pyc b/.venv/lib/python3.11/site-packages/anyio/abc/__pycache__/_eventloop.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c02d1450800b19bf06d61a3d3a6c40243ac866a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/anyio/abc/__pycache__/_eventloop.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/anyio/abc/__pycache__/_resources.cpython-311.pyc b/.venv/lib/python3.11/site-packages/anyio/abc/__pycache__/_resources.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f34b4189d85a70e4b7b609fd967003a982c11eec Binary files /dev/null and b/.venv/lib/python3.11/site-packages/anyio/abc/__pycache__/_resources.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/anyio/abc/__pycache__/_sockets.cpython-311.pyc b/.venv/lib/python3.11/site-packages/anyio/abc/__pycache__/_sockets.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..974ca31fb2e5bdfa284d871b19809046ca7a152c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/anyio/abc/__pycache__/_sockets.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/anyio/abc/__pycache__/_streams.cpython-311.pyc b/.venv/lib/python3.11/site-packages/anyio/abc/__pycache__/_streams.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..59036ed2baeed8ad5df3df6e7aea4ccb48656e28 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/anyio/abc/__pycache__/_streams.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/anyio/abc/__pycache__/_subprocesses.cpython-311.pyc b/.venv/lib/python3.11/site-packages/anyio/abc/__pycache__/_subprocesses.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d42a7c91d66f48c892698fd2c3a5d418be13c0a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/anyio/abc/__pycache__/_subprocesses.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/anyio/abc/__pycache__/_tasks.cpython-311.pyc b/.venv/lib/python3.11/site-packages/anyio/abc/__pycache__/_tasks.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aec822d351bb99244a1d110273e26207cf7a797c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/anyio/abc/__pycache__/_tasks.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/anyio/abc/__pycache__/_testing.cpython-311.pyc b/.venv/lib/python3.11/site-packages/anyio/abc/__pycache__/_testing.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dde0995a0563dc20aecfd24b6b2875efa3a4a3a4 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/anyio/abc/__pycache__/_testing.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/anyio/abc/_resources.py b/.venv/lib/python3.11/site-packages/anyio/abc/_resources.py new file mode 100644 index 0000000000000000000000000000000000000000..10df115a7b9f975493476da763cc1e26dbd822e5 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/anyio/abc/_resources.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +from abc import ABCMeta, abstractmethod +from types import TracebackType +from typing import TypeVar + +T = TypeVar("T") + + +class AsyncResource(metaclass=ABCMeta): + """ + Abstract base class for all closeable asynchronous resources. + + Works as an asynchronous context manager which returns the instance itself on enter, + and calls :meth:`aclose` on exit. + """ + + __slots__ = () + + async def __aenter__(self: T) -> T: + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + await self.aclose() + + @abstractmethod + async def aclose(self) -> None: + """Close the resource.""" diff --git a/.venv/lib/python3.11/site-packages/anyio/abc/_sockets.py b/.venv/lib/python3.11/site-packages/anyio/abc/_sockets.py new file mode 100644 index 0000000000000000000000000000000000000000..1c6a450cdcd1e66ed55685438e8f6f393ccfa828 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/anyio/abc/_sockets.py @@ -0,0 +1,194 @@ +from __future__ import annotations + +import socket +from abc import abstractmethod +from collections.abc import Callable, Collection, Mapping +from contextlib import AsyncExitStack +from io import IOBase +from ipaddress import IPv4Address, IPv6Address +from socket import AddressFamily +from types import TracebackType +from typing import Any, TypeVar, Union + +from .._core._typedattr import ( + TypedAttributeProvider, + TypedAttributeSet, + typed_attribute, +) +from ._streams import ByteStream, Listener, UnreliableObjectStream +from ._tasks import TaskGroup + +IPAddressType = Union[str, IPv4Address, IPv6Address] +IPSockAddrType = tuple[str, int] +SockAddrType = Union[IPSockAddrType, str] +UDPPacketType = tuple[bytes, IPSockAddrType] +UNIXDatagramPacketType = tuple[bytes, str] +T_Retval = TypeVar("T_Retval") + + +class _NullAsyncContextManager: + async def __aenter__(self) -> None: + pass + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + return None + + +class SocketAttribute(TypedAttributeSet): + #: the address family of the underlying socket + family: AddressFamily = typed_attribute() + #: the local socket address of the underlying socket + local_address: SockAddrType = typed_attribute() + #: for IP addresses, the local port the underlying socket is bound to + local_port: int = typed_attribute() + #: the underlying stdlib socket object + raw_socket: socket.socket = typed_attribute() + #: the remote address the underlying socket is connected to + remote_address: SockAddrType = typed_attribute() + #: for IP addresses, the remote port the underlying socket is connected to + remote_port: int = typed_attribute() + + +class _SocketProvider(TypedAttributeProvider): + @property + def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: + from .._core._sockets import convert_ipv6_sockaddr as convert + + attributes: dict[Any, Callable[[], Any]] = { + SocketAttribute.family: lambda: self._raw_socket.family, + SocketAttribute.local_address: lambda: convert( + self._raw_socket.getsockname() + ), + SocketAttribute.raw_socket: lambda: self._raw_socket, + } + try: + peername: tuple[str, int] | None = convert(self._raw_socket.getpeername()) + except OSError: + peername = None + + # Provide the remote address for connected sockets + if peername is not None: + attributes[SocketAttribute.remote_address] = lambda: peername + + # Provide local and remote ports for IP based sockets + if self._raw_socket.family in (AddressFamily.AF_INET, AddressFamily.AF_INET6): + attributes[SocketAttribute.local_port] = ( + lambda: self._raw_socket.getsockname()[1] + ) + if peername is not None: + remote_port = peername[1] + attributes[SocketAttribute.remote_port] = lambda: remote_port + + return attributes + + @property + @abstractmethod + def _raw_socket(self) -> socket.socket: + pass + + +class SocketStream(ByteStream, _SocketProvider): + """ + Transports bytes over a socket. + + Supports all relevant extra attributes from :class:`~SocketAttribute`. + """ + + +class UNIXSocketStream(SocketStream): + @abstractmethod + async def send_fds(self, message: bytes, fds: Collection[int | IOBase]) -> None: + """ + Send file descriptors along with a message to the peer. + + :param message: a non-empty bytestring + :param fds: a collection of files (either numeric file descriptors or open file + or socket objects) + """ + + @abstractmethod + async def receive_fds(self, msglen: int, maxfds: int) -> tuple[bytes, list[int]]: + """ + Receive file descriptors along with a message from the peer. + + :param msglen: length of the message to expect from the peer + :param maxfds: maximum number of file descriptors to expect from the peer + :return: a tuple of (message, file descriptors) + """ + + +class SocketListener(Listener[SocketStream], _SocketProvider): + """ + Listens to incoming socket connections. + + Supports all relevant extra attributes from :class:`~SocketAttribute`. + """ + + @abstractmethod + async def accept(self) -> SocketStream: + """Accept an incoming connection.""" + + async def serve( + self, + handler: Callable[[SocketStream], Any], + task_group: TaskGroup | None = None, + ) -> None: + from .. import create_task_group + + async with AsyncExitStack() as stack: + if task_group is None: + task_group = await stack.enter_async_context(create_task_group()) + + while True: + stream = await self.accept() + task_group.start_soon(handler, stream) + + +class UDPSocket(UnreliableObjectStream[UDPPacketType], _SocketProvider): + """ + Represents an unconnected UDP socket. + + Supports all relevant extra attributes from :class:`~SocketAttribute`. + """ + + async def sendto(self, data: bytes, host: str, port: int) -> None: + """ + Alias for :meth:`~.UnreliableObjectSendStream.send` ((data, (host, port))). + + """ + return await self.send((data, (host, port))) + + +class ConnectedUDPSocket(UnreliableObjectStream[bytes], _SocketProvider): + """ + Represents an connected UDP socket. + + Supports all relevant extra attributes from :class:`~SocketAttribute`. + """ + + +class UNIXDatagramSocket( + UnreliableObjectStream[UNIXDatagramPacketType], _SocketProvider +): + """ + Represents an unconnected Unix datagram socket. + + Supports all relevant extra attributes from :class:`~SocketAttribute`. + """ + + async def sendto(self, data: bytes, path: str) -> None: + """Alias for :meth:`~.UnreliableObjectSendStream.send` ((data, path)).""" + return await self.send((data, path)) + + +class ConnectedUNIXDatagramSocket(UnreliableObjectStream[bytes], _SocketProvider): + """ + Represents a connected Unix datagram socket. + + Supports all relevant extra attributes from :class:`~SocketAttribute`. + """ diff --git a/.venv/lib/python3.11/site-packages/anyio/abc/_streams.py b/.venv/lib/python3.11/site-packages/anyio/abc/_streams.py new file mode 100644 index 0000000000000000000000000000000000000000..8c638683a49245b377ed1917e0385284be5a46dd --- /dev/null +++ b/.venv/lib/python3.11/site-packages/anyio/abc/_streams.py @@ -0,0 +1,203 @@ +from __future__ import annotations + +from abc import abstractmethod +from collections.abc import Callable +from typing import Any, Generic, TypeVar, Union + +from .._core._exceptions import EndOfStream +from .._core._typedattr import TypedAttributeProvider +from ._resources import AsyncResource +from ._tasks import TaskGroup + +T_Item = TypeVar("T_Item") +T_co = TypeVar("T_co", covariant=True) +T_contra = TypeVar("T_contra", contravariant=True) + + +class UnreliableObjectReceiveStream( + Generic[T_co], AsyncResource, TypedAttributeProvider +): + """ + An interface for receiving objects. + + This interface makes no guarantees that the received messages arrive in the order in + which they were sent, or that no messages are missed. + + Asynchronously iterating over objects of this type will yield objects matching the + given type parameter. + """ + + def __aiter__(self) -> UnreliableObjectReceiveStream[T_co]: + return self + + async def __anext__(self) -> T_co: + try: + return await self.receive() + except EndOfStream: + raise StopAsyncIteration + + @abstractmethod + async def receive(self) -> T_co: + """ + Receive the next item. + + :raises ~anyio.ClosedResourceError: if the receive stream has been explicitly + closed + :raises ~anyio.EndOfStream: if this stream has been closed from the other end + :raises ~anyio.BrokenResourceError: if this stream has been rendered unusable + due to external causes + """ + + +class UnreliableObjectSendStream( + Generic[T_contra], AsyncResource, TypedAttributeProvider +): + """ + An interface for sending objects. + + This interface makes no guarantees that the messages sent will reach the + recipient(s) in the same order in which they were sent, or at all. + """ + + @abstractmethod + async def send(self, item: T_contra) -> None: + """ + Send an item to the peer(s). + + :param item: the item to send + :raises ~anyio.ClosedResourceError: if the send stream has been explicitly + closed + :raises ~anyio.BrokenResourceError: if this stream has been rendered unusable + due to external causes + """ + + +class UnreliableObjectStream( + UnreliableObjectReceiveStream[T_Item], UnreliableObjectSendStream[T_Item] +): + """ + A bidirectional message stream which does not guarantee the order or reliability of + message delivery. + """ + + +class ObjectReceiveStream(UnreliableObjectReceiveStream[T_co]): + """ + A receive message stream which guarantees that messages are received in the same + order in which they were sent, and that no messages are missed. + """ + + +class ObjectSendStream(UnreliableObjectSendStream[T_contra]): + """ + A send message stream which guarantees that messages are delivered in the same order + in which they were sent, without missing any messages in the middle. + """ + + +class ObjectStream( + ObjectReceiveStream[T_Item], + ObjectSendStream[T_Item], + UnreliableObjectStream[T_Item], +): + """ + A bidirectional message stream which guarantees the order and reliability of message + delivery. + """ + + @abstractmethod + async def send_eof(self) -> None: + """ + Send an end-of-file indication to the peer. + + You should not try to send any further data to this stream after calling this + method. This method is idempotent (does nothing on successive calls). + """ + + +class ByteReceiveStream(AsyncResource, TypedAttributeProvider): + """ + An interface for receiving bytes from a single peer. + + Iterating this byte stream will yield a byte string of arbitrary length, but no more + than 65536 bytes. + """ + + def __aiter__(self) -> ByteReceiveStream: + return self + + async def __anext__(self) -> bytes: + try: + return await self.receive() + except EndOfStream: + raise StopAsyncIteration + + @abstractmethod + async def receive(self, max_bytes: int = 65536) -> bytes: + """ + Receive at most ``max_bytes`` bytes from the peer. + + .. note:: Implementors of this interface should not return an empty + :class:`bytes` object, and users should ignore them. + + :param max_bytes: maximum number of bytes to receive + :return: the received bytes + :raises ~anyio.EndOfStream: if this stream has been closed from the other end + """ + + +class ByteSendStream(AsyncResource, TypedAttributeProvider): + """An interface for sending bytes to a single peer.""" + + @abstractmethod + async def send(self, item: bytes) -> None: + """ + Send the given bytes to the peer. + + :param item: the bytes to send + """ + + +class ByteStream(ByteReceiveStream, ByteSendStream): + """A bidirectional byte stream.""" + + @abstractmethod + async def send_eof(self) -> None: + """ + Send an end-of-file indication to the peer. + + You should not try to send any further data to this stream after calling this + method. This method is idempotent (does nothing on successive calls). + """ + + +#: Type alias for all unreliable bytes-oriented receive streams. +AnyUnreliableByteReceiveStream = Union[ + UnreliableObjectReceiveStream[bytes], ByteReceiveStream +] +#: Type alias for all unreliable bytes-oriented send streams. +AnyUnreliableByteSendStream = Union[UnreliableObjectSendStream[bytes], ByteSendStream] +#: Type alias for all unreliable bytes-oriented streams. +AnyUnreliableByteStream = Union[UnreliableObjectStream[bytes], ByteStream] +#: Type alias for all bytes-oriented receive streams. +AnyByteReceiveStream = Union[ObjectReceiveStream[bytes], ByteReceiveStream] +#: Type alias for all bytes-oriented send streams. +AnyByteSendStream = Union[ObjectSendStream[bytes], ByteSendStream] +#: Type alias for all bytes-oriented streams. +AnyByteStream = Union[ObjectStream[bytes], ByteStream] + + +class Listener(Generic[T_co], AsyncResource, TypedAttributeProvider): + """An interface for objects that let you accept incoming connections.""" + + @abstractmethod + async def serve( + self, handler: Callable[[T_co], Any], task_group: TaskGroup | None = None + ) -> None: + """ + Accept incoming connections as they come in and start tasks to handle them. + + :param handler: a callable that will be used to handle each accepted connection + :param task_group: the task group that will be used to start tasks for handling + each accepted connection (if omitted, an ad-hoc task group will be created) + """ diff --git a/.venv/lib/python3.11/site-packages/anyio/abc/_tasks.py b/.venv/lib/python3.11/site-packages/anyio/abc/_tasks.py new file mode 100644 index 0000000000000000000000000000000000000000..f6e5c40c7ff2a878d4ce7a37364b6d93974b8ee8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/anyio/abc/_tasks.py @@ -0,0 +1,101 @@ +from __future__ import annotations + +import sys +from abc import ABCMeta, abstractmethod +from collections.abc import Awaitable, Callable +from types import TracebackType +from typing import TYPE_CHECKING, Any, Protocol, TypeVar, overload + +if sys.version_info >= (3, 11): + from typing import TypeVarTuple, Unpack +else: + from typing_extensions import TypeVarTuple, Unpack + +if TYPE_CHECKING: + from .._core._tasks import CancelScope + +T_Retval = TypeVar("T_Retval") +T_contra = TypeVar("T_contra", contravariant=True) +PosArgsT = TypeVarTuple("PosArgsT") + + +class TaskStatus(Protocol[T_contra]): + @overload + def started(self: TaskStatus[None]) -> None: ... + + @overload + def started(self, value: T_contra) -> None: ... + + def started(self, value: T_contra | None = None) -> None: + """ + Signal that the task has started. + + :param value: object passed back to the starter of the task + """ + + +class TaskGroup(metaclass=ABCMeta): + """ + Groups several asynchronous tasks together. + + :ivar cancel_scope: the cancel scope inherited by all child tasks + :vartype cancel_scope: CancelScope + + .. note:: On asyncio, support for eager task factories is considered to be + **experimental**. In particular, they don't follow the usual semantics of new + tasks being scheduled on the next iteration of the event loop, and may thus + cause unexpected behavior in code that wasn't written with such semantics in + mind. + """ + + cancel_scope: CancelScope + + @abstractmethod + def start_soon( + self, + func: Callable[[Unpack[PosArgsT]], Awaitable[Any]], + *args: Unpack[PosArgsT], + name: object = None, + ) -> None: + """ + Start a new task in this task group. + + :param func: a coroutine function + :param args: positional arguments to call the function with + :param name: name of the task, for the purposes of introspection and debugging + + .. versionadded:: 3.0 + """ + + @abstractmethod + async def start( + self, + func: Callable[..., Awaitable[Any]], + *args: object, + name: object = None, + ) -> Any: + """ + Start a new task and wait until it signals for readiness. + + :param func: a coroutine function + :param args: positional arguments to call the function with + :param name: name of the task, for the purposes of introspection and debugging + :return: the value passed to ``task_status.started()`` + :raises RuntimeError: if the task finishes without calling + ``task_status.started()`` + + .. versionadded:: 3.0 + """ + + @abstractmethod + async def __aenter__(self) -> TaskGroup: + """Enter the task group context and allow starting new tasks.""" + + @abstractmethod + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + """Exit the task group context waiting for all tasks to finish.""" diff --git a/.venv/lib/python3.11/site-packages/jsonschema_specifications-2024.10.1.dist-info/INSTALLER b/.venv/lib/python3.11/site-packages/jsonschema_specifications-2024.10.1.dist-info/INSTALLER new file mode 100644 index 0000000000000000000000000000000000000000..a1b589e38a32041e49332e5e81c2d363dc418d68 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/jsonschema_specifications-2024.10.1.dist-info/INSTALLER @@ -0,0 +1 @@ +pip diff --git a/.venv/lib/python3.11/site-packages/jsonschema_specifications-2024.10.1.dist-info/METADATA b/.venv/lib/python3.11/site-packages/jsonschema_specifications-2024.10.1.dist-info/METADATA new file mode 100644 index 0000000000000000000000000000000000000000..b4ced65633f7c8ef9bf38d0c26707ff1514fc7fe --- /dev/null +++ b/.venv/lib/python3.11/site-packages/jsonschema_specifications-2024.10.1.dist-info/METADATA @@ -0,0 +1,55 @@ +Metadata-Version: 2.3 +Name: jsonschema-specifications +Version: 2024.10.1 +Summary: The JSON Schema meta-schemas and vocabularies, exposed as a Registry +Project-URL: Documentation, https://jsonschema-specifications.readthedocs.io/ +Project-URL: Homepage, https://github.com/python-jsonschema/jsonschema-specifications +Project-URL: Issues, https://github.com/python-jsonschema/jsonschema-specifications/issues/ +Project-URL: Funding, https://github.com/sponsors/Julian +Project-URL: Tidelift, https://tidelift.com/subscription/pkg/pypi-jsonschema-specifications?utm_source=pypi-jsonschema-specifications&utm_medium=referral&utm_campaign=pypi-link +Project-URL: Source, https://github.com/python-jsonschema/jsonschema-specifications +Author-email: Julian Berman +License-File: COPYING +Keywords: data validation,json,json schema,jsonschema,validation +Classifier: Development Status :: 5 - Production/Stable +Classifier: Intended Audience :: Developers +Classifier: License :: OSI Approved :: MIT License +Classifier: Operating System :: OS Independent +Classifier: Programming Language :: Python +Classifier: Programming Language :: Python :: 3.8 +Classifier: Programming Language :: Python :: 3.9 +Classifier: Programming Language :: Python :: 3.10 +Classifier: Programming Language :: Python :: 3.11 +Classifier: Programming Language :: Python :: 3.12 +Classifier: Programming Language :: Python :: 3.13 +Classifier: Programming Language :: Python :: Implementation :: CPython +Classifier: Programming Language :: Python :: Implementation :: PyPy +Classifier: Topic :: File Formats :: JSON +Classifier: Topic :: File Formats :: JSON :: JSON Schema +Requires-Python: >=3.9 +Requires-Dist: referencing>=0.31.0 +Description-Content-Type: text/x-rst + +============================= +``jsonschema-specifications`` +============================= + +|PyPI| |Pythons| |CI| |ReadTheDocs| + +JSON support files from the `JSON Schema Specifications `_ (metaschemas, vocabularies, etc.), packaged for runtime access from Python as a `referencing-based Schema Registry `_. + +.. |PyPI| image:: https://img.shields.io/pypi/v/jsonschema-specifications.svg + :alt: PyPI version + :target: https://pypi.org/project/jsonschema-specifications/ + +.. |Pythons| image:: https://img.shields.io/pypi/pyversions/jsonschema-specifications.svg + :alt: Supported Python versions + :target: https://pypi.org/project/jsonschema-specifications/ + +.. |CI| image:: https://github.com/python-jsonschema/jsonschema-specifications/workflows/CI/badge.svg + :alt: Build status + :target: https://github.com/python-jsonschema/jsonschema-specifications/actions?query=workflow%3ACI + +.. |ReadTheDocs| image:: https://readthedocs.org/projects/jsonschema-specifications/badge/?version=stable&style=flat + :alt: ReadTheDocs status + :target: https://jsonschema-specifications.readthedocs.io/en/stable/ diff --git a/.venv/lib/python3.11/site-packages/jsonschema_specifications-2024.10.1.dist-info/RECORD b/.venv/lib/python3.11/site-packages/jsonschema_specifications-2024.10.1.dist-info/RECORD new file mode 100644 index 0000000000000000000000000000000000000000..0dd7229b746832689a5c4b0e9613c5ce88eec944 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/jsonschema_specifications-2024.10.1.dist-info/RECORD @@ -0,0 +1,33 @@ +jsonschema_specifications-2024.10.1.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +jsonschema_specifications-2024.10.1.dist-info/METADATA,sha256=-jCfClPka5D4aDTtJ683zNiEcSHXhPBLuk9r9XWwyHI,2985 +jsonschema_specifications-2024.10.1.dist-info/RECORD,, +jsonschema_specifications-2024.10.1.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87 +jsonschema_specifications-2024.10.1.dist-info/licenses/COPYING,sha256=QtzWNJX4e063x3V6-jebtVpT-Ur9el9lfZrfVyNuUVw,1057 +jsonschema_specifications/__init__.py,sha256=qoTB2DKY7qvNrGhMPH6gtmAJRLilmVQ-fFZwT6ryqw0,386 +jsonschema_specifications/__pycache__/__init__.cpython-311.pyc,, +jsonschema_specifications/__pycache__/_core.cpython-311.pyc,, +jsonschema_specifications/_core.py,sha256=tFhc1CMleJ3AJOK_bjxOpFQTdrsUClFGfFxPBU_CebM,1140 +jsonschema_specifications/schemas/draft201909/metaschema.json,sha256=e3YbPhIfCgyh6ioLjizIVrz4AWBLgmjXG6yqICvAwTs,1785 +jsonschema_specifications/schemas/draft201909/vocabularies/applicator,sha256=aJUQDplyb7sQcFhRK77D7P1LJOj9L6zuPlBe5ysNTDE,1860 +jsonschema_specifications/schemas/draft201909/vocabularies/content,sha256=m31PVaTi_bAsQwBo_f-rxzKt3OI42j8d8mkCScM1MnQ,517 +jsonschema_specifications/schemas/draft201909/vocabularies/core,sha256=taLElX9kldClCB8ECevooU5BOayyA_x0hHH47eKvWyw,1531 +jsonschema_specifications/schemas/draft201909/vocabularies/meta-data,sha256=1H4kRd1qgicaKY2DzGxsuNSuHhXg3Fa-zTehY-zwEoY,892 +jsonschema_specifications/schemas/draft201909/vocabularies/validation,sha256=HlJsHTNac0gF_ILPV5jBK5YK19olF8Zs2lobCTWcPBw,2834 +jsonschema_specifications/schemas/draft202012/metaschema.json,sha256=Qdp29a-3zgYtJI92JGOpL3ykfk4PkFsiS6av7vkd7Q8,2452 +jsonschema_specifications/schemas/draft202012/vocabularies/applicator,sha256=xKbkFHuR_vf-ptwFjLG_k0AvdBS3ZXiosWqvHa1qrO8,1659 +jsonschema_specifications/schemas/draft202012/vocabularies/content,sha256=CDQ3R3ZOSlgUJieTz01lIFenkThjxZUNQyl-jh_axbY,519 +jsonschema_specifications/schemas/draft202012/vocabularies/core,sha256=wtEqjk3RHTNt_IOj9mOqTGnwtJs76wlP_rJbUxb0gD0,1564 +jsonschema_specifications/schemas/draft202012/vocabularies/format,sha256=UOu_55BhGoSbjMQAoJwdDg-2q1wNQ6DyIgH9NiUFa_Q,403 +jsonschema_specifications/schemas/draft202012/vocabularies/format-annotation,sha256=q8d1rf79idIjWBcNm_k_Tr0jSVY7u-3WDwK-98gSvMA,448 +jsonschema_specifications/schemas/draft202012/vocabularies/format-assertion,sha256=xSJCuaG7eGsmw-gset1CjDH5yW5XXc6Z5W6l_qptogw,445 +jsonschema_specifications/schemas/draft202012/vocabularies/meta-data,sha256=j3bW4U9Bubku-TO3CM3FFEyLUmhlGtEZGEhfsXVPHHY,892 +jsonschema_specifications/schemas/draft202012/vocabularies/unevaluated,sha256=Lb-8tzmUtnCwl2SSre4f_7RsIWgnhNL1pMpWH54tDLQ,506 +jsonschema_specifications/schemas/draft202012/vocabularies/validation,sha256=cBCjHlQfMtK-ch4t40jfdcmzaHaj7TBId_wKvaHTelg,2834 +jsonschema_specifications/schemas/draft3/metaschema.json,sha256=LPdfZENvtb43Si6qJ6uLfh_WUcm0ba6mxnsC_WTiRYs,2600 +jsonschema_specifications/schemas/draft4/metaschema.json,sha256=4UidC0dV8CeTMCWR0_y48Htok6gqlPJIlfjk7fEbguI,4357 +jsonschema_specifications/schemas/draft6/metaschema.json,sha256=wp386fVINcOgbAOzxdXsDtp3cGVo-cTffPvHVmpRAG0,4437 +jsonschema_specifications/schemas/draft7/metaschema.json,sha256=PVOSCIJhYGxVm2A_OFMpyfGrRbXWZ-uZBodFOwVdQF4,4819 +jsonschema_specifications/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +jsonschema_specifications/tests/__pycache__/__init__.cpython-311.pyc,, +jsonschema_specifications/tests/__pycache__/test_jsonschema_specifications.cpython-311.pyc,, +jsonschema_specifications/tests/test_jsonschema_specifications.py,sha256=WkbYRW6A6FoZ0rivShfqVLSCsAiHJ2x8TxqECJTXPTY,1106 diff --git a/.venv/lib/python3.11/site-packages/jsonschema_specifications-2024.10.1.dist-info/WHEEL b/.venv/lib/python3.11/site-packages/jsonschema_specifications-2024.10.1.dist-info/WHEEL new file mode 100644 index 0000000000000000000000000000000000000000..cdd68a497cdfa8d3f2b837225beacef711b85047 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/jsonschema_specifications-2024.10.1.dist-info/WHEEL @@ -0,0 +1,4 @@ +Wheel-Version: 1.0 +Generator: hatchling 1.25.0 +Root-Is-Purelib: true +Tag: py3-none-any diff --git a/.venv/lib/python3.11/site-packages/jsonschema_specifications-2024.10.1.dist-info/licenses/COPYING b/.venv/lib/python3.11/site-packages/jsonschema_specifications-2024.10.1.dist-info/licenses/COPYING new file mode 100644 index 0000000000000000000000000000000000000000..a9f853e43069b8e3f8a156a4af2b1198a004230d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/jsonschema_specifications-2024.10.1.dist-info/licenses/COPYING @@ -0,0 +1,19 @@ +Copyright (c) 2022 Julian Berman + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/.venv/lib/python3.11/site-packages/xformers/__init__.py b/.venv/lib/python3.11/site-packages/xformers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bfcae56b950074f3cb74b67cad270b1b49730f26 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/__init__.py @@ -0,0 +1,73 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os + +import torch + +from . import _cpp_lib +from .checkpoint import ( # noqa: E402, F401 + checkpoint, + get_optimal_checkpoint_policy, + list_operators, + selective_checkpoint_wrapper, +) + +try: + from .version import __version__ # noqa: F401 +except ImportError: + __version__ = "0.0.0" + + +logger = logging.getLogger("xformers") + +_has_cpp_library: bool = _cpp_lib._cpp_library_load_exception is None + +_is_opensource: bool = True + + +def compute_once(func): + value = None + + def func_wrapper(): + nonlocal value + if value is None: + value = func() + return value + + return func_wrapper + + +@compute_once +def _is_triton_available(): + if os.environ.get("XFORMERS_ENABLE_TRITON", "0") == "1": + return True + if not torch.cuda.is_available(): + return False + if os.environ.get("XFORMERS_FORCE_DISABLE_TRITON", "0") == "1": + return False + # We have many errors on V100 with recent triton versions + # Let's just drop support for triton kernels below A100 + if torch.cuda.get_device_capability("cuda") < (8, 0): + return False + try: + import triton # noqa + + return True + except (ImportError, AttributeError): + logger.warning( + "A matching Triton is not available, some optimizations will not be enabled", + exc_info=True, + ) + return False + + +@compute_once +def get_python_lib(): + return torch.library.Library("xformers_python", "DEF") + + +# end of file diff --git a/.venv/lib/python3.11/site-packages/xformers/_cpp_lib.py b/.venv/lib/python3.11/site-packages/xformers/_cpp_lib.py new file mode 100644 index 0000000000000000000000000000000000000000..b9686b4adb76046f5bac41694173e406836f1ee1 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/_cpp_lib.py @@ -0,0 +1,155 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import dataclasses +import json +import logging +import os +import platform +from typing import Any, Dict, Optional + +import torch + +logger = logging.getLogger("xformers") + +UNAVAILABLE_FEATURES_MSG = ( + " Memory-efficient attention, SwiGLU, sparse and more won't be available." +) + + +@dataclasses.dataclass +class _BuildInfo: + metadata: Dict[str, Any] + + @property + def cuda_version(self) -> Optional[int]: + return self.metadata["version"]["cuda"] + + @property + def hip_version(self) -> Optional[int]: + return self.metadata["version"]["hip"] + + @property + def torch_version(self) -> str: + return self.metadata["version"]["torch"] + + @property + def python_version(self) -> str: + return self.metadata["version"]["python"] + + @property + def flash_version(self) -> str: + return self.metadata["version"].get("flash", "0.0.0") + + @property + def use_torch_flash(self) -> bool: + return self.metadata["version"].get("use_torch_flash", False) + + @property + def build_env(self) -> Dict[str, Any]: + return self.metadata["env"] + + +class xFormersWasNotBuiltException(Exception): + def __str__(self) -> str: + return ( + "Need to compile C++ extensions to use all xFormers features.\n" + " Please install xformers properly " + "(see https://github.com/facebookresearch/xformers#installing-xformers)\n" + + UNAVAILABLE_FEATURES_MSG + ) + + +class xFormersInvalidLibException(Exception): + def __init__(self, build_info: Optional[_BuildInfo]) -> None: + self.build_info = build_info + + def __str__(self) -> str: + if self.build_info is None: + msg = "xFormers was built for a different version of PyTorch or Python." + else: + msg = f"""xFormers was built for: + PyTorch {self.build_info.torch_version} with CUDA {self.build_info.cuda_version} (you have {torch.__version__}) + Python {self.build_info.python_version} (you have {platform.python_version()})""" + return ( + "xFormers can't load C++/CUDA extensions. " + + msg + + "\n Please reinstall xformers " + "(see https://github.com/facebookresearch/xformers#installing-xformers)\n" + + UNAVAILABLE_FEATURES_MSG + ) + + +def _register_extensions(): + import importlib + import os + + import torch + + # load the custom_op_library and register the custom ops + lib_dir = os.path.dirname(__file__) + if os.name == "nt": + # Register the main torchvision library location on the default DLL path + import ctypes + import sys + + kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True) + with_load_library_flags = hasattr(kernel32, "AddDllDirectory") + prev_error_mode = kernel32.SetErrorMode(0x0001) + + if with_load_library_flags: + kernel32.AddDllDirectory.restype = ctypes.c_void_p + + if sys.version_info >= (3, 8): + os.add_dll_directory(lib_dir) + elif with_load_library_flags: + res = kernel32.AddDllDirectory(lib_dir) + if res is None: + err = ctypes.WinError(ctypes.get_last_error()) + err.strerror += f' Error adding "{lib_dir}" to the DLL directories.' + raise err + + kernel32.SetErrorMode(prev_error_mode) + + loader_details = ( + importlib.machinery.ExtensionFileLoader, + importlib.machinery.EXTENSION_SUFFIXES, + ) + + extfinder = importlib.machinery.FileFinder(lib_dir, loader_details) + if torch.version.hip and not hasattr(torch.version, "git_version"): + ext_specs = extfinder.find_spec("_C_hip") + else: + ext_specs = extfinder.find_spec("_C") + if ext_specs is None: + raise xFormersWasNotBuiltException() + cpp_lib_json = os.path.join(lib_dir, "cpp_lib.json") + with open(cpp_lib_json, "r") as fp: + build_metadata = _BuildInfo(json.load(fp)) + try: + torch.ops.load_library(ext_specs.origin) + except OSError as exc: + raise xFormersInvalidLibException(build_metadata) from exc + return build_metadata + + +_cpp_library_load_exception = None +_build_metadata: Optional[_BuildInfo] = None + +try: + _build_metadata = _register_extensions() +except (xFormersInvalidLibException, xFormersWasNotBuiltException) as e: + ENV_VAR_FOR_DETAILS = "XFORMERS_MORE_DETAILS" + if os.environ.get(ENV_VAR_FOR_DETAILS, False): + logger.warning(f"WARNING[XFORMERS]: {e}", exc_info=e) + else: + logger.warning( + f"WARNING[XFORMERS]: {e}\n Set {ENV_VAR_FOR_DETAILS}=1 for more details" + ) + _cpp_library_load_exception = e + +_built_with_cuda = ( + _build_metadata is not None and _build_metadata.cuda_version is not None +) diff --git a/.venv/lib/python3.11/site-packages/xformers/_deprecation_warning.py b/.venv/lib/python3.11/site-packages/xformers/_deprecation_warning.py new file mode 100644 index 0000000000000000000000000000000000000000..505ef15e65b7638dc6332f4204793c708e7c5ae3 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/_deprecation_warning.py @@ -0,0 +1,12 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import warnings + + +def deprecated_function(self): + name = repr(self) # self.__name__ + msg = f"{name} is deprecated and is not maintained anymore. It might be removed in a future version of xFormers" + warnings.warn(msg, FutureWarning, stacklevel=2) diff --git a/.venv/lib/python3.11/site-packages/xformers/attn_bias_utils.py b/.venv/lib/python3.11/site-packages/xformers/attn_bias_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..02114c39a42f1a1d9fd78b839073fc93c5d6f52d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/attn_bias_utils.py @@ -0,0 +1,501 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import math +import random +from typing import List, Optional, Sequence, Tuple, Type + +import torch + +from xformers.ops import AttentionBias, fmha +from xformers.ops.fmha.attn_bias import AttentionBiasSubTensor +from xformers.ops.fmha.common import AttentionOpBase + + +def _create_aligned_bias(*shape: int, **kwargs) -> torch.Tensor: + align_to = 8 + return ( + torch.randn( + ( + *shape[:-1], + align_to * ((shape[-1] + align_to - 1) // align_to), + ), + **kwargs, + ) + * 3 + ).narrow(-1, 0, shape[-1]) + + +def create_attn_bias( + bias_type, + batch_size: int, + num_heads: int, + num_heads_groups: int, + q_len: int, + kv_len: int, + device, + dtype, + requires_grad: bool, + fmt: str, + op: Optional[Type[AttentionOpBase]] = None, + page_size: Optional[int] = None, +): + if bias_type is None or isinstance(None, bias_type): + return None + r = random.Random("-".join(map(str, [batch_size, q_len, kv_len, dtype, fmt]))) + window_size = {0: 3, 1: 128, 2: 300}[r.randint(0, 2)] + if bias_type is torch.Tensor: + if fmt == "BMK": + batch_size *= num_heads + num_heads = 1 + if op is not None and issubclass(op, fmha.triton_splitk.FwOp): + attn_bias = ( + torch.randn( + (batch_size, num_heads_groups, num_heads, q_len, kv_len), + device=device, + dtype=dtype, + ) + * 3 + ) + if fmt in ["BMK", "BMHK"]: + attn_bias = attn_bias[:, 0] + else: + attn_bias = _create_aligned_bias( + batch_size, + num_heads_groups, + num_heads, + q_len, + kv_len, + device=device, + dtype=dtype, + ) + + # make sure it also works if the first columns/rows are partially masked out + attn_bias[0, 0, 0, : q_len - 1, : kv_len - 1] = -math.inf + if fmt in ["BMK", "BMHK"]: + attn_bias = attn_bias[:, 0] + + if requires_grad: + attn_bias.requires_grad_(True) + if fmt == "BMK": + attn_bias = attn_bias[:, 0] + return attn_bias + if bias_type is fmha.attn_bias.LowerTriangularMask: + return bias_type() + if bias_type is fmha.attn_bias.LowerTriangularFromBottomRightMask: + return bias_type() + if bias_type is fmha.attn_bias.LowerTriangularFromBottomRightLocalAttentionMask: + return bias_type(window_size) + if bias_type is fmha.attn_bias.LowerTriangularMaskWithTensorBias: + attn_bias = _create_aligned_bias( + batch_size, + num_heads_groups, + num_heads, + q_len, + kv_len, + device=device, + dtype=dtype, + ) + if fmt in ["BMK", "BMHK"]: + attn_bias = attn_bias[:, 0] + if fmt == "BMK": + attn_bias = attn_bias[:, 0] + if requires_grad: + attn_bias.requires_grad_(True) + return fmha.attn_bias.LowerTriangularMaskWithTensorBias(attn_bias) + if bias_type in [ + fmha.attn_bias.BlockDiagonalMask, + fmha.attn_bias.BlockDiagonalCausalMask, + fmha.attn_bias.BlockDiagonalCausalFromBottomRightMask, + fmha.attn_bias.BlockDiagonalCausalLocalAttentionMask, + fmha.attn_bias.BlockDiagonalCausalLocalAttentionFromBottomRightMask, + ]: + # These bias types are not supported in BMK format + assert fmt in ["BMGHK", "BMHK"] + max_q_minus_k = None + if bias_type in { + fmha.attn_bias.BlockDiagonalCausalFromBottomRightMask, + fmha.attn_bias.BlockDiagonalCausalLocalAttentionFromBottomRightMask, + }: + max_q_minus_k = 0 + elif bias_type == fmha.attn_bias.BlockDiagonalCausalLocalAttentionMask: + assert window_size is not None + max_q_minus_k = window_size - 1 + + block_diag = fmha.attn_bias.BlockDiagonalMask.from_seqlens( + *_rand_seqlens( + r, + batch_size, + q_len, + kv_len, + max_q_minus_k=max_q_minus_k, + ) + ) + if bias_type is fmha.attn_bias.BlockDiagonalCausalMask: + block_diag = block_diag.make_causal() + if bias_type in { + fmha.attn_bias.BlockDiagonalCausalLocalAttentionMask, + fmha.attn_bias.BlockDiagonalCausalLocalAttentionFromBottomRightMask, + }: + block_diag = fmha.attn_bias.BlockDiagonalMask( + q_seqinfo=block_diag.q_seqinfo, + k_seqinfo=block_diag.k_seqinfo, + _batch_sizes=block_diag._batch_sizes, + ) + assert window_size is not None + if bias_type is fmha.attn_bias.BlockDiagonalCausalLocalAttentionMask: + block_diag = block_diag.make_local_attention(window_size) + else: + block_diag = block_diag.make_local_attention_from_bottomright( + window_size + ) + if bias_type is fmha.attn_bias.BlockDiagonalCausalFromBottomRightMask: + block_diag = block_diag.make_causal_from_bottomright() + return block_diag + if bias_type in [ + fmha.attn_bias.BlockDiagonalPaddedKeysMask, + fmha.attn_bias.BlockDiagonalCausalLocalAttentionPaddedKeysMask, + fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask, + fmha.attn_bias.PagedBlockDiagonalPaddedKeysMask, + fmha.attn_bias.PagedBlockDiagonalCausalWithOffsetPaddedKeysMask, + ]: + assert fmt in ["BMHK", "BMGHK"] + q, k = _rand_seqlens_padded_k(r, batch_size, q_len, kv_len) + block_diag_type = ( + bias_type._UNPAGED_TYPE + if issubclass(bias_type, fmha.attn_bias.PagedBlockDiagonalPaddedKeysMask) + else bias_type + ) + if bias_type is fmha.attn_bias.BlockDiagonalCausalLocalAttentionPaddedKeysMask: + g_block_diag = block_diag_type.from_seqlens_local( + q_seqlen=q, + kv_padding=kv_len, + kv_seqlen=k, + window_size=min(window_size, min(k)), + ) + else: + g_block_diag = block_diag_type.from_seqlens( + q_seqlen=q, + kv_padding=kv_len, + kv_seqlen=k, + ) + if issubclass(bias_type, fmha.attn_bias.PagedBlockDiagonalPaddedKeysMask): + assert page_size is not None + pages_per_row = (kv_len + page_size - 1) // page_size + block_tables = torch.tensor( + r.sample(range(batch_size * pages_per_row), batch_size * pages_per_row), + device=device, + dtype=torch.int32, + ).reshape(batch_size, pages_per_row) + return g_block_diag.make_paged( + block_tables=block_tables, page_size=page_size, paged_type=bias_type + ) + return g_block_diag + if bias_type in [ + fmha.attn_bias.BlockDiagonalCausalWithOffsetGappyKeysMask, + fmha.attn_bias.BlockDiagonalGappyKeysMask, + ]: + assert fmt in ["BMHK", "BMGHK"] + max_q_minus_k = ( + None if bias_type is fmha.attn_bias.BlockDiagonalGappyKeysMask else 0 + ) + q, k = _rand_seqlens(r, batch_size, q_len, kv_len, max_q_minus_k) + total_kv_len = kv_len * batch_size + starts = [r.randint(0, total_kv_len - ki) for ki in k] + [total_kv_len] + return fmha.attn_bias.BlockDiagonalGappyKeysMask.from_seqlens( + q_seqlen=q, + kv_seqstarts=starts, + kv_seqlen=k, + ) + if bias_type in [ + fmha.attn_bias.PagedBlockDiagonalGappyKeysMask, + ]: + assert fmt in ["BMHK", "BMGHK"] + assert page_size is not None + pages_per_row = (kv_len + page_size - 1) // page_size + total_queries = q_len * batch_size + q = _rand_maxed_partition(r, total_queries, batch_size, total_queries, False) + k = [r.randint(1, kv_len) for _ in range(batch_size)] + row_size = pages_per_row * page_size + starts = [row_size * i + r.randint(0, row_size - ki) for i, ki in enumerate(k)] + starts.append(pages_per_row * batch_size * page_size) + block_diag_type = bias_type._UNPAGED_TYPE # type: ignore + g_block_diag = block_diag_type.from_seqlens( + q_seqlen=q, + kv_seqstarts=starts, + kv_seqlen=k, + ) + block_tables = torch.tensor( + r.sample(range(batch_size * pages_per_row), batch_size * pages_per_row), + device=device, + dtype=torch.int32, + ).reshape(batch_size, pages_per_row) + return g_block_diag.make_paged( + block_tables=block_tables, + page_size=page_size, + paged_type=bias_type, + notional_padding=page_size * pages_per_row, + ) + if bias_type == fmha.attn_bias.LocalAttentionFromBottomRightMask: + return bias_type( + window_left=r.randint(0, 5), + window_right=r.randint(0, 5), + ) + + assert False, f"Unsupported bias type: {bias_type}" + + +def _rand_seqlens( + r: random.Random, + bs: int, + q_len: int, + kv_len: int, + max_q_minus_k: Optional[int], +) -> Tuple[Sequence[int], Sequence[int]]: + """ + Generates lists of lengths of query blocks and corresponding key blocks. + The total number of queries will be bs * q_len and the + total number of keys will be bs * kv_len. + max_q_minus_k: maximum allowed num_queries - num_keys. + For "bottom-right" masks it's 0, we need to have more keys than + queries, otherwise some queries have no keys to attend to. + For BlockDiagonalCausalMask it's None, there is no constraint + on num_queries - num_keys. + For BlockDiagonalCausalLocalAttentionMask it's equal + to the window size. + """ + if max_q_minus_k == 0: + # In case max_q_minus_k > 0 the exact condition is + # kv_len >= q_len - max_q_minus_k * batch_size, + # but we can't check it without knowing the actual batch size, + # which is determined in the loop below. + assert kv_len >= q_len + q_len *= bs + kv_len *= bs + seqlens_q: List[int] = [] + seqlens_k: List[int] = [] + + step_q = [max(1, q_len // 10), max(2, q_len // 2)] + step_k = [max(1, kv_len // 10), max(2, kv_len // 2)] + while sum(seqlens_q) < q_len and sum(seqlens_k) < kv_len: + if max_q_minus_k is None: + # Simple case - no constraint on the number of queries and keys. + num_queries = r.randrange(*step_q) + seqlens_q.append(num_queries) + seqlens_k.append(r.randrange(*step_k)) + else: + # In this case we need to make sure num_queries - num_keys < max_q_minus_k holds for every batch element. + # To do this, when choosing num_queries and num_keys at a given step, + # we ensure two conditions are satisfied: + # 1) num_queries <= num_keys + max_q_minus_k for the current batch element + # 2) Same holds for the remaining keys and queries, i.e. + # queries_left - num_queries <= keys_left - num_keys + max_q_minus_k + keys_left = kv_len - sum(seqlens_k, 0) + queries_left = q_len - sum(seqlens_q, 0) + + assert ( + keys_left >= queries_left - max_q_minus_k + ), f"{keys_left=} {queries_left=} {max_q_minus_k=} {kv_len=} {q_len=} {seqlens_k=} {seqlens_q=}" + # Limit num_queries from above: if num_queries > keys_left + max_q_minus_k, + # condition num_queries <= num_keys + max_q_minus_k can't be satisfied even if we take + # all the remaining keys + max_queries_to_take = min(queries_left, keys_left + max_q_minus_k) + num_queries = r.randrange(1, max_queries_to_take + 1) + seqlens_q.append(num_queries) + + # Now we know num_queries, let's select num_keys. + # How many keys can we use for the current batch element so that + # for the remaining keys and values the constraint + # num_queries - num_keys < max_q_minus_k holds on the next step? + extra_keys_available = keys_left - queries_left + max_q_minus_k + 1 + assert extra_keys_available >= 0 + if extra_keys_available > 0: + seqlens_k.append(num_queries + r.randrange(0, extra_keys_available)) + else: + seqlens_k.append(num_queries) + seqlens_q[-1] = q_len - sum(seqlens_q[:-1]) + seqlens_k[-1] = kv_len - sum(seqlens_k[:-1]) + return seqlens_q, seqlens_k + + +def _rand_maxed_partition( + r: random.Random, total: int, n: int, mx: int, positive: bool = True +) -> List[int]: + # returns list of n nonnegative integers less than mx summing to total + # NB: This is unfortunately biased towards evenly-split bins. + # If `positive`, outputs are positive + if positive: + total -= n + mx -= 1 + idxs = r.sample(range(n * mx), total) + y = torch.zeros(n, mx, dtype=torch.int32) + y.flatten()[idxs] = 1 + z = y.sum(1) + if positive: + z += 1 + return z.tolist() + + +def _rand_seqlens_padded_k( + r: random.Random, bs: int, q_len: int, kv_len: int +) -> Tuple[Sequence[int], Sequence[int]]: + # This is for BlockDiagonalCausalWithOffsetPaddedKeysMask. + # we need q_seqlens and k_seqlens to be of len bsz. + # For each "batch element" there must be more keys than queries + # because this bias type is "bottom right" and so any extra queries + # will attend to nothing and have undefined result. + # In addition every element of k_seqlens must be <= kv_len + if q_len > kv_len: + raise ValueError("need more keys than values") + if q_len == kv_len: + # all key slots are needed so we cannot have padding + q_seqlens = k_seqlens = [kv_len] * bs + else: + q_seqlens = _rand_maxed_partition(r, q_len * bs, bs, kv_len) + k_seqlens = [r.randint(i, kv_len) for i in q_seqlens] + return q_seqlens, k_seqlens + + +def ref_attention(q, k, v, attn_bias=None, drop_mask=None, p=0.0, scale=None): + if q.ndim == 5: + + def attn_bias_group(group: int): + if isinstance(attn_bias, fmha.attn_bias.AttentionBiasSubTensor): + if attn_bias.HOLDS_DENSE_TENSOR: + return attn_bias[:, group] + elif isinstance(attn_bias, torch.Tensor): + return attn_bias[:, group] + return attn_bias + + return torch.stack( + [ + ref_attention_bmhk( + q[:, :, g], + k[:, :, g], + v[:, :, g], + scale=scale, + attn_bias=attn_bias_group(g), + ) + for g in range(q.shape[2]) + ], + dim=2, + ) + if q.ndim == 4: + assert p == 0.0 + return ref_attention_bmhk(q, k, v, scale=scale, attn_bias=attn_bias) + q = q.float() + k = k.float() + v = v.float() + + scale = scale if scale is not None else (1 / q.shape[-1] ** 0.5) + q = q * scale + + attn = q @ k.transpose(-2, -1) + if attn_bias is not None: + if isinstance(attn_bias, (AttentionBias, AttentionBiasSubTensor)): + # Always create in B,H,Mq,Mk format + attn_bias_tensor = attn_bias.materialize( + (q.shape[0], 1, q.shape[1], k.shape[1]), + device=q.device, + dtype=torch.float32, + ) + else: + attn_bias_tensor = attn_bias + if attn_bias_tensor.ndim == 4: + assert q.shape[0] == attn_bias_tensor.shape[0] * attn_bias_tensor.shape[1] + attn_bias_tensor = attn_bias_tensor.reshape( + [-1, *attn_bias_tensor.shape[2:]] + ) + attn = attn + attn_bias_tensor.float() + attn = attn.softmax(-1) + if drop_mask is not None: + attn = attn * (drop_mask / (1 - p)) + return attn @ v + + +def ref_attention_bmhk(q, k, v, attn_bias, scale=None) -> torch.Tensor: + assert q.ndim == 4 + + def T(t): + return t.permute((0, 2, 1, 3)).reshape( + [t.shape[0] * t.shape[2], t.shape[1], t.shape[3]] + ) + + if isinstance(attn_bias, (AttentionBias, AttentionBiasSubTensor)): + attn_bias = attn_bias.materialize( + (q.shape[0], q.shape[2], q.shape[1], k.shape[1]), + device=q.device, + dtype=torch.float32, + ).reshape([q.shape[0] * q.shape[2], q.shape[1], k.shape[1]]) + out = ref_attention(T(q), T(k), T(v), attn_bias, scale=scale) + out = out.reshape([q.shape[0], q.shape[2], q.shape[1], v.shape[3]]) + return out.permute((0, 2, 1, 3)) + + +def pack_kv_cache( + cache_k: torch.Tensor, + cache_v: torch.Tensor, + kv_seqlens: List[int], + BLOCK_N: int, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Create block tables and pages K/V cache for testing paged attention. + Args: + cache_k, cache_v: K/V caches, each of shape [B, MAX_T, H_kv, D]. + Note that these tensors are unexpanded, + i.e. for multiquery case cache_k.shape[2] = 1 + kv_seqlens: list of K/V sequence lengths + BLOCK_N: number of tokens per per paged attention block + B: batch size + Returns: + block_tables: [B, MAX_BLOCKS] + packed_cache_k: [1, total_len_rounded, H_kv, D] + packed_cache_v: [1, total_len_rounded, H_kv, D] + where total_len_rounded is a sum of K/V seqlens, each rounded up + to a multiple of BLOCK_N. + """ + + kv_seqlens_rounded = [(x + BLOCK_N - 1) // BLOCK_N * BLOCK_N for x in kv_seqlens] + + total_len_rounded = sum(kv_seqlens_rounded) + + B, MAX_T, H, D = cache_k.shape + + packed_cache_k = torch.empty( + total_len_rounded, H, D, device=cache_k.device, dtype=cache_k.dtype + ) + packed_cache_v = torch.empty( + total_len_rounded, H, D, device=cache_k.device, dtype=cache_k.dtype + ) + seqstart = 0 + for b in range(B): + packed_cache_k[seqstart : seqstart + kv_seqlens[b]] = cache_k[ + b, : kv_seqlens[b] + ].clone() + packed_cache_v[seqstart : seqstart + kv_seqlens[b]] = cache_v[ + b, : kv_seqlens[b] + ].clone() + seqstart += kv_seqlens_rounded[b] + + num_blocks_per_row = (MAX_T + BLOCK_N - 1) // BLOCK_N + block_tables = ( + torch.arange(num_blocks_per_row, device="cuda", dtype=torch.int32) + .unsqueeze(0) + .expand(B, num_blocks_per_row) + ) + seqstarts = ( + ( + torch.tensor(kv_seqlens_rounded).cumsum(dim=0) + - torch.tensor(kv_seqlens_rounded) + ) + .to(device="cuda") + .unsqueeze(1) + ) // BLOCK_N + block_tables = (block_tables + seqstarts).contiguous().to(dtype=torch.int32) + return ( + block_tables, + packed_cache_k.unsqueeze(0), + packed_cache_v.unsqueeze(0), + ) diff --git a/.venv/lib/python3.11/site-packages/xformers/checkpoint.py b/.venv/lib/python3.11/site-packages/xformers/checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..c2f323ffa11274b09211b8051f60fda584d278a4 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/checkpoint.py @@ -0,0 +1,546 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import functools +import time +from collections import defaultdict +from copy import deepcopy +from dataclasses import astuple, dataclass +from typing import Any, Callable, ContextManager, Dict, List, Optional, Tuple + +import torch +from torch.testing._internal.composite_compliance import ( + is_inplace, + is_inplace_view_fn, + is_view_fn, +) +from torch.utils._python_dispatch import TorchDispatchMode +from torch.utils._pytree import tree_map + +_scipy_is_available = False +try: + from scipy.optimize import Bounds, LinearConstraint, milp + + _scipy_is_available = True +except ImportError: + _scipy_is_available = False + + +try: + # let's keep BC for older PyTorch for now + from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + ActivationWrapper, + ) + from torch.utils.checkpoint import ( # type: ignore + _CachedTorchDispatchMode, + _CachingTorchDispatchMode, + ) +except ImportError: + ActivationWrapper = torch.nn.Module # type: ignore + + class _NotAvailable: + def __init__(self, *args, **kwargs): + raise RuntimeError("Need PyTorch >= 2.2") + + _CachedTorchDispatchMode = _NotAvailable # type: ignore + _CachingTorchDispatchMode = _NotAvailable # type: ignore + + +try: + from torch.utils.checkpoint import SAC_IGNORED_OPS as _ignored_ops # type: ignore + + _PT_HAS_NEW_IMPL = True +except ImportError: + from torch.utils.checkpoint import _ignored_ops # type: ignore + + _PT_HAS_NEW_IMPL = False + + +_additional_ignored_ops = { + torch.ops.aten.lift_fresh.default, + torch.ops.profiler._record_function_exit._RecordFunction, + torch.ops.aten.clone.default, # seems needed for torch.compile +} +OPS_TO_ALWAYS_SKIP = _ignored_ops | _additional_ignored_ops + + +@dataclass +class ProfileMetadata: + name: str + time_taken: float + memory_used: float + curr_idx: int + output_ids: Any + inplace_info: Tuple[int, int] + is_view_like: bool + is_rand_op: bool + + +def _get_default_policy(allow_list=None): + _default_allow_list = [ + "xformers.efficient_attention_forward_cutlass.default", + "xformers_flash.flash_fwd.default", + "aten.addmm.default", + "aten.mm.default", + ] + if allow_list is None: + allow_list = _default_allow_list + + def _default_policy(ctx, func, *args, **kwargs): + return str(func) in allow_list + + return _default_policy + + +class VerboseTorchDispatchMode(TorchDispatchMode): + def __init__(self): + self.operators = [] + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + self.operators.append(func) + return func(*args, **kwargs) + + +def list_operators(function, *args, **kwargs): + """ + Returns the list of operators used inside `function` with + *args and **kwargs + """ + verbose_mode = VerboseTorchDispatchMode() + with verbose_mode: + function(*args, **kwargs) + return verbose_mode.operators + + +class CachedTorchDispatchMode(_CachedTorchDispatchMode): + def __init__(self, policy_fn, storage, allow_cache_entry_mutation): + global _PT_HAS_NEW_IMPL + if _PT_HAS_NEW_IMPL: + super().__init__(policy_fn, storage, allow_cache_entry_mutation) + else: + super().__init__(policy_fn, storage) + + # this is here for the old implementations + def pop_from_storage(self, func, args, kwargs): + # the autograd engine might add spurious views. This is a basic + # guard and should be improved + if self.storage[func]: + return self.storage[func].pop(0) + return func(*args, **kwargs) + + +class NullTorchDispatchMode(TorchDispatchMode): + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + return func(*args, **kwargs) + + +def selective_checkpoint_context_fn(policy_fn=None): + """An activation checkpoint context_fn for selectively deciding what to + store and what to recompute. Accepts a custom policy. + Args: + policy_fn(Union[List[Op], callable]): policy for deciding what to + store (instead of recompute). If it's a function, it should + be of form (func, *args, **kwargs) -> bool which indicates + if func outputs with *args and **kwargs should be stored or not. + Additionally, a list[Op] is also supported for easier cases. + The op should be in the format `torch.ops.***`, where the `***` + names of operators can be obtained with `list_operators`. + """ + if policy_fn is None: + policy_fn = _get_default_policy() + elif isinstance(policy_fn, list): + policy_fn = _get_default_policy(policy_fn) + else: + assert callable(policy_fn), "policy_fn should be None, list or a callable" + + temp_storage: Dict[Any, List[Any]] = defaultdict(list) + # assumption: grad_mode doesn't change inside function + caching_mode: ContextManager[None] + if torch.is_grad_enabled(): + caching_mode = _CachingTorchDispatchMode(deepcopy(policy_fn), temp_storage) + else: + caching_mode = NullTorchDispatchMode() + cached_mode = CachedTorchDispatchMode(deepcopy(policy_fn), temp_storage, True) + + return caching_mode, cached_mode + + +def checkpoint( + function, *args, preserve_rng_state=True, policy_fn=None, **kwargs +) -> Any: + """Wrapper around torch.utils.checkpoint that accepts a custom policy + function for selectively deciding what to store and what to recompute + Args: + function: describes what to run in the forward pass of the model or + part of the model. It should also know how to handle the inputs + passed as the tuple. For example, in LSTM, if user passes + ``(activation, hidden)``, :attr:`function` should correctly use the + first input as ``activation`` and the second input as ``hidden`` + preserve_rng_state(bool, optional): Omit stashing and restoring + the RNG state during each checkpoint. + Default: ``True`` + policy_fn(Union[List[Op], callable]): policy for deciding what to + store (instead of recompute). If it's a function, it should + be of form (func, *args, **kwargs) -> bool which indicates + if func outputs with *args and **kwargs should be stored or not. + Additionally, a list[Op] is also supported for easier cases. + The op should be in the format `torch.ops.***`, where the `***` + names of operators can be obtained with `list_operators`. + *args: Arguments to pass in to the given ``function``. + **kwargs: Keyword arguments to pass into the given ``function``. + """ + return torch.utils.checkpoint.checkpoint( + function, + *args, + use_reentrant=False, + preserve_rng_state=preserve_rng_state, + context_fn=functools.partial(selective_checkpoint_context_fn, policy_fn), + **kwargs, + ) + + +class ProfileOperatorsTorchDispatchMode(TorchDispatchMode): + def __init__(self, num_runs: int = 10) -> None: + self.data: List[ProfileMetadata] = [] + self.num_runs: int = num_runs + + def _get_inplace_metadata(self, func, out) -> Tuple[int, int, Tuple[int, ...]]: + curr_idx = len(self.data) + + def get_tensor_id(e): + return ( + e.untyped_storage().data_ptr() if isinstance(e, torch.Tensor) else None + ) + + output_ids = tree_map(get_tensor_id, out) + if not is_inplace(func): + return curr_idx, output_ids, () + + op_id = curr_idx + op_parent_id = -1 + for i, d in enumerate(self.data): + # find the first occurence of a tensor that + # shares the same storage as the current tensor + past_output_ids = d.output_ids + past_output_ids = ( + [past_output_ids] + if not isinstance(past_output_ids, (list, tuple, dict)) + else past_output_ids + ) + if output_ids in past_output_ids: + op_parent_id = i + break + if op_parent_id < 0: + op_parent_id = op_id + inplace_info = (op_id, op_parent_id) + return curr_idx, output_ids, inplace_info + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + out = func(*args, **kwargs) + + curr_idx, output_ids, inplace_info = self._get_inplace_metadata(func, out) + is_view_like = is_view_fn(func) or is_inplace_view_fn(func) + is_rand_op = torch.Tag.nondeterministic_seeded in func.tags + # sdpa has non-deterministic seed, but might be deterministic + # if no dropout is applied + if func.overloadpacket.__name__ == "_scaled_dot_product_flash_attention": + is_rand_op = kwargs.get("dropout_p", 0) != 0 + + # get runtime info of func + torch.cuda.synchronize() + t = time.time() + for i in range(self.num_runs): + func(*args, **kwargs) + torch.cuda.synchronize() + time_taken = (time.time() - t) / self.num_runs + + # get memory usage of func + torch.cuda.reset_peak_memory_stats() + mem1 = torch.cuda.max_memory_allocated() / 2**20 + func(*args, **kwargs) + mem2 = torch.cuda.max_memory_allocated() / 2**20 + + self.data.append( + ProfileMetadata( + func, + time_taken, + mem2 - mem1, + curr_idx, + output_ids, + inplace_info, + is_view_like, + is_rand_op, + ) + ) + return out + + +def _analyze_operators(function, *args) -> List[ProfileMetadata]: + """ + Use ProfileOperatorsTorchDispatchMode to get runtime and memory info. + + Args: + function: The function to optimize which will be selectively checkpointed. Usually the forward pass + of the model. + *args: Arguments to pass in to the given ``function``. + + Returns: + A list of tuples, where each tuples contains the name of the operator, the runtime of the operator, + and the memory usage of the operator. + + """ + profile_ops = ProfileOperatorsTorchDispatchMode() + with profile_ops: + function(*args) + + data = profile_ops.data + return data + + +def get_optimal_checkpoint_policy(function, *args, memory_budget: float) -> Callable: + """ + Given a function, its arguments, and the maximum amount of memory available, + find the subset of operators that can be optimized to reduce runtime while still fitting within the memory budget. + + Args: + function: The function to optimize which will be selectively checkpointed. Usually the forward pass + of the model. + *args: Arguments to pass in to the given ``function``. + memory_budget (float): A float between 0 and 1 which describes what percentage of the total memory to use. + + Returns: + A callable policy which can be passed to xformers.checkpoint() + + Raises: + RuntimeError: If `scipy` is not available. + ValueError: If `memory_budget` is not a float between 0 and 1. + + """ + if not _scipy_is_available: + raise RuntimeError( + "Please install scipy 1.9.0+ to use `get_optimal_checkpoint_policy`. You can do so using " + "`pip install scipy`." + ) + if memory_budget < 0 or memory_budget > 1: + raise ValueError( + f"`memory_budget` must be a float between 0 and 1. Got {memory_budget}." + ) + + data = _analyze_operators(function, *args) + # remove aten.detach.default from the list of ops because autograd + # inserts those during backward and it breaks the fwd-bwd alignment + data = [x for x in data if x.name not in OPS_TO_ALWAYS_SKIP] + + ops, runtimes_, memory_, new_ids, _, inplace_ops_, view_like_ops_, rand_ops_ = zip( + *[astuple(x) for x in data] + ) + runtimes = torch.tensor(runtimes_, dtype=torch.float64) + memory = torch.tensor(memory_, dtype=torch.float64) + view_like_ops = [i for i, x in enumerate(view_like_ops_) if x] + rand_ops = [i for i, x in enumerate(rand_ops_) if x] + + # remap the inplace indices as we have removed OPS_TO_ALWAYS_SKIP + inplace_ops = [tuple(map(new_ids.index, x)) for x in inplace_ops_ if x] + + # the last operation is always stored as the output of the checkpoint + # block, so we can avoid recomputing it. We set the memory to zero + # instead of adding a new constraint because we want both the 0 and 1 + # endpoints for memory_budget to be valid + # FIXME: this heuristic for finding the last non-view non-inplace op + # might not always be correct, which would yield suboptimal policies + last_op = len(ops) - 1 + skip_ops_ = set(view_like_ops) | set([x[0] for x in inplace_ops]) + skip_ops = sorted(list(skip_ops_)) + for op in reversed(skip_ops): + if op == last_op: + last_op -= 1 + + memory[last_op] = 0 + + max_memory = memory_budget * memory.sum().item() + + # workaround to fix https://github.com/pytorch/pytorch/issues/121212 + force_store_random = all([not isinstance(x, torch.Tensor) for x in args]) + + optim_output = _optimize_runtime_with_given_memory( + memory=memory, + runtimes=runtimes, + max_memory=max_memory, + view_like_ops=view_like_ops, + inplace_ops=inplace_ops, + random_ops=rand_ops, + force_store_random=force_store_random, + ) + return _OptimalPolicy(optim_output=optim_output) + + +def _optimize_runtime_with_given_memory( + memory: torch.Tensor, + runtimes: torch.Tensor, + max_memory: float, + view_like_ops: List[int], + inplace_ops: List[Tuple[int, ...]], + random_ops: List[int], + force_store_random: bool, +) -> torch.Tensor: + """ + Given a list of operator names, their corresponding runtimes, and the maximum amount of memory available, + find the subset of operators that can be optimized to reduce runtime while still fitting within the memory budget. + Uses https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.milp.html + + Args: + memory (torch.Tensor): Tensor containing the memory usage of each operator. + runtimes (torch.Tensor): Tensor containing the runtime of each operator. + max_memory (float): Maximum amount of memory to use. + view_like_ops ([List[int]): Indices of the view-like ops. + inplace_ops (List[Tuple[int, int]]): Tuple with the pair of inplace op -> parent of inplace op. + This will be used to add the constraint that in-place ops need to either be + stored in memory with the previous op, or recomputed with the previous op. + random_ops ([List[int]): Indices of the random ops, which will always be recomputed. + force_store_random (bool): force random ops to always be stored (instead of recomputed) + """ + c = -runtimes # type: ignore[operator] + + memory_constraint = LinearConstraint(A=memory, ub=max_memory) + constraints = [memory_constraint] + + # view-like ops should always be recomputed + for i in view_like_ops: + A = torch.zeros_like(c) + A[i] = 1 + constraints.append(LinearConstraint(A=A, lb=0, ub=0)) + + # inplace ops should always be done in conjunction with its parent op + # i.e., if we recompute the parent op the inplace should also be + # recomputed, and vice versa + for op, op_parent in inplace_ops: + A = torch.zeros_like(c) + if op != op_parent: + A[op_parent] = 1 + A[op] = -1 + constraints.append(LinearConstraint(A=A, lb=0, ub=0)) + else: + # if op == op_parent, it's because it's the first op + # that is inplace. Thus never recompute it + A[op] = 1 + constraints.append(LinearConstraint(A=A, lb=1, ub=1)) + + # ideally, always recompute random ops + # in practice, due to a bug in https://github.com/pytorch/pytorch/issues/121212 + # sometimes we need to store them to avoid correctness issues + for i in random_ops: + A = torch.zeros_like(c) + A[i] = 1 + val = int(force_store_random) + constraints.append(LinearConstraint(A=A, lb=val, ub=val)) + + integrality = torch.ones_like(c) + res = milp( + c=c, constraints=constraints, integrality=integrality, bounds=Bounds(0, 1) + ) + if not res.success: + raise ValueError( + "The problem is infeasible, and probably due to a change in xformers " + "that makes random ops always be stored. Try passing a larger memory_budget. " + "This will be fixed once https://github.com/pytorch/pytorch/issues/121212 " + "is solved" + ) + x = torch.from_numpy(res.x) + return x + + +class _OptimalPolicy: + def __init__(self, optim_output: torch.Tensor): + self.counter = 0 + self.optim_output = optim_output.tolist() + + def __call__(self, ctx, func, *args, **kwargs) -> bool: + # returning False means recompute, True means store in memory + if func in OPS_TO_ALWAYS_SKIP: + return False + count = self.counter + self.counter += 1 + return self.optim_output[count] == 1 + + +class SelectiveCheckpointWrapper(ActivationWrapper): + def __init__(self, mod, memory_budget=None, policy_fn=None): + super().__init__(mod) + if not ((memory_budget is None) ^ (policy_fn is None)): + raise ValueError("Need to specify either policy_fn or memory_budget") + self.memory_budget = memory_budget + self.policy_fn = policy_fn + + try: + # for backward-compatibility as this doesn't exist in PT anymore + torch._dynamo.config._experimental_support_context_fn_in_torch_utils_checkpoint = ( + True + ) + except AttributeError: + pass + + @torch.compiler.disable + def _get_policy_fn(self, *args, **kwargs): + if not torch.is_grad_enabled(): + # no need to compute a policy as it won't be used + return [] + # if policy is not specified, initialize policy for a given memory budget + with torch.random.fork_rng(): + policy_fn = get_optimal_checkpoint_policy( + self._checkpoint_wrapped_module, + *args, + **kwargs, + memory_budget=self.memory_budget, + ) + if ( + torch.distributed.is_available() + and torch.distributed.is_initialized() + and torch.distributed.get_world_size() > 1 + ): + # use the same policy across different GPUs + objects = [policy_fn] + torch.distributed.broadcast_object_list(objects, src=0) + policy_fn = objects[0] + return policy_fn + + def get_policy_fn(self, *args, **kwargs): + if self.policy_fn is None: + self.policy_fn = self._get_policy_fn(*args, **kwargs) + return self.policy_fn + + def forward(self, *args, **kwargs): + policy_fn = self.get_policy_fn(*args, **kwargs) + return checkpoint( + self._checkpoint_wrapped_module, *args, **kwargs, policy_fn=policy_fn + ) + + +def selective_checkpoint_wrapper( + module: torch.nn.Module, + memory_budget: Optional[float] = None, + policy_fn: Optional[Callable] = None, +): + """ + Wrap a module with selective activation checkpointing. + + It behaves similarly to PyTorch's checkpoint_wrapper, but gives the possibility + to the user to either specify a handcrafted policy_fn, or to let an optimization + algorithm to select the policy given a user-specified memory_budget. + + The user should either specify the memory_budget argument or the policy_fn. + + The memory_budget is a float value between 0 (recompute everything in the backward) or 1 + (store everything for backward). Using a value of 0 should be similar to PyTorch's + activation checkpoint, while 1 should be similar to the behavior of not using any + activation checkpointing. + """ + return SelectiveCheckpointWrapper(module, memory_budget, policy_fn) diff --git a/.venv/lib/python3.11/site-packages/xformers/cpp_lib.json b/.venv/lib/python3.11/site-packages/xformers/cpp_lib.json new file mode 100644 index 0000000000000000000000000000000000000000..4776f168b5b91f16f6778f8b28dc542f6efdfa82 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/cpp_lib.json @@ -0,0 +1 @@ +{"version": {"cuda": 1201, "hip": null, "torch": "2.5.1+cu121", "python": "3.11.10", "flash": "v2.6.3-24-gbdf733b", "use_torch_flash": true}, "env": {"TORCH_CUDA_ARCH_LIST": "6.0+PTX 7.0 7.5 8.0+PTX 9.0a", "PYTORCH_ROCM_ARCH": null, "XFORMERS_BUILD_TYPE": "Release", "XFORMERS_ENABLE_DEBUG_ASSERTIONS": null, "NVCC_FLAGS": "-allow-unsupported-compiler", "XFORMERS_PACKAGE_FROM": "wheel-v0.0.28.post3"}} \ No newline at end of file diff --git a/.venv/lib/python3.11/site-packages/xformers/factory/__init__.py b/.venv/lib/python3.11/site-packages/xformers/factory/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..882ca9e1d4c855f7d80044e20a02727e1c4577b7 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/factory/__init__.py @@ -0,0 +1,11 @@ +from xformers.components import MultiHeadDispatchConfig # noqa +from xformers.components.attention import AttentionConfig # noqa +from xformers.components.feedforward import FeedforwardConfig # noqa +from xformers.components.positional_embedding import PositionEmbeddingConfig # noqa + +from .block_factory import xFormerDecoderBlock # noqa +from .block_factory import xFormerDecoderConfig # noqa +from .block_factory import xFormerEncoderBlock # noqa +from .block_factory import xFormerEncoderConfig # noqa +from .model_factory import xFormer, xFormerConfig # noqa +from .weight_init import xFormerWeightInit # noqa diff --git a/.venv/lib/python3.11/site-packages/xformers/factory/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/factory/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97f79d6af2f5756452307cf2c318ab7d00d5f5f7 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/factory/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/factory/__pycache__/block_configs.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/factory/__pycache__/block_configs.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..22577bd0b7414451d337b24073f7cf9a0c9d51c6 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/factory/__pycache__/block_configs.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/factory/__pycache__/block_factory.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/factory/__pycache__/block_factory.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..724fbf24d80aa2cf0701d52e6e13d455414687a6 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/factory/__pycache__/block_factory.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/factory/__pycache__/hydra_helper.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/factory/__pycache__/hydra_helper.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f3e6bcd00381abb63470395096680522d8ae973e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/factory/__pycache__/hydra_helper.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/factory/__pycache__/model_factory.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/factory/__pycache__/model_factory.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a5030864b3c6982e9879496b98acdd96f4cbf87 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/factory/__pycache__/model_factory.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/factory/__pycache__/weight_init.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/factory/__pycache__/weight_init.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..00e561a513449caa5190f872407096934db26cd8 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/factory/__pycache__/weight_init.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/factory/block_configs.py b/.venv/lib/python3.11/site-packages/xformers/factory/block_configs.py new file mode 100644 index 0000000000000000000000000000000000000000..a628fa1510e31e6bdf43135bddd300c1443ecd91 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/factory/block_configs.py @@ -0,0 +1,237 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from dataclasses import dataclass +from enum import Enum +from typing import Any, Dict, Optional + +from xformers.components import NormalizationType, ResidualNormStyle +from xformers.components.feedforward import FEEDFORWARD_REGISTRY, FeedforwardConfig +from xformers.components.positional_embedding import ( + POSITION_EMBEDDING_REGISTRY, + PositionEmbeddingConfig, +) +from xformers.utils import generate_matching_config + + +class LayerPositionBitmask(int, Enum): + First = 0b01 + Last = 0b10 + Default = 0b11 + + +class LayerPosition: + """Bitmask to mark this layer as first, last, nothing or both""" + + def __init__(self): + self.bitmask = LayerPositionBitmask.Default + + def is_first(self): + return bool(self.bitmask & LayerPositionBitmask.First) + + def is_last(self): + return bool(self.bitmask & LayerPositionBitmask.Last) + + def mark_not_first(self): + self.bitmask &= ~LayerPositionBitmask.First + + def mark_not_last(self): + self.bitmask &= ~LayerPositionBitmask.Last + + +class BlockType(str, Enum): + Encoder = "encoder" + Decoder = "decoder" + + +@dataclass(init=False) # handle constructors explicitly to force type changes +class xFormerBlockConfig: + """ + The configuration structure to define a Transformer block. + This base class is applicable to both encoder and decoder definitions. + + This completely defines each of the blocks, for instance in terms of dimensions, + position encoding, pre or post layer norms or reversibility. + """ + + dim_model: int + feedforward_config: FeedforwardConfig + position_encoding_config: Optional[PositionEmbeddingConfig] + block_type: BlockType + residual_norm_style: ResidualNormStyle + normalization: NormalizationType + layer_position: LayerPosition + use_triton: bool + reversible: bool + num_layers: int + + def __init__( + self, + dim_model: int, + feedforward_config: Dict[str, Any], + position_encoding_config: Optional[Dict[str, Any]], + block_type: BlockType, + residual_norm_style: ResidualNormStyle = ResidualNormStyle("post"), + normalization: NormalizationType = NormalizationType.LayerNorm, + reversible: bool = False, + num_layers: int = 1, + layer_position: Optional[LayerPosition] = None, + ): + + self.dim_model = dim_model + self.block_type = block_type + self.residual_norm_style = residual_norm_style + self.reversible = reversible + self.num_layers = num_layers + self.normalization = normalization + + # Fill in possible gaps in the config for subparts of the block + self.feedforward_config = generate_matching_config( + feedforward_config, + FEEDFORWARD_REGISTRY[feedforward_config["name"]].config, + ) + + self.position_encoding_config = ( + generate_matching_config( + position_encoding_config, + POSITION_EMBEDDING_REGISTRY[position_encoding_config["name"]].config, + ) + if position_encoding_config is not None + else None + ) + + # Default is that this layer is the only one, so both first and last + if layer_position: + self.layer_position = layer_position + else: + self.layer_position = LayerPosition() + + +@dataclass(init=False) +class xFormerEncoderConfig(xFormerBlockConfig): + """ + The configuration structure for an encoder block + """ + + multi_head_config: Dict[str, Any] + use_triton: bool + simplicial_embeddings: Optional[Dict[str, Any]] + patch_embedding_config: Optional[Dict[str, Any]] + + def __init__( + self, + dim_model: int, + feedforward_config: Dict[str, Any], + multi_head_config: Dict[str, Any], + position_encoding_config: Optional[Dict[str, Any]] = None, + residual_norm_style: str = "post", + normalization: NormalizationType = NormalizationType.LayerNorm, + use_triton: bool = True, + simplicial_embeddings: Optional[Dict[str, Any]] = None, + patch_embedding_config: Optional[Dict[str, Any]] = None, + **kwargs, + ): + # Convenience, fill in duplicated fields + try: + if "dim_model" not in multi_head_config.keys(): + multi_head_config["dim_model"] = dim_model + + if "dim_model" not in feedforward_config.keys(): + feedforward_config["dim_model"] = dim_model + + if ( + position_encoding_config is not None + and "dim_model" not in position_encoding_config.keys() + ): + position_encoding_config["dim_model"] = dim_model + + if ( + patch_embedding_config is not None + and "out_channels" not in patch_embedding_config.keys() + ): + patch_embedding_config["out_channels"] = dim_model + + except AttributeError: + # A config instance was passed in, this is fine + pass + if "block_type" in kwargs: + assert kwargs["block_type"] == "encoder" + kwargs["block_type"] = BlockType("encoder") + super().__init__( + dim_model=dim_model, + feedforward_config=feedforward_config, + position_encoding_config=position_encoding_config, + residual_norm_style=ResidualNormStyle(residual_norm_style), + normalization=NormalizationType(normalization), + **kwargs, + ) + + self.multi_head_config = multi_head_config + self.use_triton = use_triton + self.simplicial_embeddings = simplicial_embeddings + self.patch_embedding_config = patch_embedding_config + + +@dataclass(init=False) +class xFormerDecoderConfig(xFormerBlockConfig): + """ + The configuration structure for a decoder block. + + This specifically defines the masked and cross attention mechanisms, + on top of the settings defining all blocks. + """ + + multi_head_config_masked: Dict[str, Any] # prior to encoder output + multi_head_config_cross: Dict[str, Any] # cross attention, takes encoder output + + def __init__( + self, + dim_model: int, + feedforward_config: Dict[str, Any], + multi_head_config_masked: Dict[str, Any], + multi_head_config_cross: Dict[str, Any], + position_encoding_config: Optional[Dict[str, Any]] = None, + residual_norm_style: str = "post", + normalization: NormalizationType = NormalizationType.LayerNorm, + use_triton: bool = True, + **kwargs, + ): + + # Convenience, fill in duplicated field + try: + if "dim_model" not in multi_head_config_masked.keys(): + multi_head_config_masked["dim_model"] = dim_model + + if "dim_model" not in multi_head_config_cross.keys(): + multi_head_config_cross["dim_model"] = dim_model + + if "dim_model" not in feedforward_config.keys(): + feedforward_config["dim_model"] = dim_model + + if ( + position_encoding_config is not None + and "dim_model" not in position_encoding_config.keys() + ): + position_encoding_config["dim_model"] = dim_model + except AttributeError: + # A config instance was passed in, this is fine + pass + if "block_type" in kwargs.keys(): + assert kwargs["block_type"] == "decoder" + kwargs["block_type"] = BlockType("decoder") + + super().__init__( + dim_model=dim_model, + feedforward_config=feedforward_config, + position_encoding_config=position_encoding_config, + residual_norm_style=ResidualNormStyle(residual_norm_style), + normalization=NormalizationType(normalization), + **kwargs, + ) + + self.multi_head_config_masked = multi_head_config_masked + self.multi_head_config_cross = multi_head_config_cross + self.use_triton = use_triton diff --git a/.venv/lib/python3.11/site-packages/xformers/factory/block_factory.py b/.venv/lib/python3.11/site-packages/xformers/factory/block_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..4868f9f4de9ce63c7613ccad270974f3abcb65ff --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/factory/block_factory.py @@ -0,0 +1,358 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import logging +from dataclasses import asdict +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn + +from xformers._deprecation_warning import deprecated_function +from xformers.components import ( + PatchEmbeddingConfig, + PostNorm, + PreNorm, + Residual, + ResidualNormStyle, + build_multi_head_attention, + build_patch_embedding, +) +from xformers.components.attention import AttentionMask +from xformers.components.feedforward import build_feedforward +from xformers.components.positional_embedding import build_positional_embedding +from xformers.components.residual import get_deepnorm_coefficients +from xformers.components.simplicial_embedding import SimplicialEmbedding +from xformers.factory.block_configs import ( + NormalizationType, + xFormerDecoderConfig, + xFormerEncoderConfig, +) + +logger = logging.getLogger("xformers") + + +def _get_ln_factory( + d_model: int, + residual_norm_style: Optional[ResidualNormStyle], + use_triton: bool, + residual: bool, + normalization: NormalizationType = NormalizationType.LayerNorm, + residual_scale: float = 1.0, +): + """ + Handle all the supported residual path configurations. + + ..Note: we return the appropriate constructor, not an actual layer + """ + + def get_layer_wrapper( + d_model: int, + sublayer: nn.Module, + residual_norm_style: Optional[ResidualNormStyle], + residual: bool, + residual_scale: float, + ): + if residual: + if residual_norm_style == ResidualNormStyle.Pre: + return Residual( + layer=PreNorm(d_model, sublayer, normalization, use_triton), + scale=None, + ) + elif residual_norm_style == ResidualNormStyle.Post: + return PostNorm( + d_model, + Residual(layer=sublayer, scale=None), + normalization, + use_triton, + ) + elif residual_norm_style == ResidualNormStyle.DeepNorm: + return PostNorm( + d_model, + Residual(layer=sublayer, scale=residual_scale), + normalization, + use_triton=use_triton, + ) + else: + raise ValueError + + return ( + PreNorm(d_model, sublayer, normalization, use_triton) + if residual_norm_style == ResidualNormStyle.Pre + else PostNorm(d_model, sublayer, normalization, use_triton) + ) + + def ln_factory(sublayer: nn.Module): + return get_layer_wrapper( + d_model, sublayer, residual_norm_style, residual, residual_scale + ) + + return ln_factory + + +class xFormerEncoderBlock(torch.nn.Module): + r"""A vanilla Transformer Encoder block""" + + def __init__(self, config: xFormerEncoderConfig, **kwargs): + super().__init__() + deprecated_function(self) + + self.reversible_f = None + self.reversible_g = None + self.residual_norm_style = config.residual_norm_style + self.dim_model = config.dim_model + + # If this layer is the first one, and a pose encoding has been requested + if ( + config.position_encoding_config is not None + and config.layer_position.is_first() + ): + self.pose_encoding = build_positional_embedding( + asdict(config.position_encoding_config) + ) + + pos_encoding_dim = config.position_encoding_config.dim_model + mha_dim = config.multi_head_config["dim_model"] + + if pos_encoding_dim != mha_dim: + logger.warning( + f"The embedding dim and model dim do not match ({pos_encoding_dim} vs {mha_dim}), adding a projector layer." # noqa + ) + self.embedding_projector = nn.Linear(pos_encoding_dim, mha_dim) + else: + self.pose_encoding = None + + if config.residual_norm_style == ResidualNormStyle.DeepNorm: + # Just use the layer norm coefficient here, + # the init will be handled at the xformers level (knows about encoder and decoder blocks) + deep_norm_coefficients, _ = get_deepnorm_coefficients( + encoder_layers=config.num_layers, decoder_layers=0 + ) + assert deep_norm_coefficients is not None + residual_scale = deep_norm_coefficients.alpha + else: + residual_scale = 1.0 + + # mini helper, builds a normalization layer with the right Pre/Post config, residuals, and the right dimensions + ln_factory = _get_ln_factory( + config.dim_model, + config.residual_norm_style, + use_triton=config.use_triton, + residual=True, + residual_scale=residual_scale, + normalization=config.normalization, + ) + + mha = build_multi_head_attention(config.multi_head_config) + feedforward = build_feedforward(asdict(config.feedforward_config)) + + # Expose attention specific capabilities + self.supports_attention_mask = mha.attention.supports_attention_mask + self.requires_same_k_q_dimensions = mha.attention.requires_same_k_q_dimensions + self.causal = ( + mha.attention.causal if hasattr(mha.attention, "causal") else False + ) + + # Wrappers handle the different layer norm styles (pre- and post-) and the residual path + self.wrap_att = ln_factory(mha) + self.wrap_ff: Union[Residual, PostNorm] = ln_factory(feedforward) + if ( + config.residual_norm_style == ResidualNormStyle.Pre + and config.layer_position.is_last() + ): + self.wrap_ff = PostNorm( + config.dim_model, + self.wrap_ff, + normalization=config.normalization, + use_triton=config.use_triton, + ) + + # Simplicial embeddings are only used if specified, and on the last layer + self.simplicial_embedding: Optional[SimplicialEmbedding] = None + if config.simplicial_embeddings is not None and config.layer_position.is_last(): + self.simplicial_embedding = SimplicialEmbedding( + **config.simplicial_embeddings + ) + + # Optional patch embedding + self.patch_emb: Optional[nn.Module] = None + + if config.patch_embedding_config is not None: + self.patch_emb = build_patch_embedding( + PatchEmbeddingConfig(**config.patch_embedding_config) + ) + + @classmethod + def from_config(cls, config: xFormerEncoderConfig): + return cls(config) + + @staticmethod + def get_reversible_layer(config) -> Tuple[nn.Module, nn.Module]: + ln_factory = _get_ln_factory( + config.dim_model, + config.residual_norm_style, + residual=False, + use_triton=config.use_triton, + normalization=config.normalization, + ) + + mha = build_multi_head_attention(config.multi_head_config) + feedforward = build_feedforward(asdict(config.feedforward_config)) + + reversible_f = ln_factory(mha) + reversible_g = ln_factory(feedforward) + return reversible_f, reversible_g + + def forward( + self, + x: torch.Tensor, + att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None, + input_mask: Optional[torch.Tensor] = None, + ): + if self.patch_emb is not None: + x = self.patch_emb(x) + + if self.pose_encoding is not None: + x = self.pose_encoding(x) + + if hasattr(self, "embedding_projector"): + x = self.embedding_projector(x) + + # Handle the optional input masking, differs on Q, K, V + if input_mask is not None: + q = x + k = x * input_mask.unsqueeze(-1) + v = k + else: + q, k, v = x, x, x + + # Pre/Post norms and residual paths are already handled + x = self.wrap_att(inputs=[q, k, v], att_mask=att_mask) + x = self.wrap_ff(inputs=[x]) + + # Optional simplicial embeddings + if self.simplicial_embedding is not None: + x = self.simplicial_embedding(x) + + return x + + +class xFormerDecoderBlock(torch.nn.Module): + r"""A vanilla Transformer Decoder block + + ... note: this implementation is not (yet ?) reversible""" + + def __init__(self, config: xFormerDecoderConfig, **kwargs): + super().__init__() + deprecated_function(self) + + # If this layer is the first one, and a pose encoding as been requested + if ( + config.position_encoding_config is not None + and config.layer_position.is_first() + ): + self.pose_encoding = build_positional_embedding( + config.position_encoding_config + ) + + pos_encoding_dim = config.position_encoding_config.dim_model + mha_dim = config.multi_head_config_masked["dim_model"] + + if pos_encoding_dim != mha_dim: + + logger.warning( + f"The embedding dim and model dim do not match ({pos_encoding_dim} vs {mha_dim}), adding a projector layer." # noqa + ) + + self.embedding_projector = nn.Linear(pos_encoding_dim, mha_dim) + else: + self.pose_encoding = None + + if config.residual_norm_style == ResidualNormStyle.DeepNorm: + # Just use the layer norm coefficient here, + # the init will be handled at the xformers level (knows about encoder and decoder blocks) + _, deep_norm_coefficients = get_deepnorm_coefficients( + encoder_layers=0, decoder_layers=config.num_layers + ) + assert deep_norm_coefficients is not None + residual_scale = deep_norm_coefficients.alpha + else: + residual_scale = 1.0 + + # mini helper, builds a LayerNorm with the right Pre/Post config and the right dimensions + ln_factory = _get_ln_factory( + config.dim_model, + config.residual_norm_style, + use_triton=config.use_triton, + residual=True, + residual_scale=residual_scale, + normalization=config.normalization, + ) + + mha = build_multi_head_attention(config.multi_head_config_masked) + cross_mha = build_multi_head_attention(config.multi_head_config_cross) + feedforward = build_feedforward(config.feedforward_config) + + # Expose attention or feedforward specific capabilities + self.supports_attention_mask = mha.attention.supports_attention_mask + self.requires_same_k_q_dimensions = mha.attention.requires_same_k_q_dimensions + self.requires_squared_context_length = ( + feedforward.requires_squared_context + or mha.attention.requires_squared_context + ) + + self.causal_attention = ( + mha.attention.causal if hasattr(mha.attention, "causal") else False + ) + + # Wrappers handle the different layer norm styles (pre- and post-) and the residual path + self.wrap_att = ln_factory(mha) + self.wrap_cross = ln_factory(cross_mha) + self.wrap_ff: Union[Residual, PostNorm] = ln_factory(feedforward) + + if ( + config.residual_norm_style == ResidualNormStyle.Pre + and config.layer_position.is_last() + ): + self.wrap_ff = PostNorm( + config.dim_model, + self.wrap_ff, + normalization=NormalizationType.LayerNorm, + ) + + @classmethod + def from_config(cls, config: xFormerDecoderConfig): + return cls(config) + + def forward( + self, + target: torch.Tensor, + memory: torch.Tensor, + encoder_att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None, + decoder_att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None, + input_mask: Optional[torch.Tensor] = None, + ): + if self.pose_encoding is not None: + target = self.pose_encoding(target) + + if hasattr(self, "embedding_projector"): + target = self.embedding_projector(target) + + # Handle the optional input masking, differs on Q, K, V + if input_mask is not None: + target_q = target + target_k = target * input_mask.unsqueeze(-1) + target_v = target_k + else: + target_q, target_k, target_v = target, target, target + + x = self.wrap_att( + inputs=[target_q, target_k, target_v], att_mask=decoder_att_mask + ) + x = self.wrap_cross(inputs=[x, memory, memory], att_mask=encoder_att_mask) + x = self.wrap_ff(inputs=[x]) + + return x diff --git a/.venv/lib/python3.11/site-packages/xformers/factory/hydra_helper.py b/.venv/lib/python3.11/site-packages/xformers/factory/hydra_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..f94557e85a25ff964bfb0535f66420df4a455109 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/factory/hydra_helper.py @@ -0,0 +1,36 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +# register components configs into Hydra ConfigStore +# component config classes could be used to validate configs +import logging + +from hydra.core.config_store import ConfigStore +from omegaconf.errors import ValidationError + +from xformers.components.attention import ATTENTION_REGISTRY +from xformers.components.feedforward import FEEDFORWARD_REGISTRY +from xformers.components.positional_embedding import POSITION_EMBEDDING_REGISTRY + +logger = logging.getLogger("xformers") + + +def import_xformer_config_schema(): + """ + Best effort - OmegaConf supports limited typing, so we may fail to import + certain config classes. For example, pytorch typing are not supported. + """ + cs = ConfigStore.instance() + + for k, v in { + "ff": FEEDFORWARD_REGISTRY, + "pe": POSITION_EMBEDDING_REGISTRY, + "attention": ATTENTION_REGISTRY, + }.items(): + for kk in v.keys(): + try: + cs.store(name=f"{kk}_schema", node=v[kk].config, group=f"xformers/{k}") + except ValidationError as e: + logger.debug(f"Error registering {kk}_schema, error: {e}") diff --git a/.venv/lib/python3.11/site-packages/xformers/factory/model_factory.py b/.venv/lib/python3.11/site-packages/xformers/factory/model_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..186529b6847387f67f8a0bc9682477f2dff576ab --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/factory/model_factory.py @@ -0,0 +1,313 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import logging +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Union + +import torch + +from xformers._deprecation_warning import deprecated_function +from xformers.components import reversible as rv +from xformers.components.residual import ResidualNormStyle, get_deepnorm_coefficients +from xformers.factory.block_configs import ( + xFormerBlockConfig, + xFormerDecoderConfig, + xFormerEncoderConfig, +) +from xformers.factory.block_factory import xFormerDecoderBlock, xFormerEncoderBlock +from xformers.factory.weight_init import get_weight_init_fn, xFormerWeightInit + +logger = logging.getLogger("xformers") + + +@dataclass(init=False) +class xFormerConfig: + """ + The configuration structure to define a full Transformer. + This can include a stack of encoder layers, and a stack of decoder layers. + + It is optionally possible to share the embedding weights in between + the encoder and decoder positional encoding, as proposed for instance by + `Using the Output Embedding to Improve Language Models`, Press et al. + + A full config example is for instance as follows: + + :: + + xformer_config = [ + { + "reversible": False, # Turn on to test the effect of using reversible layers + "block_type": "encoder", + "num_layers": LAYERS, + "dim_model": EMB, + "residual_norm_style": "pre", + "position_encoding_config": { + "name": "vocab", + "seq_len": CONTEXT, + "vocab_size": VOCAB_SIZE, + }, + "multi_head_config": { + "num_heads": NUM_HEADS, + "residual_dropout": RES_DROP, + "use_rotary_embeddings": True, + "attention": { + "name": ATTENTION_MECHANISM_STR, + "dropout": ATTN_DROP, + "causal": True, + "seq_len": CONTEXT, + }, + }, + "feedforward_config": { + "name": "MLP", + "dropout": MLP_DROP, + "activation": "gelu", + "hidden_layer_multiplier": MLP_MULTIPLIER, + }, + } + ] + + + .. _`Using the Output Embedding to Improve Language Models`: https://arxiv.org/pdf/1608.05859.pdf + """ + + stack_configs: Union[List[xFormerBlockConfig], Dict[str, xFormerBlockConfig]] + tie_embedding_weights: bool = False + weight_init: xFormerWeightInit = xFormerWeightInit.ViT + + def __init__( + self, + stack_configs: Union[List[Dict[str, Any]], Dict[str, Dict[str, Any]]], + tie_embedding_weights: bool = False, + weight_init: xFormerWeightInit = xFormerWeightInit.ViT, + ): + # Type all the configurations. Possible typos are caught here + if isinstance(stack_configs, dict): + self.stack_configs = {} + for k, config in stack_configs.items(): + if config["block_type"] == "encoder": + self.stack_configs[k] = xFormerEncoderConfig(**config) + else: + self.stack_configs[k] = xFormerDecoderConfig(**config) + else: + self.stack_configs = [] + for config in stack_configs: + if config["block_type"] == "encoder": + self.stack_configs.append(xFormerEncoderConfig(**config)) + else: + self.stack_configs.append(xFormerDecoderConfig(**config)) + + self.tie_embedding_weights = tie_embedding_weights + self.weight_init = weight_init + deprecated_function(self) + + +class xFormer(torch.nn.Module): + def __init__( + self, + stack_configs: Union[ + xFormerBlockConfig, List[xFormerBlockConfig], Dict[str, xFormerBlockConfig] + ], + tie_embedding_weights: bool = False, + weight_init: xFormerWeightInit = xFormerWeightInit.ViT, + ): + """ + Given a serialized configuration, generate the corresponding model. + This is only a helper and can easily be bypassed + """ + super().__init__() + deprecated_function(self) + + if isinstance(stack_configs, Dict): + stack_configs = list(stack_configs.values()) + + # Convenience, users can pass either a list of configs or a single one + if not isinstance(stack_configs, List): + stack_configs = [stack_configs] + + # Sanity checks, some config combinations do not make sense + self._verify_reversible(stack_configs) + self._verify_deepnorm(stack_configs) + + encoders: List[torch.nn.Module] = [] + decoders: List[torch.nn.Module] = [] + + self.reversible_encoder = False + self.rev_enc_pose_encoding = None + + # Unroll the configs and build the model + for config in stack_configs: + # Handle either Encoder or Decoder stacks + builder = ( + xFormerEncoderBlock.from_config + if isinstance(config, xFormerEncoderConfig) + else xFormerDecoderBlock.from_config + ) + recipient = ( + encoders if isinstance(config, xFormerEncoderConfig) else decoders + ) + + # Build up the stack + for i in range(config.num_layers): + # Label where this layer is in the stack + # (for instance useful for the positional encoding, or late layer norm) + if len(recipient) > 0: + config.layer_position.mark_not_first() + + if config != stack_configs[-1] or i < config.num_layers - 1: + config.layer_position.mark_not_last() + + block = builder(config) # type: ignore + + # If reversible: extract the reversible sub-parts, else append the block as-is + if config.reversible: + # WARNING: only one pose encoding is saved here (not Focal Transformer compatible for instance) + assert isinstance(config, xFormerEncoderConfig) + if block.pose_encoding is not None: + self.rev_enc_pose_encoding = block.pose_encoding + self.reversible_encoder = True + + f, g = xFormerEncoderBlock.get_reversible_layer(config) + recipient.append(torch.nn.ModuleList([f, g])) + else: + recipient.append(block) # type: ignore + + # Tie embedding weights, if requested and possible + assert ( + not tie_embedding_weights or not self.reversible_encoder + ), "Reversible layers and tied embeddings is not supported for now" + + if ( + tie_embedding_weights + and encoders + and encoders[0].pose_encoding + and decoders + and decoders[0].pose_encoding + and not config.reversible + ): + logger.info("Tying encoder and decoder embeddings, as requested") + encoders[0].pose_encoding = decoders[0].pose_encoding + + self.encoders: torch.nn.Module = ( + rv.ReversibleSequence(torch.nn.ModuleList(encoders)) + if self.reversible_encoder + else torch.nn.ModuleList(encoders) + ) + self.decoders = torch.nn.ModuleList(decoders) + + use_deepnorm = ( + stack_configs[0].residual_norm_style == ResidualNormStyle.DeepNorm + ) + + assert ( + not use_deepnorm or not self.reversible_encoder + ), "Reversible layers and deepnorm is not supported for now" + + self.init_weights(weight_init=weight_init, use_deep_norm=use_deepnorm) + + @classmethod + def from_config(cls, config: xFormerConfig): + return cls( + config.stack_configs, config.tie_embedding_weights, config.weight_init + ) + + def _verify_reversible(self, stack_configs: List[xFormerBlockConfig]): + reversible = [ + c.reversible + for c in filter(lambda x: x.block_type == "encoder", stack_configs) + ] + + assert all(reversible) or not any(reversible), ( + "All layers need to have the same reversibility setting. " + + f"Currently {reversible}" + ) + + def _verify_deepnorm(self, stack_configs: List[xFormerBlockConfig]): + deepnorm = [ + c.residual_norm_style == ResidualNormStyle.DeepNorm for c in stack_configs + ] + + assert all(deepnorm) or not any(deepnorm), ( + "All layers need to have the same deepnorm setting. " + + f"Currently {deepnorm}" + ) + + def init_weights(self, weight_init: xFormerWeightInit, use_deep_norm: bool): + # The deepnorm weight initialization method requires different gain factors for the encoder + # and decoder, depending on the general model structure (number of respective layers) + if use_deep_norm: + encoder_coefficients, decoder_coefficients = get_deepnorm_coefficients( + encoder_layers=len(self.encoders), decoder_layers=len(self.decoders) # type: ignore + ) + else: + encoder_coefficients, decoder_coefficients = None, None + + encoder_gain = ( + encoder_coefficients.beta if encoder_coefficients is not None else 1.0 + ) + decoder_gain = ( + decoder_coefficients.beta if decoder_coefficients is not None else 1.0 + ) + + # Pick the desired init function + init_fn = get_weight_init_fn(weight_init) + + # Initialize all the encoder weights + for name, module in self.encoders.named_children(): + init_fn(module=module, name=name, gain=encoder_gain) + + for name, module in self.decoders.named_children(): + init_fn(module=module, name=name, gain=decoder_gain) + + def forward( + self, + src: torch.Tensor, + tgt: Optional[torch.Tensor] = None, + encoder_input_mask: Optional[torch.Tensor] = None, + decoder_input_mask: Optional[torch.Tensor] = None, + ) -> Optional[torch.Tensor]: + + # Encode to latent space if encoder is present + if len(list(self.encoders.parameters())) > 0: + encoders = self.encoders + memory = src.clone() + if isinstance(encoders, torch.nn.ModuleList): + for encoder in encoders: + memory = encoder(memory, input_mask=encoder_input_mask) + else: + if self.rev_enc_pose_encoding: + memory = self.rev_enc_pose_encoding(src) + + # Reversible Encoder + x = torch.cat([memory, memory], dim=-1) + + # Apply the optional input masking + if encoder_input_mask is not None: + if x.dim() - encoder_input_mask.dim() > 1: + encoder_input_mask.unsqueeze(0) + x += encoder_input_mask.unsqueeze(-1) + + x = encoders(x) + memory = torch.stack(x.chunk(2, dim=-1)).mean(dim=0) + + if not self.decoders: + return memory + + # If decoder: either use the encoder output, or just decode, both options are possible + if len(self.decoders) > 0: + tgt = src.clone() if tgt is None else tgt + + for decoder in self.decoders: + tgt = decoder( + target=tgt, + # pyre-fixme[61]: `memory` is not always initialized here. + memory=memory, + input_mask=decoder_input_mask, + ) + + return tgt + + return None diff --git a/.venv/lib/python3.11/site-packages/xformers/factory/weight_init.py b/.venv/lib/python3.11/site-packages/xformers/factory/weight_init.py new file mode 100644 index 0000000000000000000000000000000000000000..754d3bed632ae4c7e90cf2e1c4167e8761377665 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/factory/weight_init.py @@ -0,0 +1,293 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +# CREDITS: Reusing a lot of code from the Timm repo +# main difference is probably the handling of deepnorm init, and adapting to some xformers specificities +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py + +import logging +import math +from enum import Enum +from typing import Callable + +import torch +import torch.nn as nn +from torch.nn.init import ( + _calculate_fan_in_and_fan_out, + _no_grad_trunc_normal_, + _no_grad_uniform_, +) + +logger = logging.getLogger("xformers") + + +_assert_if_not_initialized = False + + +class xFormerWeightInit(str, Enum): + Timm = "timm" + ViT = "vit" + Moco = "moco" + Small = "small" + + +def get_weight_init_fn(init_choice: xFormerWeightInit): + """ + Provide the xFormers factory with weight init routines. + + Supported initializations are: + - Small: follow the method outlined in `Transformer Without Tears`_ + - ViT: follow the initialization in the reference ViT_ codebase + - Timm: follow the initialization in the reference Timm_ codebase + - Moco: follow the initialization in the reference MocoV3_ codebase + + .. _ViT: https://github.com/google-research/vision_transformer + .. _Timm: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py + .. _MocoV3: https://github.com/facebookresearch/moco-v3 + """ + return { + xFormerWeightInit.Timm: _init_weights_vit_timm, + xFormerWeightInit.ViT: _init_weights_vit_jax, + xFormerWeightInit.Moco: _init_weights_vit_moco, + xFormerWeightInit.Small: _init_weights_small, + }[init_choice] + + +# Define pattern matches +def is_ffn(n): + return "feedforward" in n or ("wrap_ff" in n and not n.endswith("norm")) + + +def is_mha_input_projection(n): + return "q_proj" in n or "k_proj" in n or "v_proj" in n + + +# Define distribution helpers +def _small_init_(tensor: torch.Tensor, gain: float = 1.0) -> torch.Tensor: + r"""Fills the input `Tensor` with values according to the method + described in `Transformer Without Tears`_, using a uniform distribution. + + This is a variation of the Xavier init. The resulting tensor will have values sampled from + :math:`\mathcal{U}(-a, a)` where + + .. math:: + a = \text{gain} \times \sqrt{\frac{6}{\text{fan\_in} + 4 * \text{fan\_out}}} + + Also known as Glorot initialization. + + Args: + tensor: an n-dimensional `torch.Tensor` + gain: an optional scaling factor + + .. _`Transformer Without Tears`: https://arxiv.org/abs/1910.05895 + + """ + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + std = gain * math.sqrt(2.0 / float(fan_in + 4 * fan_out)) + a = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation + + return _no_grad_uniform_(tensor, -a, a) + + +def _lecun_normal(tensor, gain=1.0): + fan_in, _ = _calculate_fan_in_and_fan_out(tensor) + denom = fan_in + variance = gain / denom + + # constant is stddev of standard normal truncated to (-2, 2) + _no_grad_trunc_normal_( + tensor, + mean=0.0, + std=math.sqrt(variance) / 0.87962566103423978, + a=-2.0, + b=2.0, + ) + + +# Helpers to keep all the functions typesafe, and handle corner cases and common behaviours in one place +def _maybe_init_tensor(module: nn.Module, attr: str, distribution_: Callable, **kwargs): + # Small helper to catch all the corner cases, while staying type safe + if hasattr(module, attr): + maybe_tensor = getattr(module, attr) + if maybe_tensor is not None and isinstance(maybe_tensor, torch.Tensor): + distribution_(maybe_tensor, **kwargs) + + +def _maybe_report_no_init(module, name): + if len(list(module.named_children())) == 0 and ( + hasattr(module, "weight") or hasattr(module, "bias") + ): + # Skip layer norm, this is ok + if isinstance(module, torch.nn.LayerNorm): + return + + # Skip nn.Embedding, we typically initialize it one level up, else Pytorch has a valid default + if isinstance(module, torch.nn.Embedding): + return + + # This is unexpected, warn about a possible unhandled weight + logger.warning( + f"Not initializing weights in {name}, this could be a mistake.\nModule {module}" + ) + + if _assert_if_not_initialized: + assert False, ( + f"Uninitialized weight found in {module}." + + " If you have a custom module, please provide a `init_weights()` method" + ) + + +# Define the different initialization schemes +def _init_weights_vit_jax( + module: nn.Module, + name: str = "", + head_bias: float = 0.0, + gain: float = 1.0, + deepnorm_style: bool = False, + **kwargs, +): + """ViT weight initialization, matching JAX (Flax) impl""" + + if is_ffn(name): + _maybe_init_tensor(module, "bias", nn.init.normal_, std=1e-6) + _maybe_init_tensor(module, "weight", torch.nn.init.xavier_uniform_, gain=gain) + + elif is_mha_input_projection(name) or isinstance(module, nn.Linear): + if deepnorm_style and ( + "q_proj" in name.split(".") or "k_proj" in name.split(".") + ): + gain = 1.0 + + _maybe_init_tensor(module, "weight", torch.nn.init.xavier_uniform_, gain=gain) + _maybe_init_tensor(module, "bias", nn.init.zeros_) + + elif isinstance(module, nn.Conv2d): + _maybe_init_tensor(module, "weight", _lecun_normal, gain=gain) + _maybe_init_tensor(module, "bias", nn.init.zeros_) + + elif hasattr(module, "init_weights"): + module.init_weights() # type: ignore + + else: + _maybe_report_no_init(module, name) + + # Recurse over the children, if the weight init is being handled here + if not hasattr(module, "init_weights"): + for child_name, child_module in module.named_children(): + _init_weights_vit_jax(child_module, f"{name}.{child_name}", head_bias, gain) + + +def _init_weights_vit_moco( + module: nn.Module, + name: str = "", + gain: float = 1.0, + **kwargs, +): + """ViT weight initialization, matching moco-v3 impl minus fixed PatchEmbed""" + + assert ( + "deepnorm_style" not in kwargs.keys() + ), "This initialization method does not support deepnorm" + + if is_ffn(name): + _maybe_init_tensor(module, "weight", torch.nn.init.xavier_uniform_, gain=gain) + _maybe_init_tensor(module, "bias", nn.init.zeros_) + + elif is_mha_input_projection(name) or isinstance(module, nn.Linear): + if isinstance(module.weight, torch.Tensor): + val = ( + math.sqrt(6.0 / float(module.weight.shape[0] + module.weight.shape[1])) + * gain + ) + _maybe_init_tensor(module, "weight", nn.init.uniform_, a=-val, b=val) + + _maybe_init_tensor(module, "bias", nn.init.zeros_) + + elif hasattr(module, "init_weights"): + module.init_weights(gain=gain) # type: ignore + + else: + _maybe_report_no_init(module, name) + + # Recurse over the children, if the weight init is being handled here + if not hasattr(module, "init_weights"): + for child_name, child_module in module.named_children(): + _init_weights_vit_moco(child_module, child_name, gain) + + +def _init_weights_small( + module: nn.Module, + name: str = "", + head_bias: float = 0.0, + gain: float = 1.0, + deepnorm_style: bool = False, + **kwargs, +): + """Follow the `Transformer Without Tears`_ initialization for self-attention""" + + if is_ffn(name): + _maybe_init_tensor(module, "weight", torch.nn.init.xavier_uniform_, gain=gain) + _maybe_init_tensor(module, "bias", nn.init.normal_, std=1e-6) + + elif is_mha_input_projection(name) or isinstance(module, nn.Linear): + # "small init" only scales the attention layers init, not the FFN + if deepnorm_style and ( + "q_proj" in name.split(".") or "k_proj" in name.split(".") + ): + gain = 1.0 + + _maybe_init_tensor(module, "weight", _small_init_, gain=gain) + _maybe_init_tensor(module, "bias", nn.init.zeros_) + + elif isinstance(module, nn.Conv2d): + _maybe_init_tensor(module, "weight", _lecun_normal) + _maybe_init_tensor(module, "bias", nn.init.zeros_) + elif hasattr(module, "init_weights"): + module.init_weights() # type: ignore + else: + _maybe_report_no_init(module, name) + + # Recurse over the children, if the weight init is being handled here + if not hasattr(module, "init_weights"): + for child_name, child_module in module.named_children(): + _init_weights_small(child_module, f"{name}.{child_name}", head_bias, gain) + + +def _init_weights_vit_timm( + module: nn.Module, + name: str = "", + gain: float = 1.0, + deepnorm_style: bool = False, + **kwargs, +): + """ + ViT weight initialization, original timm impl (for reproducibility). + + See DeepNet_ for all the DeepNorm specific codepaths + """ + + if isinstance(module, nn.Linear): + if deepnorm_style and ( + "q_proj" in name.split(".") or "k_proj" in name.split(".") + ): + gain = 1 + + std = 0.02 * gain + a = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation + + _maybe_init_tensor( + module, "weight", _no_grad_trunc_normal_, mean=0.0, std=std, a=-a, b=a + ) + _maybe_init_tensor(module, "bias", nn.init.zeros_) + + elif hasattr(module, "init_weights"): + module.init_weights(gain=gain) # type: ignore + else: + _maybe_report_no_init(module, name) + + # Recurse over the children, if the weight init is being handled here + if not hasattr(module, "init_weights"): + for child_name, child_module in module.named_children(): + _init_weights_vit_timm(child_module, child_name, gain) diff --git a/.venv/lib/python3.11/site-packages/xformers/info.py b/.venv/lib/python3.11/site-packages/xformers/info.py new file mode 100644 index 0000000000000000000000000000000000000000..af0fa5b2f4051b98559e8995c22c6b0c363ed4d7 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/info.py @@ -0,0 +1,77 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Dict + +import torch + +from . import __version__, _cpp_lib, _is_opensource, _is_triton_available, ops +from .ops.common import OPERATORS_REGISTRY +from .profiler.profiler_dcgm import DCGM_PROFILER_AVAILABLE + + +def get_features_status() -> Dict[str, str]: + features = {} + for op in OPERATORS_REGISTRY: + status_str = "available" if op.is_available() else "unavailable" + features[f"{op.OPERATOR_CATEGORY}.{op.NAME}"] = status_str + for k, v in ops.swiglu_op._info().items(): + features[f"swiglu.{k}"] = v + features["is_triton_available"] = str(_is_triton_available()) + return features + + +def print_info(): + features = get_features_status() + print(f"xFormers {__version__}") + features["pytorch.version"] = torch.__version__ + if torch.cuda.is_available(): + features["pytorch.cuda"] = "available" + device = torch.cuda.current_device() + cap = torch.cuda.get_device_capability(device) + features["gpu.compute_capability"] = ".".join(str(ver) for ver in cap) + features["gpu.name"] = torch.cuda.get_device_name(device) + else: + features["pytorch.cuda"] = "not available" + + features["dcgm_profiler"] = ( + "available" if DCGM_PROFILER_AVAILABLE else "unavailable" + ) + + build_info = _cpp_lib._build_metadata + if build_info is None and isinstance( + _cpp_lib._cpp_library_load_exception, _cpp_lib.xFormersInvalidLibException + ): + build_info = _cpp_lib._cpp_library_load_exception.build_info + if build_info is not None: + features["build.info"] = "available" + features["build.cuda_version"] = build_info.cuda_version + features["build.hip_version"] = build_info.hip_version + features["build.python_version"] = build_info.python_version + features["build.torch_version"] = build_info.torch_version + for k, v in build_info.build_env.items(): + features[f"build.env.{k}"] = v + else: + features["build.info"] = "none" + + try: + features["build.nvcc_version"] = ".".join( + str(v) for v in torch.ops.xformers._nvcc_build_version() + ) + except (RuntimeError, AttributeError): + pass + + if _is_opensource: + features["source.privacy"] = "open source" + else: + features["source.privacy"] = "fairinternal" + + for name, status in features.items(): + print("{:<50} {}".format(f"{name}:", status)) + + +if __name__ == "__main__": + print_info() diff --git a/.venv/lib/python3.11/site-packages/xformers/ops/__pycache__/common.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/ops/__pycache__/common.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca0686bb2f83677aba3a3edb12329c7ce766934c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/ops/__pycache__/common.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/ops/__pycache__/differentiable_collectives.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/ops/__pycache__/differentiable_collectives.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2cf33ce4f4a1b868748f23ad5c978b7b48805ce2 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/ops/__pycache__/differentiable_collectives.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/ops/__pycache__/indexing.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/ops/__pycache__/indexing.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0e54b2234343b89a8322fee0e546c8eacded7643 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/ops/__pycache__/indexing.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/ops/__pycache__/ipc.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/ops/__pycache__/ipc.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c5eb1d7de8cec17e75523af108e857ff98be954 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/ops/__pycache__/ipc.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/ops/__pycache__/modpar_layers.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/ops/__pycache__/modpar_layers.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ebd26d9142f172489235dc24432d1dc5d097e118 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/ops/__pycache__/modpar_layers.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/ops/__pycache__/rmsnorm.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/ops/__pycache__/rmsnorm.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e11798f381c9ec5fee2ec0fac126fe9dd460297 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/ops/__pycache__/rmsnorm.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/ops/__pycache__/rope_padded.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/ops/__pycache__/rope_padded.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f8817bd4919b979cfd580bcbf80fae79bc59815e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/ops/__pycache__/rope_padded.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/ops/__pycache__/seqpar.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/ops/__pycache__/seqpar.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dafbcdebdf46825c14cdc06d74cec58b6e6fbc46 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/ops/__pycache__/seqpar.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/ops/__pycache__/sequence_parallel_fused_ops.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/ops/__pycache__/sequence_parallel_fused_ops.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4cc2497ea7e281e4e697c56e11f0d03d57aeac53 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/ops/__pycache__/sequence_parallel_fused_ops.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/ops/__pycache__/sp24.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/ops/__pycache__/sp24.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6e4c22f3d30a501faaa6769ae3f785366b57284b Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/ops/__pycache__/sp24.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/ops/__pycache__/swiglu_op.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/ops/__pycache__/swiglu_op.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..166a7c521a2f495e24d4a5e1d7657d56ec1841c3 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/ops/__pycache__/swiglu_op.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/ops/__pycache__/tiled_matmul.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/ops/__pycache__/tiled_matmul.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a25f119e180e330c3f4407387363c7cb2378103 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/ops/__pycache__/tiled_matmul.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/ops/__pycache__/unbind.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/ops/__pycache__/unbind.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20eb8fcd36ff849353229c36bd13b68f5e67fc40 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/ops/__pycache__/unbind.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/ops/_triton/__init__.py b/.venv/lib/python3.11/site-packages/xformers/ops/_triton/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0a8ab8e0f6ebcc23b8aa5eaa5dfad7dd70c8acee --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/ops/_triton/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +# One reason this module is called `_triton` instead of just `triton` is this: +# https://github.com/openai/triton/commit/c6040bcbd8a046785462481b2830b3fff5fc4aab + +from typing import TYPE_CHECKING + +import xformers + +if TYPE_CHECKING or xformers._is_triton_available(): + from .k_index_select_cat import index_select_cat_bwd, index_select_cat_fwd + from .k_scaled_index_add import scaled_index_add_bwd, scaled_index_add_fwd +else: + index_select_cat_fwd = index_select_cat_bwd = None + scaled_index_add_fwd = scaled_index_add_bwd = None diff --git a/.venv/lib/python3.11/site-packages/xformers/ops/_triton/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/ops/_triton/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5ddbbfbe81e577bfce0c7ff66c977adffe12619f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/ops/_triton/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/ops/_triton/__pycache__/k_index_select_cat.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/ops/_triton/__pycache__/k_index_select_cat.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb79105e2e6d22b8feab21707a0f7bb4c3b80d06 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/ops/_triton/__pycache__/k_index_select_cat.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/ops/_triton/__pycache__/k_scaled_index_add.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/ops/_triton/__pycache__/k_scaled_index_add.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ad45618fe73e0ce74c174bbc060de5c5d3d5a51 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/ops/_triton/__pycache__/k_scaled_index_add.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/ops/_triton/__pycache__/rmsnorm_kernels.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/ops/_triton/__pycache__/rmsnorm_kernels.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5b41802bcb0a9994c4d4f357b281f8543256f4ff Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/ops/_triton/__pycache__/rmsnorm_kernels.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/ops/_triton/__pycache__/rope_padded_kernels.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/ops/_triton/__pycache__/rope_padded_kernels.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac0d710e5ab1df6d8bf021dcad0ce78a455194a8 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/ops/_triton/__pycache__/rope_padded_kernels.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/ops/_triton/__pycache__/tiled_matmul_kernels.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/ops/_triton/__pycache__/tiled_matmul_kernels.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c553f903f27da781c30afd08ce9bb7c63e229f0 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/ops/_triton/__pycache__/tiled_matmul_kernels.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/ops/common.py b/.venv/lib/python3.11/site-packages/xformers/ops/common.py new file mode 100644 index 0000000000000000000000000000000000000000..f31fde7332be7e2156fad13406fb5df773f11427 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/ops/common.py @@ -0,0 +1,64 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Dict, List, Type, TypeVar + +import torch + + +def get_operator(library: str, name: str): + def no_such_operator(*args, **kwargs): + raise RuntimeError( + f"No such operator {library}::{name} - did you forget to build xformers with `python setup.py develop`?" + ) + + try: + return getattr(getattr(torch.ops, library), name) + except (RuntimeError, AttributeError): + return no_such_operator + + +def get_xformers_operator(name: str): + return get_operator("xformers", name) + + +class BaseOperator: + OPERATOR: Any + NAME: str + OPERATOR_CATEGORY: str + + @classmethod + def is_available(cls) -> bool: + # cls.OPERATOR can be either a kernel or a Triton Autotuner object, which doesn't have __name__ + if ( + cls.OPERATOR is None + or getattr(cls.OPERATOR, "__name__", "") == "no_such_operator" + ): + return False + return True + + +OPERATORS_REGISTRY: List[Type[BaseOperator]] = [] +FUNC_TO_XFORMERS_OPERATOR: Dict[Any, Type[BaseOperator]] = {} + +ClsT = TypeVar("ClsT") + + +def register_operator(cls: ClsT) -> ClsT: + global OPERATORS_REGISTRY, FUNC_TO_XFORMERS_OPERATOR + OPERATORS_REGISTRY.append(cls) # type: ignore + FUNC_TO_XFORMERS_OPERATOR[cls.OPERATOR] = cls # type: ignore + return cls + + +# post-2.0, avoids a warning +# (`torch.Tensor.storage` will also be deleted in the future) +_GET_TENSOR_STORAGE = getattr(torch.Tensor, "untyped_storage", None) +if _GET_TENSOR_STORAGE is None: # pre-2.0, `untyped_storage` didn't exist + _GET_TENSOR_STORAGE = torch.Tensor.storage + + +def _get_storage_base(x: torch.Tensor) -> int: + return _GET_TENSOR_STORAGE(x).data_ptr() # type: ignore diff --git a/.venv/lib/python3.11/site-packages/xformers/ops/differentiable_collectives.py b/.venv/lib/python3.11/site-packages/xformers/ops/differentiable_collectives.py new file mode 100644 index 0000000000000000000000000000000000000000..9d87629dd1611fca3ef98eb7a91b3ffa7b11aec1 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/ops/differentiable_collectives.py @@ -0,0 +1,174 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Optional, Tuple + +import torch +import torch.distributed + + +def all_reduce( + x: torch.Tensor, *, process_group: torch.distributed.ProcessGroup +) -> None: + mp_size = process_group.size() + if mp_size == 1: + return + + torch.distributed.all_reduce( + tensor=x, op=torch.distributed.ReduceOp.SUM, group=process_group + ) + + +def gather_along_first_dim_async( + input_: torch.Tensor, *, process_group: torch.distributed.ProcessGroup +) -> Tuple[torch.Tensor, Optional[torch.distributed.Work]]: + mp_size = process_group.size() + if mp_size == 1: + return input_, None + + output = input_.new_empty((input_.shape[0] * mp_size,) + input_.shape[1:]) + handle = torch.distributed.all_gather_into_tensor( + output_tensor=output, + input_tensor=input_, + group=process_group, + async_op=True, + ) + + return output, handle + + +def reduce_scatter_along_first_dim_async( + input_: torch.Tensor, *, process_group: torch.distributed.ProcessGroup +) -> Tuple[torch.Tensor, Optional[torch.distributed.Work]]: + mp_size = process_group.size() + if mp_size == 1: + return input_, None + + assert input_.shape[0] % mp_size == 0 + output = input_.new_empty((input_.shape[0] // mp_size,) + input_.shape[1:]) + handle = torch.distributed.reduce_scatter_tensor( + output=output, + input=input_, + op=torch.distributed.ReduceOp.SUM, + group=process_group, + async_op=True, + ) + + return output, handle + + +def gather_along_first_dim( + input_: torch.Tensor, *, process_group: torch.distributed.ProcessGroup +) -> torch.Tensor: + output, handle = gather_along_first_dim_async(input_, process_group=process_group) + if handle is not None: + handle.wait() + return output + + +def reduce_scatter_along_first_dim( + input_: torch.Tensor, *, process_group: torch.distributed.ProcessGroup +) -> torch.Tensor: + output, handle = reduce_scatter_along_first_dim_async( + input_, process_group=process_group + ) + if handle is not None: + handle.wait() + return output + + +class _CopyToModelParallelRegion(torch.autograd.Function): + @staticmethod + def forward( # type: ignore[override] + ctx, input_: torch.Tensor, process_group: torch.distributed.ProcessGroup + ) -> torch.Tensor: + ctx.process_group = process_group + return input_ + + @staticmethod + def backward( # type: ignore[override] + ctx, grad_output: torch.Tensor + ) -> Tuple[torch.Tensor, None]: + all_reduce(grad_output, process_group=ctx.process_group) + return grad_output, None + + +def copy_to_model_parallel_region( + x: torch.Tensor, process_group: torch.distributed.ProcessGroup +) -> torch.Tensor: + return _CopyToModelParallelRegion.apply(x, process_group) + + +class _ReduceFromModelParallelRegion(torch.autograd.Function): + @staticmethod + def forward( # type: ignore[override] + ctx, input_: torch.Tensor, process_group: torch.distributed.ProcessGroup + ) -> torch.Tensor: + all_reduce(input_, process_group=process_group) + ctx.mark_dirty(input_) + return input_ + + @staticmethod + def backward( # type: ignore[override] + ctx, grad_output: torch.Tensor + ) -> Tuple[torch.Tensor, None]: + return grad_output, None + + +def reduce_from_model_parallel_region( + x: torch.Tensor, process_group: torch.distributed.ProcessGroup +) -> torch.Tensor: + return _ReduceFromModelParallelRegion.apply(x, process_group) + + +class _GatherFromSequenceParallelRegion(torch.autograd.Function): + @staticmethod + def forward( # type: ignore[override] + ctx, x: torch.Tensor, process_group: torch.distributed.ProcessGroup + ) -> torch.Tensor: + ctx.process_group = process_group + return gather_along_first_dim(x, process_group=process_group) + + @staticmethod + def backward( # type: ignore[override] + ctx, grad_output: torch.Tensor + ) -> Tuple[torch.Tensor, None]: + return ( + reduce_scatter_along_first_dim( + grad_output, process_group=ctx.process_group + ), + None, + ) + + +def gather_from_sequence_parallel_region( + x: torch.Tensor, process_group: torch.distributed.ProcessGroup +) -> torch.Tensor: + return _GatherFromSequenceParallelRegion.apply(x, process_group) + + +class _ScatterToSequenceParallelRegion(torch.autograd.Function): + @staticmethod + def forward( # type: ignore[override] + ctx, x: torch.Tensor, process_group: torch.distributed.ProcessGroup + ) -> torch.Tensor: + ctx.process_group = process_group + return reduce_scatter_along_first_dim(x, process_group=process_group) + + @staticmethod + def backward( # type: ignore[override] + ctx, grad_output: torch.Tensor + ) -> Tuple[torch.Tensor, None]: + return ( + gather_along_first_dim(grad_output, process_group=ctx.process_group), + None, + ) + + +def scatter_to_sequence_parallel_region( + x: torch.Tensor, process_group: torch.distributed.ProcessGroup +) -> torch.Tensor: + return _ScatterToSequenceParallelRegion.apply(x, process_group) diff --git a/.venv/lib/python3.11/site-packages/xformers/ops/indexing.py b/.venv/lib/python3.11/site-packages/xformers/ops/indexing.py new file mode 100644 index 0000000000000000000000000000000000000000..4fafda54406130daf957b428b3e5e446fbdeaf3a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/ops/indexing.py @@ -0,0 +1,237 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Sequence + +import torch + +from xformers.ops._triton import ( + index_select_cat_bwd, + index_select_cat_fwd, + scaled_index_add_bwd, + scaled_index_add_fwd, +) + +from .common import BaseOperator, register_operator + + +# Keeping these operator registry here so that +# it's easy to check if they are available +@register_operator +class ScaledIndexAddFw(BaseOperator): + OPERATOR = scaled_index_add_fwd + OPERATOR_CATEGORY = "indexing" + NAME = "scaled_index_addF" + + +@register_operator +class ScaledIndexAddBw(BaseOperator): + OPERATOR = scaled_index_add_bwd + OPERATOR_CATEGORY = "indexing" + NAME = "scaled_index_addB" + + +@register_operator +class IndexSelect(BaseOperator): + OPERATOR = index_select_cat_fwd + OPERATOR_CATEGORY = "indexing" + NAME = "index_select" + + +class _ScaledIndexAdd(torch.autograd.Function): + @staticmethod + # type: ignore + def forward( + ctx, + x: torch.Tensor, + index: torch.Tensor, + source: torch.Tensor, + scaling: Optional[torch.Tensor], + alpha: float, + ) -> torch.Tensor: + if scaled_index_add_fwd is not None: + scaled_index_add_fwd(x, index, source, scaling, alpha) + else: + raise RuntimeError( + "Triton is needed for forward pass but it is not available!" + ) + + ctx.mark_dirty(x) + ctx.save_for_backward(index, scaling, source) + ctx.source_shape = source.shape + ctx.alpha = alpha + return x + + @staticmethod + @torch.autograd.function.once_differentiable + def backward(ctx, grad_output): + index, scaling, source = ctx.saved_tensors + grad_source = torch.empty_like(source) + grad_scaling = ( + None + if scaling is None + else torch.empty( + ctx.source_shape, dtype=scaling.dtype, device=scaling.device + ) + ) + + if scaled_index_add_bwd is not None: + scaled_index_add_bwd( + grad_output, + grad_source, + grad_scaling, + source, + scaling, + index, + ctx.alpha, + ) + else: + raise RuntimeError( + "Triton is needed for backward pass but it is not available!" + ) + + return ( + grad_output, # gradient of input + None, # gradient of index + grad_source, # gradient of source + grad_scaling, # gradient of scaling + None, # gradient of alpha + ) + + +def scaled_index_add( + input: torch.Tensor, # [B, M, D] + index: torch.Tensor, # [Bi] - int64 + source: torch.Tensor, # [Bi, M, D] + scaling: Optional[torch.Tensor] = None, # [D] + alpha: float = 1.0, +) -> torch.Tensor: + """ + In-place scaling+index_add + + Indices in ``index`` are assumed to be unique + + The max index in ``index`` is assumed to be less than the size of dim0 of ``input``. + + :Note: + + The FW pass is done in-place (``input`` is modified) + + :Equivalent pytorch code: + + .. code-block:: python + + return torch.index_add(input, dim=0, source=scaling * src, index=indices, alpha=alpha) + """ + + return _ScaledIndexAdd.apply(input, index, source, scaling, alpha) + + +class _IndexSelectCat(torch.autograd.Function): + @staticmethod + # type: ignore + def forward( + ctx, + *args: torch.Tensor, + ) -> torch.Tensor: + assert len(args) % 2 == 0 + sources = args[: len(args) // 2] + indices = args[len(args) // 2 :] + + output_numel = 0 + for source, index in zip(sources, indices): + num_rows, num_cols = source.shape + num_indices = index.shape[0] + output_numel += num_indices * num_cols + + output = torch.empty( + [output_numel], dtype=sources[0].dtype, device=sources[0].device + ) + + processed_numel = 0 + for source, index in zip(sources, indices): + num_indices = index.shape[0] + num_cols = source.shape[1] + + if index_select_cat_fwd is not None: + index_select_cat_fwd( + output[ + processed_numel : processed_numel + num_indices * num_cols + ].view([num_indices, num_cols]), + source, + index, + ) + else: + raise RuntimeError( + "Triton is needed for forward pass but it is not available!" + ) + + processed_numel += num_indices * num_cols + + ctx.save_for_backward(*indices) + ctx.source_shapes = [source.shape for source in sources] + + return output + + @staticmethod + @torch.autograd.function.once_differentiable + def backward(ctx, grad_output): + indices = ctx.saved_tensors + + gradients = [] + processed_numel = 0 + for source_shape, index in zip(ctx.source_shapes, indices): + num_rows, num_cols = source_shape + num_indices = index.shape[0] + + grad_output_slice = grad_output[ + processed_numel : processed_numel + num_indices * num_cols + ].reshape([num_indices, num_cols]) + processed_numel += num_indices * num_cols + + grad_source_slice = torch.zeros( + [num_rows, num_cols], + dtype=grad_output.dtype, + device=grad_output.device, + ) + + if index_select_cat_bwd is not None: + index_select_cat_bwd( + grad_source_slice, + index, + grad_output_slice, + ) + else: + raise RuntimeError( + "Triton is needed for backward pass but it is not available!" + ) + gradients.append(grad_source_slice) + + return (*gradients, *([None] * len(gradients))) + + +def index_select_cat( + sources: Sequence[torch.Tensor], indices: Sequence[torch.Tensor] +) -> torch.Tensor: + """ + Indices in ``index`` are assumed to be unique + In each (index, source) pair, the max index in ``index`` is assumed to be less than the size of dim0 of ``source`` + + :Example: + + Given: + - ``sources[0]`` of shape ``[S0, D0]`` + - ``indices[0]`` of shape ``[I0]`` + - ``sources[1]`` of shape ``[S1, D1]`` + - ``indices[1]`` of shape ``[I1]`` + returns a ``torch.Tensor`` of shape ``[I0 * D0 + I1 * D1]`` + + :Equivalent pytorch code: + + .. code-block:: python + + return torch.cat([s[i.long()].flatten() for s, i in zip(sources, indices)], dim=0) + """ + return _IndexSelectCat.apply(*sources, *indices) diff --git a/.venv/lib/python3.11/site-packages/xformers/ops/ipc.py b/.venv/lib/python3.11/site-packages/xformers/ops/ipc.py new file mode 100644 index 0000000000000000000000000000000000000000..e39dd0ff0c171566f9473c03e87417b0518817ae --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/ops/ipc.py @@ -0,0 +1,164 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import concurrent.futures +import json +import multiprocessing.connection +from typing import Any, List, Optional, Union + +import torch +import torch.distributed as dist +import torch.multiprocessing.reductions + +# We could just send tensors directly on mp.Connections, since PyTorch installs +# the necessary reductions to make it work. However, in the receiving process, +# PyTorch "mounts" the tensor in the CUDA context for the GPU with the **SAME +# INDEX** as on the sender. This works if all processes use CUDA_VISIBLE_DEVICES +# to limit themselves to a single GPU (which thus has index 0 everywhere) but in +# all other cases it's a mess. Hence we use our own reductions (which wrap the +# ones from PyTorch) to use the right devices. + + +def _serialize_cuda_tensor(tensor: torch.Tensor): + assert tensor.device.type == "cuda" + func, args = torch.multiprocessing.reductions.reduce_tensor(tensor) + assert func is torch.multiprocessing.reductions.rebuild_cuda_tensor + assert args[6] == tensor.device.index + return args + + +def _deserialize_cuda_tensor(args, device: torch.device) -> torch.Tensor: + args = list(args) + args[6] = device.index + return torch.multiprocessing.reductions.rebuild_cuda_tensor(*args) + + +# We need all processes to exchange a few strings with their addresses (in order +# to be able to connect to each other). The solution for this kind of things in +# PyTorch is a Store (TCPStore or FileStore) but we cannot create one ourselves +# (we don't know which addr/port/file to use, since the default one is already +# being used by PyTorch's global store) nor can we extract one from the +# ProcessGroup (since there's no API to do so). We thus resort to using the PG +# itself to exchange data, which is overkill (we need to store the pickled data +# into tensors and send it to the GPU). On top of that, it introduces one more +# catch: it doesn't work in inference mode because of something about modifying +# tensors inplace. I couldn't find a way to temporarily disable inference mode +# (although it's supposed to be possible) however inference mode is thread-local +# so we can dodge it by offloading the collective call to another thread. I hate +# all this so much. + +# Use a sequence number to create unique store keys for different invocations. +_COUNTER = 0 + + +def _exchange_addresses( + listeners: List[multiprocessing.connection.Listener], + group: dist.ProcessGroup, + device: torch.device, +) -> List[List[str]]: + global _COUNTER + rank = group.rank() + world_size = group.size() + my_addresses: List[str] = [] + for listener in listeners: + addr = listener.address + # The address could be a tuple if the listener weren't a UNIX socket + if isinstance(addr, bytes): + # Shouldn't be bytes, according to docs and typeshed, but... + # https://github.com/python/typeshed/issues/10054 + addr = addr.decode("utf-8") + assert isinstance(addr, str) + my_addresses.append(addr) + if world_size == 1: + return [my_addresses] + # In fact, we can retrieve the store from the ProcessGroup, but only using + # a private API. Hence we catch whatever exception and fall back in case. + try: + _, store = torch.distributed.distributed_c10d._world.pg_map.get( + group, (None, None) + ) + assert store is not None + store.set( + f"xformers_exchange_addresses_{_COUNTER}_{rank}", + json.dumps(my_addresses), + ) + all_addresses = [ + json.loads(store.get(f"xformers_exchange_addresses_{_COUNTER}_{i}")) + for i in range(world_size) + ] + _COUNTER += 1 + except Exception: + all_addresses = [[""] * (world_size - 1)] * world_size + with concurrent.futures.ThreadPoolExecutor( + initializer=torch.cuda.set_device, initargs=(device,) + ) as e: + e.submit( + dist.all_gather_object, + object_list=all_addresses, + obj=my_addresses, + group=group, + ).result() + return all_addresses + + +class IPCPipe: + def __init__(self, connection, my_device) -> None: + self.connection = connection + self.my_device = my_device + + def send(self, tensor: torch.Tensor) -> None: + assert self.connection is not None, "Sending to myself!" + assert tensor.device == self.my_device, f"{tensor.device=} != {self.my_device=}" + self.connection.send(_serialize_cuda_tensor(tensor)) + + def recv(self) -> torch.Tensor: + assert self.connection is not None, "Receiving from myself!" + return _deserialize_cuda_tensor(self.connection.recv(), self.my_device) + + +def init_ipc( + group: dist.ProcessGroup, + device: Union[torch.device, str] = "cuda", +) -> List[Optional[IPCPipe]]: + """ + Initializes pipes between processes of a `ProcessGroup`, that can be used + to exchange `torch.Tensor` later + """ + if isinstance(device, str): + device = torch.device(device) + if device.index is None: + device = torch.device(device.type, index=torch.cuda.current_device()) + world_size = group.size() + my_rank = group.rank() + # Open connections to all other processes. We exchange addresses via + # NCCL since we don't have access to a Store. + listeners = [ + multiprocessing.connection.Listener(family="AF_UNIX", address="", backlog=1) + for _ in range(world_size) + ] + # If any process is late, all other ones will block here + all_addresses = _exchange_addresses(listeners, group, device) + connections: Any = [] + for other_rank in range(world_size): + # For p2p connection between ranks i<->j + # if `i None: + # Mimick FairScale's _initialize_affine_weight, for backwards compatibility. + # The reason we initialize the full unpartitioned/gathered weight is so that + # different ranks get different initial values and thus "break the symmetry" + # and in order to achieve the same init for any value of model parallelism. + rank = process_group.rank() + world_size = process_group.size() + + nrows, ncols = weight.shape + if partition_dim == 0: + full_weight = weight.new_empty(nrows * world_size, ncols) + my_weight_slice = full_weight[rank::world_size, :] + else: + full_weight = weight.new_empty(nrows, ncols * world_size) + my_weight_slice = full_weight[:, rank::world_size] + + init_method(full_weight) + + with torch.no_grad(): + weight.copy_(my_weight_slice) + + +class ColumnParallelLinear(torch.nn.Module): + def __init__( + self, + in_features: int, + out_features: List[int], + *, + process_group: torch.distributed.ProcessGroup, + bias: bool = True, + gather_output: bool = True, + init_method: Callable[ + [torch.Tensor], torch.Tensor + ] = torch.nn.init.xavier_normal_, + sequence_parallel: bool = False, + fuse_sequence_parallel: bool = True, + ) -> None: + super(ColumnParallelLinear, self).__init__() + + if not isinstance(out_features, list): + raise TypeError( + "xFormers's implementation of ColumnParallelLinear requires out_features to be a list" + ) + if bias: + raise ValueError( + "xFormers's implementation of ColumnParallelLinear requires bias=False" + ) + if gather_output: + raise ValueError( + "xFormers's implementation of ColumnParallelLinear requires gather_output=False" + ) + + self.in_features = in_features + self.global_out_features = out_features + self.sequence_parallel = sequence_parallel + self.fuse_sequence_parallel = fuse_sequence_parallel + self.process_group = process_group + mp_size = process_group.size() + assert all(dim % mp_size == 0 for dim in out_features) + self.my_out_features = [dim // mp_size for dim in out_features] + + self.weights = torch.nn.ParameterList( + [ + torch.nn.Parameter(torch.empty((dim, in_features))) + for dim in self.my_out_features + ] + ) + + for w in self.weights: + _init_2d_weight(w, init_method, process_group, partition_dim=0) + + def forward(self, input_: torch.Tensor) -> List[torch.Tensor]: + if self.sequence_parallel: + outputs = sequence_parallel_leading_matmul( + input_, + [w.t() for w in self.weights], + fuse=self.fuse_sequence_parallel, + process_group=self.process_group, + ) + else: + input_ = copy_to_model_parallel_region(input_, self.process_group) + outputs = [torch.matmul(input_, w.t()) for w in self.weights] + return outputs + + +class RowParallelLinear(torch.nn.Module): + def __init__( + self, + in_features: int, + out_features: int, + *, + process_group: torch.distributed.ProcessGroup, + bias: bool = True, + input_is_parallel: bool = False, + init_method: Callable[ + [torch.Tensor], torch.Tensor + ] = torch.nn.init.xavier_normal_, + sequence_parallel: bool = False, + fuse_sequence_parallel: bool = True, + ): + super(RowParallelLinear, self).__init__() + + if bias: + raise ValueError( + "xFormers's implementation of RowParallelLinear requires bias=False" + ) + if not input_is_parallel: + raise ValueError( + "xFormers's implementation of RowParallelLinear requires input_is_parallel=True" + ) + + self.global_in_features = in_features + self.out_features = out_features + self.sequence_parallel = sequence_parallel + self.fuse_sequence_parallel = fuse_sequence_parallel + self.process_group = process_group + mp_size = process_group.size() + assert in_features % mp_size == 0 + self.my_in_features = in_features // mp_size + + self.weight = torch.nn.Parameter( + torch.empty((out_features, self.my_in_features)) + ) + + _init_2d_weight(self.weight, init_method, process_group, partition_dim=1) + + def forward(self, input_: torch.Tensor) -> torch.Tensor: + if self.sequence_parallel: + output = sequence_parallel_trailing_matmul( + input_, + self.weight.t(), + fuse=self.fuse_sequence_parallel, + process_group=self.process_group, + ) + else: + output = torch.matmul(input_, self.weight.t()) + output = reduce_from_model_parallel_region(output, self.process_group) + return output diff --git a/.venv/lib/python3.11/site-packages/xformers/ops/rmsnorm.py b/.venv/lib/python3.11/site-packages/xformers/ops/rmsnorm.py new file mode 100644 index 0000000000000000000000000000000000000000..94a3743fbdf57e98fb1697bfc4f8a391d6a27651 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/ops/rmsnorm.py @@ -0,0 +1,113 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +from typing import Optional + +import torch +from torch import nn + +from .. import _is_triton_available + + +def rms_norm(x, weight: Optional[torch.Tensor], eps: float = 1e-6): + """ + RMS Normalization along the last dimension. + + This is similar to torch.nn.functional.normalize but with eps being added + instead of max. + + Expects x contiguous of shape (..., dim), and returns normalized data + of the same shape. For each dim-length vector x, the result has + + x / sqrt( x*x.sum() + eps) + + If weights are included, they are a contiguous parameter of length dim + which multiplies the result. + + This functionality is experimental. Its API might be changed without warnings. + Use it at your own risk. + """ + assert _is_triton_available() + from ._triton.rmsnorm_kernels import _rms_norm_forward + + if torch.is_grad_enabled() and ( + x.requires_grad or (weight is not None and weight.requires_grad) + ): + raise ValueError("Gradients not supported.") + + return _rms_norm_forward(x, weight, eps) + + +def rms_norm_add( + x: torch.Tensor, y: torch.Tensor, weight: Optional[torch.Tensor], eps: float = 1e-6 +): + """ + An addition fused with rms_norm. + + z = rms_norm_add(x, y, weight, eps) + + is equivalent to + + x += y + z = rms_norm(x, weight, eps) + + where x, y and z are all contiguous. + + This functionality is experimental. Its API might be changed without warnings. + Use it at your own risk. + """ + if torch.is_grad_enabled() and ( + x.requires_grad + or y.requires_grad + or (weight is not None and weight.requires_grad) + ): + raise ValueError("Gradients not supported.") + assert _is_triton_available() + from ._triton.rmsnorm_kernels import _rms_norm_add_forward + + return _rms_norm_add_forward(x, y, weight, eps) + + +class RMSNorm(torch.nn.Module): + """ + RMS Normalization layer along the last dimension. + + This is similar to torch.nn.functional.normalize but with eps being added + instead of max. + + Expects contiguous input of shape (..., dim), and returns normalized data + of the same shape. For each dim-length vector x, the result has + + x / sqrt( x*x.sum() + eps) + + If weights are included, they are a parameter of length dim which multiplies + the result. + + This functionality is experimental. Its API might be changed without warnings. + Use it at your own risk. + """ + + def __init__(self, dim: int, include_weight: bool = True, eps: float = 1e-6): + super().__init__() + self.eps = eps + if include_weight: + self.weight: Optional[nn.Parameter] = nn.Parameter(torch.ones(dim)) + else: + self.weight = None + + def forward(self, x: torch.Tensor): + return rms_norm(x, self.weight, self.eps) # type: ignore + + def increment_and_forward_(self, x: torch.Tensor, y: torch.Tensor): + """ + An addition fused with forward. + + z = layer.increment_and_forward_(x, y) + + is equivalent to + + x += y + z = layer(x) + """ + return rms_norm_add(x, y, self.weight, self.eps) # type: ignore diff --git a/.venv/lib/python3.11/site-packages/xformers/ops/rope_padded.py b/.venv/lib/python3.11/site-packages/xformers/ops/rope_padded.py new file mode 100644 index 0000000000000000000000000000000000000000..2329a38c4f6ad8c213a364dd036f371e89580962 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/ops/rope_padded.py @@ -0,0 +1,300 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +from typing import Optional, Tuple + +import torch + +from xformers.ops.fmha.attn_bias import ( # type: ignore + BlockDiagonalCausalWithOffsetPaddedKeysMask, +) + +from .. import _is_triton_available + + +def rope_padded( + xq: torch.Tensor, + xk: torch.Tensor, + xv: torch.Tensor, + cache_k: torch.Tensor, + cache_v: torch.Tensor, + attn_bias: BlockDiagonalCausalWithOffsetPaddedKeysMask, + *, + theta: float = 10000.0, + linear_scale: float = 1.0, + use_dynamic_scaling: bool = False, + dynamic_old_context_len: float = 8192.0, + dynamic_scale_factor: float = 16.0, + dynamic_low_freq_factor: float = 1.0, + dynamic_high_freq_factor: float = 32.0, + out_q: Optional[torch.Tensor] = None, + first_seqpos: Optional[torch.Tensor] = None, + seqpos: Optional[torch.Tensor] = None, + adjacents: bool = True, + internal_dtype: str = "", +): + """ + Performs RoPE (rotary embeddings) and kv-cache emplacement for a heterogeneous + batch for inference in the style given by + BlockDiagonalCausalWithOffsetPaddedKeysMask. + The batch is concatenated along the sequence dimension, so the + actual dim-0 length of all tensors is 1. + + xq, xk and xv should be (1, slen, n_heads, dim), where + xq's n_heads can differ from xk and xv. + + This function places the roped xk in the right place in cache_k, and + xv (unmodified) in the right place in cache_v, and returns out_q + (the roped xq) such that things are ready to call + + xformers.ops.memory_efficient_attention( + out_q, cache_k, cache_v, attn_bias=attn_bias + ) + + This functionality is experimental. Its API might be changed without warnings. + Use it at your own risk. + + Arguments: + xq: tensor of queries to apply rope to + xk: tensor of keys to apply rope to + xv: tensor of values to copy into cache_v + cache_k: cache of keys, MODIFIED IN PLACE + cache_v: cache of values, MODIFIED IN PLACE + attn_bias: details the layout of caches. + Used to determine frequencies for the + RoPE calculation as well as the locations in cache_k and cache_v + to write to. Must be on the device. + first_seqpos: Optionally a tensor containing the sequence position of the + beginning of the cache for each batch element. + Providing a tensor of zeros is the same as providing None. + This affects the numerical calculation but not which memory + locations are read or written. + seqpos: Optionally a 1D tensor containing the sequence position of each + query. This should have length equal to xq.shape[1] . + This affects the numerical calculation but not which memory + locations are read or written. + adjacents: If True, the inputs are in adjacent pairs along the final dim axis. + This is like the released LLaMA model. + If False, the dim axis is split in two equal pieces. + I.e. the features are ordered with all the real parts before all + the imaginary parts. This matches HuggingFace, e.g. + https://github.com/huggingface/transformers/blob/ + f143037789288ba532dada934a118e648e715738/ + src/transformers/models/llama/modeling_llama.py#L126-L130 + linear_scale: A scaling factor to apply to the sequence ids when computing + the RoPE frequencies. When set to K, all sequence indices + are divided by K. + use_dynamic_scaling: If true, dynamic scaling in use, using a scaling like + “YaRN: Efficient Context Window Extension of Large Language Models” + dynamic_old_context_len: used with use_dynamic_scaling + dynamic_scale_factor: used with use_dynamic_scaling + dynamic_low_freq_factor: used with use_dynamic_scaling + dynamic_high_freq_factor: used with use_dynamic_scaling + internal_dtype: set to "f32" or "f64" to enforce dtype in the calculation + """ + if torch.is_grad_enabled() and ( + xq.requires_grad + or xk.requires_grad + or xv.requires_grad + or cache_k.requires_grad + or cache_v.requires_grad + or out_q is not None + ): + raise ValueError("Gradients not supported.") + assert _is_triton_available() + import triton + + from ._triton.rope_padded_kernels import _rope_padded_kernel + + n_total_queries = attn_bias.q_seqinfo.seqstart_py[-1] + cache_length = attn_bias.k_seqinfo.seqstart_py[-1] + ndim = xq.ndim + if ndim not in [4, 5]: + raise ValueError("Unexpected xq dimension") + xq_stride = xq.stride() + xk_stride = xk.stride() + xv_stride = xv.stride() + cache_k_stride = cache_k.stride() + cache_v_stride = cache_v.stride() + cache_k_shape = cache_k.shape + xk_shape = xk.shape + n_kv_heads = xk_shape[-2] + expected_kv_heads = n_kv_heads + if xk_stride[-2] == 0: + n_kv_heads = 1 + expected_cache_heads = n_kv_heads + if n_kv_heads == 1 and cache_k_stride[-2] == 0: + # If there's 1 kv head, don't care how expanded + # cache_k is. User might expand before or after rope. + expected_cache_heads = cache_k_shape[-2] + + if ndim == 4: + bsz, q_len, n_q_heads, dim = xq.shape + assert q_len == n_total_queries + if xk_shape != (1, n_total_queries, expected_kv_heads, dim): + raise ValueError( + f"unexpected k shape {xk_shape}: expected {(1, n_total_queries, expected_kv_heads, dim)}" + ) + if xv.shape != (1, n_total_queries, expected_kv_heads, dim): + raise ValueError( + f"unexpected v shape {xv.shape}: expected {(1, n_total_queries, expected_kv_heads, dim)}" + ) + if cache_k_shape != (1, cache_length, expected_cache_heads, dim): + raise ValueError("unexpected cache_k shape") + if cache_v.shape != (1, cache_length, expected_cache_heads, dim): + raise ValueError("unexpected cache_v shape") + n_groups = 1 + out_q_stride: Tuple[int, ...] = (0, n_q_heads * dim, dim, 1) + + else: + bsz, q_len, n_groups, n_q_heads, dim = xq.shape + assert q_len == n_total_queries + if xk_shape != (1, n_total_queries, n_groups, expected_kv_heads, dim): + raise ValueError( + f"unexpected k shape {xk_shape}: expected {(1, n_total_queries, n_groups, expected_kv_heads, dim)}" + ) + if xv.shape != (1, n_total_queries, n_groups, expected_kv_heads, dim): + raise ValueError( + f"unexpected v shape {xv.shape}: expected {(1, n_total_queries, n_groups, expected_kv_heads, dim)}" + ) + if cache_k_shape != (1, cache_length, n_groups, expected_cache_heads, dim): + raise ValueError( + f"unexpected cache_k shape {cache_k_shape}: " + f"expected {(1, cache_length, n_groups, expected_cache_heads, dim)}" + ) + if cache_v.shape != (1, cache_length, n_groups, expected_cache_heads, dim): + raise ValueError( + f"unexpected cache_v shape {cache_v.shape}: " + f"expected {(1, cache_length, n_groups, expected_cache_heads, dim)}" + ) + out_q_stride = ( + 0, + n_q_heads * dim * n_groups, + n_q_heads * dim, + dim, + 1, + ) + + if bsz != 1: + raise ValueError( + "Expected batch size dimension to be 1 as batches should be concatenated." + ) + if xq_stride[-1] != 1: + raise ValueError("Each q head must be contiguous") + if xk_stride[-1] != 1: + raise ValueError("Each k head must be contiguous") + if xv_stride[-1] != 1: + raise ValueError("Each v head must be contiguous") + if cache_k_stride[-1] != 1: + raise ValueError("Each cache_k head must be contiguous") + if cache_v_stride[-1] != 1: + raise ValueError("Each cache_v head must be contiguous") + n_total_heads = n_q_heads + 2 * n_kv_heads + v_start = n_total_heads - n_kv_heads + k_start = n_q_heads + if out_q is None: + out_q = xq.new_empty(xq.shape) + else: + if out_q.shape != xq.shape: + raise ValueError("Unexpected shape of out_q") + out_q_stride = out_q.stride() + if out_q_stride[-1] != 1: + raise ValueError("Each out_q head must be contiguous") + + assert out_q is not None + + logical_bsz = len(attn_bias.q_seqinfo.seqstart_py) - 1 + + if first_seqpos is not None and seqpos is not None: + raise ValueError("seqpos and first_seqpos may not both be provided") + stride_seqpos = 0 + if first_seqpos is not None: + if first_seqpos.shape != (logical_bsz,): + shape = tuple(first_seqpos.shape) + raise ValueError( + f"first_seqpos.shape {shape} but ({logical_bsz},) expected." + ) + stride_seqpos = first_seqpos.stride(0) + elif seqpos is not None: + if seqpos.shape != (n_total_queries,): + shape = tuple(seqpos.shape) + raise ValueError(f"seqpos.shape {shape} but ({n_total_queries},) expected.") + stride_seqpos = seqpos.stride(0) + + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // xq.element_size() + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(dim)) + BLOCK_SIZE = max(BLOCK_SIZE, 128) + BLOCK_SIZE = min(BLOCK_SIZE, 4096) + # heuristics for number of warps + num_warps = min(max(BLOCK_SIZE // 256, 1), 8) + device = xq.device + seqstartq = attn_bias.q_seqinfo.seqstart + seqstartk = attn_bias.k_seqinfo.seqstart + seqlenk = attn_bias.k_seqinfo.seqlen + if ( + seqstartq.device != device + or seqstartk.device != device + or seqlenk.device != device + ): + raise ValueError("`attn_bias` must be on the same device as the other inputs") + assert internal_dtype in ["", "f32", "f64"] + # experiment with the order of dims here. + with torch.cuda.device(xq.device): + _rope_padded_kernel[ + (attn_bias.q_seqinfo.max_seqlen, logical_bsz, n_total_heads * n_groups) + ]( + xq, + xk, + xv, + out_q, + cache_k, + cache_v, + seqstartq, + seqstartk, + seqlenk, + theta, + linear_scale, + use_dynamic_scaling, + dynamic_old_context_len if use_dynamic_scaling else 0, + dynamic_scale_factor if use_dynamic_scaling else 0, + dynamic_low_freq_factor if use_dynamic_scaling else 0, + dynamic_high_freq_factor if use_dynamic_scaling else 0, + first_seqpos, + seqpos, + k_start, + v_start, + n_groups, + dim, + xq_stride[1], + xq_stride[2] if ndim == 5 else 0, + xq_stride[-2], + xk_stride[1], + xk_stride[2] if ndim == 5 else 0, + xk_stride[-2], + xv_stride[1], + xv_stride[2] if ndim == 5 else 0, + xv_stride[-2], + cache_k_stride[1], + cache_k_stride[2] if ndim == 5 else 0, + cache_k_stride[-2], + cache_v_stride[1], + cache_v_stride[2] if ndim == 5 else 0, + cache_v_stride[-2], + seqstartq.stride(0), + seqstartk.stride(0), + seqlenk.stride(0), + out_q_stride[1], + out_q_stride[2] if ndim == 5 else 0, + out_q_stride[-2], + stride_seqpos, + internal_dtype, + const_batch_strides=False, + cache_padding_length=0, + seqlenk_shift=0, + BLOCK_SIZE=BLOCK_SIZE, + adjacents=adjacents, + num_warps=num_warps, + ) + return out_q diff --git a/.venv/lib/python3.11/site-packages/xformers/ops/sequence_parallel_fused_ops.py b/.venv/lib/python3.11/site-packages/xformers/ops/sequence_parallel_fused_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..41b2a3d5c54d58a802d31203661bd63b8b54e369 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/ops/sequence_parallel_fused_ops.py @@ -0,0 +1,959 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import os +from typing import Callable, Dict, List, Optional, Sequence, Union, overload + +import torch +import torch.distributed as dist +import torch.multiprocessing.reductions + +from .common import BaseOperator, get_xformers_operator, register_operator +from .ipc import init_ipc + +# The sequence numbers will be communicated as 32-bit integers, due to +# limitations in both CUDA (memset can only operate on 4 bytes at a time at +# most) and Triton (scalar arguments are int32 if they fit). 32 bits are not +# enough to be sure that we'll never see overflow. Moreover, different parts of +# the code use signed or unsigned ints. To be safe, let's simulate overflow +# ourselves, at a value low enough so that it fits both a signed and an unsigned +# 32-bit integer. And, in fact, let's make it so low that we're sure we'll hit +# it in our tests, to avoid bugs that only manifest in long-running training. +SEQ_NUM_WRAP_AROUND = 2**8 + + +@register_operator +class WriteValues(BaseOperator): + OPERATOR = get_xformers_operator("write_values") + OPERATOR_CATEGORY = "sequence_parallel_fused" + NAME = "write_values" + + +@register_operator +class WaitValues(BaseOperator): + OPERATOR = get_xformers_operator("wait_values") + OPERATOR_CATEGORY = "sequence_parallel_fused" + NAME = "wait_values" + + +@register_operator +class Memset32bAsync(BaseOperator): + OPERATOR = get_xformers_operator("cuda_memset_32b_async") + OPERATOR_CATEGORY = "sequence_parallel_fused" + NAME = "cuda_memset_32b_async" + + +def _is_fp8_dtype(dt: torch.dtype): + # Detect if it's float8_e4m3fn or float8_e5m2 without mentioning them in + # order to support old versions of PyTorch that don't define them. + return dt.is_floating_point and torch.finfo(dt).bits == 8 + + +class _FusedSequenceParallel: + """Set up a communication ring and perform fused ops on it + + Stores the persistent state needed to support a ring of connections between + processes, and the logic that can do fused comms + matmuls on it. + + We want to achieve overlap between: + - a computation which reads from the data we received from a remote GPU + - and the communication where we send some data to another GPU + And in order to do that we need some staging buffers and a way to + synchronize access to them across processes. + + To perform the communication over NVLink we make the processes exchange + their staging buffers using IPC (Inter-Process Communication) handles, which + "mounts"/"mmaps" an allocation on one GPU into the virtual address space of + another GPU: the memory remains backed by the original GPU but the other GPU + can access it as if it were local. We exchange these IPC handles using + multiprocessing Connections (and the "reductions" provided by PyTorch), + which we establish over UNIX domain sockets, whose addresses we exchange by + using a ProcessGroup. + + To synchronize accesses we use a set of counters/sequence numbers that are + also allocated in memory shared over IPC handles. Processes signal that they + completed an operation by launching a kernel that increases that value, and + they wait for anoher process to complete an operation by launching a kernel + that busy-waits for that value to increase. Currently we implement these + kernels manually, but on recent CUDA drivers (515.43.04+, corresponding to + CUDA 11.7) we could use standard stream memory operations (see + https://docs.nvidia.com/cuda/archive/11.7.0/cuda-driver-api/group__CUDA__MEMOP.html). + + We prefer to use these kernels (or the stream memory ops) over IPC events + because IPC events require signaling between processes at launch time to + ensure that the wait on one process occurs after the record on another + process. This signaling means that _launching_ our fused operation becomes a + synchronization barrier, which can increase the launch overhead. It would + also behave differently from NCCL, where launching is async and all the + synchronization happens on device in the kernels. A previous version of this + code which uses IPC events can be found here: + https://github.com/fairinternal/xformers/pull/504. + + """ + + def __init__( + self, + device: torch.device, + group: dist.ProcessGroup, + ): + self.my_device = device + self.my_rank = group.rank() + self.world_size = group.size() + + self.p2p_comms = init_ipc(group, self.my_device) + + self.next_seq_num = 1 + + # My staging buffers + self.staging = torch.empty((0,), device=self.my_device) + + # (Mmapped view of a handle to) buddies' staging buffers + self.buddys_staging = [ + torch.empty((0,), device=self.my_device) + ] * self.world_size + + # Allocate buffers for locally-hosted counters + self.op_finished_produce = torch.zeros( + (), dtype=torch.int, device=self.my_device + ) + self.comms_ready_consume = torch.zeros( + (self.world_size,), dtype=torch.int, device=self.my_device + ) + + # Send my handles to buddies + for rank, conn in enumerate(self.p2p_comms): + if conn is not None: + conn.send(self.op_finished_produce) + conn.send(self.comms_ready_consume[rank]) + + # Open buddies' inboxes as my outboxes + self.op_finished_consume = [ + torch.empty((0,), device=self.my_device) if conn is None else conn.recv() + for conn in self.p2p_comms + ] + self.comms_ready_produce = [ + torch.empty((0,), device=self.my_device) if conn is None else conn.recv() + for conn in self.p2p_comms + ] + + self.second_stream = torch.cuda.Stream() + # CUDA can schedule the matmul and the memcpy at the same time, but it + # tends to run the matmul first and delay the memcpy, which causes a + # domino effect. We thus "encourage" it to prioritize the memcpy. + self.memcpy_stream = torch.cuda.Stream(priority=-1) + # Use dedicated streams to run the wait kernels in the background. + self.compute_wait_stream = torch.cuda.Stream(priority=-1) + self.memcpy_wait_stream = torch.cuda.Stream(priority=-1) + + self.next_stream_idx = 0 + + def _ensure_staging_is_large_enough( + self, num_elements: int, random_init: bool, dtype: torch.dtype + ): + total_num_bytes = num_elements * dtype.itemsize + + # Lazily size up the staging area as needed. (If it's the first call, + # this will always trigger, since staging starts empty). Once at steady + # state, staging will be of the right (max) size and never grow again. + if self.staging.numel() < self.world_size * total_num_bytes: + # When running with _memcpy=False (i.e., for benchmarks) we must + # ensure that the staging buffer doesn't contain all zeroes as that + # makes the matmuls go faster (better L2 compression or something). + self.staging = torch.empty( + (self.world_size, total_num_bytes), + device=self.my_device, + dtype=torch.uint8, + ) + if random_init: + self.staging.view(torch.bfloat16).normal_() + for rank, conn in enumerate(self.p2p_comms): + if conn is not None: + conn.send(self.staging[rank]) + self.buddys_staging = [ + torch.empty((0,), device=self.my_device) + if conn is None + else conn.recv() + for rank, conn in enumerate(self.p2p_comms) + ] + + def make_stream_factory( + self, current_stream: torch.cuda.Stream + ) -> Callable[[], torch.cuda.Stream]: + def result(): + stream = [current_stream, self.second_stream][self.next_stream_idx] + self.next_stream_idx += 1 + self.next_stream_idx %= 2 + return stream + + return result + + def allgather_and_linear( + self, + scattered_inputs: List[torch.Tensor], + my_matmul: Callable[ + [List[torch.Tensor], int, Callable[[], torch.cuda.Stream]], None + ], + timeout_s: int, + _wait: bool = True, + _memcpy: bool = True, + ): + """Perform a fused all-gather followed by a linear layer""" + + dtype = scattered_inputs[0].dtype + assert all(si.device == self.my_device for si in scattered_inputs) + assert all(si.dtype == dtype for si in scattered_inputs) + + scattered_input_numels = [si.numel() for si in scattered_inputs] + total_scattered_input_numel = sum(scattered_input_numels) + self._ensure_staging_is_large_enough( + total_scattered_input_numel, random_init=_memcpy is False, dtype=dtype + ) + + seq_num = self.next_seq_num % SEQ_NUM_WRAP_AROUND + prev_seq_num = (seq_num - 1) % SEQ_NUM_WRAP_AROUND + self.next_seq_num += 1 + + stagings = [ + s.view((self.world_size,) + si.shape) + for s, si in zip( + self.staging.view(dtype)[:, :total_scattered_input_numel].split( + scattered_input_numels, dim=-1 + ), + scattered_inputs, + ) + ] + buddys_stagings = [ + [bs] * len(scattered_inputs) + if bs.numel() == 0 + else [ + s.view(si.shape) + for s, si in zip( + bs.view(dtype)[:total_scattered_input_numel].split( + scattered_input_numels, dim=-1 + ), + scattered_inputs, + ) + ] + for bs in self.buddys_staging + ] + + current_stream = torch.cuda.current_stream() + self.second_stream.wait_stream(current_stream) + self.compute_wait_stream.wait_stream(current_stream) + self.memcpy_wait_stream.wait_stream(current_stream) + stream_factory = self.make_stream_factory(current_stream) + + for iter_ in range(1, self.world_size): + dst_rank = (self.my_rank + iter_) % self.world_size + + # Wait for buddy to signal that it read from the data before we + # overwrite it (this wait matches up with write [B] below). + if _wait: + WaitValues.OPERATOR( + [self.op_finished_consume[dst_rank]], + prev_seq_num, + self.memcpy_wait_stream, + timeout_s, + ) + + self.memcpy_stream.wait_stream(self.memcpy_wait_stream) + + if _memcpy: + with torch.cuda.stream(self.memcpy_stream): + for bs, si in zip(buddys_stagings[dst_rank], scattered_inputs): + bs.copy_(si) + + # Signal to buddy that we have written into the data so it can + # read from it (this write matches up with wait [A] below). + if _wait: + Memset32bAsync.OPERATOR( + self.comms_ready_produce[dst_rank], + seq_num, + self.memcpy_stream, + ) + + my_matmul(scattered_inputs, self.my_rank, stream_factory) + + for iter_ in range(1, self.world_size): + src_rank = (self.my_rank - iter_) % self.world_size + + # Wait for buddy to signal that it wrote into the data before we + # read from it (this wait matches up with write [A] above). + if _wait: + WaitValues.OPERATOR( + [self.comms_ready_consume[src_rank]], + seq_num, + self.compute_wait_stream, + timeout_s, + ) + + current_stream.wait_stream(self.compute_wait_stream) + self.second_stream.wait_stream(self.compute_wait_stream) + + my_matmul([s[src_rank] for s in stagings], src_rank, stream_factory) + + current_stream.wait_stream(self.second_stream) + current_stream.wait_stream(self.memcpy_stream) + + # Signal to buddy that we have read from the data so it can + # overwrite it (this write matches up with wait [B] above). + if _wait: + WriteValues.OPERATOR( + [self.op_finished_produce], + seq_num, + current_stream, + ) + + def linear_and_reducescatter( + self, + my_matmul: Callable[ + [List[torch.Tensor], int, Callable[[], torch.cuda.Stream]], None + ], + gathered_outputs: List[torch.Tensor], + scattered_outputs: List[torch.Tensor], + timeout_s: int, + _wait: bool = True, + _memcpy: bool = True, + ): + """Perform a fused linear layer followed by a reduce-scatter""" + + dtype = gathered_outputs[0].dtype + assert all(go.device == self.my_device for go in gathered_outputs) + assert all(go.dtype == dtype for go in gathered_outputs) + assert all(so.device == self.my_device for so in scattered_outputs) + assert all(so.dtype == dtype for so in scattered_outputs) + + scattered_output_numels = [so.numel() for so in scattered_outputs] + total_scattered_output_numel = sum(scattered_output_numels) + self._ensure_staging_is_large_enough( + total_scattered_output_numel, random_init=_memcpy is False, dtype=dtype + ) + + seq_num = self.next_seq_num % SEQ_NUM_WRAP_AROUND + prev_seq_num = (seq_num - 1) % SEQ_NUM_WRAP_AROUND + self.next_seq_num += 1 + + stagings = [ + s.view((self.world_size,) + so.shape) + for s, so in zip( + self.staging.view(dtype)[:, :total_scattered_output_numel].split( + scattered_output_numels, dim=-1 + ), + scattered_outputs, + ) + ] + buddys_stagings = [ + [bs] * len(scattered_outputs) + if bs.numel() == 0 + else [ + s.view(so.shape) + for s, so in zip( + bs.view(dtype)[:total_scattered_output_numel].split( + scattered_output_numels, dim=-1 + ), + scattered_outputs, + ) + ] + for bs in self.buddys_staging + ] + + current_stream = torch.cuda.current_stream() + self.second_stream.wait_stream(current_stream) + self.compute_wait_stream.wait_stream(current_stream) + self.memcpy_wait_stream.wait_stream(current_stream) + stream_factory = self.make_stream_factory(current_stream) + + for iter_ in range(1, self.world_size): + dst_rank = (self.my_rank + iter_) % self.world_size + + # Wait for buddy to signal that it read from the data before we + # overwrite it (this wait matches up with write [2] below). + if _wait: + WaitValues.OPERATOR( + [self.op_finished_consume[dst_rank]], + prev_seq_num, + self.compute_wait_stream, + timeout_s, + ) + + current_stream.wait_stream(self.compute_wait_stream) + self.second_stream.wait_stream(self.compute_wait_stream) + + my_matmul([s[dst_rank] for s in stagings], dst_rank, stream_factory) + + # Deduce which stream contains the last kernel launched. + final_stream = [current_stream, self.second_stream][ + (self.next_stream_idx - 1) % 2 + ] + final_stream.wait_stream(current_stream) + final_stream.wait_stream(self.second_stream) + + # Signal to buddy that we have written into the data so it can + # read from it (this write matches up with wait [1] below). + if _wait: + Memset32bAsync.OPERATOR( + self.comms_ready_produce[dst_rank], + seq_num, + final_stream, + ) + + my_matmul( + [o[self.my_rank] for o in gathered_outputs], + self.my_rank, + stream_factory, + ) + + for iter_ in range(1, self.world_size): + src_rank = (self.my_rank - iter_) % self.world_size + + # Wait for buddy to signal that it wrote into the data before we + # read from it (this wait matches up with write [1] above). + if _wait: + WaitValues.OPERATOR( + [self.comms_ready_consume[src_rank]], + seq_num, + self.memcpy_wait_stream, + timeout_s, + ) + + self.memcpy_stream.wait_stream(self.memcpy_wait_stream) + + if _memcpy: + with torch.cuda.stream(self.memcpy_stream): + for go, bs in zip(gathered_outputs, buddys_stagings[src_rank]): + go[src_rank].copy_(bs) + + current_stream.wait_stream(self.second_stream) + current_stream.wait_stream(self.memcpy_stream) + + for go, so in zip(gathered_outputs, scattered_outputs): + torch.sum(go, dim=0, out=so) + + # Signal to buddy that we have read from the data so it can + # overwrite it (this write matches up with wait [2] above). + if _wait: + WriteValues.OPERATOR( + [self.op_finished_produce], + seq_num, + current_stream, + ) + + +# We'd store this as an attribute on the PG object itself, but some PGs are +# pybind-bound classes and thus don't support it, so we simulate this as an +# external cache. +CACHE: Dict[int, Optional[_FusedSequenceParallel]] = {} + + +def _can_ranks_communicate_all_to_all_over_nvlink(group: dist.ProcessGroup) -> bool: + # FIXME This is currently overly simplistic, must be improved. The following + # should be enough: + # - ensure that all ranks are running on the same machine (by exchanging + # their /proc/sys/kernel/random/boot_id value) + # - ensure there's P2P between all pairs of ranks (can_device_access_peer + # could help here but it's unclear what happens if target devices aren't + # visible? maybe just trying to exchange IPC handles and catching errors + # would work? note that in any case some ranks might succeed while some + # might fail so we need a barrier to have them all make the same decision) + return group.size() <= 8 + + +def _lazy_init( + device: torch.device, group: dist.ProcessGroup +) -> Optional[_FusedSequenceParallel]: + world_size = group.size() + try: + obj = CACHE[id(group)] + except KeyError: + if int(os.environ.get("DISABLE_FUSED_SEQUENCE_PARALLEL", "0")): + obj = None + elif world_size == 1: + obj = None + elif not _can_ranks_communicate_all_to_all_over_nvlink(group): + obj = None + else: + obj = _FusedSequenceParallel(device, group) + CACHE[id(group)] = obj + return obj + + +def _default_stream_factory() -> torch.cuda.Stream: + return torch.cuda.current_stream() + + +@overload +def fused_allgather_and_linear( + scattered_input: torch.Tensor, + weight: torch.Tensor, + *, + group: dist.ProcessGroup, + out: Optional[torch.Tensor] = None, + timeout_s: int = 60 * 60, + scale_scattered_input: Optional[torch.Tensor] = None, + scale_weight: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, + out_dtype: Optional[torch.dtype] = None, + **private_args_DO_NOT_USE, +) -> torch.Tensor: + ... + + +@overload +def fused_allgather_and_linear( + scattered_input: torch.Tensor, + weight: List[torch.Tensor], + *, + group: dist.ProcessGroup, + out: Optional[List[torch.Tensor]] = None, + timeout_s: int = 60 * 60, + scale_scattered_input: Optional[torch.Tensor] = None, + scale_weight: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, + out_dtype: Optional[torch.dtype] = None, + **private_args_DO_NOT_USE, +) -> List[torch.Tensor]: + ... + + +def fused_allgather_and_linear( + scattered_input: torch.Tensor, + weight: Union[torch.Tensor, List[torch.Tensor]], + *, + group: dist.ProcessGroup, + out: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, + timeout_s: int = 60 * 60, + scale_scattered_input: Optional[torch.Tensor] = None, + scale_weight: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, + out_dtype: Optional[torch.dtype] = None, + **private_args_DO_NOT_USE, +) -> Union[torch.Tensor, List[torch.Tensor]]: + """Performs a fused all-gather followed by a linear op + + It is equivalent to the following plain PyTorch code: + + # like scattered_input but with first dim multiplied by group's world size + gathered_input = scattered_input.new_empty(...) + dist.all_gather_into_tensor(gathered_input, scattered_input, group=group) + return torch.nn.functional.linear(gathered_input, weight) + + It achieves this by breaking down the matmul into smaller partial ops (as + many as the world size), each needing as input a different "contribution" + to the all-gather (by a different rank), and writing to a different chunk of + the output. Then, on one stream, it sends the local contribution to all + other ranks (first one rank over, then two, ...) while, on another stream, + it launches the sub-matmuls in the order in which the remote contributions + (which are the sub-matmuls' inputs) are supposed to arrive, so that ideally + none of the sub-matmuls will ever have to wait. + + The idea comes from this paper: https://arxiv.org/abs/2302.05442 + + This method uses a staging buffer, which persists across calls, of the same + size as the all-gathered input tensor (i.e., the input's size times the + world size). If multiple inputs of multiple sizes are used, the staging + buffer will be the maximum needed by any of them. Each call, when it starts, + must first wait for the previous call to finish using the staging buffer. In + normal conditions, where there's some other operation between two calls, + this isn't an issue. + + Supports FP8 gemm for tensor-wise quantized weight and input tensors. + To enable FP8 gemm: + 1. pass scattered_input and weight as quantized FP8 datatype + 2. pass scale_scattered_input and scale_weight, the scales used to + quantize input and weight, respectively. + 3. set out_dtype, if not specified, will be inferred from scattered_input type. + + """ + world_size = group.size() + weights = weight if isinstance(weight, list) else [weight] + assert (scale_scattered_input is None) == (scale_weight is None) + if scale_weight is not None: + assert isinstance(weight, list) == isinstance(scale_weight, list) + scales_weights: Sequence[Optional[torch.Tensor]] = ( + scale_weight if isinstance(scale_weight, list) else [scale_weight] + ) + assert len(weights) == len(scales_weights) + assert _is_fp8_dtype(scattered_input.dtype) + assert all(_is_fp8_dtype(w.dtype) for w in weights) + assert out_dtype is not None, "output_dtype is required with FP8" + else: + scales_weights = [None] * len(weights) + assert all(w.ndim == 2 for w in weights) + assert scattered_input.ndim >= 2 + assert all(scattered_input.shape[-1] == w.shape[-1] for w in weights) + assert scattered_input.is_contiguous() + gathered_input_shape = (world_size,) + scattered_input.shape + gathered_output_shapes = [gathered_input_shape[:-1] + w.shape[:-1] for w in weights] + if out is not None: + assert isinstance(out, list) == isinstance(weight, list) + gathered_outputs = out if isinstance(out, list) else [out] + assert len(gathered_outputs) == len(gathered_output_shapes) + assert all( + go.shape == gos for go, gos in zip(gathered_outputs, gathered_output_shapes) + ) + assert all(go.is_contiguous() for go in gathered_outputs) + if out_dtype is not None: + if isinstance(out, list): + for o in out: + assert o.dtype == out_dtype + else: + assert out.dtype == out_dtype + else: + gathered_outputs = [ + scattered_input.new_empty( + gos, + dtype=out_dtype if out_dtype is not None else scattered_input.dtype, + ) + for gos in gathered_output_shapes + ] + + torch.ops.xformers_python._fused_allgather_and_linear_impl( + scattered_input, + weights, + group.group_name, + gathered_outputs, + timeout_s=timeout_s, + _wait=private_args_DO_NOT_USE.get("_wait", True), + _memcpy=private_args_DO_NOT_USE.get("_memcpy", True), + scale_scattered_input=scale_scattered_input, + scales_weights=scales_weights, + ) + + if isinstance(weight, list): + return [go.flatten(0, 1) for go in gathered_outputs] + else: + return gathered_outputs[0].flatten(0, 1) + + +@torch.library.custom_op( + "xformers_python::_fused_allgather_and_linear_impl", + mutates_args={"gathered_outputs"}, + device_types="cuda", +) +def _fused_allgather_and_linear_custom_op( + scattered_input: torch.Tensor, + weights: List[torch.Tensor], + process_group_name: str, + gathered_outputs: List[torch.Tensor], + timeout_s: int, + _wait: bool, + _memcpy: bool, + scale_scattered_input: torch.Tensor, + scales_weights: Sequence[Optional[torch.Tensor]], +) -> None: + process_group = dist.distributed_c10d._resolve_process_group(process_group_name) + + def my_matmul( + inputs: List[torch.Tensor], + src_rank: int, + stream_factory: Callable[[], torch.cuda.Stream], + ) -> None: + for w, scale_weight, go in zip(weights, scales_weights, gathered_outputs): + with torch.cuda.stream(stream_factory()): + if scale_scattered_input is not None and scale_weight is not None: + torch._scaled_mm( + inputs[0], + w.t(), + out_dtype=go[src_rank].dtype, + scale_a=scale_scattered_input, + scale_b=scale_weight, + out=go[src_rank], + ) + else: + torch.matmul(inputs[0], w.t(), out=go[src_rank]) + + fused_allgather_and_anything( + [scattered_input], + my_matmul, + group=process_group, + timeout_s=timeout_s, + _wait=_wait, + _memcpy=_memcpy, + ) + + +def fused_allgather_and_anything( + scattered_inputs: List[torch.Tensor], + my_matmul: Callable[ + [List[torch.Tensor], int, Callable[[], torch.cuda.Stream]], None + ], + *, + group: dist.ProcessGroup, + timeout_s: int = 60 * 60, + **private_args_DO_NOT_USE, +) -> None: + world_size = group.size() + + if len(scattered_inputs) == 0: + for src_rank in range(world_size): + my_matmul([], src_rank, _default_stream_factory) + return + + assert all(si.is_contiguous() for si in scattered_inputs) + assert all(si.device == scattered_inputs[0].device for si in scattered_inputs) + assert all(si.dtype == scattered_inputs[0].dtype for si in scattered_inputs) + + gathered_input_shapes = [(world_size,) + si.shape for si in scattered_inputs] + + obj = _lazy_init(scattered_inputs[0].device, group) + + if world_size == 1: + my_matmul(scattered_inputs, 0, _default_stream_factory) + + # Fallback + elif obj is None: + gathered_inputs = [ + si.new_empty(gis) + for si, gis in zip(scattered_inputs, gathered_input_shapes) + ] + for si, gi in zip(scattered_inputs, gathered_inputs): + dist.all_gather_into_tensor(output_tensor=gi, input_tensor=si, group=group) + for src_rank in range(world_size): + my_matmul( + [gi[src_rank] for gi in gathered_inputs], + src_rank, + _default_stream_factory, + ) + + # Fast path + else: + assert scattered_inputs[0].device == obj.my_device + obj.allgather_and_linear( + scattered_inputs, + my_matmul, + timeout_s=timeout_s, + _wait=private_args_DO_NOT_USE.get("_wait", True), + _memcpy=private_args_DO_NOT_USE.get("_memcpy", True), + ) + + +@overload +def fused_linear_and_reducescatter( + gathered_input: torch.Tensor, + weight: torch.Tensor, + *, + group: dist.ProcessGroup, + out: Optional[torch.Tensor] = None, + timeout_s: int = 60 * 60, + scale_gathered_input: Optional[torch.Tensor] = None, + scale_weight: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, + out_dtype: Optional[torch.dtype] = None, + **private_args_DO_NOT_USE, +) -> torch.Tensor: + ... + + +@overload +def fused_linear_and_reducescatter( + gathered_input: torch.Tensor, + weight: List[torch.Tensor], + *, + group: dist.ProcessGroup, + out: Optional[List[torch.Tensor]] = None, + timeout_s: int = 60 * 60, + scale_gathered_input: Optional[torch.Tensor] = None, + scale_weight: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, + out_dtype: Optional[torch.dtype] = None, + **private_args_DO_NOT_USE, +) -> List[torch.Tensor]: + ... + + +def fused_linear_and_reducescatter( + gathered_input: torch.Tensor, + weight: Union[torch.Tensor, List[torch.Tensor]], + *, + group: dist.ProcessGroup, + out: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, + timeout_s: int = 60 * 60, + scale_gathered_input: Optional[torch.Tensor] = None, + scale_weight: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, + out_dtype: Optional[torch.dtype] = None, + **private_args_DO_NOT_USE, +) -> Union[torch.Tensor, List[torch.Tensor]]: + """Performs a fused linear op followed by a reduce-scatter + + It is equivalent to the following plain PyTorch code: + + gathered_output = torch.nn.functional.linear(gathered_input, weight) + # like gathered_output but with first dim divided by group's world size + scattered_output = gathered_output.new_empty(...) + dist.reduce_scatter_tensor(scattered_output, gathered_output, group=group) + + Supports FP8 gemm with tensor-wise quantized weights. To enable FP8 gemm: + 1. pass weight and gathered_input as FP8 tensors + 2. Set `scale_gathered_input` and `scale_weight` to the scales used to quantize + inputs and weight, respectively. + 3. Set out_dtype to the desired output dtype. If not specified, it will be inferred from + gathered_input datatype. + """ + world_size = group.size() + weights = weight if isinstance(weight, list) else [weight] + assert (scale_gathered_input is None) == (scale_weight is None) + if scale_weight is not None: + assert isinstance(weight, list) == isinstance(scale_weight, list) + scales_weights: Sequence[Optional[torch.Tensor]] = ( + scale_weight if isinstance(scale_weight, list) else [scale_weight] + ) + assert len(weights) == len(scales_weights) + assert _is_fp8_dtype(gathered_input.dtype) + assert all(_is_fp8_dtype(w.dtype) for w in weights) + assert out_dtype is not None, "output_dtype is required with FP8" + else: + scales_weights = [None] * len(weights) + assert all(w.ndim == 2 for w in weights) + assert gathered_input.ndim >= 2 + assert all(gathered_input.shape[-1] == w.shape[-1] for w in weights) + assert gathered_input.is_contiguous() + assert gathered_input.shape[0] % world_size == 0 + gathered_input = gathered_input.view( + (world_size, gathered_input.shape[0] // world_size) + gathered_input.shape[1:] + ) + gathered_output_shapes = [gathered_input.shape[:-1] + w.shape[:-1] for w in weights] + scattered_output_shapes = [gos[1:] for gos in gathered_output_shapes] + if out is not None: + assert isinstance(out, list) == isinstance(weight, list) + scattered_outputs = out if isinstance(out, list) else [out] + assert len(scattered_outputs) == scattered_output_shapes + assert all(so.device == gathered_input.device for so in scattered_outputs) + assert all(so.dtype == gathered_input.dtype for so in scattered_outputs) + assert all( + so.shape == sos + for so, sos in zip(scattered_outputs, scattered_output_shapes) + ) + if out_dtype is not None: + if isinstance(out, list): + for o in out: + assert o.dtype == out_dtype + else: + assert out.dtype == out_dtype + else: + scattered_outputs = [ + gathered_input.new_empty( + sos, + dtype=out_dtype if out_dtype is not None else gathered_input.dtype, + ) + for sos in scattered_output_shapes + ] + + torch.ops.xformers_python._fused_linear_and_reducescatter_impl( + gathered_input, + weights, + group.group_name, + scattered_outputs, + timeout_s=timeout_s, + _wait=private_args_DO_NOT_USE.get("_wait", True), + _memcpy=private_args_DO_NOT_USE.get("_memcpy", True), + scale_gathered_input=scale_gathered_input, + scales_weights=scales_weights, + ) + + if isinstance(weight, list): + return scattered_outputs + else: + return scattered_outputs[0] + + +@torch.library.custom_op( + "xformers_python::_fused_linear_and_reducescatter_impl", + mutates_args={"scattered_outputs"}, + device_types="cuda", +) +def _fused_linear_and_reducescatter_custom_op( + gathered_input: torch.Tensor, + weights: List[torch.Tensor], + process_group_name: str, + scattered_outputs: List[torch.Tensor], + timeout_s: int, + _wait: bool, + _memcpy: bool, + scale_gathered_input: torch.Tensor, + scales_weights: Sequence[Optional[torch.Tensor]], +) -> None: + process_group = dist.distributed_c10d._resolve_process_group(process_group_name) + + def my_matmul( + outputs: List[torch.Tensor], + dst_rank: int, + stream_factory: Callable[[], torch.cuda.Stream], + ) -> None: + for w, scale_weight, o in zip(weights, scales_weights, outputs): + with torch.cuda.stream(stream_factory()): + if scale_gathered_input is not None and scale_weight is not None: + torch._scaled_mm( + gathered_input[dst_rank], + w.t(), + out_dtype=o.dtype, + scale_a=scale_gathered_input, + scale_b=scale_weight, + out=o, + ) + else: + torch.matmul(gathered_input[dst_rank], w.t(), out=o) + + fused_anything_and_reducescatter( + my_matmul, + scattered_outputs, + group=process_group, + timeout_s=timeout_s, + _wait=_wait, + _memcpy=_memcpy, + ) + + +def fused_anything_and_reducescatter( + my_matmul: Callable[ + [List[torch.Tensor], int, Callable[[], torch.cuda.Stream]], None + ], + scattered_outputs: List[torch.Tensor], + *, + group: dist.ProcessGroup, + timeout_s: int = 60 * 60, + **private_args_DO_NOT_USE, +) -> None: + world_size = group.size() + + if len(scattered_outputs) == 0: + for dst_rank in range(world_size): + my_matmul([], dst_rank, _default_stream_factory) + return + + assert all(so.is_contiguous() for so in scattered_outputs) + assert all(so.device == scattered_outputs[0].device for so in scattered_outputs) + assert all(so.dtype == scattered_outputs[0].dtype for so in scattered_outputs) + + gathered_output_shapes = [(world_size,) + so.shape for so in scattered_outputs] + + obj = _lazy_init(scattered_outputs[0].device, group) + + if world_size == 1: + my_matmul(scattered_outputs, 0, _default_stream_factory) + + # Fallback + elif obj is None: + gathered_outputs = [ + so.new_empty(gos) + for so, gos in zip(scattered_outputs, gathered_output_shapes) + ] + for dst_rank in range(world_size): + my_matmul( + [go[dst_rank] for go in gathered_outputs], + dst_rank, + _default_stream_factory, + ) + for go, so in zip(gathered_outputs, scattered_outputs): + dist.reduce_scatter_tensor(output=so, input=go, group=group) + + # Fast path + else: + assert scattered_outputs[0].device == obj.my_device + gathered_outputs = [ + scattered_outputs[0].new_empty(gos) for gos in gathered_output_shapes + ] + obj.linear_and_reducescatter( + my_matmul, + gathered_outputs, + scattered_outputs, + timeout_s=timeout_s, + _wait=private_args_DO_NOT_USE.get("_wait", True), + _memcpy=private_args_DO_NOT_USE.get("_memcpy", True), + ) diff --git a/.venv/lib/python3.11/site-packages/xformers/ops/swiglu_op.py b/.venv/lib/python3.11/site-packages/xformers/ops/swiglu_op.py new file mode 100644 index 0000000000000000000000000000000000000000..630335ac6c6b3b015b5a4c3a491bd0fe005b1a07 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/ops/swiglu_op.py @@ -0,0 +1,554 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from typing import Dict, Optional, Sequence, Tuple, Union + +import torch +import torch.nn.functional as F +from torch import nn +from torch.amp import custom_bwd, custom_fwd + +from .common import BaseOperator, get_xformers_operator, register_operator +from .unbind import stack_or_none, unbind + +if torch.version.hip: + + @torch.library.register_kernel("xformers::dual_gemm_silu_identity_mul", "cuda") # type: ignore + def dual_gemm_silu_identity_mul_cuda( + x: torch.Tensor, + w1: torch.Tensor, + b1: Optional[torch.Tensor], + w2: torch.Tensor, + b2: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + x1 = x @ w1.T + if b1 is not None: + x1 += b1 + + x2 = x @ w2.T + if b2 is not None: + x2 += b2 + + x4 = F.silu(x1) * x2 + return x1, x2, x4 + + +@register_operator +class DualGemmSiluOp(BaseOperator): + OPERATOR = get_xformers_operator("dual_gemm_silu_identity_mul") + OPERATOR_CATEGORY = "swiglu" + NAME = "dual_gemm_silu" + + +@register_operator +class GemmFusedSumOp(BaseOperator): + OPERATOR = get_xformers_operator("gemm_fused_operand_sum") + OPERATOR_CATEGORY = "swiglu" + NAME = "gemm_fused_operand_sum" + + +class _SwiGLUDecomposedFunc(torch.autograd.Function): + """ + This is just an example implementation with all + operations explicited. This implementation is worse + than pytorch, because pytorch is able to fuse some operations + (eg the linear forward ...) that are decomposed here. + + The time measurements were made on the ViT-Giant setting: + - A100/f16 + - input: [4440, 1536] + - hidden: [4440, 4096] + """ + + NAME = "decomposed" + FORCE_BW_F32 = False + + def _silu_backward(dy, x): + # https://github.com/pytorch/pytorch/blob/563b065f5a4b4055fa6b025c2514b566d5fd9439/aten/src/ATen/native/Activation.cpp#L483 + sigm = 1 / (1 + torch.exp(-x.float())) + return (dy.float() * sigm * (1 + x.float() * (1 - sigm))).to(x.dtype) + + # 952us + @classmethod + def forward(cls, ctx, x, w1, b1, w2, b2, w3, b3): + x1 = x @ w1.transpose(-2, -1) + b1 # 275us + x2 = x @ w2.transpose(-2, -1) + b2 # 275us + x3 = F.silu(x1) # 62us + x4 = x3 * x2 # 90us + x5 = x4 @ w3.transpose(-2, -1) + b3 # 250us + + ctx.save_for_backward(x, w1, b1, w2, b2, w3, b3, x1, x2, x3, x4, x5) + return x5 + + # 1900us + @classmethod + def backward(cls, ctx, dx5): + saved_tensors = ctx.saved_tensors + if cls.FORCE_BW_F32: + dx5 = dx5.float() + saved_tensors = [t.float() for t in ctx.saved_tensors] + x, w1, b1, w2, b2, w3, b3, x1, x2, x3, x4, x5 = saved_tensors + dx4 = dx5 @ w3 # 255us (nn) + dw3 = dx5.transpose(-2, -1) @ x4 # 247us (nt) + db3 = dx5.sum(0) # 25us + dx3 = dx4 * x2 # 88us + dx2 = dx4 * x3 # 88us + dx1 = cls._silu_backward(dx3, x1) # 90us + dx = dx2 @ w2 # 260us (nn) + dw2 = dx2.transpose(-2, -1) @ x # 245us (nt) + db2 = dx2.sum(0) # 50us + dx += dx1 @ w1 # 260us (nn) + dw1 = dx1.transpose(-2, -1) @ x # 245us (nt) + db1 = dx1.sum(0) # 50us + return (dx, dw1, db1, dw2, db2, dw3, db3) + + +class _SwiGLUFusedFunc(torch.autograd.Function): + NAME = "fused.py" + + @classmethod + @custom_fwd(device_type="cuda") + def forward(cls, ctx, x, w1, b1, w2, b2, w3, b3): + x1, x2, x4 = DualGemmSiluOp.OPERATOR(x, w1, b1, w2, b2) + + x5 = F.linear(x4, w3, b3) + ctx.save_for_backward(x, w1, w2, w3, x1, x2) + ctx.bias = [b1 is not None, b2 is not None, b3 is not None] + return x5 + + @staticmethod + def _linear_bw( + dy: torch.Tensor, x: torch.Tensor, bias: bool + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + if not bias: + return (dy.transpose(-2, -1) @ x), None + db = torch.empty([dy.shape[1]], dtype=dy.dtype, device=dy.device) + dw = torch.empty([dy.shape[1], x.shape[1]], dtype=dy.dtype, device=dy.device) + GemmFusedSumOp.OPERATOR(dy.transpose(-2, -1), x, dw, db) + return dw, db + + @classmethod + @custom_bwd(device_type="cuda") + def backward(cls, ctx, dx5): + x, w1, w2, w3, x1, x2 = ctx.saved_tensors + w1w2 = stack_or_none([w1, w2], dim=0) + + dx4 = dx5 @ w3 # 255us (nn) + dx1dx2, x4 = torch.ops.xformers.silu_bw_fused(x1, x2, dx4) + dx1, dx2 = dx1dx2.unbind(1) + del x1, x2, dx4 + + dw3, db3 = cls._linear_bw(dx5, x4, bias=ctx.bias[2]) + del x4, dx5 + if w1w2 is not None: + assert dx1dx2.is_contiguous() + assert w1w2.is_contiguous() + w1w2 = w1w2.view([w1.shape[0] * 2, w1.shape[1]]) + dx = dx1dx2.view([dx1.shape[0], 2 * dx1.shape[1]]) @ w1w2 + + # backward of linear1 + linear2 - packed + dw1dw2 = dx1dx2.view([dx1.shape[0], 2 * dx1.shape[1]]).transpose(-2, -1) @ x + dw1dw2, db1db2 = cls._linear_bw( + dx1dx2.view([dx1.shape[0], 2 * dx1.shape[1]]), x, bias=ctx.bias[0] + ) + dw1, dw2 = dw1dw2.view([2, *w1.shape]).unbind(0) + if ctx.bias[0]: + db1db2 = db1db2.view([2, dx1.shape[1]]) + db1, db2 = torch.unbind(db1db2, dim=0) + else: + db1 = db2 = None + else: + dx = dx2 @ w2 # 260us (nn) + torch.addmm( + dx, dx1, w1.to(dx1.dtype), beta=1, alpha=1, out=dx + ) # dx += dx1 @ w1 + dw2, db2 = cls._linear_bw(dx2, x, bias=ctx.bias[1]) + dw1, db1 = cls._linear_bw(dx1, x, bias=ctx.bias[0]) + return (dx, dw1, db1, dw2, db2, dw3, db3) + + +class SwiGLUOp: + """Base class for any swiglu operator in :attr:`xformers.ops.swiglu`""" + + def __init__(self, op, packed_weights: bool, name: str, constraints): + self.NAME = name + self.PACKED_WEIGHTS = packed_weights + self.op = op + self.constraints = constraints + + def supports(self, op: "SwiGLUOpDispatch") -> bool: + if self.PACKED_WEIGHTS and not op.packed_weights: + return False + return all(c(op) for c in self.constraints) + + def __call__(self, *args: Optional[torch.Tensor]) -> torch.Tensor: + raise NotImplementedError() + + def __str__(self) -> str: + return f"SwiGLUOp:{self.NAME}" + + +class _ForwardToPythonAutogradFunc(SwiGLUOp): + def supports(self, op: "SwiGLUOpDispatch") -> bool: + return super().supports(op) + + def __call__(self, *args, **kwargs): + return self.op.apply(*args, **kwargs) + + +class _ForwardToFunc(SwiGLUOp): + def __call__(self, *args, **kwargs): + return self.op(*args, **kwargs) + + def info(self): + if self.op.__name__ == "no_such_operator": + return "not built" + return "available" + + +def _eager_functional_swiglu( + x: torch.Tensor, + w1: torch.Tensor, + b1: torch.Tensor, + w2: torch.Tensor, + b2: torch.Tensor, + w3: torch.Tensor, + b3: torch.Tensor, +) -> torch.Tensor: + x1 = F.linear(x, w1, b1) + x2 = F.linear(x, w2, b2) + hidden = F.silu(x1) * x2 + return F.linear(hidden, w3, b3) + + +@dataclass +class SwiGLUOpDispatch: + """Dispatcher to automatically select + the best operator in :attr:`xformers.ops.swiglu` + """ + + device: Union[torch.device, str] + dtype: torch.dtype + dtype_autocast_gpu: Optional[torch.dtype] + packed_weights: bool + bias_enabled: bool + + @property + def op(self) -> SwiGLUOp: + """Computes the best operator + + Returns: + SwiGLUOp: The best operator for the configuration + """ + priorities: Sequence[SwiGLUOp] = [ + SwiGLUPackedFusedOp, + SwiGLUFusedOp, + ] + for op in priorities: + if op.supports(self): + return op + return SwiGLUEagerOp + + @staticmethod + def from_arguments( + x: torch.Tensor, + w1: torch.Tensor, + b1: Optional[torch.Tensor], + w2: torch.Tensor, + b2: Optional[torch.Tensor], + w3: torch.Tensor, + b3: Optional[torch.Tensor], + ) -> "SwiGLUOpDispatch": + return SwiGLUOpDispatch( + device=x.device, + dtype=x.dtype, + packed_weights=stack_or_none((w1, w2), dim=0) is not None, + dtype_autocast_gpu=torch.get_autocast_gpu_dtype() + if torch.is_autocast_enabled() + else w1.dtype, + bias_enabled=b1 is not None and b2 is not None and b3 is not None, + ) + + +def _only_sm80(op: SwiGLUOpDispatch) -> bool: + device_type = op.device if isinstance(op.device, str) else op.device.type + return device_type == "cuda" and torch.cuda.get_device_capability(op.device)[0] >= 8 + + +def _only_half_or_autocast(op: SwiGLUOpDispatch) -> bool: + HALF_DTYPES = [torch.half, torch.bfloat16] + return op.dtype in HALF_DTYPES or ( + op.dtype_autocast_gpu is not None and op.dtype_autocast_gpu in HALF_DTYPES + ) + + +def _bias_enabled(op: SwiGLUOpDispatch) -> bool: + return op.bias_enabled + + +_SwiGLUDecomposedOp = _ForwardToPythonAutogradFunc( + _SwiGLUDecomposedFunc, False, "decomposed", constraints=[_bias_enabled] +) +SwiGLUFusedOp = _ForwardToPythonAutogradFunc( + _SwiGLUFusedFunc, False, "fused", constraints=[_only_sm80, _only_half_or_autocast] +) +SwiGLUPackedFusedOp = _ForwardToFunc( + get_xformers_operator("swiglu_packedw"), + True, + "fused.p.cpp", + constraints=[_only_sm80, _only_half_or_autocast], +) +SwiGLUEagerOp = _ForwardToFunc( + _eager_functional_swiglu, + False, + "eager", + constraints=[], +) + + +def _info() -> Dict[str, str]: + return {op.NAME: op.info() for op in [SwiGLUPackedFusedOp]} + + +def swiglu( + x: torch.Tensor, + w1: torch.Tensor, + b1: Optional[torch.Tensor], + w2: torch.Tensor, + b2: Optional[torch.Tensor], + w3: torch.Tensor, + b3: Optional[torch.Tensor], + *, + op: Optional[SwiGLUOp] = None, +) -> torch.Tensor: + """ + Computes a SwiGLU block given the weights/bias of the 3 + linear layers. + + - It is recommended to keep ``op=None`` so the best implementation \ + available for the inputs will be used. + + + :Equivalent pytorch code: + + .. code-block:: python + + x1 = F.linear(x, w1, b1) + x2 = F.linear(x, w2, b2) + hidden = F.silu(x1) * x2 + return F.linear(hidden, w3, b3) + + :Packing weights: + + To allow faster implementations, it's recommended to have w1/w2 come from the same storage, as in: + .. code-block:: python + + w1, w2 = xformers.ops.unbind(w12, 0) + + :Supported hardware: + + This operator is only optimized on A100+ on ``torch.half`` or ``torch.bfloat16`` \ + (autocast is supported), and will fallback to a functional pytorch \ + implementation otherwise. + """ + + batch_shape = x.shape[:-1] + x = x.reshape([-1, x.shape[-1]]) + if w1.ndim != 2 or w1.shape != w2.shape: + raise ValueError(f"Invalid shapes for w1: {w1.shape} / w2: {w2.shape}") + if b1 is not None: + if b1.ndim != 1 or b1.shape[0] != w1.shape[0]: + raise ValueError(f"Invalid shapes for b1: {b1.shape}") + if b2 is not None: + if b2.ndim != 1 or b2.shape[0] != w2.shape[0]: + raise ValueError(f"Invalid shapes for b2: {b2.shape}") + if w3.ndim != 2 or w3.shape[1] != w2.shape[0]: + raise ValueError(f"Invalid shape for w3: {w3.shape}") + if b3 is not None: + if b3.ndim != 1 or b3.shape[0] != w3.shape[0]: + raise ValueError(f"Invalid shapes for w3: {w3.shape} / b3: {b3.shape}") + + if op is None: + op = SwiGLUOpDispatch.from_arguments(x, w1, b1, w2, b2, w3, b3).op + + if not op.PACKED_WEIGHTS: + return op(x, w1, b1, w2, b2, w3, b3).reshape([*batch_shape, -1]) + w1w2 = stack_or_none((w1, w2), dim=0) + if b1 is not None and b2 is not None: + b1b2: Optional[torch.Tensor] = stack_or_none((b1, b2), dim=0) + if b1b2 is None: + raise NotImplementedError("b1/b2 needs to be properly packed") + else: + b1b2 = None + assert b1 is None and b2 is None + + if w1w2 is None: + raise NotImplementedError("w1/w2 needs to be properly packed") + return op(x, w1w2, b1b2, w3, b3).reshape([*batch_shape, -1]) + + +def swiglu_packed( + x: torch.Tensor, + w1w2: torch.Tensor, + b1b2: Optional[torch.Tensor], + w3: torch.Tensor, + b3: Optional[torch.Tensor], + *, + op: SwiGLUOp, +) -> torch.Tensor: + """ + Computes a SwiGLU block given the weights/bias of the 3 + linear layers. + + :Equivalent pytorch code: + + .. code-block:: python + + x1 = F.linear(x, w1, b1) + x2 = F.linear(x, w2, b2) + hidden = F.silu(x1) * x2 + return F.linear(hidden, w3, b3) + + :Supported hardware: + + This operator is only optimized on A100+ on ``torch.half`` or ``torch.bfloat16`` \ + (autocast is supported), and will fallback to a functional pytorch \ + implementation otherwise. + """ + batch_shape = x.shape[:-1] + x = x.reshape([-1, x.shape[-1]]) + + if b3 is not None: + if b3.ndim != 1 or b3.shape[0] != w3.shape[0]: + raise ValueError(f"Invalid shapes for w3: {w3.shape} / b3: {b3.shape}") + + assert op.PACKED_WEIGHTS, "Not implemented PACKED_WEIGHTS" + + return op(x, w1w2, b1b2, w3, b3).reshape([*batch_shape, -1]) + + +class SwiGLU(nn.Module): + """ + A Module that encapsulates the call to :attr:`xformers.ops.swiglu`, + and holds the weights for the 3 linear layers + """ + + def __init__( + self, + in_features: int, + hidden_features: int, + out_features: Optional[int] = None, + bias: bool = True, + *, + _pack_weights: bool = True, + ) -> None: + """Create a SwiGLU module + + Args: + in_features (int): Number of features of the input + hidden_features (int): Number of hidden features + out_features (Optional[int], optional): Number of features of the input. Defaults to None. + bias (bool, optional): Whether linear layers also include a bias. Defaults to True. + """ + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + + self.w12: Optional[nn.Linear] + if _pack_weights: + self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) + else: + self.w12 = None + self.w1 = nn.Linear(in_features, hidden_features, bias=bias) + self.w2 = nn.Linear(in_features, hidden_features, bias=bias) + self.w3 = nn.Linear(hidden_features, out_features, bias=bias) + + self.hidden_features = hidden_features + self.out_features = out_features + self.in_features = in_features + self.op: Optional[SwiGLUOp] = None + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Computes :attr:`swiglu` with the module's weights + + Args: + x (torch.Tensor): A Tensor of shape ``[..., in_features]`` + + Returns: + torch.Tensor: A Tensor of shape ``[..., out_features]`` + """ + if self.w12 is not None: + if self.op is not None: + assert ( + self.op.PACKED_WEIGHTS + ), "_pack_weights and self.op.PACKED_WEIGHTS should match" + return swiglu_packed(x, *self._packed_ordered_params(), op=self.op) + + return swiglu(x, *self._ordered_params(), op=self.op) + + def _ordered_params( + self, + ) -> Tuple[ + torch.Tensor, + Optional[torch.Tensor], + torch.Tensor, + Optional[torch.Tensor], + torch.Tensor, + Optional[torch.Tensor], + ]: + """Used for testing - returns ordered arguments for operators""" + b1: Optional[torch.Tensor] + b2: Optional[torch.Tensor] + if self.w12 is not None: + w1w2 = self.w12.weight + b1b2 = self.w12.bias + w1, w2 = unbind( + w1w2.view([2, w1w2.shape[0] // 2, w1w2.shape[1]]), + dim=0, + ) + if b1b2 is not None: + b1, b2 = unbind(b1b2.view([2, b1b2.shape[0] // 2]), dim=0) + else: + b1, b2 = None, None + else: + w1, w2 = self.w1.weight, self.w2.weight + b1, b2 = self.w1.bias, self.w2.bias + + return ( + w1, + b1, + w2, + b2, + self.w3.weight, + self.w3.bias, + ) + + def _packed_ordered_params( + self, + ) -> Tuple[ + torch.Tensor, + Optional[torch.Tensor], + torch.Tensor, + Optional[torch.Tensor], + ]: + assert self.w12 is not None, "Packed weights are only available when using w12" + + """Used for testing - returns ordered arguments for packed operators""" + w1w2 = self.w12.weight + b1b2_param = self.w12.bias + + w1w2 = w1w2.view([2, w1w2.shape[0] // 2, w1w2.shape[1]]) + + b1b2: Optional[torch.Tensor] = None + if b1b2_param is not None: + b1b2 = b1b2_param.view([2, b1b2_param.shape[0] // 2]) + + return ( + w1w2, + b1b2, + self.w3.weight, + self.w3.bias, + ) diff --git a/.venv/lib/python3.11/site-packages/xformers/ops/tiled_matmul.py b/.venv/lib/python3.11/site-packages/xformers/ops/tiled_matmul.py new file mode 100644 index 0000000000000000000000000000000000000000..d3e77061248ba308c15ebcaead78af886e7191b0 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/ops/tiled_matmul.py @@ -0,0 +1,318 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import os +from typing import List, Tuple + +import torch + +from .. import _is_triton_available + + +# Copied over from the sequence parallel fused ops. +def _should_use_triton(device: torch.device, dtype: torch.dtype) -> bool: + if not int(os.getenv("XFORMERS_TILED_MATMUL_ENABLE_TRITON", "1")): + return False + if not _is_triton_available(): + return False + device_capability = torch.cuda.get_device_capability(device) + # Triton seems to be having issues on P100 and V100 GPUs, such as + # https://github.com/openai/triton/issues/1609 + # https://github.com/openai/triton/issues/1610 + # https://github.com/openai/triton/issues/1257#issuecomment-1532616965 + # and, in recent Triton versions (Jan 2024), returning wrong values. + if device_capability < (8, 0): + return False + return True + + +def check_inputs( + a: List[List[torch.Tensor]], + b: List[List[torch.Tensor]], +) -> Tuple[List[int], List[int], List[int]]: + assert len(a) >= 1 and len(a[0]) >= 1 and all(len(row) == len(a[0]) for row in a), ( + "the first operand must be a non-empty two-dimensional regular list of lists " + "of tenors" + ) + assert len(b) >= 1 and len(b[0]) >= 1 and all(len(row) == len(b[0]) for row in b), ( + "the second operand must be a non-empty two-dimensional regular list of lists " + "of tenors" + ) + + m_tiles = len(a) + k_tiles = len(a[0]) + assert len(b) == k_tiles, ( + "the first operand's inner dimension must match the second operand's outer " + f"dimension, got {k_tiles} and {len(b)}" + ) + n_tiles = len(b[0]) + + ms = [a[tile_m][0].shape[0] for tile_m in range(m_tiles)] + ns = [b[0][tile_n].shape[1] for tile_n in range(n_tiles)] + aks = [a[0][tile_k].shape[1] for tile_k in range(k_tiles)] + bks = [b[tile_k][0].shape[0] for tile_k in range(k_tiles)] + + for tile_m in range(m_tiles): + for tile_k in range(k_tiles): + assert a[tile_m][tile_k].shape[0] == ms[tile_m], ( + f"the tensors on row {tile_m} of the first operand must all have the " + f"same size along the m dimension, got {ms[tile_m]} at position 0 and " + f"{a[tile_m][tile_k].shape[0]} at position {tile_k}" + ) + assert a[tile_m][tile_k].shape[1] == aks[tile_k], ( + f"the tensors on column {tile_k} of the first operand must all have " + f"the same size along the k dimension, got {aks[tile_k]} at position 0 " + f"and {a[tile_m][tile_k].shape[1]} at position {tile_m}" + ) + + for tile_n in range(n_tiles): + for tile_k in range(k_tiles): + assert b[tile_k][tile_n].shape[0] == bks[tile_k], ( + f"the tensors on row {tile_k} of the second operand must all have the " + f"same size along the k dimension, got {bks[tile_k]} at position 0 and " + f"{b[tile_k][tile_n].shape[0]} at position {tile_n}" + ) + assert b[tile_k][tile_n].shape[1] == ns[tile_n], ( + f"the tensors on column {tile_n} of the second operand must all have " + f"the same size along the n dimension, got {ns[tile_n]} at position 0 " + f"and {b[tile_k][tile_n].shape[1]} at position {tile_k}" + ) + + for tile_k in range(k_tiles): + assert aks[tile_k] == bks[tile_k], ( + f"the tensors on column {tile_k} of the first operand and those on row " + f"{tile_k} of the second operand must have the same size along the k " + f"dimension, got {aks[tile_k]} and {bks[tile_k]}" + ) + ks = aks + + return ms, ns, ks + + +def check_output(out: List[List[torch.Tensor]], ms: List[int], ns: List[int]) -> None: + m_tiles, n_tiles = len(ms), len(ns) + assert ( + len(out) >= 1 + and len(out[0]) >= 1 + and all(len(row) == len(out[0]) for row in out) + ), "out must be a non-empty two-dimensional regular list of lists of tenors" + assert len(out) == m_tiles + assert len(out[0]) == n_tiles + cms = [out[tile_m][0].shape[0] for tile_m in range(m_tiles)] + cns = [out[0][tile_n].shape[1] for tile_n in range(n_tiles)] + for tile_m in range(m_tiles): + for tile_n in range(n_tiles): + assert out[tile_m][tile_n].shape[0] == cms[tile_m], ( + f"the tensors on row {tile_m} of out must all have the same size " + f"along the m dimension, got {cms[tile_m]} at position 0 and " + f"{out[tile_m][tile_n].shape[0]} at position {tile_n}" + ) + assert out[tile_m][tile_n].shape[1] == cns[tile_n], ( + f"the tensors on column {tile_n} of out must all have the same " + f"size along the k dimension, got {cns[tile_n]} at position 0 and " + f"{out[tile_m][tile_n].shape[1]} at position {tile_m}" + ) + for tile_m in range(m_tiles): + assert cms[tile_m] == ms[tile_m], ( + f"the tensors on row {tile_m} of out and those on row {tile_m} of the " + f"first operand must have the same size along the m dimension, got " + f"{cms[tile_m]} and {ms[tile_m]}" + ) + for tile_n in range(n_tiles): + assert cns[tile_n] == ns[tile_n], ( + f"the tensors on column {tile_n} of out and those on column {tile_n} " + f"of the second operand must have the same size along the n dimension, " + f"got {cns[tile_n]} and {ns[tile_n]}" + ) + + +# Using out= args in PyTorch is complicated, especially with custom_ops and +# torch.compile (we need to declare our side-effects with mutates_args, which +# then requires a functionalization step, ...). Thus this out= variant of the +# operator is exposed as a PLAIN PYTHON function, and is not compilable nor +# differentiable. It needs to be invoked from within a custom_op elsewhere. +def tiled_matmul_out( + a: List[List[torch.Tensor]], + b: List[List[torch.Tensor]], + out: List[List[torch.Tensor]], +) -> None: + ms, ns, ks = check_inputs(a, b) + check_output(out, ms, ns) + + # TODO We can try merging tiles that come from contiguous memory, using + # stack_or_none, to further improve performance. + + # Because the Triton kernel is hardcoded for maximum three tiles. + # Because, in turn, we aimed this at the fusion of wq/wk/wv. + if ( + len(ms) <= 3 + and len(ks) <= 3 + and len(ns) <= 3 + and _should_use_triton(a[0][0].device, a[0][0].dtype) + ): + from ._triton.tiled_matmul_kernels import _launch_triton_matmul + + _launch_triton_matmul(a, b, out, ms, ns, ks) + else: + for tile_m in range(len(ms)): + for tile_n in range(len(ns)): + torch.mm(a[tile_m][0], b[0][tile_n], out=out[tile_m][tile_n]) + for tile_k in range(1, len(ks)): + out[tile_m][tile_n].addmm_(a[tile_m][tile_k], b[tile_k][tile_n]) + + +def _flatten(x: List[List[torch.Tensor]], rows: int, cols: int) -> List[torch.Tensor]: + assert len(x) == rows + assert all(len(row) == cols for row in x) + flat_x = [elem for row in x for elem in row] + assert len(flat_x) == rows * cols + return flat_x + + +def _unflatten( + flat_x: List[torch.Tensor], rows: int, cols: int +) -> List[List[torch.Tensor]]: + assert len(flat_x) == cols * rows + x = [ + flat_x[row_offset : row_offset + cols] + for row_offset in range(0, rows * cols, cols) + ] + assert len(x) == rows + assert all(len(row) == cols for row in x) + return x + + +def _flattened_transpose( + flat_x: List[torch.Tensor], rows: int, cols: int +) -> List[torch.Tensor]: + x = _unflatten(flat_x, rows, cols) + transposed_x = [[elem.t() for elem in col] for col in zip(*x)] + flat_transposed_x = _flatten(transposed_x, cols, rows) + return flat_transposed_x + + +# PyTorch (custom_op and torch.compile, but also the dispatcher in general) +# have a hard time with Tensor[][] args. Thus we flatten them into Tensor[] to +# pass them into and out of this operator. +# See: https://github.com/pytorch/pytorch/issues/113022 +@torch.library.custom_op( + "xformers_python::tiled_matmul_fwd", + mutates_args=(), + device_types="cuda", +) +def tiled_matmul_fwd( + flat_a: List[torch.Tensor], + flat_b: List[torch.Tensor], + ms: List[int], + ns: List[int], + ks: List[int], +) -> List[torch.Tensor]: + a = _unflatten(flat_a, len(ms), len(ks)) + b = _unflatten(flat_b, len(ks), len(ns)) + + c = [[a[0][0].new_empty((m, n)) for n in ns] for m in ms] + tiled_matmul_out(a, b, out=c) + + return _flatten(c, len(ms), len(ns)) + + +@torch.library.register_fake("xformers_python::tiled_matmul_fwd") +def tiled_matmul_fwd_fake( + flat_a: List[torch.Tensor], + flat_b: List[torch.Tensor], + ms: List[int], + ns: List[int], + ks: List[int], +) -> List[torch.Tensor]: + c = [[flat_a[0][0].new_empty((m, n)) for n in ns] for m in ms] + return _flatten(c, len(ms), len(ks)) + + +def tiled_matmul_setup_context(ctx, inputs, output): + flat_a, flat_b, ctx.ms, ctx.ns, ctx.ks = inputs + ctx.save_for_backward(*flat_a, *flat_b) + + +def tiled_matmul_bwd(ctx, flat_grad_c): + assert len(ctx.saved_tensors) == len(ctx.ms) * len(ctx.ks) + len(ctx.ks) * len( + ctx.ns + ) + flat_a = ctx.saved_tensors[: len(ctx.ms) * len(ctx.ks)] + flat_b = ctx.saved_tensors[-len(ctx.ks) * len(ctx.ns) :] + + flat_transposed_a = _flattened_transpose(flat_a, len(ctx.ms), len(ctx.ks)) + flat_transposed_b = _flattened_transpose(flat_b, len(ctx.ks), len(ctx.ns)) + + flat_grad_a = tiled_matmul_fwd( + flat_grad_c, flat_transposed_b, ctx.ms, ctx.ks, ctx.ns + ) + flat_grad_b = tiled_matmul_fwd( + flat_transposed_a, flat_grad_c, ctx.ks, ctx.ns, ctx.ms + ) + + return flat_grad_a, flat_grad_b, None, None, None + + +torch.library.register_autograd( + "xformers_python::tiled_matmul_fwd", + tiled_matmul_bwd, + setup_context=tiled_matmul_setup_context, +) + + +def tiled_matmul( + a: List[List[torch.Tensor]], + b: List[List[torch.Tensor]], +) -> List[List[torch.Tensor]]: + """Multiply two matrices given as grids of tiles + + It performs the matmul between A and B, which are given as two-dimensional + grids of tiles (i.e., blocks), represented as lists of lists of tensors. + The output will itself be a matrix in such a form. Formally: + + out[m][n] = sum(a[m][k] @ b[k][n] for k in range(...)) + + with the obvious constraints needed to make it work, in terms of number of + tiles and sizes of each tile. + + The interest of this operator is to improve performance by avoding wave + quantization effects when doing independent matrix multiplications in + series. Sometimes, when these matmuls have one operand in common, this can + also be addressed by concatenating the other operands into a single matrix, + and issuing a single matmul. However this isn't always possible (e.g., might + break the checkpoint format) and it's an anti-pattern, as it obscures the + logic (e.g., changing the modelling code out of performance reasons). This + tiled matmul performs the same computation as if the matrices were merged, + without merging them, simply through a smarter memory addressing scheme. + + The tiled matmul is less generic than a grouped matmul, which can also help + with wave quantization, and doesn't need the matmuls to have the same lhs + or rhs operand. However, a grouped matmul will write the result of each + matmul to a separate output matrix, whereas the tiled matmul allows to add + them together into a single output. This is needed during the backward pass + of a linear layer, and it's the reason we wrote this instead of using a + grouped matmul. + + The tiled matmul is implemented using a custom Triton kernel, which puts + constraints on the strides of the tiles. All rows of A must have the same + K stride, all columns of A must have the same M stride, and so on. + + Currently the tiled matmul supports at most three tiles on each dimension, + although fewer can also be given. This is because we needed it to fuse the + query, key and value weights of an attention layer. This limit can be + increased if needed. + + This operator is differentiable. + + """ + # Inputs are checked inside the op, but we check them as well to make sure + # that they are "regular" and can be flattened. + ms, ns, ks = check_inputs(a, b) + flat_a = _flatten(a, len(ms), len(ks)) + flat_b = _flatten(b, len(ks), len(ns)) + flat_c = tiled_matmul_fwd(flat_a, flat_b, ms, ns, ks) + c = _unflatten(flat_c, len(ms), len(ns)) + return c diff --git a/.venv/lib/python3.11/site-packages/xformers/ops/unbind.py b/.venv/lib/python3.11/site-packages/xformers/ops/unbind.py new file mode 100644 index 0000000000000000000000000000000000000000..4e22ae0574c7d785c9137677181a249fdaa489aa --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/ops/unbind.py @@ -0,0 +1,129 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Optional, Sequence, Tuple, Union + +import torch + +from .common import _get_storage_base + + +def get_stack_strides( + tensors: Sequence[torch.Tensor], dim: int +) -> Optional[Tuple[Union[int, torch.SymInt], ...]]: + """ + If the tensors are already stacked on dimension :code:`dim`, \ + returns the strides of the stacked tensors. \ + Otherwise returns :code:`None`. + """ + if len(tensors) <= 1 or dim > tensors[0].ndim: + return None + + final_stride = [] + for i in range(tensors[0].ndim + 1): + if i == dim: + # PyTorch 2.5 messed up the type annotations for SymInt, but 2.6 will fix it + # https://github.com/pytorch/pytorch/issues/138478 + final_stride.append( + tensors[1].storage_offset() - tensors[0].storage_offset() # type: ignore[operator] + ) + continue + if i > dim: + i -= 1 + final_stride.append(tensors[0].stride(i)) + + storage_data_ptr: Optional[int] = None + for i, x in enumerate(tensors[1:]): + # Sanity checks + if x.shape != tensors[0].shape: + return None + if x.stride() != tensors[0].stride(): + return None + # PyTorch 2.5 messed up the type annotations for SymInt, but 2.6 will fix it + # https://github.com/pytorch/pytorch/issues/138478 + if ( + x.storage_offset() + != tensors[0].storage_offset() + (i + 1) * final_stride[dim] # type: ignore[operator] + ): + return None + if storage_data_ptr is None: + storage_data_ptr = _get_storage_base(tensors[0]) + # Actual storage check + if _get_storage_base(x) != storage_data_ptr: + return None + return tuple(final_stride) + + +def _stack_or_none_fw( + tensors: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]], + dim: int, +) -> Optional[torch.Tensor]: + strides = get_stack_strides(tensors, dim) + if strides is not None: + input_shape = list(tensors[0].shape) + input_shape.insert(dim, len(tensors)) + return tensors[0].as_strided(input_shape, strides) + return None + + +def _stack_fw( + tensors: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]], + dim: int, +) -> torch.Tensor: + out = _stack_or_none_fw(tensors, dim) + if out is None: + out = torch.stack(tensors, dim=dim) + return out + + +class _Unbind(torch.autograd.Function): + """ + See function `unbind` + """ + + @staticmethod + # type: ignore + def forward(ctx, x: torch.Tensor, dim: int): + ctx.dim = dim + return x.unbind(dim) + + @classmethod + # type: ignore + def backward(cls, ctx, *tensors: torch.Tensor): + return _stack_fw(tensors, ctx.dim), None + + +class _StackOrNone(torch.autograd.Function): + """ + See function `stack_or_none` + """ + + @staticmethod + # type: ignore + def forward(ctx, dim: int, *tensors: torch.Tensor): + ctx.dim = dim + return _stack_or_none_fw(tensors, dim=dim) + + @classmethod + # type: ignore + def backward(cls, ctx, grad: torch.Tensor): + return (None, *grad.unbind(dim=ctx.dim)) + + +def unbind(x: torch.Tensor, dim: int) -> Tuple[torch.Tensor, ...]: + """ + Does exactly the same as :attr:`torch.unbind` for the forward. + In backward, avoids a :attr:`torch.cat` if the gradients + are already multiple views of the same storage + """ + return _Unbind.apply(x, dim) + + +def stack_or_none(tensors: Sequence[torch.Tensor], dim: int) -> torch.Tensor: + """ + Does exactly the same as :attr:`torch.stack` if the tensors can be concatenated + without any memory operation. Otherwise returns None. + """ + return _StackOrNone.apply(dim, *tensors) diff --git a/.venv/lib/python3.11/site-packages/xformers/test.py b/.venv/lib/python3.11/site-packages/xformers/test.py new file mode 100644 index 0000000000000000000000000000000000000000..6677db08c0dcd87f7570a3249f61cc3eabe0532a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/test.py @@ -0,0 +1,4 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. diff --git a/.venv/lib/python3.11/site-packages/xformers/utils.py b/.venv/lib/python3.11/site-packages/xformers/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..69515ece8f9cebeb02868156ab7c3e0abdb80489 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/utils.py @@ -0,0 +1,145 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import importlib +import os +import sys +from collections import namedtuple +from dataclasses import fields +from typing import Any, Callable, Dict, List, Optional + +import torch + +Item = namedtuple("Item", ["constructor", "config"]) + + +# credit: snippet used in ClassyVision (and probably other places) +def import_all_modules(root: str, base_module: str) -> List[str]: + modules: List[str] = [] + for file in os.listdir(root): + if file.endswith((".py", ".pyc")) and not file.startswith("_"): + module = file[: file.find(".py")] + if module not in sys.modules: + module_name = ".".join([base_module, module]) + importlib.import_module(module_name) + modules.append(module_name) + + return modules + + +def get_registry_decorator( + class_registry, name_registry, reference_class, default_config +) -> Callable[[str, Any], Callable[[Any], Any]]: + def register_item(name: str, config: Any = default_config): + """Registers a subclass. + + This decorator allows xFormers to instantiate a given subclass + from a configuration file, even if the class itself is not part of the + xFormers library.""" + + def register_cls(cls): + if name in class_registry: + raise ValueError("Cannot register duplicate item ({})".format(name)) + if not issubclass(cls, reference_class): + raise ValueError( + "Item ({}: {}) must extend the base class: {}".format( + name, cls.__name__, reference_class.__name__ + ) + ) + if cls.__name__ in name_registry: + raise ValueError( + "Cannot register item with duplicate class name ({})".format( + cls.__name__ + ) + ) + + class_registry[name] = Item(constructor=cls, config=config) + name_registry.add(cls.__name__) + return cls + + return register_cls + + return register_item + + +def generate_matching_config(superset: Dict[str, Any], config_class: Any) -> Any: + """Given a superset of the inputs and a reference config class, + return exactly the needed config""" + + # Extract the required fields + field_names = list(map(lambda x: x.name, fields(config_class))) + subset = {k: v for k, v in superset.items() if k in field_names} + + # The missing fields get Noned + for k in field_names: + if k not in subset.keys(): + subset[k] = None + + return config_class(**subset) + + +# from https://github.com/openai/triton/blob/95d9b7f4ae21710dc899e1de6a579b2136ea4f3d/python/triton/testing.py#L19 +def do_bench_cudagraph( + fn: Callable, rep: int = 20, grad_to_none: Optional[List[torch.Tensor]] = None +) -> float: + """ + Benchmark the runtime of the provided function. + Args: + fn: Function to benchmark + rep: Repetition time (in ms) + grad_to_none: Reset the gradient of the provided tensor to None + Returns: + Benchmarked runtime in ms + """ + if torch.cuda.current_stream() == torch.cuda.default_stream(): + raise RuntimeError( + "Cannot capture graph in default stream. " + "Please use side stream in benchmark code." + ) + # warmup + fn() + # step 1 - we estimate the amount of time the kernel call takes + # NOTE: this estimate isn't super accurate because the GPU isn't warmed up at this point + # but it is probably good enough + if grad_to_none is not None: + for x in grad_to_none: + x.detach_() + x.requires_grad_(True) + x.grad = None + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + fn() + torch.cuda.synchronize() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + g.replay() + end_event.record() + torch.cuda.synchronize() + estimate_ms = start_event.elapsed_time(end_event) + n_repeat = max(1, int(rep / estimate_ms)) + # step 2 - construct a cuda graph with `n_repeat` unrolled function calls to minimize + # host overhead + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + for i in range(n_repeat): + if grad_to_none is not None: + for x in grad_to_none: + x.grad = None + fn() + torch.cuda.synchronize() + # measure time and return + ret = [] + n_retries = 10 + for i in range(n_retries): + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + g.replay() + end_event.record() + torch.cuda.synchronize() + ret += [start_event.elapsed_time(end_event) / n_repeat] + return torch.mean(torch.tensor(ret)).item() diff --git a/.venv/lib/python3.11/site-packages/xformers/version.py b/.venv/lib/python3.11/site-packages/xformers/version.py new file mode 100644 index 0000000000000000000000000000000000000000..076b3e9d726293d961c115bb0053739d3ad8aa4d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/version.py @@ -0,0 +1,2 @@ +# noqa: C801 +__version__ = "0.0.28.post3" diff --git a/llm_tutorial/llm_recipes/models/hf-model-eval/llm-jp-v3-3.7b_zh-ja_3M-pairs/iter_0000991/model-00004-of-00004.safetensors b/llm_tutorial/llm_recipes/models/hf-model-eval/llm-jp-v3-3.7b_zh-ja_3M-pairs/iter_0000991/model-00004-of-00004.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..bd0b05c8e6710b12ce67a70eb10b60e320beb5db --- /dev/null +++ b/llm_tutorial/llm_recipes/models/hf-model-eval/llm-jp-v3-3.7b_zh-ja_3M-pairs/iter_0000991/model-00004-of-00004.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2adb9665e8356b92502d26a7155bee48911a53b4005fbcb8bf7dba118b4a390f +size 1223688320