Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .venv/lib/python3.11/site-packages/anyio/_core/__pycache__/_eventloop.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/anyio/_core/__pycache__/_fileio.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/anyio/_core/__pycache__/_streams.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/anyio/_core/__pycache__/_testing.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/anyio/abc/__init__.py +55 -0
- .venv/lib/python3.11/site-packages/anyio/abc/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/anyio/abc/__pycache__/_eventloop.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/anyio/abc/__pycache__/_resources.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/anyio/abc/__pycache__/_sockets.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/anyio/abc/__pycache__/_streams.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/anyio/abc/__pycache__/_subprocesses.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/anyio/abc/__pycache__/_tasks.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/anyio/abc/__pycache__/_testing.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/anyio/abc/_resources.py +33 -0
- .venv/lib/python3.11/site-packages/anyio/abc/_sockets.py +194 -0
- .venv/lib/python3.11/site-packages/anyio/abc/_streams.py +203 -0
- .venv/lib/python3.11/site-packages/anyio/abc/_tasks.py +101 -0
- .venv/lib/python3.11/site-packages/jsonschema_specifications-2024.10.1.dist-info/INSTALLER +1 -0
- .venv/lib/python3.11/site-packages/jsonschema_specifications-2024.10.1.dist-info/METADATA +55 -0
- .venv/lib/python3.11/site-packages/jsonschema_specifications-2024.10.1.dist-info/RECORD +33 -0
- .venv/lib/python3.11/site-packages/jsonschema_specifications-2024.10.1.dist-info/WHEEL +4 -0
- .venv/lib/python3.11/site-packages/jsonschema_specifications-2024.10.1.dist-info/licenses/COPYING +19 -0
- .venv/lib/python3.11/site-packages/xformers/__init__.py +73 -0
- .venv/lib/python3.11/site-packages/xformers/_cpp_lib.py +155 -0
- .venv/lib/python3.11/site-packages/xformers/_deprecation_warning.py +12 -0
- .venv/lib/python3.11/site-packages/xformers/attn_bias_utils.py +501 -0
- .venv/lib/python3.11/site-packages/xformers/checkpoint.py +546 -0
- .venv/lib/python3.11/site-packages/xformers/cpp_lib.json +1 -0
- .venv/lib/python3.11/site-packages/xformers/factory/__init__.py +11 -0
- .venv/lib/python3.11/site-packages/xformers/factory/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/factory/__pycache__/block_configs.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/factory/__pycache__/block_factory.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/factory/__pycache__/hydra_helper.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/factory/__pycache__/model_factory.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/factory/__pycache__/weight_init.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/factory/block_configs.py +237 -0
- .venv/lib/python3.11/site-packages/xformers/factory/block_factory.py +358 -0
- .venv/lib/python3.11/site-packages/xformers/factory/hydra_helper.py +36 -0
- .venv/lib/python3.11/site-packages/xformers/factory/model_factory.py +313 -0
- .venv/lib/python3.11/site-packages/xformers/factory/weight_init.py +293 -0
- .venv/lib/python3.11/site-packages/xformers/info.py +77 -0
- .venv/lib/python3.11/site-packages/xformers/ops/__pycache__/common.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/ops/__pycache__/differentiable_collectives.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/ops/__pycache__/indexing.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/ops/__pycache__/ipc.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/ops/__pycache__/modpar_layers.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/ops/__pycache__/rmsnorm.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/ops/__pycache__/rope_padded.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/ops/__pycache__/seqpar.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/ops/__pycache__/sequence_parallel_fused_ops.cpython-311.pyc +0 -0
.venv/lib/python3.11/site-packages/anyio/_core/__pycache__/_eventloop.cpython-311.pyc
ADDED
|
Binary file (6.92 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/anyio/_core/__pycache__/_fileio.cpython-311.pyc
ADDED
|
Binary file (41.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/anyio/_core/__pycache__/_streams.cpython-311.pyc
ADDED
|
Binary file (2.65 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/anyio/_core/__pycache__/_testing.cpython-311.pyc
ADDED
|
Binary file (3.86 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/anyio/abc/__init__.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from ._eventloop import AsyncBackend as AsyncBackend
|
| 4 |
+
from ._resources import AsyncResource as AsyncResource
|
| 5 |
+
from ._sockets import ConnectedUDPSocket as ConnectedUDPSocket
|
| 6 |
+
from ._sockets import ConnectedUNIXDatagramSocket as ConnectedUNIXDatagramSocket
|
| 7 |
+
from ._sockets import IPAddressType as IPAddressType
|
| 8 |
+
from ._sockets import IPSockAddrType as IPSockAddrType
|
| 9 |
+
from ._sockets import SocketAttribute as SocketAttribute
|
| 10 |
+
from ._sockets import SocketListener as SocketListener
|
| 11 |
+
from ._sockets import SocketStream as SocketStream
|
| 12 |
+
from ._sockets import UDPPacketType as UDPPacketType
|
| 13 |
+
from ._sockets import UDPSocket as UDPSocket
|
| 14 |
+
from ._sockets import UNIXDatagramPacketType as UNIXDatagramPacketType
|
| 15 |
+
from ._sockets import UNIXDatagramSocket as UNIXDatagramSocket
|
| 16 |
+
from ._sockets import UNIXSocketStream as UNIXSocketStream
|
| 17 |
+
from ._streams import AnyByteReceiveStream as AnyByteReceiveStream
|
| 18 |
+
from ._streams import AnyByteSendStream as AnyByteSendStream
|
| 19 |
+
from ._streams import AnyByteStream as AnyByteStream
|
| 20 |
+
from ._streams import AnyUnreliableByteReceiveStream as AnyUnreliableByteReceiveStream
|
| 21 |
+
from ._streams import AnyUnreliableByteSendStream as AnyUnreliableByteSendStream
|
| 22 |
+
from ._streams import AnyUnreliableByteStream as AnyUnreliableByteStream
|
| 23 |
+
from ._streams import ByteReceiveStream as ByteReceiveStream
|
| 24 |
+
from ._streams import ByteSendStream as ByteSendStream
|
| 25 |
+
from ._streams import ByteStream as ByteStream
|
| 26 |
+
from ._streams import Listener as Listener
|
| 27 |
+
from ._streams import ObjectReceiveStream as ObjectReceiveStream
|
| 28 |
+
from ._streams import ObjectSendStream as ObjectSendStream
|
| 29 |
+
from ._streams import ObjectStream as ObjectStream
|
| 30 |
+
from ._streams import UnreliableObjectReceiveStream as UnreliableObjectReceiveStream
|
| 31 |
+
from ._streams import UnreliableObjectSendStream as UnreliableObjectSendStream
|
| 32 |
+
from ._streams import UnreliableObjectStream as UnreliableObjectStream
|
| 33 |
+
from ._subprocesses import Process as Process
|
| 34 |
+
from ._tasks import TaskGroup as TaskGroup
|
| 35 |
+
from ._tasks import TaskStatus as TaskStatus
|
| 36 |
+
from ._testing import TestRunner as TestRunner
|
| 37 |
+
|
| 38 |
+
# Re-exported here, for backwards compatibility
|
| 39 |
+
# isort: off
|
| 40 |
+
from .._core._synchronization import (
|
| 41 |
+
CapacityLimiter as CapacityLimiter,
|
| 42 |
+
Condition as Condition,
|
| 43 |
+
Event as Event,
|
| 44 |
+
Lock as Lock,
|
| 45 |
+
Semaphore as Semaphore,
|
| 46 |
+
)
|
| 47 |
+
from .._core._tasks import CancelScope as CancelScope
|
| 48 |
+
from ..from_thread import BlockingPortal as BlockingPortal
|
| 49 |
+
|
| 50 |
+
# Re-export imports so they look like they live directly in this package
|
| 51 |
+
for __value in list(locals().values()):
|
| 52 |
+
if getattr(__value, "__module__", "").startswith("anyio.abc."):
|
| 53 |
+
__value.__module__ = __name__
|
| 54 |
+
|
| 55 |
+
del __value
|
.venv/lib/python3.11/site-packages/anyio/abc/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (2.89 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/anyio/abc/__pycache__/_eventloop.cpython-311.pyc
ADDED
|
Binary file (16.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/anyio/abc/__pycache__/_resources.cpython-311.pyc
ADDED
|
Binary file (1.82 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/anyio/abc/__pycache__/_sockets.cpython-311.pyc
ADDED
|
Binary file (11.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/anyio/abc/__pycache__/_streams.cpython-311.pyc
ADDED
|
Binary file (9.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/anyio/abc/__pycache__/_subprocesses.cpython-311.pyc
ADDED
|
Binary file (3.66 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/anyio/abc/__pycache__/_tasks.cpython-311.pyc
ADDED
|
Binary file (4.99 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/anyio/abc/__pycache__/_testing.cpython-311.pyc
ADDED
|
Binary file (3.02 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/anyio/abc/_resources.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from abc import ABCMeta, abstractmethod
|
| 4 |
+
from types import TracebackType
|
| 5 |
+
from typing import TypeVar
|
| 6 |
+
|
| 7 |
+
T = TypeVar("T")
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class AsyncResource(metaclass=ABCMeta):
|
| 11 |
+
"""
|
| 12 |
+
Abstract base class for all closeable asynchronous resources.
|
| 13 |
+
|
| 14 |
+
Works as an asynchronous context manager which returns the instance itself on enter,
|
| 15 |
+
and calls :meth:`aclose` on exit.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
__slots__ = ()
|
| 19 |
+
|
| 20 |
+
async def __aenter__(self: T) -> T:
|
| 21 |
+
return self
|
| 22 |
+
|
| 23 |
+
async def __aexit__(
|
| 24 |
+
self,
|
| 25 |
+
exc_type: type[BaseException] | None,
|
| 26 |
+
exc_val: BaseException | None,
|
| 27 |
+
exc_tb: TracebackType | None,
|
| 28 |
+
) -> None:
|
| 29 |
+
await self.aclose()
|
| 30 |
+
|
| 31 |
+
@abstractmethod
|
| 32 |
+
async def aclose(self) -> None:
|
| 33 |
+
"""Close the resource."""
|
.venv/lib/python3.11/site-packages/anyio/abc/_sockets.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import socket
|
| 4 |
+
from abc import abstractmethod
|
| 5 |
+
from collections.abc import Callable, Collection, Mapping
|
| 6 |
+
from contextlib import AsyncExitStack
|
| 7 |
+
from io import IOBase
|
| 8 |
+
from ipaddress import IPv4Address, IPv6Address
|
| 9 |
+
from socket import AddressFamily
|
| 10 |
+
from types import TracebackType
|
| 11 |
+
from typing import Any, TypeVar, Union
|
| 12 |
+
|
| 13 |
+
from .._core._typedattr import (
|
| 14 |
+
TypedAttributeProvider,
|
| 15 |
+
TypedAttributeSet,
|
| 16 |
+
typed_attribute,
|
| 17 |
+
)
|
| 18 |
+
from ._streams import ByteStream, Listener, UnreliableObjectStream
|
| 19 |
+
from ._tasks import TaskGroup
|
| 20 |
+
|
| 21 |
+
IPAddressType = Union[str, IPv4Address, IPv6Address]
|
| 22 |
+
IPSockAddrType = tuple[str, int]
|
| 23 |
+
SockAddrType = Union[IPSockAddrType, str]
|
| 24 |
+
UDPPacketType = tuple[bytes, IPSockAddrType]
|
| 25 |
+
UNIXDatagramPacketType = tuple[bytes, str]
|
| 26 |
+
T_Retval = TypeVar("T_Retval")
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class _NullAsyncContextManager:
|
| 30 |
+
async def __aenter__(self) -> None:
|
| 31 |
+
pass
|
| 32 |
+
|
| 33 |
+
async def __aexit__(
|
| 34 |
+
self,
|
| 35 |
+
exc_type: type[BaseException] | None,
|
| 36 |
+
exc_val: BaseException | None,
|
| 37 |
+
exc_tb: TracebackType | None,
|
| 38 |
+
) -> bool | None:
|
| 39 |
+
return None
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class SocketAttribute(TypedAttributeSet):
|
| 43 |
+
#: the address family of the underlying socket
|
| 44 |
+
family: AddressFamily = typed_attribute()
|
| 45 |
+
#: the local socket address of the underlying socket
|
| 46 |
+
local_address: SockAddrType = typed_attribute()
|
| 47 |
+
#: for IP addresses, the local port the underlying socket is bound to
|
| 48 |
+
local_port: int = typed_attribute()
|
| 49 |
+
#: the underlying stdlib socket object
|
| 50 |
+
raw_socket: socket.socket = typed_attribute()
|
| 51 |
+
#: the remote address the underlying socket is connected to
|
| 52 |
+
remote_address: SockAddrType = typed_attribute()
|
| 53 |
+
#: for IP addresses, the remote port the underlying socket is connected to
|
| 54 |
+
remote_port: int = typed_attribute()
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class _SocketProvider(TypedAttributeProvider):
|
| 58 |
+
@property
|
| 59 |
+
def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
|
| 60 |
+
from .._core._sockets import convert_ipv6_sockaddr as convert
|
| 61 |
+
|
| 62 |
+
attributes: dict[Any, Callable[[], Any]] = {
|
| 63 |
+
SocketAttribute.family: lambda: self._raw_socket.family,
|
| 64 |
+
SocketAttribute.local_address: lambda: convert(
|
| 65 |
+
self._raw_socket.getsockname()
|
| 66 |
+
),
|
| 67 |
+
SocketAttribute.raw_socket: lambda: self._raw_socket,
|
| 68 |
+
}
|
| 69 |
+
try:
|
| 70 |
+
peername: tuple[str, int] | None = convert(self._raw_socket.getpeername())
|
| 71 |
+
except OSError:
|
| 72 |
+
peername = None
|
| 73 |
+
|
| 74 |
+
# Provide the remote address for connected sockets
|
| 75 |
+
if peername is not None:
|
| 76 |
+
attributes[SocketAttribute.remote_address] = lambda: peername
|
| 77 |
+
|
| 78 |
+
# Provide local and remote ports for IP based sockets
|
| 79 |
+
if self._raw_socket.family in (AddressFamily.AF_INET, AddressFamily.AF_INET6):
|
| 80 |
+
attributes[SocketAttribute.local_port] = (
|
| 81 |
+
lambda: self._raw_socket.getsockname()[1]
|
| 82 |
+
)
|
| 83 |
+
if peername is not None:
|
| 84 |
+
remote_port = peername[1]
|
| 85 |
+
attributes[SocketAttribute.remote_port] = lambda: remote_port
|
| 86 |
+
|
| 87 |
+
return attributes
|
| 88 |
+
|
| 89 |
+
@property
|
| 90 |
+
@abstractmethod
|
| 91 |
+
def _raw_socket(self) -> socket.socket:
|
| 92 |
+
pass
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class SocketStream(ByteStream, _SocketProvider):
|
| 96 |
+
"""
|
| 97 |
+
Transports bytes over a socket.
|
| 98 |
+
|
| 99 |
+
Supports all relevant extra attributes from :class:`~SocketAttribute`.
|
| 100 |
+
"""
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class UNIXSocketStream(SocketStream):
|
| 104 |
+
@abstractmethod
|
| 105 |
+
async def send_fds(self, message: bytes, fds: Collection[int | IOBase]) -> None:
|
| 106 |
+
"""
|
| 107 |
+
Send file descriptors along with a message to the peer.
|
| 108 |
+
|
| 109 |
+
:param message: a non-empty bytestring
|
| 110 |
+
:param fds: a collection of files (either numeric file descriptors or open file
|
| 111 |
+
or socket objects)
|
| 112 |
+
"""
|
| 113 |
+
|
| 114 |
+
@abstractmethod
|
| 115 |
+
async def receive_fds(self, msglen: int, maxfds: int) -> tuple[bytes, list[int]]:
|
| 116 |
+
"""
|
| 117 |
+
Receive file descriptors along with a message from the peer.
|
| 118 |
+
|
| 119 |
+
:param msglen: length of the message to expect from the peer
|
| 120 |
+
:param maxfds: maximum number of file descriptors to expect from the peer
|
| 121 |
+
:return: a tuple of (message, file descriptors)
|
| 122 |
+
"""
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class SocketListener(Listener[SocketStream], _SocketProvider):
|
| 126 |
+
"""
|
| 127 |
+
Listens to incoming socket connections.
|
| 128 |
+
|
| 129 |
+
Supports all relevant extra attributes from :class:`~SocketAttribute`.
|
| 130 |
+
"""
|
| 131 |
+
|
| 132 |
+
@abstractmethod
|
| 133 |
+
async def accept(self) -> SocketStream:
|
| 134 |
+
"""Accept an incoming connection."""
|
| 135 |
+
|
| 136 |
+
async def serve(
|
| 137 |
+
self,
|
| 138 |
+
handler: Callable[[SocketStream], Any],
|
| 139 |
+
task_group: TaskGroup | None = None,
|
| 140 |
+
) -> None:
|
| 141 |
+
from .. import create_task_group
|
| 142 |
+
|
| 143 |
+
async with AsyncExitStack() as stack:
|
| 144 |
+
if task_group is None:
|
| 145 |
+
task_group = await stack.enter_async_context(create_task_group())
|
| 146 |
+
|
| 147 |
+
while True:
|
| 148 |
+
stream = await self.accept()
|
| 149 |
+
task_group.start_soon(handler, stream)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class UDPSocket(UnreliableObjectStream[UDPPacketType], _SocketProvider):
|
| 153 |
+
"""
|
| 154 |
+
Represents an unconnected UDP socket.
|
| 155 |
+
|
| 156 |
+
Supports all relevant extra attributes from :class:`~SocketAttribute`.
|
| 157 |
+
"""
|
| 158 |
+
|
| 159 |
+
async def sendto(self, data: bytes, host: str, port: int) -> None:
|
| 160 |
+
"""
|
| 161 |
+
Alias for :meth:`~.UnreliableObjectSendStream.send` ((data, (host, port))).
|
| 162 |
+
|
| 163 |
+
"""
|
| 164 |
+
return await self.send((data, (host, port)))
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class ConnectedUDPSocket(UnreliableObjectStream[bytes], _SocketProvider):
|
| 168 |
+
"""
|
| 169 |
+
Represents an connected UDP socket.
|
| 170 |
+
|
| 171 |
+
Supports all relevant extra attributes from :class:`~SocketAttribute`.
|
| 172 |
+
"""
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
class UNIXDatagramSocket(
|
| 176 |
+
UnreliableObjectStream[UNIXDatagramPacketType], _SocketProvider
|
| 177 |
+
):
|
| 178 |
+
"""
|
| 179 |
+
Represents an unconnected Unix datagram socket.
|
| 180 |
+
|
| 181 |
+
Supports all relevant extra attributes from :class:`~SocketAttribute`.
|
| 182 |
+
"""
|
| 183 |
+
|
| 184 |
+
async def sendto(self, data: bytes, path: str) -> None:
|
| 185 |
+
"""Alias for :meth:`~.UnreliableObjectSendStream.send` ((data, path))."""
|
| 186 |
+
return await self.send((data, path))
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
class ConnectedUNIXDatagramSocket(UnreliableObjectStream[bytes], _SocketProvider):
|
| 190 |
+
"""
|
| 191 |
+
Represents a connected Unix datagram socket.
|
| 192 |
+
|
| 193 |
+
Supports all relevant extra attributes from :class:`~SocketAttribute`.
|
| 194 |
+
"""
|
.venv/lib/python3.11/site-packages/anyio/abc/_streams.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from abc import abstractmethod
|
| 4 |
+
from collections.abc import Callable
|
| 5 |
+
from typing import Any, Generic, TypeVar, Union
|
| 6 |
+
|
| 7 |
+
from .._core._exceptions import EndOfStream
|
| 8 |
+
from .._core._typedattr import TypedAttributeProvider
|
| 9 |
+
from ._resources import AsyncResource
|
| 10 |
+
from ._tasks import TaskGroup
|
| 11 |
+
|
| 12 |
+
T_Item = TypeVar("T_Item")
|
| 13 |
+
T_co = TypeVar("T_co", covariant=True)
|
| 14 |
+
T_contra = TypeVar("T_contra", contravariant=True)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class UnreliableObjectReceiveStream(
|
| 18 |
+
Generic[T_co], AsyncResource, TypedAttributeProvider
|
| 19 |
+
):
|
| 20 |
+
"""
|
| 21 |
+
An interface for receiving objects.
|
| 22 |
+
|
| 23 |
+
This interface makes no guarantees that the received messages arrive in the order in
|
| 24 |
+
which they were sent, or that no messages are missed.
|
| 25 |
+
|
| 26 |
+
Asynchronously iterating over objects of this type will yield objects matching the
|
| 27 |
+
given type parameter.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __aiter__(self) -> UnreliableObjectReceiveStream[T_co]:
|
| 31 |
+
return self
|
| 32 |
+
|
| 33 |
+
async def __anext__(self) -> T_co:
|
| 34 |
+
try:
|
| 35 |
+
return await self.receive()
|
| 36 |
+
except EndOfStream:
|
| 37 |
+
raise StopAsyncIteration
|
| 38 |
+
|
| 39 |
+
@abstractmethod
|
| 40 |
+
async def receive(self) -> T_co:
|
| 41 |
+
"""
|
| 42 |
+
Receive the next item.
|
| 43 |
+
|
| 44 |
+
:raises ~anyio.ClosedResourceError: if the receive stream has been explicitly
|
| 45 |
+
closed
|
| 46 |
+
:raises ~anyio.EndOfStream: if this stream has been closed from the other end
|
| 47 |
+
:raises ~anyio.BrokenResourceError: if this stream has been rendered unusable
|
| 48 |
+
due to external causes
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class UnreliableObjectSendStream(
|
| 53 |
+
Generic[T_contra], AsyncResource, TypedAttributeProvider
|
| 54 |
+
):
|
| 55 |
+
"""
|
| 56 |
+
An interface for sending objects.
|
| 57 |
+
|
| 58 |
+
This interface makes no guarantees that the messages sent will reach the
|
| 59 |
+
recipient(s) in the same order in which they were sent, or at all.
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
@abstractmethod
|
| 63 |
+
async def send(self, item: T_contra) -> None:
|
| 64 |
+
"""
|
| 65 |
+
Send an item to the peer(s).
|
| 66 |
+
|
| 67 |
+
:param item: the item to send
|
| 68 |
+
:raises ~anyio.ClosedResourceError: if the send stream has been explicitly
|
| 69 |
+
closed
|
| 70 |
+
:raises ~anyio.BrokenResourceError: if this stream has been rendered unusable
|
| 71 |
+
due to external causes
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class UnreliableObjectStream(
|
| 76 |
+
UnreliableObjectReceiveStream[T_Item], UnreliableObjectSendStream[T_Item]
|
| 77 |
+
):
|
| 78 |
+
"""
|
| 79 |
+
A bidirectional message stream which does not guarantee the order or reliability of
|
| 80 |
+
message delivery.
|
| 81 |
+
"""
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class ObjectReceiveStream(UnreliableObjectReceiveStream[T_co]):
|
| 85 |
+
"""
|
| 86 |
+
A receive message stream which guarantees that messages are received in the same
|
| 87 |
+
order in which they were sent, and that no messages are missed.
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class ObjectSendStream(UnreliableObjectSendStream[T_contra]):
|
| 92 |
+
"""
|
| 93 |
+
A send message stream which guarantees that messages are delivered in the same order
|
| 94 |
+
in which they were sent, without missing any messages in the middle.
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class ObjectStream(
|
| 99 |
+
ObjectReceiveStream[T_Item],
|
| 100 |
+
ObjectSendStream[T_Item],
|
| 101 |
+
UnreliableObjectStream[T_Item],
|
| 102 |
+
):
|
| 103 |
+
"""
|
| 104 |
+
A bidirectional message stream which guarantees the order and reliability of message
|
| 105 |
+
delivery.
|
| 106 |
+
"""
|
| 107 |
+
|
| 108 |
+
@abstractmethod
|
| 109 |
+
async def send_eof(self) -> None:
|
| 110 |
+
"""
|
| 111 |
+
Send an end-of-file indication to the peer.
|
| 112 |
+
|
| 113 |
+
You should not try to send any further data to this stream after calling this
|
| 114 |
+
method. This method is idempotent (does nothing on successive calls).
|
| 115 |
+
"""
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class ByteReceiveStream(AsyncResource, TypedAttributeProvider):
|
| 119 |
+
"""
|
| 120 |
+
An interface for receiving bytes from a single peer.
|
| 121 |
+
|
| 122 |
+
Iterating this byte stream will yield a byte string of arbitrary length, but no more
|
| 123 |
+
than 65536 bytes.
|
| 124 |
+
"""
|
| 125 |
+
|
| 126 |
+
def __aiter__(self) -> ByteReceiveStream:
|
| 127 |
+
return self
|
| 128 |
+
|
| 129 |
+
async def __anext__(self) -> bytes:
|
| 130 |
+
try:
|
| 131 |
+
return await self.receive()
|
| 132 |
+
except EndOfStream:
|
| 133 |
+
raise StopAsyncIteration
|
| 134 |
+
|
| 135 |
+
@abstractmethod
|
| 136 |
+
async def receive(self, max_bytes: int = 65536) -> bytes:
|
| 137 |
+
"""
|
| 138 |
+
Receive at most ``max_bytes`` bytes from the peer.
|
| 139 |
+
|
| 140 |
+
.. note:: Implementors of this interface should not return an empty
|
| 141 |
+
:class:`bytes` object, and users should ignore them.
|
| 142 |
+
|
| 143 |
+
:param max_bytes: maximum number of bytes to receive
|
| 144 |
+
:return: the received bytes
|
| 145 |
+
:raises ~anyio.EndOfStream: if this stream has been closed from the other end
|
| 146 |
+
"""
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
class ByteSendStream(AsyncResource, TypedAttributeProvider):
|
| 150 |
+
"""An interface for sending bytes to a single peer."""
|
| 151 |
+
|
| 152 |
+
@abstractmethod
|
| 153 |
+
async def send(self, item: bytes) -> None:
|
| 154 |
+
"""
|
| 155 |
+
Send the given bytes to the peer.
|
| 156 |
+
|
| 157 |
+
:param item: the bytes to send
|
| 158 |
+
"""
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
class ByteStream(ByteReceiveStream, ByteSendStream):
|
| 162 |
+
"""A bidirectional byte stream."""
|
| 163 |
+
|
| 164 |
+
@abstractmethod
|
| 165 |
+
async def send_eof(self) -> None:
|
| 166 |
+
"""
|
| 167 |
+
Send an end-of-file indication to the peer.
|
| 168 |
+
|
| 169 |
+
You should not try to send any further data to this stream after calling this
|
| 170 |
+
method. This method is idempotent (does nothing on successive calls).
|
| 171 |
+
"""
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
#: Type alias for all unreliable bytes-oriented receive streams.
|
| 175 |
+
AnyUnreliableByteReceiveStream = Union[
|
| 176 |
+
UnreliableObjectReceiveStream[bytes], ByteReceiveStream
|
| 177 |
+
]
|
| 178 |
+
#: Type alias for all unreliable bytes-oriented send streams.
|
| 179 |
+
AnyUnreliableByteSendStream = Union[UnreliableObjectSendStream[bytes], ByteSendStream]
|
| 180 |
+
#: Type alias for all unreliable bytes-oriented streams.
|
| 181 |
+
AnyUnreliableByteStream = Union[UnreliableObjectStream[bytes], ByteStream]
|
| 182 |
+
#: Type alias for all bytes-oriented receive streams.
|
| 183 |
+
AnyByteReceiveStream = Union[ObjectReceiveStream[bytes], ByteReceiveStream]
|
| 184 |
+
#: Type alias for all bytes-oriented send streams.
|
| 185 |
+
AnyByteSendStream = Union[ObjectSendStream[bytes], ByteSendStream]
|
| 186 |
+
#: Type alias for all bytes-oriented streams.
|
| 187 |
+
AnyByteStream = Union[ObjectStream[bytes], ByteStream]
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
class Listener(Generic[T_co], AsyncResource, TypedAttributeProvider):
|
| 191 |
+
"""An interface for objects that let you accept incoming connections."""
|
| 192 |
+
|
| 193 |
+
@abstractmethod
|
| 194 |
+
async def serve(
|
| 195 |
+
self, handler: Callable[[T_co], Any], task_group: TaskGroup | None = None
|
| 196 |
+
) -> None:
|
| 197 |
+
"""
|
| 198 |
+
Accept incoming connections as they come in and start tasks to handle them.
|
| 199 |
+
|
| 200 |
+
:param handler: a callable that will be used to handle each accepted connection
|
| 201 |
+
:param task_group: the task group that will be used to start tasks for handling
|
| 202 |
+
each accepted connection (if omitted, an ad-hoc task group will be created)
|
| 203 |
+
"""
|
.venv/lib/python3.11/site-packages/anyio/abc/_tasks.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import sys
|
| 4 |
+
from abc import ABCMeta, abstractmethod
|
| 5 |
+
from collections.abc import Awaitable, Callable
|
| 6 |
+
from types import TracebackType
|
| 7 |
+
from typing import TYPE_CHECKING, Any, Protocol, TypeVar, overload
|
| 8 |
+
|
| 9 |
+
if sys.version_info >= (3, 11):
|
| 10 |
+
from typing import TypeVarTuple, Unpack
|
| 11 |
+
else:
|
| 12 |
+
from typing_extensions import TypeVarTuple, Unpack
|
| 13 |
+
|
| 14 |
+
if TYPE_CHECKING:
|
| 15 |
+
from .._core._tasks import CancelScope
|
| 16 |
+
|
| 17 |
+
T_Retval = TypeVar("T_Retval")
|
| 18 |
+
T_contra = TypeVar("T_contra", contravariant=True)
|
| 19 |
+
PosArgsT = TypeVarTuple("PosArgsT")
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class TaskStatus(Protocol[T_contra]):
|
| 23 |
+
@overload
|
| 24 |
+
def started(self: TaskStatus[None]) -> None: ...
|
| 25 |
+
|
| 26 |
+
@overload
|
| 27 |
+
def started(self, value: T_contra) -> None: ...
|
| 28 |
+
|
| 29 |
+
def started(self, value: T_contra | None = None) -> None:
|
| 30 |
+
"""
|
| 31 |
+
Signal that the task has started.
|
| 32 |
+
|
| 33 |
+
:param value: object passed back to the starter of the task
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class TaskGroup(metaclass=ABCMeta):
|
| 38 |
+
"""
|
| 39 |
+
Groups several asynchronous tasks together.
|
| 40 |
+
|
| 41 |
+
:ivar cancel_scope: the cancel scope inherited by all child tasks
|
| 42 |
+
:vartype cancel_scope: CancelScope
|
| 43 |
+
|
| 44 |
+
.. note:: On asyncio, support for eager task factories is considered to be
|
| 45 |
+
**experimental**. In particular, they don't follow the usual semantics of new
|
| 46 |
+
tasks being scheduled on the next iteration of the event loop, and may thus
|
| 47 |
+
cause unexpected behavior in code that wasn't written with such semantics in
|
| 48 |
+
mind.
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
cancel_scope: CancelScope
|
| 52 |
+
|
| 53 |
+
@abstractmethod
|
| 54 |
+
def start_soon(
|
| 55 |
+
self,
|
| 56 |
+
func: Callable[[Unpack[PosArgsT]], Awaitable[Any]],
|
| 57 |
+
*args: Unpack[PosArgsT],
|
| 58 |
+
name: object = None,
|
| 59 |
+
) -> None:
|
| 60 |
+
"""
|
| 61 |
+
Start a new task in this task group.
|
| 62 |
+
|
| 63 |
+
:param func: a coroutine function
|
| 64 |
+
:param args: positional arguments to call the function with
|
| 65 |
+
:param name: name of the task, for the purposes of introspection and debugging
|
| 66 |
+
|
| 67 |
+
.. versionadded:: 3.0
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
@abstractmethod
|
| 71 |
+
async def start(
|
| 72 |
+
self,
|
| 73 |
+
func: Callable[..., Awaitable[Any]],
|
| 74 |
+
*args: object,
|
| 75 |
+
name: object = None,
|
| 76 |
+
) -> Any:
|
| 77 |
+
"""
|
| 78 |
+
Start a new task and wait until it signals for readiness.
|
| 79 |
+
|
| 80 |
+
:param func: a coroutine function
|
| 81 |
+
:param args: positional arguments to call the function with
|
| 82 |
+
:param name: name of the task, for the purposes of introspection and debugging
|
| 83 |
+
:return: the value passed to ``task_status.started()``
|
| 84 |
+
:raises RuntimeError: if the task finishes without calling
|
| 85 |
+
``task_status.started()``
|
| 86 |
+
|
| 87 |
+
.. versionadded:: 3.0
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
@abstractmethod
|
| 91 |
+
async def __aenter__(self) -> TaskGroup:
|
| 92 |
+
"""Enter the task group context and allow starting new tasks."""
|
| 93 |
+
|
| 94 |
+
@abstractmethod
|
| 95 |
+
async def __aexit__(
|
| 96 |
+
self,
|
| 97 |
+
exc_type: type[BaseException] | None,
|
| 98 |
+
exc_val: BaseException | None,
|
| 99 |
+
exc_tb: TracebackType | None,
|
| 100 |
+
) -> bool | None:
|
| 101 |
+
"""Exit the task group context waiting for all tasks to finish."""
|
.venv/lib/python3.11/site-packages/jsonschema_specifications-2024.10.1.dist-info/INSTALLER
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
pip
|
.venv/lib/python3.11/site-packages/jsonschema_specifications-2024.10.1.dist-info/METADATA
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Metadata-Version: 2.3
|
| 2 |
+
Name: jsonschema-specifications
|
| 3 |
+
Version: 2024.10.1
|
| 4 |
+
Summary: The JSON Schema meta-schemas and vocabularies, exposed as a Registry
|
| 5 |
+
Project-URL: Documentation, https://jsonschema-specifications.readthedocs.io/
|
| 6 |
+
Project-URL: Homepage, https://github.com/python-jsonschema/jsonschema-specifications
|
| 7 |
+
Project-URL: Issues, https://github.com/python-jsonschema/jsonschema-specifications/issues/
|
| 8 |
+
Project-URL: Funding, https://github.com/sponsors/Julian
|
| 9 |
+
Project-URL: Tidelift, https://tidelift.com/subscription/pkg/pypi-jsonschema-specifications?utm_source=pypi-jsonschema-specifications&utm_medium=referral&utm_campaign=pypi-link
|
| 10 |
+
Project-URL: Source, https://github.com/python-jsonschema/jsonschema-specifications
|
| 11 |
+
Author-email: Julian Berman <Julian+jsonschema-specifications@GrayVines.com>
|
| 12 |
+
License-File: COPYING
|
| 13 |
+
Keywords: data validation,json,json schema,jsonschema,validation
|
| 14 |
+
Classifier: Development Status :: 5 - Production/Stable
|
| 15 |
+
Classifier: Intended Audience :: Developers
|
| 16 |
+
Classifier: License :: OSI Approved :: MIT License
|
| 17 |
+
Classifier: Operating System :: OS Independent
|
| 18 |
+
Classifier: Programming Language :: Python
|
| 19 |
+
Classifier: Programming Language :: Python :: 3.8
|
| 20 |
+
Classifier: Programming Language :: Python :: 3.9
|
| 21 |
+
Classifier: Programming Language :: Python :: 3.10
|
| 22 |
+
Classifier: Programming Language :: Python :: 3.11
|
| 23 |
+
Classifier: Programming Language :: Python :: 3.12
|
| 24 |
+
Classifier: Programming Language :: Python :: 3.13
|
| 25 |
+
Classifier: Programming Language :: Python :: Implementation :: CPython
|
| 26 |
+
Classifier: Programming Language :: Python :: Implementation :: PyPy
|
| 27 |
+
Classifier: Topic :: File Formats :: JSON
|
| 28 |
+
Classifier: Topic :: File Formats :: JSON :: JSON Schema
|
| 29 |
+
Requires-Python: >=3.9
|
| 30 |
+
Requires-Dist: referencing>=0.31.0
|
| 31 |
+
Description-Content-Type: text/x-rst
|
| 32 |
+
|
| 33 |
+
=============================
|
| 34 |
+
``jsonschema-specifications``
|
| 35 |
+
=============================
|
| 36 |
+
|
| 37 |
+
|PyPI| |Pythons| |CI| |ReadTheDocs|
|
| 38 |
+
|
| 39 |
+
JSON support files from the `JSON Schema Specifications <https://json-schema.org/specification.html>`_ (metaschemas, vocabularies, etc.), packaged for runtime access from Python as a `referencing-based Schema Registry <https://referencing.readthedocs.io/en/stable/api/#referencing.Registry>`_.
|
| 40 |
+
|
| 41 |
+
.. |PyPI| image:: https://img.shields.io/pypi/v/jsonschema-specifications.svg
|
| 42 |
+
:alt: PyPI version
|
| 43 |
+
:target: https://pypi.org/project/jsonschema-specifications/
|
| 44 |
+
|
| 45 |
+
.. |Pythons| image:: https://img.shields.io/pypi/pyversions/jsonschema-specifications.svg
|
| 46 |
+
:alt: Supported Python versions
|
| 47 |
+
:target: https://pypi.org/project/jsonschema-specifications/
|
| 48 |
+
|
| 49 |
+
.. |CI| image:: https://github.com/python-jsonschema/jsonschema-specifications/workflows/CI/badge.svg
|
| 50 |
+
:alt: Build status
|
| 51 |
+
:target: https://github.com/python-jsonschema/jsonschema-specifications/actions?query=workflow%3ACI
|
| 52 |
+
|
| 53 |
+
.. |ReadTheDocs| image:: https://readthedocs.org/projects/jsonschema-specifications/badge/?version=stable&style=flat
|
| 54 |
+
:alt: ReadTheDocs status
|
| 55 |
+
:target: https://jsonschema-specifications.readthedocs.io/en/stable/
|
.venv/lib/python3.11/site-packages/jsonschema_specifications-2024.10.1.dist-info/RECORD
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
jsonschema_specifications-2024.10.1.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
|
| 2 |
+
jsonschema_specifications-2024.10.1.dist-info/METADATA,sha256=-jCfClPka5D4aDTtJ683zNiEcSHXhPBLuk9r9XWwyHI,2985
|
| 3 |
+
jsonschema_specifications-2024.10.1.dist-info/RECORD,,
|
| 4 |
+
jsonschema_specifications-2024.10.1.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
|
| 5 |
+
jsonschema_specifications-2024.10.1.dist-info/licenses/COPYING,sha256=QtzWNJX4e063x3V6-jebtVpT-Ur9el9lfZrfVyNuUVw,1057
|
| 6 |
+
jsonschema_specifications/__init__.py,sha256=qoTB2DKY7qvNrGhMPH6gtmAJRLilmVQ-fFZwT6ryqw0,386
|
| 7 |
+
jsonschema_specifications/__pycache__/__init__.cpython-311.pyc,,
|
| 8 |
+
jsonschema_specifications/__pycache__/_core.cpython-311.pyc,,
|
| 9 |
+
jsonschema_specifications/_core.py,sha256=tFhc1CMleJ3AJOK_bjxOpFQTdrsUClFGfFxPBU_CebM,1140
|
| 10 |
+
jsonschema_specifications/schemas/draft201909/metaschema.json,sha256=e3YbPhIfCgyh6ioLjizIVrz4AWBLgmjXG6yqICvAwTs,1785
|
| 11 |
+
jsonschema_specifications/schemas/draft201909/vocabularies/applicator,sha256=aJUQDplyb7sQcFhRK77D7P1LJOj9L6zuPlBe5ysNTDE,1860
|
| 12 |
+
jsonschema_specifications/schemas/draft201909/vocabularies/content,sha256=m31PVaTi_bAsQwBo_f-rxzKt3OI42j8d8mkCScM1MnQ,517
|
| 13 |
+
jsonschema_specifications/schemas/draft201909/vocabularies/core,sha256=taLElX9kldClCB8ECevooU5BOayyA_x0hHH47eKvWyw,1531
|
| 14 |
+
jsonschema_specifications/schemas/draft201909/vocabularies/meta-data,sha256=1H4kRd1qgicaKY2DzGxsuNSuHhXg3Fa-zTehY-zwEoY,892
|
| 15 |
+
jsonschema_specifications/schemas/draft201909/vocabularies/validation,sha256=HlJsHTNac0gF_ILPV5jBK5YK19olF8Zs2lobCTWcPBw,2834
|
| 16 |
+
jsonschema_specifications/schemas/draft202012/metaschema.json,sha256=Qdp29a-3zgYtJI92JGOpL3ykfk4PkFsiS6av7vkd7Q8,2452
|
| 17 |
+
jsonschema_specifications/schemas/draft202012/vocabularies/applicator,sha256=xKbkFHuR_vf-ptwFjLG_k0AvdBS3ZXiosWqvHa1qrO8,1659
|
| 18 |
+
jsonschema_specifications/schemas/draft202012/vocabularies/content,sha256=CDQ3R3ZOSlgUJieTz01lIFenkThjxZUNQyl-jh_axbY,519
|
| 19 |
+
jsonschema_specifications/schemas/draft202012/vocabularies/core,sha256=wtEqjk3RHTNt_IOj9mOqTGnwtJs76wlP_rJbUxb0gD0,1564
|
| 20 |
+
jsonschema_specifications/schemas/draft202012/vocabularies/format,sha256=UOu_55BhGoSbjMQAoJwdDg-2q1wNQ6DyIgH9NiUFa_Q,403
|
| 21 |
+
jsonschema_specifications/schemas/draft202012/vocabularies/format-annotation,sha256=q8d1rf79idIjWBcNm_k_Tr0jSVY7u-3WDwK-98gSvMA,448
|
| 22 |
+
jsonschema_specifications/schemas/draft202012/vocabularies/format-assertion,sha256=xSJCuaG7eGsmw-gset1CjDH5yW5XXc6Z5W6l_qptogw,445
|
| 23 |
+
jsonschema_specifications/schemas/draft202012/vocabularies/meta-data,sha256=j3bW4U9Bubku-TO3CM3FFEyLUmhlGtEZGEhfsXVPHHY,892
|
| 24 |
+
jsonschema_specifications/schemas/draft202012/vocabularies/unevaluated,sha256=Lb-8tzmUtnCwl2SSre4f_7RsIWgnhNL1pMpWH54tDLQ,506
|
| 25 |
+
jsonschema_specifications/schemas/draft202012/vocabularies/validation,sha256=cBCjHlQfMtK-ch4t40jfdcmzaHaj7TBId_wKvaHTelg,2834
|
| 26 |
+
jsonschema_specifications/schemas/draft3/metaschema.json,sha256=LPdfZENvtb43Si6qJ6uLfh_WUcm0ba6mxnsC_WTiRYs,2600
|
| 27 |
+
jsonschema_specifications/schemas/draft4/metaschema.json,sha256=4UidC0dV8CeTMCWR0_y48Htok6gqlPJIlfjk7fEbguI,4357
|
| 28 |
+
jsonschema_specifications/schemas/draft6/metaschema.json,sha256=wp386fVINcOgbAOzxdXsDtp3cGVo-cTffPvHVmpRAG0,4437
|
| 29 |
+
jsonschema_specifications/schemas/draft7/metaschema.json,sha256=PVOSCIJhYGxVm2A_OFMpyfGrRbXWZ-uZBodFOwVdQF4,4819
|
| 30 |
+
jsonschema_specifications/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
| 31 |
+
jsonschema_specifications/tests/__pycache__/__init__.cpython-311.pyc,,
|
| 32 |
+
jsonschema_specifications/tests/__pycache__/test_jsonschema_specifications.cpython-311.pyc,,
|
| 33 |
+
jsonschema_specifications/tests/test_jsonschema_specifications.py,sha256=WkbYRW6A6FoZ0rivShfqVLSCsAiHJ2x8TxqECJTXPTY,1106
|
.venv/lib/python3.11/site-packages/jsonschema_specifications-2024.10.1.dist-info/WHEEL
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Wheel-Version: 1.0
|
| 2 |
+
Generator: hatchling 1.25.0
|
| 3 |
+
Root-Is-Purelib: true
|
| 4 |
+
Tag: py3-none-any
|
.venv/lib/python3.11/site-packages/jsonschema_specifications-2024.10.1.dist-info/licenses/COPYING
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Copyright (c) 2022 Julian Berman
|
| 2 |
+
|
| 3 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 4 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 5 |
+
in the Software without restriction, including without limitation the rights
|
| 6 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 7 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 8 |
+
furnished to do so, subject to the following conditions:
|
| 9 |
+
|
| 10 |
+
The above copyright notice and this permission notice shall be included in
|
| 11 |
+
all copies or substantial portions of the Software.
|
| 12 |
+
|
| 13 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 14 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 15 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 16 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 17 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 18 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
| 19 |
+
THE SOFTWARE.
|
.venv/lib/python3.11/site-packages/xformers/__init__.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the BSD license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from . import _cpp_lib
|
| 12 |
+
from .checkpoint import ( # noqa: E402, F401
|
| 13 |
+
checkpoint,
|
| 14 |
+
get_optimal_checkpoint_policy,
|
| 15 |
+
list_operators,
|
| 16 |
+
selective_checkpoint_wrapper,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
from .version import __version__ # noqa: F401
|
| 21 |
+
except ImportError:
|
| 22 |
+
__version__ = "0.0.0"
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
logger = logging.getLogger("xformers")
|
| 26 |
+
|
| 27 |
+
_has_cpp_library: bool = _cpp_lib._cpp_library_load_exception is None
|
| 28 |
+
|
| 29 |
+
_is_opensource: bool = True
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def compute_once(func):
|
| 33 |
+
value = None
|
| 34 |
+
|
| 35 |
+
def func_wrapper():
|
| 36 |
+
nonlocal value
|
| 37 |
+
if value is None:
|
| 38 |
+
value = func()
|
| 39 |
+
return value
|
| 40 |
+
|
| 41 |
+
return func_wrapper
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@compute_once
|
| 45 |
+
def _is_triton_available():
|
| 46 |
+
if os.environ.get("XFORMERS_ENABLE_TRITON", "0") == "1":
|
| 47 |
+
return True
|
| 48 |
+
if not torch.cuda.is_available():
|
| 49 |
+
return False
|
| 50 |
+
if os.environ.get("XFORMERS_FORCE_DISABLE_TRITON", "0") == "1":
|
| 51 |
+
return False
|
| 52 |
+
# We have many errors on V100 with recent triton versions
|
| 53 |
+
# Let's just drop support for triton kernels below A100
|
| 54 |
+
if torch.cuda.get_device_capability("cuda") < (8, 0):
|
| 55 |
+
return False
|
| 56 |
+
try:
|
| 57 |
+
import triton # noqa
|
| 58 |
+
|
| 59 |
+
return True
|
| 60 |
+
except (ImportError, AttributeError):
|
| 61 |
+
logger.warning(
|
| 62 |
+
"A matching Triton is not available, some optimizations will not be enabled",
|
| 63 |
+
exc_info=True,
|
| 64 |
+
)
|
| 65 |
+
return False
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
@compute_once
|
| 69 |
+
def get_python_lib():
|
| 70 |
+
return torch.library.Library("xformers_python", "DEF")
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
# end of file
|
.venv/lib/python3.11/site-packages/xformers/_cpp_lib.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the BSD license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import dataclasses
|
| 7 |
+
import json
|
| 8 |
+
import logging
|
| 9 |
+
import os
|
| 10 |
+
import platform
|
| 11 |
+
from typing import Any, Dict, Optional
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger("xformers")
|
| 16 |
+
|
| 17 |
+
UNAVAILABLE_FEATURES_MSG = (
|
| 18 |
+
" Memory-efficient attention, SwiGLU, sparse and more won't be available."
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclasses.dataclass
|
| 23 |
+
class _BuildInfo:
|
| 24 |
+
metadata: Dict[str, Any]
|
| 25 |
+
|
| 26 |
+
@property
|
| 27 |
+
def cuda_version(self) -> Optional[int]:
|
| 28 |
+
return self.metadata["version"]["cuda"]
|
| 29 |
+
|
| 30 |
+
@property
|
| 31 |
+
def hip_version(self) -> Optional[int]:
|
| 32 |
+
return self.metadata["version"]["hip"]
|
| 33 |
+
|
| 34 |
+
@property
|
| 35 |
+
def torch_version(self) -> str:
|
| 36 |
+
return self.metadata["version"]["torch"]
|
| 37 |
+
|
| 38 |
+
@property
|
| 39 |
+
def python_version(self) -> str:
|
| 40 |
+
return self.metadata["version"]["python"]
|
| 41 |
+
|
| 42 |
+
@property
|
| 43 |
+
def flash_version(self) -> str:
|
| 44 |
+
return self.metadata["version"].get("flash", "0.0.0")
|
| 45 |
+
|
| 46 |
+
@property
|
| 47 |
+
def use_torch_flash(self) -> bool:
|
| 48 |
+
return self.metadata["version"].get("use_torch_flash", False)
|
| 49 |
+
|
| 50 |
+
@property
|
| 51 |
+
def build_env(self) -> Dict[str, Any]:
|
| 52 |
+
return self.metadata["env"]
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class xFormersWasNotBuiltException(Exception):
|
| 56 |
+
def __str__(self) -> str:
|
| 57 |
+
return (
|
| 58 |
+
"Need to compile C++ extensions to use all xFormers features.\n"
|
| 59 |
+
" Please install xformers properly "
|
| 60 |
+
"(see https://github.com/facebookresearch/xformers#installing-xformers)\n"
|
| 61 |
+
+ UNAVAILABLE_FEATURES_MSG
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class xFormersInvalidLibException(Exception):
|
| 66 |
+
def __init__(self, build_info: Optional[_BuildInfo]) -> None:
|
| 67 |
+
self.build_info = build_info
|
| 68 |
+
|
| 69 |
+
def __str__(self) -> str:
|
| 70 |
+
if self.build_info is None:
|
| 71 |
+
msg = "xFormers was built for a different version of PyTorch or Python."
|
| 72 |
+
else:
|
| 73 |
+
msg = f"""xFormers was built for:
|
| 74 |
+
PyTorch {self.build_info.torch_version} with CUDA {self.build_info.cuda_version} (you have {torch.__version__})
|
| 75 |
+
Python {self.build_info.python_version} (you have {platform.python_version()})"""
|
| 76 |
+
return (
|
| 77 |
+
"xFormers can't load C++/CUDA extensions. "
|
| 78 |
+
+ msg
|
| 79 |
+
+ "\n Please reinstall xformers "
|
| 80 |
+
"(see https://github.com/facebookresearch/xformers#installing-xformers)\n"
|
| 81 |
+
+ UNAVAILABLE_FEATURES_MSG
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def _register_extensions():
|
| 86 |
+
import importlib
|
| 87 |
+
import os
|
| 88 |
+
|
| 89 |
+
import torch
|
| 90 |
+
|
| 91 |
+
# load the custom_op_library and register the custom ops
|
| 92 |
+
lib_dir = os.path.dirname(__file__)
|
| 93 |
+
if os.name == "nt":
|
| 94 |
+
# Register the main torchvision library location on the default DLL path
|
| 95 |
+
import ctypes
|
| 96 |
+
import sys
|
| 97 |
+
|
| 98 |
+
kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True)
|
| 99 |
+
with_load_library_flags = hasattr(kernel32, "AddDllDirectory")
|
| 100 |
+
prev_error_mode = kernel32.SetErrorMode(0x0001)
|
| 101 |
+
|
| 102 |
+
if with_load_library_flags:
|
| 103 |
+
kernel32.AddDllDirectory.restype = ctypes.c_void_p
|
| 104 |
+
|
| 105 |
+
if sys.version_info >= (3, 8):
|
| 106 |
+
os.add_dll_directory(lib_dir)
|
| 107 |
+
elif with_load_library_flags:
|
| 108 |
+
res = kernel32.AddDllDirectory(lib_dir)
|
| 109 |
+
if res is None:
|
| 110 |
+
err = ctypes.WinError(ctypes.get_last_error())
|
| 111 |
+
err.strerror += f' Error adding "{lib_dir}" to the DLL directories.'
|
| 112 |
+
raise err
|
| 113 |
+
|
| 114 |
+
kernel32.SetErrorMode(prev_error_mode)
|
| 115 |
+
|
| 116 |
+
loader_details = (
|
| 117 |
+
importlib.machinery.ExtensionFileLoader,
|
| 118 |
+
importlib.machinery.EXTENSION_SUFFIXES,
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
extfinder = importlib.machinery.FileFinder(lib_dir, loader_details)
|
| 122 |
+
if torch.version.hip and not hasattr(torch.version, "git_version"):
|
| 123 |
+
ext_specs = extfinder.find_spec("_C_hip")
|
| 124 |
+
else:
|
| 125 |
+
ext_specs = extfinder.find_spec("_C")
|
| 126 |
+
if ext_specs is None:
|
| 127 |
+
raise xFormersWasNotBuiltException()
|
| 128 |
+
cpp_lib_json = os.path.join(lib_dir, "cpp_lib.json")
|
| 129 |
+
with open(cpp_lib_json, "r") as fp:
|
| 130 |
+
build_metadata = _BuildInfo(json.load(fp))
|
| 131 |
+
try:
|
| 132 |
+
torch.ops.load_library(ext_specs.origin)
|
| 133 |
+
except OSError as exc:
|
| 134 |
+
raise xFormersInvalidLibException(build_metadata) from exc
|
| 135 |
+
return build_metadata
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
_cpp_library_load_exception = None
|
| 139 |
+
_build_metadata: Optional[_BuildInfo] = None
|
| 140 |
+
|
| 141 |
+
try:
|
| 142 |
+
_build_metadata = _register_extensions()
|
| 143 |
+
except (xFormersInvalidLibException, xFormersWasNotBuiltException) as e:
|
| 144 |
+
ENV_VAR_FOR_DETAILS = "XFORMERS_MORE_DETAILS"
|
| 145 |
+
if os.environ.get(ENV_VAR_FOR_DETAILS, False):
|
| 146 |
+
logger.warning(f"WARNING[XFORMERS]: {e}", exc_info=e)
|
| 147 |
+
else:
|
| 148 |
+
logger.warning(
|
| 149 |
+
f"WARNING[XFORMERS]: {e}\n Set {ENV_VAR_FOR_DETAILS}=1 for more details"
|
| 150 |
+
)
|
| 151 |
+
_cpp_library_load_exception = e
|
| 152 |
+
|
| 153 |
+
_built_with_cuda = (
|
| 154 |
+
_build_metadata is not None and _build_metadata.cuda_version is not None
|
| 155 |
+
)
|
.venv/lib/python3.11/site-packages/xformers/_deprecation_warning.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the BSD license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import warnings
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def deprecated_function(self):
|
| 10 |
+
name = repr(self) # self.__name__
|
| 11 |
+
msg = f"{name} is deprecated and is not maintained anymore. It might be removed in a future version of xFormers"
|
| 12 |
+
warnings.warn(msg, FutureWarning, stacklevel=2)
|
.venv/lib/python3.11/site-packages/xformers/attn_bias_utils.py
ADDED
|
@@ -0,0 +1,501 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the BSD license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
import random
|
| 8 |
+
from typing import List, Optional, Sequence, Tuple, Type
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
from xformers.ops import AttentionBias, fmha
|
| 13 |
+
from xformers.ops.fmha.attn_bias import AttentionBiasSubTensor
|
| 14 |
+
from xformers.ops.fmha.common import AttentionOpBase
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _create_aligned_bias(*shape: int, **kwargs) -> torch.Tensor:
|
| 18 |
+
align_to = 8
|
| 19 |
+
return (
|
| 20 |
+
torch.randn(
|
| 21 |
+
(
|
| 22 |
+
*shape[:-1],
|
| 23 |
+
align_to * ((shape[-1] + align_to - 1) // align_to),
|
| 24 |
+
),
|
| 25 |
+
**kwargs,
|
| 26 |
+
)
|
| 27 |
+
* 3
|
| 28 |
+
).narrow(-1, 0, shape[-1])
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def create_attn_bias(
|
| 32 |
+
bias_type,
|
| 33 |
+
batch_size: int,
|
| 34 |
+
num_heads: int,
|
| 35 |
+
num_heads_groups: int,
|
| 36 |
+
q_len: int,
|
| 37 |
+
kv_len: int,
|
| 38 |
+
device,
|
| 39 |
+
dtype,
|
| 40 |
+
requires_grad: bool,
|
| 41 |
+
fmt: str,
|
| 42 |
+
op: Optional[Type[AttentionOpBase]] = None,
|
| 43 |
+
page_size: Optional[int] = None,
|
| 44 |
+
):
|
| 45 |
+
if bias_type is None or isinstance(None, bias_type):
|
| 46 |
+
return None
|
| 47 |
+
r = random.Random("-".join(map(str, [batch_size, q_len, kv_len, dtype, fmt])))
|
| 48 |
+
window_size = {0: 3, 1: 128, 2: 300}[r.randint(0, 2)]
|
| 49 |
+
if bias_type is torch.Tensor:
|
| 50 |
+
if fmt == "BMK":
|
| 51 |
+
batch_size *= num_heads
|
| 52 |
+
num_heads = 1
|
| 53 |
+
if op is not None and issubclass(op, fmha.triton_splitk.FwOp):
|
| 54 |
+
attn_bias = (
|
| 55 |
+
torch.randn(
|
| 56 |
+
(batch_size, num_heads_groups, num_heads, q_len, kv_len),
|
| 57 |
+
device=device,
|
| 58 |
+
dtype=dtype,
|
| 59 |
+
)
|
| 60 |
+
* 3
|
| 61 |
+
)
|
| 62 |
+
if fmt in ["BMK", "BMHK"]:
|
| 63 |
+
attn_bias = attn_bias[:, 0]
|
| 64 |
+
else:
|
| 65 |
+
attn_bias = _create_aligned_bias(
|
| 66 |
+
batch_size,
|
| 67 |
+
num_heads_groups,
|
| 68 |
+
num_heads,
|
| 69 |
+
q_len,
|
| 70 |
+
kv_len,
|
| 71 |
+
device=device,
|
| 72 |
+
dtype=dtype,
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
# make sure it also works if the first columns/rows are partially masked out
|
| 76 |
+
attn_bias[0, 0, 0, : q_len - 1, : kv_len - 1] = -math.inf
|
| 77 |
+
if fmt in ["BMK", "BMHK"]:
|
| 78 |
+
attn_bias = attn_bias[:, 0]
|
| 79 |
+
|
| 80 |
+
if requires_grad:
|
| 81 |
+
attn_bias.requires_grad_(True)
|
| 82 |
+
if fmt == "BMK":
|
| 83 |
+
attn_bias = attn_bias[:, 0]
|
| 84 |
+
return attn_bias
|
| 85 |
+
if bias_type is fmha.attn_bias.LowerTriangularMask:
|
| 86 |
+
return bias_type()
|
| 87 |
+
if bias_type is fmha.attn_bias.LowerTriangularFromBottomRightMask:
|
| 88 |
+
return bias_type()
|
| 89 |
+
if bias_type is fmha.attn_bias.LowerTriangularFromBottomRightLocalAttentionMask:
|
| 90 |
+
return bias_type(window_size)
|
| 91 |
+
if bias_type is fmha.attn_bias.LowerTriangularMaskWithTensorBias:
|
| 92 |
+
attn_bias = _create_aligned_bias(
|
| 93 |
+
batch_size,
|
| 94 |
+
num_heads_groups,
|
| 95 |
+
num_heads,
|
| 96 |
+
q_len,
|
| 97 |
+
kv_len,
|
| 98 |
+
device=device,
|
| 99 |
+
dtype=dtype,
|
| 100 |
+
)
|
| 101 |
+
if fmt in ["BMK", "BMHK"]:
|
| 102 |
+
attn_bias = attn_bias[:, 0]
|
| 103 |
+
if fmt == "BMK":
|
| 104 |
+
attn_bias = attn_bias[:, 0]
|
| 105 |
+
if requires_grad:
|
| 106 |
+
attn_bias.requires_grad_(True)
|
| 107 |
+
return fmha.attn_bias.LowerTriangularMaskWithTensorBias(attn_bias)
|
| 108 |
+
if bias_type in [
|
| 109 |
+
fmha.attn_bias.BlockDiagonalMask,
|
| 110 |
+
fmha.attn_bias.BlockDiagonalCausalMask,
|
| 111 |
+
fmha.attn_bias.BlockDiagonalCausalFromBottomRightMask,
|
| 112 |
+
fmha.attn_bias.BlockDiagonalCausalLocalAttentionMask,
|
| 113 |
+
fmha.attn_bias.BlockDiagonalCausalLocalAttentionFromBottomRightMask,
|
| 114 |
+
]:
|
| 115 |
+
# These bias types are not supported in BMK format
|
| 116 |
+
assert fmt in ["BMGHK", "BMHK"]
|
| 117 |
+
max_q_minus_k = None
|
| 118 |
+
if bias_type in {
|
| 119 |
+
fmha.attn_bias.BlockDiagonalCausalFromBottomRightMask,
|
| 120 |
+
fmha.attn_bias.BlockDiagonalCausalLocalAttentionFromBottomRightMask,
|
| 121 |
+
}:
|
| 122 |
+
max_q_minus_k = 0
|
| 123 |
+
elif bias_type == fmha.attn_bias.BlockDiagonalCausalLocalAttentionMask:
|
| 124 |
+
assert window_size is not None
|
| 125 |
+
max_q_minus_k = window_size - 1
|
| 126 |
+
|
| 127 |
+
block_diag = fmha.attn_bias.BlockDiagonalMask.from_seqlens(
|
| 128 |
+
*_rand_seqlens(
|
| 129 |
+
r,
|
| 130 |
+
batch_size,
|
| 131 |
+
q_len,
|
| 132 |
+
kv_len,
|
| 133 |
+
max_q_minus_k=max_q_minus_k,
|
| 134 |
+
)
|
| 135 |
+
)
|
| 136 |
+
if bias_type is fmha.attn_bias.BlockDiagonalCausalMask:
|
| 137 |
+
block_diag = block_diag.make_causal()
|
| 138 |
+
if bias_type in {
|
| 139 |
+
fmha.attn_bias.BlockDiagonalCausalLocalAttentionMask,
|
| 140 |
+
fmha.attn_bias.BlockDiagonalCausalLocalAttentionFromBottomRightMask,
|
| 141 |
+
}:
|
| 142 |
+
block_diag = fmha.attn_bias.BlockDiagonalMask(
|
| 143 |
+
q_seqinfo=block_diag.q_seqinfo,
|
| 144 |
+
k_seqinfo=block_diag.k_seqinfo,
|
| 145 |
+
_batch_sizes=block_diag._batch_sizes,
|
| 146 |
+
)
|
| 147 |
+
assert window_size is not None
|
| 148 |
+
if bias_type is fmha.attn_bias.BlockDiagonalCausalLocalAttentionMask:
|
| 149 |
+
block_diag = block_diag.make_local_attention(window_size)
|
| 150 |
+
else:
|
| 151 |
+
block_diag = block_diag.make_local_attention_from_bottomright(
|
| 152 |
+
window_size
|
| 153 |
+
)
|
| 154 |
+
if bias_type is fmha.attn_bias.BlockDiagonalCausalFromBottomRightMask:
|
| 155 |
+
block_diag = block_diag.make_causal_from_bottomright()
|
| 156 |
+
return block_diag
|
| 157 |
+
if bias_type in [
|
| 158 |
+
fmha.attn_bias.BlockDiagonalPaddedKeysMask,
|
| 159 |
+
fmha.attn_bias.BlockDiagonalCausalLocalAttentionPaddedKeysMask,
|
| 160 |
+
fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask,
|
| 161 |
+
fmha.attn_bias.PagedBlockDiagonalPaddedKeysMask,
|
| 162 |
+
fmha.attn_bias.PagedBlockDiagonalCausalWithOffsetPaddedKeysMask,
|
| 163 |
+
]:
|
| 164 |
+
assert fmt in ["BMHK", "BMGHK"]
|
| 165 |
+
q, k = _rand_seqlens_padded_k(r, batch_size, q_len, kv_len)
|
| 166 |
+
block_diag_type = (
|
| 167 |
+
bias_type._UNPAGED_TYPE
|
| 168 |
+
if issubclass(bias_type, fmha.attn_bias.PagedBlockDiagonalPaddedKeysMask)
|
| 169 |
+
else bias_type
|
| 170 |
+
)
|
| 171 |
+
if bias_type is fmha.attn_bias.BlockDiagonalCausalLocalAttentionPaddedKeysMask:
|
| 172 |
+
g_block_diag = block_diag_type.from_seqlens_local(
|
| 173 |
+
q_seqlen=q,
|
| 174 |
+
kv_padding=kv_len,
|
| 175 |
+
kv_seqlen=k,
|
| 176 |
+
window_size=min(window_size, min(k)),
|
| 177 |
+
)
|
| 178 |
+
else:
|
| 179 |
+
g_block_diag = block_diag_type.from_seqlens(
|
| 180 |
+
q_seqlen=q,
|
| 181 |
+
kv_padding=kv_len,
|
| 182 |
+
kv_seqlen=k,
|
| 183 |
+
)
|
| 184 |
+
if issubclass(bias_type, fmha.attn_bias.PagedBlockDiagonalPaddedKeysMask):
|
| 185 |
+
assert page_size is not None
|
| 186 |
+
pages_per_row = (kv_len + page_size - 1) // page_size
|
| 187 |
+
block_tables = torch.tensor(
|
| 188 |
+
r.sample(range(batch_size * pages_per_row), batch_size * pages_per_row),
|
| 189 |
+
device=device,
|
| 190 |
+
dtype=torch.int32,
|
| 191 |
+
).reshape(batch_size, pages_per_row)
|
| 192 |
+
return g_block_diag.make_paged(
|
| 193 |
+
block_tables=block_tables, page_size=page_size, paged_type=bias_type
|
| 194 |
+
)
|
| 195 |
+
return g_block_diag
|
| 196 |
+
if bias_type in [
|
| 197 |
+
fmha.attn_bias.BlockDiagonalCausalWithOffsetGappyKeysMask,
|
| 198 |
+
fmha.attn_bias.BlockDiagonalGappyKeysMask,
|
| 199 |
+
]:
|
| 200 |
+
assert fmt in ["BMHK", "BMGHK"]
|
| 201 |
+
max_q_minus_k = (
|
| 202 |
+
None if bias_type is fmha.attn_bias.BlockDiagonalGappyKeysMask else 0
|
| 203 |
+
)
|
| 204 |
+
q, k = _rand_seqlens(r, batch_size, q_len, kv_len, max_q_minus_k)
|
| 205 |
+
total_kv_len = kv_len * batch_size
|
| 206 |
+
starts = [r.randint(0, total_kv_len - ki) for ki in k] + [total_kv_len]
|
| 207 |
+
return fmha.attn_bias.BlockDiagonalGappyKeysMask.from_seqlens(
|
| 208 |
+
q_seqlen=q,
|
| 209 |
+
kv_seqstarts=starts,
|
| 210 |
+
kv_seqlen=k,
|
| 211 |
+
)
|
| 212 |
+
if bias_type in [
|
| 213 |
+
fmha.attn_bias.PagedBlockDiagonalGappyKeysMask,
|
| 214 |
+
]:
|
| 215 |
+
assert fmt in ["BMHK", "BMGHK"]
|
| 216 |
+
assert page_size is not None
|
| 217 |
+
pages_per_row = (kv_len + page_size - 1) // page_size
|
| 218 |
+
total_queries = q_len * batch_size
|
| 219 |
+
q = _rand_maxed_partition(r, total_queries, batch_size, total_queries, False)
|
| 220 |
+
k = [r.randint(1, kv_len) for _ in range(batch_size)]
|
| 221 |
+
row_size = pages_per_row * page_size
|
| 222 |
+
starts = [row_size * i + r.randint(0, row_size - ki) for i, ki in enumerate(k)]
|
| 223 |
+
starts.append(pages_per_row * batch_size * page_size)
|
| 224 |
+
block_diag_type = bias_type._UNPAGED_TYPE # type: ignore
|
| 225 |
+
g_block_diag = block_diag_type.from_seqlens(
|
| 226 |
+
q_seqlen=q,
|
| 227 |
+
kv_seqstarts=starts,
|
| 228 |
+
kv_seqlen=k,
|
| 229 |
+
)
|
| 230 |
+
block_tables = torch.tensor(
|
| 231 |
+
r.sample(range(batch_size * pages_per_row), batch_size * pages_per_row),
|
| 232 |
+
device=device,
|
| 233 |
+
dtype=torch.int32,
|
| 234 |
+
).reshape(batch_size, pages_per_row)
|
| 235 |
+
return g_block_diag.make_paged(
|
| 236 |
+
block_tables=block_tables,
|
| 237 |
+
page_size=page_size,
|
| 238 |
+
paged_type=bias_type,
|
| 239 |
+
notional_padding=page_size * pages_per_row,
|
| 240 |
+
)
|
| 241 |
+
if bias_type == fmha.attn_bias.LocalAttentionFromBottomRightMask:
|
| 242 |
+
return bias_type(
|
| 243 |
+
window_left=r.randint(0, 5),
|
| 244 |
+
window_right=r.randint(0, 5),
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
assert False, f"Unsupported bias type: {bias_type}"
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def _rand_seqlens(
|
| 251 |
+
r: random.Random,
|
| 252 |
+
bs: int,
|
| 253 |
+
q_len: int,
|
| 254 |
+
kv_len: int,
|
| 255 |
+
max_q_minus_k: Optional[int],
|
| 256 |
+
) -> Tuple[Sequence[int], Sequence[int]]:
|
| 257 |
+
"""
|
| 258 |
+
Generates lists of lengths of query blocks and corresponding key blocks.
|
| 259 |
+
The total number of queries will be bs * q_len and the
|
| 260 |
+
total number of keys will be bs * kv_len.
|
| 261 |
+
max_q_minus_k: maximum allowed num_queries - num_keys.
|
| 262 |
+
For "bottom-right" masks it's 0, we need to have more keys than
|
| 263 |
+
queries, otherwise some queries have no keys to attend to.
|
| 264 |
+
For BlockDiagonalCausalMask it's None, there is no constraint
|
| 265 |
+
on num_queries - num_keys.
|
| 266 |
+
For BlockDiagonalCausalLocalAttentionMask it's equal
|
| 267 |
+
to the window size.
|
| 268 |
+
"""
|
| 269 |
+
if max_q_minus_k == 0:
|
| 270 |
+
# In case max_q_minus_k > 0 the exact condition is
|
| 271 |
+
# kv_len >= q_len - max_q_minus_k * batch_size,
|
| 272 |
+
# but we can't check it without knowing the actual batch size,
|
| 273 |
+
# which is determined in the loop below.
|
| 274 |
+
assert kv_len >= q_len
|
| 275 |
+
q_len *= bs
|
| 276 |
+
kv_len *= bs
|
| 277 |
+
seqlens_q: List[int] = []
|
| 278 |
+
seqlens_k: List[int] = []
|
| 279 |
+
|
| 280 |
+
step_q = [max(1, q_len // 10), max(2, q_len // 2)]
|
| 281 |
+
step_k = [max(1, kv_len // 10), max(2, kv_len // 2)]
|
| 282 |
+
while sum(seqlens_q) < q_len and sum(seqlens_k) < kv_len:
|
| 283 |
+
if max_q_minus_k is None:
|
| 284 |
+
# Simple case - no constraint on the number of queries and keys.
|
| 285 |
+
num_queries = r.randrange(*step_q)
|
| 286 |
+
seqlens_q.append(num_queries)
|
| 287 |
+
seqlens_k.append(r.randrange(*step_k))
|
| 288 |
+
else:
|
| 289 |
+
# In this case we need to make sure num_queries - num_keys < max_q_minus_k holds for every batch element.
|
| 290 |
+
# To do this, when choosing num_queries and num_keys at a given step,
|
| 291 |
+
# we ensure two conditions are satisfied:
|
| 292 |
+
# 1) num_queries <= num_keys + max_q_minus_k for the current batch element
|
| 293 |
+
# 2) Same holds for the remaining keys and queries, i.e.
|
| 294 |
+
# queries_left - num_queries <= keys_left - num_keys + max_q_minus_k
|
| 295 |
+
keys_left = kv_len - sum(seqlens_k, 0)
|
| 296 |
+
queries_left = q_len - sum(seqlens_q, 0)
|
| 297 |
+
|
| 298 |
+
assert (
|
| 299 |
+
keys_left >= queries_left - max_q_minus_k
|
| 300 |
+
), f"{keys_left=} {queries_left=} {max_q_minus_k=} {kv_len=} {q_len=} {seqlens_k=} {seqlens_q=}"
|
| 301 |
+
# Limit num_queries from above: if num_queries > keys_left + max_q_minus_k,
|
| 302 |
+
# condition num_queries <= num_keys + max_q_minus_k can't be satisfied even if we take
|
| 303 |
+
# all the remaining keys
|
| 304 |
+
max_queries_to_take = min(queries_left, keys_left + max_q_minus_k)
|
| 305 |
+
num_queries = r.randrange(1, max_queries_to_take + 1)
|
| 306 |
+
seqlens_q.append(num_queries)
|
| 307 |
+
|
| 308 |
+
# Now we know num_queries, let's select num_keys.
|
| 309 |
+
# How many keys can we use for the current batch element so that
|
| 310 |
+
# for the remaining keys and values the constraint
|
| 311 |
+
# num_queries - num_keys < max_q_minus_k holds on the next step?
|
| 312 |
+
extra_keys_available = keys_left - queries_left + max_q_minus_k + 1
|
| 313 |
+
assert extra_keys_available >= 0
|
| 314 |
+
if extra_keys_available > 0:
|
| 315 |
+
seqlens_k.append(num_queries + r.randrange(0, extra_keys_available))
|
| 316 |
+
else:
|
| 317 |
+
seqlens_k.append(num_queries)
|
| 318 |
+
seqlens_q[-1] = q_len - sum(seqlens_q[:-1])
|
| 319 |
+
seqlens_k[-1] = kv_len - sum(seqlens_k[:-1])
|
| 320 |
+
return seqlens_q, seqlens_k
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
def _rand_maxed_partition(
|
| 324 |
+
r: random.Random, total: int, n: int, mx: int, positive: bool = True
|
| 325 |
+
) -> List[int]:
|
| 326 |
+
# returns list of n nonnegative integers less than mx summing to total
|
| 327 |
+
# NB: This is unfortunately biased towards evenly-split bins.
|
| 328 |
+
# If `positive`, outputs are positive
|
| 329 |
+
if positive:
|
| 330 |
+
total -= n
|
| 331 |
+
mx -= 1
|
| 332 |
+
idxs = r.sample(range(n * mx), total)
|
| 333 |
+
y = torch.zeros(n, mx, dtype=torch.int32)
|
| 334 |
+
y.flatten()[idxs] = 1
|
| 335 |
+
z = y.sum(1)
|
| 336 |
+
if positive:
|
| 337 |
+
z += 1
|
| 338 |
+
return z.tolist()
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
def _rand_seqlens_padded_k(
|
| 342 |
+
r: random.Random, bs: int, q_len: int, kv_len: int
|
| 343 |
+
) -> Tuple[Sequence[int], Sequence[int]]:
|
| 344 |
+
# This is for BlockDiagonalCausalWithOffsetPaddedKeysMask.
|
| 345 |
+
# we need q_seqlens and k_seqlens to be of len bsz.
|
| 346 |
+
# For each "batch element" there must be more keys than queries
|
| 347 |
+
# because this bias type is "bottom right" and so any extra queries
|
| 348 |
+
# will attend to nothing and have undefined result.
|
| 349 |
+
# In addition every element of k_seqlens must be <= kv_len
|
| 350 |
+
if q_len > kv_len:
|
| 351 |
+
raise ValueError("need more keys than values")
|
| 352 |
+
if q_len == kv_len:
|
| 353 |
+
# all key slots are needed so we cannot have padding
|
| 354 |
+
q_seqlens = k_seqlens = [kv_len] * bs
|
| 355 |
+
else:
|
| 356 |
+
q_seqlens = _rand_maxed_partition(r, q_len * bs, bs, kv_len)
|
| 357 |
+
k_seqlens = [r.randint(i, kv_len) for i in q_seqlens]
|
| 358 |
+
return q_seqlens, k_seqlens
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
def ref_attention(q, k, v, attn_bias=None, drop_mask=None, p=0.0, scale=None):
|
| 362 |
+
if q.ndim == 5:
|
| 363 |
+
|
| 364 |
+
def attn_bias_group(group: int):
|
| 365 |
+
if isinstance(attn_bias, fmha.attn_bias.AttentionBiasSubTensor):
|
| 366 |
+
if attn_bias.HOLDS_DENSE_TENSOR:
|
| 367 |
+
return attn_bias[:, group]
|
| 368 |
+
elif isinstance(attn_bias, torch.Tensor):
|
| 369 |
+
return attn_bias[:, group]
|
| 370 |
+
return attn_bias
|
| 371 |
+
|
| 372 |
+
return torch.stack(
|
| 373 |
+
[
|
| 374 |
+
ref_attention_bmhk(
|
| 375 |
+
q[:, :, g],
|
| 376 |
+
k[:, :, g],
|
| 377 |
+
v[:, :, g],
|
| 378 |
+
scale=scale,
|
| 379 |
+
attn_bias=attn_bias_group(g),
|
| 380 |
+
)
|
| 381 |
+
for g in range(q.shape[2])
|
| 382 |
+
],
|
| 383 |
+
dim=2,
|
| 384 |
+
)
|
| 385 |
+
if q.ndim == 4:
|
| 386 |
+
assert p == 0.0
|
| 387 |
+
return ref_attention_bmhk(q, k, v, scale=scale, attn_bias=attn_bias)
|
| 388 |
+
q = q.float()
|
| 389 |
+
k = k.float()
|
| 390 |
+
v = v.float()
|
| 391 |
+
|
| 392 |
+
scale = scale if scale is not None else (1 / q.shape[-1] ** 0.5)
|
| 393 |
+
q = q * scale
|
| 394 |
+
|
| 395 |
+
attn = q @ k.transpose(-2, -1)
|
| 396 |
+
if attn_bias is not None:
|
| 397 |
+
if isinstance(attn_bias, (AttentionBias, AttentionBiasSubTensor)):
|
| 398 |
+
# Always create in B,H,Mq,Mk format
|
| 399 |
+
attn_bias_tensor = attn_bias.materialize(
|
| 400 |
+
(q.shape[0], 1, q.shape[1], k.shape[1]),
|
| 401 |
+
device=q.device,
|
| 402 |
+
dtype=torch.float32,
|
| 403 |
+
)
|
| 404 |
+
else:
|
| 405 |
+
attn_bias_tensor = attn_bias
|
| 406 |
+
if attn_bias_tensor.ndim == 4:
|
| 407 |
+
assert q.shape[0] == attn_bias_tensor.shape[0] * attn_bias_tensor.shape[1]
|
| 408 |
+
attn_bias_tensor = attn_bias_tensor.reshape(
|
| 409 |
+
[-1, *attn_bias_tensor.shape[2:]]
|
| 410 |
+
)
|
| 411 |
+
attn = attn + attn_bias_tensor.float()
|
| 412 |
+
attn = attn.softmax(-1)
|
| 413 |
+
if drop_mask is not None:
|
| 414 |
+
attn = attn * (drop_mask / (1 - p))
|
| 415 |
+
return attn @ v
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
def ref_attention_bmhk(q, k, v, attn_bias, scale=None) -> torch.Tensor:
|
| 419 |
+
assert q.ndim == 4
|
| 420 |
+
|
| 421 |
+
def T(t):
|
| 422 |
+
return t.permute((0, 2, 1, 3)).reshape(
|
| 423 |
+
[t.shape[0] * t.shape[2], t.shape[1], t.shape[3]]
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
if isinstance(attn_bias, (AttentionBias, AttentionBiasSubTensor)):
|
| 427 |
+
attn_bias = attn_bias.materialize(
|
| 428 |
+
(q.shape[0], q.shape[2], q.shape[1], k.shape[1]),
|
| 429 |
+
device=q.device,
|
| 430 |
+
dtype=torch.float32,
|
| 431 |
+
).reshape([q.shape[0] * q.shape[2], q.shape[1], k.shape[1]])
|
| 432 |
+
out = ref_attention(T(q), T(k), T(v), attn_bias, scale=scale)
|
| 433 |
+
out = out.reshape([q.shape[0], q.shape[2], q.shape[1], v.shape[3]])
|
| 434 |
+
return out.permute((0, 2, 1, 3))
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
def pack_kv_cache(
|
| 438 |
+
cache_k: torch.Tensor,
|
| 439 |
+
cache_v: torch.Tensor,
|
| 440 |
+
kv_seqlens: List[int],
|
| 441 |
+
BLOCK_N: int,
|
| 442 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 443 |
+
"""
|
| 444 |
+
Create block tables and pages K/V cache for testing paged attention.
|
| 445 |
+
Args:
|
| 446 |
+
cache_k, cache_v: K/V caches, each of shape [B, MAX_T, H_kv, D].
|
| 447 |
+
Note that these tensors are unexpanded,
|
| 448 |
+
i.e. for multiquery case cache_k.shape[2] = 1
|
| 449 |
+
kv_seqlens: list of K/V sequence lengths
|
| 450 |
+
BLOCK_N: number of tokens per per paged attention block
|
| 451 |
+
B: batch size
|
| 452 |
+
Returns:
|
| 453 |
+
block_tables: [B, MAX_BLOCKS]
|
| 454 |
+
packed_cache_k: [1, total_len_rounded, H_kv, D]
|
| 455 |
+
packed_cache_v: [1, total_len_rounded, H_kv, D]
|
| 456 |
+
where total_len_rounded is a sum of K/V seqlens, each rounded up
|
| 457 |
+
to a multiple of BLOCK_N.
|
| 458 |
+
"""
|
| 459 |
+
|
| 460 |
+
kv_seqlens_rounded = [(x + BLOCK_N - 1) // BLOCK_N * BLOCK_N for x in kv_seqlens]
|
| 461 |
+
|
| 462 |
+
total_len_rounded = sum(kv_seqlens_rounded)
|
| 463 |
+
|
| 464 |
+
B, MAX_T, H, D = cache_k.shape
|
| 465 |
+
|
| 466 |
+
packed_cache_k = torch.empty(
|
| 467 |
+
total_len_rounded, H, D, device=cache_k.device, dtype=cache_k.dtype
|
| 468 |
+
)
|
| 469 |
+
packed_cache_v = torch.empty(
|
| 470 |
+
total_len_rounded, H, D, device=cache_k.device, dtype=cache_k.dtype
|
| 471 |
+
)
|
| 472 |
+
seqstart = 0
|
| 473 |
+
for b in range(B):
|
| 474 |
+
packed_cache_k[seqstart : seqstart + kv_seqlens[b]] = cache_k[
|
| 475 |
+
b, : kv_seqlens[b]
|
| 476 |
+
].clone()
|
| 477 |
+
packed_cache_v[seqstart : seqstart + kv_seqlens[b]] = cache_v[
|
| 478 |
+
b, : kv_seqlens[b]
|
| 479 |
+
].clone()
|
| 480 |
+
seqstart += kv_seqlens_rounded[b]
|
| 481 |
+
|
| 482 |
+
num_blocks_per_row = (MAX_T + BLOCK_N - 1) // BLOCK_N
|
| 483 |
+
block_tables = (
|
| 484 |
+
torch.arange(num_blocks_per_row, device="cuda", dtype=torch.int32)
|
| 485 |
+
.unsqueeze(0)
|
| 486 |
+
.expand(B, num_blocks_per_row)
|
| 487 |
+
)
|
| 488 |
+
seqstarts = (
|
| 489 |
+
(
|
| 490 |
+
torch.tensor(kv_seqlens_rounded).cumsum(dim=0)
|
| 491 |
+
- torch.tensor(kv_seqlens_rounded)
|
| 492 |
+
)
|
| 493 |
+
.to(device="cuda")
|
| 494 |
+
.unsqueeze(1)
|
| 495 |
+
) // BLOCK_N
|
| 496 |
+
block_tables = (block_tables + seqstarts).contiguous().to(dtype=torch.int32)
|
| 497 |
+
return (
|
| 498 |
+
block_tables,
|
| 499 |
+
packed_cache_k.unsqueeze(0),
|
| 500 |
+
packed_cache_v.unsqueeze(0),
|
| 501 |
+
)
|
.venv/lib/python3.11/site-packages/xformers/checkpoint.py
ADDED
|
@@ -0,0 +1,546 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the BSD license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
import functools
|
| 8 |
+
import time
|
| 9 |
+
from collections import defaultdict
|
| 10 |
+
from copy import deepcopy
|
| 11 |
+
from dataclasses import astuple, dataclass
|
| 12 |
+
from typing import Any, Callable, ContextManager, Dict, List, Optional, Tuple
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
from torch.testing._internal.composite_compliance import (
|
| 16 |
+
is_inplace,
|
| 17 |
+
is_inplace_view_fn,
|
| 18 |
+
is_view_fn,
|
| 19 |
+
)
|
| 20 |
+
from torch.utils._python_dispatch import TorchDispatchMode
|
| 21 |
+
from torch.utils._pytree import tree_map
|
| 22 |
+
|
| 23 |
+
_scipy_is_available = False
|
| 24 |
+
try:
|
| 25 |
+
from scipy.optimize import Bounds, LinearConstraint, milp
|
| 26 |
+
|
| 27 |
+
_scipy_is_available = True
|
| 28 |
+
except ImportError:
|
| 29 |
+
_scipy_is_available = False
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
try:
|
| 33 |
+
# let's keep BC for older PyTorch for now
|
| 34 |
+
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
| 35 |
+
ActivationWrapper,
|
| 36 |
+
)
|
| 37 |
+
from torch.utils.checkpoint import ( # type: ignore
|
| 38 |
+
_CachedTorchDispatchMode,
|
| 39 |
+
_CachingTorchDispatchMode,
|
| 40 |
+
)
|
| 41 |
+
except ImportError:
|
| 42 |
+
ActivationWrapper = torch.nn.Module # type: ignore
|
| 43 |
+
|
| 44 |
+
class _NotAvailable:
|
| 45 |
+
def __init__(self, *args, **kwargs):
|
| 46 |
+
raise RuntimeError("Need PyTorch >= 2.2")
|
| 47 |
+
|
| 48 |
+
_CachedTorchDispatchMode = _NotAvailable # type: ignore
|
| 49 |
+
_CachingTorchDispatchMode = _NotAvailable # type: ignore
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
try:
|
| 53 |
+
from torch.utils.checkpoint import SAC_IGNORED_OPS as _ignored_ops # type: ignore
|
| 54 |
+
|
| 55 |
+
_PT_HAS_NEW_IMPL = True
|
| 56 |
+
except ImportError:
|
| 57 |
+
from torch.utils.checkpoint import _ignored_ops # type: ignore
|
| 58 |
+
|
| 59 |
+
_PT_HAS_NEW_IMPL = False
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
_additional_ignored_ops = {
|
| 63 |
+
torch.ops.aten.lift_fresh.default,
|
| 64 |
+
torch.ops.profiler._record_function_exit._RecordFunction,
|
| 65 |
+
torch.ops.aten.clone.default, # seems needed for torch.compile
|
| 66 |
+
}
|
| 67 |
+
OPS_TO_ALWAYS_SKIP = _ignored_ops | _additional_ignored_ops
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
@dataclass
|
| 71 |
+
class ProfileMetadata:
|
| 72 |
+
name: str
|
| 73 |
+
time_taken: float
|
| 74 |
+
memory_used: float
|
| 75 |
+
curr_idx: int
|
| 76 |
+
output_ids: Any
|
| 77 |
+
inplace_info: Tuple[int, int]
|
| 78 |
+
is_view_like: bool
|
| 79 |
+
is_rand_op: bool
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def _get_default_policy(allow_list=None):
|
| 83 |
+
_default_allow_list = [
|
| 84 |
+
"xformers.efficient_attention_forward_cutlass.default",
|
| 85 |
+
"xformers_flash.flash_fwd.default",
|
| 86 |
+
"aten.addmm.default",
|
| 87 |
+
"aten.mm.default",
|
| 88 |
+
]
|
| 89 |
+
if allow_list is None:
|
| 90 |
+
allow_list = _default_allow_list
|
| 91 |
+
|
| 92 |
+
def _default_policy(ctx, func, *args, **kwargs):
|
| 93 |
+
return str(func) in allow_list
|
| 94 |
+
|
| 95 |
+
return _default_policy
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class VerboseTorchDispatchMode(TorchDispatchMode):
|
| 99 |
+
def __init__(self):
|
| 100 |
+
self.operators = []
|
| 101 |
+
|
| 102 |
+
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
| 103 |
+
if kwargs is None:
|
| 104 |
+
kwargs = {}
|
| 105 |
+
self.operators.append(func)
|
| 106 |
+
return func(*args, **kwargs)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def list_operators(function, *args, **kwargs):
|
| 110 |
+
"""
|
| 111 |
+
Returns the list of operators used inside `function` with
|
| 112 |
+
*args and **kwargs
|
| 113 |
+
"""
|
| 114 |
+
verbose_mode = VerboseTorchDispatchMode()
|
| 115 |
+
with verbose_mode:
|
| 116 |
+
function(*args, **kwargs)
|
| 117 |
+
return verbose_mode.operators
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class CachedTorchDispatchMode(_CachedTorchDispatchMode):
|
| 121 |
+
def __init__(self, policy_fn, storage, allow_cache_entry_mutation):
|
| 122 |
+
global _PT_HAS_NEW_IMPL
|
| 123 |
+
if _PT_HAS_NEW_IMPL:
|
| 124 |
+
super().__init__(policy_fn, storage, allow_cache_entry_mutation)
|
| 125 |
+
else:
|
| 126 |
+
super().__init__(policy_fn, storage)
|
| 127 |
+
|
| 128 |
+
# this is here for the old implementations
|
| 129 |
+
def pop_from_storage(self, func, args, kwargs):
|
| 130 |
+
# the autograd engine might add spurious views. This is a basic
|
| 131 |
+
# guard and should be improved
|
| 132 |
+
if self.storage[func]:
|
| 133 |
+
return self.storage[func].pop(0)
|
| 134 |
+
return func(*args, **kwargs)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class NullTorchDispatchMode(TorchDispatchMode):
|
| 138 |
+
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
| 139 |
+
if kwargs is None:
|
| 140 |
+
kwargs = {}
|
| 141 |
+
return func(*args, **kwargs)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def selective_checkpoint_context_fn(policy_fn=None):
|
| 145 |
+
"""An activation checkpoint context_fn for selectively deciding what to
|
| 146 |
+
store and what to recompute. Accepts a custom policy.
|
| 147 |
+
Args:
|
| 148 |
+
policy_fn(Union[List[Op], callable]): policy for deciding what to
|
| 149 |
+
store (instead of recompute). If it's a function, it should
|
| 150 |
+
be of form (func, *args, **kwargs) -> bool which indicates
|
| 151 |
+
if func outputs with *args and **kwargs should be stored or not.
|
| 152 |
+
Additionally, a list[Op] is also supported for easier cases.
|
| 153 |
+
The op should be in the format `torch.ops.***`, where the `***`
|
| 154 |
+
names of operators can be obtained with `list_operators`.
|
| 155 |
+
"""
|
| 156 |
+
if policy_fn is None:
|
| 157 |
+
policy_fn = _get_default_policy()
|
| 158 |
+
elif isinstance(policy_fn, list):
|
| 159 |
+
policy_fn = _get_default_policy(policy_fn)
|
| 160 |
+
else:
|
| 161 |
+
assert callable(policy_fn), "policy_fn should be None, list or a callable"
|
| 162 |
+
|
| 163 |
+
temp_storage: Dict[Any, List[Any]] = defaultdict(list)
|
| 164 |
+
# assumption: grad_mode doesn't change inside function
|
| 165 |
+
caching_mode: ContextManager[None]
|
| 166 |
+
if torch.is_grad_enabled():
|
| 167 |
+
caching_mode = _CachingTorchDispatchMode(deepcopy(policy_fn), temp_storage)
|
| 168 |
+
else:
|
| 169 |
+
caching_mode = NullTorchDispatchMode()
|
| 170 |
+
cached_mode = CachedTorchDispatchMode(deepcopy(policy_fn), temp_storage, True)
|
| 171 |
+
|
| 172 |
+
return caching_mode, cached_mode
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def checkpoint(
|
| 176 |
+
function, *args, preserve_rng_state=True, policy_fn=None, **kwargs
|
| 177 |
+
) -> Any:
|
| 178 |
+
"""Wrapper around torch.utils.checkpoint that accepts a custom policy
|
| 179 |
+
function for selectively deciding what to store and what to recompute
|
| 180 |
+
Args:
|
| 181 |
+
function: describes what to run in the forward pass of the model or
|
| 182 |
+
part of the model. It should also know how to handle the inputs
|
| 183 |
+
passed as the tuple. For example, in LSTM, if user passes
|
| 184 |
+
``(activation, hidden)``, :attr:`function` should correctly use the
|
| 185 |
+
first input as ``activation`` and the second input as ``hidden``
|
| 186 |
+
preserve_rng_state(bool, optional): Omit stashing and restoring
|
| 187 |
+
the RNG state during each checkpoint.
|
| 188 |
+
Default: ``True``
|
| 189 |
+
policy_fn(Union[List[Op], callable]): policy for deciding what to
|
| 190 |
+
store (instead of recompute). If it's a function, it should
|
| 191 |
+
be of form (func, *args, **kwargs) -> bool which indicates
|
| 192 |
+
if func outputs with *args and **kwargs should be stored or not.
|
| 193 |
+
Additionally, a list[Op] is also supported for easier cases.
|
| 194 |
+
The op should be in the format `torch.ops.***`, where the `***`
|
| 195 |
+
names of operators can be obtained with `list_operators`.
|
| 196 |
+
*args: Arguments to pass in to the given ``function``.
|
| 197 |
+
**kwargs: Keyword arguments to pass into the given ``function``.
|
| 198 |
+
"""
|
| 199 |
+
return torch.utils.checkpoint.checkpoint(
|
| 200 |
+
function,
|
| 201 |
+
*args,
|
| 202 |
+
use_reentrant=False,
|
| 203 |
+
preserve_rng_state=preserve_rng_state,
|
| 204 |
+
context_fn=functools.partial(selective_checkpoint_context_fn, policy_fn),
|
| 205 |
+
**kwargs,
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
class ProfileOperatorsTorchDispatchMode(TorchDispatchMode):
|
| 210 |
+
def __init__(self, num_runs: int = 10) -> None:
|
| 211 |
+
self.data: List[ProfileMetadata] = []
|
| 212 |
+
self.num_runs: int = num_runs
|
| 213 |
+
|
| 214 |
+
def _get_inplace_metadata(self, func, out) -> Tuple[int, int, Tuple[int, ...]]:
|
| 215 |
+
curr_idx = len(self.data)
|
| 216 |
+
|
| 217 |
+
def get_tensor_id(e):
|
| 218 |
+
return (
|
| 219 |
+
e.untyped_storage().data_ptr() if isinstance(e, torch.Tensor) else None
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
output_ids = tree_map(get_tensor_id, out)
|
| 223 |
+
if not is_inplace(func):
|
| 224 |
+
return curr_idx, output_ids, ()
|
| 225 |
+
|
| 226 |
+
op_id = curr_idx
|
| 227 |
+
op_parent_id = -1
|
| 228 |
+
for i, d in enumerate(self.data):
|
| 229 |
+
# find the first occurence of a tensor that
|
| 230 |
+
# shares the same storage as the current tensor
|
| 231 |
+
past_output_ids = d.output_ids
|
| 232 |
+
past_output_ids = (
|
| 233 |
+
[past_output_ids]
|
| 234 |
+
if not isinstance(past_output_ids, (list, tuple, dict))
|
| 235 |
+
else past_output_ids
|
| 236 |
+
)
|
| 237 |
+
if output_ids in past_output_ids:
|
| 238 |
+
op_parent_id = i
|
| 239 |
+
break
|
| 240 |
+
if op_parent_id < 0:
|
| 241 |
+
op_parent_id = op_id
|
| 242 |
+
inplace_info = (op_id, op_parent_id)
|
| 243 |
+
return curr_idx, output_ids, inplace_info
|
| 244 |
+
|
| 245 |
+
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
| 246 |
+
if kwargs is None:
|
| 247 |
+
kwargs = {}
|
| 248 |
+
out = func(*args, **kwargs)
|
| 249 |
+
|
| 250 |
+
curr_idx, output_ids, inplace_info = self._get_inplace_metadata(func, out)
|
| 251 |
+
is_view_like = is_view_fn(func) or is_inplace_view_fn(func)
|
| 252 |
+
is_rand_op = torch.Tag.nondeterministic_seeded in func.tags
|
| 253 |
+
# sdpa has non-deterministic seed, but might be deterministic
|
| 254 |
+
# if no dropout is applied
|
| 255 |
+
if func.overloadpacket.__name__ == "_scaled_dot_product_flash_attention":
|
| 256 |
+
is_rand_op = kwargs.get("dropout_p", 0) != 0
|
| 257 |
+
|
| 258 |
+
# get runtime info of func
|
| 259 |
+
torch.cuda.synchronize()
|
| 260 |
+
t = time.time()
|
| 261 |
+
for i in range(self.num_runs):
|
| 262 |
+
func(*args, **kwargs)
|
| 263 |
+
torch.cuda.synchronize()
|
| 264 |
+
time_taken = (time.time() - t) / self.num_runs
|
| 265 |
+
|
| 266 |
+
# get memory usage of func
|
| 267 |
+
torch.cuda.reset_peak_memory_stats()
|
| 268 |
+
mem1 = torch.cuda.max_memory_allocated() / 2**20
|
| 269 |
+
func(*args, **kwargs)
|
| 270 |
+
mem2 = torch.cuda.max_memory_allocated() / 2**20
|
| 271 |
+
|
| 272 |
+
self.data.append(
|
| 273 |
+
ProfileMetadata(
|
| 274 |
+
func,
|
| 275 |
+
time_taken,
|
| 276 |
+
mem2 - mem1,
|
| 277 |
+
curr_idx,
|
| 278 |
+
output_ids,
|
| 279 |
+
inplace_info,
|
| 280 |
+
is_view_like,
|
| 281 |
+
is_rand_op,
|
| 282 |
+
)
|
| 283 |
+
)
|
| 284 |
+
return out
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def _analyze_operators(function, *args) -> List[ProfileMetadata]:
|
| 288 |
+
"""
|
| 289 |
+
Use ProfileOperatorsTorchDispatchMode to get runtime and memory info.
|
| 290 |
+
|
| 291 |
+
Args:
|
| 292 |
+
function: The function to optimize which will be selectively checkpointed. Usually the forward pass
|
| 293 |
+
of the model.
|
| 294 |
+
*args: Arguments to pass in to the given ``function``.
|
| 295 |
+
|
| 296 |
+
Returns:
|
| 297 |
+
A list of tuples, where each tuples contains the name of the operator, the runtime of the operator,
|
| 298 |
+
and the memory usage of the operator.
|
| 299 |
+
|
| 300 |
+
"""
|
| 301 |
+
profile_ops = ProfileOperatorsTorchDispatchMode()
|
| 302 |
+
with profile_ops:
|
| 303 |
+
function(*args)
|
| 304 |
+
|
| 305 |
+
data = profile_ops.data
|
| 306 |
+
return data
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def get_optimal_checkpoint_policy(function, *args, memory_budget: float) -> Callable:
|
| 310 |
+
"""
|
| 311 |
+
Given a function, its arguments, and the maximum amount of memory available,
|
| 312 |
+
find the subset of operators that can be optimized to reduce runtime while still fitting within the memory budget.
|
| 313 |
+
|
| 314 |
+
Args:
|
| 315 |
+
function: The function to optimize which will be selectively checkpointed. Usually the forward pass
|
| 316 |
+
of the model.
|
| 317 |
+
*args: Arguments to pass in to the given ``function``.
|
| 318 |
+
memory_budget (float): A float between 0 and 1 which describes what percentage of the total memory to use.
|
| 319 |
+
|
| 320 |
+
Returns:
|
| 321 |
+
A callable policy which can be passed to xformers.checkpoint()
|
| 322 |
+
|
| 323 |
+
Raises:
|
| 324 |
+
RuntimeError: If `scipy` is not available.
|
| 325 |
+
ValueError: If `memory_budget` is not a float between 0 and 1.
|
| 326 |
+
|
| 327 |
+
"""
|
| 328 |
+
if not _scipy_is_available:
|
| 329 |
+
raise RuntimeError(
|
| 330 |
+
"Please install scipy 1.9.0+ to use `get_optimal_checkpoint_policy`. You can do so using "
|
| 331 |
+
"`pip install scipy`."
|
| 332 |
+
)
|
| 333 |
+
if memory_budget < 0 or memory_budget > 1:
|
| 334 |
+
raise ValueError(
|
| 335 |
+
f"`memory_budget` must be a float between 0 and 1. Got {memory_budget}."
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
data = _analyze_operators(function, *args)
|
| 339 |
+
# remove aten.detach.default from the list of ops because autograd
|
| 340 |
+
# inserts those during backward and it breaks the fwd-bwd alignment
|
| 341 |
+
data = [x for x in data if x.name not in OPS_TO_ALWAYS_SKIP]
|
| 342 |
+
|
| 343 |
+
ops, runtimes_, memory_, new_ids, _, inplace_ops_, view_like_ops_, rand_ops_ = zip(
|
| 344 |
+
*[astuple(x) for x in data]
|
| 345 |
+
)
|
| 346 |
+
runtimes = torch.tensor(runtimes_, dtype=torch.float64)
|
| 347 |
+
memory = torch.tensor(memory_, dtype=torch.float64)
|
| 348 |
+
view_like_ops = [i for i, x in enumerate(view_like_ops_) if x]
|
| 349 |
+
rand_ops = [i for i, x in enumerate(rand_ops_) if x]
|
| 350 |
+
|
| 351 |
+
# remap the inplace indices as we have removed OPS_TO_ALWAYS_SKIP
|
| 352 |
+
inplace_ops = [tuple(map(new_ids.index, x)) for x in inplace_ops_ if x]
|
| 353 |
+
|
| 354 |
+
# the last operation is always stored as the output of the checkpoint
|
| 355 |
+
# block, so we can avoid recomputing it. We set the memory to zero
|
| 356 |
+
# instead of adding a new constraint because we want both the 0 and 1
|
| 357 |
+
# endpoints for memory_budget to be valid
|
| 358 |
+
# FIXME: this heuristic for finding the last non-view non-inplace op
|
| 359 |
+
# might not always be correct, which would yield suboptimal policies
|
| 360 |
+
last_op = len(ops) - 1
|
| 361 |
+
skip_ops_ = set(view_like_ops) | set([x[0] for x in inplace_ops])
|
| 362 |
+
skip_ops = sorted(list(skip_ops_))
|
| 363 |
+
for op in reversed(skip_ops):
|
| 364 |
+
if op == last_op:
|
| 365 |
+
last_op -= 1
|
| 366 |
+
|
| 367 |
+
memory[last_op] = 0
|
| 368 |
+
|
| 369 |
+
max_memory = memory_budget * memory.sum().item()
|
| 370 |
+
|
| 371 |
+
# workaround to fix https://github.com/pytorch/pytorch/issues/121212
|
| 372 |
+
force_store_random = all([not isinstance(x, torch.Tensor) for x in args])
|
| 373 |
+
|
| 374 |
+
optim_output = _optimize_runtime_with_given_memory(
|
| 375 |
+
memory=memory,
|
| 376 |
+
runtimes=runtimes,
|
| 377 |
+
max_memory=max_memory,
|
| 378 |
+
view_like_ops=view_like_ops,
|
| 379 |
+
inplace_ops=inplace_ops,
|
| 380 |
+
random_ops=rand_ops,
|
| 381 |
+
force_store_random=force_store_random,
|
| 382 |
+
)
|
| 383 |
+
return _OptimalPolicy(optim_output=optim_output)
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
def _optimize_runtime_with_given_memory(
|
| 387 |
+
memory: torch.Tensor,
|
| 388 |
+
runtimes: torch.Tensor,
|
| 389 |
+
max_memory: float,
|
| 390 |
+
view_like_ops: List[int],
|
| 391 |
+
inplace_ops: List[Tuple[int, ...]],
|
| 392 |
+
random_ops: List[int],
|
| 393 |
+
force_store_random: bool,
|
| 394 |
+
) -> torch.Tensor:
|
| 395 |
+
"""
|
| 396 |
+
Given a list of operator names, their corresponding runtimes, and the maximum amount of memory available,
|
| 397 |
+
find the subset of operators that can be optimized to reduce runtime while still fitting within the memory budget.
|
| 398 |
+
Uses https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.milp.html
|
| 399 |
+
|
| 400 |
+
Args:
|
| 401 |
+
memory (torch.Tensor): Tensor containing the memory usage of each operator.
|
| 402 |
+
runtimes (torch.Tensor): Tensor containing the runtime of each operator.
|
| 403 |
+
max_memory (float): Maximum amount of memory to use.
|
| 404 |
+
view_like_ops ([List[int]): Indices of the view-like ops.
|
| 405 |
+
inplace_ops (List[Tuple[int, int]]): Tuple with the pair of inplace op -> parent of inplace op.
|
| 406 |
+
This will be used to add the constraint that in-place ops need to either be
|
| 407 |
+
stored in memory with the previous op, or recomputed with the previous op.
|
| 408 |
+
random_ops ([List[int]): Indices of the random ops, which will always be recomputed.
|
| 409 |
+
force_store_random (bool): force random ops to always be stored (instead of recomputed)
|
| 410 |
+
"""
|
| 411 |
+
c = -runtimes # type: ignore[operator]
|
| 412 |
+
|
| 413 |
+
memory_constraint = LinearConstraint(A=memory, ub=max_memory)
|
| 414 |
+
constraints = [memory_constraint]
|
| 415 |
+
|
| 416 |
+
# view-like ops should always be recomputed
|
| 417 |
+
for i in view_like_ops:
|
| 418 |
+
A = torch.zeros_like(c)
|
| 419 |
+
A[i] = 1
|
| 420 |
+
constraints.append(LinearConstraint(A=A, lb=0, ub=0))
|
| 421 |
+
|
| 422 |
+
# inplace ops should always be done in conjunction with its parent op
|
| 423 |
+
# i.e., if we recompute the parent op the inplace should also be
|
| 424 |
+
# recomputed, and vice versa
|
| 425 |
+
for op, op_parent in inplace_ops:
|
| 426 |
+
A = torch.zeros_like(c)
|
| 427 |
+
if op != op_parent:
|
| 428 |
+
A[op_parent] = 1
|
| 429 |
+
A[op] = -1
|
| 430 |
+
constraints.append(LinearConstraint(A=A, lb=0, ub=0))
|
| 431 |
+
else:
|
| 432 |
+
# if op == op_parent, it's because it's the first op
|
| 433 |
+
# that is inplace. Thus never recompute it
|
| 434 |
+
A[op] = 1
|
| 435 |
+
constraints.append(LinearConstraint(A=A, lb=1, ub=1))
|
| 436 |
+
|
| 437 |
+
# ideally, always recompute random ops
|
| 438 |
+
# in practice, due to a bug in https://github.com/pytorch/pytorch/issues/121212
|
| 439 |
+
# sometimes we need to store them to avoid correctness issues
|
| 440 |
+
for i in random_ops:
|
| 441 |
+
A = torch.zeros_like(c)
|
| 442 |
+
A[i] = 1
|
| 443 |
+
val = int(force_store_random)
|
| 444 |
+
constraints.append(LinearConstraint(A=A, lb=val, ub=val))
|
| 445 |
+
|
| 446 |
+
integrality = torch.ones_like(c)
|
| 447 |
+
res = milp(
|
| 448 |
+
c=c, constraints=constraints, integrality=integrality, bounds=Bounds(0, 1)
|
| 449 |
+
)
|
| 450 |
+
if not res.success:
|
| 451 |
+
raise ValueError(
|
| 452 |
+
"The problem is infeasible, and probably due to a change in xformers "
|
| 453 |
+
"that makes random ops always be stored. Try passing a larger memory_budget. "
|
| 454 |
+
"This will be fixed once https://github.com/pytorch/pytorch/issues/121212 "
|
| 455 |
+
"is solved"
|
| 456 |
+
)
|
| 457 |
+
x = torch.from_numpy(res.x)
|
| 458 |
+
return x
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
class _OptimalPolicy:
|
| 462 |
+
def __init__(self, optim_output: torch.Tensor):
|
| 463 |
+
self.counter = 0
|
| 464 |
+
self.optim_output = optim_output.tolist()
|
| 465 |
+
|
| 466 |
+
def __call__(self, ctx, func, *args, **kwargs) -> bool:
|
| 467 |
+
# returning False means recompute, True means store in memory
|
| 468 |
+
if func in OPS_TO_ALWAYS_SKIP:
|
| 469 |
+
return False
|
| 470 |
+
count = self.counter
|
| 471 |
+
self.counter += 1
|
| 472 |
+
return self.optim_output[count] == 1
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
class SelectiveCheckpointWrapper(ActivationWrapper):
|
| 476 |
+
def __init__(self, mod, memory_budget=None, policy_fn=None):
|
| 477 |
+
super().__init__(mod)
|
| 478 |
+
if not ((memory_budget is None) ^ (policy_fn is None)):
|
| 479 |
+
raise ValueError("Need to specify either policy_fn or memory_budget")
|
| 480 |
+
self.memory_budget = memory_budget
|
| 481 |
+
self.policy_fn = policy_fn
|
| 482 |
+
|
| 483 |
+
try:
|
| 484 |
+
# for backward-compatibility as this doesn't exist in PT anymore
|
| 485 |
+
torch._dynamo.config._experimental_support_context_fn_in_torch_utils_checkpoint = (
|
| 486 |
+
True
|
| 487 |
+
)
|
| 488 |
+
except AttributeError:
|
| 489 |
+
pass
|
| 490 |
+
|
| 491 |
+
@torch.compiler.disable
|
| 492 |
+
def _get_policy_fn(self, *args, **kwargs):
|
| 493 |
+
if not torch.is_grad_enabled():
|
| 494 |
+
# no need to compute a policy as it won't be used
|
| 495 |
+
return []
|
| 496 |
+
# if policy is not specified, initialize policy for a given memory budget
|
| 497 |
+
with torch.random.fork_rng():
|
| 498 |
+
policy_fn = get_optimal_checkpoint_policy(
|
| 499 |
+
self._checkpoint_wrapped_module,
|
| 500 |
+
*args,
|
| 501 |
+
**kwargs,
|
| 502 |
+
memory_budget=self.memory_budget,
|
| 503 |
+
)
|
| 504 |
+
if (
|
| 505 |
+
torch.distributed.is_available()
|
| 506 |
+
and torch.distributed.is_initialized()
|
| 507 |
+
and torch.distributed.get_world_size() > 1
|
| 508 |
+
):
|
| 509 |
+
# use the same policy across different GPUs
|
| 510 |
+
objects = [policy_fn]
|
| 511 |
+
torch.distributed.broadcast_object_list(objects, src=0)
|
| 512 |
+
policy_fn = objects[0]
|
| 513 |
+
return policy_fn
|
| 514 |
+
|
| 515 |
+
def get_policy_fn(self, *args, **kwargs):
|
| 516 |
+
if self.policy_fn is None:
|
| 517 |
+
self.policy_fn = self._get_policy_fn(*args, **kwargs)
|
| 518 |
+
return self.policy_fn
|
| 519 |
+
|
| 520 |
+
def forward(self, *args, **kwargs):
|
| 521 |
+
policy_fn = self.get_policy_fn(*args, **kwargs)
|
| 522 |
+
return checkpoint(
|
| 523 |
+
self._checkpoint_wrapped_module, *args, **kwargs, policy_fn=policy_fn
|
| 524 |
+
)
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
def selective_checkpoint_wrapper(
|
| 528 |
+
module: torch.nn.Module,
|
| 529 |
+
memory_budget: Optional[float] = None,
|
| 530 |
+
policy_fn: Optional[Callable] = None,
|
| 531 |
+
):
|
| 532 |
+
"""
|
| 533 |
+
Wrap a module with selective activation checkpointing.
|
| 534 |
+
|
| 535 |
+
It behaves similarly to PyTorch's checkpoint_wrapper, but gives the possibility
|
| 536 |
+
to the user to either specify a handcrafted policy_fn, or to let an optimization
|
| 537 |
+
algorithm to select the policy given a user-specified memory_budget.
|
| 538 |
+
|
| 539 |
+
The user should either specify the memory_budget argument or the policy_fn.
|
| 540 |
+
|
| 541 |
+
The memory_budget is a float value between 0 (recompute everything in the backward) or 1
|
| 542 |
+
(store everything for backward). Using a value of 0 should be similar to PyTorch's
|
| 543 |
+
activation checkpoint, while 1 should be similar to the behavior of not using any
|
| 544 |
+
activation checkpointing.
|
| 545 |
+
"""
|
| 546 |
+
return SelectiveCheckpointWrapper(module, memory_budget, policy_fn)
|
.venv/lib/python3.11/site-packages/xformers/cpp_lib.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"version": {"cuda": 1201, "hip": null, "torch": "2.5.1+cu121", "python": "3.11.10", "flash": "v2.6.3-24-gbdf733b", "use_torch_flash": true}, "env": {"TORCH_CUDA_ARCH_LIST": "6.0+PTX 7.0 7.5 8.0+PTX 9.0a", "PYTORCH_ROCM_ARCH": null, "XFORMERS_BUILD_TYPE": "Release", "XFORMERS_ENABLE_DEBUG_ASSERTIONS": null, "NVCC_FLAGS": "-allow-unsupported-compiler", "XFORMERS_PACKAGE_FROM": "wheel-v0.0.28.post3"}}
|
.venv/lib/python3.11/site-packages/xformers/factory/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from xformers.components import MultiHeadDispatchConfig # noqa
|
| 2 |
+
from xformers.components.attention import AttentionConfig # noqa
|
| 3 |
+
from xformers.components.feedforward import FeedforwardConfig # noqa
|
| 4 |
+
from xformers.components.positional_embedding import PositionEmbeddingConfig # noqa
|
| 5 |
+
|
| 6 |
+
from .block_factory import xFormerDecoderBlock # noqa
|
| 7 |
+
from .block_factory import xFormerDecoderConfig # noqa
|
| 8 |
+
from .block_factory import xFormerEncoderBlock # noqa
|
| 9 |
+
from .block_factory import xFormerEncoderConfig # noqa
|
| 10 |
+
from .model_factory import xFormer, xFormerConfig # noqa
|
| 11 |
+
from .weight_init import xFormerWeightInit # noqa
|
.venv/lib/python3.11/site-packages/xformers/factory/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (962 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/factory/__pycache__/block_configs.cpython-311.pyc
ADDED
|
Binary file (9.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/factory/__pycache__/block_factory.cpython-311.pyc
ADDED
|
Binary file (13.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/factory/__pycache__/hydra_helper.cpython-311.pyc
ADDED
|
Binary file (1.87 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/factory/__pycache__/model_factory.cpython-311.pyc
ADDED
|
Binary file (14 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/factory/__pycache__/weight_init.cpython-311.pyc
ADDED
|
Binary file (13 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/factory/block_configs.py
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the BSD license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from enum import Enum
|
| 9 |
+
from typing import Any, Dict, Optional
|
| 10 |
+
|
| 11 |
+
from xformers.components import NormalizationType, ResidualNormStyle
|
| 12 |
+
from xformers.components.feedforward import FEEDFORWARD_REGISTRY, FeedforwardConfig
|
| 13 |
+
from xformers.components.positional_embedding import (
|
| 14 |
+
POSITION_EMBEDDING_REGISTRY,
|
| 15 |
+
PositionEmbeddingConfig,
|
| 16 |
+
)
|
| 17 |
+
from xformers.utils import generate_matching_config
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class LayerPositionBitmask(int, Enum):
|
| 21 |
+
First = 0b01
|
| 22 |
+
Last = 0b10
|
| 23 |
+
Default = 0b11
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class LayerPosition:
|
| 27 |
+
"""Bitmask to mark this layer as first, last, nothing or both"""
|
| 28 |
+
|
| 29 |
+
def __init__(self):
|
| 30 |
+
self.bitmask = LayerPositionBitmask.Default
|
| 31 |
+
|
| 32 |
+
def is_first(self):
|
| 33 |
+
return bool(self.bitmask & LayerPositionBitmask.First)
|
| 34 |
+
|
| 35 |
+
def is_last(self):
|
| 36 |
+
return bool(self.bitmask & LayerPositionBitmask.Last)
|
| 37 |
+
|
| 38 |
+
def mark_not_first(self):
|
| 39 |
+
self.bitmask &= ~LayerPositionBitmask.First
|
| 40 |
+
|
| 41 |
+
def mark_not_last(self):
|
| 42 |
+
self.bitmask &= ~LayerPositionBitmask.Last
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class BlockType(str, Enum):
|
| 46 |
+
Encoder = "encoder"
|
| 47 |
+
Decoder = "decoder"
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@dataclass(init=False) # handle constructors explicitly to force type changes
|
| 51 |
+
class xFormerBlockConfig:
|
| 52 |
+
"""
|
| 53 |
+
The configuration structure to define a Transformer block.
|
| 54 |
+
This base class is applicable to both encoder and decoder definitions.
|
| 55 |
+
|
| 56 |
+
This completely defines each of the blocks, for instance in terms of dimensions,
|
| 57 |
+
position encoding, pre or post layer norms or reversibility.
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
dim_model: int
|
| 61 |
+
feedforward_config: FeedforwardConfig
|
| 62 |
+
position_encoding_config: Optional[PositionEmbeddingConfig]
|
| 63 |
+
block_type: BlockType
|
| 64 |
+
residual_norm_style: ResidualNormStyle
|
| 65 |
+
normalization: NormalizationType
|
| 66 |
+
layer_position: LayerPosition
|
| 67 |
+
use_triton: bool
|
| 68 |
+
reversible: bool
|
| 69 |
+
num_layers: int
|
| 70 |
+
|
| 71 |
+
def __init__(
|
| 72 |
+
self,
|
| 73 |
+
dim_model: int,
|
| 74 |
+
feedforward_config: Dict[str, Any],
|
| 75 |
+
position_encoding_config: Optional[Dict[str, Any]],
|
| 76 |
+
block_type: BlockType,
|
| 77 |
+
residual_norm_style: ResidualNormStyle = ResidualNormStyle("post"),
|
| 78 |
+
normalization: NormalizationType = NormalizationType.LayerNorm,
|
| 79 |
+
reversible: bool = False,
|
| 80 |
+
num_layers: int = 1,
|
| 81 |
+
layer_position: Optional[LayerPosition] = None,
|
| 82 |
+
):
|
| 83 |
+
|
| 84 |
+
self.dim_model = dim_model
|
| 85 |
+
self.block_type = block_type
|
| 86 |
+
self.residual_norm_style = residual_norm_style
|
| 87 |
+
self.reversible = reversible
|
| 88 |
+
self.num_layers = num_layers
|
| 89 |
+
self.normalization = normalization
|
| 90 |
+
|
| 91 |
+
# Fill in possible gaps in the config for subparts of the block
|
| 92 |
+
self.feedforward_config = generate_matching_config(
|
| 93 |
+
feedforward_config,
|
| 94 |
+
FEEDFORWARD_REGISTRY[feedforward_config["name"]].config,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
self.position_encoding_config = (
|
| 98 |
+
generate_matching_config(
|
| 99 |
+
position_encoding_config,
|
| 100 |
+
POSITION_EMBEDDING_REGISTRY[position_encoding_config["name"]].config,
|
| 101 |
+
)
|
| 102 |
+
if position_encoding_config is not None
|
| 103 |
+
else None
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
# Default is that this layer is the only one, so both first and last
|
| 107 |
+
if layer_position:
|
| 108 |
+
self.layer_position = layer_position
|
| 109 |
+
else:
|
| 110 |
+
self.layer_position = LayerPosition()
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
@dataclass(init=False)
|
| 114 |
+
class xFormerEncoderConfig(xFormerBlockConfig):
|
| 115 |
+
"""
|
| 116 |
+
The configuration structure for an encoder block
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
multi_head_config: Dict[str, Any]
|
| 120 |
+
use_triton: bool
|
| 121 |
+
simplicial_embeddings: Optional[Dict[str, Any]]
|
| 122 |
+
patch_embedding_config: Optional[Dict[str, Any]]
|
| 123 |
+
|
| 124 |
+
def __init__(
|
| 125 |
+
self,
|
| 126 |
+
dim_model: int,
|
| 127 |
+
feedforward_config: Dict[str, Any],
|
| 128 |
+
multi_head_config: Dict[str, Any],
|
| 129 |
+
position_encoding_config: Optional[Dict[str, Any]] = None,
|
| 130 |
+
residual_norm_style: str = "post",
|
| 131 |
+
normalization: NormalizationType = NormalizationType.LayerNorm,
|
| 132 |
+
use_triton: bool = True,
|
| 133 |
+
simplicial_embeddings: Optional[Dict[str, Any]] = None,
|
| 134 |
+
patch_embedding_config: Optional[Dict[str, Any]] = None,
|
| 135 |
+
**kwargs,
|
| 136 |
+
):
|
| 137 |
+
# Convenience, fill in duplicated fields
|
| 138 |
+
try:
|
| 139 |
+
if "dim_model" not in multi_head_config.keys():
|
| 140 |
+
multi_head_config["dim_model"] = dim_model
|
| 141 |
+
|
| 142 |
+
if "dim_model" not in feedforward_config.keys():
|
| 143 |
+
feedforward_config["dim_model"] = dim_model
|
| 144 |
+
|
| 145 |
+
if (
|
| 146 |
+
position_encoding_config is not None
|
| 147 |
+
and "dim_model" not in position_encoding_config.keys()
|
| 148 |
+
):
|
| 149 |
+
position_encoding_config["dim_model"] = dim_model
|
| 150 |
+
|
| 151 |
+
if (
|
| 152 |
+
patch_embedding_config is not None
|
| 153 |
+
and "out_channels" not in patch_embedding_config.keys()
|
| 154 |
+
):
|
| 155 |
+
patch_embedding_config["out_channels"] = dim_model
|
| 156 |
+
|
| 157 |
+
except AttributeError:
|
| 158 |
+
# A config instance was passed in, this is fine
|
| 159 |
+
pass
|
| 160 |
+
if "block_type" in kwargs:
|
| 161 |
+
assert kwargs["block_type"] == "encoder"
|
| 162 |
+
kwargs["block_type"] = BlockType("encoder")
|
| 163 |
+
super().__init__(
|
| 164 |
+
dim_model=dim_model,
|
| 165 |
+
feedforward_config=feedforward_config,
|
| 166 |
+
position_encoding_config=position_encoding_config,
|
| 167 |
+
residual_norm_style=ResidualNormStyle(residual_norm_style),
|
| 168 |
+
normalization=NormalizationType(normalization),
|
| 169 |
+
**kwargs,
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
self.multi_head_config = multi_head_config
|
| 173 |
+
self.use_triton = use_triton
|
| 174 |
+
self.simplicial_embeddings = simplicial_embeddings
|
| 175 |
+
self.patch_embedding_config = patch_embedding_config
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
@dataclass(init=False)
|
| 179 |
+
class xFormerDecoderConfig(xFormerBlockConfig):
|
| 180 |
+
"""
|
| 181 |
+
The configuration structure for a decoder block.
|
| 182 |
+
|
| 183 |
+
This specifically defines the masked and cross attention mechanisms,
|
| 184 |
+
on top of the settings defining all blocks.
|
| 185 |
+
"""
|
| 186 |
+
|
| 187 |
+
multi_head_config_masked: Dict[str, Any] # prior to encoder output
|
| 188 |
+
multi_head_config_cross: Dict[str, Any] # cross attention, takes encoder output
|
| 189 |
+
|
| 190 |
+
def __init__(
|
| 191 |
+
self,
|
| 192 |
+
dim_model: int,
|
| 193 |
+
feedforward_config: Dict[str, Any],
|
| 194 |
+
multi_head_config_masked: Dict[str, Any],
|
| 195 |
+
multi_head_config_cross: Dict[str, Any],
|
| 196 |
+
position_encoding_config: Optional[Dict[str, Any]] = None,
|
| 197 |
+
residual_norm_style: str = "post",
|
| 198 |
+
normalization: NormalizationType = NormalizationType.LayerNorm,
|
| 199 |
+
use_triton: bool = True,
|
| 200 |
+
**kwargs,
|
| 201 |
+
):
|
| 202 |
+
|
| 203 |
+
# Convenience, fill in duplicated field
|
| 204 |
+
try:
|
| 205 |
+
if "dim_model" not in multi_head_config_masked.keys():
|
| 206 |
+
multi_head_config_masked["dim_model"] = dim_model
|
| 207 |
+
|
| 208 |
+
if "dim_model" not in multi_head_config_cross.keys():
|
| 209 |
+
multi_head_config_cross["dim_model"] = dim_model
|
| 210 |
+
|
| 211 |
+
if "dim_model" not in feedforward_config.keys():
|
| 212 |
+
feedforward_config["dim_model"] = dim_model
|
| 213 |
+
|
| 214 |
+
if (
|
| 215 |
+
position_encoding_config is not None
|
| 216 |
+
and "dim_model" not in position_encoding_config.keys()
|
| 217 |
+
):
|
| 218 |
+
position_encoding_config["dim_model"] = dim_model
|
| 219 |
+
except AttributeError:
|
| 220 |
+
# A config instance was passed in, this is fine
|
| 221 |
+
pass
|
| 222 |
+
if "block_type" in kwargs.keys():
|
| 223 |
+
assert kwargs["block_type"] == "decoder"
|
| 224 |
+
kwargs["block_type"] = BlockType("decoder")
|
| 225 |
+
|
| 226 |
+
super().__init__(
|
| 227 |
+
dim_model=dim_model,
|
| 228 |
+
feedforward_config=feedforward_config,
|
| 229 |
+
position_encoding_config=position_encoding_config,
|
| 230 |
+
residual_norm_style=ResidualNormStyle(residual_norm_style),
|
| 231 |
+
normalization=NormalizationType(normalization),
|
| 232 |
+
**kwargs,
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
self.multi_head_config_masked = multi_head_config_masked
|
| 236 |
+
self.multi_head_config_cross = multi_head_config_cross
|
| 237 |
+
self.use_triton = use_triton
|
.venv/lib/python3.11/site-packages/xformers/factory/block_factory.py
ADDED
|
@@ -0,0 +1,358 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the BSD license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
from dataclasses import asdict
|
| 9 |
+
from typing import Optional, Tuple, Union
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
|
| 14 |
+
from xformers._deprecation_warning import deprecated_function
|
| 15 |
+
from xformers.components import (
|
| 16 |
+
PatchEmbeddingConfig,
|
| 17 |
+
PostNorm,
|
| 18 |
+
PreNorm,
|
| 19 |
+
Residual,
|
| 20 |
+
ResidualNormStyle,
|
| 21 |
+
build_multi_head_attention,
|
| 22 |
+
build_patch_embedding,
|
| 23 |
+
)
|
| 24 |
+
from xformers.components.attention import AttentionMask
|
| 25 |
+
from xformers.components.feedforward import build_feedforward
|
| 26 |
+
from xformers.components.positional_embedding import build_positional_embedding
|
| 27 |
+
from xformers.components.residual import get_deepnorm_coefficients
|
| 28 |
+
from xformers.components.simplicial_embedding import SimplicialEmbedding
|
| 29 |
+
from xformers.factory.block_configs import (
|
| 30 |
+
NormalizationType,
|
| 31 |
+
xFormerDecoderConfig,
|
| 32 |
+
xFormerEncoderConfig,
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
logger = logging.getLogger("xformers")
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _get_ln_factory(
|
| 39 |
+
d_model: int,
|
| 40 |
+
residual_norm_style: Optional[ResidualNormStyle],
|
| 41 |
+
use_triton: bool,
|
| 42 |
+
residual: bool,
|
| 43 |
+
normalization: NormalizationType = NormalizationType.LayerNorm,
|
| 44 |
+
residual_scale: float = 1.0,
|
| 45 |
+
):
|
| 46 |
+
"""
|
| 47 |
+
Handle all the supported residual path configurations.
|
| 48 |
+
|
| 49 |
+
..Note: we return the appropriate constructor, not an actual layer
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
def get_layer_wrapper(
|
| 53 |
+
d_model: int,
|
| 54 |
+
sublayer: nn.Module,
|
| 55 |
+
residual_norm_style: Optional[ResidualNormStyle],
|
| 56 |
+
residual: bool,
|
| 57 |
+
residual_scale: float,
|
| 58 |
+
):
|
| 59 |
+
if residual:
|
| 60 |
+
if residual_norm_style == ResidualNormStyle.Pre:
|
| 61 |
+
return Residual(
|
| 62 |
+
layer=PreNorm(d_model, sublayer, normalization, use_triton),
|
| 63 |
+
scale=None,
|
| 64 |
+
)
|
| 65 |
+
elif residual_norm_style == ResidualNormStyle.Post:
|
| 66 |
+
return PostNorm(
|
| 67 |
+
d_model,
|
| 68 |
+
Residual(layer=sublayer, scale=None),
|
| 69 |
+
normalization,
|
| 70 |
+
use_triton,
|
| 71 |
+
)
|
| 72 |
+
elif residual_norm_style == ResidualNormStyle.DeepNorm:
|
| 73 |
+
return PostNorm(
|
| 74 |
+
d_model,
|
| 75 |
+
Residual(layer=sublayer, scale=residual_scale),
|
| 76 |
+
normalization,
|
| 77 |
+
use_triton=use_triton,
|
| 78 |
+
)
|
| 79 |
+
else:
|
| 80 |
+
raise ValueError
|
| 81 |
+
|
| 82 |
+
return (
|
| 83 |
+
PreNorm(d_model, sublayer, normalization, use_triton)
|
| 84 |
+
if residual_norm_style == ResidualNormStyle.Pre
|
| 85 |
+
else PostNorm(d_model, sublayer, normalization, use_triton)
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
def ln_factory(sublayer: nn.Module):
|
| 89 |
+
return get_layer_wrapper(
|
| 90 |
+
d_model, sublayer, residual_norm_style, residual, residual_scale
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
return ln_factory
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class xFormerEncoderBlock(torch.nn.Module):
|
| 97 |
+
r"""A vanilla Transformer Encoder block"""
|
| 98 |
+
|
| 99 |
+
def __init__(self, config: xFormerEncoderConfig, **kwargs):
|
| 100 |
+
super().__init__()
|
| 101 |
+
deprecated_function(self)
|
| 102 |
+
|
| 103 |
+
self.reversible_f = None
|
| 104 |
+
self.reversible_g = None
|
| 105 |
+
self.residual_norm_style = config.residual_norm_style
|
| 106 |
+
self.dim_model = config.dim_model
|
| 107 |
+
|
| 108 |
+
# If this layer is the first one, and a pose encoding has been requested
|
| 109 |
+
if (
|
| 110 |
+
config.position_encoding_config is not None
|
| 111 |
+
and config.layer_position.is_first()
|
| 112 |
+
):
|
| 113 |
+
self.pose_encoding = build_positional_embedding(
|
| 114 |
+
asdict(config.position_encoding_config)
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
pos_encoding_dim = config.position_encoding_config.dim_model
|
| 118 |
+
mha_dim = config.multi_head_config["dim_model"]
|
| 119 |
+
|
| 120 |
+
if pos_encoding_dim != mha_dim:
|
| 121 |
+
logger.warning(
|
| 122 |
+
f"The embedding dim and model dim do not match ({pos_encoding_dim} vs {mha_dim}), adding a projector layer." # noqa
|
| 123 |
+
)
|
| 124 |
+
self.embedding_projector = nn.Linear(pos_encoding_dim, mha_dim)
|
| 125 |
+
else:
|
| 126 |
+
self.pose_encoding = None
|
| 127 |
+
|
| 128 |
+
if config.residual_norm_style == ResidualNormStyle.DeepNorm:
|
| 129 |
+
# Just use the layer norm coefficient here,
|
| 130 |
+
# the init will be handled at the xformers level (knows about encoder and decoder blocks)
|
| 131 |
+
deep_norm_coefficients, _ = get_deepnorm_coefficients(
|
| 132 |
+
encoder_layers=config.num_layers, decoder_layers=0
|
| 133 |
+
)
|
| 134 |
+
assert deep_norm_coefficients is not None
|
| 135 |
+
residual_scale = deep_norm_coefficients.alpha
|
| 136 |
+
else:
|
| 137 |
+
residual_scale = 1.0
|
| 138 |
+
|
| 139 |
+
# mini helper, builds a normalization layer with the right Pre/Post config, residuals, and the right dimensions
|
| 140 |
+
ln_factory = _get_ln_factory(
|
| 141 |
+
config.dim_model,
|
| 142 |
+
config.residual_norm_style,
|
| 143 |
+
use_triton=config.use_triton,
|
| 144 |
+
residual=True,
|
| 145 |
+
residual_scale=residual_scale,
|
| 146 |
+
normalization=config.normalization,
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
mha = build_multi_head_attention(config.multi_head_config)
|
| 150 |
+
feedforward = build_feedforward(asdict(config.feedforward_config))
|
| 151 |
+
|
| 152 |
+
# Expose attention specific capabilities
|
| 153 |
+
self.supports_attention_mask = mha.attention.supports_attention_mask
|
| 154 |
+
self.requires_same_k_q_dimensions = mha.attention.requires_same_k_q_dimensions
|
| 155 |
+
self.causal = (
|
| 156 |
+
mha.attention.causal if hasattr(mha.attention, "causal") else False
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
# Wrappers handle the different layer norm styles (pre- and post-) and the residual path
|
| 160 |
+
self.wrap_att = ln_factory(mha)
|
| 161 |
+
self.wrap_ff: Union[Residual, PostNorm] = ln_factory(feedforward)
|
| 162 |
+
if (
|
| 163 |
+
config.residual_norm_style == ResidualNormStyle.Pre
|
| 164 |
+
and config.layer_position.is_last()
|
| 165 |
+
):
|
| 166 |
+
self.wrap_ff = PostNorm(
|
| 167 |
+
config.dim_model,
|
| 168 |
+
self.wrap_ff,
|
| 169 |
+
normalization=config.normalization,
|
| 170 |
+
use_triton=config.use_triton,
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
# Simplicial embeddings are only used if specified, and on the last layer
|
| 174 |
+
self.simplicial_embedding: Optional[SimplicialEmbedding] = None
|
| 175 |
+
if config.simplicial_embeddings is not None and config.layer_position.is_last():
|
| 176 |
+
self.simplicial_embedding = SimplicialEmbedding(
|
| 177 |
+
**config.simplicial_embeddings
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
# Optional patch embedding
|
| 181 |
+
self.patch_emb: Optional[nn.Module] = None
|
| 182 |
+
|
| 183 |
+
if config.patch_embedding_config is not None:
|
| 184 |
+
self.patch_emb = build_patch_embedding(
|
| 185 |
+
PatchEmbeddingConfig(**config.patch_embedding_config)
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
@classmethod
|
| 189 |
+
def from_config(cls, config: xFormerEncoderConfig):
|
| 190 |
+
return cls(config)
|
| 191 |
+
|
| 192 |
+
@staticmethod
|
| 193 |
+
def get_reversible_layer(config) -> Tuple[nn.Module, nn.Module]:
|
| 194 |
+
ln_factory = _get_ln_factory(
|
| 195 |
+
config.dim_model,
|
| 196 |
+
config.residual_norm_style,
|
| 197 |
+
residual=False,
|
| 198 |
+
use_triton=config.use_triton,
|
| 199 |
+
normalization=config.normalization,
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
mha = build_multi_head_attention(config.multi_head_config)
|
| 203 |
+
feedforward = build_feedforward(asdict(config.feedforward_config))
|
| 204 |
+
|
| 205 |
+
reversible_f = ln_factory(mha)
|
| 206 |
+
reversible_g = ln_factory(feedforward)
|
| 207 |
+
return reversible_f, reversible_g
|
| 208 |
+
|
| 209 |
+
def forward(
|
| 210 |
+
self,
|
| 211 |
+
x: torch.Tensor,
|
| 212 |
+
att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None,
|
| 213 |
+
input_mask: Optional[torch.Tensor] = None,
|
| 214 |
+
):
|
| 215 |
+
if self.patch_emb is not None:
|
| 216 |
+
x = self.patch_emb(x)
|
| 217 |
+
|
| 218 |
+
if self.pose_encoding is not None:
|
| 219 |
+
x = self.pose_encoding(x)
|
| 220 |
+
|
| 221 |
+
if hasattr(self, "embedding_projector"):
|
| 222 |
+
x = self.embedding_projector(x)
|
| 223 |
+
|
| 224 |
+
# Handle the optional input masking, differs on Q, K, V
|
| 225 |
+
if input_mask is not None:
|
| 226 |
+
q = x
|
| 227 |
+
k = x * input_mask.unsqueeze(-1)
|
| 228 |
+
v = k
|
| 229 |
+
else:
|
| 230 |
+
q, k, v = x, x, x
|
| 231 |
+
|
| 232 |
+
# Pre/Post norms and residual paths are already handled
|
| 233 |
+
x = self.wrap_att(inputs=[q, k, v], att_mask=att_mask)
|
| 234 |
+
x = self.wrap_ff(inputs=[x])
|
| 235 |
+
|
| 236 |
+
# Optional simplicial embeddings
|
| 237 |
+
if self.simplicial_embedding is not None:
|
| 238 |
+
x = self.simplicial_embedding(x)
|
| 239 |
+
|
| 240 |
+
return x
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
class xFormerDecoderBlock(torch.nn.Module):
|
| 244 |
+
r"""A vanilla Transformer Decoder block
|
| 245 |
+
|
| 246 |
+
... note: this implementation is not (yet ?) reversible"""
|
| 247 |
+
|
| 248 |
+
def __init__(self, config: xFormerDecoderConfig, **kwargs):
|
| 249 |
+
super().__init__()
|
| 250 |
+
deprecated_function(self)
|
| 251 |
+
|
| 252 |
+
# If this layer is the first one, and a pose encoding as been requested
|
| 253 |
+
if (
|
| 254 |
+
config.position_encoding_config is not None
|
| 255 |
+
and config.layer_position.is_first()
|
| 256 |
+
):
|
| 257 |
+
self.pose_encoding = build_positional_embedding(
|
| 258 |
+
config.position_encoding_config
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
pos_encoding_dim = config.position_encoding_config.dim_model
|
| 262 |
+
mha_dim = config.multi_head_config_masked["dim_model"]
|
| 263 |
+
|
| 264 |
+
if pos_encoding_dim != mha_dim:
|
| 265 |
+
|
| 266 |
+
logger.warning(
|
| 267 |
+
f"The embedding dim and model dim do not match ({pos_encoding_dim} vs {mha_dim}), adding a projector layer." # noqa
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
self.embedding_projector = nn.Linear(pos_encoding_dim, mha_dim)
|
| 271 |
+
else:
|
| 272 |
+
self.pose_encoding = None
|
| 273 |
+
|
| 274 |
+
if config.residual_norm_style == ResidualNormStyle.DeepNorm:
|
| 275 |
+
# Just use the layer norm coefficient here,
|
| 276 |
+
# the init will be handled at the xformers level (knows about encoder and decoder blocks)
|
| 277 |
+
_, deep_norm_coefficients = get_deepnorm_coefficients(
|
| 278 |
+
encoder_layers=0, decoder_layers=config.num_layers
|
| 279 |
+
)
|
| 280 |
+
assert deep_norm_coefficients is not None
|
| 281 |
+
residual_scale = deep_norm_coefficients.alpha
|
| 282 |
+
else:
|
| 283 |
+
residual_scale = 1.0
|
| 284 |
+
|
| 285 |
+
# mini helper, builds a LayerNorm with the right Pre/Post config and the right dimensions
|
| 286 |
+
ln_factory = _get_ln_factory(
|
| 287 |
+
config.dim_model,
|
| 288 |
+
config.residual_norm_style,
|
| 289 |
+
use_triton=config.use_triton,
|
| 290 |
+
residual=True,
|
| 291 |
+
residual_scale=residual_scale,
|
| 292 |
+
normalization=config.normalization,
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
mha = build_multi_head_attention(config.multi_head_config_masked)
|
| 296 |
+
cross_mha = build_multi_head_attention(config.multi_head_config_cross)
|
| 297 |
+
feedforward = build_feedforward(config.feedforward_config)
|
| 298 |
+
|
| 299 |
+
# Expose attention or feedforward specific capabilities
|
| 300 |
+
self.supports_attention_mask = mha.attention.supports_attention_mask
|
| 301 |
+
self.requires_same_k_q_dimensions = mha.attention.requires_same_k_q_dimensions
|
| 302 |
+
self.requires_squared_context_length = (
|
| 303 |
+
feedforward.requires_squared_context
|
| 304 |
+
or mha.attention.requires_squared_context
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
self.causal_attention = (
|
| 308 |
+
mha.attention.causal if hasattr(mha.attention, "causal") else False
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
# Wrappers handle the different layer norm styles (pre- and post-) and the residual path
|
| 312 |
+
self.wrap_att = ln_factory(mha)
|
| 313 |
+
self.wrap_cross = ln_factory(cross_mha)
|
| 314 |
+
self.wrap_ff: Union[Residual, PostNorm] = ln_factory(feedforward)
|
| 315 |
+
|
| 316 |
+
if (
|
| 317 |
+
config.residual_norm_style == ResidualNormStyle.Pre
|
| 318 |
+
and config.layer_position.is_last()
|
| 319 |
+
):
|
| 320 |
+
self.wrap_ff = PostNorm(
|
| 321 |
+
config.dim_model,
|
| 322 |
+
self.wrap_ff,
|
| 323 |
+
normalization=NormalizationType.LayerNorm,
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
@classmethod
|
| 327 |
+
def from_config(cls, config: xFormerDecoderConfig):
|
| 328 |
+
return cls(config)
|
| 329 |
+
|
| 330 |
+
def forward(
|
| 331 |
+
self,
|
| 332 |
+
target: torch.Tensor,
|
| 333 |
+
memory: torch.Tensor,
|
| 334 |
+
encoder_att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None,
|
| 335 |
+
decoder_att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None,
|
| 336 |
+
input_mask: Optional[torch.Tensor] = None,
|
| 337 |
+
):
|
| 338 |
+
if self.pose_encoding is not None:
|
| 339 |
+
target = self.pose_encoding(target)
|
| 340 |
+
|
| 341 |
+
if hasattr(self, "embedding_projector"):
|
| 342 |
+
target = self.embedding_projector(target)
|
| 343 |
+
|
| 344 |
+
# Handle the optional input masking, differs on Q, K, V
|
| 345 |
+
if input_mask is not None:
|
| 346 |
+
target_q = target
|
| 347 |
+
target_k = target * input_mask.unsqueeze(-1)
|
| 348 |
+
target_v = target_k
|
| 349 |
+
else:
|
| 350 |
+
target_q, target_k, target_v = target, target, target
|
| 351 |
+
|
| 352 |
+
x = self.wrap_att(
|
| 353 |
+
inputs=[target_q, target_k, target_v], att_mask=decoder_att_mask
|
| 354 |
+
)
|
| 355 |
+
x = self.wrap_cross(inputs=[x, memory, memory], att_mask=encoder_att_mask)
|
| 356 |
+
x = self.wrap_ff(inputs=[x])
|
| 357 |
+
|
| 358 |
+
return x
|
.venv/lib/python3.11/site-packages/xformers/factory/hydra_helper.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the BSD license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# register components configs into Hydra ConfigStore
|
| 7 |
+
# component config classes could be used to validate configs
|
| 8 |
+
import logging
|
| 9 |
+
|
| 10 |
+
from hydra.core.config_store import ConfigStore
|
| 11 |
+
from omegaconf.errors import ValidationError
|
| 12 |
+
|
| 13 |
+
from xformers.components.attention import ATTENTION_REGISTRY
|
| 14 |
+
from xformers.components.feedforward import FEEDFORWARD_REGISTRY
|
| 15 |
+
from xformers.components.positional_embedding import POSITION_EMBEDDING_REGISTRY
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger("xformers")
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def import_xformer_config_schema():
|
| 21 |
+
"""
|
| 22 |
+
Best effort - OmegaConf supports limited typing, so we may fail to import
|
| 23 |
+
certain config classes. For example, pytorch typing are not supported.
|
| 24 |
+
"""
|
| 25 |
+
cs = ConfigStore.instance()
|
| 26 |
+
|
| 27 |
+
for k, v in {
|
| 28 |
+
"ff": FEEDFORWARD_REGISTRY,
|
| 29 |
+
"pe": POSITION_EMBEDDING_REGISTRY,
|
| 30 |
+
"attention": ATTENTION_REGISTRY,
|
| 31 |
+
}.items():
|
| 32 |
+
for kk in v.keys():
|
| 33 |
+
try:
|
| 34 |
+
cs.store(name=f"{kk}_schema", node=v[kk].config, group=f"xformers/{k}")
|
| 35 |
+
except ValidationError as e:
|
| 36 |
+
logger.debug(f"Error registering {kk}_schema, error: {e}")
|
.venv/lib/python3.11/site-packages/xformers/factory/model_factory.py
ADDED
|
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the BSD license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from typing import Any, Dict, List, Optional, Union
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
from xformers._deprecation_warning import deprecated_function
|
| 14 |
+
from xformers.components import reversible as rv
|
| 15 |
+
from xformers.components.residual import ResidualNormStyle, get_deepnorm_coefficients
|
| 16 |
+
from xformers.factory.block_configs import (
|
| 17 |
+
xFormerBlockConfig,
|
| 18 |
+
xFormerDecoderConfig,
|
| 19 |
+
xFormerEncoderConfig,
|
| 20 |
+
)
|
| 21 |
+
from xformers.factory.block_factory import xFormerDecoderBlock, xFormerEncoderBlock
|
| 22 |
+
from xformers.factory.weight_init import get_weight_init_fn, xFormerWeightInit
|
| 23 |
+
|
| 24 |
+
logger = logging.getLogger("xformers")
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@dataclass(init=False)
|
| 28 |
+
class xFormerConfig:
|
| 29 |
+
"""
|
| 30 |
+
The configuration structure to define a full Transformer.
|
| 31 |
+
This can include a stack of encoder layers, and a stack of decoder layers.
|
| 32 |
+
|
| 33 |
+
It is optionally possible to share the embedding weights in between
|
| 34 |
+
the encoder and decoder positional encoding, as proposed for instance by
|
| 35 |
+
`Using the Output Embedding to Improve Language Models`, Press et al.
|
| 36 |
+
|
| 37 |
+
A full config example is for instance as follows:
|
| 38 |
+
|
| 39 |
+
::
|
| 40 |
+
|
| 41 |
+
xformer_config = [
|
| 42 |
+
{
|
| 43 |
+
"reversible": False, # Turn on to test the effect of using reversible layers
|
| 44 |
+
"block_type": "encoder",
|
| 45 |
+
"num_layers": LAYERS,
|
| 46 |
+
"dim_model": EMB,
|
| 47 |
+
"residual_norm_style": "pre",
|
| 48 |
+
"position_encoding_config": {
|
| 49 |
+
"name": "vocab",
|
| 50 |
+
"seq_len": CONTEXT,
|
| 51 |
+
"vocab_size": VOCAB_SIZE,
|
| 52 |
+
},
|
| 53 |
+
"multi_head_config": {
|
| 54 |
+
"num_heads": NUM_HEADS,
|
| 55 |
+
"residual_dropout": RES_DROP,
|
| 56 |
+
"use_rotary_embeddings": True,
|
| 57 |
+
"attention": {
|
| 58 |
+
"name": ATTENTION_MECHANISM_STR,
|
| 59 |
+
"dropout": ATTN_DROP,
|
| 60 |
+
"causal": True,
|
| 61 |
+
"seq_len": CONTEXT,
|
| 62 |
+
},
|
| 63 |
+
},
|
| 64 |
+
"feedforward_config": {
|
| 65 |
+
"name": "MLP",
|
| 66 |
+
"dropout": MLP_DROP,
|
| 67 |
+
"activation": "gelu",
|
| 68 |
+
"hidden_layer_multiplier": MLP_MULTIPLIER,
|
| 69 |
+
},
|
| 70 |
+
}
|
| 71 |
+
]
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
.. _`Using the Output Embedding to Improve Language Models`: https://arxiv.org/pdf/1608.05859.pdf
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
stack_configs: Union[List[xFormerBlockConfig], Dict[str, xFormerBlockConfig]]
|
| 78 |
+
tie_embedding_weights: bool = False
|
| 79 |
+
weight_init: xFormerWeightInit = xFormerWeightInit.ViT
|
| 80 |
+
|
| 81 |
+
def __init__(
|
| 82 |
+
self,
|
| 83 |
+
stack_configs: Union[List[Dict[str, Any]], Dict[str, Dict[str, Any]]],
|
| 84 |
+
tie_embedding_weights: bool = False,
|
| 85 |
+
weight_init: xFormerWeightInit = xFormerWeightInit.ViT,
|
| 86 |
+
):
|
| 87 |
+
# Type all the configurations. Possible typos are caught here
|
| 88 |
+
if isinstance(stack_configs, dict):
|
| 89 |
+
self.stack_configs = {}
|
| 90 |
+
for k, config in stack_configs.items():
|
| 91 |
+
if config["block_type"] == "encoder":
|
| 92 |
+
self.stack_configs[k] = xFormerEncoderConfig(**config)
|
| 93 |
+
else:
|
| 94 |
+
self.stack_configs[k] = xFormerDecoderConfig(**config)
|
| 95 |
+
else:
|
| 96 |
+
self.stack_configs = []
|
| 97 |
+
for config in stack_configs:
|
| 98 |
+
if config["block_type"] == "encoder":
|
| 99 |
+
self.stack_configs.append(xFormerEncoderConfig(**config))
|
| 100 |
+
else:
|
| 101 |
+
self.stack_configs.append(xFormerDecoderConfig(**config))
|
| 102 |
+
|
| 103 |
+
self.tie_embedding_weights = tie_embedding_weights
|
| 104 |
+
self.weight_init = weight_init
|
| 105 |
+
deprecated_function(self)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class xFormer(torch.nn.Module):
|
| 109 |
+
def __init__(
|
| 110 |
+
self,
|
| 111 |
+
stack_configs: Union[
|
| 112 |
+
xFormerBlockConfig, List[xFormerBlockConfig], Dict[str, xFormerBlockConfig]
|
| 113 |
+
],
|
| 114 |
+
tie_embedding_weights: bool = False,
|
| 115 |
+
weight_init: xFormerWeightInit = xFormerWeightInit.ViT,
|
| 116 |
+
):
|
| 117 |
+
"""
|
| 118 |
+
Given a serialized configuration, generate the corresponding model.
|
| 119 |
+
This is only a helper and can easily be bypassed
|
| 120 |
+
"""
|
| 121 |
+
super().__init__()
|
| 122 |
+
deprecated_function(self)
|
| 123 |
+
|
| 124 |
+
if isinstance(stack_configs, Dict):
|
| 125 |
+
stack_configs = list(stack_configs.values())
|
| 126 |
+
|
| 127 |
+
# Convenience, users can pass either a list of configs or a single one
|
| 128 |
+
if not isinstance(stack_configs, List):
|
| 129 |
+
stack_configs = [stack_configs]
|
| 130 |
+
|
| 131 |
+
# Sanity checks, some config combinations do not make sense
|
| 132 |
+
self._verify_reversible(stack_configs)
|
| 133 |
+
self._verify_deepnorm(stack_configs)
|
| 134 |
+
|
| 135 |
+
encoders: List[torch.nn.Module] = []
|
| 136 |
+
decoders: List[torch.nn.Module] = []
|
| 137 |
+
|
| 138 |
+
self.reversible_encoder = False
|
| 139 |
+
self.rev_enc_pose_encoding = None
|
| 140 |
+
|
| 141 |
+
# Unroll the configs and build the model
|
| 142 |
+
for config in stack_configs:
|
| 143 |
+
# Handle either Encoder or Decoder stacks
|
| 144 |
+
builder = (
|
| 145 |
+
xFormerEncoderBlock.from_config
|
| 146 |
+
if isinstance(config, xFormerEncoderConfig)
|
| 147 |
+
else xFormerDecoderBlock.from_config
|
| 148 |
+
)
|
| 149 |
+
recipient = (
|
| 150 |
+
encoders if isinstance(config, xFormerEncoderConfig) else decoders
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
# Build up the stack
|
| 154 |
+
for i in range(config.num_layers):
|
| 155 |
+
# Label where this layer is in the stack
|
| 156 |
+
# (for instance useful for the positional encoding, or late layer norm)
|
| 157 |
+
if len(recipient) > 0:
|
| 158 |
+
config.layer_position.mark_not_first()
|
| 159 |
+
|
| 160 |
+
if config != stack_configs[-1] or i < config.num_layers - 1:
|
| 161 |
+
config.layer_position.mark_not_last()
|
| 162 |
+
|
| 163 |
+
block = builder(config) # type: ignore
|
| 164 |
+
|
| 165 |
+
# If reversible: extract the reversible sub-parts, else append the block as-is
|
| 166 |
+
if config.reversible:
|
| 167 |
+
# WARNING: only one pose encoding is saved here (not Focal Transformer compatible for instance)
|
| 168 |
+
assert isinstance(config, xFormerEncoderConfig)
|
| 169 |
+
if block.pose_encoding is not None:
|
| 170 |
+
self.rev_enc_pose_encoding = block.pose_encoding
|
| 171 |
+
self.reversible_encoder = True
|
| 172 |
+
|
| 173 |
+
f, g = xFormerEncoderBlock.get_reversible_layer(config)
|
| 174 |
+
recipient.append(torch.nn.ModuleList([f, g]))
|
| 175 |
+
else:
|
| 176 |
+
recipient.append(block) # type: ignore
|
| 177 |
+
|
| 178 |
+
# Tie embedding weights, if requested and possible
|
| 179 |
+
assert (
|
| 180 |
+
not tie_embedding_weights or not self.reversible_encoder
|
| 181 |
+
), "Reversible layers and tied embeddings is not supported for now"
|
| 182 |
+
|
| 183 |
+
if (
|
| 184 |
+
tie_embedding_weights
|
| 185 |
+
and encoders
|
| 186 |
+
and encoders[0].pose_encoding
|
| 187 |
+
and decoders
|
| 188 |
+
and decoders[0].pose_encoding
|
| 189 |
+
and not config.reversible
|
| 190 |
+
):
|
| 191 |
+
logger.info("Tying encoder and decoder embeddings, as requested")
|
| 192 |
+
encoders[0].pose_encoding = decoders[0].pose_encoding
|
| 193 |
+
|
| 194 |
+
self.encoders: torch.nn.Module = (
|
| 195 |
+
rv.ReversibleSequence(torch.nn.ModuleList(encoders))
|
| 196 |
+
if self.reversible_encoder
|
| 197 |
+
else torch.nn.ModuleList(encoders)
|
| 198 |
+
)
|
| 199 |
+
self.decoders = torch.nn.ModuleList(decoders)
|
| 200 |
+
|
| 201 |
+
use_deepnorm = (
|
| 202 |
+
stack_configs[0].residual_norm_style == ResidualNormStyle.DeepNorm
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
assert (
|
| 206 |
+
not use_deepnorm or not self.reversible_encoder
|
| 207 |
+
), "Reversible layers and deepnorm is not supported for now"
|
| 208 |
+
|
| 209 |
+
self.init_weights(weight_init=weight_init, use_deep_norm=use_deepnorm)
|
| 210 |
+
|
| 211 |
+
@classmethod
|
| 212 |
+
def from_config(cls, config: xFormerConfig):
|
| 213 |
+
return cls(
|
| 214 |
+
config.stack_configs, config.tie_embedding_weights, config.weight_init
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
def _verify_reversible(self, stack_configs: List[xFormerBlockConfig]):
|
| 218 |
+
reversible = [
|
| 219 |
+
c.reversible
|
| 220 |
+
for c in filter(lambda x: x.block_type == "encoder", stack_configs)
|
| 221 |
+
]
|
| 222 |
+
|
| 223 |
+
assert all(reversible) or not any(reversible), (
|
| 224 |
+
"All layers need to have the same reversibility setting. "
|
| 225 |
+
+ f"Currently {reversible}"
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
def _verify_deepnorm(self, stack_configs: List[xFormerBlockConfig]):
|
| 229 |
+
deepnorm = [
|
| 230 |
+
c.residual_norm_style == ResidualNormStyle.DeepNorm for c in stack_configs
|
| 231 |
+
]
|
| 232 |
+
|
| 233 |
+
assert all(deepnorm) or not any(deepnorm), (
|
| 234 |
+
"All layers need to have the same deepnorm setting. "
|
| 235 |
+
+ f"Currently {deepnorm}"
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
def init_weights(self, weight_init: xFormerWeightInit, use_deep_norm: bool):
|
| 239 |
+
# The deepnorm weight initialization method requires different gain factors for the encoder
|
| 240 |
+
# and decoder, depending on the general model structure (number of respective layers)
|
| 241 |
+
if use_deep_norm:
|
| 242 |
+
encoder_coefficients, decoder_coefficients = get_deepnorm_coefficients(
|
| 243 |
+
encoder_layers=len(self.encoders), decoder_layers=len(self.decoders) # type: ignore
|
| 244 |
+
)
|
| 245 |
+
else:
|
| 246 |
+
encoder_coefficients, decoder_coefficients = None, None
|
| 247 |
+
|
| 248 |
+
encoder_gain = (
|
| 249 |
+
encoder_coefficients.beta if encoder_coefficients is not None else 1.0
|
| 250 |
+
)
|
| 251 |
+
decoder_gain = (
|
| 252 |
+
decoder_coefficients.beta if decoder_coefficients is not None else 1.0
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
# Pick the desired init function
|
| 256 |
+
init_fn = get_weight_init_fn(weight_init)
|
| 257 |
+
|
| 258 |
+
# Initialize all the encoder weights
|
| 259 |
+
for name, module in self.encoders.named_children():
|
| 260 |
+
init_fn(module=module, name=name, gain=encoder_gain)
|
| 261 |
+
|
| 262 |
+
for name, module in self.decoders.named_children():
|
| 263 |
+
init_fn(module=module, name=name, gain=decoder_gain)
|
| 264 |
+
|
| 265 |
+
def forward(
|
| 266 |
+
self,
|
| 267 |
+
src: torch.Tensor,
|
| 268 |
+
tgt: Optional[torch.Tensor] = None,
|
| 269 |
+
encoder_input_mask: Optional[torch.Tensor] = None,
|
| 270 |
+
decoder_input_mask: Optional[torch.Tensor] = None,
|
| 271 |
+
) -> Optional[torch.Tensor]:
|
| 272 |
+
|
| 273 |
+
# Encode to latent space if encoder is present
|
| 274 |
+
if len(list(self.encoders.parameters())) > 0:
|
| 275 |
+
encoders = self.encoders
|
| 276 |
+
memory = src.clone()
|
| 277 |
+
if isinstance(encoders, torch.nn.ModuleList):
|
| 278 |
+
for encoder in encoders:
|
| 279 |
+
memory = encoder(memory, input_mask=encoder_input_mask)
|
| 280 |
+
else:
|
| 281 |
+
if self.rev_enc_pose_encoding:
|
| 282 |
+
memory = self.rev_enc_pose_encoding(src)
|
| 283 |
+
|
| 284 |
+
# Reversible Encoder
|
| 285 |
+
x = torch.cat([memory, memory], dim=-1)
|
| 286 |
+
|
| 287 |
+
# Apply the optional input masking
|
| 288 |
+
if encoder_input_mask is not None:
|
| 289 |
+
if x.dim() - encoder_input_mask.dim() > 1:
|
| 290 |
+
encoder_input_mask.unsqueeze(0)
|
| 291 |
+
x += encoder_input_mask.unsqueeze(-1)
|
| 292 |
+
|
| 293 |
+
x = encoders(x)
|
| 294 |
+
memory = torch.stack(x.chunk(2, dim=-1)).mean(dim=0)
|
| 295 |
+
|
| 296 |
+
if not self.decoders:
|
| 297 |
+
return memory
|
| 298 |
+
|
| 299 |
+
# If decoder: either use the encoder output, or just decode, both options are possible
|
| 300 |
+
if len(self.decoders) > 0:
|
| 301 |
+
tgt = src.clone() if tgt is None else tgt
|
| 302 |
+
|
| 303 |
+
for decoder in self.decoders:
|
| 304 |
+
tgt = decoder(
|
| 305 |
+
target=tgt,
|
| 306 |
+
# pyre-fixme[61]: `memory` is not always initialized here.
|
| 307 |
+
memory=memory,
|
| 308 |
+
input_mask=decoder_input_mask,
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
return tgt
|
| 312 |
+
|
| 313 |
+
return None
|
.venv/lib/python3.11/site-packages/xformers/factory/weight_init.py
ADDED
|
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the BSD license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# CREDITS: Reusing a lot of code from the Timm repo
|
| 7 |
+
# main difference is probably the handling of deepnorm init, and adapting to some xformers specificities
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
| 9 |
+
|
| 10 |
+
import logging
|
| 11 |
+
import math
|
| 12 |
+
from enum import Enum
|
| 13 |
+
from typing import Callable
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
from torch.nn.init import (
|
| 18 |
+
_calculate_fan_in_and_fan_out,
|
| 19 |
+
_no_grad_trunc_normal_,
|
| 20 |
+
_no_grad_uniform_,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
logger = logging.getLogger("xformers")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
_assert_if_not_initialized = False
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class xFormerWeightInit(str, Enum):
|
| 30 |
+
Timm = "timm"
|
| 31 |
+
ViT = "vit"
|
| 32 |
+
Moco = "moco"
|
| 33 |
+
Small = "small"
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def get_weight_init_fn(init_choice: xFormerWeightInit):
|
| 37 |
+
"""
|
| 38 |
+
Provide the xFormers factory with weight init routines.
|
| 39 |
+
|
| 40 |
+
Supported initializations are:
|
| 41 |
+
- Small: follow the method outlined in `Transformer Without Tears`_
|
| 42 |
+
- ViT: follow the initialization in the reference ViT_ codebase
|
| 43 |
+
- Timm: follow the initialization in the reference Timm_ codebase
|
| 44 |
+
- Moco: follow the initialization in the reference MocoV3_ codebase
|
| 45 |
+
|
| 46 |
+
.. _ViT: https://github.com/google-research/vision_transformer
|
| 47 |
+
.. _Timm: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
| 48 |
+
.. _MocoV3: https://github.com/facebookresearch/moco-v3
|
| 49 |
+
"""
|
| 50 |
+
return {
|
| 51 |
+
xFormerWeightInit.Timm: _init_weights_vit_timm,
|
| 52 |
+
xFormerWeightInit.ViT: _init_weights_vit_jax,
|
| 53 |
+
xFormerWeightInit.Moco: _init_weights_vit_moco,
|
| 54 |
+
xFormerWeightInit.Small: _init_weights_small,
|
| 55 |
+
}[init_choice]
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# Define pattern matches
|
| 59 |
+
def is_ffn(n):
|
| 60 |
+
return "feedforward" in n or ("wrap_ff" in n and not n.endswith("norm"))
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def is_mha_input_projection(n):
|
| 64 |
+
return "q_proj" in n or "k_proj" in n or "v_proj" in n
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# Define distribution helpers
|
| 68 |
+
def _small_init_(tensor: torch.Tensor, gain: float = 1.0) -> torch.Tensor:
|
| 69 |
+
r"""Fills the input `Tensor` with values according to the method
|
| 70 |
+
described in `Transformer Without Tears`_, using a uniform distribution.
|
| 71 |
+
|
| 72 |
+
This is a variation of the Xavier init. The resulting tensor will have values sampled from
|
| 73 |
+
:math:`\mathcal{U}(-a, a)` where
|
| 74 |
+
|
| 75 |
+
.. math::
|
| 76 |
+
a = \text{gain} \times \sqrt{\frac{6}{\text{fan\_in} + 4 * \text{fan\_out}}}
|
| 77 |
+
|
| 78 |
+
Also known as Glorot initialization.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
tensor: an n-dimensional `torch.Tensor`
|
| 82 |
+
gain: an optional scaling factor
|
| 83 |
+
|
| 84 |
+
.. _`Transformer Without Tears`: https://arxiv.org/abs/1910.05895
|
| 85 |
+
|
| 86 |
+
"""
|
| 87 |
+
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
|
| 88 |
+
std = gain * math.sqrt(2.0 / float(fan_in + 4 * fan_out))
|
| 89 |
+
a = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
|
| 90 |
+
|
| 91 |
+
return _no_grad_uniform_(tensor, -a, a)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def _lecun_normal(tensor, gain=1.0):
|
| 95 |
+
fan_in, _ = _calculate_fan_in_and_fan_out(tensor)
|
| 96 |
+
denom = fan_in
|
| 97 |
+
variance = gain / denom
|
| 98 |
+
|
| 99 |
+
# constant is stddev of standard normal truncated to (-2, 2)
|
| 100 |
+
_no_grad_trunc_normal_(
|
| 101 |
+
tensor,
|
| 102 |
+
mean=0.0,
|
| 103 |
+
std=math.sqrt(variance) / 0.87962566103423978,
|
| 104 |
+
a=-2.0,
|
| 105 |
+
b=2.0,
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
# Helpers to keep all the functions typesafe, and handle corner cases and common behaviours in one place
|
| 110 |
+
def _maybe_init_tensor(module: nn.Module, attr: str, distribution_: Callable, **kwargs):
|
| 111 |
+
# Small helper to catch all the corner cases, while staying type safe
|
| 112 |
+
if hasattr(module, attr):
|
| 113 |
+
maybe_tensor = getattr(module, attr)
|
| 114 |
+
if maybe_tensor is not None and isinstance(maybe_tensor, torch.Tensor):
|
| 115 |
+
distribution_(maybe_tensor, **kwargs)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def _maybe_report_no_init(module, name):
|
| 119 |
+
if len(list(module.named_children())) == 0 and (
|
| 120 |
+
hasattr(module, "weight") or hasattr(module, "bias")
|
| 121 |
+
):
|
| 122 |
+
# Skip layer norm, this is ok
|
| 123 |
+
if isinstance(module, torch.nn.LayerNorm):
|
| 124 |
+
return
|
| 125 |
+
|
| 126 |
+
# Skip nn.Embedding, we typically initialize it one level up, else Pytorch has a valid default
|
| 127 |
+
if isinstance(module, torch.nn.Embedding):
|
| 128 |
+
return
|
| 129 |
+
|
| 130 |
+
# This is unexpected, warn about a possible unhandled weight
|
| 131 |
+
logger.warning(
|
| 132 |
+
f"Not initializing weights in {name}, this could be a mistake.\nModule {module}"
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
if _assert_if_not_initialized:
|
| 136 |
+
assert False, (
|
| 137 |
+
f"Uninitialized weight found in {module}."
|
| 138 |
+
+ " If you have a custom module, please provide a `init_weights()` method"
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
# Define the different initialization schemes
|
| 143 |
+
def _init_weights_vit_jax(
|
| 144 |
+
module: nn.Module,
|
| 145 |
+
name: str = "",
|
| 146 |
+
head_bias: float = 0.0,
|
| 147 |
+
gain: float = 1.0,
|
| 148 |
+
deepnorm_style: bool = False,
|
| 149 |
+
**kwargs,
|
| 150 |
+
):
|
| 151 |
+
"""ViT weight initialization, matching JAX (Flax) impl"""
|
| 152 |
+
|
| 153 |
+
if is_ffn(name):
|
| 154 |
+
_maybe_init_tensor(module, "bias", nn.init.normal_, std=1e-6)
|
| 155 |
+
_maybe_init_tensor(module, "weight", torch.nn.init.xavier_uniform_, gain=gain)
|
| 156 |
+
|
| 157 |
+
elif is_mha_input_projection(name) or isinstance(module, nn.Linear):
|
| 158 |
+
if deepnorm_style and (
|
| 159 |
+
"q_proj" in name.split(".") or "k_proj" in name.split(".")
|
| 160 |
+
):
|
| 161 |
+
gain = 1.0
|
| 162 |
+
|
| 163 |
+
_maybe_init_tensor(module, "weight", torch.nn.init.xavier_uniform_, gain=gain)
|
| 164 |
+
_maybe_init_tensor(module, "bias", nn.init.zeros_)
|
| 165 |
+
|
| 166 |
+
elif isinstance(module, nn.Conv2d):
|
| 167 |
+
_maybe_init_tensor(module, "weight", _lecun_normal, gain=gain)
|
| 168 |
+
_maybe_init_tensor(module, "bias", nn.init.zeros_)
|
| 169 |
+
|
| 170 |
+
elif hasattr(module, "init_weights"):
|
| 171 |
+
module.init_weights() # type: ignore
|
| 172 |
+
|
| 173 |
+
else:
|
| 174 |
+
_maybe_report_no_init(module, name)
|
| 175 |
+
|
| 176 |
+
# Recurse over the children, if the weight init is being handled here
|
| 177 |
+
if not hasattr(module, "init_weights"):
|
| 178 |
+
for child_name, child_module in module.named_children():
|
| 179 |
+
_init_weights_vit_jax(child_module, f"{name}.{child_name}", head_bias, gain)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def _init_weights_vit_moco(
|
| 183 |
+
module: nn.Module,
|
| 184 |
+
name: str = "",
|
| 185 |
+
gain: float = 1.0,
|
| 186 |
+
**kwargs,
|
| 187 |
+
):
|
| 188 |
+
"""ViT weight initialization, matching moco-v3 impl minus fixed PatchEmbed"""
|
| 189 |
+
|
| 190 |
+
assert (
|
| 191 |
+
"deepnorm_style" not in kwargs.keys()
|
| 192 |
+
), "This initialization method does not support deepnorm"
|
| 193 |
+
|
| 194 |
+
if is_ffn(name):
|
| 195 |
+
_maybe_init_tensor(module, "weight", torch.nn.init.xavier_uniform_, gain=gain)
|
| 196 |
+
_maybe_init_tensor(module, "bias", nn.init.zeros_)
|
| 197 |
+
|
| 198 |
+
elif is_mha_input_projection(name) or isinstance(module, nn.Linear):
|
| 199 |
+
if isinstance(module.weight, torch.Tensor):
|
| 200 |
+
val = (
|
| 201 |
+
math.sqrt(6.0 / float(module.weight.shape[0] + module.weight.shape[1]))
|
| 202 |
+
* gain
|
| 203 |
+
)
|
| 204 |
+
_maybe_init_tensor(module, "weight", nn.init.uniform_, a=-val, b=val)
|
| 205 |
+
|
| 206 |
+
_maybe_init_tensor(module, "bias", nn.init.zeros_)
|
| 207 |
+
|
| 208 |
+
elif hasattr(module, "init_weights"):
|
| 209 |
+
module.init_weights(gain=gain) # type: ignore
|
| 210 |
+
|
| 211 |
+
else:
|
| 212 |
+
_maybe_report_no_init(module, name)
|
| 213 |
+
|
| 214 |
+
# Recurse over the children, if the weight init is being handled here
|
| 215 |
+
if not hasattr(module, "init_weights"):
|
| 216 |
+
for child_name, child_module in module.named_children():
|
| 217 |
+
_init_weights_vit_moco(child_module, child_name, gain)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def _init_weights_small(
|
| 221 |
+
module: nn.Module,
|
| 222 |
+
name: str = "",
|
| 223 |
+
head_bias: float = 0.0,
|
| 224 |
+
gain: float = 1.0,
|
| 225 |
+
deepnorm_style: bool = False,
|
| 226 |
+
**kwargs,
|
| 227 |
+
):
|
| 228 |
+
"""Follow the `Transformer Without Tears`_ initialization for self-attention"""
|
| 229 |
+
|
| 230 |
+
if is_ffn(name):
|
| 231 |
+
_maybe_init_tensor(module, "weight", torch.nn.init.xavier_uniform_, gain=gain)
|
| 232 |
+
_maybe_init_tensor(module, "bias", nn.init.normal_, std=1e-6)
|
| 233 |
+
|
| 234 |
+
elif is_mha_input_projection(name) or isinstance(module, nn.Linear):
|
| 235 |
+
# "small init" only scales the attention layers init, not the FFN
|
| 236 |
+
if deepnorm_style and (
|
| 237 |
+
"q_proj" in name.split(".") or "k_proj" in name.split(".")
|
| 238 |
+
):
|
| 239 |
+
gain = 1.0
|
| 240 |
+
|
| 241 |
+
_maybe_init_tensor(module, "weight", _small_init_, gain=gain)
|
| 242 |
+
_maybe_init_tensor(module, "bias", nn.init.zeros_)
|
| 243 |
+
|
| 244 |
+
elif isinstance(module, nn.Conv2d):
|
| 245 |
+
_maybe_init_tensor(module, "weight", _lecun_normal)
|
| 246 |
+
_maybe_init_tensor(module, "bias", nn.init.zeros_)
|
| 247 |
+
elif hasattr(module, "init_weights"):
|
| 248 |
+
module.init_weights() # type: ignore
|
| 249 |
+
else:
|
| 250 |
+
_maybe_report_no_init(module, name)
|
| 251 |
+
|
| 252 |
+
# Recurse over the children, if the weight init is being handled here
|
| 253 |
+
if not hasattr(module, "init_weights"):
|
| 254 |
+
for child_name, child_module in module.named_children():
|
| 255 |
+
_init_weights_small(child_module, f"{name}.{child_name}", head_bias, gain)
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def _init_weights_vit_timm(
|
| 259 |
+
module: nn.Module,
|
| 260 |
+
name: str = "",
|
| 261 |
+
gain: float = 1.0,
|
| 262 |
+
deepnorm_style: bool = False,
|
| 263 |
+
**kwargs,
|
| 264 |
+
):
|
| 265 |
+
"""
|
| 266 |
+
ViT weight initialization, original timm impl (for reproducibility).
|
| 267 |
+
|
| 268 |
+
See DeepNet_ for all the DeepNorm specific codepaths
|
| 269 |
+
"""
|
| 270 |
+
|
| 271 |
+
if isinstance(module, nn.Linear):
|
| 272 |
+
if deepnorm_style and (
|
| 273 |
+
"q_proj" in name.split(".") or "k_proj" in name.split(".")
|
| 274 |
+
):
|
| 275 |
+
gain = 1
|
| 276 |
+
|
| 277 |
+
std = 0.02 * gain
|
| 278 |
+
a = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
|
| 279 |
+
|
| 280 |
+
_maybe_init_tensor(
|
| 281 |
+
module, "weight", _no_grad_trunc_normal_, mean=0.0, std=std, a=-a, b=a
|
| 282 |
+
)
|
| 283 |
+
_maybe_init_tensor(module, "bias", nn.init.zeros_)
|
| 284 |
+
|
| 285 |
+
elif hasattr(module, "init_weights"):
|
| 286 |
+
module.init_weights(gain=gain) # type: ignore
|
| 287 |
+
else:
|
| 288 |
+
_maybe_report_no_init(module, name)
|
| 289 |
+
|
| 290 |
+
# Recurse over the children, if the weight init is being handled here
|
| 291 |
+
if not hasattr(module, "init_weights"):
|
| 292 |
+
for child_name, child_module in module.named_children():
|
| 293 |
+
_init_weights_vit_timm(child_module, child_name, gain)
|
.venv/lib/python3.11/site-packages/xformers/info.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the BSD license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
from typing import Dict
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from . import __version__, _cpp_lib, _is_opensource, _is_triton_available, ops
|
| 12 |
+
from .ops.common import OPERATORS_REGISTRY
|
| 13 |
+
from .profiler.profiler_dcgm import DCGM_PROFILER_AVAILABLE
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def get_features_status() -> Dict[str, str]:
|
| 17 |
+
features = {}
|
| 18 |
+
for op in OPERATORS_REGISTRY:
|
| 19 |
+
status_str = "available" if op.is_available() else "unavailable"
|
| 20 |
+
features[f"{op.OPERATOR_CATEGORY}.{op.NAME}"] = status_str
|
| 21 |
+
for k, v in ops.swiglu_op._info().items():
|
| 22 |
+
features[f"swiglu.{k}"] = v
|
| 23 |
+
features["is_triton_available"] = str(_is_triton_available())
|
| 24 |
+
return features
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def print_info():
|
| 28 |
+
features = get_features_status()
|
| 29 |
+
print(f"xFormers {__version__}")
|
| 30 |
+
features["pytorch.version"] = torch.__version__
|
| 31 |
+
if torch.cuda.is_available():
|
| 32 |
+
features["pytorch.cuda"] = "available"
|
| 33 |
+
device = torch.cuda.current_device()
|
| 34 |
+
cap = torch.cuda.get_device_capability(device)
|
| 35 |
+
features["gpu.compute_capability"] = ".".join(str(ver) for ver in cap)
|
| 36 |
+
features["gpu.name"] = torch.cuda.get_device_name(device)
|
| 37 |
+
else:
|
| 38 |
+
features["pytorch.cuda"] = "not available"
|
| 39 |
+
|
| 40 |
+
features["dcgm_profiler"] = (
|
| 41 |
+
"available" if DCGM_PROFILER_AVAILABLE else "unavailable"
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
build_info = _cpp_lib._build_metadata
|
| 45 |
+
if build_info is None and isinstance(
|
| 46 |
+
_cpp_lib._cpp_library_load_exception, _cpp_lib.xFormersInvalidLibException
|
| 47 |
+
):
|
| 48 |
+
build_info = _cpp_lib._cpp_library_load_exception.build_info
|
| 49 |
+
if build_info is not None:
|
| 50 |
+
features["build.info"] = "available"
|
| 51 |
+
features["build.cuda_version"] = build_info.cuda_version
|
| 52 |
+
features["build.hip_version"] = build_info.hip_version
|
| 53 |
+
features["build.python_version"] = build_info.python_version
|
| 54 |
+
features["build.torch_version"] = build_info.torch_version
|
| 55 |
+
for k, v in build_info.build_env.items():
|
| 56 |
+
features[f"build.env.{k}"] = v
|
| 57 |
+
else:
|
| 58 |
+
features["build.info"] = "none"
|
| 59 |
+
|
| 60 |
+
try:
|
| 61 |
+
features["build.nvcc_version"] = ".".join(
|
| 62 |
+
str(v) for v in torch.ops.xformers._nvcc_build_version()
|
| 63 |
+
)
|
| 64 |
+
except (RuntimeError, AttributeError):
|
| 65 |
+
pass
|
| 66 |
+
|
| 67 |
+
if _is_opensource:
|
| 68 |
+
features["source.privacy"] = "open source"
|
| 69 |
+
else:
|
| 70 |
+
features["source.privacy"] = "fairinternal"
|
| 71 |
+
|
| 72 |
+
for name, status in features.items():
|
| 73 |
+
print("{:<50} {}".format(f"{name}:", status))
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
if __name__ == "__main__":
|
| 77 |
+
print_info()
|
.venv/lib/python3.11/site-packages/xformers/ops/__pycache__/common.cpython-311.pyc
ADDED
|
Binary file (3.08 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/ops/__pycache__/differentiable_collectives.cpython-311.pyc
ADDED
|
Binary file (8.61 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/ops/__pycache__/indexing.cpython-311.pyc
ADDED
|
Binary file (8.59 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/ops/__pycache__/ipc.cpython-311.pyc
ADDED
|
Binary file (7.89 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/ops/__pycache__/modpar_layers.cpython-311.pyc
ADDED
|
Binary file (8.76 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/ops/__pycache__/rmsnorm.cpython-311.pyc
ADDED
|
Binary file (4.78 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/ops/__pycache__/rope_padded.cpython-311.pyc
ADDED
|
Binary file (12.8 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/ops/__pycache__/seqpar.cpython-311.pyc
ADDED
|
Binary file (19.8 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/ops/__pycache__/sequence_parallel_fused_ops.cpython-311.pyc
ADDED
|
Binary file (53.7 kB). View file
|
|
|