diff --git a/.venv/lib/python3.11/site-packages/OpenSSL/SSL.py b/.venv/lib/python3.11/site-packages/OpenSSL/SSL.py new file mode 100644 index 0000000000000000000000000000000000000000..4ca59503da669086f504f307b9db3080ecd6eb0d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/OpenSSL/SSL.py @@ -0,0 +1,2961 @@ +import os +import socket +import typing +from errno import errorcode +from functools import partial, wraps +from itertools import chain, count +from sys import platform +from typing import Any, Callable, List, Optional, Sequence, Tuple, TypeVar +from weakref import WeakValueDictionary + +from OpenSSL._util import ( + StrOrBytesPath as _StrOrBytesPath, +) +from OpenSSL._util import ( + exception_from_error_queue as _exception_from_error_queue, +) +from OpenSSL._util import ( + ffi as _ffi, +) +from OpenSSL._util import ( + lib as _lib, +) +from OpenSSL._util import ( + make_assert as _make_assert, +) +from OpenSSL._util import ( + no_zero_allocator as _no_zero_allocator, +) +from OpenSSL._util import ( + path_bytes as _path_bytes, +) +from OpenSSL._util import ( + text_to_bytes_and_warn as _text_to_bytes_and_warn, +) +from OpenSSL.crypto import ( + FILETYPE_PEM, + X509, + PKey, + X509Name, + X509Store, + _EllipticCurve, + _PassphraseHelper, +) + +__all__ = [ + "OPENSSL_VERSION_NUMBER", + "SSLEAY_VERSION", + "SSLEAY_CFLAGS", + "SSLEAY_PLATFORM", + "SSLEAY_DIR", + "SSLEAY_BUILT_ON", + "OPENSSL_VERSION", + "OPENSSL_CFLAGS", + "OPENSSL_PLATFORM", + "OPENSSL_DIR", + "OPENSSL_BUILT_ON", + "SENT_SHUTDOWN", + "RECEIVED_SHUTDOWN", + "SSLv23_METHOD", + "TLSv1_METHOD", + "TLSv1_1_METHOD", + "TLSv1_2_METHOD", + "TLS_METHOD", + "TLS_SERVER_METHOD", + "TLS_CLIENT_METHOD", + "DTLS_METHOD", + "DTLS_SERVER_METHOD", + "DTLS_CLIENT_METHOD", + "SSL3_VERSION", + "TLS1_VERSION", + "TLS1_1_VERSION", + "TLS1_2_VERSION", + "TLS1_3_VERSION", + "OP_NO_SSLv2", + "OP_NO_SSLv3", + "OP_NO_TLSv1", + "OP_NO_TLSv1_1", + "OP_NO_TLSv1_2", + "MODE_RELEASE_BUFFERS", + "OP_SINGLE_DH_USE", + "OP_SINGLE_ECDH_USE", + "OP_EPHEMERAL_RSA", + "OP_MICROSOFT_SESS_ID_BUG", + "OP_NETSCAPE_CHALLENGE_BUG", + "OP_NETSCAPE_REUSE_CIPHER_CHANGE_BUG", + "OP_SSLREF2_REUSE_CERT_TYPE_BUG", + "OP_MICROSOFT_BIG_SSLV3_BUFFER", + "OP_MSIE_SSLV2_RSA_PADDING", + "OP_SSLEAY_080_CLIENT_DH_BUG", + "OP_TLS_D5_BUG", + "OP_TLS_BLOCK_PADDING_BUG", + "OP_DONT_INSERT_EMPTY_FRAGMENTS", + "OP_CIPHER_SERVER_PREFERENCE", + "OP_TLS_ROLLBACK_BUG", + "OP_PKCS1_CHECK_1", + "OP_PKCS1_CHECK_2", + "OP_NETSCAPE_CA_DN_BUG", + "OP_NETSCAPE_DEMO_CIPHER_CHANGE_BUG", + "OP_NO_COMPRESSION", + "OP_NO_QUERY_MTU", + "OP_COOKIE_EXCHANGE", + "OP_NO_TICKET", + "OP_ALL", + "VERIFY_PEER", + "VERIFY_FAIL_IF_NO_PEER_CERT", + "VERIFY_CLIENT_ONCE", + "VERIFY_NONE", + "SESS_CACHE_OFF", + "SESS_CACHE_CLIENT", + "SESS_CACHE_SERVER", + "SESS_CACHE_BOTH", + "SESS_CACHE_NO_AUTO_CLEAR", + "SESS_CACHE_NO_INTERNAL_LOOKUP", + "SESS_CACHE_NO_INTERNAL_STORE", + "SESS_CACHE_NO_INTERNAL", + "SSL_ST_CONNECT", + "SSL_ST_ACCEPT", + "SSL_ST_MASK", + "SSL_CB_LOOP", + "SSL_CB_EXIT", + "SSL_CB_READ", + "SSL_CB_WRITE", + "SSL_CB_ALERT", + "SSL_CB_READ_ALERT", + "SSL_CB_WRITE_ALERT", + "SSL_CB_ACCEPT_LOOP", + "SSL_CB_ACCEPT_EXIT", + "SSL_CB_CONNECT_LOOP", + "SSL_CB_CONNECT_EXIT", + "SSL_CB_HANDSHAKE_START", + "SSL_CB_HANDSHAKE_DONE", + "Error", + "WantReadError", + "WantWriteError", + "WantX509LookupError", + "ZeroReturnError", + "SysCallError", + "NO_OVERLAPPING_PROTOCOLS", + "SSLeay_version", + "Session", + "Context", + "Connection", + "X509VerificationCodes", +] + + +OPENSSL_VERSION_NUMBER: int = _lib.OPENSSL_VERSION_NUMBER +OPENSSL_VERSION: int = _lib.OPENSSL_VERSION +OPENSSL_CFLAGS: int = _lib.OPENSSL_CFLAGS +OPENSSL_PLATFORM: int = _lib.OPENSSL_PLATFORM +OPENSSL_DIR: int = _lib.OPENSSL_DIR +OPENSSL_BUILT_ON: int = _lib.OPENSSL_BUILT_ON + +SSLEAY_VERSION = OPENSSL_VERSION +SSLEAY_CFLAGS = OPENSSL_CFLAGS +SSLEAY_PLATFORM = OPENSSL_PLATFORM +SSLEAY_DIR = OPENSSL_DIR +SSLEAY_BUILT_ON = OPENSSL_BUILT_ON + +SENT_SHUTDOWN = _lib.SSL_SENT_SHUTDOWN +RECEIVED_SHUTDOWN = _lib.SSL_RECEIVED_SHUTDOWN + +SSLv23_METHOD = 3 +TLSv1_METHOD = 4 +TLSv1_1_METHOD = 5 +TLSv1_2_METHOD = 6 +TLS_METHOD = 7 +TLS_SERVER_METHOD = 8 +TLS_CLIENT_METHOD = 9 +DTLS_METHOD = 10 +DTLS_SERVER_METHOD = 11 +DTLS_CLIENT_METHOD = 12 + +SSL3_VERSION: int = _lib.SSL3_VERSION +TLS1_VERSION: int = _lib.TLS1_VERSION +TLS1_1_VERSION: int = _lib.TLS1_1_VERSION +TLS1_2_VERSION: int = _lib.TLS1_2_VERSION +TLS1_3_VERSION: int = _lib.TLS1_3_VERSION + +OP_NO_SSLv2: int = _lib.SSL_OP_NO_SSLv2 +OP_NO_SSLv3: int = _lib.SSL_OP_NO_SSLv3 +OP_NO_TLSv1: int = _lib.SSL_OP_NO_TLSv1 +OP_NO_TLSv1_1: int = _lib.SSL_OP_NO_TLSv1_1 +OP_NO_TLSv1_2: int = _lib.SSL_OP_NO_TLSv1_2 +try: + OP_NO_TLSv1_3: int = _lib.SSL_OP_NO_TLSv1_3 + __all__.append("OP_NO_TLSv1_3") +except AttributeError: + pass + +MODE_RELEASE_BUFFERS: int = _lib.SSL_MODE_RELEASE_BUFFERS + +OP_SINGLE_DH_USE: int = _lib.SSL_OP_SINGLE_DH_USE +OP_SINGLE_ECDH_USE: int = _lib.SSL_OP_SINGLE_ECDH_USE +OP_EPHEMERAL_RSA: int = _lib.SSL_OP_EPHEMERAL_RSA +OP_MICROSOFT_SESS_ID_BUG: int = _lib.SSL_OP_MICROSOFT_SESS_ID_BUG +OP_NETSCAPE_CHALLENGE_BUG: int = _lib.SSL_OP_NETSCAPE_CHALLENGE_BUG +OP_NETSCAPE_REUSE_CIPHER_CHANGE_BUG: int = ( + _lib.SSL_OP_NETSCAPE_REUSE_CIPHER_CHANGE_BUG +) +OP_SSLREF2_REUSE_CERT_TYPE_BUG: int = _lib.SSL_OP_SSLREF2_REUSE_CERT_TYPE_BUG +OP_MICROSOFT_BIG_SSLV3_BUFFER: int = _lib.SSL_OP_MICROSOFT_BIG_SSLV3_BUFFER +OP_MSIE_SSLV2_RSA_PADDING: int = _lib.SSL_OP_MSIE_SSLV2_RSA_PADDING +OP_SSLEAY_080_CLIENT_DH_BUG: int = _lib.SSL_OP_SSLEAY_080_CLIENT_DH_BUG +OP_TLS_D5_BUG: int = _lib.SSL_OP_TLS_D5_BUG +OP_TLS_BLOCK_PADDING_BUG: int = _lib.SSL_OP_TLS_BLOCK_PADDING_BUG +OP_DONT_INSERT_EMPTY_FRAGMENTS: int = _lib.SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS +OP_CIPHER_SERVER_PREFERENCE: int = _lib.SSL_OP_CIPHER_SERVER_PREFERENCE +OP_TLS_ROLLBACK_BUG: int = _lib.SSL_OP_TLS_ROLLBACK_BUG +OP_PKCS1_CHECK_1 = _lib.SSL_OP_PKCS1_CHECK_1 +OP_PKCS1_CHECK_2: int = _lib.SSL_OP_PKCS1_CHECK_2 +OP_NETSCAPE_CA_DN_BUG: int = _lib.SSL_OP_NETSCAPE_CA_DN_BUG +OP_NETSCAPE_DEMO_CIPHER_CHANGE_BUG: int = ( + _lib.SSL_OP_NETSCAPE_DEMO_CIPHER_CHANGE_BUG +) +OP_NO_COMPRESSION: int = _lib.SSL_OP_NO_COMPRESSION + +OP_NO_QUERY_MTU: int = _lib.SSL_OP_NO_QUERY_MTU +OP_COOKIE_EXCHANGE: int = _lib.SSL_OP_COOKIE_EXCHANGE +OP_NO_TICKET: int = _lib.SSL_OP_NO_TICKET + +try: + OP_NO_RENEGOTIATION: int = _lib.SSL_OP_NO_RENEGOTIATION + __all__.append("OP_NO_RENEGOTIATION") +except AttributeError: + pass + +try: + OP_IGNORE_UNEXPECTED_EOF: int = _lib.SSL_OP_IGNORE_UNEXPECTED_EOF + __all__.append("OP_IGNORE_UNEXPECTED_EOF") +except AttributeError: + pass + +try: + OP_LEGACY_SERVER_CONNECT: int = _lib.SSL_OP_LEGACY_SERVER_CONNECT + __all__.append("OP_LEGACY_SERVER_CONNECT") +except AttributeError: + pass + +OP_ALL: int = _lib.SSL_OP_ALL + +VERIFY_PEER: int = _lib.SSL_VERIFY_PEER +VERIFY_FAIL_IF_NO_PEER_CERT: int = _lib.SSL_VERIFY_FAIL_IF_NO_PEER_CERT +VERIFY_CLIENT_ONCE: int = _lib.SSL_VERIFY_CLIENT_ONCE +VERIFY_NONE: int = _lib.SSL_VERIFY_NONE + +SESS_CACHE_OFF: int = _lib.SSL_SESS_CACHE_OFF +SESS_CACHE_CLIENT: int = _lib.SSL_SESS_CACHE_CLIENT +SESS_CACHE_SERVER: int = _lib.SSL_SESS_CACHE_SERVER +SESS_CACHE_BOTH: int = _lib.SSL_SESS_CACHE_BOTH +SESS_CACHE_NO_AUTO_CLEAR: int = _lib.SSL_SESS_CACHE_NO_AUTO_CLEAR +SESS_CACHE_NO_INTERNAL_LOOKUP: int = _lib.SSL_SESS_CACHE_NO_INTERNAL_LOOKUP +SESS_CACHE_NO_INTERNAL_STORE: int = _lib.SSL_SESS_CACHE_NO_INTERNAL_STORE +SESS_CACHE_NO_INTERNAL: int = _lib.SSL_SESS_CACHE_NO_INTERNAL + +SSL_ST_CONNECT: int = _lib.SSL_ST_CONNECT +SSL_ST_ACCEPT: int = _lib.SSL_ST_ACCEPT +SSL_ST_MASK: int = _lib.SSL_ST_MASK + +SSL_CB_LOOP: int = _lib.SSL_CB_LOOP +SSL_CB_EXIT: int = _lib.SSL_CB_EXIT +SSL_CB_READ: int = _lib.SSL_CB_READ +SSL_CB_WRITE: int = _lib.SSL_CB_WRITE +SSL_CB_ALERT: int = _lib.SSL_CB_ALERT +SSL_CB_READ_ALERT: int = _lib.SSL_CB_READ_ALERT +SSL_CB_WRITE_ALERT: int = _lib.SSL_CB_WRITE_ALERT +SSL_CB_ACCEPT_LOOP: int = _lib.SSL_CB_ACCEPT_LOOP +SSL_CB_ACCEPT_EXIT: int = _lib.SSL_CB_ACCEPT_EXIT +SSL_CB_CONNECT_LOOP: int = _lib.SSL_CB_CONNECT_LOOP +SSL_CB_CONNECT_EXIT: int = _lib.SSL_CB_CONNECT_EXIT +SSL_CB_HANDSHAKE_START: int = _lib.SSL_CB_HANDSHAKE_START +SSL_CB_HANDSHAKE_DONE: int = _lib.SSL_CB_HANDSHAKE_DONE + +_T = TypeVar("_T") + + +class _NoOverlappingProtocols: + pass + + +NO_OVERLAPPING_PROTOCOLS = _NoOverlappingProtocols() + +# Callback types. +_ALPNSelectCallback = Callable[ + [ + "Connection", + typing.Union[List[bytes], _NoOverlappingProtocols], + ], + None, +] +_CookieGenerateCallback = Callable[["Connection"], bytes] +_CookieVerifyCallback = Callable[["Connection", bytes], bool] +_OCSPClientCallback = Callable[["Connection", bytes, Optional[_T]], bool] +_OCSPServerCallback = Callable[["Connection", Optional[_T]], bytes] +_PassphraseCallback = Callable[[int, bool, Optional[_T]], bytes] +_VerifyCallback = Callable[["Connection", X509, int, int, int], bool] + + +class X509VerificationCodes: + """ + Success and error codes for X509 verification, as returned by the + underlying ``X509_STORE_CTX_get_error()`` function and passed by pyOpenSSL + to verification callback functions. + + See `OpenSSL Verification Errors + `_ + for details. + """ + + OK = _lib.X509_V_OK + ERR_UNABLE_TO_GET_ISSUER_CERT = _lib.X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT + ERR_UNABLE_TO_GET_CRL = _lib.X509_V_ERR_UNABLE_TO_GET_CRL + ERR_UNABLE_TO_DECRYPT_CERT_SIGNATURE = ( + _lib.X509_V_ERR_UNABLE_TO_DECRYPT_CERT_SIGNATURE + ) + ERR_UNABLE_TO_DECRYPT_CRL_SIGNATURE = ( + _lib.X509_V_ERR_UNABLE_TO_DECRYPT_CRL_SIGNATURE + ) + ERR_UNABLE_TO_DECODE_ISSUER_PUBLIC_KEY = ( + _lib.X509_V_ERR_UNABLE_TO_DECODE_ISSUER_PUBLIC_KEY + ) + ERR_CERT_SIGNATURE_FAILURE = _lib.X509_V_ERR_CERT_SIGNATURE_FAILURE + ERR_CRL_SIGNATURE_FAILURE = _lib.X509_V_ERR_CRL_SIGNATURE_FAILURE + ERR_CERT_NOT_YET_VALID = _lib.X509_V_ERR_CERT_NOT_YET_VALID + ERR_CERT_HAS_EXPIRED = _lib.X509_V_ERR_CERT_HAS_EXPIRED + ERR_CRL_NOT_YET_VALID = _lib.X509_V_ERR_CRL_NOT_YET_VALID + ERR_CRL_HAS_EXPIRED = _lib.X509_V_ERR_CRL_HAS_EXPIRED + ERR_ERROR_IN_CERT_NOT_BEFORE_FIELD = ( + _lib.X509_V_ERR_ERROR_IN_CERT_NOT_BEFORE_FIELD + ) + ERR_ERROR_IN_CERT_NOT_AFTER_FIELD = ( + _lib.X509_V_ERR_ERROR_IN_CERT_NOT_AFTER_FIELD + ) + ERR_ERROR_IN_CRL_LAST_UPDATE_FIELD = ( + _lib.X509_V_ERR_ERROR_IN_CRL_LAST_UPDATE_FIELD + ) + ERR_ERROR_IN_CRL_NEXT_UPDATE_FIELD = ( + _lib.X509_V_ERR_ERROR_IN_CRL_NEXT_UPDATE_FIELD + ) + ERR_OUT_OF_MEM = _lib.X509_V_ERR_OUT_OF_MEM + ERR_DEPTH_ZERO_SELF_SIGNED_CERT = ( + _lib.X509_V_ERR_DEPTH_ZERO_SELF_SIGNED_CERT + ) + ERR_SELF_SIGNED_CERT_IN_CHAIN = _lib.X509_V_ERR_SELF_SIGNED_CERT_IN_CHAIN + ERR_UNABLE_TO_GET_ISSUER_CERT_LOCALLY = ( + _lib.X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT_LOCALLY + ) + ERR_UNABLE_TO_VERIFY_LEAF_SIGNATURE = ( + _lib.X509_V_ERR_UNABLE_TO_VERIFY_LEAF_SIGNATURE + ) + ERR_CERT_CHAIN_TOO_LONG = _lib.X509_V_ERR_CERT_CHAIN_TOO_LONG + ERR_CERT_REVOKED = _lib.X509_V_ERR_CERT_REVOKED + ERR_INVALID_CA = _lib.X509_V_ERR_INVALID_CA + ERR_PATH_LENGTH_EXCEEDED = _lib.X509_V_ERR_PATH_LENGTH_EXCEEDED + ERR_INVALID_PURPOSE = _lib.X509_V_ERR_INVALID_PURPOSE + ERR_CERT_UNTRUSTED = _lib.X509_V_ERR_CERT_UNTRUSTED + ERR_CERT_REJECTED = _lib.X509_V_ERR_CERT_REJECTED + ERR_SUBJECT_ISSUER_MISMATCH = _lib.X509_V_ERR_SUBJECT_ISSUER_MISMATCH + ERR_AKID_SKID_MISMATCH = _lib.X509_V_ERR_AKID_SKID_MISMATCH + ERR_AKID_ISSUER_SERIAL_MISMATCH = ( + _lib.X509_V_ERR_AKID_ISSUER_SERIAL_MISMATCH + ) + ERR_KEYUSAGE_NO_CERTSIGN = _lib.X509_V_ERR_KEYUSAGE_NO_CERTSIGN + ERR_UNABLE_TO_GET_CRL_ISSUER = _lib.X509_V_ERR_UNABLE_TO_GET_CRL_ISSUER + ERR_UNHANDLED_CRITICAL_EXTENSION = ( + _lib.X509_V_ERR_UNHANDLED_CRITICAL_EXTENSION + ) + ERR_KEYUSAGE_NO_CRL_SIGN = _lib.X509_V_ERR_KEYUSAGE_NO_CRL_SIGN + ERR_UNHANDLED_CRITICAL_CRL_EXTENSION = ( + _lib.X509_V_ERR_UNHANDLED_CRITICAL_CRL_EXTENSION + ) + ERR_INVALID_NON_CA = _lib.X509_V_ERR_INVALID_NON_CA + ERR_PROXY_PATH_LENGTH_EXCEEDED = _lib.X509_V_ERR_PROXY_PATH_LENGTH_EXCEEDED + ERR_KEYUSAGE_NO_DIGITAL_SIGNATURE = ( + _lib.X509_V_ERR_KEYUSAGE_NO_DIGITAL_SIGNATURE + ) + ERR_PROXY_CERTIFICATES_NOT_ALLOWED = ( + _lib.X509_V_ERR_PROXY_CERTIFICATES_NOT_ALLOWED + ) + ERR_INVALID_EXTENSION = _lib.X509_V_ERR_INVALID_EXTENSION + ERR_INVALID_POLICY_EXTENSION = _lib.X509_V_ERR_INVALID_POLICY_EXTENSION + ERR_NO_EXPLICIT_POLICY = _lib.X509_V_ERR_NO_EXPLICIT_POLICY + ERR_DIFFERENT_CRL_SCOPE = _lib.X509_V_ERR_DIFFERENT_CRL_SCOPE + ERR_UNSUPPORTED_EXTENSION_FEATURE = ( + _lib.X509_V_ERR_UNSUPPORTED_EXTENSION_FEATURE + ) + ERR_UNNESTED_RESOURCE = _lib.X509_V_ERR_UNNESTED_RESOURCE + ERR_PERMITTED_VIOLATION = _lib.X509_V_ERR_PERMITTED_VIOLATION + ERR_EXCLUDED_VIOLATION = _lib.X509_V_ERR_EXCLUDED_VIOLATION + ERR_SUBTREE_MINMAX = _lib.X509_V_ERR_SUBTREE_MINMAX + ERR_UNSUPPORTED_CONSTRAINT_TYPE = ( + _lib.X509_V_ERR_UNSUPPORTED_CONSTRAINT_TYPE + ) + ERR_UNSUPPORTED_CONSTRAINT_SYNTAX = ( + _lib.X509_V_ERR_UNSUPPORTED_CONSTRAINT_SYNTAX + ) + ERR_UNSUPPORTED_NAME_SYNTAX = _lib.X509_V_ERR_UNSUPPORTED_NAME_SYNTAX + ERR_CRL_PATH_VALIDATION_ERROR = _lib.X509_V_ERR_CRL_PATH_VALIDATION_ERROR + ERR_HOSTNAME_MISMATCH = _lib.X509_V_ERR_HOSTNAME_MISMATCH + ERR_EMAIL_MISMATCH = _lib.X509_V_ERR_EMAIL_MISMATCH + ERR_IP_ADDRESS_MISMATCH = _lib.X509_V_ERR_IP_ADDRESS_MISMATCH + ERR_APPLICATION_VERIFICATION = _lib.X509_V_ERR_APPLICATION_VERIFICATION + + +# Taken from https://golang.org/src/crypto/x509/root_linux.go +_CERTIFICATE_FILE_LOCATIONS = [ + "/etc/ssl/certs/ca-certificates.crt", # Debian/Ubuntu/Gentoo etc. + "/etc/pki/tls/certs/ca-bundle.crt", # Fedora/RHEL 6 + "/etc/ssl/ca-bundle.pem", # OpenSUSE + "/etc/pki/tls/cacert.pem", # OpenELEC + "/etc/pki/ca-trust/extracted/pem/tls-ca-bundle.pem", # CentOS/RHEL 7 +] + +_CERTIFICATE_PATH_LOCATIONS = [ + "/etc/ssl/certs", # SLES10/SLES11 +] + +# These values are compared to output from cffi's ffi.string so they must be +# byte strings. +_CRYPTOGRAPHY_MANYLINUX_CA_DIR = b"/opt/pyca/cryptography/openssl/certs" +_CRYPTOGRAPHY_MANYLINUX_CA_FILE = b"/opt/pyca/cryptography/openssl/cert.pem" + + +class Error(Exception): + """ + An error occurred in an `OpenSSL.SSL` API. + """ + + +_raise_current_error = partial(_exception_from_error_queue, Error) +_openssl_assert = _make_assert(Error) + + +class WantReadError(Error): + pass + + +class WantWriteError(Error): + pass + + +class WantX509LookupError(Error): + pass + + +class ZeroReturnError(Error): + pass + + +class SysCallError(Error): + pass + + +class _CallbackExceptionHelper: + """ + A base class for wrapper classes that allow for intelligent exception + handling in OpenSSL callbacks. + + :ivar list _problems: Any exceptions that occurred while executing in a + context where they could not be raised in the normal way. Typically + this is because OpenSSL has called into some Python code and requires a + return value. The exceptions are saved to be raised later when it is + possible to do so. + """ + + def __init__(self) -> None: + self._problems: List[Exception] = [] + + def raise_if_problem(self) -> None: + """ + Raise an exception from the OpenSSL error queue or that was previously + captured whe running a callback. + """ + if self._problems: + try: + _raise_current_error() + except Error: + pass + raise self._problems.pop(0) + + +class _VerifyHelper(_CallbackExceptionHelper): + """ + Wrap a callback such that it can be used as a certificate verification + callback. + """ + + def __init__(self, callback: _VerifyCallback) -> None: + _CallbackExceptionHelper.__init__(self) + + @wraps(callback) + def wrapper(ok, store_ctx): # type: ignore[no-untyped-def] + x509 = _lib.X509_STORE_CTX_get_current_cert(store_ctx) + _lib.X509_up_ref(x509) + cert = X509._from_raw_x509_ptr(x509) + error_number = _lib.X509_STORE_CTX_get_error(store_ctx) + error_depth = _lib.X509_STORE_CTX_get_error_depth(store_ctx) + + index = _lib.SSL_get_ex_data_X509_STORE_CTX_idx() + ssl = _lib.X509_STORE_CTX_get_ex_data(store_ctx, index) + connection = Connection._reverse_mapping[ssl] + + try: + result = callback( + connection, cert, error_number, error_depth, ok + ) + except Exception as e: + self._problems.append(e) + return 0 + else: + if result: + _lib.X509_STORE_CTX_set_error(store_ctx, _lib.X509_V_OK) + return 1 + else: + return 0 + + self.callback = _ffi.callback( + "int (*)(int, X509_STORE_CTX *)", wrapper + ) + + +class _ALPNSelectHelper(_CallbackExceptionHelper): + """ + Wrap a callback such that it can be used as an ALPN selection callback. + """ + + def __init__(self, callback: _ALPNSelectCallback) -> None: + _CallbackExceptionHelper.__init__(self) + + @wraps(callback) + def wrapper(ssl, out, outlen, in_, inlen, arg): # type: ignore[no-untyped-def] + try: + conn = Connection._reverse_mapping[ssl] + + # The string passed to us is made up of multiple + # length-prefixed bytestrings. We need to split that into a + # list. + instr = _ffi.buffer(in_, inlen)[:] + protolist = [] + while instr: + encoded_len = instr[0] + proto = instr[1 : encoded_len + 1] + protolist.append(proto) + instr = instr[encoded_len + 1 :] + + # Call the callback + outbytes = callback(conn, protolist) + any_accepted = True + if outbytes is NO_OVERLAPPING_PROTOCOLS: + outbytes = b"" + any_accepted = False + elif not isinstance(outbytes, bytes): + raise TypeError( + "ALPN callback must return a bytestring or the " + "special NO_OVERLAPPING_PROTOCOLS sentinel value." + ) + + # Save our callback arguments on the connection object to make + # sure that they don't get freed before OpenSSL can use them. + # Then, return them in the appropriate output parameters. + conn._alpn_select_callback_args = [ + _ffi.new("unsigned char *", len(outbytes)), + _ffi.new("unsigned char[]", outbytes), + ] + outlen[0] = conn._alpn_select_callback_args[0][0] + out[0] = conn._alpn_select_callback_args[1] + if not any_accepted: + return _lib.SSL_TLSEXT_ERR_NOACK + return _lib.SSL_TLSEXT_ERR_OK + except Exception as e: + self._problems.append(e) + return _lib.SSL_TLSEXT_ERR_ALERT_FATAL + + self.callback = _ffi.callback( + ( + "int (*)(SSL *, unsigned char **, unsigned char *, " + "const unsigned char *, unsigned int, void *)" + ), + wrapper, + ) + + +class _OCSPServerCallbackHelper(_CallbackExceptionHelper): + """ + Wrap a callback such that it can be used as an OCSP callback for the server + side. + + Annoyingly, OpenSSL defines one OCSP callback but uses it in two different + ways. For servers, that callback is expected to retrieve some OCSP data and + hand it to OpenSSL, and may return only SSL_TLSEXT_ERR_OK, + SSL_TLSEXT_ERR_FATAL, and SSL_TLSEXT_ERR_NOACK. For clients, that callback + is expected to check the OCSP data, and returns a negative value on error, + 0 if the response is not acceptable, or positive if it is. These are + mutually exclusive return code behaviours, and they mean that we need two + helpers so that we always return an appropriate error code if the user's + code throws an exception. + + Given that we have to have two helpers anyway, these helpers are a bit more + helpery than most: specifically, they hide a few more of the OpenSSL + functions so that the user has an easier time writing these callbacks. + + This helper implements the server side. + """ + + def __init__(self, callback: _OCSPServerCallback[Any]) -> None: + _CallbackExceptionHelper.__init__(self) + + @wraps(callback) + def wrapper(ssl, cdata): # type: ignore[no-untyped-def] + try: + conn = Connection._reverse_mapping[ssl] + + # Extract the data if any was provided. + if cdata != _ffi.NULL: + data = _ffi.from_handle(cdata) + else: + data = None + + # Call the callback. + ocsp_data = callback(conn, data) + + if not isinstance(ocsp_data, bytes): + raise TypeError("OCSP callback must return a bytestring.") + + # If the OCSP data was provided, we will pass it to OpenSSL. + # However, we have an early exit here: if no OCSP data was + # provided we will just exit out and tell OpenSSL that there + # is nothing to do. + if not ocsp_data: + return 3 # SSL_TLSEXT_ERR_NOACK + + # OpenSSL takes ownership of this data and expects it to have + # been allocated by OPENSSL_malloc. + ocsp_data_length = len(ocsp_data) + data_ptr = _lib.OPENSSL_malloc(ocsp_data_length) + _ffi.buffer(data_ptr, ocsp_data_length)[:] = ocsp_data + + _lib.SSL_set_tlsext_status_ocsp_resp( + ssl, data_ptr, ocsp_data_length + ) + + return 0 + except Exception as e: + self._problems.append(e) + return 2 # SSL_TLSEXT_ERR_ALERT_FATAL + + self.callback = _ffi.callback("int (*)(SSL *, void *)", wrapper) + + +class _OCSPClientCallbackHelper(_CallbackExceptionHelper): + """ + Wrap a callback such that it can be used as an OCSP callback for the client + side. + + Annoyingly, OpenSSL defines one OCSP callback but uses it in two different + ways. For servers, that callback is expected to retrieve some OCSP data and + hand it to OpenSSL, and may return only SSL_TLSEXT_ERR_OK, + SSL_TLSEXT_ERR_FATAL, and SSL_TLSEXT_ERR_NOACK. For clients, that callback + is expected to check the OCSP data, and returns a negative value on error, + 0 if the response is not acceptable, or positive if it is. These are + mutually exclusive return code behaviours, and they mean that we need two + helpers so that we always return an appropriate error code if the user's + code throws an exception. + + Given that we have to have two helpers anyway, these helpers are a bit more + helpery than most: specifically, they hide a few more of the OpenSSL + functions so that the user has an easier time writing these callbacks. + + This helper implements the client side. + """ + + def __init__(self, callback: _OCSPClientCallback[Any]) -> None: + _CallbackExceptionHelper.__init__(self) + + @wraps(callback) + def wrapper(ssl, cdata): # type: ignore[no-untyped-def] + try: + conn = Connection._reverse_mapping[ssl] + + # Extract the data if any was provided. + if cdata != _ffi.NULL: + data = _ffi.from_handle(cdata) + else: + data = None + + # Get the OCSP data. + ocsp_ptr = _ffi.new("unsigned char **") + ocsp_len = _lib.SSL_get_tlsext_status_ocsp_resp(ssl, ocsp_ptr) + if ocsp_len < 0: + # No OCSP data. + ocsp_data = b"" + else: + # Copy the OCSP data, then pass it to the callback. + ocsp_data = _ffi.buffer(ocsp_ptr[0], ocsp_len)[:] + + valid = callback(conn, ocsp_data, data) + + # Return 1 on success or 0 on error. + return int(bool(valid)) + + except Exception as e: + self._problems.append(e) + # Return negative value if an exception is hit. + return -1 + + self.callback = _ffi.callback("int (*)(SSL *, void *)", wrapper) + + +class _CookieGenerateCallbackHelper(_CallbackExceptionHelper): + def __init__(self, callback: _CookieGenerateCallback) -> None: + _CallbackExceptionHelper.__init__(self) + + @wraps(callback) + def wrapper(ssl, out, outlen): # type: ignore[no-untyped-def] + try: + conn = Connection._reverse_mapping[ssl] + cookie = callback(conn) + out[0 : len(cookie)] = cookie + outlen[0] = len(cookie) + return 1 + except Exception as e: + self._problems.append(e) + # "a zero return value can be used to abort the handshake" + return 0 + + self.callback = _ffi.callback( + "int (*)(SSL *, unsigned char *, unsigned int *)", + wrapper, + ) + + +class _CookieVerifyCallbackHelper(_CallbackExceptionHelper): + def __init__(self, callback: _CookieVerifyCallback) -> None: + _CallbackExceptionHelper.__init__(self) + + @wraps(callback) + def wrapper(ssl, c_cookie, cookie_len): # type: ignore[no-untyped-def] + try: + conn = Connection._reverse_mapping[ssl] + return callback(conn, bytes(c_cookie[0:cookie_len])) + except Exception as e: + self._problems.append(e) + return 0 + + self.callback = _ffi.callback( + "int (*)(SSL *, unsigned char *, unsigned int)", + wrapper, + ) + + +def _asFileDescriptor(obj: Any) -> int: + fd = None + if not isinstance(obj, int): + meth = getattr(obj, "fileno", None) + if meth is not None: + obj = meth() + + if isinstance(obj, int): + fd = obj + + if not isinstance(fd, int): + raise TypeError("argument must be an int, or have a fileno() method.") + elif fd < 0: + raise ValueError( + "file descriptor cannot be a negative integer (%i)" % (fd,) + ) + + return fd + + +def OpenSSL_version(type: int) -> bytes: + """ + Return a string describing the version of OpenSSL in use. + + :param type: One of the :const:`OPENSSL_` constants defined in this module. + """ + return _ffi.string(_lib.OpenSSL_version(type)) + + +SSLeay_version = OpenSSL_version + + +def _make_requires(flag: int, error: str) -> Callable[[_T], _T]: + """ + Builds a decorator that ensures that functions that rely on OpenSSL + functions that are not present in this build raise NotImplementedError, + rather than AttributeError coming out of cryptography. + + :param flag: A cryptography flag that guards the functions, e.g. + ``Cryptography_HAS_NEXTPROTONEG``. + :param error: The string to be used in the exception if the flag is false. + """ + + def _requires_decorator(func): # type: ignore[no-untyped-def] + if not flag: + + @wraps(func) + def explode(*args, **kwargs): # type: ignore[no-untyped-def] + raise NotImplementedError(error) + + return explode + else: + return func + + return _requires_decorator + + +_requires_alpn = _make_requires( + _lib.Cryptography_HAS_ALPN, "ALPN not available" +) + + +_requires_keylog = _make_requires( + getattr(_lib, "Cryptography_HAS_KEYLOG", 0), "Key logging not available" +) + + +class Session: + """ + A class representing an SSL session. A session defines certain connection + parameters which may be re-used to speed up the setup of subsequent + connections. + + .. versionadded:: 0.14 + """ + + _session: Any + + +class Context: + """ + :class:`OpenSSL.SSL.Context` instances define the parameters for setting + up new SSL connections. + + :param method: One of TLS_METHOD, TLS_CLIENT_METHOD, TLS_SERVER_METHOD, + DTLS_METHOD, DTLS_CLIENT_METHOD, or DTLS_SERVER_METHOD. + SSLv23_METHOD, TLSv1_METHOD, etc. are deprecated and should + not be used. + """ + + _methods: typing.ClassVar[ + typing.Dict[int, typing.Tuple[Callable[[], Any], Optional[int]]] + ] = { + SSLv23_METHOD: (_lib.TLS_method, None), + TLSv1_METHOD: (_lib.TLS_method, TLS1_VERSION), + TLSv1_1_METHOD: (_lib.TLS_method, TLS1_1_VERSION), + TLSv1_2_METHOD: (_lib.TLS_method, TLS1_2_VERSION), + TLS_METHOD: (_lib.TLS_method, None), + TLS_SERVER_METHOD: (_lib.TLS_server_method, None), + TLS_CLIENT_METHOD: (_lib.TLS_client_method, None), + DTLS_METHOD: (_lib.DTLS_method, None), + DTLS_SERVER_METHOD: (_lib.DTLS_server_method, None), + DTLS_CLIENT_METHOD: (_lib.DTLS_client_method, None), + } + + def __init__(self, method: int) -> None: + if not isinstance(method, int): + raise TypeError("method must be an integer") + + try: + method_func, version = self._methods[method] + except KeyError: + raise ValueError("No such protocol") + + method_obj = method_func() + _openssl_assert(method_obj != _ffi.NULL) + + context = _lib.SSL_CTX_new(method_obj) + _openssl_assert(context != _ffi.NULL) + context = _ffi.gc(context, _lib.SSL_CTX_free) + + self._context = context + self._passphrase_helper: Optional[_PassphraseHelper] = None + self._passphrase_callback: Optional[_PassphraseCallback[Any]] = None + self._passphrase_userdata: Optional[Any] = None + self._verify_helper: Optional[_VerifyHelper] = None + self._verify_callback: Optional[_VerifyCallback] = None + self._info_callback = None + self._keylog_callback = None + self._tlsext_servername_callback = None + self._app_data = None + self._alpn_select_helper: Optional[_ALPNSelectHelper] = None + self._alpn_select_callback: Optional[_ALPNSelectCallback] = None + self._ocsp_helper: typing.Union[ + _OCSPClientCallbackHelper, _OCSPServerCallbackHelper, None + ] = None + self._ocsp_callback: typing.Union[ + _OCSPClientCallback[Any], _OCSPServerCallback[Any], None + ] = None + self._ocsp_data: Optional[Any] = None + self._cookie_generate_helper: Optional[ + _CookieGenerateCallbackHelper + ] = None + self._cookie_verify_helper: Optional[_CookieVerifyCallbackHelper] = ( + None + ) + + self.set_mode(_lib.SSL_MODE_ENABLE_PARTIAL_WRITE) + if version is not None: + self.set_min_proto_version(version) + self.set_max_proto_version(version) + + def set_min_proto_version(self, version: int) -> None: + """ + Set the minimum supported protocol version. Setting the minimum + version to 0 will enable protocol versions down to the lowest version + supported by the library. + + If the underlying OpenSSL build is missing support for the selected + version, this method will raise an exception. + """ + _openssl_assert( + _lib.SSL_CTX_set_min_proto_version(self._context, version) == 1 + ) + + def set_max_proto_version(self, version: int) -> None: + """ + Set the maximum supported protocol version. Setting the maximum + version to 0 will enable protocol versions up to the highest version + supported by the library. + + If the underlying OpenSSL build is missing support for the selected + version, this method will raise an exception. + """ + _openssl_assert( + _lib.SSL_CTX_set_max_proto_version(self._context, version) == 1 + ) + + def load_verify_locations( + self, + cafile: Optional[_StrOrBytesPath], + capath: Optional[_StrOrBytesPath] = None, + ) -> None: + """ + Let SSL know where we can find trusted certificates for the certificate + chain. Note that the certificates have to be in PEM format. + + If capath is passed, it must be a directory prepared using the + ``c_rehash`` tool included with OpenSSL. Either, but not both, of + *pemfile* or *capath* may be :data:`None`. + + :param cafile: In which file we can find the certificates (``bytes`` or + ``str``). + :param capath: In which directory we can find the certificates + (``bytes`` or ``str``). + + :return: None + """ + if cafile is None: + cafile = _ffi.NULL + else: + cafile = _path_bytes(cafile) + + if capath is None: + capath = _ffi.NULL + else: + capath = _path_bytes(capath) + + load_result = _lib.SSL_CTX_load_verify_locations( + self._context, cafile, capath + ) + if not load_result: + _raise_current_error() + + def _wrap_callback( + self, callback: _PassphraseCallback[_T] + ) -> _PassphraseHelper: + @wraps(callback) + def wrapper(size: int, verify: bool, userdata: Any) -> bytes: + return callback(size, verify, self._passphrase_userdata) + + return _PassphraseHelper( + FILETYPE_PEM, wrapper, more_args=True, truncate=True + ) + + def set_passwd_cb( + self, + callback: _PassphraseCallback[_T], + userdata: Optional[_T] = None, + ) -> None: + """ + Set the passphrase callback. This function will be called + when a private key with a passphrase is loaded. + + :param callback: The Python callback to use. This must accept three + positional arguments. First, an integer giving the maximum length + of the passphrase it may return. If the returned passphrase is + longer than this, it will be truncated. Second, a boolean value + which will be true if the user should be prompted for the + passphrase twice and the callback should verify that the two values + supplied are equal. Third, the value given as the *userdata* + parameter to :meth:`set_passwd_cb`. The *callback* must return + a byte string. If an error occurs, *callback* should return a false + value (e.g. an empty string). + :param userdata: (optional) A Python object which will be given as + argument to the callback + :return: None + """ + if not callable(callback): + raise TypeError("callback must be callable") + + self._passphrase_helper = self._wrap_callback(callback) + self._passphrase_callback = self._passphrase_helper.callback + _lib.SSL_CTX_set_default_passwd_cb( + self._context, self._passphrase_callback + ) + self._passphrase_userdata = userdata + + def set_default_verify_paths(self) -> None: + """ + Specify that the platform provided CA certificates are to be used for + verification purposes. This method has some caveats related to the + binary wheels that cryptography (pyOpenSSL's primary dependency) ships: + + * macOS will only load certificates using this method if the user has + the ``openssl@1.1`` `Homebrew `_ formula installed + in the default location. + * Windows will not work. + * manylinux cryptography wheels will work on most common Linux + distributions in pyOpenSSL 17.1.0 and above. pyOpenSSL detects the + manylinux wheel and attempts to load roots via a fallback path. + + :return: None + """ + # SSL_CTX_set_default_verify_paths will attempt to load certs from + # both a cafile and capath that are set at compile time. However, + # it will first check environment variables and, if present, load + # those paths instead + set_result = _lib.SSL_CTX_set_default_verify_paths(self._context) + _openssl_assert(set_result == 1) + # After attempting to set default_verify_paths we need to know whether + # to go down the fallback path. + # First we'll check to see if any env vars have been set. If so, + # we won't try to do anything else because the user has set the path + # themselves. + dir_env_var = _ffi.string(_lib.X509_get_default_cert_dir_env()).decode( + "ascii" + ) + file_env_var = _ffi.string( + _lib.X509_get_default_cert_file_env() + ).decode("ascii") + if not self._check_env_vars_set(dir_env_var, file_env_var): + default_dir = _ffi.string(_lib.X509_get_default_cert_dir()) + default_file = _ffi.string(_lib.X509_get_default_cert_file()) + # Now we check to see if the default_dir and default_file are set + # to the exact values we use in our manylinux builds. If they are + # then we know to load the fallbacks + if ( + default_dir == _CRYPTOGRAPHY_MANYLINUX_CA_DIR + and default_file == _CRYPTOGRAPHY_MANYLINUX_CA_FILE + ): + # This is manylinux, let's load our fallback paths + self._fallback_default_verify_paths( + _CERTIFICATE_FILE_LOCATIONS, _CERTIFICATE_PATH_LOCATIONS + ) + + def _check_env_vars_set(self, dir_env_var: str, file_env_var: str) -> bool: + """ + Check to see if the default cert dir/file environment vars are present. + + :return: bool + """ + return ( + os.environ.get(file_env_var) is not None + or os.environ.get(dir_env_var) is not None + ) + + def _fallback_default_verify_paths( + self, file_path: List[str], dir_path: List[str] + ) -> None: + """ + Default verify paths are based on the compiled version of OpenSSL. + However, when pyca/cryptography is compiled as a manylinux wheel + that compiled location can potentially be wrong. So, like Go, we + will try a predefined set of paths and attempt to load roots + from there. + + :return: None + """ + for cafile in file_path: + if os.path.isfile(cafile): + self.load_verify_locations(cafile) + break + + for capath in dir_path: + if os.path.isdir(capath): + self.load_verify_locations(None, capath) + break + + def use_certificate_chain_file(self, certfile: _StrOrBytesPath) -> None: + """ + Load a certificate chain from a file. + + :param certfile: The name of the certificate chain file (``bytes`` or + ``str``). Must be PEM encoded. + + :return: None + """ + certfile = _path_bytes(certfile) + + result = _lib.SSL_CTX_use_certificate_chain_file( + self._context, certfile + ) + if not result: + _raise_current_error() + + def use_certificate_file( + self, certfile: _StrOrBytesPath, filetype: int = FILETYPE_PEM + ) -> None: + """ + Load a certificate from a file + + :param certfile: The name of the certificate file (``bytes`` or + ``str``). + :param filetype: (optional) The encoding of the file, which is either + :const:`FILETYPE_PEM` or :const:`FILETYPE_ASN1`. The default is + :const:`FILETYPE_PEM`. + + :return: None + """ + certfile = _path_bytes(certfile) + if not isinstance(filetype, int): + raise TypeError("filetype must be an integer") + + use_result = _lib.SSL_CTX_use_certificate_file( + self._context, certfile, filetype + ) + if not use_result: + _raise_current_error() + + def use_certificate(self, cert: X509) -> None: + """ + Load a certificate from a X509 object + + :param cert: The X509 object + :return: None + """ + # Mirrored at Connection.use_certificate + if not isinstance(cert, X509): + raise TypeError("cert must be an X509 instance") + + use_result = _lib.SSL_CTX_use_certificate(self._context, cert._x509) + if not use_result: + _raise_current_error() + + def add_extra_chain_cert(self, certobj: X509) -> None: + """ + Add certificate to chain + + :param certobj: The X509 certificate object to add to the chain + :return: None + """ + if not isinstance(certobj, X509): + raise TypeError("certobj must be an X509 instance") + + copy = _lib.X509_dup(certobj._x509) + add_result = _lib.SSL_CTX_add_extra_chain_cert(self._context, copy) + if not add_result: + # TODO: This is untested. + _lib.X509_free(copy) + _raise_current_error() + + def _raise_passphrase_exception(self) -> None: + if self._passphrase_helper is not None: + self._passphrase_helper.raise_if_problem(Error) + + _raise_current_error() + + def use_privatekey_file( + self, keyfile: _StrOrBytesPath, filetype: int = FILETYPE_PEM + ) -> None: + """ + Load a private key from a file + + :param keyfile: The name of the key file (``bytes`` or ``str``) + :param filetype: (optional) The encoding of the file, which is either + :const:`FILETYPE_PEM` or :const:`FILETYPE_ASN1`. The default is + :const:`FILETYPE_PEM`. + + :return: None + """ + keyfile = _path_bytes(keyfile) + + if not isinstance(filetype, int): + raise TypeError("filetype must be an integer") + + use_result = _lib.SSL_CTX_use_PrivateKey_file( + self._context, keyfile, filetype + ) + if not use_result: + self._raise_passphrase_exception() + + def use_privatekey(self, pkey: PKey) -> None: + """ + Load a private key from a PKey object + + :param pkey: The PKey object + :return: None + """ + # Mirrored at Connection.use_privatekey + if not isinstance(pkey, PKey): + raise TypeError("pkey must be a PKey instance") + + use_result = _lib.SSL_CTX_use_PrivateKey(self._context, pkey._pkey) + if not use_result: + self._raise_passphrase_exception() + + def check_privatekey(self) -> None: + """ + Check if the private key (loaded with :meth:`use_privatekey`) matches + the certificate (loaded with :meth:`use_certificate`) + + :return: :data:`None` (raises :exc:`Error` if something's wrong) + """ + if not _lib.SSL_CTX_check_private_key(self._context): + _raise_current_error() + + def load_client_ca(self, cafile: bytes) -> None: + """ + Load the trusted certificates that will be sent to the client. Does + not actually imply any of the certificates are trusted; that must be + configured separately. + + :param bytes cafile: The path to a certificates file in PEM format. + :return: None + """ + ca_list = _lib.SSL_load_client_CA_file( + _text_to_bytes_and_warn("cafile", cafile) + ) + _openssl_assert(ca_list != _ffi.NULL) + _lib.SSL_CTX_set_client_CA_list(self._context, ca_list) + + def set_session_id(self, buf: bytes) -> None: + """ + Set the session id to *buf* within which a session can be reused for + this Context object. This is needed when doing session resumption, + because there is no way for a stored session to know which Context + object it is associated with. + + :param bytes buf: The session id. + + :returns: None + """ + buf = _text_to_bytes_and_warn("buf", buf) + _openssl_assert( + _lib.SSL_CTX_set_session_id_context(self._context, buf, len(buf)) + == 1 + ) + + def set_session_cache_mode(self, mode: int) -> None: + """ + Set the behavior of the session cache used by all connections using + this Context. The previously set mode is returned. See + :const:`SESS_CACHE_*` for details about particular modes. + + :param mode: One or more of the SESS_CACHE_* flags (combine using + bitwise or) + :returns: The previously set caching mode. + + .. versionadded:: 0.14 + """ + if not isinstance(mode, int): + raise TypeError("mode must be an integer") + + return _lib.SSL_CTX_set_session_cache_mode(self._context, mode) + + def get_session_cache_mode(self) -> int: + """ + Get the current session cache mode. + + :returns: The currently used cache mode. + + .. versionadded:: 0.14 + """ + return _lib.SSL_CTX_get_session_cache_mode(self._context) + + def set_verify( + self, mode: int, callback: Optional[_VerifyCallback] = None + ) -> None: + """ + Set the verification flags for this Context object to *mode* and + specify that *callback* should be used for verification callbacks. + + :param mode: The verify mode, this should be one of + :const:`VERIFY_NONE` and :const:`VERIFY_PEER`. If + :const:`VERIFY_PEER` is used, *mode* can be OR:ed with + :const:`VERIFY_FAIL_IF_NO_PEER_CERT` and + :const:`VERIFY_CLIENT_ONCE` to further control the behaviour. + :param callback: The optional Python verification callback to use. + This should take five arguments: A Connection object, an X509 + object, and three integer variables, which are in turn potential + error number, error depth and return code. *callback* should + return True if verification passes and False otherwise. + If omitted, OpenSSL's default verification is used. + :return: None + + See SSL_CTX_set_verify(3SSL) for further details. + """ + if not isinstance(mode, int): + raise TypeError("mode must be an integer") + + if callback is None: + self._verify_helper = None + self._verify_callback = None + _lib.SSL_CTX_set_verify(self._context, mode, _ffi.NULL) + else: + if not callable(callback): + raise TypeError("callback must be callable") + + self._verify_helper = _VerifyHelper(callback) + self._verify_callback = self._verify_helper.callback + _lib.SSL_CTX_set_verify(self._context, mode, self._verify_callback) + + def set_verify_depth(self, depth: int) -> None: + """ + Set the maximum depth for the certificate chain verification that shall + be allowed for this Context object. + + :param depth: An integer specifying the verify depth + :return: None + """ + if not isinstance(depth, int): + raise TypeError("depth must be an integer") + + _lib.SSL_CTX_set_verify_depth(self._context, depth) + + def get_verify_mode(self) -> int: + """ + Retrieve the Context object's verify mode, as set by + :meth:`set_verify`. + + :return: The verify mode + """ + return _lib.SSL_CTX_get_verify_mode(self._context) + + def get_verify_depth(self) -> int: + """ + Retrieve the Context object's verify depth, as set by + :meth:`set_verify_depth`. + + :return: The verify depth + """ + return _lib.SSL_CTX_get_verify_depth(self._context) + + def load_tmp_dh(self, dhfile: _StrOrBytesPath) -> None: + """ + Load parameters for Ephemeral Diffie-Hellman + + :param dhfile: The file to load EDH parameters from (``bytes`` or + ``str``). + + :return: None + """ + dhfile = _path_bytes(dhfile) + + bio = _lib.BIO_new_file(dhfile, b"r") + if bio == _ffi.NULL: + _raise_current_error() + bio = _ffi.gc(bio, _lib.BIO_free) + + dh = _lib.PEM_read_bio_DHparams(bio, _ffi.NULL, _ffi.NULL, _ffi.NULL) + dh = _ffi.gc(dh, _lib.DH_free) + res = _lib.SSL_CTX_set_tmp_dh(self._context, dh) + _openssl_assert(res == 1) + + def set_tmp_ecdh(self, curve: _EllipticCurve) -> None: + """ + Select a curve to use for ECDHE key exchange. + + :param curve: A curve object to use as returned by either + :meth:`OpenSSL.crypto.get_elliptic_curve` or + :meth:`OpenSSL.crypto.get_elliptic_curves`. + + :return: None + """ + _lib.SSL_CTX_set_tmp_ecdh(self._context, curve._to_EC_KEY()) + + def set_cipher_list(self, cipher_list: bytes) -> None: + """ + Set the list of ciphers to be used in this context. + + See the OpenSSL manual for more information (e.g. + :manpage:`ciphers(1)`). + + :param bytes cipher_list: An OpenSSL cipher string. + :return: None + """ + cipher_list = _text_to_bytes_and_warn("cipher_list", cipher_list) + + if not isinstance(cipher_list, bytes): + raise TypeError("cipher_list must be a byte string.") + + _openssl_assert( + _lib.SSL_CTX_set_cipher_list(self._context, cipher_list) == 1 + ) + # In OpenSSL 1.1.1 setting the cipher list will always return TLS 1.3 + # ciphers even if you pass an invalid cipher. Applications (like + # Twisted) have tests that depend on an error being raised if an + # invalid cipher string is passed, but without the following check + # for the TLS 1.3 specific cipher suites it would never error. + tmpconn = Connection(self, None) + if tmpconn.get_cipher_list() == [ + "TLS_AES_256_GCM_SHA384", + "TLS_CHACHA20_POLY1305_SHA256", + "TLS_AES_128_GCM_SHA256", + ]: + raise Error( + [ + ( + "SSL routines", + "SSL_CTX_set_cipher_list", + "no cipher match", + ), + ], + ) + + def set_client_ca_list( + self, certificate_authorities: Sequence[X509Name] + ) -> None: + """ + Set the list of preferred client certificate signers for this server + context. + + This list of certificate authorities will be sent to the client when + the server requests a client certificate. + + :param certificate_authorities: a sequence of X509Names. + :return: None + + .. versionadded:: 0.10 + """ + name_stack = _lib.sk_X509_NAME_new_null() + _openssl_assert(name_stack != _ffi.NULL) + + try: + for ca_name in certificate_authorities: + if not isinstance(ca_name, X509Name): + raise TypeError( + f"client CAs must be X509Name objects, not " + f"{type(ca_name).__name__} objects" + ) + copy = _lib.X509_NAME_dup(ca_name._name) + _openssl_assert(copy != _ffi.NULL) + push_result = _lib.sk_X509_NAME_push(name_stack, copy) + if not push_result: + _lib.X509_NAME_free(copy) + _raise_current_error() + except Exception: + _lib.sk_X509_NAME_free(name_stack) + raise + + _lib.SSL_CTX_set_client_CA_list(self._context, name_stack) + + def add_client_ca(self, certificate_authority: X509) -> None: + """ + Add the CA certificate to the list of preferred signers for this + context. + + The list of certificate authorities will be sent to the client when the + server requests a client certificate. + + :param certificate_authority: certificate authority's X509 certificate. + :return: None + + .. versionadded:: 0.10 + """ + if not isinstance(certificate_authority, X509): + raise TypeError("certificate_authority must be an X509 instance") + + add_result = _lib.SSL_CTX_add_client_CA( + self._context, certificate_authority._x509 + ) + _openssl_assert(add_result == 1) + + def set_timeout(self, timeout: int) -> None: + """ + Set the timeout for newly created sessions for this Context object to + *timeout*. The default value is 300 seconds. See the OpenSSL manual + for more information (e.g. :manpage:`SSL_CTX_set_timeout(3)`). + + :param timeout: The timeout in (whole) seconds + :return: The previous session timeout + """ + if not isinstance(timeout, int): + raise TypeError("timeout must be an integer") + + return _lib.SSL_CTX_set_timeout(self._context, timeout) + + def get_timeout(self) -> int: + """ + Retrieve session timeout, as set by :meth:`set_timeout`. The default + is 300 seconds. + + :return: The session timeout + """ + return _lib.SSL_CTX_get_timeout(self._context) + + def set_info_callback( + self, callback: Callable[["Connection", int, int], None] + ) -> None: + """ + Set the information callback to *callback*. This function will be + called from time to time during SSL handshakes. + + :param callback: The Python callback to use. This should take three + arguments: a Connection object and two integers. The first integer + specifies where in the SSL handshake the function was called, and + the other the return code from a (possibly failed) internal + function call. + :return: None + """ + + @wraps(callback) + def wrapper(ssl, where, return_code): # type: ignore[no-untyped-def] + callback(Connection._reverse_mapping[ssl], where, return_code) + + self._info_callback = _ffi.callback( + "void (*)(const SSL *, int, int)", wrapper + ) + _lib.SSL_CTX_set_info_callback(self._context, self._info_callback) + + @_requires_keylog + def set_keylog_callback( + self, callback: Callable[["Connection", bytes], None] + ) -> None: + """ + Set the TLS key logging callback to *callback*. This function will be + called whenever TLS key material is generated or received, in order + to allow applications to store this keying material for debugging + purposes. + + :param callback: The Python callback to use. This should take two + arguments: a Connection object and a bytestring that contains + the key material in the format used by NSS for its SSLKEYLOGFILE + debugging output. + :return: None + """ + + @wraps(callback) + def wrapper(ssl, line): # type: ignore[no-untyped-def] + line = _ffi.string(line) + callback(Connection._reverse_mapping[ssl], line) + + self._keylog_callback = _ffi.callback( + "void (*)(const SSL *, const char *)", wrapper + ) + _lib.SSL_CTX_set_keylog_callback(self._context, self._keylog_callback) + + def get_app_data(self) -> Any: + """ + Get the application data (supplied via :meth:`set_app_data()`) + + :return: The application data + """ + return self._app_data + + def set_app_data(self, data: Any) -> None: + """ + Set the application data (will be returned from get_app_data()) + + :param data: Any Python object + :return: None + """ + self._app_data = data + + def get_cert_store(self) -> Optional[X509Store]: + """ + Get the certificate store for the context. This can be used to add + "trusted" certificates without using the + :meth:`load_verify_locations` method. + + :return: A X509Store object or None if it does not have one. + """ + store = _lib.SSL_CTX_get_cert_store(self._context) + if store == _ffi.NULL: + # TODO: This is untested. + return None + + pystore = X509Store.__new__(X509Store) + pystore._store = store + return pystore + + def set_options(self, options: int) -> None: + """ + Add options. Options set before are not cleared! + This method should be used with the :const:`OP_*` constants. + + :param options: The options to add. + :return: The new option bitmask. + """ + if not isinstance(options, int): + raise TypeError("options must be an integer") + + return _lib.SSL_CTX_set_options(self._context, options) + + def set_mode(self, mode: int) -> None: + """ + Add modes via bitmask. Modes set before are not cleared! This method + should be used with the :const:`MODE_*` constants. + + :param mode: The mode to add. + :return: The new mode bitmask. + """ + if not isinstance(mode, int): + raise TypeError("mode must be an integer") + + return _lib.SSL_CTX_set_mode(self._context, mode) + + def set_tlsext_servername_callback( + self, callback: Callable[["Connection"], None] + ) -> None: + """ + Specify a callback function to be called when clients specify a server + name. + + :param callback: The callback function. It will be invoked with one + argument, the Connection instance. + + .. versionadded:: 0.13 + """ + + @wraps(callback) + def wrapper(ssl, alert, arg): # type: ignore[no-untyped-def] + callback(Connection._reverse_mapping[ssl]) + return 0 + + self._tlsext_servername_callback = _ffi.callback( + "int (*)(SSL *, int *, void *)", wrapper + ) + _lib.SSL_CTX_set_tlsext_servername_callback( + self._context, self._tlsext_servername_callback + ) + + def set_tlsext_use_srtp(self, profiles: bytes) -> None: + """ + Enable support for negotiating SRTP keying material. + + :param bytes profiles: A colon delimited list of protection profile + names, like ``b'SRTP_AES128_CM_SHA1_80:SRTP_AES128_CM_SHA1_32'``. + :return: None + """ + if not isinstance(profiles, bytes): + raise TypeError("profiles must be a byte string.") + + _openssl_assert( + _lib.SSL_CTX_set_tlsext_use_srtp(self._context, profiles) == 0 + ) + + @_requires_alpn + def set_alpn_protos(self, protos: List[bytes]) -> None: + """ + Specify the protocols that the client is prepared to speak after the + TLS connection has been negotiated using Application Layer Protocol + Negotiation. + + :param protos: A list of the protocols to be offered to the server. + This list should be a Python list of bytestrings representing the + protocols to offer, e.g. ``[b'http/1.1', b'spdy/2']``. + """ + # Different versions of OpenSSL are inconsistent about how they handle + # empty proto lists (see #1043), so we avoid the problem entirely by + # rejecting them ourselves. + if not protos: + raise ValueError("at least one protocol must be specified") + + # Take the list of protocols and join them together, prefixing them + # with their lengths. + protostr = b"".join( + chain.from_iterable((bytes((len(p),)), p) for p in protos) + ) + + # Build a C string from the list. We don't need to save this off + # because OpenSSL immediately copies the data out. + input_str = _ffi.new("unsigned char[]", protostr) + + # https://www.openssl.org/docs/man1.1.0/man3/SSL_CTX_set_alpn_protos.html: + # SSL_CTX_set_alpn_protos() and SSL_set_alpn_protos() + # return 0 on success, and non-0 on failure. + # WARNING: these functions reverse the return value convention. + _openssl_assert( + _lib.SSL_CTX_set_alpn_protos( + self._context, input_str, len(protostr) + ) + == 0 + ) + + @_requires_alpn + def set_alpn_select_callback(self, callback: _ALPNSelectCallback) -> None: + """ + Specify a callback function that will be called on the server when a + client offers protocols using ALPN. + + :param callback: The callback function. It will be invoked with two + arguments: the Connection, and a list of offered protocols as + bytestrings, e.g ``[b'http/1.1', b'spdy/2']``. It can return + one of those bytestrings to indicate the chosen protocol, the + empty bytestring to terminate the TLS connection, or the + :py:obj:`NO_OVERLAPPING_PROTOCOLS` to indicate that no offered + protocol was selected, but that the connection should not be + aborted. + """ + self._alpn_select_helper = _ALPNSelectHelper(callback) + self._alpn_select_callback = self._alpn_select_helper.callback + _lib.SSL_CTX_set_alpn_select_cb( + self._context, self._alpn_select_callback, _ffi.NULL + ) + + def _set_ocsp_callback( + self, + helper: typing.Union[ + _OCSPClientCallbackHelper, _OCSPServerCallbackHelper + ], + data: Optional[Any], + ) -> None: + """ + This internal helper does the common work for + ``set_ocsp_server_callback`` and ``set_ocsp_client_callback``, which is + almost all of it. + """ + self._ocsp_helper = helper + self._ocsp_callback = helper.callback + if data is None: + self._ocsp_data = _ffi.NULL + else: + self._ocsp_data = _ffi.new_handle(data) + + rc = _lib.SSL_CTX_set_tlsext_status_cb( + self._context, self._ocsp_callback + ) + _openssl_assert(rc == 1) + rc = _lib.SSL_CTX_set_tlsext_status_arg(self._context, self._ocsp_data) + _openssl_assert(rc == 1) + + def set_ocsp_server_callback( + self, + callback: _OCSPServerCallback[_T], + data: Optional[_T] = None, + ) -> None: + """ + Set a callback to provide OCSP data to be stapled to the TLS handshake + on the server side. + + :param callback: The callback function. It will be invoked with two + arguments: the Connection, and the optional arbitrary data you have + provided. The callback must return a bytestring that contains the + OCSP data to staple to the handshake. If no OCSP data is available + for this connection, return the empty bytestring. + :param data: Some opaque data that will be passed into the callback + function when called. This can be used to avoid needing to do + complex data lookups or to keep track of what context is being + used. This parameter is optional. + """ + helper = _OCSPServerCallbackHelper(callback) + self._set_ocsp_callback(helper, data) + + def set_ocsp_client_callback( + self, + callback: _OCSPClientCallback[_T], + data: Optional[_T] = None, + ) -> None: + """ + Set a callback to validate OCSP data stapled to the TLS handshake on + the client side. + + :param callback: The callback function. It will be invoked with three + arguments: the Connection, a bytestring containing the stapled OCSP + assertion, and the optional arbitrary data you have provided. The + callback must return a boolean that indicates the result of + validating the OCSP data: ``True`` if the OCSP data is valid and + the certificate can be trusted, or ``False`` if either the OCSP + data is invalid or the certificate has been revoked. + :param data: Some opaque data that will be passed into the callback + function when called. This can be used to avoid needing to do + complex data lookups or to keep track of what context is being + used. This parameter is optional. + """ + helper = _OCSPClientCallbackHelper(callback) + self._set_ocsp_callback(helper, data) + + def set_cookie_generate_callback( + self, callback: _CookieGenerateCallback + ) -> None: + self._cookie_generate_helper = _CookieGenerateCallbackHelper(callback) + _lib.SSL_CTX_set_cookie_generate_cb( + self._context, + self._cookie_generate_helper.callback, + ) + + def set_cookie_verify_callback( + self, callback: _CookieVerifyCallback + ) -> None: + self._cookie_verify_helper = _CookieVerifyCallbackHelper(callback) + _lib.SSL_CTX_set_cookie_verify_cb( + self._context, + self._cookie_verify_helper.callback, + ) + + +class Connection: + _reverse_mapping: typing.MutableMapping[Any, "Connection"] = ( + WeakValueDictionary() + ) + + def __init__( + self, context: Context, socket: Optional[socket.socket] = None + ) -> None: + """ + Create a new Connection object, using the given OpenSSL.SSL.Context + instance and socket. + + :param context: An SSL Context to use for this connection + :param socket: The socket to use for transport layer + """ + if not isinstance(context, Context): + raise TypeError("context must be a Context instance") + + ssl = _lib.SSL_new(context._context) + self._ssl = _ffi.gc(ssl, _lib.SSL_free) + # We set SSL_MODE_AUTO_RETRY to handle situations where OpenSSL returns + # an SSL_ERROR_WANT_READ when processing a non-application data packet + # even though there is still data on the underlying transport. + # See https://github.com/openssl/openssl/issues/6234 for more details. + _lib.SSL_set_mode(self._ssl, _lib.SSL_MODE_AUTO_RETRY) + self._context = context + self._app_data = None + + # References to strings used for Application Layer Protocol + # Negotiation. These strings get copied at some point but it's well + # after the callback returns, so we have to hang them somewhere to + # avoid them getting freed. + self._alpn_select_callback_args = None + + # Reference the verify_callback of the Context. This ensures that if + # set_verify is called again after the SSL object has been created we + # do not point to a dangling reference + self._verify_helper = context._verify_helper + self._verify_callback = context._verify_callback + + # And likewise for the cookie callbacks + self._cookie_generate_helper = context._cookie_generate_helper + self._cookie_verify_helper = context._cookie_verify_helper + + self._reverse_mapping[self._ssl] = self + + if socket is None: + self._socket = None + # Don't set up any gc for these, SSL_free will take care of them. + self._into_ssl = _lib.BIO_new(_lib.BIO_s_mem()) + _openssl_assert(self._into_ssl != _ffi.NULL) + + self._from_ssl = _lib.BIO_new(_lib.BIO_s_mem()) + _openssl_assert(self._from_ssl != _ffi.NULL) + + _lib.SSL_set_bio(self._ssl, self._into_ssl, self._from_ssl) + else: + self._into_ssl = None + self._from_ssl = None + self._socket = socket + set_result = _lib.SSL_set_fd( + self._ssl, _asFileDescriptor(self._socket) + ) + _openssl_assert(set_result == 1) + + def __getattr__(self, name: str) -> Any: + """ + Look up attributes on the wrapped socket object if they are not found + on the Connection object. + """ + if self._socket is None: + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{name}'" + ) + else: + return getattr(self._socket, name) + + def _raise_ssl_error(self, ssl: Any, result: int) -> None: + if self._context._verify_helper is not None: + self._context._verify_helper.raise_if_problem() + if self._context._alpn_select_helper is not None: + self._context._alpn_select_helper.raise_if_problem() + if self._context._ocsp_helper is not None: + self._context._ocsp_helper.raise_if_problem() + + error = _lib.SSL_get_error(ssl, result) + if error == _lib.SSL_ERROR_WANT_READ: + raise WantReadError() + elif error == _lib.SSL_ERROR_WANT_WRITE: + raise WantWriteError() + elif error == _lib.SSL_ERROR_ZERO_RETURN: + raise ZeroReturnError() + elif error == _lib.SSL_ERROR_WANT_X509_LOOKUP: + # TODO: This is untested. + raise WantX509LookupError() + elif error == _lib.SSL_ERROR_SYSCALL: + if _lib.ERR_peek_error() == 0: + if result < 0: + if platform == "win32": + errno = _ffi.getwinerror()[0] + else: + errno = _ffi.errno + + if errno != 0: + raise SysCallError(errno, errorcode.get(errno)) + raise SysCallError(-1, "Unexpected EOF") + else: + # TODO: This is untested. + _raise_current_error() + elif error == _lib.SSL_ERROR_SSL and _lib.ERR_peek_error() != 0: + # In 3.0.x an unexpected EOF no longer triggers syscall error + # but we want to maintain compatibility so we check here and + # raise syscall if it is an EOF. Since we're not actually sure + # what else could raise SSL_ERROR_SSL we check for the presence + # of the OpenSSL 3 constant SSL_R_UNEXPECTED_EOF_WHILE_READING + # and if it's not present we just raise an error, which matches + # the behavior before we added this elif section + peeked_error = _lib.ERR_peek_error() + reason = _lib.ERR_GET_REASON(peeked_error) + if _lib.Cryptography_HAS_UNEXPECTED_EOF_WHILE_READING: + _openssl_assert( + reason == _lib.SSL_R_UNEXPECTED_EOF_WHILE_READING + ) + _lib.ERR_clear_error() + raise SysCallError(-1, "Unexpected EOF") + else: + _raise_current_error() + elif error == _lib.SSL_ERROR_NONE: + pass + else: + _raise_current_error() + + def get_context(self) -> Context: + """ + Retrieve the :class:`Context` object associated with this + :class:`Connection`. + """ + return self._context + + def set_context(self, context: Context) -> None: + """ + Switch this connection to a new session context. + + :param context: A :class:`Context` instance giving the new session + context to use. + """ + if not isinstance(context, Context): + raise TypeError("context must be a Context instance") + + _lib.SSL_set_SSL_CTX(self._ssl, context._context) + self._context = context + + def get_servername(self) -> Optional[bytes]: + """ + Retrieve the servername extension value if provided in the client hello + message, or None if there wasn't one. + + :return: A byte string giving the server name or :data:`None`. + + .. versionadded:: 0.13 + """ + name = _lib.SSL_get_servername( + self._ssl, _lib.TLSEXT_NAMETYPE_host_name + ) + if name == _ffi.NULL: + return None + + return _ffi.string(name) + + def set_verify( + self, mode: int, callback: Optional[_VerifyCallback] = None + ) -> None: + """ + Override the Context object's verification flags for this specific + connection. See :py:meth:`Context.set_verify` for details. + """ + if not isinstance(mode, int): + raise TypeError("mode must be an integer") + + if callback is None: + self._verify_helper = None + self._verify_callback = None + _lib.SSL_set_verify(self._ssl, mode, _ffi.NULL) + else: + if not callable(callback): + raise TypeError("callback must be callable") + + self._verify_helper = _VerifyHelper(callback) + self._verify_callback = self._verify_helper.callback + _lib.SSL_set_verify(self._ssl, mode, self._verify_callback) + + def get_verify_mode(self) -> int: + """ + Retrieve the Connection object's verify mode, as set by + :meth:`set_verify`. + + :return: The verify mode + """ + return _lib.SSL_get_verify_mode(self._ssl) + + def use_certificate(self, cert: X509) -> None: + """ + Load a certificate from a X509 object + + :param cert: The X509 object + :return: None + """ + # Mirrored from Context.use_certificate + if not isinstance(cert, X509): + raise TypeError("cert must be an X509 instance") + + use_result = _lib.SSL_use_certificate(self._ssl, cert._x509) + if not use_result: + _raise_current_error() + + def use_privatekey(self, pkey: PKey) -> None: + """ + Load a private key from a PKey object + + :param pkey: The PKey object + :return: None + """ + # Mirrored from Context.use_privatekey + if not isinstance(pkey, PKey): + raise TypeError("pkey must be a PKey instance") + + use_result = _lib.SSL_use_PrivateKey(self._ssl, pkey._pkey) + if not use_result: + self._context._raise_passphrase_exception() + + def set_ciphertext_mtu(self, mtu: int) -> None: + """ + For DTLS, set the maximum UDP payload size (*not* including IP/UDP + overhead). + + Note that you might have to set :data:`OP_NO_QUERY_MTU` to prevent + OpenSSL from spontaneously clearing this. + + :param mtu: An integer giving the maximum transmission unit. + + .. versionadded:: 21.1 + """ + _lib.SSL_set_mtu(self._ssl, mtu) + + def get_cleartext_mtu(self) -> int: + """ + For DTLS, get the maximum size of unencrypted data you can pass to + :meth:`write` without exceeding the MTU (as passed to + :meth:`set_ciphertext_mtu`). + + :return: The effective MTU as an integer. + + .. versionadded:: 21.1 + """ + + if not hasattr(_lib, "DTLS_get_data_mtu"): + raise NotImplementedError("requires OpenSSL 1.1.1 or better") + return _lib.DTLS_get_data_mtu(self._ssl) + + def set_tlsext_host_name(self, name: bytes) -> None: + """ + Set the value of the servername extension to send in the client hello. + + :param name: A byte string giving the name. + + .. versionadded:: 0.13 + """ + if not isinstance(name, bytes): + raise TypeError("name must be a byte string") + elif b"\0" in name: + raise TypeError("name must not contain NUL byte") + + # XXX I guess this can fail sometimes? + _lib.SSL_set_tlsext_host_name(self._ssl, name) + + def pending(self) -> int: + """ + Get the number of bytes that can be safely read from the SSL buffer + (**not** the underlying transport buffer). + + :return: The number of bytes available in the receive buffer. + """ + return _lib.SSL_pending(self._ssl) + + def send(self, buf: bytes, flags: int = 0) -> int: + """ + Send data on the connection. NOTE: If you get one of the WantRead, + WantWrite or WantX509Lookup exceptions on this, you have to call the + method again with the SAME buffer. + + :param buf: The string, buffer or memoryview to send + :param flags: (optional) Included for compatibility with the socket + API, the value is ignored + :return: The number of bytes written + """ + # Backward compatibility + buf = _text_to_bytes_and_warn("buf", buf) + + with _ffi.from_buffer(buf) as data: + # check len(buf) instead of len(data) for testability + if len(buf) > 2147483647: + raise ValueError( + "Cannot send more than 2**31-1 bytes at once." + ) + + result = _lib.SSL_write(self._ssl, data, len(data)) + self._raise_ssl_error(self._ssl, result) + + return result + + write = send + + def sendall(self, buf: bytes, flags: int = 0) -> int: + """ + Send "all" data on the connection. This calls send() repeatedly until + all data is sent. If an error occurs, it's impossible to tell how much + data has been sent. + + :param buf: The string, buffer or memoryview to send + :param flags: (optional) Included for compatibility with the socket + API, the value is ignored + :return: The number of bytes written + """ + buf = _text_to_bytes_and_warn("buf", buf) + + with _ffi.from_buffer(buf) as data: + left_to_send = len(buf) + total_sent = 0 + + while left_to_send: + # SSL_write's num arg is an int, + # so we cannot send more than 2**31-1 bytes at once. + result = _lib.SSL_write( + self._ssl, data + total_sent, min(left_to_send, 2147483647) + ) + self._raise_ssl_error(self._ssl, result) + total_sent += result + left_to_send -= result + + return total_sent + + def recv(self, bufsiz: int, flags: Optional[int] = None) -> bytes: + """ + Receive data on the connection. + + :param bufsiz: The maximum number of bytes to read + :param flags: (optional) The only supported flag is ``MSG_PEEK``, + all other flags are ignored. + :return: The string read from the Connection + """ + buf = _no_zero_allocator("char[]", bufsiz) + if flags is not None and flags & socket.MSG_PEEK: + result = _lib.SSL_peek(self._ssl, buf, bufsiz) + else: + result = _lib.SSL_read(self._ssl, buf, bufsiz) + self._raise_ssl_error(self._ssl, result) + return _ffi.buffer(buf, result)[:] + + read = recv + + def recv_into( + self, + buffer: Any, # collections.abc.Buffer once we use Python 3.12+ + nbytes: Optional[int] = None, + flags: Optional[int] = None, + ) -> int: + """ + Receive data on the connection and copy it directly into the provided + buffer, rather than creating a new string. + + :param buffer: The buffer to copy into. + :param nbytes: (optional) The maximum number of bytes to read into the + buffer. If not present, defaults to the size of the buffer. If + larger than the size of the buffer, is reduced to the size of the + buffer. + :param flags: (optional) The only supported flag is ``MSG_PEEK``, + all other flags are ignored. + :return: The number of bytes read into the buffer. + """ + if nbytes is None: + nbytes = len(buffer) + else: + nbytes = min(nbytes, len(buffer)) + + # We need to create a temporary buffer. This is annoying, it would be + # better if we could pass memoryviews straight into the SSL_read call, + # but right now we can't. Revisit this if CFFI gets that ability. + buf = _no_zero_allocator("char[]", nbytes) + if flags is not None and flags & socket.MSG_PEEK: + result = _lib.SSL_peek(self._ssl, buf, nbytes) + else: + result = _lib.SSL_read(self._ssl, buf, nbytes) + self._raise_ssl_error(self._ssl, result) + + # This strange line is all to avoid a memory copy. The buffer protocol + # should allow us to assign a CFFI buffer to the LHS of this line, but + # on CPython 3.3+ that segfaults. As a workaround, we can temporarily + # wrap it in a memoryview. + buffer[:result] = memoryview(_ffi.buffer(buf, result)) + + return result + + def _handle_bio_errors(self, bio: Any, result: int) -> typing.NoReturn: + if _lib.BIO_should_retry(bio): + if _lib.BIO_should_read(bio): + raise WantReadError() + elif _lib.BIO_should_write(bio): + # TODO: This is untested. + raise WantWriteError() + elif _lib.BIO_should_io_special(bio): + # TODO: This is untested. I think io_special means the socket + # BIO has a not-yet connected socket. + raise ValueError("BIO_should_io_special") + else: + # TODO: This is untested. + raise ValueError("unknown bio failure") + else: + # TODO: This is untested. + _raise_current_error() + + def bio_read(self, bufsiz: int) -> bytes: + """ + If the Connection was created with a memory BIO, this method can be + used to read bytes from the write end of that memory BIO. Many + Connection methods will add bytes which must be read in this manner or + the buffer will eventually fill up and the Connection will be able to + take no further actions. + + :param bufsiz: The maximum number of bytes to read + :return: The string read. + """ + if self._from_ssl is None: + raise TypeError("Connection sock was not None") + + if not isinstance(bufsiz, int): + raise TypeError("bufsiz must be an integer") + + buf = _no_zero_allocator("char[]", bufsiz) + result = _lib.BIO_read(self._from_ssl, buf, bufsiz) + if result <= 0: + self._handle_bio_errors(self._from_ssl, result) + + return _ffi.buffer(buf, result)[:] + + def bio_write(self, buf: bytes) -> int: + """ + If the Connection was created with a memory BIO, this method can be + used to add bytes to the read end of that memory BIO. The Connection + can then read the bytes (for example, in response to a call to + :meth:`recv`). + + :param buf: The string to put into the memory BIO. + :return: The number of bytes written + """ + buf = _text_to_bytes_and_warn("buf", buf) + + if self._into_ssl is None: + raise TypeError("Connection sock was not None") + + with _ffi.from_buffer(buf) as data: + result = _lib.BIO_write(self._into_ssl, data, len(data)) + if result <= 0: + self._handle_bio_errors(self._into_ssl, result) + return result + + def renegotiate(self) -> bool: + """ + Renegotiate the session. + + :return: True if the renegotiation can be started, False otherwise + """ + if not self.renegotiate_pending(): + _openssl_assert(_lib.SSL_renegotiate(self._ssl) == 1) + return True + return False + + def do_handshake(self) -> None: + """ + Perform an SSL handshake (usually called after :meth:`renegotiate` or + one of :meth:`set_accept_state` or :meth:`set_connect_state`). This can + raise the same exceptions as :meth:`send` and :meth:`recv`. + + :return: None. + """ + result = _lib.SSL_do_handshake(self._ssl) + self._raise_ssl_error(self._ssl, result) + + def renegotiate_pending(self) -> bool: + """ + Check if there's a renegotiation in progress, it will return False once + a renegotiation is finished. + + :return: Whether there's a renegotiation in progress + """ + return _lib.SSL_renegotiate_pending(self._ssl) == 1 + + def total_renegotiations(self) -> int: + """ + Find out the total number of renegotiations. + + :return: The number of renegotiations. + """ + return _lib.SSL_total_renegotiations(self._ssl) + + def connect(self, addr: Any) -> None: + """ + Call the :meth:`connect` method of the underlying socket and set up SSL + on the socket, using the :class:`Context` object supplied to this + :class:`Connection` object at creation. + + :param addr: A remote address + :return: What the socket's connect method returns + """ + _lib.SSL_set_connect_state(self._ssl) + return self._socket.connect(addr) # type: ignore[return-value, union-attr] + + def connect_ex(self, addr: Any) -> int: + """ + Call the :meth:`connect_ex` method of the underlying socket and set up + SSL on the socket, using the Context object supplied to this Connection + object at creation. Note that if the :meth:`connect_ex` method of the + socket doesn't return 0, SSL won't be initialized. + + :param addr: A remove address + :return: What the socket's connect_ex method returns + """ + connect_ex = self._socket.connect_ex # type: ignore[union-attr] + self.set_connect_state() + return connect_ex(addr) + + def accept(self) -> Tuple["Connection", Any]: + """ + Call the :meth:`accept` method of the underlying socket and set up SSL + on the returned socket, using the Context object supplied to this + :class:`Connection` object at creation. + + :return: A *(conn, addr)* pair where *conn* is the new + :class:`Connection` object created, and *address* is as returned by + the socket's :meth:`accept`. + """ + client, addr = self._socket.accept() # type: ignore[union-attr] + conn = Connection(self._context, client) + conn.set_accept_state() + return (conn, addr) + + def DTLSv1_listen(self) -> None: + """ + Call the OpenSSL function DTLSv1_listen on this connection. See the + OpenSSL manual for more details. + + :return: None + """ + # Possible future extension: return the BIO_ADDR in some form. + bio_addr = _lib.BIO_ADDR_new() + try: + result = _lib.DTLSv1_listen(self._ssl, bio_addr) + finally: + _lib.BIO_ADDR_free(bio_addr) + # DTLSv1_listen is weird. A zero return value means 'didn't find a + # ClientHello with valid cookie, but keep trying'. So basically + # WantReadError. But it doesn't work correctly with _raise_ssl_error. + # So we raise it manually instead. + if self._cookie_generate_helper is not None: + self._cookie_generate_helper.raise_if_problem() + if self._cookie_verify_helper is not None: + self._cookie_verify_helper.raise_if_problem() + if result == 0: + raise WantReadError() + if result < 0: + self._raise_ssl_error(self._ssl, result) + + def DTLSv1_get_timeout(self) -> Optional[int]: + """ + Determine when the DTLS SSL object next needs to perform internal + processing due to the passage of time. + + When the returned number of seconds have passed, the + :meth:`DTLSv1_handle_timeout` method needs to be called. + + :return: The time left in seconds before the next timeout or `None` + if no timeout is currently active. + """ + ptv_sec = _ffi.new("time_t *") + ptv_usec = _ffi.new("long *") + if _lib.Cryptography_DTLSv1_get_timeout(self._ssl, ptv_sec, ptv_usec): + return ptv_sec[0] + (ptv_usec[0] / 1000000) + else: + return None + + def DTLSv1_handle_timeout(self) -> bool: + """ + Handles any timeout events which have become pending on a DTLS SSL + object. + + :return: `True` if there was a pending timeout, `False` otherwise. + """ + result = _lib.DTLSv1_handle_timeout(self._ssl) + if result < 0: + self._raise_ssl_error(self._ssl, result) + assert False, "unreachable" + else: + return bool(result) + + def bio_shutdown(self) -> None: + """ + If the Connection was created with a memory BIO, this method can be + used to indicate that *end of file* has been reached on the read end of + that memory BIO. + + :return: None + """ + if self._from_ssl is None: + raise TypeError("Connection sock was not None") + + _lib.BIO_set_mem_eof_return(self._into_ssl, 0) + + def shutdown(self) -> bool: + """ + Send the shutdown message to the Connection. + + :return: True if the shutdown completed successfully (i.e. both sides + have sent closure alerts), False otherwise (in which case you + call :meth:`recv` or :meth:`send` when the connection becomes + readable/writeable). + """ + result = _lib.SSL_shutdown(self._ssl) + if result < 0: + self._raise_ssl_error(self._ssl, result) + assert False, "unreachable" + elif result > 0: + return True + else: + return False + + def get_cipher_list(self) -> List[str]: + """ + Retrieve the list of ciphers used by the Connection object. + + :return: A list of native cipher strings. + """ + ciphers = [] + for i in count(): + result = _lib.SSL_get_cipher_list(self._ssl, i) + if result == _ffi.NULL: + break + ciphers.append(_ffi.string(result).decode("utf-8")) + return ciphers + + def get_client_ca_list(self) -> List[X509Name]: + """ + Get CAs whose certificates are suggested for client authentication. + + :return: If this is a server connection, the list of certificate + authorities that will be sent or has been sent to the client, as + controlled by this :class:`Connection`'s :class:`Context`. + + If this is a client connection, the list will be empty until the + connection with the server is established. + + .. versionadded:: 0.10 + """ + ca_names = _lib.SSL_get_client_CA_list(self._ssl) + if ca_names == _ffi.NULL: + # TODO: This is untested. + return [] + + result = [] + for i in range(_lib.sk_X509_NAME_num(ca_names)): + name = _lib.sk_X509_NAME_value(ca_names, i) + copy = _lib.X509_NAME_dup(name) + _openssl_assert(copy != _ffi.NULL) + + pyname = X509Name.__new__(X509Name) + pyname._name = _ffi.gc(copy, _lib.X509_NAME_free) + result.append(pyname) + return result + + def makefile(self, *args: Any, **kwargs: Any) -> typing.NoReturn: + """ + The makefile() method is not implemented, since there is no dup + semantics for SSL connections + + :raise: NotImplementedError + """ + raise NotImplementedError( + "Cannot make file object of OpenSSL.SSL.Connection" + ) + + def get_app_data(self) -> Any: + """ + Retrieve application data as set by :meth:`set_app_data`. + + :return: The application data + """ + return self._app_data + + def set_app_data(self, data: Any) -> None: + """ + Set application data + + :param data: The application data + :return: None + """ + self._app_data = data + + def get_shutdown(self) -> int: + """ + Get the shutdown state of the Connection. + + :return: The shutdown state, a bitvector of SENT_SHUTDOWN, + RECEIVED_SHUTDOWN. + """ + return _lib.SSL_get_shutdown(self._ssl) + + def set_shutdown(self, state: int) -> None: + """ + Set the shutdown state of the Connection. + + :param state: bitvector of SENT_SHUTDOWN, RECEIVED_SHUTDOWN. + :return: None + """ + if not isinstance(state, int): + raise TypeError("state must be an integer") + + _lib.SSL_set_shutdown(self._ssl, state) + + def get_state_string(self) -> bytes: + """ + Retrieve a verbose string detailing the state of the Connection. + + :return: A string representing the state + """ + return _ffi.string(_lib.SSL_state_string_long(self._ssl)) + + def server_random(self) -> Optional[bytes]: + """ + Retrieve the random value used with the server hello message. + + :return: A string representing the state + """ + session = _lib.SSL_get_session(self._ssl) + if session == _ffi.NULL: + return None + length = _lib.SSL_get_server_random(self._ssl, _ffi.NULL, 0) + _openssl_assert(length > 0) + outp = _no_zero_allocator("unsigned char[]", length) + _lib.SSL_get_server_random(self._ssl, outp, length) + return _ffi.buffer(outp, length)[:] + + def client_random(self) -> Optional[bytes]: + """ + Retrieve the random value used with the client hello message. + + :return: A string representing the state + """ + session = _lib.SSL_get_session(self._ssl) + if session == _ffi.NULL: + return None + + length = _lib.SSL_get_client_random(self._ssl, _ffi.NULL, 0) + _openssl_assert(length > 0) + outp = _no_zero_allocator("unsigned char[]", length) + _lib.SSL_get_client_random(self._ssl, outp, length) + return _ffi.buffer(outp, length)[:] + + def master_key(self) -> Optional[bytes]: + """ + Retrieve the value of the master key for this session. + + :return: A string representing the state + """ + session = _lib.SSL_get_session(self._ssl) + if session == _ffi.NULL: + return None + + length = _lib.SSL_SESSION_get_master_key(session, _ffi.NULL, 0) + _openssl_assert(length > 0) + outp = _no_zero_allocator("unsigned char[]", length) + _lib.SSL_SESSION_get_master_key(session, outp, length) + return _ffi.buffer(outp, length)[:] + + def export_keying_material( + self, label: bytes, olen: int, context: Optional[bytes] = None + ) -> bytes: + """ + Obtain keying material for application use. + + :param: label - a disambiguating label string as described in RFC 5705 + :param: olen - the length of the exported key material in bytes + :param: context - a per-association context value + :return: the exported key material bytes or None + """ + outp = _no_zero_allocator("unsigned char[]", olen) + context_buf = _ffi.NULL + context_len = 0 + use_context = 0 + if context is not None: + context_buf = context + context_len = len(context) + use_context = 1 + success = _lib.SSL_export_keying_material( + self._ssl, + outp, + olen, + label, + len(label), + context_buf, + context_len, + use_context, + ) + _openssl_assert(success == 1) + return _ffi.buffer(outp, olen)[:] + + def sock_shutdown(self, *args: Any, **kwargs: Any) -> None: + """ + Call the :meth:`shutdown` method of the underlying socket. + See :manpage:`shutdown(2)`. + + :return: What the socket's shutdown() method returns + """ + return self._socket.shutdown(*args, **kwargs) # type: ignore[return-value, union-attr] + + def get_certificate(self) -> Optional[X509]: + """ + Retrieve the local certificate (if any) + + :return: The local certificate + """ + cert = _lib.SSL_get_certificate(self._ssl) + if cert != _ffi.NULL: + _lib.X509_up_ref(cert) + return X509._from_raw_x509_ptr(cert) + return None + + def get_peer_certificate(self) -> Optional[X509]: + """ + Retrieve the other side's certificate (if any) + + :return: The peer's certificate + """ + cert = _lib.SSL_get_peer_certificate(self._ssl) + if cert != _ffi.NULL: + return X509._from_raw_x509_ptr(cert) + return None + + @staticmethod + def _cert_stack_to_list(cert_stack: Any) -> List[X509]: + """ + Internal helper to convert a STACK_OF(X509) to a list of X509 + instances. + """ + result = [] + for i in range(_lib.sk_X509_num(cert_stack)): + cert = _lib.sk_X509_value(cert_stack, i) + _openssl_assert(cert != _ffi.NULL) + res = _lib.X509_up_ref(cert) + _openssl_assert(res >= 1) + pycert = X509._from_raw_x509_ptr(cert) + result.append(pycert) + return result + + def get_peer_cert_chain(self) -> Optional[List[X509]]: + """ + Retrieve the other side's certificate (if any) + + :return: A list of X509 instances giving the peer's certificate chain, + or None if it does not have one. + """ + cert_stack = _lib.SSL_get_peer_cert_chain(self._ssl) + if cert_stack == _ffi.NULL: + return None + + return self._cert_stack_to_list(cert_stack) + + def get_verified_chain(self) -> Optional[List[X509]]: + """ + Retrieve the verified certificate chain of the peer including the + peer's end entity certificate. It must be called after a session has + been successfully established. If peer verification was not successful + the chain may be incomplete, invalid, or None. + + :return: A list of X509 instances giving the peer's verified + certificate chain, or None if it does not have one. + + .. versionadded:: 20.0 + """ + # OpenSSL 1.1+ + cert_stack = _lib.SSL_get0_verified_chain(self._ssl) + if cert_stack == _ffi.NULL: + return None + + return self._cert_stack_to_list(cert_stack) + + def want_read(self) -> bool: + """ + Checks if more data has to be read from the transport layer to complete + an operation. + + :return: True iff more data has to be read + """ + return _lib.SSL_want_read(self._ssl) + + def want_write(self) -> bool: + """ + Checks if there is data to write to the transport layer to complete an + operation. + + :return: True iff there is data to write + """ + return _lib.SSL_want_write(self._ssl) + + def set_accept_state(self) -> None: + """ + Set the connection to work in server mode. The handshake will be + handled automatically by read/write. + + :return: None + """ + _lib.SSL_set_accept_state(self._ssl) + + def set_connect_state(self) -> None: + """ + Set the connection to work in client mode. The handshake will be + handled automatically by read/write. + + :return: None + """ + _lib.SSL_set_connect_state(self._ssl) + + def get_session(self) -> Optional[Session]: + """ + Returns the Session currently used. + + :return: An instance of :class:`OpenSSL.SSL.Session` or + :obj:`None` if no session exists. + + .. versionadded:: 0.14 + """ + session = _lib.SSL_get1_session(self._ssl) + if session == _ffi.NULL: + return None + + pysession = Session.__new__(Session) + pysession._session = _ffi.gc(session, _lib.SSL_SESSION_free) + return pysession + + def set_session(self, session: Session) -> None: + """ + Set the session to be used when the TLS/SSL connection is established. + + :param session: A Session instance representing the session to use. + :returns: None + + .. versionadded:: 0.14 + """ + if not isinstance(session, Session): + raise TypeError("session must be a Session instance") + + result = _lib.SSL_set_session(self._ssl, session._session) + _openssl_assert(result == 1) + + def _get_finished_message( + self, function: Callable[[Any, Any, int], int] + ) -> Optional[bytes]: + """ + Helper to implement :meth:`get_finished` and + :meth:`get_peer_finished`. + + :param function: Either :data:`SSL_get_finished`: or + :data:`SSL_get_peer_finished`. + + :return: :data:`None` if the desired message has not yet been + received, otherwise the contents of the message. + """ + # The OpenSSL documentation says nothing about what might happen if the + # count argument given is zero. Specifically, it doesn't say whether + # the output buffer may be NULL in that case or not. Inspection of the + # implementation reveals that it calls memcpy() unconditionally. + # Section 7.1.4, paragraph 1 of the C standard suggests that + # memcpy(NULL, source, 0) is not guaranteed to produce defined (let + # alone desirable) behavior (though it probably does on just about + # every implementation...) + # + # Allocate a tiny buffer to pass in (instead of just passing NULL as + # one might expect) for the initial call so as to be safe against this + # potentially undefined behavior. + empty = _ffi.new("char[]", 0) + size = function(self._ssl, empty, 0) + if size == 0: + # No Finished message so far. + return None + + buf = _no_zero_allocator("char[]", size) + function(self._ssl, buf, size) + return _ffi.buffer(buf, size)[:] + + def get_finished(self) -> Optional[bytes]: + """ + Obtain the latest TLS Finished message that we sent. + + :return: The contents of the message or :obj:`None` if the TLS + handshake has not yet completed. + + .. versionadded:: 0.15 + """ + return self._get_finished_message(_lib.SSL_get_finished) + + def get_peer_finished(self) -> Optional[bytes]: + """ + Obtain the latest TLS Finished message that we received from the peer. + + :return: The contents of the message or :obj:`None` if the TLS + handshake has not yet completed. + + .. versionadded:: 0.15 + """ + return self._get_finished_message(_lib.SSL_get_peer_finished) + + def get_cipher_name(self) -> Optional[str]: + """ + Obtain the name of the currently used cipher. + + :returns: The name of the currently used cipher or :obj:`None` + if no connection has been established. + + .. versionadded:: 0.15 + """ + cipher = _lib.SSL_get_current_cipher(self._ssl) + if cipher == _ffi.NULL: + return None + else: + name = _ffi.string(_lib.SSL_CIPHER_get_name(cipher)) + return name.decode("utf-8") + + def get_cipher_bits(self) -> Optional[int]: + """ + Obtain the number of secret bits of the currently used cipher. + + :returns: The number of secret bits of the currently used cipher + or :obj:`None` if no connection has been established. + + .. versionadded:: 0.15 + """ + cipher = _lib.SSL_get_current_cipher(self._ssl) + if cipher == _ffi.NULL: + return None + else: + return _lib.SSL_CIPHER_get_bits(cipher, _ffi.NULL) + + def get_cipher_version(self) -> Optional[str]: + """ + Obtain the protocol version of the currently used cipher. + + :returns: The protocol name of the currently used cipher + or :obj:`None` if no connection has been established. + + .. versionadded:: 0.15 + """ + cipher = _lib.SSL_get_current_cipher(self._ssl) + if cipher == _ffi.NULL: + return None + else: + version = _ffi.string(_lib.SSL_CIPHER_get_version(cipher)) + return version.decode("utf-8") + + def get_protocol_version_name(self) -> str: + """ + Retrieve the protocol version of the current connection. + + :returns: The TLS version of the current connection, for example + the value for TLS 1.2 would be ``TLSv1.2``or ``Unknown`` + for connections that were not successfully established. + """ + version = _ffi.string(_lib.SSL_get_version(self._ssl)) + return version.decode("utf-8") + + def get_protocol_version(self) -> int: + """ + Retrieve the SSL or TLS protocol version of the current connection. + + :returns: The TLS version of the current connection. For example, + it will return ``0x769`` for connections made over TLS version 1. + """ + version = _lib.SSL_version(self._ssl) + return version + + @_requires_alpn + def set_alpn_protos(self, protos: List[bytes]) -> None: + """ + Specify the client's ALPN protocol list. + + These protocols are offered to the server during protocol negotiation. + + :param protos: A list of the protocols to be offered to the server. + This list should be a Python list of bytestrings representing the + protocols to offer, e.g. ``[b'http/1.1', b'spdy/2']``. + """ + # Different versions of OpenSSL are inconsistent about how they handle + # empty proto lists (see #1043), so we avoid the problem entirely by + # rejecting them ourselves. + if not protos: + raise ValueError("at least one protocol must be specified") + + # Take the list of protocols and join them together, prefixing them + # with their lengths. + protostr = b"".join( + chain.from_iterable((bytes((len(p),)), p) for p in protos) + ) + + # Build a C string from the list. We don't need to save this off + # because OpenSSL immediately copies the data out. + input_str = _ffi.new("unsigned char[]", protostr) + + # https://www.openssl.org/docs/man1.1.0/man3/SSL_CTX_set_alpn_protos.html: + # SSL_CTX_set_alpn_protos() and SSL_set_alpn_protos() + # return 0 on success, and non-0 on failure. + # WARNING: these functions reverse the return value convention. + _openssl_assert( + _lib.SSL_set_alpn_protos(self._ssl, input_str, len(protostr)) == 0 + ) + + @_requires_alpn + def get_alpn_proto_negotiated(self) -> bytes: + """ + Get the protocol that was negotiated by ALPN. + + :returns: A bytestring of the protocol name. If no protocol has been + negotiated yet, returns an empty bytestring. + """ + data = _ffi.new("unsigned char **") + data_len = _ffi.new("unsigned int *") + + _lib.SSL_get0_alpn_selected(self._ssl, data, data_len) + + if not data_len: + return b"" + + return _ffi.buffer(data[0], data_len[0])[:] + + def get_selected_srtp_profile(self) -> bytes: + """ + Get the SRTP protocol which was negotiated. + + :returns: A bytestring of the SRTP profile name. If no profile has been + negotiated yet, returns an empty bytestring. + """ + profile = _lib.SSL_get_selected_srtp_profile(self._ssl) + if not profile: + return b"" + + return _ffi.string(profile.name) + + def request_ocsp(self) -> None: + """ + Called to request that the server sends stapled OCSP data, if + available. If this is not called on the client side then the server + will not send OCSP data. Should be used in conjunction with + :meth:`Context.set_ocsp_client_callback`. + """ + rc = _lib.SSL_set_tlsext_status_type( + self._ssl, _lib.TLSEXT_STATUSTYPE_ocsp + ) + _openssl_assert(rc == 1) diff --git a/.venv/lib/python3.11/site-packages/OpenSSL/__init__.py b/.venv/lib/python3.11/site-packages/OpenSSL/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2442ee6f0540f1ef626bb55e314cc4c2c8da3a4d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/OpenSSL/__init__.py @@ -0,0 +1,31 @@ +# Copyright (C) AB Strakt +# See LICENSE for details. + +""" +pyOpenSSL - A simple wrapper around the OpenSSL library +""" + +from OpenSSL import SSL, crypto +from OpenSSL.version import ( + __author__, + __copyright__, + __email__, + __license__, + __summary__, + __title__, + __uri__, + __version__, +) + +__all__ = [ + "SSL", + "crypto", + "__author__", + "__copyright__", + "__email__", + "__license__", + "__summary__", + "__title__", + "__uri__", + "__version__", +] diff --git a/.venv/lib/python3.11/site-packages/OpenSSL/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/OpenSSL/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f75ab5d62862692e31f4235680655c2f0309ece4 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/OpenSSL/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/OpenSSL/__pycache__/_util.cpython-311.pyc b/.venv/lib/python3.11/site-packages/OpenSSL/__pycache__/_util.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..86446507410b40883570d7da8b0c84c410c60293 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/OpenSSL/__pycache__/_util.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/OpenSSL/__pycache__/debug.cpython-311.pyc b/.venv/lib/python3.11/site-packages/OpenSSL/__pycache__/debug.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0bc867eb4da991479fd800c9e1e11cc83b96bf36 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/OpenSSL/__pycache__/debug.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/OpenSSL/__pycache__/rand.cpython-311.pyc b/.venv/lib/python3.11/site-packages/OpenSSL/__pycache__/rand.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..75020fcbac5fc7207349210bb01940872e3d35ae Binary files /dev/null and b/.venv/lib/python3.11/site-packages/OpenSSL/__pycache__/rand.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/OpenSSL/__pycache__/version.cpython-311.pyc b/.venv/lib/python3.11/site-packages/OpenSSL/__pycache__/version.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..92105ba1b6c612c2c1d74911fbf572f2d9eb847f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/OpenSSL/__pycache__/version.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/OpenSSL/_util.py b/.venv/lib/python3.11/site-packages/OpenSSL/_util.py new file mode 100644 index 0000000000000000000000000000000000000000..7a102e6c9105382d52f0fee50e1b68ff53644596 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/OpenSSL/_util.py @@ -0,0 +1,124 @@ +import os +import sys +import warnings +from typing import Any, Callable, NoReturn, Type, Union + +from cryptography.hazmat.bindings.openssl.binding import Binding + +StrOrBytesPath = Union[str, bytes, os.PathLike] + +binding = Binding() +ffi = binding.ffi +lib = binding.lib + + +# This is a special CFFI allocator that does not bother to zero its memory +# after allocation. This has vastly better performance on large allocations and +# so should be used whenever we don't need the memory zeroed out. +no_zero_allocator = ffi.new_allocator(should_clear_after_alloc=False) + + +def text(charp: Any) -> str: + """ + Get a native string type representing of the given CFFI ``char*`` object. + + :param charp: A C-style string represented using CFFI. + + :return: :class:`str` + """ + if not charp: + return "" + return ffi.string(charp).decode("utf-8") + + +def exception_from_error_queue(exception_type: Type[Exception]) -> NoReturn: + """ + Convert an OpenSSL library failure into a Python exception. + + When a call to the native OpenSSL library fails, this is usually signalled + by the return value, and an error code is stored in an error queue + associated with the current thread. The err library provides functions to + obtain these error codes and textual error messages. + """ + errors = [] + + while True: + error = lib.ERR_get_error() + if error == 0: + break + errors.append( + ( + text(lib.ERR_lib_error_string(error)), + text(lib.ERR_func_error_string(error)), + text(lib.ERR_reason_error_string(error)), + ) + ) + + raise exception_type(errors) + + +def make_assert(error: Type[Exception]) -> Callable[[bool], Any]: + """ + Create an assert function that uses :func:`exception_from_error_queue` to + raise an exception wrapped by *error*. + """ + + def openssl_assert(ok: bool) -> None: + """ + If *ok* is not True, retrieve the error from OpenSSL and raise it. + """ + if ok is not True: + exception_from_error_queue(error) + + return openssl_assert + + +def path_bytes(s: StrOrBytesPath) -> bytes: + """ + Convert a Python path to a :py:class:`bytes` for the path which can be + passed into an OpenSSL API accepting a filename. + + :param s: A path (valid for os.fspath). + + :return: An instance of :py:class:`bytes`. + """ + b = os.fspath(s) + + if isinstance(b, str): + return b.encode(sys.getfilesystemencoding()) + else: + return b + + +def byte_string(s: str) -> bytes: + return s.encode("charmap") + + +# A marker object to observe whether some optional arguments are passed any +# value or not. +UNSPECIFIED = object() + +_TEXT_WARNING = "str for {0} is no longer accepted, use bytes" + + +def text_to_bytes_and_warn(label: str, obj: Any) -> Any: + """ + If ``obj`` is text, emit a warning that it should be bytes instead and try + to convert it to bytes automatically. + + :param str label: The name of the parameter from which ``obj`` was taken + (so a developer can easily find the source of the problem and correct + it). + + :return: If ``obj`` is the text string type, a ``bytes`` object giving the + UTF-8 encoding of that text is returned. Otherwise, ``obj`` itself is + returned. + """ + if isinstance(obj, str): + warnings.warn( + _TEXT_WARNING.format(label), + category=DeprecationWarning, + stacklevel=3, + ) + return obj.encode("utf-8") + return obj diff --git a/.venv/lib/python3.11/site-packages/OpenSSL/crypto.py b/.venv/lib/python3.11/site-packages/OpenSSL/crypto.py new file mode 100644 index 0000000000000000000000000000000000000000..07c112a646903c3e70888cc0ab5009e3221c015a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/OpenSSL/crypto.py @@ -0,0 +1,3062 @@ +import calendar +import datetime +import functools +import typing +from base64 import b16encode +from functools import partial +from os import PathLike +from typing import ( + Any, + Callable, + Iterable, + List, + NoReturn, + Optional, + Sequence, + Set, + Tuple, + Type, + Union, +) + +from cryptography import utils, x509 +from cryptography.hazmat.primitives.asymmetric import ( + dsa, + ec, + ed448, + ed25519, + rsa, +) + +from OpenSSL._util import ( + UNSPECIFIED as _UNSPECIFIED, +) +from OpenSSL._util import ( + byte_string as _byte_string, +) +from OpenSSL._util import ( + exception_from_error_queue as _exception_from_error_queue, +) +from OpenSSL._util import ( + ffi as _ffi, +) +from OpenSSL._util import ( + lib as _lib, +) +from OpenSSL._util import ( + make_assert as _make_assert, +) +from OpenSSL._util import ( + path_bytes as _path_bytes, +) +from OpenSSL._util import ( + text_to_bytes_and_warn as _text_to_bytes_and_warn, +) + +__all__ = [ + "FILETYPE_PEM", + "FILETYPE_ASN1", + "FILETYPE_TEXT", + "TYPE_RSA", + "TYPE_DSA", + "Error", + "PKey", + "get_elliptic_curves", + "get_elliptic_curve", + "X509Name", + "X509Extension", + "X509Req", + "X509", + "X509StoreFlags", + "X509Store", + "X509StoreContextError", + "X509StoreContext", + "load_certificate", + "dump_certificate", + "dump_publickey", + "dump_privatekey", + "Revoked", + "CRL", + "load_publickey", + "load_privatekey", + "dump_certificate_request", + "load_certificate_request", + "sign", + "verify", + "dump_crl", + "load_crl", +] + + +_Key = Union[ + dsa.DSAPrivateKey, + dsa.DSAPublicKey, + ec.EllipticCurvePrivateKey, + ec.EllipticCurvePublicKey, + ed25519.Ed25519PrivateKey, + ed25519.Ed25519PublicKey, + ed448.Ed448PrivateKey, + ed448.Ed448PublicKey, + rsa.RSAPrivateKey, + rsa.RSAPublicKey, +] +StrOrBytesPath = Union[str, bytes, PathLike] +PassphraseCallableT = Union[bytes, Callable[..., bytes]] + + +FILETYPE_PEM: int = _lib.SSL_FILETYPE_PEM +FILETYPE_ASN1: int = _lib.SSL_FILETYPE_ASN1 + +# TODO This was an API mistake. OpenSSL has no such constant. +FILETYPE_TEXT = 2**16 - 1 + +TYPE_RSA: int = _lib.EVP_PKEY_RSA +TYPE_DSA: int = _lib.EVP_PKEY_DSA +TYPE_DH: int = _lib.EVP_PKEY_DH +TYPE_EC: int = _lib.EVP_PKEY_EC + + +class Error(Exception): + """ + An error occurred in an `OpenSSL.crypto` API. + """ + + +_raise_current_error = partial(_exception_from_error_queue, Error) +_openssl_assert = _make_assert(Error) + + +def _untested_error(where: str) -> NoReturn: + """ + An OpenSSL API failed somehow. Additionally, the failure which was + encountered isn't one that's exercised by the test suite so future behavior + of pyOpenSSL is now somewhat less predictable. + """ + raise RuntimeError(f"Unknown {where} failure") + + +def _new_mem_buf(buffer: Optional[bytes] = None) -> Any: + """ + Allocate a new OpenSSL memory BIO. + + Arrange for the garbage collector to clean it up automatically. + + :param buffer: None or some bytes to use to put into the BIO so that they + can be read out. + """ + if buffer is None: + bio = _lib.BIO_new(_lib.BIO_s_mem()) + free = _lib.BIO_free + else: + data = _ffi.new("char[]", buffer) + bio = _lib.BIO_new_mem_buf(data, len(buffer)) + + # Keep the memory alive as long as the bio is alive! + def free(bio: Any, ref: Any = data) -> Any: + return _lib.BIO_free(bio) + + _openssl_assert(bio != _ffi.NULL) + + bio = _ffi.gc(bio, free) + return bio + + +def _bio_to_string(bio: Any) -> bytes: + """ + Copy the contents of an OpenSSL BIO object into a Python byte string. + """ + result_buffer = _ffi.new("char**") + buffer_length = _lib.BIO_get_mem_data(bio, result_buffer) + return _ffi.buffer(result_buffer[0], buffer_length)[:] + + +def _set_asn1_time(boundary: Any, when: bytes) -> None: + """ + The the time value of an ASN1 time object. + + @param boundary: An ASN1_TIME pointer (or an object safely + castable to that type) which will have its value set. + @param when: A string representation of the desired time value. + + @raise TypeError: If C{when} is not a L{bytes} string. + @raise ValueError: If C{when} does not represent a time in the required + format. + @raise RuntimeError: If the time value cannot be set for some other + (unspecified) reason. + """ + if not isinstance(when, bytes): + raise TypeError("when must be a byte string") + # ASN1_TIME_set_string validates the string without writing anything + # when the destination is NULL. + _openssl_assert(boundary != _ffi.NULL) + + set_result = _lib.ASN1_TIME_set_string(boundary, when) + if set_result == 0: + raise ValueError("Invalid string") + + +def _new_asn1_time(when: bytes) -> Any: + """ + Behaves like _set_asn1_time but returns a new ASN1_TIME object. + + @param when: A string representation of the desired time value. + + @raise TypeError: If C{when} is not a L{bytes} string. + @raise ValueError: If C{when} does not represent a time in the required + format. + @raise RuntimeError: If the time value cannot be set for some other + (unspecified) reason. + """ + ret = _lib.ASN1_TIME_new() + _openssl_assert(ret != _ffi.NULL) + ret = _ffi.gc(ret, _lib.ASN1_TIME_free) + _set_asn1_time(ret, when) + return ret + + +def _get_asn1_time(timestamp: Any) -> Optional[bytes]: + """ + Retrieve the time value of an ASN1 time object. + + @param timestamp: An ASN1_GENERALIZEDTIME* (or an object safely castable to + that type) from which the time value will be retrieved. + + @return: The time value from C{timestamp} as a L{bytes} string in a certain + format. Or C{None} if the object contains no time value. + """ + string_timestamp = _ffi.cast("ASN1_STRING*", timestamp) + if _lib.ASN1_STRING_length(string_timestamp) == 0: + return None + elif ( + _lib.ASN1_STRING_type(string_timestamp) == _lib.V_ASN1_GENERALIZEDTIME + ): + return _ffi.string(_lib.ASN1_STRING_get0_data(string_timestamp)) + else: + generalized_timestamp = _ffi.new("ASN1_GENERALIZEDTIME**") + _lib.ASN1_TIME_to_generalizedtime(timestamp, generalized_timestamp) + if generalized_timestamp[0] == _ffi.NULL: + # This may happen: + # - if timestamp was not an ASN1_TIME + # - if allocating memory for the ASN1_GENERALIZEDTIME failed + # - if a copy of the time data from timestamp cannot be made for + # the newly allocated ASN1_GENERALIZEDTIME + # + # These are difficult to test. cffi enforces the ASN1_TIME type. + # Memory allocation failures are a pain to trigger + # deterministically. + _untested_error("ASN1_TIME_to_generalizedtime") + else: + string_timestamp = _ffi.cast( + "ASN1_STRING*", generalized_timestamp[0] + ) + string_data = _lib.ASN1_STRING_get0_data(string_timestamp) + string_result = _ffi.string(string_data) + _lib.ASN1_GENERALIZEDTIME_free(generalized_timestamp[0]) + return string_result + + +class _X509NameInvalidator: + def __init__(self) -> None: + self._names: List[X509Name] = [] + + def add(self, name: "X509Name") -> None: + self._names.append(name) + + def clear(self) -> None: + for name in self._names: + # Breaks the object, but also prevents UAF! + del name._name + + +class PKey: + """ + A class representing an DSA or RSA public key or key pair. + """ + + _only_public = False + _initialized = True + + def __init__(self) -> None: + pkey = _lib.EVP_PKEY_new() + self._pkey = _ffi.gc(pkey, _lib.EVP_PKEY_free) + self._initialized = False + + def to_cryptography_key(self) -> _Key: + """ + Export as a ``cryptography`` key. + + :rtype: One of ``cryptography``'s `key interfaces`_. + + .. _key interfaces: https://cryptography.io/en/latest/hazmat/\ + primitives/asymmetric/rsa/#key-interfaces + + .. versionadded:: 16.1.0 + """ + from cryptography.hazmat.primitives.serialization import ( + load_der_private_key, + load_der_public_key, + ) + + if self._only_public: + der = dump_publickey(FILETYPE_ASN1, self) + return load_der_public_key(der) + else: + der = dump_privatekey(FILETYPE_ASN1, self) + return load_der_private_key(der, None) + + @classmethod + def from_cryptography_key(cls, crypto_key: _Key) -> "PKey": + """ + Construct based on a ``cryptography`` *crypto_key*. + + :param crypto_key: A ``cryptography`` key. + :type crypto_key: One of ``cryptography``'s `key interfaces`_. + + :rtype: PKey + + .. versionadded:: 16.1.0 + """ + if not isinstance( + crypto_key, + ( + dsa.DSAPrivateKey, + dsa.DSAPublicKey, + ec.EllipticCurvePrivateKey, + ec.EllipticCurvePublicKey, + ed25519.Ed25519PrivateKey, + ed25519.Ed25519PublicKey, + ed448.Ed448PrivateKey, + ed448.Ed448PublicKey, + rsa.RSAPrivateKey, + rsa.RSAPublicKey, + ), + ): + raise TypeError("Unsupported key type") + + from cryptography.hazmat.primitives.serialization import ( + Encoding, + NoEncryption, + PrivateFormat, + PublicFormat, + ) + + if isinstance( + crypto_key, + ( + dsa.DSAPublicKey, + ec.EllipticCurvePublicKey, + ed25519.Ed25519PublicKey, + ed448.Ed448PublicKey, + rsa.RSAPublicKey, + ), + ): + return load_publickey( + FILETYPE_ASN1, + crypto_key.public_bytes( + Encoding.DER, PublicFormat.SubjectPublicKeyInfo + ), + ) + else: + der = crypto_key.private_bytes( + Encoding.DER, PrivateFormat.PKCS8, NoEncryption() + ) + return load_privatekey(FILETYPE_ASN1, der) + + def generate_key(self, type: int, bits: int) -> None: + """ + Generate a key pair of the given type, with the given number of bits. + + This generates a key "into" the this object. + + :param type: The key type. + :type type: :py:data:`TYPE_RSA` or :py:data:`TYPE_DSA` + :param bits: The number of bits. + :type bits: :py:data:`int` ``>= 0`` + :raises TypeError: If :py:data:`type` or :py:data:`bits` isn't + of the appropriate type. + :raises ValueError: If the number of bits isn't an integer of + the appropriate size. + :return: ``None`` + """ + if not isinstance(type, int): + raise TypeError("type must be an integer") + + if not isinstance(bits, int): + raise TypeError("bits must be an integer") + + if type == TYPE_RSA: + if bits <= 0: + raise ValueError("Invalid number of bits") + + # TODO Check error return + exponent = _lib.BN_new() + exponent = _ffi.gc(exponent, _lib.BN_free) + _lib.BN_set_word(exponent, _lib.RSA_F4) + + rsa = _lib.RSA_new() + + result = _lib.RSA_generate_key_ex(rsa, bits, exponent, _ffi.NULL) + _openssl_assert(result == 1) + + result = _lib.EVP_PKEY_assign_RSA(self._pkey, rsa) + _openssl_assert(result == 1) + + elif type == TYPE_DSA: + dsa = _lib.DSA_new() + _openssl_assert(dsa != _ffi.NULL) + + dsa = _ffi.gc(dsa, _lib.DSA_free) + res = _lib.DSA_generate_parameters_ex( + dsa, bits, _ffi.NULL, 0, _ffi.NULL, _ffi.NULL, _ffi.NULL + ) + _openssl_assert(res == 1) + + _openssl_assert(_lib.DSA_generate_key(dsa) == 1) + _openssl_assert(_lib.EVP_PKEY_set1_DSA(self._pkey, dsa) == 1) + else: + raise Error("No such key type") + + self._initialized = True + + def check(self) -> bool: + """ + Check the consistency of an RSA private key. + + This is the Python equivalent of OpenSSL's ``RSA_check_key``. + + :return: ``True`` if key is consistent. + + :raise OpenSSL.crypto.Error: if the key is inconsistent. + + :raise TypeError: if the key is of a type which cannot be checked. + Only RSA keys can currently be checked. + """ + if self._only_public: + raise TypeError("public key only") + + if _lib.EVP_PKEY_type(self.type()) != _lib.EVP_PKEY_RSA: + raise TypeError("Only RSA keys can currently be checked.") + + rsa = _lib.EVP_PKEY_get1_RSA(self._pkey) + rsa = _ffi.gc(rsa, _lib.RSA_free) + result = _lib.RSA_check_key(rsa) + if result == 1: + return True + _raise_current_error() + + def type(self) -> int: + """ + Returns the type of the key + + :return: The type of the key. + """ + return _lib.EVP_PKEY_id(self._pkey) + + def bits(self) -> int: + """ + Returns the number of bits of the key + + :return: The number of bits of the key. + """ + return _lib.EVP_PKEY_bits(self._pkey) + + +class _EllipticCurve: + """ + A representation of a supported elliptic curve. + + @cvar _curves: :py:obj:`None` until an attempt is made to load the curves. + Thereafter, a :py:type:`set` containing :py:type:`_EllipticCurve` + instances each of which represents one curve supported by the system. + @type _curves: :py:type:`NoneType` or :py:type:`set` + """ + + _curves = None + + def __ne__(self, other: Any) -> bool: + """ + Implement cooperation with the right-hand side argument of ``!=``. + + Python 3 seems to have dropped this cooperation in this very narrow + circumstance. + """ + if isinstance(other, _EllipticCurve): + return super().__ne__(other) + return NotImplemented + + @classmethod + def _load_elliptic_curves(cls, lib: Any) -> Set["_EllipticCurve"]: + """ + Get the curves supported by OpenSSL. + + :param lib: The OpenSSL library binding object. + + :return: A :py:type:`set` of ``cls`` instances giving the names of the + elliptic curves the underlying library supports. + """ + num_curves = lib.EC_get_builtin_curves(_ffi.NULL, 0) + builtin_curves = _ffi.new("EC_builtin_curve[]", num_curves) + # The return value on this call should be num_curves again. We + # could check it to make sure but if it *isn't* then.. what could + # we do? Abort the whole process, I suppose...? -exarkun + lib.EC_get_builtin_curves(builtin_curves, num_curves) + return set(cls.from_nid(lib, c.nid) for c in builtin_curves) + + @classmethod + def _get_elliptic_curves(cls, lib: Any) -> Set["_EllipticCurve"]: + """ + Get, cache, and return the curves supported by OpenSSL. + + :param lib: The OpenSSL library binding object. + + :return: A :py:type:`set` of ``cls`` instances giving the names of the + elliptic curves the underlying library supports. + """ + if cls._curves is None: + cls._curves = cls._load_elliptic_curves(lib) + return cls._curves + + @classmethod + def from_nid(cls, lib: Any, nid: int) -> "_EllipticCurve": + """ + Instantiate a new :py:class:`_EllipticCurve` associated with the given + OpenSSL NID. + + :param lib: The OpenSSL library binding object. + + :param nid: The OpenSSL NID the resulting curve object will represent. + This must be a curve NID (and not, for example, a hash NID) or + subsequent operations will fail in unpredictable ways. + :type nid: :py:class:`int` + + :return: The curve object. + """ + return cls(lib, nid, _ffi.string(lib.OBJ_nid2sn(nid)).decode("ascii")) + + def __init__(self, lib: Any, nid: int, name: str) -> None: + """ + :param _lib: The :py:mod:`cryptography` binding instance used to + interface with OpenSSL. + + :param _nid: The OpenSSL NID identifying the curve this object + represents. + :type _nid: :py:class:`int` + + :param name: The OpenSSL short name identifying the curve this object + represents. + :type name: :py:class:`unicode` + """ + self._lib = lib + self._nid = nid + self.name = name + + def __repr__(self) -> str: + return f"" + + def _to_EC_KEY(self) -> Any: + """ + Create a new OpenSSL EC_KEY structure initialized to use this curve. + + The structure is automatically garbage collected when the Python object + is garbage collected. + """ + key = self._lib.EC_KEY_new_by_curve_name(self._nid) + return _ffi.gc(key, _lib.EC_KEY_free) + + +def get_elliptic_curves() -> Set["_EllipticCurve"]: + """ + Return a set of objects representing the elliptic curves supported in the + OpenSSL build in use. + + The curve objects have a :py:class:`unicode` ``name`` attribute by which + they identify themselves. + + The curve objects are useful as values for the argument accepted by + :py:meth:`Context.set_tmp_ecdh` to specify which elliptical curve should be + used for ECDHE key exchange. + """ + return _EllipticCurve._get_elliptic_curves(_lib) + + +def get_elliptic_curve(name: str) -> _EllipticCurve: + """ + Return a single curve object selected by name. + + See :py:func:`get_elliptic_curves` for information about curve objects. + + :param name: The OpenSSL short name identifying the curve object to + retrieve. + :type name: :py:class:`unicode` + + If the named curve is not supported then :py:class:`ValueError` is raised. + """ + for curve in get_elliptic_curves(): + if curve.name == name: + return curve + raise ValueError("unknown curve name", name) + + +@functools.total_ordering +class X509Name: + """ + An X.509 Distinguished Name. + + :ivar countryName: The country of the entity. + :ivar C: Alias for :py:attr:`countryName`. + + :ivar stateOrProvinceName: The state or province of the entity. + :ivar ST: Alias for :py:attr:`stateOrProvinceName`. + + :ivar localityName: The locality of the entity. + :ivar L: Alias for :py:attr:`localityName`. + + :ivar organizationName: The organization name of the entity. + :ivar O: Alias for :py:attr:`organizationName`. + + :ivar organizationalUnitName: The organizational unit of the entity. + :ivar OU: Alias for :py:attr:`organizationalUnitName` + + :ivar commonName: The common name of the entity. + :ivar CN: Alias for :py:attr:`commonName`. + + :ivar emailAddress: The e-mail address of the entity. + """ + + def __init__(self, name: "X509Name") -> None: + """ + Create a new X509Name, copying the given X509Name instance. + + :param name: The name to copy. + :type name: :py:class:`X509Name` + """ + name = _lib.X509_NAME_dup(name._name) + self._name: Any = _ffi.gc(name, _lib.X509_NAME_free) + + def __setattr__(self, name: str, value: Any) -> None: + if name.startswith("_"): + return super().__setattr__(name, value) + + # Note: we really do not want str subclasses here, so we do not use + # isinstance. + if type(name) is not str: + raise TypeError( + f"attribute name must be string, not " + f"'{type(value).__name__:.200}'" + ) + + nid = _lib.OBJ_txt2nid(_byte_string(name)) + if nid == _lib.NID_undef: + try: + _raise_current_error() + except Error: + pass + raise AttributeError("No such attribute") + + # If there's an old entry for this NID, remove it + for i in range(_lib.X509_NAME_entry_count(self._name)): + ent = _lib.X509_NAME_get_entry(self._name, i) + ent_obj = _lib.X509_NAME_ENTRY_get_object(ent) + ent_nid = _lib.OBJ_obj2nid(ent_obj) + if nid == ent_nid: + ent = _lib.X509_NAME_delete_entry(self._name, i) + _lib.X509_NAME_ENTRY_free(ent) + break + + if isinstance(value, str): + value = value.encode("utf-8") + + add_result = _lib.X509_NAME_add_entry_by_NID( + self._name, nid, _lib.MBSTRING_UTF8, value, -1, -1, 0 + ) + if not add_result: + _raise_current_error() + + def __getattr__(self, name: str) -> Optional[str]: + """ + Find attribute. An X509Name object has the following attributes: + countryName (alias C), stateOrProvince (alias ST), locality (alias L), + organization (alias O), organizationalUnit (alias OU), commonName + (alias CN) and more... + """ + nid = _lib.OBJ_txt2nid(_byte_string(name)) + if nid == _lib.NID_undef: + # This is a bit weird. OBJ_txt2nid indicated failure, but it seems + # a lower level function, a2d_ASN1_OBJECT, also feels the need to + # push something onto the error queue. If we don't clean that up + # now, someone else will bump into it later and be quite confused. + # See lp#314814. + try: + _raise_current_error() + except Error: + pass + raise AttributeError("No such attribute") + + entry_index = _lib.X509_NAME_get_index_by_NID(self._name, nid, -1) + if entry_index == -1: + return None + + entry = _lib.X509_NAME_get_entry(self._name, entry_index) + data = _lib.X509_NAME_ENTRY_get_data(entry) + + result_buffer = _ffi.new("unsigned char**") + data_length = _lib.ASN1_STRING_to_UTF8(result_buffer, data) + _openssl_assert(data_length >= 0) + + try: + result = _ffi.buffer(result_buffer[0], data_length)[:].decode( + "utf-8" + ) + finally: + # XXX untested + _lib.OPENSSL_free(result_buffer[0]) + return result + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, X509Name): + return NotImplemented + + return _lib.X509_NAME_cmp(self._name, other._name) == 0 + + def __lt__(self, other: Any) -> bool: + if not isinstance(other, X509Name): + return NotImplemented + + return _lib.X509_NAME_cmp(self._name, other._name) < 0 + + def __repr__(self) -> str: + """ + String representation of an X509Name + """ + result_buffer = _ffi.new("char[]", 512) + format_result = _lib.X509_NAME_oneline( + self._name, result_buffer, len(result_buffer) + ) + _openssl_assert(format_result != _ffi.NULL) + + return "".format( + _ffi.string(result_buffer).decode("utf-8"), + ) + + def hash(self) -> int: + """ + Return an integer representation of the first four bytes of the + MD5 digest of the DER representation of the name. + + This is the Python equivalent of OpenSSL's ``X509_NAME_hash``. + + :return: The (integer) hash of this name. + :rtype: :py:class:`int` + """ + return _lib.X509_NAME_hash(self._name) + + def der(self) -> bytes: + """ + Return the DER encoding of this name. + + :return: The DER encoded form of this name. + :rtype: :py:class:`bytes` + """ + result_buffer = _ffi.new("unsigned char**") + encode_result = _lib.i2d_X509_NAME(self._name, result_buffer) + _openssl_assert(encode_result >= 0) + + string_result = _ffi.buffer(result_buffer[0], encode_result)[:] + _lib.OPENSSL_free(result_buffer[0]) + return string_result + + def get_components(self) -> List[Tuple[bytes, bytes]]: + """ + Returns the components of this name, as a sequence of 2-tuples. + + :return: The components of this name. + :rtype: :py:class:`list` of ``name, value`` tuples. + """ + result = [] + for i in range(_lib.X509_NAME_entry_count(self._name)): + ent = _lib.X509_NAME_get_entry(self._name, i) + + fname = _lib.X509_NAME_ENTRY_get_object(ent) + fval = _lib.X509_NAME_ENTRY_get_data(ent) + + nid = _lib.OBJ_obj2nid(fname) + name = _lib.OBJ_nid2sn(nid) + + # ffi.string does not handle strings containing NULL bytes + # (which may have been generated by old, broken software) + value = _ffi.buffer( + _lib.ASN1_STRING_get0_data(fval), _lib.ASN1_STRING_length(fval) + )[:] + result.append((_ffi.string(name), value)) + + return result + + +class X509Extension: + """ + An X.509 v3 certificate extension. + """ + + def __init__( + self, + type_name: bytes, + critical: bool, + value: bytes, + subject: Optional["X509"] = None, + issuer: Optional["X509"] = None, + ) -> None: + """ + Initializes an X509 extension. + + :param type_name: The name of the type of extension_ to create. + :type type_name: :py:data:`bytes` + + :param bool critical: A flag indicating whether this is a critical + extension. + + :param value: The OpenSSL textual representation of the extension's + value. + :type value: :py:data:`bytes` + + :param subject: Optional X509 certificate to use as subject. + :type subject: :py:class:`X509` + + :param issuer: Optional X509 certificate to use as issuer. + :type issuer: :py:class:`X509` + + .. _extension: https://www.openssl.org/docs/manmaster/man5/ + x509v3_config.html#STANDARD-EXTENSIONS + """ + ctx = _ffi.new("X509V3_CTX*") + + # A context is necessary for any extension which uses the r2i + # conversion method. That is, X509V3_EXT_nconf may segfault if passed + # a NULL ctx. Start off by initializing most of the fields to NULL. + _lib.X509V3_set_ctx(ctx, _ffi.NULL, _ffi.NULL, _ffi.NULL, _ffi.NULL, 0) + + # We have no configuration database - but perhaps we should (some + # extensions may require it). + _lib.X509V3_set_ctx_nodb(ctx) + + # Initialize the subject and issuer, if appropriate. ctx is a local, + # and as far as I can tell none of the X509V3_* APIs invoked here steal + # any references, so no need to mess with reference counts or + # duplicates. + if issuer is not None: + if not isinstance(issuer, X509): + raise TypeError("issuer must be an X509 instance") + ctx.issuer_cert = issuer._x509 + if subject is not None: + if not isinstance(subject, X509): + raise TypeError("subject must be an X509 instance") + ctx.subject_cert = subject._x509 + + if critical: + # There are other OpenSSL APIs which would let us pass in critical + # separately, but they're harder to use, and since value is already + # a pile of crappy junk smuggling a ton of utterly important + # structured data, what's the point of trying to avoid nasty stuff + # with strings? (However, X509V3_EXT_i2d in particular seems like + # it would be a better API to invoke. I do not know where to get + # the ext_struc it desires for its last parameter, though.) + value = b"critical," + value + + extension = _lib.X509V3_EXT_nconf(_ffi.NULL, ctx, type_name, value) + if extension == _ffi.NULL: + _raise_current_error() + self._extension = _ffi.gc(extension, _lib.X509_EXTENSION_free) + + @property + def _nid(self) -> Any: + return _lib.OBJ_obj2nid( + _lib.X509_EXTENSION_get_object(self._extension) + ) + + _prefixes: typing.ClassVar[typing.Dict[int, str]] = { + _lib.GEN_EMAIL: "email", + _lib.GEN_DNS: "DNS", + _lib.GEN_URI: "URI", + } + + def _subjectAltNameString(self) -> str: + names = _ffi.cast( + "GENERAL_NAMES*", _lib.X509V3_EXT_d2i(self._extension) + ) + + names = _ffi.gc(names, _lib.GENERAL_NAMES_free) + parts = [] + for i in range(_lib.sk_GENERAL_NAME_num(names)): + name = _lib.sk_GENERAL_NAME_value(names, i) + try: + label = self._prefixes[name.type] + except KeyError: + bio = _new_mem_buf() + _lib.GENERAL_NAME_print(bio, name) + parts.append(_bio_to_string(bio).decode("utf-8")) + else: + value = _ffi.buffer(name.d.ia5.data, name.d.ia5.length)[ + : + ].decode("utf-8") + parts.append(label + ":" + value) + return ", ".join(parts) + + def __str__(self) -> str: + """ + :return: a nice text representation of the extension + """ + if _lib.NID_subject_alt_name == self._nid: + return self._subjectAltNameString() + + bio = _new_mem_buf() + print_result = _lib.X509V3_EXT_print(bio, self._extension, 0, 0) + _openssl_assert(print_result != 0) + + return _bio_to_string(bio).decode("utf-8") + + def get_critical(self) -> bool: + """ + Returns the critical field of this X.509 extension. + + :return: The critical field. + """ + return _lib.X509_EXTENSION_get_critical(self._extension) + + def get_short_name(self) -> bytes: + """ + Returns the short type name of this X.509 extension. + + The result is a byte string such as :py:const:`b"basicConstraints"`. + + :return: The short type name. + :rtype: :py:data:`bytes` + + .. versionadded:: 0.12 + """ + obj = _lib.X509_EXTENSION_get_object(self._extension) + nid = _lib.OBJ_obj2nid(obj) + # OpenSSL 3.1.0 has a bug where nid2sn returns NULL for NIDs that + # previously returned UNDEF. This is a workaround for that issue. + # https://github.com/openssl/openssl/commit/908ba3ed9adbb3df90f76 + buf = _lib.OBJ_nid2sn(nid) + if buf != _ffi.NULL: + return _ffi.string(buf) + else: + return b"UNDEF" + + def get_data(self) -> bytes: + """ + Returns the data of the X509 extension, encoded as ASN.1. + + :return: The ASN.1 encoded data of this X509 extension. + :rtype: :py:data:`bytes` + + .. versionadded:: 0.12 + """ + octet_result = _lib.X509_EXTENSION_get_data(self._extension) + string_result = _ffi.cast("ASN1_STRING*", octet_result) + char_result = _lib.ASN1_STRING_get0_data(string_result) + result_length = _lib.ASN1_STRING_length(string_result) + return _ffi.buffer(char_result, result_length)[:] + + +_X509ExtensionInternal = X509Extension +utils.deprecated( + X509Extension, + __name__, + ( + "X509Extension support in pyOpenSSL is deprecated. You should use the " + "APIs in cryptography." + ), + DeprecationWarning, + name="X509Extension", +) + + +class X509Req: + """ + An X.509 certificate signing requests. + """ + + def __init__(self) -> None: + req = _lib.X509_REQ_new() + self._req = _ffi.gc(req, _lib.X509_REQ_free) + # Default to version 0. + self.set_version(0) + + def to_cryptography(self) -> x509.CertificateSigningRequest: + """ + Export as a ``cryptography`` certificate signing request. + + :rtype: ``cryptography.x509.CertificateSigningRequest`` + + .. versionadded:: 17.1.0 + """ + from cryptography.x509 import load_der_x509_csr + + der = _dump_certificate_request_internal(FILETYPE_ASN1, self) + + return load_der_x509_csr(der) + + @classmethod + def from_cryptography( + cls, crypto_req: x509.CertificateSigningRequest + ) -> "X509Req": + """ + Construct based on a ``cryptography`` *crypto_req*. + + :param crypto_req: A ``cryptography`` X.509 certificate signing request + :type crypto_req: ``cryptography.x509.CertificateSigningRequest`` + + :rtype: X509Req + + .. versionadded:: 17.1.0 + """ + if not isinstance(crypto_req, x509.CertificateSigningRequest): + raise TypeError("Must be a certificate signing request") + + from cryptography.hazmat.primitives.serialization import Encoding + + der = crypto_req.public_bytes(Encoding.DER) + return _load_certificate_request_internal(FILETYPE_ASN1, der) + + def set_pubkey(self, pkey: PKey) -> None: + """ + Set the public key of the certificate signing request. + + :param pkey: The public key to use. + :type pkey: :py:class:`PKey` + + :return: ``None`` + """ + set_result = _lib.X509_REQ_set_pubkey(self._req, pkey._pkey) + _openssl_assert(set_result == 1) + + def get_pubkey(self) -> PKey: + """ + Get the public key of the certificate signing request. + + :return: The public key. + :rtype: :py:class:`PKey` + """ + pkey = PKey.__new__(PKey) + pkey._pkey = _lib.X509_REQ_get_pubkey(self._req) + _openssl_assert(pkey._pkey != _ffi.NULL) + pkey._pkey = _ffi.gc(pkey._pkey, _lib.EVP_PKEY_free) + pkey._only_public = True + return pkey + + def set_version(self, version: int) -> None: + """ + Set the version subfield (RFC 2986, section 4.1) of the certificate + request. + + :param int version: The version number. + :return: ``None`` + """ + if not isinstance(version, int): + raise TypeError("version must be an int") + if version != 0: + raise ValueError( + "Invalid version. The only valid version for X509Req is 0." + ) + set_result = _lib.X509_REQ_set_version(self._req, version) + _openssl_assert(set_result == 1) + + def get_version(self) -> int: + """ + Get the version subfield (RFC 2459, section 4.1.2.1) of the certificate + request. + + :return: The value of the version subfield. + :rtype: :py:class:`int` + """ + return _lib.X509_REQ_get_version(self._req) + + def get_subject(self) -> X509Name: + """ + Return the subject of this certificate signing request. + + This creates a new :class:`X509Name` that wraps the underlying subject + name field on the certificate signing request. Modifying it will modify + the underlying signing request, and will have the effect of modifying + any other :class:`X509Name` that refers to this subject. + + :return: The subject of this certificate signing request. + :rtype: :class:`X509Name` + """ + name = X509Name.__new__(X509Name) + name._name = _lib.X509_REQ_get_subject_name(self._req) + _openssl_assert(name._name != _ffi.NULL) + + # The name is owned by the X509Req structure. As long as the X509Name + # Python object is alive, keep the X509Req Python object alive. + name._owner = self + + return name + + def add_extensions( + self, extensions: Iterable[_X509ExtensionInternal] + ) -> None: + """ + Add extensions to the certificate signing request. + + :param extensions: The X.509 extensions to add. + :type extensions: iterable of :py:class:`X509Extension` + :return: ``None`` + """ + stack = _lib.sk_X509_EXTENSION_new_null() + _openssl_assert(stack != _ffi.NULL) + + stack = _ffi.gc(stack, _lib.sk_X509_EXTENSION_free) + + for ext in extensions: + if not isinstance(ext, _X509ExtensionInternal): + raise ValueError("One of the elements is not an X509Extension") + + # TODO push can fail (here and elsewhere) + _lib.sk_X509_EXTENSION_push(stack, ext._extension) + + add_result = _lib.X509_REQ_add_extensions(self._req, stack) + _openssl_assert(add_result == 1) + + def get_extensions(self) -> List[_X509ExtensionInternal]: + """ + Get X.509 extensions in the certificate signing request. + + :return: The X.509 extensions in this request. + :rtype: :py:class:`list` of :py:class:`X509Extension` objects. + + .. versionadded:: 0.15 + """ + exts = [] + native_exts_obj = _lib.X509_REQ_get_extensions(self._req) + native_exts_obj = _ffi.gc( + native_exts_obj, + lambda x: _lib.sk_X509_EXTENSION_pop_free( + x, + _ffi.addressof(_lib._original_lib, "X509_EXTENSION_free"), + ), + ) + + for i in range(_lib.sk_X509_EXTENSION_num(native_exts_obj)): + ext = _X509ExtensionInternal.__new__(_X509ExtensionInternal) + extension = _lib.X509_EXTENSION_dup( + _lib.sk_X509_EXTENSION_value(native_exts_obj, i) + ) + ext._extension = _ffi.gc(extension, _lib.X509_EXTENSION_free) + exts.append(ext) + return exts + + def sign(self, pkey: PKey, digest: str) -> None: + """ + Sign the certificate signing request with this key and digest type. + + :param pkey: The key pair to sign with. + :type pkey: :py:class:`PKey` + :param digest: The name of the message digest to use for the signature, + e.g. :py:data:`"sha256"`. + :type digest: :py:class:`str` + :return: ``None`` + """ + if pkey._only_public: + raise ValueError("Key has only public part") + + if not pkey._initialized: + raise ValueError("Key is uninitialized") + + digest_obj = _lib.EVP_get_digestbyname(_byte_string(digest)) + if digest_obj == _ffi.NULL: + raise ValueError("No such digest method") + + sign_result = _lib.X509_REQ_sign(self._req, pkey._pkey, digest_obj) + _openssl_assert(sign_result > 0) + + def verify(self, pkey: PKey) -> bool: + """ + Verifies the signature on this certificate signing request. + + :param PKey key: A public key. + + :return: ``True`` if the signature is correct. + :rtype: bool + + :raises OpenSSL.crypto.Error: If the signature is invalid or there is a + problem verifying the signature. + """ + if not isinstance(pkey, PKey): + raise TypeError("pkey must be a PKey instance") + + result = _lib.X509_REQ_verify(self._req, pkey._pkey) + if result <= 0: + _raise_current_error() + + return result + + +_X509ReqInternal = X509Req + +utils.deprecated( + X509Req, + __name__, + ( + "CSR support in pyOpenSSL is deprecated. You should use the APIs " + "in cryptography." + ), + DeprecationWarning, + name="X509Req", +) + + +class X509: + """ + An X.509 certificate. + """ + + def __init__(self) -> None: + x509 = _lib.X509_new() + _openssl_assert(x509 != _ffi.NULL) + self._x509 = _ffi.gc(x509, _lib.X509_free) + + self._issuer_invalidator = _X509NameInvalidator() + self._subject_invalidator = _X509NameInvalidator() + + @classmethod + def _from_raw_x509_ptr(cls, x509: Any) -> "X509": + cert = cls.__new__(cls) + cert._x509 = _ffi.gc(x509, _lib.X509_free) + cert._issuer_invalidator = _X509NameInvalidator() + cert._subject_invalidator = _X509NameInvalidator() + return cert + + def to_cryptography(self) -> x509.Certificate: + """ + Export as a ``cryptography`` certificate. + + :rtype: ``cryptography.x509.Certificate`` + + .. versionadded:: 17.1.0 + """ + from cryptography.x509 import load_der_x509_certificate + + der = dump_certificate(FILETYPE_ASN1, self) + return load_der_x509_certificate(der) + + @classmethod + def from_cryptography(cls, crypto_cert: x509.Certificate) -> "X509": + """ + Construct based on a ``cryptography`` *crypto_cert*. + + :param crypto_key: A ``cryptography`` X.509 certificate. + :type crypto_key: ``cryptography.x509.Certificate`` + + :rtype: X509 + + .. versionadded:: 17.1.0 + """ + if not isinstance(crypto_cert, x509.Certificate): + raise TypeError("Must be a certificate") + + from cryptography.hazmat.primitives.serialization import Encoding + + der = crypto_cert.public_bytes(Encoding.DER) + return load_certificate(FILETYPE_ASN1, der) + + def set_version(self, version: int) -> None: + """ + Set the version number of the certificate. Note that the + version value is zero-based, eg. a value of 0 is V1. + + :param version: The version number of the certificate. + :type version: :py:class:`int` + + :return: ``None`` + """ + if not isinstance(version, int): + raise TypeError("version must be an integer") + + _openssl_assert(_lib.X509_set_version(self._x509, version) == 1) + + def get_version(self) -> int: + """ + Return the version number of the certificate. + + :return: The version number of the certificate. + :rtype: :py:class:`int` + """ + return _lib.X509_get_version(self._x509) + + def get_pubkey(self) -> PKey: + """ + Get the public key of the certificate. + + :return: The public key. + :rtype: :py:class:`PKey` + """ + pkey = PKey.__new__(PKey) + pkey._pkey = _lib.X509_get_pubkey(self._x509) + if pkey._pkey == _ffi.NULL: + _raise_current_error() + pkey._pkey = _ffi.gc(pkey._pkey, _lib.EVP_PKEY_free) + pkey._only_public = True + return pkey + + def set_pubkey(self, pkey: PKey) -> None: + """ + Set the public key of the certificate. + + :param pkey: The public key. + :type pkey: :py:class:`PKey` + + :return: :py:data:`None` + """ + if not isinstance(pkey, PKey): + raise TypeError("pkey must be a PKey instance") + + set_result = _lib.X509_set_pubkey(self._x509, pkey._pkey) + _openssl_assert(set_result == 1) + + def sign(self, pkey: PKey, digest: str) -> None: + """ + Sign the certificate with this key and digest type. + + :param pkey: The key to sign with. + :type pkey: :py:class:`PKey` + + :param digest: The name of the message digest to use. + :type digest: :py:class:`str` + + :return: :py:data:`None` + """ + if not isinstance(pkey, PKey): + raise TypeError("pkey must be a PKey instance") + + if pkey._only_public: + raise ValueError("Key only has public part") + + if not pkey._initialized: + raise ValueError("Key is uninitialized") + + evp_md = _lib.EVP_get_digestbyname(_byte_string(digest)) + if evp_md == _ffi.NULL: + raise ValueError("No such digest method") + + sign_result = _lib.X509_sign(self._x509, pkey._pkey, evp_md) + _openssl_assert(sign_result > 0) + + def get_signature_algorithm(self) -> bytes: + """ + Return the signature algorithm used in the certificate. + + :return: The name of the algorithm. + :rtype: :py:class:`bytes` + + :raises ValueError: If the signature algorithm is undefined. + + .. versionadded:: 0.13 + """ + sig_alg = _lib.X509_get0_tbs_sigalg(self._x509) + alg = _ffi.new("ASN1_OBJECT **") + _lib.X509_ALGOR_get0(alg, _ffi.NULL, _ffi.NULL, sig_alg) + nid = _lib.OBJ_obj2nid(alg[0]) + if nid == _lib.NID_undef: + raise ValueError("Undefined signature algorithm") + return _ffi.string(_lib.OBJ_nid2ln(nid)) + + def digest(self, digest_name: str) -> bytes: + """ + Return the digest of the X509 object. + + :param digest_name: The name of the digest algorithm to use. + :type digest_name: :py:class:`str` + + :return: The digest of the object, formatted as + :py:const:`b":"`-delimited hex pairs. + :rtype: :py:class:`bytes` + """ + digest = _lib.EVP_get_digestbyname(_byte_string(digest_name)) + if digest == _ffi.NULL: + raise ValueError("No such digest method") + + result_buffer = _ffi.new("unsigned char[]", _lib.EVP_MAX_MD_SIZE) + result_length = _ffi.new("unsigned int[]", 1) + result_length[0] = len(result_buffer) + + digest_result = _lib.X509_digest( + self._x509, digest, result_buffer, result_length + ) + _openssl_assert(digest_result == 1) + + return b":".join( + [ + b16encode(ch).upper() + for ch in _ffi.buffer(result_buffer, result_length[0]) + ] + ) + + def subject_name_hash(self) -> bytes: + """ + Return the hash of the X509 subject. + + :return: The hash of the subject. + :rtype: :py:class:`bytes` + """ + return _lib.X509_subject_name_hash(self._x509) + + def set_serial_number(self, serial: int) -> None: + """ + Set the serial number of the certificate. + + :param serial: The new serial number. + :type serial: :py:class:`int` + + :return: :py:data`None` + """ + if not isinstance(serial, int): + raise TypeError("serial must be an integer") + + hex_serial = hex(serial)[2:] + hex_serial_bytes = hex_serial.encode("ascii") + + bignum_serial = _ffi.new("BIGNUM**") + + # BN_hex2bn stores the result in &bignum. Unless it doesn't feel like + # it. If bignum is still NULL after this call, then the return value + # is actually the result. I hope. -exarkun + small_serial = _lib.BN_hex2bn(bignum_serial, hex_serial_bytes) + + if bignum_serial[0] == _ffi.NULL: + set_result = _lib.ASN1_INTEGER_set( + _lib.X509_get_serialNumber(self._x509), small_serial + ) + if set_result: + # TODO Not tested + _raise_current_error() + else: + asn1_serial = _lib.BN_to_ASN1_INTEGER(bignum_serial[0], _ffi.NULL) + _lib.BN_free(bignum_serial[0]) + if asn1_serial == _ffi.NULL: + # TODO Not tested + _raise_current_error() + asn1_serial = _ffi.gc(asn1_serial, _lib.ASN1_INTEGER_free) + set_result = _lib.X509_set_serialNumber(self._x509, asn1_serial) + _openssl_assert(set_result == 1) + + def get_serial_number(self) -> int: + """ + Return the serial number of this certificate. + + :return: The serial number. + :rtype: int + """ + asn1_serial = _lib.X509_get_serialNumber(self._x509) + bignum_serial = _lib.ASN1_INTEGER_to_BN(asn1_serial, _ffi.NULL) + try: + hex_serial = _lib.BN_bn2hex(bignum_serial) + try: + hexstring_serial = _ffi.string(hex_serial) + serial = int(hexstring_serial, 16) + return serial + finally: + _lib.OPENSSL_free(hex_serial) + finally: + _lib.BN_free(bignum_serial) + + def gmtime_adj_notAfter(self, amount: int) -> None: + """ + Adjust the time stamp on which the certificate stops being valid. + + :param int amount: The number of seconds by which to adjust the + timestamp. + :return: ``None`` + """ + if not isinstance(amount, int): + raise TypeError("amount must be an integer") + + notAfter = _lib.X509_getm_notAfter(self._x509) + _lib.X509_gmtime_adj(notAfter, amount) + + def gmtime_adj_notBefore(self, amount: int) -> None: + """ + Adjust the timestamp on which the certificate starts being valid. + + :param amount: The number of seconds by which to adjust the timestamp. + :return: ``None`` + """ + if not isinstance(amount, int): + raise TypeError("amount must be an integer") + + notBefore = _lib.X509_getm_notBefore(self._x509) + _lib.X509_gmtime_adj(notBefore, amount) + + def has_expired(self) -> bool: + """ + Check whether the certificate has expired. + + :return: ``True`` if the certificate has expired, ``False`` otherwise. + :rtype: bool + """ + time_bytes = self.get_notAfter() + if time_bytes is None: + raise ValueError("Unable to determine notAfter") + time_string = time_bytes.decode("utf-8") + not_after = datetime.datetime.strptime(time_string, "%Y%m%d%H%M%SZ") + + UTC = datetime.timezone.utc + utcnow = datetime.datetime.now(UTC).replace(tzinfo=None) + return not_after < utcnow + + def _get_boundary_time(self, which: Any) -> Optional[bytes]: + return _get_asn1_time(which(self._x509)) + + def get_notBefore(self) -> Optional[bytes]: + """ + Get the timestamp at which the certificate starts being valid. + + The timestamp is formatted as an ASN.1 TIME:: + + YYYYMMDDhhmmssZ + + :return: A timestamp string, or ``None`` if there is none. + :rtype: bytes or NoneType + """ + return self._get_boundary_time(_lib.X509_getm_notBefore) + + def _set_boundary_time( + self, which: Callable[..., Any], when: bytes + ) -> None: + return _set_asn1_time(which(self._x509), when) + + def set_notBefore(self, when: bytes) -> None: + """ + Set the timestamp at which the certificate starts being valid. + + The timestamp is formatted as an ASN.1 TIME:: + + YYYYMMDDhhmmssZ + + :param bytes when: A timestamp string. + :return: ``None`` + """ + return self._set_boundary_time(_lib.X509_getm_notBefore, when) + + def get_notAfter(self) -> Optional[bytes]: + """ + Get the timestamp at which the certificate stops being valid. + + The timestamp is formatted as an ASN.1 TIME:: + + YYYYMMDDhhmmssZ + + :return: A timestamp string, or ``None`` if there is none. + :rtype: bytes or NoneType + """ + return self._get_boundary_time(_lib.X509_getm_notAfter) + + def set_notAfter(self, when: bytes) -> None: + """ + Set the timestamp at which the certificate stops being valid. + + The timestamp is formatted as an ASN.1 TIME:: + + YYYYMMDDhhmmssZ + + :param bytes when: A timestamp string. + :return: ``None`` + """ + return self._set_boundary_time(_lib.X509_getm_notAfter, when) + + def _get_name(self, which: Any) -> X509Name: + name = X509Name.__new__(X509Name) + name._name = which(self._x509) + _openssl_assert(name._name != _ffi.NULL) + + # The name is owned by the X509 structure. As long as the X509Name + # Python object is alive, keep the X509 Python object alive. + name._owner = self + + return name + + def _set_name(self, which: Any, name: X509Name) -> None: + if not isinstance(name, X509Name): + raise TypeError("name must be an X509Name") + set_result = which(self._x509, name._name) + _openssl_assert(set_result == 1) + + def get_issuer(self) -> X509Name: + """ + Return the issuer of this certificate. + + This creates a new :class:`X509Name` that wraps the underlying issuer + name field on the certificate. Modifying it will modify the underlying + certificate, and will have the effect of modifying any other + :class:`X509Name` that refers to this issuer. + + :return: The issuer of this certificate. + :rtype: :class:`X509Name` + """ + name = self._get_name(_lib.X509_get_issuer_name) + self._issuer_invalidator.add(name) + return name + + def set_issuer(self, issuer: X509Name) -> None: + """ + Set the issuer of this certificate. + + :param issuer: The issuer. + :type issuer: :py:class:`X509Name` + + :return: ``None`` + """ + self._set_name(_lib.X509_set_issuer_name, issuer) + self._issuer_invalidator.clear() + + def get_subject(self) -> X509Name: + """ + Return the subject of this certificate. + + This creates a new :class:`X509Name` that wraps the underlying subject + name field on the certificate. Modifying it will modify the underlying + certificate, and will have the effect of modifying any other + :class:`X509Name` that refers to this subject. + + :return: The subject of this certificate. + :rtype: :class:`X509Name` + """ + name = self._get_name(_lib.X509_get_subject_name) + self._subject_invalidator.add(name) + return name + + def set_subject(self, subject: X509Name) -> None: + """ + Set the subject of this certificate. + + :param subject: The subject. + :type subject: :py:class:`X509Name` + + :return: ``None`` + """ + self._set_name(_lib.X509_set_subject_name, subject) + self._subject_invalidator.clear() + + def get_extension_count(self) -> int: + """ + Get the number of extensions on this certificate. + + :return: The number of extensions. + :rtype: :py:class:`int` + + .. versionadded:: 0.12 + """ + return _lib.X509_get_ext_count(self._x509) + + def add_extensions( + self, extensions: Iterable[_X509ExtensionInternal] + ) -> None: + """ + Add extensions to the certificate. + + :param extensions: The extensions to add. + :type extensions: An iterable of :py:class:`X509Extension` objects. + :return: ``None`` + """ + for ext in extensions: + if not isinstance(ext, _X509ExtensionInternal): + raise ValueError("One of the elements is not an X509Extension") + + add_result = _lib.X509_add_ext(self._x509, ext._extension, -1) + if not add_result: + _raise_current_error() + + def get_extension(self, index: int) -> _X509ExtensionInternal: + """ + Get a specific extension of the certificate by index. + + Extensions on a certificate are kept in order. The index + parameter selects which extension will be returned. + + :param int index: The index of the extension to retrieve. + :return: The extension at the specified index. + :rtype: :py:class:`X509Extension` + :raises IndexError: If the extension index was out of bounds. + + .. versionadded:: 0.12 + """ + ext = _X509ExtensionInternal.__new__(_X509ExtensionInternal) + ext._extension = _lib.X509_get_ext(self._x509, index) + if ext._extension == _ffi.NULL: + raise IndexError("extension index out of bounds") + + extension = _lib.X509_EXTENSION_dup(ext._extension) + ext._extension = _ffi.gc(extension, _lib.X509_EXTENSION_free) + return ext + + +class X509StoreFlags: + """ + Flags for X509 verification, used to change the behavior of + :class:`X509Store`. + + See `OpenSSL Verification Flags`_ for details. + + .. _OpenSSL Verification Flags: + https://www.openssl.org/docs/manmaster/man3/X509_VERIFY_PARAM_set_flags.html + """ + + CRL_CHECK: int = _lib.X509_V_FLAG_CRL_CHECK + CRL_CHECK_ALL: int = _lib.X509_V_FLAG_CRL_CHECK_ALL + IGNORE_CRITICAL: int = _lib.X509_V_FLAG_IGNORE_CRITICAL + X509_STRICT: int = _lib.X509_V_FLAG_X509_STRICT + ALLOW_PROXY_CERTS: int = _lib.X509_V_FLAG_ALLOW_PROXY_CERTS + POLICY_CHECK: int = _lib.X509_V_FLAG_POLICY_CHECK + EXPLICIT_POLICY: int = _lib.X509_V_FLAG_EXPLICIT_POLICY + INHIBIT_MAP: int = _lib.X509_V_FLAG_INHIBIT_MAP + CHECK_SS_SIGNATURE: int = _lib.X509_V_FLAG_CHECK_SS_SIGNATURE + PARTIAL_CHAIN: int = _lib.X509_V_FLAG_PARTIAL_CHAIN + + +class X509Store: + """ + An X.509 store. + + An X.509 store is used to describe a context in which to verify a + certificate. A description of a context may include a set of certificates + to trust, a set of certificate revocation lists, verification flags and + more. + + An X.509 store, being only a description, cannot be used by itself to + verify a certificate. To carry out the actual verification process, see + :class:`X509StoreContext`. + """ + + def __init__(self) -> None: + store = _lib.X509_STORE_new() + self._store = _ffi.gc(store, _lib.X509_STORE_free) + + def add_cert(self, cert: X509) -> None: + """ + Adds a trusted certificate to this store. + + Adding a certificate with this method adds this certificate as a + *trusted* certificate. + + :param X509 cert: The certificate to add to this store. + + :raises TypeError: If the certificate is not an :class:`X509`. + + :raises OpenSSL.crypto.Error: If OpenSSL was unhappy with your + certificate. + + :return: ``None`` if the certificate was added successfully. + """ + if not isinstance(cert, X509): + raise TypeError() + + res = _lib.X509_STORE_add_cert(self._store, cert._x509) + _openssl_assert(res == 1) + + def add_crl( + self, crl: Union["_CRLInternal", x509.CertificateRevocationList] + ) -> None: + """ + Add a certificate revocation list to this store. + + The certificate revocation lists added to a store will only be used if + the associated flags are configured to check certificate revocation + lists. + + .. versionadded:: 16.1.0 + + :param crl: The certificate revocation list to add to this store. + :type crl: ``Union[CRL, cryptography.x509.CertificateRevocationList]`` + :return: ``None`` if the certificate revocation list was added + successfully. + """ + if isinstance(crl, x509.CertificateRevocationList): + from cryptography.hazmat.primitives.serialization import Encoding + + bio = _new_mem_buf(crl.public_bytes(Encoding.DER)) + openssl_crl = _lib.d2i_X509_CRL_bio(bio, _ffi.NULL) + if openssl_crl == _ffi.NULL: + _raise_current_error() + + crl = _ffi.gc(openssl_crl, _lib.X509_CRL_free) + elif isinstance(crl, _CRLInternal): + crl = crl._crl + else: + raise TypeError( + "CRL must be of type OpenSSL.crypto.CRL or " + "cryptography.x509.CertificateRevocationList" + ) + + _openssl_assert(_lib.X509_STORE_add_crl(self._store, crl) != 0) + + def set_flags(self, flags: int) -> None: + """ + Set verification flags to this store. + + Verification flags can be combined by oring them together. + + .. note:: + + Setting a verification flag sometimes requires clients to add + additional information to the store, otherwise a suitable error will + be raised. + + For example, in setting flags to enable CRL checking a + suitable CRL must be added to the store otherwise an error will be + raised. + + .. versionadded:: 16.1.0 + + :param int flags: The verification flags to set on this store. + See :class:`X509StoreFlags` for available constants. + :return: ``None`` if the verification flags were successfully set. + """ + _openssl_assert(_lib.X509_STORE_set_flags(self._store, flags) != 0) + + def set_time(self, vfy_time: datetime.datetime) -> None: + """ + Set the time against which the certificates are verified. + + Normally the current time is used. + + .. note:: + + For example, you can determine if a certificate was valid at a given + time. + + .. versionadded:: 17.0.0 + + :param datetime vfy_time: The verification time to set on this store. + :return: ``None`` if the verification time was successfully set. + """ + param = _lib.X509_VERIFY_PARAM_new() + param = _ffi.gc(param, _lib.X509_VERIFY_PARAM_free) + + _lib.X509_VERIFY_PARAM_set_time( + param, calendar.timegm(vfy_time.timetuple()) + ) + _openssl_assert(_lib.X509_STORE_set1_param(self._store, param) != 0) + + def load_locations( + self, cafile: StrOrBytesPath, capath: Optional[StrOrBytesPath] = None + ) -> None: + """ + Let X509Store know where we can find trusted certificates for the + certificate chain. Note that the certificates have to be in PEM + format. + + If *capath* is passed, it must be a directory prepared using the + ``c_rehash`` tool included with OpenSSL. Either, but not both, of + *cafile* or *capath* may be ``None``. + + .. note:: + + Both *cafile* and *capath* may be set simultaneously. + + Call this method multiple times to add more than one location. + For example, CA certificates, and certificate revocation list bundles + may be passed in *cafile* in subsequent calls to this method. + + .. versionadded:: 20.0 + + :param cafile: In which file we can find the certificates (``bytes`` or + ``unicode``). + :param capath: In which directory we can find the certificates + (``bytes`` or ``unicode``). + + :return: ``None`` if the locations were set successfully. + + :raises OpenSSL.crypto.Error: If both *cafile* and *capath* is ``None`` + or the locations could not be set for any reason. + + """ + if cafile is None: + cafile = _ffi.NULL + else: + cafile = _path_bytes(cafile) + + if capath is None: + capath = _ffi.NULL + else: + capath = _path_bytes(capath) + + load_result = _lib.X509_STORE_load_locations( + self._store, cafile, capath + ) + if not load_result: + _raise_current_error() + + +class X509StoreContextError(Exception): + """ + An exception raised when an error occurred while verifying a certificate + using `OpenSSL.X509StoreContext.verify_certificate`. + + :ivar certificate: The certificate which caused verificate failure. + :type certificate: :class:`X509` + """ + + def __init__( + self, message: str, errors: List[Any], certificate: X509 + ) -> None: + super().__init__(message) + self.errors = errors + self.certificate = certificate + + +class X509StoreContext: + """ + An X.509 store context. + + An X.509 store context is used to carry out the actual verification process + of a certificate in a described context. For describing such a context, see + :class:`X509Store`. + + :param X509Store store: The certificates which will be trusted for the + purposes of any verifications. + :param X509 certificate: The certificate to be verified. + :param chain: List of untrusted certificates that may be used for building + the certificate chain. May be ``None``. + :type chain: :class:`list` of :class:`X509` + """ + + def __init__( + self, + store: X509Store, + certificate: X509, + chain: Optional[Sequence[X509]] = None, + ) -> None: + self._store = store + self._cert = certificate + self._chain = self._build_certificate_stack(chain) + + @staticmethod + def _build_certificate_stack( + certificates: Optional[Sequence[X509]], + ) -> None: + def cleanup(s: Any) -> None: + # Equivalent to sk_X509_pop_free, but we don't + # currently have a CFFI binding for that available + for i in range(_lib.sk_X509_num(s)): + x = _lib.sk_X509_value(s, i) + _lib.X509_free(x) + _lib.sk_X509_free(s) + + if certificates is None or len(certificates) == 0: + return _ffi.NULL + + stack = _lib.sk_X509_new_null() + _openssl_assert(stack != _ffi.NULL) + stack = _ffi.gc(stack, cleanup) + + for cert in certificates: + if not isinstance(cert, X509): + raise TypeError("One of the elements is not an X509 instance") + + _openssl_assert(_lib.X509_up_ref(cert._x509) > 0) + if _lib.sk_X509_push(stack, cert._x509) <= 0: + _lib.X509_free(cert._x509) + _raise_current_error() + + return stack + + @staticmethod + def _exception_from_context(store_ctx: Any) -> X509StoreContextError: + """ + Convert an OpenSSL native context error failure into a Python + exception. + + When a call to native OpenSSL X509_verify_cert fails, additional + information about the failure can be obtained from the store context. + """ + message = _ffi.string( + _lib.X509_verify_cert_error_string( + _lib.X509_STORE_CTX_get_error(store_ctx) + ) + ).decode("utf-8") + errors = [ + _lib.X509_STORE_CTX_get_error(store_ctx), + _lib.X509_STORE_CTX_get_error_depth(store_ctx), + message, + ] + # A context error should always be associated with a certificate, so we + # expect this call to never return :class:`None`. + _x509 = _lib.X509_STORE_CTX_get_current_cert(store_ctx) + _cert = _lib.X509_dup(_x509) + pycert = X509._from_raw_x509_ptr(_cert) + return X509StoreContextError(message, errors, pycert) + + def _verify_certificate(self) -> Any: + """ + Verifies the certificate and runs an X509_STORE_CTX containing the + results. + + :raises X509StoreContextError: If an error occurred when validating a + certificate in the context. Sets ``certificate`` attribute to + indicate which certificate caused the error. + """ + store_ctx = _lib.X509_STORE_CTX_new() + _openssl_assert(store_ctx != _ffi.NULL) + store_ctx = _ffi.gc(store_ctx, _lib.X509_STORE_CTX_free) + + ret = _lib.X509_STORE_CTX_init( + store_ctx, self._store._store, self._cert._x509, self._chain + ) + _openssl_assert(ret == 1) + + ret = _lib.X509_verify_cert(store_ctx) + if ret <= 0: + raise self._exception_from_context(store_ctx) + + return store_ctx + + def set_store(self, store: X509Store) -> None: + """ + Set the context's X.509 store. + + .. versionadded:: 0.15 + + :param X509Store store: The store description which will be used for + the purposes of any *future* verifications. + """ + self._store = store + + def verify_certificate(self) -> None: + """ + Verify a certificate in a context. + + .. versionadded:: 0.15 + + :raises X509StoreContextError: If an error occurred when validating a + certificate in the context. Sets ``certificate`` attribute to + indicate which certificate caused the error. + """ + self._verify_certificate() + + def get_verified_chain(self) -> List[X509]: + """ + Verify a certificate in a context and return the complete validated + chain. + + :raises X509StoreContextError: If an error occurred when validating a + certificate in the context. Sets ``certificate`` attribute to + indicate which certificate caused the error. + + .. versionadded:: 20.0 + """ + store_ctx = self._verify_certificate() + + # Note: X509_STORE_CTX_get1_chain returns a deep copy of the chain. + cert_stack = _lib.X509_STORE_CTX_get1_chain(store_ctx) + _openssl_assert(cert_stack != _ffi.NULL) + + result = [] + for i in range(_lib.sk_X509_num(cert_stack)): + cert = _lib.sk_X509_value(cert_stack, i) + _openssl_assert(cert != _ffi.NULL) + pycert = X509._from_raw_x509_ptr(cert) + result.append(pycert) + + # Free the stack but not the members which are freed by the X509 class. + _lib.sk_X509_free(cert_stack) + return result + + +def load_certificate(type: int, buffer: bytes) -> X509: + """ + Load a certificate (X509) from the string *buffer* encoded with the + type *type*. + + :param type: The file type (one of FILETYPE_PEM, FILETYPE_ASN1) + + :param bytes buffer: The buffer the certificate is stored in + + :return: The X509 object + """ + if isinstance(buffer, str): + buffer = buffer.encode("ascii") + + bio = _new_mem_buf(buffer) + + if type == FILETYPE_PEM: + x509 = _lib.PEM_read_bio_X509(bio, _ffi.NULL, _ffi.NULL, _ffi.NULL) + elif type == FILETYPE_ASN1: + x509 = _lib.d2i_X509_bio(bio, _ffi.NULL) + else: + raise ValueError("type argument must be FILETYPE_PEM or FILETYPE_ASN1") + + if x509 == _ffi.NULL: + _raise_current_error() + + return X509._from_raw_x509_ptr(x509) + + +def dump_certificate(type: int, cert: X509) -> bytes: + """ + Dump the certificate *cert* into a buffer string encoded with the type + *type*. + + :param type: The file type (one of FILETYPE_PEM, FILETYPE_ASN1, or + FILETYPE_TEXT) + :param cert: The certificate to dump + :return: The buffer with the dumped certificate in + """ + bio = _new_mem_buf() + + if type == FILETYPE_PEM: + result_code = _lib.PEM_write_bio_X509(bio, cert._x509) + elif type == FILETYPE_ASN1: + result_code = _lib.i2d_X509_bio(bio, cert._x509) + elif type == FILETYPE_TEXT: + result_code = _lib.X509_print_ex(bio, cert._x509, 0, 0) + else: + raise ValueError( + "type argument must be FILETYPE_PEM, FILETYPE_ASN1, or " + "FILETYPE_TEXT" + ) + + _openssl_assert(result_code == 1) + return _bio_to_string(bio) + + +def dump_publickey(type: int, pkey: PKey) -> bytes: + """ + Dump a public key to a buffer. + + :param type: The file type (one of :data:`FILETYPE_PEM` or + :data:`FILETYPE_ASN1`). + :param PKey pkey: The public key to dump + :return: The buffer with the dumped key in it. + :rtype: bytes + """ + bio = _new_mem_buf() + if type == FILETYPE_PEM: + write_bio = _lib.PEM_write_bio_PUBKEY + elif type == FILETYPE_ASN1: + write_bio = _lib.i2d_PUBKEY_bio + else: + raise ValueError("type argument must be FILETYPE_PEM or FILETYPE_ASN1") + + result_code = write_bio(bio, pkey._pkey) + if result_code != 1: # pragma: no cover + _raise_current_error() + + return _bio_to_string(bio) + + +def dump_privatekey( + type: int, + pkey: PKey, + cipher: Optional[str] = None, + passphrase: Optional[PassphraseCallableT] = None, +) -> bytes: + """ + Dump the private key *pkey* into a buffer string encoded with the type + *type*. Optionally (if *type* is :const:`FILETYPE_PEM`) encrypting it + using *cipher* and *passphrase*. + + :param type: The file type (one of :const:`FILETYPE_PEM`, + :const:`FILETYPE_ASN1`, or :const:`FILETYPE_TEXT`) + :param PKey pkey: The PKey to dump + :param cipher: (optional) if encrypted PEM format, the cipher to use + :param passphrase: (optional) if encrypted PEM format, this can be either + the passphrase to use, or a callback for providing the passphrase. + + :return: The buffer with the dumped key in + :rtype: bytes + """ + bio = _new_mem_buf() + + if not isinstance(pkey, PKey): + raise TypeError("pkey must be a PKey") + + if cipher is not None: + if passphrase is None: + raise TypeError( + "if a value is given for cipher " + "one must also be given for passphrase" + ) + cipher_obj = _lib.EVP_get_cipherbyname(_byte_string(cipher)) + if cipher_obj == _ffi.NULL: + raise ValueError("Invalid cipher name") + else: + cipher_obj = _ffi.NULL + + helper = _PassphraseHelper(type, passphrase) + if type == FILETYPE_PEM: + result_code = _lib.PEM_write_bio_PrivateKey( + bio, + pkey._pkey, + cipher_obj, + _ffi.NULL, + 0, + helper.callback, + helper.callback_args, + ) + helper.raise_if_problem() + elif type == FILETYPE_ASN1: + result_code = _lib.i2d_PrivateKey_bio(bio, pkey._pkey) + elif type == FILETYPE_TEXT: + if _lib.EVP_PKEY_id(pkey._pkey) != _lib.EVP_PKEY_RSA: + raise TypeError("Only RSA keys are supported for FILETYPE_TEXT") + + rsa = _ffi.gc(_lib.EVP_PKEY_get1_RSA(pkey._pkey), _lib.RSA_free) + result_code = _lib.RSA_print(bio, rsa, 0) + else: + raise ValueError( + "type argument must be FILETYPE_PEM, FILETYPE_ASN1, or " + "FILETYPE_TEXT" + ) + + _openssl_assert(result_code != 0) + + return _bio_to_string(bio) + + +class Revoked: + """ + A certificate revocation. + """ + + # https://www.openssl.org/docs/manmaster/man5/x509v3_config.html#CRL-distribution-points + # which differs from crl_reasons of crypto/x509v3/v3_enum.c that matches + # OCSP_crl_reason_str. We use the latter, just like the command line + # program. + _crl_reasons: typing.ClassVar[typing.List[bytes]] = [ + b"unspecified", + b"keyCompromise", + b"CACompromise", + b"affiliationChanged", + b"superseded", + b"cessationOfOperation", + b"certificateHold", + # b"removeFromCRL", + ] + + def __init__(self) -> None: + revoked = _lib.X509_REVOKED_new() + self._revoked = _ffi.gc(revoked, _lib.X509_REVOKED_free) + + def set_serial(self, hex_str: bytes) -> None: + """ + Set the serial number. + + The serial number is formatted as a hexadecimal number encoded in + ASCII. + + :param bytes hex_str: The new serial number. + + :return: ``None`` + """ + bignum_serial = _ffi.gc(_lib.BN_new(), _lib.BN_free) + bignum_ptr = _ffi.new("BIGNUM**") + bignum_ptr[0] = bignum_serial + bn_result = _lib.BN_hex2bn(bignum_ptr, hex_str) + if not bn_result: + raise ValueError("bad hex string") + + asn1_serial = _ffi.gc( + _lib.BN_to_ASN1_INTEGER(bignum_serial, _ffi.NULL), + _lib.ASN1_INTEGER_free, + ) + _lib.X509_REVOKED_set_serialNumber(self._revoked, asn1_serial) + + def get_serial(self) -> bytes: + """ + Get the serial number. + + The serial number is formatted as a hexadecimal number encoded in + ASCII. + + :return: The serial number. + :rtype: bytes + """ + bio = _new_mem_buf() + + asn1_int = _lib.X509_REVOKED_get0_serialNumber(self._revoked) + _openssl_assert(asn1_int != _ffi.NULL) + result = _lib.i2a_ASN1_INTEGER(bio, asn1_int) + _openssl_assert(result >= 0) + return _bio_to_string(bio) + + def _delete_reason(self) -> None: + for i in range(_lib.X509_REVOKED_get_ext_count(self._revoked)): + ext = _lib.X509_REVOKED_get_ext(self._revoked, i) + obj = _lib.X509_EXTENSION_get_object(ext) + if _lib.OBJ_obj2nid(obj) == _lib.NID_crl_reason: + _lib.X509_EXTENSION_free(ext) + _lib.X509_REVOKED_delete_ext(self._revoked, i) + break + + def set_reason(self, reason: Optional[bytes]) -> None: + """ + Set the reason of this revocation. + + If :data:`reason` is ``None``, delete the reason instead. + + :param reason: The reason string. + :type reason: :class:`bytes` or :class:`NoneType` + + :return: ``None`` + + .. seealso:: + + :meth:`all_reasons`, which gives you a list of all supported + reasons which you might pass to this method. + """ + if reason is None: + self._delete_reason() + elif not isinstance(reason, bytes): + raise TypeError("reason must be None or a byte string") + else: + reason = reason.lower().replace(b" ", b"") + reason_code = [r.lower() for r in self._crl_reasons].index(reason) + + new_reason_ext = _lib.ASN1_ENUMERATED_new() + _openssl_assert(new_reason_ext != _ffi.NULL) + new_reason_ext = _ffi.gc(new_reason_ext, _lib.ASN1_ENUMERATED_free) + + set_result = _lib.ASN1_ENUMERATED_set(new_reason_ext, reason_code) + _openssl_assert(set_result != _ffi.NULL) + + self._delete_reason() + add_result = _lib.X509_REVOKED_add1_ext_i2d( + self._revoked, _lib.NID_crl_reason, new_reason_ext, 0, 0 + ) + _openssl_assert(add_result == 1) + + def get_reason(self) -> Optional[bytes]: + """ + Get the reason of this revocation. + + :return: The reason, or ``None`` if there is none. + :rtype: bytes or NoneType + + .. seealso:: + + :meth:`all_reasons`, which gives you a list of all supported + reasons this method might return. + """ + for i in range(_lib.X509_REVOKED_get_ext_count(self._revoked)): + ext = _lib.X509_REVOKED_get_ext(self._revoked, i) + obj = _lib.X509_EXTENSION_get_object(ext) + if _lib.OBJ_obj2nid(obj) == _lib.NID_crl_reason: + bio = _new_mem_buf() + + print_result = _lib.X509V3_EXT_print(bio, ext, 0, 0) + if not print_result: + print_result = _lib.M_ASN1_OCTET_STRING_print( + bio, _lib.X509_EXTENSION_get_data(ext) + ) + _openssl_assert(print_result != 0) + + return _bio_to_string(bio) + return None + + def all_reasons(self) -> List[bytes]: + """ + Return a list of all the supported reason strings. + + This list is a copy; modifying it does not change the supported reason + strings. + + :return: A list of reason strings. + :rtype: :class:`list` of :class:`bytes` + """ + return self._crl_reasons[:] + + def set_rev_date(self, when: bytes) -> None: + """ + Set the revocation timestamp. + + :param bytes when: The timestamp of the revocation, + as ASN.1 TIME. + :return: ``None`` + """ + revocationDate = _new_asn1_time(when) + ret = _lib.X509_REVOKED_set_revocationDate( + self._revoked, revocationDate + ) + _openssl_assert(ret == 1) + + def get_rev_date(self) -> Optional[bytes]: + """ + Get the revocation timestamp. + + :return: The timestamp of the revocation, as ASN.1 TIME. + :rtype: bytes + """ + dt = _lib.X509_REVOKED_get0_revocationDate(self._revoked) + return _get_asn1_time(dt) + + +_RevokedInternal = Revoked +utils.deprecated( + Revoked, + __name__, + ( + "CRL support in pyOpenSSL is deprecated. You should use the APIs " + "in cryptography." + ), + DeprecationWarning, + name="Revoked", +) + + +class CRL: + """ + A certificate revocation list. + """ + + def __init__(self) -> None: + crl = _lib.X509_CRL_new() + self._crl = _ffi.gc(crl, _lib.X509_CRL_free) + + def to_cryptography(self) -> x509.CertificateRevocationList: + """ + Export as a ``cryptography`` CRL. + + :rtype: ``cryptography.x509.CertificateRevocationList`` + + .. versionadded:: 17.1.0 + """ + from cryptography.x509 import load_der_x509_crl + + der = _dump_crl_internal(FILETYPE_ASN1, self) + return load_der_x509_crl(der) + + @classmethod + def from_cryptography( + cls, crypto_crl: x509.CertificateRevocationList + ) -> "_CRLInternal": + """ + Construct based on a ``cryptography`` *crypto_crl*. + + :param crypto_crl: A ``cryptography`` certificate revocation list + :type crypto_crl: ``cryptography.x509.CertificateRevocationList`` + + :rtype: CRL + + .. versionadded:: 17.1.0 + """ + if not isinstance(crypto_crl, x509.CertificateRevocationList): + raise TypeError("Must be a certificate revocation list") + + from cryptography.hazmat.primitives.serialization import Encoding + + der = crypto_crl.public_bytes(Encoding.DER) + return _load_crl_internal(FILETYPE_ASN1, der) + + def get_revoked(self) -> Optional[Tuple[_RevokedInternal, ...]]: + """ + Return the revocations in this certificate revocation list. + + These revocations will be provided by value, not by reference. + That means it's okay to mutate them: it won't affect this CRL. + + :return: The revocations in this CRL. + :rtype: :class:`tuple` of :class:`Revocation` + """ + results = [] + revoked_stack = _lib.X509_CRL_get_REVOKED(self._crl) + for i in range(_lib.sk_X509_REVOKED_num(revoked_stack)): + revoked = _lib.sk_X509_REVOKED_value(revoked_stack, i) + revoked_copy = _lib.X509_REVOKED_dup(revoked) + pyrev = _RevokedInternal.__new__(_RevokedInternal) + pyrev._revoked = _ffi.gc(revoked_copy, _lib.X509_REVOKED_free) + results.append(pyrev) + if results: + return tuple(results) + return None + + def add_revoked(self, revoked: _RevokedInternal) -> None: + """ + Add a revoked (by value not reference) to the CRL structure + + This revocation will be added by value, not by reference. That + means it's okay to mutate it after adding: it won't affect + this CRL. + + :param Revoked revoked: The new revocation. + :return: ``None`` + """ + copy = _lib.X509_REVOKED_dup(revoked._revoked) + _openssl_assert(copy != _ffi.NULL) + + add_result = _lib.X509_CRL_add0_revoked(self._crl, copy) + _openssl_assert(add_result != 0) + + def get_issuer(self) -> X509Name: + """ + Get the CRL's issuer. + + .. versionadded:: 16.1.0 + + :rtype: X509Name + """ + _issuer = _lib.X509_NAME_dup(_lib.X509_CRL_get_issuer(self._crl)) + _openssl_assert(_issuer != _ffi.NULL) + _issuer = _ffi.gc(_issuer, _lib.X509_NAME_free) + issuer = X509Name.__new__(X509Name) + issuer._name = _issuer + return issuer + + def set_version(self, version: int) -> None: + """ + Set the CRL version. + + .. versionadded:: 16.1.0 + + :param int version: The version of the CRL. + :return: ``None`` + """ + _openssl_assert(_lib.X509_CRL_set_version(self._crl, version) != 0) + + def set_lastUpdate(self, when: bytes) -> None: + """ + Set when the CRL was last updated. + + The timestamp is formatted as an ASN.1 TIME:: + + YYYYMMDDhhmmssZ + + .. versionadded:: 16.1.0 + + :param bytes when: A timestamp string. + :return: ``None`` + """ + lastUpdate = _new_asn1_time(when) + ret = _lib.X509_CRL_set1_lastUpdate(self._crl, lastUpdate) + _openssl_assert(ret == 1) + + def set_nextUpdate(self, when: bytes) -> None: + """ + Set when the CRL will next be updated. + + The timestamp is formatted as an ASN.1 TIME:: + + YYYYMMDDhhmmssZ + + .. versionadded:: 16.1.0 + + :param bytes when: A timestamp string. + :return: ``None`` + """ + nextUpdate = _new_asn1_time(when) + ret = _lib.X509_CRL_set1_nextUpdate(self._crl, nextUpdate) + _openssl_assert(ret == 1) + + def sign(self, issuer_cert: X509, issuer_key: PKey, digest: bytes) -> None: + """ + Sign the CRL. + + Signing a CRL enables clients to associate the CRL itself with an + issuer. Before a CRL is meaningful to other OpenSSL functions, it must + be signed by an issuer. + + This method implicitly sets the issuer's name based on the issuer + certificate and private key used to sign the CRL. + + .. versionadded:: 16.1.0 + + :param X509 issuer_cert: The issuer's certificate. + :param PKey issuer_key: The issuer's private key. + :param bytes digest: The digest method to sign the CRL with. + """ + digest_obj = _lib.EVP_get_digestbyname(digest) + _openssl_assert(digest_obj != _ffi.NULL) + _lib.X509_CRL_set_issuer_name( + self._crl, _lib.X509_get_subject_name(issuer_cert._x509) + ) + _lib.X509_CRL_sort(self._crl) + result = _lib.X509_CRL_sign(self._crl, issuer_key._pkey, digest_obj) + _openssl_assert(result != 0) + + def export( + self, + cert: X509, + key: PKey, + type: int = FILETYPE_PEM, + days: int = 100, + digest: bytes = _UNSPECIFIED, # type: ignore + ) -> bytes: + """ + Export the CRL as a string. + + :param X509 cert: The certificate used to sign the CRL. + :param PKey key: The key used to sign the CRL. + :param int type: The export format, either :data:`FILETYPE_PEM`, + :data:`FILETYPE_ASN1`, or :data:`FILETYPE_TEXT`. + :param int days: The number of days until the next update of this CRL. + :param bytes digest: The name of the message digest to use (eg + ``b"sha256"``). + :rtype: bytes + """ + + if not isinstance(cert, X509): + raise TypeError("cert must be an X509 instance") + if not isinstance(key, PKey): + raise TypeError("key must be a PKey instance") + if not isinstance(type, int): + raise TypeError("type must be an integer") + + if digest is _UNSPECIFIED: + raise TypeError("digest must be provided") + + digest_obj = _lib.EVP_get_digestbyname(digest) + if digest_obj == _ffi.NULL: + raise ValueError("No such digest method") + + # A scratch time object to give different values to different CRL + # fields + sometime = _lib.ASN1_TIME_new() + _openssl_assert(sometime != _ffi.NULL) + sometime = _ffi.gc(sometime, _lib.ASN1_TIME_free) + + ret = _lib.X509_gmtime_adj(sometime, 0) + _openssl_assert(ret != _ffi.NULL) + ret = _lib.X509_CRL_set1_lastUpdate(self._crl, sometime) + _openssl_assert(ret == 1) + + ret = _lib.X509_gmtime_adj(sometime, days * 24 * 60 * 60) + _openssl_assert(ret != _ffi.NULL) + ret = _lib.X509_CRL_set1_nextUpdate(self._crl, sometime) + _openssl_assert(ret == 1) + + ret = _lib.X509_CRL_set_issuer_name( + self._crl, _lib.X509_get_subject_name(cert._x509) + ) + _openssl_assert(ret == 1) + + sign_result = _lib.X509_CRL_sign(self._crl, key._pkey, digest_obj) + if not sign_result: + _raise_current_error() + + return _dump_crl_internal(type, self) + + +_CRLInternal = CRL +utils.deprecated( + CRL, + __name__, + ( + "CRL support in pyOpenSSL is deprecated. You should use the APIs " + "in cryptography." + ), + DeprecationWarning, + name="CRL", +) + + +class _PassphraseHelper: + def __init__( + self, + type: int, + passphrase: Optional[PassphraseCallableT], + more_args: bool = False, + truncate: bool = False, + ) -> None: + if type != FILETYPE_PEM and passphrase is not None: + raise ValueError( + "only FILETYPE_PEM key format supports encryption" + ) + self._passphrase = passphrase + self._more_args = more_args + self._truncate = truncate + self._problems: List[Exception] = [] + + @property + def callback(self) -> Any: + if self._passphrase is None: + return _ffi.NULL + elif isinstance(self._passphrase, bytes) or callable(self._passphrase): + return _ffi.callback("pem_password_cb", self._read_passphrase) + else: + raise TypeError( + "Last argument must be a byte string or a callable." + ) + + @property + def callback_args(self) -> Any: + if self._passphrase is None: + return _ffi.NULL + elif isinstance(self._passphrase, bytes) or callable(self._passphrase): + return _ffi.NULL + else: + raise TypeError( + "Last argument must be a byte string or a callable." + ) + + def raise_if_problem(self, exceptionType: Type[Exception] = Error) -> None: + if self._problems: + # Flush the OpenSSL error queue + try: + _exception_from_error_queue(exceptionType) + except exceptionType: + pass + + raise self._problems.pop(0) + + def _read_passphrase( + self, buf: Any, size: int, rwflag: Any, userdata: Any + ) -> int: + try: + if callable(self._passphrase): + if self._more_args: + result = self._passphrase(size, rwflag, userdata) + else: + result = self._passphrase(rwflag) + else: + assert self._passphrase is not None + result = self._passphrase + if not isinstance(result, bytes): + raise ValueError("Bytes expected") + if len(result) > size: + if self._truncate: + result = result[:size] + else: + raise ValueError( + "passphrase returned by callback is too long" + ) + for i in range(len(result)): + buf[i] = result[i : i + 1] + return len(result) + except Exception as e: + self._problems.append(e) + return 0 + + +def load_publickey(type: int, buffer: Union[str, bytes]) -> PKey: + """ + Load a public key from a buffer. + + :param type: The file type (one of :data:`FILETYPE_PEM`, + :data:`FILETYPE_ASN1`). + :param buffer: The buffer the key is stored in. + :type buffer: A Python string object, either unicode or bytestring. + :return: The PKey object. + :rtype: :class:`PKey` + """ + if isinstance(buffer, str): + buffer = buffer.encode("ascii") + + bio = _new_mem_buf(buffer) + + if type == FILETYPE_PEM: + evp_pkey = _lib.PEM_read_bio_PUBKEY( + bio, _ffi.NULL, _ffi.NULL, _ffi.NULL + ) + elif type == FILETYPE_ASN1: + evp_pkey = _lib.d2i_PUBKEY_bio(bio, _ffi.NULL) + else: + raise ValueError("type argument must be FILETYPE_PEM or FILETYPE_ASN1") + + if evp_pkey == _ffi.NULL: + _raise_current_error() + + pkey = PKey.__new__(PKey) + pkey._pkey = _ffi.gc(evp_pkey, _lib.EVP_PKEY_free) + pkey._only_public = True + return pkey + + +def load_privatekey( + type: int, + buffer: Union[str, bytes], + passphrase: Optional[PassphraseCallableT] = None, +) -> PKey: + """ + Load a private key (PKey) from the string *buffer* encoded with the type + *type*. + + :param type: The file type (one of FILETYPE_PEM, FILETYPE_ASN1) + :param buffer: The buffer the key is stored in + :param passphrase: (optional) if encrypted PEM format, this can be + either the passphrase to use, or a callback for + providing the passphrase. + + :return: The PKey object + """ + if isinstance(buffer, str): + buffer = buffer.encode("ascii") + + bio = _new_mem_buf(buffer) + + helper = _PassphraseHelper(type, passphrase) + if type == FILETYPE_PEM: + evp_pkey = _lib.PEM_read_bio_PrivateKey( + bio, _ffi.NULL, helper.callback, helper.callback_args + ) + helper.raise_if_problem() + elif type == FILETYPE_ASN1: + evp_pkey = _lib.d2i_PrivateKey_bio(bio, _ffi.NULL) + else: + raise ValueError("type argument must be FILETYPE_PEM or FILETYPE_ASN1") + + if evp_pkey == _ffi.NULL: + _raise_current_error() + + pkey = PKey.__new__(PKey) + pkey._pkey = _ffi.gc(evp_pkey, _lib.EVP_PKEY_free) + return pkey + + +def dump_certificate_request(type: int, req: X509Req) -> bytes: + """ + Dump the certificate request *req* into a buffer string encoded with the + type *type*. + + :param type: The file type (one of FILETYPE_PEM, FILETYPE_ASN1) + :param req: The certificate request to dump + :return: The buffer with the dumped certificate request in + """ + bio = _new_mem_buf() + + if type == FILETYPE_PEM: + result_code = _lib.PEM_write_bio_X509_REQ(bio, req._req) + elif type == FILETYPE_ASN1: + result_code = _lib.i2d_X509_REQ_bio(bio, req._req) + elif type == FILETYPE_TEXT: + result_code = _lib.X509_REQ_print_ex(bio, req._req, 0, 0) + else: + raise ValueError( + "type argument must be FILETYPE_PEM, FILETYPE_ASN1, or " + "FILETYPE_TEXT" + ) + + _openssl_assert(result_code != 0) + + return _bio_to_string(bio) + + +_dump_certificate_request_internal = dump_certificate_request + +utils.deprecated( + dump_certificate_request, + __name__, + ( + "CSR support in pyOpenSSL is deprecated. You should use the APIs " + "in cryptography." + ), + DeprecationWarning, + name="dump_certificate_request", +) + + +def load_certificate_request(type: int, buffer: bytes) -> X509Req: + """ + Load a certificate request (X509Req) from the string *buffer* encoded with + the type *type*. + + :param type: The file type (one of FILETYPE_PEM, FILETYPE_ASN1) + :param buffer: The buffer the certificate request is stored in + :return: The X509Req object + """ + if isinstance(buffer, str): + buffer = buffer.encode("ascii") + + bio = _new_mem_buf(buffer) + + if type == FILETYPE_PEM: + req = _lib.PEM_read_bio_X509_REQ(bio, _ffi.NULL, _ffi.NULL, _ffi.NULL) + elif type == FILETYPE_ASN1: + req = _lib.d2i_X509_REQ_bio(bio, _ffi.NULL) + else: + raise ValueError("type argument must be FILETYPE_PEM or FILETYPE_ASN1") + + _openssl_assert(req != _ffi.NULL) + + x509req = _X509ReqInternal.__new__(_X509ReqInternal) + x509req._req = _ffi.gc(req, _lib.X509_REQ_free) + return x509req + + +_load_certificate_request_internal = load_certificate_request + +utils.deprecated( + load_certificate_request, + __name__, + ( + "CSR support in pyOpenSSL is deprecated. You should use the APIs " + "in cryptography." + ), + DeprecationWarning, + name="load_certificate_request", +) + + +def sign(pkey: PKey, data: Union[str, bytes], digest: str) -> bytes: + """ + Sign a data string using the given key and message digest. + + :param pkey: PKey to sign with + :param data: data to be signed + :param digest: message digest to use + :return: signature + + .. versionadded:: 0.11 + """ + data = _text_to_bytes_and_warn("data", data) + + digest_obj = _lib.EVP_get_digestbyname(_byte_string(digest)) + if digest_obj == _ffi.NULL: + raise ValueError("No such digest method") + + md_ctx = _lib.EVP_MD_CTX_new() + md_ctx = _ffi.gc(md_ctx, _lib.EVP_MD_CTX_free) + + _lib.EVP_SignInit(md_ctx, digest_obj) + _lib.EVP_SignUpdate(md_ctx, data, len(data)) + + length = _lib.EVP_PKEY_size(pkey._pkey) + _openssl_assert(length > 0) + signature_buffer = _ffi.new("unsigned char[]", length) + signature_length = _ffi.new("unsigned int *") + final_result = _lib.EVP_SignFinal( + md_ctx, signature_buffer, signature_length, pkey._pkey + ) + _openssl_assert(final_result == 1) + + return _ffi.buffer(signature_buffer, signature_length[0])[:] + + +utils.deprecated( + sign, + __name__, + "sign() is deprecated. Use the equivalent APIs in cryptography.", + DeprecationWarning, + name="sign", +) + + +def verify( + cert: X509, signature: bytes, data: Union[str, bytes], digest: str +) -> None: + """ + Verify the signature for a data string. + + :param cert: signing certificate (X509 object) corresponding to the + private key which generated the signature. + :param signature: signature returned by sign function + :param data: data to be verified + :param digest: message digest to use + :return: ``None`` if the signature is correct, raise exception otherwise. + + .. versionadded:: 0.11 + """ + data = _text_to_bytes_and_warn("data", data) + + digest_obj = _lib.EVP_get_digestbyname(_byte_string(digest)) + if digest_obj == _ffi.NULL: + raise ValueError("No such digest method") + + pkey = _lib.X509_get_pubkey(cert._x509) + _openssl_assert(pkey != _ffi.NULL) + pkey = _ffi.gc(pkey, _lib.EVP_PKEY_free) + + md_ctx = _lib.EVP_MD_CTX_new() + md_ctx = _ffi.gc(md_ctx, _lib.EVP_MD_CTX_free) + + _lib.EVP_VerifyInit(md_ctx, digest_obj) + _lib.EVP_VerifyUpdate(md_ctx, data, len(data)) + verify_result = _lib.EVP_VerifyFinal( + md_ctx, signature, len(signature), pkey + ) + + if verify_result != 1: + _raise_current_error() + + +utils.deprecated( + verify, + __name__, + "verify() is deprecated. Use the equivalent APIs in cryptography.", + DeprecationWarning, + name="verify", +) + + +def dump_crl(type: int, crl: _CRLInternal) -> bytes: + """ + Dump a certificate revocation list to a buffer. + + :param type: The file type (one of ``FILETYPE_PEM``, ``FILETYPE_ASN1``, or + ``FILETYPE_TEXT``). + :param CRL crl: The CRL to dump. + + :return: The buffer with the CRL. + :rtype: bytes + """ + bio = _new_mem_buf() + + if type == FILETYPE_PEM: + ret = _lib.PEM_write_bio_X509_CRL(bio, crl._crl) + elif type == FILETYPE_ASN1: + ret = _lib.i2d_X509_CRL_bio(bio, crl._crl) + elif type == FILETYPE_TEXT: + ret = _lib.X509_CRL_print(bio, crl._crl) + else: + raise ValueError( + "type argument must be FILETYPE_PEM, FILETYPE_ASN1, or " + "FILETYPE_TEXT" + ) + + _openssl_assert(ret == 1) + return _bio_to_string(bio) + + +_dump_crl_internal = dump_crl +utils.deprecated( + dump_crl, + __name__, + ( + "CRL support in pyOpenSSL is deprecated. You should use the APIs " + "in cryptography." + ), + DeprecationWarning, + name="dump_crl", +) + + +def load_crl(type: int, buffer: Union[str, bytes]) -> _CRLInternal: + """ + Load Certificate Revocation List (CRL) data from a string *buffer*. + *buffer* encoded with the type *type*. + + :param type: The file type (one of FILETYPE_PEM, FILETYPE_ASN1) + :param buffer: The buffer the CRL is stored in + + :return: The CRL object + """ + if isinstance(buffer, str): + buffer = buffer.encode("ascii") + + bio = _new_mem_buf(buffer) + + if type == FILETYPE_PEM: + crl = _lib.PEM_read_bio_X509_CRL(bio, _ffi.NULL, _ffi.NULL, _ffi.NULL) + elif type == FILETYPE_ASN1: + crl = _lib.d2i_X509_CRL_bio(bio, _ffi.NULL) + else: + raise ValueError("type argument must be FILETYPE_PEM or FILETYPE_ASN1") + + if crl == _ffi.NULL: + _raise_current_error() + + result = _CRLInternal.__new__(_CRLInternal) + result._crl = _ffi.gc(crl, _lib.X509_CRL_free) + return result + + +_load_crl_internal = load_crl +utils.deprecated( + load_crl, + __name__, + ( + "CRL support in pyOpenSSL is deprecated. You should use the APIs " + "in cryptography." + ), + DeprecationWarning, + name="load_crl", +) diff --git a/.venv/lib/python3.11/site-packages/OpenSSL/debug.py b/.venv/lib/python3.11/site-packages/OpenSSL/debug.py new file mode 100644 index 0000000000000000000000000000000000000000..e0ed3f81d62d090e0310cc8b0b78042821b12e86 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/OpenSSL/debug.py @@ -0,0 +1,40 @@ +import ssl +import sys + +import cffi +import cryptography + +import OpenSSL.SSL + +from . import version + +_env_info = """\ +pyOpenSSL: {pyopenssl} +cryptography: {cryptography} +cffi: {cffi} +cryptography's compiled against OpenSSL: {crypto_openssl_compile} +cryptography's linked OpenSSL: {crypto_openssl_link} +Python's OpenSSL: {python_openssl} +Python executable: {python} +Python version: {python_version} +Platform: {platform} +sys.path: {sys_path}""".format( + pyopenssl=version.__version__, + crypto_openssl_compile=OpenSSL._util.ffi.string( + OpenSSL._util.lib.OPENSSL_VERSION_TEXT, + ).decode("ascii"), + crypto_openssl_link=OpenSSL.SSL.SSLeay_version( + OpenSSL.SSL.SSLEAY_VERSION + ).decode("ascii"), + python_openssl=getattr(ssl, "OPENSSL_VERSION", "n/a"), + cryptography=cryptography.__version__, + cffi=cffi.__version__, + python=sys.executable, + python_version=sys.version, + platform=sys.platform, + sys_path=sys.path, +) + + +if __name__ == "__main__": + print(_env_info) diff --git a/.venv/lib/python3.11/site-packages/OpenSSL/py.typed b/.venv/lib/python3.11/site-packages/OpenSSL/py.typed new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/OpenSSL/rand.py b/.venv/lib/python3.11/site-packages/OpenSSL/rand.py new file mode 100644 index 0000000000000000000000000000000000000000..a4aa7210dda2eb56ea4fb3394c0919cfd6582a6c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/OpenSSL/rand.py @@ -0,0 +1,40 @@ +""" +PRNG management routines, thin wrappers. +""" + +from OpenSSL._util import lib as _lib + + +def add(buffer: bytes, entropy: int) -> None: + """ + Mix bytes from *string* into the PRNG state. + + The *entropy* argument is (the lower bound of) an estimate of how much + randomness is contained in *string*, measured in bytes. + + For more information, see e.g. :rfc:`1750`. + + This function is only relevant if you are forking Python processes and + need to reseed the CSPRNG after fork. + + :param buffer: Buffer with random data. + :param entropy: The entropy (in bytes) measurement of the buffer. + + :return: :obj:`None` + """ + if not isinstance(buffer, bytes): + raise TypeError("buffer must be a byte string") + + if not isinstance(entropy, int): + raise TypeError("entropy must be an integer") + + _lib.RAND_add(buffer, len(buffer), entropy) + + +def status() -> int: + """ + Check whether the PRNG has been seeded with enough data. + + :return: 1 if the PRNG is seeded enough, 0 otherwise. + """ + return _lib.RAND_status() diff --git a/.venv/lib/python3.11/site-packages/OpenSSL/version.py b/.venv/lib/python3.11/site-packages/OpenSSL/version.py new file mode 100644 index 0000000000000000000000000000000000000000..0cfb06537bf9ec3a54b0a97556b74cd94a7930b2 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/OpenSSL/version.py @@ -0,0 +1,28 @@ +# Copyright (C) AB Strakt +# Copyright (C) Jean-Paul Calderone +# See LICENSE for details. + +""" +pyOpenSSL - A simple wrapper around the OpenSSL library +""" + +__all__ = [ + "__author__", + "__copyright__", + "__email__", + "__license__", + "__summary__", + "__title__", + "__uri__", + "__version__", +] + +__version__ = "24.2.1" + +__title__ = "pyOpenSSL" +__uri__ = "https://pyopenssl.org/" +__summary__ = "Python wrapper module around the OpenSSL library" +__author__ = "The pyOpenSSL developers" +__email__ = "cryptography-dev@python.org" +__license__ = "Apache License, Version 2.0" +__copyright__ = f"Copyright 2001-2024 {__author__}" diff --git a/.venv/lib/python3.11/site-packages/compressed_tensors/__init__.py b/.venv/lib/python3.11/site-packages/compressed_tensors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..64a52ddac805e028276e702617ba9de24db6642f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/compressed_tensors/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .base import * + +# flake8: noqa +from .compressors import * +from .config import * +from .quantization import QuantizationConfig, QuantizationStatus +from .utils import * +from .version import * diff --git a/.venv/lib/python3.11/site-packages/compressed_tensors/base.py b/.venv/lib/python3.11/site-packages/compressed_tensors/base.py new file mode 100644 index 0000000000000000000000000000000000000000..0e073262f85632636ecd8aedef9d2fa16c3b1648 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/compressed_tensors/base.py @@ -0,0 +1,20 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +SPARSITY_CONFIG_NAME = "sparsity_config" +QUANTIZATION_CONFIG_NAME = "quantization_config" +COMPRESSION_CONFIG_NAME = "compression_config" +KV_CACHE_SCHEME_NAME = "kv_cache_scheme" +COMPRESSION_VERSION_NAME = "version" +QUANTIZATION_METHOD_NAME = "quant_method" diff --git a/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/__init__.py b/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..138e3899bf47d1b22914781cfb43dc9598b5eb95 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# flake8: noqa + +from .base import * +from .helpers import * +from .model_compressors import * +from .quantized_compressors import * +from .sparse_compressors import * +from .sparse_quantized_compressors import * diff --git a/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/base.py b/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/base.py new file mode 100644 index 0000000000000000000000000000000000000000..ee7510530def57b37307576f51ccbc4df250de94 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/base.py @@ -0,0 +1,188 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from typing import Dict, Generator, Optional, Tuple, Union + +import torch +from compressed_tensors.config import SparsityCompressionConfig +from compressed_tensors.quantization import QuantizationArgs, QuantizationConfig +from compressed_tensors.registry import RegistryMixin +from torch import Tensor +from torch.nn import Module + + +__all__ = ["BaseCompressor"] + + +class BaseCompressor(RegistryMixin, ABC): + """ + Base class representing a model compression algorithm. Each child class should + implement compression_param_info, compress_weight and decompress_weight. + + Compressors support compressing/decompressing a full module state dict or a single + quantized PyTorch leaf module. + + Model Load Lifecycle (run_compressed=False): + - ModelCompressor.decompress() + - apply_quantization_config() + - BaseCompressor.decompress() + + Model Save Lifecycle: + - ModelCompressor.compress() + - BaseCompressor.compress() + + + Module Lifecycle (run_compressed=True): + - apply_quantization_config() + - compressed_module = CompressedLinear(module) + - initialize_module_for_quantization() + - BaseCompressor.compression_param_info() + - register_parameters() + - compressed_module.forward() + -compressed_module.decompress() + + + :param config: config specifying compression parameters + """ + + def __init__( + self, config: Union[SparsityCompressionConfig, QuantizationConfig, None] = None + ): + self.config = config + + def compression_param_info( + self, + weight_shape: torch.Size, + quantization_args: Optional[QuantizationArgs] = None, + ) -> Dict[str, Tuple[torch.Size, torch.dtype]]: + """ + Creates a dictionary of expected shapes and dtypes for each compression + parameter used by the compressor + + :param weight_shape: uncompressed weight shape + :param quantization_args: quantization parameters for the weight + :return: dictionary mapping compressed parameter names to shape and dtype + """ + raise NotImplementedError() + + @abstractmethod + def compress( + self, + model_state: Dict[str, Tensor], + **kwargs, + ) -> Dict[str, Tensor]: + """ + Compresses a dense state dict + + :param model_state: state dict of uncompressed model + :param kwargs: additional arguments for compression + :return: compressed state dict + """ + raise NotImplementedError() + + @abstractmethod + def decompress( + self, + path_to_model_or_tensors: str, + device: str = "cpu", + **kwargs, + ) -> Generator[Tuple[str, Tensor], None, None]: + """ + Reads a compressed state dict located at path_to_model_or_tensors + and returns a generator for sequentially decompressing back to a + dense state dict + + :param path_to_model_or_tensors: path to compressed safetensors model (directory + with one or more safetensors files) or compressed tensors file + :param names_to_scheme: quantization args for each quantized weight + :param device: optional device to load intermediate weights into + :return: compressed state dict + """ + raise NotImplementedError() + + def compress_module(self, module: Module) -> Optional[Dict[str, torch.Tensor]]: + """ + Compresses a single quantized leaf PyTorch module. If the module is not + quantized, this function has no effect. + + :param module: PyTorch module to compress + :return: dictionary of compressed weight data, or None if module is not + quantized + """ + if not hasattr(module, "quantization_scheme"): + return None # module is not quantized + quantization_scheme = module.quantization_scheme + if not hasattr(quantization_scheme, "weights"): + return None # weights are not quantized + + quantization_args = quantization_scheme.weights + weight = getattr(module, "weight", None) + weight_scale = getattr(module, "weight_scale", None) + weight_zero_point = getattr(module, "weight_zero_point", None) + + return self.compress_weight( + weight=weight, + scale=weight_scale, + zero_point=weight_zero_point, + quantization_args=quantization_args, + ) + + def compress_weight( + self, + weight: Tensor, + **kwargs, + ) -> Dict[str, torch.Tensor]: + """ + Compresses a single uncompressed weight + + :param weight: uncompressed weight tensor + :param kwargs: additional arguments for compression + """ + raise NotImplementedError() + + def decompress_module(self, module: Module): + """ + Decompresses a single compressed leaf PyTorch module. If the module is not + quantized, this function has no effect. + + :param module: PyTorch module to decompress + :return: tensor of the decompressed weight, or None if module is not quantized + """ + if not hasattr(module, "quantization_scheme"): + return None # module is not quantized + quantization_scheme = module.quantization_scheme + if not hasattr(quantization_scheme, "weights"): + return None # weights are not quantized + + quantization_args = quantization_scheme.weights + compressed_data = {} + for name, parameter in module.named_parameters(): + compressed_data[name] = parameter + + return self.decompress_weight( + compressed_data=compressed_data, quantization_args=quantization_args + ) + + def decompress_weight( + self, compressed_data: Dict[str, Tensor], **kwargs + ) -> torch.Tensor: + """ + Decompresses a single compressed weight + + :param compressed_data: dictionary of data needed for decompression + :param kwargs: additional arguments for decompression + :return: tensor of the decompressed weight + """ + raise NotImplementedError() diff --git a/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/helpers.py b/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..7b03a9a150d9aecaea347ad91748c930edf497ef --- /dev/null +++ b/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/helpers.py @@ -0,0 +1,137 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path +from typing import Dict, Generator, Optional, Tuple, Union + +import torch +from compressed_tensors.compressors import BaseCompressor +from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig +from compressed_tensors.utils.safetensors_load import get_weight_mappings +from safetensors import safe_open +from safetensors.torch import save_file +from torch import Tensor + + +__all__ = [ + "load_compressed", + "save_compressed", + "save_compressed_model", +] + + +def save_compressed( + tensors: Dict[str, Tensor], + save_path: Union[str, Path], + compression_format: Optional[CompressionFormat] = None, +): + """ + Save compressed tensors to disk. If tensors are not compressed, + save them as is. + + :param tensors: dictionary of tensors to compress + :param save_path: path to save compressed tensors + :param compression_format: compression format used for the tensors + :return: compression config, if tensors were compressed - None otherwise + """ + if tensors is None or len(tensors) == 0: + raise ValueError("No tensors or empty tensors provided to compress") + + # if no compression_format specified, default to `dense` + compression_format = compression_format or CompressionFormat.dense.value + + if not ( + compression_format in BaseCompressor.registered_names() + or compression_format in BaseCompressor.registered_aliases() + ): + raise ValueError( + f"Unknown compression format: {compression_format}. " + f"Must be one of {set(BaseCompressor.registered_names() + BaseCompressor.registered_aliases())}" # noqa E501 + ) + + # compress + compressor = BaseCompressor.load_from_registry(compression_format) + # save compressed tensors + compressed_tensors = compressor.compress(tensors) + save_file(compressed_tensors, save_path) + + +def load_compressed( + compressed_tensors: Union[str, Path], + compression_config: SparsityCompressionConfig = None, + device: Optional[str] = "cpu", +) -> Generator[Tuple[str, Tensor], None, None]: + """ + Load compressed tensors from disk. + If tensors are not compressed, load them as is. + + :param compressed_tensors: path to compressed tensors. + This can be a path to a file or a directory containing + one or multiple safetensor files (if multiple - in the format + assumed by huggingface) + :param compression_config: compression config to use for decompressing tensors. + :param device: device to move tensors to. If None, tensors are loaded on CPU. + :param return_dict: if True, return a dictionary of decompressed tensors + :return a generator that yields the name and tensor of the decompressed tensor + """ + if compressed_tensors is None or not Path(compressed_tensors).exists(): + raise ValueError("No compressed tensors provided to load") + + if ( + compression_config is None + or compression_config.format == CompressionFormat.dense.value + ): + # if no compression_config specified, or `dense` format specified, + # assume tensors are not compressed on disk + weight_mappings = get_weight_mappings(compressed_tensors) + for weight_name, file_with_weight_name in weight_mappings.items(): + with safe_open(file_with_weight_name, framework="pt", device=device) as f: + weight = f.get_tensor(weight_name) + yield weight_name, weight + else: + # decompress tensors + compression_format = compression_config.format + compressor = BaseCompressor.load_from_registry( + compression_format, config=compression_config + ) + yield from compressor.decompress(compressed_tensors, device=device) + + +def save_compressed_model( + model: torch.nn.Module, + filename: str, + compression_format: Optional[CompressionFormat] = None, + force_contiguous: bool = True, +): + """ + Wrapper around safetensors `save_model` helper function, which allows for + saving compressed model to disk. + + Note: The model is assumed to have a + state_dict with unique entries + + :param model: model to save on disk + :param filename: filename location to save the file + :param compression_format: compression format used for the model + :param force_contiguous: forcing the state_dict to be saved as contiguous tensors + """ + state_dict = model.state_dict() + if force_contiguous: + state_dict = {k: v.contiguous() for k, v in state_dict.items()} + try: + save_compressed(state_dict, filename, compression_format=compression_format) + except ValueError as e: + msg = str(e) + msg += " Or use save_compressed_model(..., force_contiguous=True), read the docs for potential caveats." # noqa E501 + raise ValueError(msg) diff --git a/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/quantized_compressors/__init__.py b/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/quantized_compressors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..51e8b8e2c6ba6f8630b20dea222b971ad69e6f10 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/quantized_compressors/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# flake8: noqa + +from .base import * +from .naive_quantized import * +from .pack_quantized import * diff --git a/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/quantized_compressors/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/quantized_compressors/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5e661f56fe017b915bfb7e7ba8e94fe436995d4 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/quantized_compressors/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/quantized_compressors/__pycache__/base.cpython-311.pyc b/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/quantized_compressors/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4381050417ae65dee02a1eecdd4381be07a9274e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/quantized_compressors/__pycache__/base.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/quantized_compressors/__pycache__/naive_quantized.cpython-311.pyc b/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/quantized_compressors/__pycache__/naive_quantized.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..72b5236881f29857d90fc6b2399ece69f258d7da Binary files /dev/null and b/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/quantized_compressors/__pycache__/naive_quantized.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/quantized_compressors/__pycache__/pack_quantized.cpython-311.pyc b/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/quantized_compressors/__pycache__/pack_quantized.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c42df2c21702f5a73deb8b92c963d91e349c5bd2 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/quantized_compressors/__pycache__/pack_quantized.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/quantized_compressors/base.py b/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/quantized_compressors/base.py new file mode 100644 index 0000000000000000000000000000000000000000..b49361d4e31029b94edb10803cd85aa897af4c58 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/quantized_compressors/base.py @@ -0,0 +1,176 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from pathlib import Path +from typing import Any, Dict, Generator, Tuple, Union + +import torch +from compressed_tensors.compressors.base import BaseCompressor +from compressed_tensors.quantization import QuantizationArgs +from compressed_tensors.utils import ( + get_nested_mappings_from_state_dict, + get_nested_weight_mappings, + merge_names, +) +from safetensors import safe_open +from torch import Tensor +from tqdm import tqdm + + +_LOGGER: logging.Logger = logging.getLogger(__name__) + +__all__ = ["BaseQuantizationCompressor"] + + +class BaseQuantizationCompressor(BaseCompressor): + """ + Base class representing a quant compression algorithm. Each child class should + implement compression_param_info, compress_weight and decompress_weight. + + Compressors support compressing/decompressing a full module state dict or a single + quantized PyTorch leaf module. + + Model Load Lifecycle (run_compressed=False): + - ModelCompressor.decompress() + - apply_quantization_config() + - BaseQuantizationCompressor.decompress() + - BaseQuantizationCompressor.decompress_weight() + + Model Save Lifecycle: + - ModelCompressor.compress() + - BaseQuantizationCompressor.compress() + - BaseQuantizationCompressor.compress_weight() + + Module Lifecycle (run_compressed=True): + - apply_quantization_config() + - compressed_module = CompressedLinear(module) + - initialize_module_for_quantization() + - BaseQuantizationCompressor.compression_param_info() + - register_parameters() + - compressed_module.forward() + - compressed_module.decompress() + + + :param config: config specifying compression parameters + """ + + def compress( + self, + model_state: Dict[str, Tensor], + names_to_scheme: Dict[str, QuantizationArgs], + **kwargs, + ) -> Dict[str, Tensor]: + """ + Compresses a dense state dict + + :param model_state: state dict of uncompressed model + :param names_to_scheme: quantization args for each quantized weight, needed for + quantize function to calculate bit depth + :return: compressed state dict + """ + compressed_dict = {} + weight_suffix = ".weight" + _LOGGER.debug( + f"Compressing model with {len(model_state)} parameterized layers..." + ) + + for name, value in tqdm(model_state.items(), desc="Quantized Compression"): + if name.endswith(weight_suffix): + prefix = name[: -(len(weight_suffix))] + scale = model_state.get(merge_names(prefix, "weight_scale"), None) + zp = model_state.get(merge_names(prefix, "weight_zero_point"), None) + g_idx = model_state.get(merge_names(prefix, "weight_g_idx"), None) + if scale is not None: + # weight is quantized, compress it + quant_args = names_to_scheme[prefix] + compressed_data = self.compress_weight( + weight=value, + scale=scale, + zero_point=zp, + g_idx=g_idx, + quantization_args=quant_args, + device="cpu", + ) + for key, value in compressed_data.items(): + compressed_dict[merge_names(prefix, key)] = value + else: + compressed_dict[name] = value.to("cpu") + elif name.endswith("zero_point") and torch.all(value == 0): + continue + elif name.endswith("g_idx") and torch.any(value <= -1): + continue + else: + compressed_dict[name] = value.to("cpu") + + return compressed_dict + + def decompress( + self, + path_to_model_or_tensors: Union[str, Path, Dict[str, Any]], + names_to_scheme: Dict[str, QuantizationArgs], + device: str = "cpu", + ) -> Generator[Tuple[str, Tensor], None, None]: + """ + Reads a compressed state dict located at path_to_model_or_tensors + and returns a generator for sequentially decompressing back to a + dense state dict + :param path_to_model_or_tensors: path to compressed safetensors model (directory + with one or more safetensors files) or compressed tensors file + :param names_to_scheme: quantization args for each quantized weight + :param device: optional device to load intermediate weights into + :return: compressed state dict + """ + if isinstance(path_to_model_or_tensors, (str, Path)): + yield from self._decompress_from_path( + path_to_model_or_tensors, names_to_scheme, device + ) + + else: + yield from self._decompress_from_state_dict( + path_to_model_or_tensors, names_to_scheme + ) + + def _decompress_from_path(self, path_to_model, names_to_scheme, device): + weight_mappings = get_nested_weight_mappings( + path_to_model, self.COMPRESSION_PARAM_NAMES + ) + for weight_name in weight_mappings.keys(): + weight_data = {} + for param_name, safe_path in weight_mappings[weight_name].items(): + full_name = merge_names(weight_name, param_name) + with safe_open(safe_path, framework="pt", device=device) as f: + weight_data[param_name] = f.get_tensor(full_name) + if "weight_scale" in weight_data: + quant_args = names_to_scheme[weight_name] + decompressed = self.decompress_weight( + compressed_data=weight_data, quantization_args=quant_args + ) + yield merge_names(weight_name, "weight"), decompressed + + def _decompress_from_state_dict(self, state_dict, names_to_scheme): + weight_mappings = get_nested_mappings_from_state_dict( + state_dict, self.COMPRESSION_PARAM_NAMES + ) + for weight_name in weight_mappings.keys(): + weight_data = {} + for param_name, param_value in weight_mappings[weight_name].items(): + weight_data[param_name] = param_value + + if "weight_scale" in weight_data: + quant_args = names_to_scheme[weight_name] + decompressed = self.decompress_weight( + compressed_data=weight_data, quantization_args=quant_args + ) + yield merge_names(weight_name, "weight"), decompressed diff --git a/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/quantized_compressors/naive_quantized.py b/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/quantized_compressors/naive_quantized.py new file mode 100644 index 0000000000000000000000000000000000000000..69d9d596d65f848fb967c417ff897cf5bc0a88f1 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/quantized_compressors/naive_quantized.py @@ -0,0 +1,142 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Optional, Tuple + +import torch +from compressed_tensors.compressors.base import BaseCompressor +from compressed_tensors.compressors.quantized_compressors.base import ( + BaseQuantizationCompressor, +) +from compressed_tensors.config import CompressionFormat +from compressed_tensors.quantization import QuantizationArgs +from compressed_tensors.quantization.lifecycle.forward import dequantize, quantize +from compressed_tensors.quantization.utils import can_quantize +from torch import Tensor + + +__all__ = [ + "NaiveQuantizationCompressor", + "IntQuantizationCompressor", + "FloatQuantizationCompressor", +] + + +@BaseCompressor.register(name=CompressionFormat.naive_quantized.value) +class NaiveQuantizationCompressor(BaseQuantizationCompressor): + """ + Implements naive compression for quantized models. Weight of each + quantized layer is converted from its original float type to the closest Pytorch + type to the type specified by the layer's QuantizationArgs. + """ + + COMPRESSION_PARAM_NAMES = [ + "weight", + "weight_scale", + "weight_zero_point", + "weight_g_idx", + ] + + def compression_param_info( + self, + weight_shape: torch.Size, + quantization_args: Optional[QuantizationArgs] = None, + ) -> Dict[str, Tuple[torch.Size, torch.dtype]]: + """ + Creates a dictionary of expected shapes and dtypes for each compression + parameter used by the compressor + + :param weight_shape: uncompressed weight shape + :param quantization_args: quantization parameters for the weight + :return: dictionary mapping compressed parameter names to shape and dtype + """ + dtype = quantization_args.pytorch_dtype() + return {"weight": (weight_shape, dtype)} + + def compress_weight( + self, + weight: Tensor, + scale: Tensor, + quantization_args: QuantizationArgs, + zero_point: Optional[Tensor] = None, + g_idx: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + ) -> Dict[str, torch.Tensor]: + """ + Compresses a single uncompressed weight + + :param weight: uncompressed weight tensor + :param scale: quantization scale for weight + :param quantization_args: quantization parameters for weight + :param zero_point: quantization zero point for weight + :param g_idx: optional mapping from column index to group index + :param device: optional device to move compressed output to + :return: dictionary of compressed weight data + """ + if can_quantize(weight, quantization_args): + quantized_weight = quantize( + x=weight, + scale=scale, + zero_point=zero_point, + g_idx=g_idx, + args=quantization_args, + dtype=quantization_args.pytorch_dtype(), + ) + else: + quantized_weight = weight + + if device is not None: + quantized_weight = quantized_weight.to(device) + + return {"weight": quantized_weight} + + def decompress_weight( + self, + compressed_data: Dict[str, Tensor], + quantization_args: Optional[QuantizationArgs] = None, + ) -> torch.Tensor: + """ + Decompresses a single compressed weight + + :param compressed_data: dictionary of data needed for decompression + :param quantization_args: quantization parameters for the weight + :return: tensor of the decompressed weight + """ + weight = compressed_data["weight"] + scale = compressed_data["weight_scale"] + zero_point = compressed_data.get("weight_zero_point", None) + g_idx = compressed_data.get("weight_g_idx", None) + decompressed_weight = dequantize( + x_q=weight, scale=scale, zero_point=zero_point, g_idx=g_idx + ) + + return decompressed_weight + + +@BaseCompressor.register(name=CompressionFormat.int_quantized.value) +class IntQuantizationCompressor(NaiveQuantizationCompressor): + """ + Alias for integer quantized models + """ + + pass + + +@BaseCompressor.register(name=CompressionFormat.float_quantized.value) +class FloatQuantizationCompressor(NaiveQuantizationCompressor): + """ + Alias for fp quantized models + """ + + pass diff --git a/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/quantized_compressors/pack_quantized.py b/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/quantized_compressors/pack_quantized.py new file mode 100644 index 0000000000000000000000000000000000000000..629ef37e3e4ac4ac7e5afcf2b3125b460f3406ae --- /dev/null +++ b/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/quantized_compressors/pack_quantized.py @@ -0,0 +1,213 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import Dict, Optional, Tuple + +import numpy as np +import torch +from compressed_tensors.compressors.base import BaseCompressor +from compressed_tensors.compressors.quantized_compressors.base import ( + BaseQuantizationCompressor, +) +from compressed_tensors.config import CompressionFormat +from compressed_tensors.quantization import QuantizationArgs +from compressed_tensors.quantization.lifecycle.forward import dequantize, quantize +from compressed_tensors.quantization.utils import can_quantize +from torch import Tensor + + +__all__ = ["PackedQuantizationCompressor", "pack_to_int32", "unpack_from_int32"] + + +@BaseCompressor.register(name=CompressionFormat.pack_quantized.value) +class PackedQuantizationCompressor(BaseQuantizationCompressor): + """ + Compresses a quantized model by packing every eight 4-bit weights into an int32 + """ + + COMPRESSION_PARAM_NAMES = [ + "weight_packed", + "weight_scale", + "weight_zero_point", + "weight_g_idx", + "weight_shape", + ] + + def compression_param_info( + self, + weight_shape: torch.Size, + quantization_args: Optional[QuantizationArgs] = None, + ) -> Dict[str, Tuple[torch.Size, torch.dtype]]: + """ + Creates a dictionary of expected shapes and dtypes for each compression + parameter used by the compressor + + :param weight_shape: uncompressed weight shape + :param quantization_args: quantization parameters for the weight + :return: dictionary mapping compressed parameter names to shape and dtype + """ + pack_factor = 32 // quantization_args.num_bits + packed_size = math.ceil(weight_shape[1] / pack_factor) + return { + "weight_packed": (torch.Size((weight_shape[0], packed_size)), torch.int32), + "weight_shape": (torch.Size((2,)), torch.int32), + } + + def compress_weight( + self, + weight: Tensor, + scale: Tensor, + quantization_args: QuantizationArgs, + zero_point: Optional[Tensor] = None, + g_idx: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + ) -> Dict[str, torch.Tensor]: + """ + Compresses a single uncompressed weight + + :param weight: uncompressed weight tensor + :param scale: quantization scale for weight + :param quantization_args: quantization parameters for weight + :param zero_point: quantization zero point for weight + :param g_idx: optional mapping from column index to group index + :param device: optional device to move compressed output to + :return: dictionary of compressed weight data + """ + compressed_dict = {} + if can_quantize(weight, quantization_args): + quantized_weight = quantize( + x=weight, + scale=scale, + zero_point=zero_point, + g_idx=g_idx, + args=quantization_args, + dtype=torch.int8, + ) + else: + quantized_weight = weight + + packed_weight = pack_to_int32(quantized_weight, quantization_args.num_bits) + weight_shape = torch.tensor(weight.shape) + if device is not None: + packed_weight = packed_weight.to(device) + weight_shape = weight_shape.to(device) + + compressed_dict["weight_shape"] = weight_shape + compressed_dict["weight_packed"] = packed_weight + + return compressed_dict + + def decompress_weight( + self, + compressed_data: Dict[str, Tensor], + quantization_args: Optional[QuantizationArgs] = None, + ) -> torch.Tensor: + """ + Decompresses a single compressed weight + + :param compressed_data: dictionary of data needed for decompression + :param quantization_args: quantization parameters for the weight + :return: tensor of the decompressed weight + """ + weight = compressed_data["weight_packed"] + scale = compressed_data["weight_scale"] + zero_point = compressed_data.get("weight_zero_point", None) + g_idx = compressed_data.get("weight_g_idx", None) + original_shape = torch.Size(compressed_data["weight_shape"]) + num_bits = quantization_args.num_bits + unpacked = unpack_from_int32(weight, num_bits, original_shape) + decompressed_weight = dequantize( + x_q=unpacked, scale=scale, zero_point=zero_point, g_idx=g_idx + ) + + return decompressed_weight + + +def pack_to_int32(value: torch.Tensor, num_bits: int) -> torch.Tensor: + """ + Packs a tensor of quantized weights stored in int8 into int32s with padding + + :param value: tensor to pack + :param num_bits: number of bits used to store underlying data + :returns: packed int32 tensor + """ + if value.dtype is not torch.int8: + raise ValueError("Tensor must be quantized to torch.int8 before packing") + + if num_bits > 8: + raise ValueError("Packing is only supported for less than 8 bits") + + # convert to unsigned for packing + offset = pow(2, num_bits) // 2 + value = (value + offset).to(torch.uint8) + value = value.cpu().numpy().astype(np.uint32) + pack_factor = 32 // num_bits + + # pad input tensor and initialize packed output + packed_size = math.ceil(value.shape[1] / pack_factor) + packed = np.zeros((value.shape[0], packed_size), dtype=np.uint32) + padding = packed.shape[1] * pack_factor - value.shape[1] + value = np.pad(value, pad_width=[(0, 0), (0, padding)], constant_values=0) + + # pack values + for i in range(pack_factor): + packed |= value[:, i::pack_factor] << num_bits * i + + # convert back to signed and torch + packed = np.ascontiguousarray(packed).view(np.int32) + return torch.from_numpy(packed) + + +def unpack_from_int32( + value: torch.Tensor, num_bits: int, shape: torch.Size +) -> torch.Tensor: + """ + Unpacks a tensor of packed int32 weights into individual int8s, maintaining the + original their bit range + + :param value: tensor to upack + :param num_bits: number of bits to unpack each data point into + :param shape: shape to unpack into, used to remove padding + :returns: unpacked int8 tensor + """ + if value.dtype is not torch.int32: + raise ValueError( + f"Expected {torch.int32} but got {value.dtype}, Aborting unpack." + ) + + if num_bits > 8: + raise ValueError("Unpacking is only supported for less than 8 bits") + + pack_factor = 32 // num_bits + + # unpack + mask = pow(2, num_bits) - 1 + unpacked = torch.zeros( + (value.shape[0], value.shape[1] * pack_factor), + device=value.device, + dtype=torch.int32, + ) + for i in range(pack_factor): + unpacked[:, i::pack_factor] = (value >> (num_bits * i)) & mask + + # remove padding + original_row_size = int(shape[1]) + unpacked = unpacked[:, :original_row_size] + + # bits are packed in unsigned format, reformat to signed + # update the value range from unsigned to signed + offset = pow(2, num_bits) // 2 + unpacked = (unpacked - offset).to(torch.int8) + + return unpacked diff --git a/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/sparse_compressors/__init__.py b/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/sparse_compressors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..871079ac9abcfbbe5852f441738ba2ff1808c846 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/sparse_compressors/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# flake8: noqa + +from .base import * +from .dense import * +from .sparse_24_bitmask import * +from .sparse_bitmask import * diff --git a/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/sparse_compressors/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/sparse_compressors/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..554bf588e3508b66a30d7845dd335dd6c0bcbb65 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/sparse_compressors/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/sparse_compressors/__pycache__/base.cpython-311.pyc b/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/sparse_compressors/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae7c51ea47f7c17be0ffb275c5b1705b3eb1e810 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/sparse_compressors/__pycache__/base.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/sparse_compressors/__pycache__/dense.cpython-311.pyc b/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/sparse_compressors/__pycache__/dense.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f512f8100d158390ee63722f0d83f4743164cf7d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/sparse_compressors/__pycache__/dense.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/sparse_compressors/__pycache__/sparse_24_bitmask.cpython-311.pyc b/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/sparse_compressors/__pycache__/sparse_24_bitmask.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c937de3e38163d19c959d89ac8caf2026285b2b2 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/sparse_compressors/__pycache__/sparse_24_bitmask.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/sparse_compressors/__pycache__/sparse_bitmask.cpython-311.pyc b/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/sparse_compressors/__pycache__/sparse_bitmask.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..68538121cc85c72b2c466b863e2c3cda987b91d9 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/sparse_compressors/__pycache__/sparse_bitmask.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/sparse_compressors/base.py b/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/sparse_compressors/base.py new file mode 100644 index 0000000000000000000000000000000000000000..7cd6e8e8b4a5393280469bb2fede0868a597af24 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/sparse_compressors/base.py @@ -0,0 +1,148 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Dict, Generator, Optional, Set, Tuple + +from compressed_tensors.compressors.base import BaseCompressor +from compressed_tensors.utils import get_nested_weight_mappings, merge_names +from safetensors import safe_open +from torch import Tensor +from tqdm import tqdm + + +__all__ = ["BaseSparseCompressor"] + +_LOGGER: logging.Logger = logging.getLogger(__name__) + + +class BaseSparseCompressor(BaseCompressor): + """ + Base class representing a sparse compression algorithm. Each child class should + implement compression_param_info, compress_weight and decompress_weight; child + classes should also define COMPRESSION_PARAM_NAMES. + + Compressors support compressing/decompressing a full module state dict or a single + quantized PyTorch leaf module. + + Model Load Lifecycle (run_compressed=False): + - ModelCompressor.decompress() + - apply_quantization_config() + - BaseSparseCompressor.decompress() + - BaseSparseCompressor.decompress_weight() + + Model Save Lifecycle: + - ModelCompressor.compress() + - BaseSparseCompressor.compress() + - BaseSparseCompressor.compress_weight() + + Module Lifecycle (run_compressed=True): + - apply_quantization_config() + - compressed_module = CompressedLinear(module) + - initialize_module_for_quantization() + - BaseSparseCompressor.compression_param_info() + - register_parameters() + - compressed_module.forward() + - compressed_module.decompress() + + + :param config: config specifying compression parameters + """ + + def compress( + self, + model_state: Dict[str, Tensor], + compression_targets: Optional[Set[str]] = None, + ) -> Dict[str, Tensor]: + """ + Compresses a dense state dict using bitmask compression + + :param model_state: state dict of uncompressed model + :param compression_targets: optional set of layer prefixes to compress, + otherwise compress all layers (for backwards compatibility) + :return: compressed state dict + """ + compressed_dict = {} + _LOGGER.debug( + f"Compressing model with {len(model_state)} parameterized layers..." + ) + for name, value in tqdm(model_state.items(), desc="Compressing model"): + if not self.should_compress(name, compression_targets): + compressed_dict[name] = value + continue + prefix = name + if prefix.endswith(".weight"): + prefix = prefix[: -(len(".weight"))] + + compression_data = self.compress_weight(prefix, value) + for key in compression_data.keys(): + if key in compressed_dict: + _LOGGER.warn( + f"Expected all compressed state_dict keys to be unique, but " + f"found an existing entry for {key}. The existing entry will " + "be replaced." + ) + + compressed_dict.update(compression_data) + + return compressed_dict + + def decompress( + self, path_to_model_or_tensors: str, device: str = "cpu", **kwargs + ) -> Generator[Tuple[str, Tensor], None, None]: + """ + Reads a bitmask compressed state dict located + at path_to_model_or_tensors and returns a generator + for sequentially decompressing back to a dense state dict + + :param model_path: path to compressed safetensors model (directory with + one or more safetensors files) or compressed tensors file + :param device: device to load decompressed weights onto + :return: iterator for generating decompressed weights + """ + weight_mappings, ignored_params = get_nested_weight_mappings( + path_to_model_or_tensors, + self.COMPRESSION_PARAM_NAMES, + return_unmatched_params=True, + ) + for weight_name in weight_mappings.keys(): + weight_data = {} + for param_name, safe_path in weight_mappings[weight_name].items(): + full_name = merge_names(weight_name, param_name) + with safe_open(safe_path, framework="pt", device=device) as f: + weight_data[param_name] = f.get_tensor(full_name) + decompressed = self.decompress_weight(weight_data) + yield merge_names(weight_name, "weight"), decompressed + + for ignored_param_name, safe_path in ignored_params.items(): + with safe_open(safe_path, framework="pt", device=device) as f: + value = f.get_tensor(ignored_param_name) + yield ignored_param_name, value + + @staticmethod + def should_compress(name: str, expanded_targets: Optional[Set[str]] = None) -> bool: + """ + Check if a parameter should be compressed. + Currently, this only returns True for weight parameters. + + :param name: name of the parameter + :param expanded_targets: set of layer prefixes to compress + :return: whether or not the parameter should be compressed + """ + if expanded_targets is None: + return name.endswith(".weight") + + return ( + name.endswith(".weight") and name[: -(len(".weight"))] in expanded_targets + ) diff --git a/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/sparse_compressors/dense.py b/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/sparse_compressors/dense.py new file mode 100644 index 0000000000000000000000000000000000000000..53aaf6c81ac3840dd78d83f2fee1075a2a4af5fa --- /dev/null +++ b/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/sparse_compressors/dense.py @@ -0,0 +1,34 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Generator, Tuple + +from compressed_tensors.compressors.base import BaseCompressor +from compressed_tensors.config import CompressionFormat +from torch import Tensor + + +@BaseCompressor.register(name=CompressionFormat.dense.value) +class DenseCompressor(BaseCompressor): + """ + Identity compressor for dense models, returns the original state_dict + """ + + def compress(self, model_state: Dict[str, Tensor], **kwargs) -> Dict[str, Tensor]: + return model_state + + def decompress( + self, path_to_model_or_tensors: str, device: str = "cpu", **kwargs + ) -> Generator[Tuple[str, Tensor], None, None]: + return iter([]) diff --git a/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py b/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py new file mode 100644 index 0000000000000000000000000000000000000000..0ce044ccb795ec1a322864512ec9a46bd6e119cf --- /dev/null +++ b/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py @@ -0,0 +1,240 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Dict, List, Tuple, Union + +import torch +from compressed_tensors.compressors.base import BaseCompressor +from compressed_tensors.compressors.sparse_compressors.base import BaseSparseCompressor +from compressed_tensors.config import CompressionFormat, SparsityStructure +from compressed_tensors.quantization import FP8_DTYPE +from compressed_tensors.utils import merge_names, pack_bitmasks, unpack_bitmasks +from torch import Tensor + + +__all__ = [ + "Sparse24BitMaskCompressor", + "Sparse24BitMaskTensor", + "sparse24_bitmask_compress", + "sparse24_bitmask_decompress", + "get_24_bytemasks", +] + + +@BaseCompressor.register(name=CompressionFormat.sparse_24_bitmask.value) +class Sparse24BitMaskCompressor(BaseSparseCompressor): + """ + Compression for sparse models using bitmasks. Non-zero weights are stored in a 2d + values tensor, with their locations stored in a 2d bitmask + """ + + COMPRESSION_PARAM_NAMES = [ + "shape", + "compressed", + "bitmask", + ] + + def compress_weight(self, name, value): + bitmask_tensor = Sparse24BitMaskTensor.from_dense( + value, self.config.sparsity_structure + ) + bitmask_dict = bitmask_tensor.dict(name_prefix=name, device="cpu") + return bitmask_dict + + def decompress_weight(self, weight_data): + data = Sparse24BitMaskTensor.from_compressed_data(**weight_data) + decompressed = data.decompress() + return decompressed + + +@dataclass +class Sparse24BitMaskTensor: + """ + Owns compressions and decompression for a single 2:4 sparse + bitmask compressed tensor. + + :param shape: shape of dense tensor + :param compressed: 2d tensor of non-zero values + :param bitmask: 2d bitmask of non-zero values + """ + + shape: List[int] + compressed: Tensor + bitmask: Tensor + + @staticmethod + def from_dense( + tensor: Tensor, + sparsity_structure: Union[SparsityStructure, str] = SparsityStructure.TWO_FOUR, + ) -> "Sparse24BitMaskTensor": + """ + :param tensor: dense tensor to compress + :return: instantiated compressed tensor + """ + shape = list(tensor.shape) + compressed, bitmask = sparse24_bitmask_compress( + tensor.cpu(), sparsity_structure=sparsity_structure + ) + return Sparse24BitMaskTensor( + shape=shape, + compressed=compressed, + bitmask=bitmask, + ) + + @staticmethod + def from_compressed_data( + shape: Union[List[int], Tensor], compressed: Tensor, bitmask: Tensor + ) -> "Sparse24BitMaskTensor": + """ + :param shape: shape of the dense tensor (can be a list or a tensor) + :param compressed: 2d tensor of non-zero values + :param bitmask: 2d bitmask of non-zero values + :return: instantiated Sparse24BitMaskTensor + """ + if isinstance(shape, list): + shape = torch.tensor(shape) + if isinstance(shape, torch.Tensor): + shape = shape.flatten().tolist() + return Sparse24BitMaskTensor( + shape=shape, compressed=compressed, bitmask=bitmask + ) + + def decompress(self) -> Tensor: + """ + :return: reconstructed dense tensor + """ + return sparse24_bitmask_decompress(self.compressed, self.bitmask, self.shape) + + def curr_memory_size_bytes(self) -> int: + """ + :return: size in bytes required to store compressed tensor on disk + """ + + def sizeof_tensor(a: Tensor) -> int: + return a.element_size() * a.nelement() + + return sizeof_tensor(self.compressed) + sizeof_tensor(self.bitmask) + + def dict(self, name_prefix: str, device: str = "cpu") -> Dict[str, Tensor]: + """ + :param name_prefix: name of original tensor to store compressed weight as + :return: dict of compressed data for the stored weight + """ + if name_prefix.endswith(".weight"): + name_prefix = name_prefix[: -len(".weight")] + return { + merge_names(name_prefix, "shape"): torch.tensor( + self.shape, device=device + ).reshape(-1, 1), + merge_names(name_prefix, "compressed"): self.compressed.to(device), + merge_names(name_prefix, "bitmask"): self.bitmask.to(device), + } + + def __repr__(self) -> str: + return f"BitMaskTensor(shape={self.shape}, compressed=True)" + + +def sparse24_bitmask_compress( + tensor: Tensor, + sparsity_structure: Union[SparsityStructure, str] = SparsityStructure.TWO_FOUR, +) -> Tuple[Tensor, Tensor, Tensor]: + """ + Compresses a dense tensor using bitmask compression + + :param tensor: dense 2D tensor to compress + :param sparsity_structure: structure of sparsity in the tensor, defaults + to unstructured, can also be set to `2:4` + :return: tuple of compressed data representing tensor + """ + assert len(tensor.shape) == 2, "Only 2D tensors are supported" + assert ( + SparsityStructure(sparsity_structure) == SparsityStructure.TWO_FOUR + ), "Only 2:4 sparsity is supported" + + bytemasks = get_24_bytemasks(tensor=tensor) + + if tensor.dtype == FP8_DTYPE: + # acces raw bytes of the tensor + tensor_view = tensor.view(torch.int8) + values = tensor_view[bytemasks] + values = values.view(FP8_DTYPE) + else: + values = tensor[bytemasks] + + num_rows, num_cols = tensor.shape + compressed_values = values.reshape(num_rows, num_cols // 2) + bitmasks_packed = pack_bitmasks(bytemasks) + return compressed_values, bitmasks_packed + + +def sparse24_bitmask_decompress( + values: Tensor, bitmasks: Tensor, original_shape: torch.Size +) -> Tensor: + """ + Reconstructs a dense tensor from a compressed one + + :param values: 1d tensor of non-zero values + :param bitmasks: 2d int8 tensor flagging locations of non-zero values in the + tensors original shape + :param original_shape: shape of the dense tensor + :return: decompressed dense tensor + """ + bytemasks_unpacked = unpack_bitmasks(bitmasks, original_shape) + + decompressed_tensor = torch.zeros(original_shape, dtype=values.dtype) + decompressed_tensor = decompressed_tensor.to(values.device) + values = values.flatten() + if decompressed_tensor.dtype == FP8_DTYPE: + decompressed_tensor[bytemasks_unpacked] = values + decompressed_tensor = decompressed_tensor.cuda() + else: + decompressed_tensor[bytemasks_unpacked] = values + return decompressed_tensor + + +def get_24_bytemasks(tensor): + """ + Generate a 2:4 sparsity mask for the given tensor. + + This function creates a mask where exactly 2 out of every 4 elements are + preserved based on their magnitudes. The preserved elements are the ones + with the highest absolute values in each group of 4 elements. + + :param tensor: The input tensor for which the 2:4 sparsity mask is to be created. + The tensor can be of any shape but its total number of elements + must be a multiple of 4. + :return: A boolean tensor of the same shape as the input tensor, where `True` + indicates the preserved elements and `False` indicates the pruned elements. + :raises ValueError: If the total number of elements in the tensor is not a + multiple of 4. + """ + original_dtype = tensor.dtype + if tensor.dtype == FP8_DTYPE: + tensor = tensor.view(torch.int8) + original_shape = tensor.shape + num_elements = tensor.numel() + + if num_elements % 4 != 0: + raise ValueError("Tensor size must be a multiple of 4 for TWO_FOUR sparsity") + + reshaped_tensor = tensor.view(-1, 4) + abs_tensor = reshaped_tensor.abs() + topk_indices = abs_tensor.topk(2, dim=1).indices + mask = torch.zeros_like(reshaped_tensor, dtype=torch.bool) + mask.scatter_(1, topk_indices, True) + mask = mask.view(original_shape) + tensor = tensor.view(original_dtype) + + return mask diff --git a/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py b/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py new file mode 100644 index 0000000000000000000000000000000000000000..7c2023cf571bfdf5b1810cc0e44abad2b3305c4a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py @@ -0,0 +1,163 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, Tuple, Union + +import torch +from compressed_tensors.compressors.base import BaseCompressor +from compressed_tensors.compressors.sparse_compressors.base import BaseSparseCompressor +from compressed_tensors.config import CompressionFormat +from compressed_tensors.quantization import FP8_DTYPE +from compressed_tensors.utils import merge_names, pack_bitmasks, unpack_bitmasks +from torch import Tensor + + +__all__ = [ + "BitmaskCompressor", + "BitmaskTensor", + "bitmask_compress", + "bitmask_decompress", +] + + +@BaseCompressor.register(name=CompressionFormat.sparse_bitmask.value) +class BitmaskCompressor(BaseSparseCompressor): + """ + Compression for sparse models using bitmasks. Non-zero weights are stored in a 1d + values tensor, with their locations stored in a 2d bitmask + """ + + COMPRESSION_PARAM_NAMES = ["shape", "compressed", "bitmask", "row_offsets"] + + def compress_weight(self, name, value): + bitmask_tensor = BitmaskTensor.from_dense(value) + bitmask_dict = bitmask_tensor.dict(name_prefix=name, device="cpu") + return bitmask_dict + + def decompress_weight(self, weight_data): + data = BitmaskTensor(**weight_data) + decompressed = data.decompress() + return decompressed + + +class BitmaskTensor: + """ + Owns compressions and decompression for a single bitmask compressed tensor. + Adapted from: https://github.com/mgoin/torch_bitmask/tree/main + + :param shape: shape of dense tensor + :compressed: flat tensor of non-zero values + :bitmask: 2d bitmask of non-zero values + :row_offsets: flat tensor indicating what index in values each dense row starts at + """ + + def __init__( + self, + shape: Union[torch.Size, List], + compressed: Tensor, + bitmask: Tensor, + row_offsets: Tensor, + ): + self.shape = list(shape) + self.compressed = compressed + self.bitmask = bitmask + self.row_offsets = row_offsets + + @staticmethod + def from_dense(tensor: Tensor) -> "BitmaskTensor": + """ + :param tensor: dense tensor to compress + :return: instantiated compressed tensor + """ + shape = tensor.shape + compressed, bitmask, row_offsets = bitmask_compress(tensor.cpu()) + return BitmaskTensor( + shape=shape, compressed=compressed, bitmask=bitmask, row_offsets=row_offsets + ) + + def decompress(self) -> Tensor: + """ + :return: reconstructed dense tensor + """ + return bitmask_decompress(self.compressed, self.bitmask, self.shape) + + def curr_memory_size_bytes(self): + """ + :return: size in bytes required to store compressed tensor on disk + """ + + def sizeof_tensor(a): + return a.element_size() * a.nelement() + + return ( + sizeof_tensor(self.compressed) + + sizeof_tensor(self.bitmask) + + sizeof_tensor(self.row_offsets) + ) + + def dict(self, name_prefix: str, device: str = "cpu") -> Dict[str, Tensor]: + """ + :name_prefix: name of original tensor to store compressed weight as + :return: dict of compressed data for the stored weight + """ + return { + merge_names(name_prefix, "shape"): torch.tensor(self.shape, device=device), + merge_names(name_prefix, "compressed"): self.compressed.to(device), + merge_names(name_prefix, "bitmask"): self.bitmask.to(device), + merge_names(name_prefix, "row_offsets"): self.row_offsets.to(device), + } + + def __repr__(self): + return f"BitmaskTensor(shape={self.shape}, compressed=True)" + + +def bitmask_compress(tensor: Tensor) -> Tuple[Tensor, Tensor, Tensor]: + """ + Compresses a dense tensor using bitmask compression + + :param tensor: dense tensor to compress + :return: tuple of compressed data representing tensor + """ + bytemasks = tensor != 0 + row_counts = bytemasks.sum(dim=-1) + row_offsets = torch.cumsum(row_counts, 0) - row_counts + if tensor.dtype == FP8_DTYPE: + # acces raw bytes of the tensor + tensor_view = tensor.view(torch.int8) + values = tensor_view[bytemasks] + values = values.view(FP8_DTYPE) + else: + values = tensor[bytemasks] + bitmasks_packed = pack_bitmasks(bytemasks) + return values, bitmasks_packed, row_offsets + + +def bitmask_decompress( + values: Tensor, bitmasks: Tensor, original_shape: torch.Size +) -> Tensor: + """ + Reconstructs a dense tensor from a compressed one + + :param values: 1d tensor of non-zero values + :param bitmasks: 2d int8 tensor flagging locations of non-zero values in the + tensors original shape + :param original_shape: shape of the dense tensor + :return: decompressed dense tensor + """ + bytemasks_unpacked = unpack_bitmasks(bitmasks, original_shape) + + decompressed_tensor = torch.zeros(original_shape, dtype=values.dtype) + decompressed_tensor[bytemasks_unpacked] = values + + return decompressed_tensor diff --git a/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/sparse_quantized_compressors/__init__.py b/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/sparse_quantized_compressors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c3615f062c958821e70c0cde951352540b2f5055 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/sparse_quantized_compressors/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# flake8: noqa + +from .marlin_24 import Marlin24Compressor diff --git a/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py b/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py new file mode 100644 index 0000000000000000000000000000000000000000..4c544588085d2ae105df04935f1cf6732e7a4377 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py @@ -0,0 +1,251 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Dict, Generator, Tuple + +import numpy as np +import torch +from compressed_tensors.compressors.base import BaseCompressor +from compressed_tensors.config import CompressionFormat +from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy +from compressed_tensors.quantization.lifecycle.forward import quantize +from compressed_tensors.utils import ( + get_permutations_24, + is_quantization_param, + merge_names, + sparse_semi_structured_from_dense_cutlass, + tensor_follows_mask_structure, +) +from torch import Tensor +from tqdm import tqdm + + +_LOGGER: logging.Logger = logging.getLogger(__name__) + + +@BaseCompressor.register(name=CompressionFormat.marlin_24.value) +class Marlin24Compressor(BaseCompressor): + """ + Compresses a quantized model with 2:4 sparsity structure for inference with the + Marlin24 kernel. Decompression is not implemented for this compressor. + """ + + COMPRESSION_PARAM_NAMES = ["weight_packed", "scale_packed", "meta"] + + @staticmethod + def validate_quant_compatability( + model_quant_args: Dict[str, QuantizationArgs] + ) -> bool: + """ + Checks if every quantized module in the model is compatible with Marlin24 + compression. Quantization must be channel or group strategy with group_size + of 128. Only symmetric quantization is supported + + :param model_quant_args: dictionary of mapping module names to their + quantization configuration + :return: True if all modules are compatible with Marlin24 compression, raises + a ValueError otherwise + """ + for name, quant_args in model_quant_args.items(): + strategy = quant_args.strategy + group_size = quant_args.group_size + symmetric = quant_args.symmetric + if ( + strategy is not QuantizationStrategy.GROUP.value + and strategy is not QuantizationStrategy.CHANNEL.value + ): + raise ValueError( + f"Marlin24 Compressor is only valid for group and channel " + f"quantization strategies, got {strategy} in {name}" + ) + + if group_size is not None and group_size != 128: + raise ValueError( + f"Marlin24 Compressor is only valid for group size 128, " + f"got {group_size} in {name}" + ) + + if not symmetric: + raise ValueError( + f"Marlin24 Compressor is only valid for symmetric quantzation, " + f"got symmetric={symmetric} in {name}" + ) + + return True + + @staticmethod + def validate_sparsity_structure(name: str, weight: Tensor) -> bool: + """ + Checks if a tensor fits the required 2:4 sparsity structure + + :param name: name of the tensor to check + :param weight: tensor to check for sparsity structure + :return: True if all rows match the 2:4 sparsity structure, raises + ValueError otherwise + """ + + if not tensor_follows_mask_structure(weight): + raise ValueError( + "Marlin24 Compressor is only compatible with weights that have " + f"a 2:4 sparsity structure. Found segments in {name} " + "that do not match the expected structure." + ) + + return True + + def compress( + self, + model_state: Dict[str, Tensor], + names_to_scheme: Dict[str, QuantizationArgs], + **kwargs, + ) -> Dict[str, Tensor]: + """ + Compresses a quantized state_dict with 2:4 sparsity structure for inference + with the Marlin24 kernel + + :param model_state: state dict of uncompressed model + :param names_to_scheme: quantization args for each quantized weight, needed for + quantize function to calculate bit depth + :return: compressed state dict + """ + self.validate_quant_compatability(names_to_scheme) + + compressed_dict = {} + weight_suffix = ".weight" + _LOGGER.debug( + f"Compressing model with {len(model_state)} parameterized layers..." + ) + + for name, value in tqdm(model_state.items(), desc="Compressing model"): + if name.endswith(weight_suffix): + prefix = name[: -(len(weight_suffix))] + scale = model_state.get(merge_names(prefix, "weight_scale"), None) + zp = model_state.get(merge_names(prefix, "weight_zero_point"), None) + if scale is not None: # weight is quantized, compress it + + # Marlin24 kernel requires float16 inputs + scale = scale.to(torch.float16) + value = value.to(torch.float16) + + # quantize weight, keeping it as a float16 for now + quant_args = names_to_scheme[prefix] + value = quantize( + x=value, scale=scale, zero_point=zp, args=quant_args + ) + + # compress based on sparsity structure + self.validate_sparsity_structure(prefix, value) + value, meta = compress_weight_24(value) + meta = meta.cpu() + + # Marlin24 kernel expects input dim first + value = value.t().contiguous().cpu() + scale = scale.t().contiguous().cpu() + og_weight_shape = value.shape + + # Marlin24 kernel expects unsigned values, shift zero-point + value += (1 << quant_args.num_bits) // 2 + + # pack quantized weight and scale + value = pack_weight_24(value, quant_args) + packed_scale = pack_scales_24(scale, quant_args, og_weight_shape) + meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2) + + # save compressed values + compressed_dict[merge_names(prefix, "scale_packed")] = packed_scale + compressed_dict[merge_names(prefix, "weight_packed")] = value + compressed_dict[merge_names(prefix, "meta")] = meta + continue + + if not is_quantization_param(name): + # export unquantized parameters without modifying + compressed_dict[name] = value.to("cpu") + + return compressed_dict + + def decompress( + self, path_to_model_or_tensors: str, device: str = "cpu", **kwargs + ) -> Generator[Tuple[str, Tensor], None, None]: + raise NotImplementedError( + "Decompression is not implemented for the Marlin24 Compressor." + ) + + +def compress_weight_24(weight: Tensor): + weight = weight.contiguous() + w_comp, meta = sparse_semi_structured_from_dense_cutlass(weight) + w_comp = w_comp.contiguous() + return w_comp, meta + + +def marlin_permute_weights(q_w, size_k, size_n, perm, tile): + assert q_w.shape == (size_k, size_n) + assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}" + assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}" + + # Permute weights to 16x64 marlin tiles + q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile)) + q_w = q_w.permute((0, 2, 1, 3)) + q_w = q_w.reshape((size_k // tile, size_n * tile)) + + q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape) + + return q_w + + +def pack_weight_24( + weight: Tensor, + quantization_args: QuantizationArgs, + tile: int = 16, +): + size_k = weight.shape[0] + size_n = weight.shape[1] + num_bits = quantization_args.num_bits + pack_factor = 32 // num_bits + + # Reshuffle to marlin_24 format + perm, _, _ = get_permutations_24(num_bits) + q_w = marlin_permute_weights(weight, size_k, size_n, perm, tile) + + q_w = q_w.cpu().numpy().astype(np.uint32) + + q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32) + for i in range(pack_factor): + q_packed |= q_w[:, i::pack_factor] << num_bits * i + + q_packed = torch.from_numpy(q_packed.astype(np.int32)) + + return q_packed + + +def pack_scales_24(scales, quantization_args, w_shape): + size_k = w_shape[0] + size_n = w_shape[1] + num_bits = quantization_args.num_bits + + _, scale_perm_2_4, scale_perm_single_2_4 = get_permutations_24(num_bits) + + if ( + quantization_args.strategy == QuantizationStrategy.GROUP + and quantization_args.group_size < size_k + ): + scales = scales.reshape((-1, len(scale_perm_2_4)))[:, scale_perm_2_4] + else: # channelwise + scales = scales.reshape((-1, len(scale_perm_single_2_4)))[ + :, scale_perm_single_2_4 + ] + scales = scales.reshape((-1, size_n)).contiguous() + + return scales diff --git a/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/__init__.py b/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9fde69a351939461d536f6eeea980520563cc7c0 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# flake8: noqa +# isort: skip_file + +from .quant_args import * +from .quant_config import * +from .quant_scheme import * +from .lifecycle import * diff --git a/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4718d6f116111817fcdf6dd2e208da391b40e43d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/__pycache__/quant_args.cpython-311.pyc b/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/__pycache__/quant_args.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9d8adbe7e98b0894f0276a128d4de6ecc00c2c83 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/__pycache__/quant_args.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/__pycache__/quant_config.cpython-311.pyc b/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/__pycache__/quant_config.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5143c60c2d2887f544725b1e00033f1ba65acf0 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/__pycache__/quant_config.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/__pycache__/quant_scheme.cpython-311.pyc b/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/__pycache__/quant_scheme.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e1d768cbe95161a3301dcfb28fbef5ee74ea11a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/__pycache__/quant_scheme.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/lifecycle/__init__.py b/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/lifecycle/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6acab2551f1a03dac27d80ac9f72f5d36728cb36 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/lifecycle/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# flake8: noqa +# isort: skip_file + +from .forward import * +from .initialize import * +from .compressed import * +from .apply import * +from .helpers import * diff --git a/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/lifecycle/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/lifecycle/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e7599a5c8884632de38a1b329f0edb457cf3513b Binary files /dev/null and b/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/lifecycle/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/lifecycle/__pycache__/apply.cpython-311.pyc b/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/lifecycle/__pycache__/apply.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb91940415c73148feee79761c64707c22bdb3a5 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/lifecycle/__pycache__/apply.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/lifecycle/__pycache__/compressed.cpython-311.pyc b/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/lifecycle/__pycache__/compressed.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff3716235c73f311b65cb9a9c402ee642da061d1 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/lifecycle/__pycache__/compressed.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/lifecycle/__pycache__/forward.cpython-311.pyc b/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/lifecycle/__pycache__/forward.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..44bc134f2efbae772e8f001dbf0e85559af8e627 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/lifecycle/__pycache__/forward.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/lifecycle/__pycache__/helpers.cpython-311.pyc b/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/lifecycle/__pycache__/helpers.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c4b9c1623798be7337ee6ffcbf2260d7d8fd0f81 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/lifecycle/__pycache__/helpers.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/lifecycle/__pycache__/initialize.cpython-311.pyc b/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/lifecycle/__pycache__/initialize.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea1784975a519e7288f5a3a78dccaf3da223c27f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/lifecycle/__pycache__/initialize.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/lifecycle/apply.py b/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/lifecycle/apply.py new file mode 100644 index 0000000000000000000000000000000000000000..31f14df045156c4252dd8f6b518ef6d1d4d3e417 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/lifecycle/apply.py @@ -0,0 +1,436 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import re +from collections import OrderedDict, defaultdict +from copy import deepcopy +from typing import Dict, Iterable, List, Optional +from typing import OrderedDict as OrderedDictType +from typing import Set, Union + +import torch +from compressed_tensors.config import CompressionFormat +from compressed_tensors.quantization.lifecycle.compressed import ( + compress_quantized_weights, +) +from compressed_tensors.quantization.lifecycle.initialize import ( + initialize_module_for_quantization, +) +from compressed_tensors.quantization.quant_args import QuantizationArgs +from compressed_tensors.quantization.quant_config import ( + QuantizationConfig, + QuantizationStatus, +) +from compressed_tensors.quantization.quant_scheme import QuantizationScheme +from compressed_tensors.quantization.utils import ( + KV_CACHE_TARGETS, + infer_quantization_status, + is_kv_cache_quant_scheme, + iter_named_leaf_modules, + iter_named_quantizable_modules, +) +from compressed_tensors.utils.helpers import fix_fsdp_module_name, replace_module +from compressed_tensors.utils.offload import update_parameter_data +from compressed_tensors.utils.safetensors_load import get_safetensors_folder +from torch.nn import Module + + +__all__ = [ + "load_pretrained_quantization", + "apply_quantization_config", + "apply_quantization_status", + "find_name_or_class_matches", + "expand_sparse_target_names", + "is_sparse_target", +] + +from compressed_tensors.quantization.utils.helpers import is_module_quantized +from compressed_tensors.utils.safetensors_load import get_quantization_state_dict + + +_LOGGER = logging.getLogger(__name__) + + +def load_pretrained_quantization(model: Module, model_name_or_path: str): + """ + Loads the quantization parameters (scale and zero point) from model_name_or_path to + a model that has already been initialized with a quantization config + + :param model: model to load pretrained quantization parameters to + :param model_name_or_path: Hugging Face stub or local folder containing a quantized + model, which is used to load quantization parameters + """ + model_path = get_safetensors_folder(model_name_or_path) + state_dict = get_quantization_state_dict(model_path) + + for name, submodule in iter_named_leaf_modules(model): + if not is_module_quantized(submodule): + continue + if submodule.quantization_scheme.weights is not None: + base_name = "weight" + _load_quant_args_from_state_dict( + base_name=base_name, + module_name=name, + module=submodule, + state_dict=state_dict, + ) + if submodule.quantization_scheme.input_activations is not None: + base_name = "input" + _load_quant_args_from_state_dict( + base_name=base_name, + module_name=name, + module=submodule, + state_dict=state_dict, + ) + if submodule.quantization_scheme.output_activations is not None: + base_name = "output" + _load_quant_args_from_state_dict( + base_name=base_name, + module_name=name, + module=submodule, + state_dict=state_dict, + ) + + +def apply_quantization_config( + model: Module, config: Union[QuantizationConfig, None], run_compressed: bool = False +) -> OrderedDict: + """ + Initializes the model for quantization in-place based on the given config. + Optionally coverts quantizable modules to compressed_linear modules + + :param model: model to apply quantization config to + :param config: quantization config + :param run_compressed: Whether the model will be run in compressed mode or + decompressed fully on load + """ + # Workaround for when HF Quantizer passes None, see PR #180 + if config is None: + return OrderedDict() + + # remove reference to the original `config` + # argument. This function can mutate it, and we'd + # like to keep the original `config` as it is. + config = deepcopy(config) + # build mapping of targets to schemes for easier matching + # use ordered dict to preserve target ordering in config + target_to_scheme = OrderedDict() + config = process_quantization_config(config) + names_to_scheme = OrderedDict() + for scheme in config.config_groups.values(): + for target in scheme.targets: + target_to_scheme[target] = scheme + + if run_compressed: + from compressed_tensors.linear.compressed_linear import CompressedLinear + + # list of submodules to ignore + ignored_submodules = defaultdict(list) + # mark appropriate layers for quantization by setting their quantization schemes + for name, submodule in iter_named_quantizable_modules( + model, + include_children=True, + include_attn=True, + ): # child modules and attention modules + # potentially fix module name to remove FSDP wrapper prefix + name = fix_fsdp_module_name(name) + if matches := find_name_or_class_matches(name, submodule, config.ignore): + for match in matches: + ignored_submodules[match].append(name) + continue # layer matches ignore list, continue + + targets = find_name_or_class_matches(name, submodule, target_to_scheme) + + if targets: + # mark modules to be quantized by adding + # quant scheme to the matching layers + scheme = _scheme_from_targets(target_to_scheme, targets, name) + if run_compressed: + format = config.format + if format != CompressionFormat.dense.value: + if isinstance(submodule, torch.nn.Linear): + # TODO: expand to more module types + compressed_linear = CompressedLinear.from_linear( + submodule, + quantization_scheme=scheme, + quantization_format=format, + ) + replace_module(model, name, compressed_linear) + + # target matched - add layer and scheme to target list + submodule.quantization_scheme = _scheme_from_targets( + target_to_scheme, targets, name + ) + + names_to_scheme[name] = submodule.quantization_scheme.weights + + if config.ignore is not None and ignored_submodules is not None: + if set(config.ignore) - set(ignored_submodules): + _LOGGER.warning( + "Some layers that were to be ignored were " + "not found in the model: " + f"{set(config.ignore) - set(ignored_submodules)}" + ) + + # apply current quantization status across all targeted layers + apply_quantization_status(model, config.quantization_status) + return names_to_scheme + + +def process_quantization_config(config: QuantizationConfig) -> QuantizationConfig: + """ + Preprocess the raw QuantizationConfig + + :param config: the raw QuantizationConfig + :return: the processed QuantizationConfig + """ + if config.kv_cache_scheme is not None: + config = process_kv_cache_config(config) + + return config + + +def process_kv_cache_config( + config: QuantizationConfig, targets: Union[List[str], str] = KV_CACHE_TARGETS +) -> QuantizationConfig: + """ + Reformulate the `config.kv_cache` as a `config_group` + and add it to the set of existing `config.groups` + + :param config: the QuantizationConfig + :return: the QuantizationConfig with additional "kv_cache" group + """ + if targets == KV_CACHE_TARGETS: + _LOGGER.info(f"KV cache targets set to default value of: {KV_CACHE_TARGETS}") + + kv_cache_dict = config.kv_cache_scheme.model_dump() + kv_cache_scheme = QuantizationScheme( + output_activations=QuantizationArgs(**kv_cache_dict), + targets=targets, + ) + kv_cache_group = dict(kv_cache=kv_cache_scheme) + config.config_groups.update(kv_cache_group) + return config + + +def apply_quantization_status(model: Module, status: QuantizationStatus): + """ + Applies in place the quantization lifecycle up to the given status + + :param model: model to apply quantization to + :param status: status to update the module to + """ + + current_status = infer_quantization_status(model) + + if status >= QuantizationStatus.INITIALIZED > current_status: + force_zero_point_init = status != QuantizationStatus.COMPRESSED + model.apply( + lambda module: initialize_module_for_quantization( + module, force_zero_point=force_zero_point_init + ) + ) + + if current_status < status >= QuantizationStatus.COMPRESSED > current_status: + model.apply(compress_quantized_weights) + + +def expand_sparse_target_names( + model: Module, targets: Iterable[str], ignore: Iterable[str] +) -> Set[str]: + """ + Finds all unique module names in the model that match the given + targets and ignore lists. + + Note: Targets must be regexes, layer types, or full layer names. + + :param model: model to search for targets in + :param targets: list of targets to search for + :param ignore: list of targets to ignore + :return: set of all targets that match the given targets and should + not be ignored + """ + return { + name + for name, module in iter_named_leaf_modules(model) + if is_sparse_target(name, module, targets, ignore) + } + + +def is_sparse_target( + name: str, module: Module, targets: Iterable[str], ignore: Iterable[str] +) -> bool: + """ + Determines if a module should be included in the targets based on the + targets and ignore lists. + + Note: Targets must be regexes, layer types, or full layer names. + + :param name: name of the module + :param module: the module itself + :param targets: list of targets to search for + :param ignore: list of targets to ignore + :return: True if the module is a target and not ignored, False otherwise + """ + return bool( + find_name_or_class_matches(name, module, targets) + and not find_name_or_class_matches(name, module, ignore or []) + ) + + +def find_name_or_class_matches( + name: str, module: Module, targets: Iterable[str], check_contains: bool = False +) -> List[str]: + """ + Returns all targets that match the given name or the class name. + Returns empty list otherwise. + The order of the output `matches` list matters. + The entries are sorted in the following order: + 1. matches on exact strings + 2. matches on regex patterns + 3. matches on module names + """ + targets = sorted(targets, key=lambda x: ("re:" in x, x)) + if isinstance(targets, Iterable): + matches = _find_matches(name, targets) + _find_matches( + module.__class__.__name__, targets, check_contains + ) + matches = [match for match in matches if match is not None] + return matches + + +def _find_matches( + value: str, targets: Iterable[str], check_contains: bool = False +) -> List[str]: + # returns all the targets that match value either + # exactly or as a regex after 're:'. if check_contains is set to True, + # additionally checks if the target string is contained with value. + matches = [] + for target in targets: + if target.startswith("re:"): + pattern = target[3:] + if re.match(pattern, value): + matches.append(target) + elif check_contains: + if target.lower() in value.lower(): + matches.append(target) + elif target == value: + matches.append(target) + return matches + + +def _infer_status(model: Module) -> Optional[QuantizationStatus]: + for module in model.modules(): + status = getattr(module, "quantization_status", None) + if status is not None: + return status + return None + + +def _load_quant_args_from_state_dict( + base_name: str, module_name: str, module: Module, state_dict: Dict +): + """ + Loads scale and zero point from a state_dict into the specified module + + :param base_name: quantization target, one of: weights, input_activations or + output_activations + :param module_name: pytorch module name to look up in state_dict + :module: pytorch module associated with module_name + :state_dict: state_dict to search for matching quantization parameters + """ + scale_name = f"{base_name}_scale" + zp_name = f"{base_name}_zero_point" + g_idx_name = f"{base_name}_g_idx" + + state_dict_scale = state_dict.get(f"{module_name}.{scale_name}", None) + state_dict_zp = state_dict.get(f"{module_name}.{zp_name}", None) + state_dict_g_idx = state_dict.get(f"{module_name}.{g_idx_name}", None) + + if state_dict_scale is not None: + # module is quantized + update_parameter_data(module, state_dict_scale, scale_name) + if state_dict_zp is None: + # fill in zero point for symmetric quantization + state_dict_zp = torch.zeros_like(state_dict_scale, device="cpu") + update_parameter_data(module, state_dict_zp, zp_name) + + if state_dict_g_idx is not None: + update_parameter_data(module, state_dict_g_idx, g_idx_name) + + +def _scheme_from_targets( + target_to_scheme: OrderedDictType[str, QuantizationScheme], + targets: List[str], + name: str, +) -> QuantizationScheme: + if len(targets) == 1: + # if `targets` iterable contains a single element + # use it as the key + return target_to_scheme[targets[0]] + + # otherwise, we need to merge QuantizationSchemes corresponding + # to multiple targets. This is most likely because `name` module + # is being target both as an ordinary quantization target, as well + # as kv cache quantization target + schemes_to_merge = [target_to_scheme[target] for target in targets] + return _merge_schemes(schemes_to_merge, name) + + +def _merge_schemes( + schemes_to_merge: List[QuantizationScheme], name: str +) -> QuantizationScheme: + + kv_cache_quantization_scheme = [ + scheme for scheme in schemes_to_merge if is_kv_cache_quant_scheme(scheme) + ] + if not kv_cache_quantization_scheme: + # if the schemes_to_merge do not contain any + # kv cache QuantizationScheme + # return the first scheme (the prioritized one, + # since the order of schemes_to_merge matters) + return schemes_to_merge[0] + else: + # fetch the kv cache QuantizationScheme and the highest + # priority non-kv cache QuantizationScheme and merge them + kv_cache_quantization_scheme = kv_cache_quantization_scheme[0] + quantization_scheme = [ + scheme + for scheme in schemes_to_merge + if not is_kv_cache_quant_scheme(scheme) + ][0] + schemes_to_merge = [kv_cache_quantization_scheme, quantization_scheme] + merged_scheme = {} + for scheme in schemes_to_merge: + scheme_dict = { + k: v for k, v in scheme.model_dump().items() if v is not None + } + # when merging multiple schemes, the final target will be + # the `name` argument - hence erase the original targets + del scheme_dict["targets"] + # make sure that schemes do not "clash" with each other + overlapping_keys = set(merged_scheme.keys()) & set(scheme_dict.keys()) + if overlapping_keys: + raise ValueError( + f"The module: {name} is being modified by two clashing " + f"quantization schemes, that jointly try to override " + f"properties: {overlapping_keys}. Fix the quantization config " + "so that it is not ambiguous." + ) + merged_scheme.update(scheme_dict) + + merged_scheme.update(targets=[name]) + + return QuantizationScheme(**merged_scheme) diff --git a/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/lifecycle/compressed.py b/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/lifecycle/compressed.py new file mode 100644 index 0000000000000000000000000000000000000000..00f707920d5c01f4cb60f5286f3fb499b26304f4 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/lifecycle/compressed.py @@ -0,0 +1,71 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging + +import torch +from compressed_tensors.quantization.lifecycle.forward import quantize +from compressed_tensors.quantization.quant_config import QuantizationStatus +from torch.nn import Module + + +__all__ = [ + "compress_quantized_weights", +] + + +_LOGGER = logging.getLogger(__name__) + + +def compress_quantized_weights(module: Module): + """ + Quantizes the module weight representation to use fewer bits in memory + + apply to full model with `model.apply(compress_quantized_weights)` + + :param module: module to compress to quantized representation + """ + scheme = getattr(module, "quantization_scheme", None) + if not scheme or not scheme.weights: + # no quantization scheme or weights not quantized, nothing to do + return + + if scheme is QuantizationStatus.COMPRESSED: + # module is already compressed, nothing to do + return + + weight = getattr(module, "weight", None) + scale = getattr(module, "weight_scale", None) + zero_point = getattr(module, "weight_zero_point", None) + g_idx = getattr(module, "weight_g_idx", None) + + if weight is None or scale is None: + # no weight, scale, or ZP, nothing to do + + # mark as compressed here to maintain consistent status throughout the model + module.quantization_status = QuantizationStatus.COMPRESSED + return + + module.weight.requires_grad = False # cannot use auto grad after compression + module.weight.data = quantize( + x=weight, + scale=scale, + zero_point=zero_point, + g_idx=g_idx, + args=scheme.weights, + dtype=torch.int8, + ) + + module.quantization_status = QuantizationStatus.COMPRESSED diff --git a/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/lifecycle/forward.py b/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/lifecycle/forward.py new file mode 100644 index 0000000000000000000000000000000000000000..f4f93f278ea5789c222d85a02455ada8b9af9d10 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/lifecycle/forward.py @@ -0,0 +1,394 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import wraps +from math import ceil +from typing import Optional + +import torch +from compressed_tensors.quantization.quant_args import ( + QuantizationArgs, + QuantizationStrategy, + round_to_quantized_type, +) +from compressed_tensors.quantization.quant_config import QuantizationStatus +from compressed_tensors.quantization.quant_scheme import QuantizationScheme +from compressed_tensors.quantization.utils import ( + calculate_range, + compute_dynamic_scales_and_zp, +) +from compressed_tensors.utils import safe_permute +from torch.nn import Module + + +__all__ = [ + "quantize", + "dequantize", + "fake_quantize", + "wrap_module_forward_quantized", + "forward_quantize", +] + + +@torch.no_grad() +def quantize( + x: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + args: QuantizationArgs, + dtype: Optional[torch.dtype] = None, + g_idx: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + Quantize the input tensor x using the QuantizationStrategy specified in args. + Quantization can be done per tensor, channel, token or group. For group + quantization, the group_size must be divisible by the column size. The input scale + and zero_points are reshaped to support vectorization (Assumes 1 is the + channel dimension) + + :param x: Input tensor + :param scale: scale tensor + :param zero_point: zero point tensor + :param args: quantization args dictating how to quantize x + :param dtype: optional dtype to cast the quantized output to + :param g_idx: optional mapping from column index to group index + :return: fake quantized tensor + """ + + return _process_quantization( + x=x, + scale=scale, + zero_point=zero_point, + args=args, + dtype=dtype, + do_quantize=True, + do_dequantize=False, + g_idx=g_idx, + ) + + +@torch.no_grad() +def dequantize( + x_q: torch.Tensor, + scale: torch.Tensor, + zero_point: Optional[torch.Tensor] = None, + args: Optional[QuantizationArgs] = None, + dtype: Optional[torch.dtype] = None, + g_idx: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + Dequantize a quantized input tensor x_q based on the strategy specified in args. If + args is not provided, the strategy will be inferred. + + :param x: quantized input tensor + :param scale: scale tensor + :param zero_point: zero point tensor + :param args: quantization args used to quantize x_q + :param dtype: optional dtype to cast the dequantized output to + :param g_idx: optional mapping from column index to group index + :return: dequantized float tensor + """ + if args is None: + if scale.ndim == 0 or scale.ndim == 1: + args = QuantizationArgs(strategy=QuantizationStrategy.TENSOR) + elif scale.ndim == 2: + if scale.shape[1] == 1: + args = QuantizationArgs(strategy=QuantizationStrategy.CHANNEL) + else: + group_size = int(x_q.shape[1] / scale.shape[1]) + args = QuantizationArgs( + strategy=QuantizationStrategy.GROUP, group_size=group_size + ) + else: + raise ValueError( + f"Could not infer a quantization strategy from scale with {scale.ndim} " + "dimmensions. Expected 0 or 2 dimmensions." + ) + + if dtype is None: + dtype = scale.dtype + + return _process_quantization( + x=x_q, + scale=scale, + zero_point=zero_point, + args=args, + do_quantize=False, + do_dequantize=True, + dtype=dtype, + g_idx=g_idx, + ) + + +@torch.no_grad() +def fake_quantize( + x: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + args: QuantizationArgs, + g_idx: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + Fake quantize the input tensor x by quantizing then dequantizing with + the QuantizationStrategy specified in args. Quantization can be done per tensor, + channel, token or group. For group quantization, the group_size must be divisible + by the column size. The input scale and zero_points are reshaped to support + vectorization (Assumes 1 is the channel dimension) + + :param x: Input tensor + :param scale: scale tensor + :param zero_point: zero point tensor + :param args: quantization args dictating how to quantize x + :param g_idx: optional mapping from column index to group index + :return: fake quantized tensor + """ + return _process_quantization( + x=x, + scale=scale, + zero_point=zero_point, + args=args, + do_quantize=True, + do_dequantize=True, + g_idx=g_idx, + ) + + +@torch.no_grad() +def _process_quantization( + x: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + args: QuantizationArgs, + g_idx: Optional[torch.Tensor] = None, + dtype: Optional[torch.dtype] = None, + do_quantize: bool = True, + do_dequantize: bool = True, +) -> torch.Tensor: + q_min, q_max = calculate_range(args, x.device) + group_size = args.group_size + + if args.strategy == QuantizationStrategy.GROUP: + output_dtype = dtype if dtype is not None else x.dtype + output = torch.zeros_like(x).to(output_dtype) + columns = output.shape[1] + + # TODO: make validation step for inputs + + while scale.ndim < 2: + # pad scale and zero point dims for slicing + scale = scale.unsqueeze(1) + zero_point = zero_point.unsqueeze(1) if zero_point is not None else None + + if columns >= group_size: + if columns % group_size != 0: + raise ValueError( + "tensor column shape must be divisble " + f"by the given group_size {group_size}" + ) + + # support column-order (default) quantization as well as other orderings + # such as activation ordering. Below checks if g_idx has been initialized + is_column_order = g_idx is None or -1 in g_idx + if is_column_order: + num_groups = int(ceil(columns / group_size)) + group_sizes = torch.full((num_groups,), group_size, dtype=torch.int) + + else: + group_indices, group_sizes = torch.unique(g_idx, return_counts=True) + group_sizes = group_sizes[torch.argsort(group_indices)] + + perm = torch.argsort(g_idx) + x = safe_permute(x, perm, dim=1) + + # TODO: experiment with vectorizing for loop for performance + end = 0 + for index, group_count in enumerate(group_sizes): + sc = scale[:, index].view(-1, 1) + zp = zero_point[:, index].view(-1, 1) if zero_point is not None else None + + start = end + end = start + group_count + if do_quantize: + output[:, start:end] = _quantize( + x[:, start:end], + sc, + zp, + q_min, + q_max, + args, + dtype=dtype, + ) + + if do_dequantize: + input = output[:, start:end] if do_quantize else x[:, start:end] + output[:, start:end] = _dequantize(input, sc, zp) + + if not is_column_order: + output = safe_permute(output, torch.argsort(perm), dim=1) + + else: # covers channel, token and tensor strategies + if do_quantize: + output = _quantize( + x, + scale, + zero_point, + q_min, + q_max, + args, + dtype=dtype, + ) + if do_dequantize: + output = _dequantize(output if do_quantize else x, scale, zero_point) + + return output + + +def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme): + # expects a module already initialized and injected with the parameters in + # initialize_module_for_quantization + if hasattr(module.forward, "__func__"): + forward_func_orig = module.forward.__func__ + else: + forward_func_orig = module.forward.func + + @wraps(forward_func_orig) # ensures docstring, names, etc are propagated + def wrapped_forward(self, *args, **kwargs): + if not getattr(module, "quantization_enabled", True): + # quantization is disabled on forward passes, return baseline + # forward call + return forward_func_orig.__get__(module, module.__class__)(*args, **kwargs) + + input_ = args[0] + + compressed = module.quantization_status == QuantizationStatus.COMPRESSED + + if scheme.input_activations is not None: + # prehook should calibrate activations before forward call + input_ = forward_quantize(module, input_, "input", scheme.input_activations) + + if scheme.weights is not None and not compressed: + # calibrate and (fake) quantize weights when applicable + unquantized_weight = self.weight.data.clone() + self.weight.data = forward_quantize( + module, self.weight, "weight", scheme.weights + ) + + # perform wrapped forward call + output = forward_func_orig.__get__(module, module.__class__)( + input_, *args[1:], **kwargs + ) + + # restore back to unquantized_value + if scheme.weights is not None and not compressed: + self.weight.data = unquantized_weight + + if scheme.output_activations is not None: + # forward-hook should calibrate/forward_quantize + if ( + module.quantization_status == QuantizationStatus.CALIBRATION + and not scheme.output_activations.dynamic + ): + return output + + output = forward_quantize( + module, output, "output", scheme.output_activations + ) + return output + + # bind wrapped forward to module class so reference to `self` is correct + bound_wrapped_forward = wrapped_forward.__get__(module, module.__class__) + # set forward to wrapped forward + setattr(module, "forward", bound_wrapped_forward) + + +def forward_quantize( + module: Module, value: torch.Tensor, base_name: str, args: "QuantizationArgs" +) -> torch.Tensor: + + # in compressed mode, the weight is already compressed and quantized so we don't + # need to run fake quantization + if ( + module.quantization_status == QuantizationStatus.COMPRESSED + and base_name == "weight" + ): + return value + + if value.numel() == 0: + # if the tensor is empty, + # skip quantization + return value + + g_idx = getattr(module, "weight_g_idx", None) + + if args.dynamic: + # dynamic quantization - determine the scale/zp on the fly + scale, zero_point = compute_dynamic_scales_and_zp(value=value, args=args) + else: + # static quantization - get scale and zero point from layer + scale = getattr(module, f"{base_name}_scale") + zero_point = getattr(module, f"{base_name}_zero_point", None) + + return fake_quantize( + x=value, + scale=scale, + zero_point=zero_point, + args=args, + g_idx=g_idx, + ) + + +@torch.no_grad() +def _quantize( + x: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + q_min: torch.Tensor, + q_max: torch.Tensor, + args: QuantizationArgs, + dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + + scaled = x / scale + if zero_point is not None: + scaled += zero_point.to(x.dtype) + # clamp first because cast isn't guaranteed to be saturated (ie for fp8) + clamped_value = torch.clamp( + scaled, + q_min, + q_max, + ) + quantized_value = round_to_quantized_type(clamped_value, args) + if dtype is not None: + quantized_value = quantized_value.to(dtype) + + return quantized_value + + +@torch.no_grad() +def _dequantize( + x_q: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor = None, + dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + dequant_value = x_q.to(scale.dtype) + + if zero_point is not None: + dequant_value = dequant_value - zero_point.to(scale.dtype) + dequant_value = dequant_value * scale + + if dtype is not None: + dequant_value = dequant_value.to(dtype) + + return dequant_value diff --git a/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/lifecycle/helpers.py b/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/lifecycle/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..9d7553284da7ee9cce616062d78aa195895cd6af --- /dev/null +++ b/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/lifecycle/helpers.py @@ -0,0 +1,33 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Miscelaneous helpers for the quantization lifecycle +""" + +from torch.nn import Module + + +__all__ = [ + "enable_quantization", + "disable_quantization", +] + + +def enable_quantization(module: Module): + module.quantization_enabled = True + + +def disable_quantization(module: Module): + module.quantization_enabled = False diff --git a/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/lifecycle/initialize.py b/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/lifecycle/initialize.py new file mode 100644 index 0000000000000000000000000000000000000000..8dd8fc513b8aceb8020dd558b90494cbae9aff60 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/lifecycle/initialize.py @@ -0,0 +1,213 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +from enum import Enum +from typing import Optional + +import torch +from compressed_tensors.quantization.lifecycle.forward import ( + wrap_module_forward_quantized, +) +from compressed_tensors.quantization.quant_args import ( + ActivationOrdering, + QuantizationArgs, + QuantizationStrategy, +) +from compressed_tensors.quantization.quant_config import QuantizationStatus +from compressed_tensors.quantization.quant_scheme import QuantizationScheme +from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme +from compressed_tensors.utils import ( + disable_hf_hook, + has_offloaded_params, + register_offload_parameter, +) +from torch.nn import Module, Parameter + + +__all__ = [ + "initialize_module_for_quantization", + "is_attention_module", + "KVCacheScaleType", +] + + +_LOGGER = logging.getLogger(__name__) + + +class KVCacheScaleType(Enum): + KEY = "k_scale" + VALUE = "v_scale" + + +def initialize_module_for_quantization( + module: Module, + scheme: Optional[QuantizationScheme] = None, + force_zero_point: bool = True, +): + """ + attaches appropriate scales, zero points, and observers to a layer + given its target quantization scheme + + apply to full model with `model.apply(initialize_module_for_quantization)` + + :param module: module to set for calibration + :param scheme: scheme to use for quantization. if None is provided, + will attempt to use scheme stored in the module under `quantization_scheme`, + if not provided, the layer will be skipped + :param force_zero_point: whether to force initialization of a zero point for + symmetric quantization + """ + scheme = scheme or getattr(module, "quantization_scheme", None) + if scheme is None: + # no scheme passed and layer not targeted for quantization - skip + return + + if is_attention_module(module): + # quantized actions based on calltime status + _initialize_attn_scales(module) + + else: + + if scheme.input_activations is not None: + _initialize_scale_zero_point( + module, + "input", + scheme.input_activations, + force_zero_point=force_zero_point, + ) + if scheme.weights is not None: + if hasattr(module, "weight"): + weight_shape = None + if isinstance(module, torch.nn.Linear): + weight_shape = module.weight.shape + _initialize_scale_zero_point( + module, + "weight", + scheme.weights, + weight_shape=weight_shape, + force_zero_point=force_zero_point, + ) + else: + _LOGGER.warning( + f"module type {type(module)} targeted for weight quantization but " + "has no attribute weight, skipping weight quantization " + f"for {type(module)}" + ) + + if scheme.output_activations is not None: + if not is_kv_cache_quant_scheme(scheme): + _initialize_scale_zero_point( + module, "output", scheme.output_activations + ) + + module.quantization_scheme = scheme + module.quantization_status = QuantizationStatus.INITIALIZED + + with disable_hf_hook(module): + # wrap forward call of module to perform + # quantized actions based on calltime status + wrap_module_forward_quantized(module, scheme) + + +def is_attention_module(module: Module): + return "attention" in module.__class__.__name__.lower() and ( + hasattr(module, "k_proj") + or hasattr(module, "v_proj") + or hasattr(module, "qkv_proj") + ) + + +def _initialize_scale_zero_point( + module: Module, + base_name: str, + quantization_args: QuantizationArgs, + weight_shape: Optional[torch.Size] = None, + force_zero_point: bool = True, +): + if quantization_args.dynamic: + return + + # begin on the same device as other parameters or cpu if offloaded. + # in the offloaded case, there's no point moving tensors to the execution device + # if they're going to be immediately offloaded by `register_offload_parameter` + params_device = next(module.parameters()).device + device = "cpu" if has_offloaded_params(module) else params_device + + # infer expected scale/zero point shape + if quantization_args.strategy == QuantizationStrategy.TOKEN: + expected_shape = (1, 1) + else: + expected_shape = 1 + + if base_name == "weight" and weight_shape is not None: + if quantization_args.strategy == QuantizationStrategy.CHANNEL: + # (output_channels, 1) + expected_shape = (weight_shape[0], 1) + elif quantization_args.strategy == QuantizationStrategy.GROUP: + num_groups = weight_shape[1] // quantization_args.group_size + expected_shape = (weight_shape[0], max(num_groups, 1)) + + scale_dtype = module.weight.dtype + if scale_dtype not in [torch.float16, torch.bfloat16, torch.float32]: + scale_dtype = torch.float16 + + # initializes empty scale, zero point, and g_idx parameters for the module + init_scale = Parameter( + torch.empty(expected_shape, dtype=scale_dtype, device=device), + requires_grad=False, + ) + register_offload_parameter(module, f"{base_name}_scale", init_scale) + + if force_zero_point or not quantization_args.symmetric: + zp_dtype = quantization_args.pytorch_dtype() + init_zero_point = Parameter( + torch.zeros(expected_shape, device=device, dtype=zp_dtype), + requires_grad=False, + ) + register_offload_parameter(module, f"{base_name}_zero_point", init_zero_point) + + # only grouped activation ordering has g_idx + if quantization_args.actorder == ActivationOrdering.GROUP: + g_idx_shape = (weight_shape[1],) + g_idx_dtype = torch.int + init_g_idx = Parameter( + torch.full(g_idx_shape, -1, device=device, dtype=g_idx_dtype), + requires_grad=False, + ) + register_offload_parameter(module, f"{base_name}_g_idx", init_g_idx) + + +def _initialize_attn_scales(module: Module) -> None: + """Initlaize k_scale, v_scale for self_attn""" + + expected_shape = 1 # per tensor + + param = next(module.parameters()) + scale_dtype = param.dtype + device = param.device + + init_scale = Parameter( + torch.empty(expected_shape, dtype=scale_dtype, device=device), + requires_grad=False, + ) + + module.register_parameter(KVCacheScaleType.KEY.value, init_scale) + + init_scale = Parameter( + torch.empty(expected_shape, dtype=scale_dtype, device=device), + requires_grad=False, + ) + module.register_parameter(KVCacheScaleType.VALUE.value, init_scale) diff --git a/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/quant_args.py b/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/quant_args.py new file mode 100644 index 0000000000000000000000000000000000000000..dcb4ecc9e9f553289e7a2f82ec0c29cc27b9a87e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/quant_args.py @@ -0,0 +1,273 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from enum import Enum +from typing import Any, Dict, Optional, Union + +import torch +from compressed_tensors.utils import Aliasable +from pydantic import BaseModel, Field, field_validator, model_validator + + +__all__ = [ + "FP8_DTYPE", + "QuantizationType", + "QuantizationStrategy", + "QuantizationArgs", + "round_to_quantized_type", + "ActivationOrdering", +] + +FP8_DTYPE = torch.float8_e4m3fn + + +class QuantizationType(str, Enum): + """ + Enum storing quantization type options + """ + + INT = "int" + FLOAT = "float" + + +class QuantizationStrategy(str, Enum): + """ + Enum storing quantization strategy options + """ + + TENSOR = "tensor" + CHANNEL = "channel" + GROUP = "group" + BLOCK = "block" + TOKEN = "token" + + +class ActivationOrdering(Aliasable, str, Enum): + """ + Enum storing strategies for activation ordering + + Group: reorder groups and weight\n + Weight: only reorder weight, not groups. Slightly lower accuracy but also lower + latency when compared to group actorder\n + Dynamic: alias for Group\n + Static: alias for Weight\n + """ + + GROUP = "group" + WEIGHT = "weight" + # aliases + DYNAMIC = "dynamic" + STATIC = "static" + + @staticmethod + def get_aliases() -> Dict[str, str]: + return { + "dynamic": "group", + "static": "weight", + } + + +class QuantizationArgs(BaseModel, use_enum_values=True): + """ + User facing arguments used to define a quantization config for weights or + activations + + :param num_bits: quantization bit depth + :param type: dtype to quantized to, either int or float + :param symmetric: whether or not quantization scale is symmetric about zero-point + :param strategy: string id determining the scope of scale/zero-point to apply + :param group_size: group length to use for the group strategy + :param block_structure: 2d block structure to use for the block strategy, must be + of the format "2x4", "8x16", etc. + :param dynamic: set True to perform dynamic quantization - values will not be + calibrated during calibration phase, instead during inference new quantization + ranges will be observed with every sample. Defaults to False for static + quantization. Note that enabling dynamic quantization will change the default + observer to a memoryless one + :param actorder: whether to apply group quantization in decreasing order of + activation. Defaults to None for arbitrary ordering + """ + + num_bits: int = 8 + type: QuantizationType = QuantizationType.INT + symmetric: bool = True + group_size: Optional[int] = None + strategy: Optional[QuantizationStrategy] = None + block_structure: Optional[str] = None + dynamic: bool = False + actorder: Union[ActivationOrdering, bool, None] = None + observer: Optional[str] = Field( + default="minmax", + description=( + "The class to use to compute the quantization param - " + "scale and zero-point'" + ), + ) + observer_kwargs: Dict[str, Any] = Field( + default_factory=dict, + description=( + "optional dict of kwargs to be passed directly to torch quantization " + "Observers constructor excluding quantization range or symmetry" + ), + ) + + def get_observer(self): + """ + :return: torch quantization FakeQuantize built based on these QuantizationArgs + """ + return self.observer + + @field_validator("type", mode="before") + def validate_type(cls, value) -> QuantizationType: + if isinstance(value, str): + return QuantizationType(value.lower()) + + return value + + @field_validator("group_size", mode="before") + def validate_group(cls, value) -> Union[int, None]: + if value is None: + return value + + if value < -1: + raise ValueError( + f"Invalid group size {value}. Use group_size > 0 for " + "strategy='group' and group_size = -1 for 'channel'" + ) + + return value + + @field_validator("strategy", mode="before") + def validate_strategy(cls, value) -> Union[QuantizationStrategy, None]: + if isinstance(value, str): + return QuantizationStrategy(value.lower()) + + return value + + @field_validator("actorder", mode="before") + def validate_actorder(cls, value) -> Optional[ActivationOrdering]: + if isinstance(value, bool): + return ActivationOrdering.GROUP if value else None + + if isinstance(value, str): + return ActivationOrdering(value.lower()) + + return value + + @model_validator(mode="after") + def validate_model_after(model: "QuantizationArgs") -> Dict[str, Any]: + # extract user-passed values from dictionary + strategy = model.strategy + group_size = model.group_size + actorder = model.actorder + dynamic = model.dynamic + observer = model.observer + + # infer strategy + if strategy is None: + if group_size is None: + strategy = QuantizationStrategy.TENSOR + elif group_size > 0: + strategy = QuantizationStrategy.GROUP + elif group_size == -1: + strategy = QuantizationStrategy.CHANNEL + else: + raise ValueError( + f"Invalid group size {group_size}. Use group_size > 0 for " + "strategy='group' and group_size = -1 for 'channel'" + ) + + # validate strategy and group + if strategy == QuantizationStrategy.GROUP: + if group_size is None or group_size <= 0: + raise ValueError( + f"strategy {strategy} requires group_size to be " + "set to a positive value" + ) + if ( + group_size is not None + and group_size > 0 + and strategy != QuantizationStrategy.GROUP + ): + raise ValueError("group_size requires strategy to be set to 'group'") + + # validate activation ordering and strategy + if actorder is not None and strategy != QuantizationStrategy.GROUP: + raise ValueError( + "Must use group quantization strategy in order to apply " + "activation ordering" + ) + + # infer observer w.r.t. dynamic + if dynamic: + if strategy not in ( + QuantizationStrategy.TOKEN, + QuantizationStrategy.TENSOR, + ): + raise ValueError( + f"One of {QuantizationStrategy.TOKEN} or " + f"{QuantizationStrategy.TENSOR} must be used for dynamic ", + "quantization", + ) + if observer is not None: + if observer != "memoryless": # avoid annoying users with old configs + warnings.warn( + "No observer is used for dynamic quantization, setting to None" + ) + observer = None + + elif observer is None: + # default to minmax for non-dynamic cases + observer = "minmax" + + # write back modified values + model.strategy = strategy + model.observer = observer + return model + + def pytorch_dtype(self) -> torch.dtype: + if self.type == QuantizationType.FLOAT: + return FP8_DTYPE + elif self.type == QuantizationType.INT: + if self.num_bits <= 8: + return torch.int8 + elif self.num_bits <= 16: + return torch.int16 + else: + return torch.int32 + else: + raise ValueError(f"Invalid quantization type {self.type}") + + +def round_to_quantized_type( + tensor: torch.Tensor, args: QuantizationArgs +) -> torch.Tensor: + """ + Rounds each element of the input tensor to the nearest quantized representation, + keeping to original dtype + + :param tensor: tensor to round + :param args: QuantizationArgs to pull appropriate dtype from + :return: rounded tensor + """ + original_dtype = tensor.dtype + if args.type == QuantizationType.FLOAT: + rounded = tensor.to(FP8_DTYPE) + elif args.type == QuantizationType.INT: + rounded = torch.round(tensor) + else: + raise ValueError(f"Invalid quantization type {args.type}") + + return rounded.to(original_dtype) diff --git a/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/quant_config.py b/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/quant_config.py new file mode 100644 index 0000000000000000000000000000000000000000..3a80f0cb1cfcc2afc892356e5f5ee36901479060 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/quant_config.py @@ -0,0 +1,264 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum +from typing import Dict, List, Optional, Union + +from compressed_tensors.config import CompressionFormat +from compressed_tensors.quantization.quant_args import QuantizationArgs +from compressed_tensors.quantization.quant_scheme import ( + QuantizationScheme, + preset_name_to_scheme, +) +from compressed_tensors.quantization.utils import ( + calculate_compression_ratio, + is_module_quantized, + iter_named_quantizable_modules, + module_type, + parse_out_kv_cache_args, +) +from pydantic import BaseModel, Field +from torch.nn import Module + + +__all__ = [ + "QuantizationStatus", + "QuantizationConfig", + "LIFECYCLE_ORDER", + "DEFAULT_QUANTIZATION_METHOD", + "DEFAULT_QUANTIZATION_FORMAT", +] + + +class QuantizationStatus(str, Enum): + """ + Enum storing the different states a quantized layer can be in + + Initialized: scale, zero points and observers have been attached to the layer but + are set to dummy values (not yet calibrated) + Calibration: scale and zero points have been calibrated through OBCQ or similar + algorithm, observers are still attached + Frozen: scale and zero points are finalized, observers have been deleted, weights + are still in their original precision + Compressed: weights have been converted to their target type or compressed to + their closed approximation + """ + + INITIALIZED = "initialized" + CALIBRATION = "calibration" + FROZEN = "frozen" + COMPRESSED = "compressed" + + @classmethod + def lifecycle_order(cls) -> List["QuantizationStatus"]: + """ + :return: list of correct quantization lifecycle order + """ + return + + def __ge__(self, other): + if other is None: + return True + if not isinstance(other, self.__class__): + raise NotImplementedError + return LIFECYCLE_ORDER.index(self) >= LIFECYCLE_ORDER.index(other) + + def __gt__(self, other): + if other is None: + return True + if not isinstance(other, self.__class__): + raise NotImplementedError + return LIFECYCLE_ORDER.index(self) > LIFECYCLE_ORDER.index(other) + + def __lt__(self, other): + if other is None: + return False + if not isinstance(other, self.__class__): + raise NotImplementedError + return LIFECYCLE_ORDER.index(self) < LIFECYCLE_ORDER.index(other) + + def __le__(self, other): + if other is None: + return False + if not isinstance(other, self.__class__): + raise NotImplementedError + return LIFECYCLE_ORDER.index(self) <= LIFECYCLE_ORDER.index(other) + + +LIFECYCLE_ORDER = [ + QuantizationStatus.INITIALIZED, + QuantizationStatus.CALIBRATION, + QuantizationStatus.FROZEN, + QuantizationStatus.COMPRESSED, +] + +DEFAULT_QUANTIZATION_METHOD = "compressed-tensors" +DEFAULT_QUANTIZATION_FORMAT = "fakequant" + + +class QuantizationConfig(BaseModel): + """ + Full configuration specifying how a model is quantized. Each quantized layer is + mapped to a QuantizationScheme in config_groups. + + :param config_groups: dict of QuantizationSchemes specifying the quantization + settings for each quantized layer. A group could also be a reference to + a predefined scheme name, mapped to a list of its target layers/classes + :param quant_method: a constant used to differentiate sparseML quantization from + other quantization configs + :param format: specifies how the quantized model is stored on disk + :quantization_status: specifies the current status of all quantized layers. It is + assumed all layers are in the same state. + :param kv_cache_scheme: optional QuantizationArgs, that specify the + quantization of the kv cache. If None, kv cache is not quantized. + When applying kv cache quantization to transformer AutoModelForCausalLM, + the kv_cache_scheme gets converted into a QuantizationScheme that: + - targets the `q_proj` and `k_proj` modules of the model. The outputs + of those modules are the keys and values that might be cached + - quantizes the outputs of the aformentioned layers, so that + keys and values are compressed before storing them in the cache + There is an explicit assumption that the model contains modules with + `k_proj` and `v_proj` in their names. If this is not the case + and kv_cache_scheme != None, the quantization of kv cache will fail + :global_compression_ratio: optional informational config to report the model + compression ratio acheived by the quantization config + :ignore: optional list of layers to ignore from config_groups. Layers in this list + are not quantized even if they match up with a target in config_groups + """ + + config_groups: Dict[str, Union[QuantizationScheme, List[str]]] + quant_method: str = DEFAULT_QUANTIZATION_METHOD + kv_cache_scheme: Optional[QuantizationArgs] = None + format: str = DEFAULT_QUANTIZATION_FORMAT + quantization_status: QuantizationStatus = QuantizationStatus.INITIALIZED + global_compression_ratio: Optional[float] = None + ignore: Optional[List[str]] = Field(default_factory=list) + + def model_post_init(self, __context): + """ + updates any quantization schemes defined as presets to be fully loaded + schemes + """ + for group_name, targets_or_scheme in self.config_groups.items(): + if isinstance(targets_or_scheme, QuantizationScheme): + continue # scheme already defined + self.config_groups[group_name] = preset_name_to_scheme( + name=group_name, + targets=targets_or_scheme, + ) + + def to_dict(self): + # for compatibility with HFQuantizer + return self.model_dump() + + @staticmethod + def from_pretrained( + model: Module, format: Optional[str] = None + ) -> Optional["QuantizationConfig"]: + """ + Converts a model into its associated QuantizationConfig based on the + QuantizationScheme attached to each quantized module + + :param model: model to calculate quantization scheme of + :return: filled out QuantizationScheme for the input model + """ + quant_scheme_to_layers = [] + quantization_status = None + ignore = {} + quantization_type_names = set() + for name, submodule in iter_named_quantizable_modules( + model, include_children=True, include_attn=True + ): + layer_type = module_type(submodule) + if not is_module_quantized(submodule): + if layer_type not in ignore: + ignore[layer_type] = [] + ignore[layer_type].append(name) + else: + quantization_status = submodule.quantization_status + scheme = submodule.quantization_scheme + quantization_type_names.add(layer_type) + + match_found = False + for existing_scheme in quant_scheme_to_layers: + if scheme == existing_scheme: + match_found = True + break + if not match_found: + quant_scheme_to_layers.append(scheme) + + if len(quant_scheme_to_layers) == 0: # No quantized layers + return None + + # kv-cache only, no weight/activation quantization + if ( + len(quantization_type_names) == 1 + and "attention" in list(quantization_type_names)[0].lower() + ): + quantization_type_names.add("Linear") + + # clean up ignore list, we can leave out layers types if none of the + # instances are quantized + consolidated_ignore = [] + for layer_type, ignore_names in ignore.items(): + if layer_type in quantization_type_names: + # specific layers of a quantized type are ignored + consolidated_ignore += ignore_names + # else we leave it off the ignore list, doesn't fall under any of the + # existing quantization schemes so it won't be quantized + + kv_cache_args, quant_scheme_to_layers = parse_out_kv_cache_args( + quant_scheme_to_layers + ) + kv_cache_scheme = ( + kv_cache_args.model_dump() if kv_cache_args is not None else kv_cache_args + ) + + config_groups = {} + for idx, scheme in enumerate(quant_scheme_to_layers): + group_name = "group_" + str(idx) + config_groups[group_name] = scheme + + # TODO: this is incorrect in compressed mode, since we are overwriting the + # original weight we lose the uncompressed bit_depth indo + compression_ratio = calculate_compression_ratio(model) + + if format is None: + if quantization_status == QuantizationStatus.COMPRESSED: + format = CompressionFormat.int_quantized.value + else: + format = CompressionFormat.dense.value + + return QuantizationConfig( + config_groups=config_groups, + quantization_status=quantization_status, + kv_cache_scheme=kv_cache_scheme, + global_compression_ratio=compression_ratio, + format=format, + ignore=consolidated_ignore, + ) + + def requires_calibration_data(self): + if self.kv_cache_scheme is not None: + return True + + for _, scheme in self.config_groups.items(): + if scheme.input_activations is not None: + if not scheme.input_activations.dynamic: + return True + if scheme.output_activations is not None: + if not scheme.output_activations.dynamic: + return True + + return False diff --git a/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/quant_scheme.py b/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/quant_scheme.py new file mode 100644 index 0000000000000000000000000000000000000000..36b886044565ad3b36adb47a8c7ab5a9c081adb1 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/quant_scheme.py @@ -0,0 +1,215 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from copy import deepcopy +from typing import Any, Dict, List, Optional + +from compressed_tensors.quantization.quant_args import ( + QuantizationArgs, + QuantizationStrategy, + QuantizationType, +) +from pydantic import BaseModel, model_validator + + +__all__ = [ + "QuantizationScheme", + "preset_name_to_scheme", + "is_preset_scheme", +] + + +class QuantizationScheme(BaseModel): + """ + Set of QuantizationArgs defining how the weights, inputs and outputs of target list + of modules should be quantized + + :param targets: list of modules to apply the QuantizationArgs to, can be layer + names, layer types or a regular expression, typically ["Linear"] + :param weights: quantization config for layer weights + :param input_activations: quantization config for layer inputs + :param output_activations: quantization config for layer outputs + """ + + targets: List[str] + weights: Optional[QuantizationArgs] = None + input_activations: Optional[QuantizationArgs] = None + output_activations: Optional[QuantizationArgs] = None + + @model_validator(mode="after") + def validate_model_after(model: "QuantizationArgs") -> Dict[str, Any]: + inputs = model.input_activations + outputs = model.output_activations + + if inputs is not None: + if inputs.actorder is not None: + raise ValueError("Cannot apply actorder to input activations") + + if outputs is not None: + if outputs.actorder is not None: + raise ValueError("Cannot apply actorder to output activations") + + return model + + +""" +Pre-Set Quantization Scheme Args +""" + + +def preset_name_to_scheme(name: str, targets: List[str]) -> QuantizationScheme: + """ + :param name: preset quantization settings name. must exist in upper case in + PRESET_SCHEMES + :param targets: list of quantization targets to be passed to the Scheme + :return: new QuantizationScheme for a given name with the given targets + """ + name = name.upper() + + if name not in PRESET_SCHEMES: + raise KeyError( + f"Unknown preset scheme name {name}, " + f"available names: {list(PRESET_SCHEMES.keys())}" + ) + + scheme_args = deepcopy(PRESET_SCHEMES[name]) # deepcopy to avoid args references + return QuantizationScheme( + targets=targets, + **scheme_args, + ) + + +def is_preset_scheme(name: str) -> bool: + """ + :param name: preset quantization settings name + :return: True if the name is a preset scheme name + """ + return name.upper() in PRESET_SCHEMES + + +UNQUANTIZED = dict() + +# 8 bit integer weights and 8 bit activations quantization +INT8_W8A8 = dict( + weights=QuantizationArgs( + num_bits=8, + type=QuantizationType.INT, + strategy=QuantizationStrategy.CHANNEL, + symmetric=True, + dynamic=False, + ), + input_activations=QuantizationArgs( + num_bits=8, + type=QuantizationType.INT, + strategy=QuantizationStrategy.TOKEN, + symmetric=True, + dynamic=True, + observer=None, + ), +) + +# 8 bit integer weights only quantization +W8A16 = dict( + weights=QuantizationArgs( + num_bits=8, + type=QuantizationType.INT, + strategy=QuantizationStrategy.CHANNEL, + symmetric=True, + dynamic=False, + ), +) + +# 4 bit integer weights only quantization +W4A16 = dict( + weights=QuantizationArgs( + num_bits=4, + type=QuantizationType.INT, + strategy=QuantizationStrategy.GROUP, + group_size=128, + symmetric=True, + dynamic=False, + ), +) + +# 4 bit integer weights and 8 bit activations quantization +INT8_W4A8 = dict( + weights=QuantizationArgs( + num_bits=4, + type=QuantizationType.INT, + group_size=128, + strategy=QuantizationStrategy.GROUP, + symmetric=True, + dynamic=False, + ), + input_activations=QuantizationArgs( + num_bits=8, + type=QuantizationType.INT, + strategy=QuantizationStrategy.TOKEN, + symmetric=True, + dynamic=True, + observer=None, + ), +) + +# FP8 weights and FP8 activations quantization +FP8 = dict( + weights=QuantizationArgs( + num_bits=8, + type=QuantizationType.FLOAT, + strategy=QuantizationStrategy.TENSOR, + symmetric=True, + dynamic=False, + ), + input_activations=QuantizationArgs( + num_bits=8, + type=QuantizationType.FLOAT, + strategy=QuantizationStrategy.TENSOR, + symmetric=True, + dynamic=False, + ), +) + +# FP8 weights and FP8 dynamic activations quantization +FP8_DYNAMIC = dict( + weights=QuantizationArgs( + num_bits=8, + type=QuantizationType.FLOAT, + strategy=QuantizationStrategy.CHANNEL, + symmetric=True, + dynamic=False, + ), + input_activations=QuantizationArgs( + num_bits=8, + type=QuantizationType.FLOAT, + strategy=QuantizationStrategy.TOKEN, + symmetric=True, + dynamic=True, + observer=None, + ), +) + +PRESET_SCHEMES = { + # Unquantized (no-op) + "UNQUANTIZED": UNQUANTIZED, + # Integer weight only schemes + "W8A16": W8A16, + "W4A16": W4A16, + # Integer weight and activation schemes + "W8A8": INT8_W8A8, + "INT8": INT8_W8A8, # alias for W8A8 + "W4A8": INT8_W4A8, + # Float weight and activation schemes + "FP8": FP8, + "FP8_DYNAMIC": FP8_DYNAMIC, +} diff --git a/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/utils/__init__.py b/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a91f9e5d2cd34e36262d83ed5772724fb4181781 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/utils/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# flake8: noqa +from .helpers import * diff --git a/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/utils/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/utils/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..51e0ba9a65e2e345c5d4a9a3a62eb4aabe89de66 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/utils/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/utils/__pycache__/helpers.cpython-311.pyc b/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/utils/__pycache__/helpers.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..02ac4e874b43752d9dcebd5da8a2d990e5cac108 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/utils/__pycache__/helpers.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/utils/helpers.py b/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/utils/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..9f65ee33030fca305a178724b7cb0744b5700520 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/compressed_tensors/quantization/utils/helpers.py @@ -0,0 +1,392 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Generator, List, Optional, Tuple + +import torch +from compressed_tensors.quantization.quant_args import ( + FP8_DTYPE, + QuantizationArgs, + QuantizationStrategy, + QuantizationType, +) +from compressed_tensors.quantization.quant_scheme import QuantizationScheme +from torch import FloatTensor, IntTensor, Tensor +from torch.nn import Module +from tqdm import tqdm + + +__all__ = [ + "infer_quantization_status", + "is_module_quantized", + "is_model_quantized", + "module_type", + "calculate_compression_ratio", + "get_torch_bit_depth", + "can_quantize", + "parse_out_kv_cache_args", + "KV_CACHE_TARGETS", + "is_kv_cache_quant_scheme", + "iter_named_leaf_modules", + "iter_named_quantizable_modules", + "compute_dynamic_scales_and_zp", + "calculate_range", + "calculate_qparams", +] + +# target the self_attn layer +# QuantizedKVParameterCache is responsible for obtaining the k_scale and v_scale +KV_CACHE_TARGETS = ["re:.*self_attn$"] + +_LOGGER: logging.Logger = logging.getLogger(__name__) + + +def calculate_qparams( + min_vals: Tensor, max_vals: Tensor, quantization_args: QuantizationArgs +) -> Tuple[FloatTensor, IntTensor]: + """ + :param min_vals: tensor of min value(s) to calculate scale(s) and zero point(s) + from + :param max_vals: tensor of max value(s) to calculate scale(s) and zero point(s) + from + :param quantization_args: settings to quantization + :return: tuple of the calculated scale(s) and zero point(s) + """ + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + device = min_vals.device + + bit_min, bit_max = calculate_range(quantization_args, device) + bit_range = bit_max - bit_min + zp_dtype = quantization_args.pytorch_dtype() + + if quantization_args.symmetric: + max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals)) + scales = max_val_pos / (float(bit_range) / 2) + scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps) + zero_points = torch.zeros(scales.shape, device=device, dtype=min_vals.dtype) + else: + scales = (max_vals - min_vals) / float(bit_range) + scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps) + zero_points = bit_min - (min_vals / scales) + zero_points = torch.clamp(zero_points, bit_min, bit_max) + + # match zero-points to quantized type + zero_points = zero_points.to(zp_dtype) + + if scales.ndim == 0: + scales = scales.reshape(1) + zero_points = zero_points.reshape(1) + + return scales, zero_points + + +def compute_dynamic_scales_and_zp(value: Tensor, args: QuantizationArgs): + """ + Returns the computed scales and zero points for dynamic activation + qunatization. + + :param value: tensor to calculate quantization parameters for + :param args: quantization args + :param reduce_dims: optional tuple of dimensions to reduce along, + returned scale and zero point will be shaped (1,) along the + reduced dimensions + :return: tuple of scale and zero point derived from the observed tensor + """ + if args.strategy == QuantizationStrategy.TOKEN: + dim = {1, 2} + reduce_dims = tuple(idx for idx in range(value.ndim) if idx not in dim) + elif args.strategy == QuantizationStrategy.TENSOR: + reduce_dims = None + else: + raise ValueError( + f"One of {QuantizationStrategy.TOKEN} or {QuantizationStrategy.TENSOR} ", + "must be used for dynamic quantization", + ) + + if not reduce_dims: + min_val, max_val = torch.aminmax(value) + else: + min_val = torch.amin(value, dim=reduce_dims, keepdims=True) + max_val = torch.amax(value, dim=reduce_dims, keepdims=True) + + return calculate_qparams(min_val, max_val, args) + + +def calculate_range(quantization_args: QuantizationArgs, device: str) -> Tuple: + """ + Calculated the effective quantization range for the given Quantization Args + + :param quantization_args: quantization args to get range of + :param device: device to store the range to + :return: tuple endpoints for the given quantization range + """ + if quantization_args.type == QuantizationType.INT: + bit_range = 2**quantization_args.num_bits + q_max = torch.tensor(bit_range / 2 - 1, device=device) + q_min = torch.tensor(-bit_range / 2, device=device) + elif quantization_args.type == QuantizationType.FLOAT: + if quantization_args.num_bits != 8: + raise ValueError( + "Floating point quantization is only supported for 8 bits," + f"got {quantization_args.num_bits}" + ) + fp_range_info = torch.finfo(FP8_DTYPE) + q_max = torch.tensor(fp_range_info.max, device=device) + q_min = torch.tensor(fp_range_info.min, device=device) + else: + raise ValueError(f"Invalid quantization type {quantization_args.type}") + + return q_min, q_max + + +def infer_quantization_status(model: Module) -> Optional["QuantizationStatus"]: # noqa + """ + Checks the quantization status of a model. Assumes all modules in the model have + the same status, so only the first quantized model is checked. + + :param model: model to check quantization status for + :return: quantization status if the model is quantized, otherwise None + """ + for module in model.modules(): + status = getattr(module, "quantization_status", None) + if status is not None: + return status + return None + + +def is_module_quantized(module: Module) -> bool: + """ + Check if a module is quantized, based on the existence of a non-empty quantization + scheme + + :param module: pytorch module to check + :return: True if module is quantized, False otherwise + """ + if not hasattr(module, "quantization_scheme"): + return False + + if module.quantization_scheme.weights is not None: + return True + + if module.quantization_scheme.input_activations is not None: + return True + + if module.quantization_scheme.output_activations is not None: + return True + + return False + + +def is_model_quantized(model: Module) -> bool: + """ + Check if any modules in a model are quantized, based on the existence of a non-empty + quantization scheme in at least one module + + :param model: pytorch model + :return: True if model is quantized, False otherwise + """ + + for _, submodule in iter_named_leaf_modules(model): + if is_module_quantized(submodule): + return True + + return False + + +def module_type(module: Module) -> str: + """ + Gets a string representation of a module type + + :module: pytorch module to get type of + :return: module type as a string + """ + return type(module).__name__ + + +def iter_named_leaf_modules(model: Module) -> Generator[Tuple[str, Module], None, None]: + """ + Yields modules that do not have any submodules except observers. The observers + themselves are not yielded + :param model: model to get leaf modules of + :returns: generator tuple of (name, leaf_submodule) + """ + for name, submodule in model.named_modules(): + children = list(submodule.children()) + # TODO: verify if an observer would ever be attached in this case/remove check + if len(children) == 0 and "observer" in name: + yield name, submodule + else: + if len(children) > 0: + named_children, children = zip(*list(submodule.named_children())) + has_non_observer_children = False + for i in range(len(children)): + child_name = named_children[i] + + if "observer" not in child_name: + has_non_observer_children = True + + if not has_non_observer_children: + yield name, submodule + + +def iter_named_quantizable_modules( + model: Module, include_children: bool = True, include_attn: bool = False +) -> Generator[Tuple[str, Module], None, None]: + """ + Yield name and submodule of + - leaf modules, set by include_children + - attention modyles, set by include_attn + + :param model: model to get leaf modules of + :param include_children: flag to get the leaf modules + :param inlcude_attn: flag to get the attention modules + :returns: generator tuple of (name, submodule) + """ + for name, submodule in model.named_modules(): + # TODO: verify if an observer would ever be attached in this case/remove check + if include_children: + children = list(submodule.children()) + if len(children) == 0 and "observer" not in name: + yield name, submodule + else: + if len(children) > 0: + named_children, children = zip(*list(submodule.named_children())) + has_non_observer_children = False + for i in range(len(children)): + child_name = named_children[i] + + if "observer" not in child_name: + has_non_observer_children = True + + if not has_non_observer_children: + yield name, submodule + if include_attn: + if name.endswith("self_attn"): + yield name, submodule + + +def get_torch_bit_depth(value: torch.Tensor) -> int: + """ + Determine the number of bits used to represent the dtype of a tensor + + :param value: tensor to check bit depth of + :return: bit depth of each element in the value tensor + """ + try: + bit_depth = torch.finfo(value.dtype).bits + except TypeError: + bit_depth = torch.iinfo(value.dtype).bits + + return bit_depth + + +def can_quantize(value: torch.Tensor, quant_args: "QuantizationArgs") -> bool: # noqa + """ + Checks if value can be quantized by quant_args. + + :param value: tensor to check for quantization + :param quant_args: QuantizationArgs to use for quantization + :return: False if value is already quantized to quant_args or value is incompatible + with quant_args, True if value can be quantized with quant_args + """ + bit_depth = get_torch_bit_depth(value) + requested_depth = quant_args.num_bits + if bit_depth < quant_args.num_bits: + _LOGGER.warn( + f"Can't quantize tensor with bit depth {bit_depth} to {requested_depth}." + "The QuantizationArgs provided are not compatible with the input tensor." + ) + + return bit_depth > quant_args.num_bits + + +def calculate_compression_ratio(model: Module) -> float: + """ + Calculates the quantization compression ratio of a pytorch model, based on the + number of bits needed to represent the total weights in compressed form. Does not + take into account activation quantizatons. + + :param model: pytorch module to calculate compression ratio for + :return: compression ratio of the whole model + """ + total_compressed = 0.0 + total_uncompressed = 0.0 + for name, submodule in tqdm( + iter_named_leaf_modules(model), + desc="Calculating quantization compression ratio", + ): + for parameter in model.parameters(): + uncompressed_bits = get_torch_bit_depth(parameter) + compressed_bits = uncompressed_bits + if is_module_quantized(submodule) and submodule.quantization_scheme.weights: + compressed_bits = submodule.quantization_scheme.weights.num_bits + + num_weights = parameter.numel() + total_compressed += compressed_bits * num_weights + total_uncompressed += uncompressed_bits * num_weights + + return total_uncompressed / total_compressed + + +def is_kv_cache_quant_scheme(scheme: QuantizationScheme) -> bool: + """ + Check whether the QuantizationScheme targets the kv cache. + It does if all the following criteria are met: + - the scheme targets either exactly match the KV_CACHE_TARGETS + or the match KV_CACHE_TARGETS regex pattern + - the scheme quantizes output_activations (we want to quantize the + outputs from the KV_CACHE_TARGETS, as their correspond to the + keys and values that are to be saved in the cache) + + :param scheme: The QuantizationScheme to investigate + :return: boolean flag + """ + for target in scheme.targets: + if target in KV_CACHE_TARGETS: + return True + + return False + + +def parse_out_kv_cache_args( + quant_scheme_to_layers: List[QuantizationScheme], +) -> Tuple[Optional[QuantizationArgs], List[QuantizationScheme]]: + """ + If possible, parse out the kv cache specific QuantizationArgs + from the list of the QuantizationSchemes. If no kv cache + specific QuantizationArgs available, this function acts + as an identity function + + :param quant_scheme_to_layers: list of QuantizationSchemes + :return: kv_cache_args (optional) and the (remaining or original) + list of the QuantizationSchemes + """ + kv_cache_quant_scheme_to_layers = [ + scheme for scheme in quant_scheme_to_layers if is_kv_cache_quant_scheme(scheme) + ] + quant_scheme_to_layers = [ + scheme + for scheme in quant_scheme_to_layers + if not is_kv_cache_quant_scheme(scheme) + ] + + if kv_cache_quant_scheme_to_layers: + kv_cache_quant_scheme_to_layers = kv_cache_quant_scheme_to_layers[0] + kv_cache_args = kv_cache_quant_scheme_to_layers.output_activations + else: + kv_cache_args = None + + return kv_cache_args, quant_scheme_to_layers diff --git a/.venv/lib/python3.11/site-packages/compressed_tensors/version.py b/.venv/lib/python3.11/site-packages/compressed_tensors/version.py new file mode 100644 index 0000000000000000000000000000000000000000..2db15aac573c24b12b950eaf0cdc36bdf25977b9 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/compressed_tensors/version.py @@ -0,0 +1,53 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Functionality for storing and setting the version info for SparseML +""" + + +version_base = "0.9.1" +is_release = True # change to True to set the generated version as a release version + + +def _generate_version( + is_release: bool, + version_base: str, +): + from datetime import date + + if is_release: + return version_base + else: + return f"{version_base}.{date.today().strftime('%Y%m%d')}" + + +__all__ = [ + "__version__", + "version_base", + "is_release", + "version", + "version_major", + "version_minor", + "version_patch", + "version_build", + "version_major_minor", +] +__version__ = _generate_version(is_release, version_base) + +version = __version__ +version_major, version_minor, version_patch, version_build = version.split(".") + ( + [None] if len(version.split(".")) < 4 else [] +) # handle conditional for version being 3 parts or 4 (4 containing build date) +version_major_minor = f"{version_major}.{version_minor}" diff --git a/.venv/lib/python3.11/site-packages/xformers/helpers/__init__.py b/.venv/lib/python3.11/site-packages/xformers/helpers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..11aa37df97bf303b3a3bbc78b7d962b74254d082 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/helpers/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from .timm_sparse_attention import TimmSparseAttention # noqa diff --git a/.venv/lib/python3.11/site-packages/xformers/helpers/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/helpers/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..294ab7ac891c5cc17375bbd17456e2361d3de798 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/helpers/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/helpers/__pycache__/hierarchical_configs.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/helpers/__pycache__/hierarchical_configs.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4d67f323da10839799e2c31bbbee78812b80a44c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/helpers/__pycache__/hierarchical_configs.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/helpers/__pycache__/test_utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/helpers/__pycache__/test_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8745663488318e4b7a28bb6be90a6642475d0abc Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/helpers/__pycache__/test_utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/helpers/__pycache__/timm_sparse_attention.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/helpers/__pycache__/timm_sparse_attention.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..27186cc7e8a69e5046e9351d4abe5508cf1f3cb7 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/helpers/__pycache__/timm_sparse_attention.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/helpers/hierarchical_configs.py b/.venv/lib/python3.11/site-packages/xformers/helpers/hierarchical_configs.py new file mode 100644 index 0000000000000000000000000000000000000000..60b9564c7ef7bc39c7ea898f11ebb51768faa6e4 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/helpers/hierarchical_configs.py @@ -0,0 +1,124 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import copy +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +from xformers._deprecation_warning import deprecated_function +from xformers.components.residual import ResidualNormStyle + + +@dataclass +class BasicLayerConfig: + embedding: int + attention_mechanism: str + patch_size: int + stride: int + padding: int + seq_len: int + feedforward: str + normalization: str = "layernorm" + repeat_layer: int = 1 + + +def get_hierarchical_configuration( + layer_base_configs: List[BasicLayerConfig], + residual_norm_style: ResidualNormStyle = ResidualNormStyle.Pre, + use_rotary_embeddings: bool = True, + mlp_multiplier: int = 4, + in_channels: int = 3, + dim_head: Optional[int] = None, +): + """ + A small helper to generate hierarchical xformers configurations, + which correspond for instance to poolformer or swin architectures. + + Contrary to more "classical" Transformer architectures, which conserve the sequence/context + length across layers, hierarchical Transformers trade the sequence length for the embedding dimension + """ + deprecated_function(get_hierarchical_configuration) + + base_config: Dict[str, Any] = { + "block_type": "encoder", + "dim_model": 0, + "use_triton": False, + "residual_norm_style": str(residual_norm_style), + "multi_head_config": { + "num_heads": 1, + "use_rotary_embeddings": use_rotary_embeddings, + "attention": { + "name": "TBD", + }, + }, + "feedforward_config": { + "name": "TBD", + "activation": "gelu", + "hidden_layer_multiplier": mlp_multiplier, + "dropout": 0.0, + }, + "position_encoding_config": { + "name": "learnable", + "seq_len": 0, + "add_class_token": False, + }, + "patch_embedding_config": { + "in_channels": in_channels, + "kernel_size": 0, + "stride": 0, + "padding": 0, + }, + } + + xformers_config = [] + in_channels = in_channels + + for layer_base_config in layer_base_configs: + lc = copy.deepcopy(base_config) + + lc["normalization"] = layer_base_config.normalization + + # Fill in the changing model dimensions + lc["dim_model"] = layer_base_config.embedding + + # Update the patches + lc["patch_embedding_config"] = { + "in_channels": in_channels, + "kernel_size": layer_base_config.patch_size, + "stride": layer_base_config.stride, + "padding": layer_base_config.padding, + } + + # Update the number of channels for the next layer + in_channels = lc["dim_model"] * 1 + + lc["position_encoding_config"]["seq_len"] = layer_base_config.seq_len + + # Fill in the number of heads (defaults to 1) + if dim_head is not None: + lc["multi_head_config"]["num_heads"] = ( + layer_base_config.embedding // dim_head + ) + assert layer_base_config.embedding % dim_head == 0 + + # Fill in the attention mechanism + lc["multi_head_config"]["attention"][ + "name" + ] = layer_base_config.attention_mechanism + + # FIll in the feedforward + lc["feedforward_config"]["name"] = layer_base_config.feedforward + + print(lc) + xformers_config.append(lc) + + # Handle repeated layers (without the patch embeddings) + if layer_base_config.repeat_layer > 1: + lc_repeat = copy.deepcopy(lc) + lc_repeat.pop("patch_embedding_config") + xformers_config += [lc_repeat] * (layer_base_config.repeat_layer - 1) + + return xformers_config diff --git a/.venv/lib/python3.11/site-packages/xformers/helpers/test_utils.py b/.venv/lib/python3.11/site-packages/xformers/helpers/test_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6438e3c4dba966abce01de5af230b800b116a353 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/helpers/test_utils.py @@ -0,0 +1,32 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import sys +import tempfile + +import torch + +is_windows = False +if sys.platform == "win32": # pytorch on windows uses gloo not ncll + is_windows = True + + +def init_torch_distributed_local(): + if torch.distributed.is_initialized(): + return + + init_url = "file://" + tempfile.mkstemp()[1] + backend = ( + torch.distributed.Backend.NCCL + if torch.cuda.is_available() and not is_windows + else torch.distributed.Backend.GLOO + ) + torch.distributed.init_process_group( + backend=backend, + rank=0, + world_size=1, + init_method=init_url, + ) diff --git a/.venv/lib/python3.11/site-packages/xformers/helpers/timm_sparse_attention.py b/.venv/lib/python3.11/site-packages/xformers/helpers/timm_sparse_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..d92ef7010102d931343c933576aeee15a1cd65bd --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/helpers/timm_sparse_attention.py @@ -0,0 +1,55 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import torch + +from xformers.components.attention.core import scaled_dot_product_attention + + +class TimmSparseAttention(torch.nn.Module): + """ + Almost drop-in replacement for timm attention + but using the sparsity-aware scaled_dot_product_attention from xformers + """ + + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + attn_drop=0.0, + proj_drop=0.0, + attn_mask=None, + ): + super().__init__() + self.num_heads = num_heads + + self.qkv = torch.nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = torch.nn.Dropout(attn_drop) + self.proj = torch.nn.Linear(dim, dim) + self.proj_drop = torch.nn.Dropout(proj_drop) + self.attn_mask = attn_mask + + def forward(self, x): + B, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + qkv = qkv.flatten(1, 2) + + q, k, v = qkv.unbind() + + x = scaled_dot_product_attention( + q, k, v, self.attn_mask, dropout=self.attn_drop + ) + x = x.reshape(B, self.num_heads, N, C // self.num_heads) + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x diff --git a/.venv/lib/python3.11/site-packages/xformers/triton/__init__.py b/.venv/lib/python3.11/site-packages/xformers/triton/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6677db08c0dcd87f7570a3249f61cc3eabe0532a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/triton/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. diff --git a/.venv/lib/python3.11/site-packages/xformers/triton/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/triton/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a381fd37c94d581fb42065d5d01d09f932ec7b0 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/triton/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/triton/__pycache__/vararg_kernel.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/triton/__pycache__/vararg_kernel.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cee7d74a388b557e8c93e4014cb377aded3df4e7 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/triton/__pycache__/vararg_kernel.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/triton/vararg_kernel.py b/.venv/lib/python3.11/site-packages/xformers/triton/vararg_kernel.py new file mode 100644 index 0000000000000000000000000000000000000000..c55beccb01d0993c84076eae176e09471ff4212d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/triton/vararg_kernel.py @@ -0,0 +1,188 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import ast +import copy +import functools +import linecache +import os +import sys +import tempfile +from typing import Any, Dict, List + +import triton + + +class _ForLoopUnroller(ast.NodeTransformer): + def __init__(self, target, inline_variables, loop_iter): + self.loop_iter = loop_iter + self.target = target + self.inline_variables = inline_variables + + def visit_Name(self, node): + if node.id != self.target: + return node + return ast.Name(str(self.loop_iter)) + + def visit_Subscript(self, node): + # Pattern-matching `value[slice]` + if ( + isinstance(node.slice, ast.Name) + and node.slice.id == self.target + and isinstance(node.value, ast.Name) + and node.value.id in self.inline_variables + ): + return ast.Name(f"{node.value.id}{self.loop_iter}") + return node + + +class _VisitorUnrollKernel(ast.NodeTransformer): + def __init__(self, N): + self.inline_variables = set() + self.N = N + + def visit_AnnAssign(self, node): + # Pattern-matching: + # var_name: "VAR_ARGS_ARRAY" + if ( + node.value is None + and node.simple == 1 + and isinstance(node.target, ast.Name) + and isinstance(node.annotation, ast.Constant) + and node.annotation.value == "VAR_ARGS_ARRAY" + ): + self.inline_variables.add(node.target.id) + return [] + if node.value is not None: + node.value = self.visit(node.value) + if node.annotation is not None: + node.annotation = self.visit(node.annotation) + if node.target is not None: + node.target = self.visit(node.target) + return node + + def visit_arguments(self, node): + # Replace `args` annotated with `VAR_ARGS_ARRAY` + new_args = [] + for arg in node.args: + if ( + arg.annotation is not None + and isinstance(arg.annotation, ast.Constant) + and arg.annotation.value == "VAR_ARGS_ARRAY" + ): + self.inline_variables.add(arg.arg) + new_args += [ast.arg(f"{arg.arg}{i}") for i in range(self.N)] + continue + new_args.append(arg) + if node.vararg is not None: + self.inline_variables.add(node.vararg.arg) + new_args += [ast.arg(f"{node.vararg.arg}{i}") for i in range(self.N)] + node.vararg = None + new_args += node.kwonlyargs + node.kwonlyargs = [] + node.args = new_args + return node + + def visit_For(self, node): + if ( + not isinstance(node.iter, ast.Call) + or node.iter.func.id != "range" + or len(node.iter.args) != 1 + or not isinstance(node.iter.args[0], ast.Call) + or node.iter.args[0].func.id != "len" + or len(node.iter.args[0].args) != 1 + or node.iter.args[0].args[0].id not in self.inline_variables + ): + node.body = [self.visit(x) for x in node.body] + return node + # We know we have to modify this loop + new_nodes = [] + for i in range(self.N): + unroller = _ForLoopUnroller( + target=node.target.id, + inline_variables=self.inline_variables, + loop_iter=i, + ) + for body in node.body: + body = copy.deepcopy(body) + new_node = ast.fix_missing_locations(unroller.visit(body)) + new_node = self.visit(new_node) + new_nodes.append(new_node) + return new_nodes + + +# Hackfix to get access to get source-code for +# `exec`-created functions - see https://stackoverflow.com/a/69668999 +_getlines_orig = None +_FILENAME_TO_SRC: Dict[str, List[str]] = {} + +# Materializing the codegen to disk can be useful for external tools, e.g. ncu +# Disabled by default because writing to disk at module import time is unexpected and error-prone. +_should_materialize_codegen = os.environ.get("XFORMERS_MATERIALIZE_CODEGEN") == "1" +_tmp_dir = None + + +def _monkey_patched_getlines(filename, module_globals=None): + if filename in _FILENAME_TO_SRC: + return _FILENAME_TO_SRC[filename] + else: + return _getlines_orig(filename, module_globals) # type: ignore + + +@functools.lru_cache(None) +def unroll_varargs(kernel, N: int): + """ + Specializes a triton kernel with variable number of inputs + to a specific number of inputs `N`. + NOTE: Because it's quite costly to call `triton.jit`, + we cache the returned value with `lru_cache` + """ + global _FILENAME_TO_SRC, _getlines_orig, _tmp_dir + + k = triton.JITFunction(kernel.fn) + parsed = ast.parse(k.src) + nodeVisitor = _VisitorUnrollKernel(N=N) + parsed = nodeVisitor.visit(parsed) + parsed = ast.fix_missing_locations(parsed) + + # NOTE: `ast.unparse` requires python 3.9+ + if (sys.version_info.major, sys.version_info.minor) <= (3, 8): + raise RuntimeError("Error: This functionality requires python 3.9 or above") + new_src = ast.unparse(parsed) # type: ignore + + # Now we want to `eval` the function, but we need all this + # boilerplate code to make sure triton can run `inspect.getsource` + + fn_basename = f"unroll_varargs-{kernel.fn.__name__}-{N}" + if _should_materialize_codegen: + if not _tmp_dir: + _tmp_dir = tempfile.TemporaryDirectory() + fn_filename = os.path.join(_tmp_dir.name, f"{fn_basename}.py") + with open(fn_filename, "w") as f: + f.write(new_src) + else: + # Patch `getlines` only the first time + if not _FILENAME_TO_SRC: + _getlines_orig = linecache.getlines + linecache.getlines = _monkey_patched_getlines + fn_filename = f"<{fn_basename}>" + _FILENAME_TO_SRC[fn_filename] = new_src.splitlines(keepends=True) + + # Create function given source + code = compile(new_src, fn_filename, "exec") + + _locals: Dict[str, Any] = {} + exec(code, kernel.fn.__globals__, _locals) + assert len(_locals) == 1, len(_locals) + fn = next(iter(_locals.values())) + + jitted_fn = triton.jit(fn) + jitted_fn.src = new_src + return jitted_fn + + +# Note: just import this to make mypy happy +# when annotating variables with `VAR_ARGS_ARRAY` +VAR_ARGS_ARRAY = List[Any]