diff --git a/lib/python3.10/site-packages/cachetools/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/cachetools/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e5a042bc810b67d58fd20df03ebb2de0b668f2f6 Binary files /dev/null and b/lib/python3.10/site-packages/cachetools/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/cachetools/__pycache__/_decorators.cpython-310.pyc b/lib/python3.10/site-packages/cachetools/__pycache__/_decorators.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fcb956d636c9e286c96bf210966d71b23c63e737 Binary files /dev/null and b/lib/python3.10/site-packages/cachetools/__pycache__/_decorators.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/cachetools/__pycache__/func.cpython-310.pyc b/lib/python3.10/site-packages/cachetools/__pycache__/func.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d017ca73f27a0bf13c516efbc39ac86944134e6 Binary files /dev/null and b/lib/python3.10/site-packages/cachetools/__pycache__/func.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/cachetools/__pycache__/keys.cpython-310.pyc b/lib/python3.10/site-packages/cachetools/__pycache__/keys.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd06003a55293cf1d99091f0ca6ef525e111daab Binary files /dev/null and b/lib/python3.10/site-packages/cachetools/__pycache__/keys.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/comm-0.2.3.dist-info/licenses/LICENSE b/lib/python3.10/site-packages/comm-0.2.3.dist-info/licenses/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..eee1b58d9ed05bdf064d2550c708a6015285de67 --- /dev/null +++ b/lib/python3.10/site-packages/comm-0.2.3.dist-info/licenses/LICENSE @@ -0,0 +1,29 @@ +BSD 3-Clause License + +Copyright (c) 2022, Jupyter +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/lib/python3.10/site-packages/filelock/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/filelock/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f32216f60e2a9bc8970c3425a4a7235f52bba0d Binary files /dev/null and b/lib/python3.10/site-packages/filelock/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/filelock/__pycache__/_api.cpython-310.pyc b/lib/python3.10/site-packages/filelock/__pycache__/_api.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62402eefa294b0680ef218e96a095e41fdb70e12 Binary files /dev/null and b/lib/python3.10/site-packages/filelock/__pycache__/_api.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/filelock/__pycache__/_error.cpython-310.pyc b/lib/python3.10/site-packages/filelock/__pycache__/_error.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e5c8889d998d4212d9aa979ea3b1c2b75d1d406c Binary files /dev/null and b/lib/python3.10/site-packages/filelock/__pycache__/_error.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/filelock/__pycache__/_soft.cpython-310.pyc b/lib/python3.10/site-packages/filelock/__pycache__/_soft.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca41dcb26664df7da8491159f4f79d12faf3c4a7 Binary files /dev/null and b/lib/python3.10/site-packages/filelock/__pycache__/_soft.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/filelock/__pycache__/_unix.cpython-310.pyc b/lib/python3.10/site-packages/filelock/__pycache__/_unix.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97e44f237b6e6e4af1f51a4395a5cd7809dbf375 Binary files /dev/null and b/lib/python3.10/site-packages/filelock/__pycache__/_unix.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/filelock/__pycache__/_util.cpython-310.pyc b/lib/python3.10/site-packages/filelock/__pycache__/_util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1080fb96f743de4a8f31a05d18ea45a372ab658 Binary files /dev/null and b/lib/python3.10/site-packages/filelock/__pycache__/_util.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/filelock/__pycache__/_windows.cpython-310.pyc b/lib/python3.10/site-packages/filelock/__pycache__/_windows.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d44ee25643c76973733c3c782ce167bdbc1be06 Binary files /dev/null and b/lib/python3.10/site-packages/filelock/__pycache__/_windows.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/filelock/__pycache__/asyncio.cpython-310.pyc b/lib/python3.10/site-packages/filelock/__pycache__/asyncio.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5b111d22aaf8e34b0dee66b52a1e6fb4feab395 Binary files /dev/null and b/lib/python3.10/site-packages/filelock/__pycache__/asyncio.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/filelock/__pycache__/version.cpython-310.pyc b/lib/python3.10/site-packages/filelock/__pycache__/version.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4fcffb887e27ea66be790eb6bbcc1d83d8a259d0 Binary files /dev/null and b/lib/python3.10/site-packages/filelock/__pycache__/version.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/google/auth/__init__.py b/lib/python3.10/site-packages/google/auth/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..765bbd70581a2a74eaabee8b856a86ef102b1103 --- /dev/null +++ b/lib/python3.10/site-packages/google/auth/__init__.py @@ -0,0 +1,53 @@ +# Copyright 2016 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Google Auth Library for Python.""" + +import logging +import sys +import warnings + +from google.auth import version as google_auth_version +from google.auth._default import ( + default, + load_credentials_from_dict, + load_credentials_from_file, +) + + +__version__ = google_auth_version.__version__ + + +__all__ = ["default", "load_credentials_from_file", "load_credentials_from_dict"] + + +class Python37DeprecationWarning(DeprecationWarning): # pragma: NO COVER + """ + Deprecation warning raised when Python 3.7 runtime is detected. + Python 3.7 support will be dropped after January 1, 2024. + """ + + pass + + +# Checks if the current runtime is Python 3.7. +if sys.version_info.major == 3 and sys.version_info.minor == 7: # pragma: NO COVER + message = ( + "After January 1, 2024, new releases of this library will drop support " + "for Python 3.7." + ) + warnings.warn(message, Python37DeprecationWarning) + +# Set default logging handler to avoid "No handler found" warnings. +logging.getLogger(__name__).addHandler(logging.NullHandler()) diff --git a/lib/python3.10/site-packages/google/auth/_cloud_sdk.py b/lib/python3.10/site-packages/google/auth/_cloud_sdk.py new file mode 100644 index 0000000000000000000000000000000000000000..a94411949bdfdd5b1550e8887ee45d7baffe1d18 --- /dev/null +++ b/lib/python3.10/site-packages/google/auth/_cloud_sdk.py @@ -0,0 +1,153 @@ +# Copyright 2015 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helpers for reading the Google Cloud SDK's configuration.""" + +import os +import subprocess + +from google.auth import _helpers +from google.auth import environment_vars +from google.auth import exceptions + + +# The ~/.config subdirectory containing gcloud credentials. +_CONFIG_DIRECTORY = "gcloud" +# Windows systems store config at %APPDATA%\gcloud +_WINDOWS_CONFIG_ROOT_ENV_VAR = "APPDATA" +# The name of the file in the Cloud SDK config that contains default +# credentials. +_CREDENTIALS_FILENAME = "application_default_credentials.json" +# The name of the Cloud SDK shell script +_CLOUD_SDK_POSIX_COMMAND = "gcloud" +_CLOUD_SDK_WINDOWS_COMMAND = "gcloud.cmd" +# The command to get the Cloud SDK configuration +_CLOUD_SDK_CONFIG_GET_PROJECT_COMMAND = ("config", "get", "project") +# The command to get google user access token +_CLOUD_SDK_USER_ACCESS_TOKEN_COMMAND = ("auth", "print-access-token") +# Cloud SDK's application-default client ID +CLOUD_SDK_CLIENT_ID = ( + "764086051850-6qr4p6gpi6hn506pt8ejuq83di341hur.apps.googleusercontent.com" +) + + +def get_config_path(): + """Returns the absolute path the the Cloud SDK's configuration directory. + + Returns: + str: The Cloud SDK config path. + """ + # If the path is explicitly set, return that. + try: + return os.environ[environment_vars.CLOUD_SDK_CONFIG_DIR] + except KeyError: + pass + + # Non-windows systems store this at ~/.config/gcloud + if os.name != "nt": + return os.path.join(os.path.expanduser("~"), ".config", _CONFIG_DIRECTORY) + # Windows systems store config at %APPDATA%\gcloud + else: + try: + return os.path.join( + os.environ[_WINDOWS_CONFIG_ROOT_ENV_VAR], _CONFIG_DIRECTORY + ) + except KeyError: + # This should never happen unless someone is really + # messing with things, but we'll cover the case anyway. + drive = os.environ.get("SystemDrive", "C:") + return os.path.join(drive, "\\", _CONFIG_DIRECTORY) + + +def get_application_default_credentials_path(): + """Gets the path to the application default credentials file. + + The path may or may not exist. + + Returns: + str: The full path to application default credentials. + """ + config_path = get_config_path() + return os.path.join(config_path, _CREDENTIALS_FILENAME) + + +def _run_subprocess_ignore_stderr(command): + """ Return subprocess.check_output with the given command and ignores stderr.""" + with open(os.devnull, "w") as devnull: + output = subprocess.check_output(command, stderr=devnull) + return output + + +def get_project_id(): + """Gets the project ID from the Cloud SDK. + + Returns: + Optional[str]: The project ID. + """ + if os.name == "nt": + command = _CLOUD_SDK_WINDOWS_COMMAND + else: + command = _CLOUD_SDK_POSIX_COMMAND + + try: + # Ignore the stderr coming from gcloud, so it won't be mixed into the output. + # https://github.com/googleapis/google-auth-library-python/issues/673 + project = _run_subprocess_ignore_stderr( + (command,) + _CLOUD_SDK_CONFIG_GET_PROJECT_COMMAND + ) + + # Turn bytes into a string and remove "\n" + project = _helpers.from_bytes(project).strip() + return project if project else None + except (subprocess.CalledProcessError, OSError, IOError): + return None + + +def get_auth_access_token(account=None): + """Load user access token with the ``gcloud auth print-access-token`` command. + + Args: + account (Optional[str]): Account to get the access token for. If not + specified, the current active account will be used. + + Returns: + str: The user access token. + + Raises: + google.auth.exceptions.UserAccessTokenError: if failed to get access + token from gcloud. + """ + if os.name == "nt": + command = _CLOUD_SDK_WINDOWS_COMMAND + else: + command = _CLOUD_SDK_POSIX_COMMAND + + try: + if account: + command = ( + (command,) + + _CLOUD_SDK_USER_ACCESS_TOKEN_COMMAND + + ("--account=" + account,) + ) + else: + command = (command,) + _CLOUD_SDK_USER_ACCESS_TOKEN_COMMAND + + access_token = subprocess.check_output(command, stderr=subprocess.STDOUT) + # remove the trailing "\n" + return access_token.decode("utf-8").strip() + except (subprocess.CalledProcessError, OSError, IOError) as caught_exc: + new_exc = exceptions.UserAccessTokenError( + "Failed to obtain access token", caught_exc + ) + raise new_exc from caught_exc diff --git a/lib/python3.10/site-packages/google/auth/_credentials_async.py b/lib/python3.10/site-packages/google/auth/_credentials_async.py new file mode 100644 index 0000000000000000000000000000000000000000..760758d851b05d880db9e43b91910b177701e1aa --- /dev/null +++ b/lib/python3.10/site-packages/google/auth/_credentials_async.py @@ -0,0 +1,171 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Interfaces for credentials.""" + +import abc +import inspect + +from google.auth import credentials + + +class Credentials(credentials.Credentials, metaclass=abc.ABCMeta): + """Async inherited credentials class from google.auth.credentials. + The added functionality is the before_request call which requires + async/await syntax. + All credentials have a :attr:`token` that is used for authentication and + may also optionally set an :attr:`expiry` to indicate when the token will + no longer be valid. + + Most credentials will be :attr:`invalid` until :meth:`refresh` is called. + Credentials can do this automatically before the first HTTP request in + :meth:`before_request`. + + Although the token and expiration will change as the credentials are + :meth:`refreshed ` and used, credentials should be considered + immutable. Various credentials will accept configuration such as private + keys, scopes, and other options. These options are not changeable after + construction. Some classes will provide mechanisms to copy the credentials + with modifications such as :meth:`ScopedCredentials.with_scopes`. + """ + + async def before_request(self, request, method, url, headers): + """Performs credential-specific before request logic. + + Refreshes the credentials if necessary, then calls :meth:`apply` to + apply the token to the authentication header. + + Args: + request (google.auth.transport.Request): The object used to make + HTTP requests. + method (str): The request's HTTP method or the RPC method being + invoked. + url (str): The request's URI or the RPC service's URI. + headers (Mapping): The request's headers. + """ + # pylint: disable=unused-argument + # (Subclasses may use these arguments to ascertain information about + # the http request.) + + if not self.valid: + if inspect.iscoroutinefunction(self.refresh): + await self.refresh(request) + else: + self.refresh(request) + self.apply(headers) + + +class CredentialsWithQuotaProject(credentials.CredentialsWithQuotaProject): + """Abstract base for credentials supporting ``with_quota_project`` factory""" + + +class AnonymousCredentials(credentials.AnonymousCredentials, Credentials): + """Credentials that do not provide any authentication information. + + These are useful in the case of services that support anonymous access or + local service emulators that do not use credentials. This class inherits + from the sync anonymous credentials file, but is kept if async credentials + is initialized and we would like anonymous credentials. + """ + + +class ReadOnlyScoped(credentials.ReadOnlyScoped, metaclass=abc.ABCMeta): + """Interface for credentials whose scopes can be queried. + + OAuth 2.0-based credentials allow limiting access using scopes as described + in `RFC6749 Section 3.3`_. + If a credential class implements this interface then the credentials either + use scopes in their implementation. + + Some credentials require scopes in order to obtain a token. You can check + if scoping is necessary with :attr:`requires_scopes`:: + + if credentials.requires_scopes: + # Scoping is required. + credentials = _credentials_async.with_scopes(scopes=['one', 'two']) + + Credentials that require scopes must either be constructed with scopes:: + + credentials = SomeScopedCredentials(scopes=['one', 'two']) + + Or must copy an existing instance using :meth:`with_scopes`:: + + scoped_credentials = _credentials_async.with_scopes(scopes=['one', 'two']) + + Some credentials have scopes but do not allow or require scopes to be set, + these credentials can be used as-is. + + .. _RFC6749 Section 3.3: https://tools.ietf.org/html/rfc6749#section-3.3 + """ + + +class Scoped(credentials.Scoped): + """Interface for credentials whose scopes can be replaced while copying. + + OAuth 2.0-based credentials allow limiting access using scopes as described + in `RFC6749 Section 3.3`_. + If a credential class implements this interface then the credentials either + use scopes in their implementation. + + Some credentials require scopes in order to obtain a token. You can check + if scoping is necessary with :attr:`requires_scopes`:: + + if credentials.requires_scopes: + # Scoping is required. + credentials = _credentials_async.create_scoped(['one', 'two']) + + Credentials that require scopes must either be constructed with scopes:: + + credentials = SomeScopedCredentials(scopes=['one', 'two']) + + Or must copy an existing instance using :meth:`with_scopes`:: + + scoped_credentials = credentials.with_scopes(scopes=['one', 'two']) + + Some credentials have scopes but do not allow or require scopes to be set, + these credentials can be used as-is. + + .. _RFC6749 Section 3.3: https://tools.ietf.org/html/rfc6749#section-3.3 + """ + + +def with_scopes_if_required(credentials, scopes): + """Creates a copy of the credentials with scopes if scoping is required. + + This helper function is useful when you do not know (or care to know) the + specific type of credentials you are using (such as when you use + :func:`google.auth.default`). This function will call + :meth:`Scoped.with_scopes` if the credentials are scoped credentials and if + the credentials require scoping. Otherwise, it will return the credentials + as-is. + + Args: + credentials (google.auth.credentials.Credentials): The credentials to + scope if necessary. + scopes (Sequence[str]): The list of scopes to use. + + Returns: + google.auth._credentials_async.Credentials: Either a new set of scoped + credentials, or the passed in credentials instance if no scoping + was required. + """ + if isinstance(credentials, Scoped) and credentials.requires_scopes: + return credentials.with_scopes(scopes) + else: + return credentials + + +class Signing(credentials.Signing, metaclass=abc.ABCMeta): + """Interface for credentials that can cryptographically sign messages.""" diff --git a/lib/python3.10/site-packages/google/auth/_credentials_base.py b/lib/python3.10/site-packages/google/auth/_credentials_base.py new file mode 100644 index 0000000000000000000000000000000000000000..64d5ce34b9a352984541ed7f78b42b94f4d01e26 --- /dev/null +++ b/lib/python3.10/site-packages/google/auth/_credentials_base.py @@ -0,0 +1,75 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Interface for base credentials.""" + +import abc + +from google.auth import _helpers + + +class _BaseCredentials(metaclass=abc.ABCMeta): + """Base class for all credentials. + + All credentials have a :attr:`token` that is used for authentication and + may also optionally set an :attr:`expiry` to indicate when the token will + no longer be valid. + + Most credentials will be :attr:`invalid` until :meth:`refresh` is called. + Credentials can do this automatically before the first HTTP request in + :meth:`before_request`. + + Although the token and expiration will change as the credentials are + :meth:`refreshed ` and used, credentials should be considered + immutable. Various credentials will accept configuration such as private + keys, scopes, and other options. These options are not changeable after + construction. Some classes will provide mechanisms to copy the credentials + with modifications such as :meth:`ScopedCredentials.with_scopes`. + + Attributes: + token (Optional[str]): The bearer token that can be used in HTTP headers to make + authenticated requests. + """ + + def __init__(self): + self.token = None + + @abc.abstractmethod + def refresh(self, request): + """Refreshes the access token. + + Args: + request (google.auth.transport.Request): The object used to make + HTTP requests. + + Raises: + google.auth.exceptions.RefreshError: If the credentials could + not be refreshed. + """ + # pylint: disable=missing-raises-doc + # (pylint doesn't recognize that this is abstract) + raise NotImplementedError("Refresh must be implemented") + + def _apply(self, headers, token=None): + """Apply the token to the authentication header. + + Args: + headers (Mapping): The HTTP request headers. + token (Optional[str]): If specified, overrides the current access + token. + """ + headers["authorization"] = "Bearer {}".format( + _helpers.from_bytes(token or self.token) + ) diff --git a/lib/python3.10/site-packages/google/auth/_default.py b/lib/python3.10/site-packages/google/auth/_default.py new file mode 100644 index 0000000000000000000000000000000000000000..cf0cdd7729885b959c195f9f47f77cf558f746b3 --- /dev/null +++ b/lib/python3.10/site-packages/google/auth/_default.py @@ -0,0 +1,685 @@ +# Copyright 2015 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Application default credentials. + +Implements application default credentials and project ID detection. +""" + +import io +import json +import logging +import os +import warnings + +from google.auth import environment_vars +from google.auth import exceptions +import google.auth.transport._http_client + +_LOGGER = logging.getLogger(__name__) + +# Valid types accepted for file-based credentials. +_AUTHORIZED_USER_TYPE = "authorized_user" +_SERVICE_ACCOUNT_TYPE = "service_account" +_EXTERNAL_ACCOUNT_TYPE = "external_account" +_EXTERNAL_ACCOUNT_AUTHORIZED_USER_TYPE = "external_account_authorized_user" +_IMPERSONATED_SERVICE_ACCOUNT_TYPE = "impersonated_service_account" +_GDCH_SERVICE_ACCOUNT_TYPE = "gdch_service_account" +_VALID_TYPES = ( + _AUTHORIZED_USER_TYPE, + _SERVICE_ACCOUNT_TYPE, + _EXTERNAL_ACCOUNT_TYPE, + _EXTERNAL_ACCOUNT_AUTHORIZED_USER_TYPE, + _IMPERSONATED_SERVICE_ACCOUNT_TYPE, + _GDCH_SERVICE_ACCOUNT_TYPE, +) + +# Help message when no credentials can be found. +_CLOUD_SDK_MISSING_CREDENTIALS = """\ +Your default credentials were not found. To set up Application Default Credentials, \ +see https://cloud.google.com/docs/authentication/external/set-up-adc for more information.\ +""" + +# Warning when using Cloud SDK user credentials +_CLOUD_SDK_CREDENTIALS_WARNING = """\ +Your application has authenticated using end user credentials from Google \ +Cloud SDK without a quota project. You might receive a "quota exceeded" \ +or "API not enabled" error. See the following page for troubleshooting: \ +https://cloud.google.com/docs/authentication/adc-troubleshooting/user-creds. \ +""" + +# The subject token type used for AWS external_account credentials. +_AWS_SUBJECT_TOKEN_TYPE = "urn:ietf:params:aws:token-type:aws4_request" + + +def _warn_about_problematic_credentials(credentials): + """Determines if the credentials are problematic. + + Credentials from the Cloud SDK that are associated with Cloud SDK's project + are problematic because they may not have APIs enabled and have limited + quota. If this is the case, warn about it. + """ + from google.auth import _cloud_sdk + + if credentials.client_id == _cloud_sdk.CLOUD_SDK_CLIENT_ID: + warnings.warn(_CLOUD_SDK_CREDENTIALS_WARNING) + + +def load_credentials_from_file( + filename, scopes=None, default_scopes=None, quota_project_id=None, request=None +): + """Loads Google credentials from a file. + + The credentials file must be a service account key, stored authorized + user credentials, external account credentials, or impersonated service + account credentials. + + .. warning:: + Important: If you accept a credential configuration (credential JSON/File/Stream) + from an external source for authentication to Google Cloud Platform, you must + validate it before providing it to any Google API or client library. Providing an + unvalidated credential configuration to Google APIs or libraries can compromise + the security of your systems and data. For more information, refer to + `Validate credential configurations from external sources`_. + + .. _Validate credential configurations from external sources: + https://cloud.google.com/docs/authentication/external/externally-sourced-credentials + + Args: + filename (str): The full path to the credentials file. + scopes (Optional[Sequence[str]]): The list of scopes for the credentials. If + specified, the credentials will automatically be scoped if + necessary + default_scopes (Optional[Sequence[str]]): Default scopes passed by a + Google client library. Use 'scopes' for user-defined scopes. + quota_project_id (Optional[str]): The project ID used for + quota and billing. + request (Optional[google.auth.transport.Request]): An object used to make + HTTP requests. This is used to determine the associated project ID + for a workload identity pool resource (external account credentials). + If not specified, then it will use a + google.auth.transport.requests.Request client to make requests. + + Returns: + Tuple[google.auth.credentials.Credentials, Optional[str]]: Loaded + credentials and the project ID. Authorized user credentials do not + have the project ID information. External account credentials project + IDs may not always be determined. + + Raises: + google.auth.exceptions.DefaultCredentialsError: if the file is in the + wrong format or is missing. + """ + if not os.path.exists(filename): + raise exceptions.DefaultCredentialsError( + "File {} was not found.".format(filename) + ) + + with io.open(filename, "r") as file_obj: + try: + info = json.load(file_obj) + except ValueError as caught_exc: + new_exc = exceptions.DefaultCredentialsError( + "File {} is not a valid json file.".format(filename), caught_exc + ) + raise new_exc from caught_exc + return _load_credentials_from_info( + filename, info, scopes, default_scopes, quota_project_id, request + ) + + +def load_credentials_from_dict( + info, scopes=None, default_scopes=None, quota_project_id=None, request=None +): + """Loads Google credentials from a dict. + + The credentials file must be a service account key, stored authorized + user credentials, external account credentials, or impersonated service + account credentials. + + .. warning:: + Important: If you accept a credential configuration (credential JSON/File/Stream) + from an external source for authentication to Google Cloud Platform, you must + validate it before providing it to any Google API or client library. Providing an + unvalidated credential configuration to Google APIs or libraries can compromise + the security of your systems and data. For more information, refer to + `Validate credential configurations from external sources`_. + + .. _Validate credential configurations from external sources: + https://cloud.google.com/docs/authentication/external/externally-sourced-credentials + + Args: + info (Dict[str, Any]): A dict object containing the credentials + scopes (Optional[Sequence[str]]): The list of scopes for the credentials. If + specified, the credentials will automatically be scoped if + necessary + default_scopes (Optional[Sequence[str]]): Default scopes passed by a + Google client library. Use 'scopes' for user-defined scopes. + quota_project_id (Optional[str]): The project ID used for + quota and billing. + request (Optional[google.auth.transport.Request]): An object used to make + HTTP requests. This is used to determine the associated project ID + for a workload identity pool resource (external account credentials). + If not specified, then it will use a + google.auth.transport.requests.Request client to make requests. + + Returns: + Tuple[google.auth.credentials.Credentials, Optional[str]]: Loaded + credentials and the project ID. Authorized user credentials do not + have the project ID information. External account credentials project + IDs may not always be determined. + + Raises: + google.auth.exceptions.DefaultCredentialsError: if the file is in the + wrong format or is missing. + """ + if not isinstance(info, dict): + raise exceptions.DefaultCredentialsError( + "info object was of type {} but dict type was expected.".format(type(info)) + ) + + return _load_credentials_from_info( + "dict object", info, scopes, default_scopes, quota_project_id, request + ) + + +def _load_credentials_from_info( + filename, info, scopes, default_scopes, quota_project_id, request +): + from google.auth.credentials import CredentialsWithQuotaProject + + credential_type = info.get("type") + + if credential_type == _AUTHORIZED_USER_TYPE: + credentials, project_id = _get_authorized_user_credentials( + filename, info, scopes + ) + + elif credential_type == _SERVICE_ACCOUNT_TYPE: + credentials, project_id = _get_service_account_credentials( + filename, info, scopes, default_scopes + ) + + elif credential_type == _EXTERNAL_ACCOUNT_TYPE: + credentials, project_id = _get_external_account_credentials( + info, + filename, + scopes=scopes, + default_scopes=default_scopes, + request=request, + ) + + elif credential_type == _EXTERNAL_ACCOUNT_AUTHORIZED_USER_TYPE: + credentials, project_id = _get_external_account_authorized_user_credentials( + filename, info, request + ) + + elif credential_type == _IMPERSONATED_SERVICE_ACCOUNT_TYPE: + credentials, project_id = _get_impersonated_service_account_credentials( + filename, info, scopes + ) + elif credential_type == _GDCH_SERVICE_ACCOUNT_TYPE: + credentials, project_id = _get_gdch_service_account_credentials(filename, info) + else: + raise exceptions.DefaultCredentialsError( + "The file {file} does not have a valid type. " + "Type is {type}, expected one of {valid_types}.".format( + file=filename, type=credential_type, valid_types=_VALID_TYPES + ) + ) + if isinstance(credentials, CredentialsWithQuotaProject): + credentials = _apply_quota_project_id(credentials, quota_project_id) + return credentials, project_id + + +def _get_gcloud_sdk_credentials(quota_project_id=None): + """Gets the credentials and project ID from the Cloud SDK.""" + from google.auth import _cloud_sdk + + _LOGGER.debug("Checking Cloud SDK credentials as part of auth process...") + + # Check if application default credentials exist. + credentials_filename = _cloud_sdk.get_application_default_credentials_path() + + if not os.path.isfile(credentials_filename): + _LOGGER.debug("Cloud SDK credentials not found on disk; not using them") + return None, None + + credentials, project_id = load_credentials_from_file( + credentials_filename, quota_project_id=quota_project_id + ) + credentials._cred_file_path = credentials_filename + + if not project_id: + project_id = _cloud_sdk.get_project_id() + + return credentials, project_id + + +def _get_explicit_environ_credentials(quota_project_id=None): + """Gets credentials from the GOOGLE_APPLICATION_CREDENTIALS environment + variable.""" + from google.auth import _cloud_sdk + + cloud_sdk_adc_path = _cloud_sdk.get_application_default_credentials_path() + explicit_file = os.environ.get(environment_vars.CREDENTIALS) + + _LOGGER.debug( + "Checking %s for explicit credentials as part of auth process...", explicit_file + ) + + if explicit_file is not None and explicit_file == cloud_sdk_adc_path: + # Cloud sdk flow calls gcloud to fetch project id, so if the explicit + # file path is cloud sdk credentials path, then we should fall back + # to cloud sdk flow, otherwise project id cannot be obtained. + _LOGGER.debug( + "Explicit credentials path %s is the same as Cloud SDK credentials path, fall back to Cloud SDK credentials flow...", + explicit_file, + ) + return _get_gcloud_sdk_credentials(quota_project_id=quota_project_id) + + if explicit_file is not None: + credentials, project_id = load_credentials_from_file( + os.environ[environment_vars.CREDENTIALS], quota_project_id=quota_project_id + ) + credentials._cred_file_path = f"{explicit_file} file via the GOOGLE_APPLICATION_CREDENTIALS environment variable" + + return credentials, project_id + + else: + return None, None + + +def _get_gae_credentials(): + """Gets Google App Engine App Identity credentials and project ID.""" + # If not GAE gen1, prefer the metadata service even if the GAE APIs are + # available as per https://google.aip.dev/auth/4115. + if os.environ.get(environment_vars.LEGACY_APPENGINE_RUNTIME) != "python27": + return None, None + + # While this library is normally bundled with app_engine, there are + # some cases where it's not available, so we tolerate ImportError. + try: + _LOGGER.debug("Checking for App Engine runtime as part of auth process...") + import google.auth.app_engine as app_engine + except ImportError: + _LOGGER.warning("Import of App Engine auth library failed.") + return None, None + + try: + credentials = app_engine.Credentials() + project_id = app_engine.get_project_id() + return credentials, project_id + except EnvironmentError: + _LOGGER.debug( + "No App Engine library was found so cannot authentication via App Engine Identity Credentials." + ) + return None, None + + +def _get_gce_credentials(request=None, quota_project_id=None): + """Gets credentials and project ID from the GCE Metadata Service.""" + # Ping requires a transport, but we want application default credentials + # to require no arguments. So, we'll use the _http_client transport which + # uses http.client. This is only acceptable because the metadata server + # doesn't do SSL and never requires proxies. + + # While this library is normally bundled with compute_engine, there are + # some cases where it's not available, so we tolerate ImportError. + try: + from google.auth import compute_engine + from google.auth.compute_engine import _metadata + except ImportError: + _LOGGER.warning("Import of Compute Engine auth library failed.") + return None, None + + if request is None: + request = google.auth.transport._http_client.Request() + + if _metadata.is_on_gce(request=request): + # Get the project ID. + try: + project_id = _metadata.get_project_id(request=request) + except exceptions.TransportError: + project_id = None + + cred = compute_engine.Credentials() + cred = _apply_quota_project_id(cred, quota_project_id) + + return cred, project_id + else: + _LOGGER.warning( + "Authentication failed using Compute Engine authentication due to unavailable metadata server." + ) + return None, None + + +def _get_external_account_credentials( + info, filename, scopes=None, default_scopes=None, request=None +): + """Loads external account Credentials from the parsed external account info. + + The credentials information must correspond to a supported external account + credentials. + + Args: + info (Mapping[str, str]): The external account info in Google format. + filename (str): The full path to the credentials file. + scopes (Optional[Sequence[str]]): The list of scopes for the credentials. If + specified, the credentials will automatically be scoped if + necessary. + default_scopes (Optional[Sequence[str]]): Default scopes passed by a + Google client library. Use 'scopes' for user-defined scopes. + request (Optional[google.auth.transport.Request]): An object used to make + HTTP requests. This is used to determine the associated project ID + for a workload identity pool resource (external account credentials). + If not specified, then it will use a + google.auth.transport.requests.Request client to make requests. + + Returns: + Tuple[google.auth.credentials.Credentials, Optional[str]]: Loaded + credentials and the project ID. External account credentials project + IDs may not always be determined. + + Raises: + google.auth.exceptions.DefaultCredentialsError: if the info dictionary + is in the wrong format or is missing required information. + """ + # There are currently 3 types of external_account credentials. + if info.get("subject_token_type") == _AWS_SUBJECT_TOKEN_TYPE: + # Check if configuration corresponds to an AWS credentials. + from google.auth import aws + + credentials = aws.Credentials.from_info( + info, scopes=scopes, default_scopes=default_scopes + ) + elif ( + info.get("credential_source") is not None + and info.get("credential_source").get("executable") is not None + ): + from google.auth import pluggable + + credentials = pluggable.Credentials.from_info( + info, scopes=scopes, default_scopes=default_scopes + ) + else: + try: + # Check if configuration corresponds to an Identity Pool credentials. + from google.auth import identity_pool + + credentials = identity_pool.Credentials.from_info( + info, scopes=scopes, default_scopes=default_scopes + ) + except ValueError: + # If the configuration is invalid or does not correspond to any + # supported external_account credentials, raise an error. + raise exceptions.DefaultCredentialsError( + "Failed to load external account credentials from {}".format(filename) + ) + if request is None: + import google.auth.transport.requests + + request = google.auth.transport.requests.Request() + + return credentials, credentials.get_project_id(request=request) + + +def _get_external_account_authorized_user_credentials( + filename, info, scopes=None, default_scopes=None, request=None +): + try: + from google.auth import external_account_authorized_user + + credentials = external_account_authorized_user.Credentials.from_info(info) + except ValueError: + raise exceptions.DefaultCredentialsError( + "Failed to load external account authorized user credentials from {}".format( + filename + ) + ) + + return credentials, None + + +def _get_authorized_user_credentials(filename, info, scopes=None): + from google.oauth2 import credentials + + try: + credentials = credentials.Credentials.from_authorized_user_info( + info, scopes=scopes + ) + except ValueError as caught_exc: + msg = "Failed to load authorized user credentials from {}".format(filename) + new_exc = exceptions.DefaultCredentialsError(msg, caught_exc) + raise new_exc from caught_exc + return credentials, None + + +def _get_service_account_credentials(filename, info, scopes=None, default_scopes=None): + from google.oauth2 import service_account + + try: + credentials = service_account.Credentials.from_service_account_info( + info, scopes=scopes, default_scopes=default_scopes + ) + except ValueError as caught_exc: + msg = "Failed to load service account credentials from {}".format(filename) + new_exc = exceptions.DefaultCredentialsError(msg, caught_exc) + raise new_exc from caught_exc + return credentials, info.get("project_id") + + +def _get_impersonated_service_account_credentials(filename, info, scopes): + from google.auth import impersonated_credentials + + try: + credentials = impersonated_credentials.Credentials.from_impersonated_service_account_info( + info, scopes=scopes + ) + except ValueError as caught_exc: + msg = "Failed to load impersonated service account credentials from {}".format( + filename + ) + new_exc = exceptions.DefaultCredentialsError(msg, caught_exc) + raise new_exc from caught_exc + return credentials, None + + +def _get_gdch_service_account_credentials(filename, info): + from google.oauth2 import gdch_credentials + + try: + credentials = gdch_credentials.ServiceAccountCredentials.from_service_account_info( + info + ) + except ValueError as caught_exc: + msg = "Failed to load GDCH service account credentials from {}".format(filename) + new_exc = exceptions.DefaultCredentialsError(msg, caught_exc) + raise new_exc from caught_exc + return credentials, info.get("project") + + +def get_api_key_credentials(key): + """Return credentials with the given API key.""" + from google.auth import api_key + + return api_key.Credentials(key) + + +def _apply_quota_project_id(credentials, quota_project_id): + if quota_project_id: + credentials = credentials.with_quota_project(quota_project_id) + else: + credentials = credentials.with_quota_project_from_environment() + + from google.oauth2 import credentials as authorized_user_credentials + + if isinstance(credentials, authorized_user_credentials.Credentials) and ( + not credentials.quota_project_id + ): + _warn_about_problematic_credentials(credentials) + return credentials + + +def default(scopes=None, request=None, quota_project_id=None, default_scopes=None): + """Gets the default credentials for the current environment. + + `Application Default Credentials`_ provides an easy way to obtain + credentials to call Google APIs for server-to-server or local applications. + This function acquires credentials from the environment in the following + order: + + 1. If the environment variable ``GOOGLE_APPLICATION_CREDENTIALS`` is set + to the path of a valid service account JSON private key file, then it is + loaded and returned. The project ID returned is the project ID defined + in the service account file if available (some older files do not + contain project ID information). + + If the environment variable is set to the path of a valid external + account JSON configuration file (workload identity federation), then the + configuration file is used to determine and retrieve the external + credentials from the current environment (AWS, Azure, etc). + These will then be exchanged for Google access tokens via the Google STS + endpoint. + The project ID returned in this case is the one corresponding to the + underlying workload identity pool resource if determinable. + + If the environment variable is set to the path of a valid GDCH service + account JSON file (`Google Distributed Cloud Hosted`_), then a GDCH + credential will be returned. The project ID returned is the project + specified in the JSON file. + 2. If the `Google Cloud SDK`_ is installed and has application default + credentials set they are loaded and returned. + + To enable application default credentials with the Cloud SDK run:: + + gcloud auth application-default login + + If the Cloud SDK has an active project, the project ID is returned. The + active project can be set using:: + + gcloud config set project + + 3. If the application is running in the `App Engine standard environment`_ + (first generation) then the credentials and project ID from the + `App Identity Service`_ are used. + 4. If the application is running in `Compute Engine`_ or `Cloud Run`_ or + the `App Engine flexible environment`_ or the `App Engine standard + environment`_ (second generation) then the credentials and project ID + are obtained from the `Metadata Service`_. + 5. If no credentials are found, + :class:`~google.auth.exceptions.DefaultCredentialsError` will be raised. + + .. _Application Default Credentials: https://developers.google.com\ + /identity/protocols/application-default-credentials + .. _Google Cloud SDK: https://cloud.google.com/sdk + .. _App Engine standard environment: https://cloud.google.com/appengine + .. _App Identity Service: https://cloud.google.com/appengine/docs/python\ + /appidentity/ + .. _Compute Engine: https://cloud.google.com/compute + .. _App Engine flexible environment: https://cloud.google.com\ + /appengine/flexible + .. _Metadata Service: https://cloud.google.com/compute/docs\ + /storing-retrieving-metadata + .. _Cloud Run: https://cloud.google.com/run + .. _Google Distributed Cloud Hosted: https://cloud.google.com/blog/topics\ + /hybrid-cloud/announcing-google-distributed-cloud-edge-and-hosted + + Example:: + + import google.auth + + credentials, project_id = google.auth.default() + + Args: + scopes (Sequence[str]): The list of scopes for the credentials. If + specified, the credentials will automatically be scoped if + necessary. + request (Optional[google.auth.transport.Request]): An object used to make + HTTP requests. This is used to either detect whether the application + is running on Compute Engine or to determine the associated project + ID for a workload identity pool resource (external account + credentials). If not specified, then it will either use the standard + library http client to make requests for Compute Engine credentials + or a google.auth.transport.requests.Request client for external + account credentials. + quota_project_id (Optional[str]): The project ID used for + quota and billing. + default_scopes (Optional[Sequence[str]]): Default scopes passed by a + Google client library. Use 'scopes' for user-defined scopes. + Returns: + Tuple[~google.auth.credentials.Credentials, Optional[str]]: + the current environment's credentials and project ID. Project ID + may be None, which indicates that the Project ID could not be + ascertained from the environment. + + Raises: + ~google.auth.exceptions.DefaultCredentialsError: + If no credentials were found, or if the credentials found were + invalid. + """ + from google.auth.credentials import with_scopes_if_required + from google.auth.credentials import CredentialsWithQuotaProject + + explicit_project_id = os.environ.get( + environment_vars.PROJECT, os.environ.get(environment_vars.LEGACY_PROJECT) + ) + + checkers = ( + # Avoid passing scopes here to prevent passing scopes to user credentials. + # with_scopes_if_required() below will ensure scopes/default scopes are + # safely set on the returned credentials since requires_scopes will + # guard against setting scopes on user credentials. + lambda: _get_explicit_environ_credentials(quota_project_id=quota_project_id), + lambda: _get_gcloud_sdk_credentials(quota_project_id=quota_project_id), + _get_gae_credentials, + lambda: _get_gce_credentials(request, quota_project_id=quota_project_id), + ) + + for checker in checkers: + credentials, project_id = checker() + if credentials is not None: + credentials = with_scopes_if_required( + credentials, scopes, default_scopes=default_scopes + ) + + effective_project_id = explicit_project_id or project_id + + # For external account credentials, scopes are required to determine + # the project ID. Try to get the project ID again if not yet + # determined. + if not effective_project_id and callable( + getattr(credentials, "get_project_id", None) + ): + if request is None: + import google.auth.transport.requests + + request = google.auth.transport.requests.Request() + effective_project_id = credentials.get_project_id(request=request) + + if quota_project_id and isinstance( + credentials, CredentialsWithQuotaProject + ): + credentials = credentials.with_quota_project(quota_project_id) + + if not effective_project_id: + _LOGGER.warning( + "No project ID could be determined. Consider running " + "`gcloud config set project` or setting the %s " + "environment variable", + environment_vars.PROJECT, + ) + return credentials, effective_project_id + + raise exceptions.DefaultCredentialsError(_CLOUD_SDK_MISSING_CREDENTIALS) diff --git a/lib/python3.10/site-packages/google/auth/_default_async.py b/lib/python3.10/site-packages/google/auth/_default_async.py new file mode 100644 index 0000000000000000000000000000000000000000..2e53e20887595e8fdf3b46b3d1031eb507accf8e --- /dev/null +++ b/lib/python3.10/site-packages/google/auth/_default_async.py @@ -0,0 +1,282 @@ +# Copyright 2020 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Application default credentials. + +Implements application default credentials and project ID detection. +""" + +import io +import json +import os + +from google.auth import _default +from google.auth import environment_vars +from google.auth import exceptions + + +def load_credentials_from_file(filename, scopes=None, quota_project_id=None): + """Loads Google credentials from a file. + + The credentials file must be a service account key or stored authorized + user credentials. + + Args: + filename (str): The full path to the credentials file. + scopes (Optional[Sequence[str]]): The list of scopes for the credentials. If + specified, the credentials will automatically be scoped if + necessary + quota_project_id (Optional[str]): The project ID used for + quota and billing. + + Returns: + Tuple[google.auth.credentials.Credentials, Optional[str]]: Loaded + credentials and the project ID. Authorized user credentials do not + have the project ID information. + + Raises: + google.auth.exceptions.DefaultCredentialsError: if the file is in the + wrong format or is missing. + """ + if not os.path.exists(filename): + raise exceptions.DefaultCredentialsError( + "File {} was not found.".format(filename) + ) + + with io.open(filename, "r") as file_obj: + try: + info = json.load(file_obj) + except ValueError as caught_exc: + new_exc = exceptions.DefaultCredentialsError( + "File {} is not a valid json file.".format(filename), caught_exc + ) + raise new_exc from caught_exc + + # The type key should indicate that the file is either a service account + # credentials file or an authorized user credentials file. + credential_type = info.get("type") + + if credential_type == _default._AUTHORIZED_USER_TYPE: + from google.oauth2 import _credentials_async as credentials + + try: + credentials = credentials.Credentials.from_authorized_user_info( + info, scopes=scopes + ) + except ValueError as caught_exc: + msg = "Failed to load authorized user credentials from {}".format(filename) + new_exc = exceptions.DefaultCredentialsError(msg, caught_exc) + raise new_exc from caught_exc + if quota_project_id: + credentials = credentials.with_quota_project(quota_project_id) + if not credentials.quota_project_id: + _default._warn_about_problematic_credentials(credentials) + return credentials, None + + elif credential_type == _default._SERVICE_ACCOUNT_TYPE: + from google.oauth2 import _service_account_async as service_account + + try: + credentials = service_account.Credentials.from_service_account_info( + info, scopes=scopes + ).with_quota_project(quota_project_id) + except ValueError as caught_exc: + msg = "Failed to load service account credentials from {}".format(filename) + new_exc = exceptions.DefaultCredentialsError(msg, caught_exc) + raise new_exc from caught_exc + return credentials, info.get("project_id") + + else: + raise exceptions.DefaultCredentialsError( + "The file {file} does not have a valid type. " + "Type is {type}, expected one of {valid_types}.".format( + file=filename, type=credential_type, valid_types=_default._VALID_TYPES + ) + ) + + +def _get_gcloud_sdk_credentials(quota_project_id=None): + """Gets the credentials and project ID from the Cloud SDK.""" + from google.auth import _cloud_sdk + + # Check if application default credentials exist. + credentials_filename = _cloud_sdk.get_application_default_credentials_path() + + if not os.path.isfile(credentials_filename): + return None, None + + credentials, project_id = load_credentials_from_file( + credentials_filename, quota_project_id=quota_project_id + ) + + if not project_id: + project_id = _cloud_sdk.get_project_id() + + return credentials, project_id + + +def _get_explicit_environ_credentials(quota_project_id=None): + """Gets credentials from the GOOGLE_APPLICATION_CREDENTIALS environment + variable.""" + from google.auth import _cloud_sdk + + cloud_sdk_adc_path = _cloud_sdk.get_application_default_credentials_path() + explicit_file = os.environ.get(environment_vars.CREDENTIALS) + + if explicit_file is not None and explicit_file == cloud_sdk_adc_path: + # Cloud sdk flow calls gcloud to fetch project id, so if the explicit + # file path is cloud sdk credentials path, then we should fall back + # to cloud sdk flow, otherwise project id cannot be obtained. + return _get_gcloud_sdk_credentials(quota_project_id=quota_project_id) + + if explicit_file is not None: + credentials, project_id = load_credentials_from_file( + os.environ[environment_vars.CREDENTIALS], quota_project_id=quota_project_id + ) + + return credentials, project_id + + else: + return None, None + + +def _get_gae_credentials(): + """Gets Google App Engine App Identity credentials and project ID.""" + # While this library is normally bundled with app_engine, there are + # some cases where it's not available, so we tolerate ImportError. + + return _default._get_gae_credentials() + + +def _get_gce_credentials(request=None): + """Gets credentials and project ID from the GCE Metadata Service.""" + # Ping requires a transport, but we want application default credentials + # to require no arguments. So, we'll use the _http_client transport which + # uses http.client. This is only acceptable because the metadata server + # doesn't do SSL and never requires proxies. + + # While this library is normally bundled with compute_engine, there are + # some cases where it's not available, so we tolerate ImportError. + + return _default._get_gce_credentials(request) + + +def default_async(scopes=None, request=None, quota_project_id=None): + """Gets the default credentials for the current environment. + + `Application Default Credentials`_ provides an easy way to obtain + credentials to call Google APIs for server-to-server or local applications. + This function acquires credentials from the environment in the following + order: + + 1. If the environment variable ``GOOGLE_APPLICATION_CREDENTIALS`` is set + to the path of a valid service account JSON private key file, then it is + loaded and returned. The project ID returned is the project ID defined + in the service account file if available (some older files do not + contain project ID information). + 2. If the `Google Cloud SDK`_ is installed and has application default + credentials set they are loaded and returned. + + To enable application default credentials with the Cloud SDK run:: + + gcloud auth application-default login + + If the Cloud SDK has an active project, the project ID is returned. The + active project can be set using:: + + gcloud config set project + + 3. If the application is running in the `App Engine standard environment`_ + (first generation) then the credentials and project ID from the + `App Identity Service`_ are used. + 4. If the application is running in `Compute Engine`_ or `Cloud Run`_ or + the `App Engine flexible environment`_ or the `App Engine standard + environment`_ (second generation) then the credentials and project ID + are obtained from the `Metadata Service`_. + 5. If no credentials are found, + :class:`~google.auth.exceptions.DefaultCredentialsError` will be raised. + + .. _Application Default Credentials: https://developers.google.com\ + /identity/protocols/application-default-credentials + .. _Google Cloud SDK: https://cloud.google.com/sdk + .. _App Engine standard environment: https://cloud.google.com/appengine + .. _App Identity Service: https://cloud.google.com/appengine/docs/python\ + /appidentity/ + .. _Compute Engine: https://cloud.google.com/compute + .. _App Engine flexible environment: https://cloud.google.com\ + /appengine/flexible + .. _Metadata Service: https://cloud.google.com/compute/docs\ + /storing-retrieving-metadata + .. _Cloud Run: https://cloud.google.com/run + + Example:: + + import google.auth + + credentials, project_id = google.auth.default() + + Args: + scopes (Sequence[str]): The list of scopes for the credentials. If + specified, the credentials will automatically be scoped if + necessary. + request (google.auth.transport.Request): An object used to make + HTTP requests. This is used to detect whether the application + is running on Compute Engine. If not specified, then it will + use the standard library http client to make requests. + quota_project_id (Optional[str]): The project ID used for + quota and billing. + Returns: + Tuple[~google.auth.credentials.Credentials, Optional[str]]: + the current environment's credentials and project ID. Project ID + may be None, which indicates that the Project ID could not be + ascertained from the environment. + + Raises: + ~google.auth.exceptions.DefaultCredentialsError: + If no credentials were found, or if the credentials found were + invalid. + """ + from google.auth._credentials_async import with_scopes_if_required + from google.auth.credentials import CredentialsWithQuotaProject + + explicit_project_id = os.environ.get( + environment_vars.PROJECT, os.environ.get(environment_vars.LEGACY_PROJECT) + ) + + checkers = ( + lambda: _get_explicit_environ_credentials(quota_project_id=quota_project_id), + lambda: _get_gcloud_sdk_credentials(quota_project_id=quota_project_id), + _get_gae_credentials, + lambda: _get_gce_credentials(request), + ) + + for checker in checkers: + credentials, project_id = checker() + if credentials is not None: + credentials = with_scopes_if_required(credentials, scopes) + if quota_project_id and isinstance( + credentials, CredentialsWithQuotaProject + ): + credentials = credentials.with_quota_project(quota_project_id) + effective_project_id = explicit_project_id or project_id + if not effective_project_id: + _default._LOGGER.warning( + "No project ID could be determined. Consider running " + "`gcloud config set project` or setting the %s " + "environment variable", + environment_vars.PROJECT, + ) + return credentials, effective_project_id + + raise exceptions.DefaultCredentialsError(_default._CLOUD_SDK_MISSING_CREDENTIALS) diff --git a/lib/python3.10/site-packages/google/auth/_exponential_backoff.py b/lib/python3.10/site-packages/google/auth/_exponential_backoff.py new file mode 100644 index 0000000000000000000000000000000000000000..89853448f9fc71f47198d35607926361271d0dd1 --- /dev/null +++ b/lib/python3.10/site-packages/google/auth/_exponential_backoff.py @@ -0,0 +1,164 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import random +import time + +from google.auth import exceptions + +# The default amount of retry attempts +_DEFAULT_RETRY_TOTAL_ATTEMPTS = 3 + +# The default initial backoff period (1.0 second). +_DEFAULT_INITIAL_INTERVAL_SECONDS = 1.0 + +# The default randomization factor (0.1 which results in a random period ranging +# between 10% below and 10% above the retry interval). +_DEFAULT_RANDOMIZATION_FACTOR = 0.1 + +# The default multiplier value (2 which is 100% increase per back off). +_DEFAULT_MULTIPLIER = 2.0 + +"""Exponential Backoff Utility + +This is a private module that implements the exponential back off algorithm. +It can be used as a utility for code that needs to retry on failure, for example +an HTTP request. +""" + + +class _BaseExponentialBackoff: + """An exponential backoff iterator base class. + + Args: + total_attempts Optional[int]: + The maximum amount of retries that should happen. + The default value is 3 attempts. + initial_wait_seconds Optional[int]: + The amount of time to sleep in the first backoff. This parameter + should be in seconds. + The default value is 1 second. + randomization_factor Optional[float]: + The amount of jitter that should be in each backoff. For example, + a value of 0.1 will introduce a jitter range of 10% to the + current backoff period. + The default value is 0.1. + multiplier Optional[float]: + The backoff multipler. This adjusts how much each backoff will + increase. For example a value of 2.0 leads to a 200% backoff + on each attempt. If the initial_wait is 1.0 it would look like + this sequence [1.0, 2.0, 4.0, 8.0]. + The default value is 2.0. + """ + + def __init__( + self, + total_attempts=_DEFAULT_RETRY_TOTAL_ATTEMPTS, + initial_wait_seconds=_DEFAULT_INITIAL_INTERVAL_SECONDS, + randomization_factor=_DEFAULT_RANDOMIZATION_FACTOR, + multiplier=_DEFAULT_MULTIPLIER, + ): + if total_attempts < 1: + raise exceptions.InvalidValue( + f"total_attempts must be greater than or equal to 1 but was {total_attempts}" + ) + + self._total_attempts = total_attempts + self._initial_wait_seconds = initial_wait_seconds + + self._current_wait_in_seconds = self._initial_wait_seconds + + self._randomization_factor = randomization_factor + self._multiplier = multiplier + self._backoff_count = 0 + + @property + def total_attempts(self): + """The total amount of backoff attempts that will be made.""" + return self._total_attempts + + @property + def backoff_count(self): + """The current amount of backoff attempts that have been made.""" + return self._backoff_count + + def _reset(self): + self._backoff_count = 0 + self._current_wait_in_seconds = self._initial_wait_seconds + + def _calculate_jitter(self): + jitter_variance = self._current_wait_in_seconds * self._randomization_factor + jitter = random.uniform( + self._current_wait_in_seconds - jitter_variance, + self._current_wait_in_seconds + jitter_variance, + ) + + return jitter + + +class ExponentialBackoff(_BaseExponentialBackoff): + """An exponential backoff iterator. This can be used in a for loop to + perform requests with exponential backoff. + """ + + def __init__(self, *args, **kwargs): + super(ExponentialBackoff, self).__init__(*args, **kwargs) + + def __iter__(self): + self._reset() + return self + + def __next__(self): + if self._backoff_count >= self._total_attempts: + raise StopIteration + self._backoff_count += 1 + + if self._backoff_count <= 1: + return self._backoff_count + + jitter = self._calculate_jitter() + + time.sleep(jitter) + + self._current_wait_in_seconds *= self._multiplier + return self._backoff_count + + +class AsyncExponentialBackoff(_BaseExponentialBackoff): + """An async exponential backoff iterator. This can be used in a for loop to + perform async requests with exponential backoff. + """ + + def __init__(self, *args, **kwargs): + super(AsyncExponentialBackoff, self).__init__(*args, **kwargs) + + def __aiter__(self): + self._reset() + return self + + async def __anext__(self): + if self._backoff_count >= self._total_attempts: + raise StopAsyncIteration + self._backoff_count += 1 + + if self._backoff_count <= 1: + return self._backoff_count + + jitter = self._calculate_jitter() + + await asyncio.sleep(jitter) + + self._current_wait_in_seconds *= self._multiplier + return self._backoff_count diff --git a/lib/python3.10/site-packages/google/auth/_helpers.py b/lib/python3.10/site-packages/google/auth/_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..78fe22f72644786830c8e0dbb60404c10eba0059 --- /dev/null +++ b/lib/python3.10/site-packages/google/auth/_helpers.py @@ -0,0 +1,513 @@ +# Copyright 2015 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helper functions for commonly used utilities.""" + +import base64 +import calendar +import datetime +from email.message import Message +import hashlib +import json +import logging +import sys +from typing import Any, Dict, Mapping, Optional, Union +import urllib + +from google.auth import exceptions + + +# _BASE_LOGGER_NAME is the base logger for all google-based loggers. +_BASE_LOGGER_NAME = "google" + +# _LOGGING_INITIALIZED ensures that base logger is only configured once +# (unless already configured by the end-user). +_LOGGING_INITIALIZED = False + + +# The smallest MDS cache used by this library stores tokens until 4 minutes from +# expiry. +REFRESH_THRESHOLD = datetime.timedelta(minutes=3, seconds=45) + +# TODO(https://github.com/googleapis/google-auth-library-python/issues/1684): Audit and update the list below. +_SENSITIVE_FIELDS = { + "accessToken", + "access_token", + "id_token", + "client_id", + "refresh_token", + "client_secret", +} + + +def copy_docstring(source_class): + """Decorator that copies a method's docstring from another class. + + Args: + source_class (type): The class that has the documented method. + + Returns: + Callable: A decorator that will copy the docstring of the same + named method in the source class to the decorated method. + """ + + def decorator(method): + """Decorator implementation. + + Args: + method (Callable): The method to copy the docstring to. + + Returns: + Callable: the same method passed in with an updated docstring. + + Raises: + google.auth.exceptions.InvalidOperation: if the method already has a docstring. + """ + if method.__doc__: + raise exceptions.InvalidOperation("Method already has a docstring.") + + source_method = getattr(source_class, method.__name__) + method.__doc__ = source_method.__doc__ + + return method + + return decorator + + +def parse_content_type(header_value): + """Parse a 'content-type' header value to get just the plain media-type (without parameters). + + This is done using the class Message from email.message as suggested in PEP 594 + (because the cgi is now deprecated and will be removed in python 3.13, + see https://peps.python.org/pep-0594/#cgi). + + Args: + header_value (str): The value of a 'content-type' header as a string. + + Returns: + str: A string with just the lowercase media-type from the parsed 'content-type' header. + If the provided content-type is not parsable, returns 'text/plain', + the default value for textual files. + """ + m = Message() + m["content-type"] = header_value + return ( + m.get_content_type() + ) # Despite the name, actually returns just the media-type + + +def utcnow(): + """Returns the current UTC datetime. + + Returns: + datetime: The current time in UTC. + """ + # We used datetime.utcnow() before, since it's deprecated from python 3.12, + # we are using datetime.now(timezone.utc) now. "utcnow()" is offset-native + # (no timezone info), but "now()" is offset-aware (with timezone info). + # This will cause datetime comparison problem. For backward compatibility, + # we need to remove the timezone info. + now = datetime.datetime.now(datetime.timezone.utc) + now = now.replace(tzinfo=None) + return now + + +def datetime_to_secs(value): + """Convert a datetime object to the number of seconds since the UNIX epoch. + + Args: + value (datetime): The datetime to convert. + + Returns: + int: The number of seconds since the UNIX epoch. + """ + return calendar.timegm(value.utctimetuple()) + + +def to_bytes(value, encoding="utf-8"): + """Converts a string value to bytes, if necessary. + + Args: + value (Union[str, bytes]): The value to be converted. + encoding (str): The encoding to use to convert unicode to bytes. + Defaults to "utf-8". + + Returns: + bytes: The original value converted to bytes (if unicode) or as + passed in if it started out as bytes. + + Raises: + google.auth.exceptions.InvalidValue: If the value could not be converted to bytes. + """ + result = value.encode(encoding) if isinstance(value, str) else value + if isinstance(result, bytes): + return result + else: + raise exceptions.InvalidValue( + "{0!r} could not be converted to bytes".format(value) + ) + + +def from_bytes(value): + """Converts bytes to a string value, if necessary. + + Args: + value (Union[str, bytes]): The value to be converted. + + Returns: + str: The original value converted to unicode (if bytes) or as passed in + if it started out as unicode. + + Raises: + google.auth.exceptions.InvalidValue: If the value could not be converted to unicode. + """ + result = value.decode("utf-8") if isinstance(value, bytes) else value + if isinstance(result, str): + return result + else: + raise exceptions.InvalidValue( + "{0!r} could not be converted to unicode".format(value) + ) + + +def update_query(url, params, remove=None): + """Updates a URL's query parameters. + + Replaces any current values if they are already present in the URL. + + Args: + url (str): The URL to update. + params (Mapping[str, str]): A mapping of query parameter + keys to values. + remove (Sequence[str]): Parameters to remove from the query string. + + Returns: + str: The URL with updated query parameters. + + Examples: + + >>> url = 'http://example.com?a=1' + >>> update_query(url, {'a': '2'}) + http://example.com?a=2 + >>> update_query(url, {'b': '3'}) + http://example.com?a=1&b=3 + >> update_query(url, {'b': '3'}, remove=['a']) + http://example.com?b=3 + + """ + if remove is None: + remove = [] + + # Split the URL into parts. + parts = urllib.parse.urlparse(url) + # Parse the query string. + query_params = urllib.parse.parse_qs(parts.query) + # Update the query parameters with the new parameters. + query_params.update(params) + # Remove any values specified in remove. + query_params = { + key: value for key, value in query_params.items() if key not in remove + } + # Re-encoded the query string. + new_query = urllib.parse.urlencode(query_params, doseq=True) + # Unsplit the url. + new_parts = parts._replace(query=new_query) + return urllib.parse.urlunparse(new_parts) + + +def scopes_to_string(scopes): + """Converts scope value to a string suitable for sending to OAuth 2.0 + authorization servers. + + Args: + scopes (Sequence[str]): The sequence of scopes to convert. + + Returns: + str: The scopes formatted as a single string. + """ + return " ".join(scopes) + + +def string_to_scopes(scopes): + """Converts stringifed scopes value to a list. + + Args: + scopes (Union[Sequence, str]): The string of space-separated scopes + to convert. + Returns: + Sequence(str): The separated scopes. + """ + if not scopes: + return [] + + return scopes.split(" ") + + +def padded_urlsafe_b64decode(value): + """Decodes base64 strings lacking padding characters. + + Google infrastructure tends to omit the base64 padding characters. + + Args: + value (Union[str, bytes]): The encoded value. + + Returns: + bytes: The decoded value + """ + b64string = to_bytes(value) + padded = b64string + b"=" * (-len(b64string) % 4) + return base64.urlsafe_b64decode(padded) + + +def unpadded_urlsafe_b64encode(value): + """Encodes base64 strings removing any padding characters. + + `rfc 7515`_ defines Base64url to NOT include any padding + characters, but the stdlib doesn't do that by default. + + _rfc7515: https://tools.ietf.org/html/rfc7515#page-6 + + Args: + value (Union[str|bytes]): The bytes-like value to encode + + Returns: + Union[str|bytes]: The encoded value + """ + return base64.urlsafe_b64encode(value).rstrip(b"=") + + +def is_python_3(): + """Check if the Python interpreter is Python 2 or 3. + + Returns: + bool: True if the Python interpreter is Python 3 and False otherwise. + """ + return sys.version_info > (3, 0) + + +def _hash_sensitive_info(data: Union[dict, list]) -> Union[dict, list, str]: + """ + Hashes sensitive information within a dictionary. + + Args: + data: The dictionary containing data to be processed. + + Returns: + A new dictionary with sensitive values replaced by their SHA512 hashes. + If the input is a list, returns a list with each element recursively processed. + If the input is neither a dict nor a list, returns the type of the input as a string. + + """ + if isinstance(data, dict): + hashed_data: Dict[Any, Union[Optional[str], dict, list]] = {} + for key, value in data.items(): + if key in _SENSITIVE_FIELDS and not isinstance(value, (dict, list)): + hashed_data[key] = _hash_value(value, key) + elif isinstance(value, (dict, list)): + hashed_data[key] = _hash_sensitive_info(value) + else: + hashed_data[key] = value + return hashed_data + elif isinstance(data, list): + hashed_list = [] + for val in data: + hashed_list.append(_hash_sensitive_info(val)) + return hashed_list + else: + # TODO(https://github.com/googleapis/google-auth-library-python/issues/1701): + # Investigate and hash sensitive info before logging when the data type is + # not a dict or a list. + return str(type(data)) + + +def _hash_value(value, field_name: str) -> Optional[str]: + """Hashes a value and returns a formatted hash string.""" + if value is None: + return None + encoded_value = str(value).encode("utf-8") + hash_object = hashlib.sha512() + hash_object.update(encoded_value) + hex_digest = hash_object.hexdigest() + return f"hashed_{field_name}-{hex_digest}" + + +def _logger_configured(logger: logging.Logger) -> bool: + """Determines whether `logger` has non-default configuration + + Args: + logger: The logger to check. + + Returns: + bool: Whether the logger has any non-default configuration. + """ + return ( + logger.handlers != [] or logger.level != logging.NOTSET or not logger.propagate + ) + + +def is_logging_enabled(logger: logging.Logger) -> bool: + """ + Checks if debug logging is enabled for the given logger. + + Args: + logger: The logging.Logger instance to check. + + Returns: + True if debug logging is enabled, False otherwise. + """ + # NOTE: Log propagation to the root logger is disabled unless + # the base logger i.e. logging.getLogger("google") is + # explicitly configured by the end user. Ideally this + # needs to happen in the client layer (already does for GAPICs). + # However, this is implemented here to avoid logging + # (if a root logger is configured) when a version of google-auth + # which supports logging is used with: + # - an older version of a GAPIC which does not support logging. + # - Apiary client which does not support logging. + global _LOGGING_INITIALIZED + if not _LOGGING_INITIALIZED: + base_logger = logging.getLogger(_BASE_LOGGER_NAME) + if not _logger_configured(base_logger): + base_logger.propagate = False + _LOGGING_INITIALIZED = True + + return logger.isEnabledFor(logging.DEBUG) + + +def request_log( + logger: logging.Logger, + method: str, + url: str, + body: Optional[bytes], + headers: Optional[Mapping[str, str]], +) -> None: + """ + Logs an HTTP request at the DEBUG level if logging is enabled. + + Args: + logger: The logging.Logger instance to use. + method: The HTTP method (e.g., "GET", "POST"). + url: The URL of the request. + body: The request body (can be None). + headers: The request headers (can be None). + """ + if is_logging_enabled(logger): + content_type = ( + headers["Content-Type"] if headers and "Content-Type" in headers else "" + ) + json_body = _parse_request_body(body, content_type=content_type) + logged_body = _hash_sensitive_info(json_body) + logger.debug( + "Making request...", + extra={ + "httpRequest": { + "method": method, + "url": url, + "body": logged_body, + "headers": headers, + } + }, + ) + + +def _parse_request_body(body: Optional[bytes], content_type: str = "") -> Any: + """ + Parses a request body, handling bytes and string types, and different content types. + + Args: + body (Optional[bytes]): The request body. + content_type (str): The content type of the request body, e.g., "application/json", + "application/x-www-form-urlencoded", or "text/plain". If empty, attempts + to parse as JSON. + + Returns: + Parsed body (dict, str, or None). + - JSON: Decodes if content_type is "application/json" or None (fallback). + - URL-encoded: Parses if content_type is "application/x-www-form-urlencoded". + - Plain text: Returns string if content_type is "text/plain". + - None: Returns if body is None, UTF-8 decode fails, or content_type is unknown. + """ + if body is None: + return None + try: + body_str = body.decode("utf-8") + except (UnicodeDecodeError, AttributeError): + return None + content_type = content_type.lower() + if not content_type or "application/json" in content_type: + try: + return json.loads(body_str) + except (json.JSONDecodeError, TypeError): + return body_str + if "application/x-www-form-urlencoded" in content_type: + parsed_query = urllib.parse.parse_qs(body_str) + result = {k: v[0] for k, v in parsed_query.items()} + return result + if "text/plain" in content_type: + return body_str + return None + + +def _parse_response(response: Any) -> Any: + """ + Parses a response, attempting to decode JSON. + + Args: + response: The response object to parse. This can be any type, but + it is expected to have a `json()` method if it contains JSON. + + Returns: + The parsed response. If the response contains valid JSON, the + decoded JSON object (e.g., a dictionary or list) is returned. + If the response does not have a `json()` method or if the JSON + decoding fails, None is returned. + """ + try: + json_response = response.json() + return json_response + except Exception: + # TODO(https://github.com/googleapis/google-auth-library-python/issues/1744): + # Parse and return response payload as json based on different content types. + return None + + +def _response_log_base(logger: logging.Logger, parsed_response: Any) -> None: + """ + Logs a parsed HTTP response at the DEBUG level. + + This internal helper function takes a parsed response and logs it + using the provided logger. It also applies a hashing function to + potentially sensitive information before logging. + + Args: + logger: The logging.Logger instance to use for logging. + parsed_response: The parsed HTTP response object (e.g., a dictionary, + list, or the original response if parsing failed). + """ + + logged_response = _hash_sensitive_info(parsed_response) + logger.debug("Response received...", extra={"httpResponse": logged_response}) + + +def response_log(logger: logging.Logger, response: Any) -> None: + """ + Logs an HTTP response at the DEBUG level if logging is enabled. + + Args: + logger: The logging.Logger instance to use. + response: The HTTP response object to log. + """ + if is_logging_enabled(logger): + json_response = _parse_response(response) + _response_log_base(logger, json_response) diff --git a/lib/python3.10/site-packages/google/auth/_jwt_async.py b/lib/python3.10/site-packages/google/auth/_jwt_async.py new file mode 100644 index 0000000000000000000000000000000000000000..3a1abc5b85c91e4ff754fa79156fcb67217182e3 --- /dev/null +++ b/lib/python3.10/site-packages/google/auth/_jwt_async.py @@ -0,0 +1,164 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""JSON Web Tokens + +Provides support for creating (encoding) and verifying (decoding) JWTs, +especially JWTs generated and consumed by Google infrastructure. + +See `rfc7519`_ for more details on JWTs. + +To encode a JWT use :func:`encode`:: + + from google.auth import crypt + from google.auth import jwt_async + + signer = crypt.Signer(private_key) + payload = {'some': 'payload'} + encoded = jwt_async.encode(signer, payload) + +To decode a JWT and verify claims use :func:`decode`:: + + claims = jwt_async.decode(encoded, certs=public_certs) + +You can also skip verification:: + + claims = jwt_async.decode(encoded, verify=False) + +.. _rfc7519: https://tools.ietf.org/html/rfc7519 + + +NOTE: This async support is experimental and marked internal. This surface may +change in minor releases. +""" + +from google.auth import _credentials_async +from google.auth import jwt + + +def encode(signer, payload, header=None, key_id=None): + """Make a signed JWT. + + Args: + signer (google.auth.crypt.Signer): The signer used to sign the JWT. + payload (Mapping[str, str]): The JWT payload. + header (Mapping[str, str]): Additional JWT header payload. + key_id (str): The key id to add to the JWT header. If the + signer has a key id it will be used as the default. If this is + specified it will override the signer's key id. + + Returns: + bytes: The encoded JWT. + """ + return jwt.encode(signer, payload, header, key_id) + + +def decode(token, certs=None, verify=True, audience=None): + """Decode and verify a JWT. + + Args: + token (str): The encoded JWT. + certs (Union[str, bytes, Mapping[str, Union[str, bytes]]]): The + certificate used to validate the JWT signature. If bytes or string, + it must the the public key certificate in PEM format. If a mapping, + it must be a mapping of key IDs to public key certificates in PEM + format. The mapping must contain the same key ID that's specified + in the token's header. + verify (bool): Whether to perform signature and claim validation. + Verification is done by default. + audience (str): The audience claim, 'aud', that this JWT should + contain. If None then the JWT's 'aud' parameter is not verified. + + Returns: + Mapping[str, str]: The deserialized JSON payload in the JWT. + + Raises: + ValueError: if any verification checks failed. + """ + + return jwt.decode(token, certs, verify, audience) + + +class Credentials( + jwt.Credentials, _credentials_async.Signing, _credentials_async.Credentials +): + """Credentials that use a JWT as the bearer token. + + These credentials require an "audience" claim. This claim identifies the + intended recipient of the bearer token. + + The constructor arguments determine the claims for the JWT that is + sent with requests. Usually, you'll construct these credentials with + one of the helper constructors as shown in the next section. + + To create JWT credentials using a Google service account private key + JSON file:: + + audience = 'https://pubsub.googleapis.com/google.pubsub.v1.Publisher' + credentials = jwt_async.Credentials.from_service_account_file( + 'service-account.json', + audience=audience) + + If you already have the service account file loaded and parsed:: + + service_account_info = json.load(open('service_account.json')) + credentials = jwt_async.Credentials.from_service_account_info( + service_account_info, + audience=audience) + + Both helper methods pass on arguments to the constructor, so you can + specify the JWT claims:: + + credentials = jwt_async.Credentials.from_service_account_file( + 'service-account.json', + audience=audience, + additional_claims={'meta': 'data'}) + + You can also construct the credentials directly if you have a + :class:`~google.auth.crypt.Signer` instance:: + + credentials = jwt_async.Credentials( + signer, + issuer='your-issuer', + subject='your-subject', + audience=audience) + + The claims are considered immutable. If you want to modify the claims, + you can easily create another instance using :meth:`with_claims`:: + + new_audience = ( + 'https://pubsub.googleapis.com/google.pubsub.v1.Subscriber') + new_credentials = credentials.with_claims(audience=new_audience) + """ + + +class OnDemandCredentials( + jwt.OnDemandCredentials, _credentials_async.Signing, _credentials_async.Credentials +): + """On-demand JWT credentials. + + Like :class:`Credentials`, this class uses a JWT as the bearer token for + authentication. However, this class does not require the audience at + construction time. Instead, it will generate a new token on-demand for + each request using the request URI as the audience. It caches tokens + so that multiple requests to the same URI do not incur the overhead + of generating a new token every time. + + This behavior is especially useful for `gRPC`_ clients. A gRPC service may + have multiple audience and gRPC clients may not know all of the audiences + required for accessing a particular service. With these credentials, + no knowledge of the audiences is required ahead of time. + + .. _grpc: http://www.grpc.io/ + """ diff --git a/lib/python3.10/site-packages/google/auth/_oauth2client.py b/lib/python3.10/site-packages/google/auth/_oauth2client.py new file mode 100644 index 0000000000000000000000000000000000000000..8b83ff23c1aebdaa24556bd90a56350ee185f2a8 --- /dev/null +++ b/lib/python3.10/site-packages/google/auth/_oauth2client.py @@ -0,0 +1,167 @@ +# Copyright 2016 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helpers for transitioning from oauth2client to google-auth. + +.. warning:: + This module is private as it is intended to assist first-party downstream + clients with the transition from oauth2client to google-auth. +""" + +from __future__ import absolute_import + +from google.auth import _helpers +import google.auth.app_engine +import google.auth.compute_engine +import google.oauth2.credentials +import google.oauth2.service_account + +try: + import oauth2client.client # type: ignore + import oauth2client.contrib.gce # type: ignore + import oauth2client.service_account # type: ignore +except ImportError as caught_exc: + raise ImportError("oauth2client is not installed.") from caught_exc + +try: + import oauth2client.contrib.appengine # type: ignore + + _HAS_APPENGINE = True +except ImportError: + _HAS_APPENGINE = False + + +_CONVERT_ERROR_TMPL = "Unable to convert {} to a google-auth credentials class." + + +def _convert_oauth2_credentials(credentials): + """Converts to :class:`google.oauth2.credentials.Credentials`. + + Args: + credentials (Union[oauth2client.client.OAuth2Credentials, + oauth2client.client.GoogleCredentials]): The credentials to + convert. + + Returns: + google.oauth2.credentials.Credentials: The converted credentials. + """ + new_credentials = google.oauth2.credentials.Credentials( + token=credentials.access_token, + refresh_token=credentials.refresh_token, + token_uri=credentials.token_uri, + client_id=credentials.client_id, + client_secret=credentials.client_secret, + scopes=credentials.scopes, + ) + + new_credentials._expires = credentials.token_expiry + + return new_credentials + + +def _convert_service_account_credentials(credentials): + """Converts to :class:`google.oauth2.service_account.Credentials`. + + Args: + credentials (Union[ + oauth2client.service_account.ServiceAccountCredentials, + oauth2client.service_account._JWTAccessCredentials]): The + credentials to convert. + + Returns: + google.oauth2.service_account.Credentials: The converted credentials. + """ + info = credentials.serialization_data.copy() + info["token_uri"] = credentials.token_uri + return google.oauth2.service_account.Credentials.from_service_account_info(info) + + +def _convert_gce_app_assertion_credentials(credentials): + """Converts to :class:`google.auth.compute_engine.Credentials`. + + Args: + credentials (oauth2client.contrib.gce.AppAssertionCredentials): The + credentials to convert. + + Returns: + google.oauth2.service_account.Credentials: The converted credentials. + """ + return google.auth.compute_engine.Credentials( + service_account_email=credentials.service_account_email + ) + + +def _convert_appengine_app_assertion_credentials(credentials): + """Converts to :class:`google.auth.app_engine.Credentials`. + + Args: + credentials (oauth2client.contrib.app_engine.AppAssertionCredentials): + The credentials to convert. + + Returns: + google.oauth2.service_account.Credentials: The converted credentials. + """ + # pylint: disable=invalid-name + return google.auth.app_engine.Credentials( + scopes=_helpers.string_to_scopes(credentials.scope), + service_account_id=credentials.service_account_id, + ) + + +_CLASS_CONVERSION_MAP = { + oauth2client.client.OAuth2Credentials: _convert_oauth2_credentials, + oauth2client.client.GoogleCredentials: _convert_oauth2_credentials, + oauth2client.service_account.ServiceAccountCredentials: _convert_service_account_credentials, + oauth2client.service_account._JWTAccessCredentials: _convert_service_account_credentials, + oauth2client.contrib.gce.AppAssertionCredentials: _convert_gce_app_assertion_credentials, +} + +if _HAS_APPENGINE: + _CLASS_CONVERSION_MAP[ + oauth2client.contrib.appengine.AppAssertionCredentials + ] = _convert_appengine_app_assertion_credentials + + +def convert(credentials): + """Convert oauth2client credentials to google-auth credentials. + + This class converts: + + - :class:`oauth2client.client.OAuth2Credentials` to + :class:`google.oauth2.credentials.Credentials`. + - :class:`oauth2client.client.GoogleCredentials` to + :class:`google.oauth2.credentials.Credentials`. + - :class:`oauth2client.service_account.ServiceAccountCredentials` to + :class:`google.oauth2.service_account.Credentials`. + - :class:`oauth2client.service_account._JWTAccessCredentials` to + :class:`google.oauth2.service_account.Credentials`. + - :class:`oauth2client.contrib.gce.AppAssertionCredentials` to + :class:`google.auth.compute_engine.Credentials`. + - :class:`oauth2client.contrib.appengine.AppAssertionCredentials` to + :class:`google.auth.app_engine.Credentials`. + + Returns: + google.auth.credentials.Credentials: The converted credentials. + + Raises: + ValueError: If the credentials could not be converted. + """ + + credentials_class = type(credentials) + + try: + return _CLASS_CONVERSION_MAP[credentials_class](credentials) + except KeyError as caught_exc: + new_exc = ValueError(_CONVERT_ERROR_TMPL.format(credentials_class)) + raise new_exc from caught_exc diff --git a/lib/python3.10/site-packages/google/auth/_refresh_worker.py b/lib/python3.10/site-packages/google/auth/_refresh_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..674032d8496b58b4699d9f49a099bf3b57c456ba --- /dev/null +++ b/lib/python3.10/site-packages/google/auth/_refresh_worker.py @@ -0,0 +1,109 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import logging +import threading + +import google.auth.exceptions as e + +_LOGGER = logging.getLogger(__name__) + + +class RefreshThreadManager: + """ + Organizes exactly one background job that refresh a token. + """ + + def __init__(self): + """Initializes the manager.""" + + self._worker = None + self._lock = threading.Lock() # protects access to worker threads. + + def start_refresh(self, cred, request): + """Starts a refresh thread for the given credentials. + The credentials are refreshed using the request parameter. + request and cred MUST not be None + + Returns True if a background refresh was kicked off. False otherwise. + + Args: + cred: A credentials object. + request: A request object. + Returns: + bool + """ + if cred is None or request is None: + raise e.InvalidValue( + "Unable to start refresh. cred and request must be valid and instantiated objects." + ) + + with self._lock: + if self._worker is not None and self._worker._error_info is not None: + return False + + if self._worker is None or not self._worker.is_alive(): # pragma: NO COVER + self._worker = RefreshThread(cred=cred, request=copy.deepcopy(request)) + self._worker.start() + return True + + def clear_error(self): + """ + Removes any errors that were stored from previous background refreshes. + """ + with self._lock: + if self._worker: + self._worker._error_info = None + + def __getstate__(self): + """Pickle helper that serializes the _lock attribute.""" + state = self.__dict__.copy() + state["_lock"] = None + return state + + def __setstate__(self, state): + """Pickle helper that deserializes the _lock attribute.""" + state["_lock"] = threading.Lock() + self.__dict__.update(state) + + +class RefreshThread(threading.Thread): + """ + Thread that refreshes credentials. + """ + + def __init__(self, cred, request, **kwargs): + """Initializes the thread. + + Args: + cred: A Credential object to refresh. + request: A Request object used to perform a credential refresh. + **kwargs: Additional keyword arguments. + """ + + super().__init__(**kwargs) + self._cred = cred + self._request = request + self._error_info = None + + def run(self): + """ + Perform the credential refresh. + """ + try: + self._cred.refresh(self._request) + except Exception as err: # pragma: NO COVER + _LOGGER.error(f"Background refresh failed due to: {err}") + self._error_info = err diff --git a/lib/python3.10/site-packages/google/auth/_service_account_info.py b/lib/python3.10/site-packages/google/auth/_service_account_info.py new file mode 100644 index 0000000000000000000000000000000000000000..6b64adcaebb06e27aa5ea4eab0f38f4f311f3ae7 --- /dev/null +++ b/lib/python3.10/site-packages/google/auth/_service_account_info.py @@ -0,0 +1,80 @@ +# Copyright 2016 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helper functions for loading data from a Google service account file.""" + +import io +import json + +from google.auth import crypt +from google.auth import exceptions + + +def from_dict(data, require=None, use_rsa_signer=True): + """Validates a dictionary containing Google service account data. + + Creates and returns a :class:`google.auth.crypt.Signer` instance from the + private key specified in the data. + + Args: + data (Mapping[str, str]): The service account data + require (Sequence[str]): List of keys required to be present in the + info. + use_rsa_signer (Optional[bool]): Whether to use RSA signer or EC signer. + We use RSA signer by default. + + Returns: + google.auth.crypt.Signer: A signer created from the private key in the + service account file. + + Raises: + MalformedError: if the data was in the wrong format, or if one of the + required keys is missing. + """ + keys_needed = set(require if require is not None else []) + + missing = keys_needed.difference(data.keys()) + + if missing: + raise exceptions.MalformedError( + "Service account info was not in the expected format, missing " + "fields {}.".format(", ".join(missing)) + ) + + # Create a signer. + if use_rsa_signer: + signer = crypt.RSASigner.from_service_account_info(data) + else: + signer = crypt.ES256Signer.from_service_account_info(data) + + return signer + + +def from_filename(filename, require=None, use_rsa_signer=True): + """Reads a Google service account JSON file and returns its parsed info. + + Args: + filename (str): The path to the service account .json file. + require (Sequence[str]): List of keys required to be present in the + info. + use_rsa_signer (Optional[bool]): Whether to use RSA signer or EC signer. + We use RSA signer by default. + + Returns: + Tuple[ Mapping[str, str], google.auth.crypt.Signer ]: The verified + info and a signer instance. + """ + with io.open(filename, "r", encoding="utf-8") as json_file: + data = json.load(json_file) + return data, from_dict(data, require=require, use_rsa_signer=use_rsa_signer) diff --git a/lib/python3.10/site-packages/google/auth/aio/__init__.py b/lib/python3.10/site-packages/google/auth/aio/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..331708cba62c41e0cddc2e7c96f377200ee4bc1f --- /dev/null +++ b/lib/python3.10/site-packages/google/auth/aio/__init__.py @@ -0,0 +1,25 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Google Auth AIO Library for Python.""" + +import logging + +from google.auth import version as google_auth_version + + +__version__ = google_auth_version.__version__ + +# Set default logging handler to avoid "No handler found" warnings. +logging.getLogger(__name__).addHandler(logging.NullHandler()) diff --git a/lib/python3.10/site-packages/google/auth/aio/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/google/auth/aio/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fbe306f6aa5c380e76297d85f8543eb38439ac87 Binary files /dev/null and b/lib/python3.10/site-packages/google/auth/aio/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/google/auth/aio/__pycache__/_helpers.cpython-310.pyc b/lib/python3.10/site-packages/google/auth/aio/__pycache__/_helpers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5103e2d8d3e123df2d632d2eba5d3f1353031f0f Binary files /dev/null and b/lib/python3.10/site-packages/google/auth/aio/__pycache__/_helpers.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/google/auth/aio/__pycache__/credentials.cpython-310.pyc b/lib/python3.10/site-packages/google/auth/aio/__pycache__/credentials.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf3ed26734847df6777ba0ba1b7e5b53a9f22b7d Binary files /dev/null and b/lib/python3.10/site-packages/google/auth/aio/__pycache__/credentials.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/google/auth/aio/_helpers.py b/lib/python3.10/site-packages/google/auth/aio/_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..1d5a4618cb61b1f410a9c9294e99a916445d180c --- /dev/null +++ b/lib/python3.10/site-packages/google/auth/aio/_helpers.py @@ -0,0 +1,62 @@ +# Copyright 2025 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helper functions for commonly used utilities.""" + + +import logging +from typing import Any + +from google.auth import _helpers + + +async def _parse_response_async(response: Any) -> Any: + """ + Parses an async response, attempting to decode JSON. + + Args: + response: The response object to parse. This can be any type, but + it is expected to have a `json()` method if it contains JSON. + + Returns: + The parsed response. If the response contains valid JSON, the + decoded JSON object (e.g., a dictionary) is returned. + If the response does not have a `json()` method or if the JSON + decoding fails, None is returned. + """ + try: + json_response = await response.json() + return json_response + except Exception: + # TODO(https://github.com/googleapis/google-auth-library-python/issues/1745): + # Parse and return response payload as json based on different content types. + return None + + +async def response_log_async(logger: logging.Logger, response: Any) -> None: + """ + Logs an Async HTTP response at the DEBUG level if logging is enabled. + + Args: + logger: The logging.Logger instance to use. + response: The HTTP response object to log. + """ + if _helpers.is_logging_enabled(logger): + # TODO(https://github.com/googleapis/google-auth-library-python/issues/1755): + # Parsing the response for async streaming logging results in + # the stream to be empty downstream. For now, we will not be logging + # the response for async responses until we investigate further. + # json_response = await _parse_response_async(response) + json_response = None + _helpers._response_log_base(logger, json_response) diff --git a/lib/python3.10/site-packages/google/auth/aio/credentials.py b/lib/python3.10/site-packages/google/auth/aio/credentials.py new file mode 100644 index 0000000000000000000000000000000000000000..3bc6a5a6762a5064b4c5767ad0ab049819515b2d --- /dev/null +++ b/lib/python3.10/site-packages/google/auth/aio/credentials.py @@ -0,0 +1,143 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Interfaces for asynchronous credentials.""" + + +from google.auth import _helpers +from google.auth import exceptions +from google.auth._credentials_base import _BaseCredentials + + +class Credentials(_BaseCredentials): + """Base class for all asynchronous credentials. + + All credentials have a :attr:`token` that is used for authentication and + may also optionally set an :attr:`expiry` to indicate when the token will + no longer be valid. + + Most credentials will be :attr:`invalid` until :meth:`refresh` is called. + Credentials can do this automatically before the first HTTP request in + :meth:`before_request`. + + Although the token and expiration will change as the credentials are + :meth:`refreshed ` and used, credentials should be considered + immutable. Various credentials will accept configuration such as private + keys, scopes, and other options. These options are not changeable after + construction. Some classes will provide mechanisms to copy the credentials + with modifications such as :meth:`ScopedCredentials.with_scopes`. + """ + + def __init__(self): + super(Credentials, self).__init__() + + async def apply(self, headers, token=None): + """Apply the token to the authentication header. + + Args: + headers (Mapping): The HTTP request headers. + token (Optional[str]): If specified, overrides the current access + token. + """ + self._apply(headers, token=token) + + async def refresh(self, request): + """Refreshes the access token. + + Args: + request (google.auth.aio.transport.Request): The object used to make + HTTP requests. + + Raises: + google.auth.exceptions.RefreshError: If the credentials could + not be refreshed. + """ + raise NotImplementedError("Refresh must be implemented") + + async def before_request(self, request, method, url, headers): + """Performs credential-specific before request logic. + + Refreshes the credentials if necessary, then calls :meth:`apply` to + apply the token to the authentication header. + + Args: + request (google.auth.aio.transport.Request): The object used to make + HTTP requests. + method (str): The request's HTTP method or the RPC method being + invoked. + url (str): The request's URI or the RPC service's URI. + headers (Mapping): The request's headers. + """ + await self.apply(headers) + + +class StaticCredentials(Credentials): + """Asynchronous Credentials representing an immutable access token. + + The credentials are considered immutable except the tokens which can be + configured in the constructor :: + + credentials = StaticCredentials(token="token123") + + StaticCredentials does not support :meth `refresh` and assumes that the configured + token is valid and not expired. StaticCredentials will never attempt to + refresh the token. + """ + + def __init__(self, token): + """ + Args: + token (str): The access token. + """ + super(StaticCredentials, self).__init__() + self.token = token + + @_helpers.copy_docstring(Credentials) + async def refresh(self, request): + raise exceptions.InvalidOperation("Static credentials cannot be refreshed.") + + # Note: before_request should never try to refresh access tokens. + # StaticCredentials intentionally does not support it. + @_helpers.copy_docstring(Credentials) + async def before_request(self, request, method, url, headers): + await self.apply(headers) + + +class AnonymousCredentials(Credentials): + """Asynchronous Credentials that do not provide any authentication information. + + These are useful in the case of services that support anonymous access or + local service emulators that do not use credentials. + """ + + async def refresh(self, request): + """Raises :class:``InvalidOperation``, anonymous credentials cannot be + refreshed.""" + raise exceptions.InvalidOperation("Anonymous credentials cannot be refreshed.") + + async def apply(self, headers, token=None): + """Anonymous credentials do nothing to the request. + + The optional ``token`` argument is not supported. + + Raises: + google.auth.exceptions.InvalidValue: If a token was specified. + """ + if token is not None: + raise exceptions.InvalidValue("Anonymous credentials don't support tokens.") + + async def before_request(self, request, method, url, headers): + """Anonymous credentials do nothing to the request.""" + pass diff --git a/lib/python3.10/site-packages/google/auth/aio/transport/__init__.py b/lib/python3.10/site-packages/google/auth/aio/transport/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..166a3be50914885deff2f3295a862808982b5161 --- /dev/null +++ b/lib/python3.10/site-packages/google/auth/aio/transport/__init__.py @@ -0,0 +1,144 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Transport - Asynchronous HTTP client library support. + +:mod:`google.auth.aio` is designed to work with various asynchronous client libraries such +as aiohttp. In order to work across these libraries with different +interfaces some abstraction is needed. + +This module provides two interfaces that are implemented by transport adapters +to support HTTP libraries. :class:`Request` defines the interface expected by +:mod:`google.auth` to make asynchronous requests. :class:`Response` defines the interface +for the return value of :class:`Request`. +""" + +import abc +from typing import AsyncGenerator, Mapping, Optional + +import google.auth.transport + + +_DEFAULT_TIMEOUT_SECONDS = 180 + +DEFAULT_RETRYABLE_STATUS_CODES = google.auth.transport.DEFAULT_RETRYABLE_STATUS_CODES +"""Sequence[int]: HTTP status codes indicating a request can be retried. +""" + + +DEFAULT_MAX_RETRY_ATTEMPTS = 3 +"""int: How many times to retry a request.""" + + +class Response(metaclass=abc.ABCMeta): + """Asynchronous HTTP Response Interface.""" + + @property + @abc.abstractmethod + def status_code(self) -> int: + """ + The HTTP response status code. + + Returns: + int: The HTTP response status code. + + """ + raise NotImplementedError("status_code must be implemented.") + + @property + @abc.abstractmethod + def headers(self) -> Mapping[str, str]: + """The HTTP response headers. + + Returns: + Mapping[str, str]: The HTTP response headers. + """ + raise NotImplementedError("headers must be implemented.") + + @abc.abstractmethod + async def content(self, chunk_size: int) -> AsyncGenerator[bytes, None]: + """The raw response content. + + Args: + chunk_size (int): The size of each chunk. + + Yields: + AsyncGenerator[bytes, None]: An asynchronous generator yielding + response chunks as bytes. + """ + raise NotImplementedError("content must be implemented.") + + @abc.abstractmethod + async def read(self) -> bytes: + """Read the entire response content as bytes. + + Returns: + bytes: The entire response content. + """ + raise NotImplementedError("read must be implemented.") + + @abc.abstractmethod + async def close(self): + """Close the response after it is fully consumed to resource.""" + raise NotImplementedError("close must be implemented.") + + +class Request(metaclass=abc.ABCMeta): + """Interface for a callable that makes HTTP requests. + + Specific transport implementations should provide an implementation of + this that adapts their specific request / response API. + + .. automethod:: __call__ + """ + + @abc.abstractmethod + async def __call__( + self, + url: str, + method: str, + body: Optional[bytes], + headers: Optional[Mapping[str, str]], + timeout: float, + **kwargs + ) -> Response: + """Make an HTTP request. + + Args: + url (str): The URI to be requested. + method (str): The HTTP method to use for the request. Defaults + to 'GET'. + body (Optional[bytes]): The payload / body in HTTP request. + headers (Mapping[str, str]): Request headers. + timeout (float): The number of seconds to wait for a + response from the server. If not specified or if None, the + transport-specific default timeout will be used. + kwargs: Additional arguments passed on to the transport's + request method. + + Returns: + google.auth.aio.transport.Response: The HTTP response. + + Raises: + google.auth.exceptions.TransportError: If any exception occurred. + """ + # pylint: disable=redundant-returns-doc, missing-raises-doc + # (pylint doesn't play well with abstract docstrings.) + raise NotImplementedError("__call__ must be implemented.") + + async def close(self) -> None: + """ + Close the underlying session. + """ + raise NotImplementedError("close must be implemented.") diff --git a/lib/python3.10/site-packages/google/auth/aio/transport/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/google/auth/aio/transport/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8bceb36f3a81722e48e581373ad066e46714eced Binary files /dev/null and b/lib/python3.10/site-packages/google/auth/aio/transport/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/google/auth/aio/transport/__pycache__/aiohttp.cpython-310.pyc b/lib/python3.10/site-packages/google/auth/aio/transport/__pycache__/aiohttp.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c1c425342f23b0bf138c05f179219a16bbe958b Binary files /dev/null and b/lib/python3.10/site-packages/google/auth/aio/transport/__pycache__/aiohttp.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/google/auth/aio/transport/__pycache__/sessions.cpython-310.pyc b/lib/python3.10/site-packages/google/auth/aio/transport/__pycache__/sessions.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..99ab474686a9d9a635daa5a82afea4925aa7f0f7 Binary files /dev/null and b/lib/python3.10/site-packages/google/auth/aio/transport/__pycache__/sessions.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/google/auth/aio/transport/aiohttp.py b/lib/python3.10/site-packages/google/auth/aio/transport/aiohttp.py new file mode 100644 index 0000000000000000000000000000000000000000..67a19f952d242462c881b81ee4840270e05547ea --- /dev/null +++ b/lib/python3.10/site-packages/google/auth/aio/transport/aiohttp.py @@ -0,0 +1,190 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Transport adapter for Asynchronous HTTP Requests based on aiohttp. +""" + +import asyncio +import logging +from typing import AsyncGenerator, Mapping, Optional + +try: + import aiohttp # type: ignore +except ImportError as caught_exc: # pragma: NO COVER + raise ImportError( + "The aiohttp library is not installed from please install the aiohttp package to use the aiohttp transport." + ) from caught_exc + +from google.auth import _helpers +from google.auth import exceptions +from google.auth.aio import _helpers as _helpers_async +from google.auth.aio import transport + +_LOGGER = logging.getLogger(__name__) + + +class Response(transport.Response): + """ + Represents an HTTP response and its data. It is returned by ``google.auth.aio.transport.sessions.AsyncAuthorizedSession``. + + Args: + response (aiohttp.ClientResponse): An instance of aiohttp.ClientResponse. + + Attributes: + status_code (int): The HTTP status code of the response. + headers (Mapping[str, str]): The HTTP headers of the response. + """ + + def __init__(self, response: aiohttp.ClientResponse): + self._response = response + + @property + @_helpers.copy_docstring(transport.Response) + def status_code(self) -> int: + return self._response.status + + @property + @_helpers.copy_docstring(transport.Response) + def headers(self) -> Mapping[str, str]: + return {key: value for key, value in self._response.headers.items()} + + @_helpers.copy_docstring(transport.Response) + async def content(self, chunk_size: int = 1024) -> AsyncGenerator[bytes, None]: + try: + async for chunk in self._response.content.iter_chunked( + chunk_size + ): # pragma: no branch + yield chunk + except aiohttp.ClientPayloadError as exc: + raise exceptions.ResponseError( + "Failed to read from the payload stream." + ) from exc + + @_helpers.copy_docstring(transport.Response) + async def read(self) -> bytes: + try: + return await self._response.read() + except aiohttp.ClientResponseError as exc: + raise exceptions.ResponseError("Failed to read the response body.") from exc + + @_helpers.copy_docstring(transport.Response) + async def close(self): + self._response.close() + + +class Request(transport.Request): + """Asynchronous Requests request adapter. + + This class is used internally for making requests using aiohttp + in a consistent way. If you use :class:`google.auth.aio.transport.sessions.AsyncAuthorizedSession` + you do not need to construct or use this class directly. + + This class can be useful if you want to configure a Request callable + with a custom ``aiohttp.ClientSession`` in :class:`AuthorizedSession` or if + you want to manually refresh a :class:`~google.auth.aio.credentials.Credentials` instance:: + + import aiohttp + import google.auth.aio.transport.aiohttp + + # Default example: + request = google.auth.aio.transport.aiohttp.Request() + await credentials.refresh(request) + + # Custom aiohttp Session Example: + session = session=aiohttp.ClientSession(auto_decompress=False) + request = google.auth.aio.transport.aiohttp.Request(session=session) + auth_sesion = google.auth.aio.transport.sessions.AsyncAuthorizedSession(auth_request=request) + + Args: + session (aiohttp.ClientSession): An instance :class:`aiohttp.ClientSession` used + to make HTTP requests. If not specified, a session will be created. + + .. automethod:: __call__ + """ + + def __init__(self, session: aiohttp.ClientSession = None): + self._session = session + self._closed = False + + async def __call__( + self, + url: str, + method: str = "GET", + body: Optional[bytes] = None, + headers: Optional[Mapping[str, str]] = None, + timeout: float = transport._DEFAULT_TIMEOUT_SECONDS, + **kwargs, + ) -> transport.Response: + """ + Make an HTTP request using aiohttp. + + Args: + url (str): The URL to be requested. + method (Optional[str]): + The HTTP method to use for the request. Defaults to 'GET'. + body (Optional[bytes]): + The payload or body in HTTP request. + headers (Optional[Mapping[str, str]]): + Request headers. + timeout (float): The number of seconds to wait for a + response from the server. If not specified or if None, the + requests default timeout will be used. + kwargs: Additional arguments passed through to the underlying + aiohttp :meth:`aiohttp.Session.request` method. + + Returns: + google.auth.aio.transport.Response: The HTTP response. + + Raises: + - google.auth.exceptions.TransportError: If the request fails or if the session is closed. + - google.auth.exceptions.TimeoutError: If the request times out. + """ + + try: + if self._closed: + raise exceptions.TransportError("session is closed.") + + if not self._session: + self._session = aiohttp.ClientSession() + + client_timeout = aiohttp.ClientTimeout(total=timeout) + _helpers.request_log(_LOGGER, method, url, body, headers) + response = await self._session.request( + method, + url, + data=body, + headers=headers, + timeout=client_timeout, + **kwargs, + ) + await _helpers_async.response_log_async(_LOGGER, response) + return Response(response) + + except aiohttp.ClientError as caught_exc: + client_exc = exceptions.TransportError(f"Failed to send request to {url}.") + raise client_exc from caught_exc + + except asyncio.TimeoutError as caught_exc: + timeout_exc = exceptions.TimeoutError( + f"Request timed out after {timeout} seconds." + ) + raise timeout_exc from caught_exc + + async def close(self) -> None: + """ + Close the underlying aiohttp session to release the acquired resources. + """ + if not self._closed and self._session: + await self._session.close() + self._closed = True diff --git a/lib/python3.10/site-packages/google/auth/aio/transport/sessions.py b/lib/python3.10/site-packages/google/auth/aio/transport/sessions.py new file mode 100644 index 0000000000000000000000000000000000000000..fea7cbbb2c3dd8089841a6f87a139cc1a04003f1 --- /dev/null +++ b/lib/python3.10/site-packages/google/auth/aio/transport/sessions.py @@ -0,0 +1,268 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from contextlib import asynccontextmanager +import functools +import time +from typing import Mapping, Optional + +from google.auth import _exponential_backoff, exceptions +from google.auth.aio import transport +from google.auth.aio.credentials import Credentials +from google.auth.exceptions import TimeoutError + +try: + from google.auth.aio.transport.aiohttp import Request as AiohttpRequest + + AIOHTTP_INSTALLED = True +except ImportError: # pragma: NO COVER + AIOHTTP_INSTALLED = False + + +@asynccontextmanager +async def timeout_guard(timeout): + """ + timeout_guard is an asynchronous context manager to apply a timeout to an asynchronous block of code. + + Args: + timeout (float): The time in seconds before the context manager times out. + + Raises: + google.auth.exceptions.TimeoutError: If the code within the context exceeds the provided timeout. + + Usage: + async with timeout_guard(10) as with_timeout: + await with_timeout(async_function()) + """ + start = time.monotonic() + total_timeout = timeout + + def _remaining_time(): + elapsed = time.monotonic() - start + remaining = total_timeout - elapsed + if remaining <= 0: + raise TimeoutError( + f"Context manager exceeded the configured timeout of {total_timeout}s." + ) + return remaining + + async def with_timeout(coro): + try: + remaining = _remaining_time() + response = await asyncio.wait_for(coro, remaining) + return response + except (asyncio.TimeoutError, TimeoutError) as e: + raise TimeoutError( + f"The operation {coro} exceeded the configured timeout of {total_timeout}s." + ) from e + + try: + yield with_timeout + + finally: + _remaining_time() + + +class AsyncAuthorizedSession: + """This is an asynchronous implementation of :class:`google.auth.requests.AuthorizedSession` class. + We utilize an instance of a class that implements :class:`google.auth.aio.transport.Request` configured + by the caller or otherwise default to `google.auth.aio.transport.aiohttp.Request` if the external aiohttp + package is installed. + + A Requests Session class with credentials. + + This class is used to perform asynchronous requests to API endpoints that require + authorization:: + + import aiohttp + from google.auth.aio.transport import sessions + + async with sessions.AsyncAuthorizedSession(credentials) as authed_session: + response = await authed_session.request( + 'GET', 'https://www.googleapis.com/storage/v1/b') + + The underlying :meth:`request` implementation handles adding the + credentials' headers to the request and refreshing credentials as needed. + + Args: + credentials (google.auth.aio.credentials.Credentials): + The credentials to add to the request. + auth_request (Optional[google.auth.aio.transport.Request]): + An instance of a class that implements + :class:`~google.auth.aio.transport.Request` used to make requests + and refresh credentials. If not passed, + an instance of :class:`~google.auth.aio.transport.aiohttp.Request` + is created. + + Raises: + - google.auth.exceptions.TransportError: If `auth_request` is `None` + and the external package `aiohttp` is not installed. + - google.auth.exceptions.InvalidType: If the provided credentials are + not of type `google.auth.aio.credentials.Credentials`. + """ + + def __init__( + self, credentials: Credentials, auth_request: Optional[transport.Request] = None + ): + if not isinstance(credentials, Credentials): + raise exceptions.InvalidType( + f"The configured credentials of type {type(credentials)} are invalid and must be of type `google.auth.aio.credentials.Credentials`" + ) + self._credentials = credentials + _auth_request = auth_request + if not _auth_request and AIOHTTP_INSTALLED: + _auth_request = AiohttpRequest() + if _auth_request is None: + raise exceptions.TransportError( + "`auth_request` must either be configured or the external package `aiohttp` must be installed to use the default value." + ) + self._auth_request = _auth_request + + async def request( + self, + method: str, + url: str, + data: Optional[bytes] = None, + headers: Optional[Mapping[str, str]] = None, + max_allowed_time: float = transport._DEFAULT_TIMEOUT_SECONDS, + timeout: float = transport._DEFAULT_TIMEOUT_SECONDS, + **kwargs, + ) -> transport.Response: + """ + Args: + method (str): The http method used to make the request. + url (str): The URI to be requested. + data (Optional[bytes]): The payload or body in HTTP request. + headers (Optional[Mapping[str, str]]): Request headers. + timeout (float): + The amount of time in seconds to wait for the server response + with each individual request. + max_allowed_time (float): + If the method runs longer than this, a ``Timeout`` exception is + automatically raised. Unlike the ``timeout`` parameter, this + value applies to the total method execution time, even if + multiple requests are made under the hood. + + Mind that it is not guaranteed that the timeout error is raised + at ``max_allowed_time``. It might take longer, for example, if + an underlying request takes a lot of time, but the request + itself does not timeout, e.g. if a large file is being + transmitted. The timout error will be raised after such + request completes. + + Returns: + google.auth.aio.transport.Response: The HTTP response. + + Raises: + google.auth.exceptions.TimeoutError: If the method does not complete within + the configured `max_allowed_time` or the request exceeds the configured + `timeout`. + """ + + retries = _exponential_backoff.AsyncExponentialBackoff( + total_attempts=transport.DEFAULT_MAX_RETRY_ATTEMPTS + ) + async with timeout_guard(max_allowed_time) as with_timeout: + await with_timeout( + # Note: before_request will attempt to refresh credentials if expired. + self._credentials.before_request( + self._auth_request, method, url, headers + ) + ) + # Workaround issue in python 3.9 related to code coverage by adding `# pragma: no branch` + # See https://github.com/googleapis/gapic-generator-python/pull/1174#issuecomment-1025132372 + async for _ in retries: # pragma: no branch + response = await with_timeout( + self._auth_request(url, method, data, headers, timeout, **kwargs) + ) + if response.status_code not in transport.DEFAULT_RETRYABLE_STATUS_CODES: + break + return response + + @functools.wraps(request) + async def get( + self, + url: str, + data: Optional[bytes] = None, + headers: Optional[Mapping[str, str]] = None, + max_allowed_time: float = transport._DEFAULT_TIMEOUT_SECONDS, + timeout: float = transport._DEFAULT_TIMEOUT_SECONDS, + **kwargs, + ) -> transport.Response: + return await self.request( + "GET", url, data, headers, max_allowed_time, timeout, **kwargs + ) + + @functools.wraps(request) + async def post( + self, + url: str, + data: Optional[bytes] = None, + headers: Optional[Mapping[str, str]] = None, + max_allowed_time: float = transport._DEFAULT_TIMEOUT_SECONDS, + timeout: float = transport._DEFAULT_TIMEOUT_SECONDS, + **kwargs, + ) -> transport.Response: + return await self.request( + "POST", url, data, headers, max_allowed_time, timeout, **kwargs + ) + + @functools.wraps(request) + async def put( + self, + url: str, + data: Optional[bytes] = None, + headers: Optional[Mapping[str, str]] = None, + max_allowed_time: float = transport._DEFAULT_TIMEOUT_SECONDS, + timeout: float = transport._DEFAULT_TIMEOUT_SECONDS, + **kwargs, + ) -> transport.Response: + return await self.request( + "PUT", url, data, headers, max_allowed_time, timeout, **kwargs + ) + + @functools.wraps(request) + async def patch( + self, + url: str, + data: Optional[bytes] = None, + headers: Optional[Mapping[str, str]] = None, + max_allowed_time: float = transport._DEFAULT_TIMEOUT_SECONDS, + timeout: float = transport._DEFAULT_TIMEOUT_SECONDS, + **kwargs, + ) -> transport.Response: + return await self.request( + "PATCH", url, data, headers, max_allowed_time, timeout, **kwargs + ) + + @functools.wraps(request) + async def delete( + self, + url: str, + data: Optional[bytes] = None, + headers: Optional[Mapping[str, str]] = None, + max_allowed_time: float = transport._DEFAULT_TIMEOUT_SECONDS, + timeout: float = transport._DEFAULT_TIMEOUT_SECONDS, + **kwargs, + ) -> transport.Response: + return await self.request( + "DELETE", url, data, headers, max_allowed_time, timeout, **kwargs + ) + + async def close(self) -> None: + """ + Close the underlying auth request session. + """ + await self._auth_request.close() diff --git a/lib/python3.10/site-packages/google/auth/api_key.py b/lib/python3.10/site-packages/google/auth/api_key.py new file mode 100644 index 0000000000000000000000000000000000000000..4fdf7f2769ca8d66f94e27d384c72018c5018e71 --- /dev/null +++ b/lib/python3.10/site-packages/google/auth/api_key.py @@ -0,0 +1,76 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Google API key support. +This module provides authentication using the `API key`_. +.. _API key: + https://cloud.google.com/docs/authentication/api-keys/ +""" + +from google.auth import _helpers +from google.auth import credentials +from google.auth import exceptions + + +class Credentials(credentials.Credentials): + """API key credentials. + These credentials use API key to provide authorization to applications. + """ + + def __init__(self, token): + """ + Args: + token (str): API key string + Raises: + ValueError: If the provided API key is not a non-empty string. + """ + super(Credentials, self).__init__() + if not token: + raise exceptions.InvalidValue("Token must be a non-empty API key string") + self.token = token + + @property + def expired(self): + return False + + @property + def valid(self): + return True + + @_helpers.copy_docstring(credentials.Credentials) + def refresh(self, request): + return + + def apply(self, headers, token=None): + """Apply the API key token to the x-goog-api-key header. + Args: + headers (Mapping): The HTTP request headers. + token (Optional[str]): If specified, overrides the current access + token. + """ + headers["x-goog-api-key"] = token or self.token + + def before_request(self, request, method, url, headers): + """Performs credential-specific before request logic. + Refreshes the credentials if necessary, then calls :meth:`apply` to + apply the token to the x-goog-api-key header. + Args: + request (google.auth.transport.Request): The object used to make + HTTP requests. + method (str): The request's HTTP method or the RPC method being + invoked. + url (str): The request's URI or the RPC service's URI. + headers (Mapping): The request's headers. + """ + self.apply(headers) diff --git a/lib/python3.10/site-packages/google/auth/app_engine.py b/lib/python3.10/site-packages/google/auth/app_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..7083ee61436bc1da91f16c8ee5241c52a5bd04ad --- /dev/null +++ b/lib/python3.10/site-packages/google/auth/app_engine.py @@ -0,0 +1,180 @@ +# Copyright 2016 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Google App Engine standard environment support. + +This module provides authentication and signing for applications running on App +Engine in the standard environment using the `App Identity API`_. + + +.. _App Identity API: + https://cloud.google.com/appengine/docs/python/appidentity/ +""" + +import datetime + +from google.auth import _helpers +from google.auth import credentials +from google.auth import crypt +from google.auth import exceptions + +# pytype: disable=import-error +try: + from google.appengine.api import app_identity # type: ignore +except ImportError: + app_identity = None # type: ignore +# pytype: enable=import-error + + +class Signer(crypt.Signer): + """Signs messages using the App Engine App Identity service. + + This can be used in place of :class:`google.auth.crypt.Signer` when + running in the App Engine standard environment. + """ + + @property + def key_id(self): + """Optional[str]: The key ID used to identify this private key. + + .. warning:: + This is always ``None``. The key ID used by App Engine can not + be reliably determined ahead of time. + """ + return None + + @_helpers.copy_docstring(crypt.Signer) + def sign(self, message): + message = _helpers.to_bytes(message) + _, signature = app_identity.sign_blob(message) + return signature + + +def get_project_id(): + """Gets the project ID for the current App Engine application. + + Returns: + str: The project ID + + Raises: + google.auth.exceptions.OSError: If the App Engine APIs are unavailable. + """ + # pylint: disable=missing-raises-doc + # Pylint rightfully thinks google.auth.exceptions.OSError is OSError, but doesn't + # realize it's a valid alias. + if app_identity is None: + raise exceptions.OSError("The App Engine APIs are not available.") + return app_identity.get_application_id() + + +class Credentials( + credentials.Scoped, credentials.Signing, credentials.CredentialsWithQuotaProject +): + """App Engine standard environment credentials. + + These credentials use the App Engine App Identity API to obtain access + tokens. + """ + + def __init__( + self, + scopes=None, + default_scopes=None, + service_account_id=None, + quota_project_id=None, + ): + """ + Args: + scopes (Sequence[str]): Scopes to request from the App Identity + API. + default_scopes (Sequence[str]): Default scopes passed by a + Google client library. Use 'scopes' for user-defined scopes. + service_account_id (str): The service account ID passed into + :func:`google.appengine.api.app_identity.get_access_token`. + If not specified, the default application service account + ID will be used. + quota_project_id (Optional[str]): The project ID used for quota + and billing. + + Raises: + google.auth.exceptions.OSError: If the App Engine APIs are unavailable. + """ + # pylint: disable=missing-raises-doc + # Pylint rightfully thinks google.auth.exceptions.OSError is OSError, but doesn't + # realize it's a valid alias. + if app_identity is None: + raise exceptions.OSError("The App Engine APIs are not available.") + + super(Credentials, self).__init__() + self._scopes = scopes + self._default_scopes = default_scopes + self._service_account_id = service_account_id + self._signer = Signer() + self._quota_project_id = quota_project_id + + @_helpers.copy_docstring(credentials.Credentials) + def refresh(self, request): + scopes = self._scopes if self._scopes is not None else self._default_scopes + # pylint: disable=unused-argument + token, ttl = app_identity.get_access_token(scopes, self._service_account_id) + expiry = datetime.datetime.utcfromtimestamp(ttl) + + self.token, self.expiry = token, expiry + + @property + def service_account_email(self): + """The service account email.""" + if self._service_account_id is None: + self._service_account_id = app_identity.get_service_account_name() + return self._service_account_id + + @property + def requires_scopes(self): + """Checks if the credentials requires scopes. + + Returns: + bool: True if there are no scopes set otherwise False. + """ + return not self._scopes and not self._default_scopes + + @_helpers.copy_docstring(credentials.Scoped) + def with_scopes(self, scopes, default_scopes=None): + return self.__class__( + scopes=scopes, + default_scopes=default_scopes, + service_account_id=self._service_account_id, + quota_project_id=self.quota_project_id, + ) + + @_helpers.copy_docstring(credentials.CredentialsWithQuotaProject) + def with_quota_project(self, quota_project_id): + return self.__class__( + scopes=self._scopes, + service_account_id=self._service_account_id, + quota_project_id=quota_project_id, + ) + + @_helpers.copy_docstring(credentials.Signing) + def sign_bytes(self, message): + return self._signer.sign(message) + + @property # type: ignore + @_helpers.copy_docstring(credentials.Signing) + def signer_email(self): + return self.service_account_email + + @property # type: ignore + @_helpers.copy_docstring(credentials.Signing) + def signer(self): + return self._signer diff --git a/lib/python3.10/site-packages/google/auth/aws.py b/lib/python3.10/site-packages/google/auth/aws.py new file mode 100644 index 0000000000000000000000000000000000000000..28c065d3c7511d5ea9d5909f27628ed7ca4225f2 --- /dev/null +++ b/lib/python3.10/site-packages/google/auth/aws.py @@ -0,0 +1,861 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""AWS Credentials and AWS Signature V4 Request Signer. + +This module provides credentials to access Google Cloud resources from Amazon +Web Services (AWS) workloads. These credentials are recommended over the +use of service account credentials in AWS as they do not involve the management +of long-live service account private keys. + +AWS Credentials are initialized using external_account arguments which are +typically loaded from the external credentials JSON file. + +This module also provides a definition for an abstract AWS security credentials supplier. +This supplier can be implemented to return valid AWS security credentials and an AWS region +and used to create AWS credentials. The credentials will then call the +supplier instead of using pre-defined methods such as calling the EC2 metadata endpoints. + +This module also provides a basic implementation of the +`AWS Signature Version 4`_ request signing algorithm. + +AWS Credentials use serialized signed requests to the +`AWS STS GetCallerIdentity`_ API that can be exchanged for Google access tokens +via the GCP STS endpoint. + +.. _AWS Signature Version 4: https://docs.aws.amazon.com/general/latest/gr/signature-version-4.html +.. _AWS STS GetCallerIdentity: https://docs.aws.amazon.com/STS/latest/APIReference/API_GetCallerIdentity.html +""" + +import abc +from dataclasses import dataclass +import hashlib +import hmac +import http.client as http_client +import json +import os +import posixpath +import re +from typing import Optional +import urllib +from urllib.parse import urljoin + +from google.auth import _helpers +from google.auth import environment_vars +from google.auth import exceptions +from google.auth import external_account + +# AWS Signature Version 4 signing algorithm identifier. +_AWS_ALGORITHM = "AWS4-HMAC-SHA256" +# The termination string for the AWS credential scope value as defined in +# https://docs.aws.amazon.com/general/latest/gr/sigv4-create-string-to-sign.html +_AWS_REQUEST_TYPE = "aws4_request" +# The AWS authorization header name for the security session token if available. +_AWS_SECURITY_TOKEN_HEADER = "x-amz-security-token" +# The AWS authorization header name for the auto-generated date. +_AWS_DATE_HEADER = "x-amz-date" +# The default AWS regional credential verification URL. +_DEFAULT_AWS_REGIONAL_CREDENTIAL_VERIFICATION_URL = ( + "https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15" +) +# IMDSV2 session token lifetime. This is set to a low value because the session token is used immediately. +_IMDSV2_SESSION_TOKEN_TTL_SECONDS = "300" + + +class RequestSigner(object): + """Implements an AWS request signer based on the AWS Signature Version 4 signing + process. + https://docs.aws.amazon.com/general/latest/gr/signature-version-4.html + """ + + def __init__(self, region_name): + """Instantiates an AWS request signer used to compute authenticated signed + requests to AWS APIs based on the AWS Signature Version 4 signing process. + + Args: + region_name (str): The AWS region to use. + """ + + self._region_name = region_name + + def get_request_options( + self, + aws_security_credentials, + url, + method, + request_payload="", + additional_headers={}, + ): + """Generates the signed request for the provided HTTP request for calling + an AWS API. This follows the steps described at: + https://docs.aws.amazon.com/general/latest/gr/sigv4_signing.html + + Args: + aws_security_credentials (AWSSecurityCredentials): The AWS security credentials. + url (str): The AWS service URL containing the canonical URI and + query string. + method (str): The HTTP method used to call this API. + request_payload (Optional[str]): The optional request payload if + available. + additional_headers (Optional[Mapping[str, str]]): The optional + additional headers needed for the requested AWS API. + + Returns: + Mapping[str, str]: The AWS signed request dictionary object. + """ + + additional_headers = additional_headers or {} + + uri = urllib.parse.urlparse(url) + # Normalize the URL path. This is needed for the canonical_uri. + # os.path.normpath can't be used since it normalizes "/" paths + # to "\\" in Windows OS. + normalized_uri = urllib.parse.urlparse( + urljoin(url, posixpath.normpath(uri.path)) + ) + # Validate provided URL. + if not uri.hostname or uri.scheme != "https": + raise exceptions.InvalidResource("Invalid AWS service URL") + + header_map = _generate_authentication_header_map( + host=uri.hostname, + canonical_uri=normalized_uri.path or "/", + canonical_querystring=_get_canonical_querystring(uri.query), + method=method, + region=self._region_name, + aws_security_credentials=aws_security_credentials, + request_payload=request_payload, + additional_headers=additional_headers, + ) + headers = { + "Authorization": header_map.get("authorization_header"), + "host": uri.hostname, + } + # Add x-amz-date if available. + if "amz_date" in header_map: + headers[_AWS_DATE_HEADER] = header_map.get("amz_date") + # Append additional optional headers, eg. X-Amz-Target, Content-Type, etc. + for key in additional_headers: + headers[key] = additional_headers[key] + + # Add session token if available. + if aws_security_credentials.session_token is not None: + headers[_AWS_SECURITY_TOKEN_HEADER] = aws_security_credentials.session_token + + signed_request = {"url": url, "method": method, "headers": headers} + if request_payload: + signed_request["data"] = request_payload + return signed_request + + +def _get_canonical_querystring(query): + """Generates the canonical query string given a raw query string. + Logic is based on + https://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html + + Args: + query (str): The raw query string. + + Returns: + str: The canonical query string. + """ + # Parse raw query string. + querystring = urllib.parse.parse_qs(query) + querystring_encoded_map = {} + for key in querystring: + quote_key = urllib.parse.quote(key, safe="-_.~") + # URI encode key. + querystring_encoded_map[quote_key] = [] + for item in querystring[key]: + # For each key, URI encode all values for that key. + querystring_encoded_map[quote_key].append( + urllib.parse.quote(item, safe="-_.~") + ) + # Sort values for each key. + querystring_encoded_map[quote_key].sort() + # Sort keys. + sorted_keys = list(querystring_encoded_map.keys()) + sorted_keys.sort() + # Reconstruct the query string. Preserve keys with multiple values. + querystring_encoded_pairs = [] + for key in sorted_keys: + for item in querystring_encoded_map[key]: + querystring_encoded_pairs.append("{}={}".format(key, item)) + return "&".join(querystring_encoded_pairs) + + +def _sign(key, msg): + """Creates the HMAC-SHA256 hash of the provided message using the provided + key. + + Args: + key (str): The HMAC-SHA256 key to use. + msg (str): The message to hash. + + Returns: + str: The computed hash bytes. + """ + return hmac.new(key, msg.encode("utf-8"), hashlib.sha256).digest() + + +def _get_signing_key(key, date_stamp, region_name, service_name): + """Calculates the signing key used to calculate the signature for + AWS Signature Version 4 based on: + https://docs.aws.amazon.com/general/latest/gr/sigv4-calculate-signature.html + + Args: + key (str): The AWS secret access key. + date_stamp (str): The '%Y%m%d' date format. + region_name (str): The AWS region. + service_name (str): The AWS service name, eg. sts. + + Returns: + str: The signing key bytes. + """ + k_date = _sign(("AWS4" + key).encode("utf-8"), date_stamp) + k_region = _sign(k_date, region_name) + k_service = _sign(k_region, service_name) + k_signing = _sign(k_service, "aws4_request") + return k_signing + + +def _generate_authentication_header_map( + host, + canonical_uri, + canonical_querystring, + method, + region, + aws_security_credentials, + request_payload="", + additional_headers={}, +): + """Generates the authentication header map needed for generating the AWS + Signature Version 4 signed request. + + Args: + host (str): The AWS service URL hostname. + canonical_uri (str): The AWS service URL path name. + canonical_querystring (str): The AWS service URL query string. + method (str): The HTTP method used to call this API. + region (str): The AWS region. + aws_security_credentials (AWSSecurityCredentials): The AWS security credentials. + request_payload (Optional[str]): The optional request payload if + available. + additional_headers (Optional[Mapping[str, str]]): The optional + additional headers needed for the requested AWS API. + + Returns: + Mapping[str, str]: The AWS authentication header dictionary object. + This contains the x-amz-date and authorization header information. + """ + # iam.amazonaws.com host => iam service. + # sts.us-east-2.amazonaws.com host => sts service. + service_name = host.split(".")[0] + + current_time = _helpers.utcnow() + amz_date = current_time.strftime("%Y%m%dT%H%M%SZ") + date_stamp = current_time.strftime("%Y%m%d") + + # Change all additional headers to be lower case. + full_headers = {} + for key in additional_headers: + full_headers[key.lower()] = additional_headers[key] + # Add AWS session token if available. + if aws_security_credentials.session_token is not None: + full_headers[ + _AWS_SECURITY_TOKEN_HEADER + ] = aws_security_credentials.session_token + + # Required headers + full_headers["host"] = host + # Do not use generated x-amz-date if the date header is provided. + # Previously the date was not fixed with x-amz- and could be provided + # manually. + # https://github.com/boto/botocore/blob/879f8440a4e9ace5d3cf145ce8b3d5e5ffb892ef/tests/unit/auth/aws4_testsuite/get-header-value-trim.req + if "date" not in full_headers: + full_headers[_AWS_DATE_HEADER] = amz_date + + # Header keys need to be sorted alphabetically. + canonical_headers = "" + header_keys = list(full_headers.keys()) + header_keys.sort() + for key in header_keys: + canonical_headers = "{}{}:{}\n".format( + canonical_headers, key, full_headers[key] + ) + signed_headers = ";".join(header_keys) + + payload_hash = hashlib.sha256((request_payload or "").encode("utf-8")).hexdigest() + + # https://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html + canonical_request = "{}\n{}\n{}\n{}\n{}\n{}".format( + method, + canonical_uri, + canonical_querystring, + canonical_headers, + signed_headers, + payload_hash, + ) + + credential_scope = "{}/{}/{}/{}".format( + date_stamp, region, service_name, _AWS_REQUEST_TYPE + ) + + # https://docs.aws.amazon.com/general/latest/gr/sigv4-create-string-to-sign.html + string_to_sign = "{}\n{}\n{}\n{}".format( + _AWS_ALGORITHM, + amz_date, + credential_scope, + hashlib.sha256(canonical_request.encode("utf-8")).hexdigest(), + ) + + # https://docs.aws.amazon.com/general/latest/gr/sigv4-calculate-signature.html + signing_key = _get_signing_key( + aws_security_credentials.secret_access_key, date_stamp, region, service_name + ) + signature = hmac.new( + signing_key, string_to_sign.encode("utf-8"), hashlib.sha256 + ).hexdigest() + + # https://docs.aws.amazon.com/general/latest/gr/sigv4-add-signature-to-request.html + authorization_header = "{} Credential={}/{}, SignedHeaders={}, Signature={}".format( + _AWS_ALGORITHM, + aws_security_credentials.access_key_id, + credential_scope, + signed_headers, + signature, + ) + + authentication_header = {"authorization_header": authorization_header} + # Do not use generated x-amz-date if the date header is provided. + if "date" not in full_headers: + authentication_header["amz_date"] = amz_date + return authentication_header + + +@dataclass +class AwsSecurityCredentials: + """A class that models AWS security credentials with an optional session token. + + Attributes: + access_key_id (str): The AWS security credentials access key id. + secret_access_key (str): The AWS security credentials secret access key. + session_token (Optional[str]): The optional AWS security credentials session token. This should be set when using temporary credentials. + """ + + access_key_id: str + secret_access_key: str + session_token: Optional[str] = None + + +class AwsSecurityCredentialsSupplier(metaclass=abc.ABCMeta): + """Base class for AWS security credential suppliers. This can be implemented with custom logic to retrieve + AWS security credentials to exchange for a Google Cloud access token. The AWS external account credential does + not cache the AWS security credentials, so caching logic should be added in the implementation. + """ + + @abc.abstractmethod + def get_aws_security_credentials(self, context, request): + """Returns the AWS security credentials for the requested context. + + .. warning: This is not cached by the calling Google credential, so caching logic should be implemented in the supplier. + + Args: + context (google.auth.externalaccount.SupplierContext): The context object + containing information about the requested audience and subject token type. + request (google.auth.transport.Request): The object used to make + HTTP requests. + + Raises: + google.auth.exceptions.RefreshError: If an error is encountered during + security credential retrieval logic. + + Returns: + AwsSecurityCredentials: The requested AWS security credentials. + """ + raise NotImplementedError("") + + @abc.abstractmethod + def get_aws_region(self, context, request): + """Returns the AWS region for the requested context. + + Args: + context (google.auth.externalaccount.SupplierContext): The context object + containing information about the requested audience and subject token type. + request (google.auth.transport.Request): The object used to make + HTTP requests. + + Raises: + google.auth.exceptions.RefreshError: If an error is encountered during + region retrieval logic. + + Returns: + str: The AWS region. + """ + raise NotImplementedError("") + + +class _DefaultAwsSecurityCredentialsSupplier(AwsSecurityCredentialsSupplier): + """Default implementation of AWS security credentials supplier. Supports retrieving + credentials and region via EC2 metadata endpoints and environment variables. + """ + + def __init__(self, credential_source): + self._region_url = credential_source.get("region_url") + self._security_credentials_url = credential_source.get("url") + self._imdsv2_session_token_url = credential_source.get( + "imdsv2_session_token_url" + ) + + @_helpers.copy_docstring(AwsSecurityCredentialsSupplier) + def get_aws_security_credentials(self, context, request): + + # Check environment variables for permanent credentials first. + # https://docs.aws.amazon.com/general/latest/gr/aws-sec-cred-types.html + env_aws_access_key_id = os.environ.get(environment_vars.AWS_ACCESS_KEY_ID) + env_aws_secret_access_key = os.environ.get( + environment_vars.AWS_SECRET_ACCESS_KEY + ) + # This is normally not available for permanent credentials. + env_aws_session_token = os.environ.get(environment_vars.AWS_SESSION_TOKEN) + if env_aws_access_key_id and env_aws_secret_access_key: + return AwsSecurityCredentials( + env_aws_access_key_id, env_aws_secret_access_key, env_aws_session_token + ) + + imdsv2_session_token = self._get_imdsv2_session_token(request) + role_name = self._get_metadata_role_name(request, imdsv2_session_token) + + # Get security credentials. + credentials = self._get_metadata_security_credentials( + request, role_name, imdsv2_session_token + ) + + return AwsSecurityCredentials( + credentials.get("AccessKeyId"), + credentials.get("SecretAccessKey"), + credentials.get("Token"), + ) + + @_helpers.copy_docstring(AwsSecurityCredentialsSupplier) + def get_aws_region(self, context, request): + # The AWS metadata server is not available in some AWS environments + # such as AWS lambda. Instead, it is available via environment + # variable. + env_aws_region = os.environ.get(environment_vars.AWS_REGION) + if env_aws_region is not None: + return env_aws_region + + env_aws_region = os.environ.get(environment_vars.AWS_DEFAULT_REGION) + if env_aws_region is not None: + return env_aws_region + + if not self._region_url: + raise exceptions.RefreshError("Unable to determine AWS region") + + headers = None + imdsv2_session_token = self._get_imdsv2_session_token(request) + if imdsv2_session_token is not None: + headers = {"X-aws-ec2-metadata-token": imdsv2_session_token} + + response = request(url=self._region_url, method="GET", headers=headers) + + # Support both string and bytes type response.data. + response_body = ( + response.data.decode("utf-8") + if hasattr(response.data, "decode") + else response.data + ) + + if response.status != http_client.OK: + raise exceptions.RefreshError( + "Unable to retrieve AWS region: {}".format(response_body) + ) + + # This endpoint will return the region in format: us-east-2b. + # Only the us-east-2 part should be used. + return response_body[:-1] + + def _get_imdsv2_session_token(self, request): + if request is not None and self._imdsv2_session_token_url is not None: + headers = { + "X-aws-ec2-metadata-token-ttl-seconds": _IMDSV2_SESSION_TOKEN_TTL_SECONDS + } + + imdsv2_session_token_response = request( + url=self._imdsv2_session_token_url, method="PUT", headers=headers + ) + + if imdsv2_session_token_response.status != http_client.OK: + raise exceptions.RefreshError( + "Unable to retrieve AWS Session Token: {}".format( + imdsv2_session_token_response.data + ) + ) + + return imdsv2_session_token_response.data + else: + return None + + def _get_metadata_security_credentials( + self, request, role_name, imdsv2_session_token + ): + """Retrieves the AWS security credentials required for signing AWS + requests from the AWS metadata server. + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. + role_name (str): The AWS role name required by the AWS metadata + server security_credentials endpoint in order to return the + credentials. + imdsv2_session_token (str): The AWS IMDSv2 session token to be added as a + header in the requests to AWS metadata endpoint. + + Returns: + Mapping[str, str]: The AWS metadata server security credentials + response. + + Raises: + google.auth.exceptions.RefreshError: If an error occurs while + retrieving the AWS security credentials. + """ + headers = {"Content-Type": "application/json"} + if imdsv2_session_token is not None: + headers["X-aws-ec2-metadata-token"] = imdsv2_session_token + + response = request( + url="{}/{}".format(self._security_credentials_url, role_name), + method="GET", + headers=headers, + ) + + # support both string and bytes type response.data + response_body = ( + response.data.decode("utf-8") + if hasattr(response.data, "decode") + else response.data + ) + + if response.status != http_client.OK: + raise exceptions.RefreshError( + "Unable to retrieve AWS security credentials: {}".format(response_body) + ) + + credentials_response = json.loads(response_body) + + return credentials_response + + def _get_metadata_role_name(self, request, imdsv2_session_token): + """Retrieves the AWS role currently attached to the current AWS + workload by querying the AWS metadata server. This is needed for the + AWS metadata server security credentials endpoint in order to retrieve + the AWS security credentials needed to sign requests to AWS APIs. + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. + imdsv2_session_token (str): The AWS IMDSv2 session token to be added as a + header in the requests to AWS metadata endpoint. + + Returns: + str: The AWS role name. + + Raises: + google.auth.exceptions.RefreshError: If an error occurs while + retrieving the AWS role name. + """ + if self._security_credentials_url is None: + raise exceptions.RefreshError( + "Unable to determine the AWS metadata server security credentials endpoint" + ) + + headers = None + if imdsv2_session_token is not None: + headers = {"X-aws-ec2-metadata-token": imdsv2_session_token} + + response = request( + url=self._security_credentials_url, method="GET", headers=headers + ) + + # support both string and bytes type response.data + response_body = ( + response.data.decode("utf-8") + if hasattr(response.data, "decode") + else response.data + ) + + if response.status != http_client.OK: + raise exceptions.RefreshError( + "Unable to retrieve AWS role name {}".format(response_body) + ) + + return response_body + + +class Credentials(external_account.Credentials): + """AWS external account credentials. + This is used to exchange serialized AWS signature v4 signed requests to + AWS STS GetCallerIdentity service for Google access tokens. + """ + + def __init__( + self, + audience, + subject_token_type, + token_url=external_account._DEFAULT_TOKEN_URL, + credential_source=None, + aws_security_credentials_supplier=None, + *args, + **kwargs + ): + """Instantiates an AWS workload external account credentials object. + + Args: + audience (str): The STS audience field. + subject_token_type (str): The subject token type based on the Oauth2.0 token exchange spec. + Expected values include:: + + “urn:ietf:params:aws:token-type:aws4_request” + + token_url (Optional [str]): The STS endpoint URL. If not provided, will default to "https://sts.googleapis.com/v1/token". + credential_source (Optional [Mapping]): The credential source dictionary used + to provide instructions on how to retrieve external credential to be exchanged for Google access tokens. + Either a credential source or an AWS security credentials supplier must be provided. + + Example credential_source for AWS credential:: + + { + "environment_id": "aws1", + "regional_cred_verification_url": "https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", + "region_url": "http://169.254.169.254/latest/meta-data/placement/availability-zone", + "url": "http://169.254.169.254/latest/meta-data/iam/security-credentials", + imdsv2_session_token_url": "http://169.254.169.254/latest/api/token" + } + + aws_security_credentials_supplier (Optional [AwsSecurityCredentialsSupplier]): Optional AWS security credentials supplier. + This will be called to supply valid AWS security credentails which will then + be exchanged for Google access tokens. Either an AWS security credentials supplier + or a credential source must be provided. + args (List): Optional positional arguments passed into the underlying :meth:`~external_account.Credentials.__init__` method. + kwargs (Mapping): Optional keyword arguments passed into the underlying :meth:`~external_account.Credentials.__init__` method. + + Raises: + google.auth.exceptions.RefreshError: If an error is encountered during + access token retrieval logic. + ValueError: For invalid parameters. + + .. note:: Typically one of the helper constructors + :meth:`from_file` or + :meth:`from_info` are used instead of calling the constructor directly. + """ + super(Credentials, self).__init__( + audience=audience, + subject_token_type=subject_token_type, + token_url=token_url, + credential_source=credential_source, + *args, + **kwargs + ) + if credential_source is None and aws_security_credentials_supplier is None: + raise exceptions.InvalidValue( + "A valid credential source or AWS security credentials supplier must be provided." + ) + if ( + credential_source is not None + and aws_security_credentials_supplier is not None + ): + raise exceptions.InvalidValue( + "AWS credential cannot have both a credential source and an AWS security credentials supplier." + ) + + if aws_security_credentials_supplier: + self._aws_security_credentials_supplier = aws_security_credentials_supplier + # The regional cred verification URL would normally be provided through the credential source. So set it to the default one here. + self._cred_verification_url = ( + _DEFAULT_AWS_REGIONAL_CREDENTIAL_VERIFICATION_URL + ) + else: + environment_id = credential_source.get("environment_id") or "" + self._aws_security_credentials_supplier = _DefaultAwsSecurityCredentialsSupplier( + credential_source + ) + self._cred_verification_url = credential_source.get( + "regional_cred_verification_url" + ) + + # Get the environment ID, i.e. "aws1". Currently, only one version supported (1). + matches = re.match(r"^(aws)([\d]+)$", environment_id) + if matches: + env_id, env_version = matches.groups() + else: + env_id, env_version = (None, None) + + if env_id != "aws" or self._cred_verification_url is None: + raise exceptions.InvalidResource( + "No valid AWS 'credential_source' provided" + ) + elif env_version is None or int(env_version) != 1: + raise exceptions.InvalidValue( + "aws version '{}' is not supported in the current build.".format( + env_version + ) + ) + + self._target_resource = audience + self._request_signer = None + + def retrieve_subject_token(self, request): + """Retrieves the subject token using the credential_source object. + The subject token is a serialized `AWS GetCallerIdentity signed request`_. + + The logic is summarized as: + + Retrieve the AWS region from the AWS_REGION or AWS_DEFAULT_REGION + environment variable or from the AWS metadata server availability-zone + if not found in the environment variable. + + Check AWS credentials in environment variables. If not found, retrieve + from the AWS metadata server security-credentials endpoint. + + When retrieving AWS credentials from the metadata server + security-credentials endpoint, the AWS role needs to be determined by + calling the security-credentials endpoint without any argument. Then the + credentials can be retrieved via: security-credentials/role_name + + Generate the signed request to AWS STS GetCallerIdentity action. + + Inject x-goog-cloud-target-resource into header and serialize the + signed request. This will be the subject-token to pass to GCP STS. + + .. _AWS GetCallerIdentity signed request: + https://cloud.google.com/iam/docs/access-resources-aws#exchange-token + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. + Returns: + str: The retrieved subject token. + """ + + # Initialize the request signer if not yet initialized after determining + # the current AWS region. + if self._request_signer is None: + self._region = self._aws_security_credentials_supplier.get_aws_region( + self._supplier_context, request + ) + self._request_signer = RequestSigner(self._region) + + # Retrieve the AWS security credentials needed to generate the signed + # request. + aws_security_credentials = self._aws_security_credentials_supplier.get_aws_security_credentials( + self._supplier_context, request + ) + # Generate the signed request to AWS STS GetCallerIdentity API. + # Use the required regional endpoint. Otherwise, the request will fail. + request_options = self._request_signer.get_request_options( + aws_security_credentials, + self._cred_verification_url.replace("{region}", self._region), + "POST", + ) + # The GCP STS endpoint expects the headers to be formatted as: + # [ + # {key: 'x-amz-date', value: '...'}, + # {key: 'Authorization', value: '...'}, + # ... + # ] + # And then serialized as: + # quote(json.dumps({ + # url: '...', + # method: 'POST', + # headers: [{key: 'x-amz-date', value: '...'}, ...] + # })) + request_headers = request_options.get("headers") + # The full, canonical resource name of the workload identity pool + # provider, with or without the HTTPS prefix. + # Including this header as part of the signature is recommended to + # ensure data integrity. + request_headers["x-goog-cloud-target-resource"] = self._target_resource + + # Serialize AWS signed request. + aws_signed_req = {} + aws_signed_req["url"] = request_options.get("url") + aws_signed_req["method"] = request_options.get("method") + aws_signed_req["headers"] = [] + # Reformat header to GCP STS expected format. + for key in request_headers.keys(): + aws_signed_req["headers"].append( + {"key": key, "value": request_headers[key]} + ) + + return urllib.parse.quote( + json.dumps(aws_signed_req, separators=(",", ":"), sort_keys=True) + ) + + def _create_default_metrics_options(self): + metrics_options = super(Credentials, self)._create_default_metrics_options() + metrics_options["source"] = "aws" + if self._has_custom_supplier(): + metrics_options["source"] = "programmatic" + return metrics_options + + def _has_custom_supplier(self): + return self._credential_source is None + + def _constructor_args(self): + args = super(Credentials, self)._constructor_args() + # If a custom supplier was used, append it to the args dict. + if self._has_custom_supplier(): + args.update( + { + "aws_security_credentials_supplier": self._aws_security_credentials_supplier + } + ) + return args + + @classmethod + def from_info(cls, info, **kwargs): + """Creates an AWS Credentials instance from parsed external account info. + + Args: + info (Mapping[str, str]): The AWS external account info in Google + format. + kwargs: Additional arguments to pass to the constructor. + + Returns: + google.auth.aws.Credentials: The constructed credentials. + + Raises: + ValueError: For invalid parameters. + """ + aws_security_credentials_supplier = info.get( + "aws_security_credentials_supplier" + ) + kwargs.update( + {"aws_security_credentials_supplier": aws_security_credentials_supplier} + ) + return super(Credentials, cls).from_info(info, **kwargs) + + @classmethod + def from_file(cls, filename, **kwargs): + """Creates an AWS Credentials instance from an external account json file. + + Args: + filename (str): The path to the AWS external account json file. + kwargs: Additional arguments to pass to the constructor. + + Returns: + google.auth.aws.Credentials: The constructed credentials. + """ + return super(Credentials, cls).from_file(filename, **kwargs) diff --git a/lib/python3.10/site-packages/google/auth/compute_engine/_metadata.py b/lib/python3.10/site-packages/google/auth/compute_engine/_metadata.py new file mode 100644 index 0000000000000000000000000000000000000000..ddbe8ac2f7089edb274575c134010390eb147701 --- /dev/null +++ b/lib/python3.10/site-packages/google/auth/compute_engine/_metadata.py @@ -0,0 +1,379 @@ +# Copyright 2016 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Provides helper methods for talking to the Compute Engine metadata server. + +See https://cloud.google.com/compute/docs/metadata for more details. +""" + +import datetime +import http.client as http_client +import json +import logging +import os +from urllib.parse import urljoin + +from google.auth import _helpers +from google.auth import environment_vars +from google.auth import exceptions +from google.auth import metrics +from google.auth import transport +from google.auth._exponential_backoff import ExponentialBackoff + +_LOGGER = logging.getLogger(__name__) + +# Environment variable GCE_METADATA_HOST is originally named +# GCE_METADATA_ROOT. For compatibility reasons, here it checks +# the new variable first; if not set, the system falls back +# to the old variable. +_GCE_METADATA_HOST = os.getenv(environment_vars.GCE_METADATA_HOST, None) +if not _GCE_METADATA_HOST: + _GCE_METADATA_HOST = os.getenv( + environment_vars.GCE_METADATA_ROOT, "metadata.google.internal" + ) +_METADATA_ROOT = "http://{}/computeMetadata/v1/".format(_GCE_METADATA_HOST) + +# This is used to ping the metadata server, it avoids the cost of a DNS +# lookup. +_METADATA_IP_ROOT = "http://{}".format( + os.getenv(environment_vars.GCE_METADATA_IP, "169.254.169.254") +) +_METADATA_FLAVOR_HEADER = "metadata-flavor" +_METADATA_FLAVOR_VALUE = "Google" +_METADATA_HEADERS = {_METADATA_FLAVOR_HEADER: _METADATA_FLAVOR_VALUE} + +# Timeout in seconds to wait for the GCE metadata server when detecting the +# GCE environment. +try: + _METADATA_DEFAULT_TIMEOUT = int(os.getenv("GCE_METADATA_TIMEOUT", 3)) +except ValueError: # pragma: NO COVER + _METADATA_DEFAULT_TIMEOUT = 3 + +# Detect GCE Residency +_GOOGLE = "Google" +_GCE_PRODUCT_NAME_FILE = "/sys/class/dmi/id/product_name" + + +def is_on_gce(request): + """Checks to see if the code runs on Google Compute Engine + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. + + Returns: + bool: True if the code runs on Google Compute Engine, False otherwise. + """ + if ping(request): + return True + + if os.name == "nt": + # TODO: implement GCE residency detection on Windows + return False + + # Detect GCE residency on Linux + return detect_gce_residency_linux() + + +def detect_gce_residency_linux(): + """Detect Google Compute Engine residency by smbios check on Linux + + Returns: + bool: True if the GCE product name file is detected, False otherwise. + """ + try: + with open(_GCE_PRODUCT_NAME_FILE, "r") as file_obj: + content = file_obj.read().strip() + + except Exception: + return False + + return content.startswith(_GOOGLE) + + +def ping(request, timeout=_METADATA_DEFAULT_TIMEOUT, retry_count=3): + """Checks to see if the metadata server is available. + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. + timeout (int): How long to wait for the metadata server to respond. + retry_count (int): How many times to attempt connecting to metadata + server using above timeout. + + Returns: + bool: True if the metadata server is reachable, False otherwise. + """ + # NOTE: The explicit ``timeout`` is a workaround. The underlying + # issue is that resolving an unknown host on some networks will take + # 20-30 seconds; making this timeout short fixes the issue, but + # could lead to false negatives in the event that we are on GCE, but + # the metadata resolution was particularly slow. The latter case is + # "unlikely". + headers = _METADATA_HEADERS.copy() + headers[metrics.API_CLIENT_HEADER] = metrics.mds_ping() + + backoff = ExponentialBackoff(total_attempts=retry_count) + + for attempt in backoff: + try: + response = request( + url=_METADATA_IP_ROOT, method="GET", headers=headers, timeout=timeout + ) + + metadata_flavor = response.headers.get(_METADATA_FLAVOR_HEADER) + return ( + response.status == http_client.OK + and metadata_flavor == _METADATA_FLAVOR_VALUE + ) + + except exceptions.TransportError as e: + _LOGGER.warning( + "Compute Engine Metadata server unavailable on " + "attempt %s of %s. Reason: %s", + attempt, + retry_count, + e, + ) + + return False + + +def get( + request, + path, + root=_METADATA_ROOT, + params=None, + recursive=False, + retry_count=5, + headers=None, + return_none_for_not_found_error=False, + timeout=_METADATA_DEFAULT_TIMEOUT, +): + """Fetch a resource from the metadata server. + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. + path (str): The resource to retrieve. For example, + ``'instance/service-accounts/default'``. + root (str): The full path to the metadata server root. + params (Optional[Mapping[str, str]]): A mapping of query parameter + keys to values. + recursive (bool): Whether to do a recursive query of metadata. See + https://cloud.google.com/compute/docs/metadata#aggcontents for more + details. + retry_count (int): How many times to attempt connecting to metadata + server using above timeout. + headers (Optional[Mapping[str, str]]): Headers for the request. + return_none_for_not_found_error (Optional[bool]): If True, returns None + for 404 error instead of throwing an exception. + timeout (int): How long to wait, in seconds for the metadata server to respond. + + Returns: + Union[Mapping, str]: If the metadata server returns JSON, a mapping of + the decoded JSON is returned. Otherwise, the response content is + returned as a string. + + Raises: + google.auth.exceptions.TransportError: if an error occurred while + retrieving metadata. + """ + base_url = urljoin(root, path) + query_params = {} if params is None else params + + headers_to_use = _METADATA_HEADERS.copy() + if headers: + headers_to_use.update(headers) + + if recursive: + query_params["recursive"] = "true" + + url = _helpers.update_query(base_url, query_params) + + backoff = ExponentialBackoff(total_attempts=retry_count) + failure_reason = None + for attempt in backoff: + try: + response = request( + url=url, method="GET", headers=headers_to_use, timeout=timeout + ) + if response.status in transport.DEFAULT_RETRYABLE_STATUS_CODES: + _LOGGER.warning( + "Compute Engine Metadata server unavailable on " + "attempt %s of %s. Response status: %s", + attempt, + retry_count, + response.status, + ) + failure_reason = ( + response.data.decode("utf-8") + if hasattr(response.data, "decode") + else response.data + ) + continue + else: + break + + except exceptions.TransportError as e: + _LOGGER.warning( + "Compute Engine Metadata server unavailable on " + "attempt %s of %s. Reason: %s", + attempt, + retry_count, + e, + ) + failure_reason = e + else: + raise exceptions.TransportError( + "Failed to retrieve {} from the Google Compute Engine " + "metadata service. Compute Engine Metadata server unavailable due to {}".format( + url, failure_reason + ) + ) + + content = _helpers.from_bytes(response.data) + + if response.status == http_client.NOT_FOUND and return_none_for_not_found_error: + return None + + if response.status == http_client.OK: + if ( + _helpers.parse_content_type(response.headers["content-type"]) + == "application/json" + ): + try: + return json.loads(content) + except ValueError as caught_exc: + new_exc = exceptions.TransportError( + "Received invalid JSON from the Google Compute Engine " + "metadata service: {:.20}".format(content) + ) + raise new_exc from caught_exc + else: + return content + + raise exceptions.TransportError( + "Failed to retrieve {} from the Google Compute Engine " + "metadata service. Status: {} Response:\n{}".format( + url, response.status, response.data + ), + response, + ) + + +def get_project_id(request): + """Get the Google Cloud Project ID from the metadata server. + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. + + Returns: + str: The project ID + + Raises: + google.auth.exceptions.TransportError: if an error occurred while + retrieving metadata. + """ + return get(request, "project/project-id") + + +def get_universe_domain(request): + """Get the universe domain value from the metadata server. + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. + + Returns: + str: The universe domain value. If the universe domain endpoint is not + not found, return the default value, which is googleapis.com + + Raises: + google.auth.exceptions.TransportError: if an error other than + 404 occurs while retrieving metadata. + """ + universe_domain = get( + request, "universe/universe-domain", return_none_for_not_found_error=True + ) + if not universe_domain: + return "googleapis.com" + return universe_domain + + +def get_service_account_info(request, service_account="default"): + """Get information about a service account from the metadata server. + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. + service_account (str): The string 'default' or a service account email + address. The determines which service account for which to acquire + information. + + Returns: + Mapping: The service account's information, for example:: + + { + 'email': '...', + 'scopes': ['scope', ...], + 'aliases': ['default', '...'] + } + + Raises: + google.auth.exceptions.TransportError: if an error occurred while + retrieving metadata. + """ + path = "instance/service-accounts/{0}/".format(service_account) + # See https://cloud.google.com/compute/docs/metadata#aggcontents + # for more on the use of 'recursive'. + return get(request, path, params={"recursive": "true"}) + + +def get_service_account_token(request, service_account="default", scopes=None): + """Get the OAuth 2.0 access token for a service account. + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. + service_account (str): The string 'default' or a service account email + address. The determines which service account for which to acquire + an access token. + scopes (Optional[Union[str, List[str]]]): Optional string or list of + strings with auth scopes. + Returns: + Tuple[str, datetime]: The access token and its expiration. + + Raises: + google.auth.exceptions.TransportError: if an error occurred while + retrieving metadata. + """ + if scopes: + if not isinstance(scopes, str): + scopes = ",".join(scopes) + params = {"scopes": scopes} + else: + params = None + + metrics_header = { + metrics.API_CLIENT_HEADER: metrics.token_request_access_token_mds() + } + + path = "instance/service-accounts/{0}/token".format(service_account) + token_json = get(request, path, params=params, headers=metrics_header) + token_expiry = _helpers.utcnow() + datetime.timedelta( + seconds=token_json["expires_in"] + ) + return token_json["access_token"], token_expiry diff --git a/lib/python3.10/site-packages/google/auth/compute_engine/credentials.py b/lib/python3.10/site-packages/google/auth/compute_engine/credentials.py new file mode 100644 index 0000000000000000000000000000000000000000..eb50d288b6a07ac5f85d44565d3170cab4a4c8cd --- /dev/null +++ b/lib/python3.10/site-packages/google/auth/compute_engine/credentials.py @@ -0,0 +1,497 @@ +# Copyright 2016 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Google Compute Engine credentials. + +This module provides authentication for an application running on Google +Compute Engine using the Compute Engine metadata server. + +""" + +import datetime + +from google.auth import _helpers +from google.auth import credentials +from google.auth import exceptions +from google.auth import iam +from google.auth import jwt +from google.auth import metrics +from google.auth.compute_engine import _metadata +from google.oauth2 import _client + + +class Credentials( + credentials.Scoped, + credentials.CredentialsWithQuotaProject, + credentials.CredentialsWithUniverseDomain, +): + """Compute Engine Credentials. + + These credentials use the Google Compute Engine metadata server to obtain + OAuth 2.0 access tokens associated with the instance's service account, + and are also used for Cloud Run, Flex and App Engine (except for the Python + 2.7 runtime, which is supported only on older versions of this library). + + For more information about Compute Engine authentication, including how + to configure scopes, see the `Compute Engine authentication + documentation`_. + + .. note:: On Compute Engine the metadata server ignores requested scopes. + On Cloud Run, Flex and App Engine the server honours requested scopes. + + .. _Compute Engine authentication documentation: + https://cloud.google.com/compute/docs/authentication#using + """ + + def __init__( + self, + service_account_email="default", + quota_project_id=None, + scopes=None, + default_scopes=None, + universe_domain=None, + ): + """ + Args: + service_account_email (str): The service account email to use, or + 'default'. A Compute Engine instance may have multiple service + accounts. + quota_project_id (Optional[str]): The project ID used for quota and + billing. + scopes (Optional[Sequence[str]]): The list of scopes for the credentials. + default_scopes (Optional[Sequence[str]]): Default scopes passed by a + Google client library. Use 'scopes' for user-defined scopes. + universe_domain (Optional[str]): The universe domain. If not + provided or None, credential will attempt to fetch the value + from metadata server. If metadata server doesn't have universe + domain endpoint, then the default googleapis.com will be used. + """ + super(Credentials, self).__init__() + self._service_account_email = service_account_email + self._quota_project_id = quota_project_id + self._scopes = scopes + self._default_scopes = default_scopes + self._universe_domain_cached = False + if universe_domain: + self._universe_domain = universe_domain + self._universe_domain_cached = True + + def _retrieve_info(self, request): + """Retrieve information about the service account. + + Updates the scopes and retrieves the full service account email. + + Args: + request (google.auth.transport.Request): The object used to make + HTTP requests. + """ + info = _metadata.get_service_account_info( + request, service_account=self._service_account_email + ) + + self._service_account_email = info["email"] + + # Don't override scopes requested by the user. + if self._scopes is None: + self._scopes = info["scopes"] + + def _metric_header_for_usage(self): + return metrics.CRED_TYPE_SA_MDS + + def refresh(self, request): + """Refresh the access token and scopes. + + Args: + request (google.auth.transport.Request): The object used to make + HTTP requests. + + Raises: + google.auth.exceptions.RefreshError: If the Compute Engine metadata + service can't be reached if if the instance has not + credentials. + """ + scopes = self._scopes if self._scopes is not None else self._default_scopes + try: + self._retrieve_info(request) + # Always fetch token with default service account email. + self.token, self.expiry = _metadata.get_service_account_token( + request, service_account="default", scopes=scopes + ) + except exceptions.TransportError as caught_exc: + new_exc = exceptions.RefreshError(caught_exc) + raise new_exc from caught_exc + + @property + def service_account_email(self): + """The service account email. + + .. note:: This is not guaranteed to be set until :meth:`refresh` has been + called. + """ + return self._service_account_email + + @property + def requires_scopes(self): + return not self._scopes + + @property + def universe_domain(self): + if self._universe_domain_cached: + return self._universe_domain + + from google.auth.transport import requests as google_auth_requests + + self._universe_domain = _metadata.get_universe_domain( + google_auth_requests.Request() + ) + self._universe_domain_cached = True + return self._universe_domain + + @_helpers.copy_docstring(credentials.Credentials) + def get_cred_info(self): + return { + "credential_source": "metadata server", + "credential_type": "VM credentials", + "principal": self.service_account_email, + } + + @_helpers.copy_docstring(credentials.CredentialsWithQuotaProject) + def with_quota_project(self, quota_project_id): + creds = self.__class__( + service_account_email=self._service_account_email, + quota_project_id=quota_project_id, + scopes=self._scopes, + default_scopes=self._default_scopes, + ) + creds._universe_domain = self._universe_domain + creds._universe_domain_cached = self._universe_domain_cached + return creds + + @_helpers.copy_docstring(credentials.Scoped) + def with_scopes(self, scopes, default_scopes=None): + # Compute Engine credentials can not be scoped (the metadata service + # ignores the scopes parameter). App Engine, Cloud Run and Flex support + # requesting scopes. + creds = self.__class__( + scopes=scopes, + default_scopes=default_scopes, + service_account_email=self._service_account_email, + quota_project_id=self._quota_project_id, + ) + creds._universe_domain = self._universe_domain + creds._universe_domain_cached = self._universe_domain_cached + return creds + + @_helpers.copy_docstring(credentials.CredentialsWithUniverseDomain) + def with_universe_domain(self, universe_domain): + return self.__class__( + scopes=self._scopes, + default_scopes=self._default_scopes, + service_account_email=self._service_account_email, + quota_project_id=self._quota_project_id, + universe_domain=universe_domain, + ) + + +_DEFAULT_TOKEN_LIFETIME_SECS = 3600 # 1 hour in seconds +_DEFAULT_TOKEN_URI = "https://www.googleapis.com/oauth2/v4/token" + + +class IDTokenCredentials( + credentials.CredentialsWithQuotaProject, + credentials.Signing, + credentials.CredentialsWithTokenUri, +): + """Open ID Connect ID Token-based service account credentials. + + These credentials relies on the default service account of a GCE instance. + + ID token can be requested from `GCE metadata server identity endpoint`_, IAM + token endpoint or other token endpoints you specify. If metadata server + identity endpoint is not used, the GCE instance must have been started with + a service account that has access to the IAM Cloud API. + + .. _GCE metadata server identity endpoint: + https://cloud.google.com/compute/docs/instances/verifying-instance-identity + """ + + def __init__( + self, + request, + target_audience, + token_uri=None, + additional_claims=None, + service_account_email=None, + signer=None, + use_metadata_identity_endpoint=False, + quota_project_id=None, + ): + """ + Args: + request (google.auth.transport.Request): The object used to make + HTTP requests. + target_audience (str): The intended audience for these credentials, + used when requesting the ID Token. The ID Token's ``aud`` claim + will be set to this string. + token_uri (str): The OAuth 2.0 Token URI. + additional_claims (Mapping[str, str]): Any additional claims for + the JWT assertion used in the authorization grant. + service_account_email (str): Optional explicit service account to + use to sign JWT tokens. + By default, this is the default GCE service account. + signer (google.auth.crypt.Signer): The signer used to sign JWTs. + In case the signer is specified, the request argument will be + ignored. + use_metadata_identity_endpoint (bool): Whether to use GCE metadata + identity endpoint. For backward compatibility the default value + is False. If set to True, ``token_uri``, ``additional_claims``, + ``service_account_email``, ``signer`` argument should not be set; + otherwise ValueError will be raised. + quota_project_id (Optional[str]): The project ID used for quota and + billing. + + Raises: + ValueError: + If ``use_metadata_identity_endpoint`` is set to True, and one of + ``token_uri``, ``additional_claims``, ``service_account_email``, + ``signer`` arguments is set. + """ + super(IDTokenCredentials, self).__init__() + + self._quota_project_id = quota_project_id + self._use_metadata_identity_endpoint = use_metadata_identity_endpoint + self._target_audience = target_audience + + if use_metadata_identity_endpoint: + if token_uri or additional_claims or service_account_email or signer: + raise exceptions.MalformedError( + "If use_metadata_identity_endpoint is set, token_uri, " + "additional_claims, service_account_email, signer arguments" + " must not be set" + ) + self._token_uri = None + self._additional_claims = None + self._signer = None + + if service_account_email is None: + sa_info = _metadata.get_service_account_info(request) + self._service_account_email = sa_info["email"] + else: + self._service_account_email = service_account_email + + if not use_metadata_identity_endpoint: + if signer is None: + signer = iam.Signer( + request=request, + credentials=Credentials(), + service_account_email=self._service_account_email, + ) + self._signer = signer + self._token_uri = token_uri or _DEFAULT_TOKEN_URI + + if additional_claims is not None: + self._additional_claims = additional_claims + else: + self._additional_claims = {} + + def with_target_audience(self, target_audience): + """Create a copy of these credentials with the specified target + audience. + Args: + target_audience (str): The intended audience for these credentials, + used when requesting the ID Token. + Returns: + google.auth.service_account.IDTokenCredentials: A new credentials + instance. + """ + # since the signer is already instantiated, + # the request is not needed + if self._use_metadata_identity_endpoint: + return self.__class__( + None, + target_audience=target_audience, + use_metadata_identity_endpoint=True, + quota_project_id=self._quota_project_id, + ) + else: + return self.__class__( + None, + service_account_email=self._service_account_email, + token_uri=self._token_uri, + target_audience=target_audience, + additional_claims=self._additional_claims.copy(), + signer=self.signer, + use_metadata_identity_endpoint=False, + quota_project_id=self._quota_project_id, + ) + + @_helpers.copy_docstring(credentials.CredentialsWithQuotaProject) + def with_quota_project(self, quota_project_id): + + # since the signer is already instantiated, + # the request is not needed + if self._use_metadata_identity_endpoint: + return self.__class__( + None, + target_audience=self._target_audience, + use_metadata_identity_endpoint=True, + quota_project_id=quota_project_id, + ) + else: + return self.__class__( + None, + service_account_email=self._service_account_email, + token_uri=self._token_uri, + target_audience=self._target_audience, + additional_claims=self._additional_claims.copy(), + signer=self.signer, + use_metadata_identity_endpoint=False, + quota_project_id=quota_project_id, + ) + + @_helpers.copy_docstring(credentials.CredentialsWithTokenUri) + def with_token_uri(self, token_uri): + + # since the signer is already instantiated, + # the request is not needed + if self._use_metadata_identity_endpoint: + raise exceptions.MalformedError( + "If use_metadata_identity_endpoint is set, token_uri" " must not be set" + ) + else: + return self.__class__( + None, + service_account_email=self._service_account_email, + token_uri=token_uri, + target_audience=self._target_audience, + additional_claims=self._additional_claims.copy(), + signer=self.signer, + use_metadata_identity_endpoint=False, + quota_project_id=self.quota_project_id, + ) + + def _make_authorization_grant_assertion(self): + """Create the OAuth 2.0 assertion. + This assertion is used during the OAuth 2.0 grant to acquire an + ID token. + Returns: + bytes: The authorization grant assertion. + """ + now = _helpers.utcnow() + lifetime = datetime.timedelta(seconds=_DEFAULT_TOKEN_LIFETIME_SECS) + expiry = now + lifetime + + payload = { + "iat": _helpers.datetime_to_secs(now), + "exp": _helpers.datetime_to_secs(expiry), + # The issuer must be the service account email. + "iss": self.service_account_email, + # The audience must be the auth token endpoint's URI + "aud": self._token_uri, + # The target audience specifies which service the ID token is + # intended for. + "target_audience": self._target_audience, + } + + payload.update(self._additional_claims) + + token = jwt.encode(self._signer, payload) + + return token + + def _call_metadata_identity_endpoint(self, request): + """Request ID token from metadata identity endpoint. + + Args: + request (google.auth.transport.Request): The object used to make + HTTP requests. + + Returns: + Tuple[str, datetime.datetime]: The ID token and the expiry of the ID token. + + Raises: + google.auth.exceptions.RefreshError: If the Compute Engine metadata + service can't be reached or if the instance has no credentials. + ValueError: If extracting expiry from the obtained ID token fails. + """ + try: + path = "instance/service-accounts/default/identity" + params = {"audience": self._target_audience, "format": "full"} + metrics_header = { + metrics.API_CLIENT_HEADER: metrics.token_request_id_token_mds() + } + id_token = _metadata.get( + request, path, params=params, headers=metrics_header + ) + except exceptions.TransportError as caught_exc: + new_exc = exceptions.RefreshError(caught_exc) + raise new_exc from caught_exc + + _, payload, _, _ = jwt._unverified_decode(id_token) + return id_token, datetime.datetime.utcfromtimestamp(payload["exp"]) + + def refresh(self, request): + """Refreshes the ID token. + + Args: + request (google.auth.transport.Request): The object used to make + HTTP requests. + + Raises: + google.auth.exceptions.RefreshError: If the credentials could + not be refreshed. + ValueError: If extracting expiry from the obtained ID token fails. + """ + if self._use_metadata_identity_endpoint: + self.token, self.expiry = self._call_metadata_identity_endpoint(request) + else: + assertion = self._make_authorization_grant_assertion() + access_token, expiry, _ = _client.id_token_jwt_grant( + request, self._token_uri, assertion + ) + self.token = access_token + self.expiry = expiry + + @property # type: ignore + @_helpers.copy_docstring(credentials.Signing) + def signer(self): + return self._signer + + def sign_bytes(self, message): + """Signs the given message. + + Args: + message (bytes): The message to sign. + + Returns: + bytes: The message's cryptographic signature. + + Raises: + ValueError: + Signer is not available if metadata identity endpoint is used. + """ + if self._use_metadata_identity_endpoint: + raise exceptions.InvalidOperation( + "Signer is not available if metadata identity endpoint is used" + ) + return self._signer.sign(message) + + @property + def service_account_email(self): + """The service account email.""" + return self._service_account_email + + @property + def signer_email(self): + return self._service_account_email diff --git a/lib/python3.10/site-packages/google/auth/credentials.py b/lib/python3.10/site-packages/google/auth/credentials.py new file mode 100644 index 0000000000000000000000000000000000000000..2c67e04432a4c1cf1c0d995bc19acdc5baf52289 --- /dev/null +++ b/lib/python3.10/site-packages/google/auth/credentials.py @@ -0,0 +1,522 @@ +# Copyright 2016 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Interfaces for credentials.""" + +import abc +from enum import Enum +import os + +from google.auth import _helpers, environment_vars +from google.auth import exceptions +from google.auth import metrics +from google.auth._credentials_base import _BaseCredentials +from google.auth._refresh_worker import RefreshThreadManager + +DEFAULT_UNIVERSE_DOMAIN = "googleapis.com" + + +class Credentials(_BaseCredentials): + """Base class for all credentials. + + All credentials have a :attr:`token` that is used for authentication and + may also optionally set an :attr:`expiry` to indicate when the token will + no longer be valid. + + Most credentials will be :attr:`invalid` until :meth:`refresh` is called. + Credentials can do this automatically before the first HTTP request in + :meth:`before_request`. + + Although the token and expiration will change as the credentials are + :meth:`refreshed ` and used, credentials should be considered + immutable. Various credentials will accept configuration such as private + keys, scopes, and other options. These options are not changeable after + construction. Some classes will provide mechanisms to copy the credentials + with modifications such as :meth:`ScopedCredentials.with_scopes`. + """ + + def __init__(self): + super(Credentials, self).__init__() + + self.expiry = None + """Optional[datetime]: When the token expires and is no longer valid. + If this is None, the token is assumed to never expire.""" + self._quota_project_id = None + """Optional[str]: Project to use for quota and billing purposes.""" + self._trust_boundary = None + """Optional[dict]: Cache of a trust boundary response which has a list + of allowed regions and an encoded string representation of credentials + trust boundary.""" + self._universe_domain = DEFAULT_UNIVERSE_DOMAIN + """Optional[str]: The universe domain value, default is googleapis.com + """ + + self._use_non_blocking_refresh = False + self._refresh_worker = RefreshThreadManager() + + @property + def expired(self): + """Checks if the credentials are expired. + + Note that credentials can be invalid but not expired because + Credentials with :attr:`expiry` set to None is considered to never + expire. + + .. deprecated:: v2.24.0 + Prefer checking :attr:`token_state` instead. + """ + if not self.expiry: + return False + # Remove some threshold from expiry to err on the side of reporting + # expiration early so that we avoid the 401-refresh-retry loop. + skewed_expiry = self.expiry - _helpers.REFRESH_THRESHOLD + return _helpers.utcnow() >= skewed_expiry + + @property + def valid(self): + """Checks the validity of the credentials. + + This is True if the credentials have a :attr:`token` and the token + is not :attr:`expired`. + + .. deprecated:: v2.24.0 + Prefer checking :attr:`token_state` instead. + """ + return self.token is not None and not self.expired + + @property + def token_state(self): + """ + See `:obj:`TokenState` + """ + if self.token is None: + return TokenState.INVALID + + # Credentials that can't expire are always treated as fresh. + if self.expiry is None: + return TokenState.FRESH + + expired = _helpers.utcnow() >= self.expiry + if expired: + return TokenState.INVALID + + is_stale = _helpers.utcnow() >= (self.expiry - _helpers.REFRESH_THRESHOLD) + if is_stale: + return TokenState.STALE + + return TokenState.FRESH + + @property + def quota_project_id(self): + """Project to use for quota and billing purposes.""" + return self._quota_project_id + + @property + def universe_domain(self): + """The universe domain value.""" + return self._universe_domain + + def get_cred_info(self): + """The credential information JSON. + + The credential information will be added to auth related error messages + by client library. + + Returns: + Mapping[str, str]: The credential information JSON. + """ + return None + + @abc.abstractmethod + def refresh(self, request): + """Refreshes the access token. + + Args: + request (google.auth.transport.Request): The object used to make + HTTP requests. + + Raises: + google.auth.exceptions.RefreshError: If the credentials could + not be refreshed. + """ + # pylint: disable=missing-raises-doc + # (pylint doesn't recognize that this is abstract) + raise NotImplementedError("Refresh must be implemented") + + def _metric_header_for_usage(self): + """The x-goog-api-client header for token usage metric. + + This header will be added to the API service requests in before_request + method. For example, "cred-type/sa-jwt" means service account self + signed jwt access token is used in the API service request + authorization header. Children credentials classes need to override + this method to provide the header value, if the token usage metric is + needed. + + Returns: + str: The x-goog-api-client header value. + """ + return None + + def apply(self, headers, token=None): + """Apply the token to the authentication header. + + Args: + headers (Mapping): The HTTP request headers. + token (Optional[str]): If specified, overrides the current access + token. + """ + self._apply(headers, token=token) + """Trust boundary value will be a cached value from global lookup. + + The response of trust boundary will be a list of regions and a hex + encoded representation. + + An example of global lookup response: + { + "locations": [ + "us-central1", "us-east1", "europe-west1", "asia-east1" + ] + "encoded_locations": "0xA30" + } + """ + if self._trust_boundary is not None: + headers["x-allowed-locations"] = self._trust_boundary["encoded_locations"] + if self.quota_project_id: + headers["x-goog-user-project"] = self.quota_project_id + + def _blocking_refresh(self, request): + if not self.valid: + self.refresh(request) + + def _non_blocking_refresh(self, request): + use_blocking_refresh_fallback = False + + if self.token_state == TokenState.STALE: + use_blocking_refresh_fallback = not self._refresh_worker.start_refresh( + self, request + ) + + if self.token_state == TokenState.INVALID or use_blocking_refresh_fallback: + self.refresh(request) + # If the blocking refresh succeeds then we can clear the error info + # on the background refresh worker, and perform refreshes in a + # background thread. + self._refresh_worker.clear_error() + + def before_request(self, request, method, url, headers): + """Performs credential-specific before request logic. + + Refreshes the credentials if necessary, then calls :meth:`apply` to + apply the token to the authentication header. + + Args: + request (google.auth.transport.Request): The object used to make + HTTP requests. + method (str): The request's HTTP method or the RPC method being + invoked. + url (str): The request's URI or the RPC service's URI. + headers (Mapping): The request's headers. + """ + # pylint: disable=unused-argument + # (Subclasses may use these arguments to ascertain information about + # the http request.) + if self._use_non_blocking_refresh: + self._non_blocking_refresh(request) + else: + self._blocking_refresh(request) + + metrics.add_metric_header(headers, self._metric_header_for_usage()) + self.apply(headers) + + def with_non_blocking_refresh(self): + self._use_non_blocking_refresh = True + + +class CredentialsWithQuotaProject(Credentials): + """Abstract base for credentials supporting ``with_quota_project`` factory""" + + def with_quota_project(self, quota_project_id): + """Returns a copy of these credentials with a modified quota project. + + Args: + quota_project_id (str): The project to use for quota and + billing purposes + + Returns: + google.auth.credentials.Credentials: A new credentials instance. + """ + raise NotImplementedError("This credential does not support quota project.") + + def with_quota_project_from_environment(self): + quota_from_env = os.environ.get(environment_vars.GOOGLE_CLOUD_QUOTA_PROJECT) + if quota_from_env: + return self.with_quota_project(quota_from_env) + return self + + +class CredentialsWithTokenUri(Credentials): + """Abstract base for credentials supporting ``with_token_uri`` factory""" + + def with_token_uri(self, token_uri): + """Returns a copy of these credentials with a modified token uri. + + Args: + token_uri (str): The uri to use for fetching/exchanging tokens + + Returns: + google.auth.credentials.Credentials: A new credentials instance. + """ + raise NotImplementedError("This credential does not use token uri.") + + +class CredentialsWithUniverseDomain(Credentials): + """Abstract base for credentials supporting ``with_universe_domain`` factory""" + + def with_universe_domain(self, universe_domain): + """Returns a copy of these credentials with a modified universe domain. + + Args: + universe_domain (str): The universe domain to use + + Returns: + google.auth.credentials.Credentials: A new credentials instance. + """ + raise NotImplementedError( + "This credential does not support with_universe_domain." + ) + + +class AnonymousCredentials(Credentials): + """Credentials that do not provide any authentication information. + + These are useful in the case of services that support anonymous access or + local service emulators that do not use credentials. + """ + + @property + def expired(self): + """Returns `False`, anonymous credentials never expire.""" + return False + + @property + def valid(self): + """Returns `True`, anonymous credentials are always valid.""" + return True + + def refresh(self, request): + """Raises :class:``InvalidOperation``, anonymous credentials cannot be + refreshed.""" + raise exceptions.InvalidOperation("Anonymous credentials cannot be refreshed.") + + def apply(self, headers, token=None): + """Anonymous credentials do nothing to the request. + + The optional ``token`` argument is not supported. + + Raises: + google.auth.exceptions.InvalidValue: If a token was specified. + """ + if token is not None: + raise exceptions.InvalidValue("Anonymous credentials don't support tokens.") + + def before_request(self, request, method, url, headers): + """Anonymous credentials do nothing to the request.""" + + +class ReadOnlyScoped(metaclass=abc.ABCMeta): + """Interface for credentials whose scopes can be queried. + + OAuth 2.0-based credentials allow limiting access using scopes as described + in `RFC6749 Section 3.3`_. + If a credential class implements this interface then the credentials either + use scopes in their implementation. + + Some credentials require scopes in order to obtain a token. You can check + if scoping is necessary with :attr:`requires_scopes`:: + + if credentials.requires_scopes: + # Scoping is required. + credentials = credentials.with_scopes(scopes=['one', 'two']) + + Credentials that require scopes must either be constructed with scopes:: + + credentials = SomeScopedCredentials(scopes=['one', 'two']) + + Or must copy an existing instance using :meth:`with_scopes`:: + + scoped_credentials = credentials.with_scopes(scopes=['one', 'two']) + + Some credentials have scopes but do not allow or require scopes to be set, + these credentials can be used as-is. + + .. _RFC6749 Section 3.3: https://tools.ietf.org/html/rfc6749#section-3.3 + """ + + def __init__(self): + super(ReadOnlyScoped, self).__init__() + self._scopes = None + self._default_scopes = None + + @property + def scopes(self): + """Sequence[str]: the credentials' current set of scopes.""" + return self._scopes + + @property + def default_scopes(self): + """Sequence[str]: the credentials' current set of default scopes.""" + return self._default_scopes + + @abc.abstractproperty + def requires_scopes(self): + """True if these credentials require scopes to obtain an access token. + """ + return False + + def has_scopes(self, scopes): + """Checks if the credentials have the given scopes. + + .. warning: This method is not guaranteed to be accurate if the + credentials are :attr:`~Credentials.invalid`. + + Args: + scopes (Sequence[str]): The list of scopes to check. + + Returns: + bool: True if the credentials have the given scopes. + """ + credential_scopes = ( + self._scopes if self._scopes is not None else self._default_scopes + ) + return set(scopes).issubset(set(credential_scopes or [])) + + +class Scoped(ReadOnlyScoped): + """Interface for credentials whose scopes can be replaced while copying. + + OAuth 2.0-based credentials allow limiting access using scopes as described + in `RFC6749 Section 3.3`_. + If a credential class implements this interface then the credentials either + use scopes in their implementation. + + Some credentials require scopes in order to obtain a token. You can check + if scoping is necessary with :attr:`requires_scopes`:: + + if credentials.requires_scopes: + # Scoping is required. + credentials = credentials.create_scoped(['one', 'two']) + + Credentials that require scopes must either be constructed with scopes:: + + credentials = SomeScopedCredentials(scopes=['one', 'two']) + + Or must copy an existing instance using :meth:`with_scopes`:: + + scoped_credentials = credentials.with_scopes(scopes=['one', 'two']) + + Some credentials have scopes but do not allow or require scopes to be set, + these credentials can be used as-is. + + .. _RFC6749 Section 3.3: https://tools.ietf.org/html/rfc6749#section-3.3 + """ + + @abc.abstractmethod + def with_scopes(self, scopes, default_scopes=None): + """Create a copy of these credentials with the specified scopes. + + Args: + scopes (Sequence[str]): The list of scopes to attach to the + current credentials. + + Raises: + NotImplementedError: If the credentials' scopes can not be changed. + This can be avoided by checking :attr:`requires_scopes` before + calling this method. + """ + raise NotImplementedError("This class does not require scoping.") + + +def with_scopes_if_required(credentials, scopes, default_scopes=None): + """Creates a copy of the credentials with scopes if scoping is required. + + This helper function is useful when you do not know (or care to know) the + specific type of credentials you are using (such as when you use + :func:`google.auth.default`). This function will call + :meth:`Scoped.with_scopes` if the credentials are scoped credentials and if + the credentials require scoping. Otherwise, it will return the credentials + as-is. + + Args: + credentials (google.auth.credentials.Credentials): The credentials to + scope if necessary. + scopes (Sequence[str]): The list of scopes to use. + default_scopes (Sequence[str]): Default scopes passed by a + Google client library. Use 'scopes' for user-defined scopes. + + Returns: + google.auth.credentials.Credentials: Either a new set of scoped + credentials, or the passed in credentials instance if no scoping + was required. + """ + if isinstance(credentials, Scoped) and credentials.requires_scopes: + return credentials.with_scopes(scopes, default_scopes=default_scopes) + else: + return credentials + + +class Signing(metaclass=abc.ABCMeta): + """Interface for credentials that can cryptographically sign messages.""" + + @abc.abstractmethod + def sign_bytes(self, message): + """Signs the given message. + + Args: + message (bytes): The message to sign. + + Returns: + bytes: The message's cryptographic signature. + """ + # pylint: disable=missing-raises-doc,redundant-returns-doc + # (pylint doesn't recognize that this is abstract) + raise NotImplementedError("Sign bytes must be implemented.") + + @abc.abstractproperty + def signer_email(self): + """Optional[str]: An email address that identifies the signer.""" + # pylint: disable=missing-raises-doc + # (pylint doesn't recognize that this is abstract) + raise NotImplementedError("Signer email must be implemented.") + + @abc.abstractproperty + def signer(self): + """google.auth.crypt.Signer: The signer used to sign bytes.""" + # pylint: disable=missing-raises-doc + # (pylint doesn't recognize that this is abstract) + raise NotImplementedError("Signer must be implemented.") + + +class TokenState(Enum): + """ + Tracks the state of a token. + FRESH: The token is valid. It is not expired or close to expired, or the token has no expiry. + STALE: The token is close to expired, and should be refreshed. The token can be used normally. + INVALID: The token is expired or invalid. The token cannot be used for a normal operation. + """ + + FRESH = 1 + STALE = 2 + INVALID = 3 diff --git a/lib/python3.10/site-packages/google/auth/downscoped.py b/lib/python3.10/site-packages/google/auth/downscoped.py new file mode 100644 index 0000000000000000000000000000000000000000..ea75be90fe4e084108c50dab3f7d8c119edad4d1 --- /dev/null +++ b/lib/python3.10/site-packages/google/auth/downscoped.py @@ -0,0 +1,512 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Downscoping with Credential Access Boundaries + +This module provides the ability to downscope credentials using +`Downscoping with Credential Access Boundaries`_. This is useful to restrict the +Identity and Access Management (IAM) permissions that a short-lived credential +can use. + +To downscope permissions of a source credential, a Credential Access Boundary +that specifies which resources the new credential can access, as well as +an upper bound on the permissions that are available on each resource, has to +be defined. A downscoped credential can then be instantiated using the source +credential and the Credential Access Boundary. + +The common pattern of usage is to have a token broker with elevated access +generate these downscoped credentials from higher access source credentials and +pass the downscoped short-lived access tokens to a token consumer via some +secure authenticated channel for limited access to Google Cloud Storage +resources. + +For example, a token broker can be set up on a server in a private network. +Various workloads (token consumers) in the same network will send authenticated +requests to that broker for downscoped tokens to access or modify specific google +cloud storage buckets. + +The broker will instantiate downscoped credentials instances that can be used to +generate short lived downscoped access tokens that can be passed to the token +consumer. These downscoped access tokens can be injected by the consumer into +google.oauth2.Credentials and used to initialize a storage client instance to +access Google Cloud Storage resources with restricted access. + +Note: Only Cloud Storage supports Credential Access Boundaries. Other Google +Cloud services do not support this feature. + +.. _Downscoping with Credential Access Boundaries: https://cloud.google.com/iam/docs/downscoping-short-lived-credentials +""" + +import datetime + +from google.auth import _helpers +from google.auth import credentials +from google.auth import exceptions +from google.oauth2 import sts + +# The maximum number of access boundary rules a Credential Access Boundary can +# contain. +_MAX_ACCESS_BOUNDARY_RULES_COUNT = 10 +# The token exchange grant_type used for exchanging credentials. +_STS_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:token-exchange" +# The token exchange requested_token_type. This is always an access_token. +_STS_REQUESTED_TOKEN_TYPE = "urn:ietf:params:oauth:token-type:access_token" +# The STS token URL used to exchanged a short lived access token for a downscoped one. +_STS_TOKEN_URL_PATTERN = "https://sts.{}/v1/token" +# The subject token type to use when exchanging a short lived access token for a +# downscoped token. +_STS_SUBJECT_TOKEN_TYPE = "urn:ietf:params:oauth:token-type:access_token" + + +class CredentialAccessBoundary(object): + """Defines a Credential Access Boundary which contains a list of access boundary + rules. Each rule contains information on the resource that the rule applies to, + the upper bound of the permissions that are available on that resource and an + optional condition to further restrict permissions. + """ + + def __init__(self, rules=[]): + """Instantiates a Credential Access Boundary. A Credential Access Boundary + can contain up to 10 access boundary rules. + + Args: + rules (Sequence[google.auth.downscoped.AccessBoundaryRule]): The list of + access boundary rules limiting the access that a downscoped credential + will have. + Raises: + InvalidType: If any of the rules are not a valid type. + InvalidValue: If the provided rules exceed the maximum allowed. + """ + self.rules = rules + + @property + def rules(self): + """Returns the list of access boundary rules defined on the Credential + Access Boundary. + + Returns: + Tuple[google.auth.downscoped.AccessBoundaryRule, ...]: The list of access + boundary rules defined on the Credential Access Boundary. These are returned + as an immutable tuple to prevent modification. + """ + return tuple(self._rules) + + @rules.setter + def rules(self, value): + """Updates the current rules on the Credential Access Boundary. This will overwrite + the existing set of rules. + + Args: + value (Sequence[google.auth.downscoped.AccessBoundaryRule]): The list of + access boundary rules limiting the access that a downscoped credential + will have. + Raises: + InvalidType: If any of the rules are not a valid type. + InvalidValue: If the provided rules exceed the maximum allowed. + """ + if len(value) > _MAX_ACCESS_BOUNDARY_RULES_COUNT: + raise exceptions.InvalidValue( + "Credential access boundary rules can have a maximum of {} rules.".format( + _MAX_ACCESS_BOUNDARY_RULES_COUNT + ) + ) + for access_boundary_rule in value: + if not isinstance(access_boundary_rule, AccessBoundaryRule): + raise exceptions.InvalidType( + "List of rules provided do not contain a valid 'google.auth.downscoped.AccessBoundaryRule'." + ) + # Make a copy of the original list. + self._rules = list(value) + + def add_rule(self, rule): + """Adds a single access boundary rule to the existing rules. + + Args: + rule (google.auth.downscoped.AccessBoundaryRule): The access boundary rule, + limiting the access that a downscoped credential will have, to be added to + the existing rules. + Raises: + InvalidType: If any of the rules are not a valid type. + InvalidValue: If the provided rules exceed the maximum allowed. + """ + if len(self.rules) == _MAX_ACCESS_BOUNDARY_RULES_COUNT: + raise exceptions.InvalidValue( + "Credential access boundary rules can have a maximum of {} rules.".format( + _MAX_ACCESS_BOUNDARY_RULES_COUNT + ) + ) + if not isinstance(rule, AccessBoundaryRule): + raise exceptions.InvalidType( + "The provided rule does not contain a valid 'google.auth.downscoped.AccessBoundaryRule'." + ) + self._rules.append(rule) + + def to_json(self): + """Generates the dictionary representation of the Credential Access Boundary. + This uses the format expected by the Security Token Service API as documented in + `Defining a Credential Access Boundary`_. + + .. _Defining a Credential Access Boundary: + https://cloud.google.com/iam/docs/downscoping-short-lived-credentials#define-boundary + + Returns: + Mapping: Credential Access Boundary Rule represented in a dictionary object. + """ + rules = [] + for access_boundary_rule in self.rules: + rules.append(access_boundary_rule.to_json()) + + return {"accessBoundary": {"accessBoundaryRules": rules}} + + +class AccessBoundaryRule(object): + """Defines an access boundary rule which contains information on the resource that + the rule applies to, the upper bound of the permissions that are available on that + resource and an optional condition to further restrict permissions. + """ + + def __init__( + self, available_resource, available_permissions, availability_condition=None + ): + """Instantiates a single access boundary rule. + + Args: + available_resource (str): The full resource name of the Cloud Storage bucket + that the rule applies to. Use the format + "//storage.googleapis.com/projects/_/buckets/bucket-name". + available_permissions (Sequence[str]): A list defining the upper bound that + the downscoped token will have on the available permissions for the + resource. Each value is the identifier for an IAM predefined role or + custom role, with the prefix "inRole:". For example: + "inRole:roles/storage.objectViewer". + Only the permissions in these roles will be available. + availability_condition (Optional[google.auth.downscoped.AvailabilityCondition]): + Optional condition that restricts the availability of permissions to + specific Cloud Storage objects. + + Raises: + InvalidType: If any of the parameters are not of the expected types. + InvalidValue: If any of the parameters are not of the expected values. + """ + self.available_resource = available_resource + self.available_permissions = available_permissions + self.availability_condition = availability_condition + + @property + def available_resource(self): + """Returns the current available resource. + + Returns: + str: The current available resource. + """ + return self._available_resource + + @available_resource.setter + def available_resource(self, value): + """Updates the current available resource. + + Args: + value (str): The updated value of the available resource. + + Raises: + google.auth.exceptions.InvalidType: If the value is not a string. + """ + if not isinstance(value, str): + raise exceptions.InvalidType( + "The provided available_resource is not a string." + ) + self._available_resource = value + + @property + def available_permissions(self): + """Returns the current available permissions. + + Returns: + Tuple[str, ...]: The current available permissions. These are returned + as an immutable tuple to prevent modification. + """ + return tuple(self._available_permissions) + + @available_permissions.setter + def available_permissions(self, value): + """Updates the current available permissions. + + Args: + value (Sequence[str]): The updated value of the available permissions. + + Raises: + InvalidType: If the value is not a list of strings. + InvalidValue: If the value is not valid. + """ + for available_permission in value: + if not isinstance(available_permission, str): + raise exceptions.InvalidType( + "Provided available_permissions are not a list of strings." + ) + if available_permission.find("inRole:") != 0: + raise exceptions.InvalidValue( + "available_permissions must be prefixed with 'inRole:'." + ) + # Make a copy of the original list. + self._available_permissions = list(value) + + @property + def availability_condition(self): + """Returns the current availability condition. + + Returns: + Optional[google.auth.downscoped.AvailabilityCondition]: The current + availability condition. + """ + return self._availability_condition + + @availability_condition.setter + def availability_condition(self, value): + """Updates the current availability condition. + + Args: + value (Optional[google.auth.downscoped.AvailabilityCondition]): The updated + value of the availability condition. + + Raises: + google.auth.exceptions.InvalidType: If the value is not of type google.auth.downscoped.AvailabilityCondition + or None. + """ + if not isinstance(value, AvailabilityCondition) and value is not None: + raise exceptions.InvalidType( + "The provided availability_condition is not a 'google.auth.downscoped.AvailabilityCondition' or None." + ) + self._availability_condition = value + + def to_json(self): + """Generates the dictionary representation of the access boundary rule. + This uses the format expected by the Security Token Service API as documented in + `Defining a Credential Access Boundary`_. + + .. _Defining a Credential Access Boundary: + https://cloud.google.com/iam/docs/downscoping-short-lived-credentials#define-boundary + + Returns: + Mapping: The access boundary rule represented in a dictionary object. + """ + json = { + "availablePermissions": list(self.available_permissions), + "availableResource": self.available_resource, + } + if self.availability_condition: + json["availabilityCondition"] = self.availability_condition.to_json() + return json + + +class AvailabilityCondition(object): + """An optional condition that can be used as part of a Credential Access Boundary + to further restrict permissions.""" + + def __init__(self, expression, title=None, description=None): + """Instantiates an availability condition using the provided expression and + optional title or description. + + Args: + expression (str): A condition expression that specifies the Cloud Storage + objects where permissions are available. For example, this expression + makes permissions available for objects whose name starts with "customer-a": + "resource.name.startsWith('projects/_/buckets/example-bucket/objects/customer-a')" + title (Optional[str]): An optional short string that identifies the purpose of + the condition. + description (Optional[str]): Optional details about the purpose of the condition. + + Raises: + InvalidType: If any of the parameters are not of the expected types. + InvalidValue: If any of the parameters are not of the expected values. + """ + self.expression = expression + self.title = title + self.description = description + + @property + def expression(self): + """Returns the current condition expression. + + Returns: + str: The current conditon expression. + """ + return self._expression + + @expression.setter + def expression(self, value): + """Updates the current condition expression. + + Args: + value (str): The updated value of the condition expression. + + Raises: + google.auth.exceptions.InvalidType: If the value is not of type string. + """ + if not isinstance(value, str): + raise exceptions.InvalidType("The provided expression is not a string.") + self._expression = value + + @property + def title(self): + """Returns the current title. + + Returns: + Optional[str]: The current title. + """ + return self._title + + @title.setter + def title(self, value): + """Updates the current title. + + Args: + value (Optional[str]): The updated value of the title. + + Raises: + google.auth.exceptions.InvalidType: If the value is not of type string or None. + """ + if not isinstance(value, str) and value is not None: + raise exceptions.InvalidType("The provided title is not a string or None.") + self._title = value + + @property + def description(self): + """Returns the current description. + + Returns: + Optional[str]: The current description. + """ + return self._description + + @description.setter + def description(self, value): + """Updates the current description. + + Args: + value (Optional[str]): The updated value of the description. + + Raises: + google.auth.exceptions.InvalidType: If the value is not of type string or None. + """ + if not isinstance(value, str) and value is not None: + raise exceptions.InvalidType( + "The provided description is not a string or None." + ) + self._description = value + + def to_json(self): + """Generates the dictionary representation of the availability condition. + This uses the format expected by the Security Token Service API as documented in + `Defining a Credential Access Boundary`_. + + .. _Defining a Credential Access Boundary: + https://cloud.google.com/iam/docs/downscoping-short-lived-credentials#define-boundary + + Returns: + Mapping[str, str]: The availability condition represented in a dictionary + object. + """ + json = {"expression": self.expression} + if self.title: + json["title"] = self.title + if self.description: + json["description"] = self.description + return json + + +class Credentials(credentials.CredentialsWithQuotaProject): + """Defines a set of Google credentials that are downscoped from an existing set + of Google OAuth2 credentials. This is useful to restrict the Identity and Access + Management (IAM) permissions that a short-lived credential can use. + The common pattern of usage is to have a token broker with elevated access + generate these downscoped credentials from higher access source credentials and + pass the downscoped short-lived access tokens to a token consumer via some + secure authenticated channel for limited access to Google Cloud Storage + resources. + """ + + def __init__( + self, + source_credentials, + credential_access_boundary, + quota_project_id=None, + universe_domain=credentials.DEFAULT_UNIVERSE_DOMAIN, + ): + """Instantiates a downscoped credentials object using the provided source + credentials and credential access boundary rules. + To downscope permissions of a source credential, a Credential Access Boundary + that specifies which resources the new credential can access, as well as an + upper bound on the permissions that are available on each resource, has to be + defined. A downscoped credential can then be instantiated using the source + credential and the Credential Access Boundary. + + Args: + source_credentials (google.auth.credentials.Credentials): The source credentials + to be downscoped based on the provided Credential Access Boundary rules. + credential_access_boundary (google.auth.downscoped.CredentialAccessBoundary): + The Credential Access Boundary which contains a list of access boundary + rules. Each rule contains information on the resource that the rule applies to, + the upper bound of the permissions that are available on that resource and an + optional condition to further restrict permissions. + quota_project_id (Optional[str]): The optional quota project ID. + universe_domain (Optional[str]): The universe domain value, default is googleapis.com + Raises: + google.auth.exceptions.RefreshError: If the source credentials + return an error on token refresh. + google.auth.exceptions.OAuthError: If the STS token exchange + endpoint returned an error during downscoped token generation. + """ + + super(Credentials, self).__init__() + self._source_credentials = source_credentials + self._credential_access_boundary = credential_access_boundary + self._quota_project_id = quota_project_id + self._universe_domain = universe_domain or credentials.DEFAULT_UNIVERSE_DOMAIN + self._sts_client = sts.Client( + _STS_TOKEN_URL_PATTERN.format(self.universe_domain) + ) + + @_helpers.copy_docstring(credentials.Credentials) + def refresh(self, request): + # Generate an access token from the source credentials. + self._source_credentials.refresh(request) + now = _helpers.utcnow() + # Exchange the access token for a downscoped access token. + response_data = self._sts_client.exchange_token( + request=request, + grant_type=_STS_GRANT_TYPE, + subject_token=self._source_credentials.token, + subject_token_type=_STS_SUBJECT_TOKEN_TYPE, + requested_token_type=_STS_REQUESTED_TOKEN_TYPE, + additional_options=self._credential_access_boundary.to_json(), + ) + self.token = response_data.get("access_token") + # For downscoping CAB flow, the STS endpoint may not return the expiration + # field for some flows. The generated downscoped token should always have + # the same expiration time as the source credentials. When no expires_in + # field is returned in the response, we can just get the expiration time + # from the source credentials. + if response_data.get("expires_in"): + lifetime = datetime.timedelta(seconds=response_data.get("expires_in")) + self.expiry = now + lifetime + else: + self.expiry = self._source_credentials.expiry + + @_helpers.copy_docstring(credentials.CredentialsWithQuotaProject) + def with_quota_project(self, quota_project_id): + return self.__class__( + self._source_credentials, + self._credential_access_boundary, + quota_project_id=quota_project_id, + ) diff --git a/lib/python3.10/site-packages/google/auth/environment_vars.py b/lib/python3.10/site-packages/google/auth/environment_vars.py new file mode 100644 index 0000000000000000000000000000000000000000..81f31571eb8080e0d8ec2418e7b27006dca72cd6 --- /dev/null +++ b/lib/python3.10/site-packages/google/auth/environment_vars.py @@ -0,0 +1,84 @@ +# Copyright 2016 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Environment variables used by :mod:`google.auth`.""" + + +PROJECT = "GOOGLE_CLOUD_PROJECT" +"""Environment variable defining default project. + +This used by :func:`google.auth.default` to explicitly set a project ID. This +environment variable is also used by the Google Cloud Python Library. +""" + +LEGACY_PROJECT = "GCLOUD_PROJECT" +"""Previously used environment variable defining the default project. + +This environment variable is used instead of the current one in some +situations (such as Google App Engine). +""" + +GOOGLE_CLOUD_QUOTA_PROJECT = "GOOGLE_CLOUD_QUOTA_PROJECT" +"""Environment variable defining the project to be used for +quota and billing.""" + +CREDENTIALS = "GOOGLE_APPLICATION_CREDENTIALS" +"""Environment variable defining the location of Google application default +credentials.""" + +# The environment variable name which can replace ~/.config if set. +CLOUD_SDK_CONFIG_DIR = "CLOUDSDK_CONFIG" +"""Environment variable defines the location of Google Cloud SDK's config +files.""" + +# These two variables allow for customization of the addresses used when +# contacting the GCE metadata service. +GCE_METADATA_HOST = "GCE_METADATA_HOST" +"""Environment variable providing an alternate hostname or host:port to be +used for GCE metadata requests. + +This environment variable was originally named GCE_METADATA_ROOT. The system will +check this environemnt variable first; should there be no value present, +the system will fall back to the old variable. +""" + +GCE_METADATA_ROOT = "GCE_METADATA_ROOT" +"""Old environment variable for GCE_METADATA_HOST.""" + +GCE_METADATA_IP = "GCE_METADATA_IP" +"""Environment variable providing an alternate ip:port to be used for ip-only +GCE metadata requests.""" + +GOOGLE_API_USE_CLIENT_CERTIFICATE = "GOOGLE_API_USE_CLIENT_CERTIFICATE" +"""Environment variable controlling whether to use client certificate or not. + +The default value is false. Users have to explicitly set this value to true +in order to use client certificate to establish a mutual TLS channel.""" + +LEGACY_APPENGINE_RUNTIME = "APPENGINE_RUNTIME" +"""Gen1 environment variable defining the App Engine Runtime. + +Used to distinguish between GAE gen1 and GAE gen2+. +""" + +# AWS environment variables used with AWS workload identity pools to retrieve +# AWS security credentials and the AWS region needed to create a serialized +# signed requests to the AWS STS GetCalledIdentity API that can be exchanged +# for a Google access tokens via the GCP STS endpoint. +# When not available the AWS metadata server is used to retrieve these values. +AWS_ACCESS_KEY_ID = "AWS_ACCESS_KEY_ID" +AWS_SECRET_ACCESS_KEY = "AWS_SECRET_ACCESS_KEY" +AWS_SESSION_TOKEN = "AWS_SESSION_TOKEN" +AWS_REGION = "AWS_REGION" +AWS_DEFAULT_REGION = "AWS_DEFAULT_REGION" diff --git a/lib/python3.10/site-packages/google/auth/exceptions.py b/lib/python3.10/site-packages/google/auth/exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..feb9f7411e04630bece9b6fc61fdc5cc0e9d0a15 --- /dev/null +++ b/lib/python3.10/site-packages/google/auth/exceptions.py @@ -0,0 +1,108 @@ +# Copyright 2016 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Exceptions used in the google.auth package.""" + + +class GoogleAuthError(Exception): + """Base class for all google.auth errors.""" + + def __init__(self, *args, **kwargs): + super(GoogleAuthError, self).__init__(*args) + retryable = kwargs.get("retryable", False) + self._retryable = retryable + + @property + def retryable(self): + return self._retryable + + +class TransportError(GoogleAuthError): + """Used to indicate an error occurred during an HTTP request.""" + + +class RefreshError(GoogleAuthError): + """Used to indicate that an refreshing the credentials' access token + failed.""" + + +class UserAccessTokenError(GoogleAuthError): + """Used to indicate ``gcloud auth print-access-token`` command failed.""" + + +class DefaultCredentialsError(GoogleAuthError): + """Used to indicate that acquiring default credentials failed.""" + + +class MutualTLSChannelError(GoogleAuthError): + """Used to indicate that mutual TLS channel creation is failed, or mutual + TLS channel credentials is missing or invalid.""" + + +class ClientCertError(GoogleAuthError): + """Used to indicate that client certificate is missing or invalid.""" + + @property + def retryable(self): + return False + + +class OAuthError(GoogleAuthError): + """Used to indicate an error occurred during an OAuth related HTTP + request.""" + + +class ReauthFailError(RefreshError): + """An exception for when reauth failed.""" + + def __init__(self, message=None, **kwargs): + super(ReauthFailError, self).__init__( + "Reauthentication failed. {0}".format(message), **kwargs + ) + + +class ReauthSamlChallengeFailError(ReauthFailError): + """An exception for SAML reauth challenge failures.""" + + +class MalformedError(DefaultCredentialsError, ValueError): + """An exception for malformed data.""" + + +class InvalidResource(DefaultCredentialsError, ValueError): + """An exception for URL error.""" + + +class InvalidOperation(DefaultCredentialsError, ValueError): + """An exception for invalid operation.""" + + +class InvalidValue(DefaultCredentialsError, ValueError): + """Used to wrap general ValueError of python.""" + + +class InvalidType(DefaultCredentialsError, TypeError): + """Used to wrap general TypeError of python.""" + + +class OSError(DefaultCredentialsError, EnvironmentError): + """Used to wrap EnvironmentError(OSError after python3.3).""" + + +class TimeoutError(GoogleAuthError): + """Used to indicate a timeout error occurred during an HTTP request.""" + + +class ResponseError(GoogleAuthError): + """Used to indicate an error occurred when reading an HTTP response.""" diff --git a/lib/python3.10/site-packages/google/auth/external_account.py b/lib/python3.10/site-packages/google/auth/external_account.py new file mode 100644 index 0000000000000000000000000000000000000000..161e6c50ce326833914f41d974a91454c68f52b3 --- /dev/null +++ b/lib/python3.10/site-packages/google/auth/external_account.py @@ -0,0 +1,628 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""External Account Credentials. + +This module provides credentials that exchange workload identity pool external +credentials for Google access tokens. This facilitates accessing Google Cloud +Platform resources from on-prem and non-Google Cloud platforms (e.g. AWS, +Microsoft Azure, OIDC identity providers), using native credentials retrieved +from the current environment without the need to copy, save and manage +long-lived service account credentials. + +Specifically, this is intended to use access tokens acquired using the GCP STS +token exchange endpoint following the `OAuth 2.0 Token Exchange`_ spec. + +.. _OAuth 2.0 Token Exchange: https://tools.ietf.org/html/rfc8693 +""" + +import abc +import copy +from dataclasses import dataclass +import datetime +import functools +import io +import json +import re + +from google.auth import _helpers +from google.auth import credentials +from google.auth import exceptions +from google.auth import impersonated_credentials +from google.auth import metrics +from google.oauth2 import sts +from google.oauth2 import utils + +# External account JSON type identifier. +_EXTERNAL_ACCOUNT_JSON_TYPE = "external_account" +# The token exchange grant_type used for exchanging credentials. +_STS_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:token-exchange" +# The token exchange requested_token_type. This is always an access_token. +_STS_REQUESTED_TOKEN_TYPE = "urn:ietf:params:oauth:token-type:access_token" +# Cloud resource manager URL used to retrieve project information. +_CLOUD_RESOURCE_MANAGER = "https://cloudresourcemanager.googleapis.com/v1/projects/" +# Default Google sts token url. +_DEFAULT_TOKEN_URL = "https://sts.{universe_domain}/v1/token" + + +@dataclass +class SupplierContext: + """A context class that contains information about the requested third party credential that is passed + to AWS security credential and subject token suppliers. + + Attributes: + subject_token_type (str): The requested subject token type based on the Oauth2.0 token exchange spec. + Expected values include:: + + “urn:ietf:params:oauth:token-type:jwt” + “urn:ietf:params:oauth:token-type:id-token” + “urn:ietf:params:oauth:token-type:saml2” + “urn:ietf:params:aws:token-type:aws4_request” + + audience (str): The requested audience for the subject token. + """ + + subject_token_type: str + audience: str + + +class Credentials( + credentials.Scoped, + credentials.CredentialsWithQuotaProject, + credentials.CredentialsWithTokenUri, + metaclass=abc.ABCMeta, +): + """Base class for all external account credentials. + + This is used to instantiate Credentials for exchanging external account + credentials for Google access token and authorizing requests to Google APIs. + The base class implements the common logic for exchanging external account + credentials for Google access tokens. + """ + + def __init__( + self, + audience, + subject_token_type, + token_url, + credential_source, + service_account_impersonation_url=None, + service_account_impersonation_options=None, + client_id=None, + client_secret=None, + token_info_url=None, + quota_project_id=None, + scopes=None, + default_scopes=None, + workforce_pool_user_project=None, + universe_domain=credentials.DEFAULT_UNIVERSE_DOMAIN, + trust_boundary=None, + ): + """Instantiates an external account credentials object. + + Args: + audience (str): The STS audience field. + subject_token_type (str): The subject token type based on the Oauth2.0 token exchange spec. + Expected values include:: + + “urn:ietf:params:oauth:token-type:jwt” + “urn:ietf:params:oauth:token-type:id-token” + “urn:ietf:params:oauth:token-type:saml2” + “urn:ietf:params:aws:token-type:aws4_request” + + token_url (str): The STS endpoint URL. + credential_source (Mapping): The credential source dictionary. + service_account_impersonation_url (Optional[str]): The optional service account + impersonation generateAccessToken URL. + client_id (Optional[str]): The optional client ID. + client_secret (Optional[str]): The optional client secret. + token_info_url (str): The optional STS endpoint URL for token introspection. + quota_project_id (Optional[str]): The optional quota project ID. + scopes (Optional[Sequence[str]]): Optional scopes to request during the + authorization grant. + default_scopes (Optional[Sequence[str]]): Default scopes passed by a + Google client library. Use 'scopes' for user-defined scopes. + workforce_pool_user_project (Optona[str]): The optional workforce pool user + project number when the credential corresponds to a workforce pool and not + a workload identity pool. The underlying principal must still have + serviceusage.services.use IAM permission to use the project for + billing/quota. + universe_domain (str): The universe domain. The default universe + domain is googleapis.com. + trust_boundary (str): String representation of trust boundary meta. + Raises: + google.auth.exceptions.RefreshError: If the generateAccessToken + endpoint returned an error. + """ + super(Credentials, self).__init__() + self._audience = audience + self._subject_token_type = subject_token_type + self._universe_domain = universe_domain + self._token_url = token_url + if self._token_url == _DEFAULT_TOKEN_URL: + self._token_url = self._token_url.replace( + "{universe_domain}", self._universe_domain + ) + self._token_info_url = token_info_url + self._credential_source = credential_source + self._service_account_impersonation_url = service_account_impersonation_url + self._service_account_impersonation_options = ( + service_account_impersonation_options or {} + ) + self._client_id = client_id + self._client_secret = client_secret + self._quota_project_id = quota_project_id + self._scopes = scopes + self._default_scopes = default_scopes + self._workforce_pool_user_project = workforce_pool_user_project + self._trust_boundary = { + "locations": [], + "encoded_locations": "0x0", + } # expose a placeholder trust boundary value. + + if self._client_id: + self._client_auth = utils.ClientAuthentication( + utils.ClientAuthType.basic, self._client_id, self._client_secret + ) + else: + self._client_auth = None + self._sts_client = sts.Client(self._token_url, self._client_auth) + + self._metrics_options = self._create_default_metrics_options() + + self._impersonated_credentials = None + self._project_id = None + self._supplier_context = SupplierContext( + self._subject_token_type, self._audience + ) + self._cred_file_path = None + + if not self.is_workforce_pool and self._workforce_pool_user_project: + # Workload identity pools do not support workforce pool user projects. + raise exceptions.InvalidValue( + "workforce_pool_user_project should not be set for non-workforce pool " + "credentials" + ) + + @property + def info(self): + """Generates the dictionary representation of the current credentials. + + Returns: + Mapping: The dictionary representation of the credentials. This is the + reverse of "from_info" defined on the subclasses of this class. It is + useful for serializing the current credentials so it can deserialized + later. + """ + config_info = self._constructor_args() + config_info.update( + type=_EXTERNAL_ACCOUNT_JSON_TYPE, + service_account_impersonation=config_info.pop( + "service_account_impersonation_options", None + ), + ) + config_info.pop("scopes", None) + config_info.pop("default_scopes", None) + return {key: value for key, value in config_info.items() if value is not None} + + def _constructor_args(self): + args = { + "audience": self._audience, + "subject_token_type": self._subject_token_type, + "token_url": self._token_url, + "token_info_url": self._token_info_url, + "service_account_impersonation_url": self._service_account_impersonation_url, + "service_account_impersonation_options": copy.deepcopy( + self._service_account_impersonation_options + ) + or None, + "credential_source": copy.deepcopy(self._credential_source), + "quota_project_id": self._quota_project_id, + "client_id": self._client_id, + "client_secret": self._client_secret, + "workforce_pool_user_project": self._workforce_pool_user_project, + "scopes": self._scopes, + "default_scopes": self._default_scopes, + "universe_domain": self._universe_domain, + } + if not self.is_workforce_pool: + args.pop("workforce_pool_user_project") + return args + + @property + def service_account_email(self): + """Returns the service account email if service account impersonation is used. + + Returns: + Optional[str]: The service account email if impersonation is used. Otherwise + None is returned. + """ + if self._service_account_impersonation_url: + # Parse email from URL. The formal looks as follows: + # https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/name@project-id.iam.gserviceaccount.com:generateAccessToken + url = self._service_account_impersonation_url + start_index = url.rfind("/") + end_index = url.find(":generateAccessToken") + if start_index != -1 and end_index != -1 and start_index < end_index: + start_index = start_index + 1 + return url[start_index:end_index] + return None + + @property + def is_user(self): + """Returns whether the credentials represent a user (True) or workload (False). + Workloads behave similarly to service accounts. Currently workloads will use + service account impersonation but will eventually not require impersonation. + As a result, this property is more reliable than the service account email + property in determining if the credentials represent a user or workload. + + Returns: + bool: True if the credentials represent a user. False if they represent a + workload. + """ + # If service account impersonation is used, the credentials will always represent a + # service account. + if self._service_account_impersonation_url: + return False + return self.is_workforce_pool + + @property + def is_workforce_pool(self): + """Returns whether the credentials represent a workforce pool (True) or + workload (False) based on the credentials' audience. + + This will also return True for impersonated workforce pool credentials. + + Returns: + bool: True if the credentials represent a workforce pool. False if they + represent a workload. + """ + # Workforce pools representing users have the following audience format: + # //iam.googleapis.com/locations/$location/workforcePools/$poolId/providers/$providerId + p = re.compile(r"//iam\.googleapis\.com/locations/[^/]+/workforcePools/") + return p.match(self._audience or "") is not None + + @property + def requires_scopes(self): + """Checks if the credentials requires scopes. + + Returns: + bool: True if there are no scopes set otherwise False. + """ + return not self._scopes and not self._default_scopes + + @property + def project_number(self): + """Optional[str]: The project number corresponding to the workload identity pool.""" + + # STS audience pattern: + # //iam.googleapis.com/projects/$PROJECT_NUMBER/locations/... + components = self._audience.split("/") + try: + project_index = components.index("projects") + if project_index + 1 < len(components): + return components[project_index + 1] or None + except ValueError: + return None + + @property + def token_info_url(self): + """Optional[str]: The STS token introspection endpoint.""" + + return self._token_info_url + + @_helpers.copy_docstring(credentials.Credentials) + def get_cred_info(self): + if self._cred_file_path: + cred_info_json = { + "credential_source": self._cred_file_path, + "credential_type": "external account credentials", + } + if self.service_account_email: + cred_info_json["principal"] = self.service_account_email + return cred_info_json + return None + + @_helpers.copy_docstring(credentials.Scoped) + def with_scopes(self, scopes, default_scopes=None): + kwargs = self._constructor_args() + kwargs.update(scopes=scopes, default_scopes=default_scopes) + scoped = self.__class__(**kwargs) + scoped._cred_file_path = self._cred_file_path + scoped._metrics_options = self._metrics_options + return scoped + + @abc.abstractmethod + def retrieve_subject_token(self, request): + """Retrieves the subject token using the credential_source object. + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. + Returns: + str: The retrieved subject token. + """ + # pylint: disable=missing-raises-doc + # (pylint doesn't recognize that this is abstract) + raise NotImplementedError("retrieve_subject_token must be implemented") + + def get_project_id(self, request): + """Retrieves the project ID corresponding to the workload identity or workforce pool. + For workforce pool credentials, it returns the project ID corresponding to + the workforce_pool_user_project. + + When not determinable, None is returned. + + This is introduced to support the current pattern of using the Auth library: + + credentials, project_id = google.auth.default() + + The resource may not have permission (resourcemanager.projects.get) to + call this API or the required scopes may not be selected: + https://cloud.google.com/resource-manager/reference/rest/v1/projects/get#authorization-scopes + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. + Returns: + Optional[str]: The project ID corresponding to the workload identity pool + or workforce pool if determinable. + """ + if self._project_id: + # If already retrieved, return the cached project ID value. + return self._project_id + scopes = self._scopes if self._scopes is not None else self._default_scopes + # Scopes are required in order to retrieve a valid access token. + project_number = self.project_number or self._workforce_pool_user_project + if project_number and scopes: + headers = {} + url = _CLOUD_RESOURCE_MANAGER + project_number + self.before_request(request, "GET", url, headers) + response = request(url=url, method="GET", headers=headers) + + response_body = ( + response.data.decode("utf-8") + if hasattr(response.data, "decode") + else response.data + ) + response_data = json.loads(response_body) + + if response.status == 200: + # Cache result as this field is immutable. + self._project_id = response_data.get("projectId") + return self._project_id + + return None + + @_helpers.copy_docstring(credentials.Credentials) + def refresh(self, request): + scopes = self._scopes if self._scopes is not None else self._default_scopes + + # Inject client certificate into request. + if self._mtls_required(): + request = functools.partial( + request, cert=self._get_mtls_cert_and_key_paths() + ) + + if self._should_initialize_impersonated_credentials(): + self._impersonated_credentials = self._initialize_impersonated_credentials() + + if self._impersonated_credentials: + self._impersonated_credentials.refresh(request) + self.token = self._impersonated_credentials.token + self.expiry = self._impersonated_credentials.expiry + else: + now = _helpers.utcnow() + additional_options = None + # Do not pass workforce_pool_user_project when client authentication + # is used. The client ID is sufficient for determining the user project. + if self._workforce_pool_user_project and not self._client_id: + additional_options = {"userProject": self._workforce_pool_user_project} + additional_headers = { + metrics.API_CLIENT_HEADER: metrics.byoid_metrics_header( + self._metrics_options + ) + } + response_data = self._sts_client.exchange_token( + request=request, + grant_type=_STS_GRANT_TYPE, + subject_token=self.retrieve_subject_token(request), + subject_token_type=self._subject_token_type, + audience=self._audience, + scopes=scopes, + requested_token_type=_STS_REQUESTED_TOKEN_TYPE, + additional_options=additional_options, + additional_headers=additional_headers, + ) + self.token = response_data.get("access_token") + expires_in = response_data.get("expires_in") + # Some services do not respect the OAUTH2.0 RFC and send expires_in as a + # JSON String. + if isinstance(expires_in, str): + expires_in = int(expires_in) + + lifetime = datetime.timedelta(seconds=expires_in) + + self.expiry = now + lifetime + + def _make_copy(self): + kwargs = self._constructor_args() + new_cred = self.__class__(**kwargs) + new_cred._cred_file_path = self._cred_file_path + new_cred._metrics_options = self._metrics_options + return new_cred + + @_helpers.copy_docstring(credentials.CredentialsWithQuotaProject) + def with_quota_project(self, quota_project_id): + # Return copy of instance with the provided quota project ID. + cred = self._make_copy() + cred._quota_project_id = quota_project_id + return cred + + @_helpers.copy_docstring(credentials.CredentialsWithTokenUri) + def with_token_uri(self, token_uri): + cred = self._make_copy() + cred._token_url = token_uri + return cred + + @_helpers.copy_docstring(credentials.CredentialsWithUniverseDomain) + def with_universe_domain(self, universe_domain): + cred = self._make_copy() + cred._universe_domain = universe_domain + return cred + + def _should_initialize_impersonated_credentials(self): + return ( + self._service_account_impersonation_url is not None + and self._impersonated_credentials is None + ) + + def _initialize_impersonated_credentials(self): + """Generates an impersonated credentials. + + For more details, see `projects.serviceAccounts.generateAccessToken`_. + + .. _projects.serviceAccounts.generateAccessToken: https://cloud.google.com/iam/docs/reference/credentials/rest/v1/projects.serviceAccounts/generateAccessToken + + Returns: + impersonated_credentials.Credential: The impersonated credentials + object. + + Raises: + google.auth.exceptions.RefreshError: If the generateAccessToken + endpoint returned an error. + """ + # Return copy of instance with no service account impersonation. + kwargs = self._constructor_args() + kwargs.update( + service_account_impersonation_url=None, + service_account_impersonation_options={}, + ) + source_credentials = self.__class__(**kwargs) + source_credentials._metrics_options = self._metrics_options + + # Determine target_principal. + target_principal = self.service_account_email + if not target_principal: + raise exceptions.RefreshError( + "Unable to determine target principal from service account impersonation URL." + ) + + scopes = self._scopes if self._scopes is not None else self._default_scopes + # Initialize and return impersonated credentials. + return impersonated_credentials.Credentials( + source_credentials=source_credentials, + target_principal=target_principal, + target_scopes=scopes, + quota_project_id=self._quota_project_id, + iam_endpoint_override=self._service_account_impersonation_url, + lifetime=self._service_account_impersonation_options.get( + "token_lifetime_seconds" + ), + ) + + def _create_default_metrics_options(self): + metrics_options = {} + if self._service_account_impersonation_url: + metrics_options["sa-impersonation"] = "true" + else: + metrics_options["sa-impersonation"] = "false" + if self._service_account_impersonation_options.get("token_lifetime_seconds"): + metrics_options["config-lifetime"] = "true" + else: + metrics_options["config-lifetime"] = "false" + + return metrics_options + + def _mtls_required(self): + """Returns a boolean representing whether the current credential is configured + for mTLS and should add a certificate to the outgoing calls to the sts and service + account impersonation endpoint. + + Returns: + bool: True if the credential is configured for mTLS, False if it is not. + """ + return False + + def _get_mtls_cert_and_key_paths(self): + """Gets the file locations for a certificate and private key file + to be used for configuring mTLS for the sts and service account + impersonation calls. Currently only expected to return a value when using + X509 workload identity federation. + + Returns: + Tuple[str, str]: The cert and key file locations as strings in a tuple. + + Raises: + NotImplementedError: When the current credential is not configured for + mTLS. + """ + raise NotImplementedError( + "_get_mtls_cert_and_key_location must be implemented." + ) + + @classmethod + def from_info(cls, info, **kwargs): + """Creates a Credentials instance from parsed external account info. + + Args: + info (Mapping[str, str]): The external account info in Google + format. + kwargs: Additional arguments to pass to the constructor. + + Returns: + google.auth.identity_pool.Credentials: The constructed + credentials. + + Raises: + InvalidValue: For invalid parameters. + """ + return cls( + audience=info.get("audience"), + subject_token_type=info.get("subject_token_type"), + token_url=info.get("token_url"), + token_info_url=info.get("token_info_url"), + service_account_impersonation_url=info.get( + "service_account_impersonation_url" + ), + service_account_impersonation_options=info.get( + "service_account_impersonation" + ) + or {}, + client_id=info.get("client_id"), + client_secret=info.get("client_secret"), + credential_source=info.get("credential_source"), + quota_project_id=info.get("quota_project_id"), + workforce_pool_user_project=info.get("workforce_pool_user_project"), + universe_domain=info.get( + "universe_domain", credentials.DEFAULT_UNIVERSE_DOMAIN + ), + **kwargs + ) + + @classmethod + def from_file(cls, filename, **kwargs): + """Creates a Credentials instance from an external account json file. + + Args: + filename (str): The path to the external account json file. + kwargs: Additional arguments to pass to the constructor. + + Returns: + google.auth.identity_pool.Credentials: The constructed + credentials. + """ + with io.open(filename, "r", encoding="utf-8") as json_file: + data = json.load(json_file) + return cls.from_info(data, **kwargs) diff --git a/lib/python3.10/site-packages/google/auth/external_account_authorized_user.py b/lib/python3.10/site-packages/google/auth/external_account_authorized_user.py new file mode 100644 index 0000000000000000000000000000000000000000..4d0c3c680697ddb7f9a10a1a93d86af3379c4d8d --- /dev/null +++ b/lib/python3.10/site-packages/google/auth/external_account_authorized_user.py @@ -0,0 +1,380 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""External Account Authorized User Credentials. +This module provides credentials based on OAuth 2.0 access and refresh tokens. +These credentials usually access resources on behalf of a user (resource +owner). + +Specifically, these are sourced using external identities via Workforce Identity Federation. + +Obtaining the initial access and refresh token can be done through the Google Cloud CLI. + +Example credential: +{ + "type": "external_account_authorized_user", + "audience": "//iam.googleapis.com/locations/global/workforcePools/$WORKFORCE_POOL_ID/providers/$PROVIDER_ID", + "refresh_token": "refreshToken", + "token_url": "https://sts.googleapis.com/v1/oauth/token", + "token_info_url": "https://sts.googleapis.com/v1/instrospect", + "client_id": "clientId", + "client_secret": "clientSecret" +} +""" + +import datetime +import io +import json + +from google.auth import _helpers +from google.auth import credentials +from google.auth import exceptions +from google.oauth2 import sts +from google.oauth2 import utils + +_EXTERNAL_ACCOUNT_AUTHORIZED_USER_JSON_TYPE = "external_account_authorized_user" + + +class Credentials( + credentials.CredentialsWithQuotaProject, + credentials.ReadOnlyScoped, + credentials.CredentialsWithTokenUri, +): + """Credentials for External Account Authorized Users. + + This is used to instantiate Credentials for exchanging refresh tokens from + authorized users for Google access token and authorizing requests to Google + APIs. + + The credentials are considered immutable. If you want to modify the + quota project, use `with_quota_project` and if you want to modify the token + uri, use `with_token_uri`. + """ + + def __init__( + self, + token=None, + expiry=None, + refresh_token=None, + audience=None, + client_id=None, + client_secret=None, + token_url=None, + token_info_url=None, + revoke_url=None, + scopes=None, + quota_project_id=None, + universe_domain=credentials.DEFAULT_UNIVERSE_DOMAIN, + ): + """Instantiates a external account authorized user credentials object. + + Args: + token (str): The OAuth 2.0 access token. Can be None if refresh information + is provided. + expiry (datetime.datetime): The optional expiration datetime of the OAuth 2.0 access + token. + refresh_token (str): The optional OAuth 2.0 refresh token. If specified, + credentials can be refreshed. + audience (str): The optional STS audience which contains the resource name for the workforce + pool and the provider identifier in that pool. + client_id (str): The OAuth 2.0 client ID. Must be specified for refresh, can be left as + None if the token can not be refreshed. + client_secret (str): The OAuth 2.0 client secret. Must be specified for refresh, can be + left as None if the token can not be refreshed. + token_url (str): The optional STS token exchange endpoint for refresh. Must be specified for + refresh, can be left as None if the token can not be refreshed. + token_info_url (str): The optional STS endpoint URL for token introspection. + revoke_url (str): The optional STS endpoint URL for revoking tokens. + quota_project_id (str): The optional project ID used for quota and billing. + This project may be different from the project used to + create the credentials. + universe_domain (Optional[str]): The universe domain. The default value + is googleapis.com. + + Returns: + google.auth.external_account_authorized_user.Credentials: The + constructed credentials. + """ + super(Credentials, self).__init__() + + self.token = token + self.expiry = expiry + self._audience = audience + self._refresh_token = refresh_token + self._token_url = token_url + self._token_info_url = token_info_url + self._client_id = client_id + self._client_secret = client_secret + self._revoke_url = revoke_url + self._quota_project_id = quota_project_id + self._scopes = scopes + self._universe_domain = universe_domain or credentials.DEFAULT_UNIVERSE_DOMAIN + self._cred_file_path = None + + if not self.valid and not self.can_refresh: + raise exceptions.InvalidOperation( + "Token should be created with fields to make it valid (`token` and " + "`expiry`), or fields to allow it to refresh (`refresh_token`, " + "`token_url`, `client_id`, `client_secret`)." + ) + + self._client_auth = None + if self._client_id: + self._client_auth = utils.ClientAuthentication( + utils.ClientAuthType.basic, self._client_id, self._client_secret + ) + self._sts_client = sts.Client(self._token_url, self._client_auth) + + @property + def info(self): + """Generates the serializable dictionary representation of the current + credentials. + + Returns: + Mapping: The dictionary representation of the credentials. This is the + reverse of the "from_info" method defined in this class. It is + useful for serializing the current credentials so it can deserialized + later. + """ + config_info = self.constructor_args() + config_info.update(type=_EXTERNAL_ACCOUNT_AUTHORIZED_USER_JSON_TYPE) + if config_info["expiry"]: + config_info["expiry"] = config_info["expiry"].isoformat() + "Z" + + return {key: value for key, value in config_info.items() if value is not None} + + def constructor_args(self): + return { + "audience": self._audience, + "refresh_token": self._refresh_token, + "token_url": self._token_url, + "token_info_url": self._token_info_url, + "client_id": self._client_id, + "client_secret": self._client_secret, + "token": self.token, + "expiry": self.expiry, + "revoke_url": self._revoke_url, + "scopes": self._scopes, + "quota_project_id": self._quota_project_id, + "universe_domain": self._universe_domain, + } + + @property + def scopes(self): + """Optional[str]: The OAuth 2.0 permission scopes.""" + return self._scopes + + @property + def requires_scopes(self): + """ False: OAuth 2.0 credentials have their scopes set when + the initial token is requested and can not be changed.""" + return False + + @property + def client_id(self): + """Optional[str]: The OAuth 2.0 client ID.""" + return self._client_id + + @property + def client_secret(self): + """Optional[str]: The OAuth 2.0 client secret.""" + return self._client_secret + + @property + def audience(self): + """Optional[str]: The STS audience which contains the resource name for the + workforce pool and the provider identifier in that pool.""" + return self._audience + + @property + def refresh_token(self): + """Optional[str]: The OAuth 2.0 refresh token.""" + return self._refresh_token + + @property + def token_url(self): + """Optional[str]: The STS token exchange endpoint for refresh.""" + return self._token_url + + @property + def token_info_url(self): + """Optional[str]: The STS endpoint for token info.""" + return self._token_info_url + + @property + def revoke_url(self): + """Optional[str]: The STS endpoint for token revocation.""" + return self._revoke_url + + @property + def is_user(self): + """ True: This credential always represents a user.""" + return True + + @property + def can_refresh(self): + return all( + (self._refresh_token, self._token_url, self._client_id, self._client_secret) + ) + + def get_project_id(self, request=None): + """Retrieves the project ID corresponding to the workload identity or workforce pool. + For workforce pool credentials, it returns the project ID corresponding to + the workforce_pool_user_project. + + When not determinable, None is returned. + + Args: + request (google.auth.transport.requests.Request): Request object. + Unused here, but passed from _default.default(). + + Return: + str: project ID is not determinable for this credential type so it returns None + """ + + return None + + def to_json(self, strip=None): + """Utility function that creates a JSON representation of this + credential. + Args: + strip (Sequence[str]): Optional list of members to exclude from the + generated JSON. + Returns: + str: A JSON representation of this instance. When converted into + a dictionary, it can be passed to from_info() + to create a new instance. + """ + strip = strip if strip else [] + return json.dumps({k: v for (k, v) in self.info.items() if k not in strip}) + + def refresh(self, request): + """Refreshes the access token. + + Args: + request (google.auth.transport.Request): The object used to make + HTTP requests. + + Raises: + google.auth.exceptions.RefreshError: If the credentials could + not be refreshed. + """ + if not self.can_refresh: + raise exceptions.RefreshError( + "The credentials do not contain the necessary fields need to " + "refresh the access token. You must specify refresh_token, " + "token_url, client_id, and client_secret." + ) + + now = _helpers.utcnow() + response_data = self._make_sts_request(request) + + self.token = response_data.get("access_token") + + lifetime = datetime.timedelta(seconds=response_data.get("expires_in")) + self.expiry = now + lifetime + + if "refresh_token" in response_data: + self._refresh_token = response_data["refresh_token"] + + def _make_sts_request(self, request): + return self._sts_client.refresh_token(request, self._refresh_token) + + @_helpers.copy_docstring(credentials.Credentials) + def get_cred_info(self): + if self._cred_file_path: + return { + "credential_source": self._cred_file_path, + "credential_type": "external account authorized user credentials", + } + return None + + def _make_copy(self): + kwargs = self.constructor_args() + cred = self.__class__(**kwargs) + cred._cred_file_path = self._cred_file_path + return cred + + @_helpers.copy_docstring(credentials.CredentialsWithQuotaProject) + def with_quota_project(self, quota_project_id): + cred = self._make_copy() + cred._quota_project_id = quota_project_id + return cred + + @_helpers.copy_docstring(credentials.CredentialsWithTokenUri) + def with_token_uri(self, token_uri): + cred = self._make_copy() + cred._token_url = token_uri + return cred + + @_helpers.copy_docstring(credentials.CredentialsWithUniverseDomain) + def with_universe_domain(self, universe_domain): + cred = self._make_copy() + cred._universe_domain = universe_domain + return cred + + @classmethod + def from_info(cls, info, **kwargs): + """Creates a Credentials instance from parsed external account info. + + Args: + info (Mapping[str, str]): The external account info in Google + format. + kwargs: Additional arguments to pass to the constructor. + + Returns: + google.auth.external_account_authorized_user.Credentials: The + constructed credentials. + + Raises: + ValueError: For invalid parameters. + """ + expiry = info.get("expiry") + if expiry: + expiry = datetime.datetime.strptime( + expiry.rstrip("Z").split(".")[0], "%Y-%m-%dT%H:%M:%S" + ) + return cls( + audience=info.get("audience"), + refresh_token=info.get("refresh_token"), + token_url=info.get("token_url"), + token_info_url=info.get("token_info_url"), + client_id=info.get("client_id"), + client_secret=info.get("client_secret"), + token=info.get("token"), + expiry=expiry, + revoke_url=info.get("revoke_url"), + quota_project_id=info.get("quota_project_id"), + scopes=info.get("scopes"), + universe_domain=info.get( + "universe_domain", credentials.DEFAULT_UNIVERSE_DOMAIN + ), + **kwargs + ) + + @classmethod + def from_file(cls, filename, **kwargs): + """Creates a Credentials instance from an external account json file. + + Args: + filename (str): The path to the external account json file. + kwargs: Additional arguments to pass to the constructor. + + Returns: + google.auth.external_account_authorized_user.Credentials: The + constructed credentials. + """ + with io.open(filename, "r", encoding="utf-8") as json_file: + data = json.load(json_file) + return cls.from_info(data, **kwargs) diff --git a/lib/python3.10/site-packages/google/auth/iam.py b/lib/python3.10/site-packages/google/auth/iam.py new file mode 100644 index 0000000000000000000000000000000000000000..1e4cdffec188b74eeaa729e54ff4f2bdd806a5d6 --- /dev/null +++ b/lib/python3.10/site-packages/google/auth/iam.py @@ -0,0 +1,136 @@ +# Copyright 2017 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tools for using the Google `Cloud Identity and Access Management (IAM) +API`_'s auth-related functionality. + +.. _Cloud Identity and Access Management (IAM) API: + https://cloud.google.com/iam/docs/ +""" + +import base64 +import http.client as http_client +import json + +from google.auth import _exponential_backoff +from google.auth import _helpers +from google.auth import credentials +from google.auth import crypt +from google.auth import exceptions + +IAM_RETRY_CODES = { + http_client.INTERNAL_SERVER_ERROR, + http_client.BAD_GATEWAY, + http_client.SERVICE_UNAVAILABLE, + http_client.GATEWAY_TIMEOUT, +} + +_IAM_SCOPE = ["https://www.googleapis.com/auth/iam"] + +_IAM_ENDPOINT = ( + "https://iamcredentials.googleapis.com/v1/projects/-" + + "/serviceAccounts/{}:generateAccessToken" +) + +_IAM_SIGN_ENDPOINT = ( + "https://iamcredentials.googleapis.com/v1/projects/-" + + "/serviceAccounts/{}:signBlob" +) + +_IAM_SIGNJWT_ENDPOINT = ( + "https://iamcredentials.googleapis.com/v1/projects/-" + + "/serviceAccounts/{}:signJwt" +) + +_IAM_IDTOKEN_ENDPOINT = ( + "https://iamcredentials.googleapis.com/v1/" + + "projects/-/serviceAccounts/{}:generateIdToken" +) + + +class Signer(crypt.Signer): + """Signs messages using the IAM `signBlob API`_. + + This is useful when you need to sign bytes but do not have access to the + credential's private key file. + + .. _signBlob API: + https://cloud.google.com/iam/reference/rest/v1/projects.serviceAccounts + /signBlob + """ + + def __init__(self, request, credentials, service_account_email): + """ + Args: + request (google.auth.transport.Request): The object used to make + HTTP requests. + credentials (google.auth.credentials.Credentials): The credentials + that will be used to authenticate the request to the IAM API. + The credentials must have of one the following scopes: + + - https://www.googleapis.com/auth/iam + - https://www.googleapis.com/auth/cloud-platform + service_account_email (str): The service account email identifying + which service account to use to sign bytes. Often, this can + be the same as the service account email in the given + credentials. + """ + self._request = request + self._credentials = credentials + self._service_account_email = service_account_email + + def _make_signing_request(self, message): + """Makes a request to the API signBlob API.""" + message = _helpers.to_bytes(message) + + method = "POST" + url = _IAM_SIGN_ENDPOINT.replace( + credentials.DEFAULT_UNIVERSE_DOMAIN, self._credentials.universe_domain + ).format(self._service_account_email) + headers = {"Content-Type": "application/json"} + body = json.dumps( + {"payload": base64.b64encode(message).decode("utf-8")} + ).encode("utf-8") + + retries = _exponential_backoff.ExponentialBackoff() + for _ in retries: + self._credentials.before_request(self._request, method, url, headers) + + response = self._request(url=url, method=method, body=body, headers=headers) + + if response.status in IAM_RETRY_CODES: + continue + + if response.status != http_client.OK: + raise exceptions.TransportError( + "Error calling the IAM signBlob API: {}".format(response.data) + ) + + return json.loads(response.data.decode("utf-8")) + raise exceptions.TransportError("exhausted signBlob endpoint retries") + + @property + def key_id(self): + """Optional[str]: The key ID used to identify this private key. + + .. warning:: + This is always ``None``. The key ID used by IAM can not + be reliably determined ahead of time. + """ + return None + + @_helpers.copy_docstring(crypt.Signer) + def sign(self, message): + response = self._make_signing_request(message) + return base64.b64decode(response["signedBlob"]) diff --git a/lib/python3.10/site-packages/google/auth/identity_pool.py b/lib/python3.10/site-packages/google/auth/identity_pool.py new file mode 100644 index 0000000000000000000000000000000000000000..c06f88428702cd4a6b96a1e0dbf5e21582e433af --- /dev/null +++ b/lib/python3.10/site-packages/google/auth/identity_pool.py @@ -0,0 +1,528 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Identity Pool Credentials. + +This module provides credentials to access Google Cloud resources from on-prem +or non-Google Cloud platforms which support external credentials (e.g. OIDC ID +tokens) retrieved from local file locations or local servers. This includes +Microsoft Azure and OIDC identity providers (e.g. K8s workloads registered with +Hub with Hub workload identity enabled). + +These credentials are recommended over the use of service account credentials +in on-prem/non-Google Cloud platforms as they do not involve the management of +long-live service account private keys. + +Identity Pool Credentials are initialized using external_account +arguments which are typically loaded from an external credentials file or +an external credentials URL. + +This module also provides a definition for an abstract subject token supplier. +This supplier can be implemented to return a valid OIDC or SAML2.0 subject token +and used to create Identity Pool credentials. The credentials will then call the +supplier instead of using pre-defined methods such as reading a local file or +calling a URL. +""" + +try: + from collections.abc import Mapping +# Python 2.7 compatibility +except ImportError: # pragma: NO COVER + from collections import Mapping # type: ignore +import abc +import base64 +import json +import os +from typing import NamedTuple + +from google.auth import _helpers +from google.auth import exceptions +from google.auth import external_account +from google.auth.transport import _mtls_helper + + +class SubjectTokenSupplier(metaclass=abc.ABCMeta): + """Base class for subject token suppliers. This can be implemented with custom logic to retrieve + a subject token to exchange for a Google Cloud access token when using Workload or + Workforce Identity Federation. The identity pool credential does not cache the subject token, + so caching logic should be added in the implementation. + """ + + @abc.abstractmethod + def get_subject_token(self, context, request): + """Returns the requested subject token. The subject token must be valid. + + .. warning: This is not cached by the calling Google credential, so caching logic should be implemented in the supplier. + + Args: + context (google.auth.externalaccount.SupplierContext): The context object + containing information about the requested audience and subject token type. + request (google.auth.transport.Request): The object used to make + HTTP requests. + + Raises: + google.auth.exceptions.RefreshError: If an error is encountered during + subject token retrieval logic. + + Returns: + str: The requested subject token string. + """ + raise NotImplementedError("") + + +class _TokenContent(NamedTuple): + """Models the token content response from file and url internal suppliers. + Attributes: + content (str): The string content of the file or URL response. + location (str): The location the content was retrieved from. This will either be a file location or a URL. + """ + + content: str + location: str + + +class _FileSupplier(SubjectTokenSupplier): + """ Internal implementation of subject token supplier which supports reading a subject token from a file.""" + + def __init__(self, path, format_type, subject_token_field_name): + self._path = path + self._format_type = format_type + self._subject_token_field_name = subject_token_field_name + + @_helpers.copy_docstring(SubjectTokenSupplier) + def get_subject_token(self, context, request): + if not os.path.exists(self._path): + raise exceptions.RefreshError("File '{}' was not found.".format(self._path)) + + with open(self._path, "r", encoding="utf-8") as file_obj: + token_content = _TokenContent(file_obj.read(), self._path) + + return _parse_token_data( + token_content, self._format_type, self._subject_token_field_name + ) + + +class _UrlSupplier(SubjectTokenSupplier): + """ Internal implementation of subject token supplier which supports retrieving a subject token by calling a URL endpoint.""" + + def __init__(self, url, format_type, subject_token_field_name, headers): + self._url = url + self._format_type = format_type + self._subject_token_field_name = subject_token_field_name + self._headers = headers + + @_helpers.copy_docstring(SubjectTokenSupplier) + def get_subject_token(self, context, request): + response = request(url=self._url, method="GET", headers=self._headers) + + # support both string and bytes type response.data + response_body = ( + response.data.decode("utf-8") + if hasattr(response.data, "decode") + else response.data + ) + + if response.status != 200: + raise exceptions.RefreshError( + "Unable to retrieve Identity Pool subject token", response_body + ) + token_content = _TokenContent(response_body, self._url) + return _parse_token_data( + token_content, self._format_type, self._subject_token_field_name + ) + + +class _X509Supplier(SubjectTokenSupplier): + """Internal supplier for X509 workload credentials. This class is used internally and always returns an empty string as the subject token.""" + + def __init__(self, trust_chain_path, leaf_cert_callback): + self._trust_chain_path = trust_chain_path + self._leaf_cert_callback = leaf_cert_callback + + @_helpers.copy_docstring(SubjectTokenSupplier) + def get_subject_token(self, context, request): + # Import OpennSSL inline because it is an extra import only required by customers + # using mTLS. + from OpenSSL import crypto + + leaf_cert = crypto.load_certificate( + crypto.FILETYPE_PEM, self._leaf_cert_callback() + ) + trust_chain = self._read_trust_chain() + cert_chain = [] + + cert_chain.append(_X509Supplier._encode_cert(leaf_cert)) + + if trust_chain is None or len(trust_chain) == 0: + return json.dumps(cert_chain) + + # Append the first cert if it is not the leaf cert. + first_cert = _X509Supplier._encode_cert(trust_chain[0]) + if first_cert != cert_chain[0]: + cert_chain.append(first_cert) + + for i in range(1, len(trust_chain)): + encoded = _X509Supplier._encode_cert(trust_chain[i]) + # Check if the current cert is the leaf cert and raise an exception if it is. + if encoded == cert_chain[0]: + raise exceptions.RefreshError( + "The leaf certificate must be at the top of the trust chain file" + ) + else: + cert_chain.append(encoded) + return json.dumps(cert_chain) + + def _read_trust_chain(self): + # Import OpennSSL inline because it is an extra import only required by customers + # using mTLS. + from OpenSSL import crypto + + certificate_trust_chain = [] + # If no trust chain path was provided, return an empty list. + if self._trust_chain_path is None or self._trust_chain_path == "": + return certificate_trust_chain + try: + # Open the trust chain file. + with open(self._trust_chain_path, "rb") as f: + trust_chain_data = f.read() + # Split PEM data into individual certificates. + cert_blocks = trust_chain_data.split(b"-----BEGIN CERTIFICATE-----") + for cert_block in cert_blocks: + # Skip empty blocks. + if cert_block.strip(): + cert_data = b"-----BEGIN CERTIFICATE-----" + cert_block + try: + # Load each certificate and add it to the trust chain. + cert = crypto.load_certificate( + crypto.FILETYPE_PEM, cert_data + ) + certificate_trust_chain.append(cert) + except Exception as e: + raise exceptions.RefreshError( + "Error loading PEM certificates from the trust chain file '{}'".format( + self._trust_chain_path + ) + ) from e + return certificate_trust_chain + except FileNotFoundError: + raise exceptions.RefreshError( + "Trust chain file '{}' was not found.".format(self._trust_chain_path) + ) + + def _encode_cert(cert): + # Import OpennSSL inline because it is an extra import only required by customers + # using mTLS. + from OpenSSL import crypto + + return base64.b64encode( + crypto.dump_certificate(crypto.FILETYPE_ASN1, cert) + ).decode("utf-8") + + +def _parse_token_data(token_content, format_type="text", subject_token_field_name=None): + if format_type == "text": + token = token_content.content + else: + try: + # Parse file content as JSON. + response_data = json.loads(token_content.content) + # Get the subject_token. + token = response_data[subject_token_field_name] + except (KeyError, ValueError): + raise exceptions.RefreshError( + "Unable to parse subject_token from JSON file '{}' using key '{}'".format( + token_content.location, subject_token_field_name + ) + ) + if not token: + raise exceptions.RefreshError( + "Missing subject_token in the credential_source file" + ) + return token + + +class Credentials(external_account.Credentials): + """External account credentials sourced from files and URLs.""" + + def __init__( + self, + audience, + subject_token_type, + token_url=external_account._DEFAULT_TOKEN_URL, + credential_source=None, + subject_token_supplier=None, + *args, + **kwargs + ): + """Instantiates an external account credentials object from a file/URL. + + Args: + audience (str): The STS audience field. + subject_token_type (str): The subject token type based on the Oauth2.0 token exchange spec. + Expected values include:: + + “urn:ietf:params:oauth:token-type:jwt” + “urn:ietf:params:oauth:token-type:id-token” + “urn:ietf:params:oauth:token-type:saml2” + + token_url (Optional [str]): The STS endpoint URL. If not provided, will default to "https://sts.googleapis.com/v1/token". + credential_source (Optional [Mapping]): The credential source dictionary used to + provide instructions on how to retrieve external credential to be + exchanged for Google access tokens. Either a credential source or + a subject token supplier must be provided. + + Example credential_source for url-sourced credential:: + + { + "url": "http://www.example.com", + "format": { + "type": "json", + "subject_token_field_name": "access_token", + }, + "headers": {"foo": "bar"}, + } + + Example credential_source for file-sourced credential:: + + { + "file": "/path/to/token/file.txt" + } + subject_token_supplier (Optional [SubjectTokenSupplier]): Optional subject token supplier. + This will be called to supply a valid subject token which will then + be exchanged for Google access tokens. Either a subject token supplier + or a credential source must be provided. + args (List): Optional positional arguments passed into the underlying :meth:`~external_account.Credentials.__init__` method. + kwargs (Mapping): Optional keyword arguments passed into the underlying :meth:`~external_account.Credentials.__init__` method. + + Raises: + google.auth.exceptions.RefreshError: If an error is encountered during + access token retrieval logic. + ValueError: For invalid parameters. + + .. note:: Typically one of the helper constructors + :meth:`from_file` or + :meth:`from_info` are used instead of calling the constructor directly. + """ + + super(Credentials, self).__init__( + audience=audience, + subject_token_type=subject_token_type, + token_url=token_url, + credential_source=credential_source, + *args, + **kwargs + ) + if credential_source is None and subject_token_supplier is None: + raise exceptions.InvalidValue( + "A valid credential source or a subject token supplier must be provided." + ) + if credential_source is not None and subject_token_supplier is not None: + raise exceptions.InvalidValue( + "Identity pool credential cannot have both a credential source and a subject token supplier." + ) + + if subject_token_supplier is not None: + self._subject_token_supplier = subject_token_supplier + self._credential_source_file = None + self._credential_source_url = None + self._credential_source_certificate = None + else: + if not isinstance(credential_source, Mapping): + self._credential_source_executable = None + raise exceptions.MalformedError( + "Invalid credential_source. The credential_source is not a dict." + ) + self._credential_source_file = credential_source.get("file") + self._credential_source_url = credential_source.get("url") + self._credential_source_certificate = credential_source.get("certificate") + + # environment_id is only supported in AWS or dedicated future external + # account credentials. + if "environment_id" in credential_source: + raise exceptions.MalformedError( + "Invalid Identity Pool credential_source field 'environment_id'" + ) + + # check that only one of file, url, or certificate are provided. + self._validate_single_source() + + if self._credential_source_certificate: + self._validate_certificate_config() + else: + self._validate_file_or_url_config(credential_source) + + if self._credential_source_file: + self._subject_token_supplier = _FileSupplier( + self._credential_source_file, + self._credential_source_format_type, + self._credential_source_field_name, + ) + elif self._credential_source_url: + self._subject_token_supplier = _UrlSupplier( + self._credential_source_url, + self._credential_source_format_type, + self._credential_source_field_name, + self._credential_source_headers, + ) + else: # self._credential_source_certificate + self._subject_token_supplier = _X509Supplier( + self._trust_chain_path, self._get_cert_bytes + ) + + @_helpers.copy_docstring(external_account.Credentials) + def retrieve_subject_token(self, request): + return self._subject_token_supplier.get_subject_token( + self._supplier_context, request + ) + + def _get_mtls_cert_and_key_paths(self): + if self._credential_source_certificate is None: + raise exceptions.RefreshError( + 'The credential is not configured to use mtls requests. The credential should include a "certificate" section in the credential source.' + ) + else: + return _mtls_helper._get_workload_cert_and_key_paths( + self._certificate_config_location + ) + + def _get_cert_bytes(self): + cert_path, _ = self._get_mtls_cert_and_key_paths() + return _mtls_helper._read_cert_file(cert_path) + + def _mtls_required(self): + return self._credential_source_certificate is not None + + def _create_default_metrics_options(self): + metrics_options = super(Credentials, self)._create_default_metrics_options() + # Check that credential source is a dict before checking for credential type. This check needs to be done + # here because the external_account credential constructor needs to pass the metrics options to the + # impersonated credential object before the identity_pool credentials are validated. + if isinstance(self._credential_source, Mapping): + if self._credential_source.get("file"): + metrics_options["source"] = "file" + elif self._credential_source.get("url"): + metrics_options["source"] = "url" + else: + metrics_options["source"] = "x509" + else: + metrics_options["source"] = "programmatic" + return metrics_options + + def _has_custom_supplier(self): + return self._credential_source is None + + def _constructor_args(self): + args = super(Credentials, self)._constructor_args() + # If a custom supplier was used, append it to the args dict. + if self._has_custom_supplier(): + args.update({"subject_token_supplier": self._subject_token_supplier}) + return args + + def _validate_certificate_config(self): + self._certificate_config_location = self._credential_source_certificate.get( + "certificate_config_location" + ) + use_default = self._credential_source_certificate.get( + "use_default_certificate_config" + ) + self._trust_chain_path = self._credential_source_certificate.get( + "trust_chain_path" + ) + if self._certificate_config_location and use_default: + raise exceptions.MalformedError( + "Invalid certificate configuration, certificate_config_location cannot be specified when use_default_certificate_config = true." + ) + if not self._certificate_config_location and not use_default: + raise exceptions.MalformedError( + "Invalid certificate configuration, use_default_certificate_config should be true if no certificate_config_location is provided." + ) + + def _validate_file_or_url_config(self, credential_source): + self._credential_source_headers = credential_source.get("headers") + credential_source_format = credential_source.get("format", {}) + # Get credential_source format type. When not provided, this + # defaults to text. + self._credential_source_format_type = ( + credential_source_format.get("type") or "text" + ) + if self._credential_source_format_type not in ["text", "json"]: + raise exceptions.MalformedError( + "Invalid credential_source format '{}'".format( + self._credential_source_format_type + ) + ) + # For JSON types, get the required subject_token field name. + if self._credential_source_format_type == "json": + self._credential_source_field_name = credential_source_format.get( + "subject_token_field_name" + ) + if self._credential_source_field_name is None: + raise exceptions.MalformedError( + "Missing subject_token_field_name for JSON credential_source format" + ) + else: + self._credential_source_field_name = None + + def _validate_single_source(self): + credential_sources = [ + self._credential_source_file, + self._credential_source_url, + self._credential_source_certificate, + ] + valid_credential_sources = list( + filter(lambda source: source is not None, credential_sources) + ) + + if len(valid_credential_sources) > 1: + raise exceptions.MalformedError( + "Ambiguous credential_source. 'file', 'url', and 'certificate' are mutually exclusive.." + ) + if len(valid_credential_sources) != 1: + raise exceptions.MalformedError( + "Missing credential_source. A 'file', 'url', or 'certificate' must be provided." + ) + + @classmethod + def from_info(cls, info, **kwargs): + """Creates an Identity Pool Credentials instance from parsed external account info. + + Args: + info (Mapping[str, str]): The Identity Pool external account info in Google + format. + kwargs: Additional arguments to pass to the constructor. + + Returns: + google.auth.identity_pool.Credentials: The constructed + credentials. + + Raises: + ValueError: For invalid parameters. + """ + subject_token_supplier = info.get("subject_token_supplier") + kwargs.update({"subject_token_supplier": subject_token_supplier}) + return super(Credentials, cls).from_info(info, **kwargs) + + @classmethod + def from_file(cls, filename, **kwargs): + """Creates an IdentityPool Credentials instance from an external account json file. + + Args: + filename (str): The path to the IdentityPool external account json file. + kwargs: Additional arguments to pass to the constructor. + + Returns: + google.auth.identity_pool.Credentials: The constructed + credentials. + """ + return super(Credentials, cls).from_file(filename, **kwargs) diff --git a/lib/python3.10/site-packages/google/auth/impersonated_credentials.py b/lib/python3.10/site-packages/google/auth/impersonated_credentials.py new file mode 100644 index 0000000000000000000000000000000000000000..d49998cfbdc62f765515454067c95bab588db3b1 --- /dev/null +++ b/lib/python3.10/site-packages/google/auth/impersonated_credentials.py @@ -0,0 +1,654 @@ +# Copyright 2018 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Google Cloud Impersonated credentials. + +This module provides authentication for applications where local credentials +impersonates a remote service account using `IAM Credentials API`_. + +This class can be used to impersonate a service account as long as the original +Credential object has the "Service Account Token Creator" role on the target +service account. + + .. _IAM Credentials API: + https://cloud.google.com/iam/credentials/reference/rest/ +""" + +import base64 +import copy +from datetime import datetime +import http.client as http_client +import json + +from google.auth import _exponential_backoff +from google.auth import _helpers +from google.auth import credentials +from google.auth import exceptions +from google.auth import iam +from google.auth import jwt +from google.auth import metrics +from google.oauth2 import _client + + +_REFRESH_ERROR = "Unable to acquire impersonated credentials" + +_DEFAULT_TOKEN_LIFETIME_SECS = 3600 # 1 hour in seconds + +_GOOGLE_OAUTH2_TOKEN_ENDPOINT = "https://oauth2.googleapis.com/token" + +_SOURCE_CREDENTIAL_AUTHORIZED_USER_TYPE = "authorized_user" +_SOURCE_CREDENTIAL_SERVICE_ACCOUNT_TYPE = "service_account" +_SOURCE_CREDENTIAL_EXTERNAL_ACCOUNT_AUTHORIZED_USER_TYPE = ( + "external_account_authorized_user" +) + + +def _make_iam_token_request( + request, + principal, + headers, + body, + universe_domain=credentials.DEFAULT_UNIVERSE_DOMAIN, + iam_endpoint_override=None, +): + """Makes a request to the Google Cloud IAM service for an access token. + Args: + request (Request): The Request object to use. + principal (str): The principal to request an access token for. + headers (Mapping[str, str]): Map of headers to transmit. + body (Mapping[str, str]): JSON Payload body for the iamcredentials + API call. + iam_endpoint_override (Optiona[str]): The full IAM endpoint override + with the target_principal embedded. This is useful when supporting + impersonation with regional endpoints. + + Raises: + google.auth.exceptions.TransportError: Raised if there is an underlying + HTTP connection error + google.auth.exceptions.RefreshError: Raised if the impersonated + credentials are not available. Common reasons are + `iamcredentials.googleapis.com` is not enabled or the + `Service Account Token Creator` is not assigned + """ + iam_endpoint = iam_endpoint_override or iam._IAM_ENDPOINT.replace( + credentials.DEFAULT_UNIVERSE_DOMAIN, universe_domain + ).format(principal) + + body = json.dumps(body).encode("utf-8") + + response = request(url=iam_endpoint, method="POST", headers=headers, body=body) + + # support both string and bytes type response.data + response_body = ( + response.data.decode("utf-8") + if hasattr(response.data, "decode") + else response.data + ) + + if response.status != http_client.OK: + raise exceptions.RefreshError(_REFRESH_ERROR, response_body) + + try: + token_response = json.loads(response_body) + token = token_response["accessToken"] + expiry = datetime.strptime(token_response["expireTime"], "%Y-%m-%dT%H:%M:%SZ") + + return token, expiry + + except (KeyError, ValueError) as caught_exc: + new_exc = exceptions.RefreshError( + "{}: No access token or invalid expiration in response.".format( + _REFRESH_ERROR + ), + response_body, + ) + raise new_exc from caught_exc + + +class Credentials( + credentials.Scoped, credentials.CredentialsWithQuotaProject, credentials.Signing +): + """This module defines impersonated credentials which are essentially + impersonated identities. + + Impersonated Credentials allows credentials issued to a user or + service account to impersonate another. The target service account must + grant the originating credential principal the + `Service Account Token Creator`_ IAM role: + + For more information about Token Creator IAM role and + IAMCredentials API, see + `Creating Short-Lived Service Account Credentials`_. + + .. _Service Account Token Creator: + https://cloud.google.com/iam/docs/service-accounts#the_service_account_token_creator_role + + .. _Creating Short-Lived Service Account Credentials: + https://cloud.google.com/iam/docs/creating-short-lived-service-account-credentials + + Usage: + + First grant source_credentials the `Service Account Token Creator` + role on the target account to impersonate. In this example, the + service account represented by svc_account.json has the + token creator role on + `impersonated-account@_project_.iam.gserviceaccount.com`. + + Enable the IAMCredentials API on the source project: + `gcloud services enable iamcredentials.googleapis.com`. + + Initialize a source credential which does not have access to + list bucket:: + + from google.oauth2 import service_account + + target_scopes = [ + 'https://www.googleapis.com/auth/devstorage.read_only'] + + source_credentials = ( + service_account.Credentials.from_service_account_file( + '/path/to/svc_account.json', + scopes=target_scopes)) + + Now use the source credentials to acquire credentials to impersonate + another service account:: + + from google.auth import impersonated_credentials + + target_credentials = impersonated_credentials.Credentials( + source_credentials=source_credentials, + target_principal='impersonated-account@_project_.iam.gserviceaccount.com', + target_scopes = target_scopes, + lifetime=500) + + Resource access is granted:: + + client = storage.Client(credentials=target_credentials) + buckets = client.list_buckets(project='your_project') + for bucket in buckets: + print(bucket.name) + """ + + def __init__( + self, + source_credentials, + target_principal, + target_scopes, + delegates=None, + subject=None, + lifetime=_DEFAULT_TOKEN_LIFETIME_SECS, + quota_project_id=None, + iam_endpoint_override=None, + ): + """ + Args: + source_credentials (google.auth.Credentials): The source credential + used as to acquire the impersonated credentials. + target_principal (str): The service account to impersonate. + target_scopes (Sequence[str]): Scopes to request during the + authorization grant. + delegates (Sequence[str]): The chained list of delegates required + to grant the final access_token. If set, the sequence of + identities must have "Service Account Token Creator" capability + granted to the prceeding identity. For example, if set to + [serviceAccountB, serviceAccountC], the source_credential + must have the Token Creator role on serviceAccountB. + serviceAccountB must have the Token Creator on + serviceAccountC. + Finally, C must have Token Creator on target_principal. + If left unset, source_credential must have that role on + target_principal. + lifetime (int): Number of seconds the delegated credential should + be valid for (upto 3600). + quota_project_id (Optional[str]): The project ID used for quota and billing. + This project may be different from the project used to + create the credentials. + iam_endpoint_override (Optional[str]): The full IAM endpoint override + with the target_principal embedded. This is useful when supporting + impersonation with regional endpoints. + subject (Optional[str]): sub field of a JWT. This field should only be set + if you wish to impersonate as a user. This feature is useful when + using domain wide delegation. + """ + + super(Credentials, self).__init__() + + self._source_credentials = copy.copy(source_credentials) + # Service account source credentials must have the _IAM_SCOPE + # added to refresh correctly. User credentials cannot have + # their original scopes modified. + if isinstance(self._source_credentials, credentials.Scoped): + self._source_credentials = self._source_credentials.with_scopes( + iam._IAM_SCOPE + ) + # If the source credential is service account and self signed jwt + # is needed, we need to create a jwt credential inside it + if ( + hasattr(self._source_credentials, "_create_self_signed_jwt") + and self._source_credentials._always_use_jwt_access + ): + self._source_credentials._create_self_signed_jwt(None) + + self._universe_domain = source_credentials.universe_domain + self._target_principal = target_principal + self._target_scopes = target_scopes + self._delegates = delegates + self._subject = subject + self._lifetime = lifetime or _DEFAULT_TOKEN_LIFETIME_SECS + self.token = None + self.expiry = _helpers.utcnow() + self._quota_project_id = quota_project_id + self._iam_endpoint_override = iam_endpoint_override + self._cred_file_path = None + + def _metric_header_for_usage(self): + return metrics.CRED_TYPE_SA_IMPERSONATE + + @_helpers.copy_docstring(credentials.Credentials) + def refresh(self, request): + self._update_token(request) + + def _update_token(self, request): + """Updates credentials with a new access_token representing + the impersonated account. + + Args: + request (google.auth.transport.requests.Request): Request object + to use for refreshing credentials. + """ + + # Refresh our source credentials if it is not valid. + if ( + self._source_credentials.token_state == credentials.TokenState.STALE + or self._source_credentials.token_state == credentials.TokenState.INVALID + ): + self._source_credentials.refresh(request) + + body = { + "delegates": self._delegates, + "scope": self._target_scopes, + "lifetime": str(self._lifetime) + "s", + } + + headers = { + "Content-Type": "application/json", + metrics.API_CLIENT_HEADER: metrics.token_request_access_token_impersonate(), + } + + # Apply the source credentials authentication info. + self._source_credentials.apply(headers) + + # If a subject is specified a domain-wide delegation auth-flow is initiated + # to impersonate as the provided subject (user). + if self._subject: + if self.universe_domain != credentials.DEFAULT_UNIVERSE_DOMAIN: + raise exceptions.GoogleAuthError( + "Domain-wide delegation is not supported in universes other " + + "than googleapis.com" + ) + + now = _helpers.utcnow() + payload = { + "iss": self._target_principal, + "scope": _helpers.scopes_to_string(self._target_scopes or ()), + "sub": self._subject, + "aud": _GOOGLE_OAUTH2_TOKEN_ENDPOINT, + "iat": _helpers.datetime_to_secs(now), + "exp": _helpers.datetime_to_secs(now) + _DEFAULT_TOKEN_LIFETIME_SECS, + } + + assertion = _sign_jwt_request( + request=request, + principal=self._target_principal, + headers=headers, + payload=payload, + delegates=self._delegates, + ) + + self.token, self.expiry, _ = _client.jwt_grant( + request, _GOOGLE_OAUTH2_TOKEN_ENDPOINT, assertion + ) + + return + + self.token, self.expiry = _make_iam_token_request( + request=request, + principal=self._target_principal, + headers=headers, + body=body, + universe_domain=self.universe_domain, + iam_endpoint_override=self._iam_endpoint_override, + ) + + def sign_bytes(self, message): + from google.auth.transport.requests import AuthorizedSession + + iam_sign_endpoint = iam._IAM_SIGN_ENDPOINT.replace( + credentials.DEFAULT_UNIVERSE_DOMAIN, self.universe_domain + ).format(self._target_principal) + + body = { + "payload": base64.b64encode(message).decode("utf-8"), + "delegates": self._delegates, + } + + headers = {"Content-Type": "application/json"} + + authed_session = AuthorizedSession(self._source_credentials) + + try: + retries = _exponential_backoff.ExponentialBackoff() + for _ in retries: + response = authed_session.post( + url=iam_sign_endpoint, headers=headers, json=body + ) + if response.status_code in iam.IAM_RETRY_CODES: + continue + if response.status_code != http_client.OK: + raise exceptions.TransportError( + "Error calling sign_bytes: {}".format(response.json()) + ) + + return base64.b64decode(response.json()["signedBlob"]) + finally: + authed_session.close() + raise exceptions.TransportError("exhausted signBlob endpoint retries") + + @property + def signer_email(self): + return self._target_principal + + @property + def service_account_email(self): + return self._target_principal + + @property + def signer(self): + return self + + @property + def requires_scopes(self): + return not self._target_scopes + + @_helpers.copy_docstring(credentials.Credentials) + def get_cred_info(self): + if self._cred_file_path: + return { + "credential_source": self._cred_file_path, + "credential_type": "impersonated credentials", + "principal": self._target_principal, + } + return None + + def _make_copy(self): + cred = self.__class__( + self._source_credentials, + target_principal=self._target_principal, + target_scopes=self._target_scopes, + delegates=self._delegates, + lifetime=self._lifetime, + quota_project_id=self._quota_project_id, + iam_endpoint_override=self._iam_endpoint_override, + ) + cred._cred_file_path = self._cred_file_path + return cred + + @_helpers.copy_docstring(credentials.CredentialsWithQuotaProject) + def with_quota_project(self, quota_project_id): + cred = self._make_copy() + cred._quota_project_id = quota_project_id + return cred + + @_helpers.copy_docstring(credentials.Scoped) + def with_scopes(self, scopes, default_scopes=None): + cred = self._make_copy() + cred._target_scopes = scopes or default_scopes + return cred + + @classmethod + def from_impersonated_service_account_info(cls, info, scopes=None): + """Creates a Credentials instance from parsed impersonated service account credentials info. + + Args: + info (Mapping[str, str]): The impersonated service account credentials info in Google + format. + scopes (Sequence[str]): Optional list of scopes to include in the + credentials. + + Returns: + google.oauth2.credentials.Credentials: The constructed + credentials. + + Raises: + InvalidType: If the info["source_credentials"] are not a supported impersonation type + InvalidValue: If the info["service_account_impersonation_url"] is not in the expected format. + ValueError: If the info is not in the expected format. + """ + + source_credentials_info = info.get("source_credentials") + source_credentials_type = source_credentials_info.get("type") + if source_credentials_type == _SOURCE_CREDENTIAL_AUTHORIZED_USER_TYPE: + from google.oauth2 import credentials + + source_credentials = credentials.Credentials.from_authorized_user_info( + source_credentials_info + ) + elif source_credentials_type == _SOURCE_CREDENTIAL_SERVICE_ACCOUNT_TYPE: + from google.oauth2 import service_account + + source_credentials = service_account.Credentials.from_service_account_info( + source_credentials_info + ) + elif ( + source_credentials_type + == _SOURCE_CREDENTIAL_EXTERNAL_ACCOUNT_AUTHORIZED_USER_TYPE + ): + from google.auth import external_account_authorized_user + + source_credentials = external_account_authorized_user.Credentials.from_info( + source_credentials_info + ) + else: + raise exceptions.InvalidType( + "source credential of type {} is not supported.".format( + source_credentials_type + ) + ) + + impersonation_url = info.get("service_account_impersonation_url") + start_index = impersonation_url.rfind("/") + end_index = impersonation_url.find(":generateAccessToken") + if start_index == -1 or end_index == -1 or start_index > end_index: + raise exceptions.InvalidValue( + "Cannot extract target principal from {}".format(impersonation_url) + ) + target_principal = impersonation_url[start_index + 1 : end_index] + delegates = info.get("delegates") + quota_project_id = info.get("quota_project_id") + + return cls( + source_credentials, + target_principal, + scopes, + delegates, + quota_project_id=quota_project_id, + ) + + +class IDTokenCredentials(credentials.CredentialsWithQuotaProject): + """Open ID Connect ID Token-based service account credentials. + + """ + + def __init__( + self, + target_credentials, + target_audience=None, + include_email=False, + quota_project_id=None, + ): + """ + Args: + target_credentials (google.auth.Credentials): The target + credential used as to acquire the id tokens for. + target_audience (string): Audience to issue the token for. + include_email (bool): Include email in IdToken + quota_project_id (Optional[str]): The project ID used for + quota and billing. + """ + super(IDTokenCredentials, self).__init__() + + if not isinstance(target_credentials, Credentials): + raise exceptions.GoogleAuthError( + "Provided Credential must be " "impersonated_credentials" + ) + self._target_credentials = target_credentials + self._target_audience = target_audience + self._include_email = include_email + self._quota_project_id = quota_project_id + + def from_credentials(self, target_credentials, target_audience=None): + return self.__class__( + target_credentials=target_credentials, + target_audience=target_audience, + include_email=self._include_email, + quota_project_id=self._quota_project_id, + ) + + def with_target_audience(self, target_audience): + return self.__class__( + target_credentials=self._target_credentials, + target_audience=target_audience, + include_email=self._include_email, + quota_project_id=self._quota_project_id, + ) + + def with_include_email(self, include_email): + return self.__class__( + target_credentials=self._target_credentials, + target_audience=self._target_audience, + include_email=include_email, + quota_project_id=self._quota_project_id, + ) + + @_helpers.copy_docstring(credentials.CredentialsWithQuotaProject) + def with_quota_project(self, quota_project_id): + return self.__class__( + target_credentials=self._target_credentials, + target_audience=self._target_audience, + include_email=self._include_email, + quota_project_id=quota_project_id, + ) + + @_helpers.copy_docstring(credentials.Credentials) + def refresh(self, request): + from google.auth.transport.requests import AuthorizedSession + + iam_sign_endpoint = iam._IAM_IDTOKEN_ENDPOINT.replace( + credentials.DEFAULT_UNIVERSE_DOMAIN, + self._target_credentials.universe_domain, + ).format(self._target_credentials.signer_email) + + body = { + "audience": self._target_audience, + "delegates": self._target_credentials._delegates, + "includeEmail": self._include_email, + } + + headers = { + "Content-Type": "application/json", + metrics.API_CLIENT_HEADER: metrics.token_request_id_token_impersonate(), + } + + authed_session = AuthorizedSession( + self._target_credentials._source_credentials, auth_request=request + ) + + try: + response = authed_session.post( + url=iam_sign_endpoint, + headers=headers, + data=json.dumps(body).encode("utf-8"), + ) + finally: + authed_session.close() + + if response.status_code != http_client.OK: + raise exceptions.RefreshError( + "Error getting ID token: {}".format(response.json()) + ) + + id_token = response.json()["token"] + self.token = id_token + self.expiry = datetime.utcfromtimestamp( + jwt.decode(id_token, verify=False)["exp"] + ) + + +def _sign_jwt_request(request, principal, headers, payload, delegates=[]): + """Makes a request to the Google Cloud IAM service to sign a JWT using a + service account's system-managed private key. + Args: + request (Request): The Request object to use. + principal (str): The principal to request an access token for. + headers (Mapping[str, str]): Map of headers to transmit. + payload (Mapping[str, str]): The JWT payload to sign. Must be a + serialized JSON object that contains a JWT Claims Set. + delegates (Sequence[str]): The chained list of delegates required + to grant the final access_token. If set, the sequence of + identities must have "Service Account Token Creator" capability + granted to the prceeding identity. For example, if set to + [serviceAccountB, serviceAccountC], the source_credential + must have the Token Creator role on serviceAccountB. + serviceAccountB must have the Token Creator on + serviceAccountC. + Finally, C must have Token Creator on target_principal. + If left unset, source_credential must have that role on + target_principal. + + Raises: + google.auth.exceptions.TransportError: Raised if there is an underlying + HTTP connection error + google.auth.exceptions.RefreshError: Raised if the impersonated + credentials are not available. Common reasons are + `iamcredentials.googleapis.com` is not enabled or the + `Service Account Token Creator` is not assigned + """ + iam_endpoint = iam._IAM_SIGNJWT_ENDPOINT.format(principal) + + body = {"delegates": delegates, "payload": json.dumps(payload)} + body = json.dumps(body).encode("utf-8") + + response = request(url=iam_endpoint, method="POST", headers=headers, body=body) + + # support both string and bytes type response.data + response_body = ( + response.data.decode("utf-8") + if hasattr(response.data, "decode") + else response.data + ) + + if response.status != http_client.OK: + raise exceptions.RefreshError(_REFRESH_ERROR, response_body) + + try: + jwt_response = json.loads(response_body) + signed_jwt = jwt_response["signedJwt"] + return signed_jwt + + except (KeyError, ValueError) as caught_exc: + new_exc = exceptions.RefreshError( + "{}: No signed JWT in response.".format(_REFRESH_ERROR), response_body + ) + raise new_exc from caught_exc diff --git a/lib/python3.10/site-packages/google/auth/jwt.py b/lib/python3.10/site-packages/google/auth/jwt.py new file mode 100644 index 0000000000000000000000000000000000000000..1ebd565d4ecd3e1365a095e30b69d0e665e09a11 --- /dev/null +++ b/lib/python3.10/site-packages/google/auth/jwt.py @@ -0,0 +1,878 @@ +# Copyright 2016 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""JSON Web Tokens + +Provides support for creating (encoding) and verifying (decoding) JWTs, +especially JWTs generated and consumed by Google infrastructure. + +See `rfc7519`_ for more details on JWTs. + +To encode a JWT use :func:`encode`:: + + from google.auth import crypt + from google.auth import jwt + + signer = crypt.Signer(private_key) + payload = {'some': 'payload'} + encoded = jwt.encode(signer, payload) + +To decode a JWT and verify claims use :func:`decode`:: + + claims = jwt.decode(encoded, certs=public_certs) + +You can also skip verification:: + + claims = jwt.decode(encoded, verify=False) + +.. _rfc7519: https://tools.ietf.org/html/rfc7519 + +""" + +try: + from collections.abc import Mapping +# Python 2.7 compatibility +except ImportError: # pragma: NO COVER + from collections import Mapping # type: ignore +import copy +import datetime +import json +import urllib + +import cachetools + +from google.auth import _helpers +from google.auth import _service_account_info +from google.auth import crypt +from google.auth import exceptions +import google.auth.credentials + +try: + from google.auth.crypt import es256 +except ImportError: # pragma: NO COVER + es256 = None # type: ignore + +_DEFAULT_TOKEN_LIFETIME_SECS = 3600 # 1 hour in seconds +_DEFAULT_MAX_CACHE_SIZE = 10 +_ALGORITHM_TO_VERIFIER_CLASS = {"RS256": crypt.RSAVerifier} +_CRYPTOGRAPHY_BASED_ALGORITHMS = frozenset(["ES256"]) + +if es256 is not None: # pragma: NO COVER + _ALGORITHM_TO_VERIFIER_CLASS["ES256"] = es256.ES256Verifier # type: ignore + + +def encode(signer, payload, header=None, key_id=None): + """Make a signed JWT. + + Args: + signer (google.auth.crypt.Signer): The signer used to sign the JWT. + payload (Mapping[str, str]): The JWT payload. + header (Mapping[str, str]): Additional JWT header payload. + key_id (str): The key id to add to the JWT header. If the + signer has a key id it will be used as the default. If this is + specified it will override the signer's key id. + + Returns: + bytes: The encoded JWT. + """ + if header is None: + header = {} + + if key_id is None: + key_id = signer.key_id + + header.update({"typ": "JWT"}) + + if "alg" not in header: + if es256 is not None and isinstance(signer, es256.ES256Signer): + header.update({"alg": "ES256"}) + else: + header.update({"alg": "RS256"}) + + if key_id is not None: + header["kid"] = key_id + + segments = [ + _helpers.unpadded_urlsafe_b64encode(json.dumps(header).encode("utf-8")), + _helpers.unpadded_urlsafe_b64encode(json.dumps(payload).encode("utf-8")), + ] + + signing_input = b".".join(segments) + signature = signer.sign(signing_input) + segments.append(_helpers.unpadded_urlsafe_b64encode(signature)) + + return b".".join(segments) + + +def _decode_jwt_segment(encoded_section): + """Decodes a single JWT segment.""" + section_bytes = _helpers.padded_urlsafe_b64decode(encoded_section) + try: + return json.loads(section_bytes.decode("utf-8")) + except ValueError as caught_exc: + new_exc = exceptions.MalformedError( + "Can't parse segment: {0}".format(section_bytes) + ) + raise new_exc from caught_exc + + +def _unverified_decode(token): + """Decodes a token and does no verification. + + Args: + token (Union[str, bytes]): The encoded JWT. + + Returns: + Tuple[Mapping, Mapping, str, str]: header, payload, signed_section, and + signature. + + Raises: + google.auth.exceptions.MalformedError: if there are an incorrect amount of segments in the token or segments of the wrong type. + """ + token = _helpers.to_bytes(token) + + if token.count(b".") != 2: + raise exceptions.MalformedError( + "Wrong number of segments in token: {0}".format(token) + ) + + encoded_header, encoded_payload, signature = token.split(b".") + signed_section = encoded_header + b"." + encoded_payload + signature = _helpers.padded_urlsafe_b64decode(signature) + + # Parse segments + header = _decode_jwt_segment(encoded_header) + payload = _decode_jwt_segment(encoded_payload) + + if not isinstance(header, Mapping): + raise exceptions.MalformedError( + "Header segment should be a JSON object: {0}".format(encoded_header) + ) + + if not isinstance(payload, Mapping): + raise exceptions.MalformedError( + "Payload segment should be a JSON object: {0}".format(encoded_payload) + ) + + return header, payload, signed_section, signature + + +def decode_header(token): + """Return the decoded header of a token. + + No verification is done. This is useful to extract the key id from + the header in order to acquire the appropriate certificate to verify + the token. + + Args: + token (Union[str, bytes]): the encoded JWT. + + Returns: + Mapping: The decoded JWT header. + """ + header, _, _, _ = _unverified_decode(token) + return header + + +def _verify_iat_and_exp(payload, clock_skew_in_seconds=0): + """Verifies the ``iat`` (Issued At) and ``exp`` (Expires) claims in a token + payload. + + Args: + payload (Mapping[str, str]): The JWT payload. + clock_skew_in_seconds (int): The clock skew used for `iat` and `exp` + validation. + + Raises: + google.auth.exceptions.InvalidValue: if value validation failed. + google.auth.exceptions.MalformedError: if schema validation failed. + """ + now = _helpers.datetime_to_secs(_helpers.utcnow()) + + # Make sure the iat and exp claims are present. + for key in ("iat", "exp"): + if key not in payload: + raise exceptions.MalformedError( + "Token does not contain required claim {}".format(key) + ) + + # Make sure the token wasn't issued in the future. + iat = payload["iat"] + # Err on the side of accepting a token that is slightly early to account + # for clock skew. + earliest = iat - clock_skew_in_seconds + if now < earliest: + raise exceptions.InvalidValue( + "Token used too early, {} < {}. Check that your computer's clock is set correctly.".format( + now, iat + ) + ) + + # Make sure the token wasn't issued in the past. + exp = payload["exp"] + # Err on the side of accepting a token that is slightly out of date + # to account for clow skew. + latest = exp + clock_skew_in_seconds + if latest < now: + raise exceptions.InvalidValue("Token expired, {} < {}".format(latest, now)) + + +def decode(token, certs=None, verify=True, audience=None, clock_skew_in_seconds=0): + """Decode and verify a JWT. + + Args: + token (str): The encoded JWT. + certs (Union[str, bytes, Mapping[str, Union[str, bytes]]]): The + certificate used to validate the JWT signature. If bytes or string, + it must the the public key certificate in PEM format. If a mapping, + it must be a mapping of key IDs to public key certificates in PEM + format. The mapping must contain the same key ID that's specified + in the token's header. + verify (bool): Whether to perform signature and claim validation. + Verification is done by default. + audience (str or list): The audience claim, 'aud', that this JWT should + contain. Or a list of audience claims. If None then the JWT's 'aud' + parameter is not verified. + clock_skew_in_seconds (int): The clock skew used for `iat` and `exp` + validation. + + Returns: + Mapping[str, str]: The deserialized JSON payload in the JWT. + + Raises: + google.auth.exceptions.InvalidValue: if value validation failed. + google.auth.exceptions.MalformedError: if schema validation failed. + """ + header, payload, signed_section, signature = _unverified_decode(token) + + if not verify: + return payload + + # Pluck the key id and algorithm from the header and make sure we have + # a verifier that can support it. + key_alg = header.get("alg") + key_id = header.get("kid") + + try: + verifier_cls = _ALGORITHM_TO_VERIFIER_CLASS[key_alg] + except KeyError as exc: + if key_alg in _CRYPTOGRAPHY_BASED_ALGORITHMS: + raise exceptions.InvalidValue( + "The key algorithm {} requires the cryptography package to be installed.".format( + key_alg + ) + ) from exc + else: + raise exceptions.InvalidValue( + "Unsupported signature algorithm {}".format(key_alg) + ) from exc + # If certs is specified as a dictionary of key IDs to certificates, then + # use the certificate identified by the key ID in the token header. + if isinstance(certs, Mapping): + if key_id: + if key_id not in certs: + raise exceptions.MalformedError( + "Certificate for key id {} not found.".format(key_id) + ) + certs_to_check = [certs[key_id]] + # If there's no key id in the header, check against all of the certs. + else: + certs_to_check = certs.values() + else: + certs_to_check = certs + + # Verify that the signature matches the message. + if not crypt.verify_signature( + signed_section, signature, certs_to_check, verifier_cls + ): + raise exceptions.MalformedError("Could not verify token signature.") + + # Verify the issued at and created times in the payload. + _verify_iat_and_exp(payload, clock_skew_in_seconds) + + # Check audience. + if audience is not None: + claim_audience = payload.get("aud") + if isinstance(audience, str): + audience = [audience] + if claim_audience not in audience: + raise exceptions.InvalidValue( + "Token has wrong audience {}, expected one of {}".format( + claim_audience, audience + ) + ) + + return payload + + +class Credentials( + google.auth.credentials.Signing, google.auth.credentials.CredentialsWithQuotaProject +): + """Credentials that use a JWT as the bearer token. + + These credentials require an "audience" claim. This claim identifies the + intended recipient of the bearer token. + + The constructor arguments determine the claims for the JWT that is + sent with requests. Usually, you'll construct these credentials with + one of the helper constructors as shown in the next section. + + To create JWT credentials using a Google service account private key + JSON file:: + + audience = 'https://pubsub.googleapis.com/google.pubsub.v1.Publisher' + credentials = jwt.Credentials.from_service_account_file( + 'service-account.json', + audience=audience) + + If you already have the service account file loaded and parsed:: + + service_account_info = json.load(open('service_account.json')) + credentials = jwt.Credentials.from_service_account_info( + service_account_info, + audience=audience) + + Both helper methods pass on arguments to the constructor, so you can + specify the JWT claims:: + + credentials = jwt.Credentials.from_service_account_file( + 'service-account.json', + audience=audience, + additional_claims={'meta': 'data'}) + + You can also construct the credentials directly if you have a + :class:`~google.auth.crypt.Signer` instance:: + + credentials = jwt.Credentials( + signer, + issuer='your-issuer', + subject='your-subject', + audience=audience) + + The claims are considered immutable. If you want to modify the claims, + you can easily create another instance using :meth:`with_claims`:: + + new_audience = ( + 'https://pubsub.googleapis.com/google.pubsub.v1.Subscriber') + new_credentials = credentials.with_claims(audience=new_audience) + """ + + def __init__( + self, + signer, + issuer, + subject, + audience, + additional_claims=None, + token_lifetime=_DEFAULT_TOKEN_LIFETIME_SECS, + quota_project_id=None, + ): + """ + Args: + signer (google.auth.crypt.Signer): The signer used to sign JWTs. + issuer (str): The `iss` claim. + subject (str): The `sub` claim. + audience (str): the `aud` claim. The intended audience for the + credentials. + additional_claims (Mapping[str, str]): Any additional claims for + the JWT payload. + token_lifetime (int): The amount of time in seconds for + which the token is valid. Defaults to 1 hour. + quota_project_id (Optional[str]): The project ID used for quota + and billing. + """ + super(Credentials, self).__init__() + self._signer = signer + self._issuer = issuer + self._subject = subject + self._audience = audience + self._token_lifetime = token_lifetime + self._quota_project_id = quota_project_id + + if additional_claims is None: + additional_claims = {} + + self._additional_claims = additional_claims + + @classmethod + def _from_signer_and_info(cls, signer, info, **kwargs): + """Creates a Credentials instance from a signer and service account + info. + + Args: + signer (google.auth.crypt.Signer): The signer used to sign JWTs. + info (Mapping[str, str]): The service account info. + kwargs: Additional arguments to pass to the constructor. + + Returns: + google.auth.jwt.Credentials: The constructed credentials. + + Raises: + google.auth.exceptions.MalformedError: If the info is not in the expected format. + """ + kwargs.setdefault("subject", info["client_email"]) + kwargs.setdefault("issuer", info["client_email"]) + return cls(signer, **kwargs) + + @classmethod + def from_service_account_info(cls, info, **kwargs): + """Creates an Credentials instance from a dictionary. + + Args: + info (Mapping[str, str]): The service account info in Google + format. + kwargs: Additional arguments to pass to the constructor. + + Returns: + google.auth.jwt.Credentials: The constructed credentials. + + Raises: + google.auth.exceptions.MalformedError: If the info is not in the expected format. + """ + signer = _service_account_info.from_dict(info, require=["client_email"]) + return cls._from_signer_and_info(signer, info, **kwargs) + + @classmethod + def from_service_account_file(cls, filename, **kwargs): + """Creates a Credentials instance from a service account .json file + in Google format. + + Args: + filename (str): The path to the service account .json file. + kwargs: Additional arguments to pass to the constructor. + + Returns: + google.auth.jwt.Credentials: The constructed credentials. + """ + info, signer = _service_account_info.from_filename( + filename, require=["client_email"] + ) + return cls._from_signer_and_info(signer, info, **kwargs) + + @classmethod + def from_signing_credentials(cls, credentials, audience, **kwargs): + """Creates a new :class:`google.auth.jwt.Credentials` instance from an + existing :class:`google.auth.credentials.Signing` instance. + + The new instance will use the same signer as the existing instance and + will use the existing instance's signer email as the issuer and + subject by default. + + Example:: + + svc_creds = service_account.Credentials.from_service_account_file( + 'service_account.json') + audience = ( + 'https://pubsub.googleapis.com/google.pubsub.v1.Publisher') + jwt_creds = jwt.Credentials.from_signing_credentials( + svc_creds, audience=audience) + + Args: + credentials (google.auth.credentials.Signing): The credentials to + use to construct the new credentials. + audience (str): the `aud` claim. The intended audience for the + credentials. + kwargs: Additional arguments to pass to the constructor. + + Returns: + google.auth.jwt.Credentials: A new Credentials instance. + """ + kwargs.setdefault("issuer", credentials.signer_email) + kwargs.setdefault("subject", credentials.signer_email) + return cls(credentials.signer, audience=audience, **kwargs) + + def with_claims( + self, issuer=None, subject=None, audience=None, additional_claims=None + ): + """Returns a copy of these credentials with modified claims. + + Args: + issuer (str): The `iss` claim. If unspecified the current issuer + claim will be used. + subject (str): The `sub` claim. If unspecified the current subject + claim will be used. + audience (str): the `aud` claim. If unspecified the current + audience claim will be used. + additional_claims (Mapping[str, str]): Any additional claims for + the JWT payload. This will be merged with the current + additional claims. + + Returns: + google.auth.jwt.Credentials: A new credentials instance. + """ + new_additional_claims = copy.deepcopy(self._additional_claims) + new_additional_claims.update(additional_claims or {}) + + return self.__class__( + self._signer, + issuer=issuer if issuer is not None else self._issuer, + subject=subject if subject is not None else self._subject, + audience=audience if audience is not None else self._audience, + additional_claims=new_additional_claims, + quota_project_id=self._quota_project_id, + ) + + @_helpers.copy_docstring(google.auth.credentials.CredentialsWithQuotaProject) + def with_quota_project(self, quota_project_id): + return self.__class__( + self._signer, + issuer=self._issuer, + subject=self._subject, + audience=self._audience, + additional_claims=self._additional_claims, + quota_project_id=quota_project_id, + ) + + def _make_jwt(self): + """Make a signed JWT. + + Returns: + Tuple[bytes, datetime]: The encoded JWT and the expiration. + """ + now = _helpers.utcnow() + lifetime = datetime.timedelta(seconds=self._token_lifetime) + expiry = now + lifetime + + payload = { + "iss": self._issuer, + "sub": self._subject, + "iat": _helpers.datetime_to_secs(now), + "exp": _helpers.datetime_to_secs(expiry), + } + if self._audience: + payload["aud"] = self._audience + + payload.update(self._additional_claims) + + jwt = encode(self._signer, payload) + + return jwt, expiry + + def refresh(self, request): + """Refreshes the access token. + + Args: + request (Any): Unused. + """ + # pylint: disable=unused-argument + # (pylint doesn't correctly recognize overridden methods.) + self.token, self.expiry = self._make_jwt() + + @_helpers.copy_docstring(google.auth.credentials.Signing) + def sign_bytes(self, message): + return self._signer.sign(message) + + @property # type: ignore + @_helpers.copy_docstring(google.auth.credentials.Signing) + def signer_email(self): + return self._issuer + + @property # type: ignore + @_helpers.copy_docstring(google.auth.credentials.Signing) + def signer(self): + return self._signer + + @property # type: ignore + def additional_claims(self): + """ Additional claims the JWT object was created with.""" + return self._additional_claims + + +class OnDemandCredentials( + google.auth.credentials.Signing, google.auth.credentials.CredentialsWithQuotaProject +): + """On-demand JWT credentials. + + Like :class:`Credentials`, this class uses a JWT as the bearer token for + authentication. However, this class does not require the audience at + construction time. Instead, it will generate a new token on-demand for + each request using the request URI as the audience. It caches tokens + so that multiple requests to the same URI do not incur the overhead + of generating a new token every time. + + This behavior is especially useful for `gRPC`_ clients. A gRPC service may + have multiple audience and gRPC clients may not know all of the audiences + required for accessing a particular service. With these credentials, + no knowledge of the audiences is required ahead of time. + + .. _grpc: http://www.grpc.io/ + """ + + def __init__( + self, + signer, + issuer, + subject, + additional_claims=None, + token_lifetime=_DEFAULT_TOKEN_LIFETIME_SECS, + max_cache_size=_DEFAULT_MAX_CACHE_SIZE, + quota_project_id=None, + ): + """ + Args: + signer (google.auth.crypt.Signer): The signer used to sign JWTs. + issuer (str): The `iss` claim. + subject (str): The `sub` claim. + additional_claims (Mapping[str, str]): Any additional claims for + the JWT payload. + token_lifetime (int): The amount of time in seconds for + which the token is valid. Defaults to 1 hour. + max_cache_size (int): The maximum number of JWT tokens to keep in + cache. Tokens are cached using :class:`cachetools.LRUCache`. + quota_project_id (Optional[str]): The project ID used for quota + and billing. + + """ + super(OnDemandCredentials, self).__init__() + self._signer = signer + self._issuer = issuer + self._subject = subject + self._token_lifetime = token_lifetime + self._quota_project_id = quota_project_id + + if additional_claims is None: + additional_claims = {} + + self._additional_claims = additional_claims + self._cache = cachetools.LRUCache(maxsize=max_cache_size) + + @classmethod + def _from_signer_and_info(cls, signer, info, **kwargs): + """Creates an OnDemandCredentials instance from a signer and service + account info. + + Args: + signer (google.auth.crypt.Signer): The signer used to sign JWTs. + info (Mapping[str, str]): The service account info. + kwargs: Additional arguments to pass to the constructor. + + Returns: + google.auth.jwt.OnDemandCredentials: The constructed credentials. + + Raises: + google.auth.exceptions.MalformedError: If the info is not in the expected format. + """ + kwargs.setdefault("subject", info["client_email"]) + kwargs.setdefault("issuer", info["client_email"]) + return cls(signer, **kwargs) + + @classmethod + def from_service_account_info(cls, info, **kwargs): + """Creates an OnDemandCredentials instance from a dictionary. + + Args: + info (Mapping[str, str]): The service account info in Google + format. + kwargs: Additional arguments to pass to the constructor. + + Returns: + google.auth.jwt.OnDemandCredentials: The constructed credentials. + + Raises: + google.auth.exceptions.MalformedError: If the info is not in the expected format. + """ + signer = _service_account_info.from_dict(info, require=["client_email"]) + return cls._from_signer_and_info(signer, info, **kwargs) + + @classmethod + def from_service_account_file(cls, filename, **kwargs): + """Creates an OnDemandCredentials instance from a service account .json + file in Google format. + + Args: + filename (str): The path to the service account .json file. + kwargs: Additional arguments to pass to the constructor. + + Returns: + google.auth.jwt.OnDemandCredentials: The constructed credentials. + """ + info, signer = _service_account_info.from_filename( + filename, require=["client_email"] + ) + return cls._from_signer_and_info(signer, info, **kwargs) + + @classmethod + def from_signing_credentials(cls, credentials, **kwargs): + """Creates a new :class:`google.auth.jwt.OnDemandCredentials` instance + from an existing :class:`google.auth.credentials.Signing` instance. + + The new instance will use the same signer as the existing instance and + will use the existing instance's signer email as the issuer and + subject by default. + + Example:: + + svc_creds = service_account.Credentials.from_service_account_file( + 'service_account.json') + jwt_creds = jwt.OnDemandCredentials.from_signing_credentials( + svc_creds) + + Args: + credentials (google.auth.credentials.Signing): The credentials to + use to construct the new credentials. + kwargs: Additional arguments to pass to the constructor. + + Returns: + google.auth.jwt.Credentials: A new Credentials instance. + """ + kwargs.setdefault("issuer", credentials.signer_email) + kwargs.setdefault("subject", credentials.signer_email) + return cls(credentials.signer, **kwargs) + + def with_claims(self, issuer=None, subject=None, additional_claims=None): + """Returns a copy of these credentials with modified claims. + + Args: + issuer (str): The `iss` claim. If unspecified the current issuer + claim will be used. + subject (str): The `sub` claim. If unspecified the current subject + claim will be used. + additional_claims (Mapping[str, str]): Any additional claims for + the JWT payload. This will be merged with the current + additional claims. + + Returns: + google.auth.jwt.OnDemandCredentials: A new credentials instance. + """ + new_additional_claims = copy.deepcopy(self._additional_claims) + new_additional_claims.update(additional_claims or {}) + + return self.__class__( + self._signer, + issuer=issuer if issuer is not None else self._issuer, + subject=subject if subject is not None else self._subject, + additional_claims=new_additional_claims, + max_cache_size=self._cache.maxsize, + quota_project_id=self._quota_project_id, + ) + + @_helpers.copy_docstring(google.auth.credentials.CredentialsWithQuotaProject) + def with_quota_project(self, quota_project_id): + + return self.__class__( + self._signer, + issuer=self._issuer, + subject=self._subject, + additional_claims=self._additional_claims, + max_cache_size=self._cache.maxsize, + quota_project_id=quota_project_id, + ) + + @property + def valid(self): + """Checks the validity of the credentials. + + These credentials are always valid because it generates tokens on + demand. + """ + return True + + def _make_jwt_for_audience(self, audience): + """Make a new JWT for the given audience. + + Args: + audience (str): The intended audience. + + Returns: + Tuple[bytes, datetime]: The encoded JWT and the expiration. + """ + now = _helpers.utcnow() + lifetime = datetime.timedelta(seconds=self._token_lifetime) + expiry = now + lifetime + + payload = { + "iss": self._issuer, + "sub": self._subject, + "iat": _helpers.datetime_to_secs(now), + "exp": _helpers.datetime_to_secs(expiry), + "aud": audience, + } + + payload.update(self._additional_claims) + + jwt = encode(self._signer, payload) + + return jwt, expiry + + def _get_jwt_for_audience(self, audience): + """Get a JWT For a given audience. + + If there is already an existing, non-expired token in the cache for + the audience, that token is used. Otherwise, a new token will be + created. + + Args: + audience (str): The intended audience. + + Returns: + bytes: The encoded JWT. + """ + token, expiry = self._cache.get(audience, (None, None)) + + if token is None or expiry < _helpers.utcnow(): + token, expiry = self._make_jwt_for_audience(audience) + self._cache[audience] = token, expiry + + return token + + def refresh(self, request): + """Raises an exception, these credentials can not be directly + refreshed. + + Args: + request (Any): Unused. + + Raises: + google.auth.RefreshError + """ + # pylint: disable=unused-argument + # (pylint doesn't correctly recognize overridden methods.) + raise exceptions.RefreshError( + "OnDemandCredentials can not be directly refreshed." + ) + + def before_request(self, request, method, url, headers): + """Performs credential-specific before request logic. + + Args: + request (Any): Unused. JWT credentials do not need to make an + HTTP request to refresh. + method (str): The request's HTTP method. + url (str): The request's URI. This is used as the audience claim + when generating the JWT. + headers (Mapping): The request's headers. + """ + # pylint: disable=unused-argument + # (pylint doesn't correctly recognize overridden methods.) + parts = urllib.parse.urlsplit(url) + # Strip query string and fragment + audience = urllib.parse.urlunsplit( + (parts.scheme, parts.netloc, parts.path, "", "") + ) + token = self._get_jwt_for_audience(audience) + self.apply(headers, token=token) + + @_helpers.copy_docstring(google.auth.credentials.Signing) + def sign_bytes(self, message): + return self._signer.sign(message) + + @property # type: ignore + @_helpers.copy_docstring(google.auth.credentials.Signing) + def signer_email(self): + return self._issuer + + @property # type: ignore + @_helpers.copy_docstring(google.auth.credentials.Signing) + def signer(self): + return self._signer diff --git a/lib/python3.10/site-packages/google/auth/metrics.py b/lib/python3.10/site-packages/google/auth/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..11e4b0773077f913b39fdd4eb122c5e6448b56b0 --- /dev/null +++ b/lib/python3.10/site-packages/google/auth/metrics.py @@ -0,0 +1,154 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" We use x-goog-api-client header to report metrics. This module provides +the constants and helper methods to construct x-goog-api-client header. +""" + +import platform + +from google.auth import version + + +API_CLIENT_HEADER = "x-goog-api-client" + +# BYOID Specific consts +BYOID_HEADER_SECTION = "google-byoid-sdk" + +# Auth request type +REQUEST_TYPE_ACCESS_TOKEN = "auth-request-type/at" +REQUEST_TYPE_ID_TOKEN = "auth-request-type/it" +REQUEST_TYPE_MDS_PING = "auth-request-type/mds" +REQUEST_TYPE_REAUTH_START = "auth-request-type/re-start" +REQUEST_TYPE_REAUTH_CONTINUE = "auth-request-type/re-cont" + +# Credential type +CRED_TYPE_USER = "cred-type/u" +CRED_TYPE_SA_ASSERTION = "cred-type/sa" +CRED_TYPE_SA_JWT = "cred-type/jwt" +CRED_TYPE_SA_MDS = "cred-type/mds" +CRED_TYPE_SA_IMPERSONATE = "cred-type/imp" + + +# Versions +def python_and_auth_lib_version(): + return "gl-python/{} auth/{}".format(platform.python_version(), version.__version__) + + +# Token request metric header values + +# x-goog-api-client header value for access token request via metadata server. +# Example: "gl-python/3.7 auth/1.1 auth-request-type/at cred-type/mds" +def token_request_access_token_mds(): + return "{} {} {}".format( + python_and_auth_lib_version(), REQUEST_TYPE_ACCESS_TOKEN, CRED_TYPE_SA_MDS + ) + + +# x-goog-api-client header value for ID token request via metadata server. +# Example: "gl-python/3.7 auth/1.1 auth-request-type/it cred-type/mds" +def token_request_id_token_mds(): + return "{} {} {}".format( + python_and_auth_lib_version(), REQUEST_TYPE_ID_TOKEN, CRED_TYPE_SA_MDS + ) + + +# x-goog-api-client header value for impersonated credentials access token request. +# Example: "gl-python/3.7 auth/1.1 auth-request-type/at cred-type/imp" +def token_request_access_token_impersonate(): + return "{} {} {}".format( + python_and_auth_lib_version(), + REQUEST_TYPE_ACCESS_TOKEN, + CRED_TYPE_SA_IMPERSONATE, + ) + + +# x-goog-api-client header value for impersonated credentials ID token request. +# Example: "gl-python/3.7 auth/1.1 auth-request-type/it cred-type/imp" +def token_request_id_token_impersonate(): + return "{} {} {}".format( + python_and_auth_lib_version(), REQUEST_TYPE_ID_TOKEN, CRED_TYPE_SA_IMPERSONATE + ) + + +# x-goog-api-client header value for service account credentials access token +# request (assertion flow). +# Example: "gl-python/3.7 auth/1.1 auth-request-type/at cred-type/sa" +def token_request_access_token_sa_assertion(): + return "{} {} {}".format( + python_and_auth_lib_version(), REQUEST_TYPE_ACCESS_TOKEN, CRED_TYPE_SA_ASSERTION + ) + + +# x-goog-api-client header value for service account credentials ID token +# request (assertion flow). +# Example: "gl-python/3.7 auth/1.1 auth-request-type/it cred-type/sa" +def token_request_id_token_sa_assertion(): + return "{} {} {}".format( + python_and_auth_lib_version(), REQUEST_TYPE_ID_TOKEN, CRED_TYPE_SA_ASSERTION + ) + + +# x-goog-api-client header value for user credentials token request. +# Example: "gl-python/3.7 auth/1.1 cred-type/u" +def token_request_user(): + return "{} {}".format(python_and_auth_lib_version(), CRED_TYPE_USER) + + +# Miscellenous metrics + +# x-goog-api-client header value for metadata server ping. +# Example: "gl-python/3.7 auth/1.1 auth-request-type/mds" +def mds_ping(): + return "{} {}".format(python_and_auth_lib_version(), REQUEST_TYPE_MDS_PING) + + +# x-goog-api-client header value for reauth start endpoint calls. +# Example: "gl-python/3.7 auth/1.1 auth-request-type/re-start" +def reauth_start(): + return "{} {}".format(python_and_auth_lib_version(), REQUEST_TYPE_REAUTH_START) + + +# x-goog-api-client header value for reauth continue endpoint calls. +# Example: "gl-python/3.7 auth/1.1 cred-type/re-cont" +def reauth_continue(): + return "{} {}".format(python_and_auth_lib_version(), REQUEST_TYPE_REAUTH_CONTINUE) + + +# x-goog-api-client header value for BYOID calls to the Security Token Service exchange token endpoint. +# Example: "gl-python/3.7 auth/1.1 google-byoid-sdk source/aws sa-impersonation/true sa-impersonation/true" +def byoid_metrics_header(metrics_options): + header = "{} {}".format(python_and_auth_lib_version(), BYOID_HEADER_SECTION) + for key, value in metrics_options.items(): + header = "{} {}/{}".format(header, key, value) + return header + + +def add_metric_header(headers, metric_header_value): + """Add x-goog-api-client header with the given value. + + Args: + headers (Mapping[str, str]): The headers to which we will add the + metric header. + metric_header_value (Optional[str]): If value is None, do nothing; + if headers already has a x-goog-api-client header, append the value + to the existing header; otherwise add a new x-goog-api-client + header with the given value. + """ + if not metric_header_value: + return + if API_CLIENT_HEADER not in headers: + headers[API_CLIENT_HEADER] = metric_header_value + else: + headers[API_CLIENT_HEADER] += " " + metric_header_value diff --git a/lib/python3.10/site-packages/google/auth/pluggable.py b/lib/python3.10/site-packages/google/auth/pluggable.py new file mode 100644 index 0000000000000000000000000000000000000000..d725188f87a9a9025217b0cbe3a36b75885d5477 --- /dev/null +++ b/lib/python3.10/site-packages/google/auth/pluggable.py @@ -0,0 +1,429 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pluggable Credentials. +Pluggable Credentials are initialized using external_account arguments which +are typically loaded from third-party executables. Unlike other +credentials that can be initialized with a list of explicit arguments, secrets +or credentials, external account clients use the environment and hints/guidelines +provided by the external_account JSON file to retrieve credentials and exchange +them for Google access tokens. + +Example credential_source for pluggable credential: +{ + "executable": { + "command": "/path/to/get/credentials.sh --arg1=value1 --arg2=value2", + "timeout_millis": 5000, + "output_file": "/path/to/generated/cached/credentials" + } +} +""" + +try: + from collections.abc import Mapping +# Python 2.7 compatibility +except ImportError: # pragma: NO COVER + from collections import Mapping # type: ignore +import json +import os +import subprocess +import sys +import time + +from google.auth import _helpers +from google.auth import exceptions +from google.auth import external_account + +# The max supported executable spec version. +EXECUTABLE_SUPPORTED_MAX_VERSION = 1 + +EXECUTABLE_TIMEOUT_MILLIS_DEFAULT = 30 * 1000 # 30 seconds +EXECUTABLE_TIMEOUT_MILLIS_LOWER_BOUND = 5 * 1000 # 5 seconds +EXECUTABLE_TIMEOUT_MILLIS_UPPER_BOUND = 120 * 1000 # 2 minutes + +EXECUTABLE_INTERACTIVE_TIMEOUT_MILLIS_LOWER_BOUND = 30 * 1000 # 30 seconds +EXECUTABLE_INTERACTIVE_TIMEOUT_MILLIS_UPPER_BOUND = 30 * 60 * 1000 # 30 minutes + + +class Credentials(external_account.Credentials): + """External account credentials sourced from executables.""" + + def __init__( + self, + audience, + subject_token_type, + token_url, + credential_source, + *args, + **kwargs + ): + """Instantiates an external account credentials object from a executables. + + Args: + audience (str): The STS audience field. + subject_token_type (str): The subject token type. + token_url (str): The STS endpoint URL. + credential_source (Mapping): The credential source dictionary used to + provide instructions on how to retrieve external credential to be + exchanged for Google access tokens. + + Example credential_source for pluggable credential: + + { + "executable": { + "command": "/path/to/get/credentials.sh --arg1=value1 --arg2=value2", + "timeout_millis": 5000, + "output_file": "/path/to/generated/cached/credentials" + } + } + args (List): Optional positional arguments passed into the underlying :meth:`~external_account.Credentials.__init__` method. + kwargs (Mapping): Optional keyword arguments passed into the underlying :meth:`~external_account.Credentials.__init__` method. + + Raises: + google.auth.exceptions.RefreshError: If an error is encountered during + access token retrieval logic. + google.auth.exceptions.InvalidValue: For invalid parameters. + google.auth.exceptions.MalformedError: For invalid parameters. + + .. note:: Typically one of the helper constructors + :meth:`from_file` or + :meth:`from_info` are used instead of calling the constructor directly. + """ + + self.interactive = kwargs.pop("interactive", False) + super(Credentials, self).__init__( + audience=audience, + subject_token_type=subject_token_type, + token_url=token_url, + credential_source=credential_source, + *args, + **kwargs + ) + if not isinstance(credential_source, Mapping): + self._credential_source_executable = None + raise exceptions.MalformedError( + "Missing credential_source. The credential_source is not a dict." + ) + self._credential_source_executable = credential_source.get("executable") + if not self._credential_source_executable: + raise exceptions.MalformedError( + "Missing credential_source. An 'executable' must be provided." + ) + self._credential_source_executable_command = self._credential_source_executable.get( + "command" + ) + self._credential_source_executable_timeout_millis = self._credential_source_executable.get( + "timeout_millis" + ) + self._credential_source_executable_interactive_timeout_millis = self._credential_source_executable.get( + "interactive_timeout_millis" + ) + self._credential_source_executable_output_file = self._credential_source_executable.get( + "output_file" + ) + + # Dummy value. This variable is only used via injection, not exposed to ctor + self._tokeninfo_username = "" + + if not self._credential_source_executable_command: + raise exceptions.MalformedError( + "Missing command field. Executable command must be provided." + ) + if not self._credential_source_executable_timeout_millis: + self._credential_source_executable_timeout_millis = ( + EXECUTABLE_TIMEOUT_MILLIS_DEFAULT + ) + elif ( + self._credential_source_executable_timeout_millis + < EXECUTABLE_TIMEOUT_MILLIS_LOWER_BOUND + or self._credential_source_executable_timeout_millis + > EXECUTABLE_TIMEOUT_MILLIS_UPPER_BOUND + ): + raise exceptions.InvalidValue("Timeout must be between 5 and 120 seconds.") + + if self._credential_source_executable_interactive_timeout_millis: + if ( + self._credential_source_executable_interactive_timeout_millis + < EXECUTABLE_INTERACTIVE_TIMEOUT_MILLIS_LOWER_BOUND + or self._credential_source_executable_interactive_timeout_millis + > EXECUTABLE_INTERACTIVE_TIMEOUT_MILLIS_UPPER_BOUND + ): + raise exceptions.InvalidValue( + "Interactive timeout must be between 30 seconds and 30 minutes." + ) + + @_helpers.copy_docstring(external_account.Credentials) + def retrieve_subject_token(self, request): + self._validate_running_mode() + + # Check output file. + if self._credential_source_executable_output_file is not None: + try: + with open( + self._credential_source_executable_output_file, encoding="utf-8" + ) as output_file: + response = json.load(output_file) + except Exception: + pass + else: + try: + # If the cached response is expired, _parse_subject_token will raise an error which will be ignored and we will call the executable again. + subject_token = self._parse_subject_token(response) + if ( + "expiration_time" not in response + ): # Always treat missing expiration_time as expired and proceed to executable run. + raise exceptions.RefreshError + except (exceptions.MalformedError, exceptions.InvalidValue): + raise + except exceptions.RefreshError: + pass + else: + return subject_token + + if not _helpers.is_python_3(): + raise exceptions.RefreshError( + "Pluggable auth is only supported for python 3.7+" + ) + + # Inject env vars. + env = os.environ.copy() + self._inject_env_variables(env) + env["GOOGLE_EXTERNAL_ACCOUNT_REVOKE"] = "0" + + # Run executable. + exe_timeout = ( + self._credential_source_executable_interactive_timeout_millis / 1000 + if self.interactive + else self._credential_source_executable_timeout_millis / 1000 + ) + exe_stdin = sys.stdin if self.interactive else None + exe_stdout = sys.stdout if self.interactive else subprocess.PIPE + exe_stderr = sys.stdout if self.interactive else subprocess.STDOUT + + result = subprocess.run( + self._credential_source_executable_command.split(), + timeout=exe_timeout, + stdin=exe_stdin, + stdout=exe_stdout, + stderr=exe_stderr, + env=env, + ) + if result.returncode != 0: + raise exceptions.RefreshError( + "Executable exited with non-zero return code {}. Error: {}".format( + result.returncode, result.stdout + ) + ) + + # Handle executable output. + response = json.loads(result.stdout.decode("utf-8")) if result.stdout else None + if not response and self._credential_source_executable_output_file is not None: + response = json.load( + open(self._credential_source_executable_output_file, encoding="utf-8") + ) + + subject_token = self._parse_subject_token(response) + return subject_token + + def revoke(self, request): + """Revokes the subject token using the credential_source object. + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. + Raises: + google.auth.exceptions.RefreshError: If the executable revocation + not properly executed. + + """ + if not self.interactive: + raise exceptions.InvalidValue( + "Revoke is only enabled under interactive mode." + ) + self._validate_running_mode() + + if not _helpers.is_python_3(): + raise exceptions.RefreshError( + "Pluggable auth is only supported for python 3.7+" + ) + + # Inject variables + env = os.environ.copy() + self._inject_env_variables(env) + env["GOOGLE_EXTERNAL_ACCOUNT_REVOKE"] = "1" + + # Run executable + result = subprocess.run( + self._credential_source_executable_command.split(), + timeout=self._credential_source_executable_interactive_timeout_millis + / 1000, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + env=env, + ) + + if result.returncode != 0: + raise exceptions.RefreshError( + "Auth revoke failed on executable. Exit with non-zero return code {}. Error: {}".format( + result.returncode, result.stdout + ) + ) + + response = json.loads(result.stdout.decode("utf-8")) + self._validate_revoke_response(response) + + @property + def external_account_id(self): + """Returns the external account identifier. + + When service account impersonation is used the identifier is the service + account email. + + Without service account impersonation, this returns None, unless it is + being used by the Google Cloud CLI which populates this field. + """ + + return self.service_account_email or self._tokeninfo_username + + @classmethod + def from_info(cls, info, **kwargs): + """Creates a Pluggable Credentials instance from parsed external account info. + + Args: + info (Mapping[str, str]): The Pluggable external account info in Google + format. + kwargs: Additional arguments to pass to the constructor. + + Returns: + google.auth.pluggable.Credentials: The constructed + credentials. + + Raises: + google.auth.exceptions.InvalidValue: For invalid parameters. + google.auth.exceptions.MalformedError: For invalid parameters. + """ + return super(Credentials, cls).from_info(info, **kwargs) + + @classmethod + def from_file(cls, filename, **kwargs): + """Creates an Pluggable Credentials instance from an external account json file. + + Args: + filename (str): The path to the Pluggable external account json file. + kwargs: Additional arguments to pass to the constructor. + + Returns: + google.auth.pluggable.Credentials: The constructed + credentials. + """ + return super(Credentials, cls).from_file(filename, **kwargs) + + def _inject_env_variables(self, env): + env["GOOGLE_EXTERNAL_ACCOUNT_AUDIENCE"] = self._audience + env["GOOGLE_EXTERNAL_ACCOUNT_TOKEN_TYPE"] = self._subject_token_type + env["GOOGLE_EXTERNAL_ACCOUNT_ID"] = self.external_account_id + env["GOOGLE_EXTERNAL_ACCOUNT_INTERACTIVE"] = "1" if self.interactive else "0" + + if self._service_account_impersonation_url is not None: + env[ + "GOOGLE_EXTERNAL_ACCOUNT_IMPERSONATED_EMAIL" + ] = self.service_account_email + if self._credential_source_executable_output_file is not None: + env[ + "GOOGLE_EXTERNAL_ACCOUNT_OUTPUT_FILE" + ] = self._credential_source_executable_output_file + + def _parse_subject_token(self, response): + self._validate_response_schema(response) + if not response["success"]: + if "code" not in response or "message" not in response: + raise exceptions.MalformedError( + "Error code and message fields are required in the response." + ) + raise exceptions.RefreshError( + "Executable returned unsuccessful response: code: {}, message: {}.".format( + response["code"], response["message"] + ) + ) + if "expiration_time" in response and response["expiration_time"] < time.time(): + raise exceptions.RefreshError( + "The token returned by the executable is expired." + ) + if "token_type" not in response: + raise exceptions.MalformedError( + "The executable response is missing the token_type field." + ) + if ( + response["token_type"] == "urn:ietf:params:oauth:token-type:jwt" + or response["token_type"] == "urn:ietf:params:oauth:token-type:id_token" + ): # OIDC + return response["id_token"] + elif response["token_type"] == "urn:ietf:params:oauth:token-type:saml2": # SAML + return response["saml_response"] + else: + raise exceptions.RefreshError("Executable returned unsupported token type.") + + def _validate_revoke_response(self, response): + self._validate_response_schema(response) + if not response["success"]: + raise exceptions.RefreshError("Revoke failed with unsuccessful response.") + + def _validate_response_schema(self, response): + if "version" not in response: + raise exceptions.MalformedError( + "The executable response is missing the version field." + ) + if response["version"] > EXECUTABLE_SUPPORTED_MAX_VERSION: + raise exceptions.RefreshError( + "Executable returned unsupported version {}.".format( + response["version"] + ) + ) + + if "success" not in response: + raise exceptions.MalformedError( + "The executable response is missing the success field." + ) + + def _validate_running_mode(self): + env_allow_executables = os.environ.get( + "GOOGLE_EXTERNAL_ACCOUNT_ALLOW_EXECUTABLES" + ) + if env_allow_executables != "1": + raise exceptions.MalformedError( + "Executables need to be explicitly allowed (set GOOGLE_EXTERNAL_ACCOUNT_ALLOW_EXECUTABLES to '1') to run." + ) + + if self.interactive and not self._credential_source_executable_output_file: + raise exceptions.MalformedError( + "An output_file must be specified in the credential configuration for interactive mode." + ) + + if ( + self.interactive + and not self._credential_source_executable_interactive_timeout_millis + ): + raise exceptions.InvalidOperation( + "Interactive mode cannot run without an interactive timeout." + ) + + if self.interactive and not self.is_workforce_pool: + raise exceptions.InvalidValue( + "Interactive mode is only enabled for workforce pool." + ) + + def _create_default_metrics_options(self): + metrics_options = super(Credentials, self)._create_default_metrics_options() + metrics_options["source"] = "executable" + return metrics_options diff --git a/lib/python3.10/site-packages/google/auth/py.typed b/lib/python3.10/site-packages/google/auth/py.typed new file mode 100644 index 0000000000000000000000000000000000000000..aa7b68923f2be5df6f0b6f5ab5c86e82d5ed9732 --- /dev/null +++ b/lib/python3.10/site-packages/google/auth/py.typed @@ -0,0 +1,2 @@ +# Marker file for PEP 561. +# The google-auth package uses inline types. diff --git a/lib/python3.10/site-packages/google/auth/version.py b/lib/python3.10/site-packages/google/auth/version.py new file mode 100644 index 0000000000000000000000000000000000000000..318bf723b43287d30e505d749b267189225f537b --- /dev/null +++ b/lib/python3.10/site-packages/google/auth/version.py @@ -0,0 +1,15 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__version__ = "2.40.3" diff --git a/lib/python3.10/site-packages/google/oauth2/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/google/oauth2/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..47dbd05d05fe8d1181555af5e834463a5a6e9d8b Binary files /dev/null and b/lib/python3.10/site-packages/google/oauth2/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/google/oauth2/__pycache__/_client.cpython-310.pyc b/lib/python3.10/site-packages/google/oauth2/__pycache__/_client.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..58db4fc5c0761b8405dfa6e10e6a648872e2619c Binary files /dev/null and b/lib/python3.10/site-packages/google/oauth2/__pycache__/_client.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/google/oauth2/__pycache__/_client_async.cpython-310.pyc b/lib/python3.10/site-packages/google/oauth2/__pycache__/_client_async.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..904fe288c2aaf4ab8920231391ded68e8a7c66b6 Binary files /dev/null and b/lib/python3.10/site-packages/google/oauth2/__pycache__/_client_async.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/google/oauth2/__pycache__/_credentials_async.cpython-310.pyc b/lib/python3.10/site-packages/google/oauth2/__pycache__/_credentials_async.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc3b5352f78989c442513453350502d869793d04 Binary files /dev/null and b/lib/python3.10/site-packages/google/oauth2/__pycache__/_credentials_async.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/google/oauth2/__pycache__/_id_token_async.cpython-310.pyc b/lib/python3.10/site-packages/google/oauth2/__pycache__/_id_token_async.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e365246c1ea155de8109bb2c8cfc78ce53c8e0b9 Binary files /dev/null and b/lib/python3.10/site-packages/google/oauth2/__pycache__/_id_token_async.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/google/oauth2/__pycache__/_reauth_async.cpython-310.pyc b/lib/python3.10/site-packages/google/oauth2/__pycache__/_reauth_async.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..136c873e0c541dda2eed166b813dddc822652739 Binary files /dev/null and b/lib/python3.10/site-packages/google/oauth2/__pycache__/_reauth_async.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/google/oauth2/__pycache__/_service_account_async.cpython-310.pyc b/lib/python3.10/site-packages/google/oauth2/__pycache__/_service_account_async.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..841e83d28c1f0887fcc4d3cdcad0f9e25d71681c Binary files /dev/null and b/lib/python3.10/site-packages/google/oauth2/__pycache__/_service_account_async.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/google/oauth2/__pycache__/challenges.cpython-310.pyc b/lib/python3.10/site-packages/google/oauth2/__pycache__/challenges.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f2edcd788060d3ce13156e9d25344fae624fc450 Binary files /dev/null and b/lib/python3.10/site-packages/google/oauth2/__pycache__/challenges.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/google/oauth2/__pycache__/credentials.cpython-310.pyc b/lib/python3.10/site-packages/google/oauth2/__pycache__/credentials.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f9d858e635ccd37e52c909a3bf89952169075d08 Binary files /dev/null and b/lib/python3.10/site-packages/google/oauth2/__pycache__/credentials.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/google/oauth2/__pycache__/gdch_credentials.cpython-310.pyc b/lib/python3.10/site-packages/google/oauth2/__pycache__/gdch_credentials.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..26a53a83d0111bb77c8cb8f1217cb970c0ff5554 Binary files /dev/null and b/lib/python3.10/site-packages/google/oauth2/__pycache__/gdch_credentials.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/google/oauth2/__pycache__/id_token.cpython-310.pyc b/lib/python3.10/site-packages/google/oauth2/__pycache__/id_token.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2633e54da73bad2a4b4adb7004a3a0647e335000 Binary files /dev/null and b/lib/python3.10/site-packages/google/oauth2/__pycache__/id_token.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/google/oauth2/__pycache__/reauth.cpython-310.pyc b/lib/python3.10/site-packages/google/oauth2/__pycache__/reauth.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..18193a69864dd68a1feb4ef0ba5ae7e13b121ff6 Binary files /dev/null and b/lib/python3.10/site-packages/google/oauth2/__pycache__/reauth.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/google/oauth2/__pycache__/service_account.cpython-310.pyc b/lib/python3.10/site-packages/google/oauth2/__pycache__/service_account.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a41007901b92ebf4d0812dbbe1abea81e9f8550a Binary files /dev/null and b/lib/python3.10/site-packages/google/oauth2/__pycache__/service_account.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/google/oauth2/__pycache__/sts.cpython-310.pyc b/lib/python3.10/site-packages/google/oauth2/__pycache__/sts.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..884ad60a5dbbc4a238d5859ea0c7ec09fd26469c Binary files /dev/null and b/lib/python3.10/site-packages/google/oauth2/__pycache__/sts.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/google/oauth2/__pycache__/utils.cpython-310.pyc b/lib/python3.10/site-packages/google/oauth2/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b0658de769443cddb1edbccfaee98e52e0d14f5c Binary files /dev/null and b/lib/python3.10/site-packages/google/oauth2/__pycache__/utils.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/google/oauth2/__pycache__/webauthn_handler.cpython-310.pyc b/lib/python3.10/site-packages/google/oauth2/__pycache__/webauthn_handler.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bede07b417be07feacb4d565a26b64c0f5976548 Binary files /dev/null and b/lib/python3.10/site-packages/google/oauth2/__pycache__/webauthn_handler.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/google/oauth2/__pycache__/webauthn_handler_factory.cpython-310.pyc b/lib/python3.10/site-packages/google/oauth2/__pycache__/webauthn_handler_factory.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5deff445737c66622357104292f196807c8558d Binary files /dev/null and b/lib/python3.10/site-packages/google/oauth2/__pycache__/webauthn_handler_factory.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/google/oauth2/__pycache__/webauthn_types.cpython-310.pyc b/lib/python3.10/site-packages/google/oauth2/__pycache__/webauthn_types.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..81d977a264df233bc4d59cb57560aed50ae8b3c4 Binary files /dev/null and b/lib/python3.10/site-packages/google/oauth2/__pycache__/webauthn_types.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/google/oauth2/_client.py b/lib/python3.10/site-packages/google/oauth2/_client.py new file mode 100644 index 0000000000000000000000000000000000000000..5a9fc3503c5390ef0eea8579e8cda487d3fad6e7 --- /dev/null +++ b/lib/python3.10/site-packages/google/oauth2/_client.py @@ -0,0 +1,508 @@ +# Copyright 2016 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""OAuth 2.0 client. + +This is a client for interacting with an OAuth 2.0 authorization server's +token endpoint. + +For more information about the token endpoint, see +`Section 3.1 of rfc6749`_ + +.. _Section 3.1 of rfc6749: https://tools.ietf.org/html/rfc6749#section-3.2 +""" + +import datetime +import http.client as http_client +import json +import urllib + +from google.auth import _exponential_backoff +from google.auth import _helpers +from google.auth import credentials +from google.auth import exceptions +from google.auth import jwt +from google.auth import metrics +from google.auth import transport + +_URLENCODED_CONTENT_TYPE = "application/x-www-form-urlencoded" +_JSON_CONTENT_TYPE = "application/json" +_JWT_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:jwt-bearer" +_REFRESH_GRANT_TYPE = "refresh_token" + + +def _handle_error_response(response_data, retryable_error): + """Translates an error response into an exception. + + Args: + response_data (Mapping | str): The decoded response data. + retryable_error Optional[bool]: A boolean indicating if an error is retryable. + Defaults to False. + + Raises: + google.auth.exceptions.RefreshError: The errors contained in response_data. + """ + + retryable_error = retryable_error if retryable_error else False + + if isinstance(response_data, str): + raise exceptions.RefreshError(response_data, retryable=retryable_error) + try: + error_details = "{}: {}".format( + response_data["error"], response_data.get("error_description") + ) + # If no details could be extracted, use the response data. + except (KeyError, ValueError): + error_details = json.dumps(response_data) + + raise exceptions.RefreshError( + error_details, response_data, retryable=retryable_error + ) + + +def _can_retry(status_code, response_data): + """Checks if a request can be retried by inspecting the status code + and response body of the request. + + Args: + status_code (int): The response status code. + response_data (Mapping | str): The decoded response data. + + Returns: + bool: True if the response is retryable. False otherwise. + """ + if status_code in transport.DEFAULT_RETRYABLE_STATUS_CODES: + return True + + try: + # For a failed response, response_body could be a string + error_desc = response_data.get("error_description") or "" + error_code = response_data.get("error") or "" + + if not isinstance(error_code, str) or not isinstance(error_desc, str): + return False + + # Per Oauth 2.0 RFC https://www.rfc-editor.org/rfc/rfc6749.html#section-4.1.2.1 + # This is needed because a redirect will not return a 500 status code. + retryable_error_descriptions = { + "internal_failure", + "server_error", + "temporarily_unavailable", + } + + if any(e in retryable_error_descriptions for e in (error_code, error_desc)): + return True + + except AttributeError: + pass + + return False + + +def _parse_expiry(response_data): + """Parses the expiry field from a response into a datetime. + + Args: + response_data (Mapping): The JSON-parsed response data. + + Returns: + Optional[datetime]: The expiration or ``None`` if no expiration was + specified. + """ + expires_in = response_data.get("expires_in", None) + + if expires_in is not None: + # Some services do not respect the OAUTH2.0 RFC and send expires_in as a + # JSON String. + if isinstance(expires_in, str): + expires_in = int(expires_in) + + return _helpers.utcnow() + datetime.timedelta(seconds=expires_in) + else: + return None + + +def _token_endpoint_request_no_throw( + request, + token_uri, + body, + access_token=None, + use_json=False, + can_retry=True, + headers=None, + **kwargs +): + """Makes a request to the OAuth 2.0 authorization server's token endpoint. + This function doesn't throw on response errors. + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. + token_uri (str): The OAuth 2.0 authorizations server's token endpoint + URI. + body (Mapping[str, str]): The parameters to send in the request body. + access_token (Optional(str)): The access token needed to make the request. + use_json (Optional(bool)): Use urlencoded format or json format for the + content type. The default value is False. + can_retry (bool): Enable or disable request retry behavior. + headers (Optional[Mapping[str, str]]): The headers for the request. + kwargs: Additional arguments passed on to the request method. The + kwargs will be passed to `requests.request` method, see: + https://docs.python-requests.org/en/latest/api/#requests.request. + For example, you can use `cert=("cert_pem_path", "key_pem_path")` + to set up client side SSL certificate, and use + `verify="ca_bundle_path"` to set up the CA certificates for sever + side SSL certificate verification. + + Returns: + Tuple(bool, Mapping[str, str], Optional[bool]): A boolean indicating + if the request is successful, a mapping for the JSON-decoded response + data and in the case of an error a boolean indicating if the error + is retryable. + """ + if use_json: + headers_to_use = {"Content-Type": _JSON_CONTENT_TYPE} + body = json.dumps(body).encode("utf-8") + else: + headers_to_use = {"Content-Type": _URLENCODED_CONTENT_TYPE} + body = urllib.parse.urlencode(body).encode("utf-8") + + if access_token: + headers_to_use["Authorization"] = "Bearer {}".format(access_token) + + if headers: + headers_to_use.update(headers) + + response_data = {} + retryable_error = False + + retries = _exponential_backoff.ExponentialBackoff() + for _ in retries: + response = request( + method="POST", url=token_uri, headers=headers_to_use, body=body, **kwargs + ) + response_body = ( + response.data.decode("utf-8") + if hasattr(response.data, "decode") + else response.data + ) + + try: + # response_body should be a JSON + response_data = json.loads(response_body) + except ValueError: + response_data = response_body + + if response.status == http_client.OK: + return True, response_data, None + + retryable_error = _can_retry( + status_code=response.status, response_data=response_data + ) + + if not can_retry or not retryable_error: + return False, response_data, retryable_error + + return False, response_data, retryable_error + + +def _token_endpoint_request( + request, + token_uri, + body, + access_token=None, + use_json=False, + can_retry=True, + headers=None, + **kwargs +): + """Makes a request to the OAuth 2.0 authorization server's token endpoint. + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. + token_uri (str): The OAuth 2.0 authorizations server's token endpoint + URI. + body (Mapping[str, str]): The parameters to send in the request body. + access_token (Optional(str)): The access token needed to make the request. + use_json (Optional(bool)): Use urlencoded format or json format for the + content type. The default value is False. + can_retry (bool): Enable or disable request retry behavior. + headers (Optional[Mapping[str, str]]): The headers for the request. + kwargs: Additional arguments passed on to the request method. The + kwargs will be passed to `requests.request` method, see: + https://docs.python-requests.org/en/latest/api/#requests.request. + For example, you can use `cert=("cert_pem_path", "key_pem_path")` + to set up client side SSL certificate, and use + `verify="ca_bundle_path"` to set up the CA certificates for sever + side SSL certificate verification. + + Returns: + Mapping[str, str]: The JSON-decoded response data. + + Raises: + google.auth.exceptions.RefreshError: If the token endpoint returned + an error. + """ + + response_status_ok, response_data, retryable_error = _token_endpoint_request_no_throw( + request, + token_uri, + body, + access_token=access_token, + use_json=use_json, + can_retry=can_retry, + headers=headers, + **kwargs + ) + if not response_status_ok: + _handle_error_response(response_data, retryable_error) + return response_data + + +def jwt_grant(request, token_uri, assertion, can_retry=True): + """Implements the JWT Profile for OAuth 2.0 Authorization Grants. + + For more details, see `rfc7523 section 4`_. + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. + token_uri (str): The OAuth 2.0 authorizations server's token endpoint + URI. + assertion (str): The OAuth 2.0 assertion. + can_retry (bool): Enable or disable request retry behavior. + + Returns: + Tuple[str, Optional[datetime], Mapping[str, str]]: The access token, + expiration, and additional data returned by the token endpoint. + + Raises: + google.auth.exceptions.RefreshError: If the token endpoint returned + an error. + + .. _rfc7523 section 4: https://tools.ietf.org/html/rfc7523#section-4 + """ + body = {"assertion": assertion, "grant_type": _JWT_GRANT_TYPE} + + response_data = _token_endpoint_request( + request, + token_uri, + body, + can_retry=can_retry, + headers={ + metrics.API_CLIENT_HEADER: metrics.token_request_access_token_sa_assertion() + }, + ) + + try: + access_token = response_data["access_token"] + except KeyError as caught_exc: + new_exc = exceptions.RefreshError( + "No access token in response.", response_data, retryable=False + ) + raise new_exc from caught_exc + + expiry = _parse_expiry(response_data) + + return access_token, expiry, response_data + + +def call_iam_generate_id_token_endpoint( + request, + iam_id_token_endpoint, + signer_email, + audience, + access_token, + universe_domain=credentials.DEFAULT_UNIVERSE_DOMAIN, +): + """Call iam.generateIdToken endpoint to get ID token. + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. + iam_id_token_endpoint (str): The IAM ID token endpoint to use. + signer_email (str): The signer email used to form the IAM + generateIdToken endpoint. + audience (str): The audience for the ID token. + access_token (str): The access token used to call the IAM endpoint. + + Returns: + Tuple[str, datetime]: The ID token and expiration. + """ + body = {"audience": audience, "includeEmail": "true", "useEmailAzp": "true"} + + response_data = _token_endpoint_request( + request, + iam_id_token_endpoint.replace( + credentials.DEFAULT_UNIVERSE_DOMAIN, universe_domain + ).format(signer_email), + body, + access_token=access_token, + use_json=True, + ) + + try: + id_token = response_data["token"] + except KeyError as caught_exc: + new_exc = exceptions.RefreshError( + "No ID token in response.", response_data, retryable=False + ) + raise new_exc from caught_exc + + payload = jwt.decode(id_token, verify=False) + expiry = datetime.datetime.utcfromtimestamp(payload["exp"]) + + return id_token, expiry + + +def id_token_jwt_grant(request, token_uri, assertion, can_retry=True): + """Implements the JWT Profile for OAuth 2.0 Authorization Grants, but + requests an OpenID Connect ID Token instead of an access token. + + This is a variant on the standard JWT Profile that is currently unique + to Google. This was added for the benefit of authenticating to services + that require ID Tokens instead of access tokens or JWT bearer tokens. + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. + token_uri (str): The OAuth 2.0 authorization server's token endpoint + URI. + assertion (str): JWT token signed by a service account. The token's + payload must include a ``target_audience`` claim. + can_retry (bool): Enable or disable request retry behavior. + + Returns: + Tuple[str, Optional[datetime], Mapping[str, str]]: + The (encoded) Open ID Connect ID Token, expiration, and additional + data returned by the endpoint. + + Raises: + google.auth.exceptions.RefreshError: If the token endpoint returned + an error. + """ + body = {"assertion": assertion, "grant_type": _JWT_GRANT_TYPE} + + response_data = _token_endpoint_request( + request, + token_uri, + body, + can_retry=can_retry, + headers={ + metrics.API_CLIENT_HEADER: metrics.token_request_id_token_sa_assertion() + }, + ) + + try: + id_token = response_data["id_token"] + except KeyError as caught_exc: + new_exc = exceptions.RefreshError( + "No ID token in response.", response_data, retryable=False + ) + raise new_exc from caught_exc + + payload = jwt.decode(id_token, verify=False) + expiry = datetime.datetime.utcfromtimestamp(payload["exp"]) + + return id_token, expiry, response_data + + +def _handle_refresh_grant_response(response_data, refresh_token): + """Extract tokens from refresh grant response. + + Args: + response_data (Mapping[str, str]): Refresh grant response data. + refresh_token (str): Current refresh token. + + Returns: + Tuple[str, str, Optional[datetime], Mapping[str, str]]: The access token, + refresh token, expiration, and additional data returned by the token + endpoint. If response_data doesn't have refresh token, then the current + refresh token will be returned. + + Raises: + google.auth.exceptions.RefreshError: If the token endpoint returned + an error. + """ + try: + access_token = response_data["access_token"] + except KeyError as caught_exc: + new_exc = exceptions.RefreshError( + "No access token in response.", response_data, retryable=False + ) + raise new_exc from caught_exc + + refresh_token = response_data.get("refresh_token", refresh_token) + expiry = _parse_expiry(response_data) + + return access_token, refresh_token, expiry, response_data + + +def refresh_grant( + request, + token_uri, + refresh_token, + client_id, + client_secret, + scopes=None, + rapt_token=None, + can_retry=True, +): + """Implements the OAuth 2.0 refresh token grant. + + For more details, see `rfc678 section 6`_. + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. + token_uri (str): The OAuth 2.0 authorizations server's token endpoint + URI. + refresh_token (str): The refresh token to use to get a new access + token. + client_id (str): The OAuth 2.0 application's client ID. + client_secret (str): The Oauth 2.0 appliaction's client secret. + scopes (Optional(Sequence[str])): Scopes to request. If present, all + scopes must be authorized for the refresh token. Useful if refresh + token has a wild card scope (e.g. + 'https://www.googleapis.com/auth/any-api'). + rapt_token (Optional(str)): The reauth Proof Token. + can_retry (bool): Enable or disable request retry behavior. + + Returns: + Tuple[str, str, Optional[datetime], Mapping[str, str]]: The access + token, new or current refresh token, expiration, and additional data + returned by the token endpoint. + + Raises: + google.auth.exceptions.RefreshError: If the token endpoint returned + an error. + + .. _rfc6748 section 6: https://tools.ietf.org/html/rfc6749#section-6 + """ + body = { + "grant_type": _REFRESH_GRANT_TYPE, + "client_id": client_id, + "client_secret": client_secret, + "refresh_token": refresh_token, + } + if scopes: + body["scope"] = " ".join(scopes) + if rapt_token: + body["rapt"] = rapt_token + + response_data = _token_endpoint_request( + request, token_uri, body, can_retry=can_retry + ) + return _handle_refresh_grant_response(response_data, refresh_token) diff --git a/lib/python3.10/site-packages/google/oauth2/_client_async.py b/lib/python3.10/site-packages/google/oauth2/_client_async.py new file mode 100644 index 0000000000000000000000000000000000000000..8867f0a5274077ffd1e7be6e1d6f01614ea9006f --- /dev/null +++ b/lib/python3.10/site-packages/google/oauth2/_client_async.py @@ -0,0 +1,286 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""OAuth 2.0 async client. + +This is a client for interacting with an OAuth 2.0 authorization server's +token endpoint. + +For more information about the token endpoint, see +`Section 3.1 of rfc6749`_ + +.. _Section 3.1 of rfc6749: https://tools.ietf.org/html/rfc6749#section-3.2 +""" + +import datetime +import http.client as http_client +import json +import urllib + +from google.auth import _exponential_backoff +from google.auth import exceptions +from google.auth import jwt +from google.oauth2 import _client as client + + +async def _token_endpoint_request_no_throw( + request, token_uri, body, access_token=None, use_json=False, can_retry=True +): + """Makes a request to the OAuth 2.0 authorization server's token endpoint. + This function doesn't throw on response errors. + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. + token_uri (str): The OAuth 2.0 authorizations server's token endpoint + URI. + body (Mapping[str, str]): The parameters to send in the request body. + access_token (Optional(str)): The access token needed to make the request. + use_json (Optional(bool)): Use urlencoded format or json format for the + content type. The default value is False. + can_retry (bool): Enable or disable request retry behavior. + + Returns: + Tuple(bool, Mapping[str, str], Optional[bool]): A boolean indicating + if the request is successful, a mapping for the JSON-decoded response + data and in the case of an error a boolean indicating if the error + is retryable. + """ + if use_json: + headers = {"Content-Type": client._JSON_CONTENT_TYPE} + body = json.dumps(body).encode("utf-8") + else: + headers = {"Content-Type": client._URLENCODED_CONTENT_TYPE} + body = urllib.parse.urlencode(body).encode("utf-8") + + if access_token: + headers["Authorization"] = "Bearer {}".format(access_token) + + response_data = {} + retryable_error = False + + retries = _exponential_backoff.ExponentialBackoff() + for _ in retries: + response = await request( + method="POST", url=token_uri, headers=headers, body=body + ) + + # Using data.read() resulted in zlib decompression errors. This may require future investigation. + response_body1 = await response.content() + + response_body = ( + response_body1.decode("utf-8") + if hasattr(response_body1, "decode") + else response_body1 + ) + + try: + response_data = json.loads(response_body) + except ValueError: + response_data = response_body + + if response.status == http_client.OK: + return True, response_data, None + + retryable_error = client._can_retry( + status_code=response.status, response_data=response_data + ) + + if not can_retry or not retryable_error: + return False, response_data, retryable_error + + return False, response_data, retryable_error + + +async def _token_endpoint_request( + request, token_uri, body, access_token=None, use_json=False, can_retry=True +): + """Makes a request to the OAuth 2.0 authorization server's token endpoint. + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. + token_uri (str): The OAuth 2.0 authorizations server's token endpoint + URI. + body (Mapping[str, str]): The parameters to send in the request body. + access_token (Optional(str)): The access token needed to make the request. + use_json (Optional(bool)): Use urlencoded format or json format for the + content type. The default value is False. + can_retry (bool): Enable or disable request retry behavior. + + Returns: + Mapping[str, str]: The JSON-decoded response data. + + Raises: + google.auth.exceptions.RefreshError: If the token endpoint returned + an error. + """ + + response_status_ok, response_data, retryable_error = await _token_endpoint_request_no_throw( + request, + token_uri, + body, + access_token=access_token, + use_json=use_json, + can_retry=can_retry, + ) + if not response_status_ok: + client._handle_error_response(response_data, retryable_error) + return response_data + + +async def jwt_grant(request, token_uri, assertion, can_retry=True): + """Implements the JWT Profile for OAuth 2.0 Authorization Grants. + + For more details, see `rfc7523 section 4`_. + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. + token_uri (str): The OAuth 2.0 authorizations server's token endpoint + URI. + assertion (str): The OAuth 2.0 assertion. + can_retry (bool): Enable or disable request retry behavior. + + Returns: + Tuple[str, Optional[datetime], Mapping[str, str]]: The access token, + expiration, and additional data returned by the token endpoint. + + Raises: + google.auth.exceptions.RefreshError: If the token endpoint returned + an error. + + .. _rfc7523 section 4: https://tools.ietf.org/html/rfc7523#section-4 + """ + body = {"assertion": assertion, "grant_type": client._JWT_GRANT_TYPE} + + response_data = await _token_endpoint_request( + request, token_uri, body, can_retry=can_retry + ) + + try: + access_token = response_data["access_token"] + except KeyError as caught_exc: + new_exc = exceptions.RefreshError( + "No access token in response.", response_data, retryable=False + ) + raise new_exc from caught_exc + + expiry = client._parse_expiry(response_data) + + return access_token, expiry, response_data + + +async def id_token_jwt_grant(request, token_uri, assertion, can_retry=True): + """Implements the JWT Profile for OAuth 2.0 Authorization Grants, but + requests an OpenID Connect ID Token instead of an access token. + + This is a variant on the standard JWT Profile that is currently unique + to Google. This was added for the benefit of authenticating to services + that require ID Tokens instead of access tokens or JWT bearer tokens. + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. + token_uri (str): The OAuth 2.0 authorization server's token endpoint + URI. + assertion (str): JWT token signed by a service account. The token's + payload must include a ``target_audience`` claim. + can_retry (bool): Enable or disable request retry behavior. + + Returns: + Tuple[str, Optional[datetime], Mapping[str, str]]: + The (encoded) Open ID Connect ID Token, expiration, and additional + data returned by the endpoint. + + Raises: + google.auth.exceptions.RefreshError: If the token endpoint returned + an error. + """ + body = {"assertion": assertion, "grant_type": client._JWT_GRANT_TYPE} + + response_data = await _token_endpoint_request( + request, token_uri, body, can_retry=can_retry + ) + + try: + id_token = response_data["id_token"] + except KeyError as caught_exc: + new_exc = exceptions.RefreshError( + "No ID token in response.", response_data, retryable=False + ) + raise new_exc from caught_exc + + payload = jwt.decode(id_token, verify=False) + expiry = datetime.datetime.utcfromtimestamp(payload["exp"]) + + return id_token, expiry, response_data + + +async def refresh_grant( + request, + token_uri, + refresh_token, + client_id, + client_secret, + scopes=None, + rapt_token=None, + can_retry=True, +): + """Implements the OAuth 2.0 refresh token grant. + + For more details, see `rfc678 section 6`_. + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. + token_uri (str): The OAuth 2.0 authorizations server's token endpoint + URI. + refresh_token (str): The refresh token to use to get a new access + token. + client_id (str): The OAuth 2.0 application's client ID. + client_secret (str): The Oauth 2.0 appliaction's client secret. + scopes (Optional(Sequence[str])): Scopes to request. If present, all + scopes must be authorized for the refresh token. Useful if refresh + token has a wild card scope (e.g. + 'https://www.googleapis.com/auth/any-api'). + rapt_token (Optional(str)): The reauth Proof Token. + can_retry (bool): Enable or disable request retry behavior. + + Returns: + Tuple[str, Optional[str], Optional[datetime], Mapping[str, str]]: The + access token, new or current refresh token, expiration, and additional data + returned by the token endpoint. + + Raises: + google.auth.exceptions.RefreshError: If the token endpoint returned + an error. + + .. _rfc6748 section 6: https://tools.ietf.org/html/rfc6749#section-6 + """ + body = { + "grant_type": client._REFRESH_GRANT_TYPE, + "client_id": client_id, + "client_secret": client_secret, + "refresh_token": refresh_token, + } + if scopes: + body["scope"] = " ".join(scopes) + if rapt_token: + body["rapt"] = rapt_token + + response_data = await _token_endpoint_request( + request, token_uri, body, can_retry=can_retry + ) + return client._handle_refresh_grant_response(response_data, refresh_token) diff --git a/lib/python3.10/site-packages/google/oauth2/_credentials_async.py b/lib/python3.10/site-packages/google/oauth2/_credentials_async.py new file mode 100644 index 0000000000000000000000000000000000000000..b5561aae02293b3ab7f3180803b07993331ab149 --- /dev/null +++ b/lib/python3.10/site-packages/google/oauth2/_credentials_async.py @@ -0,0 +1,118 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""OAuth 2.0 Async Credentials. + +This module provides credentials based on OAuth 2.0 access and refresh tokens. +These credentials usually access resources on behalf of a user (resource +owner). + +Specifically, this is intended to use access tokens acquired using the +`Authorization Code grant`_ and can refresh those tokens using a +optional `refresh token`_. + +Obtaining the initial access and refresh token is outside of the scope of this +module. Consult `rfc6749 section 4.1`_ for complete details on the +Authorization Code grant flow. + +.. _Authorization Code grant: https://tools.ietf.org/html/rfc6749#section-1.3.1 +.. _refresh token: https://tools.ietf.org/html/rfc6749#section-6 +.. _rfc6749 section 4.1: https://tools.ietf.org/html/rfc6749#section-4.1 +""" + +from google.auth import _credentials_async as credentials +from google.auth import _helpers +from google.auth import exceptions +from google.oauth2 import _reauth_async as reauth +from google.oauth2 import credentials as oauth2_credentials + + +class Credentials(oauth2_credentials.Credentials): + """Credentials using OAuth 2.0 access and refresh tokens. + + The credentials are considered immutable. If you want to modify the + quota project, use :meth:`with_quota_project` or :: + + credentials = credentials.with_quota_project('myproject-123) + """ + + @_helpers.copy_docstring(credentials.Credentials) + async def refresh(self, request): + if ( + self._refresh_token is None + or self._token_uri is None + or self._client_id is None + or self._client_secret is None + ): + raise exceptions.RefreshError( + "The credentials do not contain the necessary fields need to " + "refresh the access token. You must specify refresh_token, " + "token_uri, client_id, and client_secret." + ) + + ( + access_token, + refresh_token, + expiry, + grant_response, + rapt_token, + ) = await reauth.refresh_grant( + request, + self._token_uri, + self._refresh_token, + self._client_id, + self._client_secret, + scopes=self._scopes, + rapt_token=self._rapt_token, + enable_reauth_refresh=self._enable_reauth_refresh, + ) + + self.token = access_token + self.expiry = expiry + self._refresh_token = refresh_token + self._id_token = grant_response.get("id_token") + self._rapt_token = rapt_token + + if self._scopes and "scope" in grant_response: + requested_scopes = frozenset(self._scopes) + granted_scopes = frozenset(grant_response["scope"].split()) + scopes_requested_but_not_granted = requested_scopes - granted_scopes + if scopes_requested_but_not_granted: + raise exceptions.RefreshError( + "Not all requested scopes were granted by the " + "authorization server, missing scopes {}.".format( + ", ".join(scopes_requested_but_not_granted) + ) + ) + + @_helpers.copy_docstring(credentials.Credentials) + async def before_request(self, request, method, url, headers): + if not self.valid: + await self.refresh(request) + self.apply(headers) + + +class UserAccessTokenCredentials(oauth2_credentials.UserAccessTokenCredentials): + """Access token credentials for user account. + + Obtain the access token for a given user account or the current active + user account with the ``gcloud auth print-access-token`` command. + + Args: + account (Optional[str]): Account to get the access token for. If not + specified, the current active account will be used. + quota_project_id (Optional[str]): The project ID used for quota + and billing. + + """ diff --git a/lib/python3.10/site-packages/google/oauth2/_id_token_async.py b/lib/python3.10/site-packages/google/oauth2/_id_token_async.py new file mode 100644 index 0000000000000000000000000000000000000000..6594e416aea07f114cda027d97f2c4d5d869a55d --- /dev/null +++ b/lib/python3.10/site-packages/google/oauth2/_id_token_async.py @@ -0,0 +1,285 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Google ID Token helpers. + +Provides support for verifying `OpenID Connect ID Tokens`_, especially ones +generated by Google infrastructure. + +To parse and verify an ID Token issued by Google's OAuth 2.0 authorization +server use :func:`verify_oauth2_token`. To verify an ID Token issued by +Firebase, use :func:`verify_firebase_token`. + +A general purpose ID Token verifier is available as :func:`verify_token`. + +Example:: + + from google.oauth2 import _id_token_async + from google.auth.transport import aiohttp_requests + + request = aiohttp_requests.Request() + + id_info = await _id_token_async.verify_oauth2_token( + token, request, 'my-client-id.example.com') + + if id_info['iss'] != 'https://accounts.google.com': + raise ValueError('Wrong issuer.') + + userid = id_info['sub'] + +By default, this will re-fetch certificates for each verification. Because +Google's public keys are only changed infrequently (on the order of once per +day), you may wish to take advantage of caching to reduce latency and the +potential for network errors. This can be accomplished using an external +library like `CacheControl`_ to create a cache-aware +:class:`google.auth.transport.Request`:: + + import cachecontrol + import google.auth.transport.requests + import requests + + session = requests.session() + cached_session = cachecontrol.CacheControl(session) + request = google.auth.transport.requests.Request(session=cached_session) + +.. _OpenID Connect ID Token: + http://openid.net/specs/openid-connect-core-1_0.html#IDToken +.. _CacheControl: https://cachecontrol.readthedocs.io +""" + +import http.client as http_client +import json +import os + +from google.auth import environment_vars +from google.auth import exceptions +from google.auth import jwt +from google.auth.transport import requests +from google.oauth2 import id_token as sync_id_token + + +async def _fetch_certs(request, certs_url): + """Fetches certificates. + + Google-style cerificate endpoints return JSON in the format of + ``{'key id': 'x509 certificate'}``. + + Args: + request (google.auth.transport.Request): The object used to make + HTTP requests. This must be an aiohttp request. + certs_url (str): The certificate endpoint URL. + + Returns: + Mapping[str, str]: A mapping of public key ID to x.509 certificate + data. + """ + response = await request(certs_url, method="GET") + + if response.status != http_client.OK: + raise exceptions.TransportError( + "Could not fetch certificates at {}".format(certs_url) + ) + + data = await response.content() + + return json.loads(data) + + +async def verify_token( + id_token, + request, + audience=None, + certs_url=sync_id_token._GOOGLE_OAUTH2_CERTS_URL, + clock_skew_in_seconds=0, +): + """Verifies an ID token and returns the decoded token. + + Args: + id_token (Union[str, bytes]): The encoded token. + request (google.auth.transport.Request): The object used to make + HTTP requests. This must be an aiohttp request. + audience (str): The audience that this token is intended for. If None + then the audience is not verified. + certs_url (str): The URL that specifies the certificates to use to + verify the token. This URL should return JSON in the format of + ``{'key id': 'x509 certificate'}``. + clock_skew_in_seconds (int): The clock skew used for `iat` and `exp` + validation. + + Returns: + Mapping[str, Any]: The decoded token. + """ + certs = await _fetch_certs(request, certs_url) + + return jwt.decode( + id_token, + certs=certs, + audience=audience, + clock_skew_in_seconds=clock_skew_in_seconds, + ) + + +async def verify_oauth2_token( + id_token, request, audience=None, clock_skew_in_seconds=0 +): + """Verifies an ID Token issued by Google's OAuth 2.0 authorization server. + + Args: + id_token (Union[str, bytes]): The encoded token. + request (google.auth.transport.Request): The object used to make + HTTP requests. This must be an aiohttp request. + audience (str): The audience that this token is intended for. This is + typically your application's OAuth 2.0 client ID. If None then the + audience is not verified. + clock_skew_in_seconds (int): The clock skew used for `iat` and `exp` + validation. + + Returns: + Mapping[str, Any]: The decoded token. + + Raises: + exceptions.GoogleAuthError: If the issuer is invalid. + """ + idinfo = await verify_token( + id_token, + request, + audience=audience, + certs_url=sync_id_token._GOOGLE_OAUTH2_CERTS_URL, + clock_skew_in_seconds=clock_skew_in_seconds, + ) + + if idinfo["iss"] not in sync_id_token._GOOGLE_ISSUERS: + raise exceptions.GoogleAuthError( + "Wrong issuer. 'iss' should be one of the following: {}".format( + sync_id_token._GOOGLE_ISSUERS + ) + ) + + return idinfo + + +async def verify_firebase_token( + id_token, request, audience=None, clock_skew_in_seconds=0 +): + """Verifies an ID Token issued by Firebase Authentication. + + Args: + id_token (Union[str, bytes]): The encoded token. + request (google.auth.transport.Request): The object used to make + HTTP requests. This must be an aiohttp request. + audience (str): The audience that this token is intended for. This is + typically your Firebase application ID. If None then the audience + is not verified. + clock_skew_in_seconds (int): The clock skew used for `iat` and `exp` + validation. + + Returns: + Mapping[str, Any]: The decoded token. + """ + return await verify_token( + id_token, + request, + audience=audience, + certs_url=sync_id_token._GOOGLE_APIS_CERTS_URL, + clock_skew_in_seconds=clock_skew_in_seconds, + ) + + +async def fetch_id_token(request, audience): + """Fetch the ID Token from the current environment. + + This function acquires ID token from the environment in the following order. + See https://google.aip.dev/auth/4110. + + 1. If the environment variable ``GOOGLE_APPLICATION_CREDENTIALS`` is set + to the path of a valid service account JSON file, then ID token is + acquired using this service account credentials. + 2. If the application is running in Compute Engine, App Engine or Cloud Run, + then the ID token are obtained from the metadata server. + 3. If metadata server doesn't exist and no valid service account credentials + are found, :class:`~google.auth.exceptions.DefaultCredentialsError` will + be raised. + + Example:: + + import google.oauth2._id_token_async + import google.auth.transport.aiohttp_requests + + request = google.auth.transport.aiohttp_requests.Request() + target_audience = "https://pubsub.googleapis.com" + + id_token = await google.oauth2._id_token_async.fetch_id_token(request, target_audience) + + Args: + request (google.auth.transport.aiohttp_requests.Request): A callable used to make + HTTP requests. + audience (str): The audience that this ID token is intended for. + + Returns: + str: The ID token. + + Raises: + ~google.auth.exceptions.DefaultCredentialsError: + If metadata server doesn't exist and no valid service account + credentials are found. + """ + # 1. Try to get credentials from the GOOGLE_APPLICATION_CREDENTIALS environment + # variable. + credentials_filename = os.environ.get(environment_vars.CREDENTIALS) + if credentials_filename: + if not ( + os.path.exists(credentials_filename) + and os.path.isfile(credentials_filename) + ): + raise exceptions.DefaultCredentialsError( + "GOOGLE_APPLICATION_CREDENTIALS path is either not found or invalid." + ) + + try: + with open(credentials_filename, "r") as f: + from google.oauth2 import _service_account_async as service_account + + info = json.load(f) + if info.get("type") == "service_account": + credentials = service_account.IDTokenCredentials.from_service_account_info( + info, target_audience=audience + ) + await credentials.refresh(request) + return credentials.token + except ValueError as caught_exc: + new_exc = exceptions.DefaultCredentialsError( + "GOOGLE_APPLICATION_CREDENTIALS is not valid service account credentials.", + caught_exc, + ) + raise new_exc from caught_exc + + # 2. Try to fetch ID token from metada server if it exists. The code works + # for GAE and Cloud Run metadata server as well. + try: + from google.auth import compute_engine + from google.auth.compute_engine import _metadata + + request_new = requests.Request() + if _metadata.ping(request_new): + credentials = compute_engine.IDTokenCredentials( + request_new, audience, use_metadata_identity_endpoint=True + ) + credentials.refresh(request_new) + return credentials.token + except (ImportError, exceptions.TransportError): + pass + + raise exceptions.DefaultCredentialsError( + "Neither metadata server or valid service account credentials are found." + ) diff --git a/lib/python3.10/site-packages/google/oauth2/_reauth_async.py b/lib/python3.10/site-packages/google/oauth2/_reauth_async.py new file mode 100644 index 0000000000000000000000000000000000000000..de3675c5239c42e36b5b8c496fea5848b9cbc784 --- /dev/null +++ b/lib/python3.10/site-packages/google/oauth2/_reauth_async.py @@ -0,0 +1,328 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A module that provides functions for handling rapt authentication. + +Reauth is a process of obtaining additional authentication (such as password, +security token, etc.) while refreshing OAuth 2.0 credentials for a user. + +Credentials that use the Reauth flow must have the reauth scope, +``https://www.googleapis.com/auth/accounts.reauth``. + +This module provides a high-level function for executing the Reauth process, +:func:`refresh_grant`, and lower-level helpers for doing the individual +steps of the reauth process. + +Those steps are: + +1. Obtaining a list of challenges from the reauth server. +2. Running through each challenge and sending the result back to the reauth + server. +3. Refreshing the access token using the returned rapt token. +""" + +import sys + +from google.auth import exceptions +from google.oauth2 import _client +from google.oauth2 import _client_async +from google.oauth2 import challenges +from google.oauth2 import reauth + + +async def _get_challenges( + request, supported_challenge_types, access_token, requested_scopes=None +): + """Does initial request to reauth API to get the challenges. + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. This must be an aiohttp request. + supported_challenge_types (Sequence[str]): list of challenge names + supported by the manager. + access_token (str): Access token with reauth scopes. + requested_scopes (Optional(Sequence[str])): Authorized scopes for the credentials. + + Returns: + dict: The response from the reauth API. + """ + body = {"supportedChallengeTypes": supported_challenge_types} + if requested_scopes: + body["oauthScopesForDomainPolicyLookup"] = requested_scopes + + return await _client_async._token_endpoint_request( + request, + reauth._REAUTH_API + ":start", + body, + access_token=access_token, + use_json=True, + ) + + +async def _send_challenge_result( + request, session_id, challenge_id, client_input, access_token +): + """Attempt to refresh access token by sending next challenge result. + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. This must be an aiohttp request. + session_id (str): session id returned by the initial reauth call. + challenge_id (str): challenge id returned by the initial reauth call. + client_input: dict with a challenge-specific client input. For example: + ``{'credential': password}`` for password challenge. + access_token (str): Access token with reauth scopes. + + Returns: + dict: The response from the reauth API. + """ + body = { + "sessionId": session_id, + "challengeId": challenge_id, + "action": "RESPOND", + "proposalResponse": client_input, + } + + return await _client_async._token_endpoint_request( + request, + reauth._REAUTH_API + "/{}:continue".format(session_id), + body, + access_token=access_token, + use_json=True, + ) + + +async def _run_next_challenge(msg, request, access_token): + """Get the next challenge from msg and run it. + + Args: + msg (dict): Reauth API response body (either from the initial request to + https://reauth.googleapis.com/v2/sessions:start or from sending the + previous challenge response to + https://reauth.googleapis.com/v2/sessions/id:continue) + request (google.auth.transport.Request): A callable used to make + HTTP requests. This must be an aiohttp request. + access_token (str): reauth access token + + Returns: + dict: The response from the reauth API. + + Raises: + google.auth.exceptions.ReauthError: if reauth failed. + """ + for challenge in msg["challenges"]: + if challenge["status"] != "READY": + # Skip non-activated challenges. + continue + c = challenges.AVAILABLE_CHALLENGES.get(challenge["challengeType"], None) + if not c: + raise exceptions.ReauthFailError( + "Unsupported challenge type {0}. Supported types: {1}".format( + challenge["challengeType"], + ",".join(list(challenges.AVAILABLE_CHALLENGES.keys())), + ) + ) + if not c.is_locally_eligible: + raise exceptions.ReauthFailError( + "Challenge {0} is not locally eligible".format( + challenge["challengeType"] + ) + ) + client_input = c.obtain_challenge_input(challenge) + if not client_input: + return None + return await _send_challenge_result( + request, + msg["sessionId"], + challenge["challengeId"], + client_input, + access_token, + ) + return None + + +async def _obtain_rapt(request, access_token, requested_scopes): + """Given an http request method and reauth access token, get rapt token. + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. This must be an aiohttp request. + access_token (str): reauth access token + requested_scopes (Sequence[str]): scopes required by the client application + + Returns: + str: The rapt token. + + Raises: + google.auth.exceptions.ReauthError: if reauth failed + """ + msg = await _get_challenges( + request, + list(challenges.AVAILABLE_CHALLENGES.keys()), + access_token, + requested_scopes, + ) + + if msg["status"] == reauth._AUTHENTICATED: + return msg["encodedProofOfReauthToken"] + + for _ in range(0, reauth.RUN_CHALLENGE_RETRY_LIMIT): + if not ( + msg["status"] == reauth._CHALLENGE_REQUIRED + or msg["status"] == reauth._CHALLENGE_PENDING + ): + raise exceptions.ReauthFailError( + "Reauthentication challenge failed due to API error: {}".format( + msg["status"] + ) + ) + + if not reauth.is_interactive(): + raise exceptions.ReauthFailError( + "Reauthentication challenge could not be answered because you are not" + " in an interactive session." + ) + + msg = await _run_next_challenge(msg, request, access_token) + + if msg["status"] == reauth._AUTHENTICATED: + return msg["encodedProofOfReauthToken"] + + # If we got here it means we didn't get authenticated. + raise exceptions.ReauthFailError("Failed to obtain rapt token.") + + +async def get_rapt_token( + request, client_id, client_secret, refresh_token, token_uri, scopes=None +): + """Given an http request method and refresh_token, get rapt token. + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. This must be an aiohttp request. + client_id (str): client id to get access token for reauth scope. + client_secret (str): client secret for the client_id + refresh_token (str): refresh token to refresh access token + token_uri (str): uri to refresh access token + scopes (Optional(Sequence[str])): scopes required by the client application + + Returns: + str: The rapt token. + Raises: + google.auth.exceptions.RefreshError: If reauth failed. + """ + sys.stderr.write("Reauthentication required.\n") + + # Get access token for reauth. + access_token, _, _, _ = await _client_async.refresh_grant( + request=request, + client_id=client_id, + client_secret=client_secret, + refresh_token=refresh_token, + token_uri=token_uri, + scopes=[reauth._REAUTH_SCOPE], + ) + + # Get rapt token from reauth API. + rapt_token = await _obtain_rapt(request, access_token, requested_scopes=scopes) + + return rapt_token + + +async def refresh_grant( + request, + token_uri, + refresh_token, + client_id, + client_secret, + scopes=None, + rapt_token=None, + enable_reauth_refresh=False, +): + """Implements the reauthentication flow. + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. This must be an aiohttp request. + token_uri (str): The OAuth 2.0 authorizations server's token endpoint + URI. + refresh_token (str): The refresh token to use to get a new access + token. + client_id (str): The OAuth 2.0 application's client ID. + client_secret (str): The Oauth 2.0 appliaction's client secret. + scopes (Optional(Sequence[str])): Scopes to request. If present, all + scopes must be authorized for the refresh token. Useful if refresh + token has a wild card scope (e.g. + 'https://www.googleapis.com/auth/any-api'). + rapt_token (Optional(str)): The rapt token for reauth. + enable_reauth_refresh (Optional[bool]): Whether reauth refresh flow + should be used. The default value is False. This option is for + gcloud only, other users should use the default value. + + Returns: + Tuple[str, Optional[str], Optional[datetime], Mapping[str, str], str]: The + access token, new refresh token, expiration, the additional data + returned by the token endpoint, and the rapt token. + + Raises: + google.auth.exceptions.RefreshError: If the token endpoint returned + an error. + """ + body = { + "grant_type": _client._REFRESH_GRANT_TYPE, + "client_id": client_id, + "client_secret": client_secret, + "refresh_token": refresh_token, + } + if scopes: + body["scope"] = " ".join(scopes) + if rapt_token: + body["rapt"] = rapt_token + + response_status_ok, response_data, retryable_error = await _client_async._token_endpoint_request_no_throw( + request, token_uri, body + ) + if ( + not response_status_ok + and response_data.get("error") == reauth._REAUTH_NEEDED_ERROR + and ( + response_data.get("error_subtype") + == reauth._REAUTH_NEEDED_ERROR_INVALID_RAPT + or response_data.get("error_subtype") + == reauth._REAUTH_NEEDED_ERROR_RAPT_REQUIRED + ) + ): + if not enable_reauth_refresh: + raise exceptions.RefreshError( + "Reauthentication is needed. Please run `gcloud auth application-default login` to reauthenticate." + ) + + rapt_token = await get_rapt_token( + request, client_id, client_secret, refresh_token, token_uri, scopes=scopes + ) + body["rapt"] = rapt_token + ( + response_status_ok, + response_data, + retryable_error, + ) = await _client_async._token_endpoint_request_no_throw( + request, token_uri, body + ) + + if not response_status_ok: + _client._handle_error_response(response_data, retryable_error) + refresh_response = _client._handle_refresh_grant_response( + response_data, refresh_token + ) + return refresh_response + (rapt_token,) diff --git a/lib/python3.10/site-packages/google/oauth2/challenges.py b/lib/python3.10/site-packages/google/oauth2/challenges.py new file mode 100644 index 0000000000000000000000000000000000000000..6468498bcb6d027ef02e2af4ad8d89106c525ea8 --- /dev/null +++ b/lib/python3.10/site-packages/google/oauth2/challenges.py @@ -0,0 +1,281 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" Challenges for reauthentication. +""" + +import abc +import base64 +import getpass +import sys + +from google.auth import _helpers +from google.auth import exceptions +from google.oauth2 import webauthn_handler_factory +from google.oauth2.webauthn_types import ( + AuthenticationExtensionsClientInputs, + GetRequest, + PublicKeyCredentialDescriptor, +) + + +REAUTH_ORIGIN = "https://accounts.google.com" +SAML_CHALLENGE_MESSAGE = ( + "Please run `gcloud auth login` to complete reauthentication with SAML." +) +WEBAUTHN_TIMEOUT_MS = 120000 # Two minute timeout + + +def get_user_password(text): + """Get password from user. + + Override this function with a different logic if you are using this library + outside a CLI. + + Args: + text (str): message for the password prompt. + + Returns: + str: password string. + """ + return getpass.getpass(text) + + +class ReauthChallenge(metaclass=abc.ABCMeta): + """Base class for reauth challenges.""" + + @property + @abc.abstractmethod + def name(self): # pragma: NO COVER + """Returns the name of the challenge.""" + raise NotImplementedError("name property must be implemented") + + @property + @abc.abstractmethod + def is_locally_eligible(self): # pragma: NO COVER + """Returns true if a challenge is supported locally on this machine.""" + raise NotImplementedError("is_locally_eligible property must be implemented") + + @abc.abstractmethod + def obtain_challenge_input(self, metadata): # pragma: NO COVER + """Performs logic required to obtain credentials and returns it. + + Args: + metadata (Mapping): challenge metadata returned in the 'challenges' field in + the initial reauth request. Includes the 'challengeType' field + and other challenge-specific fields. + + Returns: + response that will be send to the reauth service as the content of + the 'proposalResponse' field in the request body. Usually a dict + with the keys specific to the challenge. For example, + ``{'credential': password}`` for password challenge. + """ + raise NotImplementedError("obtain_challenge_input method must be implemented") + + +class PasswordChallenge(ReauthChallenge): + """Challenge that asks for user's password.""" + + @property + def name(self): + return "PASSWORD" + + @property + def is_locally_eligible(self): + return True + + @_helpers.copy_docstring(ReauthChallenge) + def obtain_challenge_input(self, unused_metadata): + passwd = get_user_password("Please enter your password:") + if not passwd: + passwd = " " # avoid the server crashing in case of no password :D + return {"credential": passwd} + + +class SecurityKeyChallenge(ReauthChallenge): + """Challenge that asks for user's security key touch.""" + + @property + def name(self): + return "SECURITY_KEY" + + @property + def is_locally_eligible(self): + return True + + @_helpers.copy_docstring(ReauthChallenge) + def obtain_challenge_input(self, metadata): + # Check if there is an available Webauthn Handler, if not use pyu2f + try: + factory = webauthn_handler_factory.WebauthnHandlerFactory() + webauthn_handler = factory.get_handler() + if webauthn_handler is not None: + sys.stderr.write("Please insert and touch your security key\n") + return self._obtain_challenge_input_webauthn(metadata, webauthn_handler) + except Exception: + # Attempt pyu2f if exception in webauthn flow + pass + + try: + import pyu2f.convenience.authenticator # type: ignore + import pyu2f.errors # type: ignore + import pyu2f.model # type: ignore + except ImportError: + raise exceptions.ReauthFailError( + "pyu2f dependency is required to use Security key reauth feature. " + "It can be installed via `pip install pyu2f` or `pip install google-auth[reauth]`." + ) + sk = metadata["securityKey"] + challenges = sk["challenges"] + # Read both 'applicationId' and 'relyingPartyId', if they are the same, use + # applicationId, if they are different, use relyingPartyId first and retry + # with applicationId + application_id = sk["applicationId"] + relying_party_id = sk["relyingPartyId"] + + if application_id != relying_party_id: + application_parameters = [relying_party_id, application_id] + else: + application_parameters = [application_id] + + challenge_data = [] + for c in challenges: + kh = c["keyHandle"].encode("ascii") + key = pyu2f.model.RegisteredKey(bytearray(base64.urlsafe_b64decode(kh))) + challenge = c["challenge"].encode("ascii") + challenge = base64.urlsafe_b64decode(challenge) + challenge_data.append({"key": key, "challenge": challenge}) + + # Track number of tries to suppress error message until all application_parameters + # are tried. + tries = 0 + for app_id in application_parameters: + try: + tries += 1 + api = pyu2f.convenience.authenticator.CreateCompositeAuthenticator( + REAUTH_ORIGIN + ) + response = api.Authenticate( + app_id, challenge_data, print_callback=sys.stderr.write + ) + return {"securityKey": response} + except pyu2f.errors.U2FError as e: + if e.code == pyu2f.errors.U2FError.DEVICE_INELIGIBLE: + # Only show error if all app_ids have been tried + if tries == len(application_parameters): + sys.stderr.write("Ineligible security key.\n") + return None + continue + if e.code == pyu2f.errors.U2FError.TIMEOUT: + sys.stderr.write( + "Timed out while waiting for security key touch.\n" + ) + else: + raise e + except pyu2f.errors.PluginError as e: + sys.stderr.write("Plugin error: {}.\n".format(e)) + continue + except pyu2f.errors.NoDeviceFoundError: + sys.stderr.write("No security key found.\n") + return None + + def _obtain_challenge_input_webauthn(self, metadata, webauthn_handler): + sk = metadata.get("securityKey") + if sk is None: + raise exceptions.InvalidValue("securityKey is None") + challenges = sk.get("challenges") + application_id = sk.get("applicationId") + relying_party_id = sk.get("relyingPartyId") + if challenges is None or len(challenges) < 1: + raise exceptions.InvalidValue("challenges is None or empty") + if application_id is None: + raise exceptions.InvalidValue("application_id is None") + if relying_party_id is None: + raise exceptions.InvalidValue("relying_party_id is None") + + allow_credentials = [] + for challenge in challenges: + kh = challenge.get("keyHandle") + if kh is None: + raise exceptions.InvalidValue("keyHandle is None") + key_handle = self._unpadded_urlsafe_b64recode(kh) + allow_credentials.append(PublicKeyCredentialDescriptor(id=key_handle)) + + extension = AuthenticationExtensionsClientInputs(appid=application_id) + + challenge = challenges[0].get("challenge") + if challenge is None: + raise exceptions.InvalidValue("challenge is None") + + get_request = GetRequest( + origin=REAUTH_ORIGIN, + rpid=relying_party_id, + challenge=self._unpadded_urlsafe_b64recode(challenge), + timeout_ms=WEBAUTHN_TIMEOUT_MS, + allow_credentials=allow_credentials, + user_verification="required", + extensions=extension, + ) + + try: + get_response = webauthn_handler.get(get_request) + except Exception as e: + sys.stderr.write("Webauthn Error: {}.\n".format(e)) + raise e + + response = { + "clientData": get_response.response.client_data_json, + "authenticatorData": get_response.response.authenticator_data, + "signatureData": get_response.response.signature, + "applicationId": application_id, + "keyHandle": get_response.id, + "securityKeyReplyType": 2, + } + return {"securityKey": response} + + def _unpadded_urlsafe_b64recode(self, s): + """Converts standard b64 encoded string to url safe b64 encoded string + with no padding.""" + b = base64.urlsafe_b64decode(s) + return base64.urlsafe_b64encode(b).decode().rstrip("=") + + +class SamlChallenge(ReauthChallenge): + """Challenge that asks the users to browse to their ID Providers. + + Currently SAML challenge is not supported. When obtaining the challenge + input, exception will be raised to instruct the users to run + `gcloud auth login` for reauthentication. + """ + + @property + def name(self): + return "SAML" + + @property + def is_locally_eligible(self): + return True + + def obtain_challenge_input(self, metadata): + # Magic Arch has not fully supported returning a proper dedirect URL + # for programmatic SAML users today. So we error our here and request + # users to use gcloud to complete a login. + raise exceptions.ReauthSamlChallengeFailError(SAML_CHALLENGE_MESSAGE) + + +AVAILABLE_CHALLENGES = { + challenge.name: challenge + for challenge in [SecurityKeyChallenge(), PasswordChallenge(), SamlChallenge()] +} diff --git a/lib/python3.10/site-packages/google/oauth2/credentials.py b/lib/python3.10/site-packages/google/oauth2/credentials.py new file mode 100644 index 0000000000000000000000000000000000000000..6e158089f31d22eab0c6f5e17055c4f7dff3570f --- /dev/null +++ b/lib/python3.10/site-packages/google/oauth2/credentials.py @@ -0,0 +1,614 @@ +# Copyright 2016 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""OAuth 2.0 Credentials. + +This module provides credentials based on OAuth 2.0 access and refresh tokens. +These credentials usually access resources on behalf of a user (resource +owner). + +Specifically, this is intended to use access tokens acquired using the +`Authorization Code grant`_ and can refresh those tokens using a +optional `refresh token`_. + +Obtaining the initial access and refresh token is outside of the scope of this +module. Consult `rfc6749 section 4.1`_ for complete details on the +Authorization Code grant flow. + +.. _Authorization Code grant: https://tools.ietf.org/html/rfc6749#section-1.3.1 +.. _refresh token: https://tools.ietf.org/html/rfc6749#section-6 +.. _rfc6749 section 4.1: https://tools.ietf.org/html/rfc6749#section-4.1 +""" + +from datetime import datetime +import io +import json +import logging +import warnings + +from google.auth import _cloud_sdk +from google.auth import _helpers +from google.auth import credentials +from google.auth import exceptions +from google.auth import metrics +from google.oauth2 import reauth + +_LOGGER = logging.getLogger(__name__) + + +# The Google OAuth 2.0 token endpoint. Used for authorized user credentials. +_GOOGLE_OAUTH2_TOKEN_ENDPOINT = "https://oauth2.googleapis.com/token" + +# The Google OAuth 2.0 token info endpoint. Used for getting token info JSON from access tokens. +_GOOGLE_OAUTH2_TOKEN_INFO_ENDPOINT = "https://oauth2.googleapis.com/tokeninfo" + + +class Credentials(credentials.ReadOnlyScoped, credentials.CredentialsWithQuotaProject): + """Credentials using OAuth 2.0 access and refresh tokens. + + The credentials are considered immutable except the tokens and the token + expiry, which are updated after refresh. If you want to modify the quota + project, use :meth:`with_quota_project` or :: + + credentials = credentials.with_quota_project('myproject-123') + + Reauth is disabled by default. To enable reauth, set the + `enable_reauth_refresh` parameter to True in the constructor. Note that + reauth feature is intended for gcloud to use only. + If reauth is enabled, `pyu2f` dependency has to be installed in order to use security + key reauth feature. Dependency can be installed via `pip install pyu2f` or `pip install + google-auth[reauth]`. + """ + + def __init__( + self, + token, + refresh_token=None, + id_token=None, + token_uri=None, + client_id=None, + client_secret=None, + scopes=None, + default_scopes=None, + quota_project_id=None, + expiry=None, + rapt_token=None, + refresh_handler=None, + enable_reauth_refresh=False, + granted_scopes=None, + trust_boundary=None, + universe_domain=credentials.DEFAULT_UNIVERSE_DOMAIN, + account=None, + ): + """ + Args: + token (Optional(str)): The OAuth 2.0 access token. Can be None + if refresh information is provided. + refresh_token (str): The OAuth 2.0 refresh token. If specified, + credentials can be refreshed. + id_token (str): The Open ID Connect ID Token. + token_uri (str): The OAuth 2.0 authorization server's token + endpoint URI. Must be specified for refresh, can be left as + None if the token can not be refreshed. + client_id (str): The OAuth 2.0 client ID. Must be specified for + refresh, can be left as None if the token can not be refreshed. + client_secret(str): The OAuth 2.0 client secret. Must be specified + for refresh, can be left as None if the token can not be + refreshed. + scopes (Sequence[str]): The scopes used to obtain authorization. + This parameter is used by :meth:`has_scopes`. OAuth 2.0 + credentials can not request additional scopes after + authorization. The scopes must be derivable from the refresh + token if refresh information is provided (e.g. The refresh + token scopes are a superset of this or contain a wild card + scope like 'https://www.googleapis.com/auth/any-api'). + default_scopes (Sequence[str]): Default scopes passed by a + Google client library. Use 'scopes' for user-defined scopes. + quota_project_id (Optional[str]): The project ID used for quota and billing. + This project may be different from the project used to + create the credentials. + rapt_token (Optional[str]): The reauth Proof Token. + refresh_handler (Optional[Callable[[google.auth.transport.Request, Sequence[str]], [str, datetime]]]): + A callable which takes in the HTTP request callable and the list of + OAuth scopes and when called returns an access token string for the + requested scopes and its expiry datetime. This is useful when no + refresh tokens are provided and tokens are obtained by calling + some external process on demand. It is particularly useful for + retrieving downscoped tokens from a token broker. + enable_reauth_refresh (Optional[bool]): Whether reauth refresh flow + should be used. This flag is for gcloud to use only. + granted_scopes (Optional[Sequence[str]]): The scopes that were consented/granted by the user. + This could be different from the requested scopes and it could be empty if granted + and requested scopes were same. + trust_boundary (str): String representation of trust boundary meta. + universe_domain (Optional[str]): The universe domain. The default + universe domain is googleapis.com. + account (Optional[str]): The account associated with the credential. + """ + super(Credentials, self).__init__() + self.token = token + self.expiry = expiry + self._refresh_token = refresh_token + self._id_token = id_token + self._scopes = scopes + self._default_scopes = default_scopes + self._granted_scopes = granted_scopes + self._token_uri = token_uri + self._client_id = client_id + self._client_secret = client_secret + self._quota_project_id = quota_project_id + self._rapt_token = rapt_token + self.refresh_handler = refresh_handler + self._enable_reauth_refresh = enable_reauth_refresh + self._trust_boundary = trust_boundary + self._universe_domain = universe_domain or credentials.DEFAULT_UNIVERSE_DOMAIN + self._account = account or "" + self._cred_file_path = None + + def __getstate__(self): + """A __getstate__ method must exist for the __setstate__ to be called + This is identical to the default implementation. + See https://docs.python.org/3.7/library/pickle.html#object.__setstate__ + """ + state_dict = self.__dict__.copy() + # Remove _refresh_handler function as there are limitations pickling and + # unpickling certain callables (lambda, functools.partial instances) + # because they need to be importable. + # Instead, the refresh_handler setter should be used to repopulate this. + if "_refresh_handler" in state_dict: + del state_dict["_refresh_handler"] + + if "_refresh_worker" in state_dict: + del state_dict["_refresh_worker"] + return state_dict + + def __setstate__(self, d): + """Credentials pickled with older versions of the class do not have + all the attributes.""" + self.token = d.get("token") + self.expiry = d.get("expiry") + self._refresh_token = d.get("_refresh_token") + self._id_token = d.get("_id_token") + self._scopes = d.get("_scopes") + self._default_scopes = d.get("_default_scopes") + self._granted_scopes = d.get("_granted_scopes") + self._token_uri = d.get("_token_uri") + self._client_id = d.get("_client_id") + self._client_secret = d.get("_client_secret") + self._quota_project_id = d.get("_quota_project_id") + self._rapt_token = d.get("_rapt_token") + self._enable_reauth_refresh = d.get("_enable_reauth_refresh") + self._trust_boundary = d.get("_trust_boundary") + self._universe_domain = ( + d.get("_universe_domain") or credentials.DEFAULT_UNIVERSE_DOMAIN + ) + self._cred_file_path = d.get("_cred_file_path") + # The refresh_handler setter should be used to repopulate this. + self._refresh_handler = None + self._refresh_worker = None + self._use_non_blocking_refresh = d.get("_use_non_blocking_refresh", False) + self._account = d.get("_account", "") + + @property + def refresh_token(self): + """Optional[str]: The OAuth 2.0 refresh token.""" + return self._refresh_token + + @property + def scopes(self): + """Optional[str]: The OAuth 2.0 permission scopes.""" + return self._scopes + + @property + def granted_scopes(self): + """Optional[Sequence[str]]: The OAuth 2.0 permission scopes that were granted by the user.""" + return self._granted_scopes + + @property + def token_uri(self): + """Optional[str]: The OAuth 2.0 authorization server's token endpoint + URI.""" + return self._token_uri + + @property + def id_token(self): + """Optional[str]: The Open ID Connect ID Token. + + Depending on the authorization server and the scopes requested, this + may be populated when credentials are obtained and updated when + :meth:`refresh` is called. This token is a JWT. It can be verified + and decoded using :func:`google.oauth2.id_token.verify_oauth2_token`. + """ + return self._id_token + + @property + def client_id(self): + """Optional[str]: The OAuth 2.0 client ID.""" + return self._client_id + + @property + def client_secret(self): + """Optional[str]: The OAuth 2.0 client secret.""" + return self._client_secret + + @property + def requires_scopes(self): + """False: OAuth 2.0 credentials have their scopes set when + the initial token is requested and can not be changed.""" + return False + + @property + def rapt_token(self): + """Optional[str]: The reauth Proof Token.""" + return self._rapt_token + + @property + def refresh_handler(self): + """Returns the refresh handler if available. + + Returns: + Optional[Callable[[google.auth.transport.Request, Sequence[str]], [str, datetime]]]: + The current refresh handler. + """ + return self._refresh_handler + + @refresh_handler.setter + def refresh_handler(self, value): + """Updates the current refresh handler. + + Args: + value (Optional[Callable[[google.auth.transport.Request, Sequence[str]], [str, datetime]]]): + The updated value of the refresh handler. + + Raises: + TypeError: If the value is not a callable or None. + """ + if not callable(value) and value is not None: + raise TypeError("The provided refresh_handler is not a callable or None.") + self._refresh_handler = value + + @property + def account(self): + """str: The user account associated with the credential. If the account is unknown an empty string is returned.""" + return self._account + + def _make_copy(self): + cred = self.__class__( + self.token, + refresh_token=self.refresh_token, + id_token=self.id_token, + token_uri=self.token_uri, + client_id=self.client_id, + client_secret=self.client_secret, + scopes=self.scopes, + default_scopes=self.default_scopes, + granted_scopes=self.granted_scopes, + quota_project_id=self.quota_project_id, + rapt_token=self.rapt_token, + enable_reauth_refresh=self._enable_reauth_refresh, + trust_boundary=self._trust_boundary, + universe_domain=self._universe_domain, + account=self._account, + ) + cred._cred_file_path = self._cred_file_path + return cred + + @_helpers.copy_docstring(credentials.Credentials) + def get_cred_info(self): + if self._cred_file_path: + cred_info = { + "credential_source": self._cred_file_path, + "credential_type": "user credentials", + } + if self.account: + cred_info["principal"] = self.account + return cred_info + return None + + @_helpers.copy_docstring(credentials.CredentialsWithQuotaProject) + def with_quota_project(self, quota_project_id): + cred = self._make_copy() + cred._quota_project_id = quota_project_id + return cred + + @_helpers.copy_docstring(credentials.CredentialsWithTokenUri) + def with_token_uri(self, token_uri): + cred = self._make_copy() + cred._token_uri = token_uri + return cred + + def with_account(self, account): + """Returns a copy of these credentials with a modified account. + + Args: + account (str): The account to set + + Returns: + google.oauth2.credentials.Credentials: A new credentials instance. + """ + cred = self._make_copy() + cred._account = account + return cred + + @_helpers.copy_docstring(credentials.CredentialsWithUniverseDomain) + def with_universe_domain(self, universe_domain): + cred = self._make_copy() + cred._universe_domain = universe_domain + return cred + + def _metric_header_for_usage(self): + return metrics.CRED_TYPE_USER + + @_helpers.copy_docstring(credentials.Credentials) + def refresh(self, request): + if self._universe_domain != credentials.DEFAULT_UNIVERSE_DOMAIN: + raise exceptions.RefreshError( + "User credential refresh is only supported in the default " + "googleapis.com universe domain, but the current universe " + "domain is {}. If you created the credential with an access " + "token, it's likely that the provided token is expired now, " + "please update your code with a valid token.".format( + self._universe_domain + ) + ) + + scopes = self._scopes if self._scopes is not None else self._default_scopes + # Use refresh handler if available and no refresh token is + # available. This is useful in general when tokens are obtained by calling + # some external process on demand. It is particularly useful for retrieving + # downscoped tokens from a token broker. + if self._refresh_token is None and self.refresh_handler: + token, expiry = self.refresh_handler(request, scopes=scopes) + # Validate returned data. + if not isinstance(token, str): + raise exceptions.RefreshError( + "The refresh_handler returned token is not a string." + ) + if not isinstance(expiry, datetime): + raise exceptions.RefreshError( + "The refresh_handler returned expiry is not a datetime object." + ) + if _helpers.utcnow() >= expiry - _helpers.REFRESH_THRESHOLD: + raise exceptions.RefreshError( + "The credentials returned by the refresh_handler are " + "already expired." + ) + self.token = token + self.expiry = expiry + return + + if ( + self._refresh_token is None + or self._token_uri is None + or self._client_id is None + or self._client_secret is None + ): + raise exceptions.RefreshError( + "The credentials do not contain the necessary fields need to " + "refresh the access token. You must specify refresh_token, " + "token_uri, client_id, and client_secret." + ) + + ( + access_token, + refresh_token, + expiry, + grant_response, + rapt_token, + ) = reauth.refresh_grant( + request, + self._token_uri, + self._refresh_token, + self._client_id, + self._client_secret, + scopes=scopes, + rapt_token=self._rapt_token, + enable_reauth_refresh=self._enable_reauth_refresh, + ) + + self.token = access_token + self.expiry = expiry + self._refresh_token = refresh_token + self._id_token = grant_response.get("id_token") + self._rapt_token = rapt_token + + if scopes and "scope" in grant_response: + requested_scopes = frozenset(scopes) + self._granted_scopes = grant_response["scope"].split() + granted_scopes = frozenset(self._granted_scopes) + scopes_requested_but_not_granted = requested_scopes - granted_scopes + if scopes_requested_but_not_granted: + # User might be presented with unbundled scopes at the time of + # consent. So it is a valid scenario to not have all the requested + # scopes as part of granted scopes but log a warning in case the + # developer wants to debug the scenario. + _LOGGER.warning( + "Not all requested scopes were granted by the " + "authorization server, missing scopes {}.".format( + ", ".join(scopes_requested_but_not_granted) + ) + ) + + @classmethod + def from_authorized_user_info(cls, info, scopes=None): + """Creates a Credentials instance from parsed authorized user info. + + Args: + info (Mapping[str, str]): The authorized user info in Google + format. + scopes (Sequence[str]): Optional list of scopes to include in the + credentials. + + Returns: + google.oauth2.credentials.Credentials: The constructed + credentials. + + Raises: + ValueError: If the info is not in the expected format. + """ + keys_needed = set(("refresh_token", "client_id", "client_secret")) + missing = keys_needed.difference(info.keys()) + + if missing: + raise ValueError( + "Authorized user info was not in the expected format, missing " + "fields {}.".format(", ".join(missing)) + ) + + # access token expiry (datetime obj); auto-expire if not saved + expiry = info.get("expiry") + if expiry: + expiry = datetime.strptime( + expiry.rstrip("Z").split(".")[0], "%Y-%m-%dT%H:%M:%S" + ) + else: + expiry = _helpers.utcnow() - _helpers.REFRESH_THRESHOLD + + # process scopes, which needs to be a seq + if scopes is None and "scopes" in info: + scopes = info.get("scopes") + if isinstance(scopes, str): + scopes = scopes.split(" ") + + return cls( + token=info.get("token"), + refresh_token=info.get("refresh_token"), + token_uri=_GOOGLE_OAUTH2_TOKEN_ENDPOINT, # always overrides + scopes=scopes, + client_id=info.get("client_id"), + client_secret=info.get("client_secret"), + quota_project_id=info.get("quota_project_id"), # may not exist + expiry=expiry, + rapt_token=info.get("rapt_token"), # may not exist + trust_boundary=info.get("trust_boundary"), # may not exist + universe_domain=info.get("universe_domain"), # may not exist + account=info.get("account", ""), # may not exist + ) + + @classmethod + def from_authorized_user_file(cls, filename, scopes=None): + """Creates a Credentials instance from an authorized user json file. + + Args: + filename (str): The path to the authorized user json file. + scopes (Sequence[str]): Optional list of scopes to include in the + credentials. + + Returns: + google.oauth2.credentials.Credentials: The constructed + credentials. + + Raises: + ValueError: If the file is not in the expected format. + """ + with io.open(filename, "r", encoding="utf-8") as json_file: + data = json.load(json_file) + return cls.from_authorized_user_info(data, scopes) + + def to_json(self, strip=None): + """Utility function that creates a JSON representation of a Credentials + object. + + Args: + strip (Sequence[str]): Optional list of members to exclude from the + generated JSON. + + Returns: + str: A JSON representation of this instance. When converted into + a dictionary, it can be passed to from_authorized_user_info() + to create a new credential instance. + """ + prep = { + "token": self.token, + "refresh_token": self.refresh_token, + "token_uri": self.token_uri, + "client_id": self.client_id, + "client_secret": self.client_secret, + "scopes": self.scopes, + "rapt_token": self.rapt_token, + "universe_domain": self._universe_domain, + "account": self._account, + } + if self.expiry: # flatten expiry timestamp + prep["expiry"] = self.expiry.isoformat() + "Z" + + # Remove empty entries (those which are None) + prep = {k: v for k, v in prep.items() if v is not None} + + # Remove entries that explicitely need to be removed + if strip is not None: + prep = {k: v for k, v in prep.items() if k not in strip} + + return json.dumps(prep) + + +class UserAccessTokenCredentials(credentials.CredentialsWithQuotaProject): + """Access token credentials for user account. + + Obtain the access token for a given user account or the current active + user account with the ``gcloud auth print-access-token`` command. + + Args: + account (Optional[str]): Account to get the access token for. If not + specified, the current active account will be used. + quota_project_id (Optional[str]): The project ID used for quota + and billing. + """ + + def __init__(self, account=None, quota_project_id=None): + warnings.warn( + "UserAccessTokenCredentials is deprecated, please use " + "google.oauth2.credentials.Credentials instead. To use " + "that credential type, simply run " + "`gcloud auth application-default login` and let the " + "client libraries pick up the application default credentials." + ) + super(UserAccessTokenCredentials, self).__init__() + self._account = account + self._quota_project_id = quota_project_id + + def with_account(self, account): + """Create a new instance with the given account. + + Args: + account (str): Account to get the access token for. + + Returns: + google.oauth2.credentials.UserAccessTokenCredentials: The created + credentials with the given account. + """ + return self.__class__(account=account, quota_project_id=self._quota_project_id) + + @_helpers.copy_docstring(credentials.CredentialsWithQuotaProject) + def with_quota_project(self, quota_project_id): + return self.__class__(account=self._account, quota_project_id=quota_project_id) + + def refresh(self, request): + """Refreshes the access token. + + Args: + request (google.auth.transport.Request): This argument is required + by the base class interface but not used in this implementation, + so just set it to `None`. + + Raises: + google.auth.exceptions.UserAccessTokenError: If the access token + refresh failed. + """ + self.token = _cloud_sdk.get_auth_access_token(self._account) + + @_helpers.copy_docstring(credentials.Credentials) + def before_request(self, request, method, url, headers): + self.refresh(request) + self.apply(headers) diff --git a/lib/python3.10/site-packages/google/oauth2/gdch_credentials.py b/lib/python3.10/site-packages/google/oauth2/gdch_credentials.py new file mode 100644 index 0000000000000000000000000000000000000000..7410cfc2e05eaa84ab490ab4c5ca6c80b4d61891 --- /dev/null +++ b/lib/python3.10/site-packages/google/oauth2/gdch_credentials.py @@ -0,0 +1,251 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Experimental GDCH credentials support. +""" + +import datetime + +from google.auth import _helpers +from google.auth import _service_account_info +from google.auth import credentials +from google.auth import exceptions +from google.auth import jwt +from google.oauth2 import _client + + +TOKEN_EXCHANGE_TYPE = "urn:ietf:params:oauth:token-type:token-exchange" +ACCESS_TOKEN_TOKEN_TYPE = "urn:ietf:params:oauth:token-type:access_token" +SERVICE_ACCOUNT_TOKEN_TYPE = "urn:k8s:params:oauth:token-type:serviceaccount" +JWT_LIFETIME = datetime.timedelta(seconds=3600) # 1 hour + + +class ServiceAccountCredentials(credentials.Credentials): + """Credentials for GDCH (`Google Distributed Cloud Hosted`_) for service + account users. + + .. _Google Distributed Cloud Hosted: + https://cloud.google.com/blog/topics/hybrid-cloud/\ + announcing-google-distributed-cloud-edge-and-hosted + + To create a GDCH service account credential, first create a JSON file of + the following format:: + + { + "type": "gdch_service_account", + "format_version": "1", + "project": "", + "private_key_id": "", + "private_key": "-----BEGIN EC PRIVATE KEY-----\n\n-----END EC PRIVATE KEY-----\n", + "name": "", + "ca_cert_path": "", + "token_uri": "https://service-identity./authenticate" + } + + The "format_version" field stands for the format of the JSON file. For now + it is always "1". The `private_key_id` and `private_key` is used for signing. + The `ca_cert_path` is used for token server TLS certificate verification. + + After the JSON file is created, set `GOOGLE_APPLICATION_CREDENTIALS` environment + variable to the JSON file path, then use the following code to create the + credential:: + + import google.auth + + credential, _ = google.auth.default() + credential = credential.with_gdch_audience("") + + We can also create the credential directly:: + + from google.oauth import gdch_credentials + + credential = gdch_credentials.ServiceAccountCredentials.from_service_account_file("") + credential = credential.with_gdch_audience("") + + The token is obtained in the following way. This class first creates a + self signed JWT. It uses the `name` value as the `iss` and `sub` claim, and + the `token_uri` as the `aud` claim, and signs the JWT with the `private_key`. + It then sends the JWT to the `token_uri` to exchange a final token for + `audience`. + """ + + def __init__( + self, signer, service_identity_name, project, audience, token_uri, ca_cert_path + ): + """ + Args: + signer (google.auth.crypt.Signer): The signer used to sign JWTs. + service_identity_name (str): The service identity name. It will be + used as the `iss` and `sub` claim in the self signed JWT. + project (str): The project. + audience (str): The audience for the final token. + token_uri (str): The token server uri. + ca_cert_path (str): The CA cert path for token server side TLS + certificate verification. If the token server uses well known + CA, then this parameter can be `None`. + """ + super(ServiceAccountCredentials, self).__init__() + self._signer = signer + self._service_identity_name = service_identity_name + self._project = project + self._audience = audience + self._token_uri = token_uri + self._ca_cert_path = ca_cert_path + + def _create_jwt(self): + now = _helpers.utcnow() + expiry = now + JWT_LIFETIME + iss_sub_value = "system:serviceaccount:{}:{}".format( + self._project, self._service_identity_name + ) + + payload = { + "iss": iss_sub_value, + "sub": iss_sub_value, + "aud": self._token_uri, + "iat": _helpers.datetime_to_secs(now), + "exp": _helpers.datetime_to_secs(expiry), + } + + return _helpers.from_bytes(jwt.encode(self._signer, payload)) + + @_helpers.copy_docstring(credentials.Credentials) + def refresh(self, request): + import google.auth.transport.requests + + if not isinstance(request, google.auth.transport.requests.Request): + raise exceptions.RefreshError( + "For GDCH service account credentials, request must be a google.auth.transport.requests.Request object" + ) + + # Create a self signed JWT, and do token exchange. + jwt_token = self._create_jwt() + request_body = { + "grant_type": TOKEN_EXCHANGE_TYPE, + "audience": self._audience, + "requested_token_type": ACCESS_TOKEN_TOKEN_TYPE, + "subject_token": jwt_token, + "subject_token_type": SERVICE_ACCOUNT_TOKEN_TYPE, + } + response_data = _client._token_endpoint_request( + request, + self._token_uri, + request_body, + access_token=None, + use_json=True, + verify=self._ca_cert_path, + ) + + self.token, _, self.expiry, _ = _client._handle_refresh_grant_response( + response_data, None + ) + + def with_gdch_audience(self, audience): + """Create a copy of GDCH credentials with the specified audience. + + Args: + audience (str): The intended audience for GDCH credentials. + """ + return self.__class__( + self._signer, + self._service_identity_name, + self._project, + audience, + self._token_uri, + self._ca_cert_path, + ) + + @classmethod + def _from_signer_and_info(cls, signer, info): + """Creates a Credentials instance from a signer and service account + info. + + Args: + signer (google.auth.crypt.Signer): The signer used to sign JWTs. + info (Mapping[str, str]): The service account info. + + Returns: + google.oauth2.gdch_credentials.ServiceAccountCredentials: The constructed + credentials. + + Raises: + ValueError: If the info is not in the expected format. + """ + if info["format_version"] != "1": + raise ValueError("Only format version 1 is supported") + + return cls( + signer, + info["name"], # service_identity_name + info["project"], + None, # audience + info["token_uri"], + info.get("ca_cert_path", None), + ) + + @classmethod + def from_service_account_info(cls, info): + """Creates a Credentials instance from parsed service account info. + + Args: + info (Mapping[str, str]): The service account info in Google + format. + kwargs: Additional arguments to pass to the constructor. + + Returns: + google.oauth2.gdch_credentials.ServiceAccountCredentials: The constructed + credentials. + + Raises: + ValueError: If the info is not in the expected format. + """ + signer = _service_account_info.from_dict( + info, + require=[ + "format_version", + "private_key_id", + "private_key", + "name", + "project", + "token_uri", + ], + use_rsa_signer=False, + ) + return cls._from_signer_and_info(signer, info) + + @classmethod + def from_service_account_file(cls, filename): + """Creates a Credentials instance from a service account json file. + + Args: + filename (str): The path to the service account json file. + kwargs: Additional arguments to pass to the constructor. + + Returns: + google.oauth2.gdch_credentials.ServiceAccountCredentials: The constructed + credentials. + """ + info, signer = _service_account_info.from_filename( + filename, + require=[ + "format_version", + "private_key_id", + "private_key", + "name", + "project", + "token_uri", + ], + use_rsa_signer=False, + ) + return cls._from_signer_and_info(signer, info) diff --git a/lib/python3.10/site-packages/google/oauth2/id_token.py b/lib/python3.10/site-packages/google/oauth2/id_token.py new file mode 100644 index 0000000000000000000000000000000000000000..a6c51ce63814c27d496a397a799632896358ad98 --- /dev/null +++ b/lib/python3.10/site-packages/google/oauth2/id_token.py @@ -0,0 +1,370 @@ +# Copyright 2016 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Google ID Token helpers. + +Provides support for verifying `OpenID Connect ID Tokens`_, especially ones +generated by Google infrastructure. + +To parse and verify an ID Token issued by Google's OAuth 2.0 authorization +server use :func:`verify_oauth2_token`. To verify an ID Token issued by +Firebase, use :func:`verify_firebase_token`. + +A general purpose ID Token verifier is available as :func:`verify_token`. + +Example:: + + from google.oauth2 import id_token + from google.auth.transport import requests + + request = requests.Request() + + id_info = id_token.verify_oauth2_token( + token, request, 'my-client-id.example.com') + + userid = id_info['sub'] + +By default, this will re-fetch certificates for each verification. Because +Google's public keys are only changed infrequently (on the order of once per +day), you may wish to take advantage of caching to reduce latency and the +potential for network errors. This can be accomplished using an external +library like `CacheControl`_ to create a cache-aware +:class:`google.auth.transport.Request`:: + + import cachecontrol + import google.auth.transport.requests + import requests + + session = requests.session() + cached_session = cachecontrol.CacheControl(session) + request = google.auth.transport.requests.Request(session=cached_session) + +.. _OpenID Connect ID Tokens: + http://openid.net/specs/openid-connect-core-1_0.html#IDToken +.. _CacheControl: https://cachecontrol.readthedocs.io +""" + +import http.client as http_client +import json +import os + +from google.auth import environment_vars +from google.auth import exceptions +from google.auth import jwt + + +# The URL that provides public certificates for verifying ID tokens issued +# by Google's OAuth 2.0 authorization server. +_GOOGLE_OAUTH2_CERTS_URL = "https://www.googleapis.com/oauth2/v1/certs" + +# The URL that provides public certificates for verifying ID tokens issued +# by Firebase and the Google APIs infrastructure +_GOOGLE_APIS_CERTS_URL = ( + "https://www.googleapis.com/robot/v1/metadata/x509" + "/securetoken@system.gserviceaccount.com" +) + +_GOOGLE_ISSUERS = ["accounts.google.com", "https://accounts.google.com"] + + +def _fetch_certs(request, certs_url): + """Fetches certificates. + + Google-style cerificate endpoints return JSON in the format of + ``{'key id': 'x509 certificate'}`` or a certificate array according + to the JWK spec (see https://tools.ietf.org/html/rfc7517). + + Args: + request (google.auth.transport.Request): The object used to make + HTTP requests. + certs_url (str): The certificate endpoint URL. + + Returns: + Mapping[str, str] | Mapping[str, list]: A mapping of public keys + in x.509 or JWK spec. + """ + response = request(certs_url, method="GET") + + if response.status != http_client.OK: + raise exceptions.TransportError( + "Could not fetch certificates at {}".format(certs_url) + ) + + return json.loads(response.data.decode("utf-8")) + + +def verify_token( + id_token, + request, + audience=None, + certs_url=_GOOGLE_OAUTH2_CERTS_URL, + clock_skew_in_seconds=0, +): + """Verifies an ID token and returns the decoded token. + + Args: + id_token (Union[str, bytes]): The encoded token. + request (google.auth.transport.Request): The object used to make + HTTP requests. + audience (str or list): The audience or audiences that this token is + intended for. If None then the audience is not verified. + certs_url (str): The URL that specifies the certificates to use to + verify the token. This URL should return JSON in the format of + ``{'key id': 'x509 certificate'}`` or a certificate array according to + the JWK spec (see https://tools.ietf.org/html/rfc7517). + clock_skew_in_seconds (int): The clock skew used for `iat` and `exp` + validation. + + Returns: + Mapping[str, Any]: The decoded token. + """ + certs = _fetch_certs(request, certs_url) + + if "keys" in certs: + try: + import jwt as jwt_lib # type: ignore + except ImportError as caught_exc: # pragma: NO COVER + raise ImportError( + "The pyjwt library is not installed, please install the pyjwt package to use the jwk certs format." + ) from caught_exc + jwks_client = jwt_lib.PyJWKClient(certs_url) + signing_key = jwks_client.get_signing_key_from_jwt(id_token) + return jwt_lib.decode( + id_token, + signing_key.key, + algorithms=[signing_key.algorithm_name], + audience=audience, + ) + else: + return jwt.decode( + id_token, + certs=certs, + audience=audience, + clock_skew_in_seconds=clock_skew_in_seconds, + ) + + +def verify_oauth2_token(id_token, request, audience=None, clock_skew_in_seconds=0): + """Verifies an ID Token issued by Google's OAuth 2.0 authorization server. + + Args: + id_token (Union[str, bytes]): The encoded token. + request (google.auth.transport.Request): The object used to make + HTTP requests. + audience (str): The audience that this token is intended for. This is + typically your application's OAuth 2.0 client ID. If None then the + audience is not verified. + clock_skew_in_seconds (int): The clock skew used for `iat` and `exp` + validation. + + Returns: + Mapping[str, Any]: The decoded token. + + Raises: + exceptions.GoogleAuthError: If the issuer is invalid. + ValueError: If token verification fails + """ + idinfo = verify_token( + id_token, + request, + audience=audience, + certs_url=_GOOGLE_OAUTH2_CERTS_URL, + clock_skew_in_seconds=clock_skew_in_seconds, + ) + + if idinfo["iss"] not in _GOOGLE_ISSUERS: + raise exceptions.GoogleAuthError( + "Wrong issuer. 'iss' should be one of the following: {}".format( + _GOOGLE_ISSUERS + ) + ) + + return idinfo + + +def verify_firebase_token(id_token, request, audience=None, clock_skew_in_seconds=0): + """Verifies an ID Token issued by Firebase Authentication. + + Args: + id_token (Union[str, bytes]): The encoded token. + request (google.auth.transport.Request): The object used to make + HTTP requests. + audience (str): The audience that this token is intended for. This is + typically your Firebase application ID. If None then the audience + is not verified. + clock_skew_in_seconds (int): The clock skew used for `iat` and `exp` + validation. + + Returns: + Mapping[str, Any]: The decoded token. + """ + return verify_token( + id_token, + request, + audience=audience, + certs_url=_GOOGLE_APIS_CERTS_URL, + clock_skew_in_seconds=clock_skew_in_seconds, + ) + + +def fetch_id_token_credentials(audience, request=None): + """Create the ID Token credentials from the current environment. + + This function acquires ID token from the environment in the following order. + See https://google.aip.dev/auth/4110. + + 1. If the environment variable ``GOOGLE_APPLICATION_CREDENTIALS`` is set + to the path of a valid service account JSON file, then ID token is + acquired using this service account credentials. + 2. If the application is running in Compute Engine, App Engine or Cloud Run, + then the ID token are obtained from the metadata server. + 3. If metadata server doesn't exist and no valid service account credentials + are found, :class:`~google.auth.exceptions.DefaultCredentialsError` will + be raised. + + Example:: + + import google.oauth2.id_token + import google.auth.transport.requests + + request = google.auth.transport.requests.Request() + target_audience = "https://pubsub.googleapis.com" + + # Create ID token credentials. + credentials = google.oauth2.id_token.fetch_id_token_credentials(target_audience, request=request) + + # Refresh the credential to obtain an ID token. + credentials.refresh(request) + + id_token = credentials.token + id_token_expiry = credentials.expiry + + Args: + audience (str): The audience that this ID token is intended for. + request (Optional[google.auth.transport.Request]): A callable used to make + HTTP requests. A request object will be created if not provided. + + Returns: + google.auth.credentials.Credentials: The ID token credentials. + + Raises: + ~google.auth.exceptions.DefaultCredentialsError: + If metadata server doesn't exist and no valid service account + credentials are found. + """ + # 1. Try to get credentials from the GOOGLE_APPLICATION_CREDENTIALS environment + # variable. + credentials_filename = os.environ.get(environment_vars.CREDENTIALS) + if credentials_filename: + if not ( + os.path.exists(credentials_filename) + and os.path.isfile(credentials_filename) + ): + raise exceptions.DefaultCredentialsError( + "GOOGLE_APPLICATION_CREDENTIALS path is either not found or invalid." + ) + + try: + with open(credentials_filename, "r") as f: + from google.oauth2 import service_account + + info = json.load(f) + if info.get("type") == "service_account": + return service_account.IDTokenCredentials.from_service_account_info( + info, target_audience=audience + ) + elif info.get("type") == "impersonated_service_account": + from google.auth import impersonated_credentials + + target_credentials = impersonated_credentials.Credentials.from_impersonated_service_account_info( + info + ) + + return impersonated_credentials.IDTokenCredentials( + target_credentials=target_credentials, + target_audience=audience, + include_email=True, + ) + except ValueError as caught_exc: + new_exc = exceptions.DefaultCredentialsError( + "GOOGLE_APPLICATION_CREDENTIALS is not valid service account credentials.", + caught_exc, + ) + raise new_exc from caught_exc + + # 2. Try to fetch ID token from metada server if it exists. The code + # works for GAE and Cloud Run metadata server as well. + try: + from google.auth import compute_engine + from google.auth.compute_engine import _metadata + + # Create a request object if not provided. + if not request: + import google.auth.transport.requests + + request = google.auth.transport.requests.Request() + + if _metadata.ping(request): + return compute_engine.IDTokenCredentials( + request, audience, use_metadata_identity_endpoint=True + ) + except (ImportError, exceptions.TransportError): + pass + + raise exceptions.DefaultCredentialsError( + "Neither metadata server or valid service account credentials are found." + ) + + +def fetch_id_token(request, audience): + """Fetch the ID Token from the current environment. + + This function acquires ID token from the environment in the following order. + See https://google.aip.dev/auth/4110. + + 1. If the environment variable ``GOOGLE_APPLICATION_CREDENTIALS`` is set + to the path of a valid service account JSON file, then ID token is + acquired using this service account credentials. + 2. If the application is running in Compute Engine, App Engine or Cloud Run, + then the ID token are obtained from the metadata server. + 3. If metadata server doesn't exist and no valid service account credentials + are found, :class:`~google.auth.exceptions.DefaultCredentialsError` will + be raised. + + Example:: + + import google.oauth2.id_token + import google.auth.transport.requests + + request = google.auth.transport.requests.Request() + target_audience = "https://pubsub.googleapis.com" + + id_token = google.oauth2.id_token.fetch_id_token(request, target_audience) + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. + audience (str): The audience that this ID token is intended for. + + Returns: + str: The ID token. + + Raises: + ~google.auth.exceptions.DefaultCredentialsError: + If metadata server doesn't exist and no valid service account + credentials are found. + """ + id_token_credentials = fetch_id_token_credentials(audience, request=request) + id_token_credentials.refresh(request) + return id_token_credentials.token diff --git a/lib/python3.10/site-packages/google/oauth2/py.typed b/lib/python3.10/site-packages/google/oauth2/py.typed new file mode 100644 index 0000000000000000000000000000000000000000..d82ed62c2f052f2d46ea9843ac82ed475f223593 --- /dev/null +++ b/lib/python3.10/site-packages/google/oauth2/py.typed @@ -0,0 +1,2 @@ +# Marker file for PEP 561. +# The google-oauth2 package uses inline types. diff --git a/lib/python3.10/site-packages/google/oauth2/reauth.py b/lib/python3.10/site-packages/google/oauth2/reauth.py new file mode 100644 index 0000000000000000000000000000000000000000..1e39e0bc7f410b5a34c7bdbc6856b8db08d7d979 --- /dev/null +++ b/lib/python3.10/site-packages/google/oauth2/reauth.py @@ -0,0 +1,369 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A module that provides functions for handling rapt authentication. + +Reauth is a process of obtaining additional authentication (such as password, +security token, etc.) while refreshing OAuth 2.0 credentials for a user. + +Credentials that use the Reauth flow must have the reauth scope, +``https://www.googleapis.com/auth/accounts.reauth``. + +This module provides a high-level function for executing the Reauth process, +:func:`refresh_grant`, and lower-level helpers for doing the individual +steps of the reauth process. + +Those steps are: + +1. Obtaining a list of challenges from the reauth server. +2. Running through each challenge and sending the result back to the reauth + server. +3. Refreshing the access token using the returned rapt token. +""" + +import sys + +from google.auth import exceptions +from google.auth import metrics +from google.oauth2 import _client +from google.oauth2 import challenges + + +_REAUTH_SCOPE = "https://www.googleapis.com/auth/accounts.reauth" +_REAUTH_API = "https://reauth.googleapis.com/v2/sessions" + +_REAUTH_NEEDED_ERROR = "invalid_grant" +_REAUTH_NEEDED_ERROR_INVALID_RAPT = "invalid_rapt" +_REAUTH_NEEDED_ERROR_RAPT_REQUIRED = "rapt_required" + +_AUTHENTICATED = "AUTHENTICATED" +_CHALLENGE_REQUIRED = "CHALLENGE_REQUIRED" +_CHALLENGE_PENDING = "CHALLENGE_PENDING" + + +# Override this global variable to set custom max number of rounds of reauth +# challenges should be run. +RUN_CHALLENGE_RETRY_LIMIT = 5 + + +def is_interactive(): + """Check if we are in an interractive environment. + + Override this function with a different logic if you are using this library + outside a CLI. + + If the rapt token needs refreshing, the user needs to answer the challenges. + If the user is not in an interractive environment, the challenges can not + be answered and we just wait for timeout for no reason. + + Returns: + bool: True if is interactive environment, False otherwise. + """ + + return sys.stdin.isatty() + + +def _get_challenges( + request, supported_challenge_types, access_token, requested_scopes=None +): + """Does initial request to reauth API to get the challenges. + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. + supported_challenge_types (Sequence[str]): list of challenge names + supported by the manager. + access_token (str): Access token with reauth scopes. + requested_scopes (Optional(Sequence[str])): Authorized scopes for the credentials. + + Returns: + dict: The response from the reauth API. + """ + body = {"supportedChallengeTypes": supported_challenge_types} + if requested_scopes: + body["oauthScopesForDomainPolicyLookup"] = requested_scopes + metrics_header = {metrics.API_CLIENT_HEADER: metrics.reauth_start()} + + return _client._token_endpoint_request( + request, + _REAUTH_API + ":start", + body, + access_token=access_token, + use_json=True, + headers=metrics_header, + ) + + +def _send_challenge_result( + request, session_id, challenge_id, client_input, access_token +): + """Attempt to refresh access token by sending next challenge result. + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. + session_id (str): session id returned by the initial reauth call. + challenge_id (str): challenge id returned by the initial reauth call. + client_input: dict with a challenge-specific client input. For example: + ``{'credential': password}`` for password challenge. + access_token (str): Access token with reauth scopes. + + Returns: + dict: The response from the reauth API. + """ + body = { + "sessionId": session_id, + "challengeId": challenge_id, + "action": "RESPOND", + "proposalResponse": client_input, + } + metrics_header = {metrics.API_CLIENT_HEADER: metrics.reauth_continue()} + + return _client._token_endpoint_request( + request, + _REAUTH_API + "/{}:continue".format(session_id), + body, + access_token=access_token, + use_json=True, + headers=metrics_header, + ) + + +def _run_next_challenge(msg, request, access_token): + """Get the next challenge from msg and run it. + + Args: + msg (dict): Reauth API response body (either from the initial request to + https://reauth.googleapis.com/v2/sessions:start or from sending the + previous challenge response to + https://reauth.googleapis.com/v2/sessions/id:continue) + request (google.auth.transport.Request): A callable used to make + HTTP requests. + access_token (str): reauth access token + + Returns: + dict: The response from the reauth API. + + Raises: + google.auth.exceptions.ReauthError: if reauth failed. + """ + for challenge in msg["challenges"]: + if challenge["status"] != "READY": + # Skip non-activated challenges. + continue + c = challenges.AVAILABLE_CHALLENGES.get(challenge["challengeType"], None) + if not c: + raise exceptions.ReauthFailError( + "Unsupported challenge type {0}. Supported types: {1}".format( + challenge["challengeType"], + ",".join(list(challenges.AVAILABLE_CHALLENGES.keys())), + ) + ) + if not c.is_locally_eligible: + raise exceptions.ReauthFailError( + "Challenge {0} is not locally eligible".format( + challenge["challengeType"] + ) + ) + client_input = c.obtain_challenge_input(challenge) + if not client_input: + return None + return _send_challenge_result( + request, + msg["sessionId"], + challenge["challengeId"], + client_input, + access_token, + ) + return None + + +def _obtain_rapt(request, access_token, requested_scopes): + """Given an http request method and reauth access token, get rapt token. + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. + access_token (str): reauth access token + requested_scopes (Sequence[str]): scopes required by the client application + + Returns: + str: The rapt token. + + Raises: + google.auth.exceptions.ReauthError: if reauth failed + """ + msg = _get_challenges( + request, + list(challenges.AVAILABLE_CHALLENGES.keys()), + access_token, + requested_scopes, + ) + + if msg["status"] == _AUTHENTICATED: + return msg["encodedProofOfReauthToken"] + + for _ in range(0, RUN_CHALLENGE_RETRY_LIMIT): + if not ( + msg["status"] == _CHALLENGE_REQUIRED or msg["status"] == _CHALLENGE_PENDING + ): + raise exceptions.ReauthFailError( + "Reauthentication challenge failed due to API error: {}".format( + msg["status"] + ) + ) + + if not is_interactive(): + raise exceptions.ReauthFailError( + "Reauthentication challenge could not be answered because you are not" + " in an interactive session." + ) + + msg = _run_next_challenge(msg, request, access_token) + + if not msg: + raise exceptions.ReauthFailError("Failed to obtain rapt token.") + if msg["status"] == _AUTHENTICATED: + return msg["encodedProofOfReauthToken"] + + # If we got here it means we didn't get authenticated. + raise exceptions.ReauthFailError("Failed to obtain rapt token.") + + +def get_rapt_token( + request, client_id, client_secret, refresh_token, token_uri, scopes=None +): + """Given an http request method and refresh_token, get rapt token. + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. + client_id (str): client id to get access token for reauth scope. + client_secret (str): client secret for the client_id + refresh_token (str): refresh token to refresh access token + token_uri (str): uri to refresh access token + scopes (Optional(Sequence[str])): scopes required by the client application + + Returns: + str: The rapt token. + Raises: + google.auth.exceptions.RefreshError: If reauth failed. + """ + sys.stderr.write("Reauthentication required.\n") + + # Get access token for reauth. + access_token, _, _, _ = _client.refresh_grant( + request=request, + client_id=client_id, + client_secret=client_secret, + refresh_token=refresh_token, + token_uri=token_uri, + scopes=[_REAUTH_SCOPE], + ) + + # Get rapt token from reauth API. + rapt_token = _obtain_rapt(request, access_token, requested_scopes=scopes) + sys.stderr.write("Reauthentication successful.\n") + + return rapt_token + + +def refresh_grant( + request, + token_uri, + refresh_token, + client_id, + client_secret, + scopes=None, + rapt_token=None, + enable_reauth_refresh=False, +): + """Implements the reauthentication flow. + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. + token_uri (str): The OAuth 2.0 authorizations server's token endpoint + URI. + refresh_token (str): The refresh token to use to get a new access + token. + client_id (str): The OAuth 2.0 application's client ID. + client_secret (str): The Oauth 2.0 appliaction's client secret. + scopes (Optional(Sequence[str])): Scopes to request. If present, all + scopes must be authorized for the refresh token. Useful if refresh + token has a wild card scope (e.g. + 'https://www.googleapis.com/auth/any-api'). + rapt_token (Optional(str)): The rapt token for reauth. + enable_reauth_refresh (Optional[bool]): Whether reauth refresh flow + should be used. The default value is False. This option is for + gcloud only, other users should use the default value. + + Returns: + Tuple[str, Optional[str], Optional[datetime], Mapping[str, str], str]: The + access token, new refresh token, expiration, the additional data + returned by the token endpoint, and the rapt token. + + Raises: + google.auth.exceptions.RefreshError: If the token endpoint returned + an error. + """ + body = { + "grant_type": _client._REFRESH_GRANT_TYPE, + "client_id": client_id, + "client_secret": client_secret, + "refresh_token": refresh_token, + } + if scopes: + body["scope"] = " ".join(scopes) + if rapt_token: + body["rapt"] = rapt_token + metrics_header = {metrics.API_CLIENT_HEADER: metrics.token_request_user()} + + response_status_ok, response_data, retryable_error = _client._token_endpoint_request_no_throw( + request, token_uri, body, headers=metrics_header + ) + + if not response_status_ok and isinstance(response_data, str): + raise exceptions.RefreshError(response_data, retryable=False) + + if ( + not response_status_ok + and response_data.get("error") == _REAUTH_NEEDED_ERROR + and ( + response_data.get("error_subtype") == _REAUTH_NEEDED_ERROR_INVALID_RAPT + or response_data.get("error_subtype") == _REAUTH_NEEDED_ERROR_RAPT_REQUIRED + ) + ): + if not enable_reauth_refresh: + raise exceptions.RefreshError( + "Reauthentication is needed. Please run `gcloud auth application-default login` to reauthenticate." + ) + + rapt_token = get_rapt_token( + request, client_id, client_secret, refresh_token, token_uri, scopes=scopes + ) + body["rapt"] = rapt_token + ( + response_status_ok, + response_data, + retryable_error, + ) = _client._token_endpoint_request_no_throw( + request, token_uri, body, headers=metrics_header + ) + + if not response_status_ok: + _client._handle_error_response(response_data, retryable_error) + return _client._handle_refresh_grant_response(response_data, refresh_token) + ( + rapt_token, + ) diff --git a/lib/python3.10/site-packages/google/oauth2/service_account.py b/lib/python3.10/site-packages/google/oauth2/service_account.py new file mode 100644 index 0000000000000000000000000000000000000000..3e84194ac7d10a12e27c916adc20fa93532c9835 --- /dev/null +++ b/lib/python3.10/site-packages/google/oauth2/service_account.py @@ -0,0 +1,847 @@ +# Copyright 2016 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Service Accounts: JSON Web Token (JWT) Profile for OAuth 2.0 + +This module implements the JWT Profile for OAuth 2.0 Authorization Grants +as defined by `RFC 7523`_ with particular support for how this RFC is +implemented in Google's infrastructure. Google refers to these credentials +as *Service Accounts*. + +Service accounts are used for server-to-server communication, such as +interactions between a web application server and a Google service. The +service account belongs to your application instead of to an individual end +user. In contrast to other OAuth 2.0 profiles, no users are involved and your +application "acts" as the service account. + +Typically an application uses a service account when the application uses +Google APIs to work with its own data rather than a user's data. For example, +an application that uses Google Cloud Datastore for data persistence would use +a service account to authenticate its calls to the Google Cloud Datastore API. +However, an application that needs to access a user's Drive documents would +use the normal OAuth 2.0 profile. + +Additionally, Google Apps domain administrators can grant service accounts +`domain-wide delegation`_ authority to access user data on behalf of users in +the domain. + +This profile uses a JWT to acquire an OAuth 2.0 access token. The JWT is used +in place of the usual authorization token returned during the standard +OAuth 2.0 Authorization Code grant. The JWT is only used for this purpose, as +the acquired access token is used as the bearer token when making requests +using these credentials. + +This profile differs from normal OAuth 2.0 profile because no user consent +step is required. The use of the private key allows this profile to assert +identity directly. + +This profile also differs from the :mod:`google.auth.jwt` authentication +because the JWT credentials use the JWT directly as the bearer token. This +profile instead only uses the JWT to obtain an OAuth 2.0 access token. The +obtained OAuth 2.0 access token is used as the bearer token. + +Domain-wide delegation +---------------------- + +Domain-wide delegation allows a service account to access user data on +behalf of any user in a Google Apps domain without consent from the user. +For example, an application that uses the Google Calendar API to add events to +the calendars of all users in a Google Apps domain would use a service account +to access the Google Calendar API on behalf of users. + +The Google Apps administrator must explicitly authorize the service account to +do this. This authorization step is referred to as "delegating domain-wide +authority" to a service account. + +You can use domain-wise delegation by creating a set of credentials with a +specific subject using :meth:`~Credentials.with_subject`. + +.. _RFC 7523: https://tools.ietf.org/html/rfc7523 +""" + +import copy +import datetime + +from google.auth import _helpers +from google.auth import _service_account_info +from google.auth import credentials +from google.auth import exceptions +from google.auth import iam +from google.auth import jwt +from google.auth import metrics +from google.oauth2 import _client + +_DEFAULT_TOKEN_LIFETIME_SECS = 3600 # 1 hour in seconds +_GOOGLE_OAUTH2_TOKEN_ENDPOINT = "https://oauth2.googleapis.com/token" + + +class Credentials( + credentials.Signing, + credentials.Scoped, + credentials.CredentialsWithQuotaProject, + credentials.CredentialsWithTokenUri, +): + """Service account credentials + + Usually, you'll create these credentials with one of the helper + constructors. To create credentials using a Google service account + private key JSON file:: + + credentials = service_account.Credentials.from_service_account_file( + 'service-account.json') + + Or if you already have the service account file loaded:: + + service_account_info = json.load(open('service_account.json')) + credentials = service_account.Credentials.from_service_account_info( + service_account_info) + + Both helper methods pass on arguments to the constructor, so you can + specify additional scopes and a subject if necessary:: + + credentials = service_account.Credentials.from_service_account_file( + 'service-account.json', + scopes=['email'], + subject='user@example.com') + + The credentials are considered immutable. If you want to modify the scopes + or the subject used for delegation, use :meth:`with_scopes` or + :meth:`with_subject`:: + + scoped_credentials = credentials.with_scopes(['email']) + delegated_credentials = credentials.with_subject(subject) + + To add a quota project, use :meth:`with_quota_project`:: + + credentials = credentials.with_quota_project('myproject-123') + """ + + def __init__( + self, + signer, + service_account_email, + token_uri, + scopes=None, + default_scopes=None, + subject=None, + project_id=None, + quota_project_id=None, + additional_claims=None, + always_use_jwt_access=False, + universe_domain=credentials.DEFAULT_UNIVERSE_DOMAIN, + trust_boundary=None, + ): + """ + Args: + signer (google.auth.crypt.Signer): The signer used to sign JWTs. + service_account_email (str): The service account's email. + scopes (Sequence[str]): User-defined scopes to request during the + authorization grant. + default_scopes (Sequence[str]): Default scopes passed by a + Google client library. Use 'scopes' for user-defined scopes. + token_uri (str): The OAuth 2.0 Token URI. + subject (str): For domain-wide delegation, the email address of the + user to for which to request delegated access. + project_id (str): Project ID associated with the service account + credential. + quota_project_id (Optional[str]): The project ID used for quota and + billing. + additional_claims (Mapping[str, str]): Any additional claims for + the JWT assertion used in the authorization grant. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be always used. + universe_domain (str): The universe domain. The default + universe domain is googleapis.com. For default value self + signed jwt is used for token refresh. + trust_boundary (str): String representation of trust boundary meta. + + .. note:: Typically one of the helper constructors + :meth:`from_service_account_file` or + :meth:`from_service_account_info` are used instead of calling the + constructor directly. + """ + super(Credentials, self).__init__() + + self._cred_file_path = None + self._scopes = scopes + self._default_scopes = default_scopes + self._signer = signer + self._service_account_email = service_account_email + self._subject = subject + self._project_id = project_id + self._quota_project_id = quota_project_id + self._token_uri = token_uri + self._always_use_jwt_access = always_use_jwt_access + self._universe_domain = universe_domain or credentials.DEFAULT_UNIVERSE_DOMAIN + + if universe_domain != credentials.DEFAULT_UNIVERSE_DOMAIN: + self._always_use_jwt_access = True + + self._jwt_credentials = None + + if additional_claims is not None: + self._additional_claims = additional_claims + else: + self._additional_claims = {} + self._trust_boundary = {"locations": [], "encoded_locations": "0x0"} + + @classmethod + def _from_signer_and_info(cls, signer, info, **kwargs): + """Creates a Credentials instance from a signer and service account + info. + + Args: + signer (google.auth.crypt.Signer): The signer used to sign JWTs. + info (Mapping[str, str]): The service account info. + kwargs: Additional arguments to pass to the constructor. + + Returns: + google.auth.jwt.Credentials: The constructed credentials. + + Raises: + ValueError: If the info is not in the expected format. + """ + return cls( + signer, + service_account_email=info["client_email"], + token_uri=info["token_uri"], + project_id=info.get("project_id"), + universe_domain=info.get( + "universe_domain", credentials.DEFAULT_UNIVERSE_DOMAIN + ), + trust_boundary=info.get("trust_boundary"), + **kwargs, + ) + + @classmethod + def from_service_account_info(cls, info, **kwargs): + """Creates a Credentials instance from parsed service account info. + + Args: + info (Mapping[str, str]): The service account info in Google + format. + kwargs: Additional arguments to pass to the constructor. + + Returns: + google.auth.service_account.Credentials: The constructed + credentials. + + Raises: + ValueError: If the info is not in the expected format. + """ + signer = _service_account_info.from_dict( + info, require=["client_email", "token_uri"] + ) + return cls._from_signer_and_info(signer, info, **kwargs) + + @classmethod + def from_service_account_file(cls, filename, **kwargs): + """Creates a Credentials instance from a service account json file. + + Args: + filename (str): The path to the service account json file. + kwargs: Additional arguments to pass to the constructor. + + Returns: + google.auth.service_account.Credentials: The constructed + credentials. + """ + info, signer = _service_account_info.from_filename( + filename, require=["client_email", "token_uri"] + ) + return cls._from_signer_and_info(signer, info, **kwargs) + + @property + def service_account_email(self): + """The service account email.""" + return self._service_account_email + + @property + def project_id(self): + """Project ID associated with this credential.""" + return self._project_id + + @property + def requires_scopes(self): + """Checks if the credentials requires scopes. + + Returns: + bool: True if there are no scopes set otherwise False. + """ + return True if not self._scopes else False + + def _make_copy(self): + cred = self.__class__( + self._signer, + service_account_email=self._service_account_email, + scopes=copy.copy(self._scopes), + default_scopes=copy.copy(self._default_scopes), + token_uri=self._token_uri, + subject=self._subject, + project_id=self._project_id, + quota_project_id=self._quota_project_id, + additional_claims=self._additional_claims.copy(), + always_use_jwt_access=self._always_use_jwt_access, + universe_domain=self._universe_domain, + ) + cred._cred_file_path = self._cred_file_path + return cred + + @_helpers.copy_docstring(credentials.Scoped) + def with_scopes(self, scopes, default_scopes=None): + cred = self._make_copy() + cred._scopes = scopes + cred._default_scopes = default_scopes + return cred + + def with_always_use_jwt_access(self, always_use_jwt_access): + """Create a copy of these credentials with the specified always_use_jwt_access value. + + Args: + always_use_jwt_access (bool): Whether always use self signed JWT or not. + + Returns: + google.auth.service_account.Credentials: A new credentials + instance. + Raises: + google.auth.exceptions.InvalidValue: If the universe domain is not + default and always_use_jwt_access is False. + """ + cred = self._make_copy() + if ( + cred._universe_domain != credentials.DEFAULT_UNIVERSE_DOMAIN + and not always_use_jwt_access + ): + raise exceptions.InvalidValue( + "always_use_jwt_access should be True for non-default universe domain" + ) + cred._always_use_jwt_access = always_use_jwt_access + return cred + + @_helpers.copy_docstring(credentials.CredentialsWithUniverseDomain) + def with_universe_domain(self, universe_domain): + cred = self._make_copy() + cred._universe_domain = universe_domain + if universe_domain != credentials.DEFAULT_UNIVERSE_DOMAIN: + cred._always_use_jwt_access = True + return cred + + def with_subject(self, subject): + """Create a copy of these credentials with the specified subject. + + Args: + subject (str): The subject claim. + + Returns: + google.auth.service_account.Credentials: A new credentials + instance. + """ + cred = self._make_copy() + cred._subject = subject + return cred + + def with_claims(self, additional_claims): + """Returns a copy of these credentials with modified claims. + + Args: + additional_claims (Mapping[str, str]): Any additional claims for + the JWT payload. This will be merged with the current + additional claims. + + Returns: + google.auth.service_account.Credentials: A new credentials + instance. + """ + new_additional_claims = copy.deepcopy(self._additional_claims) + new_additional_claims.update(additional_claims or {}) + cred = self._make_copy() + cred._additional_claims = new_additional_claims + return cred + + @_helpers.copy_docstring(credentials.CredentialsWithQuotaProject) + def with_quota_project(self, quota_project_id): + cred = self._make_copy() + cred._quota_project_id = quota_project_id + return cred + + @_helpers.copy_docstring(credentials.CredentialsWithTokenUri) + def with_token_uri(self, token_uri): + cred = self._make_copy() + cred._token_uri = token_uri + return cred + + def _make_authorization_grant_assertion(self): + """Create the OAuth 2.0 assertion. + + This assertion is used during the OAuth 2.0 grant to acquire an + access token. + + Returns: + bytes: The authorization grant assertion. + """ + now = _helpers.utcnow() + lifetime = datetime.timedelta(seconds=_DEFAULT_TOKEN_LIFETIME_SECS) + expiry = now + lifetime + + payload = { + "iat": _helpers.datetime_to_secs(now), + "exp": _helpers.datetime_to_secs(expiry), + # The issuer must be the service account email. + "iss": self._service_account_email, + # The audience must be the auth token endpoint's URI + "aud": _GOOGLE_OAUTH2_TOKEN_ENDPOINT, + "scope": _helpers.scopes_to_string(self._scopes or ()), + } + + payload.update(self._additional_claims) + + # The subject can be a user email for domain-wide delegation. + if self._subject: + payload.setdefault("sub", self._subject) + + token = jwt.encode(self._signer, payload) + + return token + + def _use_self_signed_jwt(self): + # Since domain wide delegation doesn't work with self signed JWT. If + # subject exists, then we should not use self signed JWT. + return self._subject is None and self._jwt_credentials is not None + + def _metric_header_for_usage(self): + if self._use_self_signed_jwt(): + return metrics.CRED_TYPE_SA_JWT + return metrics.CRED_TYPE_SA_ASSERTION + + @_helpers.copy_docstring(credentials.Credentials) + def refresh(self, request): + if self._always_use_jwt_access and not self._jwt_credentials: + # If self signed jwt should be used but jwt credential is not + # created, try to create one with scopes + self._create_self_signed_jwt(None) + + if ( + self._universe_domain != credentials.DEFAULT_UNIVERSE_DOMAIN + and self._subject + ): + raise exceptions.RefreshError( + "domain wide delegation is not supported for non-default universe domain" + ) + + if self._use_self_signed_jwt(): + self._jwt_credentials.refresh(request) + self.token = self._jwt_credentials.token.decode() + self.expiry = self._jwt_credentials.expiry + else: + assertion = self._make_authorization_grant_assertion() + access_token, expiry, _ = _client.jwt_grant( + request, self._token_uri, assertion + ) + self.token = access_token + self.expiry = expiry + + def _create_self_signed_jwt(self, audience): + """Create a self-signed JWT from the credentials if requirements are met. + + Args: + audience (str): The service URL. ``https://[API_ENDPOINT]/`` + """ + # https://google.aip.dev/auth/4111 + if self._always_use_jwt_access: + if self._scopes: + additional_claims = {"scope": " ".join(self._scopes)} + if ( + self._jwt_credentials is None + or self._jwt_credentials.additional_claims != additional_claims + ): + self._jwt_credentials = jwt.Credentials.from_signing_credentials( + self, None, additional_claims=additional_claims + ) + elif audience: + if ( + self._jwt_credentials is None + or self._jwt_credentials._audience != audience + ): + + self._jwt_credentials = jwt.Credentials.from_signing_credentials( + self, audience + ) + elif self._default_scopes: + additional_claims = {"scope": " ".join(self._default_scopes)} + if ( + self._jwt_credentials is None + or additional_claims != self._jwt_credentials.additional_claims + ): + self._jwt_credentials = jwt.Credentials.from_signing_credentials( + self, None, additional_claims=additional_claims + ) + elif not self._scopes and audience: + self._jwt_credentials = jwt.Credentials.from_signing_credentials( + self, audience + ) + + @_helpers.copy_docstring(credentials.Signing) + def sign_bytes(self, message): + return self._signer.sign(message) + + @property # type: ignore + @_helpers.copy_docstring(credentials.Signing) + def signer(self): + return self._signer + + @property # type: ignore + @_helpers.copy_docstring(credentials.Signing) + def signer_email(self): + return self._service_account_email + + @_helpers.copy_docstring(credentials.Credentials) + def get_cred_info(self): + if self._cred_file_path: + return { + "credential_source": self._cred_file_path, + "credential_type": "service account credentials", + "principal": self.service_account_email, + } + return None + + +class IDTokenCredentials( + credentials.Signing, + credentials.CredentialsWithQuotaProject, + credentials.CredentialsWithTokenUri, +): + """Open ID Connect ID Token-based service account credentials. + + These credentials are largely similar to :class:`.Credentials`, but instead + of using an OAuth 2.0 Access Token as the bearer token, they use an Open + ID Connect ID Token as the bearer token. These credentials are useful when + communicating to services that require ID Tokens and can not accept access + tokens. + + Usually, you'll create these credentials with one of the helper + constructors. To create credentials using a Google service account + private key JSON file:: + + credentials = ( + service_account.IDTokenCredentials.from_service_account_file( + 'service-account.json')) + + + Or if you already have the service account file loaded:: + + service_account_info = json.load(open('service_account.json')) + credentials = ( + service_account.IDTokenCredentials.from_service_account_info( + service_account_info)) + + + Both helper methods pass on arguments to the constructor, so you can + specify additional scopes and a subject if necessary:: + + credentials = ( + service_account.IDTokenCredentials.from_service_account_file( + 'service-account.json', + scopes=['email'], + subject='user@example.com')) + + + The credentials are considered immutable. If you want to modify the scopes + or the subject used for delegation, use :meth:`with_scopes` or + :meth:`with_subject`:: + + scoped_credentials = credentials.with_scopes(['email']) + delegated_credentials = credentials.with_subject(subject) + + """ + + def __init__( + self, + signer, + service_account_email, + token_uri, + target_audience, + additional_claims=None, + quota_project_id=None, + universe_domain=credentials.DEFAULT_UNIVERSE_DOMAIN, + ): + """ + Args: + signer (google.auth.crypt.Signer): The signer used to sign JWTs. + service_account_email (str): The service account's email. + token_uri (str): The OAuth 2.0 Token URI. + target_audience (str): The intended audience for these credentials, + used when requesting the ID Token. The ID Token's ``aud`` claim + will be set to this string. + additional_claims (Mapping[str, str]): Any additional claims for + the JWT assertion used in the authorization grant. + quota_project_id (Optional[str]): The project ID used for quota and billing. + universe_domain (str): The universe domain. The default + universe domain is googleapis.com. For default value IAM ID + token endponint is used for token refresh. Note that + iam.serviceAccountTokenCreator role is required to use the IAM + endpoint. + .. note:: Typically one of the helper constructors + :meth:`from_service_account_file` or + :meth:`from_service_account_info` are used instead of calling the + constructor directly. + """ + super(IDTokenCredentials, self).__init__() + self._signer = signer + self._service_account_email = service_account_email + self._token_uri = token_uri + self._target_audience = target_audience + self._quota_project_id = quota_project_id + self._use_iam_endpoint = False + + if not universe_domain: + self._universe_domain = credentials.DEFAULT_UNIVERSE_DOMAIN + else: + self._universe_domain = universe_domain + self._iam_id_token_endpoint = iam._IAM_IDTOKEN_ENDPOINT.replace( + "googleapis.com", self._universe_domain + ) + + if self._universe_domain != credentials.DEFAULT_UNIVERSE_DOMAIN: + self._use_iam_endpoint = True + + if additional_claims is not None: + self._additional_claims = additional_claims + else: + self._additional_claims = {} + + @classmethod + def _from_signer_and_info(cls, signer, info, **kwargs): + """Creates a credentials instance from a signer and service account + info. + + Args: + signer (google.auth.crypt.Signer): The signer used to sign JWTs. + info (Mapping[str, str]): The service account info. + kwargs: Additional arguments to pass to the constructor. + + Returns: + google.auth.jwt.IDTokenCredentials: The constructed credentials. + + Raises: + ValueError: If the info is not in the expected format. + """ + kwargs.setdefault("service_account_email", info["client_email"]) + kwargs.setdefault("token_uri", info["token_uri"]) + if "universe_domain" in info: + kwargs["universe_domain"] = info["universe_domain"] + return cls(signer, **kwargs) + + @classmethod + def from_service_account_info(cls, info, **kwargs): + """Creates a credentials instance from parsed service account info. + + Args: + info (Mapping[str, str]): The service account info in Google + format. + kwargs: Additional arguments to pass to the constructor. + + Returns: + google.auth.service_account.IDTokenCredentials: The constructed + credentials. + + Raises: + ValueError: If the info is not in the expected format. + """ + signer = _service_account_info.from_dict( + info, require=["client_email", "token_uri"] + ) + return cls._from_signer_and_info(signer, info, **kwargs) + + @classmethod + def from_service_account_file(cls, filename, **kwargs): + """Creates a credentials instance from a service account json file. + + Args: + filename (str): The path to the service account json file. + kwargs: Additional arguments to pass to the constructor. + + Returns: + google.auth.service_account.IDTokenCredentials: The constructed + credentials. + """ + info, signer = _service_account_info.from_filename( + filename, require=["client_email", "token_uri"] + ) + return cls._from_signer_and_info(signer, info, **kwargs) + + def _make_copy(self): + cred = self.__class__( + self._signer, + service_account_email=self._service_account_email, + token_uri=self._token_uri, + target_audience=self._target_audience, + additional_claims=self._additional_claims.copy(), + quota_project_id=self.quota_project_id, + universe_domain=self._universe_domain, + ) + # _use_iam_endpoint is not exposed in the constructor + cred._use_iam_endpoint = self._use_iam_endpoint + return cred + + def with_target_audience(self, target_audience): + """Create a copy of these credentials with the specified target + audience. + + Args: + target_audience (str): The intended audience for these credentials, + used when requesting the ID Token. + + Returns: + google.auth.service_account.IDTokenCredentials: A new credentials + instance. + """ + cred = self._make_copy() + cred._target_audience = target_audience + return cred + + def _with_use_iam_endpoint(self, use_iam_endpoint): + """Create a copy of these credentials with the use_iam_endpoint value. + + Args: + use_iam_endpoint (bool): If True, IAM generateIdToken endpoint will + be used instead of the token_uri. Note that + iam.serviceAccountTokenCreator role is required to use the IAM + endpoint. The default value is False. This feature is currently + experimental and subject to change without notice. + + Returns: + google.auth.service_account.IDTokenCredentials: A new credentials + instance. + Raises: + google.auth.exceptions.InvalidValue: If the universe domain is not + default and use_iam_endpoint is False. + """ + cred = self._make_copy() + if ( + cred._universe_domain != credentials.DEFAULT_UNIVERSE_DOMAIN + and not use_iam_endpoint + ): + raise exceptions.InvalidValue( + "use_iam_endpoint should be True for non-default universe domain" + ) + cred._use_iam_endpoint = use_iam_endpoint + return cred + + @_helpers.copy_docstring(credentials.CredentialsWithQuotaProject) + def with_quota_project(self, quota_project_id): + cred = self._make_copy() + cred._quota_project_id = quota_project_id + return cred + + @_helpers.copy_docstring(credentials.CredentialsWithTokenUri) + def with_token_uri(self, token_uri): + cred = self._make_copy() + cred._token_uri = token_uri + return cred + + def _make_authorization_grant_assertion(self): + """Create the OAuth 2.0 assertion. + + This assertion is used during the OAuth 2.0 grant to acquire an + ID token. + + Returns: + bytes: The authorization grant assertion. + """ + now = _helpers.utcnow() + lifetime = datetime.timedelta(seconds=_DEFAULT_TOKEN_LIFETIME_SECS) + expiry = now + lifetime + + payload = { + "iat": _helpers.datetime_to_secs(now), + "exp": _helpers.datetime_to_secs(expiry), + # The issuer must be the service account email. + "iss": self.service_account_email, + # The audience must be the auth token endpoint's URI + "aud": _GOOGLE_OAUTH2_TOKEN_ENDPOINT, + # The target audience specifies which service the ID token is + # intended for. + "target_audience": self._target_audience, + } + + payload.update(self._additional_claims) + + token = jwt.encode(self._signer, payload) + + return token + + def _refresh_with_iam_endpoint(self, request): + """Use IAM generateIdToken endpoint to obtain an ID token. + + It works as follows: + + 1. First we create a self signed jwt with + https://www.googleapis.com/auth/iam being the scope. + + 2. Next we use the self signed jwt as the access token, and make a POST + request to IAM generateIdToken endpoint. The request body is: + { + "audience": self._target_audience, + "includeEmail": "true", + "useEmailAzp": "true", + } + + If the request is succesfully, it will return {"token":"the ID token"}, + and we can extract the ID token and compute its expiry. + """ + jwt_credentials = jwt.Credentials.from_signing_credentials( + self, + None, + additional_claims={"scope": "https://www.googleapis.com/auth/iam"}, + ) + jwt_credentials.refresh(request) + self.token, self.expiry = _client.call_iam_generate_id_token_endpoint( + request, + self._iam_id_token_endpoint, + self.signer_email, + self._target_audience, + jwt_credentials.token.decode(), + self._universe_domain, + ) + + @_helpers.copy_docstring(credentials.Credentials) + def refresh(self, request): + if self._use_iam_endpoint: + self._refresh_with_iam_endpoint(request) + else: + assertion = self._make_authorization_grant_assertion() + access_token, expiry, _ = _client.id_token_jwt_grant( + request, self._token_uri, assertion + ) + self.token = access_token + self.expiry = expiry + + @property + def service_account_email(self): + """The service account email.""" + return self._service_account_email + + @_helpers.copy_docstring(credentials.Signing) + def sign_bytes(self, message): + return self._signer.sign(message) + + @property # type: ignore + @_helpers.copy_docstring(credentials.Signing) + def signer(self): + return self._signer + + @property # type: ignore + @_helpers.copy_docstring(credentials.Signing) + def signer_email(self): + return self._service_account_email diff --git a/lib/python3.10/site-packages/google/oauth2/sts.py b/lib/python3.10/site-packages/google/oauth2/sts.py new file mode 100644 index 0000000000000000000000000000000000000000..ad3962735f7ad9103a92871fc856133490930c56 --- /dev/null +++ b/lib/python3.10/site-packages/google/oauth2/sts.py @@ -0,0 +1,176 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""OAuth 2.0 Token Exchange Spec. + +This module defines a token exchange utility based on the `OAuth 2.0 Token +Exchange`_ spec. This will be mainly used to exchange external credentials +for GCP access tokens in workload identity pools to access Google APIs. + +The implementation will support various types of client authentication as +allowed in the spec. + +A deviation on the spec will be for additional Google specific options that +cannot be easily mapped to parameters defined in the RFC. + +The returned dictionary response will be based on the `rfc8693 section 2.2.1`_ +spec JSON response. + +.. _OAuth 2.0 Token Exchange: https://tools.ietf.org/html/rfc8693 +.. _rfc8693 section 2.2.1: https://tools.ietf.org/html/rfc8693#section-2.2.1 +""" + +import http.client as http_client +import json +import urllib + +from google.oauth2 import utils + + +_URLENCODED_HEADERS = {"Content-Type": "application/x-www-form-urlencoded"} + + +class Client(utils.OAuthClientAuthHandler): + """Implements the OAuth 2.0 token exchange spec based on + https://tools.ietf.org/html/rfc8693. + """ + + def __init__(self, token_exchange_endpoint, client_authentication=None): + """Initializes an STS client instance. + + Args: + token_exchange_endpoint (str): The token exchange endpoint. + client_authentication (Optional(google.oauth2.oauth2_utils.ClientAuthentication)): + The optional OAuth client authentication credentials if available. + """ + super(Client, self).__init__(client_authentication) + self._token_exchange_endpoint = token_exchange_endpoint + + def _make_request(self, request, headers, request_body): + # Initialize request headers. + request_headers = _URLENCODED_HEADERS.copy() + + # Inject additional headers. + if headers: + for k, v in dict(headers).items(): + request_headers[k] = v + + # Apply OAuth client authentication. + self.apply_client_authentication_options(request_headers, request_body) + + # Execute request. + response = request( + url=self._token_exchange_endpoint, + method="POST", + headers=request_headers, + body=urllib.parse.urlencode(request_body).encode("utf-8"), + ) + + response_body = ( + response.data.decode("utf-8") + if hasattr(response.data, "decode") + else response.data + ) + + # If non-200 response received, translate to OAuthError exception. + if response.status != http_client.OK: + utils.handle_error_response(response_body) + + response_data = json.loads(response_body) + + # Return successful response. + return response_data + + def exchange_token( + self, + request, + grant_type, + subject_token, + subject_token_type, + resource=None, + audience=None, + scopes=None, + requested_token_type=None, + actor_token=None, + actor_token_type=None, + additional_options=None, + additional_headers=None, + ): + """Exchanges the provided token for another type of token based on the + rfc8693 spec. + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. + grant_type (str): The OAuth 2.0 token exchange grant type. + subject_token (str): The OAuth 2.0 token exchange subject token. + subject_token_type (str): The OAuth 2.0 token exchange subject token type. + resource (Optional[str]): The optional OAuth 2.0 token exchange resource field. + audience (Optional[str]): The optional OAuth 2.0 token exchange audience field. + scopes (Optional[Sequence[str]]): The optional list of scopes to use. + requested_token_type (Optional[str]): The optional OAuth 2.0 token exchange requested + token type. + actor_token (Optional[str]): The optional OAuth 2.0 token exchange actor token. + actor_token_type (Optional[str]): The optional OAuth 2.0 token exchange actor token type. + additional_options (Optional[Mapping[str, str]]): The optional additional + non-standard Google specific options. + additional_headers (Optional[Mapping[str, str]]): The optional additional + headers to pass to the token exchange endpoint. + + Returns: + Mapping[str, str]: The token exchange JSON-decoded response data containing + the requested token and its expiration time. + + Raises: + google.auth.exceptions.OAuthError: If the token endpoint returned + an error. + """ + # Initialize request body. + request_body = { + "grant_type": grant_type, + "resource": resource, + "audience": audience, + "scope": " ".join(scopes or []), + "requested_token_type": requested_token_type, + "subject_token": subject_token, + "subject_token_type": subject_token_type, + "actor_token": actor_token, + "actor_token_type": actor_token_type, + "options": None, + } + # Add additional non-standard options. + if additional_options: + request_body["options"] = urllib.parse.quote(json.dumps(additional_options)) + # Remove empty fields in request body. + for k, v in dict(request_body).items(): + if v is None or v == "": + del request_body[k] + + return self._make_request(request, additional_headers, request_body) + + def refresh_token(self, request, refresh_token): + """Exchanges a refresh token for an access token based on the + RFC6749 spec. + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. + subject_token (str): The OAuth 2.0 refresh token. + """ + + return self._make_request( + request, + None, + {"grant_type": "refresh_token", "refresh_token": refresh_token}, + ) diff --git a/lib/python3.10/site-packages/google/oauth2/utils.py b/lib/python3.10/site-packages/google/oauth2/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d72ff1916631f17773d63244eac0179b9109c8d1 --- /dev/null +++ b/lib/python3.10/site-packages/google/oauth2/utils.py @@ -0,0 +1,168 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""OAuth 2.0 Utilities. + +This module provides implementations for various OAuth 2.0 utilities. +This includes `OAuth error handling`_ and +`Client authentication for OAuth flows`_. + +OAuth error handling +-------------------- +This will define interfaces for handling OAuth related error responses as +stated in `RFC 6749 section 5.2`_. +This will include a common function to convert these HTTP error responses to a +:class:`google.auth.exceptions.OAuthError` exception. + + +Client authentication for OAuth flows +------------------------------------- +We introduce an interface for defining client authentication credentials based +on `RFC 6749 section 2.3.1`_. This will expose the following +capabilities: + + * Ability to support basic authentication via request header. + * Ability to support bearer token authentication via request header. + * Ability to support client ID / secret authentication via request body. + +.. _RFC 6749 section 2.3.1: https://tools.ietf.org/html/rfc6749#section-2.3.1 +.. _RFC 6749 section 5.2: https://tools.ietf.org/html/rfc6749#section-5.2 +""" + +import abc +import base64 +import enum +import json + +from google.auth import exceptions + + +# OAuth client authentication based on +# https://tools.ietf.org/html/rfc6749#section-2.3. +class ClientAuthType(enum.Enum): + basic = 1 + request_body = 2 + + +class ClientAuthentication(object): + """Defines the client authentication credentials for basic and request-body + types based on https://tools.ietf.org/html/rfc6749#section-2.3.1. + """ + + def __init__(self, client_auth_type, client_id, client_secret=None): + """Instantiates a client authentication object containing the client ID + and secret credentials for basic and response-body auth. + + Args: + client_auth_type (google.oauth2.oauth_utils.ClientAuthType): The + client authentication type. + client_id (str): The client ID. + client_secret (Optional[str]): The client secret. + """ + self.client_auth_type = client_auth_type + self.client_id = client_id + self.client_secret = client_secret + + +class OAuthClientAuthHandler(metaclass=abc.ABCMeta): + """Abstract class for handling client authentication in OAuth-based + operations. + """ + + def __init__(self, client_authentication=None): + """Instantiates an OAuth client authentication handler. + + Args: + client_authentication (Optional[google.oauth2.utils.ClientAuthentication]): + The OAuth client authentication credentials if available. + """ + super(OAuthClientAuthHandler, self).__init__() + self._client_authentication = client_authentication + + def apply_client_authentication_options( + self, headers, request_body=None, bearer_token=None + ): + """Applies client authentication on the OAuth request's headers or POST + body. + + Args: + headers (Mapping[str, str]): The HTTP request header. + request_body (Optional[Mapping[str, str]]): The HTTP request body + dictionary. For requests that do not support request body, this + is None and will be ignored. + bearer_token (Optional[str]): The optional bearer token. + """ + # Inject authenticated header. + self._inject_authenticated_headers(headers, bearer_token) + # Inject authenticated request body. + if bearer_token is None: + self._inject_authenticated_request_body(request_body) + + def _inject_authenticated_headers(self, headers, bearer_token=None): + if bearer_token is not None: + headers["Authorization"] = "Bearer %s" % bearer_token + elif ( + self._client_authentication is not None + and self._client_authentication.client_auth_type is ClientAuthType.basic + ): + username = self._client_authentication.client_id + password = self._client_authentication.client_secret or "" + + credentials = base64.b64encode( + ("%s:%s" % (username, password)).encode() + ).decode() + headers["Authorization"] = "Basic %s" % credentials + + def _inject_authenticated_request_body(self, request_body): + if ( + self._client_authentication is not None + and self._client_authentication.client_auth_type + is ClientAuthType.request_body + ): + if request_body is None: + raise exceptions.OAuthError( + "HTTP request does not support request-body" + ) + else: + request_body["client_id"] = self._client_authentication.client_id + request_body["client_secret"] = ( + self._client_authentication.client_secret or "" + ) + + +def handle_error_response(response_body): + """Translates an error response from an OAuth operation into an + OAuthError exception. + + Args: + response_body (str): The decoded response data. + + Raises: + google.auth.exceptions.OAuthError + """ + try: + error_components = [] + error_data = json.loads(response_body) + + error_components.append("Error code {}".format(error_data["error"])) + if "error_description" in error_data: + error_components.append(": {}".format(error_data["error_description"])) + if "error_uri" in error_data: + error_components.append(" - {}".format(error_data["error_uri"])) + error_details = "".join(error_components) + # If no details could be extracted, use the response data. + except (KeyError, ValueError): + error_details = response_body + + raise exceptions.OAuthError(error_details, response_body) diff --git a/lib/python3.10/site-packages/google/oauth2/webauthn_handler.py b/lib/python3.10/site-packages/google/oauth2/webauthn_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..e27c7e0990051455a405883c307f819b73d1af31 --- /dev/null +++ b/lib/python3.10/site-packages/google/oauth2/webauthn_handler.py @@ -0,0 +1,82 @@ +import abc +import os +import struct +import subprocess + +from google.auth import exceptions +from google.oauth2.webauthn_types import GetRequest, GetResponse + + +class WebAuthnHandler(abc.ABC): + @abc.abstractmethod + def is_available(self) -> bool: + """Check whether this WebAuthn handler is available""" + raise NotImplementedError("is_available method must be implemented") + + @abc.abstractmethod + def get(self, get_request: GetRequest) -> GetResponse: + """WebAuthn get (assertion)""" + raise NotImplementedError("get method must be implemented") + + +class PluginHandler(WebAuthnHandler): + """Offloads WebAuthn get reqeust to a pluggable command-line tool. + + Offloads WebAuthn get to a plugin which takes the form of a + command-line tool. The command-line tool is configurable via the + PluginHandler._ENV_VAR environment variable. + + The WebAuthn plugin should implement the following interface: + + Communication occurs over stdin/stdout, and messages are both sent and + received in the form: + + [4 bytes - payload size (little-endian)][variable bytes - json payload] + """ + + _ENV_VAR = "GOOGLE_AUTH_WEBAUTHN_PLUGIN" + + def is_available(self) -> bool: + try: + self._find_plugin() + except Exception: + return False + else: + return True + + def get(self, get_request: GetRequest) -> GetResponse: + request_json = get_request.to_json() + cmd = self._find_plugin() + response_json = self._call_plugin(cmd, request_json) + return GetResponse.from_json(response_json) + + def _call_plugin(self, cmd: str, input_json: str) -> str: + # Calculate length of input + input_length = len(input_json) + length_bytes_le = struct.pack(" str: + plugin_cmd = os.environ.get(PluginHandler._ENV_VAR) + if plugin_cmd is None: + raise exceptions.InvalidResource( + "{} env var is not set".format(PluginHandler._ENV_VAR) + ) + return plugin_cmd diff --git a/lib/python3.10/site-packages/google/oauth2/webauthn_handler_factory.py b/lib/python3.10/site-packages/google/oauth2/webauthn_handler_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..184329fed7e921cff1003a91ddd4c9819618a065 --- /dev/null +++ b/lib/python3.10/site-packages/google/oauth2/webauthn_handler_factory.py @@ -0,0 +1,16 @@ +from typing import List, Optional + +from google.oauth2.webauthn_handler import PluginHandler, WebAuthnHandler + + +class WebauthnHandlerFactory: + handlers: List[WebAuthnHandler] + + def __init__(self): + self.handlers = [PluginHandler()] + + def get_handler(self) -> Optional[WebAuthnHandler]: + for handler in self.handlers: + if handler.is_available(): + return handler + return None diff --git a/lib/python3.10/site-packages/google/oauth2/webauthn_types.py b/lib/python3.10/site-packages/google/oauth2/webauthn_types.py new file mode 100644 index 0000000000000000000000000000000000000000..24e984f3d33636259b60819fd372207267fd6ce4 --- /dev/null +++ b/lib/python3.10/site-packages/google/oauth2/webauthn_types.py @@ -0,0 +1,156 @@ +from dataclasses import dataclass +import json +from typing import Any, Dict, List, Optional + +from google.auth import exceptions + + +@dataclass(frozen=True) +class PublicKeyCredentialDescriptor: + """Descriptor for a security key based credential. + + https://www.w3.org/TR/webauthn-3/#dictionary-credential-descriptor + + Args: + id: credential id (key handle). + transports: <'usb'|'nfc'|'ble'|'internal'> List of supported transports. + """ + + id: str + transports: Optional[List[str]] = None + + def to_dict(self): + cred = {"type": "public-key", "id": self.id} + if self.transports: + cred["transports"] = self.transports + return cred + + +@dataclass +class AuthenticationExtensionsClientInputs: + """Client extensions inputs for WebAuthn extensions. + + Args: + appid: app id that can be asserted with in addition to rpid. + https://www.w3.org/TR/webauthn-3/#sctn-appid-extension + """ + + appid: Optional[str] = None + + def to_dict(self): + extensions = {} + if self.appid: + extensions["appid"] = self.appid + return extensions + + +@dataclass +class GetRequest: + """WebAuthn get request + + Args: + origin: Origin where the WebAuthn get assertion takes place. + rpid: Relying Party ID. + challenge: raw challenge. + timeout_ms: Timeout number in millisecond. + allow_credentials: List of allowed credentials. + user_verification: <'required'|'preferred'|'discouraged'> User verification requirement. + extensions: WebAuthn authentication extensions inputs. + """ + + origin: str + rpid: str + challenge: str + timeout_ms: Optional[int] = None + allow_credentials: Optional[List[PublicKeyCredentialDescriptor]] = None + user_verification: Optional[str] = None + extensions: Optional[AuthenticationExtensionsClientInputs] = None + + def to_json(self) -> str: + req_options: Dict[str, Any] = {"rpId": self.rpid, "challenge": self.challenge} + if self.timeout_ms: + req_options["timeout"] = self.timeout_ms + if self.allow_credentials: + req_options["allowCredentials"] = [ + c.to_dict() for c in self.allow_credentials + ] + if self.user_verification: + req_options["userVerification"] = self.user_verification + if self.extensions: + req_options["extensions"] = self.extensions.to_dict() + return json.dumps( + {"type": "get", "origin": self.origin, "requestData": req_options} + ) + + +@dataclass(frozen=True) +class AuthenticatorAssertionResponse: + """Authenticator response to a WebAuthn get (assertion) request. + + https://www.w3.org/TR/webauthn-3/#authenticatorassertionresponse + + Args: + client_data_json: client data JSON. + authenticator_data: authenticator data. + signature: signature. + user_handle: user handle. + """ + + client_data_json: str + authenticator_data: str + signature: str + user_handle: Optional[str] + + +@dataclass(frozen=True) +class GetResponse: + """WebAuthn get (assertion) response. + + Args: + id: credential id (key handle). + response: The authenticator assertion response. + authenticator_attachment: <'cross-platform'|'platform'> The attachment status of the authenticator. + client_extension_results: WebAuthn authentication extensions output results in a dictionary. + """ + + id: str + response: AuthenticatorAssertionResponse + authenticator_attachment: Optional[str] + client_extension_results: Optional[Dict] + + @staticmethod + def from_json(json_str: str): + """Verify and construct GetResponse from a JSON string.""" + try: + resp_json = json.loads(json_str) + except ValueError: + raise exceptions.MalformedError("Invalid Get JSON response") + if resp_json.get("type") != "getResponse": + raise exceptions.MalformedError( + "Invalid Get response type: {}".format(resp_json.get("type")) + ) + pk_cred = resp_json.get("responseData") + if pk_cred is None: + if resp_json.get("error"): + raise exceptions.ReauthFailError( + "WebAuthn.get failure: {}".format(resp_json["error"]) + ) + else: + raise exceptions.MalformedError("Get response is empty") + if pk_cred.get("type") != "public-key": + raise exceptions.MalformedError( + "Invalid credential type: {}".format(pk_cred.get("type")) + ) + assertion_json = pk_cred["response"] + assertion_resp = AuthenticatorAssertionResponse( + client_data_json=assertion_json["clientDataJSON"], + authenticator_data=assertion_json["authenticatorData"], + signature=assertion_json["signature"], + user_handle=assertion_json.get("userHandle"), + ) + return GetResponse( + id=pk_cred["id"], + response=assertion_resp, + authenticator_attachment=pk_cred.get("authenticatorAttachment"), + client_extension_results=pk_cred.get("clientExtensionResults"), + ) diff --git a/lib/python3.10/site-packages/google/protobuf/compiler/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/google/protobuf/compiler/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..66c28536a61664ea68cda8bab84c348fd6163ced Binary files /dev/null and b/lib/python3.10/site-packages/google/protobuf/compiler/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/google/protobuf/compiler/__pycache__/plugin_pb2.cpython-310.pyc b/lib/python3.10/site-packages/google/protobuf/compiler/__pycache__/plugin_pb2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..485f1c46243760df7e9daff52e3709604b97a5eb Binary files /dev/null and b/lib/python3.10/site-packages/google/protobuf/compiler/__pycache__/plugin_pb2.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/google/protobuf/internal/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/google/protobuf/internal/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..12d343701167a002cdb13573db669fd2df79cec1 Binary files /dev/null and b/lib/python3.10/site-packages/google/protobuf/internal/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/google/protobuf/internal/__pycache__/testing_refleaks.cpython-310.pyc b/lib/python3.10/site-packages/google/protobuf/internal/__pycache__/testing_refleaks.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ddee84bc0e1ee551b808a1db40a5a205a082173d Binary files /dev/null and b/lib/python3.10/site-packages/google/protobuf/internal/__pycache__/testing_refleaks.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/google/protobuf/internal/__pycache__/type_checkers.cpython-310.pyc b/lib/python3.10/site-packages/google/protobuf/internal/__pycache__/type_checkers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a7cdce3f12a80e914b7df824214792a5ecbbe0cb Binary files /dev/null and b/lib/python3.10/site-packages/google/protobuf/internal/__pycache__/type_checkers.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/google/protobuf/internal/__pycache__/well_known_types.cpython-310.pyc b/lib/python3.10/site-packages/google/protobuf/internal/__pycache__/well_known_types.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..91d1f71e8f3a962eafecb959075249aef1cfdbb1 Binary files /dev/null and b/lib/python3.10/site-packages/google/protobuf/internal/__pycache__/well_known_types.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/google/protobuf/internal/__pycache__/wire_format.cpython-310.pyc b/lib/python3.10/site-packages/google/protobuf/internal/__pycache__/wire_format.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..90d07c254fa1ba5f77080d3046574eb5f8d28c44 Binary files /dev/null and b/lib/python3.10/site-packages/google/protobuf/internal/__pycache__/wire_format.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/google/protobuf/pyext/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/google/protobuf/pyext/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c2e5e5b97abb5243db30357ffcd08df03c110557 Binary files /dev/null and b/lib/python3.10/site-packages/google/protobuf/pyext/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/google/protobuf/pyext/__pycache__/cpp_message.cpython-310.pyc b/lib/python3.10/site-packages/google/protobuf/pyext/__pycache__/cpp_message.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..906c655051bba073c3270dfe28f06279ab48f83f Binary files /dev/null and b/lib/python3.10/site-packages/google/protobuf/pyext/__pycache__/cpp_message.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/google/protobuf/testdata/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/google/protobuf/testdata/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..12ff554b51bb1bd9f80908a7e6fead5e5dd24f9c Binary files /dev/null and b/lib/python3.10/site-packages/google/protobuf/testdata/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/google/protobuf/util/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/google/protobuf/util/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7eddda46bea20a3c9f93ea8855fe93392f731f9c Binary files /dev/null and b/lib/python3.10/site-packages/google/protobuf/util/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3b792694723e9b40d4261c7f72733d3fbe13da15 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/__pycache__/asteroidal.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/asteroidal.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c6ca3ee8a2602faa9686456b0170298a894ad7ca Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/asteroidal.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/__pycache__/boundary.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/boundary.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b6818ebffa3c9064adae47ef1f8c6858cd19d09 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/boundary.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/__pycache__/bridges.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/bridges.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..21db887b526fdbcd082cbad0499a416ba02a89c8 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/bridges.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/__pycache__/broadcasting.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/broadcasting.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb1f2552b57ea4b4495522426e01a7c0250fc212 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/broadcasting.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/__pycache__/chains.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/chains.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..57c3466f10cba3dca904e41aabd160f4913fadd1 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/chains.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/__pycache__/chordal.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/chordal.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cea0194b912c542cda48d26d5cf0b15a9f806997 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/chordal.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/__pycache__/clique.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/clique.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..664ec6f9715b54e521a9cadd5eed4b5a4cd251b5 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/clique.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/__pycache__/cluster.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/cluster.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..231919c1089d7a126372ad7efcf89aa723c62d73 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/cluster.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/__pycache__/communicability_alg.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/communicability_alg.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20368b795d88e418766875c2814a88ce44e70775 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/communicability_alg.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/__pycache__/core.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/core.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..795e534cc19ac7ea40d2dfab733474f5f1f0315f Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/core.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/__pycache__/covering.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/covering.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8fd165ed029fcb4530d17c36436d734446eaf8d2 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/covering.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/__pycache__/cuts.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/cuts.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5110ca3b2d4c512da0768df95cc210cc73fe2716 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/cuts.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/__pycache__/cycles.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/cycles.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9785c3e6aeb47187585fb78be97b7f36b38ddfcb Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/cycles.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/__pycache__/d_separation.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/d_separation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c70b3a6a5cd51dc3cf45c39711aa9444e9102d5 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/d_separation.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/__pycache__/dag.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/dag.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..87c7df34c8a9ebaecd8c236dc0c1c5c2deb7c4a6 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/dag.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/__pycache__/distance_measures.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/distance_measures.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..acc16b5622e2b1c7a120089a334bf2840fd81f0b Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/distance_measures.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/__pycache__/distance_regular.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/distance_regular.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..466f0b0cfcb3d6cfb55c4b63c24073bab58e2e51 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/distance_regular.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/__pycache__/dominance.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/dominance.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dbf1fdc00139a2c6edebbe93447ecb4817bfcc24 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/dominance.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/__pycache__/dominating.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/dominating.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..36622edc7faa53c06d429734cd7fcf5739f0a241 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/dominating.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/__pycache__/efficiency_measures.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/efficiency_measures.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a0f52545891b2e7a2e863cd95234a2788e7733bd Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/efficiency_measures.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/__pycache__/euler.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/euler.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df64ba86ad204c3975196fbcde07b770d232e305 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/euler.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/__pycache__/graph_hashing.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/graph_hashing.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..def87b03699c34be041d07d21708775861810cf5 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/graph_hashing.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/__pycache__/graphical.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/graphical.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2614bb33c5eefc753be567035a95c4b2805b01f9 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/graphical.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/__pycache__/hierarchy.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/hierarchy.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a8620de428215082b7f6020a3caf614b72e2f31 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/hierarchy.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/__pycache__/hybrid.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/hybrid.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb9328b6caaf9601a70fc7f183300a15bd3ea1a9 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/hybrid.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/__pycache__/isolate.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/isolate.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..10ce691e1be2becb5fd135250faee11291d3ed45 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/isolate.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/__pycache__/link_prediction.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/link_prediction.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5d9826aaf9d14877801c0cf69d2c0067c5a75e71 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/link_prediction.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/__pycache__/lowest_common_ancestors.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/lowest_common_ancestors.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee8c7f03c0824b634058d19de259c371077abb0b Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/lowest_common_ancestors.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/__pycache__/matching.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/matching.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..43518de259761a6c800f46470996804b786c2f85 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/matching.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/__pycache__/mis.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/mis.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab2aabeaec2292d198b24d14ab48f79da9e78a6a Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/mis.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/__pycache__/moral.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/moral.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..904477ef0ac8575a91484e121fc3814946c87f7c Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/moral.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/__pycache__/node_classification.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/node_classification.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ec9fe21330db67d2ca339d4ee19f7bb43bc9f79 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/node_classification.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/__pycache__/non_randomness.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/non_randomness.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..99f5c78f2d66ecaec49194d49e25151444281ee9 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/non_randomness.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/__pycache__/planar_drawing.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/planar_drawing.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..242946fd87a974c779c46aa542d90521f3391e6f Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/planar_drawing.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/__pycache__/planarity.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/planarity.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..894110b5f13509f99ce5bef9d157324d31b14d3d Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/planarity.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/__pycache__/polynomials.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/polynomials.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49f68bbc05416458afa9fd24b99218dfc4f2cd8f Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/polynomials.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/__pycache__/reciprocity.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/reciprocity.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2430dec6e088d0bd80530da8b9b5daf8a942452 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/reciprocity.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/__pycache__/regular.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/regular.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62e0cf9281f903529077db719b3bae93817219ca Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/regular.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/__pycache__/richclub.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/richclub.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a030dbbaa0b34731d4b593d0894d43211df85a77 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/richclub.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/__pycache__/similarity.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/similarity.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f677a4aeca4b55f66c70979d97700246652adb4f Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/similarity.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/__pycache__/simple_paths.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/simple_paths.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e0cae7bd0e90e903b814a9a662d6f0ca8f638955 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/simple_paths.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/__pycache__/smallworld.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/smallworld.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b0ee3109aa7812e8ba9bbcd04484a7e7ff27f6e Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/smallworld.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/__pycache__/smetric.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/smetric.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..954be1f3d7bb7e884e04b8334cb528244d21556d Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/smetric.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/__pycache__/sparsifiers.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/sparsifiers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..55f96978937e00dfeaa4542e391d5fd1a61580a7 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/sparsifiers.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/__pycache__/structuralholes.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/structuralholes.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d7a975cffa01392cd0349fd09b8d1e1e14ce7565 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/structuralholes.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/__pycache__/summarization.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/summarization.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..39fd9e43ab52392277e2c5a3c2108fc113501434 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/summarization.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/__pycache__/swap.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/swap.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..192c67d5e4acefcea5935d8a9843d444462b9bc7 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/swap.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/__pycache__/threshold.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/threshold.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fef5a98124b97ea6f158badb4e5b5a186043d85d Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/threshold.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/__pycache__/time_dependent.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/time_dependent.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..685ec6504ded18775438304038a39af4a42fcbd8 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/time_dependent.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/__pycache__/tournament.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/tournament.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a047becf04e283457cc355e9dfd674d769701d9e Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/tournament.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/__pycache__/triads.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/triads.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ed4515f035de6c814d41d865f9f488bcaa781e6c Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/triads.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/__pycache__/vitality.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/vitality.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..667da215362537109d2eef577e6a2196c1df4eb1 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/vitality.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/__pycache__/voronoi.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/voronoi.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0df4a74ff831c4c0dd518b0979e0a95d554ba481 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/voronoi.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/__pycache__/walks.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/walks.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..98da495d639388937a21b72971462077a6b7c789 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/walks.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/__pycache__/wiener.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/wiener.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..50eedd60d3634526c4db9aaba944341188cc5b88 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/__pycache__/wiener.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/approximation/__init__.py b/lib/python3.10/site-packages/networkx/algorithms/approximation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d12a8f8b8e4a633675971c727d959222db8ab79d --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/approximation/__init__.py @@ -0,0 +1,25 @@ +"""Approximations of graph properties and Heuristic methods for optimization. + +The functions in this class are not imported into the top-level ``networkx`` +namespace so the easiest way to use them is with:: + + >>> from networkx.algorithms import approximation + +Another option is to import the specific function with +``from networkx.algorithms.approximation import function_name``. + +""" + +from networkx.algorithms.approximation.clustering_coefficient import * +from networkx.algorithms.approximation.clique import * +from networkx.algorithms.approximation.connectivity import * +from networkx.algorithms.approximation.distance_measures import * +from networkx.algorithms.approximation.dominating_set import * +from networkx.algorithms.approximation.kcomponents import * +from networkx.algorithms.approximation.matching import * +from networkx.algorithms.approximation.ramsey import * +from networkx.algorithms.approximation.steinertree import * +from networkx.algorithms.approximation.traveling_salesman import * +from networkx.algorithms.approximation.treewidth import * +from networkx.algorithms.approximation.vertex_cover import * +from networkx.algorithms.approximation.maxcut import * diff --git a/lib/python3.10/site-packages/networkx/algorithms/approximation/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/approximation/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..84bf15b39250a62f4279ef40efc9219aecf9cc98 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/approximation/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/approximation/__pycache__/clique.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/approximation/__pycache__/clique.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ab156057f16de446622052d892337a234b80bff Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/approximation/__pycache__/clique.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/approximation/__pycache__/clustering_coefficient.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/approximation/__pycache__/clustering_coefficient.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b151e3572ce829881cd2069f3811eeeb5e54a92 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/approximation/__pycache__/clustering_coefficient.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/approximation/__pycache__/connectivity.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/approximation/__pycache__/connectivity.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca67e2465a47eaba64a8729051d76676191a97e4 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/approximation/__pycache__/connectivity.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/approximation/__pycache__/distance_measures.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/approximation/__pycache__/distance_measures.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d4ccff176eaeb211ab099e07a139cc032398eb6a Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/approximation/__pycache__/distance_measures.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/approximation/__pycache__/dominating_set.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/approximation/__pycache__/dominating_set.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..edd1832217289bfc9fd476f1140627fad6c12122 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/approximation/__pycache__/dominating_set.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/approximation/__pycache__/kcomponents.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/approximation/__pycache__/kcomponents.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a94a7c5345a916a327c5693db4d26c7305f5576 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/approximation/__pycache__/kcomponents.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/approximation/__pycache__/matching.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/approximation/__pycache__/matching.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..53c97ea77cf77d09938265896f5952d07c03c8ca Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/approximation/__pycache__/matching.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/approximation/__pycache__/maxcut.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/approximation/__pycache__/maxcut.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c441f73d2989e4201de35c639dcf19e790208a6e Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/approximation/__pycache__/maxcut.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/approximation/__pycache__/ramsey.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/approximation/__pycache__/ramsey.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e640e8594cb093dd032a151fd32c3f2431714796 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/approximation/__pycache__/ramsey.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/approximation/__pycache__/steinertree.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/approximation/__pycache__/steinertree.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a01e86f078b11eb8fa66890cd3a8bc994e96f45 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/approximation/__pycache__/steinertree.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/approximation/__pycache__/traveling_salesman.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/approximation/__pycache__/traveling_salesman.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1483ea83efc5abb51a7aed29ada77f73ba159ea5 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/approximation/__pycache__/traveling_salesman.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/approximation/__pycache__/treewidth.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/approximation/__pycache__/treewidth.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..30fe787f773345958d3ad86256dec2adcd03ac4c Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/approximation/__pycache__/treewidth.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/approximation/__pycache__/vertex_cover.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/approximation/__pycache__/vertex_cover.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d49aff3c0f509a96bf514c02f89b712b46d205d Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/approximation/__pycache__/vertex_cover.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/approximation/clique.py b/lib/python3.10/site-packages/networkx/algorithms/approximation/clique.py new file mode 100644 index 0000000000000000000000000000000000000000..ed0f3506369046c749d118c4264afeb4f054f2cd --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/approximation/clique.py @@ -0,0 +1,259 @@ +"""Functions for computing large cliques and maximum independent sets.""" + +import networkx as nx +from networkx.algorithms.approximation import ramsey +from networkx.utils import not_implemented_for + +__all__ = [ + "clique_removal", + "max_clique", + "large_clique_size", + "maximum_independent_set", +] + + +@not_implemented_for("directed") +@not_implemented_for("multigraph") +@nx._dispatchable +def maximum_independent_set(G): + """Returns an approximate maximum independent set. + + Independent set or stable set is a set of vertices in a graph, no two of + which are adjacent. That is, it is a set I of vertices such that for every + two vertices in I, there is no edge connecting the two. Equivalently, each + edge in the graph has at most one endpoint in I. The size of an independent + set is the number of vertices it contains [1]_. + + A maximum independent set is a largest independent set for a given graph G + and its size is denoted $\\alpha(G)$. The problem of finding such a set is called + the maximum independent set problem and is an NP-hard optimization problem. + As such, it is unlikely that there exists an efficient algorithm for finding + a maximum independent set of a graph. + + The Independent Set algorithm is based on [2]_. + + Parameters + ---------- + G : NetworkX graph + Undirected graph + + Returns + ------- + iset : Set + The apx-maximum independent set + + Examples + -------- + >>> G = nx.path_graph(10) + >>> nx.approximation.maximum_independent_set(G) + {0, 2, 4, 6, 9} + + Raises + ------ + NetworkXNotImplemented + If the graph is directed or is a multigraph. + + Notes + ----- + Finds the $O(|V|/(log|V|)^2)$ apx of independent set in the worst case. + + References + ---------- + .. [1] `Wikipedia: Independent set + `_ + .. [2] Boppana, R., & Halldórsson, M. M. (1992). + Approximating maximum independent sets by excluding subgraphs. + BIT Numerical Mathematics, 32(2), 180–196. Springer. + """ + iset, _ = clique_removal(G) + return iset + + +@not_implemented_for("directed") +@not_implemented_for("multigraph") +@nx._dispatchable +def max_clique(G): + r"""Find the Maximum Clique + + Finds the $O(|V|/(log|V|)^2)$ apx of maximum clique/independent set + in the worst case. + + Parameters + ---------- + G : NetworkX graph + Undirected graph + + Returns + ------- + clique : set + The apx-maximum clique of the graph + + Examples + -------- + >>> G = nx.path_graph(10) + >>> nx.approximation.max_clique(G) + {8, 9} + + Raises + ------ + NetworkXNotImplemented + If the graph is directed or is a multigraph. + + Notes + ----- + A clique in an undirected graph G = (V, E) is a subset of the vertex set + `C \subseteq V` such that for every two vertices in C there exists an edge + connecting the two. This is equivalent to saying that the subgraph + induced by C is complete (in some cases, the term clique may also refer + to the subgraph). + + A maximum clique is a clique of the largest possible size in a given graph. + The clique number `\omega(G)` of a graph G is the number of + vertices in a maximum clique in G. The intersection number of + G is the smallest number of cliques that together cover all edges of G. + + https://en.wikipedia.org/wiki/Maximum_clique + + References + ---------- + .. [1] Boppana, R., & Halldórsson, M. M. (1992). + Approximating maximum independent sets by excluding subgraphs. + BIT Numerical Mathematics, 32(2), 180–196. Springer. + doi:10.1007/BF01994876 + """ + # finding the maximum clique in a graph is equivalent to finding + # the independent set in the complementary graph + cgraph = nx.complement(G) + iset, _ = clique_removal(cgraph) + return iset + + +@not_implemented_for("directed") +@not_implemented_for("multigraph") +@nx._dispatchable +def clique_removal(G): + r"""Repeatedly remove cliques from the graph. + + Results in a $O(|V|/(\log |V|)^2)$ approximation of maximum clique + and independent set. Returns the largest independent set found, along + with found maximal cliques. + + Parameters + ---------- + G : NetworkX graph + Undirected graph + + Returns + ------- + max_ind_cliques : (set, list) tuple + 2-tuple of Maximal Independent Set and list of maximal cliques (sets). + + Examples + -------- + >>> G = nx.path_graph(10) + >>> nx.approximation.clique_removal(G) + ({0, 2, 4, 6, 9}, [{0, 1}, {2, 3}, {4, 5}, {6, 7}, {8, 9}]) + + Raises + ------ + NetworkXNotImplemented + If the graph is directed or is a multigraph. + + References + ---------- + .. [1] Boppana, R., & Halldórsson, M. M. (1992). + Approximating maximum independent sets by excluding subgraphs. + BIT Numerical Mathematics, 32(2), 180–196. Springer. + """ + graph = G.copy() + c_i, i_i = ramsey.ramsey_R2(graph) + cliques = [c_i] + isets = [i_i] + while graph: + graph.remove_nodes_from(c_i) + c_i, i_i = ramsey.ramsey_R2(graph) + if c_i: + cliques.append(c_i) + if i_i: + isets.append(i_i) + # Determine the largest independent set as measured by cardinality. + maxiset = max(isets, key=len) + return maxiset, cliques + + +@not_implemented_for("directed") +@not_implemented_for("multigraph") +@nx._dispatchable +def large_clique_size(G): + """Find the size of a large clique in a graph. + + A *clique* is a subset of nodes in which each pair of nodes is + adjacent. This function is a heuristic for finding the size of a + large clique in the graph. + + Parameters + ---------- + G : NetworkX graph + + Returns + ------- + k: integer + The size of a large clique in the graph. + + Examples + -------- + >>> G = nx.path_graph(10) + >>> nx.approximation.large_clique_size(G) + 2 + + Raises + ------ + NetworkXNotImplemented + If the graph is directed or is a multigraph. + + Notes + ----- + This implementation is from [1]_. Its worst case time complexity is + :math:`O(n d^2)`, where *n* is the number of nodes in the graph and + *d* is the maximum degree. + + This function is a heuristic, which means it may work well in + practice, but there is no rigorous mathematical guarantee on the + ratio between the returned number and the actual largest clique size + in the graph. + + References + ---------- + .. [1] Pattabiraman, Bharath, et al. + "Fast Algorithms for the Maximum Clique Problem on Massive Graphs + with Applications to Overlapping Community Detection." + *Internet Mathematics* 11.4-5 (2015): 421--448. + + + See also + -------- + + :func:`networkx.algorithms.approximation.clique.max_clique` + A function that returns an approximate maximum clique with a + guarantee on the approximation ratio. + + :mod:`networkx.algorithms.clique` + Functions for finding the exact maximum clique in a graph. + + """ + degrees = G.degree + + def _clique_heuristic(G, U, size, best_size): + if not U: + return max(best_size, size) + u = max(U, key=degrees) + U.remove(u) + N_prime = {v for v in G[u] if degrees[v] >= best_size} + return _clique_heuristic(G, U & N_prime, size + 1, best_size) + + best_size = 0 + nodes = (u for u in G if degrees[u] >= best_size) + for u in nodes: + neighbors = {v for v in G[u] if degrees[v] >= best_size} + best_size = _clique_heuristic(G, neighbors, 1, best_size) + return best_size diff --git a/lib/python3.10/site-packages/networkx/algorithms/approximation/clustering_coefficient.py b/lib/python3.10/site-packages/networkx/algorithms/approximation/clustering_coefficient.py new file mode 100644 index 0000000000000000000000000000000000000000..545fc65533b8d8f44b35498aa7129c97efc0bc52 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/approximation/clustering_coefficient.py @@ -0,0 +1,71 @@ +import networkx as nx +from networkx.utils import not_implemented_for, py_random_state + +__all__ = ["average_clustering"] + + +@not_implemented_for("directed") +@py_random_state(2) +@nx._dispatchable(name="approximate_average_clustering") +def average_clustering(G, trials=1000, seed=None): + r"""Estimates the average clustering coefficient of G. + + The local clustering of each node in `G` is the fraction of triangles + that actually exist over all possible triangles in its neighborhood. + The average clustering coefficient of a graph `G` is the mean of + local clusterings. + + This function finds an approximate average clustering coefficient + for G by repeating `n` times (defined in `trials`) the following + experiment: choose a node at random, choose two of its neighbors + at random, and check if they are connected. The approximate + coefficient is the fraction of triangles found over the number + of trials [1]_. + + Parameters + ---------- + G : NetworkX graph + + trials : integer + Number of trials to perform (default 1000). + + seed : integer, random_state, or None (default) + Indicator of random number generation state. + See :ref:`Randomness`. + + Returns + ------- + c : float + Approximated average clustering coefficient. + + Examples + -------- + >>> from networkx.algorithms import approximation + >>> G = nx.erdos_renyi_graph(10, 0.2, seed=10) + >>> approximation.average_clustering(G, trials=1000, seed=10) + 0.214 + + Raises + ------ + NetworkXNotImplemented + If G is directed. + + References + ---------- + .. [1] Schank, Thomas, and Dorothea Wagner. Approximating clustering + coefficient and transitivity. Universität Karlsruhe, Fakultät für + Informatik, 2004. + https://doi.org/10.5445/IR/1000001239 + + """ + n = len(G) + triangles = 0 + nodes = list(G) + for i in [int(seed.random() * n) for i in range(trials)]: + nbrs = list(G[nodes[i]]) + if len(nbrs) < 2: + continue + u, v = seed.sample(nbrs, 2) + if u in G[v]: + triangles += 1 + return triangles / trials diff --git a/lib/python3.10/site-packages/networkx/algorithms/approximation/connectivity.py b/lib/python3.10/site-packages/networkx/algorithms/approximation/connectivity.py new file mode 100644 index 0000000000000000000000000000000000000000..0b596fdf782bbb223e5203fc066bbac157a29b24 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/approximation/connectivity.py @@ -0,0 +1,412 @@ +"""Fast approximation for node connectivity""" + +import itertools +from operator import itemgetter + +import networkx as nx + +__all__ = [ + "local_node_connectivity", + "node_connectivity", + "all_pairs_node_connectivity", +] + + +@nx._dispatchable(name="approximate_local_node_connectivity") +def local_node_connectivity(G, source, target, cutoff=None): + """Compute node connectivity between source and target. + + Pairwise or local node connectivity between two distinct and nonadjacent + nodes is the minimum number of nodes that must be removed (minimum + separating cutset) to disconnect them. By Menger's theorem, this is equal + to the number of node independent paths (paths that share no nodes other + than source and target). Which is what we compute in this function. + + This algorithm is a fast approximation that gives an strict lower + bound on the actual number of node independent paths between two nodes [1]_. + It works for both directed and undirected graphs. + + Parameters + ---------- + + G : NetworkX graph + + source : node + Starting node for node connectivity + + target : node + Ending node for node connectivity + + cutoff : integer + Maximum node connectivity to consider. If None, the minimum degree + of source or target is used as a cutoff. Default value None. + + Returns + ------- + k: integer + pairwise node connectivity + + Examples + -------- + >>> # Platonic octahedral graph has node connectivity 4 + >>> # for each non adjacent node pair + >>> from networkx.algorithms import approximation as approx + >>> G = nx.octahedral_graph() + >>> approx.local_node_connectivity(G, 0, 5) + 4 + + Notes + ----- + This algorithm [1]_ finds node independents paths between two nodes by + computing their shortest path using BFS, marking the nodes of the path + found as 'used' and then searching other shortest paths excluding the + nodes marked as used until no more paths exist. It is not exact because + a shortest path could use nodes that, if the path were longer, may belong + to two different node independent paths. Thus it only guarantees an + strict lower bound on node connectivity. + + Note that the authors propose a further refinement, losing accuracy and + gaining speed, which is not implemented yet. + + See also + -------- + all_pairs_node_connectivity + node_connectivity + + References + ---------- + .. [1] White, Douglas R., and Mark Newman. 2001 A Fast Algorithm for + Node-Independent Paths. Santa Fe Institute Working Paper #01-07-035 + http://eclectic.ss.uci.edu/~drwhite/working.pdf + + """ + if target == source: + raise nx.NetworkXError("source and target have to be different nodes.") + + # Maximum possible node independent paths + if G.is_directed(): + possible = min(G.out_degree(source), G.in_degree(target)) + else: + possible = min(G.degree(source), G.degree(target)) + + K = 0 + if not possible: + return K + + if cutoff is None: + cutoff = float("inf") + + exclude = set() + for i in range(min(possible, cutoff)): + try: + path = _bidirectional_shortest_path(G, source, target, exclude) + exclude.update(set(path)) + K += 1 + except nx.NetworkXNoPath: + break + + return K + + +@nx._dispatchable(name="approximate_node_connectivity") +def node_connectivity(G, s=None, t=None): + r"""Returns an approximation for node connectivity for a graph or digraph G. + + Node connectivity is equal to the minimum number of nodes that + must be removed to disconnect G or render it trivial. By Menger's theorem, + this is equal to the number of node independent paths (paths that + share no nodes other than source and target). + + If source and target nodes are provided, this function returns the + local node connectivity: the minimum number of nodes that must be + removed to break all paths from source to target in G. + + This algorithm is based on a fast approximation that gives an strict lower + bound on the actual number of node independent paths between two nodes [1]_. + It works for both directed and undirected graphs. + + Parameters + ---------- + G : NetworkX graph + Undirected graph + + s : node + Source node. Optional. Default value: None. + + t : node + Target node. Optional. Default value: None. + + Returns + ------- + K : integer + Node connectivity of G, or local node connectivity if source + and target are provided. + + Examples + -------- + >>> # Platonic octahedral graph is 4-node-connected + >>> from networkx.algorithms import approximation as approx + >>> G = nx.octahedral_graph() + >>> approx.node_connectivity(G) + 4 + + Notes + ----- + This algorithm [1]_ finds node independents paths between two nodes by + computing their shortest path using BFS, marking the nodes of the path + found as 'used' and then searching other shortest paths excluding the + nodes marked as used until no more paths exist. It is not exact because + a shortest path could use nodes that, if the path were longer, may belong + to two different node independent paths. Thus it only guarantees an + strict lower bound on node connectivity. + + See also + -------- + all_pairs_node_connectivity + local_node_connectivity + + References + ---------- + .. [1] White, Douglas R., and Mark Newman. 2001 A Fast Algorithm for + Node-Independent Paths. Santa Fe Institute Working Paper #01-07-035 + http://eclectic.ss.uci.edu/~drwhite/working.pdf + + """ + if (s is not None and t is None) or (s is None and t is not None): + raise nx.NetworkXError("Both source and target must be specified.") + + # Local node connectivity + if s is not None and t is not None: + if s not in G: + raise nx.NetworkXError(f"node {s} not in graph") + if t not in G: + raise nx.NetworkXError(f"node {t} not in graph") + return local_node_connectivity(G, s, t) + + # Global node connectivity + if G.is_directed(): + connected_func = nx.is_weakly_connected + iter_func = itertools.permutations + + def neighbors(v): + return itertools.chain(G.predecessors(v), G.successors(v)) + + else: + connected_func = nx.is_connected + iter_func = itertools.combinations + neighbors = G.neighbors + + if not connected_func(G): + return 0 + + # Choose a node with minimum degree + v, minimum_degree = min(G.degree(), key=itemgetter(1)) + # Node connectivity is bounded by minimum degree + K = minimum_degree + # compute local node connectivity with all non-neighbors nodes + # and store the minimum + for w in set(G) - set(neighbors(v)) - {v}: + K = min(K, local_node_connectivity(G, v, w, cutoff=K)) + # Same for non adjacent pairs of neighbors of v + for x, y in iter_func(neighbors(v), 2): + if y not in G[x] and x != y: + K = min(K, local_node_connectivity(G, x, y, cutoff=K)) + return K + + +@nx._dispatchable(name="approximate_all_pairs_node_connectivity") +def all_pairs_node_connectivity(G, nbunch=None, cutoff=None): + """Compute node connectivity between all pairs of nodes. + + Pairwise or local node connectivity between two distinct and nonadjacent + nodes is the minimum number of nodes that must be removed (minimum + separating cutset) to disconnect them. By Menger's theorem, this is equal + to the number of node independent paths (paths that share no nodes other + than source and target). Which is what we compute in this function. + + This algorithm is a fast approximation that gives an strict lower + bound on the actual number of node independent paths between two nodes [1]_. + It works for both directed and undirected graphs. + + + Parameters + ---------- + G : NetworkX graph + + nbunch: container + Container of nodes. If provided node connectivity will be computed + only over pairs of nodes in nbunch. + + cutoff : integer + Maximum node connectivity to consider. If None, the minimum degree + of source or target is used as a cutoff in each pair of nodes. + Default value None. + + Returns + ------- + K : dictionary + Dictionary, keyed by source and target, of pairwise node connectivity + + Examples + -------- + A 3 node cycle with one extra node attached has connectivity 2 between all + nodes in the cycle and connectivity 1 between the extra node and the rest: + + >>> G = nx.cycle_graph(3) + >>> G.add_edge(2, 3) + >>> import pprint # for nice dictionary formatting + >>> pprint.pprint(nx.all_pairs_node_connectivity(G)) + {0: {1: 2, 2: 2, 3: 1}, + 1: {0: 2, 2: 2, 3: 1}, + 2: {0: 2, 1: 2, 3: 1}, + 3: {0: 1, 1: 1, 2: 1}} + + See Also + -------- + local_node_connectivity + node_connectivity + + References + ---------- + .. [1] White, Douglas R., and Mark Newman. 2001 A Fast Algorithm for + Node-Independent Paths. Santa Fe Institute Working Paper #01-07-035 + http://eclectic.ss.uci.edu/~drwhite/working.pdf + """ + if nbunch is None: + nbunch = G + else: + nbunch = set(nbunch) + + directed = G.is_directed() + if directed: + iter_func = itertools.permutations + else: + iter_func = itertools.combinations + + all_pairs = {n: {} for n in nbunch} + + for u, v in iter_func(nbunch, 2): + k = local_node_connectivity(G, u, v, cutoff=cutoff) + all_pairs[u][v] = k + if not directed: + all_pairs[v][u] = k + + return all_pairs + + +def _bidirectional_shortest_path(G, source, target, exclude): + """Returns shortest path between source and target ignoring nodes in the + container 'exclude'. + + Parameters + ---------- + + G : NetworkX graph + + source : node + Starting node for path + + target : node + Ending node for path + + exclude: container + Container for nodes to exclude from the search for shortest paths + + Returns + ------- + path: list + Shortest path between source and target ignoring nodes in 'exclude' + + Raises + ------ + NetworkXNoPath + If there is no path or if nodes are adjacent and have only one path + between them + + Notes + ----- + This function and its helper are originally from + networkx.algorithms.shortest_paths.unweighted and are modified to + accept the extra parameter 'exclude', which is a container for nodes + already used in other paths that should be ignored. + + References + ---------- + .. [1] White, Douglas R., and Mark Newman. 2001 A Fast Algorithm for + Node-Independent Paths. Santa Fe Institute Working Paper #01-07-035 + http://eclectic.ss.uci.edu/~drwhite/working.pdf + + """ + # call helper to do the real work + results = _bidirectional_pred_succ(G, source, target, exclude) + pred, succ, w = results + + # build path from pred+w+succ + path = [] + # from source to w + while w is not None: + path.append(w) + w = pred[w] + path.reverse() + # from w to target + w = succ[path[-1]] + while w is not None: + path.append(w) + w = succ[w] + + return path + + +def _bidirectional_pred_succ(G, source, target, exclude): + # does BFS from both source and target and meets in the middle + # excludes nodes in the container "exclude" from the search + + # handle either directed or undirected + if G.is_directed(): + Gpred = G.predecessors + Gsucc = G.successors + else: + Gpred = G.neighbors + Gsucc = G.neighbors + + # predecessor and successors in search + pred = {source: None} + succ = {target: None} + + # initialize fringes, start with forward + forward_fringe = [source] + reverse_fringe = [target] + + level = 0 + + while forward_fringe and reverse_fringe: + # Make sure that we iterate one step forward and one step backwards + # thus source and target will only trigger "found path" when they are + # adjacent and then they can be safely included in the container 'exclude' + level += 1 + if level % 2 != 0: + this_level = forward_fringe + forward_fringe = [] + for v in this_level: + for w in Gsucc(v): + if w in exclude: + continue + if w not in pred: + forward_fringe.append(w) + pred[w] = v + if w in succ: + return pred, succ, w # found path + else: + this_level = reverse_fringe + reverse_fringe = [] + for v in this_level: + for w in Gpred(v): + if w in exclude: + continue + if w not in succ: + succ[w] = v + reverse_fringe.append(w) + if w in pred: + return pred, succ, w # found path + + raise nx.NetworkXNoPath(f"No path between {source} and {target}.") diff --git a/lib/python3.10/site-packages/networkx/algorithms/approximation/distance_measures.py b/lib/python3.10/site-packages/networkx/algorithms/approximation/distance_measures.py new file mode 100644 index 0000000000000000000000000000000000000000..d5847e65a2a401cd607436297fe4c1bbc81db3d9 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/approximation/distance_measures.py @@ -0,0 +1,150 @@ +"""Distance measures approximated metrics.""" + +import networkx as nx +from networkx.utils.decorators import py_random_state + +__all__ = ["diameter"] + + +@py_random_state(1) +@nx._dispatchable(name="approximate_diameter") +def diameter(G, seed=None): + """Returns a lower bound on the diameter of the graph G. + + The function computes a lower bound on the diameter (i.e., the maximum eccentricity) + of a directed or undirected graph G. The procedure used varies depending on the graph + being directed or not. + + If G is an `undirected` graph, then the function uses the `2-sweep` algorithm [1]_. + The main idea is to pick the farthest node from a random node and return its eccentricity. + + Otherwise, if G is a `directed` graph, the function uses the `2-dSweep` algorithm [2]_, + The procedure starts by selecting a random source node $s$ from which it performs a + forward and a backward BFS. Let $a_1$ and $a_2$ be the farthest nodes in the forward and + backward cases, respectively. Then, it computes the backward eccentricity of $a_1$ using + a backward BFS and the forward eccentricity of $a_2$ using a forward BFS. + Finally, it returns the best lower bound between the two. + + In both cases, the time complexity is linear with respect to the size of G. + + Parameters + ---------- + G : NetworkX graph + + seed : integer, random_state, or None (default) + Indicator of random number generation state. + See :ref:`Randomness`. + + Returns + ------- + d : integer + Lower Bound on the Diameter of G + + Examples + -------- + >>> G = nx.path_graph(10) # undirected graph + >>> nx.diameter(G) + 9 + >>> G = nx.cycle_graph(3, create_using=nx.DiGraph) # directed graph + >>> nx.diameter(G) + 2 + + Raises + ------ + NetworkXError + If the graph is empty or + If the graph is undirected and not connected or + If the graph is directed and not strongly connected. + + See Also + -------- + networkx.algorithms.distance_measures.diameter + + References + ---------- + .. [1] Magnien, Clémence, Matthieu Latapy, and Michel Habib. + *Fast computation of empirically tight bounds for the diameter of massive graphs.* + Journal of Experimental Algorithmics (JEA), 2009. + https://arxiv.org/pdf/0904.2728.pdf + .. [2] Crescenzi, Pierluigi, Roberto Grossi, Leonardo Lanzi, and Andrea Marino. + *On computing the diameter of real-world directed (weighted) graphs.* + International Symposium on Experimental Algorithms. Springer, Berlin, Heidelberg, 2012. + https://courses.cs.ut.ee/MTAT.03.238/2014_fall/uploads/Main/diameter.pdf + """ + # if G is empty + if not G: + raise nx.NetworkXError("Expected non-empty NetworkX graph!") + # if there's only a node + if G.number_of_nodes() == 1: + return 0 + # if G is directed + if G.is_directed(): + return _two_sweep_directed(G, seed) + # else if G is undirected + return _two_sweep_undirected(G, seed) + + +def _two_sweep_undirected(G, seed): + """Helper function for finding a lower bound on the diameter + for undirected Graphs. + + The idea is to pick the farthest node from a random node + and return its eccentricity. + + ``G`` is a NetworkX undirected graph. + + .. note:: + + ``seed`` is a random.Random or numpy.random.RandomState instance + """ + # select a random source node + source = seed.choice(list(G)) + # get the distances to the other nodes + distances = nx.shortest_path_length(G, source) + # if some nodes have not been visited, then the graph is not connected + if len(distances) != len(G): + raise nx.NetworkXError("Graph not connected.") + # take a node that is (one of) the farthest nodes from the source + *_, node = distances + # return the eccentricity of the node + return nx.eccentricity(G, node) + + +def _two_sweep_directed(G, seed): + """Helper function for finding a lower bound on the diameter + for directed Graphs. + + It implements 2-dSweep, the directed version of the 2-sweep algorithm. + The algorithm follows the following steps. + 1. Select a source node $s$ at random. + 2. Perform a forward BFS from $s$ to select a node $a_1$ at the maximum + distance from the source, and compute $LB_1$, the backward eccentricity of $a_1$. + 3. Perform a backward BFS from $s$ to select a node $a_2$ at the maximum + distance from the source, and compute $LB_2$, the forward eccentricity of $a_2$. + 4. Return the maximum between $LB_1$ and $LB_2$. + + ``G`` is a NetworkX directed graph. + + .. note:: + + ``seed`` is a random.Random or numpy.random.RandomState instance + """ + # get a new digraph G' with the edges reversed in the opposite direction + G_reversed = G.reverse() + # select a random source node + source = seed.choice(list(G)) + # compute forward distances from source + forward_distances = nx.shortest_path_length(G, source) + # compute backward distances from source + backward_distances = nx.shortest_path_length(G_reversed, source) + # if either the source can't reach every node or not every node + # can reach the source, then the graph is not strongly connected + n = len(G) + if len(forward_distances) != n or len(backward_distances) != n: + raise nx.NetworkXError("DiGraph not strongly connected.") + # take a node a_1 at the maximum distance from the source in G + *_, a_1 = forward_distances + # take a node a_2 at the maximum distance from the source in G_reversed + *_, a_2 = backward_distances + # return the max between the backward eccentricity of a_1 and the forward eccentricity of a_2 + return max(nx.eccentricity(G_reversed, a_1), nx.eccentricity(G, a_2)) diff --git a/lib/python3.10/site-packages/networkx/algorithms/approximation/dominating_set.py b/lib/python3.10/site-packages/networkx/algorithms/approximation/dominating_set.py new file mode 100644 index 0000000000000000000000000000000000000000..e568a827ff99dad5390e8d7feb886cf92a3a6cad --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/approximation/dominating_set.py @@ -0,0 +1,149 @@ +"""Functions for finding node and edge dominating sets. + +A `dominating set`_ for an undirected graph *G* with vertex set *V* +and edge set *E* is a subset *D* of *V* such that every vertex not in +*D* is adjacent to at least one member of *D*. An `edge dominating set`_ +is a subset *F* of *E* such that every edge not in *F* is +incident to an endpoint of at least one edge in *F*. + +.. _dominating set: https://en.wikipedia.org/wiki/Dominating_set +.. _edge dominating set: https://en.wikipedia.org/wiki/Edge_dominating_set + +""" + +import networkx as nx + +from ...utils import not_implemented_for +from ..matching import maximal_matching + +__all__ = ["min_weighted_dominating_set", "min_edge_dominating_set"] + + +# TODO Why doesn't this algorithm work for directed graphs? +@not_implemented_for("directed") +@nx._dispatchable(node_attrs="weight") +def min_weighted_dominating_set(G, weight=None): + r"""Returns a dominating set that approximates the minimum weight node + dominating set. + + Parameters + ---------- + G : NetworkX graph + Undirected graph. + + weight : string + The node attribute storing the weight of an node. If provided, + the node attribute with this key must be a number for each + node. If not provided, each node is assumed to have weight one. + + Returns + ------- + min_weight_dominating_set : set + A set of nodes, the sum of whose weights is no more than `(\log + w(V)) w(V^*)`, where `w(V)` denotes the sum of the weights of + each node in the graph and `w(V^*)` denotes the sum of the + weights of each node in the minimum weight dominating set. + + Examples + -------- + >>> G = nx.Graph([(0, 1), (0, 4), (1, 4), (1, 2), (2, 3), (3, 4), (2, 5)]) + >>> nx.approximation.min_weighted_dominating_set(G) + {1, 2, 4} + + Raises + ------ + NetworkXNotImplemented + If G is directed. + + Notes + ----- + This algorithm computes an approximate minimum weighted dominating + set for the graph `G`. The returned solution has weight `(\log + w(V)) w(V^*)`, where `w(V)` denotes the sum of the weights of each + node in the graph and `w(V^*)` denotes the sum of the weights of + each node in the minimum weight dominating set for the graph. + + This implementation of the algorithm runs in $O(m)$ time, where $m$ + is the number of edges in the graph. + + References + ---------- + .. [1] Vazirani, Vijay V. + *Approximation Algorithms*. + Springer Science & Business Media, 2001. + + """ + # The unique dominating set for the null graph is the empty set. + if len(G) == 0: + return set() + + # This is the dominating set that will eventually be returned. + dom_set = set() + + def _cost(node_and_neighborhood): + """Returns the cost-effectiveness of greedily choosing the given + node. + + `node_and_neighborhood` is a two-tuple comprising a node and its + closed neighborhood. + + """ + v, neighborhood = node_and_neighborhood + return G.nodes[v].get(weight, 1) / len(neighborhood - dom_set) + + # This is a set of all vertices not already covered by the + # dominating set. + vertices = set(G) + # This is a dictionary mapping each node to the closed neighborhood + # of that node. + neighborhoods = {v: {v} | set(G[v]) for v in G} + + # Continue until all vertices are adjacent to some node in the + # dominating set. + while vertices: + # Find the most cost-effective node to add, along with its + # closed neighborhood. + dom_node, min_set = min(neighborhoods.items(), key=_cost) + # Add the node to the dominating set and reduce the remaining + # set of nodes to cover. + dom_set.add(dom_node) + del neighborhoods[dom_node] + vertices -= min_set + + return dom_set + + +@nx._dispatchable +def min_edge_dominating_set(G): + r"""Returns minimum cardinality edge dominating set. + + Parameters + ---------- + G : NetworkX graph + Undirected graph + + Returns + ------- + min_edge_dominating_set : set + Returns a set of dominating edges whose size is no more than 2 * OPT. + + Examples + -------- + >>> G = nx.petersen_graph() + >>> nx.approximation.min_edge_dominating_set(G) + {(0, 1), (4, 9), (6, 8), (5, 7), (2, 3)} + + Raises + ------ + ValueError + If the input graph `G` is empty. + + Notes + ----- + The algorithm computes an approximate solution to the edge dominating set + problem. The result is no more than 2 * OPT in terms of size of the set. + Runtime of the algorithm is $O(|E|)$. + """ + if not G: + raise ValueError("Expected non-empty NetworkX graph!") + return maximal_matching(G) diff --git a/lib/python3.10/site-packages/networkx/algorithms/approximation/kcomponents.py b/lib/python3.10/site-packages/networkx/algorithms/approximation/kcomponents.py new file mode 100644 index 0000000000000000000000000000000000000000..f726a4e686103b2d30d443f21de326685c615bf4 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/approximation/kcomponents.py @@ -0,0 +1,369 @@ +"""Fast approximation for k-component structure""" + +import itertools +from collections import defaultdict +from collections.abc import Mapping +from functools import cached_property + +import networkx as nx +from networkx.algorithms.approximation import local_node_connectivity +from networkx.exception import NetworkXError +from networkx.utils import not_implemented_for + +__all__ = ["k_components"] + + +@not_implemented_for("directed") +@nx._dispatchable(name="approximate_k_components") +def k_components(G, min_density=0.95): + r"""Returns the approximate k-component structure of a graph G. + + A `k`-component is a maximal subgraph of a graph G that has, at least, + node connectivity `k`: we need to remove at least `k` nodes to break it + into more components. `k`-components have an inherent hierarchical + structure because they are nested in terms of connectivity: a connected + graph can contain several 2-components, each of which can contain + one or more 3-components, and so forth. + + This implementation is based on the fast heuristics to approximate + the `k`-component structure of a graph [1]_. Which, in turn, it is based on + a fast approximation algorithm for finding good lower bounds of the number + of node independent paths between two nodes [2]_. + + Parameters + ---------- + G : NetworkX graph + Undirected graph + + min_density : Float + Density relaxation threshold. Default value 0.95 + + Returns + ------- + k_components : dict + Dictionary with connectivity level `k` as key and a list of + sets of nodes that form a k-component of level `k` as values. + + Raises + ------ + NetworkXNotImplemented + If G is directed. + + Examples + -------- + >>> # Petersen graph has 10 nodes and it is triconnected, thus all + >>> # nodes are in a single component on all three connectivity levels + >>> from networkx.algorithms import approximation as apxa + >>> G = nx.petersen_graph() + >>> k_components = apxa.k_components(G) + + Notes + ----- + The logic of the approximation algorithm for computing the `k`-component + structure [1]_ is based on repeatedly applying simple and fast algorithms + for `k`-cores and biconnected components in order to narrow down the + number of pairs of nodes over which we have to compute White and Newman's + approximation algorithm for finding node independent paths [2]_. More + formally, this algorithm is based on Whitney's theorem, which states + an inclusion relation among node connectivity, edge connectivity, and + minimum degree for any graph G. This theorem implies that every + `k`-component is nested inside a `k`-edge-component, which in turn, + is contained in a `k`-core. Thus, this algorithm computes node independent + paths among pairs of nodes in each biconnected part of each `k`-core, + and repeats this procedure for each `k` from 3 to the maximal core number + of a node in the input graph. + + Because, in practice, many nodes of the core of level `k` inside a + bicomponent actually are part of a component of level k, the auxiliary + graph needed for the algorithm is likely to be very dense. Thus, we use + a complement graph data structure (see `AntiGraph`) to save memory. + AntiGraph only stores information of the edges that are *not* present + in the actual auxiliary graph. When applying algorithms to this + complement graph data structure, it behaves as if it were the dense + version. + + See also + -------- + k_components + + References + ---------- + .. [1] Torrents, J. and F. Ferraro (2015) Structural Cohesion: + Visualization and Heuristics for Fast Computation. + https://arxiv.org/pdf/1503.04476v1 + + .. [2] White, Douglas R., and Mark Newman (2001) A Fast Algorithm for + Node-Independent Paths. Santa Fe Institute Working Paper #01-07-035 + https://www.santafe.edu/research/results/working-papers/fast-approximation-algorithms-for-finding-node-ind + + .. [3] Moody, J. and D. White (2003). Social cohesion and embeddedness: + A hierarchical conception of social groups. + American Sociological Review 68(1), 103--28. + https://doi.org/10.2307/3088904 + + """ + # Dictionary with connectivity level (k) as keys and a list of + # sets of nodes that form a k-component as values + k_components = defaultdict(list) + # make a few functions local for speed + node_connectivity = local_node_connectivity + k_core = nx.k_core + core_number = nx.core_number + biconnected_components = nx.biconnected_components + combinations = itertools.combinations + # Exact solution for k = {1,2} + # There is a linear time algorithm for triconnectivity, if we had an + # implementation available we could start from k = 4. + for component in nx.connected_components(G): + # isolated nodes have connectivity 0 + comp = set(component) + if len(comp) > 1: + k_components[1].append(comp) + for bicomponent in nx.biconnected_components(G): + # avoid considering dyads as bicomponents + bicomp = set(bicomponent) + if len(bicomp) > 2: + k_components[2].append(bicomp) + # There is no k-component of k > maximum core number + # \kappa(G) <= \lambda(G) <= \delta(G) + g_cnumber = core_number(G) + max_core = max(g_cnumber.values()) + for k in range(3, max_core + 1): + C = k_core(G, k, core_number=g_cnumber) + for nodes in biconnected_components(C): + # Build a subgraph SG induced by the nodes that are part of + # each biconnected component of the k-core subgraph C. + if len(nodes) < k: + continue + SG = G.subgraph(nodes) + # Build auxiliary graph + H = _AntiGraph() + H.add_nodes_from(SG.nodes()) + for u, v in combinations(SG, 2): + K = node_connectivity(SG, u, v, cutoff=k) + if k > K: + H.add_edge(u, v) + for h_nodes in biconnected_components(H): + if len(h_nodes) <= k: + continue + SH = H.subgraph(h_nodes) + for Gc in _cliques_heuristic(SG, SH, k, min_density): + for k_nodes in biconnected_components(Gc): + Gk = nx.k_core(SG.subgraph(k_nodes), k) + if len(Gk) <= k: + continue + k_components[k].append(set(Gk)) + return k_components + + +def _cliques_heuristic(G, H, k, min_density): + h_cnumber = nx.core_number(H) + for i, c_value in enumerate(sorted(set(h_cnumber.values()), reverse=True)): + cands = {n for n, c in h_cnumber.items() if c == c_value} + # Skip checking for overlap for the highest core value + if i == 0: + overlap = False + else: + overlap = set.intersection( + *[{x for x in H[n] if x not in cands} for n in cands] + ) + if overlap and len(overlap) < k: + SH = H.subgraph(cands | overlap) + else: + SH = H.subgraph(cands) + sh_cnumber = nx.core_number(SH) + SG = nx.k_core(G.subgraph(SH), k) + while not (_same(sh_cnumber) and nx.density(SH) >= min_density): + # This subgraph must be writable => .copy() + SH = H.subgraph(SG).copy() + if len(SH) <= k: + break + sh_cnumber = nx.core_number(SH) + sh_deg = dict(SH.degree()) + min_deg = min(sh_deg.values()) + SH.remove_nodes_from(n for n, d in sh_deg.items() if d == min_deg) + SG = nx.k_core(G.subgraph(SH), k) + else: + yield SG + + +def _same(measure, tol=0): + vals = set(measure.values()) + if (max(vals) - min(vals)) <= tol: + return True + return False + + +class _AntiGraph(nx.Graph): + """ + Class for complement graphs. + + The main goal is to be able to work with big and dense graphs with + a low memory footprint. + + In this class you add the edges that *do not exist* in the dense graph, + the report methods of the class return the neighbors, the edges and + the degree as if it was the dense graph. Thus it's possible to use + an instance of this class with some of NetworkX functions. In this + case we only use k-core, connected_components, and biconnected_components. + """ + + all_edge_dict = {"weight": 1} + + def single_edge_dict(self): + return self.all_edge_dict + + edge_attr_dict_factory = single_edge_dict # type: ignore[assignment] + + def __getitem__(self, n): + """Returns a dict of neighbors of node n in the dense graph. + + Parameters + ---------- + n : node + A node in the graph. + + Returns + ------- + adj_dict : dictionary + The adjacency dictionary for nodes connected to n. + + """ + all_edge_dict = self.all_edge_dict + return { + node: all_edge_dict for node in set(self._adj) - set(self._adj[n]) - {n} + } + + def neighbors(self, n): + """Returns an iterator over all neighbors of node n in the + dense graph. + """ + try: + return iter(set(self._adj) - set(self._adj[n]) - {n}) + except KeyError as err: + raise NetworkXError(f"The node {n} is not in the graph.") from err + + class AntiAtlasView(Mapping): + """An adjacency inner dict for AntiGraph""" + + def __init__(self, graph, node): + self._graph = graph + self._atlas = graph._adj[node] + self._node = node + + def __len__(self): + return len(self._graph) - len(self._atlas) - 1 + + def __iter__(self): + return (n for n in self._graph if n not in self._atlas and n != self._node) + + def __getitem__(self, nbr): + nbrs = set(self._graph._adj) - set(self._atlas) - {self._node} + if nbr in nbrs: + return self._graph.all_edge_dict + raise KeyError(nbr) + + class AntiAdjacencyView(AntiAtlasView): + """An adjacency outer dict for AntiGraph""" + + def __init__(self, graph): + self._graph = graph + self._atlas = graph._adj + + def __len__(self): + return len(self._atlas) + + def __iter__(self): + return iter(self._graph) + + def __getitem__(self, node): + if node not in self._graph: + raise KeyError(node) + return self._graph.AntiAtlasView(self._graph, node) + + @cached_property + def adj(self): + return self.AntiAdjacencyView(self) + + def subgraph(self, nodes): + """This subgraph method returns a full AntiGraph. Not a View""" + nodes = set(nodes) + G = _AntiGraph() + G.add_nodes_from(nodes) + for n in G: + Gnbrs = G.adjlist_inner_dict_factory() + G._adj[n] = Gnbrs + for nbr, d in self._adj[n].items(): + if nbr in G._adj: + Gnbrs[nbr] = d + G._adj[nbr][n] = d + G.graph = self.graph + return G + + class AntiDegreeView(nx.reportviews.DegreeView): + def __iter__(self): + all_nodes = set(self._succ) + for n in self._nodes: + nbrs = all_nodes - set(self._succ[n]) - {n} + yield (n, len(nbrs)) + + def __getitem__(self, n): + nbrs = set(self._succ) - set(self._succ[n]) - {n} + # AntiGraph is a ThinGraph so all edges have weight 1 + return len(nbrs) + (n in nbrs) + + @cached_property + def degree(self): + """Returns an iterator for (node, degree) and degree for single node. + + The node degree is the number of edges adjacent to the node. + + Parameters + ---------- + nbunch : iterable container, optional (default=all nodes) + A container of nodes. The container will be iterated + through once. + + weight : string or None, optional (default=None) + The edge attribute that holds the numerical value used + as a weight. If None, then each edge has weight 1. + The degree is the sum of the edge weights adjacent to the node. + + Returns + ------- + deg: + Degree of the node, if a single node is passed as argument. + nd_iter : an iterator + The iterator returns two-tuples of (node, degree). + + See Also + -------- + degree + + Examples + -------- + >>> G = nx.path_graph(4) + >>> G.degree(0) # node 0 with degree 1 + 1 + >>> list(G.degree([0, 1])) + [(0, 1), (1, 2)] + + """ + return self.AntiDegreeView(self) + + def adjacency(self): + """Returns an iterator of (node, adjacency set) tuples for all nodes + in the dense graph. + + This is the fastest way to look at every edge. + For directed graphs, only outgoing adjacencies are included. + + Returns + ------- + adj_iter : iterator + An iterator of (node, adjacency set) for all nodes in + the graph. + + """ + for n in self._adj: + yield (n, set(self._adj) - set(self._adj[n]) - {n}) diff --git a/lib/python3.10/site-packages/networkx/algorithms/approximation/matching.py b/lib/python3.10/site-packages/networkx/algorithms/approximation/matching.py new file mode 100644 index 0000000000000000000000000000000000000000..dc0891947cb88c38930e9f8c5479d397e3aa923b --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/approximation/matching.py @@ -0,0 +1,44 @@ +""" +************** +Graph Matching +************** + +Given a graph G = (V,E), a matching M in G is a set of pairwise non-adjacent +edges; that is, no two edges share a common vertex. + +`Wikipedia: Matching `_ +""" + +import networkx as nx + +__all__ = ["min_maximal_matching"] + + +@nx._dispatchable +def min_maximal_matching(G): + r"""Returns the minimum maximal matching of G. That is, out of all maximal + matchings of the graph G, the smallest is returned. + + Parameters + ---------- + G : NetworkX graph + Undirected graph + + Returns + ------- + min_maximal_matching : set + Returns a set of edges such that no two edges share a common endpoint + and every edge not in the set shares some common endpoint in the set. + Cardinality will be 2*OPT in the worst case. + + Notes + ----- + The algorithm computes an approximate solution for the minimum maximal + cardinality matching problem. The solution is no more than 2 * OPT in size. + Runtime is $O(|E|)$. + + References + ---------- + .. [1] Vazirani, Vijay Approximation Algorithms (2001) + """ + return nx.maximal_matching(G) diff --git a/lib/python3.10/site-packages/networkx/algorithms/approximation/maxcut.py b/lib/python3.10/site-packages/networkx/algorithms/approximation/maxcut.py new file mode 100644 index 0000000000000000000000000000000000000000..f4e1da87c35ab821f4b3d0851bba19d599d8fa6a --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/approximation/maxcut.py @@ -0,0 +1,143 @@ +import networkx as nx +from networkx.utils.decorators import not_implemented_for, py_random_state + +__all__ = ["randomized_partitioning", "one_exchange"] + + +@not_implemented_for("directed") +@not_implemented_for("multigraph") +@py_random_state(1) +@nx._dispatchable(edge_attrs="weight") +def randomized_partitioning(G, seed=None, p=0.5, weight=None): + """Compute a random partitioning of the graph nodes and its cut value. + + A partitioning is calculated by observing each node + and deciding to add it to the partition with probability `p`, + returning a random cut and its corresponding value (the + sum of weights of edges connecting different partitions). + + Parameters + ---------- + G : NetworkX graph + + seed : integer, random_state, or None (default) + Indicator of random number generation state. + See :ref:`Randomness`. + + p : scalar + Probability for each node to be part of the first partition. + Should be in [0,1] + + weight : object + Edge attribute key to use as weight. If not specified, edges + have weight one. + + Returns + ------- + cut_size : scalar + Value of the minimum cut. + + partition : pair of node sets + A partitioning of the nodes that defines a minimum cut. + + Examples + -------- + >>> G = nx.complete_graph(5) + >>> cut_size, partition = nx.approximation.randomized_partitioning(G, seed=1) + >>> cut_size + 6 + >>> partition + ({0, 3, 4}, {1, 2}) + + Raises + ------ + NetworkXNotImplemented + If the graph is directed or is a multigraph. + """ + cut = {node for node in G.nodes() if seed.random() < p} + cut_size = nx.algorithms.cut_size(G, cut, weight=weight) + partition = (cut, G.nodes - cut) + return cut_size, partition + + +def _swap_node_partition(cut, node): + return cut - {node} if node in cut else cut.union({node}) + + +@not_implemented_for("directed") +@not_implemented_for("multigraph") +@py_random_state(2) +@nx._dispatchable(edge_attrs="weight") +def one_exchange(G, initial_cut=None, seed=None, weight=None): + """Compute a partitioning of the graphs nodes and the corresponding cut value. + + Use a greedy one exchange strategy to find a locally maximal cut + and its value, it works by finding the best node (one that gives + the highest gain to the cut value) to add to the current cut + and repeats this process until no improvement can be made. + + Parameters + ---------- + G : networkx Graph + Graph to find a maximum cut for. + + initial_cut : set + Cut to use as a starting point. If not supplied the algorithm + starts with an empty cut. + + seed : integer, random_state, or None (default) + Indicator of random number generation state. + See :ref:`Randomness`. + + weight : object + Edge attribute key to use as weight. If not specified, edges + have weight one. + + Returns + ------- + cut_value : scalar + Value of the maximum cut. + + partition : pair of node sets + A partitioning of the nodes that defines a maximum cut. + + Examples + -------- + >>> G = nx.complete_graph(5) + >>> curr_cut_size, partition = nx.approximation.one_exchange(G, seed=1) + >>> curr_cut_size + 6 + >>> partition + ({0, 2}, {1, 3, 4}) + + Raises + ------ + NetworkXNotImplemented + If the graph is directed or is a multigraph. + """ + if initial_cut is None: + initial_cut = set() + cut = set(initial_cut) + current_cut_size = nx.algorithms.cut_size(G, cut, weight=weight) + while True: + nodes = list(G.nodes()) + # Shuffling the nodes ensures random tie-breaks in the following call to max + seed.shuffle(nodes) + best_node_to_swap = max( + nodes, + key=lambda v: nx.algorithms.cut_size( + G, _swap_node_partition(cut, v), weight=weight + ), + default=None, + ) + potential_cut = _swap_node_partition(cut, best_node_to_swap) + potential_cut_size = nx.algorithms.cut_size(G, potential_cut, weight=weight) + + if potential_cut_size > current_cut_size: + cut = potential_cut + current_cut_size = potential_cut_size + else: + break + + partition = (cut, G.nodes - cut) + return current_cut_size, partition diff --git a/lib/python3.10/site-packages/networkx/algorithms/approximation/ramsey.py b/lib/python3.10/site-packages/networkx/algorithms/approximation/ramsey.py new file mode 100644 index 0000000000000000000000000000000000000000..0552e4a942c9c99fbf132300d7288595307052ff --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/approximation/ramsey.py @@ -0,0 +1,53 @@ +""" +Ramsey numbers. +""" + +import networkx as nx +from networkx.utils import not_implemented_for + +from ...utils import arbitrary_element + +__all__ = ["ramsey_R2"] + + +@not_implemented_for("directed") +@not_implemented_for("multigraph") +@nx._dispatchable +def ramsey_R2(G): + r"""Compute the largest clique and largest independent set in `G`. + + This can be used to estimate bounds for the 2-color + Ramsey number `R(2;s,t)` for `G`. + + This is a recursive implementation which could run into trouble + for large recursions. Note that self-loop edges are ignored. + + Parameters + ---------- + G : NetworkX graph + Undirected graph + + Returns + ------- + max_pair : (set, set) tuple + Maximum clique, Maximum independent set. + + Raises + ------ + NetworkXNotImplemented + If the graph is directed or is a multigraph. + """ + if not G: + return set(), set() + + node = arbitrary_element(G) + nbrs = (nbr for nbr in nx.all_neighbors(G, node) if nbr != node) + nnbrs = nx.non_neighbors(G, node) + c_1, i_1 = ramsey_R2(G.subgraph(nbrs).copy()) + c_2, i_2 = ramsey_R2(G.subgraph(nnbrs).copy()) + + c_1.add(node) + i_2.add(node) + # Choose the larger of the two cliques and the larger of the two + # independent sets, according to cardinality. + return max(c_1, c_2, key=len), max(i_1, i_2, key=len) diff --git a/lib/python3.10/site-packages/networkx/algorithms/approximation/steinertree.py b/lib/python3.10/site-packages/networkx/algorithms/approximation/steinertree.py new file mode 100644 index 0000000000000000000000000000000000000000..f4840effd340e9769dbcc52b72bf1e25968a215d --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/approximation/steinertree.py @@ -0,0 +1,231 @@ +from itertools import chain + +import networkx as nx +from networkx.utils import not_implemented_for, pairwise + +__all__ = ["metric_closure", "steiner_tree"] + + +@not_implemented_for("directed") +@nx._dispatchable(edge_attrs="weight", returns_graph=True) +def metric_closure(G, weight="weight"): + """Return the metric closure of a graph. + + The metric closure of a graph *G* is the complete graph in which each edge + is weighted by the shortest path distance between the nodes in *G* . + + Parameters + ---------- + G : NetworkX graph + + Returns + ------- + NetworkX graph + Metric closure of the graph `G`. + + """ + M = nx.Graph() + + Gnodes = set(G) + + # check for connected graph while processing first node + all_paths_iter = nx.all_pairs_dijkstra(G, weight=weight) + u, (distance, path) = next(all_paths_iter) + if Gnodes - set(distance): + msg = "G is not a connected graph. metric_closure is not defined." + raise nx.NetworkXError(msg) + Gnodes.remove(u) + for v in Gnodes: + M.add_edge(u, v, distance=distance[v], path=path[v]) + + # first node done -- now process the rest + for u, (distance, path) in all_paths_iter: + Gnodes.remove(u) + for v in Gnodes: + M.add_edge(u, v, distance=distance[v], path=path[v]) + + return M + + +def _mehlhorn_steiner_tree(G, terminal_nodes, weight): + paths = nx.multi_source_dijkstra_path(G, terminal_nodes) + + d_1 = {} + s = {} + for v in G.nodes(): + s[v] = paths[v][0] + d_1[(v, s[v])] = len(paths[v]) - 1 + + # G1-G4 names match those from the Mehlhorn 1988 paper. + G_1_prime = nx.Graph() + for u, v, data in G.edges(data=True): + su, sv = s[u], s[v] + weight_here = d_1[(u, su)] + data.get(weight, 1) + d_1[(v, sv)] + if not G_1_prime.has_edge(su, sv): + G_1_prime.add_edge(su, sv, weight=weight_here) + else: + new_weight = min(weight_here, G_1_prime[su][sv]["weight"]) + G_1_prime.add_edge(su, sv, weight=new_weight) + + G_2 = nx.minimum_spanning_edges(G_1_prime, data=True) + + G_3 = nx.Graph() + for u, v, d in G_2: + path = nx.shortest_path(G, u, v, weight) + for n1, n2 in pairwise(path): + G_3.add_edge(n1, n2) + + G_3_mst = list(nx.minimum_spanning_edges(G_3, data=False)) + if G.is_multigraph(): + G_3_mst = ( + (u, v, min(G[u][v], key=lambda k: G[u][v][k][weight])) for u, v in G_3_mst + ) + G_4 = G.edge_subgraph(G_3_mst).copy() + _remove_nonterminal_leaves(G_4, terminal_nodes) + return G_4.edges() + + +def _kou_steiner_tree(G, terminal_nodes, weight): + # H is the subgraph induced by terminal_nodes in the metric closure M of G. + M = metric_closure(G, weight=weight) + H = M.subgraph(terminal_nodes) + + # Use the 'distance' attribute of each edge provided by M. + mst_edges = nx.minimum_spanning_edges(H, weight="distance", data=True) + + # Create an iterator over each edge in each shortest path; repeats are okay + mst_all_edges = chain.from_iterable(pairwise(d["path"]) for u, v, d in mst_edges) + if G.is_multigraph(): + mst_all_edges = ( + (u, v, min(G[u][v], key=lambda k: G[u][v][k][weight])) + for u, v in mst_all_edges + ) + + # Find the MST again, over this new set of edges + G_S = G.edge_subgraph(mst_all_edges) + T_S = nx.minimum_spanning_edges(G_S, weight="weight", data=False) + + # Leaf nodes that are not terminal might still remain; remove them here + T_H = G.edge_subgraph(T_S).copy() + _remove_nonterminal_leaves(T_H, terminal_nodes) + + return T_H.edges() + + +def _remove_nonterminal_leaves(G, terminals): + terminal_set = set(terminals) + leaves = {n for n in G if len(set(G[n]) - {n}) == 1} + nonterminal_leaves = leaves - terminal_set + + while nonterminal_leaves: + # Removing a node may create new non-terminal leaves, so we limit + # search for candidate non-terminal nodes to neighbors of current + # non-terminal nodes + candidate_leaves = set.union(*(set(G[n]) for n in nonterminal_leaves)) + candidate_leaves -= nonterminal_leaves | terminal_set + # Remove current set of non-terminal nodes + G.remove_nodes_from(nonterminal_leaves) + # Find any new non-terminal nodes from the set of candidates + leaves = {n for n in candidate_leaves if len(set(G[n]) - {n}) == 1} + nonterminal_leaves = leaves - terminal_set + + +ALGORITHMS = { + "kou": _kou_steiner_tree, + "mehlhorn": _mehlhorn_steiner_tree, +} + + +@not_implemented_for("directed") +@nx._dispatchable(preserve_all_attrs=True, returns_graph=True) +def steiner_tree(G, terminal_nodes, weight="weight", method=None): + r"""Return an approximation to the minimum Steiner tree of a graph. + + The minimum Steiner tree of `G` w.r.t a set of `terminal_nodes` (also *S*) + is a tree within `G` that spans those nodes and has minimum size (sum of + edge weights) among all such trees. + + The approximation algorithm is specified with the `method` keyword + argument. All three available algorithms produce a tree whose weight is + within a ``(2 - (2 / l))`` factor of the weight of the optimal Steiner tree, + where ``l`` is the minimum number of leaf nodes across all possible Steiner + trees. + + * ``"kou"`` [2]_ (runtime $O(|S| |V|^2)$) computes the minimum spanning tree of + the subgraph of the metric closure of *G* induced by the terminal nodes, + where the metric closure of *G* is the complete graph in which each edge is + weighted by the shortest path distance between the nodes in *G*. + + * ``"mehlhorn"`` [3]_ (runtime $O(|E|+|V|\log|V|)$) modifies Kou et al.'s + algorithm, beginning by finding the closest terminal node for each + non-terminal. This data is used to create a complete graph containing only + the terminal nodes, in which edge is weighted with the shortest path + distance between them. The algorithm then proceeds in the same way as Kou + et al.. + + Parameters + ---------- + G : NetworkX graph + + terminal_nodes : list + A list of terminal nodes for which minimum steiner tree is + to be found. + + weight : string (default = 'weight') + Use the edge attribute specified by this string as the edge weight. + Any edge attribute not present defaults to 1. + + method : string, optional (default = 'mehlhorn') + The algorithm to use to approximate the Steiner tree. + Supported options: 'kou', 'mehlhorn'. + Other inputs produce a ValueError. + + Returns + ------- + NetworkX graph + Approximation to the minimum steiner tree of `G` induced by + `terminal_nodes` . + + Raises + ------ + NetworkXNotImplemented + If `G` is directed. + + ValueError + If the specified `method` is not supported. + + Notes + ----- + For multigraphs, the edge between two nodes with minimum weight is the + edge put into the Steiner tree. + + + References + ---------- + .. [1] Steiner_tree_problem on Wikipedia. + https://en.wikipedia.org/wiki/Steiner_tree_problem + .. [2] Kou, L., G. Markowsky, and L. Berman. 1981. + ‘A Fast Algorithm for Steiner Trees’. + Acta Informatica 15 (2): 141–45. + https://doi.org/10.1007/BF00288961. + .. [3] Mehlhorn, Kurt. 1988. + ‘A Faster Approximation Algorithm for the Steiner Problem in Graphs’. + Information Processing Letters 27 (3): 125–28. + https://doi.org/10.1016/0020-0190(88)90066-X. + """ + if method is None: + method = "mehlhorn" + + try: + algo = ALGORITHMS[method] + except KeyError as e: + raise ValueError(f"{method} is not a valid choice for an algorithm.") from e + + edges = algo(G, terminal_nodes, weight) + # For multigraph we should add the minimal weight edge keys + if G.is_multigraph(): + edges = ( + (u, v, min(G[u][v], key=lambda k: G[u][v][k][weight])) for u, v in edges + ) + T = G.edge_subgraph(edges) + return T diff --git a/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/__init__.py b/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..19e5d6d3a5a92f6c082226222a2e35c436f9b09c Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/__pycache__/test_approx_clust_coeff.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/__pycache__/test_approx_clust_coeff.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..75c4a4eff13ad9e7271724429d7a3fd0c5b48bac Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/__pycache__/test_approx_clust_coeff.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/__pycache__/test_clique.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/__pycache__/test_clique.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1effaecd62e2b40b6941b6e355422e9554f05468 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/__pycache__/test_clique.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/__pycache__/test_connectivity.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/__pycache__/test_connectivity.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e8d20a65d6297ca0556fa9b7380bf7283de2d59 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/__pycache__/test_connectivity.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/__pycache__/test_distance_measures.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/__pycache__/test_distance_measures.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f4daa38864b1bb1270c73480d72f99dc658c9e9 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/__pycache__/test_distance_measures.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/__pycache__/test_dominating_set.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/__pycache__/test_dominating_set.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ffff0e2f2bdb98ba18cbcf4e99a46e2c3ac89e7 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/__pycache__/test_dominating_set.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/__pycache__/test_kcomponents.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/__pycache__/test_kcomponents.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0894db565e03c62dbb48ed93bb1a6f704a9ab721 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/__pycache__/test_kcomponents.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/__pycache__/test_matching.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/__pycache__/test_matching.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f21ce42b96bb249db89227ec62cfcfcd7cf2169a Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/__pycache__/test_matching.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/__pycache__/test_maxcut.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/__pycache__/test_maxcut.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a53fbc383d9961b464e6d3a74dbb7aa48d5997b Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/__pycache__/test_maxcut.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/__pycache__/test_ramsey.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/__pycache__/test_ramsey.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e79d96145d0a6279d9cac818ae60d8aeb49f2ad0 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/__pycache__/test_ramsey.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/__pycache__/test_steinertree.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/__pycache__/test_steinertree.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4d744bef005547413a077de077c26f9be0a720df Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/__pycache__/test_steinertree.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/__pycache__/test_traveling_salesman.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/__pycache__/test_traveling_salesman.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..523ceaa79ded53a38f302c56dd097053d3c805a2 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/__pycache__/test_traveling_salesman.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/__pycache__/test_treewidth.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/__pycache__/test_treewidth.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c05e76468e7b5f9df316473d4cf3c48a18fccce Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/__pycache__/test_treewidth.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/__pycache__/test_vertex_cover.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/__pycache__/test_vertex_cover.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd48de2385f867f33691bf587f14955f97fcbc7f Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/__pycache__/test_vertex_cover.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/test_approx_clust_coeff.py b/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/test_approx_clust_coeff.py new file mode 100644 index 0000000000000000000000000000000000000000..5eab5c1ee79408c9f90a1993415a6c3d7d957141 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/test_approx_clust_coeff.py @@ -0,0 +1,41 @@ +import networkx as nx +from networkx.algorithms.approximation import average_clustering + +# This approximation has to be exact in regular graphs +# with no triangles or with all possible triangles. + + +def test_petersen(): + # Actual coefficient is 0 + G = nx.petersen_graph() + assert average_clustering(G, trials=len(G) // 2) == nx.average_clustering(G) + + +def test_petersen_seed(): + # Actual coefficient is 0 + G = nx.petersen_graph() + assert average_clustering(G, trials=len(G) // 2, seed=1) == nx.average_clustering(G) + + +def test_tetrahedral(): + # Actual coefficient is 1 + G = nx.tetrahedral_graph() + assert average_clustering(G, trials=len(G) // 2) == nx.average_clustering(G) + + +def test_dodecahedral(): + # Actual coefficient is 0 + G = nx.dodecahedral_graph() + assert average_clustering(G, trials=len(G) // 2) == nx.average_clustering(G) + + +def test_empty(): + G = nx.empty_graph(5) + assert average_clustering(G, trials=len(G) // 2) == 0 + + +def test_complete(): + G = nx.complete_graph(5) + assert average_clustering(G, trials=len(G) // 2) == 1 + G = nx.complete_graph(7) + assert average_clustering(G, trials=len(G) // 2) == 1 diff --git a/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/test_clique.py b/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/test_clique.py new file mode 100644 index 0000000000000000000000000000000000000000..b40dcb904063f2f1aca8811e73ec7d590beb5dc9 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/test_clique.py @@ -0,0 +1,112 @@ +"""Unit tests for the :mod:`networkx.algorithms.approximation.clique` module.""" + +import networkx as nx +from networkx.algorithms.approximation import ( + clique_removal, + large_clique_size, + max_clique, + maximum_independent_set, +) + + +def is_independent_set(G, nodes): + """Returns True if and only if `nodes` is a clique in `G`. + + `G` is a NetworkX graph. `nodes` is an iterable of nodes in + `G`. + + """ + return G.subgraph(nodes).number_of_edges() == 0 + + +def is_clique(G, nodes): + """Returns True if and only if `nodes` is an independent set + in `G`. + + `G` is an undirected simple graph. `nodes` is an iterable of + nodes in `G`. + + """ + H = G.subgraph(nodes) + n = len(H) + return H.number_of_edges() == n * (n - 1) // 2 + + +class TestCliqueRemoval: + """Unit tests for the + :func:`~networkx.algorithms.approximation.clique_removal` function. + + """ + + def test_trivial_graph(self): + G = nx.trivial_graph() + independent_set, cliques = clique_removal(G) + assert is_independent_set(G, independent_set) + assert all(is_clique(G, clique) for clique in cliques) + # In fact, we should only have 1-cliques, that is, singleton nodes. + assert all(len(clique) == 1 for clique in cliques) + + def test_complete_graph(self): + G = nx.complete_graph(10) + independent_set, cliques = clique_removal(G) + assert is_independent_set(G, independent_set) + assert all(is_clique(G, clique) for clique in cliques) + + def test_barbell_graph(self): + G = nx.barbell_graph(10, 5) + independent_set, cliques = clique_removal(G) + assert is_independent_set(G, independent_set) + assert all(is_clique(G, clique) for clique in cliques) + + +class TestMaxClique: + """Unit tests for the :func:`networkx.algorithms.approximation.max_clique` + function. + + """ + + def test_null_graph(self): + G = nx.null_graph() + assert len(max_clique(G)) == 0 + + def test_complete_graph(self): + graph = nx.complete_graph(30) + # this should return the entire graph + mc = max_clique(graph) + assert 30 == len(mc) + + def test_maximal_by_cardinality(self): + """Tests that the maximal clique is computed according to maximum + cardinality of the sets. + + For more information, see pull request #1531. + + """ + G = nx.complete_graph(5) + G.add_edge(4, 5) + clique = max_clique(G) + assert len(clique) > 1 + + G = nx.lollipop_graph(30, 2) + clique = max_clique(G) + assert len(clique) > 2 + + +def test_large_clique_size(): + G = nx.complete_graph(9) + nx.add_cycle(G, [9, 10, 11]) + G.add_edge(8, 9) + G.add_edge(1, 12) + G.add_node(13) + + assert large_clique_size(G) == 9 + G.remove_node(5) + assert large_clique_size(G) == 8 + G.remove_edge(2, 3) + assert large_clique_size(G) == 7 + + +def test_independent_set(): + # smoke test + G = nx.Graph() + assert len(maximum_independent_set(G)) == 0 diff --git a/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/test_connectivity.py b/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/test_connectivity.py new file mode 100644 index 0000000000000000000000000000000000000000..887db20bcaef8dd2641c64e963c789234aecbb20 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/test_connectivity.py @@ -0,0 +1,199 @@ +import pytest + +import networkx as nx +from networkx.algorithms import approximation as approx + + +def test_global_node_connectivity(): + # Figure 1 chapter on Connectivity + G = nx.Graph() + G.add_edges_from( + [ + (1, 2), + (1, 3), + (1, 4), + (1, 5), + (2, 3), + (2, 6), + (3, 4), + (3, 6), + (4, 6), + (4, 7), + (5, 7), + (6, 8), + (6, 9), + (7, 8), + (7, 10), + (8, 11), + (9, 10), + (9, 11), + (10, 11), + ] + ) + assert 2 == approx.local_node_connectivity(G, 1, 11) + assert 2 == approx.node_connectivity(G) + assert 2 == approx.node_connectivity(G, 1, 11) + + +def test_white_harary1(): + # Figure 1b white and harary (2001) + # A graph with high adhesion (edge connectivity) and low cohesion + # (node connectivity) + G = nx.disjoint_union(nx.complete_graph(4), nx.complete_graph(4)) + G.remove_node(7) + for i in range(4, 7): + G.add_edge(0, i) + G = nx.disjoint_union(G, nx.complete_graph(4)) + G.remove_node(G.order() - 1) + for i in range(7, 10): + G.add_edge(0, i) + assert 1 == approx.node_connectivity(G) + + +def test_complete_graphs(): + for n in range(5, 25, 5): + G = nx.complete_graph(n) + assert n - 1 == approx.node_connectivity(G) + assert n - 1 == approx.node_connectivity(G, 0, 3) + + +def test_empty_graphs(): + for k in range(5, 25, 5): + G = nx.empty_graph(k) + assert 0 == approx.node_connectivity(G) + assert 0 == approx.node_connectivity(G, 0, 3) + + +def test_petersen(): + G = nx.petersen_graph() + assert 3 == approx.node_connectivity(G) + assert 3 == approx.node_connectivity(G, 0, 5) + + +# Approximation fails with tutte graph +# def test_tutte(): +# G = nx.tutte_graph() +# assert_equal(3, approx.node_connectivity(G)) + + +def test_dodecahedral(): + G = nx.dodecahedral_graph() + assert 3 == approx.node_connectivity(G) + assert 3 == approx.node_connectivity(G, 0, 5) + + +def test_octahedral(): + G = nx.octahedral_graph() + assert 4 == approx.node_connectivity(G) + assert 4 == approx.node_connectivity(G, 0, 5) + + +# Approximation can fail with icosahedral graph depending +# on iteration order. +# def test_icosahedral(): +# G=nx.icosahedral_graph() +# assert_equal(5, approx.node_connectivity(G)) +# assert_equal(5, approx.node_connectivity(G, 0, 5)) + + +def test_only_source(): + G = nx.complete_graph(5) + pytest.raises(nx.NetworkXError, approx.node_connectivity, G, s=0) + + +def test_only_target(): + G = nx.complete_graph(5) + pytest.raises(nx.NetworkXError, approx.node_connectivity, G, t=0) + + +def test_missing_source(): + G = nx.path_graph(4) + pytest.raises(nx.NetworkXError, approx.node_connectivity, G, 10, 1) + + +def test_missing_target(): + G = nx.path_graph(4) + pytest.raises(nx.NetworkXError, approx.node_connectivity, G, 1, 10) + + +def test_source_equals_target(): + G = nx.complete_graph(5) + pytest.raises(nx.NetworkXError, approx.local_node_connectivity, G, 0, 0) + + +def test_directed_node_connectivity(): + G = nx.cycle_graph(10, create_using=nx.DiGraph()) # only one direction + D = nx.cycle_graph(10).to_directed() # 2 reciprocal edges + assert 1 == approx.node_connectivity(G) + assert 1 == approx.node_connectivity(G, 1, 4) + assert 2 == approx.node_connectivity(D) + assert 2 == approx.node_connectivity(D, 1, 4) + + +class TestAllPairsNodeConnectivityApprox: + @classmethod + def setup_class(cls): + cls.path = nx.path_graph(7) + cls.directed_path = nx.path_graph(7, create_using=nx.DiGraph()) + cls.cycle = nx.cycle_graph(7) + cls.directed_cycle = nx.cycle_graph(7, create_using=nx.DiGraph()) + cls.gnp = nx.gnp_random_graph(30, 0.1) + cls.directed_gnp = nx.gnp_random_graph(30, 0.1, directed=True) + cls.K20 = nx.complete_graph(20) + cls.K10 = nx.complete_graph(10) + cls.K5 = nx.complete_graph(5) + cls.G_list = [ + cls.path, + cls.directed_path, + cls.cycle, + cls.directed_cycle, + cls.gnp, + cls.directed_gnp, + cls.K10, + cls.K5, + cls.K20, + ] + + def test_cycles(self): + K_undir = approx.all_pairs_node_connectivity(self.cycle) + for source in K_undir: + for target, k in K_undir[source].items(): + assert k == 2 + K_dir = approx.all_pairs_node_connectivity(self.directed_cycle) + for source in K_dir: + for target, k in K_dir[source].items(): + assert k == 1 + + def test_complete(self): + for G in [self.K10, self.K5, self.K20]: + K = approx.all_pairs_node_connectivity(G) + for source in K: + for target, k in K[source].items(): + assert k == len(G) - 1 + + def test_paths(self): + K_undir = approx.all_pairs_node_connectivity(self.path) + for source in K_undir: + for target, k in K_undir[source].items(): + assert k == 1 + K_dir = approx.all_pairs_node_connectivity(self.directed_path) + for source in K_dir: + for target, k in K_dir[source].items(): + if source < target: + assert k == 1 + else: + assert k == 0 + + def test_cutoff(self): + for G in [self.K10, self.K5, self.K20]: + for mp in [2, 3, 4]: + paths = approx.all_pairs_node_connectivity(G, cutoff=mp) + for source in paths: + for target, K in paths[source].items(): + assert K == mp + + def test_all_pairs_connectivity_nbunch(self): + G = nx.complete_graph(5) + nbunch = [0, 2, 3] + C = approx.all_pairs_node_connectivity(G, nbunch=nbunch) + assert len(C) == len(nbunch) diff --git a/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/test_distance_measures.py b/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/test_distance_measures.py new file mode 100644 index 0000000000000000000000000000000000000000..3809a8fc667db5d95fb41074e7bab3ff551e1784 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/test_distance_measures.py @@ -0,0 +1,59 @@ +"""Unit tests for the :mod:`networkx.algorithms.approximation.distance_measures` module.""" + +import pytest + +import networkx as nx +from networkx.algorithms.approximation import diameter + + +class TestDiameter: + """Unit tests for the approximate diameter function + :func:`~networkx.algorithms.approximation.distance_measures.diameter`. + """ + + def test_null_graph(self): + """Test empty graph.""" + G = nx.null_graph() + with pytest.raises( + nx.NetworkXError, match="Expected non-empty NetworkX graph!" + ): + diameter(G) + + def test_undirected_non_connected(self): + """Test an undirected disconnected graph.""" + graph = nx.path_graph(10) + graph.remove_edge(3, 4) + with pytest.raises(nx.NetworkXError, match="Graph not connected."): + diameter(graph) + + def test_directed_non_strongly_connected(self): + """Test a directed non strongly connected graph.""" + graph = nx.path_graph(10, create_using=nx.DiGraph()) + with pytest.raises(nx.NetworkXError, match="DiGraph not strongly connected."): + diameter(graph) + + def test_complete_undirected_graph(self): + """Test a complete undirected graph.""" + graph = nx.complete_graph(10) + assert diameter(graph) == 1 + + def test_complete_directed_graph(self): + """Test a complete directed graph.""" + graph = nx.complete_graph(10, create_using=nx.DiGraph()) + assert diameter(graph) == 1 + + def test_undirected_path_graph(self): + """Test an undirected path graph with 10 nodes.""" + graph = nx.path_graph(10) + assert diameter(graph) == 9 + + def test_directed_path_graph(self): + """Test a directed path graph with 10 nodes.""" + graph = nx.path_graph(10).to_directed() + assert diameter(graph) == 9 + + def test_single_node(self): + """Test a graph which contains just a node.""" + graph = nx.Graph() + graph.add_node(1) + assert diameter(graph) == 0 diff --git a/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/test_dominating_set.py b/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/test_dominating_set.py new file mode 100644 index 0000000000000000000000000000000000000000..6b90d85ecf73bb56370fd92fdec25e3bbbb91ce3 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/test_dominating_set.py @@ -0,0 +1,78 @@ +import pytest + +import networkx as nx +from networkx.algorithms.approximation import ( + min_edge_dominating_set, + min_weighted_dominating_set, +) + + +class TestMinWeightDominatingSet: + def test_min_weighted_dominating_set(self): + graph = nx.Graph() + graph.add_edge(1, 2) + graph.add_edge(1, 5) + graph.add_edge(2, 3) + graph.add_edge(2, 5) + graph.add_edge(3, 4) + graph.add_edge(3, 6) + graph.add_edge(5, 6) + + vertices = {1, 2, 3, 4, 5, 6} + # due to ties, this might be hard to test tight bounds + dom_set = min_weighted_dominating_set(graph) + for vertex in vertices - dom_set: + neighbors = set(graph.neighbors(vertex)) + assert len(neighbors & dom_set) > 0, "Non dominating set found!" + + def test_star_graph(self): + """Tests that an approximate dominating set for the star graph, + even when the center node does not have the smallest integer + label, gives just the center node. + + For more information, see #1527. + + """ + # Create a star graph in which the center node has the highest + # label instead of the lowest. + G = nx.star_graph(10) + G = nx.relabel_nodes(G, {0: 9, 9: 0}) + assert min_weighted_dominating_set(G) == {9} + + def test_null_graph(self): + """Tests that the unique dominating set for the null graph is an empty set""" + G = nx.Graph() + assert min_weighted_dominating_set(G) == set() + + def test_min_edge_dominating_set(self): + graph = nx.path_graph(5) + dom_set = min_edge_dominating_set(graph) + + # this is a crappy way to test, but good enough for now. + for edge in graph.edges(): + if edge in dom_set: + continue + else: + u, v = edge + found = False + for dom_edge in dom_set: + found |= u == dom_edge[0] or u == dom_edge[1] + assert found, "Non adjacent edge found!" + + graph = nx.complete_graph(10) + dom_set = min_edge_dominating_set(graph) + + # this is a crappy way to test, but good enough for now. + for edge in graph.edges(): + if edge in dom_set: + continue + else: + u, v = edge + found = False + for dom_edge in dom_set: + found |= u == dom_edge[0] or u == dom_edge[1] + assert found, "Non adjacent edge found!" + + graph = nx.Graph() # empty Networkx graph + with pytest.raises(ValueError, match="Expected non-empty NetworkX graph!"): + min_edge_dominating_set(graph) diff --git a/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/test_kcomponents.py b/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/test_kcomponents.py new file mode 100644 index 0000000000000000000000000000000000000000..65ba802171a6b43a5157f12010c8164e5e867eb8 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/test_kcomponents.py @@ -0,0 +1,303 @@ +# Test for approximation to k-components algorithm +import pytest + +import networkx as nx +from networkx.algorithms.approximation import k_components +from networkx.algorithms.approximation.kcomponents import _AntiGraph, _same + + +def build_k_number_dict(k_components): + k_num = {} + for k, comps in sorted(k_components.items()): + for comp in comps: + for node in comp: + k_num[node] = k + return k_num + + +## +# Some nice synthetic graphs +## + + +def graph_example_1(): + G = nx.convert_node_labels_to_integers( + nx.grid_graph([5, 5]), label_attribute="labels" + ) + rlabels = nx.get_node_attributes(G, "labels") + labels = {v: k for k, v in rlabels.items()} + + for nodes in [ + (labels[(0, 0)], labels[(1, 0)]), + (labels[(0, 4)], labels[(1, 4)]), + (labels[(3, 0)], labels[(4, 0)]), + (labels[(3, 4)], labels[(4, 4)]), + ]: + new_node = G.order() + 1 + # Petersen graph is triconnected + P = nx.petersen_graph() + G = nx.disjoint_union(G, P) + # Add two edges between the grid and P + G.add_edge(new_node + 1, nodes[0]) + G.add_edge(new_node, nodes[1]) + # K5 is 4-connected + K = nx.complete_graph(5) + G = nx.disjoint_union(G, K) + # Add three edges between P and K5 + G.add_edge(new_node + 2, new_node + 11) + G.add_edge(new_node + 3, new_node + 12) + G.add_edge(new_node + 4, new_node + 13) + # Add another K5 sharing a node + G = nx.disjoint_union(G, K) + nbrs = G[new_node + 10] + G.remove_node(new_node + 10) + for nbr in nbrs: + G.add_edge(new_node + 17, nbr) + G.add_edge(new_node + 16, new_node + 5) + return G + + +def torrents_and_ferraro_graph(): + G = nx.convert_node_labels_to_integers( + nx.grid_graph([5, 5]), label_attribute="labels" + ) + rlabels = nx.get_node_attributes(G, "labels") + labels = {v: k for k, v in rlabels.items()} + + for nodes in [(labels[(0, 4)], labels[(1, 4)]), (labels[(3, 4)], labels[(4, 4)])]: + new_node = G.order() + 1 + # Petersen graph is triconnected + P = nx.petersen_graph() + G = nx.disjoint_union(G, P) + # Add two edges between the grid and P + G.add_edge(new_node + 1, nodes[0]) + G.add_edge(new_node, nodes[1]) + # K5 is 4-connected + K = nx.complete_graph(5) + G = nx.disjoint_union(G, K) + # Add three edges between P and K5 + G.add_edge(new_node + 2, new_node + 11) + G.add_edge(new_node + 3, new_node + 12) + G.add_edge(new_node + 4, new_node + 13) + # Add another K5 sharing a node + G = nx.disjoint_union(G, K) + nbrs = G[new_node + 10] + G.remove_node(new_node + 10) + for nbr in nbrs: + G.add_edge(new_node + 17, nbr) + # Commenting this makes the graph not biconnected !! + # This stupid mistake make one reviewer very angry :P + G.add_edge(new_node + 16, new_node + 8) + + for nodes in [(labels[(0, 0)], labels[(1, 0)]), (labels[(3, 0)], labels[(4, 0)])]: + new_node = G.order() + 1 + # Petersen graph is triconnected + P = nx.petersen_graph() + G = nx.disjoint_union(G, P) + # Add two edges between the grid and P + G.add_edge(new_node + 1, nodes[0]) + G.add_edge(new_node, nodes[1]) + # K5 is 4-connected + K = nx.complete_graph(5) + G = nx.disjoint_union(G, K) + # Add three edges between P and K5 + G.add_edge(new_node + 2, new_node + 11) + G.add_edge(new_node + 3, new_node + 12) + G.add_edge(new_node + 4, new_node + 13) + # Add another K5 sharing two nodes + G = nx.disjoint_union(G, K) + nbrs = G[new_node + 10] + G.remove_node(new_node + 10) + for nbr in nbrs: + G.add_edge(new_node + 17, nbr) + nbrs2 = G[new_node + 9] + G.remove_node(new_node + 9) + for nbr in nbrs2: + G.add_edge(new_node + 18, nbr) + return G + + +# Helper function + + +def _check_connectivity(G): + result = k_components(G) + for k, components in result.items(): + if k < 3: + continue + for component in components: + C = G.subgraph(component) + K = nx.node_connectivity(C) + assert K >= k + + +def test_torrents_and_ferraro_graph(): + G = torrents_and_ferraro_graph() + _check_connectivity(G) + + +def test_example_1(): + G = graph_example_1() + _check_connectivity(G) + + +def test_karate_0(): + G = nx.karate_club_graph() + _check_connectivity(G) + + +def test_karate_1(): + karate_k_num = { + 0: 4, + 1: 4, + 2: 4, + 3: 4, + 4: 3, + 5: 3, + 6: 3, + 7: 4, + 8: 4, + 9: 2, + 10: 3, + 11: 1, + 12: 2, + 13: 4, + 14: 2, + 15: 2, + 16: 2, + 17: 2, + 18: 2, + 19: 3, + 20: 2, + 21: 2, + 22: 2, + 23: 3, + 24: 3, + 25: 3, + 26: 2, + 27: 3, + 28: 3, + 29: 3, + 30: 4, + 31: 3, + 32: 4, + 33: 4, + } + approx_karate_k_num = karate_k_num.copy() + approx_karate_k_num[24] = 2 + approx_karate_k_num[25] = 2 + G = nx.karate_club_graph() + k_comps = k_components(G) + k_num = build_k_number_dict(k_comps) + assert k_num in (karate_k_num, approx_karate_k_num) + + +def test_example_1_detail_3_and_4(): + G = graph_example_1() + result = k_components(G) + # In this example graph there are 8 3-components, 4 with 15 nodes + # and 4 with 5 nodes. + assert len(result[3]) == 8 + assert len([c for c in result[3] if len(c) == 15]) == 4 + assert len([c for c in result[3] if len(c) == 5]) == 4 + # There are also 8 4-components all with 5 nodes. + assert len(result[4]) == 8 + assert all(len(c) == 5 for c in result[4]) + # Finally check that the k-components detected have actually node + # connectivity >= k. + for k, components in result.items(): + if k < 3: + continue + for component in components: + K = nx.node_connectivity(G.subgraph(component)) + assert K >= k + + +def test_directed(): + with pytest.raises(nx.NetworkXNotImplemented): + G = nx.gnp_random_graph(10, 0.4, directed=True) + kc = k_components(G) + + +def test_same(): + equal = {"A": 2, "B": 2, "C": 2} + slightly_different = {"A": 2, "B": 1, "C": 2} + different = {"A": 2, "B": 8, "C": 18} + assert _same(equal) + assert not _same(slightly_different) + assert _same(slightly_different, tol=1) + assert not _same(different) + assert not _same(different, tol=4) + + +class TestAntiGraph: + @classmethod + def setup_class(cls): + cls.Gnp = nx.gnp_random_graph(20, 0.8, seed=42) + cls.Anp = _AntiGraph(nx.complement(cls.Gnp)) + cls.Gd = nx.davis_southern_women_graph() + cls.Ad = _AntiGraph(nx.complement(cls.Gd)) + cls.Gk = nx.karate_club_graph() + cls.Ak = _AntiGraph(nx.complement(cls.Gk)) + cls.GA = [(cls.Gnp, cls.Anp), (cls.Gd, cls.Ad), (cls.Gk, cls.Ak)] + + def test_size(self): + for G, A in self.GA: + n = G.order() + s = len(list(G.edges())) + len(list(A.edges())) + assert s == (n * (n - 1)) / 2 + + def test_degree(self): + for G, A in self.GA: + assert sorted(G.degree()) == sorted(A.degree()) + + def test_core_number(self): + for G, A in self.GA: + assert nx.core_number(G) == nx.core_number(A) + + def test_connected_components(self): + # ccs are same unless isolated nodes or any node has degree=len(G)-1 + # graphs in self.GA avoid this problem + for G, A in self.GA: + gc = [set(c) for c in nx.connected_components(G)] + ac = [set(c) for c in nx.connected_components(A)] + for comp in ac: + assert comp in gc + + def test_adj(self): + for G, A in self.GA: + for n, nbrs in G.adj.items(): + a_adj = sorted((n, sorted(ad)) for n, ad in A.adj.items()) + g_adj = sorted((n, sorted(ad)) for n, ad in G.adj.items()) + assert a_adj == g_adj + + def test_adjacency(self): + for G, A in self.GA: + a_adj = list(A.adjacency()) + for n, nbrs in G.adjacency(): + assert (n, set(nbrs)) in a_adj + + def test_neighbors(self): + for G, A in self.GA: + node = list(G.nodes())[0] + assert set(G.neighbors(node)) == set(A.neighbors(node)) + + def test_node_not_in_graph(self): + for G, A in self.GA: + node = "non_existent_node" + pytest.raises(nx.NetworkXError, A.neighbors, node) + pytest.raises(nx.NetworkXError, G.neighbors, node) + + def test_degree_thingraph(self): + for G, A in self.GA: + node = list(G.nodes())[0] + nodes = list(G.nodes())[1:4] + assert G.degree(node) == A.degree(node) + assert sum(d for n, d in G.degree()) == sum(d for n, d in A.degree()) + # AntiGraph is a ThinGraph, so all the weights are 1 + assert sum(d for n, d in A.degree()) == sum( + d for n, d in A.degree(weight="weight") + ) + assert sum(d for n, d in G.degree(nodes)) == sum( + d for n, d in A.degree(nodes) + ) diff --git a/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/test_matching.py b/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/test_matching.py new file mode 100644 index 0000000000000000000000000000000000000000..f50da3d2e07310fc19e1db2bd18fdce23223771c --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/test_matching.py @@ -0,0 +1,8 @@ +import networkx as nx +import networkx.algorithms.approximation as a + + +def test_min_maximal_matching(): + # smoke test + G = nx.Graph() + assert len(a.min_maximal_matching(G)) == 0 diff --git a/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/test_maxcut.py b/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/test_maxcut.py new file mode 100644 index 0000000000000000000000000000000000000000..ef0424401e4b2a7e6d580c762f23cea240e55b3c --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/test_maxcut.py @@ -0,0 +1,94 @@ +import random + +import pytest + +import networkx as nx +from networkx.algorithms.approximation import maxcut + + +@pytest.mark.parametrize( + "f", (nx.approximation.randomized_partitioning, nx.approximation.one_exchange) +) +@pytest.mark.parametrize("graph_constructor", (nx.DiGraph, nx.MultiGraph)) +def test_raises_on_directed_and_multigraphs(f, graph_constructor): + G = graph_constructor([(0, 1), (1, 2)]) + with pytest.raises(nx.NetworkXNotImplemented): + f(G) + + +def _is_valid_cut(G, set1, set2): + union = set1.union(set2) + assert union == set(G.nodes) + assert len(set1) + len(set2) == G.number_of_nodes() + + +def _cut_is_locally_optimal(G, cut_size, set1): + # test if cut can be locally improved + for i, node in enumerate(set1): + cut_size_without_node = nx.algorithms.cut_size( + G, set1 - {node}, weight="weight" + ) + assert cut_size_without_node <= cut_size + + +def test_random_partitioning(): + G = nx.complete_graph(5) + _, (set1, set2) = maxcut.randomized_partitioning(G, seed=5) + _is_valid_cut(G, set1, set2) + + +def test_random_partitioning_all_to_one(): + G = nx.complete_graph(5) + _, (set1, set2) = maxcut.randomized_partitioning(G, p=1) + _is_valid_cut(G, set1, set2) + assert len(set1) == G.number_of_nodes() + assert len(set2) == 0 + + +def test_one_exchange_basic(): + G = nx.complete_graph(5) + random.seed(5) + for u, v, w in G.edges(data=True): + w["weight"] = random.randrange(-100, 100, 1) / 10 + + initial_cut = set(random.sample(sorted(G.nodes()), k=5)) + cut_size, (set1, set2) = maxcut.one_exchange( + G, initial_cut, weight="weight", seed=5 + ) + + _is_valid_cut(G, set1, set2) + _cut_is_locally_optimal(G, cut_size, set1) + + +def test_one_exchange_optimal(): + # Greedy one exchange should find the optimal solution for this graph (14) + G = nx.Graph() + G.add_edge(1, 2, weight=3) + G.add_edge(1, 3, weight=3) + G.add_edge(1, 4, weight=3) + G.add_edge(1, 5, weight=3) + G.add_edge(2, 3, weight=5) + + cut_size, (set1, set2) = maxcut.one_exchange(G, weight="weight", seed=5) + + _is_valid_cut(G, set1, set2) + _cut_is_locally_optimal(G, cut_size, set1) + # check global optimality + assert cut_size == 14 + + +def test_negative_weights(): + G = nx.complete_graph(5) + random.seed(5) + for u, v, w in G.edges(data=True): + w["weight"] = -1 * random.random() + + initial_cut = set(random.sample(sorted(G.nodes()), k=5)) + cut_size, (set1, set2) = maxcut.one_exchange(G, initial_cut, weight="weight") + + # make sure it is a valid cut + _is_valid_cut(G, set1, set2) + # check local optimality + _cut_is_locally_optimal(G, cut_size, set1) + # test that all nodes are in the same partition + assert len(set1) == len(G.nodes) or len(set2) == len(G.nodes) diff --git a/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/test_ramsey.py b/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/test_ramsey.py new file mode 100644 index 0000000000000000000000000000000000000000..32fe1fb8fa917c557954d9da0d960895a6953a11 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/test_ramsey.py @@ -0,0 +1,31 @@ +import networkx as nx +import networkx.algorithms.approximation as apxa + + +def test_ramsey(): + # this should only find the complete graph + graph = nx.complete_graph(10) + c, i = apxa.ramsey_R2(graph) + cdens = nx.density(graph.subgraph(c)) + assert cdens == 1.0, "clique not correctly found by ramsey!" + idens = nx.density(graph.subgraph(i)) + assert idens == 0.0, "i-set not correctly found by ramsey!" + + # this trivial graph has no cliques. should just find i-sets + graph = nx.trivial_graph() + c, i = apxa.ramsey_R2(graph) + assert c == {0}, "clique not correctly found by ramsey!" + assert i == {0}, "i-set not correctly found by ramsey!" + + graph = nx.barbell_graph(10, 5, nx.Graph()) + c, i = apxa.ramsey_R2(graph) + cdens = nx.density(graph.subgraph(c)) + assert cdens == 1.0, "clique not correctly found by ramsey!" + idens = nx.density(graph.subgraph(i)) + assert idens == 0.0, "i-set not correctly found by ramsey!" + + # add self-loops and test again + graph.add_edges_from([(n, n) for n in range(0, len(graph), 2)]) + cc, ii = apxa.ramsey_R2(graph) + assert cc == c + assert ii == i diff --git a/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/test_steinertree.py b/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/test_steinertree.py new file mode 100644 index 0000000000000000000000000000000000000000..1b074757cd71f3c67037699ff0e6db84bdf7e119 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/test_steinertree.py @@ -0,0 +1,265 @@ +import pytest + +import networkx as nx +from networkx.algorithms.approximation.steinertree import ( + _remove_nonterminal_leaves, + metric_closure, + steiner_tree, +) +from networkx.utils import edges_equal + + +class TestSteinerTree: + @classmethod + def setup_class(cls): + G1 = nx.Graph() + G1.add_edge(1, 2, weight=10) + G1.add_edge(2, 3, weight=10) + G1.add_edge(3, 4, weight=10) + G1.add_edge(4, 5, weight=10) + G1.add_edge(5, 6, weight=10) + G1.add_edge(2, 7, weight=1) + G1.add_edge(7, 5, weight=1) + + G2 = nx.Graph() + G2.add_edge(0, 5, weight=6) + G2.add_edge(1, 2, weight=2) + G2.add_edge(1, 5, weight=3) + G2.add_edge(2, 4, weight=4) + G2.add_edge(3, 5, weight=5) + G2.add_edge(4, 5, weight=1) + + G3 = nx.Graph() + G3.add_edge(1, 2, weight=8) + G3.add_edge(1, 9, weight=3) + G3.add_edge(1, 8, weight=6) + G3.add_edge(1, 10, weight=2) + G3.add_edge(1, 14, weight=3) + G3.add_edge(2, 3, weight=6) + G3.add_edge(3, 4, weight=3) + G3.add_edge(3, 10, weight=2) + G3.add_edge(3, 11, weight=1) + G3.add_edge(4, 5, weight=1) + G3.add_edge(4, 11, weight=1) + G3.add_edge(5, 6, weight=4) + G3.add_edge(5, 11, weight=2) + G3.add_edge(5, 12, weight=1) + G3.add_edge(5, 13, weight=3) + G3.add_edge(6, 7, weight=2) + G3.add_edge(6, 12, weight=3) + G3.add_edge(6, 13, weight=1) + G3.add_edge(7, 8, weight=3) + G3.add_edge(7, 9, weight=3) + G3.add_edge(7, 11, weight=5) + G3.add_edge(7, 13, weight=2) + G3.add_edge(7, 14, weight=4) + G3.add_edge(8, 9, weight=2) + G3.add_edge(9, 14, weight=1) + G3.add_edge(10, 11, weight=2) + G3.add_edge(10, 14, weight=1) + G3.add_edge(11, 12, weight=1) + G3.add_edge(11, 14, weight=7) + G3.add_edge(12, 14, weight=3) + G3.add_edge(12, 15, weight=1) + G3.add_edge(13, 14, weight=4) + G3.add_edge(13, 15, weight=1) + G3.add_edge(14, 15, weight=2) + + cls.G1 = G1 + cls.G2 = G2 + cls.G3 = G3 + cls.G1_term_nodes = [1, 2, 3, 4, 5] + cls.G2_term_nodes = [0, 2, 3] + cls.G3_term_nodes = [1, 3, 5, 6, 8, 10, 11, 12, 13] + + cls.methods = ["kou", "mehlhorn"] + + def test_connected_metric_closure(self): + G = self.G1.copy() + G.add_node(100) + pytest.raises(nx.NetworkXError, metric_closure, G) + + def test_metric_closure(self): + M = metric_closure(self.G1) + mc = [ + (1, 2, {"distance": 10, "path": [1, 2]}), + (1, 3, {"distance": 20, "path": [1, 2, 3]}), + (1, 4, {"distance": 22, "path": [1, 2, 7, 5, 4]}), + (1, 5, {"distance": 12, "path": [1, 2, 7, 5]}), + (1, 6, {"distance": 22, "path": [1, 2, 7, 5, 6]}), + (1, 7, {"distance": 11, "path": [1, 2, 7]}), + (2, 3, {"distance": 10, "path": [2, 3]}), + (2, 4, {"distance": 12, "path": [2, 7, 5, 4]}), + (2, 5, {"distance": 2, "path": [2, 7, 5]}), + (2, 6, {"distance": 12, "path": [2, 7, 5, 6]}), + (2, 7, {"distance": 1, "path": [2, 7]}), + (3, 4, {"distance": 10, "path": [3, 4]}), + (3, 5, {"distance": 12, "path": [3, 2, 7, 5]}), + (3, 6, {"distance": 22, "path": [3, 2, 7, 5, 6]}), + (3, 7, {"distance": 11, "path": [3, 2, 7]}), + (4, 5, {"distance": 10, "path": [4, 5]}), + (4, 6, {"distance": 20, "path": [4, 5, 6]}), + (4, 7, {"distance": 11, "path": [4, 5, 7]}), + (5, 6, {"distance": 10, "path": [5, 6]}), + (5, 7, {"distance": 1, "path": [5, 7]}), + (6, 7, {"distance": 11, "path": [6, 5, 7]}), + ] + assert edges_equal(list(M.edges(data=True)), mc) + + def test_steiner_tree(self): + valid_steiner_trees = [ + [ + [ + (1, 2, {"weight": 10}), + (2, 3, {"weight": 10}), + (2, 7, {"weight": 1}), + (3, 4, {"weight": 10}), + (5, 7, {"weight": 1}), + ], + [ + (1, 2, {"weight": 10}), + (2, 7, {"weight": 1}), + (3, 4, {"weight": 10}), + (4, 5, {"weight": 10}), + (5, 7, {"weight": 1}), + ], + [ + (1, 2, {"weight": 10}), + (2, 3, {"weight": 10}), + (2, 7, {"weight": 1}), + (4, 5, {"weight": 10}), + (5, 7, {"weight": 1}), + ], + ], + [ + [ + (0, 5, {"weight": 6}), + (1, 2, {"weight": 2}), + (1, 5, {"weight": 3}), + (3, 5, {"weight": 5}), + ], + [ + (0, 5, {"weight": 6}), + (4, 2, {"weight": 4}), + (4, 5, {"weight": 1}), + (3, 5, {"weight": 5}), + ], + ], + [ + [ + (1, 10, {"weight": 2}), + (3, 10, {"weight": 2}), + (3, 11, {"weight": 1}), + (5, 12, {"weight": 1}), + (6, 13, {"weight": 1}), + (8, 9, {"weight": 2}), + (9, 14, {"weight": 1}), + (10, 14, {"weight": 1}), + (11, 12, {"weight": 1}), + (12, 15, {"weight": 1}), + (13, 15, {"weight": 1}), + ] + ], + ] + for method in self.methods: + for G, term_nodes, valid_trees in zip( + [self.G1, self.G2, self.G3], + [self.G1_term_nodes, self.G2_term_nodes, self.G3_term_nodes], + valid_steiner_trees, + ): + S = steiner_tree(G, term_nodes, method=method) + assert any( + edges_equal(list(S.edges(data=True)), valid_tree) + for valid_tree in valid_trees + ) + + def test_multigraph_steiner_tree(self): + G = nx.MultiGraph() + G.add_edges_from( + [ + (1, 2, 0, {"weight": 1}), + (2, 3, 0, {"weight": 999}), + (2, 3, 1, {"weight": 1}), + (3, 4, 0, {"weight": 1}), + (3, 5, 0, {"weight": 1}), + ] + ) + terminal_nodes = [2, 4, 5] + expected_edges = [ + (2, 3, 1, {"weight": 1}), # edge with key 1 has lower weight + (3, 4, 0, {"weight": 1}), + (3, 5, 0, {"weight": 1}), + ] + for method in self.methods: + S = steiner_tree(G, terminal_nodes, method=method) + assert edges_equal(S.edges(data=True, keys=True), expected_edges) + + def test_remove_nonterminal_leaves(self): + G = nx.path_graph(10) + _remove_nonterminal_leaves(G, [4, 5, 6]) + + assert list(G) == [4, 5, 6] # only the terminal nodes are left + + +@pytest.mark.parametrize("method", ("kou", "mehlhorn")) +def test_steiner_tree_weight_attribute(method): + G = nx.star_graph(4) + # Add an edge attribute that is named something other than "weight" + nx.set_edge_attributes(G, {e: 10 for e in G.edges}, name="distance") + H = nx.approximation.steiner_tree(G, [1, 3], method=method, weight="distance") + assert nx.utils.edges_equal(H.edges, [(0, 1), (0, 3)]) + + +@pytest.mark.parametrize("method", ("kou", "mehlhorn")) +def test_steiner_tree_multigraph_weight_attribute(method): + G = nx.cycle_graph(3, create_using=nx.MultiGraph) + nx.set_edge_attributes(G, {e: 10 for e in G.edges}, name="distance") + G.add_edge(2, 0, distance=5) + H = nx.approximation.steiner_tree(G, list(G), method=method, weight="distance") + assert len(H.edges) == 2 and H.has_edge(2, 0, key=1) + assert sum(dist for *_, dist in H.edges(data="distance")) == 15 + + +@pytest.mark.parametrize("method", (None, "mehlhorn", "kou")) +def test_steiner_tree_methods(method): + G = nx.star_graph(4) + expected = nx.Graph([(0, 1), (0, 3)]) + st = nx.approximation.steiner_tree(G, [1, 3], method=method) + assert nx.utils.edges_equal(st.edges, expected.edges) + + +def test_steiner_tree_method_invalid(): + G = nx.star_graph(4) + with pytest.raises( + ValueError, match="invalid_method is not a valid choice for an algorithm." + ): + nx.approximation.steiner_tree(G, terminal_nodes=[1, 3], method="invalid_method") + + +def test_steiner_tree_remove_non_terminal_leaves_self_loop_edges(): + # To verify that the last step of the steiner tree approximation + # behaves in the case where a non-terminal leaf has a self loop edge + G = nx.path_graph(10) + + # Add self loops to the terminal nodes + G.add_edges_from([(2, 2), (3, 3), (4, 4), (7, 7), (8, 8)]) + + # Remove non-terminal leaves + _remove_nonterminal_leaves(G, [4, 5, 6, 7]) + + # The terminal nodes should be left + assert list(G) == [4, 5, 6, 7] # only the terminal nodes are left + + +def test_steiner_tree_non_terminal_leaves_multigraph_self_loop_edges(): + # To verify that the last step of the steiner tree approximation + # behaves in the case where a non-terminal leaf has a self loop edge + G = nx.MultiGraph() + G.add_edges_from([(i, i + 1) for i in range(10)]) + G.add_edges_from([(2, 2), (3, 3), (4, 4), (4, 4), (7, 7)]) + + # Remove non-terminal leaves + _remove_nonterminal_leaves(G, [4, 5, 6, 7]) + + # Only the terminal nodes should be left + assert list(G) == [4, 5, 6, 7] diff --git a/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/test_traveling_salesman.py b/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/test_traveling_salesman.py new file mode 100644 index 0000000000000000000000000000000000000000..2084c19af44b3d2a95a2b09af7f77275b2c6b357 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/test_traveling_salesman.py @@ -0,0 +1,977 @@ +"""Unit tests for the traveling_salesman module.""" + +import random + +import pytest + +import networkx as nx +import networkx.algorithms.approximation as nx_app + +pairwise = nx.utils.pairwise + + +def test_christofides_hamiltonian(): + random.seed(42) + G = nx.complete_graph(20) + for u, v in G.edges(): + G[u][v]["weight"] = random.randint(0, 10) + + H = nx.Graph() + H.add_edges_from(pairwise(nx_app.christofides(G))) + H.remove_edges_from(nx.find_cycle(H)) + assert len(H.edges) == 0 + + tree = nx.minimum_spanning_tree(G, weight="weight") + H = nx.Graph() + H.add_edges_from(pairwise(nx_app.christofides(G, tree))) + H.remove_edges_from(nx.find_cycle(H)) + assert len(H.edges) == 0 + + +def test_christofides_incomplete_graph(): + G = nx.complete_graph(10) + G.remove_edge(0, 1) + pytest.raises(nx.NetworkXError, nx_app.christofides, G) + + +def test_christofides_ignore_selfloops(): + G = nx.complete_graph(5) + G.add_edge(3, 3) + cycle = nx_app.christofides(G) + assert len(cycle) - 1 == len(G) == len(set(cycle)) + + +# set up graphs for other tests +class TestBase: + @classmethod + def setup_class(cls): + cls.DG = nx.DiGraph() + cls.DG.add_weighted_edges_from( + { + ("A", "B", 3), + ("A", "C", 17), + ("A", "D", 14), + ("B", "A", 3), + ("B", "C", 12), + ("B", "D", 16), + ("C", "A", 13), + ("C", "B", 12), + ("C", "D", 4), + ("D", "A", 14), + ("D", "B", 15), + ("D", "C", 2), + } + ) + cls.DG_cycle = ["D", "C", "B", "A", "D"] + cls.DG_cost = 31.0 + + cls.DG2 = nx.DiGraph() + cls.DG2.add_weighted_edges_from( + { + ("A", "B", 3), + ("A", "C", 17), + ("A", "D", 14), + ("B", "A", 30), + ("B", "C", 2), + ("B", "D", 16), + ("C", "A", 33), + ("C", "B", 32), + ("C", "D", 34), + ("D", "A", 14), + ("D", "B", 15), + ("D", "C", 2), + } + ) + cls.DG2_cycle = ["D", "A", "B", "C", "D"] + cls.DG2_cost = 53.0 + + cls.unweightedUG = nx.complete_graph(5, nx.Graph()) + cls.unweightedDG = nx.complete_graph(5, nx.DiGraph()) + + cls.incompleteUG = nx.Graph() + cls.incompleteUG.add_weighted_edges_from({(0, 1, 1), (1, 2, 3)}) + cls.incompleteDG = nx.DiGraph() + cls.incompleteDG.add_weighted_edges_from({(0, 1, 1), (1, 2, 3)}) + + cls.UG = nx.Graph() + cls.UG.add_weighted_edges_from( + { + ("A", "B", 3), + ("A", "C", 17), + ("A", "D", 14), + ("B", "C", 12), + ("B", "D", 16), + ("C", "D", 4), + } + ) + cls.UG_cycle = ["D", "C", "B", "A", "D"] + cls.UG_cost = 33.0 + + cls.UG2 = nx.Graph() + cls.UG2.add_weighted_edges_from( + { + ("A", "B", 1), + ("A", "C", 15), + ("A", "D", 5), + ("B", "C", 16), + ("B", "D", 8), + ("C", "D", 3), + } + ) + cls.UG2_cycle = ["D", "C", "B", "A", "D"] + cls.UG2_cost = 25.0 + + +def validate_solution(soln, cost, exp_soln, exp_cost): + assert soln == exp_soln + assert cost == exp_cost + + +def validate_symmetric_solution(soln, cost, exp_soln, exp_cost): + assert soln == exp_soln or soln == exp_soln[::-1] + assert cost == exp_cost + + +class TestGreedyTSP(TestBase): + def test_greedy(self): + cycle = nx_app.greedy_tsp(self.DG, source="D") + cost = sum(self.DG[n][nbr]["weight"] for n, nbr in pairwise(cycle)) + validate_solution(cycle, cost, ["D", "C", "B", "A", "D"], 31.0) + + cycle = nx_app.greedy_tsp(self.DG2, source="D") + cost = sum(self.DG2[n][nbr]["weight"] for n, nbr in pairwise(cycle)) + validate_solution(cycle, cost, ["D", "C", "B", "A", "D"], 78.0) + + cycle = nx_app.greedy_tsp(self.UG, source="D") + cost = sum(self.UG[n][nbr]["weight"] for n, nbr in pairwise(cycle)) + validate_solution(cycle, cost, ["D", "C", "B", "A", "D"], 33.0) + + cycle = nx_app.greedy_tsp(self.UG2, source="D") + cost = sum(self.UG2[n][nbr]["weight"] for n, nbr in pairwise(cycle)) + validate_solution(cycle, cost, ["D", "C", "A", "B", "D"], 27.0) + + def test_not_complete_graph(self): + pytest.raises(nx.NetworkXError, nx_app.greedy_tsp, self.incompleteUG) + pytest.raises(nx.NetworkXError, nx_app.greedy_tsp, self.incompleteDG) + + def test_not_weighted_graph(self): + nx_app.greedy_tsp(self.unweightedUG) + nx_app.greedy_tsp(self.unweightedDG) + + def test_two_nodes(self): + G = nx.Graph() + G.add_weighted_edges_from({(1, 2, 1)}) + cycle = nx_app.greedy_tsp(G) + cost = sum(G[n][nbr]["weight"] for n, nbr in pairwise(cycle)) + validate_solution(cycle, cost, [1, 2, 1], 2) + + def test_ignore_selfloops(self): + G = nx.complete_graph(5) + G.add_edge(3, 3) + cycle = nx_app.greedy_tsp(G) + assert len(cycle) - 1 == len(G) == len(set(cycle)) + + +class TestSimulatedAnnealingTSP(TestBase): + tsp = staticmethod(nx_app.simulated_annealing_tsp) + + def test_simulated_annealing_directed(self): + cycle = self.tsp(self.DG, "greedy", source="D", seed=42) + cost = sum(self.DG[n][nbr]["weight"] for n, nbr in pairwise(cycle)) + validate_solution(cycle, cost, self.DG_cycle, self.DG_cost) + + initial_sol = ["D", "B", "A", "C", "D"] + cycle = self.tsp(self.DG, initial_sol, source="D", seed=42) + cost = sum(self.DG[n][nbr]["weight"] for n, nbr in pairwise(cycle)) + validate_solution(cycle, cost, self.DG_cycle, self.DG_cost) + + initial_sol = ["D", "A", "C", "B", "D"] + cycle = self.tsp(self.DG, initial_sol, move="1-0", source="D", seed=42) + cost = sum(self.DG[n][nbr]["weight"] for n, nbr in pairwise(cycle)) + validate_solution(cycle, cost, self.DG_cycle, self.DG_cost) + + cycle = self.tsp(self.DG2, "greedy", source="D", seed=42) + cost = sum(self.DG2[n][nbr]["weight"] for n, nbr in pairwise(cycle)) + validate_solution(cycle, cost, self.DG2_cycle, self.DG2_cost) + + cycle = self.tsp(self.DG2, "greedy", move="1-0", source="D", seed=42) + cost = sum(self.DG2[n][nbr]["weight"] for n, nbr in pairwise(cycle)) + validate_solution(cycle, cost, self.DG2_cycle, self.DG2_cost) + + def test_simulated_annealing_undirected(self): + cycle = self.tsp(self.UG, "greedy", source="D", seed=42) + cost = sum(self.UG[n][nbr]["weight"] for n, nbr in pairwise(cycle)) + validate_solution(cycle, cost, self.UG_cycle, self.UG_cost) + + cycle = self.tsp(self.UG2, "greedy", source="D", seed=42) + cost = sum(self.UG2[n][nbr]["weight"] for n, nbr in pairwise(cycle)) + validate_symmetric_solution(cycle, cost, self.UG2_cycle, self.UG2_cost) + + cycle = self.tsp(self.UG2, "greedy", move="1-0", source="D", seed=42) + cost = sum(self.UG2[n][nbr]["weight"] for n, nbr in pairwise(cycle)) + validate_symmetric_solution(cycle, cost, self.UG2_cycle, self.UG2_cost) + + def test_error_on_input_order_mistake(self): + # see issue #4846 https://github.com/networkx/networkx/issues/4846 + pytest.raises(TypeError, self.tsp, self.UG, weight="weight") + pytest.raises(nx.NetworkXError, self.tsp, self.UG, "weight") + + def test_not_complete_graph(self): + pytest.raises(nx.NetworkXError, self.tsp, self.incompleteUG, "greedy", source=0) + pytest.raises(nx.NetworkXError, self.tsp, self.incompleteDG, "greedy", source=0) + + def test_ignore_selfloops(self): + G = nx.complete_graph(5) + G.add_edge(3, 3) + cycle = self.tsp(G, "greedy") + assert len(cycle) - 1 == len(G) == len(set(cycle)) + + def test_not_weighted_graph(self): + self.tsp(self.unweightedUG, "greedy") + self.tsp(self.unweightedDG, "greedy") + + def test_two_nodes(self): + G = nx.Graph() + G.add_weighted_edges_from({(1, 2, 1)}) + + cycle = self.tsp(G, "greedy", source=1, seed=42) + cost = sum(G[n][nbr]["weight"] for n, nbr in pairwise(cycle)) + validate_solution(cycle, cost, [1, 2, 1], 2) + + cycle = self.tsp(G, [1, 2, 1], source=1, seed=42) + cost = sum(G[n][nbr]["weight"] for n, nbr in pairwise(cycle)) + validate_solution(cycle, cost, [1, 2, 1], 2) + + def test_failure_of_costs_too_high_when_iterations_low(self): + # Simulated Annealing Version: + # set number of moves low and alpha high + cycle = self.tsp( + self.DG2, "greedy", source="D", move="1-0", alpha=1, N_inner=1, seed=42 + ) + cost = sum(self.DG2[n][nbr]["weight"] for n, nbr in pairwise(cycle)) + print(cycle, cost) + assert cost > self.DG2_cost + + # Try with an incorrect initial guess + initial_sol = ["D", "A", "B", "C", "D"] + cycle = self.tsp( + self.DG, + initial_sol, + source="D", + move="1-0", + alpha=0.1, + N_inner=1, + max_iterations=1, + seed=42, + ) + cost = sum(self.DG[n][nbr]["weight"] for n, nbr in pairwise(cycle)) + print(cycle, cost) + assert cost > self.DG_cost + + +class TestThresholdAcceptingTSP(TestSimulatedAnnealingTSP): + tsp = staticmethod(nx_app.threshold_accepting_tsp) + + def test_failure_of_costs_too_high_when_iterations_low(self): + # Threshold Version: + # set number of moves low and number of iterations low + cycle = self.tsp( + self.DG2, + "greedy", + source="D", + move="1-0", + N_inner=1, + max_iterations=1, + seed=4, + ) + cost = sum(self.DG2[n][nbr]["weight"] for n, nbr in pairwise(cycle)) + assert cost > self.DG2_cost + + # set threshold too low + initial_sol = ["D", "A", "B", "C", "D"] + cycle = self.tsp( + self.DG, initial_sol, source="D", move="1-0", threshold=-3, seed=42 + ) + cost = sum(self.DG[n][nbr]["weight"] for n, nbr in pairwise(cycle)) + assert cost > self.DG_cost + + +# Tests for function traveling_salesman_problem +def test_TSP_method(): + G = nx.cycle_graph(9) + G[4][5]["weight"] = 10 + + # Test using the old currying method + sa_tsp = lambda G, weight: nx_app.simulated_annealing_tsp( + G, "greedy", weight, source=4, seed=1 + ) + + path = nx_app.traveling_salesman_problem( + G, + method=sa_tsp, + cycle=False, + ) + print(path) + assert path == [4, 3, 2, 1, 0, 8, 7, 6, 5] + + +def test_TSP_unweighted(): + G = nx.cycle_graph(9) + path = nx_app.traveling_salesman_problem(G, nodes=[3, 6], cycle=False) + assert path in ([3, 4, 5, 6], [6, 5, 4, 3]) + + cycle = nx_app.traveling_salesman_problem(G, nodes=[3, 6]) + assert cycle in ([3, 4, 5, 6, 5, 4, 3], [6, 5, 4, 3, 4, 5, 6]) + + +def test_TSP_weighted(): + G = nx.cycle_graph(9) + G[0][1]["weight"] = 2 + G[1][2]["weight"] = 2 + G[2][3]["weight"] = 2 + G[3][4]["weight"] = 4 + G[4][5]["weight"] = 5 + G[5][6]["weight"] = 4 + G[6][7]["weight"] = 2 + G[7][8]["weight"] = 2 + G[8][0]["weight"] = 2 + tsp = nx_app.traveling_salesman_problem + + # path between 3 and 6 + expected_paths = ([3, 2, 1, 0, 8, 7, 6], [6, 7, 8, 0, 1, 2, 3]) + # cycle between 3 and 6 + expected_cycles = ( + [3, 2, 1, 0, 8, 7, 6, 7, 8, 0, 1, 2, 3], + [6, 7, 8, 0, 1, 2, 3, 2, 1, 0, 8, 7, 6], + ) + # path through all nodes + expected_tourpaths = ([5, 6, 7, 8, 0, 1, 2, 3, 4], [4, 3, 2, 1, 0, 8, 7, 6, 5]) + + # Check default method + cycle = tsp(G, nodes=[3, 6], weight="weight") + assert cycle in expected_cycles + + path = tsp(G, nodes=[3, 6], weight="weight", cycle=False) + assert path in expected_paths + + tourpath = tsp(G, weight="weight", cycle=False) + assert tourpath in expected_tourpaths + + # Check all methods + methods = [ + (nx_app.christofides, {}), + (nx_app.greedy_tsp, {}), + ( + nx_app.simulated_annealing_tsp, + {"init_cycle": "greedy"}, + ), + ( + nx_app.threshold_accepting_tsp, + {"init_cycle": "greedy"}, + ), + ] + for method, kwargs in methods: + cycle = tsp(G, nodes=[3, 6], weight="weight", method=method, **kwargs) + assert cycle in expected_cycles + + path = tsp( + G, nodes=[3, 6], weight="weight", method=method, cycle=False, **kwargs + ) + assert path in expected_paths + + tourpath = tsp(G, weight="weight", method=method, cycle=False, **kwargs) + assert tourpath in expected_tourpaths + + +def test_TSP_incomplete_graph_short_path(): + G = nx.cycle_graph(9) + G.add_edges_from([(4, 9), (9, 10), (10, 11), (11, 0)]) + G[4][5]["weight"] = 5 + + cycle = nx_app.traveling_salesman_problem(G) + print(cycle) + assert len(cycle) == 17 and len(set(cycle)) == 12 + + # make sure that cutting one edge out of complete graph formulation + # cuts out many edges out of the path of the TSP + path = nx_app.traveling_salesman_problem(G, cycle=False) + print(path) + assert len(path) == 13 and len(set(path)) == 12 + + +def test_held_karp_ascent(): + """ + Test the Held-Karp relaxation with the ascent method + """ + import networkx.algorithms.approximation.traveling_salesman as tsp + + np = pytest.importorskip("numpy") + pytest.importorskip("scipy") + + # Adjacency matrix from page 1153 of the 1970 Held and Karp paper + # which have been edited to be directional, but also symmetric + G_array = np.array( + [ + [0, 97, 60, 73, 17, 52], + [97, 0, 41, 52, 90, 30], + [60, 41, 0, 21, 35, 41], + [73, 52, 21, 0, 95, 46], + [17, 90, 35, 95, 0, 81], + [52, 30, 41, 46, 81, 0], + ] + ) + + solution_edges = [(1, 3), (2, 4), (3, 2), (4, 0), (5, 1), (0, 5)] + + G = nx.from_numpy_array(G_array, create_using=nx.DiGraph) + opt_hk, z_star = tsp.held_karp_ascent(G) + + # Check that the optimal weights are the same + assert round(opt_hk, 2) == 207.00 + # Check that the z_stars are the same + solution = nx.DiGraph() + solution.add_edges_from(solution_edges) + assert nx.utils.edges_equal(z_star.edges, solution.edges) + + +def test_ascent_fractional_solution(): + """ + Test the ascent method using a modified version of Figure 2 on page 1140 + in 'The Traveling Salesman Problem and Minimum Spanning Trees' by Held and + Karp + """ + import networkx.algorithms.approximation.traveling_salesman as tsp + + np = pytest.importorskip("numpy") + pytest.importorskip("scipy") + + # This version of Figure 2 has all of the edge weights multiplied by 100 + # and is a complete directed graph with infinite edge weights for the + # edges not listed in the original graph + G_array = np.array( + [ + [0, 100, 100, 100000, 100000, 1], + [100, 0, 100, 100000, 1, 100000], + [100, 100, 0, 1, 100000, 100000], + [100000, 100000, 1, 0, 100, 100], + [100000, 1, 100000, 100, 0, 100], + [1, 100000, 100000, 100, 100, 0], + ] + ) + + solution_z_star = { + (0, 1): 5 / 12, + (0, 2): 5 / 12, + (0, 5): 5 / 6, + (1, 0): 5 / 12, + (1, 2): 1 / 3, + (1, 4): 5 / 6, + (2, 0): 5 / 12, + (2, 1): 1 / 3, + (2, 3): 5 / 6, + (3, 2): 5 / 6, + (3, 4): 1 / 3, + (3, 5): 1 / 2, + (4, 1): 5 / 6, + (4, 3): 1 / 3, + (4, 5): 1 / 2, + (5, 0): 5 / 6, + (5, 3): 1 / 2, + (5, 4): 1 / 2, + } + + G = nx.from_numpy_array(G_array, create_using=nx.DiGraph) + opt_hk, z_star = tsp.held_karp_ascent(G) + + # Check that the optimal weights are the same + assert round(opt_hk, 2) == 303.00 + # Check that the z_stars are the same + assert {key: round(z_star[key], 4) for key in z_star} == { + key: round(solution_z_star[key], 4) for key in solution_z_star + } + + +def test_ascent_method_asymmetric(): + """ + Tests the ascent method using a truly asymmetric graph for which the + solution has been brute forced + """ + import networkx.algorithms.approximation.traveling_salesman as tsp + + np = pytest.importorskip("numpy") + pytest.importorskip("scipy") + + G_array = np.array( + [ + [0, 26, 63, 59, 69, 31, 41], + [62, 0, 91, 53, 75, 87, 47], + [47, 82, 0, 90, 15, 9, 18], + [68, 19, 5, 0, 58, 34, 93], + [11, 58, 53, 55, 0, 61, 79], + [88, 75, 13, 76, 98, 0, 40], + [41, 61, 55, 88, 46, 45, 0], + ] + ) + + solution_edges = [(0, 1), (1, 3), (3, 2), (2, 5), (5, 6), (4, 0), (6, 4)] + + G = nx.from_numpy_array(G_array, create_using=nx.DiGraph) + opt_hk, z_star = tsp.held_karp_ascent(G) + + # Check that the optimal weights are the same + assert round(opt_hk, 2) == 190.00 + # Check that the z_stars match. + solution = nx.DiGraph() + solution.add_edges_from(solution_edges) + assert nx.utils.edges_equal(z_star.edges, solution.edges) + + +def test_ascent_method_asymmetric_2(): + """ + Tests the ascent method using a truly asymmetric graph for which the + solution has been brute forced + """ + import networkx.algorithms.approximation.traveling_salesman as tsp + + np = pytest.importorskip("numpy") + pytest.importorskip("scipy") + + G_array = np.array( + [ + [0, 45, 39, 92, 29, 31], + [72, 0, 4, 12, 21, 60], + [81, 6, 0, 98, 70, 53], + [49, 71, 59, 0, 98, 94], + [74, 95, 24, 43, 0, 47], + [56, 43, 3, 65, 22, 0], + ] + ) + + solution_edges = [(0, 5), (5, 4), (1, 3), (3, 0), (2, 1), (4, 2)] + + G = nx.from_numpy_array(G_array, create_using=nx.DiGraph) + opt_hk, z_star = tsp.held_karp_ascent(G) + + # Check that the optimal weights are the same + assert round(opt_hk, 2) == 144.00 + # Check that the z_stars match. + solution = nx.DiGraph() + solution.add_edges_from(solution_edges) + assert nx.utils.edges_equal(z_star.edges, solution.edges) + + +def test_held_karp_ascent_asymmetric_3(): + """ + Tests the ascent method using a truly asymmetric graph with a fractional + solution for which the solution has been brute forced. + + In this graph their are two different optimal, integral solutions (which + are also the overall atsp solutions) to the Held Karp relaxation. However, + this particular graph has two different tours of optimal value and the + possible solutions in the held_karp_ascent function are not stored in an + ordered data structure. + """ + import networkx.algorithms.approximation.traveling_salesman as tsp + + np = pytest.importorskip("numpy") + pytest.importorskip("scipy") + + G_array = np.array( + [ + [0, 1, 5, 2, 7, 4], + [7, 0, 7, 7, 1, 4], + [4, 7, 0, 9, 2, 1], + [7, 2, 7, 0, 4, 4], + [5, 5, 4, 4, 0, 3], + [3, 9, 1, 3, 4, 0], + ] + ) + + solution1_edges = [(0, 3), (1, 4), (2, 5), (3, 1), (4, 2), (5, 0)] + + solution2_edges = [(0, 3), (3, 1), (1, 4), (4, 5), (2, 0), (5, 2)] + + G = nx.from_numpy_array(G_array, create_using=nx.DiGraph) + opt_hk, z_star = tsp.held_karp_ascent(G) + + assert round(opt_hk, 2) == 13.00 + # Check that the z_stars are the same + solution1 = nx.DiGraph() + solution1.add_edges_from(solution1_edges) + solution2 = nx.DiGraph() + solution2.add_edges_from(solution2_edges) + assert nx.utils.edges_equal(z_star.edges, solution1.edges) or nx.utils.edges_equal( + z_star.edges, solution2.edges + ) + + +def test_held_karp_ascent_fractional_asymmetric(): + """ + Tests the ascent method using a truly asymmetric graph with a fractional + solution for which the solution has been brute forced + """ + import networkx.algorithms.approximation.traveling_salesman as tsp + + np = pytest.importorskip("numpy") + pytest.importorskip("scipy") + + G_array = np.array( + [ + [0, 100, 150, 100000, 100000, 1], + [150, 0, 100, 100000, 1, 100000], + [100, 150, 0, 1, 100000, 100000], + [100000, 100000, 1, 0, 150, 100], + [100000, 2, 100000, 100, 0, 150], + [2, 100000, 100000, 150, 100, 0], + ] + ) + + solution_z_star = { + (0, 1): 5 / 12, + (0, 2): 5 / 12, + (0, 5): 5 / 6, + (1, 0): 5 / 12, + (1, 2): 5 / 12, + (1, 4): 5 / 6, + (2, 0): 5 / 12, + (2, 1): 5 / 12, + (2, 3): 5 / 6, + (3, 2): 5 / 6, + (3, 4): 5 / 12, + (3, 5): 5 / 12, + (4, 1): 5 / 6, + (4, 3): 5 / 12, + (4, 5): 5 / 12, + (5, 0): 5 / 6, + (5, 3): 5 / 12, + (5, 4): 5 / 12, + } + + G = nx.from_numpy_array(G_array, create_using=nx.DiGraph) + opt_hk, z_star = tsp.held_karp_ascent(G) + + # Check that the optimal weights are the same + assert round(opt_hk, 2) == 304.00 + # Check that the z_stars are the same + assert {key: round(z_star[key], 4) for key in z_star} == { + key: round(solution_z_star[key], 4) for key in solution_z_star + } + + +def test_spanning_tree_distribution(): + """ + Test that we can create an exponential distribution of spanning trees such + that the probability of each tree is proportional to the product of edge + weights. + + Results of this test have been confirmed with hypothesis testing from the + created distribution. + + This test uses the symmetric, fractional Held Karp solution. + """ + import networkx.algorithms.approximation.traveling_salesman as tsp + + pytest.importorskip("numpy") + pytest.importorskip("scipy") + + z_star = { + (0, 1): 5 / 12, + (0, 2): 5 / 12, + (0, 5): 5 / 6, + (1, 0): 5 / 12, + (1, 2): 1 / 3, + (1, 4): 5 / 6, + (2, 0): 5 / 12, + (2, 1): 1 / 3, + (2, 3): 5 / 6, + (3, 2): 5 / 6, + (3, 4): 1 / 3, + (3, 5): 1 / 2, + (4, 1): 5 / 6, + (4, 3): 1 / 3, + (4, 5): 1 / 2, + (5, 0): 5 / 6, + (5, 3): 1 / 2, + (5, 4): 1 / 2, + } + + solution_gamma = { + (0, 1): -0.6383, + (0, 2): -0.6827, + (0, 5): 0, + (1, 2): -1.0781, + (1, 4): 0, + (2, 3): 0, + (5, 3): -0.2820, + (5, 4): -0.3327, + (4, 3): -0.9927, + } + + # The undirected support of z_star + G = nx.MultiGraph() + for u, v in z_star: + if (u, v) in G.edges or (v, u) in G.edges: + continue + G.add_edge(u, v) + + gamma = tsp.spanning_tree_distribution(G, z_star) + + assert {key: round(gamma[key], 4) for key in gamma} == solution_gamma + + +def test_asadpour_tsp(): + """ + Test the complete asadpour tsp algorithm with the fractional, symmetric + Held Karp solution. This test also uses an incomplete graph as input. + """ + # This version of Figure 2 has all of the edge weights multiplied by 100 + # and the 0 weight edges have a weight of 1. + pytest.importorskip("numpy") + pytest.importorskip("scipy") + + edge_list = [ + (0, 1, 100), + (0, 2, 100), + (0, 5, 1), + (1, 2, 100), + (1, 4, 1), + (2, 3, 1), + (3, 4, 100), + (3, 5, 100), + (4, 5, 100), + (1, 0, 100), + (2, 0, 100), + (5, 0, 1), + (2, 1, 100), + (4, 1, 1), + (3, 2, 1), + (4, 3, 100), + (5, 3, 100), + (5, 4, 100), + ] + + G = nx.DiGraph() + G.add_weighted_edges_from(edge_list) + + tour = nx_app.traveling_salesman_problem( + G, weight="weight", method=nx_app.asadpour_atsp, seed=19 + ) + + # Check that the returned list is a valid tour. Because this is an + # incomplete graph, the conditions are not as strict. We need the tour to + # + # Start and end at the same node + # Pass through every vertex at least once + # Have a total cost at most ln(6) / ln(ln(6)) = 3.0723 times the optimal + # + # For the second condition it is possible to have the tour pass through the + # same vertex more then. Imagine that the tour on the complete version takes + # an edge not in the original graph. In the output this is substituted with + # the shortest path between those vertices, allowing vertices to appear more + # than once. + # + # Even though we are using a fixed seed, multiple tours have been known to + # be returned. The first two are from the original development of this test, + # and the third one from issue #5913 on GitHub. If other tours are returned, + # add it on the list of expected tours. + expected_tours = [ + [1, 4, 5, 0, 2, 3, 2, 1], + [3, 2, 0, 1, 4, 5, 3], + [3, 2, 1, 0, 5, 4, 3], + ] + + assert tour in expected_tours + + +def test_asadpour_real_world(): + """ + This test uses airline prices between the six largest cities in the US. + + * New York City -> JFK + * Los Angeles -> LAX + * Chicago -> ORD + * Houston -> IAH + * Phoenix -> PHX + * Philadelphia -> PHL + + Flight prices from August 2021 using Delta or American airlines to get + nonstop flight. The brute force solution found the optimal tour to cost $872 + + This test also uses the `source` keyword argument to ensure that the tour + always starts at city 0. + """ + np = pytest.importorskip("numpy") + pytest.importorskip("scipy") + + G_array = np.array( + [ + # JFK LAX ORD IAH PHX PHL + [0, 243, 199, 208, 169, 183], # JFK + [277, 0, 217, 123, 127, 252], # LAX + [297, 197, 0, 197, 123, 177], # ORD + [303, 169, 197, 0, 117, 117], # IAH + [257, 127, 160, 117, 0, 319], # PHX + [183, 332, 217, 117, 319, 0], # PHL + ] + ) + + node_list = ["JFK", "LAX", "ORD", "IAH", "PHX", "PHL"] + + expected_tours = [ + ["JFK", "LAX", "PHX", "ORD", "IAH", "PHL", "JFK"], + ["JFK", "ORD", "PHX", "LAX", "IAH", "PHL", "JFK"], + ] + + G = nx.from_numpy_array(G_array, nodelist=node_list, create_using=nx.DiGraph) + + tour = nx_app.traveling_salesman_problem( + G, weight="weight", method=nx_app.asadpour_atsp, seed=37, source="JFK" + ) + + assert tour in expected_tours + + +def test_asadpour_real_world_path(): + """ + This test uses airline prices between the six largest cities in the US. This + time using a path, not a cycle. + + * New York City -> JFK + * Los Angeles -> LAX + * Chicago -> ORD + * Houston -> IAH + * Phoenix -> PHX + * Philadelphia -> PHL + + Flight prices from August 2021 using Delta or American airlines to get + nonstop flight. The brute force solution found the optimal tour to cost $872 + """ + np = pytest.importorskip("numpy") + pytest.importorskip("scipy") + + G_array = np.array( + [ + # JFK LAX ORD IAH PHX PHL + [0, 243, 199, 208, 169, 183], # JFK + [277, 0, 217, 123, 127, 252], # LAX + [297, 197, 0, 197, 123, 177], # ORD + [303, 169, 197, 0, 117, 117], # IAH + [257, 127, 160, 117, 0, 319], # PHX + [183, 332, 217, 117, 319, 0], # PHL + ] + ) + + node_list = ["JFK", "LAX", "ORD", "IAH", "PHX", "PHL"] + + expected_paths = [ + ["ORD", "PHX", "LAX", "IAH", "PHL", "JFK"], + ["JFK", "PHL", "IAH", "ORD", "PHX", "LAX"], + ] + + G = nx.from_numpy_array(G_array, nodelist=node_list, create_using=nx.DiGraph) + + path = nx_app.traveling_salesman_problem( + G, weight="weight", cycle=False, method=nx_app.asadpour_atsp, seed=56 + ) + + assert path in expected_paths + + +def test_asadpour_disconnected_graph(): + """ + Test that the proper exception is raised when asadpour_atsp is given an + disconnected graph. + """ + + G = nx.complete_graph(4, create_using=nx.DiGraph) + # have to set edge weights so that if the exception is not raised, the + # function will complete and we will fail the test + nx.set_edge_attributes(G, 1, "weight") + G.add_node(5) + + pytest.raises(nx.NetworkXError, nx_app.asadpour_atsp, G) + + +def test_asadpour_incomplete_graph(): + """ + Test that the proper exception is raised when asadpour_atsp is given an + incomplete graph + """ + + G = nx.complete_graph(4, create_using=nx.DiGraph) + # have to set edge weights so that if the exception is not raised, the + # function will complete and we will fail the test + nx.set_edge_attributes(G, 1, "weight") + G.remove_edge(0, 1) + + pytest.raises(nx.NetworkXError, nx_app.asadpour_atsp, G) + + +def test_asadpour_empty_graph(): + """ + Test the asadpour_atsp function with an empty graph + """ + G = nx.DiGraph() + + pytest.raises(nx.NetworkXError, nx_app.asadpour_atsp, G) + + +@pytest.mark.slow +def test_asadpour_integral_held_karp(): + """ + This test uses an integral held karp solution and the held karp function + will return a graph rather than a dict, bypassing most of the asadpour + algorithm. + + At first glance, this test probably doesn't look like it ensures that we + skip the rest of the asadpour algorithm, but it does. We are not fixing a + see for the random number generator, so if we sample any spanning trees + the approximation would be different basically every time this test is + executed but it is not since held karp is deterministic and we do not + reach the portion of the code with the dependence on random numbers. + """ + np = pytest.importorskip("numpy") + + G_array = np.array( + [ + [0, 26, 63, 59, 69, 31, 41], + [62, 0, 91, 53, 75, 87, 47], + [47, 82, 0, 90, 15, 9, 18], + [68, 19, 5, 0, 58, 34, 93], + [11, 58, 53, 55, 0, 61, 79], + [88, 75, 13, 76, 98, 0, 40], + [41, 61, 55, 88, 46, 45, 0], + ] + ) + + G = nx.from_numpy_array(G_array, create_using=nx.DiGraph) + + for _ in range(2): + tour = nx_app.traveling_salesman_problem(G, method=nx_app.asadpour_atsp) + + assert [1, 3, 2, 5, 2, 6, 4, 0, 1] == tour + + +def test_directed_tsp_impossible(): + """ + Test the asadpour algorithm with a graph without a hamiltonian circuit + """ + pytest.importorskip("numpy") + + # In this graph, once we leave node 0 we cannot return + edges = [ + (0, 1, 10), + (0, 2, 11), + (0, 3, 12), + (1, 2, 4), + (1, 3, 6), + (2, 1, 3), + (2, 3, 2), + (3, 1, 5), + (3, 2, 1), + ] + + G = nx.DiGraph() + G.add_weighted_edges_from(edges) + + pytest.raises(nx.NetworkXError, nx_app.traveling_salesman_problem, G) diff --git a/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/test_treewidth.py b/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/test_treewidth.py new file mode 100644 index 0000000000000000000000000000000000000000..461b0f2ed2dd4d043902d054e10a5f39ffb069c9 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/test_treewidth.py @@ -0,0 +1,280 @@ +import itertools + +import networkx as nx +from networkx.algorithms.approximation import ( + treewidth_min_degree, + treewidth_min_fill_in, +) +from networkx.algorithms.approximation.treewidth import ( + MinDegreeHeuristic, + min_fill_in_heuristic, +) + + +def is_tree_decomp(graph, decomp): + """Check if the given tree decomposition is valid.""" + for x in graph.nodes(): + appear_once = False + for bag in decomp.nodes(): + if x in bag: + appear_once = True + break + assert appear_once + + # Check if each connected pair of nodes are at least once together in a bag + for x, y in graph.edges(): + appear_together = False + for bag in decomp.nodes(): + if x in bag and y in bag: + appear_together = True + break + assert appear_together + + # Check if the nodes associated with vertex v form a connected subset of T + for v in graph.nodes(): + subset = [] + for bag in decomp.nodes(): + if v in bag: + subset.append(bag) + sub_graph = decomp.subgraph(subset) + assert nx.is_connected(sub_graph) + + +class TestTreewidthMinDegree: + """Unit tests for the min_degree function""" + + @classmethod + def setup_class(cls): + """Setup for different kinds of trees""" + cls.complete = nx.Graph() + cls.complete.add_edge(1, 2) + cls.complete.add_edge(2, 3) + cls.complete.add_edge(1, 3) + + cls.small_tree = nx.Graph() + cls.small_tree.add_edge(1, 3) + cls.small_tree.add_edge(4, 3) + cls.small_tree.add_edge(2, 3) + cls.small_tree.add_edge(3, 5) + cls.small_tree.add_edge(5, 6) + cls.small_tree.add_edge(5, 7) + cls.small_tree.add_edge(6, 7) + + cls.deterministic_graph = nx.Graph() + cls.deterministic_graph.add_edge(0, 1) # deg(0) = 1 + + cls.deterministic_graph.add_edge(1, 2) # deg(1) = 2 + + cls.deterministic_graph.add_edge(2, 3) + cls.deterministic_graph.add_edge(2, 4) # deg(2) = 3 + + cls.deterministic_graph.add_edge(3, 4) + cls.deterministic_graph.add_edge(3, 5) + cls.deterministic_graph.add_edge(3, 6) # deg(3) = 4 + + cls.deterministic_graph.add_edge(4, 5) + cls.deterministic_graph.add_edge(4, 6) + cls.deterministic_graph.add_edge(4, 7) # deg(4) = 5 + + cls.deterministic_graph.add_edge(5, 6) + cls.deterministic_graph.add_edge(5, 7) + cls.deterministic_graph.add_edge(5, 8) + cls.deterministic_graph.add_edge(5, 9) # deg(5) = 6 + + cls.deterministic_graph.add_edge(6, 7) + cls.deterministic_graph.add_edge(6, 8) + cls.deterministic_graph.add_edge(6, 9) # deg(6) = 6 + + cls.deterministic_graph.add_edge(7, 8) + cls.deterministic_graph.add_edge(7, 9) # deg(7) = 5 + + cls.deterministic_graph.add_edge(8, 9) # deg(8) = 4 + + def test_petersen_graph(self): + """Test Petersen graph tree decomposition result""" + G = nx.petersen_graph() + _, decomp = treewidth_min_degree(G) + is_tree_decomp(G, decomp) + + def test_small_tree_treewidth(self): + """Test small tree + + Test if the computed treewidth of the known self.small_tree is 2. + As we know which value we can expect from our heuristic, values other + than two are regressions + """ + G = self.small_tree + # the order of removal should be [1,2,4]3[5,6,7] + # (with [] denoting any order of the containing nodes) + # resulting in treewidth 2 for the heuristic + treewidth, _ = treewidth_min_fill_in(G) + assert treewidth == 2 + + def test_heuristic_abort(self): + """Test heuristic abort condition for fully connected graph""" + graph = {} + for u in self.complete: + graph[u] = set() + for v in self.complete[u]: + if u != v: # ignore self-loop + graph[u].add(v) + + deg_heuristic = MinDegreeHeuristic(graph) + node = deg_heuristic.best_node(graph) + if node is None: + pass + else: + assert False + + def test_empty_graph(self): + """Test empty graph""" + G = nx.Graph() + _, _ = treewidth_min_degree(G) + + def test_two_component_graph(self): + G = nx.Graph() + G.add_node(1) + G.add_node(2) + treewidth, _ = treewidth_min_degree(G) + assert treewidth == 0 + + def test_not_sortable_nodes(self): + G = nx.Graph([(0, "a")]) + treewidth_min_degree(G) + + def test_heuristic_first_steps(self): + """Test first steps of min_degree heuristic""" + graph = { + n: set(self.deterministic_graph[n]) - {n} for n in self.deterministic_graph + } + deg_heuristic = MinDegreeHeuristic(graph) + elim_node = deg_heuristic.best_node(graph) + print(f"Graph {graph}:") + steps = [] + + while elim_node is not None: + print(f"Removing {elim_node}:") + steps.append(elim_node) + nbrs = graph[elim_node] + + for u, v in itertools.permutations(nbrs, 2): + if v not in graph[u]: + graph[u].add(v) + + for u in graph: + if elim_node in graph[u]: + graph[u].remove(elim_node) + + del graph[elim_node] + print(f"Graph {graph}:") + elim_node = deg_heuristic.best_node(graph) + + # check only the first 5 elements for equality + assert steps[:5] == [0, 1, 2, 3, 4] + + +class TestTreewidthMinFillIn: + """Unit tests for the treewidth_min_fill_in function.""" + + @classmethod + def setup_class(cls): + """Setup for different kinds of trees""" + cls.complete = nx.Graph() + cls.complete.add_edge(1, 2) + cls.complete.add_edge(2, 3) + cls.complete.add_edge(1, 3) + + cls.small_tree = nx.Graph() + cls.small_tree.add_edge(1, 2) + cls.small_tree.add_edge(2, 3) + cls.small_tree.add_edge(3, 4) + cls.small_tree.add_edge(1, 4) + cls.small_tree.add_edge(2, 4) + cls.small_tree.add_edge(4, 5) + cls.small_tree.add_edge(5, 6) + cls.small_tree.add_edge(5, 7) + cls.small_tree.add_edge(6, 7) + + cls.deterministic_graph = nx.Graph() + cls.deterministic_graph.add_edge(1, 2) + cls.deterministic_graph.add_edge(1, 3) + cls.deterministic_graph.add_edge(3, 4) + cls.deterministic_graph.add_edge(2, 4) + cls.deterministic_graph.add_edge(3, 5) + cls.deterministic_graph.add_edge(4, 5) + cls.deterministic_graph.add_edge(3, 6) + cls.deterministic_graph.add_edge(5, 6) + + def test_petersen_graph(self): + """Test Petersen graph tree decomposition result""" + G = nx.petersen_graph() + _, decomp = treewidth_min_fill_in(G) + is_tree_decomp(G, decomp) + + def test_small_tree_treewidth(self): + """Test if the computed treewidth of the known self.small_tree is 2""" + G = self.small_tree + # the order of removal should be [1,2,4]3[5,6,7] + # (with [] denoting any order of the containing nodes) + # resulting in treewidth 2 for the heuristic + treewidth, _ = treewidth_min_fill_in(G) + assert treewidth == 2 + + def test_heuristic_abort(self): + """Test if min_fill_in returns None for fully connected graph""" + graph = {} + for u in self.complete: + graph[u] = set() + for v in self.complete[u]: + if u != v: # ignore self-loop + graph[u].add(v) + next_node = min_fill_in_heuristic(graph) + if next_node is None: + pass + else: + assert False + + def test_empty_graph(self): + """Test empty graph""" + G = nx.Graph() + _, _ = treewidth_min_fill_in(G) + + def test_two_component_graph(self): + G = nx.Graph() + G.add_node(1) + G.add_node(2) + treewidth, _ = treewidth_min_fill_in(G) + assert treewidth == 0 + + def test_not_sortable_nodes(self): + G = nx.Graph([(0, "a")]) + treewidth_min_fill_in(G) + + def test_heuristic_first_steps(self): + """Test first steps of min_fill_in heuristic""" + graph = { + n: set(self.deterministic_graph[n]) - {n} for n in self.deterministic_graph + } + print(f"Graph {graph}:") + elim_node = min_fill_in_heuristic(graph) + steps = [] + + while elim_node is not None: + print(f"Removing {elim_node}:") + steps.append(elim_node) + nbrs = graph[elim_node] + + for u, v in itertools.permutations(nbrs, 2): + if v not in graph[u]: + graph[u].add(v) + + for u in graph: + if elim_node in graph[u]: + graph[u].remove(elim_node) + + del graph[elim_node] + print(f"Graph {graph}:") + elim_node = min_fill_in_heuristic(graph) + + # check only the first 2 elements for equality + assert steps[:2] == [6, 5] diff --git a/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/test_vertex_cover.py b/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/test_vertex_cover.py new file mode 100644 index 0000000000000000000000000000000000000000..5cc5a38df9a4139684005491e0183cd563487154 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/approximation/tests/test_vertex_cover.py @@ -0,0 +1,68 @@ +import networkx as nx +from networkx.algorithms.approximation import min_weighted_vertex_cover + + +def is_cover(G, node_cover): + return all({u, v} & node_cover for u, v in G.edges()) + + +class TestMWVC: + """Unit tests for the approximate minimum weighted vertex cover + function, + :func:`~networkx.algorithms.approximation.vertex_cover.min_weighted_vertex_cover`. + + """ + + def test_unweighted_directed(self): + # Create a star graph in which half the nodes are directed in + # and half are directed out. + G = nx.DiGraph() + G.add_edges_from((0, v) for v in range(1, 26)) + G.add_edges_from((v, 0) for v in range(26, 51)) + cover = min_weighted_vertex_cover(G) + assert 1 == len(cover) + assert is_cover(G, cover) + + def test_unweighted_undirected(self): + # create a simple star graph + size = 50 + sg = nx.star_graph(size) + cover = min_weighted_vertex_cover(sg) + assert 1 == len(cover) + assert is_cover(sg, cover) + + def test_weighted(self): + wg = nx.Graph() + wg.add_node(0, weight=10) + wg.add_node(1, weight=1) + wg.add_node(2, weight=1) + wg.add_node(3, weight=1) + wg.add_node(4, weight=1) + + wg.add_edge(0, 1) + wg.add_edge(0, 2) + wg.add_edge(0, 3) + wg.add_edge(0, 4) + + wg.add_edge(1, 2) + wg.add_edge(2, 3) + wg.add_edge(3, 4) + wg.add_edge(4, 1) + + cover = min_weighted_vertex_cover(wg, weight="weight") + csum = sum(wg.nodes[node]["weight"] for node in cover) + assert 4 == csum + assert is_cover(wg, cover) + + def test_unweighted_self_loop(self): + slg = nx.Graph() + slg.add_node(0) + slg.add_node(1) + slg.add_node(2) + + slg.add_edge(0, 1) + slg.add_edge(2, 2) + + cover = min_weighted_vertex_cover(slg) + assert 2 == len(cover) + assert is_cover(slg, cover) diff --git a/lib/python3.10/site-packages/networkx/algorithms/approximation/traveling_salesman.py b/lib/python3.10/site-packages/networkx/algorithms/approximation/traveling_salesman.py new file mode 100644 index 0000000000000000000000000000000000000000..2080c99aec7df150cfa670150155052afc3ac289 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/approximation/traveling_salesman.py @@ -0,0 +1,1501 @@ +""" +================================= +Travelling Salesman Problem (TSP) +================================= + +Implementation of approximate algorithms +for solving and approximating the TSP problem. + +Categories of algorithms which are implemented: + +- Christofides (provides a 3/2-approximation of TSP) +- Greedy +- Simulated Annealing (SA) +- Threshold Accepting (TA) +- Asadpour Asymmetric Traveling Salesman Algorithm + +The Travelling Salesman Problem tries to find, given the weight +(distance) between all points where a salesman has to visit, the +route so that: + +- The total distance (cost) which the salesman travels is minimized. +- The salesman returns to the starting point. +- Note that for a complete graph, the salesman visits each point once. + +The function `travelling_salesman_problem` allows for incomplete +graphs by finding all-pairs shortest paths, effectively converting +the problem to a complete graph problem. It calls one of the +approximate methods on that problem and then converts the result +back to the original graph using the previously found shortest paths. + +TSP is an NP-hard problem in combinatorial optimization, +important in operations research and theoretical computer science. + +http://en.wikipedia.org/wiki/Travelling_salesman_problem +""" + +import math + +import networkx as nx +from networkx.algorithms.tree.mst import random_spanning_tree +from networkx.utils import not_implemented_for, pairwise, py_random_state + +__all__ = [ + "traveling_salesman_problem", + "christofides", + "asadpour_atsp", + "greedy_tsp", + "simulated_annealing_tsp", + "threshold_accepting_tsp", +] + + +def swap_two_nodes(soln, seed): + """Swap two nodes in `soln` to give a neighbor solution. + + Parameters + ---------- + soln : list of nodes + Current cycle of nodes + + seed : integer, random_state, or None (default) + Indicator of random number generation state. + See :ref:`Randomness`. + + Returns + ------- + list + The solution after move is applied. (A neighbor solution.) + + Notes + ----- + This function assumes that the incoming list `soln` is a cycle + (that the first and last element are the same) and also that + we don't want any move to change the first node in the list + (and thus not the last node either). + + The input list is changed as well as returned. Make a copy if needed. + + See Also + -------- + move_one_node + """ + a, b = seed.sample(range(1, len(soln) - 1), k=2) + soln[a], soln[b] = soln[b], soln[a] + return soln + + +def move_one_node(soln, seed): + """Move one node to another position to give a neighbor solution. + + The node to move and the position to move to are chosen randomly. + The first and last nodes are left untouched as soln must be a cycle + starting at that node. + + Parameters + ---------- + soln : list of nodes + Current cycle of nodes + + seed : integer, random_state, or None (default) + Indicator of random number generation state. + See :ref:`Randomness`. + + Returns + ------- + list + The solution after move is applied. (A neighbor solution.) + + Notes + ----- + This function assumes that the incoming list `soln` is a cycle + (that the first and last element are the same) and also that + we don't want any move to change the first node in the list + (and thus not the last node either). + + The input list is changed as well as returned. Make a copy if needed. + + See Also + -------- + swap_two_nodes + """ + a, b = seed.sample(range(1, len(soln) - 1), k=2) + soln.insert(b, soln.pop(a)) + return soln + + +@not_implemented_for("directed") +@nx._dispatchable(edge_attrs="weight") +def christofides(G, weight="weight", tree=None): + """Approximate a solution of the traveling salesman problem + + Compute a 3/2-approximation of the traveling salesman problem + in a complete undirected graph using Christofides [1]_ algorithm. + + Parameters + ---------- + G : Graph + `G` should be a complete weighted undirected graph. + The distance between all pairs of nodes should be included. + + weight : string, optional (default="weight") + Edge data key corresponding to the edge weight. + If any edge does not have this attribute the weight is set to 1. + + tree : NetworkX graph or None (default: None) + A minimum spanning tree of G. Or, if None, the minimum spanning + tree is computed using :func:`networkx.minimum_spanning_tree` + + Returns + ------- + list + List of nodes in `G` along a cycle with a 3/2-approximation of + the minimal Hamiltonian cycle. + + References + ---------- + .. [1] Christofides, Nicos. "Worst-case analysis of a new heuristic for + the travelling salesman problem." No. RR-388. Carnegie-Mellon Univ + Pittsburgh Pa Management Sciences Research Group, 1976. + """ + # Remove selfloops if necessary + loop_nodes = nx.nodes_with_selfloops(G) + try: + node = next(loop_nodes) + except StopIteration: + pass + else: + G = G.copy() + G.remove_edge(node, node) + G.remove_edges_from((n, n) for n in loop_nodes) + # Check that G is a complete graph + N = len(G) - 1 + # This check ignores selfloops which is what we want here. + if any(len(nbrdict) != N for n, nbrdict in G.adj.items()): + raise nx.NetworkXError("G must be a complete graph.") + + if tree is None: + tree = nx.minimum_spanning_tree(G, weight=weight) + L = G.copy() + L.remove_nodes_from([v for v, degree in tree.degree if not (degree % 2)]) + MG = nx.MultiGraph() + MG.add_edges_from(tree.edges) + edges = nx.min_weight_matching(L, weight=weight) + MG.add_edges_from(edges) + return _shortcutting(nx.eulerian_circuit(MG)) + + +def _shortcutting(circuit): + """Remove duplicate nodes in the path""" + nodes = [] + for u, v in circuit: + if v in nodes: + continue + if not nodes: + nodes.append(u) + nodes.append(v) + nodes.append(nodes[0]) + return nodes + + +@nx._dispatchable(edge_attrs="weight") +def traveling_salesman_problem( + G, weight="weight", nodes=None, cycle=True, method=None, **kwargs +): + """Find the shortest path in `G` connecting specified nodes + + This function allows approximate solution to the traveling salesman + problem on networks that are not complete graphs and/or where the + salesman does not need to visit all nodes. + + This function proceeds in two steps. First, it creates a complete + graph using the all-pairs shortest_paths between nodes in `nodes`. + Edge weights in the new graph are the lengths of the paths + between each pair of nodes in the original graph. + Second, an algorithm (default: `christofides` for undirected and + `asadpour_atsp` for directed) is used to approximate the minimal Hamiltonian + cycle on this new graph. The available algorithms are: + + - christofides + - greedy_tsp + - simulated_annealing_tsp + - threshold_accepting_tsp + - asadpour_atsp + + Once the Hamiltonian Cycle is found, this function post-processes to + accommodate the structure of the original graph. If `cycle` is ``False``, + the biggest weight edge is removed to make a Hamiltonian path. + Then each edge on the new complete graph used for that analysis is + replaced by the shortest_path between those nodes on the original graph. + If the input graph `G` includes edges with weights that do not adhere to + the triangle inequality, such as when `G` is not a complete graph (i.e + length of non-existent edges is infinity), then the returned path may + contain some repeating nodes (other than the starting node). + + Parameters + ---------- + G : NetworkX graph + A possibly weighted graph + + nodes : collection of nodes (default=G.nodes) + collection (list, set, etc.) of nodes to visit + + weight : string, optional (default="weight") + Edge data key corresponding to the edge weight. + If any edge does not have this attribute the weight is set to 1. + + cycle : bool (default: True) + Indicates whether a cycle should be returned, or a path. + Note: the cycle is the approximate minimal cycle. + The path simply removes the biggest edge in that cycle. + + method : function (default: None) + A function that returns a cycle on all nodes and approximates + the solution to the traveling salesman problem on a complete + graph. The returned cycle is then used to find a corresponding + solution on `G`. `method` should be callable; take inputs + `G`, and `weight`; and return a list of nodes along the cycle. + + Provided options include :func:`christofides`, :func:`greedy_tsp`, + :func:`simulated_annealing_tsp` and :func:`threshold_accepting_tsp`. + + If `method is None`: use :func:`christofides` for undirected `G` and + :func:`asadpour_atsp` for directed `G`. + + **kwargs : dict + Other keyword arguments to be passed to the `method` function passed in. + + Returns + ------- + list + List of nodes in `G` along a path with an approximation of the minimal + path through `nodes`. + + Raises + ------ + NetworkXError + If `G` is a directed graph it has to be strongly connected or the + complete version cannot be generated. + + Examples + -------- + >>> tsp = nx.approximation.traveling_salesman_problem + >>> G = nx.cycle_graph(9) + >>> G[4][5]["weight"] = 5 # all other weights are 1 + >>> tsp(G, nodes=[3, 6]) + [3, 2, 1, 0, 8, 7, 6, 7, 8, 0, 1, 2, 3] + >>> path = tsp(G, cycle=False) + >>> path in ([4, 3, 2, 1, 0, 8, 7, 6, 5], [5, 6, 7, 8, 0, 1, 2, 3, 4]) + True + + While no longer required, you can still build (curry) your own function + to provide parameter values to the methods. + + >>> SA_tsp = nx.approximation.simulated_annealing_tsp + >>> method = lambda G, weight: SA_tsp(G, "greedy", weight=weight, temp=500) + >>> path = tsp(G, cycle=False, method=method) + >>> path in ([4, 3, 2, 1, 0, 8, 7, 6, 5], [5, 6, 7, 8, 0, 1, 2, 3, 4]) + True + + Otherwise, pass other keyword arguments directly into the tsp function. + + >>> path = tsp( + ... G, + ... cycle=False, + ... method=nx.approximation.simulated_annealing_tsp, + ... init_cycle="greedy", + ... temp=500, + ... ) + >>> path in ([4, 3, 2, 1, 0, 8, 7, 6, 5], [5, 6, 7, 8, 0, 1, 2, 3, 4]) + True + """ + if method is None: + if G.is_directed(): + method = asadpour_atsp + else: + method = christofides + if nodes is None: + nodes = list(G.nodes) + + dist = {} + path = {} + for n, (d, p) in nx.all_pairs_dijkstra(G, weight=weight): + dist[n] = d + path[n] = p + + if G.is_directed(): + # If the graph is not strongly connected, raise an exception + if not nx.is_strongly_connected(G): + raise nx.NetworkXError("G is not strongly connected") + GG = nx.DiGraph() + else: + GG = nx.Graph() + for u in nodes: + for v in nodes: + if u == v: + continue + GG.add_edge(u, v, weight=dist[u][v]) + + best_GG = method(GG, weight=weight, **kwargs) + + if not cycle: + # find and remove the biggest edge + (u, v) = max(pairwise(best_GG), key=lambda x: dist[x[0]][x[1]]) + pos = best_GG.index(u) + 1 + while best_GG[pos] != v: + pos = best_GG[pos:].index(u) + 1 + best_GG = best_GG[pos:-1] + best_GG[:pos] + + best_path = [] + for u, v in pairwise(best_GG): + best_path.extend(path[u][v][:-1]) + best_path.append(v) + return best_path + + +@not_implemented_for("undirected") +@py_random_state(2) +@nx._dispatchable(edge_attrs="weight", mutates_input=True) +def asadpour_atsp(G, weight="weight", seed=None, source=None): + """ + Returns an approximate solution to the traveling salesman problem. + + This approximate solution is one of the best known approximations for the + asymmetric traveling salesman problem developed by Asadpour et al, + [1]_. The algorithm first solves the Held-Karp relaxation to find a lower + bound for the weight of the cycle. Next, it constructs an exponential + distribution of undirected spanning trees where the probability of an + edge being in the tree corresponds to the weight of that edge using a + maximum entropy rounding scheme. Next we sample that distribution + $2 \\lceil \\ln n \\rceil$ times and save the minimum sampled tree once the + direction of the arcs is added back to the edges. Finally, we augment + then short circuit that graph to find the approximate tour for the + salesman. + + Parameters + ---------- + G : nx.DiGraph + The graph should be a complete weighted directed graph. The + distance between all paris of nodes should be included and the triangle + inequality should hold. That is, the direct edge between any two nodes + should be the path of least cost. + + weight : string, optional (default="weight") + Edge data key corresponding to the edge weight. + If any edge does not have this attribute the weight is set to 1. + + seed : integer, random_state, or None (default) + Indicator of random number generation state. + See :ref:`Randomness`. + + source : node label (default=`None`) + If given, return the cycle starting and ending at the given node. + + Returns + ------- + cycle : list of nodes + Returns the cycle (list of nodes) that a salesman can follow to minimize + the total weight of the trip. + + Raises + ------ + NetworkXError + If `G` is not complete or has less than two nodes, the algorithm raises + an exception. + + NetworkXError + If `source` is not `None` and is not a node in `G`, the algorithm raises + an exception. + + NetworkXNotImplemented + If `G` is an undirected graph. + + References + ---------- + .. [1] A. Asadpour, M. X. Goemans, A. Madry, S. O. Gharan, and A. Saberi, + An o(log n/log log n)-approximation algorithm for the asymmetric + traveling salesman problem, Operations research, 65 (2017), + pp. 1043–1061 + + Examples + -------- + >>> import networkx as nx + >>> import networkx.algorithms.approximation as approx + >>> G = nx.complete_graph(3, create_using=nx.DiGraph) + >>> nx.set_edge_attributes( + ... G, + ... {(0, 1): 2, (1, 2): 2, (2, 0): 2, (0, 2): 1, (2, 1): 1, (1, 0): 1}, + ... "weight", + ... ) + >>> tour = approx.asadpour_atsp(G, source=0) + >>> tour + [0, 2, 1, 0] + """ + from math import ceil, exp + from math import log as ln + + # Check that G is a complete graph + N = len(G) - 1 + if N < 2: + raise nx.NetworkXError("G must have at least two nodes") + # This check ignores selfloops which is what we want here. + if any(len(nbrdict) - (n in nbrdict) != N for n, nbrdict in G.adj.items()): + raise nx.NetworkXError("G is not a complete DiGraph") + # Check that the source vertex, if given, is in the graph + if source is not None and source not in G.nodes: + raise nx.NetworkXError("Given source node not in G.") + + opt_hk, z_star = held_karp_ascent(G, weight) + + # Test to see if the ascent method found an integer solution or a fractional + # solution. If it is integral then z_star is a nx.Graph, otherwise it is + # a dict + if not isinstance(z_star, dict): + # Here we are using the shortcutting method to go from the list of edges + # returned from eulerian_circuit to a list of nodes + return _shortcutting(nx.eulerian_circuit(z_star, source=source)) + + # Create the undirected support of z_star + z_support = nx.MultiGraph() + for u, v in z_star: + if (u, v) not in z_support.edges: + edge_weight = min(G[u][v][weight], G[v][u][weight]) + z_support.add_edge(u, v, **{weight: edge_weight}) + + # Create the exponential distribution of spanning trees + gamma = spanning_tree_distribution(z_support, z_star) + + # Write the lambda values to the edges of z_support + z_support = nx.Graph(z_support) + lambda_dict = {(u, v): exp(gamma[(u, v)]) for u, v in z_support.edges()} + nx.set_edge_attributes(z_support, lambda_dict, "weight") + del gamma, lambda_dict + + # Sample 2 * ceil( ln(n) ) spanning trees and record the minimum one + minimum_sampled_tree = None + minimum_sampled_tree_weight = math.inf + for _ in range(2 * ceil(ln(G.number_of_nodes()))): + sampled_tree = random_spanning_tree(z_support, "weight", seed=seed) + sampled_tree_weight = sampled_tree.size(weight) + if sampled_tree_weight < minimum_sampled_tree_weight: + minimum_sampled_tree = sampled_tree.copy() + minimum_sampled_tree_weight = sampled_tree_weight + + # Orient the edges in that tree to keep the cost of the tree the same. + t_star = nx.MultiDiGraph() + for u, v, d in minimum_sampled_tree.edges(data=weight): + if d == G[u][v][weight]: + t_star.add_edge(u, v, **{weight: d}) + else: + t_star.add_edge(v, u, **{weight: d}) + + # Find the node demands needed to neutralize the flow of t_star in G + node_demands = {n: t_star.out_degree(n) - t_star.in_degree(n) for n in t_star} + nx.set_node_attributes(G, node_demands, "demand") + + # Find the min_cost_flow + flow_dict = nx.min_cost_flow(G, "demand") + + # Build the flow into t_star + for source, values in flow_dict.items(): + for target in values: + if (source, target) not in t_star.edges and values[target] > 0: + # IF values[target] > 0 we have to add that many edges + for _ in range(values[target]): + t_star.add_edge(source, target) + + # Return the shortcut eulerian circuit + circuit = nx.eulerian_circuit(t_star, source=source) + return _shortcutting(circuit) + + +@nx._dispatchable(edge_attrs="weight", mutates_input=True, returns_graph=True) +def held_karp_ascent(G, weight="weight"): + """ + Minimizes the Held-Karp relaxation of the TSP for `G` + + Solves the Held-Karp relaxation of the input complete digraph and scales + the output solution for use in the Asadpour [1]_ ASTP algorithm. + + The Held-Karp relaxation defines the lower bound for solutions to the + ATSP, although it does return a fractional solution. This is used in the + Asadpour algorithm as an initial solution which is later rounded to a + integral tree within the spanning tree polytopes. This function solves + the relaxation with the branch and bound method in [2]_. + + Parameters + ---------- + G : nx.DiGraph + The graph should be a complete weighted directed graph. + The distance between all paris of nodes should be included. + + weight : string, optional (default="weight") + Edge data key corresponding to the edge weight. + If any edge does not have this attribute the weight is set to 1. + + Returns + ------- + OPT : float + The cost for the optimal solution to the Held-Karp relaxation + z : dict or nx.Graph + A symmetrized and scaled version of the optimal solution to the + Held-Karp relaxation for use in the Asadpour algorithm. + + If an integral solution is found, then that is an optimal solution for + the ATSP problem and that is returned instead. + + References + ---------- + .. [1] A. Asadpour, M. X. Goemans, A. Madry, S. O. Gharan, and A. Saberi, + An o(log n/log log n)-approximation algorithm for the asymmetric + traveling salesman problem, Operations research, 65 (2017), + pp. 1043–1061 + + .. [2] M. Held, R. M. Karp, The traveling-salesman problem and minimum + spanning trees, Operations Research, 1970-11-01, Vol. 18 (6), + pp.1138-1162 + """ + import numpy as np + from scipy import optimize + + def k_pi(): + """ + Find the set of minimum 1-Arborescences for G at point pi. + + Returns + ------- + Set + The set of minimum 1-Arborescences + """ + # Create a copy of G without vertex 1. + G_1 = G.copy() + minimum_1_arborescences = set() + minimum_1_arborescence_weight = math.inf + + # node is node '1' in the Held and Karp paper + n = next(G.__iter__()) + G_1.remove_node(n) + + # Iterate over the spanning arborescences of the graph until we know + # that we have found the minimum 1-arborescences. My proposed strategy + # is to find the most extensive root to connect to from 'node 1' and + # the least expensive one. We then iterate over arborescences until + # the cost of the basic arborescence is the cost of the minimum one + # plus the difference between the most and least expensive roots, + # that way the cost of connecting 'node 1' will by definition not by + # minimum + min_root = {"node": None, weight: math.inf} + max_root = {"node": None, weight: -math.inf} + for u, v, d in G.edges(n, data=True): + if d[weight] < min_root[weight]: + min_root = {"node": v, weight: d[weight]} + if d[weight] > max_root[weight]: + max_root = {"node": v, weight: d[weight]} + + min_in_edge = min(G.in_edges(n, data=True), key=lambda x: x[2][weight]) + min_root[weight] = min_root[weight] + min_in_edge[2][weight] + max_root[weight] = max_root[weight] + min_in_edge[2][weight] + + min_arb_weight = math.inf + for arb in nx.ArborescenceIterator(G_1): + arb_weight = arb.size(weight) + if min_arb_weight == math.inf: + min_arb_weight = arb_weight + elif arb_weight > min_arb_weight + max_root[weight] - min_root[weight]: + break + # We have to pick the root node of the arborescence for the out + # edge of the first vertex as that is the only node without an + # edge directed into it. + for N, deg in arb.in_degree: + if deg == 0: + # root found + arb.add_edge(n, N, **{weight: G[n][N][weight]}) + arb_weight += G[n][N][weight] + break + + # We can pick the minimum weight in-edge for the vertex with + # a cycle. If there are multiple edges with the same, minimum + # weight, We need to add all of them. + # + # Delete the edge (N, v) so that we cannot pick it. + edge_data = G[N][n] + G.remove_edge(N, n) + min_weight = min(G.in_edges(n, data=weight), key=lambda x: x[2])[2] + min_edges = [ + (u, v, d) for u, v, d in G.in_edges(n, data=weight) if d == min_weight + ] + for u, v, d in min_edges: + new_arb = arb.copy() + new_arb.add_edge(u, v, **{weight: d}) + new_arb_weight = arb_weight + d + # Check to see the weight of the arborescence, if it is a + # new minimum, clear all of the old potential minimum + # 1-arborescences and add this is the only one. If its + # weight is above the known minimum, do not add it. + if new_arb_weight < minimum_1_arborescence_weight: + minimum_1_arborescences.clear() + minimum_1_arborescence_weight = new_arb_weight + # We have a 1-arborescence, add it to the set + if new_arb_weight == minimum_1_arborescence_weight: + minimum_1_arborescences.add(new_arb) + G.add_edge(N, n, **edge_data) + + return minimum_1_arborescences + + def direction_of_ascent(): + """ + Find the direction of ascent at point pi. + + See [1]_ for more information. + + Returns + ------- + dict + A mapping from the nodes of the graph which represents the direction + of ascent. + + References + ---------- + .. [1] M. Held, R. M. Karp, The traveling-salesman problem and minimum + spanning trees, Operations Research, 1970-11-01, Vol. 18 (6), + pp.1138-1162 + """ + # 1. Set d equal to the zero n-vector. + d = {} + for n in G: + d[n] = 0 + del n + # 2. Find a 1-Arborescence T^k such that k is in K(pi, d). + minimum_1_arborescences = k_pi() + while True: + # Reduce K(pi) to K(pi, d) + # Find the arborescence in K(pi) which increases the lest in + # direction d + min_k_d_weight = math.inf + min_k_d = None + for arborescence in minimum_1_arborescences: + weighted_cost = 0 + for n, deg in arborescence.degree: + weighted_cost += d[n] * (deg - 2) + if weighted_cost < min_k_d_weight: + min_k_d_weight = weighted_cost + min_k_d = arborescence + + # 3. If sum of d_i * v_{i, k} is greater than zero, terminate + if min_k_d_weight > 0: + return d, min_k_d + # 4. d_i = d_i + v_{i, k} + for n, deg in min_k_d.degree: + d[n] += deg - 2 + # Check that we do not need to terminate because the direction + # of ascent does not exist. This is done with linear + # programming. + c = np.full(len(minimum_1_arborescences), -1, dtype=int) + a_eq = np.empty((len(G) + 1, len(minimum_1_arborescences)), dtype=int) + b_eq = np.zeros(len(G) + 1, dtype=int) + b_eq[len(G)] = 1 + for arb_count, arborescence in enumerate(minimum_1_arborescences): + n_count = len(G) - 1 + for n, deg in arborescence.degree: + a_eq[n_count][arb_count] = deg - 2 + n_count -= 1 + a_eq[len(G)][arb_count] = 1 + program_result = optimize.linprog( + c, A_eq=a_eq, b_eq=b_eq, method="highs-ipm" + ) + # If the constants exist, then the direction of ascent doesn't + if program_result.success: + # There is no direction of ascent + return None, minimum_1_arborescences + + # 5. GO TO 2 + + def find_epsilon(k, d): + """ + Given the direction of ascent at pi, find the maximum distance we can go + in that direction. + + Parameters + ---------- + k_xy : set + The set of 1-arborescences which have the minimum rate of increase + in the direction of ascent + + d : dict + The direction of ascent + + Returns + ------- + float + The distance we can travel in direction `d` + """ + min_epsilon = math.inf + for e_u, e_v, e_w in G.edges(data=weight): + if (e_u, e_v) in k.edges: + continue + # Now, I have found a condition which MUST be true for the edges to + # be a valid substitute. The edge in the graph which is the + # substitute is the one with the same terminated end. This can be + # checked rather simply. + # + # Find the edge within k which is the substitute. Because k is a + # 1-arborescence, we know that they is only one such edges + # leading into every vertex. + if len(k.in_edges(e_v, data=weight)) > 1: + raise Exception + sub_u, sub_v, sub_w = next(k.in_edges(e_v, data=weight).__iter__()) + k.add_edge(e_u, e_v, **{weight: e_w}) + k.remove_edge(sub_u, sub_v) + if ( + max(d for n, d in k.in_degree()) <= 1 + and len(G) == k.number_of_edges() + and nx.is_weakly_connected(k) + ): + # Ascent method calculation + if d[sub_u] == d[e_u] or sub_w == e_w: + # Revert to the original graph + k.remove_edge(e_u, e_v) + k.add_edge(sub_u, sub_v, **{weight: sub_w}) + continue + epsilon = (sub_w - e_w) / (d[e_u] - d[sub_u]) + if 0 < epsilon < min_epsilon: + min_epsilon = epsilon + # Revert to the original graph + k.remove_edge(e_u, e_v) + k.add_edge(sub_u, sub_v, **{weight: sub_w}) + + return min_epsilon + + # I have to know that the elements in pi correspond to the correct elements + # in the direction of ascent, even if the node labels are not integers. + # Thus, I will use dictionaries to made that mapping. + pi_dict = {} + for n in G: + pi_dict[n] = 0 + del n + original_edge_weights = {} + for u, v, d in G.edges(data=True): + original_edge_weights[(u, v)] = d[weight] + dir_ascent, k_d = direction_of_ascent() + while dir_ascent is not None: + max_distance = find_epsilon(k_d, dir_ascent) + for n, v in dir_ascent.items(): + pi_dict[n] += max_distance * v + for u, v, d in G.edges(data=True): + d[weight] = original_edge_weights[(u, v)] + pi_dict[u] + dir_ascent, k_d = direction_of_ascent() + nx._clear_cache(G) + # k_d is no longer an individual 1-arborescence but rather a set of + # minimal 1-arborescences at the maximum point of the polytope and should + # be reflected as such + k_max = k_d + + # Search for a cycle within k_max. If a cycle exists, return it as the + # solution + for k in k_max: + if len([n for n in k if k.degree(n) == 2]) == G.order(): + # Tour found + # TODO: this branch does not restore original_edge_weights of G! + return k.size(weight), k + + # Write the original edge weights back to G and every member of k_max at + # the maximum point. Also average the number of times that edge appears in + # the set of minimal 1-arborescences. + x_star = {} + size_k_max = len(k_max) + for u, v, d in G.edges(data=True): + edge_count = 0 + d[weight] = original_edge_weights[(u, v)] + for k in k_max: + if (u, v) in k.edges(): + edge_count += 1 + k[u][v][weight] = original_edge_weights[(u, v)] + x_star[(u, v)] = edge_count / size_k_max + # Now symmetrize the edges in x_star and scale them according to (5) in + # reference [1] + z_star = {} + scale_factor = (G.order() - 1) / G.order() + for u, v in x_star: + frequency = x_star[(u, v)] + x_star[(v, u)] + if frequency > 0: + z_star[(u, v)] = scale_factor * frequency + del x_star + # Return the optimal weight and the z dict + return next(k_max.__iter__()).size(weight), z_star + + +@nx._dispatchable +def spanning_tree_distribution(G, z): + """ + Find the asadpour exponential distribution of spanning trees. + + Solves the Maximum Entropy Convex Program in the Asadpour algorithm [1]_ + using the approach in section 7 to build an exponential distribution of + undirected spanning trees. + + This algorithm ensures that the probability of any edge in a spanning + tree is proportional to the sum of the probabilities of the tress + containing that edge over the sum of the probabilities of all spanning + trees of the graph. + + Parameters + ---------- + G : nx.MultiGraph + The undirected support graph for the Held Karp relaxation + + z : dict + The output of `held_karp_ascent()`, a scaled version of the Held-Karp + solution. + + Returns + ------- + gamma : dict + The probability distribution which approximately preserves the marginal + probabilities of `z`. + """ + from math import exp + from math import log as ln + + def q(e): + """ + The value of q(e) is described in the Asadpour paper is "the + probability that edge e will be included in a spanning tree T that is + chosen with probability proportional to exp(gamma(T))" which + basically means that it is the total probability of the edge appearing + across the whole distribution. + + Parameters + ---------- + e : tuple + The `(u, v)` tuple describing the edge we are interested in + + Returns + ------- + float + The probability that a spanning tree chosen according to the + current values of gamma will include edge `e`. + """ + # Create the laplacian matrices + for u, v, d in G.edges(data=True): + d[lambda_key] = exp(gamma[(u, v)]) + G_Kirchhoff = nx.total_spanning_tree_weight(G, lambda_key) + G_e = nx.contracted_edge(G, e, self_loops=False) + G_e_Kirchhoff = nx.total_spanning_tree_weight(G_e, lambda_key) + + # Multiply by the weight of the contracted edge since it is not included + # in the total weight of the contracted graph. + return exp(gamma[(e[0], e[1])]) * G_e_Kirchhoff / G_Kirchhoff + + # initialize gamma to the zero dict + gamma = {} + for u, v, _ in G.edges: + gamma[(u, v)] = 0 + + # set epsilon + EPSILON = 0.2 + + # pick an edge attribute name that is unlikely to be in the graph + lambda_key = "spanning_tree_distribution's secret attribute name for lambda" + + while True: + # We need to know that know that no values of q_e are greater than + # (1 + epsilon) * z_e, however changing one gamma value can increase the + # value of a different q_e, so we have to complete the for loop without + # changing anything for the condition to be meet + in_range_count = 0 + # Search for an edge with q_e > (1 + epsilon) * z_e + for u, v in gamma: + e = (u, v) + q_e = q(e) + z_e = z[e] + if q_e > (1 + EPSILON) * z_e: + delta = ln( + (q_e * (1 - (1 + EPSILON / 2) * z_e)) + / ((1 - q_e) * (1 + EPSILON / 2) * z_e) + ) + gamma[e] -= delta + # Check that delta had the desired effect + new_q_e = q(e) + desired_q_e = (1 + EPSILON / 2) * z_e + if round(new_q_e, 8) != round(desired_q_e, 8): + raise nx.NetworkXError( + f"Unable to modify probability for edge ({u}, {v})" + ) + else: + in_range_count += 1 + # Check if the for loop terminated without changing any gamma + if in_range_count == len(gamma): + break + + # Remove the new edge attributes + for _, _, d in G.edges(data=True): + if lambda_key in d: + del d[lambda_key] + + return gamma + + +@nx._dispatchable(edge_attrs="weight") +def greedy_tsp(G, weight="weight", source=None): + """Return a low cost cycle starting at `source` and its cost. + + This approximates a solution to the traveling salesman problem. + It finds a cycle of all the nodes that a salesman can visit in order + to visit many nodes while minimizing total distance. + It uses a simple greedy algorithm. + In essence, this function returns a large cycle given a source point + for which the total cost of the cycle is minimized. + + Parameters + ---------- + G : Graph + The Graph should be a complete weighted undirected graph. + The distance between all pairs of nodes should be included. + + weight : string, optional (default="weight") + Edge data key corresponding to the edge weight. + If any edge does not have this attribute the weight is set to 1. + + source : node, optional (default: first node in list(G)) + Starting node. If None, defaults to ``next(iter(G))`` + + Returns + ------- + cycle : list of nodes + Returns the cycle (list of nodes) that a salesman + can follow to minimize total weight of the trip. + + Raises + ------ + NetworkXError + If `G` is not complete, the algorithm raises an exception. + + Examples + -------- + >>> from networkx.algorithms import approximation as approx + >>> G = nx.DiGraph() + >>> G.add_weighted_edges_from( + ... { + ... ("A", "B", 3), + ... ("A", "C", 17), + ... ("A", "D", 14), + ... ("B", "A", 3), + ... ("B", "C", 12), + ... ("B", "D", 16), + ... ("C", "A", 13), + ... ("C", "B", 12), + ... ("C", "D", 4), + ... ("D", "A", 14), + ... ("D", "B", 15), + ... ("D", "C", 2), + ... } + ... ) + >>> cycle = approx.greedy_tsp(G, source="D") + >>> cost = sum(G[n][nbr]["weight"] for n, nbr in nx.utils.pairwise(cycle)) + >>> cycle + ['D', 'C', 'B', 'A', 'D'] + >>> cost + 31 + + Notes + ----- + This implementation of a greedy algorithm is based on the following: + + - The algorithm adds a node to the solution at every iteration. + - The algorithm selects a node not already in the cycle whose connection + to the previous node adds the least cost to the cycle. + + A greedy algorithm does not always give the best solution. + However, it can construct a first feasible solution which can + be passed as a parameter to an iterative improvement algorithm such + as Simulated Annealing, or Threshold Accepting. + + Time complexity: It has a running time $O(|V|^2)$ + """ + # Check that G is a complete graph + N = len(G) - 1 + # This check ignores selfloops which is what we want here. + if any(len(nbrdict) - (n in nbrdict) != N for n, nbrdict in G.adj.items()): + raise nx.NetworkXError("G must be a complete graph.") + + if source is None: + source = nx.utils.arbitrary_element(G) + + if G.number_of_nodes() == 2: + neighbor = next(G.neighbors(source)) + return [source, neighbor, source] + + nodeset = set(G) + nodeset.remove(source) + cycle = [source] + next_node = source + while nodeset: + nbrdict = G[next_node] + next_node = min(nodeset, key=lambda n: nbrdict[n].get(weight, 1)) + cycle.append(next_node) + nodeset.remove(next_node) + cycle.append(cycle[0]) + return cycle + + +@py_random_state(9) +@nx._dispatchable(edge_attrs="weight") +def simulated_annealing_tsp( + G, + init_cycle, + weight="weight", + source=None, + temp=100, + move="1-1", + max_iterations=10, + N_inner=100, + alpha=0.01, + seed=None, +): + """Returns an approximate solution to the traveling salesman problem. + + This function uses simulated annealing to approximate the minimal cost + cycle through the nodes. Starting from a suboptimal solution, simulated + annealing perturbs that solution, occasionally accepting changes that make + the solution worse to escape from a locally optimal solution. The chance + of accepting such changes decreases over the iterations to encourage + an optimal result. In summary, the function returns a cycle starting + at `source` for which the total cost is minimized. It also returns the cost. + + The chance of accepting a proposed change is related to a parameter called + the temperature (annealing has a physical analogue of steel hardening + as it cools). As the temperature is reduced, the chance of moves that + increase cost goes down. + + Parameters + ---------- + G : Graph + `G` should be a complete weighted graph. + The distance between all pairs of nodes should be included. + + init_cycle : list of all nodes or "greedy" + The initial solution (a cycle through all nodes returning to the start). + This argument has no default to make you think about it. + If "greedy", use `greedy_tsp(G, weight)`. + Other common starting cycles are `list(G) + [next(iter(G))]` or the final + result of `simulated_annealing_tsp` when doing `threshold_accepting_tsp`. + + weight : string, optional (default="weight") + Edge data key corresponding to the edge weight. + If any edge does not have this attribute the weight is set to 1. + + source : node, optional (default: first node in list(G)) + Starting node. If None, defaults to ``next(iter(G))`` + + temp : int, optional (default=100) + The algorithm's temperature parameter. It represents the initial + value of temperature + + move : "1-1" or "1-0" or function, optional (default="1-1") + Indicator of what move to use when finding new trial solutions. + Strings indicate two special built-in moves: + + - "1-1": 1-1 exchange which transposes the position + of two elements of the current solution. + The function called is :func:`swap_two_nodes`. + For example if we apply 1-1 exchange in the solution + ``A = [3, 2, 1, 4, 3]`` + we can get the following by the transposition of 1 and 4 elements: + ``A' = [3, 2, 4, 1, 3]`` + - "1-0": 1-0 exchange which moves an node in the solution + to a new position. + The function called is :func:`move_one_node`. + For example if we apply 1-0 exchange in the solution + ``A = [3, 2, 1, 4, 3]`` + we can transfer the fourth element to the second position: + ``A' = [3, 4, 2, 1, 3]`` + + You may provide your own functions to enact a move from + one solution to a neighbor solution. The function must take + the solution as input along with a `seed` input to control + random number generation (see the `seed` input here). + Your function should maintain the solution as a cycle with + equal first and last node and all others appearing once. + Your function should return the new solution. + + max_iterations : int, optional (default=10) + Declared done when this number of consecutive iterations of + the outer loop occurs without any change in the best cost solution. + + N_inner : int, optional (default=100) + The number of iterations of the inner loop. + + alpha : float between (0, 1), optional (default=0.01) + Percentage of temperature decrease in each iteration + of outer loop + + seed : integer, random_state, or None (default) + Indicator of random number generation state. + See :ref:`Randomness`. + + Returns + ------- + cycle : list of nodes + Returns the cycle (list of nodes) that a salesman + can follow to minimize total weight of the trip. + + Raises + ------ + NetworkXError + If `G` is not complete the algorithm raises an exception. + + Examples + -------- + >>> from networkx.algorithms import approximation as approx + >>> G = nx.DiGraph() + >>> G.add_weighted_edges_from( + ... { + ... ("A", "B", 3), + ... ("A", "C", 17), + ... ("A", "D", 14), + ... ("B", "A", 3), + ... ("B", "C", 12), + ... ("B", "D", 16), + ... ("C", "A", 13), + ... ("C", "B", 12), + ... ("C", "D", 4), + ... ("D", "A", 14), + ... ("D", "B", 15), + ... ("D", "C", 2), + ... } + ... ) + >>> cycle = approx.simulated_annealing_tsp(G, "greedy", source="D") + >>> cost = sum(G[n][nbr]["weight"] for n, nbr in nx.utils.pairwise(cycle)) + >>> cycle + ['D', 'C', 'B', 'A', 'D'] + >>> cost + 31 + >>> incycle = ["D", "B", "A", "C", "D"] + >>> cycle = approx.simulated_annealing_tsp(G, incycle, source="D") + >>> cost = sum(G[n][nbr]["weight"] for n, nbr in nx.utils.pairwise(cycle)) + >>> cycle + ['D', 'C', 'B', 'A', 'D'] + >>> cost + 31 + + Notes + ----- + Simulated Annealing is a metaheuristic local search algorithm. + The main characteristic of this algorithm is that it accepts + even solutions which lead to the increase of the cost in order + to escape from low quality local optimal solutions. + + This algorithm needs an initial solution. If not provided, it is + constructed by a simple greedy algorithm. At every iteration, the + algorithm selects thoughtfully a neighbor solution. + Consider $c(x)$ cost of current solution and $c(x')$ cost of a + neighbor solution. + If $c(x') - c(x) <= 0$ then the neighbor solution becomes the current + solution for the next iteration. Otherwise, the algorithm accepts + the neighbor solution with probability $p = exp - ([c(x') - c(x)] / temp)$. + Otherwise the current solution is retained. + + `temp` is a parameter of the algorithm and represents temperature. + + Time complexity: + For $N_i$ iterations of the inner loop and $N_o$ iterations of the + outer loop, this algorithm has running time $O(N_i * N_o * |V|)$. + + For more information and how the algorithm is inspired see: + http://en.wikipedia.org/wiki/Simulated_annealing + """ + if move == "1-1": + move = swap_two_nodes + elif move == "1-0": + move = move_one_node + if init_cycle == "greedy": + # Construct an initial solution using a greedy algorithm. + cycle = greedy_tsp(G, weight=weight, source=source) + if G.number_of_nodes() == 2: + return cycle + + else: + cycle = list(init_cycle) + if source is None: + source = cycle[0] + elif source != cycle[0]: + raise nx.NetworkXError("source must be first node in init_cycle") + if cycle[0] != cycle[-1]: + raise nx.NetworkXError("init_cycle must be a cycle. (return to start)") + + if len(cycle) - 1 != len(G) or len(set(G.nbunch_iter(cycle))) != len(G): + raise nx.NetworkXError("init_cycle should be a cycle over all nodes in G.") + + # Check that G is a complete graph + N = len(G) - 1 + # This check ignores selfloops which is what we want here. + if any(len(nbrdict) - (n in nbrdict) != N for n, nbrdict in G.adj.items()): + raise nx.NetworkXError("G must be a complete graph.") + + if G.number_of_nodes() == 2: + neighbor = next(G.neighbors(source)) + return [source, neighbor, source] + + # Find the cost of initial solution + cost = sum(G[u][v].get(weight, 1) for u, v in pairwise(cycle)) + + count = 0 + best_cycle = cycle.copy() + best_cost = cost + while count <= max_iterations and temp > 0: + count += 1 + for i in range(N_inner): + adj_sol = move(cycle, seed) + adj_cost = sum(G[u][v].get(weight, 1) for u, v in pairwise(adj_sol)) + delta = adj_cost - cost + if delta <= 0: + # Set current solution the adjacent solution. + cycle = adj_sol + cost = adj_cost + + if cost < best_cost: + count = 0 + best_cycle = cycle.copy() + best_cost = cost + else: + # Accept even a worse solution with probability p. + p = math.exp(-delta / temp) + if p >= seed.random(): + cycle = adj_sol + cost = adj_cost + temp -= temp * alpha + + return best_cycle + + +@py_random_state(9) +@nx._dispatchable(edge_attrs="weight") +def threshold_accepting_tsp( + G, + init_cycle, + weight="weight", + source=None, + threshold=1, + move="1-1", + max_iterations=10, + N_inner=100, + alpha=0.1, + seed=None, +): + """Returns an approximate solution to the traveling salesman problem. + + This function uses threshold accepting methods to approximate the minimal cost + cycle through the nodes. Starting from a suboptimal solution, threshold + accepting methods perturb that solution, accepting any changes that make + the solution no worse than increasing by a threshold amount. Improvements + in cost are accepted, but so are changes leading to small increases in cost. + This allows the solution to leave suboptimal local minima in solution space. + The threshold is decreased slowly as iterations proceed helping to ensure + an optimum. In summary, the function returns a cycle starting at `source` + for which the total cost is minimized. + + Parameters + ---------- + G : Graph + `G` should be a complete weighted graph. + The distance between all pairs of nodes should be included. + + init_cycle : list or "greedy" + The initial solution (a cycle through all nodes returning to the start). + This argument has no default to make you think about it. + If "greedy", use `greedy_tsp(G, weight)`. + Other common starting cycles are `list(G) + [next(iter(G))]` or the final + result of `simulated_annealing_tsp` when doing `threshold_accepting_tsp`. + + weight : string, optional (default="weight") + Edge data key corresponding to the edge weight. + If any edge does not have this attribute the weight is set to 1. + + source : node, optional (default: first node in list(G)) + Starting node. If None, defaults to ``next(iter(G))`` + + threshold : int, optional (default=1) + The algorithm's threshold parameter. It represents the initial + threshold's value + + move : "1-1" or "1-0" or function, optional (default="1-1") + Indicator of what move to use when finding new trial solutions. + Strings indicate two special built-in moves: + + - "1-1": 1-1 exchange which transposes the position + of two elements of the current solution. + The function called is :func:`swap_two_nodes`. + For example if we apply 1-1 exchange in the solution + ``A = [3, 2, 1, 4, 3]`` + we can get the following by the transposition of 1 and 4 elements: + ``A' = [3, 2, 4, 1, 3]`` + - "1-0": 1-0 exchange which moves an node in the solution + to a new position. + The function called is :func:`move_one_node`. + For example if we apply 1-0 exchange in the solution + ``A = [3, 2, 1, 4, 3]`` + we can transfer the fourth element to the second position: + ``A' = [3, 4, 2, 1, 3]`` + + You may provide your own functions to enact a move from + one solution to a neighbor solution. The function must take + the solution as input along with a `seed` input to control + random number generation (see the `seed` input here). + Your function should maintain the solution as a cycle with + equal first and last node and all others appearing once. + Your function should return the new solution. + + max_iterations : int, optional (default=10) + Declared done when this number of consecutive iterations of + the outer loop occurs without any change in the best cost solution. + + N_inner : int, optional (default=100) + The number of iterations of the inner loop. + + alpha : float between (0, 1), optional (default=0.1) + Percentage of threshold decrease when there is at + least one acceptance of a neighbor solution. + If no inner loop moves are accepted the threshold remains unchanged. + + seed : integer, random_state, or None (default) + Indicator of random number generation state. + See :ref:`Randomness`. + + Returns + ------- + cycle : list of nodes + Returns the cycle (list of nodes) that a salesman + can follow to minimize total weight of the trip. + + Raises + ------ + NetworkXError + If `G` is not complete the algorithm raises an exception. + + Examples + -------- + >>> from networkx.algorithms import approximation as approx + >>> G = nx.DiGraph() + >>> G.add_weighted_edges_from( + ... { + ... ("A", "B", 3), + ... ("A", "C", 17), + ... ("A", "D", 14), + ... ("B", "A", 3), + ... ("B", "C", 12), + ... ("B", "D", 16), + ... ("C", "A", 13), + ... ("C", "B", 12), + ... ("C", "D", 4), + ... ("D", "A", 14), + ... ("D", "B", 15), + ... ("D", "C", 2), + ... } + ... ) + >>> cycle = approx.threshold_accepting_tsp(G, "greedy", source="D") + >>> cost = sum(G[n][nbr]["weight"] for n, nbr in nx.utils.pairwise(cycle)) + >>> cycle + ['D', 'C', 'B', 'A', 'D'] + >>> cost + 31 + >>> incycle = ["D", "B", "A", "C", "D"] + >>> cycle = approx.threshold_accepting_tsp(G, incycle, source="D") + >>> cost = sum(G[n][nbr]["weight"] for n, nbr in nx.utils.pairwise(cycle)) + >>> cycle + ['D', 'C', 'B', 'A', 'D'] + >>> cost + 31 + + Notes + ----- + Threshold Accepting is a metaheuristic local search algorithm. + The main characteristic of this algorithm is that it accepts + even solutions which lead to the increase of the cost in order + to escape from low quality local optimal solutions. + + This algorithm needs an initial solution. This solution can be + constructed by a simple greedy algorithm. At every iteration, it + selects thoughtfully a neighbor solution. + Consider $c(x)$ cost of current solution and $c(x')$ cost of + neighbor solution. + If $c(x') - c(x) <= threshold$ then the neighbor solution becomes the current + solution for the next iteration, where the threshold is named threshold. + + In comparison to the Simulated Annealing algorithm, the Threshold + Accepting algorithm does not accept very low quality solutions + (due to the presence of the threshold value). In the case of + Simulated Annealing, even a very low quality solution can + be accepted with probability $p$. + + Time complexity: + It has a running time $O(m * n * |V|)$ where $m$ and $n$ are the number + of times the outer and inner loop run respectively. + + For more information and how algorithm is inspired see: + https://doi.org/10.1016/0021-9991(90)90201-B + + See Also + -------- + simulated_annealing_tsp + + """ + if move == "1-1": + move = swap_two_nodes + elif move == "1-0": + move = move_one_node + if init_cycle == "greedy": + # Construct an initial solution using a greedy algorithm. + cycle = greedy_tsp(G, weight=weight, source=source) + if G.number_of_nodes() == 2: + return cycle + + else: + cycle = list(init_cycle) + if source is None: + source = cycle[0] + elif source != cycle[0]: + raise nx.NetworkXError("source must be first node in init_cycle") + if cycle[0] != cycle[-1]: + raise nx.NetworkXError("init_cycle must be a cycle. (return to start)") + + if len(cycle) - 1 != len(G) or len(set(G.nbunch_iter(cycle))) != len(G): + raise nx.NetworkXError("init_cycle is not all and only nodes.") + + # Check that G is a complete graph + N = len(G) - 1 + # This check ignores selfloops which is what we want here. + if any(len(nbrdict) - (n in nbrdict) != N for n, nbrdict in G.adj.items()): + raise nx.NetworkXError("G must be a complete graph.") + + if G.number_of_nodes() == 2: + neighbor = list(G.neighbors(source))[0] + return [source, neighbor, source] + + # Find the cost of initial solution + cost = sum(G[u][v].get(weight, 1) for u, v in pairwise(cycle)) + + count = 0 + best_cycle = cycle.copy() + best_cost = cost + while count <= max_iterations: + count += 1 + accepted = False + for i in range(N_inner): + adj_sol = move(cycle, seed) + adj_cost = sum(G[u][v].get(weight, 1) for u, v in pairwise(adj_sol)) + delta = adj_cost - cost + if delta <= threshold: + accepted = True + + # Set current solution the adjacent solution. + cycle = adj_sol + cost = adj_cost + + if cost < best_cost: + count = 0 + best_cycle = cycle.copy() + best_cost = cost + if accepted: + threshold -= threshold * alpha + + return best_cycle diff --git a/lib/python3.10/site-packages/networkx/algorithms/approximation/treewidth.py b/lib/python3.10/site-packages/networkx/algorithms/approximation/treewidth.py new file mode 100644 index 0000000000000000000000000000000000000000..31d73f6368237c16ab6a66efee45e768ddfaef52 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/approximation/treewidth.py @@ -0,0 +1,252 @@ +"""Functions for computing treewidth decomposition. + +Treewidth of an undirected graph is a number associated with the graph. +It can be defined as the size of the largest vertex set (bag) in a tree +decomposition of the graph minus one. + +`Wikipedia: Treewidth `_ + +The notions of treewidth and tree decomposition have gained their +attractiveness partly because many graph and network problems that are +intractable (e.g., NP-hard) on arbitrary graphs become efficiently +solvable (e.g., with a linear time algorithm) when the treewidth of the +input graphs is bounded by a constant [1]_ [2]_. + +There are two different functions for computing a tree decomposition: +:func:`treewidth_min_degree` and :func:`treewidth_min_fill_in`. + +.. [1] Hans L. Bodlaender and Arie M. C. A. Koster. 2010. "Treewidth + computations I.Upper bounds". Inf. Comput. 208, 3 (March 2010),259-275. + http://dx.doi.org/10.1016/j.ic.2009.03.008 + +.. [2] Hans L. Bodlaender. "Discovering Treewidth". Institute of Information + and Computing Sciences, Utrecht University. + Technical Report UU-CS-2005-018. + http://www.cs.uu.nl + +.. [3] K. Wang, Z. Lu, and J. Hicks *Treewidth*. + https://web.archive.org/web/20210507025929/http://web.eecs.utk.edu/~cphill25/cs594_spring2015_projects/treewidth.pdf + +""" + +import itertools +import sys +from heapq import heapify, heappop, heappush + +import networkx as nx +from networkx.utils import not_implemented_for + +__all__ = ["treewidth_min_degree", "treewidth_min_fill_in"] + + +@not_implemented_for("directed") +@not_implemented_for("multigraph") +@nx._dispatchable(returns_graph=True) +def treewidth_min_degree(G): + """Returns a treewidth decomposition using the Minimum Degree heuristic. + + The heuristic chooses the nodes according to their degree, i.e., first + the node with the lowest degree is chosen, then the graph is updated + and the corresponding node is removed. Next, a new node with the lowest + degree is chosen, and so on. + + Parameters + ---------- + G : NetworkX graph + + Returns + ------- + Treewidth decomposition : (int, Graph) tuple + 2-tuple with treewidth and the corresponding decomposed tree. + """ + deg_heuristic = MinDegreeHeuristic(G) + return treewidth_decomp(G, lambda graph: deg_heuristic.best_node(graph)) + + +@not_implemented_for("directed") +@not_implemented_for("multigraph") +@nx._dispatchable(returns_graph=True) +def treewidth_min_fill_in(G): + """Returns a treewidth decomposition using the Minimum Fill-in heuristic. + + The heuristic chooses a node from the graph, where the number of edges + added turning the neighborhood of the chosen node into clique is as + small as possible. + + Parameters + ---------- + G : NetworkX graph + + Returns + ------- + Treewidth decomposition : (int, Graph) tuple + 2-tuple with treewidth and the corresponding decomposed tree. + """ + return treewidth_decomp(G, min_fill_in_heuristic) + + +class MinDegreeHeuristic: + """Implements the Minimum Degree heuristic. + + The heuristic chooses the nodes according to their degree + (number of neighbors), i.e., first the node with the lowest degree is + chosen, then the graph is updated and the corresponding node is + removed. Next, a new node with the lowest degree is chosen, and so on. + """ + + def __init__(self, graph): + self._graph = graph + + # nodes that have to be updated in the heap before each iteration + self._update_nodes = [] + + self._degreeq = [] # a heapq with 3-tuples (degree,unique_id,node) + self.count = itertools.count() + + # build heap with initial degrees + for n in graph: + self._degreeq.append((len(graph[n]), next(self.count), n)) + heapify(self._degreeq) + + def best_node(self, graph): + # update nodes in self._update_nodes + for n in self._update_nodes: + # insert changed degrees into degreeq + heappush(self._degreeq, (len(graph[n]), next(self.count), n)) + + # get the next valid (minimum degree) node + while self._degreeq: + (min_degree, _, elim_node) = heappop(self._degreeq) + if elim_node not in graph or len(graph[elim_node]) != min_degree: + # outdated entry in degreeq + continue + elif min_degree == len(graph) - 1: + # fully connected: abort condition + return None + + # remember to update nodes in the heap before getting the next node + self._update_nodes = graph[elim_node] + return elim_node + + # the heap is empty: abort + return None + + +def min_fill_in_heuristic(graph): + """Implements the Minimum Degree heuristic. + + Returns the node from the graph, where the number of edges added when + turning the neighborhood of the chosen node into clique is as small as + possible. This algorithm chooses the nodes using the Minimum Fill-In + heuristic. The running time of the algorithm is :math:`O(V^3)` and it uses + additional constant memory.""" + + if len(graph) == 0: + return None + + min_fill_in_node = None + + min_fill_in = sys.maxsize + + # sort nodes by degree + nodes_by_degree = sorted(graph, key=lambda x: len(graph[x])) + min_degree = len(graph[nodes_by_degree[0]]) + + # abort condition (handle complete graph) + if min_degree == len(graph) - 1: + return None + + for node in nodes_by_degree: + num_fill_in = 0 + nbrs = graph[node] + for nbr in nbrs: + # count how many nodes in nbrs current nbr is not connected to + # subtract 1 for the node itself + num_fill_in += len(nbrs - graph[nbr]) - 1 + if num_fill_in >= 2 * min_fill_in: + break + + num_fill_in /= 2 # divide by 2 because of double counting + + if num_fill_in < min_fill_in: # update min-fill-in node + if num_fill_in == 0: + return node + min_fill_in = num_fill_in + min_fill_in_node = node + + return min_fill_in_node + + +@nx._dispatchable(returns_graph=True) +def treewidth_decomp(G, heuristic=min_fill_in_heuristic): + """Returns a treewidth decomposition using the passed heuristic. + + Parameters + ---------- + G : NetworkX graph + heuristic : heuristic function + + Returns + ------- + Treewidth decomposition : (int, Graph) tuple + 2-tuple with treewidth and the corresponding decomposed tree. + """ + + # make dict-of-sets structure + graph = {n: set(G[n]) - {n} for n in G} + + # stack containing nodes and neighbors in the order from the heuristic + node_stack = [] + + # get first node from heuristic + elim_node = heuristic(graph) + while elim_node is not None: + # connect all neighbors with each other + nbrs = graph[elim_node] + for u, v in itertools.permutations(nbrs, 2): + if v not in graph[u]: + graph[u].add(v) + + # push node and its current neighbors on stack + node_stack.append((elim_node, nbrs)) + + # remove node from graph + for u in graph[elim_node]: + graph[u].remove(elim_node) + + del graph[elim_node] + elim_node = heuristic(graph) + + # the abort condition is met; put all remaining nodes into one bag + decomp = nx.Graph() + first_bag = frozenset(graph.keys()) + decomp.add_node(first_bag) + + treewidth = len(first_bag) - 1 + + while node_stack: + # get node and its neighbors from the stack + (curr_node, nbrs) = node_stack.pop() + + # find a bag all neighbors are in + old_bag = None + for bag in decomp.nodes: + if nbrs <= bag: + old_bag = bag + break + + if old_bag is None: + # no old_bag was found: just connect to the first_bag + old_bag = first_bag + + # create new node for decomposition + nbrs.add(curr_node) + new_bag = frozenset(nbrs) + + # update treewidth + treewidth = max(treewidth, len(new_bag) - 1) + + # add edge to decomposition (implicitly also adds the new node) + decomp.add_edge(old_bag, new_bag) + + return treewidth, decomp diff --git a/lib/python3.10/site-packages/networkx/algorithms/approximation/vertex_cover.py b/lib/python3.10/site-packages/networkx/algorithms/approximation/vertex_cover.py new file mode 100644 index 0000000000000000000000000000000000000000..13d7167cfc1e4494cbbb2ee8c774e9ffbc3ee495 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/approximation/vertex_cover.py @@ -0,0 +1,83 @@ +"""Functions for computing an approximate minimum weight vertex cover. + +A |vertex cover|_ is a subset of nodes such that each edge in the graph +is incident to at least one node in the subset. + +.. _vertex cover: https://en.wikipedia.org/wiki/Vertex_cover +.. |vertex cover| replace:: *vertex cover* + +""" + +import networkx as nx + +__all__ = ["min_weighted_vertex_cover"] + + +@nx._dispatchable(node_attrs="weight") +def min_weighted_vertex_cover(G, weight=None): + r"""Returns an approximate minimum weighted vertex cover. + + The set of nodes returned by this function is guaranteed to be a + vertex cover, and the total weight of the set is guaranteed to be at + most twice the total weight of the minimum weight vertex cover. In + other words, + + .. math:: + + w(S) \leq 2 * w(S^*), + + where $S$ is the vertex cover returned by this function, + $S^*$ is the vertex cover of minimum weight out of all vertex + covers of the graph, and $w$ is the function that computes the + sum of the weights of each node in that given set. + + Parameters + ---------- + G : NetworkX graph + + weight : string, optional (default = None) + If None, every node has weight 1. If a string, use this node + attribute as the node weight. A node without this attribute is + assumed to have weight 1. + + Returns + ------- + min_weighted_cover : set + Returns a set of nodes whose weight sum is no more than twice + the weight sum of the minimum weight vertex cover. + + Notes + ----- + For a directed graph, a vertex cover has the same definition: a set + of nodes such that each edge in the graph is incident to at least + one node in the set. Whether the node is the head or tail of the + directed edge is ignored. + + This is the local-ratio algorithm for computing an approximate + vertex cover. The algorithm greedily reduces the costs over edges, + iteratively building a cover. The worst-case runtime of this + implementation is $O(m \log n)$, where $n$ is the number + of nodes and $m$ the number of edges in the graph. + + References + ---------- + .. [1] Bar-Yehuda, R., and Even, S. (1985). "A local-ratio theorem for + approximating the weighted vertex cover problem." + *Annals of Discrete Mathematics*, 25, 27–46 + + + """ + cost = dict(G.nodes(data=weight, default=1)) + # While there are uncovered edges, choose an uncovered and update + # the cost of the remaining edges. + cover = set() + for u, v in G.edges(): + if u in cover or v in cover: + continue + if cost[u] <= cost[v]: + cover.add(u) + cost[v] -= cost[u] + else: + cover.add(v) + cost[u] -= cost[v] + return cover diff --git a/lib/python3.10/site-packages/networkx/algorithms/assortativity/__init__.py b/lib/python3.10/site-packages/networkx/algorithms/assortativity/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4d9888609cbc43d4ba2121fcd0feda0985d1aebd --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/assortativity/__init__.py @@ -0,0 +1,5 @@ +from networkx.algorithms.assortativity.connectivity import * +from networkx.algorithms.assortativity.correlation import * +from networkx.algorithms.assortativity.mixing import * +from networkx.algorithms.assortativity.neighbor_degree import * +from networkx.algorithms.assortativity.pairs import * diff --git a/lib/python3.10/site-packages/networkx/algorithms/assortativity/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/assortativity/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b78c60721c913354db35d1c05dcdfa2be2ac5695 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/assortativity/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/assortativity/__pycache__/connectivity.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/assortativity/__pycache__/connectivity.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee02b2a8796177a9f16eb50f4c5aea71f7ab4257 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/assortativity/__pycache__/connectivity.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/assortativity/__pycache__/correlation.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/assortativity/__pycache__/correlation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d805ea37f6ee7d6c7e03d8aae05c7199193aca8 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/assortativity/__pycache__/correlation.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/assortativity/__pycache__/mixing.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/assortativity/__pycache__/mixing.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc9f1a60068cdd0ac4dc197408ffb39c8becc3fc Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/assortativity/__pycache__/mixing.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/assortativity/__pycache__/neighbor_degree.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/assortativity/__pycache__/neighbor_degree.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62ca804a85cffc6e03cc74866eb23443a1597fbc Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/assortativity/__pycache__/neighbor_degree.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/assortativity/__pycache__/pairs.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/assortativity/__pycache__/pairs.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f329afb6ed7080f41377fda2b88a3c8a7e7824fd Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/assortativity/__pycache__/pairs.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/assortativity/connectivity.py b/lib/python3.10/site-packages/networkx/algorithms/assortativity/connectivity.py new file mode 100644 index 0000000000000000000000000000000000000000..c3fde0da68a1990da29ced6996620d709c52c13d --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/assortativity/connectivity.py @@ -0,0 +1,122 @@ +from collections import defaultdict + +import networkx as nx + +__all__ = ["average_degree_connectivity"] + + +@nx._dispatchable(edge_attrs="weight") +def average_degree_connectivity( + G, source="in+out", target="in+out", nodes=None, weight=None +): + r"""Compute the average degree connectivity of graph. + + The average degree connectivity is the average nearest neighbor degree of + nodes with degree k. For weighted graphs, an analogous measure can + be computed using the weighted average neighbors degree defined in + [1]_, for a node `i`, as + + .. math:: + + k_{nn,i}^{w} = \frac{1}{s_i} \sum_{j \in N(i)} w_{ij} k_j + + where `s_i` is the weighted degree of node `i`, + `w_{ij}` is the weight of the edge that links `i` and `j`, + and `N(i)` are the neighbors of node `i`. + + Parameters + ---------- + G : NetworkX graph + + source : "in"|"out"|"in+out" (default:"in+out") + Directed graphs only. Use "in"- or "out"-degree for source node. + + target : "in"|"out"|"in+out" (default:"in+out" + Directed graphs only. Use "in"- or "out"-degree for target node. + + nodes : list or iterable (optional) + Compute neighbor connectivity for these nodes. The default is all + nodes. + + weight : string or None, optional (default=None) + The edge attribute that holds the numerical value used as a weight. + If None, then each edge has weight 1. + + Returns + ------- + d : dict + A dictionary keyed by degree k with the value of average connectivity. + + Raises + ------ + NetworkXError + If either `source` or `target` are not one of 'in', + 'out', or 'in+out'. + If either `source` or `target` is passed for an undirected graph. + + Examples + -------- + >>> G = nx.path_graph(4) + >>> G.edges[1, 2]["weight"] = 3 + >>> nx.average_degree_connectivity(G) + {1: 2.0, 2: 1.5} + >>> nx.average_degree_connectivity(G, weight="weight") + {1: 2.0, 2: 1.75} + + See Also + -------- + average_neighbor_degree + + References + ---------- + .. [1] A. Barrat, M. Barthélemy, R. Pastor-Satorras, and A. Vespignani, + "The architecture of complex weighted networks". + PNAS 101 (11): 3747–3752 (2004). + """ + # First, determine the type of neighbors and the type of degree to use. + if G.is_directed(): + if source not in ("in", "out", "in+out"): + raise nx.NetworkXError('source must be one of "in", "out", or "in+out"') + if target not in ("in", "out", "in+out"): + raise nx.NetworkXError('target must be one of "in", "out", or "in+out"') + direction = {"out": G.out_degree, "in": G.in_degree, "in+out": G.degree} + neighbor_funcs = { + "out": G.successors, + "in": G.predecessors, + "in+out": G.neighbors, + } + source_degree = direction[source] + target_degree = direction[target] + neighbors = neighbor_funcs[source] + # `reverse` indicates whether to look at the in-edge when + # computing the weight of an edge. + reverse = source == "in" + else: + if source != "in+out" or target != "in+out": + raise nx.NetworkXError( + f"source and target arguments are only supported for directed graphs" + ) + source_degree = G.degree + target_degree = G.degree + neighbors = G.neighbors + reverse = False + dsum = defaultdict(int) + dnorm = defaultdict(int) + # Check if `source_nodes` is actually a single node in the graph. + source_nodes = source_degree(nodes) + if nodes in G: + source_nodes = [(nodes, source_degree(nodes))] + for n, k in source_nodes: + nbrdeg = target_degree(neighbors(n)) + if weight is None: + s = sum(d for n, d in nbrdeg) + else: # weight nbr degree by weight of (n,nbr) edge + if reverse: + s = sum(G[nbr][n].get(weight, 1) * d for nbr, d in nbrdeg) + else: + s = sum(G[n][nbr].get(weight, 1) * d for nbr, d in nbrdeg) + dnorm[k] += source_degree(n, weight=weight) + dsum[k] += s + + # normalize + return {k: avg if dnorm[k] == 0 else avg / dnorm[k] for k, avg in dsum.items()} diff --git a/lib/python3.10/site-packages/networkx/algorithms/assortativity/correlation.py b/lib/python3.10/site-packages/networkx/algorithms/assortativity/correlation.py new file mode 100644 index 0000000000000000000000000000000000000000..52ae7a12fa9de5705412538fc6bbe873755d9b7a --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/assortativity/correlation.py @@ -0,0 +1,302 @@ +"""Node assortativity coefficients and correlation measures.""" + +import networkx as nx +from networkx.algorithms.assortativity.mixing import ( + attribute_mixing_matrix, + degree_mixing_matrix, +) +from networkx.algorithms.assortativity.pairs import node_degree_xy + +__all__ = [ + "degree_pearson_correlation_coefficient", + "degree_assortativity_coefficient", + "attribute_assortativity_coefficient", + "numeric_assortativity_coefficient", +] + + +@nx._dispatchable(edge_attrs="weight") +def degree_assortativity_coefficient(G, x="out", y="in", weight=None, nodes=None): + """Compute degree assortativity of graph. + + Assortativity measures the similarity of connections + in the graph with respect to the node degree. + + Parameters + ---------- + G : NetworkX graph + + x: string ('in','out') + The degree type for source node (directed graphs only). + + y: string ('in','out') + The degree type for target node (directed graphs only). + + weight: string or None, optional (default=None) + The edge attribute that holds the numerical value used + as a weight. If None, then each edge has weight 1. + The degree is the sum of the edge weights adjacent to the node. + + nodes: list or iterable (optional) + Compute degree assortativity only for nodes in container. + The default is all nodes. + + Returns + ------- + r : float + Assortativity of graph by degree. + + Examples + -------- + >>> G = nx.path_graph(4) + >>> r = nx.degree_assortativity_coefficient(G) + >>> print(f"{r:3.1f}") + -0.5 + + See Also + -------- + attribute_assortativity_coefficient + numeric_assortativity_coefficient + degree_mixing_dict + degree_mixing_matrix + + Notes + ----- + This computes Eq. (21) in Ref. [1]_ , where e is the joint + probability distribution (mixing matrix) of the degrees. If G is + directed than the matrix e is the joint probability of the + user-specified degree type for the source and target. + + References + ---------- + .. [1] M. E. J. Newman, Mixing patterns in networks, + Physical Review E, 67 026126, 2003 + .. [2] Foster, J.G., Foster, D.V., Grassberger, P. & Paczuski, M. + Edge direction and the structure of networks, PNAS 107, 10815-20 (2010). + """ + if nodes is None: + nodes = G.nodes + + degrees = None + + if G.is_directed(): + indeg = ( + {d for _, d in G.in_degree(nodes, weight=weight)} + if "in" in (x, y) + else set() + ) + outdeg = ( + {d for _, d in G.out_degree(nodes, weight=weight)} + if "out" in (x, y) + else set() + ) + degrees = set.union(indeg, outdeg) + else: + degrees = {d for _, d in G.degree(nodes, weight=weight)} + + mapping = {d: i for i, d in enumerate(degrees)} + M = degree_mixing_matrix(G, x=x, y=y, nodes=nodes, weight=weight, mapping=mapping) + + return _numeric_ac(M, mapping=mapping) + + +@nx._dispatchable(edge_attrs="weight") +def degree_pearson_correlation_coefficient(G, x="out", y="in", weight=None, nodes=None): + """Compute degree assortativity of graph. + + Assortativity measures the similarity of connections + in the graph with respect to the node degree. + + This is the same as degree_assortativity_coefficient but uses the + potentially faster scipy.stats.pearsonr function. + + Parameters + ---------- + G : NetworkX graph + + x: string ('in','out') + The degree type for source node (directed graphs only). + + y: string ('in','out') + The degree type for target node (directed graphs only). + + weight: string or None, optional (default=None) + The edge attribute that holds the numerical value used + as a weight. If None, then each edge has weight 1. + The degree is the sum of the edge weights adjacent to the node. + + nodes: list or iterable (optional) + Compute pearson correlation of degrees only for specified nodes. + The default is all nodes. + + Returns + ------- + r : float + Assortativity of graph by degree. + + Examples + -------- + >>> G = nx.path_graph(4) + >>> r = nx.degree_pearson_correlation_coefficient(G) + >>> print(f"{r:3.1f}") + -0.5 + + Notes + ----- + This calls scipy.stats.pearsonr. + + References + ---------- + .. [1] M. E. J. Newman, Mixing patterns in networks + Physical Review E, 67 026126, 2003 + .. [2] Foster, J.G., Foster, D.V., Grassberger, P. & Paczuski, M. + Edge direction and the structure of networks, PNAS 107, 10815-20 (2010). + """ + import scipy as sp + + xy = node_degree_xy(G, x=x, y=y, nodes=nodes, weight=weight) + x, y = zip(*xy) + return float(sp.stats.pearsonr(x, y)[0]) + + +@nx._dispatchable(node_attrs="attribute") +def attribute_assortativity_coefficient(G, attribute, nodes=None): + """Compute assortativity for node attributes. + + Assortativity measures the similarity of connections + in the graph with respect to the given attribute. + + Parameters + ---------- + G : NetworkX graph + + attribute : string + Node attribute key + + nodes: list or iterable (optional) + Compute attribute assortativity for nodes in container. + The default is all nodes. + + Returns + ------- + r: float + Assortativity of graph for given attribute + + Examples + -------- + >>> G = nx.Graph() + >>> G.add_nodes_from([0, 1], color="red") + >>> G.add_nodes_from([2, 3], color="blue") + >>> G.add_edges_from([(0, 1), (2, 3)]) + >>> print(nx.attribute_assortativity_coefficient(G, "color")) + 1.0 + + Notes + ----- + This computes Eq. (2) in Ref. [1]_ , (trace(M)-sum(M^2))/(1-sum(M^2)), + where M is the joint probability distribution (mixing matrix) + of the specified attribute. + + References + ---------- + .. [1] M. E. J. Newman, Mixing patterns in networks, + Physical Review E, 67 026126, 2003 + """ + M = attribute_mixing_matrix(G, attribute, nodes) + return attribute_ac(M) + + +@nx._dispatchable(node_attrs="attribute") +def numeric_assortativity_coefficient(G, attribute, nodes=None): + """Compute assortativity for numerical node attributes. + + Assortativity measures the similarity of connections + in the graph with respect to the given numeric attribute. + + Parameters + ---------- + G : NetworkX graph + + attribute : string + Node attribute key. + + nodes: list or iterable (optional) + Compute numeric assortativity only for attributes of nodes in + container. The default is all nodes. + + Returns + ------- + r: float + Assortativity of graph for given attribute + + Examples + -------- + >>> G = nx.Graph() + >>> G.add_nodes_from([0, 1], size=2) + >>> G.add_nodes_from([2, 3], size=3) + >>> G.add_edges_from([(0, 1), (2, 3)]) + >>> print(nx.numeric_assortativity_coefficient(G, "size")) + 1.0 + + Notes + ----- + This computes Eq. (21) in Ref. [1]_ , which is the Pearson correlation + coefficient of the specified (scalar valued) attribute across edges. + + References + ---------- + .. [1] M. E. J. Newman, Mixing patterns in networks + Physical Review E, 67 026126, 2003 + """ + if nodes is None: + nodes = G.nodes + vals = {G.nodes[n][attribute] for n in nodes} + mapping = {d: i for i, d in enumerate(vals)} + M = attribute_mixing_matrix(G, attribute, nodes, mapping) + return _numeric_ac(M, mapping) + + +def attribute_ac(M): + """Compute assortativity for attribute matrix M. + + Parameters + ---------- + M : numpy.ndarray + 2D ndarray representing the attribute mixing matrix. + + Notes + ----- + This computes Eq. (2) in Ref. [1]_ , (trace(e)-sum(e^2))/(1-sum(e^2)), + where e is the joint probability distribution (mixing matrix) + of the specified attribute. + + References + ---------- + .. [1] M. E. J. Newman, Mixing patterns in networks, + Physical Review E, 67 026126, 2003 + """ + if M.sum() != 1.0: + M = M / M.sum() + s = (M @ M).sum() + t = M.trace() + r = (t - s) / (1 - s) + return float(r) + + +def _numeric_ac(M, mapping): + # M is a 2D numpy array + # numeric assortativity coefficient, pearsonr + import numpy as np + + if M.sum() != 1.0: + M = M / M.sum() + x = np.array(list(mapping.keys())) + y = x # x and y have the same support + idx = list(mapping.values()) + a = M.sum(axis=0) + b = M.sum(axis=1) + vara = (a[idx] * x**2).sum() - ((a[idx] * x).sum()) ** 2 + varb = (b[idx] * y**2).sum() - ((b[idx] * y).sum()) ** 2 + xy = np.outer(x, y) + ab = np.outer(a[idx], b[idx]) + return float((xy * (M - ab)).sum() / np.sqrt(vara * varb)) diff --git a/lib/python3.10/site-packages/networkx/algorithms/assortativity/mixing.py b/lib/python3.10/site-packages/networkx/algorithms/assortativity/mixing.py new file mode 100644 index 0000000000000000000000000000000000000000..1762d4e56c96624ecb4cccf1f2247f46159a12e4 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/assortativity/mixing.py @@ -0,0 +1,255 @@ +""" +Mixing matrices for node attributes and degree. +""" + +import networkx as nx +from networkx.algorithms.assortativity.pairs import node_attribute_xy, node_degree_xy +from networkx.utils import dict_to_numpy_array + +__all__ = [ + "attribute_mixing_matrix", + "attribute_mixing_dict", + "degree_mixing_matrix", + "degree_mixing_dict", + "mixing_dict", +] + + +@nx._dispatchable(node_attrs="attribute") +def attribute_mixing_dict(G, attribute, nodes=None, normalized=False): + """Returns dictionary representation of mixing matrix for attribute. + + Parameters + ---------- + G : graph + NetworkX graph object. + + attribute : string + Node attribute key. + + nodes: list or iterable (optional) + Unse nodes in container to build the dict. The default is all nodes. + + normalized : bool (default=False) + Return counts if False or probabilities if True. + + Examples + -------- + >>> G = nx.Graph() + >>> G.add_nodes_from([0, 1], color="red") + >>> G.add_nodes_from([2, 3], color="blue") + >>> G.add_edge(1, 3) + >>> d = nx.attribute_mixing_dict(G, "color") + >>> print(d["red"]["blue"]) + 1 + >>> print(d["blue"]["red"]) # d symmetric for undirected graphs + 1 + + Returns + ------- + d : dictionary + Counts or joint probability of occurrence of attribute pairs. + """ + xy_iter = node_attribute_xy(G, attribute, nodes) + return mixing_dict(xy_iter, normalized=normalized) + + +@nx._dispatchable(node_attrs="attribute") +def attribute_mixing_matrix(G, attribute, nodes=None, mapping=None, normalized=True): + """Returns mixing matrix for attribute. + + Parameters + ---------- + G : graph + NetworkX graph object. + + attribute : string + Node attribute key. + + nodes: list or iterable (optional) + Use only nodes in container to build the matrix. The default is + all nodes. + + mapping : dictionary, optional + Mapping from node attribute to integer index in matrix. + If not specified, an arbitrary ordering will be used. + + normalized : bool (default=True) + Return counts if False or probabilities if True. + + Returns + ------- + m: numpy array + Counts or joint probability of occurrence of attribute pairs. + + Notes + ----- + If each node has a unique attribute value, the unnormalized mixing matrix + will be equal to the adjacency matrix. To get a denser mixing matrix, + the rounding can be performed to form groups of nodes with equal values. + For example, the exact height of persons in cm (180.79155222, 163.9080892, + 163.30095355, 167.99016217, 168.21590163, ...) can be rounded to (180, 163, + 163, 168, 168, ...). + + Definitions of attribute mixing matrix vary on whether the matrix + should include rows for attribute values that don't arise. Here we + do not include such empty-rows. But you can force them to appear + by inputting a `mapping` that includes those values. + + Examples + -------- + >>> G = nx.path_graph(3) + >>> gender = {0: "male", 1: "female", 2: "female"} + >>> nx.set_node_attributes(G, gender, "gender") + >>> mapping = {"male": 0, "female": 1} + >>> mix_mat = nx.attribute_mixing_matrix(G, "gender", mapping=mapping) + >>> mix_mat + array([[0. , 0.25], + [0.25, 0.5 ]]) + """ + d = attribute_mixing_dict(G, attribute, nodes) + a = dict_to_numpy_array(d, mapping=mapping) + if normalized: + a = a / a.sum() + return a + + +@nx._dispatchable(edge_attrs="weight") +def degree_mixing_dict(G, x="out", y="in", weight=None, nodes=None, normalized=False): + """Returns dictionary representation of mixing matrix for degree. + + Parameters + ---------- + G : graph + NetworkX graph object. + + x: string ('in','out') + The degree type for source node (directed graphs only). + + y: string ('in','out') + The degree type for target node (directed graphs only). + + weight: string or None, optional (default=None) + The edge attribute that holds the numerical value used + as a weight. If None, then each edge has weight 1. + The degree is the sum of the edge weights adjacent to the node. + + normalized : bool (default=False) + Return counts if False or probabilities if True. + + Returns + ------- + d: dictionary + Counts or joint probability of occurrence of degree pairs. + """ + xy_iter = node_degree_xy(G, x=x, y=y, nodes=nodes, weight=weight) + return mixing_dict(xy_iter, normalized=normalized) + + +@nx._dispatchable(edge_attrs="weight") +def degree_mixing_matrix( + G, x="out", y="in", weight=None, nodes=None, normalized=True, mapping=None +): + """Returns mixing matrix for attribute. + + Parameters + ---------- + G : graph + NetworkX graph object. + + x: string ('in','out') + The degree type for source node (directed graphs only). + + y: string ('in','out') + The degree type for target node (directed graphs only). + + nodes: list or iterable (optional) + Build the matrix using only nodes in container. + The default is all nodes. + + weight: string or None, optional (default=None) + The edge attribute that holds the numerical value used + as a weight. If None, then each edge has weight 1. + The degree is the sum of the edge weights adjacent to the node. + + normalized : bool (default=True) + Return counts if False or probabilities if True. + + mapping : dictionary, optional + Mapping from node degree to integer index in matrix. + If not specified, an arbitrary ordering will be used. + + Returns + ------- + m: numpy array + Counts, or joint probability, of occurrence of node degree. + + Notes + ----- + Definitions of degree mixing matrix vary on whether the matrix + should include rows for degree values that don't arise. Here we + do not include such empty-rows. But you can force them to appear + by inputting a `mapping` that includes those values. See examples. + + Examples + -------- + >>> G = nx.star_graph(3) + >>> mix_mat = nx.degree_mixing_matrix(G) + >>> mix_mat + array([[0. , 0.5], + [0.5, 0. ]]) + + If you want every possible degree to appear as a row, even if no nodes + have that degree, use `mapping` as follows, + + >>> max_degree = max(deg for n, deg in G.degree) + >>> mapping = {x: x for x in range(max_degree + 1)} # identity mapping + >>> mix_mat = nx.degree_mixing_matrix(G, mapping=mapping) + >>> mix_mat + array([[0. , 0. , 0. , 0. ], + [0. , 0. , 0. , 0.5], + [0. , 0. , 0. , 0. ], + [0. , 0.5, 0. , 0. ]]) + """ + d = degree_mixing_dict(G, x=x, y=y, nodes=nodes, weight=weight) + a = dict_to_numpy_array(d, mapping=mapping) + if normalized: + a = a / a.sum() + return a + + +def mixing_dict(xy, normalized=False): + """Returns a dictionary representation of mixing matrix. + + Parameters + ---------- + xy : list or container of two-tuples + Pairs of (x,y) items. + + attribute : string + Node attribute key + + normalized : bool (default=False) + Return counts if False or probabilities if True. + + Returns + ------- + d: dictionary + Counts or Joint probability of occurrence of values in xy. + """ + d = {} + psum = 0.0 + for x, y in xy: + if x not in d: + d[x] = {} + if y not in d: + d[y] = {} + v = d[x].get(y, 0) + d[x][y] = v + 1 + psum += 1 + + if normalized: + for _, jdict in d.items(): + for j in jdict: + jdict[j] /= psum + return d diff --git a/lib/python3.10/site-packages/networkx/algorithms/assortativity/neighbor_degree.py b/lib/python3.10/site-packages/networkx/algorithms/assortativity/neighbor_degree.py new file mode 100644 index 0000000000000000000000000000000000000000..6488d041a8bdc93ef3591283781b81bcf7f47dab --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/assortativity/neighbor_degree.py @@ -0,0 +1,160 @@ +import networkx as nx + +__all__ = ["average_neighbor_degree"] + + +@nx._dispatchable(edge_attrs="weight") +def average_neighbor_degree(G, source="out", target="out", nodes=None, weight=None): + r"""Returns the average degree of the neighborhood of each node. + + In an undirected graph, the neighborhood `N(i)` of node `i` contains the + nodes that are connected to `i` by an edge. + + For directed graphs, `N(i)` is defined according to the parameter `source`: + + - if source is 'in', then `N(i)` consists of predecessors of node `i`. + - if source is 'out', then `N(i)` consists of successors of node `i`. + - if source is 'in+out', then `N(i)` is both predecessors and successors. + + The average neighborhood degree of a node `i` is + + .. math:: + + k_{nn,i} = \frac{1}{|N(i)|} \sum_{j \in N(i)} k_j + + where `N(i)` are the neighbors of node `i` and `k_j` is + the degree of node `j` which belongs to `N(i)`. For weighted + graphs, an analogous measure can be defined [1]_, + + .. math:: + + k_{nn,i}^{w} = \frac{1}{s_i} \sum_{j \in N(i)} w_{ij} k_j + + where `s_i` is the weighted degree of node `i`, `w_{ij}` + is the weight of the edge that links `i` and `j` and + `N(i)` are the neighbors of node `i`. + + + Parameters + ---------- + G : NetworkX graph + + source : string ("in"|"out"|"in+out"), optional (default="out") + Directed graphs only. + Use "in"- or "out"-neighbors of source node. + + target : string ("in"|"out"|"in+out"), optional (default="out") + Directed graphs only. + Use "in"- or "out"-degree for target node. + + nodes : list or iterable, optional (default=G.nodes) + Compute neighbor degree only for specified nodes. + + weight : string or None, optional (default=None) + The edge attribute that holds the numerical value used as a weight. + If None, then each edge has weight 1. + + Returns + ------- + d: dict + A dictionary keyed by node to the average degree of its neighbors. + + Raises + ------ + NetworkXError + If either `source` or `target` are not one of 'in', 'out', or 'in+out'. + If either `source` or `target` is passed for an undirected graph. + + Examples + -------- + >>> G = nx.path_graph(4) + >>> G.edges[0, 1]["weight"] = 5 + >>> G.edges[2, 3]["weight"] = 3 + + >>> nx.average_neighbor_degree(G) + {0: 2.0, 1: 1.5, 2: 1.5, 3: 2.0} + >>> nx.average_neighbor_degree(G, weight="weight") + {0: 2.0, 1: 1.1666666666666667, 2: 1.25, 3: 2.0} + + >>> G = nx.DiGraph() + >>> nx.add_path(G, [0, 1, 2, 3]) + >>> nx.average_neighbor_degree(G, source="in", target="in") + {0: 0.0, 1: 0.0, 2: 1.0, 3: 1.0} + + >>> nx.average_neighbor_degree(G, source="out", target="out") + {0: 1.0, 1: 1.0, 2: 0.0, 3: 0.0} + + See Also + -------- + average_degree_connectivity + + References + ---------- + .. [1] A. Barrat, M. Barthélemy, R. Pastor-Satorras, and A. Vespignani, + "The architecture of complex weighted networks". + PNAS 101 (11): 3747–3752 (2004). + """ + if G.is_directed(): + if source == "in": + source_degree = G.in_degree + elif source == "out": + source_degree = G.out_degree + elif source == "in+out": + source_degree = G.degree + else: + raise nx.NetworkXError( + f"source argument {source} must be 'in', 'out' or 'in+out'" + ) + + if target == "in": + target_degree = G.in_degree + elif target == "out": + target_degree = G.out_degree + elif target == "in+out": + target_degree = G.degree + else: + raise nx.NetworkXError( + f"target argument {target} must be 'in', 'out' or 'in+out'" + ) + else: + if source != "out" or target != "out": + raise nx.NetworkXError( + f"source and target arguments are only supported for directed graphs" + ) + source_degree = target_degree = G.degree + + # precompute target degrees -- should *not* be weighted degree + t_deg = dict(target_degree()) + + # Set up both predecessor and successor neighbor dicts leaving empty if not needed + G_P = G_S = {n: {} for n in G} + if G.is_directed(): + # "in" or "in+out" cases: G_P contains predecessors + if "in" in source: + G_P = G.pred + # "out" or "in+out" cases: G_S contains successors + if "out" in source: + G_S = G.succ + else: + # undirected leave G_P empty but G_S is the adjacency + G_S = G.adj + + # Main loop: Compute average degree of neighbors + avg = {} + for n, deg in source_degree(nodes, weight=weight): + # handle degree zero average + if deg == 0: + avg[n] = 0.0 + continue + + # we sum over both G_P and G_S, but one of the two is usually empty. + if weight is None: + avg[n] = ( + sum(t_deg[nbr] for nbr in G_S[n]) + sum(t_deg[nbr] for nbr in G_P[n]) + ) / deg + else: + avg[n] = ( + sum(dd.get(weight, 1) * t_deg[nbr] for nbr, dd in G_S[n].items()) + + sum(dd.get(weight, 1) * t_deg[nbr] for nbr, dd in G_P[n].items()) + ) / deg + return avg diff --git a/lib/python3.10/site-packages/networkx/algorithms/assortativity/pairs.py b/lib/python3.10/site-packages/networkx/algorithms/assortativity/pairs.py new file mode 100644 index 0000000000000000000000000000000000000000..ea5fd287545c80dd2ebbb2b253d5ab0ab7480743 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/assortativity/pairs.py @@ -0,0 +1,127 @@ +"""Generators of x-y pairs of node data.""" + +import networkx as nx + +__all__ = ["node_attribute_xy", "node_degree_xy"] + + +@nx._dispatchable(node_attrs="attribute") +def node_attribute_xy(G, attribute, nodes=None): + """Yields 2-tuples of node attribute values for all edges in `G`. + + This generator yields, for each edge in `G` incident to a node in `nodes`, + a 2-tuple of form ``(attribute value, attribute value)`` for the parameter + specified node-attribute. + + Parameters + ---------- + G: NetworkX graph + + attribute: key + The node attribute key. + + nodes: list or iterable (optional) + Use only edges that are incident to specified nodes. + The default is all nodes. + + Yields + ------ + (x, y): 2-tuple + Generates 2-tuple of (attribute, attribute) values. + + Examples + -------- + >>> G = nx.DiGraph() + >>> G.add_node(1, color="red") + >>> G.add_node(2, color="blue") + >>> G.add_node(3, color="green") + >>> G.add_edge(1, 2) + >>> list(nx.node_attribute_xy(G, "color")) + [('red', 'blue')] + + Notes + ----- + For undirected graphs, each edge is produced twice, once for each edge + representation (u, v) and (v, u), with the exception of self-loop edges + which only appear once. + """ + if nodes is None: + nodes = set(G) + else: + nodes = set(nodes) + Gnodes = G.nodes + for u, nbrsdict in G.adjacency(): + if u not in nodes: + continue + uattr = Gnodes[u].get(attribute, None) + if G.is_multigraph(): + for v, keys in nbrsdict.items(): + vattr = Gnodes[v].get(attribute, None) + for _ in keys: + yield (uattr, vattr) + else: + for v in nbrsdict: + vattr = Gnodes[v].get(attribute, None) + yield (uattr, vattr) + + +@nx._dispatchable(edge_attrs="weight") +def node_degree_xy(G, x="out", y="in", weight=None, nodes=None): + """Yields 2-tuples of ``(degree, degree)`` values for edges in `G`. + + This generator yields, for each edge in `G` incident to a node in `nodes`, + a 2-tuple of form ``(degree, degree)``. The node degrees are weighted + when a `weight` attribute is specified. + + Parameters + ---------- + G: NetworkX graph + + x: string ('in','out') + The degree type for source node (directed graphs only). + + y: string ('in','out') + The degree type for target node (directed graphs only). + + weight: string or None, optional (default=None) + The edge attribute that holds the numerical value used + as a weight. If None, then each edge has weight 1. + The degree is the sum of the edge weights adjacent to the node. + + nodes: list or iterable (optional) + Use only edges that are adjacency to specified nodes. + The default is all nodes. + + Yields + ------ + (x, y): 2-tuple + Generates 2-tuple of (degree, degree) values. + + Examples + -------- + >>> G = nx.DiGraph() + >>> G.add_edge(1, 2) + >>> list(nx.node_degree_xy(G, x="out", y="in")) + [(1, 1)] + >>> list(nx.node_degree_xy(G, x="in", y="out")) + [(0, 0)] + + Notes + ----- + For undirected graphs, each edge is produced twice, once for each edge + representation (u, v) and (v, u), with the exception of self-loop edges + which only appear once. + """ + nodes = set(G) if nodes is None else set(nodes) + if G.is_directed(): + direction = {"out": G.out_degree, "in": G.in_degree} + xdeg = direction[x] + ydeg = direction[y] + else: + xdeg = ydeg = G.degree + + for u, degu in xdeg(nodes, weight=weight): + # use G.edges to treat multigraphs correctly + neighbors = (nbr for _, nbr in G.edges(u) if nbr in nodes) + for _, degv in ydeg(neighbors, weight=weight): + yield degu, degv diff --git a/lib/python3.10/site-packages/networkx/algorithms/assortativity/tests/__init__.py b/lib/python3.10/site-packages/networkx/algorithms/assortativity/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/python3.10/site-packages/networkx/algorithms/assortativity/tests/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/assortativity/tests/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3daac356a785bfacc3ed75aa22aa089074bc925a Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/assortativity/tests/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/assortativity/tests/__pycache__/base_test.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/assortativity/tests/__pycache__/base_test.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f96bf115a32aab8ab9f56008d7c06f19adb0bab Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/assortativity/tests/__pycache__/base_test.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/assortativity/tests/__pycache__/test_connectivity.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/assortativity/tests/__pycache__/test_connectivity.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e68706109a97a122fdadd1d213bb1c5270990ae6 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/assortativity/tests/__pycache__/test_connectivity.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/assortativity/tests/__pycache__/test_correlation.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/assortativity/tests/__pycache__/test_correlation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a823b84b2182528356c5cb0055b14c2c561e71a Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/assortativity/tests/__pycache__/test_correlation.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/assortativity/tests/__pycache__/test_mixing.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/assortativity/tests/__pycache__/test_mixing.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b13dabd786db417f4d4a503ded24d2298f563fe3 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/assortativity/tests/__pycache__/test_mixing.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/assortativity/tests/__pycache__/test_neighbor_degree.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/assortativity/tests/__pycache__/test_neighbor_degree.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a0012dd19572b44f5921ac497fcfb01f5ea9fe17 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/assortativity/tests/__pycache__/test_neighbor_degree.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/assortativity/tests/__pycache__/test_pairs.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/assortativity/tests/__pycache__/test_pairs.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cbabde4dcf64cccc33cace1325470288142551a4 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/assortativity/tests/__pycache__/test_pairs.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/assortativity/tests/base_test.py b/lib/python3.10/site-packages/networkx/algorithms/assortativity/tests/base_test.py new file mode 100644 index 0000000000000000000000000000000000000000..46d6300649d3b4658a7263cad04354988b4da312 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/assortativity/tests/base_test.py @@ -0,0 +1,81 @@ +import networkx as nx + + +class BaseTestAttributeMixing: + @classmethod + def setup_class(cls): + G = nx.Graph() + G.add_nodes_from([0, 1], fish="one") + G.add_nodes_from([2, 3], fish="two") + G.add_nodes_from([4], fish="red") + G.add_nodes_from([5], fish="blue") + G.add_edges_from([(0, 1), (2, 3), (0, 4), (2, 5)]) + cls.G = G + + D = nx.DiGraph() + D.add_nodes_from([0, 1], fish="one") + D.add_nodes_from([2, 3], fish="two") + D.add_nodes_from([4], fish="red") + D.add_nodes_from([5], fish="blue") + D.add_edges_from([(0, 1), (2, 3), (0, 4), (2, 5)]) + cls.D = D + + M = nx.MultiGraph() + M.add_nodes_from([0, 1], fish="one") + M.add_nodes_from([2, 3], fish="two") + M.add_nodes_from([4], fish="red") + M.add_nodes_from([5], fish="blue") + M.add_edges_from([(0, 1), (0, 1), (2, 3)]) + cls.M = M + + S = nx.Graph() + S.add_nodes_from([0, 1], fish="one") + S.add_nodes_from([2, 3], fish="two") + S.add_nodes_from([4], fish="red") + S.add_nodes_from([5], fish="blue") + S.add_edge(0, 0) + S.add_edge(2, 2) + cls.S = S + + N = nx.Graph() + N.add_nodes_from([0, 1], margin=-2) + N.add_nodes_from([2, 3], margin=-2) + N.add_nodes_from([4], margin=-3) + N.add_nodes_from([5], margin=-4) + N.add_edges_from([(0, 1), (2, 3), (0, 4), (2, 5)]) + cls.N = N + + F = nx.Graph() + F.add_edges_from([(0, 3), (1, 3), (2, 3)], weight=0.5) + F.add_edge(0, 2, weight=1) + nx.set_node_attributes(F, dict(F.degree(weight="weight")), "margin") + cls.F = F + + K = nx.Graph() + K.add_nodes_from([1, 2], margin=-1) + K.add_nodes_from([3], margin=1) + K.add_nodes_from([4], margin=2) + K.add_edges_from([(3, 4), (1, 2), (1, 3)]) + cls.K = K + + +class BaseTestDegreeMixing: + @classmethod + def setup_class(cls): + cls.P4 = nx.path_graph(4) + cls.D = nx.DiGraph() + cls.D.add_edges_from([(0, 2), (0, 3), (1, 3), (2, 3)]) + cls.D2 = nx.DiGraph() + cls.D2.add_edges_from([(0, 3), (1, 0), (1, 2), (2, 4), (4, 1), (4, 3), (4, 2)]) + cls.M = nx.MultiGraph() + nx.add_path(cls.M, range(4)) + cls.M.add_edge(0, 1) + cls.S = nx.Graph() + cls.S.add_edges_from([(0, 0), (1, 1)]) + cls.W = nx.Graph() + cls.W.add_edges_from([(0, 3), (1, 3), (2, 3)], weight=0.5) + cls.W.add_edge(0, 2, weight=1) + S1 = nx.star_graph(4) + S2 = nx.star_graph(4) + cls.DS = nx.disjoint_union(S1, S2) + cls.DS.add_edge(4, 5) diff --git a/lib/python3.10/site-packages/networkx/algorithms/assortativity/tests/test_connectivity.py b/lib/python3.10/site-packages/networkx/algorithms/assortativity/tests/test_connectivity.py new file mode 100644 index 0000000000000000000000000000000000000000..21c6287bbe6b0bfc9aa41201b593f342b2d3976e --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/assortativity/tests/test_connectivity.py @@ -0,0 +1,143 @@ +from itertools import permutations + +import pytest + +import networkx as nx + + +class TestNeighborConnectivity: + def test_degree_p4(self): + G = nx.path_graph(4) + answer = {1: 2.0, 2: 1.5} + nd = nx.average_degree_connectivity(G) + assert nd == answer + + D = G.to_directed() + answer = {2: 2.0, 4: 1.5} + nd = nx.average_degree_connectivity(D) + assert nd == answer + + answer = {1: 2.0, 2: 1.5} + D = G.to_directed() + nd = nx.average_degree_connectivity(D, source="in", target="in") + assert nd == answer + + D = G.to_directed() + nd = nx.average_degree_connectivity(D, source="in", target="in") + assert nd == answer + + def test_degree_p4_weighted(self): + G = nx.path_graph(4) + G[1][2]["weight"] = 4 + answer = {1: 2.0, 2: 1.8} + nd = nx.average_degree_connectivity(G, weight="weight") + assert nd == answer + answer = {1: 2.0, 2: 1.5} + nd = nx.average_degree_connectivity(G) + assert nd == answer + + D = G.to_directed() + answer = {2: 2.0, 4: 1.8} + nd = nx.average_degree_connectivity(D, weight="weight") + assert nd == answer + + answer = {1: 2.0, 2: 1.8} + D = G.to_directed() + nd = nx.average_degree_connectivity( + D, weight="weight", source="in", target="in" + ) + assert nd == answer + + D = G.to_directed() + nd = nx.average_degree_connectivity( + D, source="in", target="out", weight="weight" + ) + assert nd == answer + + def test_weight_keyword(self): + G = nx.path_graph(4) + G[1][2]["other"] = 4 + answer = {1: 2.0, 2: 1.8} + nd = nx.average_degree_connectivity(G, weight="other") + assert nd == answer + answer = {1: 2.0, 2: 1.5} + nd = nx.average_degree_connectivity(G, weight=None) + assert nd == answer + + D = G.to_directed() + answer = {2: 2.0, 4: 1.8} + nd = nx.average_degree_connectivity(D, weight="other") + assert nd == answer + + answer = {1: 2.0, 2: 1.8} + D = G.to_directed() + nd = nx.average_degree_connectivity(D, weight="other", source="in", target="in") + assert nd == answer + + D = G.to_directed() + nd = nx.average_degree_connectivity(D, weight="other", source="in", target="in") + assert nd == answer + + def test_degree_barrat(self): + G = nx.star_graph(5) + G.add_edges_from([(5, 6), (5, 7), (5, 8), (5, 9)]) + G[0][5]["weight"] = 5 + nd = nx.average_degree_connectivity(G)[5] + assert nd == 1.8 + nd = nx.average_degree_connectivity(G, weight="weight")[5] + assert nd == pytest.approx(3.222222, abs=1e-5) + + def test_zero_deg(self): + G = nx.DiGraph() + G.add_edge(1, 2) + G.add_edge(1, 3) + G.add_edge(1, 4) + c = nx.average_degree_connectivity(G) + assert c == {1: 0, 3: 1} + c = nx.average_degree_connectivity(G, source="in", target="in") + assert c == {0: 0, 1: 0} + c = nx.average_degree_connectivity(G, source="in", target="out") + assert c == {0: 0, 1: 3} + c = nx.average_degree_connectivity(G, source="in", target="in+out") + assert c == {0: 0, 1: 3} + c = nx.average_degree_connectivity(G, source="out", target="out") + assert c == {0: 0, 3: 0} + c = nx.average_degree_connectivity(G, source="out", target="in") + assert c == {0: 0, 3: 1} + c = nx.average_degree_connectivity(G, source="out", target="in+out") + assert c == {0: 0, 3: 1} + + def test_in_out_weight(self): + G = nx.DiGraph() + G.add_edge(1, 2, weight=1) + G.add_edge(1, 3, weight=1) + G.add_edge(3, 1, weight=1) + for s, t in permutations(["in", "out", "in+out"], 2): + c = nx.average_degree_connectivity(G, source=s, target=t) + cw = nx.average_degree_connectivity(G, source=s, target=t, weight="weight") + assert c == cw + + def test_invalid_source(self): + with pytest.raises(nx.NetworkXError): + G = nx.DiGraph() + nx.average_degree_connectivity(G, source="bogus") + + def test_invalid_target(self): + with pytest.raises(nx.NetworkXError): + G = nx.DiGraph() + nx.average_degree_connectivity(G, target="bogus") + + def test_invalid_undirected_graph(self): + G = nx.Graph() + with pytest.raises(nx.NetworkXError): + nx.average_degree_connectivity(G, target="bogus") + with pytest.raises(nx.NetworkXError): + nx.average_degree_connectivity(G, source="bogus") + + def test_single_node(self): + # TODO Is this really the intended behavior for providing a + # single node as the argument `nodes`? Shouldn't the function + # just return the connectivity value itself? + G = nx.trivial_graph() + conn = nx.average_degree_connectivity(G, nodes=0) + assert conn == {0: 0} diff --git a/lib/python3.10/site-packages/networkx/algorithms/assortativity/tests/test_correlation.py b/lib/python3.10/site-packages/networkx/algorithms/assortativity/tests/test_correlation.py new file mode 100644 index 0000000000000000000000000000000000000000..5203f9449fd022525b97a19cbe78498e33fb09a3 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/assortativity/tests/test_correlation.py @@ -0,0 +1,123 @@ +import pytest + +np = pytest.importorskip("numpy") +pytest.importorskip("scipy") + + +import networkx as nx +from networkx.algorithms.assortativity.correlation import attribute_ac + +from .base_test import BaseTestAttributeMixing, BaseTestDegreeMixing + + +class TestDegreeMixingCorrelation(BaseTestDegreeMixing): + def test_degree_assortativity_undirected(self): + r = nx.degree_assortativity_coefficient(self.P4) + np.testing.assert_almost_equal(r, -1.0 / 2, decimal=4) + + def test_degree_assortativity_node_kwargs(self): + G = nx.Graph() + edges = [(0, 1), (0, 3), (1, 2), (1, 3), (1, 4), (5, 9), (9, 0)] + G.add_edges_from(edges) + r = nx.degree_assortativity_coefficient(G, nodes=[1, 2, 4]) + np.testing.assert_almost_equal(r, -1.0, decimal=4) + + def test_degree_assortativity_directed(self): + r = nx.degree_assortativity_coefficient(self.D) + np.testing.assert_almost_equal(r, -0.57735, decimal=4) + + def test_degree_assortativity_directed2(self): + """Test degree assortativity for a directed graph where the set of + in/out degree does not equal the total degree.""" + r = nx.degree_assortativity_coefficient(self.D2) + np.testing.assert_almost_equal(r, 0.14852, decimal=4) + + def test_degree_assortativity_multigraph(self): + r = nx.degree_assortativity_coefficient(self.M) + np.testing.assert_almost_equal(r, -1.0 / 7.0, decimal=4) + + def test_degree_pearson_assortativity_undirected(self): + r = nx.degree_pearson_correlation_coefficient(self.P4) + np.testing.assert_almost_equal(r, -1.0 / 2, decimal=4) + + def test_degree_pearson_assortativity_directed(self): + r = nx.degree_pearson_correlation_coefficient(self.D) + np.testing.assert_almost_equal(r, -0.57735, decimal=4) + + def test_degree_pearson_assortativity_directed2(self): + """Test degree assortativity with Pearson for a directed graph where + the set of in/out degree does not equal the total degree.""" + r = nx.degree_pearson_correlation_coefficient(self.D2) + np.testing.assert_almost_equal(r, 0.14852, decimal=4) + + def test_degree_pearson_assortativity_multigraph(self): + r = nx.degree_pearson_correlation_coefficient(self.M) + np.testing.assert_almost_equal(r, -1.0 / 7.0, decimal=4) + + def test_degree_assortativity_weighted(self): + r = nx.degree_assortativity_coefficient(self.W, weight="weight") + np.testing.assert_almost_equal(r, -0.1429, decimal=4) + + def test_degree_assortativity_double_star(self): + r = nx.degree_assortativity_coefficient(self.DS) + np.testing.assert_almost_equal(r, -0.9339, decimal=4) + + +class TestAttributeMixingCorrelation(BaseTestAttributeMixing): + def test_attribute_assortativity_undirected(self): + r = nx.attribute_assortativity_coefficient(self.G, "fish") + assert r == 6.0 / 22.0 + + def test_attribute_assortativity_directed(self): + r = nx.attribute_assortativity_coefficient(self.D, "fish") + assert r == 1.0 / 3.0 + + def test_attribute_assortativity_multigraph(self): + r = nx.attribute_assortativity_coefficient(self.M, "fish") + assert r == 1.0 + + def test_attribute_assortativity_coefficient(self): + # from "Mixing patterns in networks" + # fmt: off + a = np.array([[0.258, 0.016, 0.035, 0.013], + [0.012, 0.157, 0.058, 0.019], + [0.013, 0.023, 0.306, 0.035], + [0.005, 0.007, 0.024, 0.016]]) + # fmt: on + r = attribute_ac(a) + np.testing.assert_almost_equal(r, 0.623, decimal=3) + + def test_attribute_assortativity_coefficient2(self): + # fmt: off + a = np.array([[0.18, 0.02, 0.01, 0.03], + [0.02, 0.20, 0.03, 0.02], + [0.01, 0.03, 0.16, 0.01], + [0.03, 0.02, 0.01, 0.22]]) + # fmt: on + r = attribute_ac(a) + np.testing.assert_almost_equal(r, 0.68, decimal=2) + + def test_attribute_assortativity(self): + a = np.array([[50, 50, 0], [50, 50, 0], [0, 0, 2]]) + r = attribute_ac(a) + np.testing.assert_almost_equal(r, 0.029, decimal=3) + + def test_attribute_assortativity_negative(self): + r = nx.numeric_assortativity_coefficient(self.N, "margin") + np.testing.assert_almost_equal(r, -0.2903, decimal=4) + + def test_assortativity_node_kwargs(self): + G = nx.Graph() + G.add_nodes_from([0, 1], size=2) + G.add_nodes_from([2, 3], size=3) + G.add_edges_from([(0, 1), (2, 3)]) + r = nx.numeric_assortativity_coefficient(G, "size", nodes=[0, 3]) + np.testing.assert_almost_equal(r, 1.0, decimal=4) + + def test_attribute_assortativity_float(self): + r = nx.numeric_assortativity_coefficient(self.F, "margin") + np.testing.assert_almost_equal(r, -0.1429, decimal=4) + + def test_attribute_assortativity_mixed(self): + r = nx.numeric_assortativity_coefficient(self.K, "margin") + np.testing.assert_almost_equal(r, 0.4340, decimal=4) diff --git a/lib/python3.10/site-packages/networkx/algorithms/assortativity/tests/test_mixing.py b/lib/python3.10/site-packages/networkx/algorithms/assortativity/tests/test_mixing.py new file mode 100644 index 0000000000000000000000000000000000000000..9af09867235b9092837b517ca542e8a85eb602ac --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/assortativity/tests/test_mixing.py @@ -0,0 +1,176 @@ +import pytest + +np = pytest.importorskip("numpy") + + +import networkx as nx + +from .base_test import BaseTestAttributeMixing, BaseTestDegreeMixing + + +class TestDegreeMixingDict(BaseTestDegreeMixing): + def test_degree_mixing_dict_undirected(self): + d = nx.degree_mixing_dict(self.P4) + d_result = {1: {2: 2}, 2: {1: 2, 2: 2}} + assert d == d_result + + def test_degree_mixing_dict_undirected_normalized(self): + d = nx.degree_mixing_dict(self.P4, normalized=True) + d_result = {1: {2: 1.0 / 3}, 2: {1: 1.0 / 3, 2: 1.0 / 3}} + assert d == d_result + + def test_degree_mixing_dict_directed(self): + d = nx.degree_mixing_dict(self.D) + print(d) + d_result = {1: {3: 2}, 2: {1: 1, 3: 1}, 3: {}} + assert d == d_result + + def test_degree_mixing_dict_multigraph(self): + d = nx.degree_mixing_dict(self.M) + d_result = {1: {2: 1}, 2: {1: 1, 3: 3}, 3: {2: 3}} + assert d == d_result + + def test_degree_mixing_dict_weighted(self): + d = nx.degree_mixing_dict(self.W, weight="weight") + d_result = {0.5: {1.5: 1}, 1.5: {1.5: 6, 0.5: 1}} + assert d == d_result + + +class TestDegreeMixingMatrix(BaseTestDegreeMixing): + def test_degree_mixing_matrix_undirected(self): + # fmt: off + a_result = np.array([[0, 2], + [2, 2]] + ) + # fmt: on + a = nx.degree_mixing_matrix(self.P4, normalized=False) + np.testing.assert_equal(a, a_result) + a = nx.degree_mixing_matrix(self.P4) + np.testing.assert_equal(a, a_result / a_result.sum()) + + def test_degree_mixing_matrix_directed(self): + # fmt: off + a_result = np.array([[0, 0, 2], + [1, 0, 1], + [0, 0, 0]] + ) + # fmt: on + a = nx.degree_mixing_matrix(self.D, normalized=False) + np.testing.assert_equal(a, a_result) + a = nx.degree_mixing_matrix(self.D) + np.testing.assert_equal(a, a_result / a_result.sum()) + + def test_degree_mixing_matrix_multigraph(self): + # fmt: off + a_result = np.array([[0, 1, 0], + [1, 0, 3], + [0, 3, 0]] + ) + # fmt: on + a = nx.degree_mixing_matrix(self.M, normalized=False) + np.testing.assert_equal(a, a_result) + a = nx.degree_mixing_matrix(self.M) + np.testing.assert_equal(a, a_result / a_result.sum()) + + def test_degree_mixing_matrix_selfloop(self): + # fmt: off + a_result = np.array([[2]]) + # fmt: on + a = nx.degree_mixing_matrix(self.S, normalized=False) + np.testing.assert_equal(a, a_result) + a = nx.degree_mixing_matrix(self.S) + np.testing.assert_equal(a, a_result / a_result.sum()) + + def test_degree_mixing_matrix_weighted(self): + a_result = np.array([[0.0, 1.0], [1.0, 6.0]]) + a = nx.degree_mixing_matrix(self.W, weight="weight", normalized=False) + np.testing.assert_equal(a, a_result) + a = nx.degree_mixing_matrix(self.W, weight="weight") + np.testing.assert_equal(a, a_result / float(a_result.sum())) + + def test_degree_mixing_matrix_mapping(self): + a_result = np.array([[6.0, 1.0], [1.0, 0.0]]) + mapping = {0.5: 1, 1.5: 0} + a = nx.degree_mixing_matrix( + self.W, weight="weight", normalized=False, mapping=mapping + ) + np.testing.assert_equal(a, a_result) + + +class TestAttributeMixingDict(BaseTestAttributeMixing): + def test_attribute_mixing_dict_undirected(self): + d = nx.attribute_mixing_dict(self.G, "fish") + d_result = { + "one": {"one": 2, "red": 1}, + "two": {"two": 2, "blue": 1}, + "red": {"one": 1}, + "blue": {"two": 1}, + } + assert d == d_result + + def test_attribute_mixing_dict_directed(self): + d = nx.attribute_mixing_dict(self.D, "fish") + d_result = { + "one": {"one": 1, "red": 1}, + "two": {"two": 1, "blue": 1}, + "red": {}, + "blue": {}, + } + assert d == d_result + + def test_attribute_mixing_dict_multigraph(self): + d = nx.attribute_mixing_dict(self.M, "fish") + d_result = {"one": {"one": 4}, "two": {"two": 2}} + assert d == d_result + + +class TestAttributeMixingMatrix(BaseTestAttributeMixing): + def test_attribute_mixing_matrix_undirected(self): + mapping = {"one": 0, "two": 1, "red": 2, "blue": 3} + a_result = np.array([[2, 0, 1, 0], [0, 2, 0, 1], [1, 0, 0, 0], [0, 1, 0, 0]]) + a = nx.attribute_mixing_matrix( + self.G, "fish", mapping=mapping, normalized=False + ) + np.testing.assert_equal(a, a_result) + a = nx.attribute_mixing_matrix(self.G, "fish", mapping=mapping) + np.testing.assert_equal(a, a_result / a_result.sum()) + + def test_attribute_mixing_matrix_directed(self): + mapping = {"one": 0, "two": 1, "red": 2, "blue": 3} + a_result = np.array([[1, 0, 1, 0], [0, 1, 0, 1], [0, 0, 0, 0], [0, 0, 0, 0]]) + a = nx.attribute_mixing_matrix( + self.D, "fish", mapping=mapping, normalized=False + ) + np.testing.assert_equal(a, a_result) + a = nx.attribute_mixing_matrix(self.D, "fish", mapping=mapping) + np.testing.assert_equal(a, a_result / a_result.sum()) + + def test_attribute_mixing_matrix_multigraph(self): + mapping = {"one": 0, "two": 1, "red": 2, "blue": 3} + a_result = np.array([[4, 0, 0, 0], [0, 2, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]) + a = nx.attribute_mixing_matrix( + self.M, "fish", mapping=mapping, normalized=False + ) + np.testing.assert_equal(a, a_result) + a = nx.attribute_mixing_matrix(self.M, "fish", mapping=mapping) + np.testing.assert_equal(a, a_result / a_result.sum()) + + def test_attribute_mixing_matrix_negative(self): + mapping = {-2: 0, -3: 1, -4: 2} + a_result = np.array([[4.0, 1.0, 1.0], [1.0, 0.0, 0.0], [1.0, 0.0, 0.0]]) + a = nx.attribute_mixing_matrix( + self.N, "margin", mapping=mapping, normalized=False + ) + np.testing.assert_equal(a, a_result) + a = nx.attribute_mixing_matrix(self.N, "margin", mapping=mapping) + np.testing.assert_equal(a, a_result / float(a_result.sum())) + + def test_attribute_mixing_matrix_float(self): + mapping = {0.5: 1, 1.5: 0} + a_result = np.array([[6.0, 1.0], [1.0, 0.0]]) + a = nx.attribute_mixing_matrix( + self.F, "margin", mapping=mapping, normalized=False + ) + np.testing.assert_equal(a, a_result) + a = nx.attribute_mixing_matrix(self.F, "margin", mapping=mapping) + np.testing.assert_equal(a, a_result / a_result.sum()) diff --git a/lib/python3.10/site-packages/networkx/algorithms/assortativity/tests/test_neighbor_degree.py b/lib/python3.10/site-packages/networkx/algorithms/assortativity/tests/test_neighbor_degree.py new file mode 100644 index 0000000000000000000000000000000000000000..bf1252d532079d4de6de4659943ce008eb9018b3 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/assortativity/tests/test_neighbor_degree.py @@ -0,0 +1,108 @@ +import pytest + +import networkx as nx + + +class TestAverageNeighbor: + def test_degree_p4(self): + G = nx.path_graph(4) + answer = {0: 2, 1: 1.5, 2: 1.5, 3: 2} + nd = nx.average_neighbor_degree(G) + assert nd == answer + + D = G.to_directed() + nd = nx.average_neighbor_degree(D) + assert nd == answer + + D = nx.DiGraph(G.edges(data=True)) + nd = nx.average_neighbor_degree(D) + assert nd == {0: 1, 1: 1, 2: 0, 3: 0} + nd = nx.average_neighbor_degree(D, "in", "out") + assert nd == {0: 0, 1: 1, 2: 1, 3: 1} + nd = nx.average_neighbor_degree(D, "out", "in") + assert nd == {0: 1, 1: 1, 2: 1, 3: 0} + nd = nx.average_neighbor_degree(D, "in", "in") + assert nd == {0: 0, 1: 0, 2: 1, 3: 1} + + def test_degree_p4_weighted(self): + G = nx.path_graph(4) + G[1][2]["weight"] = 4 + answer = {0: 2, 1: 1.8, 2: 1.8, 3: 2} + nd = nx.average_neighbor_degree(G, weight="weight") + assert nd == answer + + D = G.to_directed() + nd = nx.average_neighbor_degree(D, weight="weight") + assert nd == answer + + D = nx.DiGraph(G.edges(data=True)) + print(D.edges(data=True)) + nd = nx.average_neighbor_degree(D, weight="weight") + assert nd == {0: 1, 1: 1, 2: 0, 3: 0} + nd = nx.average_neighbor_degree(D, "out", "out", weight="weight") + assert nd == {0: 1, 1: 1, 2: 0, 3: 0} + nd = nx.average_neighbor_degree(D, "in", "in", weight="weight") + assert nd == {0: 0, 1: 0, 2: 1, 3: 1} + nd = nx.average_neighbor_degree(D, "in", "out", weight="weight") + assert nd == {0: 0, 1: 1, 2: 1, 3: 1} + nd = nx.average_neighbor_degree(D, "out", "in", weight="weight") + assert nd == {0: 1, 1: 1, 2: 1, 3: 0} + nd = nx.average_neighbor_degree(D, source="in+out", weight="weight") + assert nd == {0: 1.0, 1: 1.0, 2: 0.8, 3: 1.0} + nd = nx.average_neighbor_degree(D, target="in+out", weight="weight") + assert nd == {0: 2.0, 1: 2.0, 2: 1.0, 3: 0.0} + + D = G.to_directed() + nd = nx.average_neighbor_degree(D, weight="weight") + assert nd == answer + nd = nx.average_neighbor_degree(D, source="out", target="out", weight="weight") + assert nd == answer + + D = G.to_directed() + nd = nx.average_neighbor_degree(D, source="in", target="in", weight="weight") + assert nd == answer + + def test_degree_k4(self): + G = nx.complete_graph(4) + answer = {0: 3, 1: 3, 2: 3, 3: 3} + nd = nx.average_neighbor_degree(G) + assert nd == answer + + D = G.to_directed() + nd = nx.average_neighbor_degree(D) + assert nd == answer + + D = G.to_directed() + nd = nx.average_neighbor_degree(D) + assert nd == answer + + D = G.to_directed() + nd = nx.average_neighbor_degree(D, source="in", target="in") + assert nd == answer + + def test_degree_k4_nodes(self): + G = nx.complete_graph(4) + answer = {1: 3.0, 2: 3.0} + nd = nx.average_neighbor_degree(G, nodes=[1, 2]) + assert nd == answer + + def test_degree_barrat(self): + G = nx.star_graph(5) + G.add_edges_from([(5, 6), (5, 7), (5, 8), (5, 9)]) + G[0][5]["weight"] = 5 + nd = nx.average_neighbor_degree(G)[5] + assert nd == 1.8 + nd = nx.average_neighbor_degree(G, weight="weight")[5] + assert nd == pytest.approx(3.222222, abs=1e-5) + + def test_error_invalid_source_target(self): + G = nx.path_graph(4) + with pytest.raises(nx.NetworkXError): + nx.average_neighbor_degree(G, "error") + with pytest.raises(nx.NetworkXError): + nx.average_neighbor_degree(G, "in", "error") + G = G.to_directed() + with pytest.raises(nx.NetworkXError): + nx.average_neighbor_degree(G, "error") + with pytest.raises(nx.NetworkXError): + nx.average_neighbor_degree(G, "in", "error") diff --git a/lib/python3.10/site-packages/networkx/algorithms/assortativity/tests/test_pairs.py b/lib/python3.10/site-packages/networkx/algorithms/assortativity/tests/test_pairs.py new file mode 100644 index 0000000000000000000000000000000000000000..3984292be84dd7b306066809fb3c50a7cf0424f4 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/assortativity/tests/test_pairs.py @@ -0,0 +1,87 @@ +import networkx as nx + +from .base_test import BaseTestAttributeMixing, BaseTestDegreeMixing + + +class TestAttributeMixingXY(BaseTestAttributeMixing): + def test_node_attribute_xy_undirected(self): + attrxy = sorted(nx.node_attribute_xy(self.G, "fish")) + attrxy_result = sorted( + [ + ("one", "one"), + ("one", "one"), + ("two", "two"), + ("two", "two"), + ("one", "red"), + ("red", "one"), + ("blue", "two"), + ("two", "blue"), + ] + ) + assert attrxy == attrxy_result + + def test_node_attribute_xy_undirected_nodes(self): + attrxy = sorted(nx.node_attribute_xy(self.G, "fish", nodes=["one", "yellow"])) + attrxy_result = sorted([]) + assert attrxy == attrxy_result + + def test_node_attribute_xy_directed(self): + attrxy = sorted(nx.node_attribute_xy(self.D, "fish")) + attrxy_result = sorted( + [("one", "one"), ("two", "two"), ("one", "red"), ("two", "blue")] + ) + assert attrxy == attrxy_result + + def test_node_attribute_xy_multigraph(self): + attrxy = sorted(nx.node_attribute_xy(self.M, "fish")) + attrxy_result = [ + ("one", "one"), + ("one", "one"), + ("one", "one"), + ("one", "one"), + ("two", "two"), + ("two", "two"), + ] + assert attrxy == attrxy_result + + def test_node_attribute_xy_selfloop(self): + attrxy = sorted(nx.node_attribute_xy(self.S, "fish")) + attrxy_result = [("one", "one"), ("two", "two")] + assert attrxy == attrxy_result + + +class TestDegreeMixingXY(BaseTestDegreeMixing): + def test_node_degree_xy_undirected(self): + xy = sorted(nx.node_degree_xy(self.P4)) + xy_result = sorted([(1, 2), (2, 1), (2, 2), (2, 2), (1, 2), (2, 1)]) + assert xy == xy_result + + def test_node_degree_xy_undirected_nodes(self): + xy = sorted(nx.node_degree_xy(self.P4, nodes=[0, 1, -1])) + xy_result = sorted([(1, 2), (2, 1)]) + assert xy == xy_result + + def test_node_degree_xy_directed(self): + xy = sorted(nx.node_degree_xy(self.D)) + xy_result = sorted([(2, 1), (2, 3), (1, 3), (1, 3)]) + assert xy == xy_result + + def test_node_degree_xy_multigraph(self): + xy = sorted(nx.node_degree_xy(self.M)) + xy_result = sorted( + [(2, 3), (2, 3), (3, 2), (3, 2), (2, 3), (3, 2), (1, 2), (2, 1)] + ) + assert xy == xy_result + + def test_node_degree_xy_selfloop(self): + xy = sorted(nx.node_degree_xy(self.S)) + xy_result = sorted([(2, 2), (2, 2)]) + assert xy == xy_result + + def test_node_degree_xy_weighted(self): + G = nx.Graph() + G.add_edge(1, 2, weight=7) + G.add_edge(2, 3, weight=10) + xy = sorted(nx.node_degree_xy(G, weight="weight")) + xy_result = sorted([(7, 17), (17, 10), (17, 7), (10, 17)]) + assert xy == xy_result diff --git a/lib/python3.10/site-packages/networkx/algorithms/bipartite/__init__.py b/lib/python3.10/site-packages/networkx/algorithms/bipartite/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7839db96a712bd5db28c29112ac20864c15d9539 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/bipartite/__init__.py @@ -0,0 +1,87 @@ +r"""This module provides functions and operations for bipartite +graphs. Bipartite graphs `B = (U, V, E)` have two node sets `U,V` and edges in +`E` that only connect nodes from opposite sets. It is common in the literature +to use an spatial analogy referring to the two node sets as top and bottom nodes. + +The bipartite algorithms are not imported into the networkx namespace +at the top level so the easiest way to use them is with: + +>>> from networkx.algorithms import bipartite + +NetworkX does not have a custom bipartite graph class but the Graph() +or DiGraph() classes can be used to represent bipartite graphs. However, +you have to keep track of which set each node belongs to, and make +sure that there is no edge between nodes of the same set. The convention used +in NetworkX is to use a node attribute named `bipartite` with values 0 or 1 to +identify the sets each node belongs to. This convention is not enforced in +the source code of bipartite functions, it's only a recommendation. + +For example: + +>>> B = nx.Graph() +>>> # Add nodes with the node attribute "bipartite" +>>> B.add_nodes_from([1, 2, 3, 4], bipartite=0) +>>> B.add_nodes_from(["a", "b", "c"], bipartite=1) +>>> # Add edges only between nodes of opposite node sets +>>> B.add_edges_from([(1, "a"), (1, "b"), (2, "b"), (2, "c"), (3, "c"), (4, "a")]) + +Many algorithms of the bipartite module of NetworkX require, as an argument, a +container with all the nodes that belong to one set, in addition to the bipartite +graph `B`. The functions in the bipartite package do not check that the node set +is actually correct nor that the input graph is actually bipartite. +If `B` is connected, you can find the two node sets using a two-coloring +algorithm: + +>>> nx.is_connected(B) +True +>>> bottom_nodes, top_nodes = bipartite.sets(B) + +However, if the input graph is not connected, there are more than one possible +colorations. This is the reason why we require the user to pass a container +with all nodes of one bipartite node set as an argument to most bipartite +functions. In the face of ambiguity, we refuse the temptation to guess and +raise an :exc:`AmbiguousSolution ` +Exception if the input graph for +:func:`bipartite.sets ` +is disconnected. + +Using the `bipartite` node attribute, you can easily get the two node sets: + +>>> top_nodes = {n for n, d in B.nodes(data=True) if d["bipartite"] == 0} +>>> bottom_nodes = set(B) - top_nodes + +So you can easily use the bipartite algorithms that require, as an argument, a +container with all nodes that belong to one node set: + +>>> print(round(bipartite.density(B, bottom_nodes), 2)) +0.5 +>>> G = bipartite.projected_graph(B, top_nodes) + +All bipartite graph generators in NetworkX build bipartite graphs with the +`bipartite` node attribute. Thus, you can use the same approach: + +>>> RB = bipartite.random_graph(5, 7, 0.2) +>>> RB_top = {n for n, d in RB.nodes(data=True) if d["bipartite"] == 0} +>>> RB_bottom = set(RB) - RB_top +>>> list(RB_top) +[0, 1, 2, 3, 4] +>>> list(RB_bottom) +[5, 6, 7, 8, 9, 10, 11] + +For other bipartite graph generators see +:mod:`Generators `. + +""" + +from networkx.algorithms.bipartite.basic import * +from networkx.algorithms.bipartite.centrality import * +from networkx.algorithms.bipartite.cluster import * +from networkx.algorithms.bipartite.covering import * +from networkx.algorithms.bipartite.edgelist import * +from networkx.algorithms.bipartite.matching import * +from networkx.algorithms.bipartite.matrix import * +from networkx.algorithms.bipartite.projection import * +from networkx.algorithms.bipartite.redundancy import * +from networkx.algorithms.bipartite.spectral import * +from networkx.algorithms.bipartite.generators import * +from networkx.algorithms.bipartite.extendability import * diff --git a/lib/python3.10/site-packages/networkx/algorithms/bipartite/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/bipartite/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..60291b2e217f6b5c492f31c85e70f99f5a6679bb Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/bipartite/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/bipartite/__pycache__/basic.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/bipartite/__pycache__/basic.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d2b580721109dae8c6e2d4c2553a6980a03849eb Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/bipartite/__pycache__/basic.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/bipartite/__pycache__/centrality.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/bipartite/__pycache__/centrality.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e024293dd67069cec81d56f2b1f3438be874e964 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/bipartite/__pycache__/centrality.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/bipartite/__pycache__/cluster.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/bipartite/__pycache__/cluster.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7fb53af49c66c6fb514aff2f657d219dc025cd2b Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/bipartite/__pycache__/cluster.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/bipartite/__pycache__/covering.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/bipartite/__pycache__/covering.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..499e7521a038dc97f2ade85be377e51fd06f7417 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/bipartite/__pycache__/covering.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/bipartite/__pycache__/edgelist.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/bipartite/__pycache__/edgelist.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..26a1d838633c04c5daf71d31faa66e515eb524ef Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/bipartite/__pycache__/edgelist.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/bipartite/__pycache__/extendability.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/bipartite/__pycache__/extendability.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fba9bea5c2bba03f853ea11dd3b2a2bebcf960c1 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/bipartite/__pycache__/extendability.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/bipartite/__pycache__/generators.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/bipartite/__pycache__/generators.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..60fa82b6b5950f1a678047a3006da36c37a6e52d Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/bipartite/__pycache__/generators.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/bipartite/__pycache__/matching.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/bipartite/__pycache__/matching.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..444a2f7ddd5e8706ad17d04c89d80574f4179154 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/bipartite/__pycache__/matching.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/bipartite/__pycache__/matrix.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/bipartite/__pycache__/matrix.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7dcff6a78c1031edc0c9b73b84cd6c370c646587 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/bipartite/__pycache__/matrix.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/bipartite/__pycache__/projection.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/bipartite/__pycache__/projection.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..47c18452c4b3321ad74df2990d0fac850640cd35 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/bipartite/__pycache__/projection.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/bipartite/__pycache__/redundancy.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/bipartite/__pycache__/redundancy.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..667ab863ad785b5a856542c3db1527a6b93c039e Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/bipartite/__pycache__/redundancy.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/bipartite/__pycache__/spectral.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/bipartite/__pycache__/spectral.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..837eaee0127c9ad03573dd56cf0870d12d3bde33 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/bipartite/__pycache__/spectral.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/bipartite/basic.py b/lib/python3.10/site-packages/networkx/algorithms/bipartite/basic.py new file mode 100644 index 0000000000000000000000000000000000000000..8d9a4d5b341bf9a14048acc1132e6f450685cc62 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/bipartite/basic.py @@ -0,0 +1,322 @@ +""" +========================== +Bipartite Graph Algorithms +========================== +""" + +import networkx as nx +from networkx.algorithms.components import connected_components +from networkx.exception import AmbiguousSolution + +__all__ = [ + "is_bipartite", + "is_bipartite_node_set", + "color", + "sets", + "density", + "degrees", +] + + +@nx._dispatchable +def color(G): + """Returns a two-coloring of the graph. + + Raises an exception if the graph is not bipartite. + + Parameters + ---------- + G : NetworkX graph + + Returns + ------- + color : dictionary + A dictionary keyed by node with a 1 or 0 as data for each node color. + + Raises + ------ + NetworkXError + If the graph is not two-colorable. + + Examples + -------- + >>> from networkx.algorithms import bipartite + >>> G = nx.path_graph(4) + >>> c = bipartite.color(G) + >>> print(c) + {0: 1, 1: 0, 2: 1, 3: 0} + + You can use this to set a node attribute indicating the bipartite set: + + >>> nx.set_node_attributes(G, c, "bipartite") + >>> print(G.nodes[0]["bipartite"]) + 1 + >>> print(G.nodes[1]["bipartite"]) + 0 + """ + if G.is_directed(): + import itertools + + def neighbors(v): + return itertools.chain.from_iterable([G.predecessors(v), G.successors(v)]) + + else: + neighbors = G.neighbors + + color = {} + for n in G: # handle disconnected graphs + if n in color or len(G[n]) == 0: # skip isolates + continue + queue = [n] + color[n] = 1 # nodes seen with color (1 or 0) + while queue: + v = queue.pop() + c = 1 - color[v] # opposite color of node v + for w in neighbors(v): + if w in color: + if color[w] == color[v]: + raise nx.NetworkXError("Graph is not bipartite.") + else: + color[w] = c + queue.append(w) + # color isolates with 0 + color.update(dict.fromkeys(nx.isolates(G), 0)) + return color + + +@nx._dispatchable +def is_bipartite(G): + """Returns True if graph G is bipartite, False if not. + + Parameters + ---------- + G : NetworkX graph + + Examples + -------- + >>> from networkx.algorithms import bipartite + >>> G = nx.path_graph(4) + >>> print(bipartite.is_bipartite(G)) + True + + See Also + -------- + color, is_bipartite_node_set + """ + try: + color(G) + return True + except nx.NetworkXError: + return False + + +@nx._dispatchable +def is_bipartite_node_set(G, nodes): + """Returns True if nodes and G/nodes are a bipartition of G. + + Parameters + ---------- + G : NetworkX graph + + nodes: list or container + Check if nodes are a one of a bipartite set. + + Examples + -------- + >>> from networkx.algorithms import bipartite + >>> G = nx.path_graph(4) + >>> X = set([1, 3]) + >>> bipartite.is_bipartite_node_set(G, X) + True + + Notes + ----- + An exception is raised if the input nodes are not distinct, because in this + case some bipartite algorithms will yield incorrect results. + For connected graphs the bipartite sets are unique. This function handles + disconnected graphs. + """ + S = set(nodes) + + if len(S) < len(nodes): + # this should maybe just return False? + raise AmbiguousSolution( + "The input node set contains duplicates.\n" + "This may lead to incorrect results when using it in bipartite algorithms.\n" + "Consider using set(nodes) as the input" + ) + + for CC in (G.subgraph(c).copy() for c in connected_components(G)): + X, Y = sets(CC) + if not ( + (X.issubset(S) and Y.isdisjoint(S)) or (Y.issubset(S) and X.isdisjoint(S)) + ): + return False + return True + + +@nx._dispatchable +def sets(G, top_nodes=None): + """Returns bipartite node sets of graph G. + + Raises an exception if the graph is not bipartite or if the input + graph is disconnected and thus more than one valid solution exists. + See :mod:`bipartite documentation ` + for further details on how bipartite graphs are handled in NetworkX. + + Parameters + ---------- + G : NetworkX graph + + top_nodes : container, optional + Container with all nodes in one bipartite node set. If not supplied + it will be computed. But if more than one solution exists an exception + will be raised. + + Returns + ------- + X : set + Nodes from one side of the bipartite graph. + Y : set + Nodes from the other side. + + Raises + ------ + AmbiguousSolution + Raised if the input bipartite graph is disconnected and no container + with all nodes in one bipartite set is provided. When determining + the nodes in each bipartite set more than one valid solution is + possible if the input graph is disconnected. + NetworkXError + Raised if the input graph is not bipartite. + + Examples + -------- + >>> from networkx.algorithms import bipartite + >>> G = nx.path_graph(4) + >>> X, Y = bipartite.sets(G) + >>> list(X) + [0, 2] + >>> list(Y) + [1, 3] + + See Also + -------- + color + + """ + if G.is_directed(): + is_connected = nx.is_weakly_connected + else: + is_connected = nx.is_connected + if top_nodes is not None: + X = set(top_nodes) + Y = set(G) - X + else: + if not is_connected(G): + msg = "Disconnected graph: Ambiguous solution for bipartite sets." + raise nx.AmbiguousSolution(msg) + c = color(G) + X = {n for n, is_top in c.items() if is_top} + Y = {n for n, is_top in c.items() if not is_top} + return (X, Y) + + +@nx._dispatchable(graphs="B") +def density(B, nodes): + """Returns density of bipartite graph B. + + Parameters + ---------- + B : NetworkX graph + + nodes: list or container + Nodes in one node set of the bipartite graph. + + Returns + ------- + d : float + The bipartite density + + Examples + -------- + >>> from networkx.algorithms import bipartite + >>> G = nx.complete_bipartite_graph(3, 2) + >>> X = set([0, 1, 2]) + >>> bipartite.density(G, X) + 1.0 + >>> Y = set([3, 4]) + >>> bipartite.density(G, Y) + 1.0 + + Notes + ----- + The container of nodes passed as argument must contain all nodes + in one of the two bipartite node sets to avoid ambiguity in the + case of disconnected graphs. + See :mod:`bipartite documentation ` + for further details on how bipartite graphs are handled in NetworkX. + + See Also + -------- + color + """ + n = len(B) + m = nx.number_of_edges(B) + nb = len(nodes) + nt = n - nb + if m == 0: # includes cases n==0 and n==1 + d = 0.0 + else: + if B.is_directed(): + d = m / (2 * nb * nt) + else: + d = m / (nb * nt) + return d + + +@nx._dispatchable(graphs="B", edge_attrs="weight") +def degrees(B, nodes, weight=None): + """Returns the degrees of the two node sets in the bipartite graph B. + + Parameters + ---------- + B : NetworkX graph + + nodes: list or container + Nodes in one node set of the bipartite graph. + + weight : string or None, optional (default=None) + The edge attribute that holds the numerical value used as a weight. + If None, then each edge has weight 1. + The degree is the sum of the edge weights adjacent to the node. + + Returns + ------- + (degX,degY) : tuple of dictionaries + The degrees of the two bipartite sets as dictionaries keyed by node. + + Examples + -------- + >>> from networkx.algorithms import bipartite + >>> G = nx.complete_bipartite_graph(3, 2) + >>> Y = set([3, 4]) + >>> degX, degY = bipartite.degrees(G, Y) + >>> dict(degX) + {0: 2, 1: 2, 2: 2} + + Notes + ----- + The container of nodes passed as argument must contain all nodes + in one of the two bipartite node sets to avoid ambiguity in the + case of disconnected graphs. + See :mod:`bipartite documentation ` + for further details on how bipartite graphs are handled in NetworkX. + + See Also + -------- + color, density + """ + bottom = set(nodes) + top = set(B) - bottom + return (B.degree(top, weight), B.degree(bottom, weight)) diff --git a/lib/python3.10/site-packages/networkx/algorithms/bipartite/centrality.py b/lib/python3.10/site-packages/networkx/algorithms/bipartite/centrality.py new file mode 100644 index 0000000000000000000000000000000000000000..42d7270ee7d0bb18b56a55dc4c17dc19f5dc77a7 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/bipartite/centrality.py @@ -0,0 +1,290 @@ +import networkx as nx + +__all__ = ["degree_centrality", "betweenness_centrality", "closeness_centrality"] + + +@nx._dispatchable(name="bipartite_degree_centrality") +def degree_centrality(G, nodes): + r"""Compute the degree centrality for nodes in a bipartite network. + + The degree centrality for a node `v` is the fraction of nodes + connected to it. + + Parameters + ---------- + G : graph + A bipartite network + + nodes : list or container + Container with all nodes in one bipartite node set. + + Returns + ------- + centrality : dictionary + Dictionary keyed by node with bipartite degree centrality as the value. + + Examples + -------- + >>> G = nx.wheel_graph(5) + >>> top_nodes = {0, 1, 2} + >>> nx.bipartite.degree_centrality(G, nodes=top_nodes) + {0: 2.0, 1: 1.5, 2: 1.5, 3: 1.0, 4: 1.0} + + See Also + -------- + betweenness_centrality + closeness_centrality + :func:`~networkx.algorithms.bipartite.basic.sets` + :func:`~networkx.algorithms.bipartite.basic.is_bipartite` + + Notes + ----- + The nodes input parameter must contain all nodes in one bipartite node set, + but the dictionary returned contains all nodes from both bipartite node + sets. See :mod:`bipartite documentation ` + for further details on how bipartite graphs are handled in NetworkX. + + For unipartite networks, the degree centrality values are + normalized by dividing by the maximum possible degree (which is + `n-1` where `n` is the number of nodes in G). + + In the bipartite case, the maximum possible degree of a node in a + bipartite node set is the number of nodes in the opposite node set + [1]_. The degree centrality for a node `v` in the bipartite + sets `U` with `n` nodes and `V` with `m` nodes is + + .. math:: + + d_{v} = \frac{deg(v)}{m}, \mbox{for} v \in U , + + d_{v} = \frac{deg(v)}{n}, \mbox{for} v \in V , + + + where `deg(v)` is the degree of node `v`. + + References + ---------- + .. [1] Borgatti, S.P. and Halgin, D. In press. "Analyzing Affiliation + Networks". In Carrington, P. and Scott, J. (eds) The Sage Handbook + of Social Network Analysis. Sage Publications. + https://dx.doi.org/10.4135/9781446294413.n28 + """ + top = set(nodes) + bottom = set(G) - top + s = 1.0 / len(bottom) + centrality = {n: d * s for n, d in G.degree(top)} + s = 1.0 / len(top) + centrality.update({n: d * s for n, d in G.degree(bottom)}) + return centrality + + +@nx._dispatchable(name="bipartite_betweenness_centrality") +def betweenness_centrality(G, nodes): + r"""Compute betweenness centrality for nodes in a bipartite network. + + Betweenness centrality of a node `v` is the sum of the + fraction of all-pairs shortest paths that pass through `v`. + + Values of betweenness are normalized by the maximum possible + value which for bipartite graphs is limited by the relative size + of the two node sets [1]_. + + Let `n` be the number of nodes in the node set `U` and + `m` be the number of nodes in the node set `V`, then + nodes in `U` are normalized by dividing by + + .. math:: + + \frac{1}{2} [m^2 (s + 1)^2 + m (s + 1)(2t - s - 1) - t (2s - t + 3)] , + + where + + .. math:: + + s = (n - 1) \div m , t = (n - 1) \mod m , + + and nodes in `V` are normalized by dividing by + + .. math:: + + \frac{1}{2} [n^2 (p + 1)^2 + n (p + 1)(2r - p - 1) - r (2p - r + 3)] , + + where, + + .. math:: + + p = (m - 1) \div n , r = (m - 1) \mod n . + + Parameters + ---------- + G : graph + A bipartite graph + + nodes : list or container + Container with all nodes in one bipartite node set. + + Returns + ------- + betweenness : dictionary + Dictionary keyed by node with bipartite betweenness centrality + as the value. + + Examples + -------- + >>> G = nx.cycle_graph(4) + >>> top_nodes = {1, 2} + >>> nx.bipartite.betweenness_centrality(G, nodes=top_nodes) + {0: 0.25, 1: 0.25, 2: 0.25, 3: 0.25} + + See Also + -------- + degree_centrality + closeness_centrality + :func:`~networkx.algorithms.bipartite.basic.sets` + :func:`~networkx.algorithms.bipartite.basic.is_bipartite` + + Notes + ----- + The nodes input parameter must contain all nodes in one bipartite node set, + but the dictionary returned contains all nodes from both node sets. + See :mod:`bipartite documentation ` + for further details on how bipartite graphs are handled in NetworkX. + + + References + ---------- + .. [1] Borgatti, S.P. and Halgin, D. In press. "Analyzing Affiliation + Networks". In Carrington, P. and Scott, J. (eds) The Sage Handbook + of Social Network Analysis. Sage Publications. + https://dx.doi.org/10.4135/9781446294413.n28 + """ + top = set(nodes) + bottom = set(G) - top + n = len(top) + m = len(bottom) + s, t = divmod(n - 1, m) + bet_max_top = ( + ((m**2) * ((s + 1) ** 2)) + + (m * (s + 1) * (2 * t - s - 1)) + - (t * ((2 * s) - t + 3)) + ) / 2.0 + p, r = divmod(m - 1, n) + bet_max_bot = ( + ((n**2) * ((p + 1) ** 2)) + + (n * (p + 1) * (2 * r - p - 1)) + - (r * ((2 * p) - r + 3)) + ) / 2.0 + betweenness = nx.betweenness_centrality(G, normalized=False, weight=None) + for node in top: + betweenness[node] /= bet_max_top + for node in bottom: + betweenness[node] /= bet_max_bot + return betweenness + + +@nx._dispatchable(name="bipartite_closeness_centrality") +def closeness_centrality(G, nodes, normalized=True): + r"""Compute the closeness centrality for nodes in a bipartite network. + + The closeness of a node is the distance to all other nodes in the + graph or in the case that the graph is not connected to all other nodes + in the connected component containing that node. + + Parameters + ---------- + G : graph + A bipartite network + + nodes : list or container + Container with all nodes in one bipartite node set. + + normalized : bool, optional + If True (default) normalize by connected component size. + + Returns + ------- + closeness : dictionary + Dictionary keyed by node with bipartite closeness centrality + as the value. + + Examples + -------- + >>> G = nx.wheel_graph(5) + >>> top_nodes = {0, 1, 2} + >>> nx.bipartite.closeness_centrality(G, nodes=top_nodes) + {0: 1.5, 1: 1.2, 2: 1.2, 3: 1.0, 4: 1.0} + + See Also + -------- + betweenness_centrality + degree_centrality + :func:`~networkx.algorithms.bipartite.basic.sets` + :func:`~networkx.algorithms.bipartite.basic.is_bipartite` + + Notes + ----- + The nodes input parameter must contain all nodes in one bipartite node set, + but the dictionary returned contains all nodes from both node sets. + See :mod:`bipartite documentation ` + for further details on how bipartite graphs are handled in NetworkX. + + + Closeness centrality is normalized by the minimum distance possible. + In the bipartite case the minimum distance for a node in one bipartite + node set is 1 from all nodes in the other node set and 2 from all + other nodes in its own set [1]_. Thus the closeness centrality + for node `v` in the two bipartite sets `U` with + `n` nodes and `V` with `m` nodes is + + .. math:: + + c_{v} = \frac{m + 2(n - 1)}{d}, \mbox{for} v \in U, + + c_{v} = \frac{n + 2(m - 1)}{d}, \mbox{for} v \in V, + + where `d` is the sum of the distances from `v` to all + other nodes. + + Higher values of closeness indicate higher centrality. + + As in the unipartite case, setting normalized=True causes the + values to normalized further to n-1 / size(G)-1 where n is the + number of nodes in the connected part of graph containing the + node. If the graph is not completely connected, this algorithm + computes the closeness centrality for each connected part + separately. + + References + ---------- + .. [1] Borgatti, S.P. and Halgin, D. In press. "Analyzing Affiliation + Networks". In Carrington, P. and Scott, J. (eds) The Sage Handbook + of Social Network Analysis. Sage Publications. + https://dx.doi.org/10.4135/9781446294413.n28 + """ + closeness = {} + path_length = nx.single_source_shortest_path_length + top = set(nodes) + bottom = set(G) - top + n = len(top) + m = len(bottom) + for node in top: + sp = dict(path_length(G, node)) + totsp = sum(sp.values()) + if totsp > 0.0 and len(G) > 1: + closeness[node] = (m + 2 * (n - 1)) / totsp + if normalized: + s = (len(sp) - 1) / (len(G) - 1) + closeness[node] *= s + else: + closeness[node] = 0.0 + for node in bottom: + sp = dict(path_length(G, node)) + totsp = sum(sp.values()) + if totsp > 0.0 and len(G) > 1: + closeness[node] = (n + 2 * (m - 1)) / totsp + if normalized: + s = (len(sp) - 1) / (len(G) - 1) + closeness[node] *= s + else: + closeness[node] = 0.0 + return closeness diff --git a/lib/python3.10/site-packages/networkx/algorithms/bipartite/cluster.py b/lib/python3.10/site-packages/networkx/algorithms/bipartite/cluster.py new file mode 100644 index 0000000000000000000000000000000000000000..5b66b2802c97b44d956b158656cb58f22429b789 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/bipartite/cluster.py @@ -0,0 +1,278 @@ +"""Functions for computing clustering of pairs""" + +import itertools + +import networkx as nx + +__all__ = [ + "clustering", + "average_clustering", + "latapy_clustering", + "robins_alexander_clustering", +] + + +def cc_dot(nu, nv): + return len(nu & nv) / len(nu | nv) + + +def cc_max(nu, nv): + return len(nu & nv) / max(len(nu), len(nv)) + + +def cc_min(nu, nv): + return len(nu & nv) / min(len(nu), len(nv)) + + +modes = {"dot": cc_dot, "min": cc_min, "max": cc_max} + + +@nx._dispatchable +def latapy_clustering(G, nodes=None, mode="dot"): + r"""Compute a bipartite clustering coefficient for nodes. + + The bipartite clustering coefficient is a measure of local density + of connections defined as [1]_: + + .. math:: + + c_u = \frac{\sum_{v \in N(N(u))} c_{uv} }{|N(N(u))|} + + where `N(N(u))` are the second order neighbors of `u` in `G` excluding `u`, + and `c_{uv}` is the pairwise clustering coefficient between nodes + `u` and `v`. + + The mode selects the function for `c_{uv}` which can be: + + `dot`: + + .. math:: + + c_{uv}=\frac{|N(u)\cap N(v)|}{|N(u) \cup N(v)|} + + `min`: + + .. math:: + + c_{uv}=\frac{|N(u)\cap N(v)|}{min(|N(u)|,|N(v)|)} + + `max`: + + .. math:: + + c_{uv}=\frac{|N(u)\cap N(v)|}{max(|N(u)|,|N(v)|)} + + + Parameters + ---------- + G : graph + A bipartite graph + + nodes : list or iterable (optional) + Compute bipartite clustering for these nodes. The default + is all nodes in G. + + mode : string + The pairwise bipartite clustering method to be used in the computation. + It must be "dot", "max", or "min". + + Returns + ------- + clustering : dictionary + A dictionary keyed by node with the clustering coefficient value. + + + Examples + -------- + >>> from networkx.algorithms import bipartite + >>> G = nx.path_graph(4) # path graphs are bipartite + >>> c = bipartite.clustering(G) + >>> c[0] + 0.5 + >>> c = bipartite.clustering(G, mode="min") + >>> c[0] + 1.0 + + See Also + -------- + robins_alexander_clustering + average_clustering + networkx.algorithms.cluster.square_clustering + + References + ---------- + .. [1] Latapy, Matthieu, Clémence Magnien, and Nathalie Del Vecchio (2008). + Basic notions for the analysis of large two-mode networks. + Social Networks 30(1), 31--48. + """ + if not nx.algorithms.bipartite.is_bipartite(G): + raise nx.NetworkXError("Graph is not bipartite") + + try: + cc_func = modes[mode] + except KeyError as err: + raise nx.NetworkXError( + "Mode for bipartite clustering must be: dot, min or max" + ) from err + + if nodes is None: + nodes = G + ccs = {} + for v in nodes: + cc = 0.0 + nbrs2 = {u for nbr in G[v] for u in G[nbr]} - {v} + for u in nbrs2: + cc += cc_func(set(G[u]), set(G[v])) + if cc > 0.0: # len(nbrs2)>0 + cc /= len(nbrs2) + ccs[v] = cc + return ccs + + +clustering = latapy_clustering + + +@nx._dispatchable(name="bipartite_average_clustering") +def average_clustering(G, nodes=None, mode="dot"): + r"""Compute the average bipartite clustering coefficient. + + A clustering coefficient for the whole graph is the average, + + .. math:: + + C = \frac{1}{n}\sum_{v \in G} c_v, + + where `n` is the number of nodes in `G`. + + Similar measures for the two bipartite sets can be defined [1]_ + + .. math:: + + C_X = \frac{1}{|X|}\sum_{v \in X} c_v, + + where `X` is a bipartite set of `G`. + + Parameters + ---------- + G : graph + a bipartite graph + + nodes : list or iterable, optional + A container of nodes to use in computing the average. + The nodes should be either the entire graph (the default) or one of the + bipartite sets. + + mode : string + The pairwise bipartite clustering method. + It must be "dot", "max", or "min" + + Returns + ------- + clustering : float + The average bipartite clustering for the given set of nodes or the + entire graph if no nodes are specified. + + Examples + -------- + >>> from networkx.algorithms import bipartite + >>> G = nx.star_graph(3) # star graphs are bipartite + >>> bipartite.average_clustering(G) + 0.75 + >>> X, Y = bipartite.sets(G) + >>> bipartite.average_clustering(G, X) + 0.0 + >>> bipartite.average_clustering(G, Y) + 1.0 + + See Also + -------- + clustering + + Notes + ----- + The container of nodes passed to this function must contain all of the nodes + in one of the bipartite sets ("top" or "bottom") in order to compute + the correct average bipartite clustering coefficients. + See :mod:`bipartite documentation ` + for further details on how bipartite graphs are handled in NetworkX. + + + References + ---------- + .. [1] Latapy, Matthieu, Clémence Magnien, and Nathalie Del Vecchio (2008). + Basic notions for the analysis of large two-mode networks. + Social Networks 30(1), 31--48. + """ + if nodes is None: + nodes = G + ccs = latapy_clustering(G, nodes=nodes, mode=mode) + return sum(ccs[v] for v in nodes) / len(nodes) + + +@nx._dispatchable +def robins_alexander_clustering(G): + r"""Compute the bipartite clustering of G. + + Robins and Alexander [1]_ defined bipartite clustering coefficient as + four times the number of four cycles `C_4` divided by the number of + three paths `L_3` in a bipartite graph: + + .. math:: + + CC_4 = \frac{4 * C_4}{L_3} + + Parameters + ---------- + G : graph + a bipartite graph + + Returns + ------- + clustering : float + The Robins and Alexander bipartite clustering for the input graph. + + Examples + -------- + >>> from networkx.algorithms import bipartite + >>> G = nx.davis_southern_women_graph() + >>> print(round(bipartite.robins_alexander_clustering(G), 3)) + 0.468 + + See Also + -------- + latapy_clustering + networkx.algorithms.cluster.square_clustering + + References + ---------- + .. [1] Robins, G. and M. Alexander (2004). Small worlds among interlocking + directors: Network structure and distance in bipartite graphs. + Computational & Mathematical Organization Theory 10(1), 69–94. + + """ + if G.order() < 4 or G.size() < 3: + return 0 + L_3 = _threepaths(G) + if L_3 == 0: + return 0 + C_4 = _four_cycles(G) + return (4.0 * C_4) / L_3 + + +def _four_cycles(G): + cycles = 0 + for v in G: + for u, w in itertools.combinations(G[v], 2): + cycles += len((set(G[u]) & set(G[w])) - {v}) + return cycles / 4 + + +def _threepaths(G): + paths = 0 + for v in G: + for u in G[v]: + for w in set(G[u]) - {v}: + paths += len(set(G[w]) - {v, u}) + # Divide by two because we count each three path twice + # one for each possible starting point + return paths / 2 diff --git a/lib/python3.10/site-packages/networkx/algorithms/bipartite/covering.py b/lib/python3.10/site-packages/networkx/algorithms/bipartite/covering.py new file mode 100644 index 0000000000000000000000000000000000000000..f937903e5576ec7313a774863c8470a4a271a252 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/bipartite/covering.py @@ -0,0 +1,57 @@ +"""Functions related to graph covers.""" + +import networkx as nx +from networkx.algorithms.bipartite.matching import hopcroft_karp_matching +from networkx.algorithms.covering import min_edge_cover as _min_edge_cover +from networkx.utils import not_implemented_for + +__all__ = ["min_edge_cover"] + + +@not_implemented_for("directed") +@not_implemented_for("multigraph") +@nx._dispatchable(name="bipartite_min_edge_cover") +def min_edge_cover(G, matching_algorithm=None): + """Returns a set of edges which constitutes + the minimum edge cover of the graph. + + The smallest edge cover can be found in polynomial time by finding + a maximum matching and extending it greedily so that all nodes + are covered. + + Parameters + ---------- + G : NetworkX graph + An undirected bipartite graph. + + matching_algorithm : function + A function that returns a maximum cardinality matching in a + given bipartite graph. The function must take one input, the + graph ``G``, and return a dictionary mapping each node to its + mate. If not specified, + :func:`~networkx.algorithms.bipartite.matching.hopcroft_karp_matching` + will be used. Other possibilities include + :func:`~networkx.algorithms.bipartite.matching.eppstein_matching`, + + Returns + ------- + set + A set of the edges in a minimum edge cover of the graph, given as + pairs of nodes. It contains both the edges `(u, v)` and `(v, u)` + for given nodes `u` and `v` among the edges of minimum edge cover. + + Notes + ----- + An edge cover of a graph is a set of edges such that every node of + the graph is incident to at least one edge of the set. + A minimum edge cover is an edge covering of smallest cardinality. + + Due to its implementation, the worst-case running time of this algorithm + is bounded by the worst-case running time of the function + ``matching_algorithm``. + """ + if G.order() == 0: # Special case for the empty graph + return set() + if matching_algorithm is None: + matching_algorithm = hopcroft_karp_matching + return _min_edge_cover(G, matching_algorithm=matching_algorithm) diff --git a/lib/python3.10/site-packages/networkx/algorithms/bipartite/edgelist.py b/lib/python3.10/site-packages/networkx/algorithms/bipartite/edgelist.py new file mode 100644 index 0000000000000000000000000000000000000000..db6ef9d8e2773de9ec698db151661fac28adc326 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/bipartite/edgelist.py @@ -0,0 +1,360 @@ +""" +******************** +Bipartite Edge Lists +******************** +Read and write NetworkX graphs as bipartite edge lists. + +Format +------ +You can read or write three formats of edge lists with these functions. + +Node pairs with no data:: + + 1 2 + +Python dictionary as data:: + + 1 2 {'weight':7, 'color':'green'} + +Arbitrary data:: + + 1 2 7 green + +For each edge (u, v) the node u is assigned to part 0 and the node v to part 1. +""" + +__all__ = ["generate_edgelist", "write_edgelist", "parse_edgelist", "read_edgelist"] + +import networkx as nx +from networkx.utils import not_implemented_for, open_file + + +@open_file(1, mode="wb") +def write_edgelist(G, path, comments="#", delimiter=" ", data=True, encoding="utf-8"): + """Write a bipartite graph as a list of edges. + + Parameters + ---------- + G : Graph + A NetworkX bipartite graph + path : file or string + File or filename to write. If a file is provided, it must be + opened in 'wb' mode. Filenames ending in .gz or .bz2 will be compressed. + comments : string, optional + The character used to indicate the start of a comment + delimiter : string, optional + The string used to separate values. The default is whitespace. + data : bool or list, optional + If False write no edge data. + If True write a string representation of the edge data dictionary.. + If a list (or other iterable) is provided, write the keys specified + in the list. + encoding: string, optional + Specify which encoding to use when writing file. + + Examples + -------- + >>> G = nx.path_graph(4) + >>> G.add_nodes_from([0, 2], bipartite=0) + >>> G.add_nodes_from([1, 3], bipartite=1) + >>> nx.write_edgelist(G, "test.edgelist") + >>> fh = open("test.edgelist", "wb") + >>> nx.write_edgelist(G, fh) + >>> nx.write_edgelist(G, "test.edgelist.gz") + >>> nx.write_edgelist(G, "test.edgelist.gz", data=False) + + >>> G = nx.Graph() + >>> G.add_edge(1, 2, weight=7, color="red") + >>> nx.write_edgelist(G, "test.edgelist", data=False) + >>> nx.write_edgelist(G, "test.edgelist", data=["color"]) + >>> nx.write_edgelist(G, "test.edgelist", data=["color", "weight"]) + + See Also + -------- + write_edgelist + generate_edgelist + """ + for line in generate_edgelist(G, delimiter, data): + line += "\n" + path.write(line.encode(encoding)) + + +@not_implemented_for("directed") +def generate_edgelist(G, delimiter=" ", data=True): + """Generate a single line of the bipartite graph G in edge list format. + + Parameters + ---------- + G : NetworkX graph + The graph is assumed to have node attribute `part` set to 0,1 representing + the two graph parts + + delimiter : string, optional + Separator for node labels + + data : bool or list of keys + If False generate no edge data. If True use a dictionary + representation of edge data. If a list of keys use a list of data + values corresponding to the keys. + + Returns + ------- + lines : string + Lines of data in adjlist format. + + Examples + -------- + >>> from networkx.algorithms import bipartite + >>> G = nx.path_graph(4) + >>> G.add_nodes_from([0, 2], bipartite=0) + >>> G.add_nodes_from([1, 3], bipartite=1) + >>> G[1][2]["weight"] = 3 + >>> G[2][3]["capacity"] = 12 + >>> for line in bipartite.generate_edgelist(G, data=False): + ... print(line) + 0 1 + 2 1 + 2 3 + + >>> for line in bipartite.generate_edgelist(G): + ... print(line) + 0 1 {} + 2 1 {'weight': 3} + 2 3 {'capacity': 12} + + >>> for line in bipartite.generate_edgelist(G, data=["weight"]): + ... print(line) + 0 1 + 2 1 3 + 2 3 + """ + try: + part0 = [n for n, d in G.nodes.items() if d["bipartite"] == 0] + except BaseException as err: + raise AttributeError("Missing node attribute `bipartite`") from err + if data is True or data is False: + for n in part0: + for edge in G.edges(n, data=data): + yield delimiter.join(map(str, edge)) + else: + for n in part0: + for u, v, d in G.edges(n, data=True): + edge = [u, v] + try: + edge.extend(d[k] for k in data) + except KeyError: + pass # missing data for this edge, should warn? + yield delimiter.join(map(str, edge)) + + +@nx._dispatchable(name="bipartite_parse_edgelist", graphs=None, returns_graph=True) +def parse_edgelist( + lines, comments="#", delimiter=None, create_using=None, nodetype=None, data=True +): + """Parse lines of an edge list representation of a bipartite graph. + + Parameters + ---------- + lines : list or iterator of strings + Input data in edgelist format + comments : string, optional + Marker for comment lines + delimiter : string, optional + Separator for node labels + create_using: NetworkX graph container, optional + Use given NetworkX graph for holding nodes or edges. + nodetype : Python type, optional + Convert nodes to this type. + data : bool or list of (label,type) tuples + If False generate no edge data or if True use a dictionary + representation of edge data or a list tuples specifying dictionary + key names and types for edge data. + + Returns + ------- + G: NetworkX Graph + The bipartite graph corresponding to lines + + Examples + -------- + Edgelist with no data: + + >>> from networkx.algorithms import bipartite + >>> lines = ["1 2", "2 3", "3 4"] + >>> G = bipartite.parse_edgelist(lines, nodetype=int) + >>> sorted(G.nodes()) + [1, 2, 3, 4] + >>> sorted(G.nodes(data=True)) + [(1, {'bipartite': 0}), (2, {'bipartite': 0}), (3, {'bipartite': 0}), (4, {'bipartite': 1})] + >>> sorted(G.edges()) + [(1, 2), (2, 3), (3, 4)] + + Edgelist with data in Python dictionary representation: + + >>> lines = ["1 2 {'weight':3}", "2 3 {'weight':27}", "3 4 {'weight':3.0}"] + >>> G = bipartite.parse_edgelist(lines, nodetype=int) + >>> sorted(G.nodes()) + [1, 2, 3, 4] + >>> sorted(G.edges(data=True)) + [(1, 2, {'weight': 3}), (2, 3, {'weight': 27}), (3, 4, {'weight': 3.0})] + + Edgelist with data in a list: + + >>> lines = ["1 2 3", "2 3 27", "3 4 3.0"] + >>> G = bipartite.parse_edgelist(lines, nodetype=int, data=(("weight", float),)) + >>> sorted(G.nodes()) + [1, 2, 3, 4] + >>> sorted(G.edges(data=True)) + [(1, 2, {'weight': 3.0}), (2, 3, {'weight': 27.0}), (3, 4, {'weight': 3.0})] + + See Also + -------- + """ + from ast import literal_eval + + G = nx.empty_graph(0, create_using) + for line in lines: + p = line.find(comments) + if p >= 0: + line = line[:p] + if not len(line): + continue + # split line, should have 2 or more + s = line.rstrip("\n").split(delimiter) + if len(s) < 2: + continue + u = s.pop(0) + v = s.pop(0) + d = s + if nodetype is not None: + try: + u = nodetype(u) + v = nodetype(v) + except BaseException as err: + raise TypeError( + f"Failed to convert nodes {u},{v} to type {nodetype}." + ) from err + + if len(d) == 0 or data is False: + # no data or data type specified + edgedata = {} + elif data is True: + # no edge types specified + try: # try to evaluate as dictionary + edgedata = dict(literal_eval(" ".join(d))) + except BaseException as err: + raise TypeError( + f"Failed to convert edge data ({d}) to dictionary." + ) from err + else: + # convert edge data to dictionary with specified keys and type + if len(d) != len(data): + raise IndexError( + f"Edge data {d} and data_keys {data} are not the same length" + ) + edgedata = {} + for (edge_key, edge_type), edge_value in zip(data, d): + try: + edge_value = edge_type(edge_value) + except BaseException as err: + raise TypeError( + f"Failed to convert {edge_key} data " + f"{edge_value} to type {edge_type}." + ) from err + edgedata.update({edge_key: edge_value}) + G.add_node(u, bipartite=0) + G.add_node(v, bipartite=1) + G.add_edge(u, v, **edgedata) + return G + + +@open_file(0, mode="rb") +@nx._dispatchable(name="bipartite_read_edgelist", graphs=None, returns_graph=True) +def read_edgelist( + path, + comments="#", + delimiter=None, + create_using=None, + nodetype=None, + data=True, + edgetype=None, + encoding="utf-8", +): + """Read a bipartite graph from a list of edges. + + Parameters + ---------- + path : file or string + File or filename to read. If a file is provided, it must be + opened in 'rb' mode. + Filenames ending in .gz or .bz2 will be uncompressed. + comments : string, optional + The character used to indicate the start of a comment. + delimiter : string, optional + The string used to separate values. The default is whitespace. + create_using : Graph container, optional, + Use specified container to build graph. The default is networkx.Graph, + an undirected graph. + nodetype : int, float, str, Python type, optional + Convert node data from strings to specified type + data : bool or list of (label,type) tuples + Tuples specifying dictionary key names and types for edge data + edgetype : int, float, str, Python type, optional OBSOLETE + Convert edge data from strings to specified type and use as 'weight' + encoding: string, optional + Specify which encoding to use when reading file. + + Returns + ------- + G : graph + A networkx Graph or other type specified with create_using + + Examples + -------- + >>> from networkx.algorithms import bipartite + >>> G = nx.path_graph(4) + >>> G.add_nodes_from([0, 2], bipartite=0) + >>> G.add_nodes_from([1, 3], bipartite=1) + >>> bipartite.write_edgelist(G, "test.edgelist") + >>> G = bipartite.read_edgelist("test.edgelist") + + >>> fh = open("test.edgelist", "rb") + >>> G = bipartite.read_edgelist(fh) + >>> fh.close() + + >>> G = bipartite.read_edgelist("test.edgelist", nodetype=int) + + Edgelist with data in a list: + + >>> textline = "1 2 3" + >>> fh = open("test.edgelist", "w") + >>> d = fh.write(textline) + >>> fh.close() + >>> G = bipartite.read_edgelist( + ... "test.edgelist", nodetype=int, data=(("weight", float),) + ... ) + >>> list(G) + [1, 2] + >>> list(G.edges(data=True)) + [(1, 2, {'weight': 3.0})] + + See parse_edgelist() for more examples of formatting. + + See Also + -------- + parse_edgelist + + Notes + ----- + Since nodes must be hashable, the function nodetype must return hashable + types (e.g. int, float, str, frozenset - or tuples of those, etc.) + """ + lines = (line.decode(encoding) for line in path) + return parse_edgelist( + lines, + comments=comments, + delimiter=delimiter, + create_using=create_using, + nodetype=nodetype, + data=data, + ) diff --git a/lib/python3.10/site-packages/networkx/algorithms/bipartite/extendability.py b/lib/python3.10/site-packages/networkx/algorithms/bipartite/extendability.py new file mode 100644 index 0000000000000000000000000000000000000000..61d8d067d9792659ed7097340c5ece28d9dc2e8c --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/bipartite/extendability.py @@ -0,0 +1,105 @@ +"""Provides a function for computing the extendability of a graph which is +undirected, simple, connected and bipartite and contains at least one perfect matching.""" + +import networkx as nx +from networkx.utils import not_implemented_for + +__all__ = ["maximal_extendability"] + + +@not_implemented_for("directed") +@not_implemented_for("multigraph") +@nx._dispatchable +def maximal_extendability(G): + """Computes the extendability of a graph. + + The extendability of a graph is defined as the maximum $k$ for which `G` + is $k$-extendable. Graph `G` is $k$-extendable if and only if `G` has a + perfect matching and every set of $k$ independent edges can be extended + to a perfect matching in `G`. + + Parameters + ---------- + G : NetworkX Graph + A fully-connected bipartite graph without self-loops + + Returns + ------- + extendability : int + + Raises + ------ + NetworkXError + If the graph `G` is disconnected. + If the graph `G` is not bipartite. + If the graph `G` does not contain a perfect matching. + If the residual graph of `G` is not strongly connected. + + Notes + ----- + Definition: + Let `G` be a simple, connected, undirected and bipartite graph with a perfect + matching M and bipartition (U,V). The residual graph of `G`, denoted by $G_M$, + is the graph obtained from G by directing the edges of M from V to U and the + edges that do not belong to M from U to V. + + Lemma [1]_ : + Let M be a perfect matching of `G`. `G` is $k$-extendable if and only if its residual + graph $G_M$ is strongly connected and there are $k$ vertex-disjoint directed + paths between every vertex of U and every vertex of V. + + Assuming that input graph `G` is undirected, simple, connected, bipartite and contains + a perfect matching M, this function constructs the residual graph $G_M$ of G and + returns the minimum value among the maximum vertex-disjoint directed paths between + every vertex of U and every vertex of V in $G_M$. By combining the definitions + and the lemma, this value represents the extendability of the graph `G`. + + Time complexity O($n^3$ $m^2$)) where $n$ is the number of vertices + and $m$ is the number of edges. + + References + ---------- + .. [1] "A polynomial algorithm for the extendability problem in bipartite graphs", + J. Lakhal, L. Litzler, Information Processing Letters, 1998. + .. [2] "On n-extendible graphs", M. D. Plummer, Discrete Mathematics, 31:201–210, 1980 + https://doi.org/10.1016/0012-365X(80)90037-0 + + """ + if not nx.is_connected(G): + raise nx.NetworkXError("Graph G is not connected") + + if not nx.bipartite.is_bipartite(G): + raise nx.NetworkXError("Graph G is not bipartite") + + U, V = nx.bipartite.sets(G) + + maximum_matching = nx.bipartite.hopcroft_karp_matching(G) + + if not nx.is_perfect_matching(G, maximum_matching): + raise nx.NetworkXError("Graph G does not contain a perfect matching") + + # list of edges in perfect matching, directed from V to U + pm = [(node, maximum_matching[node]) for node in V & maximum_matching.keys()] + + # Direct all the edges of G, from V to U if in matching, else from U to V + directed_edges = [ + (x, y) if (x in V and (x, y) in pm) or (x in U and (y, x) not in pm) else (y, x) + for x, y in G.edges + ] + + # Construct the residual graph of G + residual_G = nx.DiGraph() + residual_G.add_nodes_from(G) + residual_G.add_edges_from(directed_edges) + + if not nx.is_strongly_connected(residual_G): + raise nx.NetworkXError("The residual graph of G is not strongly connected") + + # For node-pairs between V & U, keep min of max number of node-disjoint paths + # Variable $k$ stands for the extendability of graph G + k = float("inf") + for u in U: + for v in V: + num_paths = sum(1 for _ in nx.node_disjoint_paths(residual_G, u, v)) + k = k if k < num_paths else num_paths + return k diff --git a/lib/python3.10/site-packages/networkx/algorithms/bipartite/generators.py b/lib/python3.10/site-packages/networkx/algorithms/bipartite/generators.py new file mode 100644 index 0000000000000000000000000000000000000000..e8428f6bc77c4ec550f495477f7429b43129226d --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/bipartite/generators.py @@ -0,0 +1,604 @@ +""" +Generators and functions for bipartite graphs. +""" + +import math +import numbers +from functools import reduce + +import networkx as nx +from networkx.utils import nodes_or_number, py_random_state + +__all__ = [ + "configuration_model", + "havel_hakimi_graph", + "reverse_havel_hakimi_graph", + "alternating_havel_hakimi_graph", + "preferential_attachment_graph", + "random_graph", + "gnmk_random_graph", + "complete_bipartite_graph", +] + + +@nx._dispatchable(graphs=None, returns_graph=True) +@nodes_or_number([0, 1]) +def complete_bipartite_graph(n1, n2, create_using=None): + """Returns the complete bipartite graph `K_{n_1,n_2}`. + + The graph is composed of two partitions with nodes 0 to (n1 - 1) + in the first and nodes n1 to (n1 + n2 - 1) in the second. + Each node in the first is connected to each node in the second. + + Parameters + ---------- + n1, n2 : integer or iterable container of nodes + If integers, nodes are from `range(n1)` and `range(n1, n1 + n2)`. + If a container, the elements are the nodes. + create_using : NetworkX graph instance, (default: nx.Graph) + Return graph of this type. + + Notes + ----- + Nodes are the integers 0 to `n1 + n2 - 1` unless either n1 or n2 are + containers of nodes. If only one of n1 or n2 are integers, that + integer is replaced by `range` of that integer. + + The nodes are assigned the attribute 'bipartite' with the value 0 or 1 + to indicate which bipartite set the node belongs to. + + This function is not imported in the main namespace. + To use it use nx.bipartite.complete_bipartite_graph + """ + G = nx.empty_graph(0, create_using) + if G.is_directed(): + raise nx.NetworkXError("Directed Graph not supported") + + n1, top = n1 + n2, bottom = n2 + if isinstance(n1, numbers.Integral) and isinstance(n2, numbers.Integral): + bottom = [n1 + i for i in bottom] + G.add_nodes_from(top, bipartite=0) + G.add_nodes_from(bottom, bipartite=1) + if len(G) != len(top) + len(bottom): + raise nx.NetworkXError("Inputs n1 and n2 must contain distinct nodes") + G.add_edges_from((u, v) for u in top for v in bottom) + G.graph["name"] = f"complete_bipartite_graph({len(top)}, {len(bottom)})" + return G + + +@py_random_state(3) +@nx._dispatchable(name="bipartite_configuration_model", graphs=None, returns_graph=True) +def configuration_model(aseq, bseq, create_using=None, seed=None): + """Returns a random bipartite graph from two given degree sequences. + + Parameters + ---------- + aseq : list + Degree sequence for node set A. + bseq : list + Degree sequence for node set B. + create_using : NetworkX graph instance, optional + Return graph of this type. + seed : integer, random_state, or None (default) + Indicator of random number generation state. + See :ref:`Randomness`. + + The graph is composed of two partitions. Set A has nodes 0 to + (len(aseq) - 1) and set B has nodes len(aseq) to (len(bseq) - 1). + Nodes from set A are connected to nodes in set B by choosing + randomly from the possible free stubs, one in A and one in B. + + Notes + ----- + The sum of the two sequences must be equal: sum(aseq)=sum(bseq) + If no graph type is specified use MultiGraph with parallel edges. + If you want a graph with no parallel edges use create_using=Graph() + but then the resulting degree sequences might not be exact. + + The nodes are assigned the attribute 'bipartite' with the value 0 or 1 + to indicate which bipartite set the node belongs to. + + This function is not imported in the main namespace. + To use it use nx.bipartite.configuration_model + """ + G = nx.empty_graph(0, create_using, default=nx.MultiGraph) + if G.is_directed(): + raise nx.NetworkXError("Directed Graph not supported") + + # length and sum of each sequence + lena = len(aseq) + lenb = len(bseq) + suma = sum(aseq) + sumb = sum(bseq) + + if not suma == sumb: + raise nx.NetworkXError( + f"invalid degree sequences, sum(aseq)!=sum(bseq),{suma},{sumb}" + ) + + G = _add_nodes_with_bipartite_label(G, lena, lenb) + + if len(aseq) == 0 or max(aseq) == 0: + return G # done if no edges + + # build lists of degree-repeated vertex numbers + stubs = [[v] * aseq[v] for v in range(lena)] + astubs = [x for subseq in stubs for x in subseq] + + stubs = [[v] * bseq[v - lena] for v in range(lena, lena + lenb)] + bstubs = [x for subseq in stubs for x in subseq] + + # shuffle lists + seed.shuffle(astubs) + seed.shuffle(bstubs) + + G.add_edges_from([astubs[i], bstubs[i]] for i in range(suma)) + + G.name = "bipartite_configuration_model" + return G + + +@nx._dispatchable(name="bipartite_havel_hakimi_graph", graphs=None, returns_graph=True) +def havel_hakimi_graph(aseq, bseq, create_using=None): + """Returns a bipartite graph from two given degree sequences using a + Havel-Hakimi style construction. + + The graph is composed of two partitions. Set A has nodes 0 to + (len(aseq) - 1) and set B has nodes len(aseq) to (len(bseq) - 1). + Nodes from the set A are connected to nodes in the set B by + connecting the highest degree nodes in set A to the highest degree + nodes in set B until all stubs are connected. + + Parameters + ---------- + aseq : list + Degree sequence for node set A. + bseq : list + Degree sequence for node set B. + create_using : NetworkX graph instance, optional + Return graph of this type. + + Notes + ----- + The sum of the two sequences must be equal: sum(aseq)=sum(bseq) + If no graph type is specified use MultiGraph with parallel edges. + If you want a graph with no parallel edges use create_using=Graph() + but then the resulting degree sequences might not be exact. + + The nodes are assigned the attribute 'bipartite' with the value 0 or 1 + to indicate which bipartite set the node belongs to. + + This function is not imported in the main namespace. + To use it use nx.bipartite.havel_hakimi_graph + """ + G = nx.empty_graph(0, create_using, default=nx.MultiGraph) + if G.is_directed(): + raise nx.NetworkXError("Directed Graph not supported") + + # length of the each sequence + naseq = len(aseq) + nbseq = len(bseq) + + suma = sum(aseq) + sumb = sum(bseq) + + if not suma == sumb: + raise nx.NetworkXError( + f"invalid degree sequences, sum(aseq)!=sum(bseq),{suma},{sumb}" + ) + + G = _add_nodes_with_bipartite_label(G, naseq, nbseq) + + if len(aseq) == 0 or max(aseq) == 0: + return G # done if no edges + + # build list of degree-repeated vertex numbers + astubs = [[aseq[v], v] for v in range(naseq)] + bstubs = [[bseq[v - naseq], v] for v in range(naseq, naseq + nbseq)] + astubs.sort() + while astubs: + (degree, u) = astubs.pop() # take of largest degree node in the a set + if degree == 0: + break # done, all are zero + # connect the source to largest degree nodes in the b set + bstubs.sort() + for target in bstubs[-degree:]: + v = target[1] + G.add_edge(u, v) + target[0] -= 1 # note this updates bstubs too. + if target[0] == 0: + bstubs.remove(target) + + G.name = "bipartite_havel_hakimi_graph" + return G + + +@nx._dispatchable(graphs=None, returns_graph=True) +def reverse_havel_hakimi_graph(aseq, bseq, create_using=None): + """Returns a bipartite graph from two given degree sequences using a + Havel-Hakimi style construction. + + The graph is composed of two partitions. Set A has nodes 0 to + (len(aseq) - 1) and set B has nodes len(aseq) to (len(bseq) - 1). + Nodes from set A are connected to nodes in the set B by connecting + the highest degree nodes in set A to the lowest degree nodes in + set B until all stubs are connected. + + Parameters + ---------- + aseq : list + Degree sequence for node set A. + bseq : list + Degree sequence for node set B. + create_using : NetworkX graph instance, optional + Return graph of this type. + + Notes + ----- + The sum of the two sequences must be equal: sum(aseq)=sum(bseq) + If no graph type is specified use MultiGraph with parallel edges. + If you want a graph with no parallel edges use create_using=Graph() + but then the resulting degree sequences might not be exact. + + The nodes are assigned the attribute 'bipartite' with the value 0 or 1 + to indicate which bipartite set the node belongs to. + + This function is not imported in the main namespace. + To use it use nx.bipartite.reverse_havel_hakimi_graph + """ + G = nx.empty_graph(0, create_using, default=nx.MultiGraph) + if G.is_directed(): + raise nx.NetworkXError("Directed Graph not supported") + + # length of the each sequence + lena = len(aseq) + lenb = len(bseq) + suma = sum(aseq) + sumb = sum(bseq) + + if not suma == sumb: + raise nx.NetworkXError( + f"invalid degree sequences, sum(aseq)!=sum(bseq),{suma},{sumb}" + ) + + G = _add_nodes_with_bipartite_label(G, lena, lenb) + + if len(aseq) == 0 or max(aseq) == 0: + return G # done if no edges + + # build list of degree-repeated vertex numbers + astubs = [[aseq[v], v] for v in range(lena)] + bstubs = [[bseq[v - lena], v] for v in range(lena, lena + lenb)] + astubs.sort() + bstubs.sort() + while astubs: + (degree, u) = astubs.pop() # take of largest degree node in the a set + if degree == 0: + break # done, all are zero + # connect the source to the smallest degree nodes in the b set + for target in bstubs[0:degree]: + v = target[1] + G.add_edge(u, v) + target[0] -= 1 # note this updates bstubs too. + if target[0] == 0: + bstubs.remove(target) + + G.name = "bipartite_reverse_havel_hakimi_graph" + return G + + +@nx._dispatchable(graphs=None, returns_graph=True) +def alternating_havel_hakimi_graph(aseq, bseq, create_using=None): + """Returns a bipartite graph from two given degree sequences using + an alternating Havel-Hakimi style construction. + + The graph is composed of two partitions. Set A has nodes 0 to + (len(aseq) - 1) and set B has nodes len(aseq) to (len(bseq) - 1). + Nodes from the set A are connected to nodes in the set B by + connecting the highest degree nodes in set A to alternatively the + highest and the lowest degree nodes in set B until all stubs are + connected. + + Parameters + ---------- + aseq : list + Degree sequence for node set A. + bseq : list + Degree sequence for node set B. + create_using : NetworkX graph instance, optional + Return graph of this type. + + Notes + ----- + The sum of the two sequences must be equal: sum(aseq)=sum(bseq) + If no graph type is specified use MultiGraph with parallel edges. + If you want a graph with no parallel edges use create_using=Graph() + but then the resulting degree sequences might not be exact. + + The nodes are assigned the attribute 'bipartite' with the value 0 or 1 + to indicate which bipartite set the node belongs to. + + This function is not imported in the main namespace. + To use it use nx.bipartite.alternating_havel_hakimi_graph + """ + G = nx.empty_graph(0, create_using, default=nx.MultiGraph) + if G.is_directed(): + raise nx.NetworkXError("Directed Graph not supported") + + # length of the each sequence + naseq = len(aseq) + nbseq = len(bseq) + suma = sum(aseq) + sumb = sum(bseq) + + if not suma == sumb: + raise nx.NetworkXError( + f"invalid degree sequences, sum(aseq)!=sum(bseq),{suma},{sumb}" + ) + + G = _add_nodes_with_bipartite_label(G, naseq, nbseq) + + if len(aseq) == 0 or max(aseq) == 0: + return G # done if no edges + # build list of degree-repeated vertex numbers + astubs = [[aseq[v], v] for v in range(naseq)] + bstubs = [[bseq[v - naseq], v] for v in range(naseq, naseq + nbseq)] + while astubs: + astubs.sort() + (degree, u) = astubs.pop() # take of largest degree node in the a set + if degree == 0: + break # done, all are zero + bstubs.sort() + small = bstubs[0 : degree // 2] # add these low degree targets + large = bstubs[(-degree + degree // 2) :] # now high degree targets + stubs = [x for z in zip(large, small) for x in z] # combine, sorry + if len(stubs) < len(small) + len(large): # check for zip truncation + stubs.append(large.pop()) + for target in stubs: + v = target[1] + G.add_edge(u, v) + target[0] -= 1 # note this updates bstubs too. + if target[0] == 0: + bstubs.remove(target) + + G.name = "bipartite_alternating_havel_hakimi_graph" + return G + + +@py_random_state(3) +@nx._dispatchable(graphs=None, returns_graph=True) +def preferential_attachment_graph(aseq, p, create_using=None, seed=None): + """Create a bipartite graph with a preferential attachment model from + a given single degree sequence. + + The graph is composed of two partitions. Set A has nodes 0 to + (len(aseq) - 1) and set B has nodes starting with node len(aseq). + The number of nodes in set B is random. + + Parameters + ---------- + aseq : list + Degree sequence for node set A. + p : float + Probability that a new bottom node is added. + create_using : NetworkX graph instance, optional + Return graph of this type. + seed : integer, random_state, or None (default) + Indicator of random number generation state. + See :ref:`Randomness`. + + References + ---------- + .. [1] Guillaume, J.L. and Latapy, M., + Bipartite graphs as models of complex networks. + Physica A: Statistical Mechanics and its Applications, + 2006, 371(2), pp.795-813. + .. [2] Jean-Loup Guillaume and Matthieu Latapy, + Bipartite structure of all complex networks, + Inf. Process. Lett. 90, 2004, pg. 215-221 + https://doi.org/10.1016/j.ipl.2004.03.007 + + Notes + ----- + The nodes are assigned the attribute 'bipartite' with the value 0 or 1 + to indicate which bipartite set the node belongs to. + + This function is not imported in the main namespace. + To use it use nx.bipartite.preferential_attachment_graph + """ + G = nx.empty_graph(0, create_using, default=nx.MultiGraph) + if G.is_directed(): + raise nx.NetworkXError("Directed Graph not supported") + + if p > 1: + raise nx.NetworkXError(f"probability {p} > 1") + + naseq = len(aseq) + G = _add_nodes_with_bipartite_label(G, naseq, 0) + vv = [[v] * aseq[v] for v in range(naseq)] + while vv: + while vv[0]: + source = vv[0][0] + vv[0].remove(source) + if seed.random() < p or len(G) == naseq: + target = len(G) + G.add_node(target, bipartite=1) + G.add_edge(source, target) + else: + bb = [[b] * G.degree(b) for b in range(naseq, len(G))] + # flatten the list of lists into a list. + bbstubs = reduce(lambda x, y: x + y, bb) + # choose preferentially a bottom node. + target = seed.choice(bbstubs) + G.add_node(target, bipartite=1) + G.add_edge(source, target) + vv.remove(vv[0]) + G.name = "bipartite_preferential_attachment_model" + return G + + +@py_random_state(3) +@nx._dispatchable(graphs=None, returns_graph=True) +def random_graph(n, m, p, seed=None, directed=False): + """Returns a bipartite random graph. + + This is a bipartite version of the binomial (Erdős-Rényi) graph. + The graph is composed of two partitions. Set A has nodes 0 to + (n - 1) and set B has nodes n to (n + m - 1). + + Parameters + ---------- + n : int + The number of nodes in the first bipartite set. + m : int + The number of nodes in the second bipartite set. + p : float + Probability for edge creation. + seed : integer, random_state, or None (default) + Indicator of random number generation state. + See :ref:`Randomness`. + directed : bool, optional (default=False) + If True return a directed graph + + Notes + ----- + The bipartite random graph algorithm chooses each of the n*m (undirected) + or 2*nm (directed) possible edges with probability p. + + This algorithm is $O(n+m)$ where $m$ is the expected number of edges. + + The nodes are assigned the attribute 'bipartite' with the value 0 or 1 + to indicate which bipartite set the node belongs to. + + This function is not imported in the main namespace. + To use it use nx.bipartite.random_graph + + See Also + -------- + gnp_random_graph, configuration_model + + References + ---------- + .. [1] Vladimir Batagelj and Ulrik Brandes, + "Efficient generation of large random networks", + Phys. Rev. E, 71, 036113, 2005. + """ + G = nx.Graph() + G = _add_nodes_with_bipartite_label(G, n, m) + if directed: + G = nx.DiGraph(G) + G.name = f"fast_gnp_random_graph({n},{m},{p})" + + if p <= 0: + return G + if p >= 1: + return nx.complete_bipartite_graph(n, m) + + lp = math.log(1.0 - p) + + v = 0 + w = -1 + while v < n: + lr = math.log(1.0 - seed.random()) + w = w + 1 + int(lr / lp) + while w >= m and v < n: + w = w - m + v = v + 1 + if v < n: + G.add_edge(v, n + w) + + if directed: + # use the same algorithm to + # add edges from the "m" to "n" set + v = 0 + w = -1 + while v < n: + lr = math.log(1.0 - seed.random()) + w = w + 1 + int(lr / lp) + while w >= m and v < n: + w = w - m + v = v + 1 + if v < n: + G.add_edge(n + w, v) + + return G + + +@py_random_state(3) +@nx._dispatchable(graphs=None, returns_graph=True) +def gnmk_random_graph(n, m, k, seed=None, directed=False): + """Returns a random bipartite graph G_{n,m,k}. + + Produces a bipartite graph chosen randomly out of the set of all graphs + with n top nodes, m bottom nodes, and k edges. + The graph is composed of two sets of nodes. + Set A has nodes 0 to (n - 1) and set B has nodes n to (n + m - 1). + + Parameters + ---------- + n : int + The number of nodes in the first bipartite set. + m : int + The number of nodes in the second bipartite set. + k : int + The number of edges + seed : integer, random_state, or None (default) + Indicator of random number generation state. + See :ref:`Randomness`. + directed : bool, optional (default=False) + If True return a directed graph + + Examples + -------- + from nx.algorithms import bipartite + G = bipartite.gnmk_random_graph(10,20,50) + + See Also + -------- + gnm_random_graph + + Notes + ----- + If k > m * n then a complete bipartite graph is returned. + + This graph is a bipartite version of the `G_{nm}` random graph model. + + The nodes are assigned the attribute 'bipartite' with the value 0 or 1 + to indicate which bipartite set the node belongs to. + + This function is not imported in the main namespace. + To use it use nx.bipartite.gnmk_random_graph + """ + G = nx.Graph() + G = _add_nodes_with_bipartite_label(G, n, m) + if directed: + G = nx.DiGraph(G) + G.name = f"bipartite_gnm_random_graph({n},{m},{k})" + if n == 1 or m == 1: + return G + max_edges = n * m # max_edges for bipartite networks + if k >= max_edges: # Maybe we should raise an exception here + return nx.complete_bipartite_graph(n, m, create_using=G) + + top = [n for n, d in G.nodes(data=True) if d["bipartite"] == 0] + bottom = list(set(G) - set(top)) + edge_count = 0 + while edge_count < k: + # generate random edge,u,v + u = seed.choice(top) + v = seed.choice(bottom) + if v in G[u]: + continue + else: + G.add_edge(u, v) + edge_count += 1 + return G + + +def _add_nodes_with_bipartite_label(G, lena, lenb): + G.add_nodes_from(range(lena + lenb)) + b = dict(zip(range(lena), [0] * lena)) + b.update(dict(zip(range(lena, lena + lenb), [1] * lenb))) + nx.set_node_attributes(G, b, "bipartite") + return G diff --git a/lib/python3.10/site-packages/networkx/algorithms/bipartite/matching.py b/lib/python3.10/site-packages/networkx/algorithms/bipartite/matching.py new file mode 100644 index 0000000000000000000000000000000000000000..38a174780ac1eb7a42568aa6752b9adb82a2d984 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/bipartite/matching.py @@ -0,0 +1,590 @@ +# This module uses material from the Wikipedia article Hopcroft--Karp algorithm +# , accessed on +# January 3, 2015, which is released under the Creative Commons +# Attribution-Share-Alike License 3.0 +# . That article includes +# pseudocode, which has been translated into the corresponding Python code. +# +# Portions of this module use code from David Eppstein's Python Algorithms and +# Data Structures (PADS) library, which is dedicated to the public domain (for +# proof, see ). +"""Provides functions for computing maximum cardinality matchings and minimum +weight full matchings in a bipartite graph. + +If you don't care about the particular implementation of the maximum matching +algorithm, simply use the :func:`maximum_matching`. If you do care, you can +import one of the named maximum matching algorithms directly. + +For example, to find a maximum matching in the complete bipartite graph with +two vertices on the left and three vertices on the right: + +>>> G = nx.complete_bipartite_graph(2, 3) +>>> left, right = nx.bipartite.sets(G) +>>> list(left) +[0, 1] +>>> list(right) +[2, 3, 4] +>>> nx.bipartite.maximum_matching(G) +{0: 2, 1: 3, 2: 0, 3: 1} + +The dictionary returned by :func:`maximum_matching` includes a mapping for +vertices in both the left and right vertex sets. + +Similarly, :func:`minimum_weight_full_matching` produces, for a complete +weighted bipartite graph, a matching whose cardinality is the cardinality of +the smaller of the two partitions, and for which the sum of the weights of the +edges included in the matching is minimal. + +""" + +import collections +import itertools + +import networkx as nx +from networkx.algorithms.bipartite import sets as bipartite_sets +from networkx.algorithms.bipartite.matrix import biadjacency_matrix + +__all__ = [ + "maximum_matching", + "hopcroft_karp_matching", + "eppstein_matching", + "to_vertex_cover", + "minimum_weight_full_matching", +] + +INFINITY = float("inf") + + +@nx._dispatchable +def hopcroft_karp_matching(G, top_nodes=None): + """Returns the maximum cardinality matching of the bipartite graph `G`. + + A matching is a set of edges that do not share any nodes. A maximum + cardinality matching is a matching with the most edges possible. It + is not always unique. Finding a matching in a bipartite graph can be + treated as a networkx flow problem. + + The functions ``hopcroft_karp_matching`` and ``maximum_matching`` + are aliases of the same function. + + Parameters + ---------- + G : NetworkX graph + + Undirected bipartite graph + + top_nodes : container of nodes + + Container with all nodes in one bipartite node set. If not supplied + it will be computed. But if more than one solution exists an exception + will be raised. + + Returns + ------- + matches : dictionary + + The matching is returned as a dictionary, `matches`, such that + ``matches[v] == w`` if node `v` is matched to node `w`. Unmatched + nodes do not occur as a key in `matches`. + + Raises + ------ + AmbiguousSolution + Raised if the input bipartite graph is disconnected and no container + with all nodes in one bipartite set is provided. When determining + the nodes in each bipartite set more than one valid solution is + possible if the input graph is disconnected. + + Notes + ----- + This function is implemented with the `Hopcroft--Karp matching algorithm + `_ for + bipartite graphs. + + See :mod:`bipartite documentation ` + for further details on how bipartite graphs are handled in NetworkX. + + See Also + -------- + maximum_matching + hopcroft_karp_matching + eppstein_matching + + References + ---------- + .. [1] John E. Hopcroft and Richard M. Karp. "An n^{5 / 2} Algorithm for + Maximum Matchings in Bipartite Graphs" In: **SIAM Journal of Computing** + 2.4 (1973), pp. 225--231. . + + """ + + # First we define some auxiliary search functions. + # + # If you are a human reading these auxiliary search functions, the "global" + # variables `leftmatches`, `rightmatches`, `distances`, etc. are defined + # below the functions, so that they are initialized close to the initial + # invocation of the search functions. + def breadth_first_search(): + for v in left: + if leftmatches[v] is None: + distances[v] = 0 + queue.append(v) + else: + distances[v] = INFINITY + distances[None] = INFINITY + while queue: + v = queue.popleft() + if distances[v] < distances[None]: + for u in G[v]: + if distances[rightmatches[u]] is INFINITY: + distances[rightmatches[u]] = distances[v] + 1 + queue.append(rightmatches[u]) + return distances[None] is not INFINITY + + def depth_first_search(v): + if v is not None: + for u in G[v]: + if distances[rightmatches[u]] == distances[v] + 1: + if depth_first_search(rightmatches[u]): + rightmatches[u] = v + leftmatches[v] = u + return True + distances[v] = INFINITY + return False + return True + + # Initialize the "global" variables that maintain state during the search. + left, right = bipartite_sets(G, top_nodes) + leftmatches = {v: None for v in left} + rightmatches = {v: None for v in right} + distances = {} + queue = collections.deque() + + # Implementation note: this counter is incremented as pairs are matched but + # it is currently not used elsewhere in the computation. + num_matched_pairs = 0 + while breadth_first_search(): + for v in left: + if leftmatches[v] is None: + if depth_first_search(v): + num_matched_pairs += 1 + + # Strip the entries matched to `None`. + leftmatches = {k: v for k, v in leftmatches.items() if v is not None} + rightmatches = {k: v for k, v in rightmatches.items() if v is not None} + + # At this point, the left matches and the right matches are inverses of one + # another. In other words, + # + # leftmatches == {v, k for k, v in rightmatches.items()} + # + # Finally, we combine both the left matches and right matches. + return dict(itertools.chain(leftmatches.items(), rightmatches.items())) + + +@nx._dispatchable +def eppstein_matching(G, top_nodes=None): + """Returns the maximum cardinality matching of the bipartite graph `G`. + + Parameters + ---------- + G : NetworkX graph + + Undirected bipartite graph + + top_nodes : container + + Container with all nodes in one bipartite node set. If not supplied + it will be computed. But if more than one solution exists an exception + will be raised. + + Returns + ------- + matches : dictionary + + The matching is returned as a dictionary, `matching`, such that + ``matching[v] == w`` if node `v` is matched to node `w`. Unmatched + nodes do not occur as a key in `matching`. + + Raises + ------ + AmbiguousSolution + Raised if the input bipartite graph is disconnected and no container + with all nodes in one bipartite set is provided. When determining + the nodes in each bipartite set more than one valid solution is + possible if the input graph is disconnected. + + Notes + ----- + This function is implemented with David Eppstein's version of the algorithm + Hopcroft--Karp algorithm (see :func:`hopcroft_karp_matching`), which + originally appeared in the `Python Algorithms and Data Structures library + (PADS) `_. + + See :mod:`bipartite documentation ` + for further details on how bipartite graphs are handled in NetworkX. + + See Also + -------- + + hopcroft_karp_matching + + """ + # Due to its original implementation, a directed graph is needed + # so that the two sets of bipartite nodes can be distinguished + left, right = bipartite_sets(G, top_nodes) + G = nx.DiGraph(G.edges(left)) + # initialize greedy matching (redundant, but faster than full search) + matching = {} + for u in G: + for v in G[u]: + if v not in matching: + matching[v] = u + break + while True: + # structure residual graph into layers + # pred[u] gives the neighbor in the previous layer for u in U + # preds[v] gives a list of neighbors in the previous layer for v in V + # unmatched gives a list of unmatched vertices in final layer of V, + # and is also used as a flag value for pred[u] when u is in the first + # layer + preds = {} + unmatched = [] + pred = {u: unmatched for u in G} + for v in matching: + del pred[matching[v]] + layer = list(pred) + + # repeatedly extend layering structure by another pair of layers + while layer and not unmatched: + newLayer = {} + for u in layer: + for v in G[u]: + if v not in preds: + newLayer.setdefault(v, []).append(u) + layer = [] + for v in newLayer: + preds[v] = newLayer[v] + if v in matching: + layer.append(matching[v]) + pred[matching[v]] = v + else: + unmatched.append(v) + + # did we finish layering without finding any alternating paths? + if not unmatched: + # TODO - The lines between --- were unused and were thus commented + # out. This whole commented chunk should be reviewed to determine + # whether it should be built upon or completely removed. + # --- + # unlayered = {} + # for u in G: + # # TODO Why is extra inner loop necessary? + # for v in G[u]: + # if v not in preds: + # unlayered[v] = None + # --- + # TODO Originally, this function returned a three-tuple: + # + # return (matching, list(pred), list(unlayered)) + # + # For some reason, the documentation for this function + # indicated that the second and third elements of the returned + # three-tuple would be the vertices in the left and right vertex + # sets, respectively, that are also in the maximum independent set. + # However, what I think the author meant was that the second + # element is the list of vertices that were unmatched and the third + # element was the list of vertices that were matched. Since that + # seems to be the case, they don't really need to be returned, + # since that information can be inferred from the matching + # dictionary. + + # All the matched nodes must be a key in the dictionary + for key in matching.copy(): + matching[matching[key]] = key + return matching + + # recursively search backward through layers to find alternating paths + # recursion returns true if found path, false otherwise + def recurse(v): + if v in preds: + L = preds.pop(v) + for u in L: + if u in pred: + pu = pred.pop(u) + if pu is unmatched or recurse(pu): + matching[v] = u + return True + return False + + for v in unmatched: + recurse(v) + + +def _is_connected_by_alternating_path(G, v, matched_edges, unmatched_edges, targets): + """Returns True if and only if the vertex `v` is connected to one of + the target vertices by an alternating path in `G`. + + An *alternating path* is a path in which every other edge is in the + specified maximum matching (and the remaining edges in the path are not in + the matching). An alternating path may have matched edges in the even + positions or in the odd positions, as long as the edges alternate between + 'matched' and 'unmatched'. + + `G` is an undirected bipartite NetworkX graph. + + `v` is a vertex in `G`. + + `matched_edges` is a set of edges present in a maximum matching in `G`. + + `unmatched_edges` is a set of edges not present in a maximum + matching in `G`. + + `targets` is a set of vertices. + + """ + + def _alternating_dfs(u, along_matched=True): + """Returns True if and only if `u` is connected to one of the + targets by an alternating path. + + `u` is a vertex in the graph `G`. + + If `along_matched` is True, this step of the depth-first search + will continue only through edges in the given matching. Otherwise, it + will continue only through edges *not* in the given matching. + + """ + visited = set() + # Follow matched edges when depth is even, + # and follow unmatched edges when depth is odd. + initial_depth = 0 if along_matched else 1 + stack = [(u, iter(G[u]), initial_depth)] + while stack: + parent, children, depth = stack[-1] + valid_edges = matched_edges if depth % 2 else unmatched_edges + try: + child = next(children) + if child not in visited: + if (parent, child) in valid_edges or (child, parent) in valid_edges: + if child in targets: + return True + visited.add(child) + stack.append((child, iter(G[child]), depth + 1)) + except StopIteration: + stack.pop() + return False + + # Check for alternating paths starting with edges in the matching, then + # check for alternating paths starting with edges not in the + # matching. + return _alternating_dfs(v, along_matched=True) or _alternating_dfs( + v, along_matched=False + ) + + +def _connected_by_alternating_paths(G, matching, targets): + """Returns the set of vertices that are connected to one of the target + vertices by an alternating path in `G` or are themselves a target. + + An *alternating path* is a path in which every other edge is in the + specified maximum matching (and the remaining edges in the path are not in + the matching). An alternating path may have matched edges in the even + positions or in the odd positions, as long as the edges alternate between + 'matched' and 'unmatched'. + + `G` is an undirected bipartite NetworkX graph. + + `matching` is a dictionary representing a maximum matching in `G`, as + returned by, for example, :func:`maximum_matching`. + + `targets` is a set of vertices. + + """ + # Get the set of matched edges and the set of unmatched edges. Only include + # one version of each undirected edge (for example, include edge (1, 2) but + # not edge (2, 1)). Using frozensets as an intermediary step we do not + # require nodes to be orderable. + edge_sets = {frozenset((u, v)) for u, v in matching.items()} + matched_edges = {tuple(edge) for edge in edge_sets} + unmatched_edges = { + (u, v) for (u, v) in G.edges() if frozenset((u, v)) not in edge_sets + } + + return { + v + for v in G + if v in targets + or _is_connected_by_alternating_path( + G, v, matched_edges, unmatched_edges, targets + ) + } + + +@nx._dispatchable +def to_vertex_cover(G, matching, top_nodes=None): + """Returns the minimum vertex cover corresponding to the given maximum + matching of the bipartite graph `G`. + + Parameters + ---------- + G : NetworkX graph + + Undirected bipartite graph + + matching : dictionary + + A dictionary whose keys are vertices in `G` and whose values are the + distinct neighbors comprising the maximum matching for `G`, as returned + by, for example, :func:`maximum_matching`. The dictionary *must* + represent the maximum matching. + + top_nodes : container + + Container with all nodes in one bipartite node set. If not supplied + it will be computed. But if more than one solution exists an exception + will be raised. + + Returns + ------- + vertex_cover : :class:`set` + + The minimum vertex cover in `G`. + + Raises + ------ + AmbiguousSolution + Raised if the input bipartite graph is disconnected and no container + with all nodes in one bipartite set is provided. When determining + the nodes in each bipartite set more than one valid solution is + possible if the input graph is disconnected. + + Notes + ----- + This function is implemented using the procedure guaranteed by `Konig's + theorem + `_, + which proves an equivalence between a maximum matching and a minimum vertex + cover in bipartite graphs. + + Since a minimum vertex cover is the complement of a maximum independent set + for any graph, one can compute the maximum independent set of a bipartite + graph this way: + + >>> G = nx.complete_bipartite_graph(2, 3) + >>> matching = nx.bipartite.maximum_matching(G) + >>> vertex_cover = nx.bipartite.to_vertex_cover(G, matching) + >>> independent_set = set(G) - vertex_cover + >>> print(list(independent_set)) + [2, 3, 4] + + See :mod:`bipartite documentation ` + for further details on how bipartite graphs are handled in NetworkX. + + """ + # This is a Python implementation of the algorithm described at + # . + L, R = bipartite_sets(G, top_nodes) + # Let U be the set of unmatched vertices in the left vertex set. + unmatched_vertices = set(G) - set(matching) + U = unmatched_vertices & L + # Let Z be the set of vertices that are either in U or are connected to U + # by alternating paths. + Z = _connected_by_alternating_paths(G, matching, U) + # At this point, every edge either has a right endpoint in Z or a left + # endpoint not in Z. This gives us the vertex cover. + return (L - Z) | (R & Z) + + +#: Returns the maximum cardinality matching in the given bipartite graph. +#: +#: This function is simply an alias for :func:`hopcroft_karp_matching`. +maximum_matching = hopcroft_karp_matching + + +@nx._dispatchable(edge_attrs="weight") +def minimum_weight_full_matching(G, top_nodes=None, weight="weight"): + r"""Returns a minimum weight full matching of the bipartite graph `G`. + + Let :math:`G = ((U, V), E)` be a weighted bipartite graph with real weights + :math:`w : E \to \mathbb{R}`. This function then produces a matching + :math:`M \subseteq E` with cardinality + + .. math:: + \lvert M \rvert = \min(\lvert U \rvert, \lvert V \rvert), + + which minimizes the sum of the weights of the edges included in the + matching, :math:`\sum_{e \in M} w(e)`, or raises an error if no such + matching exists. + + When :math:`\lvert U \rvert = \lvert V \rvert`, this is commonly + referred to as a perfect matching; here, since we allow + :math:`\lvert U \rvert` and :math:`\lvert V \rvert` to differ, we + follow Karp [1]_ and refer to the matching as *full*. + + Parameters + ---------- + G : NetworkX graph + + Undirected bipartite graph + + top_nodes : container + + Container with all nodes in one bipartite node set. If not supplied + it will be computed. + + weight : string, optional (default='weight') + + The edge data key used to provide each value in the matrix. + If None, then each edge has weight 1. + + Returns + ------- + matches : dictionary + + The matching is returned as a dictionary, `matches`, such that + ``matches[v] == w`` if node `v` is matched to node `w`. Unmatched + nodes do not occur as a key in `matches`. + + Raises + ------ + ValueError + Raised if no full matching exists. + + ImportError + Raised if SciPy is not available. + + Notes + ----- + The problem of determining a minimum weight full matching is also known as + the rectangular linear assignment problem. This implementation defers the + calculation of the assignment to SciPy. + + References + ---------- + .. [1] Richard Manning Karp: + An algorithm to Solve the m x n Assignment Problem in Expected Time + O(mn log n). + Networks, 10(2):143–152, 1980. + + """ + import numpy as np + import scipy as sp + + left, right = nx.bipartite.sets(G, top_nodes) + U = list(left) + V = list(right) + # We explicitly create the biadjacency matrix having infinities + # where edges are missing (as opposed to zeros, which is what one would + # get by using toarray on the sparse matrix). + weights_sparse = biadjacency_matrix( + G, row_order=U, column_order=V, weight=weight, format="coo" + ) + weights = np.full(weights_sparse.shape, np.inf) + weights[weights_sparse.row, weights_sparse.col] = weights_sparse.data + left_matches = sp.optimize.linear_sum_assignment(weights) + d = {U[u]: V[v] for u, v in zip(*left_matches)} + # d will contain the matching from edges in left to right; we need to + # add the ones from right to left as well. + d.update({v: u for u, v in d.items()}) + return d diff --git a/lib/python3.10/site-packages/networkx/algorithms/bipartite/matrix.py b/lib/python3.10/site-packages/networkx/algorithms/bipartite/matrix.py new file mode 100644 index 0000000000000000000000000000000000000000..bbfa47c70089bdc23dc5bb91f0221918e87d29db --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/bipartite/matrix.py @@ -0,0 +1,168 @@ +""" +==================== +Biadjacency matrices +==================== +""" + +import itertools + +import networkx as nx +from networkx.convert_matrix import _generate_weighted_edges + +__all__ = ["biadjacency_matrix", "from_biadjacency_matrix"] + + +@nx._dispatchable(edge_attrs="weight") +def biadjacency_matrix( + G, row_order, column_order=None, dtype=None, weight="weight", format="csr" +): + r"""Returns the biadjacency matrix of the bipartite graph G. + + Let `G = (U, V, E)` be a bipartite graph with node sets + `U = u_{1},...,u_{r}` and `V = v_{1},...,v_{s}`. The biadjacency + matrix [1]_ is the `r` x `s` matrix `B` in which `b_{i,j} = 1` + if, and only if, `(u_i, v_j) \in E`. If the parameter `weight` is + not `None` and matches the name of an edge attribute, its value is + used instead of 1. + + Parameters + ---------- + G : graph + A NetworkX graph + + row_order : list of nodes + The rows of the matrix are ordered according to the list of nodes. + + column_order : list, optional + The columns of the matrix are ordered according to the list of nodes. + If column_order is None, then the ordering of columns is arbitrary. + + dtype : NumPy data-type, optional + A valid NumPy dtype used to initialize the array. If None, then the + NumPy default is used. + + weight : string or None, optional (default='weight') + The edge data key used to provide each value in the matrix. + If None, then each edge has weight 1. + + format : str in {'bsr', 'csr', 'csc', 'coo', 'lil', 'dia', 'dok'} + The type of the matrix to be returned (default 'csr'). For + some algorithms different implementations of sparse matrices + can perform better. See [2]_ for details. + + Returns + ------- + M : SciPy sparse array + Biadjacency matrix representation of the bipartite graph G. + + Notes + ----- + No attempt is made to check that the input graph is bipartite. + + For directed bipartite graphs only successors are considered as neighbors. + To obtain an adjacency matrix with ones (or weight values) for both + predecessors and successors you have to generate two biadjacency matrices + where the rows of one of them are the columns of the other, and then add + one to the transpose of the other. + + See Also + -------- + adjacency_matrix + from_biadjacency_matrix + + References + ---------- + .. [1] https://en.wikipedia.org/wiki/Adjacency_matrix#Adjacency_matrix_of_a_bipartite_graph + .. [2] Scipy Dev. References, "Sparse Matrices", + https://docs.scipy.org/doc/scipy/reference/sparse.html + """ + import scipy as sp + + nlen = len(row_order) + if nlen == 0: + raise nx.NetworkXError("row_order is empty list") + if len(row_order) != len(set(row_order)): + msg = "Ambiguous ordering: `row_order` contained duplicates." + raise nx.NetworkXError(msg) + if column_order is None: + column_order = list(set(G) - set(row_order)) + mlen = len(column_order) + if len(column_order) != len(set(column_order)): + msg = "Ambiguous ordering: `column_order` contained duplicates." + raise nx.NetworkXError(msg) + + row_index = dict(zip(row_order, itertools.count())) + col_index = dict(zip(column_order, itertools.count())) + + if G.number_of_edges() == 0: + row, col, data = [], [], [] + else: + row, col, data = zip( + *( + (row_index[u], col_index[v], d.get(weight, 1)) + for u, v, d in G.edges(row_order, data=True) + if u in row_index and v in col_index + ) + ) + A = sp.sparse.coo_array((data, (row, col)), shape=(nlen, mlen), dtype=dtype) + try: + return A.asformat(format) + except ValueError as err: + raise nx.NetworkXError(f"Unknown sparse array format: {format}") from err + + +@nx._dispatchable(graphs=None, returns_graph=True) +def from_biadjacency_matrix(A, create_using=None, edge_attribute="weight"): + r"""Creates a new bipartite graph from a biadjacency matrix given as a + SciPy sparse array. + + Parameters + ---------- + A: scipy sparse array + A biadjacency matrix representation of a graph + + create_using: NetworkX graph + Use specified graph for result. The default is Graph() + + edge_attribute: string + Name of edge attribute to store matrix numeric value. The data will + have the same type as the matrix entry (int, float, (real,imag)). + + Notes + ----- + The nodes are labeled with the attribute `bipartite` set to an integer + 0 or 1 representing membership in part 0 or part 1 of the bipartite graph. + + If `create_using` is an instance of :class:`networkx.MultiGraph` or + :class:`networkx.MultiDiGraph` and the entries of `A` are of + type :class:`int`, then this function returns a multigraph (of the same + type as `create_using`) with parallel edges. In this case, `edge_attribute` + will be ignored. + + See Also + -------- + biadjacency_matrix + from_numpy_array + + References + ---------- + [1] https://en.wikipedia.org/wiki/Adjacency_matrix#Adjacency_matrix_of_a_bipartite_graph + """ + G = nx.empty_graph(0, create_using) + n, m = A.shape + # Make sure we get even the isolated nodes of the graph. + G.add_nodes_from(range(n), bipartite=0) + G.add_nodes_from(range(n, n + m), bipartite=1) + # Create an iterable over (u, v, w) triples and for each triple, add an + # edge from u to v with weight w. + triples = ((u, n + v, d) for (u, v, d) in _generate_weighted_edges(A)) + # If the entries in the adjacency matrix are integers and the graph is a + # multigraph, then create parallel edges, each with weight 1, for each + # entry in the adjacency matrix. Otherwise, create one edge for each + # positive entry in the adjacency matrix and set the weight of that edge to + # be the entry in the matrix. + if A.dtype.kind in ("i", "u") and G.is_multigraph(): + chain = itertools.chain.from_iterable + triples = chain(((u, v, 1) for d in range(w)) for (u, v, w) in triples) + G.add_weighted_edges_from(triples, weight=edge_attribute) + return G diff --git a/lib/python3.10/site-packages/networkx/algorithms/bipartite/projection.py b/lib/python3.10/site-packages/networkx/algorithms/bipartite/projection.py new file mode 100644 index 0000000000000000000000000000000000000000..7c2a26cf73ddf39e51fbd20d442abe736acedddd --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/bipartite/projection.py @@ -0,0 +1,526 @@ +"""One-mode (unipartite) projections of bipartite graphs.""" + +import networkx as nx +from networkx.exception import NetworkXAlgorithmError +from networkx.utils import not_implemented_for + +__all__ = [ + "projected_graph", + "weighted_projected_graph", + "collaboration_weighted_projected_graph", + "overlap_weighted_projected_graph", + "generic_weighted_projected_graph", +] + + +@nx._dispatchable( + graphs="B", preserve_node_attrs=True, preserve_graph_attrs=True, returns_graph=True +) +def projected_graph(B, nodes, multigraph=False): + r"""Returns the projection of B onto one of its node sets. + + Returns the graph G that is the projection of the bipartite graph B + onto the specified nodes. They retain their attributes and are connected + in G if they have a common neighbor in B. + + Parameters + ---------- + B : NetworkX graph + The input graph should be bipartite. + + nodes : list or iterable + Nodes to project onto (the "bottom" nodes). + + multigraph: bool (default=False) + If True return a multigraph where the multiple edges represent multiple + shared neighbors. They edge key in the multigraph is assigned to the + label of the neighbor. + + Returns + ------- + Graph : NetworkX graph or multigraph + A graph that is the projection onto the given nodes. + + Examples + -------- + >>> from networkx.algorithms import bipartite + >>> B = nx.path_graph(4) + >>> G = bipartite.projected_graph(B, [1, 3]) + >>> list(G) + [1, 3] + >>> list(G.edges()) + [(1, 3)] + + If nodes `a`, and `b` are connected through both nodes 1 and 2 then + building a multigraph results in two edges in the projection onto + [`a`, `b`]: + + >>> B = nx.Graph() + >>> B.add_edges_from([("a", 1), ("b", 1), ("a", 2), ("b", 2)]) + >>> G = bipartite.projected_graph(B, ["a", "b"], multigraph=True) + >>> print([sorted((u, v)) for u, v in G.edges()]) + [['a', 'b'], ['a', 'b']] + + Notes + ----- + No attempt is made to verify that the input graph B is bipartite. + Returns a simple graph that is the projection of the bipartite graph B + onto the set of nodes given in list nodes. If multigraph=True then + a multigraph is returned with an edge for every shared neighbor. + + Directed graphs are allowed as input. The output will also then + be a directed graph with edges if there is a directed path between + the nodes. + + The graph and node properties are (shallow) copied to the projected graph. + + See :mod:`bipartite documentation ` + for further details on how bipartite graphs are handled in NetworkX. + + See Also + -------- + is_bipartite, + is_bipartite_node_set, + sets, + weighted_projected_graph, + collaboration_weighted_projected_graph, + overlap_weighted_projected_graph, + generic_weighted_projected_graph + """ + if B.is_multigraph(): + raise nx.NetworkXError("not defined for multigraphs") + if B.is_directed(): + directed = True + if multigraph: + G = nx.MultiDiGraph() + else: + G = nx.DiGraph() + else: + directed = False + if multigraph: + G = nx.MultiGraph() + else: + G = nx.Graph() + G.graph.update(B.graph) + G.add_nodes_from((n, B.nodes[n]) for n in nodes) + for u in nodes: + nbrs2 = {v for nbr in B[u] for v in B[nbr] if v != u} + if multigraph: + for n in nbrs2: + if directed: + links = set(B[u]) & set(B.pred[n]) + else: + links = set(B[u]) & set(B[n]) + for l in links: + if not G.has_edge(u, n, l): + G.add_edge(u, n, key=l) + else: + G.add_edges_from((u, n) for n in nbrs2) + return G + + +@not_implemented_for("multigraph") +@nx._dispatchable(graphs="B", returns_graph=True) +def weighted_projected_graph(B, nodes, ratio=False): + r"""Returns a weighted projection of B onto one of its node sets. + + The weighted projected graph is the projection of the bipartite + network B onto the specified nodes with weights representing the + number of shared neighbors or the ratio between actual shared + neighbors and possible shared neighbors if ``ratio is True`` [1]_. + The nodes retain their attributes and are connected in the resulting + graph if they have an edge to a common node in the original graph. + + Parameters + ---------- + B : NetworkX graph + The input graph should be bipartite. + + nodes : list or iterable + Distinct nodes to project onto (the "bottom" nodes). + + ratio: Bool (default=False) + If True, edge weight is the ratio between actual shared neighbors + and maximum possible shared neighbors (i.e., the size of the other + node set). If False, edges weight is the number of shared neighbors. + + Returns + ------- + Graph : NetworkX graph + A graph that is the projection onto the given nodes. + + Examples + -------- + >>> from networkx.algorithms import bipartite + >>> B = nx.path_graph(4) + >>> G = bipartite.weighted_projected_graph(B, [1, 3]) + >>> list(G) + [1, 3] + >>> list(G.edges(data=True)) + [(1, 3, {'weight': 1})] + >>> G = bipartite.weighted_projected_graph(B, [1, 3], ratio=True) + >>> list(G.edges(data=True)) + [(1, 3, {'weight': 0.5})] + + Notes + ----- + No attempt is made to verify that the input graph B is bipartite, or that + the input nodes are distinct. However, if the length of the input nodes is + greater than or equal to the nodes in the graph B, an exception is raised. + If the nodes are not distinct but don't raise this error, the output weights + will be incorrect. + The graph and node properties are (shallow) copied to the projected graph. + + See :mod:`bipartite documentation ` + for further details on how bipartite graphs are handled in NetworkX. + + See Also + -------- + is_bipartite, + is_bipartite_node_set, + sets, + collaboration_weighted_projected_graph, + overlap_weighted_projected_graph, + generic_weighted_projected_graph + projected_graph + + References + ---------- + .. [1] Borgatti, S.P. and Halgin, D. In press. "Analyzing Affiliation + Networks". In Carrington, P. and Scott, J. (eds) The Sage Handbook + of Social Network Analysis. Sage Publications. + """ + if B.is_directed(): + pred = B.pred + G = nx.DiGraph() + else: + pred = B.adj + G = nx.Graph() + G.graph.update(B.graph) + G.add_nodes_from((n, B.nodes[n]) for n in nodes) + n_top = len(B) - len(nodes) + + if n_top < 1: + raise NetworkXAlgorithmError( + f"the size of the nodes to project onto ({len(nodes)}) is >= the graph size ({len(B)}).\n" + "They are either not a valid bipartite partition or contain duplicates" + ) + + for u in nodes: + unbrs = set(B[u]) + nbrs2 = {n for nbr in unbrs for n in B[nbr]} - {u} + for v in nbrs2: + vnbrs = set(pred[v]) + common = unbrs & vnbrs + if not ratio: + weight = len(common) + else: + weight = len(common) / n_top + G.add_edge(u, v, weight=weight) + return G + + +@not_implemented_for("multigraph") +@nx._dispatchable(graphs="B", returns_graph=True) +def collaboration_weighted_projected_graph(B, nodes): + r"""Newman's weighted projection of B onto one of its node sets. + + The collaboration weighted projection is the projection of the + bipartite network B onto the specified nodes with weights assigned + using Newman's collaboration model [1]_: + + .. math:: + + w_{u, v} = \sum_k \frac{\delta_{u}^{k} \delta_{v}^{k}}{d_k - 1} + + where `u` and `v` are nodes from the bottom bipartite node set, + and `k` is a node of the top node set. + The value `d_k` is the degree of node `k` in the bipartite + network and `\delta_{u}^{k}` is 1 if node `u` is + linked to node `k` in the original bipartite graph or 0 otherwise. + + The nodes retain their attributes and are connected in the resulting + graph if have an edge to a common node in the original bipartite + graph. + + Parameters + ---------- + B : NetworkX graph + The input graph should be bipartite. + + nodes : list or iterable + Nodes to project onto (the "bottom" nodes). + + Returns + ------- + Graph : NetworkX graph + A graph that is the projection onto the given nodes. + + Examples + -------- + >>> from networkx.algorithms import bipartite + >>> B = nx.path_graph(5) + >>> B.add_edge(1, 5) + >>> G = bipartite.collaboration_weighted_projected_graph(B, [0, 2, 4, 5]) + >>> list(G) + [0, 2, 4, 5] + >>> for edge in sorted(G.edges(data=True)): + ... print(edge) + (0, 2, {'weight': 0.5}) + (0, 5, {'weight': 0.5}) + (2, 4, {'weight': 1.0}) + (2, 5, {'weight': 0.5}) + + Notes + ----- + No attempt is made to verify that the input graph B is bipartite. + The graph and node properties are (shallow) copied to the projected graph. + + See :mod:`bipartite documentation ` + for further details on how bipartite graphs are handled in NetworkX. + + See Also + -------- + is_bipartite, + is_bipartite_node_set, + sets, + weighted_projected_graph, + overlap_weighted_projected_graph, + generic_weighted_projected_graph, + projected_graph + + References + ---------- + .. [1] Scientific collaboration networks: II. + Shortest paths, weighted networks, and centrality, + M. E. J. Newman, Phys. Rev. E 64, 016132 (2001). + """ + if B.is_directed(): + pred = B.pred + G = nx.DiGraph() + else: + pred = B.adj + G = nx.Graph() + G.graph.update(B.graph) + G.add_nodes_from((n, B.nodes[n]) for n in nodes) + for u in nodes: + unbrs = set(B[u]) + nbrs2 = {n for nbr in unbrs for n in B[nbr] if n != u} + for v in nbrs2: + vnbrs = set(pred[v]) + common_degree = (len(B[n]) for n in unbrs & vnbrs) + weight = sum(1.0 / (deg - 1) for deg in common_degree if deg > 1) + G.add_edge(u, v, weight=weight) + return G + + +@not_implemented_for("multigraph") +@nx._dispatchable(graphs="B", returns_graph=True) +def overlap_weighted_projected_graph(B, nodes, jaccard=True): + r"""Overlap weighted projection of B onto one of its node sets. + + The overlap weighted projection is the projection of the bipartite + network B onto the specified nodes with weights representing + the Jaccard index between the neighborhoods of the two nodes in the + original bipartite network [1]_: + + .. math:: + + w_{v, u} = \frac{|N(u) \cap N(v)|}{|N(u) \cup N(v)|} + + or if the parameter 'jaccard' is False, the fraction of common + neighbors by minimum of both nodes degree in the original + bipartite graph [1]_: + + .. math:: + + w_{v, u} = \frac{|N(u) \cap N(v)|}{min(|N(u)|, |N(v)|)} + + The nodes retain their attributes and are connected in the resulting + graph if have an edge to a common node in the original bipartite graph. + + Parameters + ---------- + B : NetworkX graph + The input graph should be bipartite. + + nodes : list or iterable + Nodes to project onto (the "bottom" nodes). + + jaccard: Bool (default=True) + + Returns + ------- + Graph : NetworkX graph + A graph that is the projection onto the given nodes. + + Examples + -------- + >>> from networkx.algorithms import bipartite + >>> B = nx.path_graph(5) + >>> nodes = [0, 2, 4] + >>> G = bipartite.overlap_weighted_projected_graph(B, nodes) + >>> list(G) + [0, 2, 4] + >>> list(G.edges(data=True)) + [(0, 2, {'weight': 0.5}), (2, 4, {'weight': 0.5})] + >>> G = bipartite.overlap_weighted_projected_graph(B, nodes, jaccard=False) + >>> list(G.edges(data=True)) + [(0, 2, {'weight': 1.0}), (2, 4, {'weight': 1.0})] + + Notes + ----- + No attempt is made to verify that the input graph B is bipartite. + The graph and node properties are (shallow) copied to the projected graph. + + See :mod:`bipartite documentation ` + for further details on how bipartite graphs are handled in NetworkX. + + See Also + -------- + is_bipartite, + is_bipartite_node_set, + sets, + weighted_projected_graph, + collaboration_weighted_projected_graph, + generic_weighted_projected_graph, + projected_graph + + References + ---------- + .. [1] Borgatti, S.P. and Halgin, D. In press. Analyzing Affiliation + Networks. In Carrington, P. and Scott, J. (eds) The Sage Handbook + of Social Network Analysis. Sage Publications. + + """ + if B.is_directed(): + pred = B.pred + G = nx.DiGraph() + else: + pred = B.adj + G = nx.Graph() + G.graph.update(B.graph) + G.add_nodes_from((n, B.nodes[n]) for n in nodes) + for u in nodes: + unbrs = set(B[u]) + nbrs2 = {n for nbr in unbrs for n in B[nbr]} - {u} + for v in nbrs2: + vnbrs = set(pred[v]) + if jaccard: + wt = len(unbrs & vnbrs) / len(unbrs | vnbrs) + else: + wt = len(unbrs & vnbrs) / min(len(unbrs), len(vnbrs)) + G.add_edge(u, v, weight=wt) + return G + + +@not_implemented_for("multigraph") +@nx._dispatchable(graphs="B", preserve_all_attrs=True, returns_graph=True) +def generic_weighted_projected_graph(B, nodes, weight_function=None): + r"""Weighted projection of B with a user-specified weight function. + + The bipartite network B is projected on to the specified nodes + with weights computed by a user-specified function. This function + must accept as a parameter the neighborhood sets of two nodes and + return an integer or a float. + + The nodes retain their attributes and are connected in the resulting graph + if they have an edge to a common node in the original graph. + + Parameters + ---------- + B : NetworkX graph + The input graph should be bipartite. + + nodes : list or iterable + Nodes to project onto (the "bottom" nodes). + + weight_function : function + This function must accept as parameters the same input graph + that this function, and two nodes; and return an integer or a float. + The default function computes the number of shared neighbors. + + Returns + ------- + Graph : NetworkX graph + A graph that is the projection onto the given nodes. + + Examples + -------- + >>> from networkx.algorithms import bipartite + >>> # Define some custom weight functions + >>> def jaccard(G, u, v): + ... unbrs = set(G[u]) + ... vnbrs = set(G[v]) + ... return float(len(unbrs & vnbrs)) / len(unbrs | vnbrs) + >>> def my_weight(G, u, v, weight="weight"): + ... w = 0 + ... for nbr in set(G[u]) & set(G[v]): + ... w += G[u][nbr].get(weight, 1) + G[v][nbr].get(weight, 1) + ... return w + >>> # A complete bipartite graph with 4 nodes and 4 edges + >>> B = nx.complete_bipartite_graph(2, 2) + >>> # Add some arbitrary weight to the edges + >>> for i, (u, v) in enumerate(B.edges()): + ... B.edges[u, v]["weight"] = i + 1 + >>> for edge in B.edges(data=True): + ... print(edge) + (0, 2, {'weight': 1}) + (0, 3, {'weight': 2}) + (1, 2, {'weight': 3}) + (1, 3, {'weight': 4}) + >>> # By default, the weight is the number of shared neighbors + >>> G = bipartite.generic_weighted_projected_graph(B, [0, 1]) + >>> print(list(G.edges(data=True))) + [(0, 1, {'weight': 2})] + >>> # To specify a custom weight function use the weight_function parameter + >>> G = bipartite.generic_weighted_projected_graph( + ... B, [0, 1], weight_function=jaccard + ... ) + >>> print(list(G.edges(data=True))) + [(0, 1, {'weight': 1.0})] + >>> G = bipartite.generic_weighted_projected_graph( + ... B, [0, 1], weight_function=my_weight + ... ) + >>> print(list(G.edges(data=True))) + [(0, 1, {'weight': 10})] + + Notes + ----- + No attempt is made to verify that the input graph B is bipartite. + The graph and node properties are (shallow) copied to the projected graph. + + See :mod:`bipartite documentation ` + for further details on how bipartite graphs are handled in NetworkX. + + See Also + -------- + is_bipartite, + is_bipartite_node_set, + sets, + weighted_projected_graph, + collaboration_weighted_projected_graph, + overlap_weighted_projected_graph, + projected_graph + + """ + if B.is_directed(): + pred = B.pred + G = nx.DiGraph() + else: + pred = B.adj + G = nx.Graph() + if weight_function is None: + + def weight_function(G, u, v): + # Notice that we use set(pred[v]) for handling the directed case. + return len(set(G[u]) & set(pred[v])) + + G.graph.update(B.graph) + G.add_nodes_from((n, B.nodes[n]) for n in nodes) + for u in nodes: + nbrs2 = {n for nbr in set(B[u]) for n in B[nbr]} - {u} + for v in nbrs2: + weight = weight_function(B, u, v) + G.add_edge(u, v, weight=weight) + return G diff --git a/lib/python3.10/site-packages/networkx/algorithms/bipartite/redundancy.py b/lib/python3.10/site-packages/networkx/algorithms/bipartite/redundancy.py new file mode 100644 index 0000000000000000000000000000000000000000..b622b975f0255ee4ac4bc56031c127ac58592abf --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/bipartite/redundancy.py @@ -0,0 +1,112 @@ +"""Node redundancy for bipartite graphs.""" + +from itertools import combinations + +import networkx as nx +from networkx import NetworkXError + +__all__ = ["node_redundancy"] + + +@nx._dispatchable +def node_redundancy(G, nodes=None): + r"""Computes the node redundancy coefficients for the nodes in the bipartite + graph `G`. + + The redundancy coefficient of a node `v` is the fraction of pairs of + neighbors of `v` that are both linked to other nodes. In a one-mode + projection these nodes would be linked together even if `v` were + not there. + + More formally, for any vertex `v`, the *redundancy coefficient of `v`* is + defined by + + .. math:: + + rc(v) = \frac{|\{\{u, w\} \subseteq N(v), + \: \exists v' \neq v,\: (v',u) \in E\: + \mathrm{and}\: (v',w) \in E\}|}{ \frac{|N(v)|(|N(v)|-1)}{2}}, + + where `N(v)` is the set of neighbors of `v` in `G`. + + Parameters + ---------- + G : graph + A bipartite graph + + nodes : list or iterable (optional) + Compute redundancy for these nodes. The default is all nodes in G. + + Returns + ------- + redundancy : dictionary + A dictionary keyed by node with the node redundancy value. + + Examples + -------- + Compute the redundancy coefficient of each node in a graph:: + + >>> from networkx.algorithms import bipartite + >>> G = nx.cycle_graph(4) + >>> rc = bipartite.node_redundancy(G) + >>> rc[0] + 1.0 + + Compute the average redundancy for the graph:: + + >>> from networkx.algorithms import bipartite + >>> G = nx.cycle_graph(4) + >>> rc = bipartite.node_redundancy(G) + >>> sum(rc.values()) / len(G) + 1.0 + + Compute the average redundancy for a set of nodes:: + + >>> from networkx.algorithms import bipartite + >>> G = nx.cycle_graph(4) + >>> rc = bipartite.node_redundancy(G) + >>> nodes = [0, 2] + >>> sum(rc[n] for n in nodes) / len(nodes) + 1.0 + + Raises + ------ + NetworkXError + If any of the nodes in the graph (or in `nodes`, if specified) has + (out-)degree less than two (which would result in division by zero, + according to the definition of the redundancy coefficient). + + References + ---------- + .. [1] Latapy, Matthieu, Clémence Magnien, and Nathalie Del Vecchio (2008). + Basic notions for the analysis of large two-mode networks. + Social Networks 30(1), 31--48. + + """ + if nodes is None: + nodes = G + if any(len(G[v]) < 2 for v in nodes): + raise NetworkXError( + "Cannot compute redundancy coefficient for a node" + " that has fewer than two neighbors." + ) + # TODO This can be trivially parallelized. + return {v: _node_redundancy(G, v) for v in nodes} + + +def _node_redundancy(G, v): + """Returns the redundancy of the node `v` in the bipartite graph `G`. + + If `G` is a graph with `n` nodes, the redundancy of a node is the ratio + of the "overlap" of `v` to the maximum possible overlap of `v` + according to its degree. The overlap of `v` is the number of pairs of + neighbors that have mutual neighbors themselves, other than `v`. + + `v` must have at least two neighbors in `G`. + + """ + n = len(G[v]) + overlap = sum( + 1 for (u, w) in combinations(G[v], 2) if (set(G[u]) & set(G[w])) - {v} + ) + return (2 * overlap) / (n * (n - 1)) diff --git a/lib/python3.10/site-packages/networkx/algorithms/bipartite/spectral.py b/lib/python3.10/site-packages/networkx/algorithms/bipartite/spectral.py new file mode 100644 index 0000000000000000000000000000000000000000..cb9388f6cb61cb3c5da865e22449f4e8f2d1e720 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/bipartite/spectral.py @@ -0,0 +1,69 @@ +""" +Spectral bipartivity measure. +""" + +import networkx as nx + +__all__ = ["spectral_bipartivity"] + + +@nx._dispatchable(edge_attrs="weight") +def spectral_bipartivity(G, nodes=None, weight="weight"): + """Returns the spectral bipartivity. + + Parameters + ---------- + G : NetworkX graph + + nodes : list or container optional(default is all nodes) + Nodes to return value of spectral bipartivity contribution. + + weight : string or None optional (default = 'weight') + Edge data key to use for edge weights. If None, weights set to 1. + + Returns + ------- + sb : float or dict + A single number if the keyword nodes is not specified, or + a dictionary keyed by node with the spectral bipartivity contribution + of that node as the value. + + Examples + -------- + >>> from networkx.algorithms import bipartite + >>> G = nx.path_graph(4) + >>> bipartite.spectral_bipartivity(G) + 1.0 + + Notes + ----- + This implementation uses Numpy (dense) matrices which are not efficient + for storing large sparse graphs. + + See Also + -------- + color + + References + ---------- + .. [1] E. Estrada and J. A. Rodríguez-Velázquez, "Spectral measures of + bipartivity in complex networks", PhysRev E 72, 046105 (2005) + """ + import scipy as sp + + nodelist = list(G) # ordering of nodes in matrix + A = nx.to_numpy_array(G, nodelist, weight=weight) + expA = sp.linalg.expm(A) + expmA = sp.linalg.expm(-A) + coshA = 0.5 * (expA + expmA) + if nodes is None: + # return single number for entire graph + return float(coshA.diagonal().sum() / expA.diagonal().sum()) + else: + # contribution for individual nodes + index = dict(zip(nodelist, range(len(nodelist)))) + sb = {} + for n in nodes: + i = index[n] + sb[n] = coshA.item(i, i) / expA.item(i, i) + return sb diff --git a/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/__init__.py b/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e9afc0389d14595d1d5dcf50e92575ba57daaecc Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/__pycache__/test_basic.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/__pycache__/test_basic.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b829887f9d14edc898a05c437340f5bf8c1bce4 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/__pycache__/test_basic.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/__pycache__/test_centrality.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/__pycache__/test_centrality.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e4f33f100146b9ae6f57e4fdaf49ea47a1027968 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/__pycache__/test_centrality.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/__pycache__/test_cluster.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/__pycache__/test_cluster.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c8e44acea61238d480d5f3ae9112ec677cedf0e6 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/__pycache__/test_cluster.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/__pycache__/test_covering.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/__pycache__/test_covering.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fefbd937be82873ec4355b7095ef794ffc5c17e2 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/__pycache__/test_covering.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/__pycache__/test_edgelist.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/__pycache__/test_edgelist.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..702ac3328bb1ae5bd01dde5254576a008d8caad2 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/__pycache__/test_edgelist.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/__pycache__/test_extendability.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/__pycache__/test_extendability.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..53e2e5cd3413aadf3af7e0a639038d701398e515 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/__pycache__/test_extendability.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/__pycache__/test_generators.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/__pycache__/test_generators.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aed979d962e07f8750e2e0ffbcb8ef06bb13a956 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/__pycache__/test_generators.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/__pycache__/test_matching.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/__pycache__/test_matching.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..edd899ceb92b10ac687aa866da69a58083d07174 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/__pycache__/test_matching.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/__pycache__/test_matrix.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/__pycache__/test_matrix.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a3997ec3b6c49c948362943a303c7b18fa60914 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/__pycache__/test_matrix.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/__pycache__/test_project.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/__pycache__/test_project.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4d9ea54f8527673bfd1284035d23526542799ade Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/__pycache__/test_project.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/__pycache__/test_redundancy.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/__pycache__/test_redundancy.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c6ae2c4f91414b3fbe1b7687002024b176c2c4e4 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/__pycache__/test_redundancy.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/__pycache__/test_spectral_bipartivity.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/__pycache__/test_spectral_bipartivity.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6e0fc246053bd2150f6707c239838e902588bd0 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/__pycache__/test_spectral_bipartivity.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/test_basic.py b/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/test_basic.py new file mode 100644 index 0000000000000000000000000000000000000000..655506b4f74110b57cb37db277e2be50bb0be8f4 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/test_basic.py @@ -0,0 +1,125 @@ +import pytest + +import networkx as nx +from networkx.algorithms import bipartite + + +class TestBipartiteBasic: + def test_is_bipartite(self): + assert bipartite.is_bipartite(nx.path_graph(4)) + assert bipartite.is_bipartite(nx.DiGraph([(1, 0)])) + assert not bipartite.is_bipartite(nx.complete_graph(3)) + + def test_bipartite_color(self): + G = nx.path_graph(4) + c = bipartite.color(G) + assert c == {0: 1, 1: 0, 2: 1, 3: 0} + + def test_not_bipartite_color(self): + with pytest.raises(nx.NetworkXError): + c = bipartite.color(nx.complete_graph(4)) + + def test_bipartite_directed(self): + G = bipartite.random_graph(10, 10, 0.1, directed=True) + assert bipartite.is_bipartite(G) + + def test_bipartite_sets(self): + G = nx.path_graph(4) + X, Y = bipartite.sets(G) + assert X == {0, 2} + assert Y == {1, 3} + + def test_bipartite_sets_directed(self): + G = nx.path_graph(4) + D = G.to_directed() + X, Y = bipartite.sets(D) + assert X == {0, 2} + assert Y == {1, 3} + + def test_bipartite_sets_given_top_nodes(self): + G = nx.path_graph(4) + top_nodes = [0, 2] + X, Y = bipartite.sets(G, top_nodes) + assert X == {0, 2} + assert Y == {1, 3} + + def test_bipartite_sets_disconnected(self): + with pytest.raises(nx.AmbiguousSolution): + G = nx.path_graph(4) + G.add_edges_from([(5, 6), (6, 7)]) + X, Y = bipartite.sets(G) + + def test_is_bipartite_node_set(self): + G = nx.path_graph(4) + + with pytest.raises(nx.AmbiguousSolution): + bipartite.is_bipartite_node_set(G, [1, 1, 2, 3]) + + assert bipartite.is_bipartite_node_set(G, [0, 2]) + assert bipartite.is_bipartite_node_set(G, [1, 3]) + assert not bipartite.is_bipartite_node_set(G, [1, 2]) + G.add_edge(10, 20) + assert bipartite.is_bipartite_node_set(G, [0, 2, 10]) + assert bipartite.is_bipartite_node_set(G, [0, 2, 20]) + assert bipartite.is_bipartite_node_set(G, [1, 3, 10]) + assert bipartite.is_bipartite_node_set(G, [1, 3, 20]) + + def test_bipartite_density(self): + G = nx.path_graph(5) + X, Y = bipartite.sets(G) + density = len(list(G.edges())) / (len(X) * len(Y)) + assert bipartite.density(G, X) == density + D = nx.DiGraph(G.edges()) + assert bipartite.density(D, X) == density / 2.0 + assert bipartite.density(nx.Graph(), {}) == 0.0 + + def test_bipartite_degrees(self): + G = nx.path_graph(5) + X = {1, 3} + Y = {0, 2, 4} + u, d = bipartite.degrees(G, Y) + assert dict(u) == {1: 2, 3: 2} + assert dict(d) == {0: 1, 2: 2, 4: 1} + + def test_bipartite_weighted_degrees(self): + G = nx.path_graph(5) + G.add_edge(0, 1, weight=0.1, other=0.2) + X = {1, 3} + Y = {0, 2, 4} + u, d = bipartite.degrees(G, Y, weight="weight") + assert dict(u) == {1: 1.1, 3: 2} + assert dict(d) == {0: 0.1, 2: 2, 4: 1} + u, d = bipartite.degrees(G, Y, weight="other") + assert dict(u) == {1: 1.2, 3: 2} + assert dict(d) == {0: 0.2, 2: 2, 4: 1} + + def test_biadjacency_matrix_weight(self): + pytest.importorskip("scipy") + G = nx.path_graph(5) + G.add_edge(0, 1, weight=2, other=4) + X = [1, 3] + Y = [0, 2, 4] + M = bipartite.biadjacency_matrix(G, X, weight="weight") + assert M[0, 0] == 2 + M = bipartite.biadjacency_matrix(G, X, weight="other") + assert M[0, 0] == 4 + + def test_biadjacency_matrix(self): + pytest.importorskip("scipy") + tops = [2, 5, 10] + bots = [5, 10, 15] + for i in range(len(tops)): + G = bipartite.random_graph(tops[i], bots[i], 0.2) + top = [n for n, d in G.nodes(data=True) if d["bipartite"] == 0] + M = bipartite.biadjacency_matrix(G, top) + assert M.shape[0] == tops[i] + assert M.shape[1] == bots[i] + + def test_biadjacency_matrix_order(self): + pytest.importorskip("scipy") + G = nx.path_graph(5) + G.add_edge(0, 1, weight=2) + X = [3, 1] + Y = [4, 2, 0] + M = bipartite.biadjacency_matrix(G, X, Y, weight="weight") + assert M[1, 2] == 2 diff --git a/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/test_centrality.py b/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/test_centrality.py new file mode 100644 index 0000000000000000000000000000000000000000..19fb5d117be94c688616a394ea3322e93bfa3e00 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/test_centrality.py @@ -0,0 +1,192 @@ +import pytest + +import networkx as nx +from networkx.algorithms import bipartite + + +class TestBipartiteCentrality: + @classmethod + def setup_class(cls): + cls.P4 = nx.path_graph(4) + cls.K3 = nx.complete_bipartite_graph(3, 3) + cls.C4 = nx.cycle_graph(4) + cls.davis = nx.davis_southern_women_graph() + cls.top_nodes = [ + n for n, d in cls.davis.nodes(data=True) if d["bipartite"] == 0 + ] + + def test_degree_centrality(self): + d = bipartite.degree_centrality(self.P4, [1, 3]) + answer = {0: 0.5, 1: 1.0, 2: 1.0, 3: 0.5} + assert d == answer + d = bipartite.degree_centrality(self.K3, [0, 1, 2]) + answer = {0: 1.0, 1: 1.0, 2: 1.0, 3: 1.0, 4: 1.0, 5: 1.0} + assert d == answer + d = bipartite.degree_centrality(self.C4, [0, 2]) + answer = {0: 1.0, 1: 1.0, 2: 1.0, 3: 1.0} + assert d == answer + + def test_betweenness_centrality(self): + c = bipartite.betweenness_centrality(self.P4, [1, 3]) + answer = {0: 0.0, 1: 1.0, 2: 1.0, 3: 0.0} + assert c == answer + c = bipartite.betweenness_centrality(self.K3, [0, 1, 2]) + answer = {0: 0.125, 1: 0.125, 2: 0.125, 3: 0.125, 4: 0.125, 5: 0.125} + assert c == answer + c = bipartite.betweenness_centrality(self.C4, [0, 2]) + answer = {0: 0.25, 1: 0.25, 2: 0.25, 3: 0.25} + assert c == answer + + def test_closeness_centrality(self): + c = bipartite.closeness_centrality(self.P4, [1, 3]) + answer = {0: 2.0 / 3, 1: 1.0, 2: 1.0, 3: 2.0 / 3} + assert c == answer + c = bipartite.closeness_centrality(self.K3, [0, 1, 2]) + answer = {0: 1.0, 1: 1.0, 2: 1.0, 3: 1.0, 4: 1.0, 5: 1.0} + assert c == answer + c = bipartite.closeness_centrality(self.C4, [0, 2]) + answer = {0: 1.0, 1: 1.0, 2: 1.0, 3: 1.0} + assert c == answer + G = nx.Graph() + G.add_node(0) + G.add_node(1) + c = bipartite.closeness_centrality(G, [0]) + assert c == {0: 0.0, 1: 0.0} + c = bipartite.closeness_centrality(G, [1]) + assert c == {0: 0.0, 1: 0.0} + + def test_bipartite_closeness_centrality_unconnected(self): + G = nx.complete_bipartite_graph(3, 3) + G.add_edge(6, 7) + c = bipartite.closeness_centrality(G, [0, 2, 4, 6], normalized=False) + answer = { + 0: 10.0 / 7, + 2: 10.0 / 7, + 4: 10.0 / 7, + 6: 10.0, + 1: 10.0 / 7, + 3: 10.0 / 7, + 5: 10.0 / 7, + 7: 10.0, + } + assert c == answer + + def test_davis_degree_centrality(self): + G = self.davis + deg = bipartite.degree_centrality(G, self.top_nodes) + answer = { + "E8": 0.78, + "E9": 0.67, + "E7": 0.56, + "Nora Fayette": 0.57, + "Evelyn Jefferson": 0.57, + "Theresa Anderson": 0.57, + "E6": 0.44, + "Sylvia Avondale": 0.50, + "Laura Mandeville": 0.50, + "Brenda Rogers": 0.50, + "Katherina Rogers": 0.43, + "E5": 0.44, + "Helen Lloyd": 0.36, + "E3": 0.33, + "Ruth DeSand": 0.29, + "Verne Sanderson": 0.29, + "E12": 0.33, + "Myra Liddel": 0.29, + "E11": 0.22, + "Eleanor Nye": 0.29, + "Frances Anderson": 0.29, + "Pearl Oglethorpe": 0.21, + "E4": 0.22, + "Charlotte McDowd": 0.29, + "E10": 0.28, + "Olivia Carleton": 0.14, + "Flora Price": 0.14, + "E2": 0.17, + "E1": 0.17, + "Dorothy Murchison": 0.14, + "E13": 0.17, + "E14": 0.17, + } + for node, value in answer.items(): + assert value == pytest.approx(deg[node], abs=1e-2) + + def test_davis_betweenness_centrality(self): + G = self.davis + bet = bipartite.betweenness_centrality(G, self.top_nodes) + answer = { + "E8": 0.24, + "E9": 0.23, + "E7": 0.13, + "Nora Fayette": 0.11, + "Evelyn Jefferson": 0.10, + "Theresa Anderson": 0.09, + "E6": 0.07, + "Sylvia Avondale": 0.07, + "Laura Mandeville": 0.05, + "Brenda Rogers": 0.05, + "Katherina Rogers": 0.05, + "E5": 0.04, + "Helen Lloyd": 0.04, + "E3": 0.02, + "Ruth DeSand": 0.02, + "Verne Sanderson": 0.02, + "E12": 0.02, + "Myra Liddel": 0.02, + "E11": 0.02, + "Eleanor Nye": 0.01, + "Frances Anderson": 0.01, + "Pearl Oglethorpe": 0.01, + "E4": 0.01, + "Charlotte McDowd": 0.01, + "E10": 0.01, + "Olivia Carleton": 0.01, + "Flora Price": 0.01, + "E2": 0.00, + "E1": 0.00, + "Dorothy Murchison": 0.00, + "E13": 0.00, + "E14": 0.00, + } + for node, value in answer.items(): + assert value == pytest.approx(bet[node], abs=1e-2) + + def test_davis_closeness_centrality(self): + G = self.davis + clos = bipartite.closeness_centrality(G, self.top_nodes) + answer = { + "E8": 0.85, + "E9": 0.79, + "E7": 0.73, + "Nora Fayette": 0.80, + "Evelyn Jefferson": 0.80, + "Theresa Anderson": 0.80, + "E6": 0.69, + "Sylvia Avondale": 0.77, + "Laura Mandeville": 0.73, + "Brenda Rogers": 0.73, + "Katherina Rogers": 0.73, + "E5": 0.59, + "Helen Lloyd": 0.73, + "E3": 0.56, + "Ruth DeSand": 0.71, + "Verne Sanderson": 0.71, + "E12": 0.56, + "Myra Liddel": 0.69, + "E11": 0.54, + "Eleanor Nye": 0.67, + "Frances Anderson": 0.67, + "Pearl Oglethorpe": 0.67, + "E4": 0.54, + "Charlotte McDowd": 0.60, + "E10": 0.55, + "Olivia Carleton": 0.59, + "Flora Price": 0.59, + "E2": 0.52, + "E1": 0.52, + "Dorothy Murchison": 0.65, + "E13": 0.52, + "E14": 0.52, + } + for node, value in answer.items(): + assert value == pytest.approx(clos[node], abs=1e-2) diff --git a/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/test_cluster.py b/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/test_cluster.py new file mode 100644 index 0000000000000000000000000000000000000000..72e2dbadd64e9e768d1541b2ce742c2b62278929 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/test_cluster.py @@ -0,0 +1,84 @@ +import pytest + +import networkx as nx +from networkx.algorithms import bipartite +from networkx.algorithms.bipartite.cluster import cc_dot, cc_max, cc_min + + +def test_pairwise_bipartite_cc_functions(): + # Test functions for different kinds of bipartite clustering coefficients + # between pairs of nodes using 3 example graphs from figure 5 p. 40 + # Latapy et al (2008) + G1 = nx.Graph([(0, 2), (0, 3), (0, 4), (0, 5), (0, 6), (1, 5), (1, 6), (1, 7)]) + G2 = nx.Graph([(0, 2), (0, 3), (0, 4), (1, 3), (1, 4), (1, 5)]) + G3 = nx.Graph( + [(0, 2), (0, 3), (0, 4), (0, 5), (0, 6), (1, 5), (1, 6), (1, 7), (1, 8), (1, 9)] + ) + result = { + 0: [1 / 3.0, 2 / 3.0, 2 / 5.0], + 1: [1 / 2.0, 2 / 3.0, 2 / 3.0], + 2: [2 / 8.0, 2 / 5.0, 2 / 5.0], + } + for i, G in enumerate([G1, G2, G3]): + assert bipartite.is_bipartite(G) + assert cc_dot(set(G[0]), set(G[1])) == result[i][0] + assert cc_min(set(G[0]), set(G[1])) == result[i][1] + assert cc_max(set(G[0]), set(G[1])) == result[i][2] + + +def test_star_graph(): + G = nx.star_graph(3) + # all modes are the same + answer = {0: 0, 1: 1, 2: 1, 3: 1} + assert bipartite.clustering(G, mode="dot") == answer + assert bipartite.clustering(G, mode="min") == answer + assert bipartite.clustering(G, mode="max") == answer + + +def test_not_bipartite(): + with pytest.raises(nx.NetworkXError): + bipartite.clustering(nx.complete_graph(4)) + + +def test_bad_mode(): + with pytest.raises(nx.NetworkXError): + bipartite.clustering(nx.path_graph(4), mode="foo") + + +def test_path_graph(): + G = nx.path_graph(4) + answer = {0: 0.5, 1: 0.5, 2: 0.5, 3: 0.5} + assert bipartite.clustering(G, mode="dot") == answer + assert bipartite.clustering(G, mode="max") == answer + answer = {0: 1, 1: 1, 2: 1, 3: 1} + assert bipartite.clustering(G, mode="min") == answer + + +def test_average_path_graph(): + G = nx.path_graph(4) + assert bipartite.average_clustering(G, mode="dot") == 0.5 + assert bipartite.average_clustering(G, mode="max") == 0.5 + assert bipartite.average_clustering(G, mode="min") == 1 + + +def test_ra_clustering_davis(): + G = nx.davis_southern_women_graph() + cc4 = round(bipartite.robins_alexander_clustering(G), 3) + assert cc4 == 0.468 + + +def test_ra_clustering_square(): + G = nx.path_graph(4) + G.add_edge(0, 3) + assert bipartite.robins_alexander_clustering(G) == 1.0 + + +def test_ra_clustering_zero(): + G = nx.Graph() + assert bipartite.robins_alexander_clustering(G) == 0 + G.add_nodes_from(range(4)) + assert bipartite.robins_alexander_clustering(G) == 0 + G.add_edges_from([(0, 1), (2, 3), (3, 4)]) + assert bipartite.robins_alexander_clustering(G) == 0 + G.add_edge(1, 2) + assert bipartite.robins_alexander_clustering(G) == 0 diff --git a/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/test_covering.py b/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/test_covering.py new file mode 100644 index 0000000000000000000000000000000000000000..9507e13492acbe505aa3394a24dbc41c095a037c --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/test_covering.py @@ -0,0 +1,33 @@ +import networkx as nx +from networkx.algorithms import bipartite + + +class TestMinEdgeCover: + """Tests for :func:`networkx.algorithms.bipartite.min_edge_cover`""" + + def test_empty_graph(self): + G = nx.Graph() + assert bipartite.min_edge_cover(G) == set() + + def test_graph_single_edge(self): + G = nx.Graph() + G.add_edge(0, 1) + assert bipartite.min_edge_cover(G) == {(0, 1), (1, 0)} + + def test_bipartite_default(self): + G = nx.Graph() + G.add_nodes_from([1, 2, 3, 4], bipartite=0) + G.add_nodes_from(["a", "b", "c"], bipartite=1) + G.add_edges_from([(1, "a"), (1, "b"), (2, "b"), (2, "c"), (3, "c"), (4, "a")]) + min_cover = bipartite.min_edge_cover(G) + assert nx.is_edge_cover(G, min_cover) + assert len(min_cover) == 8 + + def test_bipartite_explicit(self): + G = nx.Graph() + G.add_nodes_from([1, 2, 3, 4], bipartite=0) + G.add_nodes_from(["a", "b", "c"], bipartite=1) + G.add_edges_from([(1, "a"), (1, "b"), (2, "b"), (2, "c"), (3, "c"), (4, "a")]) + min_cover = bipartite.min_edge_cover(G, bipartite.eppstein_matching) + assert nx.is_edge_cover(G, min_cover) + assert len(min_cover) == 8 diff --git a/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/test_edgelist.py b/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/test_edgelist.py new file mode 100644 index 0000000000000000000000000000000000000000..66be8a2f5b3e1f9486594c63015295ad6a270efa --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/test_edgelist.py @@ -0,0 +1,240 @@ +""" +Unit tests for bipartite edgelists. +""" + +import io + +import pytest + +import networkx as nx +from networkx.algorithms import bipartite +from networkx.utils import edges_equal, graphs_equal, nodes_equal + + +class TestEdgelist: + @classmethod + def setup_class(cls): + cls.G = nx.Graph(name="test") + e = [("a", "b"), ("b", "c"), ("c", "d"), ("d", "e"), ("e", "f"), ("a", "f")] + cls.G.add_edges_from(e) + cls.G.add_nodes_from(["a", "c", "e"], bipartite=0) + cls.G.add_nodes_from(["b", "d", "f"], bipartite=1) + cls.G.add_node("g", bipartite=0) + cls.DG = nx.DiGraph(cls.G) + cls.MG = nx.MultiGraph() + cls.MG.add_edges_from([(1, 2), (1, 2), (1, 2)]) + cls.MG.add_node(1, bipartite=0) + cls.MG.add_node(2, bipartite=1) + + def test_read_edgelist_1(self): + s = b"""\ +# comment line +1 2 +# comment line +2 3 +""" + bytesIO = io.BytesIO(s) + G = bipartite.read_edgelist(bytesIO, nodetype=int) + assert edges_equal(G.edges(), [(1, 2), (2, 3)]) + + def test_read_edgelist_3(self): + s = b"""\ +# comment line +1 2 {'weight':2.0} +# comment line +2 3 {'weight':3.0} +""" + bytesIO = io.BytesIO(s) + G = bipartite.read_edgelist(bytesIO, nodetype=int, data=False) + assert edges_equal(G.edges(), [(1, 2), (2, 3)]) + + bytesIO = io.BytesIO(s) + G = bipartite.read_edgelist(bytesIO, nodetype=int, data=True) + assert edges_equal( + G.edges(data=True), [(1, 2, {"weight": 2.0}), (2, 3, {"weight": 3.0})] + ) + + def test_write_edgelist_1(self): + fh = io.BytesIO() + G = nx.Graph() + G.add_edges_from([(1, 2), (2, 3)]) + G.add_node(1, bipartite=0) + G.add_node(2, bipartite=1) + G.add_node(3, bipartite=0) + bipartite.write_edgelist(G, fh, data=False) + fh.seek(0) + assert fh.read() == b"1 2\n3 2\n" + + def test_write_edgelist_2(self): + fh = io.BytesIO() + G = nx.Graph() + G.add_edges_from([(1, 2), (2, 3)]) + G.add_node(1, bipartite=0) + G.add_node(2, bipartite=1) + G.add_node(3, bipartite=0) + bipartite.write_edgelist(G, fh, data=True) + fh.seek(0) + assert fh.read() == b"1 2 {}\n3 2 {}\n" + + def test_write_edgelist_3(self): + fh = io.BytesIO() + G = nx.Graph() + G.add_edge(1, 2, weight=2.0) + G.add_edge(2, 3, weight=3.0) + G.add_node(1, bipartite=0) + G.add_node(2, bipartite=1) + G.add_node(3, bipartite=0) + bipartite.write_edgelist(G, fh, data=True) + fh.seek(0) + assert fh.read() == b"1 2 {'weight': 2.0}\n3 2 {'weight': 3.0}\n" + + def test_write_edgelist_4(self): + fh = io.BytesIO() + G = nx.Graph() + G.add_edge(1, 2, weight=2.0) + G.add_edge(2, 3, weight=3.0) + G.add_node(1, bipartite=0) + G.add_node(2, bipartite=1) + G.add_node(3, bipartite=0) + bipartite.write_edgelist(G, fh, data=[("weight")]) + fh.seek(0) + assert fh.read() == b"1 2 2.0\n3 2 3.0\n" + + def test_unicode(self, tmp_path): + G = nx.Graph() + name1 = chr(2344) + chr(123) + chr(6543) + name2 = chr(5543) + chr(1543) + chr(324) + G.add_edge(name1, "Radiohead", **{name2: 3}) + G.add_node(name1, bipartite=0) + G.add_node("Radiohead", bipartite=1) + + fname = tmp_path / "edgelist.txt" + bipartite.write_edgelist(G, fname) + H = bipartite.read_edgelist(fname) + assert graphs_equal(G, H) + + def test_latin1_issue(self, tmp_path): + G = nx.Graph() + name1 = chr(2344) + chr(123) + chr(6543) + name2 = chr(5543) + chr(1543) + chr(324) + G.add_edge(name1, "Radiohead", **{name2: 3}) + G.add_node(name1, bipartite=0) + G.add_node("Radiohead", bipartite=1) + + fname = tmp_path / "edgelist.txt" + with pytest.raises(UnicodeEncodeError): + bipartite.write_edgelist(G, fname, encoding="latin-1") + + def test_latin1(self, tmp_path): + G = nx.Graph() + name1 = "Bj" + chr(246) + "rk" + name2 = chr(220) + "ber" + G.add_edge(name1, "Radiohead", **{name2: 3}) + G.add_node(name1, bipartite=0) + G.add_node("Radiohead", bipartite=1) + + fname = tmp_path / "edgelist.txt" + bipartite.write_edgelist(G, fname, encoding="latin-1") + H = bipartite.read_edgelist(fname, encoding="latin-1") + assert graphs_equal(G, H) + + def test_edgelist_graph(self, tmp_path): + G = self.G + fname = tmp_path / "edgelist.txt" + bipartite.write_edgelist(G, fname) + H = bipartite.read_edgelist(fname) + H2 = bipartite.read_edgelist(fname) + assert H is not H2 # they should be different graphs + G.remove_node("g") # isolated nodes are not written in edgelist + assert nodes_equal(list(H), list(G)) + assert edges_equal(list(H.edges()), list(G.edges())) + + def test_edgelist_integers(self, tmp_path): + G = nx.convert_node_labels_to_integers(self.G) + fname = tmp_path / "edgelist.txt" + bipartite.write_edgelist(G, fname) + H = bipartite.read_edgelist(fname, nodetype=int) + # isolated nodes are not written in edgelist + G.remove_nodes_from(list(nx.isolates(G))) + assert nodes_equal(list(H), list(G)) + assert edges_equal(list(H.edges()), list(G.edges())) + + def test_edgelist_multigraph(self, tmp_path): + G = self.MG + fname = tmp_path / "edgelist.txt" + bipartite.write_edgelist(G, fname) + H = bipartite.read_edgelist(fname, nodetype=int, create_using=nx.MultiGraph()) + H2 = bipartite.read_edgelist(fname, nodetype=int, create_using=nx.MultiGraph()) + assert H is not H2 # they should be different graphs + assert nodes_equal(list(H), list(G)) + assert edges_equal(list(H.edges()), list(G.edges())) + + def test_empty_digraph(self): + with pytest.raises(nx.NetworkXNotImplemented): + bytesIO = io.BytesIO() + bipartite.write_edgelist(nx.DiGraph(), bytesIO) + + def test_raise_attribute(self): + with pytest.raises(AttributeError): + G = nx.path_graph(4) + bytesIO = io.BytesIO() + bipartite.write_edgelist(G, bytesIO) + + def test_parse_edgelist(self): + """Tests for conditions specific to + parse_edge_list method""" + + # ignore strings of length less than 2 + lines = ["1 2", "2 3", "3 1", "4", " "] + G = bipartite.parse_edgelist(lines, nodetype=int) + assert list(G.nodes) == [1, 2, 3] + + # Exception raised when node is not convertible + # to specified data type + with pytest.raises(TypeError, match=".*Failed to convert nodes"): + lines = ["a b", "b c", "c a"] + G = bipartite.parse_edgelist(lines, nodetype=int) + + # Exception raised when format of data is not + # convertible to dictionary object + with pytest.raises(TypeError, match=".*Failed to convert edge data"): + lines = ["1 2 3", "2 3 4", "3 1 2"] + G = bipartite.parse_edgelist(lines, nodetype=int) + + # Exception raised when edge data and data + # keys are not of same length + with pytest.raises(IndexError): + lines = ["1 2 3 4", "2 3 4"] + G = bipartite.parse_edgelist( + lines, nodetype=int, data=[("weight", int), ("key", int)] + ) + + # Exception raised when edge data is not + # convertible to specified data type + with pytest.raises(TypeError, match=".*Failed to convert key data"): + lines = ["1 2 3 a", "2 3 4 b"] + G = bipartite.parse_edgelist( + lines, nodetype=int, data=[("weight", int), ("key", int)] + ) + + +def test_bipartite_edgelist_consistent_strip_handling(): + """See gh-7462 + + Input when printed looks like: + + A B interaction 2 + B C interaction 4 + C A interaction + + Note the trailing \\t in the last line, which indicates the existence of + an empty data field. + """ + lines = io.StringIO( + "A\tB\tinteraction\t2\nB\tC\tinteraction\t4\nC\tA\tinteraction\t" + ) + descr = [("type", str), ("weight", str)] + # Should not raise + G = nx.bipartite.parse_edgelist(lines, delimiter="\t", data=descr) + expected = [("A", "B", "2"), ("A", "C", ""), ("B", "C", "4")] + assert sorted(G.edges(data="weight")) == expected diff --git a/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/test_extendability.py b/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/test_extendability.py new file mode 100644 index 0000000000000000000000000000000000000000..17b7124341bd6b0e82b5f01b8e5c6f8d1235efb9 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/test_extendability.py @@ -0,0 +1,334 @@ +import pytest + +import networkx as nx + + +def test_selfloops_raises(): + G = nx.ladder_graph(3) + G.add_edge(0, 0) + with pytest.raises(nx.NetworkXError, match=".*not bipartite"): + nx.bipartite.maximal_extendability(G) + + +def test_disconnected_raises(): + G = nx.ladder_graph(3) + G.add_node("a") + with pytest.raises(nx.NetworkXError, match=".*not connected"): + nx.bipartite.maximal_extendability(G) + + +def test_not_bipartite_raises(): + G = nx.complete_graph(5) + with pytest.raises(nx.NetworkXError, match=".*not bipartite"): + nx.bipartite.maximal_extendability(G) + + +def test_no_perfect_matching_raises(): + G = nx.Graph([(0, 1), (0, 2)]) + with pytest.raises(nx.NetworkXError, match=".*not contain a perfect matching"): + nx.bipartite.maximal_extendability(G) + + +def test_residual_graph_not_strongly_connected_raises(): + G = nx.Graph([(1, 2), (2, 3), (3, 4)]) + with pytest.raises( + nx.NetworkXError, match="The residual graph of G is not strongly connected" + ): + nx.bipartite.maximal_extendability(G) + + +def test_ladder_graph_is_1(): + G = nx.ladder_graph(3) + assert nx.bipartite.maximal_extendability(G) == 1 + + +def test_cubical_graph_is_2(): + G = nx.cubical_graph() + assert nx.bipartite.maximal_extendability(G) == 2 + + +def test_k_is_3(): + G = nx.Graph( + [ + (1, 6), + (1, 7), + (1, 8), + (1, 9), + (2, 6), + (2, 7), + (2, 8), + (2, 10), + (3, 6), + (3, 8), + (3, 9), + (3, 10), + (4, 7), + (4, 8), + (4, 9), + (4, 10), + (5, 6), + (5, 7), + (5, 9), + (5, 10), + ] + ) + assert nx.bipartite.maximal_extendability(G) == 3 + + +def test_k_is_4(): + G = nx.Graph( + [ + (8, 1), + (8, 2), + (8, 3), + (8, 4), + (8, 5), + (9, 1), + (9, 2), + (9, 3), + (9, 4), + (9, 7), + (10, 1), + (10, 2), + (10, 3), + (10, 4), + (10, 6), + (11, 1), + (11, 2), + (11, 5), + (11, 6), + (11, 7), + (12, 1), + (12, 3), + (12, 5), + (12, 6), + (12, 7), + (13, 2), + (13, 4), + (13, 5), + (13, 6), + (13, 7), + (14, 3), + (14, 4), + (14, 5), + (14, 6), + (14, 7), + ] + ) + assert nx.bipartite.maximal_extendability(G) == 4 + + +def test_k_is_5(): + G = nx.Graph( + [ + (8, 1), + (8, 2), + (8, 3), + (8, 4), + (8, 5), + (8, 6), + (9, 1), + (9, 2), + (9, 3), + (9, 4), + (9, 5), + (9, 7), + (10, 1), + (10, 2), + (10, 3), + (10, 4), + (10, 6), + (10, 7), + (11, 1), + (11, 2), + (11, 3), + (11, 5), + (11, 6), + (11, 7), + (12, 1), + (12, 2), + (12, 4), + (12, 5), + (12, 6), + (12, 7), + (13, 1), + (13, 3), + (13, 4), + (13, 5), + (13, 6), + (13, 7), + (14, 2), + (14, 3), + (14, 4), + (14, 5), + (14, 6), + (14, 7), + ] + ) + assert nx.bipartite.maximal_extendability(G) == 5 + + +def test_k_is_6(): + G = nx.Graph( + [ + (9, 1), + (9, 2), + (9, 3), + (9, 4), + (9, 5), + (9, 6), + (9, 7), + (10, 1), + (10, 2), + (10, 3), + (10, 4), + (10, 5), + (10, 6), + (10, 8), + (11, 1), + (11, 2), + (11, 3), + (11, 4), + (11, 5), + (11, 7), + (11, 8), + (12, 1), + (12, 2), + (12, 3), + (12, 4), + (12, 6), + (12, 7), + (12, 8), + (13, 1), + (13, 2), + (13, 3), + (13, 5), + (13, 6), + (13, 7), + (13, 8), + (14, 1), + (14, 2), + (14, 4), + (14, 5), + (14, 6), + (14, 7), + (14, 8), + (15, 1), + (15, 3), + (15, 4), + (15, 5), + (15, 6), + (15, 7), + (15, 8), + (16, 2), + (16, 3), + (16, 4), + (16, 5), + (16, 6), + (16, 7), + (16, 8), + ] + ) + assert nx.bipartite.maximal_extendability(G) == 6 + + +def test_k_is_7(): + G = nx.Graph( + [ + (1, 11), + (1, 12), + (1, 13), + (1, 14), + (1, 15), + (1, 16), + (1, 17), + (1, 18), + (2, 11), + (2, 12), + (2, 13), + (2, 14), + (2, 15), + (2, 16), + (2, 17), + (2, 19), + (3, 11), + (3, 12), + (3, 13), + (3, 14), + (3, 15), + (3, 16), + (3, 17), + (3, 20), + (4, 11), + (4, 12), + (4, 13), + (4, 14), + (4, 15), + (4, 16), + (4, 17), + (4, 18), + (4, 19), + (4, 20), + (5, 11), + (5, 12), + (5, 13), + (5, 14), + (5, 15), + (5, 16), + (5, 17), + (5, 18), + (5, 19), + (5, 20), + (6, 11), + (6, 12), + (6, 13), + (6, 14), + (6, 15), + (6, 16), + (6, 17), + (6, 18), + (6, 19), + (6, 20), + (7, 11), + (7, 12), + (7, 13), + (7, 14), + (7, 15), + (7, 16), + (7, 17), + (7, 18), + (7, 19), + (7, 20), + (8, 11), + (8, 12), + (8, 13), + (8, 14), + (8, 15), + (8, 16), + (8, 17), + (8, 18), + (8, 19), + (8, 20), + (9, 11), + (9, 12), + (9, 13), + (9, 14), + (9, 15), + (9, 16), + (9, 17), + (9, 18), + (9, 19), + (9, 20), + (10, 11), + (10, 12), + (10, 13), + (10, 14), + (10, 15), + (10, 16), + (10, 17), + (10, 18), + (10, 19), + (10, 20), + ] + ) + assert nx.bipartite.maximal_extendability(G) == 7 diff --git a/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/test_generators.py b/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/test_generators.py new file mode 100644 index 0000000000000000000000000000000000000000..8b1e7793ffcf3b43eb7ff19e8662d60df4bc59df --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/test_generators.py @@ -0,0 +1,409 @@ +import numbers + +import pytest + +import networkx as nx + +from ..generators import ( + alternating_havel_hakimi_graph, + complete_bipartite_graph, + configuration_model, + gnmk_random_graph, + havel_hakimi_graph, + preferential_attachment_graph, + random_graph, + reverse_havel_hakimi_graph, +) + +""" +Generators - Bipartite +---------------------- +""" + + +class TestGeneratorsBipartite: + def test_complete_bipartite_graph(self): + G = complete_bipartite_graph(0, 0) + assert nx.is_isomorphic(G, nx.null_graph()) + + for i in [1, 5]: + G = complete_bipartite_graph(i, 0) + assert nx.is_isomorphic(G, nx.empty_graph(i)) + G = complete_bipartite_graph(0, i) + assert nx.is_isomorphic(G, nx.empty_graph(i)) + + G = complete_bipartite_graph(2, 2) + assert nx.is_isomorphic(G, nx.cycle_graph(4)) + + G = complete_bipartite_graph(1, 5) + assert nx.is_isomorphic(G, nx.star_graph(5)) + + G = complete_bipartite_graph(5, 1) + assert nx.is_isomorphic(G, nx.star_graph(5)) + + # complete_bipartite_graph(m1,m2) is a connected graph with + # m1+m2 nodes and m1*m2 edges + for m1, m2 in [(5, 11), (7, 3)]: + G = complete_bipartite_graph(m1, m2) + assert nx.number_of_nodes(G) == m1 + m2 + assert nx.number_of_edges(G) == m1 * m2 + + with pytest.raises(nx.NetworkXError): + complete_bipartite_graph(7, 3, create_using=nx.DiGraph) + with pytest.raises(nx.NetworkXError): + complete_bipartite_graph(7, 3, create_using=nx.MultiDiGraph) + + mG = complete_bipartite_graph(7, 3, create_using=nx.MultiGraph) + assert mG.is_multigraph() + assert sorted(mG.edges()) == sorted(G.edges()) + + mG = complete_bipartite_graph(7, 3, create_using=nx.MultiGraph) + assert mG.is_multigraph() + assert sorted(mG.edges()) == sorted(G.edges()) + + mG = complete_bipartite_graph(7, 3) # default to Graph + assert sorted(mG.edges()) == sorted(G.edges()) + assert not mG.is_multigraph() + assert not mG.is_directed() + + # specify nodes rather than number of nodes + for n1, n2 in [([1, 2], "ab"), (3, 2), (3, "ab"), ("ab", 3)]: + G = complete_bipartite_graph(n1, n2) + if isinstance(n1, numbers.Integral): + if isinstance(n2, numbers.Integral): + n2 = range(n1, n1 + n2) + n1 = range(n1) + elif isinstance(n2, numbers.Integral): + n2 = range(n2) + edges = {(u, v) for u in n1 for v in n2} + assert edges == set(G.edges) + assert G.size() == len(edges) + + # raise when node sets are not distinct + for n1, n2 in [([1, 2], 3), (3, [1, 2]), ("abc", "bcd")]: + pytest.raises(nx.NetworkXError, complete_bipartite_graph, n1, n2) + + def test_configuration_model(self): + aseq = [] + bseq = [] + G = configuration_model(aseq, bseq) + assert len(G) == 0 + + aseq = [0, 0] + bseq = [0, 0] + G = configuration_model(aseq, bseq) + assert len(G) == 4 + assert G.number_of_edges() == 0 + + aseq = [3, 3, 3, 3] + bseq = [2, 2, 2, 2, 2] + pytest.raises(nx.NetworkXError, configuration_model, aseq, bseq) + + aseq = [3, 3, 3, 3] + bseq = [2, 2, 2, 2, 2, 2] + G = configuration_model(aseq, bseq) + assert sorted(d for n, d in G.degree()) == [2, 2, 2, 2, 2, 2, 3, 3, 3, 3] + + aseq = [2, 2, 2, 2, 2, 2] + bseq = [3, 3, 3, 3] + G = configuration_model(aseq, bseq) + assert sorted(d for n, d in G.degree()) == [2, 2, 2, 2, 2, 2, 3, 3, 3, 3] + + aseq = [2, 2, 2, 1, 1, 1] + bseq = [3, 3, 3] + G = configuration_model(aseq, bseq) + assert G.is_multigraph() + assert not G.is_directed() + assert sorted(d for n, d in G.degree()) == [1, 1, 1, 2, 2, 2, 3, 3, 3] + + GU = nx.projected_graph(nx.Graph(G), range(len(aseq))) + assert GU.number_of_nodes() == 6 + + GD = nx.projected_graph(nx.Graph(G), range(len(aseq), len(aseq) + len(bseq))) + assert GD.number_of_nodes() == 3 + + G = reverse_havel_hakimi_graph(aseq, bseq, create_using=nx.Graph) + assert not G.is_multigraph() + assert not G.is_directed() + + pytest.raises( + nx.NetworkXError, configuration_model, aseq, bseq, create_using=nx.DiGraph() + ) + pytest.raises( + nx.NetworkXError, configuration_model, aseq, bseq, create_using=nx.DiGraph + ) + pytest.raises( + nx.NetworkXError, + configuration_model, + aseq, + bseq, + create_using=nx.MultiDiGraph, + ) + + def test_havel_hakimi_graph(self): + aseq = [] + bseq = [] + G = havel_hakimi_graph(aseq, bseq) + assert len(G) == 0 + + aseq = [0, 0] + bseq = [0, 0] + G = havel_hakimi_graph(aseq, bseq) + assert len(G) == 4 + assert G.number_of_edges() == 0 + + aseq = [3, 3, 3, 3] + bseq = [2, 2, 2, 2, 2] + pytest.raises(nx.NetworkXError, havel_hakimi_graph, aseq, bseq) + + bseq = [2, 2, 2, 2, 2, 2] + G = havel_hakimi_graph(aseq, bseq) + assert sorted(d for n, d in G.degree()) == [2, 2, 2, 2, 2, 2, 3, 3, 3, 3] + + aseq = [2, 2, 2, 2, 2, 2] + bseq = [3, 3, 3, 3] + G = havel_hakimi_graph(aseq, bseq) + assert G.is_multigraph() + assert not G.is_directed() + assert sorted(d for n, d in G.degree()) == [2, 2, 2, 2, 2, 2, 3, 3, 3, 3] + + GU = nx.projected_graph(nx.Graph(G), range(len(aseq))) + assert GU.number_of_nodes() == 6 + + GD = nx.projected_graph(nx.Graph(G), range(len(aseq), len(aseq) + len(bseq))) + assert GD.number_of_nodes() == 4 + + G = reverse_havel_hakimi_graph(aseq, bseq, create_using=nx.Graph) + assert not G.is_multigraph() + assert not G.is_directed() + + pytest.raises( + nx.NetworkXError, havel_hakimi_graph, aseq, bseq, create_using=nx.DiGraph + ) + pytest.raises( + nx.NetworkXError, havel_hakimi_graph, aseq, bseq, create_using=nx.DiGraph + ) + pytest.raises( + nx.NetworkXError, + havel_hakimi_graph, + aseq, + bseq, + create_using=nx.MultiDiGraph, + ) + + def test_reverse_havel_hakimi_graph(self): + aseq = [] + bseq = [] + G = reverse_havel_hakimi_graph(aseq, bseq) + assert len(G) == 0 + + aseq = [0, 0] + bseq = [0, 0] + G = reverse_havel_hakimi_graph(aseq, bseq) + assert len(G) == 4 + assert G.number_of_edges() == 0 + + aseq = [3, 3, 3, 3] + bseq = [2, 2, 2, 2, 2] + pytest.raises(nx.NetworkXError, reverse_havel_hakimi_graph, aseq, bseq) + + bseq = [2, 2, 2, 2, 2, 2] + G = reverse_havel_hakimi_graph(aseq, bseq) + assert sorted(d for n, d in G.degree()) == [2, 2, 2, 2, 2, 2, 3, 3, 3, 3] + + aseq = [2, 2, 2, 2, 2, 2] + bseq = [3, 3, 3, 3] + G = reverse_havel_hakimi_graph(aseq, bseq) + assert sorted(d for n, d in G.degree()) == [2, 2, 2, 2, 2, 2, 3, 3, 3, 3] + + aseq = [2, 2, 2, 1, 1, 1] + bseq = [3, 3, 3] + G = reverse_havel_hakimi_graph(aseq, bseq) + assert G.is_multigraph() + assert not G.is_directed() + assert sorted(d for n, d in G.degree()) == [1, 1, 1, 2, 2, 2, 3, 3, 3] + + GU = nx.projected_graph(nx.Graph(G), range(len(aseq))) + assert GU.number_of_nodes() == 6 + + GD = nx.projected_graph(nx.Graph(G), range(len(aseq), len(aseq) + len(bseq))) + assert GD.number_of_nodes() == 3 + + G = reverse_havel_hakimi_graph(aseq, bseq, create_using=nx.Graph) + assert not G.is_multigraph() + assert not G.is_directed() + + pytest.raises( + nx.NetworkXError, + reverse_havel_hakimi_graph, + aseq, + bseq, + create_using=nx.DiGraph, + ) + pytest.raises( + nx.NetworkXError, + reverse_havel_hakimi_graph, + aseq, + bseq, + create_using=nx.DiGraph, + ) + pytest.raises( + nx.NetworkXError, + reverse_havel_hakimi_graph, + aseq, + bseq, + create_using=nx.MultiDiGraph, + ) + + def test_alternating_havel_hakimi_graph(self): + aseq = [] + bseq = [] + G = alternating_havel_hakimi_graph(aseq, bseq) + assert len(G) == 0 + + aseq = [0, 0] + bseq = [0, 0] + G = alternating_havel_hakimi_graph(aseq, bseq) + assert len(G) == 4 + assert G.number_of_edges() == 0 + + aseq = [3, 3, 3, 3] + bseq = [2, 2, 2, 2, 2] + pytest.raises(nx.NetworkXError, alternating_havel_hakimi_graph, aseq, bseq) + + bseq = [2, 2, 2, 2, 2, 2] + G = alternating_havel_hakimi_graph(aseq, bseq) + assert sorted(d for n, d in G.degree()) == [2, 2, 2, 2, 2, 2, 3, 3, 3, 3] + + aseq = [2, 2, 2, 2, 2, 2] + bseq = [3, 3, 3, 3] + G = alternating_havel_hakimi_graph(aseq, bseq) + assert sorted(d for n, d in G.degree()) == [2, 2, 2, 2, 2, 2, 3, 3, 3, 3] + + aseq = [2, 2, 2, 1, 1, 1] + bseq = [3, 3, 3] + G = alternating_havel_hakimi_graph(aseq, bseq) + assert G.is_multigraph() + assert not G.is_directed() + assert sorted(d for n, d in G.degree()) == [1, 1, 1, 2, 2, 2, 3, 3, 3] + + GU = nx.projected_graph(nx.Graph(G), range(len(aseq))) + assert GU.number_of_nodes() == 6 + + GD = nx.projected_graph(nx.Graph(G), range(len(aseq), len(aseq) + len(bseq))) + assert GD.number_of_nodes() == 3 + + G = reverse_havel_hakimi_graph(aseq, bseq, create_using=nx.Graph) + assert not G.is_multigraph() + assert not G.is_directed() + + pytest.raises( + nx.NetworkXError, + alternating_havel_hakimi_graph, + aseq, + bseq, + create_using=nx.DiGraph, + ) + pytest.raises( + nx.NetworkXError, + alternating_havel_hakimi_graph, + aseq, + bseq, + create_using=nx.DiGraph, + ) + pytest.raises( + nx.NetworkXError, + alternating_havel_hakimi_graph, + aseq, + bseq, + create_using=nx.MultiDiGraph, + ) + + def test_preferential_attachment(self): + aseq = [3, 2, 1, 1] + G = preferential_attachment_graph(aseq, 0.5) + assert G.is_multigraph() + assert not G.is_directed() + + G = preferential_attachment_graph(aseq, 0.5, create_using=nx.Graph) + assert not G.is_multigraph() + assert not G.is_directed() + + pytest.raises( + nx.NetworkXError, + preferential_attachment_graph, + aseq, + 0.5, + create_using=nx.DiGraph(), + ) + pytest.raises( + nx.NetworkXError, + preferential_attachment_graph, + aseq, + 0.5, + create_using=nx.DiGraph(), + ) + pytest.raises( + nx.NetworkXError, + preferential_attachment_graph, + aseq, + 0.5, + create_using=nx.DiGraph(), + ) + + def test_random_graph(self): + n = 10 + m = 20 + G = random_graph(n, m, 0.9) + assert len(G) == 30 + assert nx.is_bipartite(G) + X, Y = nx.algorithms.bipartite.sets(G) + assert set(range(n)) == X + assert set(range(n, n + m)) == Y + + def test_random_digraph(self): + n = 10 + m = 20 + G = random_graph(n, m, 0.9, directed=True) + assert len(G) == 30 + assert nx.is_bipartite(G) + X, Y = nx.algorithms.bipartite.sets(G) + assert set(range(n)) == X + assert set(range(n, n + m)) == Y + + def test_gnmk_random_graph(self): + n = 10 + m = 20 + edges = 100 + # set seed because sometimes it is not connected + # which raises an error in bipartite.sets(G) below. + G = gnmk_random_graph(n, m, edges, seed=1234) + assert len(G) == n + m + assert nx.is_bipartite(G) + X, Y = nx.algorithms.bipartite.sets(G) + # print(X) + assert set(range(n)) == X + assert set(range(n, n + m)) == Y + assert edges == len(list(G.edges())) + + def test_gnmk_random_graph_complete(self): + n = 10 + m = 20 + edges = 200 + G = gnmk_random_graph(n, m, edges) + assert len(G) == n + m + assert nx.is_bipartite(G) + X, Y = nx.algorithms.bipartite.sets(G) + # print(X) + assert set(range(n)) == X + assert set(range(n, n + m)) == Y + assert edges == len(list(G.edges())) + + @pytest.mark.parametrize("n", (4, range(4), {0, 1, 2, 3})) + @pytest.mark.parametrize("m", (range(4, 7), {4, 5, 6})) + def test_complete_bipartite_graph_str(self, n, m): + """Ensure G.name is consistent for all inputs accepted by nodes_or_number. + See gh-7396""" + G = nx.complete_bipartite_graph(n, m) + ans = "Graph named 'complete_bipartite_graph(4, 3)' with 7 nodes and 12 edges" + assert str(G) == ans diff --git a/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/test_matching.py b/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/test_matching.py new file mode 100644 index 0000000000000000000000000000000000000000..c24659ea8fcf01ab4a26a6c6959c4935ab9aad2d --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/test_matching.py @@ -0,0 +1,327 @@ +"""Unit tests for the :mod:`networkx.algorithms.bipartite.matching` module.""" + +import itertools + +import pytest + +import networkx as nx +from networkx.algorithms.bipartite.matching import ( + eppstein_matching, + hopcroft_karp_matching, + maximum_matching, + minimum_weight_full_matching, + to_vertex_cover, +) + + +class TestMatching: + """Tests for bipartite matching algorithms.""" + + def setup_method(self): + """Creates a bipartite graph for use in testing matching algorithms. + + The bipartite graph has a maximum cardinality matching that leaves + vertex 1 and vertex 10 unmatched. The first six numbers are the left + vertices and the next six numbers are the right vertices. + + """ + self.simple_graph = nx.complete_bipartite_graph(2, 3) + self.simple_solution = {0: 2, 1: 3, 2: 0, 3: 1} + + edges = [(0, 7), (0, 8), (2, 6), (2, 9), (3, 8), (4, 8), (4, 9), (5, 11)] + self.top_nodes = set(range(6)) + self.graph = nx.Graph() + self.graph.add_nodes_from(range(12)) + self.graph.add_edges_from(edges) + + # Example bipartite graph from issue 2127 + G = nx.Graph() + G.add_nodes_from( + [ + (1, "C"), + (1, "B"), + (0, "G"), + (1, "F"), + (1, "E"), + (0, "C"), + (1, "D"), + (1, "I"), + (0, "A"), + (0, "D"), + (0, "F"), + (0, "E"), + (0, "H"), + (1, "G"), + (1, "A"), + (0, "I"), + (0, "B"), + (1, "H"), + ] + ) + G.add_edge((1, "C"), (0, "A")) + G.add_edge((1, "B"), (0, "A")) + G.add_edge((0, "G"), (1, "I")) + G.add_edge((0, "G"), (1, "H")) + G.add_edge((1, "F"), (0, "A")) + G.add_edge((1, "F"), (0, "C")) + G.add_edge((1, "F"), (0, "E")) + G.add_edge((1, "E"), (0, "A")) + G.add_edge((1, "E"), (0, "C")) + G.add_edge((0, "C"), (1, "D")) + G.add_edge((0, "C"), (1, "I")) + G.add_edge((0, "C"), (1, "G")) + G.add_edge((0, "C"), (1, "H")) + G.add_edge((1, "D"), (0, "A")) + G.add_edge((1, "I"), (0, "A")) + G.add_edge((1, "I"), (0, "E")) + G.add_edge((0, "A"), (1, "G")) + G.add_edge((0, "A"), (1, "H")) + G.add_edge((0, "E"), (1, "G")) + G.add_edge((0, "E"), (1, "H")) + self.disconnected_graph = G + + def check_match(self, matching): + """Asserts that the matching is what we expect from the bipartite graph + constructed in the :meth:`setup` fixture. + + """ + # For the sake of brevity, rename `matching` to `M`. + M = matching + matched_vertices = frozenset(itertools.chain(*M.items())) + # Assert that the maximum number of vertices (10) is matched. + assert matched_vertices == frozenset(range(12)) - {1, 10} + # Assert that no vertex appears in two edges, or in other words, that + # the matching (u, v) and (v, u) both appear in the matching + # dictionary. + assert all(u == M[M[u]] for u in range(12) if u in M) + + def check_vertex_cover(self, vertices): + """Asserts that the given set of vertices is the vertex cover we + expected from the bipartite graph constructed in the :meth:`setup` + fixture. + + """ + # By Konig's theorem, the number of edges in a maximum matching equals + # the number of vertices in a minimum vertex cover. + assert len(vertices) == 5 + # Assert that the set is truly a vertex cover. + for u, v in self.graph.edges(): + assert u in vertices or v in vertices + # TODO Assert that the vertices are the correct ones. + + def test_eppstein_matching(self): + """Tests that David Eppstein's implementation of the Hopcroft--Karp + algorithm produces a maximum cardinality matching. + + """ + self.check_match(eppstein_matching(self.graph, self.top_nodes)) + + def test_hopcroft_karp_matching(self): + """Tests that the Hopcroft--Karp algorithm produces a maximum + cardinality matching in a bipartite graph. + + """ + self.check_match(hopcroft_karp_matching(self.graph, self.top_nodes)) + + def test_to_vertex_cover(self): + """Test for converting a maximum matching to a minimum vertex cover.""" + matching = maximum_matching(self.graph, self.top_nodes) + vertex_cover = to_vertex_cover(self.graph, matching, self.top_nodes) + self.check_vertex_cover(vertex_cover) + + def test_eppstein_matching_simple(self): + match = eppstein_matching(self.simple_graph) + assert match == self.simple_solution + + def test_hopcroft_karp_matching_simple(self): + match = hopcroft_karp_matching(self.simple_graph) + assert match == self.simple_solution + + def test_eppstein_matching_disconnected(self): + with pytest.raises(nx.AmbiguousSolution): + match = eppstein_matching(self.disconnected_graph) + + def test_hopcroft_karp_matching_disconnected(self): + with pytest.raises(nx.AmbiguousSolution): + match = hopcroft_karp_matching(self.disconnected_graph) + + def test_issue_2127(self): + """Test from issue 2127""" + # Build the example DAG + G = nx.DiGraph() + G.add_edge("A", "C") + G.add_edge("A", "B") + G.add_edge("C", "E") + G.add_edge("C", "D") + G.add_edge("E", "G") + G.add_edge("E", "F") + G.add_edge("G", "I") + G.add_edge("G", "H") + + tc = nx.transitive_closure(G) + btc = nx.Graph() + + # Create a bipartite graph based on the transitive closure of G + for v in tc.nodes(): + btc.add_node((0, v)) + btc.add_node((1, v)) + + for u, v in tc.edges(): + btc.add_edge((0, u), (1, v)) + + top_nodes = {n for n in btc if n[0] == 0} + matching = hopcroft_karp_matching(btc, top_nodes) + vertex_cover = to_vertex_cover(btc, matching, top_nodes) + independent_set = set(G) - {v for _, v in vertex_cover} + assert {"B", "D", "F", "I", "H"} == independent_set + + def test_vertex_cover_issue_2384(self): + G = nx.Graph([(0, 3), (1, 3), (1, 4), (2, 3)]) + matching = maximum_matching(G) + vertex_cover = to_vertex_cover(G, matching) + for u, v in G.edges(): + assert u in vertex_cover or v in vertex_cover + + def test_vertex_cover_issue_3306(self): + G = nx.Graph() + edges = [(0, 2), (1, 0), (1, 1), (1, 2), (2, 2)] + G.add_edges_from([((i, "L"), (j, "R")) for i, j in edges]) + + matching = maximum_matching(G) + vertex_cover = to_vertex_cover(G, matching) + for u, v in G.edges(): + assert u in vertex_cover or v in vertex_cover + + def test_unorderable_nodes(self): + a = object() + b = object() + c = object() + d = object() + e = object() + G = nx.Graph([(a, d), (b, d), (b, e), (c, d)]) + matching = maximum_matching(G) + vertex_cover = to_vertex_cover(G, matching) + for u, v in G.edges(): + assert u in vertex_cover or v in vertex_cover + + +def test_eppstein_matching(): + """Test in accordance to issue #1927""" + G = nx.Graph() + G.add_nodes_from(["a", 2, 3, 4], bipartite=0) + G.add_nodes_from([1, "b", "c"], bipartite=1) + G.add_edges_from([("a", 1), ("a", "b"), (2, "b"), (2, "c"), (3, "c"), (4, 1)]) + matching = eppstein_matching(G) + assert len(matching) == len(maximum_matching(G)) + assert all(x in set(matching.keys()) for x in set(matching.values())) + + +class TestMinimumWeightFullMatching: + @classmethod + def setup_class(cls): + pytest.importorskip("scipy") + + def test_minimum_weight_full_matching_incomplete_graph(self): + B = nx.Graph() + B.add_nodes_from([1, 2], bipartite=0) + B.add_nodes_from([3, 4], bipartite=1) + B.add_edge(1, 4, weight=100) + B.add_edge(2, 3, weight=100) + B.add_edge(2, 4, weight=50) + matching = minimum_weight_full_matching(B) + assert matching == {1: 4, 2: 3, 4: 1, 3: 2} + + def test_minimum_weight_full_matching_with_no_full_matching(self): + B = nx.Graph() + B.add_nodes_from([1, 2, 3], bipartite=0) + B.add_nodes_from([4, 5, 6], bipartite=1) + B.add_edge(1, 4, weight=100) + B.add_edge(2, 4, weight=100) + B.add_edge(3, 4, weight=50) + B.add_edge(3, 5, weight=50) + B.add_edge(3, 6, weight=50) + with pytest.raises(ValueError): + minimum_weight_full_matching(B) + + def test_minimum_weight_full_matching_square(self): + G = nx.complete_bipartite_graph(3, 3) + G.add_edge(0, 3, weight=400) + G.add_edge(0, 4, weight=150) + G.add_edge(0, 5, weight=400) + G.add_edge(1, 3, weight=400) + G.add_edge(1, 4, weight=450) + G.add_edge(1, 5, weight=600) + G.add_edge(2, 3, weight=300) + G.add_edge(2, 4, weight=225) + G.add_edge(2, 5, weight=300) + matching = minimum_weight_full_matching(G) + assert matching == {0: 4, 1: 3, 2: 5, 4: 0, 3: 1, 5: 2} + + def test_minimum_weight_full_matching_smaller_left(self): + G = nx.complete_bipartite_graph(3, 4) + G.add_edge(0, 3, weight=400) + G.add_edge(0, 4, weight=150) + G.add_edge(0, 5, weight=400) + G.add_edge(0, 6, weight=1) + G.add_edge(1, 3, weight=400) + G.add_edge(1, 4, weight=450) + G.add_edge(1, 5, weight=600) + G.add_edge(1, 6, weight=2) + G.add_edge(2, 3, weight=300) + G.add_edge(2, 4, weight=225) + G.add_edge(2, 5, weight=290) + G.add_edge(2, 6, weight=3) + matching = minimum_weight_full_matching(G) + assert matching == {0: 4, 1: 6, 2: 5, 4: 0, 5: 2, 6: 1} + + def test_minimum_weight_full_matching_smaller_top_nodes_right(self): + G = nx.complete_bipartite_graph(3, 4) + G.add_edge(0, 3, weight=400) + G.add_edge(0, 4, weight=150) + G.add_edge(0, 5, weight=400) + G.add_edge(0, 6, weight=1) + G.add_edge(1, 3, weight=400) + G.add_edge(1, 4, weight=450) + G.add_edge(1, 5, weight=600) + G.add_edge(1, 6, weight=2) + G.add_edge(2, 3, weight=300) + G.add_edge(2, 4, weight=225) + G.add_edge(2, 5, weight=290) + G.add_edge(2, 6, weight=3) + matching = minimum_weight_full_matching(G, top_nodes=[3, 4, 5, 6]) + assert matching == {0: 4, 1: 6, 2: 5, 4: 0, 5: 2, 6: 1} + + def test_minimum_weight_full_matching_smaller_right(self): + G = nx.complete_bipartite_graph(4, 3) + G.add_edge(0, 4, weight=400) + G.add_edge(0, 5, weight=400) + G.add_edge(0, 6, weight=300) + G.add_edge(1, 4, weight=150) + G.add_edge(1, 5, weight=450) + G.add_edge(1, 6, weight=225) + G.add_edge(2, 4, weight=400) + G.add_edge(2, 5, weight=600) + G.add_edge(2, 6, weight=290) + G.add_edge(3, 4, weight=1) + G.add_edge(3, 5, weight=2) + G.add_edge(3, 6, weight=3) + matching = minimum_weight_full_matching(G) + assert matching == {1: 4, 2: 6, 3: 5, 4: 1, 5: 3, 6: 2} + + def test_minimum_weight_full_matching_negative_weights(self): + G = nx.complete_bipartite_graph(2, 2) + G.add_edge(0, 2, weight=-2) + G.add_edge(0, 3, weight=0.2) + G.add_edge(1, 2, weight=-2) + G.add_edge(1, 3, weight=0.3) + matching = minimum_weight_full_matching(G) + assert matching == {0: 3, 1: 2, 2: 1, 3: 0} + + def test_minimum_weight_full_matching_different_weight_key(self): + G = nx.complete_bipartite_graph(2, 2) + G.add_edge(0, 2, mass=2) + G.add_edge(0, 3, mass=0.2) + G.add_edge(1, 2, mass=1) + G.add_edge(1, 3, mass=2) + matching = minimum_weight_full_matching(G, weight="mass") + assert matching == {0: 3, 1: 2, 2: 1, 3: 0} diff --git a/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/test_matrix.py b/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/test_matrix.py new file mode 100644 index 0000000000000000000000000000000000000000..53d83115118e9bbdf6238b89d171bf6b7b829477 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/test_matrix.py @@ -0,0 +1,84 @@ +import pytest + +np = pytest.importorskip("numpy") +sp = pytest.importorskip("scipy") +sparse = pytest.importorskip("scipy.sparse") + + +import networkx as nx +from networkx.algorithms import bipartite +from networkx.utils import edges_equal + + +class TestBiadjacencyMatrix: + def test_biadjacency_matrix_weight(self): + G = nx.path_graph(5) + G.add_edge(0, 1, weight=2, other=4) + X = [1, 3] + Y = [0, 2, 4] + M = bipartite.biadjacency_matrix(G, X, weight="weight") + assert M[0, 0] == 2 + M = bipartite.biadjacency_matrix(G, X, weight="other") + assert M[0, 0] == 4 + + def test_biadjacency_matrix(self): + tops = [2, 5, 10] + bots = [5, 10, 15] + for i in range(len(tops)): + G = bipartite.random_graph(tops[i], bots[i], 0.2) + top = [n for n, d in G.nodes(data=True) if d["bipartite"] == 0] + M = bipartite.biadjacency_matrix(G, top) + assert M.shape[0] == tops[i] + assert M.shape[1] == bots[i] + + def test_biadjacency_matrix_order(self): + G = nx.path_graph(5) + G.add_edge(0, 1, weight=2) + X = [3, 1] + Y = [4, 2, 0] + M = bipartite.biadjacency_matrix(G, X, Y, weight="weight") + assert M[1, 2] == 2 + + def test_biadjacency_matrix_empty_graph(self): + G = nx.empty_graph(2) + M = nx.bipartite.biadjacency_matrix(G, [0]) + assert np.array_equal(M.toarray(), np.array([[0]])) + + def test_null_graph(self): + with pytest.raises(nx.NetworkXError): + bipartite.biadjacency_matrix(nx.Graph(), []) + + def test_empty_graph(self): + with pytest.raises(nx.NetworkXError): + bipartite.biadjacency_matrix(nx.Graph([(1, 0)]), []) + + def test_duplicate_row(self): + with pytest.raises(nx.NetworkXError): + bipartite.biadjacency_matrix(nx.Graph([(1, 0)]), [1, 1]) + + def test_duplicate_col(self): + with pytest.raises(nx.NetworkXError): + bipartite.biadjacency_matrix(nx.Graph([(1, 0)]), [0], [1, 1]) + + def test_format_keyword(self): + with pytest.raises(nx.NetworkXError): + bipartite.biadjacency_matrix(nx.Graph([(1, 0)]), [0], format="foo") + + def test_from_biadjacency_roundtrip(self): + B1 = nx.path_graph(5) + M = bipartite.biadjacency_matrix(B1, [0, 2, 4]) + B2 = bipartite.from_biadjacency_matrix(M) + assert nx.is_isomorphic(B1, B2) + + def test_from_biadjacency_weight(self): + M = sparse.csc_matrix([[1, 2], [0, 3]]) + B = bipartite.from_biadjacency_matrix(M) + assert edges_equal(B.edges(), [(0, 2), (0, 3), (1, 3)]) + B = bipartite.from_biadjacency_matrix(M, edge_attribute="weight") + e = [(0, 2, {"weight": 1}), (0, 3, {"weight": 2}), (1, 3, {"weight": 3})] + assert edges_equal(B.edges(data=True), e) + + def test_from_biadjacency_multigraph(self): + M = sparse.csc_matrix([[1, 2], [0, 3]]) + B = bipartite.from_biadjacency_matrix(M, create_using=nx.MultiGraph()) + assert edges_equal(B.edges(), [(0, 2), (0, 3), (0, 3), (1, 3), (1, 3), (1, 3)]) diff --git a/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/test_project.py b/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/test_project.py new file mode 100644 index 0000000000000000000000000000000000000000..076bb42b668657cad51f6423e5aacf23a2a1cd28 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/test_project.py @@ -0,0 +1,407 @@ +import pytest + +import networkx as nx +from networkx.algorithms import bipartite +from networkx.utils import edges_equal, nodes_equal + + +class TestBipartiteProject: + def test_path_projected_graph(self): + G = nx.path_graph(4) + P = bipartite.projected_graph(G, [1, 3]) + assert nodes_equal(list(P), [1, 3]) + assert edges_equal(list(P.edges()), [(1, 3)]) + P = bipartite.projected_graph(G, [0, 2]) + assert nodes_equal(list(P), [0, 2]) + assert edges_equal(list(P.edges()), [(0, 2)]) + G = nx.MultiGraph([(0, 1)]) + with pytest.raises(nx.NetworkXError, match="not defined for multigraphs"): + bipartite.projected_graph(G, [0]) + + def test_path_projected_properties_graph(self): + G = nx.path_graph(4) + G.add_node(1, name="one") + G.add_node(2, name="two") + P = bipartite.projected_graph(G, [1, 3]) + assert nodes_equal(list(P), [1, 3]) + assert edges_equal(list(P.edges()), [(1, 3)]) + assert P.nodes[1]["name"] == G.nodes[1]["name"] + P = bipartite.projected_graph(G, [0, 2]) + assert nodes_equal(list(P), [0, 2]) + assert edges_equal(list(P.edges()), [(0, 2)]) + assert P.nodes[2]["name"] == G.nodes[2]["name"] + + def test_path_collaboration_projected_graph(self): + G = nx.path_graph(4) + P = bipartite.collaboration_weighted_projected_graph(G, [1, 3]) + assert nodes_equal(list(P), [1, 3]) + assert edges_equal(list(P.edges()), [(1, 3)]) + P[1][3]["weight"] = 1 + P = bipartite.collaboration_weighted_projected_graph(G, [0, 2]) + assert nodes_equal(list(P), [0, 2]) + assert edges_equal(list(P.edges()), [(0, 2)]) + P[0][2]["weight"] = 1 + + def test_directed_path_collaboration_projected_graph(self): + G = nx.DiGraph() + nx.add_path(G, range(4)) + P = bipartite.collaboration_weighted_projected_graph(G, [1, 3]) + assert nodes_equal(list(P), [1, 3]) + assert edges_equal(list(P.edges()), [(1, 3)]) + P[1][3]["weight"] = 1 + P = bipartite.collaboration_weighted_projected_graph(G, [0, 2]) + assert nodes_equal(list(P), [0, 2]) + assert edges_equal(list(P.edges()), [(0, 2)]) + P[0][2]["weight"] = 1 + + def test_path_weighted_projected_graph(self): + G = nx.path_graph(4) + + with pytest.raises(nx.NetworkXAlgorithmError): + bipartite.weighted_projected_graph(G, [1, 2, 3, 3]) + + P = bipartite.weighted_projected_graph(G, [1, 3]) + assert nodes_equal(list(P), [1, 3]) + assert edges_equal(list(P.edges()), [(1, 3)]) + P[1][3]["weight"] = 1 + P = bipartite.weighted_projected_graph(G, [0, 2]) + assert nodes_equal(list(P), [0, 2]) + assert edges_equal(list(P.edges()), [(0, 2)]) + P[0][2]["weight"] = 1 + + def test_digraph_weighted_projection(self): + G = nx.DiGraph([(0, 1), (1, 2), (2, 3), (3, 4)]) + P = bipartite.overlap_weighted_projected_graph(G, [1, 3]) + assert nx.get_edge_attributes(P, "weight") == {(1, 3): 1.0} + assert len(P) == 2 + + def test_path_weighted_projected_directed_graph(self): + G = nx.DiGraph() + nx.add_path(G, range(4)) + P = bipartite.weighted_projected_graph(G, [1, 3]) + assert nodes_equal(list(P), [1, 3]) + assert edges_equal(list(P.edges()), [(1, 3)]) + P[1][3]["weight"] = 1 + P = bipartite.weighted_projected_graph(G, [0, 2]) + assert nodes_equal(list(P), [0, 2]) + assert edges_equal(list(P.edges()), [(0, 2)]) + P[0][2]["weight"] = 1 + + def test_star_projected_graph(self): + G = nx.star_graph(3) + P = bipartite.projected_graph(G, [1, 2, 3]) + assert nodes_equal(list(P), [1, 2, 3]) + assert edges_equal(list(P.edges()), [(1, 2), (1, 3), (2, 3)]) + P = bipartite.weighted_projected_graph(G, [1, 2, 3]) + assert nodes_equal(list(P), [1, 2, 3]) + assert edges_equal(list(P.edges()), [(1, 2), (1, 3), (2, 3)]) + + P = bipartite.projected_graph(G, [0]) + assert nodes_equal(list(P), [0]) + assert edges_equal(list(P.edges()), []) + + def test_project_multigraph(self): + G = nx.Graph() + G.add_edge("a", 1) + G.add_edge("b", 1) + G.add_edge("a", 2) + G.add_edge("b", 2) + P = bipartite.projected_graph(G, "ab") + assert edges_equal(list(P.edges()), [("a", "b")]) + P = bipartite.weighted_projected_graph(G, "ab") + assert edges_equal(list(P.edges()), [("a", "b")]) + P = bipartite.projected_graph(G, "ab", multigraph=True) + assert edges_equal(list(P.edges()), [("a", "b"), ("a", "b")]) + + def test_project_collaboration(self): + G = nx.Graph() + G.add_edge("a", 1) + G.add_edge("b", 1) + G.add_edge("b", 2) + G.add_edge("c", 2) + G.add_edge("c", 3) + G.add_edge("c", 4) + G.add_edge("b", 4) + P = bipartite.collaboration_weighted_projected_graph(G, "abc") + assert P["a"]["b"]["weight"] == 1 + assert P["b"]["c"]["weight"] == 2 + + def test_directed_projection(self): + G = nx.DiGraph() + G.add_edge("A", 1) + G.add_edge(1, "B") + G.add_edge("A", 2) + G.add_edge("B", 2) + P = bipartite.projected_graph(G, "AB") + assert edges_equal(list(P.edges()), [("A", "B")]) + P = bipartite.weighted_projected_graph(G, "AB") + assert edges_equal(list(P.edges()), [("A", "B")]) + assert P["A"]["B"]["weight"] == 1 + + P = bipartite.projected_graph(G, "AB", multigraph=True) + assert edges_equal(list(P.edges()), [("A", "B")]) + + G = nx.DiGraph() + G.add_edge("A", 1) + G.add_edge(1, "B") + G.add_edge("A", 2) + G.add_edge(2, "B") + P = bipartite.projected_graph(G, "AB") + assert edges_equal(list(P.edges()), [("A", "B")]) + P = bipartite.weighted_projected_graph(G, "AB") + assert edges_equal(list(P.edges()), [("A", "B")]) + assert P["A"]["B"]["weight"] == 2 + + P = bipartite.projected_graph(G, "AB", multigraph=True) + assert edges_equal(list(P.edges()), [("A", "B"), ("A", "B")]) + + +class TestBipartiteWeightedProjection: + @classmethod + def setup_class(cls): + # Tore Opsahl's example + # http://toreopsahl.com/2009/05/01/projecting-two-mode-networks-onto-weighted-one-mode-networks/ + cls.G = nx.Graph() + cls.G.add_edge("A", 1) + cls.G.add_edge("A", 2) + cls.G.add_edge("B", 1) + cls.G.add_edge("B", 2) + cls.G.add_edge("B", 3) + cls.G.add_edge("B", 4) + cls.G.add_edge("B", 5) + cls.G.add_edge("C", 1) + cls.G.add_edge("D", 3) + cls.G.add_edge("E", 4) + cls.G.add_edge("E", 5) + cls.G.add_edge("E", 6) + cls.G.add_edge("F", 6) + # Graph based on figure 6 from Newman (2001) + cls.N = nx.Graph() + cls.N.add_edge("A", 1) + cls.N.add_edge("A", 2) + cls.N.add_edge("A", 3) + cls.N.add_edge("B", 1) + cls.N.add_edge("B", 2) + cls.N.add_edge("B", 3) + cls.N.add_edge("C", 1) + cls.N.add_edge("D", 1) + cls.N.add_edge("E", 3) + + def test_project_weighted_shared(self): + edges = [ + ("A", "B", 2), + ("A", "C", 1), + ("B", "C", 1), + ("B", "D", 1), + ("B", "E", 2), + ("E", "F", 1), + ] + Panswer = nx.Graph() + Panswer.add_weighted_edges_from(edges) + P = bipartite.weighted_projected_graph(self.G, "ABCDEF") + assert edges_equal(list(P.edges()), Panswer.edges()) + for u, v in list(P.edges()): + assert P[u][v]["weight"] == Panswer[u][v]["weight"] + + edges = [ + ("A", "B", 3), + ("A", "E", 1), + ("A", "C", 1), + ("A", "D", 1), + ("B", "E", 1), + ("B", "C", 1), + ("B", "D", 1), + ("C", "D", 1), + ] + Panswer = nx.Graph() + Panswer.add_weighted_edges_from(edges) + P = bipartite.weighted_projected_graph(self.N, "ABCDE") + assert edges_equal(list(P.edges()), Panswer.edges()) + for u, v in list(P.edges()): + assert P[u][v]["weight"] == Panswer[u][v]["weight"] + + def test_project_weighted_newman(self): + edges = [ + ("A", "B", 1.5), + ("A", "C", 0.5), + ("B", "C", 0.5), + ("B", "D", 1), + ("B", "E", 2), + ("E", "F", 1), + ] + Panswer = nx.Graph() + Panswer.add_weighted_edges_from(edges) + P = bipartite.collaboration_weighted_projected_graph(self.G, "ABCDEF") + assert edges_equal(list(P.edges()), Panswer.edges()) + for u, v in list(P.edges()): + assert P[u][v]["weight"] == Panswer[u][v]["weight"] + + edges = [ + ("A", "B", 11 / 6.0), + ("A", "E", 1 / 2.0), + ("A", "C", 1 / 3.0), + ("A", "D", 1 / 3.0), + ("B", "E", 1 / 2.0), + ("B", "C", 1 / 3.0), + ("B", "D", 1 / 3.0), + ("C", "D", 1 / 3.0), + ] + Panswer = nx.Graph() + Panswer.add_weighted_edges_from(edges) + P = bipartite.collaboration_weighted_projected_graph(self.N, "ABCDE") + assert edges_equal(list(P.edges()), Panswer.edges()) + for u, v in list(P.edges()): + assert P[u][v]["weight"] == Panswer[u][v]["weight"] + + def test_project_weighted_ratio(self): + edges = [ + ("A", "B", 2 / 6.0), + ("A", "C", 1 / 6.0), + ("B", "C", 1 / 6.0), + ("B", "D", 1 / 6.0), + ("B", "E", 2 / 6.0), + ("E", "F", 1 / 6.0), + ] + Panswer = nx.Graph() + Panswer.add_weighted_edges_from(edges) + P = bipartite.weighted_projected_graph(self.G, "ABCDEF", ratio=True) + assert edges_equal(list(P.edges()), Panswer.edges()) + for u, v in list(P.edges()): + assert P[u][v]["weight"] == Panswer[u][v]["weight"] + + edges = [ + ("A", "B", 3 / 3.0), + ("A", "E", 1 / 3.0), + ("A", "C", 1 / 3.0), + ("A", "D", 1 / 3.0), + ("B", "E", 1 / 3.0), + ("B", "C", 1 / 3.0), + ("B", "D", 1 / 3.0), + ("C", "D", 1 / 3.0), + ] + Panswer = nx.Graph() + Panswer.add_weighted_edges_from(edges) + P = bipartite.weighted_projected_graph(self.N, "ABCDE", ratio=True) + assert edges_equal(list(P.edges()), Panswer.edges()) + for u, v in list(P.edges()): + assert P[u][v]["weight"] == Panswer[u][v]["weight"] + + def test_project_weighted_overlap(self): + edges = [ + ("A", "B", 2 / 2.0), + ("A", "C", 1 / 1.0), + ("B", "C", 1 / 1.0), + ("B", "D", 1 / 1.0), + ("B", "E", 2 / 3.0), + ("E", "F", 1 / 1.0), + ] + Panswer = nx.Graph() + Panswer.add_weighted_edges_from(edges) + P = bipartite.overlap_weighted_projected_graph(self.G, "ABCDEF", jaccard=False) + assert edges_equal(list(P.edges()), Panswer.edges()) + for u, v in list(P.edges()): + assert P[u][v]["weight"] == Panswer[u][v]["weight"] + + edges = [ + ("A", "B", 3 / 3.0), + ("A", "E", 1 / 1.0), + ("A", "C", 1 / 1.0), + ("A", "D", 1 / 1.0), + ("B", "E", 1 / 1.0), + ("B", "C", 1 / 1.0), + ("B", "D", 1 / 1.0), + ("C", "D", 1 / 1.0), + ] + Panswer = nx.Graph() + Panswer.add_weighted_edges_from(edges) + P = bipartite.overlap_weighted_projected_graph(self.N, "ABCDE", jaccard=False) + assert edges_equal(list(P.edges()), Panswer.edges()) + for u, v in list(P.edges()): + assert P[u][v]["weight"] == Panswer[u][v]["weight"] + + def test_project_weighted_jaccard(self): + edges = [ + ("A", "B", 2 / 5.0), + ("A", "C", 1 / 2.0), + ("B", "C", 1 / 5.0), + ("B", "D", 1 / 5.0), + ("B", "E", 2 / 6.0), + ("E", "F", 1 / 3.0), + ] + Panswer = nx.Graph() + Panswer.add_weighted_edges_from(edges) + P = bipartite.overlap_weighted_projected_graph(self.G, "ABCDEF") + assert edges_equal(list(P.edges()), Panswer.edges()) + for u, v in list(P.edges()): + assert P[u][v]["weight"] == Panswer[u][v]["weight"] + + edges = [ + ("A", "B", 3 / 3.0), + ("A", "E", 1 / 3.0), + ("A", "C", 1 / 3.0), + ("A", "D", 1 / 3.0), + ("B", "E", 1 / 3.0), + ("B", "C", 1 / 3.0), + ("B", "D", 1 / 3.0), + ("C", "D", 1 / 1.0), + ] + Panswer = nx.Graph() + Panswer.add_weighted_edges_from(edges) + P = bipartite.overlap_weighted_projected_graph(self.N, "ABCDE") + assert edges_equal(list(P.edges()), Panswer.edges()) + for u, v in P.edges(): + assert P[u][v]["weight"] == Panswer[u][v]["weight"] + + def test_generic_weighted_projected_graph_simple(self): + def shared(G, u, v): + return len(set(G[u]) & set(G[v])) + + B = nx.path_graph(5) + G = bipartite.generic_weighted_projected_graph( + B, [0, 2, 4], weight_function=shared + ) + assert nodes_equal(list(G), [0, 2, 4]) + assert edges_equal( + list(G.edges(data=True)), + [(0, 2, {"weight": 1}), (2, 4, {"weight": 1})], + ) + + G = bipartite.generic_weighted_projected_graph(B, [0, 2, 4]) + assert nodes_equal(list(G), [0, 2, 4]) + assert edges_equal( + list(G.edges(data=True)), + [(0, 2, {"weight": 1}), (2, 4, {"weight": 1})], + ) + B = nx.DiGraph() + nx.add_path(B, range(5)) + G = bipartite.generic_weighted_projected_graph(B, [0, 2, 4]) + assert nodes_equal(list(G), [0, 2, 4]) + assert edges_equal( + list(G.edges(data=True)), [(0, 2, {"weight": 1}), (2, 4, {"weight": 1})] + ) + + def test_generic_weighted_projected_graph_custom(self): + def jaccard(G, u, v): + unbrs = set(G[u]) + vnbrs = set(G[v]) + return len(unbrs & vnbrs) / len(unbrs | vnbrs) + + def my_weight(G, u, v, weight="weight"): + w = 0 + for nbr in set(G[u]) & set(G[v]): + w += G.edges[u, nbr].get(weight, 1) + G.edges[v, nbr].get(weight, 1) + return w + + B = nx.bipartite.complete_bipartite_graph(2, 2) + for i, (u, v) in enumerate(B.edges()): + B.edges[u, v]["weight"] = i + 1 + G = bipartite.generic_weighted_projected_graph( + B, [0, 1], weight_function=jaccard + ) + assert edges_equal(list(G.edges(data=True)), [(0, 1, {"weight": 1.0})]) + G = bipartite.generic_weighted_projected_graph( + B, [0, 1], weight_function=my_weight + ) + assert edges_equal(list(G.edges(data=True)), [(0, 1, {"weight": 10})]) + G = bipartite.generic_weighted_projected_graph(B, [0, 1]) + assert edges_equal(list(G.edges(data=True)), [(0, 1, {"weight": 2})]) diff --git a/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/test_redundancy.py b/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/test_redundancy.py new file mode 100644 index 0000000000000000000000000000000000000000..8d979db8d8599b1ec59c2123d258bc29efaa4c9b --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/test_redundancy.py @@ -0,0 +1,35 @@ +"""Unit tests for the :mod:`networkx.algorithms.bipartite.redundancy` module.""" + +import pytest + +from networkx import NetworkXError, cycle_graph +from networkx.algorithms.bipartite import complete_bipartite_graph, node_redundancy + + +def test_no_redundant_nodes(): + G = complete_bipartite_graph(2, 2) + + # when nodes is None + rc = node_redundancy(G) + assert all(redundancy == 1 for redundancy in rc.values()) + + # when set of nodes is specified + rc = node_redundancy(G, (2, 3)) + assert rc == {2: 1.0, 3: 1.0} + + +def test_redundant_nodes(): + G = cycle_graph(6) + edge = {0, 3} + G.add_edge(*edge) + redundancy = node_redundancy(G) + for v in edge: + assert redundancy[v] == 2 / 3 + for v in set(G) - edge: + assert redundancy[v] == 1 + + +def test_not_enough_neighbors(): + with pytest.raises(NetworkXError): + G = complete_bipartite_graph(1, 2) + node_redundancy(G) diff --git a/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/test_spectral_bipartivity.py b/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/test_spectral_bipartivity.py new file mode 100644 index 0000000000000000000000000000000000000000..b940649793d40aa73606914f3d48348761c329df --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/bipartite/tests/test_spectral_bipartivity.py @@ -0,0 +1,80 @@ +import pytest + +pytest.importorskip("scipy") + +import networkx as nx +from networkx.algorithms.bipartite import spectral_bipartivity as sb + +# Examples from Figure 1 +# E. Estrada and J. A. Rodríguez-Velázquez, "Spectral measures of +# bipartivity in complex networks", PhysRev E 72, 046105 (2005) + + +class TestSpectralBipartivity: + def test_star_like(self): + # star-like + + G = nx.star_graph(2) + G.add_edge(1, 2) + assert sb(G) == pytest.approx(0.843, abs=1e-3) + + G = nx.star_graph(3) + G.add_edge(1, 2) + assert sb(G) == pytest.approx(0.871, abs=1e-3) + + G = nx.star_graph(4) + G.add_edge(1, 2) + assert sb(G) == pytest.approx(0.890, abs=1e-3) + + def test_k23_like(self): + # K2,3-like + G = nx.complete_bipartite_graph(2, 3) + G.add_edge(0, 1) + assert sb(G) == pytest.approx(0.769, abs=1e-3) + + G = nx.complete_bipartite_graph(2, 3) + G.add_edge(2, 4) + assert sb(G) == pytest.approx(0.829, abs=1e-3) + + G = nx.complete_bipartite_graph(2, 3) + G.add_edge(2, 4) + G.add_edge(3, 4) + assert sb(G) == pytest.approx(0.731, abs=1e-3) + + G = nx.complete_bipartite_graph(2, 3) + G.add_edge(0, 1) + G.add_edge(2, 4) + assert sb(G) == pytest.approx(0.692, abs=1e-3) + + G = nx.complete_bipartite_graph(2, 3) + G.add_edge(2, 4) + G.add_edge(3, 4) + G.add_edge(0, 1) + assert sb(G) == pytest.approx(0.645, abs=1e-3) + + G = nx.complete_bipartite_graph(2, 3) + G.add_edge(2, 4) + G.add_edge(3, 4) + G.add_edge(2, 3) + assert sb(G) == pytest.approx(0.645, abs=1e-3) + + G = nx.complete_bipartite_graph(2, 3) + G.add_edge(2, 4) + G.add_edge(3, 4) + G.add_edge(2, 3) + G.add_edge(0, 1) + assert sb(G) == pytest.approx(0.597, abs=1e-3) + + def test_single_nodes(self): + # single nodes + G = nx.complete_bipartite_graph(2, 3) + G.add_edge(2, 4) + sbn = sb(G, nodes=[1, 2]) + assert sbn[1] == pytest.approx(0.85, abs=1e-2) + assert sbn[2] == pytest.approx(0.77, abs=1e-2) + + G = nx.complete_bipartite_graph(2, 3) + G.add_edge(0, 1) + sbn = sb(G, nodes=[1, 2]) + assert sbn[1] == pytest.approx(0.73, abs=1e-2) + assert sbn[2] == pytest.approx(0.82, abs=1e-2) diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/__init__.py b/lib/python3.10/site-packages/networkx/algorithms/centrality/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c91a904a13496ecab5a3a6c8caa026970d99a540 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/centrality/__init__.py @@ -0,0 +1,20 @@ +from .betweenness import * +from .betweenness_subset import * +from .closeness import * +from .current_flow_betweenness import * +from .current_flow_betweenness_subset import * +from .current_flow_closeness import * +from .degree_alg import * +from .dispersion import * +from .eigenvector import * +from .group import * +from .harmonic import * +from .katz import * +from .load import * +from .percolation import * +from .reaching import * +from .second_order import * +from .subgraph_alg import * +from .trophic import * +from .voterank_alg import * +from .laplacian import * diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/centrality/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..708513e6f92992d7eb47d946ad24d8a97ed331bc Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/centrality/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/__pycache__/betweenness.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/centrality/__pycache__/betweenness.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..528dc29b04c9f44d548d5d4eceaab993837207cb Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/centrality/__pycache__/betweenness.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/__pycache__/betweenness_subset.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/centrality/__pycache__/betweenness_subset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c49c64a4c7f1f331011954668d1f09686fb2e9e4 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/centrality/__pycache__/betweenness_subset.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/__pycache__/closeness.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/centrality/__pycache__/closeness.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1f2e985af8cdf31d5c94dbf3ec89e107ea05bb4f Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/centrality/__pycache__/closeness.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/__pycache__/current_flow_betweenness.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/centrality/__pycache__/current_flow_betweenness.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bffc2a95d6848bad6af88517e0b876318bf8e546 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/centrality/__pycache__/current_flow_betweenness.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/__pycache__/current_flow_betweenness_subset.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/centrality/__pycache__/current_flow_betweenness_subset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..57b1f6cc8cc279f163417a323401f131a6084186 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/centrality/__pycache__/current_flow_betweenness_subset.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/__pycache__/current_flow_closeness.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/centrality/__pycache__/current_flow_closeness.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc66a3dc5776c8f7099432a134fe47021722d792 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/centrality/__pycache__/current_flow_closeness.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/__pycache__/degree_alg.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/centrality/__pycache__/degree_alg.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1cd42775410b0d6b48074ede3a693572b497c30f Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/centrality/__pycache__/degree_alg.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/__pycache__/dispersion.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/centrality/__pycache__/dispersion.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec48f111d20643a1e7d261f4482083e0c6b85cf6 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/centrality/__pycache__/dispersion.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/__pycache__/eigenvector.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/centrality/__pycache__/eigenvector.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..759a11b7c6303ad6305abe67dbf6c8f579bd1f40 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/centrality/__pycache__/eigenvector.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/__pycache__/flow_matrix.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/centrality/__pycache__/flow_matrix.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..47ab3f0895b606f6a12ff153f8ac163f43c9e031 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/centrality/__pycache__/flow_matrix.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/__pycache__/group.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/centrality/__pycache__/group.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae50fa7fb20084c6a0224d24af082d58c4261649 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/centrality/__pycache__/group.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/__pycache__/harmonic.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/centrality/__pycache__/harmonic.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3248096c9678314e79d8e08af5b6159de04e305a Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/centrality/__pycache__/harmonic.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/__pycache__/katz.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/centrality/__pycache__/katz.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e828d01ce872a561ea0973acdf38fcba2997ac17 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/centrality/__pycache__/katz.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/__pycache__/laplacian.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/centrality/__pycache__/laplacian.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c8a603a0f9319406567e5953f12b7b1e8f313c05 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/centrality/__pycache__/laplacian.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/__pycache__/load.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/centrality/__pycache__/load.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c1808f00c01a2c33f26ff122bb28f4e9246413ff Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/centrality/__pycache__/load.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/__pycache__/percolation.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/centrality/__pycache__/percolation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..468edff1aded0e0c8625f0e91362c18960c179d0 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/centrality/__pycache__/percolation.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/__pycache__/reaching.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/centrality/__pycache__/reaching.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3dca1701ff28810adb9caa54a4fd5a3e4e9477a0 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/centrality/__pycache__/reaching.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/__pycache__/second_order.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/centrality/__pycache__/second_order.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..795fda39386b24265e571beaf2b43ed22aa8629b Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/centrality/__pycache__/second_order.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/__pycache__/subgraph_alg.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/centrality/__pycache__/subgraph_alg.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5ed39e01ab673e6e80d914b3d5a3dc8d2372ed43 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/centrality/__pycache__/subgraph_alg.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/__pycache__/trophic.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/centrality/__pycache__/trophic.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c51c9ccdd7ac438193b07c1a332b35618e505c0 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/centrality/__pycache__/trophic.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/__pycache__/voterank_alg.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/centrality/__pycache__/voterank_alg.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..54d4cea735f711b02817ef6d9c571f5ebe4c25f3 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/centrality/__pycache__/voterank_alg.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/betweenness.py b/lib/python3.10/site-packages/networkx/algorithms/centrality/betweenness.py new file mode 100644 index 0000000000000000000000000000000000000000..42e09771d45eeda64ac5e52a663f8586cc14bc73 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/centrality/betweenness.py @@ -0,0 +1,436 @@ +"""Betweenness centrality measures.""" + +from collections import deque +from heapq import heappop, heappush +from itertools import count + +import networkx as nx +from networkx.algorithms.shortest_paths.weighted import _weight_function +from networkx.utils import py_random_state +from networkx.utils.decorators import not_implemented_for + +__all__ = ["betweenness_centrality", "edge_betweenness_centrality"] + + +@py_random_state(5) +@nx._dispatchable(edge_attrs="weight") +def betweenness_centrality( + G, k=None, normalized=True, weight=None, endpoints=False, seed=None +): + r"""Compute the shortest-path betweenness centrality for nodes. + + Betweenness centrality of a node $v$ is the sum of the + fraction of all-pairs shortest paths that pass through $v$ + + .. math:: + + c_B(v) =\sum_{s,t \in V} \frac{\sigma(s, t|v)}{\sigma(s, t)} + + where $V$ is the set of nodes, $\sigma(s, t)$ is the number of + shortest $(s, t)$-paths, and $\sigma(s, t|v)$ is the number of + those paths passing through some node $v$ other than $s, t$. + If $s = t$, $\sigma(s, t) = 1$, and if $v \in {s, t}$, + $\sigma(s, t|v) = 0$ [2]_. + + Parameters + ---------- + G : graph + A NetworkX graph. + + k : int, optional (default=None) + If k is not None use k node samples to estimate betweenness. + The value of k <= n where n is the number of nodes in the graph. + Higher values give better approximation. + + normalized : bool, optional + If True the betweenness values are normalized by `2/((n-1)(n-2))` + for graphs, and `1/((n-1)(n-2))` for directed graphs where `n` + is the number of nodes in G. + + weight : None or string, optional (default=None) + If None, all edge weights are considered equal. + Otherwise holds the name of the edge attribute used as weight. + Weights are used to calculate weighted shortest paths, so they are + interpreted as distances. + + endpoints : bool, optional + If True include the endpoints in the shortest path counts. + + seed : integer, random_state, or None (default) + Indicator of random number generation state. + See :ref:`Randomness`. + Note that this is only used if k is not None. + + Returns + ------- + nodes : dictionary + Dictionary of nodes with betweenness centrality as the value. + + See Also + -------- + edge_betweenness_centrality + load_centrality + + Notes + ----- + The algorithm is from Ulrik Brandes [1]_. + See [4]_ for the original first published version and [2]_ for details on + algorithms for variations and related metrics. + + For approximate betweenness calculations set k=#samples to use + k nodes ("pivots") to estimate the betweenness values. For an estimate + of the number of pivots needed see [3]_. + + For weighted graphs the edge weights must be greater than zero. + Zero edge weights can produce an infinite number of equal length + paths between pairs of nodes. + + The total number of paths between source and target is counted + differently for directed and undirected graphs. Directed paths + are easy to count. Undirected paths are tricky: should a path + from "u" to "v" count as 1 undirected path or as 2 directed paths? + + For betweenness_centrality we report the number of undirected + paths when G is undirected. + + For betweenness_centrality_subset the reporting is different. + If the source and target subsets are the same, then we want + to count undirected paths. But if the source and target subsets + differ -- for example, if sources is {0} and targets is {1}, + then we are only counting the paths in one direction. They are + undirected paths but we are counting them in a directed way. + To count them as undirected paths, each should count as half a path. + + This algorithm is not guaranteed to be correct if edge weights + are floating point numbers. As a workaround you can use integer + numbers by multiplying the relevant edge attributes by a convenient + constant factor (eg 100) and converting to integers. + + References + ---------- + .. [1] Ulrik Brandes: + A Faster Algorithm for Betweenness Centrality. + Journal of Mathematical Sociology 25(2):163-177, 2001. + https://doi.org/10.1080/0022250X.2001.9990249 + .. [2] Ulrik Brandes: + On Variants of Shortest-Path Betweenness + Centrality and their Generic Computation. + Social Networks 30(2):136-145, 2008. + https://doi.org/10.1016/j.socnet.2007.11.001 + .. [3] Ulrik Brandes and Christian Pich: + Centrality Estimation in Large Networks. + International Journal of Bifurcation and Chaos 17(7):2303-2318, 2007. + https://dx.doi.org/10.1142/S0218127407018403 + .. [4] Linton C. Freeman: + A set of measures of centrality based on betweenness. + Sociometry 40: 35–41, 1977 + https://doi.org/10.2307/3033543 + """ + betweenness = dict.fromkeys(G, 0.0) # b[v]=0 for v in G + if k is None: + nodes = G + else: + nodes = seed.sample(list(G.nodes()), k) + for s in nodes: + # single source shortest paths + if weight is None: # use BFS + S, P, sigma, _ = _single_source_shortest_path_basic(G, s) + else: # use Dijkstra's algorithm + S, P, sigma, _ = _single_source_dijkstra_path_basic(G, s, weight) + # accumulation + if endpoints: + betweenness, _ = _accumulate_endpoints(betweenness, S, P, sigma, s) + else: + betweenness, _ = _accumulate_basic(betweenness, S, P, sigma, s) + # rescaling + betweenness = _rescale( + betweenness, + len(G), + normalized=normalized, + directed=G.is_directed(), + k=k, + endpoints=endpoints, + ) + return betweenness + + +@py_random_state(4) +@nx._dispatchable(edge_attrs="weight") +def edge_betweenness_centrality(G, k=None, normalized=True, weight=None, seed=None): + r"""Compute betweenness centrality for edges. + + Betweenness centrality of an edge $e$ is the sum of the + fraction of all-pairs shortest paths that pass through $e$ + + .. math:: + + c_B(e) =\sum_{s,t \in V} \frac{\sigma(s, t|e)}{\sigma(s, t)} + + where $V$ is the set of nodes, $\sigma(s, t)$ is the number of + shortest $(s, t)$-paths, and $\sigma(s, t|e)$ is the number of + those paths passing through edge $e$ [2]_. + + Parameters + ---------- + G : graph + A NetworkX graph. + + k : int, optional (default=None) + If k is not None use k node samples to estimate betweenness. + The value of k <= n where n is the number of nodes in the graph. + Higher values give better approximation. + + normalized : bool, optional + If True the betweenness values are normalized by $2/(n(n-1))$ + for graphs, and $1/(n(n-1))$ for directed graphs where $n$ + is the number of nodes in G. + + weight : None or string, optional (default=None) + If None, all edge weights are considered equal. + Otherwise holds the name of the edge attribute used as weight. + Weights are used to calculate weighted shortest paths, so they are + interpreted as distances. + + seed : integer, random_state, or None (default) + Indicator of random number generation state. + See :ref:`Randomness`. + Note that this is only used if k is not None. + + Returns + ------- + edges : dictionary + Dictionary of edges with betweenness centrality as the value. + + See Also + -------- + betweenness_centrality + edge_load + + Notes + ----- + The algorithm is from Ulrik Brandes [1]_. + + For weighted graphs the edge weights must be greater than zero. + Zero edge weights can produce an infinite number of equal length + paths between pairs of nodes. + + References + ---------- + .. [1] A Faster Algorithm for Betweenness Centrality. Ulrik Brandes, + Journal of Mathematical Sociology 25(2):163-177, 2001. + https://doi.org/10.1080/0022250X.2001.9990249 + .. [2] Ulrik Brandes: On Variants of Shortest-Path Betweenness + Centrality and their Generic Computation. + Social Networks 30(2):136-145, 2008. + https://doi.org/10.1016/j.socnet.2007.11.001 + """ + betweenness = dict.fromkeys(G, 0.0) # b[v]=0 for v in G + # b[e]=0 for e in G.edges() + betweenness.update(dict.fromkeys(G.edges(), 0.0)) + if k is None: + nodes = G + else: + nodes = seed.sample(list(G.nodes()), k) + for s in nodes: + # single source shortest paths + if weight is None: # use BFS + S, P, sigma, _ = _single_source_shortest_path_basic(G, s) + else: # use Dijkstra's algorithm + S, P, sigma, _ = _single_source_dijkstra_path_basic(G, s, weight) + # accumulation + betweenness = _accumulate_edges(betweenness, S, P, sigma, s) + # rescaling + for n in G: # remove nodes to only return edges + del betweenness[n] + betweenness = _rescale_e( + betweenness, len(G), normalized=normalized, directed=G.is_directed() + ) + if G.is_multigraph(): + betweenness = _add_edge_keys(G, betweenness, weight=weight) + return betweenness + + +# helpers for betweenness centrality + + +def _single_source_shortest_path_basic(G, s): + S = [] + P = {} + for v in G: + P[v] = [] + sigma = dict.fromkeys(G, 0.0) # sigma[v]=0 for v in G + D = {} + sigma[s] = 1.0 + D[s] = 0 + Q = deque([s]) + while Q: # use BFS to find shortest paths + v = Q.popleft() + S.append(v) + Dv = D[v] + sigmav = sigma[v] + for w in G[v]: + if w not in D: + Q.append(w) + D[w] = Dv + 1 + if D[w] == Dv + 1: # this is a shortest path, count paths + sigma[w] += sigmav + P[w].append(v) # predecessors + return S, P, sigma, D + + +def _single_source_dijkstra_path_basic(G, s, weight): + weight = _weight_function(G, weight) + # modified from Eppstein + S = [] + P = {} + for v in G: + P[v] = [] + sigma = dict.fromkeys(G, 0.0) # sigma[v]=0 for v in G + D = {} + sigma[s] = 1.0 + push = heappush + pop = heappop + seen = {s: 0} + c = count() + Q = [] # use Q as heap with (distance,node id) tuples + push(Q, (0, next(c), s, s)) + while Q: + (dist, _, pred, v) = pop(Q) + if v in D: + continue # already searched this node. + sigma[v] += sigma[pred] # count paths + S.append(v) + D[v] = dist + for w, edgedata in G[v].items(): + vw_dist = dist + weight(v, w, edgedata) + if w not in D and (w not in seen or vw_dist < seen[w]): + seen[w] = vw_dist + push(Q, (vw_dist, next(c), v, w)) + sigma[w] = 0.0 + P[w] = [v] + elif vw_dist == seen[w]: # handle equal paths + sigma[w] += sigma[v] + P[w].append(v) + return S, P, sigma, D + + +def _accumulate_basic(betweenness, S, P, sigma, s): + delta = dict.fromkeys(S, 0) + while S: + w = S.pop() + coeff = (1 + delta[w]) / sigma[w] + for v in P[w]: + delta[v] += sigma[v] * coeff + if w != s: + betweenness[w] += delta[w] + return betweenness, delta + + +def _accumulate_endpoints(betweenness, S, P, sigma, s): + betweenness[s] += len(S) - 1 + delta = dict.fromkeys(S, 0) + while S: + w = S.pop() + coeff = (1 + delta[w]) / sigma[w] + for v in P[w]: + delta[v] += sigma[v] * coeff + if w != s: + betweenness[w] += delta[w] + 1 + return betweenness, delta + + +def _accumulate_edges(betweenness, S, P, sigma, s): + delta = dict.fromkeys(S, 0) + while S: + w = S.pop() + coeff = (1 + delta[w]) / sigma[w] + for v in P[w]: + c = sigma[v] * coeff + if (v, w) not in betweenness: + betweenness[(w, v)] += c + else: + betweenness[(v, w)] += c + delta[v] += c + if w != s: + betweenness[w] += delta[w] + return betweenness + + +def _rescale(betweenness, n, normalized, directed=False, k=None, endpoints=False): + if normalized: + if endpoints: + if n < 2: + scale = None # no normalization + else: + # Scale factor should include endpoint nodes + scale = 1 / (n * (n - 1)) + elif n <= 2: + scale = None # no normalization b=0 for all nodes + else: + scale = 1 / ((n - 1) * (n - 2)) + else: # rescale by 2 for undirected graphs + if not directed: + scale = 0.5 + else: + scale = None + if scale is not None: + if k is not None: + scale = scale * n / k + for v in betweenness: + betweenness[v] *= scale + return betweenness + + +def _rescale_e(betweenness, n, normalized, directed=False, k=None): + if normalized: + if n <= 1: + scale = None # no normalization b=0 for all nodes + else: + scale = 1 / (n * (n - 1)) + else: # rescale by 2 for undirected graphs + if not directed: + scale = 0.5 + else: + scale = None + if scale is not None: + if k is not None: + scale = scale * n / k + for v in betweenness: + betweenness[v] *= scale + return betweenness + + +@not_implemented_for("graph") +def _add_edge_keys(G, betweenness, weight=None): + r"""Adds the corrected betweenness centrality (BC) values for multigraphs. + + Parameters + ---------- + G : NetworkX graph. + + betweenness : dictionary + Dictionary mapping adjacent node tuples to betweenness centrality values. + + weight : string or function + See `_weight_function` for details. Defaults to `None`. + + Returns + ------- + edges : dictionary + The parameter `betweenness` including edges with keys and their + betweenness centrality values. + + The BC value is divided among edges of equal weight. + """ + _weight = _weight_function(G, weight) + + edge_bc = dict.fromkeys(G.edges, 0.0) + for u, v in betweenness: + d = G[u][v] + wt = _weight(u, v, d) + keys = [k for k in d if _weight(u, v, {k: d[k]}) == wt] + bc = betweenness[(u, v)] / len(keys) + for k in keys: + edge_bc[(u, v, k)] = bc + + return edge_bc diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/betweenness_subset.py b/lib/python3.10/site-packages/networkx/algorithms/centrality/betweenness_subset.py new file mode 100644 index 0000000000000000000000000000000000000000..b9e99365ff33301693b79c7afe54c4591561b5db --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/centrality/betweenness_subset.py @@ -0,0 +1,275 @@ +"""Betweenness centrality measures for subsets of nodes.""" + +import networkx as nx +from networkx.algorithms.centrality.betweenness import ( + _add_edge_keys, +) +from networkx.algorithms.centrality.betweenness import ( + _single_source_dijkstra_path_basic as dijkstra, +) +from networkx.algorithms.centrality.betweenness import ( + _single_source_shortest_path_basic as shortest_path, +) + +__all__ = [ + "betweenness_centrality_subset", + "edge_betweenness_centrality_subset", +] + + +@nx._dispatchable(edge_attrs="weight") +def betweenness_centrality_subset(G, sources, targets, normalized=False, weight=None): + r"""Compute betweenness centrality for a subset of nodes. + + .. math:: + + c_B(v) =\sum_{s\in S, t \in T} \frac{\sigma(s, t|v)}{\sigma(s, t)} + + where $S$ is the set of sources, $T$ is the set of targets, + $\sigma(s, t)$ is the number of shortest $(s, t)$-paths, + and $\sigma(s, t|v)$ is the number of those paths + passing through some node $v$ other than $s, t$. + If $s = t$, $\sigma(s, t) = 1$, + and if $v \in {s, t}$, $\sigma(s, t|v) = 0$ [2]_. + + + Parameters + ---------- + G : graph + A NetworkX graph. + + sources: list of nodes + Nodes to use as sources for shortest paths in betweenness + + targets: list of nodes + Nodes to use as targets for shortest paths in betweenness + + normalized : bool, optional + If True the betweenness values are normalized by $2/((n-1)(n-2))$ + for graphs, and $1/((n-1)(n-2))$ for directed graphs where $n$ + is the number of nodes in G. + + weight : None or string, optional (default=None) + If None, all edge weights are considered equal. + Otherwise holds the name of the edge attribute used as weight. + Weights are used to calculate weighted shortest paths, so they are + interpreted as distances. + + Returns + ------- + nodes : dictionary + Dictionary of nodes with betweenness centrality as the value. + + See Also + -------- + edge_betweenness_centrality + load_centrality + + Notes + ----- + The basic algorithm is from [1]_. + + For weighted graphs the edge weights must be greater than zero. + Zero edge weights can produce an infinite number of equal length + paths between pairs of nodes. + + The normalization might seem a little strange but it is + designed to make betweenness_centrality(G) be the same as + betweenness_centrality_subset(G,sources=G.nodes(),targets=G.nodes()). + + The total number of paths between source and target is counted + differently for directed and undirected graphs. Directed paths + are easy to count. Undirected paths are tricky: should a path + from "u" to "v" count as 1 undirected path or as 2 directed paths? + + For betweenness_centrality we report the number of undirected + paths when G is undirected. + + For betweenness_centrality_subset the reporting is different. + If the source and target subsets are the same, then we want + to count undirected paths. But if the source and target subsets + differ -- for example, if sources is {0} and targets is {1}, + then we are only counting the paths in one direction. They are + undirected paths but we are counting them in a directed way. + To count them as undirected paths, each should count as half a path. + + References + ---------- + .. [1] Ulrik Brandes, A Faster Algorithm for Betweenness Centrality. + Journal of Mathematical Sociology 25(2):163-177, 2001. + https://doi.org/10.1080/0022250X.2001.9990249 + .. [2] Ulrik Brandes: On Variants of Shortest-Path Betweenness + Centrality and their Generic Computation. + Social Networks 30(2):136-145, 2008. + https://doi.org/10.1016/j.socnet.2007.11.001 + """ + b = dict.fromkeys(G, 0.0) # b[v]=0 for v in G + for s in sources: + # single source shortest paths + if weight is None: # use BFS + S, P, sigma, _ = shortest_path(G, s) + else: # use Dijkstra's algorithm + S, P, sigma, _ = dijkstra(G, s, weight) + b = _accumulate_subset(b, S, P, sigma, s, targets) + b = _rescale(b, len(G), normalized=normalized, directed=G.is_directed()) + return b + + +@nx._dispatchable(edge_attrs="weight") +def edge_betweenness_centrality_subset( + G, sources, targets, normalized=False, weight=None +): + r"""Compute betweenness centrality for edges for a subset of nodes. + + .. math:: + + c_B(v) =\sum_{s\in S,t \in T} \frac{\sigma(s, t|e)}{\sigma(s, t)} + + where $S$ is the set of sources, $T$ is the set of targets, + $\sigma(s, t)$ is the number of shortest $(s, t)$-paths, + and $\sigma(s, t|e)$ is the number of those paths + passing through edge $e$ [2]_. + + Parameters + ---------- + G : graph + A networkx graph. + + sources: list of nodes + Nodes to use as sources for shortest paths in betweenness + + targets: list of nodes + Nodes to use as targets for shortest paths in betweenness + + normalized : bool, optional + If True the betweenness values are normalized by `2/(n(n-1))` + for graphs, and `1/(n(n-1))` for directed graphs where `n` + is the number of nodes in G. + + weight : None or string, optional (default=None) + If None, all edge weights are considered equal. + Otherwise holds the name of the edge attribute used as weight. + Weights are used to calculate weighted shortest paths, so they are + interpreted as distances. + + Returns + ------- + edges : dictionary + Dictionary of edges with Betweenness centrality as the value. + + See Also + -------- + betweenness_centrality + edge_load + + Notes + ----- + The basic algorithm is from [1]_. + + For weighted graphs the edge weights must be greater than zero. + Zero edge weights can produce an infinite number of equal length + paths between pairs of nodes. + + The normalization might seem a little strange but it is the same + as in edge_betweenness_centrality() and is designed to make + edge_betweenness_centrality(G) be the same as + edge_betweenness_centrality_subset(G,sources=G.nodes(),targets=G.nodes()). + + References + ---------- + .. [1] Ulrik Brandes, A Faster Algorithm for Betweenness Centrality. + Journal of Mathematical Sociology 25(2):163-177, 2001. + https://doi.org/10.1080/0022250X.2001.9990249 + .. [2] Ulrik Brandes: On Variants of Shortest-Path Betweenness + Centrality and their Generic Computation. + Social Networks 30(2):136-145, 2008. + https://doi.org/10.1016/j.socnet.2007.11.001 + """ + b = dict.fromkeys(G, 0.0) # b[v]=0 for v in G + b.update(dict.fromkeys(G.edges(), 0.0)) # b[e] for e in G.edges() + for s in sources: + # single source shortest paths + if weight is None: # use BFS + S, P, sigma, _ = shortest_path(G, s) + else: # use Dijkstra's algorithm + S, P, sigma, _ = dijkstra(G, s, weight) + b = _accumulate_edges_subset(b, S, P, sigma, s, targets) + for n in G: # remove nodes to only return edges + del b[n] + b = _rescale_e(b, len(G), normalized=normalized, directed=G.is_directed()) + if G.is_multigraph(): + b = _add_edge_keys(G, b, weight=weight) + return b + + +def _accumulate_subset(betweenness, S, P, sigma, s, targets): + delta = dict.fromkeys(S, 0.0) + target_set = set(targets) - {s} + while S: + w = S.pop() + if w in target_set: + coeff = (delta[w] + 1.0) / sigma[w] + else: + coeff = delta[w] / sigma[w] + for v in P[w]: + delta[v] += sigma[v] * coeff + if w != s: + betweenness[w] += delta[w] + return betweenness + + +def _accumulate_edges_subset(betweenness, S, P, sigma, s, targets): + """edge_betweenness_centrality_subset helper.""" + delta = dict.fromkeys(S, 0) + target_set = set(targets) + while S: + w = S.pop() + for v in P[w]: + if w in target_set: + c = (sigma[v] / sigma[w]) * (1.0 + delta[w]) + else: + c = delta[w] / len(P[w]) + if (v, w) not in betweenness: + betweenness[(w, v)] += c + else: + betweenness[(v, w)] += c + delta[v] += c + if w != s: + betweenness[w] += delta[w] + return betweenness + + +def _rescale(betweenness, n, normalized, directed=False): + """betweenness_centrality_subset helper.""" + if normalized: + if n <= 2: + scale = None # no normalization b=0 for all nodes + else: + scale = 1.0 / ((n - 1) * (n - 2)) + else: # rescale by 2 for undirected graphs + if not directed: + scale = 0.5 + else: + scale = None + if scale is not None: + for v in betweenness: + betweenness[v] *= scale + return betweenness + + +def _rescale_e(betweenness, n, normalized, directed=False): + """edge_betweenness_centrality_subset helper.""" + if normalized: + if n <= 1: + scale = None # no normalization b=0 for all nodes + else: + scale = 1.0 / (n * (n - 1)) + else: # rescale by 2 for undirected graphs + if not directed: + scale = 0.5 + else: + scale = None + if scale is not None: + for v in betweenness: + betweenness[v] *= scale + return betweenness diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/closeness.py b/lib/python3.10/site-packages/networkx/algorithms/centrality/closeness.py new file mode 100644 index 0000000000000000000000000000000000000000..1cc2f9599f2a3af8fa653b8faef58a9a8f2d2355 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/centrality/closeness.py @@ -0,0 +1,282 @@ +""" +Closeness centrality measures. +""" + +import functools + +import networkx as nx +from networkx.exception import NetworkXError +from networkx.utils.decorators import not_implemented_for + +__all__ = ["closeness_centrality", "incremental_closeness_centrality"] + + +@nx._dispatchable(edge_attrs="distance") +def closeness_centrality(G, u=None, distance=None, wf_improved=True): + r"""Compute closeness centrality for nodes. + + Closeness centrality [1]_ of a node `u` is the reciprocal of the + average shortest path distance to `u` over all `n-1` reachable nodes. + + .. math:: + + C(u) = \frac{n - 1}{\sum_{v=1}^{n-1} d(v, u)}, + + where `d(v, u)` is the shortest-path distance between `v` and `u`, + and `n-1` is the number of nodes reachable from `u`. Notice that the + closeness distance function computes the incoming distance to `u` + for directed graphs. To use outward distance, act on `G.reverse()`. + + Notice that higher values of closeness indicate higher centrality. + + Wasserman and Faust propose an improved formula for graphs with + more than one connected component. The result is "a ratio of the + fraction of actors in the group who are reachable, to the average + distance" from the reachable actors [2]_. You might think this + scale factor is inverted but it is not. As is, nodes from small + components receive a smaller closeness value. Letting `N` denote + the number of nodes in the graph, + + .. math:: + + C_{WF}(u) = \frac{n-1}{N-1} \frac{n - 1}{\sum_{v=1}^{n-1} d(v, u)}, + + Parameters + ---------- + G : graph + A NetworkX graph + + u : node, optional + Return only the value for node u + + distance : edge attribute key, optional (default=None) + Use the specified edge attribute as the edge distance in shortest + path calculations. If `None` (the default) all edges have a distance of 1. + Absent edge attributes are assigned a distance of 1. Note that no check + is performed to ensure that edges have the provided attribute. + + wf_improved : bool, optional (default=True) + If True, scale by the fraction of nodes reachable. This gives the + Wasserman and Faust improved formula. For single component graphs + it is the same as the original formula. + + Returns + ------- + nodes : dictionary + Dictionary of nodes with closeness centrality as the value. + + Examples + -------- + >>> G = nx.Graph([(0, 1), (0, 2), (0, 3), (1, 2), (1, 3)]) + >>> nx.closeness_centrality(G) + {0: 1.0, 1: 1.0, 2: 0.75, 3: 0.75} + + See Also + -------- + betweenness_centrality, load_centrality, eigenvector_centrality, + degree_centrality, incremental_closeness_centrality + + Notes + ----- + The closeness centrality is normalized to `(n-1)/(|G|-1)` where + `n` is the number of nodes in the connected part of graph + containing the node. If the graph is not completely connected, + this algorithm computes the closeness centrality for each + connected part separately scaled by that parts size. + + If the 'distance' keyword is set to an edge attribute key then the + shortest-path length will be computed using Dijkstra's algorithm with + that edge attribute as the edge weight. + + The closeness centrality uses *inward* distance to a node, not outward. + If you want to use outword distances apply the function to `G.reverse()` + + In NetworkX 2.2 and earlier a bug caused Dijkstra's algorithm to use the + outward distance rather than the inward distance. If you use a 'distance' + keyword and a DiGraph, your results will change between v2.2 and v2.3. + + References + ---------- + .. [1] Linton C. Freeman: Centrality in networks: I. + Conceptual clarification. Social Networks 1:215-239, 1979. + https://doi.org/10.1016/0378-8733(78)90021-7 + .. [2] pg. 201 of Wasserman, S. and Faust, K., + Social Network Analysis: Methods and Applications, 1994, + Cambridge University Press. + """ + if G.is_directed(): + G = G.reverse() # create a reversed graph view + + if distance is not None: + # use Dijkstra's algorithm with specified attribute as edge weight + path_length = functools.partial( + nx.single_source_dijkstra_path_length, weight=distance + ) + else: + path_length = nx.single_source_shortest_path_length + + if u is None: + nodes = G.nodes + else: + nodes = [u] + closeness_dict = {} + for n in nodes: + sp = path_length(G, n) + totsp = sum(sp.values()) + len_G = len(G) + _closeness_centrality = 0.0 + if totsp > 0.0 and len_G > 1: + _closeness_centrality = (len(sp) - 1.0) / totsp + # normalize to number of nodes-1 in connected part + if wf_improved: + s = (len(sp) - 1.0) / (len_G - 1) + _closeness_centrality *= s + closeness_dict[n] = _closeness_centrality + if u is not None: + return closeness_dict[u] + return closeness_dict + + +@not_implemented_for("directed") +@nx._dispatchable(mutates_input=True) +def incremental_closeness_centrality( + G, edge, prev_cc=None, insertion=True, wf_improved=True +): + r"""Incremental closeness centrality for nodes. + + Compute closeness centrality for nodes using level-based work filtering + as described in Incremental Algorithms for Closeness Centrality by Sariyuce et al. + + Level-based work filtering detects unnecessary updates to the closeness + centrality and filters them out. + + --- + From "Incremental Algorithms for Closeness Centrality": + + Theorem 1: Let :math:`G = (V, E)` be a graph and u and v be two vertices in V + such that there is no edge (u, v) in E. Let :math:`G' = (V, E \cup uv)` + Then :math:`cc[s] = cc'[s]` if and only if :math:`\left|dG(s, u) - dG(s, v)\right| \leq 1`. + + Where :math:`dG(u, v)` denotes the length of the shortest path between + two vertices u, v in a graph G, cc[s] is the closeness centrality for a + vertex s in V, and cc'[s] is the closeness centrality for a + vertex s in V, with the (u, v) edge added. + --- + + We use Theorem 1 to filter out updates when adding or removing an edge. + When adding an edge (u, v), we compute the shortest path lengths from all + other nodes to u and to v before the node is added. When removing an edge, + we compute the shortest path lengths after the edge is removed. Then we + apply Theorem 1 to use previously computed closeness centrality for nodes + where :math:`\left|dG(s, u) - dG(s, v)\right| \leq 1`. This works only for + undirected, unweighted graphs; the distance argument is not supported. + + Closeness centrality [1]_ of a node `u` is the reciprocal of the + sum of the shortest path distances from `u` to all `n-1` other nodes. + Since the sum of distances depends on the number of nodes in the + graph, closeness is normalized by the sum of minimum possible + distances `n-1`. + + .. math:: + + C(u) = \frac{n - 1}{\sum_{v=1}^{n-1} d(v, u)}, + + where `d(v, u)` is the shortest-path distance between `v` and `u`, + and `n` is the number of nodes in the graph. + + Notice that higher values of closeness indicate higher centrality. + + Parameters + ---------- + G : graph + A NetworkX graph + + edge : tuple + The modified edge (u, v) in the graph. + + prev_cc : dictionary + The previous closeness centrality for all nodes in the graph. + + insertion : bool, optional + If True (default) the edge was inserted, otherwise it was deleted from the graph. + + wf_improved : bool, optional (default=True) + If True, scale by the fraction of nodes reachable. This gives the + Wasserman and Faust improved formula. For single component graphs + it is the same as the original formula. + + Returns + ------- + nodes : dictionary + Dictionary of nodes with closeness centrality as the value. + + See Also + -------- + betweenness_centrality, load_centrality, eigenvector_centrality, + degree_centrality, closeness_centrality + + Notes + ----- + The closeness centrality is normalized to `(n-1)/(|G|-1)` where + `n` is the number of nodes in the connected part of graph + containing the node. If the graph is not completely connected, + this algorithm computes the closeness centrality for each + connected part separately. + + References + ---------- + .. [1] Freeman, L.C., 1979. Centrality in networks: I. + Conceptual clarification. Social Networks 1, 215--239. + https://doi.org/10.1016/0378-8733(78)90021-7 + .. [2] Sariyuce, A.E. ; Kaya, K. ; Saule, E. ; Catalyiirek, U.V. Incremental + Algorithms for Closeness Centrality. 2013 IEEE International Conference on Big Data + http://sariyuce.com/papers/bigdata13.pdf + """ + if prev_cc is not None and set(prev_cc.keys()) != set(G.nodes()): + raise NetworkXError("prev_cc and G do not have the same nodes") + + # Unpack edge + (u, v) = edge + path_length = nx.single_source_shortest_path_length + + if insertion: + # For edge insertion, we want shortest paths before the edge is inserted + du = path_length(G, u) + dv = path_length(G, v) + + G.add_edge(u, v) + else: + G.remove_edge(u, v) + + # For edge removal, we want shortest paths after the edge is removed + du = path_length(G, u) + dv = path_length(G, v) + + if prev_cc is None: + return nx.closeness_centrality(G) + + nodes = G.nodes() + closeness_dict = {} + for n in nodes: + if n in du and n in dv and abs(du[n] - dv[n]) <= 1: + closeness_dict[n] = prev_cc[n] + else: + sp = path_length(G, n) + totsp = sum(sp.values()) + len_G = len(G) + _closeness_centrality = 0.0 + if totsp > 0.0 and len_G > 1: + _closeness_centrality = (len(sp) - 1.0) / totsp + # normalize to number of nodes-1 in connected part + if wf_improved: + s = (len(sp) - 1.0) / (len_G - 1) + _closeness_centrality *= s + closeness_dict[n] = _closeness_centrality + + # Leave the graph as we found it + if insertion: + G.remove_edge(u, v) + else: + G.add_edge(u, v) + + return closeness_dict diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/current_flow_betweenness.py b/lib/python3.10/site-packages/networkx/algorithms/centrality/current_flow_betweenness.py new file mode 100644 index 0000000000000000000000000000000000000000..bfde279afc8729780be597c494f3e8fa4a281ab7 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/centrality/current_flow_betweenness.py @@ -0,0 +1,342 @@ +"""Current-flow betweenness centrality measures.""" + +import networkx as nx +from networkx.algorithms.centrality.flow_matrix import ( + CGInverseLaplacian, + FullInverseLaplacian, + SuperLUInverseLaplacian, + flow_matrix_row, +) +from networkx.utils import ( + not_implemented_for, + py_random_state, + reverse_cuthill_mckee_ordering, +) + +__all__ = [ + "current_flow_betweenness_centrality", + "approximate_current_flow_betweenness_centrality", + "edge_current_flow_betweenness_centrality", +] + + +@not_implemented_for("directed") +@py_random_state(7) +@nx._dispatchable(edge_attrs="weight") +def approximate_current_flow_betweenness_centrality( + G, + normalized=True, + weight=None, + dtype=float, + solver="full", + epsilon=0.5, + kmax=10000, + seed=None, +): + r"""Compute the approximate current-flow betweenness centrality for nodes. + + Approximates the current-flow betweenness centrality within absolute + error of epsilon with high probability [1]_. + + + Parameters + ---------- + G : graph + A NetworkX graph + + normalized : bool, optional (default=True) + If True the betweenness values are normalized by 2/[(n-1)(n-2)] where + n is the number of nodes in G. + + weight : string or None, optional (default=None) + Key for edge data used as the edge weight. + If None, then use 1 as each edge weight. + The weight reflects the capacity or the strength of the + edge. + + dtype : data type (float) + Default data type for internal matrices. + Set to np.float32 for lower memory consumption. + + solver : string (default='full') + Type of linear solver to use for computing the flow matrix. + Options are "full" (uses most memory), "lu" (recommended), and + "cg" (uses least memory). + + epsilon: float + Absolute error tolerance. + + kmax: int + Maximum number of sample node pairs to use for approximation. + + seed : integer, random_state, or None (default) + Indicator of random number generation state. + See :ref:`Randomness`. + + Returns + ------- + nodes : dictionary + Dictionary of nodes with betweenness centrality as the value. + + See Also + -------- + current_flow_betweenness_centrality + + Notes + ----- + The running time is $O((1/\epsilon^2)m{\sqrt k} \log n)$ + and the space required is $O(m)$ for $n$ nodes and $m$ edges. + + If the edges have a 'weight' attribute they will be used as + weights in this algorithm. Unspecified weights are set to 1. + + References + ---------- + .. [1] Ulrik Brandes and Daniel Fleischer: + Centrality Measures Based on Current Flow. + Proc. 22nd Symp. Theoretical Aspects of Computer Science (STACS '05). + LNCS 3404, pp. 533-544. Springer-Verlag, 2005. + https://doi.org/10.1007/978-3-540-31856-9_44 + """ + import numpy as np + + if not nx.is_connected(G): + raise nx.NetworkXError("Graph not connected.") + solvername = { + "full": FullInverseLaplacian, + "lu": SuperLUInverseLaplacian, + "cg": CGInverseLaplacian, + } + n = G.number_of_nodes() + ordering = list(reverse_cuthill_mckee_ordering(G)) + # make a copy with integer labels according to rcm ordering + # this could be done without a copy if we really wanted to + H = nx.relabel_nodes(G, dict(zip(ordering, range(n)))) + L = nx.laplacian_matrix(H, nodelist=range(n), weight=weight).asformat("csc") + L = L.astype(dtype) + C = solvername[solver](L, dtype=dtype) # initialize solver + betweenness = dict.fromkeys(H, 0.0) + nb = (n - 1.0) * (n - 2.0) # normalization factor + cstar = n * (n - 1) / nb + l = 1 # parameter in approximation, adjustable + k = l * int(np.ceil((cstar / epsilon) ** 2 * np.log(n))) + if k > kmax: + msg = f"Number random pairs k>kmax ({k}>{kmax}) " + raise nx.NetworkXError(msg, "Increase kmax or epsilon") + cstar2k = cstar / (2 * k) + for _ in range(k): + s, t = pair = seed.sample(range(n), 2) + b = np.zeros(n, dtype=dtype) + b[s] = 1 + b[t] = -1 + p = C.solve(b) + for v in H: + if v in pair: + continue + for nbr in H[v]: + w = H[v][nbr].get(weight, 1.0) + betweenness[v] += float(w * np.abs(p[v] - p[nbr]) * cstar2k) + if normalized: + factor = 1.0 + else: + factor = nb / 2.0 + # remap to original node names and "unnormalize" if required + return {ordering[k]: v * factor for k, v in betweenness.items()} + + +@not_implemented_for("directed") +@nx._dispatchable(edge_attrs="weight") +def current_flow_betweenness_centrality( + G, normalized=True, weight=None, dtype=float, solver="full" +): + r"""Compute current-flow betweenness centrality for nodes. + + Current-flow betweenness centrality uses an electrical current + model for information spreading in contrast to betweenness + centrality which uses shortest paths. + + Current-flow betweenness centrality is also known as + random-walk betweenness centrality [2]_. + + Parameters + ---------- + G : graph + A NetworkX graph + + normalized : bool, optional (default=True) + If True the betweenness values are normalized by 2/[(n-1)(n-2)] where + n is the number of nodes in G. + + weight : string or None, optional (default=None) + Key for edge data used as the edge weight. + If None, then use 1 as each edge weight. + The weight reflects the capacity or the strength of the + edge. + + dtype : data type (float) + Default data type for internal matrices. + Set to np.float32 for lower memory consumption. + + solver : string (default='full') + Type of linear solver to use for computing the flow matrix. + Options are "full" (uses most memory), "lu" (recommended), and + "cg" (uses least memory). + + Returns + ------- + nodes : dictionary + Dictionary of nodes with betweenness centrality as the value. + + See Also + -------- + approximate_current_flow_betweenness_centrality + betweenness_centrality + edge_betweenness_centrality + edge_current_flow_betweenness_centrality + + Notes + ----- + Current-flow betweenness can be computed in $O(I(n-1)+mn \log n)$ + time [1]_, where $I(n-1)$ is the time needed to compute the + inverse Laplacian. For a full matrix this is $O(n^3)$ but using + sparse methods you can achieve $O(nm{\sqrt k})$ where $k$ is the + Laplacian matrix condition number. + + The space required is $O(nw)$ where $w$ is the width of the sparse + Laplacian matrix. Worse case is $w=n$ for $O(n^2)$. + + If the edges have a 'weight' attribute they will be used as + weights in this algorithm. Unspecified weights are set to 1. + + References + ---------- + .. [1] Centrality Measures Based on Current Flow. + Ulrik Brandes and Daniel Fleischer, + Proc. 22nd Symp. Theoretical Aspects of Computer Science (STACS '05). + LNCS 3404, pp. 533-544. Springer-Verlag, 2005. + https://doi.org/10.1007/978-3-540-31856-9_44 + + .. [2] A measure of betweenness centrality based on random walks, + M. E. J. Newman, Social Networks 27, 39-54 (2005). + """ + if not nx.is_connected(G): + raise nx.NetworkXError("Graph not connected.") + N = G.number_of_nodes() + ordering = list(reverse_cuthill_mckee_ordering(G)) + # make a copy with integer labels according to rcm ordering + # this could be done without a copy if we really wanted to + H = nx.relabel_nodes(G, dict(zip(ordering, range(N)))) + betweenness = dict.fromkeys(H, 0.0) # b[n]=0 for n in H + for row, (s, t) in flow_matrix_row(H, weight=weight, dtype=dtype, solver=solver): + pos = dict(zip(row.argsort()[::-1], range(N))) + for i in range(N): + betweenness[s] += (i - pos[i]) * row.item(i) + betweenness[t] += (N - i - 1 - pos[i]) * row.item(i) + if normalized: + nb = (N - 1.0) * (N - 2.0) # normalization factor + else: + nb = 2.0 + return {ordering[n]: (b - n) * 2.0 / nb for n, b in betweenness.items()} + + +@not_implemented_for("directed") +@nx._dispatchable(edge_attrs="weight") +def edge_current_flow_betweenness_centrality( + G, normalized=True, weight=None, dtype=float, solver="full" +): + r"""Compute current-flow betweenness centrality for edges. + + Current-flow betweenness centrality uses an electrical current + model for information spreading in contrast to betweenness + centrality which uses shortest paths. + + Current-flow betweenness centrality is also known as + random-walk betweenness centrality [2]_. + + Parameters + ---------- + G : graph + A NetworkX graph + + normalized : bool, optional (default=True) + If True the betweenness values are normalized by 2/[(n-1)(n-2)] where + n is the number of nodes in G. + + weight : string or None, optional (default=None) + Key for edge data used as the edge weight. + If None, then use 1 as each edge weight. + The weight reflects the capacity or the strength of the + edge. + + dtype : data type (default=float) + Default data type for internal matrices. + Set to np.float32 for lower memory consumption. + + solver : string (default='full') + Type of linear solver to use for computing the flow matrix. + Options are "full" (uses most memory), "lu" (recommended), and + "cg" (uses least memory). + + Returns + ------- + nodes : dictionary + Dictionary of edge tuples with betweenness centrality as the value. + + Raises + ------ + NetworkXError + The algorithm does not support DiGraphs. + If the input graph is an instance of DiGraph class, NetworkXError + is raised. + + See Also + -------- + betweenness_centrality + edge_betweenness_centrality + current_flow_betweenness_centrality + + Notes + ----- + Current-flow betweenness can be computed in $O(I(n-1)+mn \log n)$ + time [1]_, where $I(n-1)$ is the time needed to compute the + inverse Laplacian. For a full matrix this is $O(n^3)$ but using + sparse methods you can achieve $O(nm{\sqrt k})$ where $k$ is the + Laplacian matrix condition number. + + The space required is $O(nw)$ where $w$ is the width of the sparse + Laplacian matrix. Worse case is $w=n$ for $O(n^2)$. + + If the edges have a 'weight' attribute they will be used as + weights in this algorithm. Unspecified weights are set to 1. + + References + ---------- + .. [1] Centrality Measures Based on Current Flow. + Ulrik Brandes and Daniel Fleischer, + Proc. 22nd Symp. Theoretical Aspects of Computer Science (STACS '05). + LNCS 3404, pp. 533-544. Springer-Verlag, 2005. + https://doi.org/10.1007/978-3-540-31856-9_44 + + .. [2] A measure of betweenness centrality based on random walks, + M. E. J. Newman, Social Networks 27, 39-54 (2005). + """ + if not nx.is_connected(G): + raise nx.NetworkXError("Graph not connected.") + N = G.number_of_nodes() + ordering = list(reverse_cuthill_mckee_ordering(G)) + # make a copy with integer labels according to rcm ordering + # this could be done without a copy if we really wanted to + H = nx.relabel_nodes(G, dict(zip(ordering, range(N)))) + edges = (tuple(sorted((u, v))) for u, v in H.edges()) + betweenness = dict.fromkeys(edges, 0.0) + if normalized: + nb = (N - 1.0) * (N - 2.0) # normalization factor + else: + nb = 2.0 + for row, (e) in flow_matrix_row(H, weight=weight, dtype=dtype, solver=solver): + pos = dict(zip(row.argsort()[::-1], range(1, N + 1))) + for i in range(N): + betweenness[e] += (i + 1 - pos[i]) * row.item(i) + betweenness[e] += (N - i - pos[i]) * row.item(i) + betweenness[e] /= nb + return {(ordering[s], ordering[t]): b for (s, t), b in betweenness.items()} diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/current_flow_betweenness_subset.py b/lib/python3.10/site-packages/networkx/algorithms/centrality/current_flow_betweenness_subset.py new file mode 100644 index 0000000000000000000000000000000000000000..911718c80bd50589abe645e44e862add4fc8dbcd --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/centrality/current_flow_betweenness_subset.py @@ -0,0 +1,227 @@ +"""Current-flow betweenness centrality measures for subsets of nodes.""" + +import networkx as nx +from networkx.algorithms.centrality.flow_matrix import flow_matrix_row +from networkx.utils import not_implemented_for, reverse_cuthill_mckee_ordering + +__all__ = [ + "current_flow_betweenness_centrality_subset", + "edge_current_flow_betweenness_centrality_subset", +] + + +@not_implemented_for("directed") +@nx._dispatchable(edge_attrs="weight") +def current_flow_betweenness_centrality_subset( + G, sources, targets, normalized=True, weight=None, dtype=float, solver="lu" +): + r"""Compute current-flow betweenness centrality for subsets of nodes. + + Current-flow betweenness centrality uses an electrical current + model for information spreading in contrast to betweenness + centrality which uses shortest paths. + + Current-flow betweenness centrality is also known as + random-walk betweenness centrality [2]_. + + Parameters + ---------- + G : graph + A NetworkX graph + + sources: list of nodes + Nodes to use as sources for current + + targets: list of nodes + Nodes to use as sinks for current + + normalized : bool, optional (default=True) + If True the betweenness values are normalized by b=b/(n-1)(n-2) where + n is the number of nodes in G. + + weight : string or None, optional (default=None) + Key for edge data used as the edge weight. + If None, then use 1 as each edge weight. + The weight reflects the capacity or the strength of the + edge. + + dtype: data type (float) + Default data type for internal matrices. + Set to np.float32 for lower memory consumption. + + solver: string (default='lu') + Type of linear solver to use for computing the flow matrix. + Options are "full" (uses most memory), "lu" (recommended), and + "cg" (uses least memory). + + Returns + ------- + nodes : dictionary + Dictionary of nodes with betweenness centrality as the value. + + See Also + -------- + approximate_current_flow_betweenness_centrality + betweenness_centrality + edge_betweenness_centrality + edge_current_flow_betweenness_centrality + + Notes + ----- + Current-flow betweenness can be computed in $O(I(n-1)+mn \log n)$ + time [1]_, where $I(n-1)$ is the time needed to compute the + inverse Laplacian. For a full matrix this is $O(n^3)$ but using + sparse methods you can achieve $O(nm{\sqrt k})$ where $k$ is the + Laplacian matrix condition number. + + The space required is $O(nw)$ where $w$ is the width of the sparse + Laplacian matrix. Worse case is $w=n$ for $O(n^2)$. + + If the edges have a 'weight' attribute they will be used as + weights in this algorithm. Unspecified weights are set to 1. + + References + ---------- + .. [1] Centrality Measures Based on Current Flow. + Ulrik Brandes and Daniel Fleischer, + Proc. 22nd Symp. Theoretical Aspects of Computer Science (STACS '05). + LNCS 3404, pp. 533-544. Springer-Verlag, 2005. + https://doi.org/10.1007/978-3-540-31856-9_44 + + .. [2] A measure of betweenness centrality based on random walks, + M. E. J. Newman, Social Networks 27, 39-54 (2005). + """ + import numpy as np + + from networkx.utils import reverse_cuthill_mckee_ordering + + if not nx.is_connected(G): + raise nx.NetworkXError("Graph not connected.") + N = G.number_of_nodes() + ordering = list(reverse_cuthill_mckee_ordering(G)) + # make a copy with integer labels according to rcm ordering + # this could be done without a copy if we really wanted to + mapping = dict(zip(ordering, range(N))) + H = nx.relabel_nodes(G, mapping) + betweenness = dict.fromkeys(H, 0.0) # b[n]=0 for n in H + for row, (s, t) in flow_matrix_row(H, weight=weight, dtype=dtype, solver=solver): + for ss in sources: + i = mapping[ss] + for tt in targets: + j = mapping[tt] + betweenness[s] += 0.5 * abs(row.item(i) - row.item(j)) + betweenness[t] += 0.5 * abs(row.item(i) - row.item(j)) + if normalized: + nb = (N - 1.0) * (N - 2.0) # normalization factor + else: + nb = 2.0 + for node in H: + betweenness[node] = betweenness[node] / nb + 1.0 / (2 - N) + return {ordering[node]: value for node, value in betweenness.items()} + + +@not_implemented_for("directed") +@nx._dispatchable(edge_attrs="weight") +def edge_current_flow_betweenness_centrality_subset( + G, sources, targets, normalized=True, weight=None, dtype=float, solver="lu" +): + r"""Compute current-flow betweenness centrality for edges using subsets + of nodes. + + Current-flow betweenness centrality uses an electrical current + model for information spreading in contrast to betweenness + centrality which uses shortest paths. + + Current-flow betweenness centrality is also known as + random-walk betweenness centrality [2]_. + + Parameters + ---------- + G : graph + A NetworkX graph + + sources: list of nodes + Nodes to use as sources for current + + targets: list of nodes + Nodes to use as sinks for current + + normalized : bool, optional (default=True) + If True the betweenness values are normalized by b=b/(n-1)(n-2) where + n is the number of nodes in G. + + weight : string or None, optional (default=None) + Key for edge data used as the edge weight. + If None, then use 1 as each edge weight. + The weight reflects the capacity or the strength of the + edge. + + dtype: data type (float) + Default data type for internal matrices. + Set to np.float32 for lower memory consumption. + + solver: string (default='lu') + Type of linear solver to use for computing the flow matrix. + Options are "full" (uses most memory), "lu" (recommended), and + "cg" (uses least memory). + + Returns + ------- + nodes : dict + Dictionary of edge tuples with betweenness centrality as the value. + + See Also + -------- + betweenness_centrality + edge_betweenness_centrality + current_flow_betweenness_centrality + + Notes + ----- + Current-flow betweenness can be computed in $O(I(n-1)+mn \log n)$ + time [1]_, where $I(n-1)$ is the time needed to compute the + inverse Laplacian. For a full matrix this is $O(n^3)$ but using + sparse methods you can achieve $O(nm{\sqrt k})$ where $k$ is the + Laplacian matrix condition number. + + The space required is $O(nw)$ where $w$ is the width of the sparse + Laplacian matrix. Worse case is $w=n$ for $O(n^2)$. + + If the edges have a 'weight' attribute they will be used as + weights in this algorithm. Unspecified weights are set to 1. + + References + ---------- + .. [1] Centrality Measures Based on Current Flow. + Ulrik Brandes and Daniel Fleischer, + Proc. 22nd Symp. Theoretical Aspects of Computer Science (STACS '05). + LNCS 3404, pp. 533-544. Springer-Verlag, 2005. + https://doi.org/10.1007/978-3-540-31856-9_44 + + .. [2] A measure of betweenness centrality based on random walks, + M. E. J. Newman, Social Networks 27, 39-54 (2005). + """ + import numpy as np + + if not nx.is_connected(G): + raise nx.NetworkXError("Graph not connected.") + N = G.number_of_nodes() + ordering = list(reverse_cuthill_mckee_ordering(G)) + # make a copy with integer labels according to rcm ordering + # this could be done without a copy if we really wanted to + mapping = dict(zip(ordering, range(N))) + H = nx.relabel_nodes(G, mapping) + edges = (tuple(sorted((u, v))) for u, v in H.edges()) + betweenness = dict.fromkeys(edges, 0.0) + if normalized: + nb = (N - 1.0) * (N - 2.0) # normalization factor + else: + nb = 2.0 + for row, (e) in flow_matrix_row(H, weight=weight, dtype=dtype, solver=solver): + for ss in sources: + i = mapping[ss] + for tt in targets: + j = mapping[tt] + betweenness[e] += 0.5 * abs(row.item(i) - row.item(j)) + betweenness[e] /= nb + return {(ordering[s], ordering[t]): value for (s, t), value in betweenness.items()} diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/current_flow_closeness.py b/lib/python3.10/site-packages/networkx/algorithms/centrality/current_flow_closeness.py new file mode 100644 index 0000000000000000000000000000000000000000..67f86397bdcd61b344256b2b4c08f2c21986e05a --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/centrality/current_flow_closeness.py @@ -0,0 +1,96 @@ +"""Current-flow closeness centrality measures.""" + +import networkx as nx +from networkx.algorithms.centrality.flow_matrix import ( + CGInverseLaplacian, + FullInverseLaplacian, + SuperLUInverseLaplacian, +) +from networkx.utils import not_implemented_for, reverse_cuthill_mckee_ordering + +__all__ = ["current_flow_closeness_centrality", "information_centrality"] + + +@not_implemented_for("directed") +@nx._dispatchable(edge_attrs="weight") +def current_flow_closeness_centrality(G, weight=None, dtype=float, solver="lu"): + """Compute current-flow closeness centrality for nodes. + + Current-flow closeness centrality is variant of closeness + centrality based on effective resistance between nodes in + a network. This metric is also known as information centrality. + + Parameters + ---------- + G : graph + A NetworkX graph. + + weight : None or string, optional (default=None) + If None, all edge weights are considered equal. + Otherwise holds the name of the edge attribute used as weight. + The weight reflects the capacity or the strength of the + edge. + + dtype: data type (default=float) + Default data type for internal matrices. + Set to np.float32 for lower memory consumption. + + solver: string (default='lu') + Type of linear solver to use for computing the flow matrix. + Options are "full" (uses most memory), "lu" (recommended), and + "cg" (uses least memory). + + Returns + ------- + nodes : dictionary + Dictionary of nodes with current flow closeness centrality as the value. + + See Also + -------- + closeness_centrality + + Notes + ----- + The algorithm is from Brandes [1]_. + + See also [2]_ for the original definition of information centrality. + + References + ---------- + .. [1] Ulrik Brandes and Daniel Fleischer, + Centrality Measures Based on Current Flow. + Proc. 22nd Symp. Theoretical Aspects of Computer Science (STACS '05). + LNCS 3404, pp. 533-544. Springer-Verlag, 2005. + https://doi.org/10.1007/978-3-540-31856-9_44 + + .. [2] Karen Stephenson and Marvin Zelen: + Rethinking centrality: Methods and examples. + Social Networks 11(1):1-37, 1989. + https://doi.org/10.1016/0378-8733(89)90016-6 + """ + if not nx.is_connected(G): + raise nx.NetworkXError("Graph not connected.") + solvername = { + "full": FullInverseLaplacian, + "lu": SuperLUInverseLaplacian, + "cg": CGInverseLaplacian, + } + N = G.number_of_nodes() + ordering = list(reverse_cuthill_mckee_ordering(G)) + # make a copy with integer labels according to rcm ordering + # this could be done without a copy if we really wanted to + H = nx.relabel_nodes(G, dict(zip(ordering, range(N)))) + betweenness = dict.fromkeys(H, 0.0) # b[n]=0 for n in H + N = H.number_of_nodes() + L = nx.laplacian_matrix(H, nodelist=range(N), weight=weight).asformat("csc") + L = L.astype(dtype) + C2 = solvername[solver](L, width=1, dtype=dtype) # initialize solver + for v in H: + col = C2.get_row(v) + for w in H: + betweenness[v] += col.item(v) - 2 * col.item(w) + betweenness[w] += col.item(v) + return {ordering[node]: 1 / value for node, value in betweenness.items()} + + +information_centrality = current_flow_closeness_centrality diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/degree_alg.py b/lib/python3.10/site-packages/networkx/algorithms/centrality/degree_alg.py new file mode 100644 index 0000000000000000000000000000000000000000..b3c1e321be3f9f3febce5a9104bde09924847001 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/centrality/degree_alg.py @@ -0,0 +1,150 @@ +"""Degree centrality measures.""" + +import networkx as nx +from networkx.utils.decorators import not_implemented_for + +__all__ = ["degree_centrality", "in_degree_centrality", "out_degree_centrality"] + + +@nx._dispatchable +def degree_centrality(G): + """Compute the degree centrality for nodes. + + The degree centrality for a node v is the fraction of nodes it + is connected to. + + Parameters + ---------- + G : graph + A networkx graph + + Returns + ------- + nodes : dictionary + Dictionary of nodes with degree centrality as the value. + + Examples + -------- + >>> G = nx.Graph([(0, 1), (0, 2), (0, 3), (1, 2), (1, 3)]) + >>> nx.degree_centrality(G) + {0: 1.0, 1: 1.0, 2: 0.6666666666666666, 3: 0.6666666666666666} + + See Also + -------- + betweenness_centrality, load_centrality, eigenvector_centrality + + Notes + ----- + The degree centrality values are normalized by dividing by the maximum + possible degree in a simple graph n-1 where n is the number of nodes in G. + + For multigraphs or graphs with self loops the maximum degree might + be higher than n-1 and values of degree centrality greater than 1 + are possible. + """ + if len(G) <= 1: + return {n: 1 for n in G} + + s = 1.0 / (len(G) - 1.0) + centrality = {n: d * s for n, d in G.degree()} + return centrality + + +@not_implemented_for("undirected") +@nx._dispatchable +def in_degree_centrality(G): + """Compute the in-degree centrality for nodes. + + The in-degree centrality for a node v is the fraction of nodes its + incoming edges are connected to. + + Parameters + ---------- + G : graph + A NetworkX graph + + Returns + ------- + nodes : dictionary + Dictionary of nodes with in-degree centrality as values. + + Raises + ------ + NetworkXNotImplemented + If G is undirected. + + Examples + -------- + >>> G = nx.DiGraph([(0, 1), (0, 2), (0, 3), (1, 2), (1, 3)]) + >>> nx.in_degree_centrality(G) + {0: 0.0, 1: 0.3333333333333333, 2: 0.6666666666666666, 3: 0.6666666666666666} + + See Also + -------- + degree_centrality, out_degree_centrality + + Notes + ----- + The degree centrality values are normalized by dividing by the maximum + possible degree in a simple graph n-1 where n is the number of nodes in G. + + For multigraphs or graphs with self loops the maximum degree might + be higher than n-1 and values of degree centrality greater than 1 + are possible. + """ + if len(G) <= 1: + return {n: 1 for n in G} + + s = 1.0 / (len(G) - 1.0) + centrality = {n: d * s for n, d in G.in_degree()} + return centrality + + +@not_implemented_for("undirected") +@nx._dispatchable +def out_degree_centrality(G): + """Compute the out-degree centrality for nodes. + + The out-degree centrality for a node v is the fraction of nodes its + outgoing edges are connected to. + + Parameters + ---------- + G : graph + A NetworkX graph + + Returns + ------- + nodes : dictionary + Dictionary of nodes with out-degree centrality as values. + + Raises + ------ + NetworkXNotImplemented + If G is undirected. + + Examples + -------- + >>> G = nx.DiGraph([(0, 1), (0, 2), (0, 3), (1, 2), (1, 3)]) + >>> nx.out_degree_centrality(G) + {0: 1.0, 1: 0.6666666666666666, 2: 0.0, 3: 0.0} + + See Also + -------- + degree_centrality, in_degree_centrality + + Notes + ----- + The degree centrality values are normalized by dividing by the maximum + possible degree in a simple graph n-1 where n is the number of nodes in G. + + For multigraphs or graphs with self loops the maximum degree might + be higher than n-1 and values of degree centrality greater than 1 + are possible. + """ + if len(G) <= 1: + return {n: 1 for n in G} + + s = 1.0 / (len(G) - 1.0) + centrality = {n: d * s for n, d in G.out_degree()} + return centrality diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/dispersion.py b/lib/python3.10/site-packages/networkx/algorithms/centrality/dispersion.py new file mode 100644 index 0000000000000000000000000000000000000000..a3fa68583a9d18a40e6fbd4c8267e25f7a13c60a --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/centrality/dispersion.py @@ -0,0 +1,107 @@ +from itertools import combinations + +import networkx as nx + +__all__ = ["dispersion"] + + +@nx._dispatchable +def dispersion(G, u=None, v=None, normalized=True, alpha=1.0, b=0.0, c=0.0): + r"""Calculate dispersion between `u` and `v` in `G`. + + A link between two actors (`u` and `v`) has a high dispersion when their + mutual ties (`s` and `t`) are not well connected with each other. + + Parameters + ---------- + G : graph + A NetworkX graph. + u : node, optional + The source for the dispersion score (e.g. ego node of the network). + v : node, optional + The target of the dispersion score if specified. + normalized : bool + If True (default) normalize by the embeddedness of the nodes (u and v). + alpha, b, c : float + Parameters for the normalization procedure. When `normalized` is True, + the dispersion value is normalized by:: + + result = ((dispersion + b) ** alpha) / (embeddedness + c) + + as long as the denominator is nonzero. + + Returns + ------- + nodes : dictionary + If u (v) is specified, returns a dictionary of nodes with dispersion + score for all "target" ("source") nodes. If neither u nor v is + specified, returns a dictionary of dictionaries for all nodes 'u' in the + graph with a dispersion score for each node 'v'. + + Notes + ----- + This implementation follows Lars Backstrom and Jon Kleinberg [1]_. Typical + usage would be to run dispersion on the ego network $G_u$ if $u$ were + specified. Running :func:`dispersion` with neither $u$ nor $v$ specified + can take some time to complete. + + References + ---------- + .. [1] Romantic Partnerships and the Dispersion of Social Ties: + A Network Analysis of Relationship Status on Facebook. + Lars Backstrom, Jon Kleinberg. + https://arxiv.org/pdf/1310.6753v1.pdf + + """ + + def _dispersion(G_u, u, v): + """dispersion for all nodes 'v' in a ego network G_u of node 'u'""" + u_nbrs = set(G_u[u]) + ST = {n for n in G_u[v] if n in u_nbrs} + set_uv = {u, v} + # all possible ties of connections that u and b share + possib = combinations(ST, 2) + total = 0 + for s, t in possib: + # neighbors of s that are in G_u, not including u and v + nbrs_s = u_nbrs.intersection(G_u[s]) - set_uv + # s and t are not directly connected + if t not in nbrs_s: + # s and t do not share a connection + if nbrs_s.isdisjoint(G_u[t]): + # tick for disp(u, v) + total += 1 + # neighbors that u and v share + embeddedness = len(ST) + + dispersion_val = total + if normalized: + dispersion_val = (total + b) ** alpha + if embeddedness + c != 0: + dispersion_val /= embeddedness + c + + return dispersion_val + + if u is None: + # v and u are not specified + if v is None: + results = {n: {} for n in G} + for u in G: + for v in G[u]: + results[u][v] = _dispersion(G, u, v) + # u is not specified, but v is + else: + results = dict.fromkeys(G[v], {}) + for u in G[v]: + results[u] = _dispersion(G, v, u) + else: + # u is specified with no target v + if v is None: + results = dict.fromkeys(G[u], {}) + for v in G[u]: + results[v] = _dispersion(G, u, v) + # both u and v are specified + else: + results = _dispersion(G, u, v) + + return results diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/eigenvector.py b/lib/python3.10/site-packages/networkx/algorithms/centrality/eigenvector.py new file mode 100644 index 0000000000000000000000000000000000000000..b8cf63e8dc3562df2d570c3590501a51654367ff --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/centrality/eigenvector.py @@ -0,0 +1,357 @@ +"""Functions for computing eigenvector centrality.""" + +import math + +import networkx as nx +from networkx.utils import not_implemented_for + +__all__ = ["eigenvector_centrality", "eigenvector_centrality_numpy"] + + +@not_implemented_for("multigraph") +@nx._dispatchable(edge_attrs="weight") +def eigenvector_centrality(G, max_iter=100, tol=1.0e-6, nstart=None, weight=None): + r"""Compute the eigenvector centrality for the graph G. + + Eigenvector centrality computes the centrality for a node by adding + the centrality of its predecessors. The centrality for node $i$ is the + $i$-th element of a left eigenvector associated with the eigenvalue $\lambda$ + of maximum modulus that is positive. Such an eigenvector $x$ is + defined up to a multiplicative constant by the equation + + .. math:: + + \lambda x^T = x^T A, + + where $A$ is the adjacency matrix of the graph G. By definition of + row-column product, the equation above is equivalent to + + .. math:: + + \lambda x_i = \sum_{j\to i}x_j. + + That is, adding the eigenvector centralities of the predecessors of + $i$ one obtains the eigenvector centrality of $i$ multiplied by + $\lambda$. In the case of undirected graphs, $x$ also solves the familiar + right-eigenvector equation $Ax = \lambda x$. + + By virtue of the Perron–Frobenius theorem [1]_, if G is strongly + connected there is a unique eigenvector $x$, and all its entries + are strictly positive. + + If G is not strongly connected there might be several left + eigenvectors associated with $\lambda$, and some of their elements + might be zero. + + Parameters + ---------- + G : graph + A networkx graph. + + max_iter : integer, optional (default=100) + Maximum number of power iterations. + + tol : float, optional (default=1.0e-6) + Error tolerance (in Euclidean norm) used to check convergence in + power iteration. + + nstart : dictionary, optional (default=None) + Starting value of power iteration for each node. Must have a nonzero + projection on the desired eigenvector for the power method to converge. + If None, this implementation uses an all-ones vector, which is a safe + choice. + + weight : None or string, optional (default=None) + If None, all edge weights are considered equal. Otherwise holds the + name of the edge attribute used as weight. In this measure the + weight is interpreted as the connection strength. + + Returns + ------- + nodes : dictionary + Dictionary of nodes with eigenvector centrality as the value. The + associated vector has unit Euclidean norm and the values are + nonegative. + + Examples + -------- + >>> G = nx.path_graph(4) + >>> centrality = nx.eigenvector_centrality(G) + >>> sorted((v, f"{c:0.2f}") for v, c in centrality.items()) + [(0, '0.37'), (1, '0.60'), (2, '0.60'), (3, '0.37')] + + Raises + ------ + NetworkXPointlessConcept + If the graph G is the null graph. + + NetworkXError + If each value in `nstart` is zero. + + PowerIterationFailedConvergence + If the algorithm fails to converge to the specified tolerance + within the specified number of iterations of the power iteration + method. + + See Also + -------- + eigenvector_centrality_numpy + :func:`~networkx.algorithms.link_analysis.pagerank_alg.pagerank` + :func:`~networkx.algorithms.link_analysis.hits_alg.hits` + + Notes + ----- + Eigenvector centrality was introduced by Landau [2]_ for chess + tournaments. It was later rediscovered by Wei [3]_ and then + popularized by Kendall [4]_ in the context of sport ranking. Berge + introduced a general definition for graphs based on social connections + [5]_. Bonacich [6]_ reintroduced again eigenvector centrality and made + it popular in link analysis. + + This function computes the left dominant eigenvector, which corresponds + to adding the centrality of predecessors: this is the usual approach. + To add the centrality of successors first reverse the graph with + ``G.reverse()``. + + The implementation uses power iteration [7]_ to compute a dominant + eigenvector starting from the provided vector `nstart`. Convergence is + guaranteed as long as `nstart` has a nonzero projection on a dominant + eigenvector, which certainly happens using the default value. + + The method stops when the change in the computed vector between two + iterations is smaller than an error tolerance of ``G.number_of_nodes() + * tol`` or after ``max_iter`` iterations, but in the second case it + raises an exception. + + This implementation uses $(A + I)$ rather than the adjacency matrix + $A$ because the change preserves eigenvectors, but it shifts the + spectrum, thus guaranteeing convergence even for networks with + negative eigenvalues of maximum modulus. + + References + ---------- + .. [1] Abraham Berman and Robert J. Plemmons. + "Nonnegative Matrices in the Mathematical Sciences." + Classics in Applied Mathematics. SIAM, 1994. + + .. [2] Edmund Landau. + "Zur relativen Wertbemessung der Turnierresultate." + Deutsches Wochenschach, 11:366–369, 1895. + + .. [3] Teh-Hsing Wei. + "The Algebraic Foundations of Ranking Theory." + PhD thesis, University of Cambridge, 1952. + + .. [4] Maurice G. Kendall. + "Further contributions to the theory of paired comparisons." + Biometrics, 11(1):43–62, 1955. + https://www.jstor.org/stable/3001479 + + .. [5] Claude Berge + "Théorie des graphes et ses applications." + Dunod, Paris, France, 1958. + + .. [6] Phillip Bonacich. + "Technique for analyzing overlapping memberships." + Sociological Methodology, 4:176–185, 1972. + https://www.jstor.org/stable/270732 + + .. [7] Power iteration:: https://en.wikipedia.org/wiki/Power_iteration + + """ + if len(G) == 0: + raise nx.NetworkXPointlessConcept( + "cannot compute centrality for the null graph" + ) + # If no initial vector is provided, start with the all-ones vector. + if nstart is None: + nstart = {v: 1 for v in G} + if all(v == 0 for v in nstart.values()): + raise nx.NetworkXError("initial vector cannot have all zero values") + # Normalize the initial vector so that each entry is in [0, 1]. This is + # guaranteed to never have a divide-by-zero error by the previous line. + nstart_sum = sum(nstart.values()) + x = {k: v / nstart_sum for k, v in nstart.items()} + nnodes = G.number_of_nodes() + # make up to max_iter iterations + for _ in range(max_iter): + xlast = x + x = xlast.copy() # Start with xlast times I to iterate with (A+I) + # do the multiplication y^T = x^T A (left eigenvector) + for n in x: + for nbr in G[n]: + w = G[n][nbr].get(weight, 1) if weight else 1 + x[nbr] += xlast[n] * w + # Normalize the vector. The normalization denominator `norm` + # should never be zero by the Perron--Frobenius + # theorem. However, in case it is due to numerical error, we + # assume the norm to be one instead. + norm = math.hypot(*x.values()) or 1 + x = {k: v / norm for k, v in x.items()} + # Check for convergence (in the L_1 norm). + if sum(abs(x[n] - xlast[n]) for n in x) < nnodes * tol: + return x + raise nx.PowerIterationFailedConvergence(max_iter) + + +@nx._dispatchable(edge_attrs="weight") +def eigenvector_centrality_numpy(G, weight=None, max_iter=50, tol=0): + r"""Compute the eigenvector centrality for the graph `G`. + + Eigenvector centrality computes the centrality for a node by adding + the centrality of its predecessors. The centrality for node $i$ is the + $i$-th element of a left eigenvector associated with the eigenvalue $\lambda$ + of maximum modulus that is positive. Such an eigenvector $x$ is + defined up to a multiplicative constant by the equation + + .. math:: + + \lambda x^T = x^T A, + + where $A$ is the adjacency matrix of the graph `G`. By definition of + row-column product, the equation above is equivalent to + + .. math:: + + \lambda x_i = \sum_{j\to i}x_j. + + That is, adding the eigenvector centralities of the predecessors of + $i$ one obtains the eigenvector centrality of $i$ multiplied by + $\lambda$. In the case of undirected graphs, $x$ also solves the familiar + right-eigenvector equation $Ax = \lambda x$. + + By virtue of the Perron--Frobenius theorem [1]_, if `G` is (strongly) + connected, there is a unique eigenvector $x$, and all its entries + are strictly positive. + + However, if `G` is not (strongly) connected, there might be several left + eigenvectors associated with $\lambda$, and some of their elements + might be zero. + Depending on the method used to choose eigenvectors, round-off error can affect + which of the infinitely many eigenvectors is reported. + This can lead to inconsistent results for the same graph, + which the underlying implementation is not robust to. + For this reason, only (strongly) connected graphs are accepted. + + Parameters + ---------- + G : graph + A connected NetworkX graph. + + weight : None or string, optional (default=None) + If ``None``, all edge weights are considered equal. Otherwise holds the + name of the edge attribute used as weight. In this measure the + weight is interpreted as the connection strength. + + max_iter : integer, optional (default=50) + Maximum number of Arnoldi update iterations allowed. + + tol : float, optional (default=0) + Relative accuracy for eigenvalues (stopping criterion). + The default value of 0 implies machine precision. + + Returns + ------- + nodes : dict of nodes + Dictionary of nodes with eigenvector centrality as the value. The + associated vector has unit Euclidean norm and the values are + nonnegative. + + Examples + -------- + >>> G = nx.path_graph(4) + >>> centrality = nx.eigenvector_centrality_numpy(G) + >>> print([f"{node} {centrality[node]:0.2f}" for node in centrality]) + ['0 0.37', '1 0.60', '2 0.60', '3 0.37'] + + Raises + ------ + NetworkXPointlessConcept + If the graph `G` is the null graph. + + ArpackNoConvergence + When the requested convergence is not obtained. The currently + converged eigenvalues and eigenvectors can be found as + eigenvalues and eigenvectors attributes of the exception object. + + AmbiguousSolution + If `G` is not connected. + + See Also + -------- + :func:`scipy.sparse.linalg.eigs` + eigenvector_centrality + :func:`~networkx.algorithms.link_analysis.pagerank_alg.pagerank` + :func:`~networkx.algorithms.link_analysis.hits_alg.hits` + + Notes + ----- + Eigenvector centrality was introduced by Landau [2]_ for chess + tournaments. It was later rediscovered by Wei [3]_ and then + popularized by Kendall [4]_ in the context of sport ranking. Berge + introduced a general definition for graphs based on social connections + [5]_. Bonacich [6]_ reintroduced again eigenvector centrality and made + it popular in link analysis. + + This function computes the left dominant eigenvector, which corresponds + to adding the centrality of predecessors: this is the usual approach. + To add the centrality of successors first reverse the graph with + ``G.reverse()``. + + This implementation uses the + :func:`SciPy sparse eigenvalue solver` (ARPACK) + to find the largest eigenvalue/eigenvector pair using Arnoldi iterations + [7]_. + + References + ---------- + .. [1] Abraham Berman and Robert J. Plemmons. + "Nonnegative Matrices in the Mathematical Sciences". + Classics in Applied Mathematics. SIAM, 1994. + + .. [2] Edmund Landau. + "Zur relativen Wertbemessung der Turnierresultate". + Deutsches Wochenschach, 11:366--369, 1895. + + .. [3] Teh-Hsing Wei. + "The Algebraic Foundations of Ranking Theory". + PhD thesis, University of Cambridge, 1952. + + .. [4] Maurice G. Kendall. + "Further contributions to the theory of paired comparisons". + Biometrics, 11(1):43--62, 1955. + https://www.jstor.org/stable/3001479 + + .. [5] Claude Berge. + "Théorie des graphes et ses applications". + Dunod, Paris, France, 1958. + + .. [6] Phillip Bonacich. + "Technique for analyzing overlapping memberships". + Sociological Methodology, 4:176--185, 1972. + https://www.jstor.org/stable/270732 + + .. [7] Arnoldi, W. E. (1951). + "The principle of minimized iterations in the solution of the matrix eigenvalue problem". + Quarterly of Applied Mathematics. 9 (1): 17--29. + https://doi.org/10.1090/qam/42792 + """ + import numpy as np + import scipy as sp + + if len(G) == 0: + raise nx.NetworkXPointlessConcept( + "cannot compute centrality for the null graph" + ) + connected = nx.is_strongly_connected(G) if G.is_directed() else nx.is_connected(G) + if not connected: # See gh-6888. + raise nx.AmbiguousSolution( + "`eigenvector_centrality_numpy` does not give consistent results for disconnected graphs" + ) + M = nx.to_scipy_sparse_array(G, nodelist=list(G), weight=weight, dtype=float) + _, eigenvector = sp.sparse.linalg.eigs( + M.T, k=1, which="LR", maxiter=max_iter, tol=tol + ) + largest = eigenvector.flatten().real + norm = np.sign(largest.sum()) * sp.linalg.norm(largest) + return dict(zip(G, (largest / norm).tolist())) diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/flow_matrix.py b/lib/python3.10/site-packages/networkx/algorithms/centrality/flow_matrix.py new file mode 100644 index 0000000000000000000000000000000000000000..e72b5e976c003c9e870f0c17e0fea25bb6e0596a --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/centrality/flow_matrix.py @@ -0,0 +1,130 @@ +# Helpers for current-flow betweenness and current-flow closeness +# Lazy computations for inverse Laplacian and flow-matrix rows. +import networkx as nx + + +@nx._dispatchable(edge_attrs="weight") +def flow_matrix_row(G, weight=None, dtype=float, solver="lu"): + # Generate a row of the current-flow matrix + import numpy as np + + solvername = { + "full": FullInverseLaplacian, + "lu": SuperLUInverseLaplacian, + "cg": CGInverseLaplacian, + } + n = G.number_of_nodes() + L = nx.laplacian_matrix(G, nodelist=range(n), weight=weight).asformat("csc") + L = L.astype(dtype) + C = solvername[solver](L, dtype=dtype) # initialize solver + w = C.w # w is the Laplacian matrix width + # row-by-row flow matrix + for u, v in sorted(sorted((u, v)) for u, v in G.edges()): + B = np.zeros(w, dtype=dtype) + c = G[u][v].get(weight, 1.0) + B[u % w] = c + B[v % w] = -c + # get only the rows needed in the inverse laplacian + # and multiply to get the flow matrix row + row = B @ C.get_rows(u, v) + yield row, (u, v) + + +# Class to compute the inverse laplacian only for specified rows +# Allows computation of the current-flow matrix without storing entire +# inverse laplacian matrix +class InverseLaplacian: + def __init__(self, L, width=None, dtype=None): + global np + import numpy as np + + (n, n) = L.shape + self.dtype = dtype + self.n = n + if width is None: + self.w = self.width(L) + else: + self.w = width + self.C = np.zeros((self.w, n), dtype=dtype) + self.L1 = L[1:, 1:] + self.init_solver(L) + + def init_solver(self, L): + pass + + def solve(self, r): + raise nx.NetworkXError("Implement solver") + + def solve_inverse(self, r): + raise nx.NetworkXError("Implement solver") + + def get_rows(self, r1, r2): + for r in range(r1, r2 + 1): + self.C[r % self.w, 1:] = self.solve_inverse(r) + return self.C + + def get_row(self, r): + self.C[r % self.w, 1:] = self.solve_inverse(r) + return self.C[r % self.w] + + def width(self, L): + m = 0 + for i, row in enumerate(L): + w = 0 + y = np.nonzero(row)[-1] + if len(y) > 0: + v = y - i + w = v.max() - v.min() + 1 + m = max(w, m) + return m + + +class FullInverseLaplacian(InverseLaplacian): + def init_solver(self, L): + self.IL = np.zeros(L.shape, dtype=self.dtype) + self.IL[1:, 1:] = np.linalg.inv(self.L1.todense()) + + def solve(self, rhs): + s = np.zeros(rhs.shape, dtype=self.dtype) + s = self.IL @ rhs + return s + + def solve_inverse(self, r): + return self.IL[r, 1:] + + +class SuperLUInverseLaplacian(InverseLaplacian): + def init_solver(self, L): + import scipy as sp + + self.lusolve = sp.sparse.linalg.factorized(self.L1.tocsc()) + + def solve_inverse(self, r): + rhs = np.zeros(self.n, dtype=self.dtype) + rhs[r] = 1 + return self.lusolve(rhs[1:]) + + def solve(self, rhs): + s = np.zeros(rhs.shape, dtype=self.dtype) + s[1:] = self.lusolve(rhs[1:]) + return s + + +class CGInverseLaplacian(InverseLaplacian): + def init_solver(self, L): + global sp + import scipy as sp + + ilu = sp.sparse.linalg.spilu(self.L1.tocsc()) + n = self.n - 1 + self.M = sp.sparse.linalg.LinearOperator(shape=(n, n), matvec=ilu.solve) + + def solve(self, rhs): + s = np.zeros(rhs.shape, dtype=self.dtype) + s[1:] = sp.sparse.linalg.cg(self.L1, rhs[1:], M=self.M, atol=0)[0] + return s + + def solve_inverse(self, r): + rhs = np.zeros(self.n, self.dtype) + rhs[r] = 1 + return sp.sparse.linalg.cg(self.L1, rhs[1:], M=self.M, atol=0)[0] diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/group.py b/lib/python3.10/site-packages/networkx/algorithms/centrality/group.py new file mode 100644 index 0000000000000000000000000000000000000000..7c48742a3b16c7500237c622a410d3a27e1cb953 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/centrality/group.py @@ -0,0 +1,787 @@ +"""Group centrality measures.""" + +from copy import deepcopy + +import networkx as nx +from networkx.algorithms.centrality.betweenness import ( + _accumulate_endpoints, + _single_source_dijkstra_path_basic, + _single_source_shortest_path_basic, +) +from networkx.utils.decorators import not_implemented_for + +__all__ = [ + "group_betweenness_centrality", + "group_closeness_centrality", + "group_degree_centrality", + "group_in_degree_centrality", + "group_out_degree_centrality", + "prominent_group", +] + + +@nx._dispatchable(edge_attrs="weight") +def group_betweenness_centrality(G, C, normalized=True, weight=None, endpoints=False): + r"""Compute the group betweenness centrality for a group of nodes. + + Group betweenness centrality of a group of nodes $C$ is the sum of the + fraction of all-pairs shortest paths that pass through any vertex in $C$ + + .. math:: + + c_B(v) =\sum_{s,t \in V} \frac{\sigma(s, t|v)}{\sigma(s, t)} + + where $V$ is the set of nodes, $\sigma(s, t)$ is the number of + shortest $(s, t)$-paths, and $\sigma(s, t|C)$ is the number of + those paths passing through some node in group $C$. Note that + $(s, t)$ are not members of the group ($V-C$ is the set of nodes + in $V$ that are not in $C$). + + Parameters + ---------- + G : graph + A NetworkX graph. + + C : list or set or list of lists or list of sets + A group or a list of groups containing nodes which belong to G, for which group betweenness + centrality is to be calculated. + + normalized : bool, optional (default=True) + If True, group betweenness is normalized by `1/((|V|-|C|)(|V|-|C|-1))` + where `|V|` is the number of nodes in G and `|C|` is the number of nodes in C. + + weight : None or string, optional (default=None) + If None, all edge weights are considered equal. + Otherwise holds the name of the edge attribute used as weight. + The weight of an edge is treated as the length or distance between the two sides. + + endpoints : bool, optional (default=False) + If True include the endpoints in the shortest path counts. + + Raises + ------ + NodeNotFound + If node(s) in C are not present in G. + + Returns + ------- + betweenness : list of floats or float + If C is a single group then return a float. If C is a list with + several groups then return a list of group betweenness centralities. + + See Also + -------- + betweenness_centrality + + Notes + ----- + Group betweenness centrality is described in [1]_ and its importance discussed in [3]_. + The initial implementation of the algorithm is mentioned in [2]_. This function uses + an improved algorithm presented in [4]_. + + The number of nodes in the group must be a maximum of n - 2 where `n` + is the total number of nodes in the graph. + + For weighted graphs the edge weights must be greater than zero. + Zero edge weights can produce an infinite number of equal length + paths between pairs of nodes. + + The total number of paths between source and target is counted + differently for directed and undirected graphs. Directed paths + between "u" and "v" are counted as two possible paths (one each + direction) while undirected paths between "u" and "v" are counted + as one path. Said another way, the sum in the expression above is + over all ``s != t`` for directed graphs and for ``s < t`` for undirected graphs. + + + References + ---------- + .. [1] M G Everett and S P Borgatti: + The Centrality of Groups and Classes. + Journal of Mathematical Sociology. 23(3): 181-201. 1999. + http://www.analytictech.com/borgatti/group_centrality.htm + .. [2] Ulrik Brandes: + On Variants of Shortest-Path Betweenness + Centrality and their Generic Computation. + Social Networks 30(2):136-145, 2008. + http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.72.9610&rep=rep1&type=pdf + .. [3] Sourav Medya et. al.: + Group Centrality Maximization via Network Design. + SIAM International Conference on Data Mining, SDM 2018, 126–134. + https://sites.cs.ucsb.edu/~arlei/pubs/sdm18.pdf + .. [4] Rami Puzis, Yuval Elovici, and Shlomi Dolev. + "Fast algorithm for successive computation of group betweenness centrality." + https://journals.aps.org/pre/pdf/10.1103/PhysRevE.76.056709 + + """ + GBC = [] # initialize betweenness + list_of_groups = True + # check weather C contains one or many groups + if any(el in G for el in C): + C = [C] + list_of_groups = False + set_v = {node for group in C for node in group} + if set_v - G.nodes: # element(s) of C not in G + raise nx.NodeNotFound(f"The node(s) {set_v - G.nodes} are in C but not in G.") + + # pre-processing + PB, sigma, D = _group_preprocessing(G, set_v, weight) + + # the algorithm for each group + for group in C: + group = set(group) # set of nodes in group + # initialize the matrices of the sigma and the PB + GBC_group = 0 + sigma_m = deepcopy(sigma) + PB_m = deepcopy(PB) + sigma_m_v = deepcopy(sigma_m) + PB_m_v = deepcopy(PB_m) + for v in group: + GBC_group += PB_m[v][v] + for x in group: + for y in group: + dxvy = 0 + dxyv = 0 + dvxy = 0 + if not ( + sigma_m[x][y] == 0 or sigma_m[x][v] == 0 or sigma_m[v][y] == 0 + ): + if D[x][v] == D[x][y] + D[y][v]: + dxyv = sigma_m[x][y] * sigma_m[y][v] / sigma_m[x][v] + if D[x][y] == D[x][v] + D[v][y]: + dxvy = sigma_m[x][v] * sigma_m[v][y] / sigma_m[x][y] + if D[v][y] == D[v][x] + D[x][y]: + dvxy = sigma_m[v][x] * sigma[x][y] / sigma[v][y] + sigma_m_v[x][y] = sigma_m[x][y] * (1 - dxvy) + PB_m_v[x][y] = PB_m[x][y] - PB_m[x][y] * dxvy + if y != v: + PB_m_v[x][y] -= PB_m[x][v] * dxyv + if x != v: + PB_m_v[x][y] -= PB_m[v][y] * dvxy + sigma_m, sigma_m_v = sigma_m_v, sigma_m + PB_m, PB_m_v = PB_m_v, PB_m + + # endpoints + v, c = len(G), len(group) + if not endpoints: + scale = 0 + # if the graph is connected then subtract the endpoints from + # the count for all the nodes in the graph. else count how many + # nodes are connected to the group's nodes and subtract that. + if nx.is_directed(G): + if nx.is_strongly_connected(G): + scale = c * (2 * v - c - 1) + elif nx.is_connected(G): + scale = c * (2 * v - c - 1) + if scale == 0: + for group_node1 in group: + for node in D[group_node1]: + if node != group_node1: + if node in group: + scale += 1 + else: + scale += 2 + GBC_group -= scale + + # normalized + if normalized: + scale = 1 / ((v - c) * (v - c - 1)) + GBC_group *= scale + + # If undirected than count only the undirected edges + elif not G.is_directed(): + GBC_group /= 2 + + GBC.append(GBC_group) + if list_of_groups: + return GBC + return GBC[0] + + +def _group_preprocessing(G, set_v, weight): + sigma = {} + delta = {} + D = {} + betweenness = dict.fromkeys(G, 0) + for s in G: + if weight is None: # use BFS + S, P, sigma[s], D[s] = _single_source_shortest_path_basic(G, s) + else: # use Dijkstra's algorithm + S, P, sigma[s], D[s] = _single_source_dijkstra_path_basic(G, s, weight) + betweenness, delta[s] = _accumulate_endpoints(betweenness, S, P, sigma[s], s) + for i in delta[s]: # add the paths from s to i and rescale sigma + if s != i: + delta[s][i] += 1 + if weight is not None: + sigma[s][i] = sigma[s][i] / 2 + # building the path betweenness matrix only for nodes that appear in the group + PB = dict.fromkeys(G) + for group_node1 in set_v: + PB[group_node1] = dict.fromkeys(G, 0.0) + for group_node2 in set_v: + if group_node2 not in D[group_node1]: + continue + for node in G: + # if node is connected to the two group nodes than continue + if group_node2 in D[node] and group_node1 in D[node]: + if ( + D[node][group_node2] + == D[node][group_node1] + D[group_node1][group_node2] + ): + PB[group_node1][group_node2] += ( + delta[node][group_node2] + * sigma[node][group_node1] + * sigma[group_node1][group_node2] + / sigma[node][group_node2] + ) + return PB, sigma, D + + +@nx._dispatchable(edge_attrs="weight") +def prominent_group( + G, k, weight=None, C=None, endpoints=False, normalized=True, greedy=False +): + r"""Find the prominent group of size $k$ in graph $G$. The prominence of the + group is evaluated by the group betweenness centrality. + + Group betweenness centrality of a group of nodes $C$ is the sum of the + fraction of all-pairs shortest paths that pass through any vertex in $C$ + + .. math:: + + c_B(v) =\sum_{s,t \in V} \frac{\sigma(s, t|v)}{\sigma(s, t)} + + where $V$ is the set of nodes, $\sigma(s, t)$ is the number of + shortest $(s, t)$-paths, and $\sigma(s, t|C)$ is the number of + those paths passing through some node in group $C$. Note that + $(s, t)$ are not members of the group ($V-C$ is the set of nodes + in $V$ that are not in $C$). + + Parameters + ---------- + G : graph + A NetworkX graph. + + k : int + The number of nodes in the group. + + normalized : bool, optional (default=True) + If True, group betweenness is normalized by ``1/((|V|-|C|)(|V|-|C|-1))`` + where ``|V|`` is the number of nodes in G and ``|C|`` is the number of + nodes in C. + + weight : None or string, optional (default=None) + If None, all edge weights are considered equal. + Otherwise holds the name of the edge attribute used as weight. + The weight of an edge is treated as the length or distance between the two sides. + + endpoints : bool, optional (default=False) + If True include the endpoints in the shortest path counts. + + C : list or set, optional (default=None) + list of nodes which won't be candidates of the prominent group. + + greedy : bool, optional (default=False) + Using a naive greedy algorithm in order to find non-optimal prominent + group. For scale free networks the results are negligibly below the optimal + results. + + Raises + ------ + NodeNotFound + If node(s) in C are not present in G. + + Returns + ------- + max_GBC : float + The group betweenness centrality of the prominent group. + + max_group : list + The list of nodes in the prominent group. + + See Also + -------- + betweenness_centrality, group_betweenness_centrality + + Notes + ----- + Group betweenness centrality is described in [1]_ and its importance discussed in [3]_. + The algorithm is described in [2]_ and is based on techniques mentioned in [4]_. + + The number of nodes in the group must be a maximum of ``n - 2`` where ``n`` + is the total number of nodes in the graph. + + For weighted graphs the edge weights must be greater than zero. + Zero edge weights can produce an infinite number of equal length + paths between pairs of nodes. + + The total number of paths between source and target is counted + differently for directed and undirected graphs. Directed paths + between "u" and "v" are counted as two possible paths (one each + direction) while undirected paths between "u" and "v" are counted + as one path. Said another way, the sum in the expression above is + over all ``s != t`` for directed graphs and for ``s < t`` for undirected graphs. + + References + ---------- + .. [1] M G Everett and S P Borgatti: + The Centrality of Groups and Classes. + Journal of Mathematical Sociology. 23(3): 181-201. 1999. + http://www.analytictech.com/borgatti/group_centrality.htm + .. [2] Rami Puzis, Yuval Elovici, and Shlomi Dolev: + "Finding the Most Prominent Group in Complex Networks" + AI communications 20(4): 287-296, 2007. + https://www.researchgate.net/profile/Rami_Puzis2/publication/220308855 + .. [3] Sourav Medya et. al.: + Group Centrality Maximization via Network Design. + SIAM International Conference on Data Mining, SDM 2018, 126–134. + https://sites.cs.ucsb.edu/~arlei/pubs/sdm18.pdf + .. [4] Rami Puzis, Yuval Elovici, and Shlomi Dolev. + "Fast algorithm for successive computation of group betweenness centrality." + https://journals.aps.org/pre/pdf/10.1103/PhysRevE.76.056709 + """ + import numpy as np + import pandas as pd + + if C is not None: + C = set(C) + if C - G.nodes: # element(s) of C not in G + raise nx.NodeNotFound(f"The node(s) {C - G.nodes} are in C but not in G.") + nodes = list(G.nodes - C) + else: + nodes = list(G.nodes) + DF_tree = nx.Graph() + DF_tree.__networkx_cache__ = None # Disable caching + PB, sigma, D = _group_preprocessing(G, nodes, weight) + betweenness = pd.DataFrame.from_dict(PB) + if C is not None: + for node in C: + # remove from the betweenness all the nodes not part of the group + betweenness.drop(index=node, inplace=True) + betweenness.drop(columns=node, inplace=True) + CL = [node for _, node in sorted(zip(np.diag(betweenness), nodes), reverse=True)] + max_GBC = 0 + max_group = [] + DF_tree.add_node( + 1, + CL=CL, + betweenness=betweenness, + GBC=0, + GM=[], + sigma=sigma, + cont=dict(zip(nodes, np.diag(betweenness))), + ) + + # the algorithm + DF_tree.nodes[1]["heu"] = 0 + for i in range(k): + DF_tree.nodes[1]["heu"] += DF_tree.nodes[1]["cont"][DF_tree.nodes[1]["CL"][i]] + max_GBC, DF_tree, max_group = _dfbnb( + G, k, DF_tree, max_GBC, 1, D, max_group, nodes, greedy + ) + + v = len(G) + if not endpoints: + scale = 0 + # if the graph is connected then subtract the endpoints from + # the count for all the nodes in the graph. else count how many + # nodes are connected to the group's nodes and subtract that. + if nx.is_directed(G): + if nx.is_strongly_connected(G): + scale = k * (2 * v - k - 1) + elif nx.is_connected(G): + scale = k * (2 * v - k - 1) + if scale == 0: + for group_node1 in max_group: + for node in D[group_node1]: + if node != group_node1: + if node in max_group: + scale += 1 + else: + scale += 2 + max_GBC -= scale + + # normalized + if normalized: + scale = 1 / ((v - k) * (v - k - 1)) + max_GBC *= scale + + # If undirected then count only the undirected edges + elif not G.is_directed(): + max_GBC /= 2 + max_GBC = float(f"{max_GBC:.2f}") + return max_GBC, max_group + + +def _dfbnb(G, k, DF_tree, max_GBC, root, D, max_group, nodes, greedy): + # stopping condition - if we found a group of size k and with higher GBC then prune + if len(DF_tree.nodes[root]["GM"]) == k and DF_tree.nodes[root]["GBC"] > max_GBC: + return DF_tree.nodes[root]["GBC"], DF_tree, DF_tree.nodes[root]["GM"] + # stopping condition - if the size of group members equal to k or there are less than + # k - |GM| in the candidate list or the heuristic function plus the GBC is below the + # maximal GBC found then prune + if ( + len(DF_tree.nodes[root]["GM"]) == k + or len(DF_tree.nodes[root]["CL"]) <= k - len(DF_tree.nodes[root]["GM"]) + or DF_tree.nodes[root]["GBC"] + DF_tree.nodes[root]["heu"] <= max_GBC + ): + return max_GBC, DF_tree, max_group + + # finding the heuristic of both children + node_p, node_m, DF_tree = _heuristic(k, root, DF_tree, D, nodes, greedy) + + # finding the child with the bigger heuristic + GBC and expand + # that node first if greedy then only expand the plus node + if greedy: + max_GBC, DF_tree, max_group = _dfbnb( + G, k, DF_tree, max_GBC, node_p, D, max_group, nodes, greedy + ) + + elif ( + DF_tree.nodes[node_p]["GBC"] + DF_tree.nodes[node_p]["heu"] + > DF_tree.nodes[node_m]["GBC"] + DF_tree.nodes[node_m]["heu"] + ): + max_GBC, DF_tree, max_group = _dfbnb( + G, k, DF_tree, max_GBC, node_p, D, max_group, nodes, greedy + ) + max_GBC, DF_tree, max_group = _dfbnb( + G, k, DF_tree, max_GBC, node_m, D, max_group, nodes, greedy + ) + else: + max_GBC, DF_tree, max_group = _dfbnb( + G, k, DF_tree, max_GBC, node_m, D, max_group, nodes, greedy + ) + max_GBC, DF_tree, max_group = _dfbnb( + G, k, DF_tree, max_GBC, node_p, D, max_group, nodes, greedy + ) + return max_GBC, DF_tree, max_group + + +def _heuristic(k, root, DF_tree, D, nodes, greedy): + import numpy as np + + # This helper function add two nodes to DF_tree - one left son and the + # other right son, finds their heuristic, CL, GBC, and GM + node_p = DF_tree.number_of_nodes() + 1 + node_m = DF_tree.number_of_nodes() + 2 + added_node = DF_tree.nodes[root]["CL"][0] + + # adding the plus node + DF_tree.add_nodes_from([(node_p, deepcopy(DF_tree.nodes[root]))]) + DF_tree.nodes[node_p]["GM"].append(added_node) + DF_tree.nodes[node_p]["GBC"] += DF_tree.nodes[node_p]["cont"][added_node] + root_node = DF_tree.nodes[root] + for x in nodes: + for y in nodes: + dxvy = 0 + dxyv = 0 + dvxy = 0 + if not ( + root_node["sigma"][x][y] == 0 + or root_node["sigma"][x][added_node] == 0 + or root_node["sigma"][added_node][y] == 0 + ): + if D[x][added_node] == D[x][y] + D[y][added_node]: + dxyv = ( + root_node["sigma"][x][y] + * root_node["sigma"][y][added_node] + / root_node["sigma"][x][added_node] + ) + if D[x][y] == D[x][added_node] + D[added_node][y]: + dxvy = ( + root_node["sigma"][x][added_node] + * root_node["sigma"][added_node][y] + / root_node["sigma"][x][y] + ) + if D[added_node][y] == D[added_node][x] + D[x][y]: + dvxy = ( + root_node["sigma"][added_node][x] + * root_node["sigma"][x][y] + / root_node["sigma"][added_node][y] + ) + DF_tree.nodes[node_p]["sigma"][x][y] = root_node["sigma"][x][y] * (1 - dxvy) + DF_tree.nodes[node_p]["betweenness"].loc[y, x] = ( + root_node["betweenness"][x][y] - root_node["betweenness"][x][y] * dxvy + ) + if y != added_node: + DF_tree.nodes[node_p]["betweenness"].loc[y, x] -= ( + root_node["betweenness"][x][added_node] * dxyv + ) + if x != added_node: + DF_tree.nodes[node_p]["betweenness"].loc[y, x] -= ( + root_node["betweenness"][added_node][y] * dvxy + ) + + DF_tree.nodes[node_p]["CL"] = [ + node + for _, node in sorted( + zip(np.diag(DF_tree.nodes[node_p]["betweenness"]), nodes), reverse=True + ) + if node not in DF_tree.nodes[node_p]["GM"] + ] + DF_tree.nodes[node_p]["cont"] = dict( + zip(nodes, np.diag(DF_tree.nodes[node_p]["betweenness"])) + ) + DF_tree.nodes[node_p]["heu"] = 0 + for i in range(k - len(DF_tree.nodes[node_p]["GM"])): + DF_tree.nodes[node_p]["heu"] += DF_tree.nodes[node_p]["cont"][ + DF_tree.nodes[node_p]["CL"][i] + ] + + # adding the minus node - don't insert the first node in the CL to GM + # Insert minus node only if isn't greedy type algorithm + if not greedy: + DF_tree.add_nodes_from([(node_m, deepcopy(DF_tree.nodes[root]))]) + DF_tree.nodes[node_m]["CL"].pop(0) + DF_tree.nodes[node_m]["cont"].pop(added_node) + DF_tree.nodes[node_m]["heu"] = 0 + for i in range(k - len(DF_tree.nodes[node_m]["GM"])): + DF_tree.nodes[node_m]["heu"] += DF_tree.nodes[node_m]["cont"][ + DF_tree.nodes[node_m]["CL"][i] + ] + else: + node_m = None + + return node_p, node_m, DF_tree + + +@nx._dispatchable(edge_attrs="weight") +def group_closeness_centrality(G, S, weight=None): + r"""Compute the group closeness centrality for a group of nodes. + + Group closeness centrality of a group of nodes $S$ is a measure + of how close the group is to the other nodes in the graph. + + .. math:: + + c_{close}(S) = \frac{|V-S|}{\sum_{v \in V-S} d_{S, v}} + + d_{S, v} = min_{u \in S} (d_{u, v}) + + where $V$ is the set of nodes, $d_{S, v}$ is the distance of + the group $S$ from $v$ defined as above. ($V-S$ is the set of nodes + in $V$ that are not in $S$). + + Parameters + ---------- + G : graph + A NetworkX graph. + + S : list or set + S is a group of nodes which belong to G, for which group closeness + centrality is to be calculated. + + weight : None or string, optional (default=None) + If None, all edge weights are considered equal. + Otherwise holds the name of the edge attribute used as weight. + The weight of an edge is treated as the length or distance between the two sides. + + Raises + ------ + NodeNotFound + If node(s) in S are not present in G. + + Returns + ------- + closeness : float + Group closeness centrality of the group S. + + See Also + -------- + closeness_centrality + + Notes + ----- + The measure was introduced in [1]_. + The formula implemented here is described in [2]_. + + Higher values of closeness indicate greater centrality. + + It is assumed that 1 / 0 is 0 (required in the case of directed graphs, + or when a shortest path length is 0). + + The number of nodes in the group must be a maximum of n - 1 where `n` + is the total number of nodes in the graph. + + For directed graphs, the incoming distance is utilized here. To use the + outward distance, act on `G.reverse()`. + + For weighted graphs the edge weights must be greater than zero. + Zero edge weights can produce an infinite number of equal length + paths between pairs of nodes. + + References + ---------- + .. [1] M G Everett and S P Borgatti: + The Centrality of Groups and Classes. + Journal of Mathematical Sociology. 23(3): 181-201. 1999. + http://www.analytictech.com/borgatti/group_centrality.htm + .. [2] J. Zhao et. al.: + Measuring and Maximizing Group Closeness Centrality over + Disk Resident Graphs. + WWWConference Proceedings, 2014. 689-694. + https://doi.org/10.1145/2567948.2579356 + """ + if G.is_directed(): + G = G.reverse() # reverse view + closeness = 0 # initialize to 0 + V = set(G) # set of nodes in G + S = set(S) # set of nodes in group S + V_S = V - S # set of nodes in V but not S + shortest_path_lengths = nx.multi_source_dijkstra_path_length(G, S, weight=weight) + # accumulation + for v in V_S: + try: + closeness += shortest_path_lengths[v] + except KeyError: # no path exists + closeness += 0 + try: + closeness = len(V_S) / closeness + except ZeroDivisionError: # 1 / 0 assumed as 0 + closeness = 0 + return closeness + + +@nx._dispatchable +def group_degree_centrality(G, S): + """Compute the group degree centrality for a group of nodes. + + Group degree centrality of a group of nodes $S$ is the fraction + of non-group members connected to group members. + + Parameters + ---------- + G : graph + A NetworkX graph. + + S : list or set + S is a group of nodes which belong to G, for which group degree + centrality is to be calculated. + + Raises + ------ + NetworkXError + If node(s) in S are not in G. + + Returns + ------- + centrality : float + Group degree centrality of the group S. + + See Also + -------- + degree_centrality + group_in_degree_centrality + group_out_degree_centrality + + Notes + ----- + The measure was introduced in [1]_. + + The number of nodes in the group must be a maximum of n - 1 where `n` + is the total number of nodes in the graph. + + References + ---------- + .. [1] M G Everett and S P Borgatti: + The Centrality of Groups and Classes. + Journal of Mathematical Sociology. 23(3): 181-201. 1999. + http://www.analytictech.com/borgatti/group_centrality.htm + """ + centrality = len(set().union(*[set(G.neighbors(i)) for i in S]) - set(S)) + centrality /= len(G.nodes()) - len(S) + return centrality + + +@not_implemented_for("undirected") +@nx._dispatchable +def group_in_degree_centrality(G, S): + """Compute the group in-degree centrality for a group of nodes. + + Group in-degree centrality of a group of nodes $S$ is the fraction + of non-group members connected to group members by incoming edges. + + Parameters + ---------- + G : graph + A NetworkX graph. + + S : list or set + S is a group of nodes which belong to G, for which group in-degree + centrality is to be calculated. + + Returns + ------- + centrality : float + Group in-degree centrality of the group S. + + Raises + ------ + NetworkXNotImplemented + If G is undirected. + + NodeNotFound + If node(s) in S are not in G. + + See Also + -------- + degree_centrality + group_degree_centrality + group_out_degree_centrality + + Notes + ----- + The number of nodes in the group must be a maximum of n - 1 where `n` + is the total number of nodes in the graph. + + `G.neighbors(i)` gives nodes with an outward edge from i, in a DiGraph, + so for group in-degree centrality, the reverse graph is used. + """ + return group_degree_centrality(G.reverse(), S) + + +@not_implemented_for("undirected") +@nx._dispatchable +def group_out_degree_centrality(G, S): + """Compute the group out-degree centrality for a group of nodes. + + Group out-degree centrality of a group of nodes $S$ is the fraction + of non-group members connected to group members by outgoing edges. + + Parameters + ---------- + G : graph + A NetworkX graph. + + S : list or set + S is a group of nodes which belong to G, for which group in-degree + centrality is to be calculated. + + Returns + ------- + centrality : float + Group out-degree centrality of the group S. + + Raises + ------ + NetworkXNotImplemented + If G is undirected. + + NodeNotFound + If node(s) in S are not in G. + + See Also + -------- + degree_centrality + group_degree_centrality + group_in_degree_centrality + + Notes + ----- + The number of nodes in the group must be a maximum of n - 1 where `n` + is the total number of nodes in the graph. + + `G.neighbors(i)` gives nodes with an outward edge from i, in a DiGraph, + so for group out-degree centrality, the graph itself is used. + """ + return group_degree_centrality(G, S) diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/harmonic.py b/lib/python3.10/site-packages/networkx/algorithms/centrality/harmonic.py new file mode 100644 index 0000000000000000000000000000000000000000..236e14919a805a907da28ac2691eb53cc993ad2d --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/centrality/harmonic.py @@ -0,0 +1,89 @@ +"""Functions for computing the harmonic centrality of a graph.""" + +from functools import partial + +import networkx as nx + +__all__ = ["harmonic_centrality"] + + +@nx._dispatchable(edge_attrs="distance") +def harmonic_centrality(G, nbunch=None, distance=None, sources=None): + r"""Compute harmonic centrality for nodes. + + Harmonic centrality [1]_ of a node `u` is the sum of the reciprocal + of the shortest path distances from all other nodes to `u` + + .. math:: + + C(u) = \sum_{v \neq u} \frac{1}{d(v, u)} + + where `d(v, u)` is the shortest-path distance between `v` and `u`. + + If `sources` is given as an argument, the returned harmonic centrality + values are calculated as the sum of the reciprocals of the shortest + path distances from the nodes specified in `sources` to `u` instead + of from all nodes to `u`. + + Notice that higher values indicate higher centrality. + + Parameters + ---------- + G : graph + A NetworkX graph + + nbunch : container (default: all nodes in G) + Container of nodes for which harmonic centrality values are calculated. + + sources : container (default: all nodes in G) + Container of nodes `v` over which reciprocal distances are computed. + Nodes not in `G` are silently ignored. + + distance : edge attribute key, optional (default=None) + Use the specified edge attribute as the edge distance in shortest + path calculations. If `None`, then each edge will have distance equal to 1. + + Returns + ------- + nodes : dictionary + Dictionary of nodes with harmonic centrality as the value. + + See Also + -------- + betweenness_centrality, load_centrality, eigenvector_centrality, + degree_centrality, closeness_centrality + + Notes + ----- + If the 'distance' keyword is set to an edge attribute key then the + shortest-path length will be computed using Dijkstra's algorithm with + that edge attribute as the edge weight. + + References + ---------- + .. [1] Boldi, Paolo, and Sebastiano Vigna. "Axioms for centrality." + Internet Mathematics 10.3-4 (2014): 222-262. + """ + + nbunch = set(G.nbunch_iter(nbunch) if nbunch is not None else G.nodes) + sources = set(G.nbunch_iter(sources) if sources is not None else G.nodes) + + centrality = {u: 0 for u in nbunch} + + transposed = False + if len(nbunch) < len(sources): + transposed = True + nbunch, sources = sources, nbunch + if nx.is_directed(G): + G = nx.reverse(G, copy=False) + + spl = partial(nx.shortest_path_length, G, weight=distance) + for v in sources: + dist = spl(v) + for u in nbunch.intersection(dist): + d = dist[u] + if d == 0: # handle u == v and edges with 0 weight + continue + centrality[v if transposed else u] += 1 / d + + return centrality diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/katz.py b/lib/python3.10/site-packages/networkx/algorithms/centrality/katz.py new file mode 100644 index 0000000000000000000000000000000000000000..4bd087bc3e55de4f71413033f969ad22e8acddd7 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/centrality/katz.py @@ -0,0 +1,331 @@ +"""Katz centrality.""" + +import math + +import networkx as nx +from networkx.utils import not_implemented_for + +__all__ = ["katz_centrality", "katz_centrality_numpy"] + + +@not_implemented_for("multigraph") +@nx._dispatchable(edge_attrs="weight") +def katz_centrality( + G, + alpha=0.1, + beta=1.0, + max_iter=1000, + tol=1.0e-6, + nstart=None, + normalized=True, + weight=None, +): + r"""Compute the Katz centrality for the nodes of the graph G. + + Katz centrality computes the centrality for a node based on the centrality + of its neighbors. It is a generalization of the eigenvector centrality. The + Katz centrality for node $i$ is + + .. math:: + + x_i = \alpha \sum_{j} A_{ij} x_j + \beta, + + where $A$ is the adjacency matrix of graph G with eigenvalues $\lambda$. + + The parameter $\beta$ controls the initial centrality and + + .. math:: + + \alpha < \frac{1}{\lambda_{\max}}. + + Katz centrality computes the relative influence of a node within a + network by measuring the number of the immediate neighbors (first + degree nodes) and also all other nodes in the network that connect + to the node under consideration through these immediate neighbors. + + Extra weight can be provided to immediate neighbors through the + parameter $\beta$. Connections made with distant neighbors + are, however, penalized by an attenuation factor $\alpha$ which + should be strictly less than the inverse largest eigenvalue of the + adjacency matrix in order for the Katz centrality to be computed + correctly. More information is provided in [1]_. + + Parameters + ---------- + G : graph + A NetworkX graph. + + alpha : float, optional (default=0.1) + Attenuation factor + + beta : scalar or dictionary, optional (default=1.0) + Weight attributed to the immediate neighborhood. If not a scalar, the + dictionary must have a value for every node. + + max_iter : integer, optional (default=1000) + Maximum number of iterations in power method. + + tol : float, optional (default=1.0e-6) + Error tolerance used to check convergence in power method iteration. + + nstart : dictionary, optional + Starting value of Katz iteration for each node. + + normalized : bool, optional (default=True) + If True normalize the resulting values. + + weight : None or string, optional (default=None) + If None, all edge weights are considered equal. + Otherwise holds the name of the edge attribute used as weight. + In this measure the weight is interpreted as the connection strength. + + Returns + ------- + nodes : dictionary + Dictionary of nodes with Katz centrality as the value. + + Raises + ------ + NetworkXError + If the parameter `beta` is not a scalar but lacks a value for at least + one node + + PowerIterationFailedConvergence + If the algorithm fails to converge to the specified tolerance + within the specified number of iterations of the power iteration + method. + + Examples + -------- + >>> import math + >>> G = nx.path_graph(4) + >>> phi = (1 + math.sqrt(5)) / 2.0 # largest eigenvalue of adj matrix + >>> centrality = nx.katz_centrality(G, 1 / phi - 0.01) + >>> for n, c in sorted(centrality.items()): + ... print(f"{n} {c:.2f}") + 0 0.37 + 1 0.60 + 2 0.60 + 3 0.37 + + See Also + -------- + katz_centrality_numpy + eigenvector_centrality + eigenvector_centrality_numpy + :func:`~networkx.algorithms.link_analysis.pagerank_alg.pagerank` + :func:`~networkx.algorithms.link_analysis.hits_alg.hits` + + Notes + ----- + Katz centrality was introduced by [2]_. + + This algorithm it uses the power method to find the eigenvector + corresponding to the largest eigenvalue of the adjacency matrix of ``G``. + The parameter ``alpha`` should be strictly less than the inverse of largest + eigenvalue of the adjacency matrix for the algorithm to converge. + You can use ``max(nx.adjacency_spectrum(G))`` to get $\lambda_{\max}$ the largest + eigenvalue of the adjacency matrix. + The iteration will stop after ``max_iter`` iterations or an error tolerance of + ``number_of_nodes(G) * tol`` has been reached. + + For strongly connected graphs, as $\alpha \to 1/\lambda_{\max}$, and $\beta > 0$, + Katz centrality approaches the results for eigenvector centrality. + + For directed graphs this finds "left" eigenvectors which corresponds + to the in-edges in the graph. For out-edges Katz centrality, + first reverse the graph with ``G.reverse()``. + + References + ---------- + .. [1] Mark E. J. Newman: + Networks: An Introduction. + Oxford University Press, USA, 2010, p. 720. + .. [2] Leo Katz: + A New Status Index Derived from Sociometric Index. + Psychometrika 18(1):39–43, 1953 + https://link.springer.com/content/pdf/10.1007/BF02289026.pdf + """ + if len(G) == 0: + return {} + + nnodes = G.number_of_nodes() + + if nstart is None: + # choose starting vector with entries of 0 + x = {n: 0 for n in G} + else: + x = nstart + + try: + b = dict.fromkeys(G, float(beta)) + except (TypeError, ValueError, AttributeError) as err: + b = beta + if set(beta) != set(G): + raise nx.NetworkXError( + "beta dictionary must have a value for every node" + ) from err + + # make up to max_iter iterations + for _ in range(max_iter): + xlast = x + x = dict.fromkeys(xlast, 0) + # do the multiplication y^T = Alpha * x^T A + Beta + for n in x: + for nbr in G[n]: + x[nbr] += xlast[n] * G[n][nbr].get(weight, 1) + for n in x: + x[n] = alpha * x[n] + b[n] + + # check convergence + error = sum(abs(x[n] - xlast[n]) for n in x) + if error < nnodes * tol: + if normalized: + # normalize vector + try: + s = 1.0 / math.hypot(*x.values()) + except ZeroDivisionError: + s = 1.0 + else: + s = 1 + for n in x: + x[n] *= s + return x + raise nx.PowerIterationFailedConvergence(max_iter) + + +@not_implemented_for("multigraph") +@nx._dispatchable(edge_attrs="weight") +def katz_centrality_numpy(G, alpha=0.1, beta=1.0, normalized=True, weight=None): + r"""Compute the Katz centrality for the graph G. + + Katz centrality computes the centrality for a node based on the centrality + of its neighbors. It is a generalization of the eigenvector centrality. The + Katz centrality for node $i$ is + + .. math:: + + x_i = \alpha \sum_{j} A_{ij} x_j + \beta, + + where $A$ is the adjacency matrix of graph G with eigenvalues $\lambda$. + + The parameter $\beta$ controls the initial centrality and + + .. math:: + + \alpha < \frac{1}{\lambda_{\max}}. + + Katz centrality computes the relative influence of a node within a + network by measuring the number of the immediate neighbors (first + degree nodes) and also all other nodes in the network that connect + to the node under consideration through these immediate neighbors. + + Extra weight can be provided to immediate neighbors through the + parameter $\beta$. Connections made with distant neighbors + are, however, penalized by an attenuation factor $\alpha$ which + should be strictly less than the inverse largest eigenvalue of the + adjacency matrix in order for the Katz centrality to be computed + correctly. More information is provided in [1]_. + + Parameters + ---------- + G : graph + A NetworkX graph + + alpha : float + Attenuation factor + + beta : scalar or dictionary, optional (default=1.0) + Weight attributed to the immediate neighborhood. If not a scalar the + dictionary must have an value for every node. + + normalized : bool + If True normalize the resulting values. + + weight : None or string, optional + If None, all edge weights are considered equal. + Otherwise holds the name of the edge attribute used as weight. + In this measure the weight is interpreted as the connection strength. + + Returns + ------- + nodes : dictionary + Dictionary of nodes with Katz centrality as the value. + + Raises + ------ + NetworkXError + If the parameter `beta` is not a scalar but lacks a value for at least + one node + + Examples + -------- + >>> import math + >>> G = nx.path_graph(4) + >>> phi = (1 + math.sqrt(5)) / 2.0 # largest eigenvalue of adj matrix + >>> centrality = nx.katz_centrality_numpy(G, 1 / phi) + >>> for n, c in sorted(centrality.items()): + ... print(f"{n} {c:.2f}") + 0 0.37 + 1 0.60 + 2 0.60 + 3 0.37 + + See Also + -------- + katz_centrality + eigenvector_centrality_numpy + eigenvector_centrality + :func:`~networkx.algorithms.link_analysis.pagerank_alg.pagerank` + :func:`~networkx.algorithms.link_analysis.hits_alg.hits` + + Notes + ----- + Katz centrality was introduced by [2]_. + + This algorithm uses a direct linear solver to solve the above equation. + The parameter ``alpha`` should be strictly less than the inverse of largest + eigenvalue of the adjacency matrix for there to be a solution. + You can use ``max(nx.adjacency_spectrum(G))`` to get $\lambda_{\max}$ the largest + eigenvalue of the adjacency matrix. + + For strongly connected graphs, as $\alpha \to 1/\lambda_{\max}$, and $\beta > 0$, + Katz centrality approaches the results for eigenvector centrality. + + For directed graphs this finds "left" eigenvectors which corresponds + to the in-edges in the graph. For out-edges Katz centrality, + first reverse the graph with ``G.reverse()``. + + References + ---------- + .. [1] Mark E. J. Newman: + Networks: An Introduction. + Oxford University Press, USA, 2010, p. 173. + .. [2] Leo Katz: + A New Status Index Derived from Sociometric Index. + Psychometrika 18(1):39–43, 1953 + https://link.springer.com/content/pdf/10.1007/BF02289026.pdf + """ + import numpy as np + + if len(G) == 0: + return {} + try: + nodelist = beta.keys() + if set(nodelist) != set(G): + raise nx.NetworkXError("beta dictionary must have a value for every node") + b = np.array(list(beta.values()), dtype=float) + except AttributeError: + nodelist = list(G) + try: + b = np.ones((len(nodelist), 1)) * beta + except (TypeError, ValueError, AttributeError) as err: + raise nx.NetworkXError("beta must be a number") from err + + A = nx.adjacency_matrix(G, nodelist=nodelist, weight=weight).todense().T + n = A.shape[0] + centrality = np.linalg.solve(np.eye(n, n) - (alpha * A), b).squeeze() + + # Normalize: rely on truediv to cast to float, then tolist to make Python numbers + norm = np.sign(sum(centrality)) * np.linalg.norm(centrality) if normalized else 1 + return dict(zip(nodelist, (centrality / norm).tolist())) diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/laplacian.py b/lib/python3.10/site-packages/networkx/algorithms/centrality/laplacian.py new file mode 100644 index 0000000000000000000000000000000000000000..efb6e8f67b6a0980381dd121051ce62608b8a984 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/centrality/laplacian.py @@ -0,0 +1,150 @@ +""" +Laplacian centrality measures. +""" + +import networkx as nx + +__all__ = ["laplacian_centrality"] + + +@nx._dispatchable(edge_attrs="weight") +def laplacian_centrality( + G, normalized=True, nodelist=None, weight="weight", walk_type=None, alpha=0.95 +): + r"""Compute the Laplacian centrality for nodes in the graph `G`. + + The Laplacian Centrality of a node ``i`` is measured by the drop in the + Laplacian Energy after deleting node ``i`` from the graph. The Laplacian Energy + is the sum of the squared eigenvalues of a graph's Laplacian matrix. + + .. math:: + + C_L(u_i,G) = \frac{(\Delta E)_i}{E_L (G)} = \frac{E_L (G)-E_L (G_i)}{E_L (G)} + + E_L (G) = \sum_{i=0}^n \lambda_i^2 + + Where $E_L (G)$ is the Laplacian energy of graph `G`, + E_L (G_i) is the Laplacian energy of graph `G` after deleting node ``i`` + and $\lambda_i$ are the eigenvalues of `G`'s Laplacian matrix. + This formula shows the normalized value. Without normalization, + the numerator on the right side is returned. + + Parameters + ---------- + G : graph + A networkx graph + + normalized : bool (default = True) + If True the centrality score is scaled so the sum over all nodes is 1. + If False the centrality score for each node is the drop in Laplacian + energy when that node is removed. + + nodelist : list, optional (default = None) + The rows and columns are ordered according to the nodes in nodelist. + If nodelist is None, then the ordering is produced by G.nodes(). + + weight: string or None, optional (default=`weight`) + Optional parameter `weight` to compute the Laplacian matrix. + The edge data key used to compute each value in the matrix. + If None, then each edge has weight 1. + + walk_type : string or None, optional (default=None) + Optional parameter `walk_type` used when calling + :func:`directed_laplacian_matrix `. + One of ``"random"``, ``"lazy"``, or ``"pagerank"``. If ``walk_type=None`` + (the default), then a value is selected according to the properties of `G`: + - ``walk_type="random"`` if `G` is strongly connected and aperiodic + - ``walk_type="lazy"`` if `G` is strongly connected but not aperiodic + - ``walk_type="pagerank"`` for all other cases. + + alpha : real (default = 0.95) + Optional parameter `alpha` used when calling + :func:`directed_laplacian_matrix `. + (1 - alpha) is the teleportation probability used with pagerank. + + Returns + ------- + nodes : dictionary + Dictionary of nodes with Laplacian centrality as the value. + + Examples + -------- + >>> G = nx.Graph() + >>> edges = [(0, 1, 4), (0, 2, 2), (2, 1, 1), (1, 3, 2), (1, 4, 2), (4, 5, 1)] + >>> G.add_weighted_edges_from(edges) + >>> sorted((v, f"{c:0.2f}") for v, c in laplacian_centrality(G).items()) + [(0, '0.70'), (1, '0.90'), (2, '0.28'), (3, '0.22'), (4, '0.26'), (5, '0.04')] + + Notes + ----- + The algorithm is implemented based on [1]_ with an extension to directed graphs + using the ``directed_laplacian_matrix`` function. + + Raises + ------ + NetworkXPointlessConcept + If the graph `G` is the null graph. + ZeroDivisionError + If the graph `G` has no edges (is empty) and normalization is requested. + + References + ---------- + .. [1] Qi, X., Fuller, E., Wu, Q., Wu, Y., and Zhang, C.-Q. (2012). + Laplacian centrality: A new centrality measure for weighted networks. + Information Sciences, 194:240-253. + https://math.wvu.edu/~cqzhang/Publication-files/my-paper/INS-2012-Laplacian-W.pdf + + See Also + -------- + :func:`~networkx.linalg.laplacianmatrix.directed_laplacian_matrix` + :func:`~networkx.linalg.laplacianmatrix.laplacian_matrix` + """ + import numpy as np + import scipy as sp + + if len(G) == 0: + raise nx.NetworkXPointlessConcept("null graph has no centrality defined") + if G.size(weight=weight) == 0: + if normalized: + raise ZeroDivisionError("graph with no edges has zero full energy") + return {n: 0 for n in G} + + if nodelist is not None: + nodeset = set(G.nbunch_iter(nodelist)) + if len(nodeset) != len(nodelist): + raise nx.NetworkXError("nodelist has duplicate nodes or nodes not in G") + nodes = nodelist + [n for n in G if n not in nodeset] + else: + nodelist = nodes = list(G) + + if G.is_directed(): + lap_matrix = nx.directed_laplacian_matrix(G, nodes, weight, walk_type, alpha) + else: + lap_matrix = nx.laplacian_matrix(G, nodes, weight).toarray() + + full_energy = np.power(sp.linalg.eigh(lap_matrix, eigvals_only=True), 2).sum() + + # calculate laplacian centrality + laplace_centralities_dict = {} + for i, node in enumerate(nodelist): + # remove row and col i from lap_matrix + all_but_i = list(np.arange(lap_matrix.shape[0])) + all_but_i.remove(i) + A_2 = lap_matrix[all_but_i, :][:, all_but_i] + + # Adjust diagonal for removed row + new_diag = lap_matrix.diagonal() - abs(lap_matrix[:, i]) + np.fill_diagonal(A_2, new_diag[all_but_i]) + + if len(all_but_i) > 0: # catches degenerate case of single node + new_energy = np.power(sp.linalg.eigh(A_2, eigvals_only=True), 2).sum() + else: + new_energy = 0.0 + + lapl_cent = full_energy - new_energy + if normalized: + lapl_cent = lapl_cent / full_energy + + laplace_centralities_dict[node] = float(lapl_cent) + + return laplace_centralities_dict diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/load.py b/lib/python3.10/site-packages/networkx/algorithms/centrality/load.py new file mode 100644 index 0000000000000000000000000000000000000000..fc46edd6fa2a1555181058aa17c68cf8a9820429 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/centrality/load.py @@ -0,0 +1,200 @@ +"""Load centrality.""" + +from operator import itemgetter + +import networkx as nx + +__all__ = ["load_centrality", "edge_load_centrality"] + + +@nx._dispatchable(edge_attrs="weight") +def newman_betweenness_centrality(G, v=None, cutoff=None, normalized=True, weight=None): + """Compute load centrality for nodes. + + The load centrality of a node is the fraction of all shortest + paths that pass through that node. + + Parameters + ---------- + G : graph + A networkx graph. + + normalized : bool, optional (default=True) + If True the betweenness values are normalized by b=b/(n-1)(n-2) where + n is the number of nodes in G. + + weight : None or string, optional (default=None) + If None, edge weights are ignored. + Otherwise holds the name of the edge attribute used as weight. + The weight of an edge is treated as the length or distance between the two sides. + + cutoff : bool, optional (default=None) + If specified, only consider paths of length <= cutoff. + + Returns + ------- + nodes : dictionary + Dictionary of nodes with centrality as the value. + + See Also + -------- + betweenness_centrality + + Notes + ----- + Load centrality is slightly different than betweenness. It was originally + introduced by [2]_. For this load algorithm see [1]_. + + References + ---------- + .. [1] Mark E. J. Newman: + Scientific collaboration networks. II. + Shortest paths, weighted networks, and centrality. + Physical Review E 64, 016132, 2001. + http://journals.aps.org/pre/abstract/10.1103/PhysRevE.64.016132 + .. [2] Kwang-Il Goh, Byungnam Kahng and Doochul Kim + Universal behavior of Load Distribution in Scale-Free Networks. + Physical Review Letters 87(27):1–4, 2001. + https://doi.org/10.1103/PhysRevLett.87.278701 + """ + if v is not None: # only one node + betweenness = 0.0 + for source in G: + ubetween = _node_betweenness(G, source, cutoff, False, weight) + betweenness += ubetween[v] if v in ubetween else 0 + if normalized: + order = G.order() + if order <= 2: + return betweenness # no normalization b=0 for all nodes + betweenness *= 1.0 / ((order - 1) * (order - 2)) + else: + betweenness = {}.fromkeys(G, 0.0) + for source in betweenness: + ubetween = _node_betweenness(G, source, cutoff, False, weight) + for vk in ubetween: + betweenness[vk] += ubetween[vk] + if normalized: + order = G.order() + if order <= 2: + return betweenness # no normalization b=0 for all nodes + scale = 1.0 / ((order - 1) * (order - 2)) + for v in betweenness: + betweenness[v] *= scale + return betweenness # all nodes + + +def _node_betweenness(G, source, cutoff=False, normalized=True, weight=None): + """Node betweenness_centrality helper: + + See betweenness_centrality for what you probably want. + This actually computes "load" and not betweenness. + See https://networkx.lanl.gov/ticket/103 + + This calculates the load of each node for paths from a single source. + (The fraction of number of shortests paths from source that go + through each node.) + + To get the load for a node you need to do all-pairs shortest paths. + + If weight is not None then use Dijkstra for finding shortest paths. + """ + # get the predecessor and path length data + if weight is None: + (pred, length) = nx.predecessor(G, source, cutoff=cutoff, return_seen=True) + else: + (pred, length) = nx.dijkstra_predecessor_and_distance(G, source, cutoff, weight) + + # order the nodes by path length + onodes = [(l, vert) for (vert, l) in length.items()] + onodes.sort() + onodes[:] = [vert for (l, vert) in onodes if l > 0] + + # initialize betweenness + between = {}.fromkeys(length, 1.0) + + while onodes: + v = onodes.pop() + if v in pred: + num_paths = len(pred[v]) # Discount betweenness if more than + for x in pred[v]: # one shortest path. + if x == source: # stop if hit source because all remaining v + break # also have pred[v]==[source] + between[x] += between[v] / num_paths + # remove source + for v in between: + between[v] -= 1 + # rescale to be between 0 and 1 + if normalized: + l = len(between) + if l > 2: + # scale by 1/the number of possible paths + scale = 1 / ((l - 1) * (l - 2)) + for v in between: + between[v] *= scale + return between + + +load_centrality = newman_betweenness_centrality + + +@nx._dispatchable +def edge_load_centrality(G, cutoff=False): + """Compute edge load. + + WARNING: This concept of edge load has not been analysed + or discussed outside of NetworkX that we know of. + It is based loosely on load_centrality in the sense that + it counts the number of shortest paths which cross each edge. + This function is for demonstration and testing purposes. + + Parameters + ---------- + G : graph + A networkx graph + + cutoff : bool, optional (default=False) + If specified, only consider paths of length <= cutoff. + + Returns + ------- + A dict keyed by edge 2-tuple to the number of shortest paths + which use that edge. Where more than one path is shortest + the count is divided equally among paths. + """ + betweenness = {} + for u, v in G.edges(): + betweenness[(u, v)] = 0.0 + betweenness[(v, u)] = 0.0 + + for source in G: + ubetween = _edge_betweenness(G, source, cutoff=cutoff) + for e, ubetweenv in ubetween.items(): + betweenness[e] += ubetweenv # cumulative total + return betweenness + + +def _edge_betweenness(G, source, nodes=None, cutoff=False): + """Edge betweenness helper.""" + # get the predecessor data + (pred, length) = nx.predecessor(G, source, cutoff=cutoff, return_seen=True) + # order the nodes by path length + onodes = [n for n, d in sorted(length.items(), key=itemgetter(1))] + # initialize betweenness, doesn't account for any edge weights + between = {} + for u, v in G.edges(nodes): + between[(u, v)] = 1.0 + between[(v, u)] = 1.0 + + while onodes: # work through all paths + v = onodes.pop() + if v in pred: + # Discount betweenness if more than one shortest path. + num_paths = len(pred[v]) + for w in pred[v]: + if w in pred: + # Discount betweenness, mult path + num_paths = len(pred[w]) + for x in pred[w]: + between[(w, x)] += between[(v, w)] / num_paths + between[(x, w)] += between[(w, v)] / num_paths + return between diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/percolation.py b/lib/python3.10/site-packages/networkx/algorithms/centrality/percolation.py new file mode 100644 index 0000000000000000000000000000000000000000..0d4c87132b48fe02f6a86e06f4ada0d7a72239f1 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/centrality/percolation.py @@ -0,0 +1,128 @@ +"""Percolation centrality measures.""" + +import networkx as nx +from networkx.algorithms.centrality.betweenness import ( + _single_source_dijkstra_path_basic as dijkstra, +) +from networkx.algorithms.centrality.betweenness import ( + _single_source_shortest_path_basic as shortest_path, +) + +__all__ = ["percolation_centrality"] + + +@nx._dispatchable(node_attrs="attribute", edge_attrs="weight") +def percolation_centrality(G, attribute="percolation", states=None, weight=None): + r"""Compute the percolation centrality for nodes. + + Percolation centrality of a node $v$, at a given time, is defined + as the proportion of ‘percolated paths’ that go through that node. + + This measure quantifies relative impact of nodes based on their + topological connectivity, as well as their percolation states. + + Percolation states of nodes are used to depict network percolation + scenarios (such as during infection transmission in a social network + of individuals, spreading of computer viruses on computer networks, or + transmission of disease over a network of towns) over time. In this + measure usually the percolation state is expressed as a decimal + between 0.0 and 1.0. + + When all nodes are in the same percolated state this measure is + equivalent to betweenness centrality. + + Parameters + ---------- + G : graph + A NetworkX graph. + + attribute : None or string, optional (default='percolation') + Name of the node attribute to use for percolation state, used + if `states` is None. If a node does not set the attribute the + state of that node will be set to the default value of 1. + If all nodes do not have the attribute all nodes will be set to + 1 and the centrality measure will be equivalent to betweenness centrality. + + states : None or dict, optional (default=None) + Specify percolation states for the nodes, nodes as keys states + as values. + + weight : None or string, optional (default=None) + If None, all edge weights are considered equal. + Otherwise holds the name of the edge attribute used as weight. + The weight of an edge is treated as the length or distance between the two sides. + + + Returns + ------- + nodes : dictionary + Dictionary of nodes with percolation centrality as the value. + + See Also + -------- + betweenness_centrality + + Notes + ----- + The algorithm is from Mahendra Piraveenan, Mikhail Prokopenko, and + Liaquat Hossain [1]_ + Pair dependencies are calculated and accumulated using [2]_ + + For weighted graphs the edge weights must be greater than zero. + Zero edge weights can produce an infinite number of equal length + paths between pairs of nodes. + + References + ---------- + .. [1] Mahendra Piraveenan, Mikhail Prokopenko, Liaquat Hossain + Percolation Centrality: Quantifying Graph-Theoretic Impact of Nodes + during Percolation in Networks + http://journals.plos.org/plosone/article?id=10.1371/journal.pone.0053095 + .. [2] Ulrik Brandes: + A Faster Algorithm for Betweenness Centrality. + Journal of Mathematical Sociology 25(2):163-177, 2001. + https://doi.org/10.1080/0022250X.2001.9990249 + """ + percolation = dict.fromkeys(G, 0.0) # b[v]=0 for v in G + + nodes = G + + if states is None: + states = nx.get_node_attributes(nodes, attribute, default=1) + + # sum of all percolation states + p_sigma_x_t = 0.0 + for v in states.values(): + p_sigma_x_t += v + + for s in nodes: + # single source shortest paths + if weight is None: # use BFS + S, P, sigma, _ = shortest_path(G, s) + else: # use Dijkstra's algorithm + S, P, sigma, _ = dijkstra(G, s, weight) + # accumulation + percolation = _accumulate_percolation( + percolation, S, P, sigma, s, states, p_sigma_x_t + ) + + n = len(G) + + for v in percolation: + percolation[v] *= 1 / (n - 2) + + return percolation + + +def _accumulate_percolation(percolation, S, P, sigma, s, states, p_sigma_x_t): + delta = dict.fromkeys(S, 0) + while S: + w = S.pop() + coeff = (1 + delta[w]) / sigma[w] + for v in P[w]: + delta[v] += sigma[v] * coeff + if w != s: + # percolation weight + pw_s_w = states[s] / (p_sigma_x_t - states[w]) + percolation[w] += delta[w] * pw_s_w + return percolation diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/reaching.py b/lib/python3.10/site-packages/networkx/algorithms/centrality/reaching.py new file mode 100644 index 0000000000000000000000000000000000000000..378e8a0589c6bcaa7ddd2696c167328c1f32ae26 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/centrality/reaching.py @@ -0,0 +1,209 @@ +"""Functions for computing reaching centrality of a node or a graph.""" + +import networkx as nx +from networkx.utils import pairwise + +__all__ = ["global_reaching_centrality", "local_reaching_centrality"] + + +def _average_weight(G, path, weight=None): + """Returns the average weight of an edge in a weighted path. + + Parameters + ---------- + G : graph + A networkx graph. + + path: list + A list of vertices that define the path. + + weight : None or string, optional (default=None) + If None, edge weights are ignored. Then the average weight of an edge + is assumed to be the multiplicative inverse of the length of the path. + Otherwise holds the name of the edge attribute used as weight. + """ + path_length = len(path) - 1 + if path_length <= 0: + return 0 + if weight is None: + return 1 / path_length + total_weight = sum(G.edges[i, j][weight] for i, j in pairwise(path)) + return total_weight / path_length + + +@nx._dispatchable(edge_attrs="weight") +def global_reaching_centrality(G, weight=None, normalized=True): + """Returns the global reaching centrality of a directed graph. + + The *global reaching centrality* of a weighted directed graph is the + average over all nodes of the difference between the local reaching + centrality of the node and the greatest local reaching centrality of + any node in the graph [1]_. For more information on the local + reaching centrality, see :func:`local_reaching_centrality`. + Informally, the local reaching centrality is the proportion of the + graph that is reachable from the neighbors of the node. + + Parameters + ---------- + G : DiGraph + A networkx DiGraph. + + weight : None or string, optional (default=None) + Attribute to use for edge weights. If ``None``, each edge weight + is assumed to be one. A higher weight implies a stronger + connection between nodes and a *shorter* path length. + + normalized : bool, optional (default=True) + Whether to normalize the edge weights by the total sum of edge + weights. + + Returns + ------- + h : float + The global reaching centrality of the graph. + + Examples + -------- + >>> G = nx.DiGraph() + >>> G.add_edge(1, 2) + >>> G.add_edge(1, 3) + >>> nx.global_reaching_centrality(G) + 1.0 + >>> G.add_edge(3, 2) + >>> nx.global_reaching_centrality(G) + 0.75 + + See also + -------- + local_reaching_centrality + + References + ---------- + .. [1] Mones, Enys, Lilla Vicsek, and Tamás Vicsek. + "Hierarchy Measure for Complex Networks." + *PLoS ONE* 7.3 (2012): e33799. + https://doi.org/10.1371/journal.pone.0033799 + """ + if nx.is_negatively_weighted(G, weight=weight): + raise nx.NetworkXError("edge weights must be positive") + total_weight = G.size(weight=weight) + if total_weight <= 0: + raise nx.NetworkXError("Size of G must be positive") + # If provided, weights must be interpreted as connection strength + # (so higher weights are more likely to be chosen). However, the + # shortest path algorithms in NetworkX assume the provided "weight" + # is actually a distance (so edges with higher weight are less + # likely to be chosen). Therefore we need to invert the weights when + # computing shortest paths. + # + # If weight is None, we leave it as-is so that the shortest path + # algorithm can use a faster, unweighted algorithm. + if weight is not None: + + def as_distance(u, v, d): + return total_weight / d.get(weight, 1) + + shortest_paths = nx.shortest_path(G, weight=as_distance) + else: + shortest_paths = nx.shortest_path(G) + + centrality = local_reaching_centrality + # TODO This can be trivially parallelized. + lrc = [ + centrality(G, node, paths=paths, weight=weight, normalized=normalized) + for node, paths in shortest_paths.items() + ] + + max_lrc = max(lrc) + return sum(max_lrc - c for c in lrc) / (len(G) - 1) + + +@nx._dispatchable(edge_attrs="weight") +def local_reaching_centrality(G, v, paths=None, weight=None, normalized=True): + """Returns the local reaching centrality of a node in a directed + graph. + + The *local reaching centrality* of a node in a directed graph is the + proportion of other nodes reachable from that node [1]_. + + Parameters + ---------- + G : DiGraph + A NetworkX DiGraph. + + v : node + A node in the directed graph `G`. + + paths : dictionary (default=None) + If this is not `None` it must be a dictionary representation + of single-source shortest paths, as computed by, for example, + :func:`networkx.shortest_path` with source node `v`. Use this + keyword argument if you intend to invoke this function many + times but don't want the paths to be recomputed each time. + + weight : None or string, optional (default=None) + Attribute to use for edge weights. If `None`, each edge weight + is assumed to be one. A higher weight implies a stronger + connection between nodes and a *shorter* path length. + + normalized : bool, optional (default=True) + Whether to normalize the edge weights by the total sum of edge + weights. + + Returns + ------- + h : float + The local reaching centrality of the node ``v`` in the graph + ``G``. + + Examples + -------- + >>> G = nx.DiGraph() + >>> G.add_edges_from([(1, 2), (1, 3)]) + >>> nx.local_reaching_centrality(G, 3) + 0.0 + >>> G.add_edge(3, 2) + >>> nx.local_reaching_centrality(G, 3) + 0.5 + + See also + -------- + global_reaching_centrality + + References + ---------- + .. [1] Mones, Enys, Lilla Vicsek, and Tamás Vicsek. + "Hierarchy Measure for Complex Networks." + *PLoS ONE* 7.3 (2012): e33799. + https://doi.org/10.1371/journal.pone.0033799 + """ + # Corner case: graph with single node containing a self-loop + if (total_weight := G.size(weight=weight)) > 0 and len(G) == 1: + raise nx.NetworkXError( + "local_reaching_centrality of a single node with self-loop not well-defined" + ) + if paths is None: + if nx.is_negatively_weighted(G, weight=weight): + raise nx.NetworkXError("edge weights must be positive") + if total_weight <= 0: + raise nx.NetworkXError("Size of G must be positive") + if weight is not None: + # Interpret weights as lengths. + def as_distance(u, v, d): + return total_weight / d.get(weight, 1) + + paths = nx.shortest_path(G, source=v, weight=as_distance) + else: + paths = nx.shortest_path(G, source=v) + # If the graph is unweighted, simply return the proportion of nodes + # reachable from the source node ``v``. + if weight is None and G.is_directed(): + return (len(paths) - 1) / (len(G) - 1) + if normalized and weight is not None: + norm = G.size(weight=weight) / G.size() + else: + norm = 1 + # TODO This can be trivially parallelized. + avgw = (_average_weight(G, path, weight=weight) for path in paths.values()) + sum_avg_weight = sum(avgw) / norm + return sum_avg_weight / (len(G) - 1) diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/second_order.py b/lib/python3.10/site-packages/networkx/algorithms/centrality/second_order.py new file mode 100644 index 0000000000000000000000000000000000000000..35583cd63e55d14c0c389040cbdeab39b27d1bf9 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/centrality/second_order.py @@ -0,0 +1,141 @@ +"""Copyright (c) 2015 – Thomson Licensing, SAS + +Redistribution and use in source and binary forms, with or without +modification, are permitted (subject to the limitations in the +disclaimer below) provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in the +documentation and/or other materials provided with the distribution. + +* Neither the name of Thomson Licensing, or Technicolor, nor the names +of its contributors may be used to endorse or promote products derived +from this software without specific prior written permission. + +NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE +GRANTED BY THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT +HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED +WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF +MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR +BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE +OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN +IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +""" + +import networkx as nx +from networkx.utils import not_implemented_for + +# Authors: Erwan Le Merrer (erwan.lemerrer@technicolor.com) + +__all__ = ["second_order_centrality"] + + +@not_implemented_for("directed") +@nx._dispatchable(edge_attrs="weight") +def second_order_centrality(G, weight="weight"): + """Compute the second order centrality for nodes of G. + + The second order centrality of a given node is the standard deviation of + the return times to that node of a perpetual random walk on G: + + Parameters + ---------- + G : graph + A NetworkX connected and undirected graph. + + weight : string or None, optional (default="weight") + The name of an edge attribute that holds the numerical value + used as a weight. If None then each edge has weight 1. + + Returns + ------- + nodes : dictionary + Dictionary keyed by node with second order centrality as the value. + + Examples + -------- + >>> G = nx.star_graph(10) + >>> soc = nx.second_order_centrality(G) + >>> print(sorted(soc.items(), key=lambda x: x[1])[0][0]) # pick first id + 0 + + Raises + ------ + NetworkXException + If the graph G is empty, non connected or has negative weights. + + See Also + -------- + betweenness_centrality + + Notes + ----- + Lower values of second order centrality indicate higher centrality. + + The algorithm is from Kermarrec, Le Merrer, Sericola and Trédan [1]_. + + This code implements the analytical version of the algorithm, i.e., + there is no simulation of a random walk process involved. The random walk + is here unbiased (corresponding to eq 6 of the paper [1]_), thus the + centrality values are the standard deviations for random walk return times + on the transformed input graph G (equal in-degree at each nodes by adding + self-loops). + + Complexity of this implementation, made to run locally on a single machine, + is O(n^3), with n the size of G, which makes it viable only for small + graphs. + + References + ---------- + .. [1] Anne-Marie Kermarrec, Erwan Le Merrer, Bruno Sericola, Gilles Trédan + "Second order centrality: Distributed assessment of nodes criticity in + complex networks", Elsevier Computer Communications 34(5):619-628, 2011. + """ + import numpy as np + + n = len(G) + + if n == 0: + raise nx.NetworkXException("Empty graph.") + if not nx.is_connected(G): + raise nx.NetworkXException("Non connected graph.") + if any(d.get(weight, 0) < 0 for u, v, d in G.edges(data=True)): + raise nx.NetworkXException("Graph has negative edge weights.") + + # balancing G for Metropolis-Hastings random walks + G = nx.DiGraph(G) + in_deg = dict(G.in_degree(weight=weight)) + d_max = max(in_deg.values()) + for i, deg in in_deg.items(): + if deg < d_max: + G.add_edge(i, i, weight=d_max - deg) + + P = nx.to_numpy_array(G) + P /= P.sum(axis=1)[:, np.newaxis] # to transition probability matrix + + def _Qj(P, j): + P = P.copy() + P[:, j] = 0 + return P + + M = np.empty([n, n]) + + for i in range(n): + M[:, i] = np.linalg.solve( + np.identity(n) - _Qj(P, i), np.ones([n, 1])[:, 0] + ) # eq 3 + + return dict( + zip( + G.nodes, + (float(np.sqrt(2 * np.sum(M[:, i]) - n * (n + 1))) for i in range(n)), + ) + ) # eq 6 diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/subgraph_alg.py b/lib/python3.10/site-packages/networkx/algorithms/centrality/subgraph_alg.py new file mode 100644 index 0000000000000000000000000000000000000000..0a49e6f4dd79c88b11c8dd2d3c9a0a35f2d562d8 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/centrality/subgraph_alg.py @@ -0,0 +1,340 @@ +""" +Subraph centrality and communicability betweenness. +""" + +import networkx as nx +from networkx.utils import not_implemented_for + +__all__ = [ + "subgraph_centrality_exp", + "subgraph_centrality", + "communicability_betweenness_centrality", + "estrada_index", +] + + +@not_implemented_for("directed") +@not_implemented_for("multigraph") +@nx._dispatchable +def subgraph_centrality_exp(G): + r"""Returns the subgraph centrality for each node of G. + + Subgraph centrality of a node `n` is the sum of weighted closed + walks of all lengths starting and ending at node `n`. The weights + decrease with path length. Each closed walk is associated with a + connected subgraph ([1]_). + + Parameters + ---------- + G: graph + + Returns + ------- + nodes:dictionary + Dictionary of nodes with subgraph centrality as the value. + + Raises + ------ + NetworkXError + If the graph is not undirected and simple. + + See Also + -------- + subgraph_centrality: + Alternative algorithm of the subgraph centrality for each node of G. + + Notes + ----- + This version of the algorithm exponentiates the adjacency matrix. + + The subgraph centrality of a node `u` in G can be found using + the matrix exponential of the adjacency matrix of G [1]_, + + .. math:: + + SC(u)=(e^A)_{uu} . + + References + ---------- + .. [1] Ernesto Estrada, Juan A. Rodriguez-Velazquez, + "Subgraph centrality in complex networks", + Physical Review E 71, 056103 (2005). + https://arxiv.org/abs/cond-mat/0504730 + + Examples + -------- + (Example from [1]_) + >>> G = nx.Graph( + ... [ + ... (1, 2), + ... (1, 5), + ... (1, 8), + ... (2, 3), + ... (2, 8), + ... (3, 4), + ... (3, 6), + ... (4, 5), + ... (4, 7), + ... (5, 6), + ... (6, 7), + ... (7, 8), + ... ] + ... ) + >>> sc = nx.subgraph_centrality_exp(G) + >>> print([f"{node} {sc[node]:0.2f}" for node in sorted(sc)]) + ['1 3.90', '2 3.90', '3 3.64', '4 3.71', '5 3.64', '6 3.71', '7 3.64', '8 3.90'] + """ + # alternative implementation that calculates the matrix exponential + import scipy as sp + + nodelist = list(G) # ordering of nodes in matrix + A = nx.to_numpy_array(G, nodelist) + # convert to 0-1 matrix + A[A != 0.0] = 1 + expA = sp.linalg.expm(A) + # convert diagonal to dictionary keyed by node + sc = dict(zip(nodelist, map(float, expA.diagonal()))) + return sc + + +@not_implemented_for("directed") +@not_implemented_for("multigraph") +@nx._dispatchable +def subgraph_centrality(G): + r"""Returns subgraph centrality for each node in G. + + Subgraph centrality of a node `n` is the sum of weighted closed + walks of all lengths starting and ending at node `n`. The weights + decrease with path length. Each closed walk is associated with a + connected subgraph ([1]_). + + Parameters + ---------- + G: graph + + Returns + ------- + nodes : dictionary + Dictionary of nodes with subgraph centrality as the value. + + Raises + ------ + NetworkXError + If the graph is not undirected and simple. + + See Also + -------- + subgraph_centrality_exp: + Alternative algorithm of the subgraph centrality for each node of G. + + Notes + ----- + This version of the algorithm computes eigenvalues and eigenvectors + of the adjacency matrix. + + Subgraph centrality of a node `u` in G can be found using + a spectral decomposition of the adjacency matrix [1]_, + + .. math:: + + SC(u)=\sum_{j=1}^{N}(v_{j}^{u})^2 e^{\lambda_{j}}, + + where `v_j` is an eigenvector of the adjacency matrix `A` of G + corresponding to the eigenvalue `\lambda_j`. + + Examples + -------- + (Example from [1]_) + >>> G = nx.Graph( + ... [ + ... (1, 2), + ... (1, 5), + ... (1, 8), + ... (2, 3), + ... (2, 8), + ... (3, 4), + ... (3, 6), + ... (4, 5), + ... (4, 7), + ... (5, 6), + ... (6, 7), + ... (7, 8), + ... ] + ... ) + >>> sc = nx.subgraph_centrality(G) + >>> print([f"{node} {sc[node]:0.2f}" for node in sorted(sc)]) + ['1 3.90', '2 3.90', '3 3.64', '4 3.71', '5 3.64', '6 3.71', '7 3.64', '8 3.90'] + + References + ---------- + .. [1] Ernesto Estrada, Juan A. Rodriguez-Velazquez, + "Subgraph centrality in complex networks", + Physical Review E 71, 056103 (2005). + https://arxiv.org/abs/cond-mat/0504730 + + """ + import numpy as np + + nodelist = list(G) # ordering of nodes in matrix + A = nx.to_numpy_array(G, nodelist) + # convert to 0-1 matrix + A[np.nonzero(A)] = 1 + w, v = np.linalg.eigh(A) + vsquare = np.array(v) ** 2 + expw = np.exp(w) + xg = vsquare @ expw + # convert vector dictionary keyed by node + sc = dict(zip(nodelist, map(float, xg))) + return sc + + +@not_implemented_for("directed") +@not_implemented_for("multigraph") +@nx._dispatchable +def communicability_betweenness_centrality(G): + r"""Returns subgraph communicability for all pairs of nodes in G. + + Communicability betweenness measure makes use of the number of walks + connecting every pair of nodes as the basis of a betweenness centrality + measure. + + Parameters + ---------- + G: graph + + Returns + ------- + nodes : dictionary + Dictionary of nodes with communicability betweenness as the value. + + Raises + ------ + NetworkXError + If the graph is not undirected and simple. + + Notes + ----- + Let `G=(V,E)` be a simple undirected graph with `n` nodes and `m` edges, + and `A` denote the adjacency matrix of `G`. + + Let `G(r)=(V,E(r))` be the graph resulting from + removing all edges connected to node `r` but not the node itself. + + The adjacency matrix for `G(r)` is `A+E(r)`, where `E(r)` has nonzeros + only in row and column `r`. + + The subraph betweenness of a node `r` is [1]_ + + .. math:: + + \omega_{r} = \frac{1}{C}\sum_{p}\sum_{q}\frac{G_{prq}}{G_{pq}}, + p\neq q, q\neq r, + + where + `G_{prq}=(e^{A}_{pq} - (e^{A+E(r)})_{pq}` is the number of walks + involving node r, + `G_{pq}=(e^{A})_{pq}` is the number of closed walks starting + at node `p` and ending at node `q`, + and `C=(n-1)^{2}-(n-1)` is a normalization factor equal to the + number of terms in the sum. + + The resulting `\omega_{r}` takes values between zero and one. + The lower bound cannot be attained for a connected + graph, and the upper bound is attained in the star graph. + + References + ---------- + .. [1] Ernesto Estrada, Desmond J. Higham, Naomichi Hatano, + "Communicability Betweenness in Complex Networks" + Physica A 388 (2009) 764-774. + https://arxiv.org/abs/0905.4102 + + Examples + -------- + >>> G = nx.Graph([(0, 1), (1, 2), (1, 5), (5, 4), (2, 4), (2, 3), (4, 3), (3, 6)]) + >>> cbc = nx.communicability_betweenness_centrality(G) + >>> print([f"{node} {cbc[node]:0.2f}" for node in sorted(cbc)]) + ['0 0.03', '1 0.45', '2 0.51', '3 0.45', '4 0.40', '5 0.19', '6 0.03'] + """ + import numpy as np + import scipy as sp + + nodelist = list(G) # ordering of nodes in matrix + n = len(nodelist) + A = nx.to_numpy_array(G, nodelist) + # convert to 0-1 matrix + A[np.nonzero(A)] = 1 + expA = sp.linalg.expm(A) + mapping = dict(zip(nodelist, range(n))) + cbc = {} + for v in G: + # remove row and col of node v + i = mapping[v] + row = A[i, :].copy() + col = A[:, i].copy() + A[i, :] = 0 + A[:, i] = 0 + B = (expA - sp.linalg.expm(A)) / expA + # sum with row/col of node v and diag set to zero + B[i, :] = 0 + B[:, i] = 0 + B -= np.diag(np.diag(B)) + cbc[v] = float(B.sum()) + # put row and col back + A[i, :] = row + A[:, i] = col + # rescale when more than two nodes + order = len(cbc) + if order > 2: + scale = 1.0 / ((order - 1.0) ** 2 - (order - 1.0)) + cbc = {node: value * scale for node, value in cbc.items()} + return cbc + + +@nx._dispatchable +def estrada_index(G): + r"""Returns the Estrada index of a the graph G. + + The Estrada Index is a topological index of folding or 3D "compactness" ([1]_). + + Parameters + ---------- + G: graph + + Returns + ------- + estrada index: float + + Raises + ------ + NetworkXError + If the graph is not undirected and simple. + + Notes + ----- + Let `G=(V,E)` be a simple undirected graph with `n` nodes and let + `\lambda_{1}\leq\lambda_{2}\leq\cdots\lambda_{n}` + be a non-increasing ordering of the eigenvalues of its adjacency + matrix `A`. The Estrada index is ([1]_, [2]_) + + .. math:: + EE(G)=\sum_{j=1}^n e^{\lambda _j}. + + References + ---------- + .. [1] E. Estrada, "Characterization of 3D molecular structure", + Chem. Phys. Lett. 319, 713 (2000). + https://doi.org/10.1016/S0009-2614(00)00158-5 + .. [2] José Antonio de la Peñaa, Ivan Gutman, Juan Rada, + "Estimating the Estrada index", + Linear Algebra and its Applications. 427, 1 (2007). + https://doi.org/10.1016/j.laa.2007.06.020 + + Examples + -------- + >>> G = nx.Graph([(0, 1), (1, 2), (1, 5), (5, 4), (2, 4), (2, 3), (4, 3), (3, 6)]) + >>> ei = nx.estrada_index(G) + >>> print(f"{ei:0.5}") + 20.55 + """ + return sum(subgraph_centrality(G).values()) diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/__init__.py b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..85e8b8d27bcd69acecbc50a8b5dee8bab6d0b7a6 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/__pycache__/test_betweenness_centrality.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/__pycache__/test_betweenness_centrality.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..29e973d604cf110731fb69d98cdaac43252b2923 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/__pycache__/test_betweenness_centrality.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/__pycache__/test_betweenness_centrality_subset.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/__pycache__/test_betweenness_centrality_subset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..43a846fc8dd1ed36231b8e8bc2a32879b218217a Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/__pycache__/test_betweenness_centrality_subset.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/__pycache__/test_closeness_centrality.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/__pycache__/test_closeness_centrality.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..98b05be787e5f48709ed4f5019557e53eba60307 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/__pycache__/test_closeness_centrality.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/__pycache__/test_current_flow_betweenness_centrality.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/__pycache__/test_current_flow_betweenness_centrality.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f4e6779b10e7de501e7af4862078d6be92139bd1 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/__pycache__/test_current_flow_betweenness_centrality.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/__pycache__/test_current_flow_betweenness_centrality_subset.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/__pycache__/test_current_flow_betweenness_centrality_subset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a72318473ba0b0632c76d8dfaf29eddbf1dd20ae Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/__pycache__/test_current_flow_betweenness_centrality_subset.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/__pycache__/test_current_flow_closeness.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/__pycache__/test_current_flow_closeness.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb774bb713bfc50e73b838c5bbb9fd53986b8fab Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/__pycache__/test_current_flow_closeness.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/__pycache__/test_degree_centrality.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/__pycache__/test_degree_centrality.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..32e7a7e7c90a73a5d802c7ef384194b0f4be3733 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/__pycache__/test_degree_centrality.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/__pycache__/test_dispersion.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/__pycache__/test_dispersion.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e02575c7986512b9369de27f27445d33d2ffca66 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/__pycache__/test_dispersion.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/__pycache__/test_eigenvector_centrality.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/__pycache__/test_eigenvector_centrality.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..329fd1ab93a819eab80e3bbd4661dd26c2485946 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/__pycache__/test_eigenvector_centrality.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/__pycache__/test_group.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/__pycache__/test_group.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9d217ed7f3012cb0daa89a819686acfd7006d1da Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/__pycache__/test_group.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/__pycache__/test_harmonic_centrality.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/__pycache__/test_harmonic_centrality.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c27b062f8627b6eaf1f237b3bf1a893f56065ee5 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/__pycache__/test_harmonic_centrality.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/__pycache__/test_katz_centrality.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/__pycache__/test_katz_centrality.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0dae8690764c8b08cd6b1c8a3815d014ee7b9705 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/__pycache__/test_katz_centrality.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/__pycache__/test_laplacian_centrality.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/__pycache__/test_laplacian_centrality.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4741dc651acf5c006a052b53832419534efda9bd Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/__pycache__/test_laplacian_centrality.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/__pycache__/test_load_centrality.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/__pycache__/test_load_centrality.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5d8285a1910410b05ad05add9f20032d290dd95c Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/__pycache__/test_load_centrality.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/__pycache__/test_percolation_centrality.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/__pycache__/test_percolation_centrality.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e3a2b1344a9329ba53d36ce8ad6165b52c2d3b0d Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/__pycache__/test_percolation_centrality.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/__pycache__/test_reaching.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/__pycache__/test_reaching.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..022e23ecfdd19c14c7e5e8cdbfa8109a7e517b14 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/__pycache__/test_reaching.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/__pycache__/test_second_order_centrality.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/__pycache__/test_second_order_centrality.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1509f8a47a55260195a3ba33315f4efdb8b436d5 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/__pycache__/test_second_order_centrality.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/__pycache__/test_subgraph.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/__pycache__/test_subgraph.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7affdf91e169fb03e02768450e29e5b53fdc7633 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/__pycache__/test_subgraph.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/__pycache__/test_trophic.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/__pycache__/test_trophic.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14201cc97473755adfe3cd2e6538224afb26221f Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/__pycache__/test_trophic.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/__pycache__/test_voterank.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/__pycache__/test_voterank.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ddf336bc6ddc2c1f10922ce2f4b1437f5718dd4b Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/__pycache__/test_voterank.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/test_betweenness_centrality.py b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/test_betweenness_centrality.py new file mode 100644 index 0000000000000000000000000000000000000000..4c059cf980666f7e14a80929f84c80bc38749432 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/test_betweenness_centrality.py @@ -0,0 +1,780 @@ +import pytest + +import networkx as nx + + +def weighted_G(): + G = nx.Graph() + G.add_edge(0, 1, weight=3) + G.add_edge(0, 2, weight=2) + G.add_edge(0, 3, weight=6) + G.add_edge(0, 4, weight=4) + G.add_edge(1, 3, weight=5) + G.add_edge(1, 5, weight=5) + G.add_edge(2, 4, weight=1) + G.add_edge(3, 4, weight=2) + G.add_edge(3, 5, weight=1) + G.add_edge(4, 5, weight=4) + return G + + +class TestBetweennessCentrality: + def test_K5(self): + """Betweenness centrality: K5""" + G = nx.complete_graph(5) + b = nx.betweenness_centrality(G, weight=None, normalized=False) + b_answer = {0: 0.0, 1: 0.0, 2: 0.0, 3: 0.0, 4: 0.0} + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + + def test_K5_endpoints(self): + """Betweenness centrality: K5 endpoints""" + G = nx.complete_graph(5) + b = nx.betweenness_centrality(G, weight=None, normalized=False, endpoints=True) + b_answer = {0: 4.0, 1: 4.0, 2: 4.0, 3: 4.0, 4: 4.0} + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + # normalized = True case + b = nx.betweenness_centrality(G, weight=None, normalized=True, endpoints=True) + b_answer = {0: 0.4, 1: 0.4, 2: 0.4, 3: 0.4, 4: 0.4} + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + + def test_P3_normalized(self): + """Betweenness centrality: P3 normalized""" + G = nx.path_graph(3) + b = nx.betweenness_centrality(G, weight=None, normalized=True) + b_answer = {0: 0.0, 1: 1.0, 2: 0.0} + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + + def test_P3(self): + """Betweenness centrality: P3""" + G = nx.path_graph(3) + b_answer = {0: 0.0, 1: 1.0, 2: 0.0} + b = nx.betweenness_centrality(G, weight=None, normalized=False) + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + + def test_sample_from_P3(self): + """Betweenness centrality: P3 sample""" + G = nx.path_graph(3) + b_answer = {0: 0.0, 1: 1.0, 2: 0.0} + b = nx.betweenness_centrality(G, k=3, weight=None, normalized=False, seed=1) + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + b = nx.betweenness_centrality(G, k=2, weight=None, normalized=False, seed=1) + # python versions give different results with same seed + b_approx1 = {0: 0.0, 1: 1.5, 2: 0.0} + b_approx2 = {0: 0.0, 1: 0.75, 2: 0.0} + for n in sorted(G): + assert b[n] in (b_approx1[n], b_approx2[n]) + + def test_P3_endpoints(self): + """Betweenness centrality: P3 endpoints""" + G = nx.path_graph(3) + b_answer = {0: 2.0, 1: 3.0, 2: 2.0} + b = nx.betweenness_centrality(G, weight=None, normalized=False, endpoints=True) + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + # normalized = True case + b_answer = {0: 2 / 3, 1: 1.0, 2: 2 / 3} + b = nx.betweenness_centrality(G, weight=None, normalized=True, endpoints=True) + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + + def test_krackhardt_kite_graph(self): + """Betweenness centrality: Krackhardt kite graph""" + G = nx.krackhardt_kite_graph() + b_answer = { + 0: 1.667, + 1: 1.667, + 2: 0.000, + 3: 7.333, + 4: 0.000, + 5: 16.667, + 6: 16.667, + 7: 28.000, + 8: 16.000, + 9: 0.000, + } + for b in b_answer: + b_answer[b] /= 2 + b = nx.betweenness_centrality(G, weight=None, normalized=False) + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-3) + + def test_krackhardt_kite_graph_normalized(self): + """Betweenness centrality: Krackhardt kite graph normalized""" + G = nx.krackhardt_kite_graph() + b_answer = { + 0: 0.023, + 1: 0.023, + 2: 0.000, + 3: 0.102, + 4: 0.000, + 5: 0.231, + 6: 0.231, + 7: 0.389, + 8: 0.222, + 9: 0.000, + } + b = nx.betweenness_centrality(G, weight=None, normalized=True) + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-3) + + def test_florentine_families_graph(self): + """Betweenness centrality: Florentine families graph""" + G = nx.florentine_families_graph() + b_answer = { + "Acciaiuoli": 0.000, + "Albizzi": 0.212, + "Barbadori": 0.093, + "Bischeri": 0.104, + "Castellani": 0.055, + "Ginori": 0.000, + "Guadagni": 0.255, + "Lamberteschi": 0.000, + "Medici": 0.522, + "Pazzi": 0.000, + "Peruzzi": 0.022, + "Ridolfi": 0.114, + "Salviati": 0.143, + "Strozzi": 0.103, + "Tornabuoni": 0.092, + } + + b = nx.betweenness_centrality(G, weight=None, normalized=True) + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-3) + + def test_les_miserables_graph(self): + """Betweenness centrality: Les Miserables graph""" + G = nx.les_miserables_graph() + b_answer = { + "Napoleon": 0.000, + "Myriel": 0.177, + "MlleBaptistine": 0.000, + "MmeMagloire": 0.000, + "CountessDeLo": 0.000, + "Geborand": 0.000, + "Champtercier": 0.000, + "Cravatte": 0.000, + "Count": 0.000, + "OldMan": 0.000, + "Valjean": 0.570, + "Labarre": 0.000, + "Marguerite": 0.000, + "MmeDeR": 0.000, + "Isabeau": 0.000, + "Gervais": 0.000, + "Listolier": 0.000, + "Tholomyes": 0.041, + "Fameuil": 0.000, + "Blacheville": 0.000, + "Favourite": 0.000, + "Dahlia": 0.000, + "Zephine": 0.000, + "Fantine": 0.130, + "MmeThenardier": 0.029, + "Thenardier": 0.075, + "Cosette": 0.024, + "Javert": 0.054, + "Fauchelevent": 0.026, + "Bamatabois": 0.008, + "Perpetue": 0.000, + "Simplice": 0.009, + "Scaufflaire": 0.000, + "Woman1": 0.000, + "Judge": 0.000, + "Champmathieu": 0.000, + "Brevet": 0.000, + "Chenildieu": 0.000, + "Cochepaille": 0.000, + "Pontmercy": 0.007, + "Boulatruelle": 0.000, + "Eponine": 0.011, + "Anzelma": 0.000, + "Woman2": 0.000, + "MotherInnocent": 0.000, + "Gribier": 0.000, + "MmeBurgon": 0.026, + "Jondrette": 0.000, + "Gavroche": 0.165, + "Gillenormand": 0.020, + "Magnon": 0.000, + "MlleGillenormand": 0.048, + "MmePontmercy": 0.000, + "MlleVaubois": 0.000, + "LtGillenormand": 0.000, + "Marius": 0.132, + "BaronessT": 0.000, + "Mabeuf": 0.028, + "Enjolras": 0.043, + "Combeferre": 0.001, + "Prouvaire": 0.000, + "Feuilly": 0.001, + "Courfeyrac": 0.005, + "Bahorel": 0.002, + "Bossuet": 0.031, + "Joly": 0.002, + "Grantaire": 0.000, + "MotherPlutarch": 0.000, + "Gueulemer": 0.005, + "Babet": 0.005, + "Claquesous": 0.005, + "Montparnasse": 0.004, + "Toussaint": 0.000, + "Child1": 0.000, + "Child2": 0.000, + "Brujon": 0.000, + "MmeHucheloup": 0.000, + } + + b = nx.betweenness_centrality(G, weight=None, normalized=True) + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-3) + + def test_ladder_graph(self): + """Betweenness centrality: Ladder graph""" + G = nx.Graph() # ladder_graph(3) + G.add_edges_from([(0, 1), (0, 2), (1, 3), (2, 3), (2, 4), (4, 5), (3, 5)]) + b_answer = {0: 1.667, 1: 1.667, 2: 6.667, 3: 6.667, 4: 1.667, 5: 1.667} + for b in b_answer: + b_answer[b] /= 2 + b = nx.betweenness_centrality(G, weight=None, normalized=False) + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-3) + + def test_disconnected_path(self): + """Betweenness centrality: disconnected path""" + G = nx.Graph() + nx.add_path(G, [0, 1, 2]) + nx.add_path(G, [3, 4, 5, 6]) + b_answer = {0: 0, 1: 1, 2: 0, 3: 0, 4: 2, 5: 2, 6: 0} + b = nx.betweenness_centrality(G, weight=None, normalized=False) + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + + def test_disconnected_path_endpoints(self): + """Betweenness centrality: disconnected path endpoints""" + G = nx.Graph() + nx.add_path(G, [0, 1, 2]) + nx.add_path(G, [3, 4, 5, 6]) + b_answer = {0: 2, 1: 3, 2: 2, 3: 3, 4: 5, 5: 5, 6: 3} + b = nx.betweenness_centrality(G, weight=None, normalized=False, endpoints=True) + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + # normalized = True case + b = nx.betweenness_centrality(G, weight=None, normalized=True, endpoints=True) + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n] / 21, abs=1e-7) + + def test_directed_path(self): + """Betweenness centrality: directed path""" + G = nx.DiGraph() + nx.add_path(G, [0, 1, 2]) + b = nx.betweenness_centrality(G, weight=None, normalized=False) + b_answer = {0: 0.0, 1: 1.0, 2: 0.0} + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + + def test_directed_path_normalized(self): + """Betweenness centrality: directed path normalized""" + G = nx.DiGraph() + nx.add_path(G, [0, 1, 2]) + b = nx.betweenness_centrality(G, weight=None, normalized=True) + b_answer = {0: 0.0, 1: 0.5, 2: 0.0} + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + + +class TestWeightedBetweennessCentrality: + def test_K5(self): + """Weighted betweenness centrality: K5""" + G = nx.complete_graph(5) + b = nx.betweenness_centrality(G, weight="weight", normalized=False) + b_answer = {0: 0.0, 1: 0.0, 2: 0.0, 3: 0.0, 4: 0.0} + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + + def test_P3_normalized(self): + """Weighted betweenness centrality: P3 normalized""" + G = nx.path_graph(3) + b = nx.betweenness_centrality(G, weight="weight", normalized=True) + b_answer = {0: 0.0, 1: 1.0, 2: 0.0} + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + + def test_P3(self): + """Weighted betweenness centrality: P3""" + G = nx.path_graph(3) + b_answer = {0: 0.0, 1: 1.0, 2: 0.0} + b = nx.betweenness_centrality(G, weight="weight", normalized=False) + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + + def test_krackhardt_kite_graph(self): + """Weighted betweenness centrality: Krackhardt kite graph""" + G = nx.krackhardt_kite_graph() + b_answer = { + 0: 1.667, + 1: 1.667, + 2: 0.000, + 3: 7.333, + 4: 0.000, + 5: 16.667, + 6: 16.667, + 7: 28.000, + 8: 16.000, + 9: 0.000, + } + for b in b_answer: + b_answer[b] /= 2 + + b = nx.betweenness_centrality(G, weight="weight", normalized=False) + + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-3) + + def test_krackhardt_kite_graph_normalized(self): + """Weighted betweenness centrality: + Krackhardt kite graph normalized + """ + G = nx.krackhardt_kite_graph() + b_answer = { + 0: 0.023, + 1: 0.023, + 2: 0.000, + 3: 0.102, + 4: 0.000, + 5: 0.231, + 6: 0.231, + 7: 0.389, + 8: 0.222, + 9: 0.000, + } + b = nx.betweenness_centrality(G, weight="weight", normalized=True) + + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-3) + + def test_florentine_families_graph(self): + """Weighted betweenness centrality: + Florentine families graph""" + G = nx.florentine_families_graph() + b_answer = { + "Acciaiuoli": 0.000, + "Albizzi": 0.212, + "Barbadori": 0.093, + "Bischeri": 0.104, + "Castellani": 0.055, + "Ginori": 0.000, + "Guadagni": 0.255, + "Lamberteschi": 0.000, + "Medici": 0.522, + "Pazzi": 0.000, + "Peruzzi": 0.022, + "Ridolfi": 0.114, + "Salviati": 0.143, + "Strozzi": 0.103, + "Tornabuoni": 0.092, + } + + b = nx.betweenness_centrality(G, weight="weight", normalized=True) + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-3) + + def test_les_miserables_graph(self): + """Weighted betweenness centrality: Les Miserables graph""" + G = nx.les_miserables_graph() + b_answer = { + "Napoleon": 0.000, + "Myriel": 0.177, + "MlleBaptistine": 0.000, + "MmeMagloire": 0.000, + "CountessDeLo": 0.000, + "Geborand": 0.000, + "Champtercier": 0.000, + "Cravatte": 0.000, + "Count": 0.000, + "OldMan": 0.000, + "Valjean": 0.454, + "Labarre": 0.000, + "Marguerite": 0.009, + "MmeDeR": 0.000, + "Isabeau": 0.000, + "Gervais": 0.000, + "Listolier": 0.000, + "Tholomyes": 0.066, + "Fameuil": 0.000, + "Blacheville": 0.000, + "Favourite": 0.000, + "Dahlia": 0.000, + "Zephine": 0.000, + "Fantine": 0.114, + "MmeThenardier": 0.046, + "Thenardier": 0.129, + "Cosette": 0.075, + "Javert": 0.193, + "Fauchelevent": 0.026, + "Bamatabois": 0.080, + "Perpetue": 0.000, + "Simplice": 0.001, + "Scaufflaire": 0.000, + "Woman1": 0.000, + "Judge": 0.000, + "Champmathieu": 0.000, + "Brevet": 0.000, + "Chenildieu": 0.000, + "Cochepaille": 0.000, + "Pontmercy": 0.023, + "Boulatruelle": 0.000, + "Eponine": 0.023, + "Anzelma": 0.000, + "Woman2": 0.000, + "MotherInnocent": 0.000, + "Gribier": 0.000, + "MmeBurgon": 0.026, + "Jondrette": 0.000, + "Gavroche": 0.285, + "Gillenormand": 0.024, + "Magnon": 0.005, + "MlleGillenormand": 0.036, + "MmePontmercy": 0.005, + "MlleVaubois": 0.000, + "LtGillenormand": 0.015, + "Marius": 0.072, + "BaronessT": 0.004, + "Mabeuf": 0.089, + "Enjolras": 0.003, + "Combeferre": 0.000, + "Prouvaire": 0.000, + "Feuilly": 0.004, + "Courfeyrac": 0.001, + "Bahorel": 0.007, + "Bossuet": 0.028, + "Joly": 0.000, + "Grantaire": 0.036, + "MotherPlutarch": 0.000, + "Gueulemer": 0.025, + "Babet": 0.015, + "Claquesous": 0.042, + "Montparnasse": 0.050, + "Toussaint": 0.011, + "Child1": 0.000, + "Child2": 0.000, + "Brujon": 0.002, + "MmeHucheloup": 0.034, + } + + b = nx.betweenness_centrality(G, weight="weight", normalized=True) + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-3) + + def test_ladder_graph(self): + """Weighted betweenness centrality: Ladder graph""" + G = nx.Graph() # ladder_graph(3) + G.add_edges_from([(0, 1), (0, 2), (1, 3), (2, 3), (2, 4), (4, 5), (3, 5)]) + b_answer = {0: 1.667, 1: 1.667, 2: 6.667, 3: 6.667, 4: 1.667, 5: 1.667} + for b in b_answer: + b_answer[b] /= 2 + b = nx.betweenness_centrality(G, weight="weight", normalized=False) + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-3) + + def test_G(self): + """Weighted betweenness centrality: G""" + G = weighted_G() + b_answer = {0: 2.0, 1: 0.0, 2: 4.0, 3: 3.0, 4: 4.0, 5: 0.0} + b = nx.betweenness_centrality(G, weight="weight", normalized=False) + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + + def test_G2(self): + """Weighted betweenness centrality: G2""" + G = nx.DiGraph() + G.add_weighted_edges_from( + [ + ("s", "u", 10), + ("s", "x", 5), + ("u", "v", 1), + ("u", "x", 2), + ("v", "y", 1), + ("x", "u", 3), + ("x", "v", 5), + ("x", "y", 2), + ("y", "s", 7), + ("y", "v", 6), + ] + ) + + b_answer = {"y": 5.0, "x": 5.0, "s": 4.0, "u": 2.0, "v": 2.0} + + b = nx.betweenness_centrality(G, weight="weight", normalized=False) + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + + def test_G3(self): + """Weighted betweenness centrality: G3""" + G = nx.MultiGraph(weighted_G()) + es = list(G.edges(data=True))[::2] # duplicate every other edge + G.add_edges_from(es) + b_answer = {0: 2.0, 1: 0.0, 2: 4.0, 3: 3.0, 4: 4.0, 5: 0.0} + b = nx.betweenness_centrality(G, weight="weight", normalized=False) + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + + def test_G4(self): + """Weighted betweenness centrality: G4""" + G = nx.MultiDiGraph() + G.add_weighted_edges_from( + [ + ("s", "u", 10), + ("s", "x", 5), + ("s", "x", 6), + ("u", "v", 1), + ("u", "x", 2), + ("v", "y", 1), + ("v", "y", 1), + ("x", "u", 3), + ("x", "v", 5), + ("x", "y", 2), + ("x", "y", 3), + ("y", "s", 7), + ("y", "v", 6), + ("y", "v", 6), + ] + ) + + b_answer = {"y": 5.0, "x": 5.0, "s": 4.0, "u": 2.0, "v": 2.0} + + b = nx.betweenness_centrality(G, weight="weight", normalized=False) + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + + +class TestEdgeBetweennessCentrality: + def test_K5(self): + """Edge betweenness centrality: K5""" + G = nx.complete_graph(5) + b = nx.edge_betweenness_centrality(G, weight=None, normalized=False) + b_answer = dict.fromkeys(G.edges(), 1) + for n in sorted(G.edges()): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + + def test_normalized_K5(self): + """Edge betweenness centrality: K5""" + G = nx.complete_graph(5) + b = nx.edge_betweenness_centrality(G, weight=None, normalized=True) + b_answer = dict.fromkeys(G.edges(), 1 / 10) + for n in sorted(G.edges()): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + + def test_C4(self): + """Edge betweenness centrality: C4""" + G = nx.cycle_graph(4) + b = nx.edge_betweenness_centrality(G, weight=None, normalized=True) + b_answer = {(0, 1): 2, (0, 3): 2, (1, 2): 2, (2, 3): 2} + for n in sorted(G.edges()): + assert b[n] == pytest.approx(b_answer[n] / 6, abs=1e-7) + + def test_P4(self): + """Edge betweenness centrality: P4""" + G = nx.path_graph(4) + b = nx.edge_betweenness_centrality(G, weight=None, normalized=False) + b_answer = {(0, 1): 3, (1, 2): 4, (2, 3): 3} + for n in sorted(G.edges()): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + + def test_normalized_P4(self): + """Edge betweenness centrality: P4""" + G = nx.path_graph(4) + b = nx.edge_betweenness_centrality(G, weight=None, normalized=True) + b_answer = {(0, 1): 3, (1, 2): 4, (2, 3): 3} + for n in sorted(G.edges()): + assert b[n] == pytest.approx(b_answer[n] / 6, abs=1e-7) + + def test_balanced_tree(self): + """Edge betweenness centrality: balanced tree""" + G = nx.balanced_tree(r=2, h=2) + b = nx.edge_betweenness_centrality(G, weight=None, normalized=False) + b_answer = {(0, 1): 12, (0, 2): 12, (1, 3): 6, (1, 4): 6, (2, 5): 6, (2, 6): 6} + for n in sorted(G.edges()): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + + +class TestWeightedEdgeBetweennessCentrality: + def test_K5(self): + """Edge betweenness centrality: K5""" + G = nx.complete_graph(5) + b = nx.edge_betweenness_centrality(G, weight="weight", normalized=False) + b_answer = dict.fromkeys(G.edges(), 1) + for n in sorted(G.edges()): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + + def test_C4(self): + """Edge betweenness centrality: C4""" + G = nx.cycle_graph(4) + b = nx.edge_betweenness_centrality(G, weight="weight", normalized=False) + b_answer = {(0, 1): 2, (0, 3): 2, (1, 2): 2, (2, 3): 2} + for n in sorted(G.edges()): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + + def test_P4(self): + """Edge betweenness centrality: P4""" + G = nx.path_graph(4) + b = nx.edge_betweenness_centrality(G, weight="weight", normalized=False) + b_answer = {(0, 1): 3, (1, 2): 4, (2, 3): 3} + for n in sorted(G.edges()): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + + def test_balanced_tree(self): + """Edge betweenness centrality: balanced tree""" + G = nx.balanced_tree(r=2, h=2) + b = nx.edge_betweenness_centrality(G, weight="weight", normalized=False) + b_answer = {(0, 1): 12, (0, 2): 12, (1, 3): 6, (1, 4): 6, (2, 5): 6, (2, 6): 6} + for n in sorted(G.edges()): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + + def test_weighted_graph(self): + """Edge betweenness centrality: weighted""" + eList = [ + (0, 1, 5), + (0, 2, 4), + (0, 3, 3), + (0, 4, 2), + (1, 2, 4), + (1, 3, 1), + (1, 4, 3), + (2, 4, 5), + (3, 4, 4), + ] + G = nx.Graph() + G.add_weighted_edges_from(eList) + b = nx.edge_betweenness_centrality(G, weight="weight", normalized=False) + b_answer = { + (0, 1): 0.0, + (0, 2): 1.0, + (0, 3): 2.0, + (0, 4): 1.0, + (1, 2): 2.0, + (1, 3): 3.5, + (1, 4): 1.5, + (2, 4): 1.0, + (3, 4): 0.5, + } + for n in sorted(G.edges()): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + + def test_normalized_weighted_graph(self): + """Edge betweenness centrality: normalized weighted""" + eList = [ + (0, 1, 5), + (0, 2, 4), + (0, 3, 3), + (0, 4, 2), + (1, 2, 4), + (1, 3, 1), + (1, 4, 3), + (2, 4, 5), + (3, 4, 4), + ] + G = nx.Graph() + G.add_weighted_edges_from(eList) + b = nx.edge_betweenness_centrality(G, weight="weight", normalized=True) + b_answer = { + (0, 1): 0.0, + (0, 2): 1.0, + (0, 3): 2.0, + (0, 4): 1.0, + (1, 2): 2.0, + (1, 3): 3.5, + (1, 4): 1.5, + (2, 4): 1.0, + (3, 4): 0.5, + } + norm = len(G) * (len(G) - 1) / 2 + for n in sorted(G.edges()): + assert b[n] == pytest.approx(b_answer[n] / norm, abs=1e-7) + + def test_weighted_multigraph(self): + """Edge betweenness centrality: weighted multigraph""" + eList = [ + (0, 1, 5), + (0, 1, 4), + (0, 2, 4), + (0, 3, 3), + (0, 3, 3), + (0, 4, 2), + (1, 2, 4), + (1, 3, 1), + (1, 3, 2), + (1, 4, 3), + (1, 4, 4), + (2, 4, 5), + (3, 4, 4), + (3, 4, 4), + ] + G = nx.MultiGraph() + G.add_weighted_edges_from(eList) + b = nx.edge_betweenness_centrality(G, weight="weight", normalized=False) + b_answer = { + (0, 1, 0): 0.0, + (0, 1, 1): 0.5, + (0, 2, 0): 1.0, + (0, 3, 0): 0.75, + (0, 3, 1): 0.75, + (0, 4, 0): 1.0, + (1, 2, 0): 2.0, + (1, 3, 0): 3.0, + (1, 3, 1): 0.0, + (1, 4, 0): 1.5, + (1, 4, 1): 0.0, + (2, 4, 0): 1.0, + (3, 4, 0): 0.25, + (3, 4, 1): 0.25, + } + for n in sorted(G.edges(keys=True)): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + + def test_normalized_weighted_multigraph(self): + """Edge betweenness centrality: normalized weighted multigraph""" + eList = [ + (0, 1, 5), + (0, 1, 4), + (0, 2, 4), + (0, 3, 3), + (0, 3, 3), + (0, 4, 2), + (1, 2, 4), + (1, 3, 1), + (1, 3, 2), + (1, 4, 3), + (1, 4, 4), + (2, 4, 5), + (3, 4, 4), + (3, 4, 4), + ] + G = nx.MultiGraph() + G.add_weighted_edges_from(eList) + b = nx.edge_betweenness_centrality(G, weight="weight", normalized=True) + b_answer = { + (0, 1, 0): 0.0, + (0, 1, 1): 0.5, + (0, 2, 0): 1.0, + (0, 3, 0): 0.75, + (0, 3, 1): 0.75, + (0, 4, 0): 1.0, + (1, 2, 0): 2.0, + (1, 3, 0): 3.0, + (1, 3, 1): 0.0, + (1, 4, 0): 1.5, + (1, 4, 1): 0.0, + (2, 4, 0): 1.0, + (3, 4, 0): 0.25, + (3, 4, 1): 0.25, + } + norm = len(G) * (len(G) - 1) / 2 + for n in sorted(G.edges(keys=True)): + assert b[n] == pytest.approx(b_answer[n] / norm, abs=1e-7) diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/test_betweenness_centrality_subset.py b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/test_betweenness_centrality_subset.py new file mode 100644 index 0000000000000000000000000000000000000000..a35a401a28e31d279c0d715f79f8a7cc5738050f --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/test_betweenness_centrality_subset.py @@ -0,0 +1,340 @@ +import pytest + +import networkx as nx + + +class TestSubsetBetweennessCentrality: + def test_K5(self): + """Betweenness Centrality Subset: K5""" + G = nx.complete_graph(5) + b = nx.betweenness_centrality_subset( + G, sources=[0], targets=[1, 3], weight=None + ) + b_answer = {0: 0.0, 1: 0.0, 2: 0.0, 3: 0.0, 4: 0.0} + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + + def test_P5_directed(self): + """Betweenness Centrality Subset: P5 directed""" + G = nx.DiGraph() + nx.add_path(G, range(5)) + b_answer = {0: 0, 1: 1, 2: 1, 3: 0, 4: 0, 5: 0} + b = nx.betweenness_centrality_subset(G, sources=[0], targets=[3], weight=None) + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + + def test_P5(self): + """Betweenness Centrality Subset: P5""" + G = nx.Graph() + nx.add_path(G, range(5)) + b_answer = {0: 0, 1: 0.5, 2: 0.5, 3: 0, 4: 0, 5: 0} + b = nx.betweenness_centrality_subset(G, sources=[0], targets=[3], weight=None) + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + + def test_P5_multiple_target(self): + """Betweenness Centrality Subset: P5 multiple target""" + G = nx.Graph() + nx.add_path(G, range(5)) + b_answer = {0: 0, 1: 1, 2: 1, 3: 0.5, 4: 0, 5: 0} + b = nx.betweenness_centrality_subset( + G, sources=[0], targets=[3, 4], weight=None + ) + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + + def test_box(self): + """Betweenness Centrality Subset: box""" + G = nx.Graph() + G.add_edges_from([(0, 1), (0, 2), (1, 3), (2, 3)]) + b_answer = {0: 0, 1: 0.25, 2: 0.25, 3: 0} + b = nx.betweenness_centrality_subset(G, sources=[0], targets=[3], weight=None) + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + + def test_box_and_path(self): + """Betweenness Centrality Subset: box and path""" + G = nx.Graph() + G.add_edges_from([(0, 1), (0, 2), (1, 3), (2, 3), (3, 4), (4, 5)]) + b_answer = {0: 0, 1: 0.5, 2: 0.5, 3: 0.5, 4: 0, 5: 0} + b = nx.betweenness_centrality_subset( + G, sources=[0], targets=[3, 4], weight=None + ) + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + + def test_box_and_path2(self): + """Betweenness Centrality Subset: box and path multiple target""" + G = nx.Graph() + G.add_edges_from([(0, 1), (1, 2), (2, 3), (1, 20), (20, 3), (3, 4)]) + b_answer = {0: 0, 1: 1.0, 2: 0.5, 20: 0.5, 3: 0.5, 4: 0} + b = nx.betweenness_centrality_subset( + G, sources=[0], targets=[3, 4], weight=None + ) + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + + def test_diamond_multi_path(self): + """Betweenness Centrality Subset: Diamond Multi Path""" + G = nx.Graph() + G.add_edges_from( + [ + (1, 2), + (1, 3), + (1, 4), + (1, 5), + (1, 10), + (10, 11), + (11, 12), + (12, 9), + (2, 6), + (3, 6), + (4, 6), + (5, 7), + (7, 8), + (6, 8), + (8, 9), + ] + ) + b = nx.betweenness_centrality_subset(G, sources=[1], targets=[9], weight=None) + + expected_b = { + 1: 0, + 2: 1.0 / 10, + 3: 1.0 / 10, + 4: 1.0 / 10, + 5: 1.0 / 10, + 6: 3.0 / 10, + 7: 1.0 / 10, + 8: 4.0 / 10, + 9: 0, + 10: 1.0 / 10, + 11: 1.0 / 10, + 12: 1.0 / 10, + } + + for n in sorted(G): + assert b[n] == pytest.approx(expected_b[n], abs=1e-7) + + def test_normalized_p2(self): + """ + Betweenness Centrality Subset: Normalized P2 + if n <= 2: no normalization, betweenness centrality should be 0 for all nodes. + """ + G = nx.Graph() + nx.add_path(G, range(2)) + b_answer = {0: 0, 1: 0.0} + b = nx.betweenness_centrality_subset( + G, sources=[0], targets=[1], normalized=True, weight=None + ) + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + + def test_normalized_P5_directed(self): + """Betweenness Centrality Subset: Normalized Directed P5""" + G = nx.DiGraph() + nx.add_path(G, range(5)) + b_answer = {0: 0, 1: 1.0 / 12.0, 2: 1.0 / 12.0, 3: 0, 4: 0, 5: 0} + b = nx.betweenness_centrality_subset( + G, sources=[0], targets=[3], normalized=True, weight=None + ) + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + + def test_weighted_graph(self): + """Betweenness Centrality Subset: Weighted Graph""" + G = nx.DiGraph() + G.add_edge(0, 1, weight=3) + G.add_edge(0, 2, weight=2) + G.add_edge(0, 3, weight=6) + G.add_edge(0, 4, weight=4) + G.add_edge(1, 3, weight=5) + G.add_edge(1, 5, weight=5) + G.add_edge(2, 4, weight=1) + G.add_edge(3, 4, weight=2) + G.add_edge(3, 5, weight=1) + G.add_edge(4, 5, weight=4) + b_answer = {0: 0.0, 1: 0.0, 2: 0.5, 3: 0.5, 4: 0.5, 5: 0.0} + b = nx.betweenness_centrality_subset( + G, sources=[0], targets=[5], normalized=False, weight="weight" + ) + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + + +class TestEdgeSubsetBetweennessCentrality: + def test_K5(self): + """Edge betweenness subset centrality: K5""" + G = nx.complete_graph(5) + b = nx.edge_betweenness_centrality_subset( + G, sources=[0], targets=[1, 3], weight=None + ) + b_answer = dict.fromkeys(G.edges(), 0) + b_answer[(0, 3)] = b_answer[(0, 1)] = 0.5 + for n in sorted(G.edges()): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + + def test_P5_directed(self): + """Edge betweenness subset centrality: P5 directed""" + G = nx.DiGraph() + nx.add_path(G, range(5)) + b_answer = dict.fromkeys(G.edges(), 0) + b_answer[(0, 1)] = b_answer[(1, 2)] = b_answer[(2, 3)] = 1 + b = nx.edge_betweenness_centrality_subset( + G, sources=[0], targets=[3], weight=None + ) + for n in sorted(G.edges()): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + + def test_P5(self): + """Edge betweenness subset centrality: P5""" + G = nx.Graph() + nx.add_path(G, range(5)) + b_answer = dict.fromkeys(G.edges(), 0) + b_answer[(0, 1)] = b_answer[(1, 2)] = b_answer[(2, 3)] = 0.5 + b = nx.edge_betweenness_centrality_subset( + G, sources=[0], targets=[3], weight=None + ) + for n in sorted(G.edges()): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + + def test_P5_multiple_target(self): + """Edge betweenness subset centrality: P5 multiple target""" + G = nx.Graph() + nx.add_path(G, range(5)) + b_answer = dict.fromkeys(G.edges(), 0) + b_answer[(0, 1)] = b_answer[(1, 2)] = b_answer[(2, 3)] = 1 + b_answer[(3, 4)] = 0.5 + b = nx.edge_betweenness_centrality_subset( + G, sources=[0], targets=[3, 4], weight=None + ) + for n in sorted(G.edges()): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + + def test_box(self): + """Edge betweenness subset centrality: box""" + G = nx.Graph() + G.add_edges_from([(0, 1), (0, 2), (1, 3), (2, 3)]) + b_answer = dict.fromkeys(G.edges(), 0) + b_answer[(0, 1)] = b_answer[(0, 2)] = 0.25 + b_answer[(1, 3)] = b_answer[(2, 3)] = 0.25 + b = nx.edge_betweenness_centrality_subset( + G, sources=[0], targets=[3], weight=None + ) + for n in sorted(G.edges()): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + + def test_box_and_path(self): + """Edge betweenness subset centrality: box and path""" + G = nx.Graph() + G.add_edges_from([(0, 1), (0, 2), (1, 3), (2, 3), (3, 4), (4, 5)]) + b_answer = dict.fromkeys(G.edges(), 0) + b_answer[(0, 1)] = b_answer[(0, 2)] = 0.5 + b_answer[(1, 3)] = b_answer[(2, 3)] = 0.5 + b_answer[(3, 4)] = 0.5 + b = nx.edge_betweenness_centrality_subset( + G, sources=[0], targets=[3, 4], weight=None + ) + for n in sorted(G.edges()): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + + def test_box_and_path2(self): + """Edge betweenness subset centrality: box and path multiple target""" + G = nx.Graph() + G.add_edges_from([(0, 1), (1, 2), (2, 3), (1, 20), (20, 3), (3, 4)]) + b_answer = dict.fromkeys(G.edges(), 0) + b_answer[(0, 1)] = 1.0 + b_answer[(1, 20)] = b_answer[(3, 20)] = 0.5 + b_answer[(1, 2)] = b_answer[(2, 3)] = 0.5 + b_answer[(3, 4)] = 0.5 + b = nx.edge_betweenness_centrality_subset( + G, sources=[0], targets=[3, 4], weight=None + ) + for n in sorted(G.edges()): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + + def test_diamond_multi_path(self): + """Edge betweenness subset centrality: Diamond Multi Path""" + G = nx.Graph() + G.add_edges_from( + [ + (1, 2), + (1, 3), + (1, 4), + (1, 5), + (1, 10), + (10, 11), + (11, 12), + (12, 9), + (2, 6), + (3, 6), + (4, 6), + (5, 7), + (7, 8), + (6, 8), + (8, 9), + ] + ) + b_answer = dict.fromkeys(G.edges(), 0) + b_answer[(8, 9)] = 0.4 + b_answer[(6, 8)] = b_answer[(7, 8)] = 0.2 + b_answer[(2, 6)] = b_answer[(3, 6)] = b_answer[(4, 6)] = 0.2 / 3.0 + b_answer[(1, 2)] = b_answer[(1, 3)] = b_answer[(1, 4)] = 0.2 / 3.0 + b_answer[(5, 7)] = 0.2 + b_answer[(1, 5)] = 0.2 + b_answer[(9, 12)] = 0.1 + b_answer[(11, 12)] = b_answer[(10, 11)] = b_answer[(1, 10)] = 0.1 + b = nx.edge_betweenness_centrality_subset( + G, sources=[1], targets=[9], weight=None + ) + for n in G.edges(): + sort_n = tuple(sorted(n)) + assert b[n] == pytest.approx(b_answer[sort_n], abs=1e-7) + + def test_normalized_p1(self): + """ + Edge betweenness subset centrality: P1 + if n <= 1: no normalization b=0 for all nodes + """ + G = nx.Graph() + nx.add_path(G, range(1)) + b_answer = dict.fromkeys(G.edges(), 0) + b = nx.edge_betweenness_centrality_subset( + G, sources=[0], targets=[0], normalized=True, weight=None + ) + for n in G.edges(): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + + def test_normalized_P5_directed(self): + """Edge betweenness subset centrality: Normalized Directed P5""" + G = nx.DiGraph() + nx.add_path(G, range(5)) + b_answer = dict.fromkeys(G.edges(), 0) + b_answer[(0, 1)] = b_answer[(1, 2)] = b_answer[(2, 3)] = 0.05 + b = nx.edge_betweenness_centrality_subset( + G, sources=[0], targets=[3], normalized=True, weight=None + ) + for n in G.edges(): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + + def test_weighted_graph(self): + """Edge betweenness subset centrality: Weighted Graph""" + G = nx.DiGraph() + G.add_edge(0, 1, weight=3) + G.add_edge(0, 2, weight=2) + G.add_edge(0, 3, weight=6) + G.add_edge(0, 4, weight=4) + G.add_edge(1, 3, weight=5) + G.add_edge(1, 5, weight=5) + G.add_edge(2, 4, weight=1) + G.add_edge(3, 4, weight=2) + G.add_edge(3, 5, weight=1) + G.add_edge(4, 5, weight=4) + b_answer = dict.fromkeys(G.edges(), 0) + b_answer[(0, 2)] = b_answer[(2, 4)] = b_answer[(4, 5)] = 0.5 + b_answer[(0, 3)] = b_answer[(3, 5)] = 0.5 + b = nx.edge_betweenness_centrality_subset( + G, sources=[0], targets=[5], normalized=False, weight="weight" + ) + for n in G.edges(): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/test_closeness_centrality.py b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/test_closeness_centrality.py new file mode 100644 index 0000000000000000000000000000000000000000..7bdb7e7c38bbfe644ca914ea071e54308cf8c76e --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/test_closeness_centrality.py @@ -0,0 +1,307 @@ +""" +Tests for closeness centrality. +""" + +import pytest + +import networkx as nx + + +class TestClosenessCentrality: + @classmethod + def setup_class(cls): + cls.K = nx.krackhardt_kite_graph() + cls.P3 = nx.path_graph(3) + cls.P4 = nx.path_graph(4) + cls.K5 = nx.complete_graph(5) + + cls.C4 = nx.cycle_graph(4) + cls.T = nx.balanced_tree(r=2, h=2) + cls.Gb = nx.Graph() + cls.Gb.add_edges_from([(0, 1), (0, 2), (1, 3), (2, 3), (2, 4), (4, 5), (3, 5)]) + + F = nx.florentine_families_graph() + cls.F = F + + cls.LM = nx.les_miserables_graph() + + # Create random undirected, unweighted graph for testing incremental version + cls.undirected_G = nx.fast_gnp_random_graph(n=100, p=0.6, seed=123) + cls.undirected_G_cc = nx.closeness_centrality(cls.undirected_G) + + def test_wf_improved(self): + G = nx.union(self.P4, nx.path_graph([4, 5, 6])) + c = nx.closeness_centrality(G) + cwf = nx.closeness_centrality(G, wf_improved=False) + res = {0: 0.25, 1: 0.375, 2: 0.375, 3: 0.25, 4: 0.222, 5: 0.333, 6: 0.222} + wf_res = {0: 0.5, 1: 0.75, 2: 0.75, 3: 0.5, 4: 0.667, 5: 1.0, 6: 0.667} + for n in G: + assert c[n] == pytest.approx(res[n], abs=1e-3) + assert cwf[n] == pytest.approx(wf_res[n], abs=1e-3) + + def test_digraph(self): + G = nx.path_graph(3, create_using=nx.DiGraph()) + c = nx.closeness_centrality(G) + cr = nx.closeness_centrality(G.reverse()) + d = {0: 0.0, 1: 0.500, 2: 0.667} + dr = {0: 0.667, 1: 0.500, 2: 0.0} + for n in sorted(self.P3): + assert c[n] == pytest.approx(d[n], abs=1e-3) + assert cr[n] == pytest.approx(dr[n], abs=1e-3) + + def test_k5_closeness(self): + c = nx.closeness_centrality(self.K5) + d = {0: 1.000, 1: 1.000, 2: 1.000, 3: 1.000, 4: 1.000} + for n in sorted(self.K5): + assert c[n] == pytest.approx(d[n], abs=1e-3) + + def test_p3_closeness(self): + c = nx.closeness_centrality(self.P3) + d = {0: 0.667, 1: 1.000, 2: 0.667} + for n in sorted(self.P3): + assert c[n] == pytest.approx(d[n], abs=1e-3) + + def test_krackhardt_closeness(self): + c = nx.closeness_centrality(self.K) + d = { + 0: 0.529, + 1: 0.529, + 2: 0.500, + 3: 0.600, + 4: 0.500, + 5: 0.643, + 6: 0.643, + 7: 0.600, + 8: 0.429, + 9: 0.310, + } + for n in sorted(self.K): + assert c[n] == pytest.approx(d[n], abs=1e-3) + + def test_florentine_families_closeness(self): + c = nx.closeness_centrality(self.F) + d = { + "Acciaiuoli": 0.368, + "Albizzi": 0.483, + "Barbadori": 0.4375, + "Bischeri": 0.400, + "Castellani": 0.389, + "Ginori": 0.333, + "Guadagni": 0.467, + "Lamberteschi": 0.326, + "Medici": 0.560, + "Pazzi": 0.286, + "Peruzzi": 0.368, + "Ridolfi": 0.500, + "Salviati": 0.389, + "Strozzi": 0.4375, + "Tornabuoni": 0.483, + } + for n in sorted(self.F): + assert c[n] == pytest.approx(d[n], abs=1e-3) + + def test_les_miserables_closeness(self): + c = nx.closeness_centrality(self.LM) + d = { + "Napoleon": 0.302, + "Myriel": 0.429, + "MlleBaptistine": 0.413, + "MmeMagloire": 0.413, + "CountessDeLo": 0.302, + "Geborand": 0.302, + "Champtercier": 0.302, + "Cravatte": 0.302, + "Count": 0.302, + "OldMan": 0.302, + "Valjean": 0.644, + "Labarre": 0.394, + "Marguerite": 0.413, + "MmeDeR": 0.394, + "Isabeau": 0.394, + "Gervais": 0.394, + "Listolier": 0.341, + "Tholomyes": 0.392, + "Fameuil": 0.341, + "Blacheville": 0.341, + "Favourite": 0.341, + "Dahlia": 0.341, + "Zephine": 0.341, + "Fantine": 0.461, + "MmeThenardier": 0.461, + "Thenardier": 0.517, + "Cosette": 0.478, + "Javert": 0.517, + "Fauchelevent": 0.402, + "Bamatabois": 0.427, + "Perpetue": 0.318, + "Simplice": 0.418, + "Scaufflaire": 0.394, + "Woman1": 0.396, + "Judge": 0.404, + "Champmathieu": 0.404, + "Brevet": 0.404, + "Chenildieu": 0.404, + "Cochepaille": 0.404, + "Pontmercy": 0.373, + "Boulatruelle": 0.342, + "Eponine": 0.396, + "Anzelma": 0.352, + "Woman2": 0.402, + "MotherInnocent": 0.398, + "Gribier": 0.288, + "MmeBurgon": 0.344, + "Jondrette": 0.257, + "Gavroche": 0.514, + "Gillenormand": 0.442, + "Magnon": 0.335, + "MlleGillenormand": 0.442, + "MmePontmercy": 0.315, + "MlleVaubois": 0.308, + "LtGillenormand": 0.365, + "Marius": 0.531, + "BaronessT": 0.352, + "Mabeuf": 0.396, + "Enjolras": 0.481, + "Combeferre": 0.392, + "Prouvaire": 0.357, + "Feuilly": 0.392, + "Courfeyrac": 0.400, + "Bahorel": 0.394, + "Bossuet": 0.475, + "Joly": 0.394, + "Grantaire": 0.358, + "MotherPlutarch": 0.285, + "Gueulemer": 0.463, + "Babet": 0.463, + "Claquesous": 0.452, + "Montparnasse": 0.458, + "Toussaint": 0.402, + "Child1": 0.342, + "Child2": 0.342, + "Brujon": 0.380, + "MmeHucheloup": 0.353, + } + for n in sorted(self.LM): + assert c[n] == pytest.approx(d[n], abs=1e-3) + + def test_weighted_closeness(self): + edges = [ + ("s", "u", 10), + ("s", "x", 5), + ("u", "v", 1), + ("u", "x", 2), + ("v", "y", 1), + ("x", "u", 3), + ("x", "v", 5), + ("x", "y", 2), + ("y", "s", 7), + ("y", "v", 6), + ] + XG = nx.Graph() + XG.add_weighted_edges_from(edges) + c = nx.closeness_centrality(XG, distance="weight") + d = {"y": 0.200, "x": 0.286, "s": 0.138, "u": 0.235, "v": 0.200} + for n in sorted(XG): + assert c[n] == pytest.approx(d[n], abs=1e-3) + + # + # Tests for incremental closeness centrality. + # + @staticmethod + def pick_add_edge(g): + u = nx.utils.arbitrary_element(g) + possible_nodes = set(g.nodes()) + neighbors = list(g.neighbors(u)) + [u] + possible_nodes.difference_update(neighbors) + v = nx.utils.arbitrary_element(possible_nodes) + return (u, v) + + @staticmethod + def pick_remove_edge(g): + u = nx.utils.arbitrary_element(g) + possible_nodes = list(g.neighbors(u)) + v = nx.utils.arbitrary_element(possible_nodes) + return (u, v) + + def test_directed_raises(self): + with pytest.raises(nx.NetworkXNotImplemented): + dir_G = nx.gn_graph(n=5) + prev_cc = None + edge = self.pick_add_edge(dir_G) + insert = True + nx.incremental_closeness_centrality(dir_G, edge, prev_cc, insert) + + def test_wrong_size_prev_cc_raises(self): + with pytest.raises(nx.NetworkXError): + G = self.undirected_G.copy() + edge = self.pick_add_edge(G) + insert = True + prev_cc = self.undirected_G_cc.copy() + prev_cc.pop(0) + nx.incremental_closeness_centrality(G, edge, prev_cc, insert) + + def test_wrong_nodes_prev_cc_raises(self): + with pytest.raises(nx.NetworkXError): + G = self.undirected_G.copy() + edge = self.pick_add_edge(G) + insert = True + prev_cc = self.undirected_G_cc.copy() + num_nodes = len(prev_cc) + prev_cc.pop(0) + prev_cc[num_nodes] = 0.5 + nx.incremental_closeness_centrality(G, edge, prev_cc, insert) + + def test_zero_centrality(self): + G = nx.path_graph(3) + prev_cc = nx.closeness_centrality(G) + edge = self.pick_remove_edge(G) + test_cc = nx.incremental_closeness_centrality(G, edge, prev_cc, insertion=False) + G.remove_edges_from([edge]) + real_cc = nx.closeness_centrality(G) + shared_items = set(test_cc.items()) & set(real_cc.items()) + assert len(shared_items) == len(real_cc) + assert 0 in test_cc.values() + + def test_incremental(self): + # Check that incremental and regular give same output + G = self.undirected_G.copy() + prev_cc = None + for i in range(5): + if i % 2 == 0: + # Remove an edge + insert = False + edge = self.pick_remove_edge(G) + else: + # Add an edge + insert = True + edge = self.pick_add_edge(G) + + # start = timeit.default_timer() + test_cc = nx.incremental_closeness_centrality(G, edge, prev_cc, insert) + # inc_elapsed = (timeit.default_timer() - start) + # print(f"incremental time: {inc_elapsed}") + + if insert: + G.add_edges_from([edge]) + else: + G.remove_edges_from([edge]) + + # start = timeit.default_timer() + real_cc = nx.closeness_centrality(G) + # reg_elapsed = (timeit.default_timer() - start) + # print(f"regular time: {reg_elapsed}") + # Example output: + # incremental time: 0.208 + # regular time: 0.276 + # incremental time: 0.00683 + # regular time: 0.260 + # incremental time: 0.0224 + # regular time: 0.278 + # incremental time: 0.00804 + # regular time: 0.208 + # incremental time: 0.00947 + # regular time: 0.188 + + assert set(test_cc.items()) == set(real_cc.items()) + + prev_cc = test_cc diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/test_current_flow_betweenness_centrality.py b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/test_current_flow_betweenness_centrality.py new file mode 100644 index 0000000000000000000000000000000000000000..4e3d4385c9b266975140d49b739d09fbd449d8a6 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/test_current_flow_betweenness_centrality.py @@ -0,0 +1,197 @@ +import pytest + +import networkx as nx +from networkx import approximate_current_flow_betweenness_centrality as approximate_cfbc +from networkx import edge_current_flow_betweenness_centrality as edge_current_flow + +np = pytest.importorskip("numpy") +pytest.importorskip("scipy") + + +class TestFlowBetweennessCentrality: + def test_K4_normalized(self): + """Betweenness centrality: K4""" + G = nx.complete_graph(4) + b = nx.current_flow_betweenness_centrality(G, normalized=True) + b_answer = {0: 0.25, 1: 0.25, 2: 0.25, 3: 0.25} + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + G.add_edge(0, 1, weight=0.5, other=0.3) + b = nx.current_flow_betweenness_centrality(G, normalized=True, weight=None) + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + wb_answer = {0: 0.2222222, 1: 0.2222222, 2: 0.30555555, 3: 0.30555555} + b = nx.current_flow_betweenness_centrality(G, normalized=True, weight="weight") + for n in sorted(G): + assert b[n] == pytest.approx(wb_answer[n], abs=1e-7) + wb_answer = {0: 0.2051282, 1: 0.2051282, 2: 0.33974358, 3: 0.33974358} + b = nx.current_flow_betweenness_centrality(G, normalized=True, weight="other") + for n in sorted(G): + assert b[n] == pytest.approx(wb_answer[n], abs=1e-7) + + def test_K4(self): + """Betweenness centrality: K4""" + G = nx.complete_graph(4) + for solver in ["full", "lu", "cg"]: + b = nx.current_flow_betweenness_centrality( + G, normalized=False, solver=solver + ) + b_answer = {0: 0.75, 1: 0.75, 2: 0.75, 3: 0.75} + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + + def test_P4_normalized(self): + """Betweenness centrality: P4 normalized""" + G = nx.path_graph(4) + b = nx.current_flow_betweenness_centrality(G, normalized=True) + b_answer = {0: 0, 1: 2.0 / 3, 2: 2.0 / 3, 3: 0} + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + + def test_P4(self): + """Betweenness centrality: P4""" + G = nx.path_graph(4) + b = nx.current_flow_betweenness_centrality(G, normalized=False) + b_answer = {0: 0, 1: 2, 2: 2, 3: 0} + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + + def test_star(self): + """Betweenness centrality: star""" + G = nx.Graph() + nx.add_star(G, ["a", "b", "c", "d"]) + b = nx.current_flow_betweenness_centrality(G, normalized=True) + b_answer = {"a": 1.0, "b": 0.0, "c": 0.0, "d": 0.0} + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + + def test_solvers2(self): + """Betweenness centrality: alternate solvers""" + G = nx.complete_graph(4) + for solver in ["full", "lu", "cg"]: + b = nx.current_flow_betweenness_centrality( + G, normalized=False, solver=solver + ) + b_answer = {0: 0.75, 1: 0.75, 2: 0.75, 3: 0.75} + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + + +class TestApproximateFlowBetweennessCentrality: + def test_K4_normalized(self): + "Approximate current-flow betweenness centrality: K4 normalized" + G = nx.complete_graph(4) + b = nx.current_flow_betweenness_centrality(G, normalized=True) + epsilon = 0.1 + ba = approximate_cfbc(G, normalized=True, epsilon=0.5 * epsilon) + for n in sorted(G): + np.testing.assert_allclose(b[n], ba[n], atol=epsilon) + + def test_K4(self): + "Approximate current-flow betweenness centrality: K4" + G = nx.complete_graph(4) + b = nx.current_flow_betweenness_centrality(G, normalized=False) + epsilon = 0.1 + ba = approximate_cfbc(G, normalized=False, epsilon=0.5 * epsilon) + for n in sorted(G): + np.testing.assert_allclose(b[n], ba[n], atol=epsilon * len(G) ** 2) + + def test_star(self): + "Approximate current-flow betweenness centrality: star" + G = nx.Graph() + nx.add_star(G, ["a", "b", "c", "d"]) + b = nx.current_flow_betweenness_centrality(G, normalized=True) + epsilon = 0.1 + ba = approximate_cfbc(G, normalized=True, epsilon=0.5 * epsilon) + for n in sorted(G): + np.testing.assert_allclose(b[n], ba[n], atol=epsilon) + + def test_grid(self): + "Approximate current-flow betweenness centrality: 2d grid" + G = nx.grid_2d_graph(4, 4) + b = nx.current_flow_betweenness_centrality(G, normalized=True) + epsilon = 0.1 + ba = approximate_cfbc(G, normalized=True, epsilon=0.5 * epsilon) + for n in sorted(G): + np.testing.assert_allclose(b[n], ba[n], atol=epsilon) + + def test_seed(self): + G = nx.complete_graph(4) + b = approximate_cfbc(G, normalized=False, epsilon=0.05, seed=1) + b_answer = {0: 0.75, 1: 0.75, 2: 0.75, 3: 0.75} + for n in sorted(G): + np.testing.assert_allclose(b[n], b_answer[n], atol=0.1) + + def test_solvers(self): + "Approximate current-flow betweenness centrality: solvers" + G = nx.complete_graph(4) + epsilon = 0.1 + for solver in ["full", "lu", "cg"]: + b = approximate_cfbc( + G, normalized=False, solver=solver, epsilon=0.5 * epsilon + ) + b_answer = {0: 0.75, 1: 0.75, 2: 0.75, 3: 0.75} + for n in sorted(G): + np.testing.assert_allclose(b[n], b_answer[n], atol=epsilon) + + def test_lower_kmax(self): + G = nx.complete_graph(4) + with pytest.raises(nx.NetworkXError, match="Increase kmax or epsilon"): + nx.approximate_current_flow_betweenness_centrality(G, kmax=4) + + +class TestWeightedFlowBetweennessCentrality: + pass + + +class TestEdgeFlowBetweennessCentrality: + def test_K4(self): + """Edge flow betweenness centrality: K4""" + G = nx.complete_graph(4) + b = edge_current_flow(G, normalized=True) + b_answer = dict.fromkeys(G.edges(), 0.25) + for (s, t), v1 in b_answer.items(): + v2 = b.get((s, t), b.get((t, s))) + assert v1 == pytest.approx(v2, abs=1e-7) + + def test_K4_normalized(self): + """Edge flow betweenness centrality: K4""" + G = nx.complete_graph(4) + b = edge_current_flow(G, normalized=False) + b_answer = dict.fromkeys(G.edges(), 0.75) + for (s, t), v1 in b_answer.items(): + v2 = b.get((s, t), b.get((t, s))) + assert v1 == pytest.approx(v2, abs=1e-7) + + def test_C4(self): + """Edge flow betweenness centrality: C4""" + G = nx.cycle_graph(4) + b = edge_current_flow(G, normalized=False) + b_answer = {(0, 1): 1.25, (0, 3): 1.25, (1, 2): 1.25, (2, 3): 1.25} + for (s, t), v1 in b_answer.items(): + v2 = b.get((s, t), b.get((t, s))) + assert v1 == pytest.approx(v2, abs=1e-7) + + def test_P4(self): + """Edge betweenness centrality: P4""" + G = nx.path_graph(4) + b = edge_current_flow(G, normalized=False) + b_answer = {(0, 1): 1.5, (1, 2): 2.0, (2, 3): 1.5} + for (s, t), v1 in b_answer.items(): + v2 = b.get((s, t), b.get((t, s))) + assert v1 == pytest.approx(v2, abs=1e-7) + + +@pytest.mark.parametrize( + "centrality_func", + ( + nx.current_flow_betweenness_centrality, + nx.edge_current_flow_betweenness_centrality, + nx.approximate_current_flow_betweenness_centrality, + ), +) +def test_unconnected_graphs_betweenness_centrality(centrality_func): + G = nx.Graph([(1, 2), (3, 4)]) + G.add_node(5) + with pytest.raises(nx.NetworkXError, match="Graph not connected"): + centrality_func(G) diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/test_current_flow_betweenness_centrality_subset.py b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/test_current_flow_betweenness_centrality_subset.py new file mode 100644 index 0000000000000000000000000000000000000000..7b1611b07bbf890f5e45bba7a42c298bd8f4e749 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/test_current_flow_betweenness_centrality_subset.py @@ -0,0 +1,147 @@ +import pytest + +pytest.importorskip("numpy") +pytest.importorskip("scipy") + +import networkx as nx +from networkx import edge_current_flow_betweenness_centrality as edge_current_flow +from networkx import ( + edge_current_flow_betweenness_centrality_subset as edge_current_flow_subset, +) + + +class TestFlowBetweennessCentrality: + def test_K4_normalized(self): + """Betweenness centrality: K4""" + G = nx.complete_graph(4) + b = nx.current_flow_betweenness_centrality_subset( + G, list(G), list(G), normalized=True + ) + b_answer = nx.current_flow_betweenness_centrality(G, normalized=True) + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + + def test_K4(self): + """Betweenness centrality: K4""" + G = nx.complete_graph(4) + b = nx.current_flow_betweenness_centrality_subset( + G, list(G), list(G), normalized=True + ) + b_answer = nx.current_flow_betweenness_centrality(G, normalized=True) + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + # test weighted network + G.add_edge(0, 1, weight=0.5, other=0.3) + b = nx.current_flow_betweenness_centrality_subset( + G, list(G), list(G), normalized=True, weight=None + ) + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + b = nx.current_flow_betweenness_centrality_subset( + G, list(G), list(G), normalized=True + ) + b_answer = nx.current_flow_betweenness_centrality(G, normalized=True) + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + b = nx.current_flow_betweenness_centrality_subset( + G, list(G), list(G), normalized=True, weight="other" + ) + b_answer = nx.current_flow_betweenness_centrality( + G, normalized=True, weight="other" + ) + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + + def test_P4_normalized(self): + """Betweenness centrality: P4 normalized""" + G = nx.path_graph(4) + b = nx.current_flow_betweenness_centrality_subset( + G, list(G), list(G), normalized=True + ) + b_answer = nx.current_flow_betweenness_centrality(G, normalized=True) + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + + def test_P4(self): + """Betweenness centrality: P4""" + G = nx.path_graph(4) + b = nx.current_flow_betweenness_centrality_subset( + G, list(G), list(G), normalized=True + ) + b_answer = nx.current_flow_betweenness_centrality(G, normalized=True) + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + + def test_star(self): + """Betweenness centrality: star""" + G = nx.Graph() + nx.add_star(G, ["a", "b", "c", "d"]) + b = nx.current_flow_betweenness_centrality_subset( + G, list(G), list(G), normalized=True + ) + b_answer = nx.current_flow_betweenness_centrality(G, normalized=True) + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + + +# class TestWeightedFlowBetweennessCentrality(): +# pass + + +class TestEdgeFlowBetweennessCentrality: + def test_K4_normalized(self): + """Betweenness centrality: K4""" + G = nx.complete_graph(4) + b = edge_current_flow_subset(G, list(G), list(G), normalized=True) + b_answer = edge_current_flow(G, normalized=True) + for (s, t), v1 in b_answer.items(): + v2 = b.get((s, t), b.get((t, s))) + assert v1 == pytest.approx(v2, abs=1e-7) + + def test_K4(self): + """Betweenness centrality: K4""" + G = nx.complete_graph(4) + b = edge_current_flow_subset(G, list(G), list(G), normalized=False) + b_answer = edge_current_flow(G, normalized=False) + for (s, t), v1 in b_answer.items(): + v2 = b.get((s, t), b.get((t, s))) + assert v1 == pytest.approx(v2, abs=1e-7) + # test weighted network + G.add_edge(0, 1, weight=0.5, other=0.3) + b = edge_current_flow_subset(G, list(G), list(G), normalized=False, weight=None) + # weight is None => same as unweighted network + for (s, t), v1 in b_answer.items(): + v2 = b.get((s, t), b.get((t, s))) + assert v1 == pytest.approx(v2, abs=1e-7) + + b = edge_current_flow_subset(G, list(G), list(G), normalized=False) + b_answer = edge_current_flow(G, normalized=False) + for (s, t), v1 in b_answer.items(): + v2 = b.get((s, t), b.get((t, s))) + assert v1 == pytest.approx(v2, abs=1e-7) + + b = edge_current_flow_subset( + G, list(G), list(G), normalized=False, weight="other" + ) + b_answer = edge_current_flow(G, normalized=False, weight="other") + for (s, t), v1 in b_answer.items(): + v2 = b.get((s, t), b.get((t, s))) + assert v1 == pytest.approx(v2, abs=1e-7) + + def test_C4(self): + """Edge betweenness centrality: C4""" + G = nx.cycle_graph(4) + b = edge_current_flow_subset(G, list(G), list(G), normalized=True) + b_answer = edge_current_flow(G, normalized=True) + for (s, t), v1 in b_answer.items(): + v2 = b.get((s, t), b.get((t, s))) + assert v1 == pytest.approx(v2, abs=1e-7) + + def test_P4(self): + """Edge betweenness centrality: P4""" + G = nx.path_graph(4) + b = edge_current_flow_subset(G, list(G), list(G), normalized=True) + b_answer = edge_current_flow(G, normalized=True) + for (s, t), v1 in b_answer.items(): + v2 = b.get((s, t), b.get((t, s))) + assert v1 == pytest.approx(v2, abs=1e-7) diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/test_current_flow_closeness.py b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/test_current_flow_closeness.py new file mode 100644 index 0000000000000000000000000000000000000000..2528d622855938b8f569d4fb33309ebed1dbd7c8 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/test_current_flow_closeness.py @@ -0,0 +1,43 @@ +import pytest + +pytest.importorskip("numpy") +pytest.importorskip("scipy") + +import networkx as nx + + +class TestFlowClosenessCentrality: + def test_K4(self): + """Closeness centrality: K4""" + G = nx.complete_graph(4) + b = nx.current_flow_closeness_centrality(G) + b_answer = {0: 2.0 / 3, 1: 2.0 / 3, 2: 2.0 / 3, 3: 2.0 / 3} + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + + def test_P4(self): + """Closeness centrality: P4""" + G = nx.path_graph(4) + b = nx.current_flow_closeness_centrality(G) + b_answer = {0: 1.0 / 6, 1: 1.0 / 4, 2: 1.0 / 4, 3: 1.0 / 6} + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + + def test_star(self): + """Closeness centrality: star""" + G = nx.Graph() + nx.add_star(G, ["a", "b", "c", "d"]) + b = nx.current_flow_closeness_centrality(G) + b_answer = {"a": 1.0 / 3, "b": 0.6 / 3, "c": 0.6 / 3, "d": 0.6 / 3} + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + + def test_current_flow_closeness_centrality_not_connected(self): + G = nx.Graph() + G.add_nodes_from([1, 2, 3]) + with pytest.raises(nx.NetworkXError): + nx.current_flow_closeness_centrality(G) + + +class TestWeightedFlowClosenessCentrality: + pass diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/test_degree_centrality.py b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/test_degree_centrality.py new file mode 100644 index 0000000000000000000000000000000000000000..e39aa3b19f248acdd3a23e126e426bfd1c45c4c7 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/test_degree_centrality.py @@ -0,0 +1,144 @@ +""" +Unit tests for degree centrality. +""" + +import pytest + +import networkx as nx + + +class TestDegreeCentrality: + def setup_method(self): + self.K = nx.krackhardt_kite_graph() + self.P3 = nx.path_graph(3) + self.K5 = nx.complete_graph(5) + + F = nx.Graph() # Florentine families + F.add_edge("Acciaiuoli", "Medici") + F.add_edge("Castellani", "Peruzzi") + F.add_edge("Castellani", "Strozzi") + F.add_edge("Castellani", "Barbadori") + F.add_edge("Medici", "Barbadori") + F.add_edge("Medici", "Ridolfi") + F.add_edge("Medici", "Tornabuoni") + F.add_edge("Medici", "Albizzi") + F.add_edge("Medici", "Salviati") + F.add_edge("Salviati", "Pazzi") + F.add_edge("Peruzzi", "Strozzi") + F.add_edge("Peruzzi", "Bischeri") + F.add_edge("Strozzi", "Ridolfi") + F.add_edge("Strozzi", "Bischeri") + F.add_edge("Ridolfi", "Tornabuoni") + F.add_edge("Tornabuoni", "Guadagni") + F.add_edge("Albizzi", "Ginori") + F.add_edge("Albizzi", "Guadagni") + F.add_edge("Bischeri", "Guadagni") + F.add_edge("Guadagni", "Lamberteschi") + self.F = F + + G = nx.DiGraph() + G.add_edge(0, 5) + G.add_edge(1, 5) + G.add_edge(2, 5) + G.add_edge(3, 5) + G.add_edge(4, 5) + G.add_edge(5, 6) + G.add_edge(5, 7) + G.add_edge(5, 8) + self.G = G + + def test_degree_centrality_1(self): + d = nx.degree_centrality(self.K5) + exact = dict(zip(range(5), [1] * 5)) + for n, dc in d.items(): + assert exact[n] == pytest.approx(dc, abs=1e-7) + + def test_degree_centrality_2(self): + d = nx.degree_centrality(self.P3) + exact = {0: 0.5, 1: 1, 2: 0.5} + for n, dc in d.items(): + assert exact[n] == pytest.approx(dc, abs=1e-7) + + def test_degree_centrality_3(self): + d = nx.degree_centrality(self.K) + exact = { + 0: 0.444, + 1: 0.444, + 2: 0.333, + 3: 0.667, + 4: 0.333, + 5: 0.556, + 6: 0.556, + 7: 0.333, + 8: 0.222, + 9: 0.111, + } + for n, dc in d.items(): + assert exact[n] == pytest.approx(float(f"{dc:.3f}"), abs=1e-7) + + def test_degree_centrality_4(self): + d = nx.degree_centrality(self.F) + names = sorted(self.F.nodes()) + dcs = [ + 0.071, + 0.214, + 0.143, + 0.214, + 0.214, + 0.071, + 0.286, + 0.071, + 0.429, + 0.071, + 0.214, + 0.214, + 0.143, + 0.286, + 0.214, + ] + exact = dict(zip(names, dcs)) + for n, dc in d.items(): + assert exact[n] == pytest.approx(float(f"{dc:.3f}"), abs=1e-7) + + def test_indegree_centrality(self): + d = nx.in_degree_centrality(self.G) + exact = { + 0: 0.0, + 1: 0.0, + 2: 0.0, + 3: 0.0, + 4: 0.0, + 5: 0.625, + 6: 0.125, + 7: 0.125, + 8: 0.125, + } + for n, dc in d.items(): + assert exact[n] == pytest.approx(dc, abs=1e-7) + + def test_outdegree_centrality(self): + d = nx.out_degree_centrality(self.G) + exact = { + 0: 0.125, + 1: 0.125, + 2: 0.125, + 3: 0.125, + 4: 0.125, + 5: 0.375, + 6: 0.0, + 7: 0.0, + 8: 0.0, + } + for n, dc in d.items(): + assert exact[n] == pytest.approx(dc, abs=1e-7) + + def test_small_graph_centrality(self): + G = nx.empty_graph(create_using=nx.DiGraph) + assert {} == nx.degree_centrality(G) + assert {} == nx.out_degree_centrality(G) + assert {} == nx.in_degree_centrality(G) + + G = nx.empty_graph(1, create_using=nx.DiGraph) + assert {0: 1} == nx.degree_centrality(G) + assert {0: 1} == nx.out_degree_centrality(G) + assert {0: 1} == nx.in_degree_centrality(G) diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/test_dispersion.py b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/test_dispersion.py new file mode 100644 index 0000000000000000000000000000000000000000..05de1c43659a44f2dbf45368bf2ee552dd61dd78 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/test_dispersion.py @@ -0,0 +1,73 @@ +import networkx as nx + + +def small_ego_G(): + """The sample network from https://arxiv.org/pdf/1310.6753v1.pdf""" + edges = [ + ("a", "b"), + ("a", "c"), + ("b", "c"), + ("b", "d"), + ("b", "e"), + ("b", "f"), + ("c", "d"), + ("c", "f"), + ("c", "h"), + ("d", "f"), + ("e", "f"), + ("f", "h"), + ("h", "j"), + ("h", "k"), + ("i", "j"), + ("i", "k"), + ("j", "k"), + ("u", "a"), + ("u", "b"), + ("u", "c"), + ("u", "d"), + ("u", "e"), + ("u", "f"), + ("u", "g"), + ("u", "h"), + ("u", "i"), + ("u", "j"), + ("u", "k"), + ] + G = nx.Graph() + G.add_edges_from(edges) + + return G + + +class TestDispersion: + def test_article(self): + """our algorithm matches article's""" + G = small_ego_G() + disp_uh = nx.dispersion(G, "u", "h", normalized=False) + disp_ub = nx.dispersion(G, "u", "b", normalized=False) + assert disp_uh == 4 + assert disp_ub == 1 + + def test_results_length(self): + """there is a result for every node""" + G = small_ego_G() + disp = nx.dispersion(G) + disp_Gu = nx.dispersion(G, "u") + disp_uv = nx.dispersion(G, "u", "h") + assert len(disp) == len(G) + assert len(disp_Gu) == len(G) - 1 + assert isinstance(disp_uv, float) + + def test_dispersion_v_only(self): + G = small_ego_G() + disp_G_h = nx.dispersion(G, v="h", normalized=False) + disp_G_h_normalized = nx.dispersion(G, v="h", normalized=True) + assert disp_G_h == {"c": 0, "f": 0, "j": 0, "k": 0, "u": 4} + assert disp_G_h_normalized == {"c": 0.0, "f": 0.0, "j": 0.0, "k": 0.0, "u": 1.0} + + def test_impossible_things(self): + G = nx.karate_club_graph() + disp = nx.dispersion(G) + for u in disp: + for v in disp[u]: + assert disp[u][v] >= 0 diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/test_eigenvector_centrality.py b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/test_eigenvector_centrality.py new file mode 100644 index 0000000000000000000000000000000000000000..cfc9ee7953f9088f958cfa108e29ffd17e723223 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/test_eigenvector_centrality.py @@ -0,0 +1,187 @@ +import math + +import pytest + +np = pytest.importorskip("numpy") +pytest.importorskip("scipy") + + +import networkx as nx + + +class TestEigenvectorCentrality: + def test_K5(self): + """Eigenvector centrality: K5""" + G = nx.complete_graph(5) + b = nx.eigenvector_centrality(G) + v = math.sqrt(1 / 5.0) + b_answer = dict.fromkeys(G, v) + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + nstart = {n: 1 for n in G} + b = nx.eigenvector_centrality(G, nstart=nstart) + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + + b = nx.eigenvector_centrality_numpy(G) + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-3) + + def test_P3(self): + """Eigenvector centrality: P3""" + G = nx.path_graph(3) + b_answer = {0: 0.5, 1: 0.7071, 2: 0.5} + b = nx.eigenvector_centrality_numpy(G) + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-4) + b = nx.eigenvector_centrality(G) + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-4) + + def test_P3_unweighted(self): + """Eigenvector centrality: P3""" + G = nx.path_graph(3) + b_answer = {0: 0.5, 1: 0.7071, 2: 0.5} + b = nx.eigenvector_centrality_numpy(G, weight=None) + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-4) + + def test_maxiter(self): + with pytest.raises(nx.PowerIterationFailedConvergence): + G = nx.path_graph(3) + nx.eigenvector_centrality(G, max_iter=0) + + +class TestEigenvectorCentralityDirected: + @classmethod + def setup_class(cls): + G = nx.DiGraph() + + edges = [ + (1, 2), + (1, 3), + (2, 4), + (3, 2), + (3, 5), + (4, 2), + (4, 5), + (4, 6), + (5, 6), + (5, 7), + (5, 8), + (6, 8), + (7, 1), + (7, 5), + (7, 8), + (8, 6), + (8, 7), + ] + + G.add_edges_from(edges, weight=2.0) + cls.G = G.reverse() + cls.G.evc = [ + 0.25368793, + 0.19576478, + 0.32817092, + 0.40430835, + 0.48199885, + 0.15724483, + 0.51346196, + 0.32475403, + ] + + H = nx.DiGraph() + + edges = [ + (1, 2), + (1, 3), + (2, 4), + (3, 2), + (3, 5), + (4, 2), + (4, 5), + (4, 6), + (5, 6), + (5, 7), + (5, 8), + (6, 8), + (7, 1), + (7, 5), + (7, 8), + (8, 6), + (8, 7), + ] + + G.add_edges_from(edges) + cls.H = G.reverse() + cls.H.evc = [ + 0.25368793, + 0.19576478, + 0.32817092, + 0.40430835, + 0.48199885, + 0.15724483, + 0.51346196, + 0.32475403, + ] + + def test_eigenvector_centrality_weighted(self): + G = self.G + p = nx.eigenvector_centrality(G) + for a, b in zip(list(p.values()), self.G.evc): + assert a == pytest.approx(b, abs=1e-4) + + def test_eigenvector_centrality_weighted_numpy(self): + G = self.G + p = nx.eigenvector_centrality_numpy(G) + for a, b in zip(list(p.values()), self.G.evc): + assert a == pytest.approx(b, abs=1e-7) + + def test_eigenvector_centrality_unweighted(self): + G = self.H + p = nx.eigenvector_centrality(G) + for a, b in zip(list(p.values()), self.G.evc): + assert a == pytest.approx(b, abs=1e-4) + + def test_eigenvector_centrality_unweighted_numpy(self): + G = self.H + p = nx.eigenvector_centrality_numpy(G) + for a, b in zip(list(p.values()), self.G.evc): + assert a == pytest.approx(b, abs=1e-7) + + +class TestEigenvectorCentralityExceptions: + def test_multigraph(self): + with pytest.raises(nx.NetworkXException): + nx.eigenvector_centrality(nx.MultiGraph()) + + def test_multigraph_numpy(self): + with pytest.raises(nx.NetworkXException): + nx.eigenvector_centrality_numpy(nx.MultiGraph()) + + def test_null(self): + with pytest.raises(nx.NetworkXException): + nx.eigenvector_centrality(nx.Graph()) + + def test_null_numpy(self): + with pytest.raises(nx.NetworkXException): + nx.eigenvector_centrality_numpy(nx.Graph()) + + @pytest.mark.parametrize( + "G", + [ + nx.empty_graph(3), + nx.DiGraph([(0, 1), (1, 2)]), + ], + ) + def test_disconnected_numpy(self, G): + msg = "does not give consistent results for disconnected" + with pytest.raises(nx.AmbiguousSolution, match=msg): + nx.eigenvector_centrality_numpy(G) + + def test_zero_nstart(self): + G = nx.Graph([(1, 2), (1, 3), (2, 3)]) + with pytest.raises( + nx.NetworkXException, match="initial vector cannot have all zero values" + ): + nx.eigenvector_centrality(G, nstart={v: 0 for v in G}) diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/test_group.py b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/test_group.py new file mode 100644 index 0000000000000000000000000000000000000000..82343f28702382c18676fc776511bd7efdf22a78 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/test_group.py @@ -0,0 +1,277 @@ +""" +Tests for Group Centrality Measures +""" + +import pytest + +import networkx as nx + + +class TestGroupBetweennessCentrality: + def test_group_betweenness_single_node(self): + """ + Group betweenness centrality for single node group + """ + G = nx.path_graph(5) + C = [1] + b = nx.group_betweenness_centrality( + G, C, weight=None, normalized=False, endpoints=False + ) + b_answer = 3.0 + assert b == b_answer + + def test_group_betweenness_with_endpoints(self): + """ + Group betweenness centrality for single node group + """ + G = nx.path_graph(5) + C = [1] + b = nx.group_betweenness_centrality( + G, C, weight=None, normalized=False, endpoints=True + ) + b_answer = 7.0 + assert b == b_answer + + def test_group_betweenness_normalized(self): + """ + Group betweenness centrality for group with more than + 1 node and normalized + """ + G = nx.path_graph(5) + C = [1, 3] + b = nx.group_betweenness_centrality( + G, C, weight=None, normalized=True, endpoints=False + ) + b_answer = 1.0 + assert b == b_answer + + def test_two_group_betweenness_value_zero(self): + """ + Group betweenness centrality value of 0 + """ + G = nx.cycle_graph(7) + C = [[0, 1, 6], [0, 1, 5]] + b = nx.group_betweenness_centrality(G, C, weight=None, normalized=False) + b_answer = [0.0, 3.0] + assert b == b_answer + + def test_group_betweenness_value_zero(self): + """ + Group betweenness centrality value of 0 + """ + G = nx.cycle_graph(6) + C = [0, 1, 5] + b = nx.group_betweenness_centrality(G, C, weight=None, normalized=False) + b_answer = 0.0 + assert b == b_answer + + def test_group_betweenness_disconnected_graph(self): + """ + Group betweenness centrality in a disconnected graph + """ + G = nx.path_graph(5) + G.remove_edge(0, 1) + C = [1] + b = nx.group_betweenness_centrality(G, C, weight=None, normalized=False) + b_answer = 0.0 + assert b == b_answer + + def test_group_betweenness_node_not_in_graph(self): + """ + Node(s) in C not in graph, raises NodeNotFound exception + """ + with pytest.raises(nx.NodeNotFound): + nx.group_betweenness_centrality(nx.path_graph(5), [4, 7, 8]) + + def test_group_betweenness_directed_weighted(self): + """ + Group betweenness centrality in a directed and weighted graph + """ + G = nx.DiGraph() + G.add_edge(1, 0, weight=1) + G.add_edge(0, 2, weight=2) + G.add_edge(1, 2, weight=3) + G.add_edge(3, 1, weight=4) + G.add_edge(2, 3, weight=1) + G.add_edge(4, 3, weight=6) + G.add_edge(2, 4, weight=7) + C = [1, 2] + b = nx.group_betweenness_centrality(G, C, weight="weight", normalized=False) + b_answer = 5.0 + assert b == b_answer + + +class TestProminentGroup: + np = pytest.importorskip("numpy") + pd = pytest.importorskip("pandas") + + def test_prominent_group_single_node(self): + """ + Prominent group for single node + """ + G = nx.path_graph(5) + k = 1 + b, g = nx.prominent_group(G, k, normalized=False, endpoints=False) + b_answer, g_answer = 4.0, [2] + assert b == b_answer and g == g_answer + + def test_prominent_group_with_c(self): + """ + Prominent group without some nodes + """ + G = nx.path_graph(5) + k = 1 + b, g = nx.prominent_group(G, k, normalized=False, C=[2]) + b_answer, g_answer = 3.0, [1] + assert b == b_answer and g == g_answer + + def test_prominent_group_normalized_endpoints(self): + """ + Prominent group with normalized result, with endpoints + """ + G = nx.cycle_graph(7) + k = 2 + b, g = nx.prominent_group(G, k, normalized=True, endpoints=True) + b_answer, g_answer = 1.7, [2, 5] + assert b == b_answer and g == g_answer + + def test_prominent_group_disconnected_graph(self): + """ + Prominent group of disconnected graph + """ + G = nx.path_graph(6) + G.remove_edge(0, 1) + k = 1 + b, g = nx.prominent_group(G, k, weight=None, normalized=False) + b_answer, g_answer = 4.0, [3] + assert b == b_answer and g == g_answer + + def test_prominent_group_node_not_in_graph(self): + """ + Node(s) in C not in graph, raises NodeNotFound exception + """ + with pytest.raises(nx.NodeNotFound): + nx.prominent_group(nx.path_graph(5), 1, C=[10]) + + def test_group_betweenness_directed_weighted(self): + """ + Group betweenness centrality in a directed and weighted graph + """ + G = nx.DiGraph() + G.add_edge(1, 0, weight=1) + G.add_edge(0, 2, weight=2) + G.add_edge(1, 2, weight=3) + G.add_edge(3, 1, weight=4) + G.add_edge(2, 3, weight=1) + G.add_edge(4, 3, weight=6) + G.add_edge(2, 4, weight=7) + k = 2 + b, g = nx.prominent_group(G, k, weight="weight", normalized=False) + b_answer, g_answer = 5.0, [1, 2] + assert b == b_answer and g == g_answer + + def test_prominent_group_greedy_algorithm(self): + """ + Group betweenness centrality in a greedy algorithm + """ + G = nx.cycle_graph(7) + k = 2 + b, g = nx.prominent_group(G, k, normalized=True, endpoints=True, greedy=True) + b_answer, g_answer = 1.7, [6, 3] + assert b == b_answer and g == g_answer + + +class TestGroupClosenessCentrality: + def test_group_closeness_single_node(self): + """ + Group closeness centrality for a single node group + """ + G = nx.path_graph(5) + c = nx.group_closeness_centrality(G, [1]) + c_answer = nx.closeness_centrality(G, 1) + assert c == c_answer + + def test_group_closeness_disconnected(self): + """ + Group closeness centrality for a disconnected graph + """ + G = nx.Graph() + G.add_nodes_from([1, 2, 3, 4]) + c = nx.group_closeness_centrality(G, [1, 2]) + c_answer = 0 + assert c == c_answer + + def test_group_closeness_multiple_node(self): + """ + Group closeness centrality for a group with more than + 1 node + """ + G = nx.path_graph(4) + c = nx.group_closeness_centrality(G, [1, 2]) + c_answer = 1 + assert c == c_answer + + def test_group_closeness_node_not_in_graph(self): + """ + Node(s) in S not in graph, raises NodeNotFound exception + """ + with pytest.raises(nx.NodeNotFound): + nx.group_closeness_centrality(nx.path_graph(5), [6, 7, 8]) + + +class TestGroupDegreeCentrality: + def test_group_degree_centrality_single_node(self): + """ + Group degree centrality for a single node group + """ + G = nx.path_graph(4) + d = nx.group_degree_centrality(G, [1]) + d_answer = nx.degree_centrality(G)[1] + assert d == d_answer + + def test_group_degree_centrality_multiple_node(self): + """ + Group degree centrality for group with more than + 1 node + """ + G = nx.Graph() + G.add_nodes_from([1, 2, 3, 4, 5, 6, 7, 8]) + G.add_edges_from( + [(1, 2), (1, 3), (1, 6), (1, 7), (1, 8), (2, 3), (2, 4), (2, 5)] + ) + d = nx.group_degree_centrality(G, [1, 2]) + d_answer = 1 + assert d == d_answer + + def test_group_in_degree_centrality(self): + """ + Group in-degree centrality in a DiGraph + """ + G = nx.DiGraph() + G.add_nodes_from([1, 2, 3, 4, 5, 6, 7, 8]) + G.add_edges_from( + [(1, 2), (1, 3), (1, 6), (1, 7), (1, 8), (2, 3), (2, 4), (2, 5)] + ) + d = nx.group_in_degree_centrality(G, [1, 2]) + d_answer = 0 + assert d == d_answer + + def test_group_out_degree_centrality(self): + """ + Group out-degree centrality in a DiGraph + """ + G = nx.DiGraph() + G.add_nodes_from([1, 2, 3, 4, 5, 6, 7, 8]) + G.add_edges_from( + [(1, 2), (1, 3), (1, 6), (1, 7), (1, 8), (2, 3), (2, 4), (2, 5)] + ) + d = nx.group_out_degree_centrality(G, [1, 2]) + d_answer = 1 + assert d == d_answer + + def test_group_degree_centrality_node_not_in_graph(self): + """ + Node(s) in S not in graph, raises NetworkXError + """ + with pytest.raises(nx.NetworkXError): + nx.group_degree_centrality(nx.path_graph(5), [6, 7, 8]) diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/test_harmonic_centrality.py b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/test_harmonic_centrality.py new file mode 100644 index 0000000000000000000000000000000000000000..4b3dc4ac356701eb562a3179a69023ad83a8d74e --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/test_harmonic_centrality.py @@ -0,0 +1,122 @@ +""" +Tests for degree centrality. +""" + +import pytest + +import networkx as nx +from networkx.algorithms.centrality import harmonic_centrality + + +class TestClosenessCentrality: + @classmethod + def setup_class(cls): + cls.P3 = nx.path_graph(3) + cls.P4 = nx.path_graph(4) + cls.K5 = nx.complete_graph(5) + + cls.C4 = nx.cycle_graph(4) + cls.C4_directed = nx.cycle_graph(4, create_using=nx.DiGraph) + + cls.C5 = nx.cycle_graph(5) + + cls.T = nx.balanced_tree(r=2, h=2) + + cls.Gb = nx.DiGraph() + cls.Gb.add_edges_from([(0, 1), (0, 2), (0, 4), (2, 1), (2, 3), (4, 3)]) + + def test_p3_harmonic(self): + c = harmonic_centrality(self.P3) + d = {0: 1.5, 1: 2, 2: 1.5} + for n in sorted(self.P3): + assert c[n] == pytest.approx(d[n], abs=1e-3) + + def test_p4_harmonic(self): + c = harmonic_centrality(self.P4) + d = {0: 1.8333333, 1: 2.5, 2: 2.5, 3: 1.8333333} + for n in sorted(self.P4): + assert c[n] == pytest.approx(d[n], abs=1e-3) + + def test_clique_complete(self): + c = harmonic_centrality(self.K5) + d = {0: 4, 1: 4, 2: 4, 3: 4, 4: 4} + for n in sorted(self.P3): + assert c[n] == pytest.approx(d[n], abs=1e-3) + + def test_cycle_C4(self): + c = harmonic_centrality(self.C4) + d = {0: 2.5, 1: 2.5, 2: 2.5, 3: 2.5} + for n in sorted(self.C4): + assert c[n] == pytest.approx(d[n], abs=1e-3) + + def test_cycle_C5(self): + c = harmonic_centrality(self.C5) + d = {0: 3, 1: 3, 2: 3, 3: 3, 4: 3, 5: 4} + for n in sorted(self.C5): + assert c[n] == pytest.approx(d[n], abs=1e-3) + + def test_bal_tree(self): + c = harmonic_centrality(self.T) + d = {0: 4.0, 1: 4.1666, 2: 4.1666, 3: 2.8333, 4: 2.8333, 5: 2.8333, 6: 2.8333} + for n in sorted(self.T): + assert c[n] == pytest.approx(d[n], abs=1e-3) + + def test_exampleGraph(self): + c = harmonic_centrality(self.Gb) + d = {0: 0, 1: 2, 2: 1, 3: 2.5, 4: 1} + for n in sorted(self.Gb): + assert c[n] == pytest.approx(d[n], abs=1e-3) + + def test_weighted_harmonic(self): + XG = nx.DiGraph() + XG.add_weighted_edges_from( + [ + ("a", "b", 10), + ("d", "c", 5), + ("a", "c", 1), + ("e", "f", 2), + ("f", "c", 1), + ("a", "f", 3), + ] + ) + c = harmonic_centrality(XG, distance="weight") + d = {"a": 0, "b": 0.1, "c": 2.533, "d": 0, "e": 0, "f": 0.83333} + for n in sorted(XG): + assert c[n] == pytest.approx(d[n], abs=1e-3) + + def test_empty(self): + G = nx.DiGraph() + c = harmonic_centrality(G, distance="weight") + d = {} + assert c == d + + def test_singleton(self): + G = nx.DiGraph() + G.add_node(0) + c = harmonic_centrality(G, distance="weight") + d = {0: 0} + assert c == d + + def test_cycle_c4_directed(self): + c = harmonic_centrality(self.C4_directed, nbunch=[0, 1], sources=[1, 2]) + d = {0: 0.833, 1: 0.333} + for n in [0, 1]: + assert c[n] == pytest.approx(d[n], abs=1e-3) + + def test_cycle_c4_directed_subset(self): + c = harmonic_centrality(self.C4_directed, nbunch=[0, 1]) + d = 1.833 + for n in [0, 1]: + assert c[n] == pytest.approx(d, abs=1e-3) + + def test_p3_harmonic_subset(self): + c = harmonic_centrality(self.P3, sources=[0, 1]) + d = {0: 1, 1: 1, 2: 1.5} + for n in self.P3: + assert c[n] == pytest.approx(d[n], abs=1e-3) + + def test_p4_harmonic_subset(self): + c = harmonic_centrality(self.P4, nbunch=[2, 3], sources=[0, 1]) + d = {2: 1.5, 3: 0.8333333} + for n in [2, 3]: + assert c[n] == pytest.approx(d[n], abs=1e-3) diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/test_katz_centrality.py b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/test_katz_centrality.py new file mode 100644 index 0000000000000000000000000000000000000000..0927f00bc5c31ad1134dae0c8f59367baed67bb6 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/test_katz_centrality.py @@ -0,0 +1,345 @@ +import math + +import pytest + +import networkx as nx + + +class TestKatzCentrality: + def test_K5(self): + """Katz centrality: K5""" + G = nx.complete_graph(5) + alpha = 0.1 + b = nx.katz_centrality(G, alpha) + v = math.sqrt(1 / 5.0) + b_answer = dict.fromkeys(G, v) + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + nstart = {n: 1 for n in G} + b = nx.katz_centrality(G, alpha, nstart=nstart) + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + + def test_P3(self): + """Katz centrality: P3""" + alpha = 0.1 + G = nx.path_graph(3) + b_answer = {0: 0.5598852584152165, 1: 0.6107839182711449, 2: 0.5598852584152162} + b = nx.katz_centrality(G, alpha) + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-4) + + def test_maxiter(self): + with pytest.raises(nx.PowerIterationFailedConvergence): + nx.katz_centrality(nx.path_graph(3), 0.1, max_iter=0) + + def test_beta_as_scalar(self): + alpha = 0.1 + beta = 0.1 + b_answer = {0: 0.5598852584152165, 1: 0.6107839182711449, 2: 0.5598852584152162} + G = nx.path_graph(3) + b = nx.katz_centrality(G, alpha, beta) + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-4) + + def test_beta_as_dict(self): + alpha = 0.1 + beta = {0: 1.0, 1: 1.0, 2: 1.0} + b_answer = {0: 0.5598852584152165, 1: 0.6107839182711449, 2: 0.5598852584152162} + G = nx.path_graph(3) + b = nx.katz_centrality(G, alpha, beta) + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-4) + + def test_multiple_alpha(self): + alpha_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6] + for alpha in alpha_list: + b_answer = { + 0.1: { + 0: 0.5598852584152165, + 1: 0.6107839182711449, + 2: 0.5598852584152162, + }, + 0.2: { + 0: 0.5454545454545454, + 1: 0.6363636363636365, + 2: 0.5454545454545454, + }, + 0.3: { + 0: 0.5333964609104419, + 1: 0.6564879518897746, + 2: 0.5333964609104419, + }, + 0.4: { + 0: 0.5232045649263551, + 1: 0.6726915834767423, + 2: 0.5232045649263551, + }, + 0.5: { + 0: 0.5144957746691622, + 1: 0.6859943117075809, + 2: 0.5144957746691622, + }, + 0.6: { + 0: 0.5069794004195823, + 1: 0.6970966755769258, + 2: 0.5069794004195823, + }, + } + G = nx.path_graph(3) + b = nx.katz_centrality(G, alpha) + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[alpha][n], abs=1e-4) + + def test_multigraph(self): + with pytest.raises(nx.NetworkXException): + nx.katz_centrality(nx.MultiGraph(), 0.1) + + def test_empty(self): + e = nx.katz_centrality(nx.Graph(), 0.1) + assert e == {} + + def test_bad_beta(self): + with pytest.raises(nx.NetworkXException): + G = nx.Graph([(0, 1)]) + beta = {0: 77} + nx.katz_centrality(G, 0.1, beta=beta) + + def test_bad_beta_number(self): + with pytest.raises(nx.NetworkXException): + G = nx.Graph([(0, 1)]) + nx.katz_centrality(G, 0.1, beta="foo") + + +class TestKatzCentralityNumpy: + @classmethod + def setup_class(cls): + global np + np = pytest.importorskip("numpy") + pytest.importorskip("scipy") + + def test_K5(self): + """Katz centrality: K5""" + G = nx.complete_graph(5) + alpha = 0.1 + b = nx.katz_centrality(G, alpha) + v = math.sqrt(1 / 5.0) + b_answer = dict.fromkeys(G, v) + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + b = nx.eigenvector_centrality_numpy(G) + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-3) + + def test_P3(self): + """Katz centrality: P3""" + alpha = 0.1 + G = nx.path_graph(3) + b_answer = {0: 0.5598852584152165, 1: 0.6107839182711449, 2: 0.5598852584152162} + b = nx.katz_centrality_numpy(G, alpha) + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-4) + + def test_beta_as_scalar(self): + alpha = 0.1 + beta = 0.1 + b_answer = {0: 0.5598852584152165, 1: 0.6107839182711449, 2: 0.5598852584152162} + G = nx.path_graph(3) + b = nx.katz_centrality_numpy(G, alpha, beta) + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-4) + + def test_beta_as_dict(self): + alpha = 0.1 + beta = {0: 1.0, 1: 1.0, 2: 1.0} + b_answer = {0: 0.5598852584152165, 1: 0.6107839182711449, 2: 0.5598852584152162} + G = nx.path_graph(3) + b = nx.katz_centrality_numpy(G, alpha, beta) + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-4) + + def test_multiple_alpha(self): + alpha_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6] + for alpha in alpha_list: + b_answer = { + 0.1: { + 0: 0.5598852584152165, + 1: 0.6107839182711449, + 2: 0.5598852584152162, + }, + 0.2: { + 0: 0.5454545454545454, + 1: 0.6363636363636365, + 2: 0.5454545454545454, + }, + 0.3: { + 0: 0.5333964609104419, + 1: 0.6564879518897746, + 2: 0.5333964609104419, + }, + 0.4: { + 0: 0.5232045649263551, + 1: 0.6726915834767423, + 2: 0.5232045649263551, + }, + 0.5: { + 0: 0.5144957746691622, + 1: 0.6859943117075809, + 2: 0.5144957746691622, + }, + 0.6: { + 0: 0.5069794004195823, + 1: 0.6970966755769258, + 2: 0.5069794004195823, + }, + } + G = nx.path_graph(3) + b = nx.katz_centrality_numpy(G, alpha) + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[alpha][n], abs=1e-4) + + def test_multigraph(self): + with pytest.raises(nx.NetworkXException): + nx.katz_centrality(nx.MultiGraph(), 0.1) + + def test_empty(self): + e = nx.katz_centrality(nx.Graph(), 0.1) + assert e == {} + + def test_bad_beta(self): + with pytest.raises(nx.NetworkXException): + G = nx.Graph([(0, 1)]) + beta = {0: 77} + nx.katz_centrality_numpy(G, 0.1, beta=beta) + + def test_bad_beta_numbe(self): + with pytest.raises(nx.NetworkXException): + G = nx.Graph([(0, 1)]) + nx.katz_centrality_numpy(G, 0.1, beta="foo") + + def test_K5_unweighted(self): + """Katz centrality: K5""" + G = nx.complete_graph(5) + alpha = 0.1 + b = nx.katz_centrality(G, alpha, weight=None) + v = math.sqrt(1 / 5.0) + b_answer = dict.fromkeys(G, v) + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-7) + b = nx.eigenvector_centrality_numpy(G, weight=None) + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-3) + + def test_P3_unweighted(self): + """Katz centrality: P3""" + alpha = 0.1 + G = nx.path_graph(3) + b_answer = {0: 0.5598852584152165, 1: 0.6107839182711449, 2: 0.5598852584152162} + b = nx.katz_centrality_numpy(G, alpha, weight=None) + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-4) + + +class TestKatzCentralityDirected: + @classmethod + def setup_class(cls): + G = nx.DiGraph() + edges = [ + (1, 2), + (1, 3), + (2, 4), + (3, 2), + (3, 5), + (4, 2), + (4, 5), + (4, 6), + (5, 6), + (5, 7), + (5, 8), + (6, 8), + (7, 1), + (7, 5), + (7, 8), + (8, 6), + (8, 7), + ] + G.add_edges_from(edges, weight=2.0) + cls.G = G.reverse() + cls.G.alpha = 0.1 + cls.G.evc = [ + 0.3289589783189635, + 0.2832077296243516, + 0.3425906003685471, + 0.3970420865198392, + 0.41074871061646284, + 0.272257430756461, + 0.4201989685435462, + 0.34229059218038554, + ] + + H = nx.DiGraph(edges) + cls.H = G.reverse() + cls.H.alpha = 0.1 + cls.H.evc = [ + 0.3289589783189635, + 0.2832077296243516, + 0.3425906003685471, + 0.3970420865198392, + 0.41074871061646284, + 0.272257430756461, + 0.4201989685435462, + 0.34229059218038554, + ] + + def test_katz_centrality_weighted(self): + G = self.G + alpha = self.G.alpha + p = nx.katz_centrality(G, alpha, weight="weight") + for a, b in zip(list(p.values()), self.G.evc): + assert a == pytest.approx(b, abs=1e-7) + + def test_katz_centrality_unweighted(self): + H = self.H + alpha = self.H.alpha + p = nx.katz_centrality(H, alpha, weight="weight") + for a, b in zip(list(p.values()), self.H.evc): + assert a == pytest.approx(b, abs=1e-7) + + +class TestKatzCentralityDirectedNumpy(TestKatzCentralityDirected): + @classmethod + def setup_class(cls): + global np + np = pytest.importorskip("numpy") + pytest.importorskip("scipy") + super().setup_class() + + def test_katz_centrality_weighted(self): + G = self.G + alpha = self.G.alpha + p = nx.katz_centrality_numpy(G, alpha, weight="weight") + for a, b in zip(list(p.values()), self.G.evc): + assert a == pytest.approx(b, abs=1e-7) + + def test_katz_centrality_unweighted(self): + H = self.H + alpha = self.H.alpha + p = nx.katz_centrality_numpy(H, alpha, weight="weight") + for a, b in zip(list(p.values()), self.H.evc): + assert a == pytest.approx(b, abs=1e-7) + + +class TestKatzEigenvectorVKatz: + @classmethod + def setup_class(cls): + global np + np = pytest.importorskip("numpy") + pytest.importorskip("scipy") + + def test_eigenvector_v_katz_random(self): + G = nx.gnp_random_graph(10, 0.5, seed=1234) + l = max(np.linalg.eigvals(nx.adjacency_matrix(G).todense())) + e = nx.eigenvector_centrality_numpy(G) + k = nx.katz_centrality_numpy(G, 1.0 / l) + for n in G: + assert e[n] == pytest.approx(k[n], abs=1e-7) diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/test_laplacian_centrality.py b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/test_laplacian_centrality.py new file mode 100644 index 0000000000000000000000000000000000000000..21aa28b0b7c155078ab9c1a25e14d9aafa65683d --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/test_laplacian_centrality.py @@ -0,0 +1,221 @@ +import pytest + +import networkx as nx + +np = pytest.importorskip("numpy") +sp = pytest.importorskip("scipy") + + +def test_laplacian_centrality_null_graph(): + G = nx.Graph() + with pytest.raises(nx.NetworkXPointlessConcept): + d = nx.laplacian_centrality(G, normalized=False) + + +def test_laplacian_centrality_single_node(): + """See gh-6571""" + G = nx.empty_graph(1) + assert nx.laplacian_centrality(G, normalized=False) == {0: 0} + with pytest.raises(ZeroDivisionError): + nx.laplacian_centrality(G, normalized=True) + + +def test_laplacian_centrality_unconnected_nodes(): + """laplacian_centrality on a unconnected node graph should return 0 + + For graphs without edges, the Laplacian energy is 0 and is unchanged with + node removal, so:: + + LC(v) = LE(G) - LE(G - v) = 0 - 0 = 0 + """ + G = nx.empty_graph(3) + assert nx.laplacian_centrality(G, normalized=False) == {0: 0, 1: 0, 2: 0} + + +def test_laplacian_centrality_empty_graph(): + G = nx.empty_graph(3) + with pytest.raises(ZeroDivisionError): + d = nx.laplacian_centrality(G, normalized=True) + + +def test_laplacian_centrality_E(): + E = nx.Graph() + E.add_weighted_edges_from( + [(0, 1, 4), (4, 5, 1), (0, 2, 2), (2, 1, 1), (1, 3, 2), (1, 4, 2)] + ) + d = nx.laplacian_centrality(E) + exact = { + 0: 0.700000, + 1: 0.900000, + 2: 0.280000, + 3: 0.220000, + 4: 0.260000, + 5: 0.040000, + } + + for n, dc in d.items(): + assert exact[n] == pytest.approx(dc, abs=1e-7) + + # Check not normalized + full_energy = 200 + dnn = nx.laplacian_centrality(E, normalized=False) + for n, dc in dnn.items(): + assert exact[n] * full_energy == pytest.approx(dc, abs=1e-7) + + # Check unweighted not-normalized version + duw_nn = nx.laplacian_centrality(E, normalized=False, weight=None) + print(duw_nn) + exact_uw_nn = { + 0: 18, + 1: 34, + 2: 18, + 3: 10, + 4: 16, + 5: 6, + } + for n, dc in duw_nn.items(): + assert exact_uw_nn[n] == pytest.approx(dc, abs=1e-7) + + # Check unweighted version + duw = nx.laplacian_centrality(E, weight=None) + full_energy = 42 + for n, dc in duw.items(): + assert exact_uw_nn[n] / full_energy == pytest.approx(dc, abs=1e-7) + + +def test_laplacian_centrality_KC(): + KC = nx.karate_club_graph() + d = nx.laplacian_centrality(KC) + exact = { + 0: 0.2543593, + 1: 0.1724524, + 2: 0.2166053, + 3: 0.0964646, + 4: 0.0350344, + 5: 0.0571109, + 6: 0.0540713, + 7: 0.0788674, + 8: 0.1222204, + 9: 0.0217565, + 10: 0.0308751, + 11: 0.0215965, + 12: 0.0174372, + 13: 0.118861, + 14: 0.0366341, + 15: 0.0548712, + 16: 0.0172772, + 17: 0.0191969, + 18: 0.0225564, + 19: 0.0331147, + 20: 0.0279955, + 21: 0.0246361, + 22: 0.0382339, + 23: 0.1294193, + 24: 0.0227164, + 25: 0.0644697, + 26: 0.0281555, + 27: 0.075188, + 28: 0.0364742, + 29: 0.0707087, + 30: 0.0708687, + 31: 0.131019, + 32: 0.2370821, + 33: 0.3066709, + } + for n, dc in d.items(): + assert exact[n] == pytest.approx(dc, abs=1e-7) + + # Check not normalized + full_energy = 12502 + dnn = nx.laplacian_centrality(KC, normalized=False) + for n, dc in dnn.items(): + assert exact[n] * full_energy == pytest.approx(dc, abs=1e-3) + + +def test_laplacian_centrality_K(): + K = nx.krackhardt_kite_graph() + d = nx.laplacian_centrality(K) + exact = { + 0: 0.3010753, + 1: 0.3010753, + 2: 0.2258065, + 3: 0.483871, + 4: 0.2258065, + 5: 0.3870968, + 6: 0.3870968, + 7: 0.1935484, + 8: 0.0752688, + 9: 0.0322581, + } + for n, dc in d.items(): + assert exact[n] == pytest.approx(dc, abs=1e-7) + + # Check not normalized + full_energy = 186 + dnn = nx.laplacian_centrality(K, normalized=False) + for n, dc in dnn.items(): + assert exact[n] * full_energy == pytest.approx(dc, abs=1e-3) + + +def test_laplacian_centrality_P3(): + P3 = nx.path_graph(3) + d = nx.laplacian_centrality(P3) + exact = {0: 0.6, 1: 1.0, 2: 0.6} + for n, dc in d.items(): + assert exact[n] == pytest.approx(dc, abs=1e-7) + + +def test_laplacian_centrality_K5(): + K5 = nx.complete_graph(5) + d = nx.laplacian_centrality(K5) + exact = {0: 0.52, 1: 0.52, 2: 0.52, 3: 0.52, 4: 0.52} + for n, dc in d.items(): + assert exact[n] == pytest.approx(dc, abs=1e-7) + + +def test_laplacian_centrality_FF(): + FF = nx.florentine_families_graph() + d = nx.laplacian_centrality(FF) + exact = { + "Acciaiuoli": 0.0804598, + "Medici": 0.4022989, + "Castellani": 0.1724138, + "Peruzzi": 0.183908, + "Strozzi": 0.2528736, + "Barbadori": 0.137931, + "Ridolfi": 0.2183908, + "Tornabuoni": 0.2183908, + "Albizzi": 0.1954023, + "Salviati": 0.1149425, + "Pazzi": 0.0344828, + "Bischeri": 0.1954023, + "Guadagni": 0.2298851, + "Ginori": 0.045977, + "Lamberteschi": 0.0574713, + } + for n, dc in d.items(): + assert exact[n] == pytest.approx(dc, abs=1e-7) + + +def test_laplacian_centrality_DG(): + DG = nx.DiGraph([(0, 5), (1, 5), (2, 5), (3, 5), (4, 5), (5, 6), (5, 7), (5, 8)]) + d = nx.laplacian_centrality(DG) + exact = { + 0: 0.2123352, + 5: 0.515391, + 1: 0.2123352, + 2: 0.2123352, + 3: 0.2123352, + 4: 0.2123352, + 6: 0.2952031, + 7: 0.2952031, + 8: 0.2952031, + } + for n, dc in d.items(): + assert exact[n] == pytest.approx(dc, abs=1e-7) + + # Check not normalized + full_energy = 9.50704 + dnn = nx.laplacian_centrality(DG, normalized=False) + for n, dc in dnn.items(): + assert exact[n] * full_energy == pytest.approx(dc, abs=1e-4) diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/test_load_centrality.py b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/test_load_centrality.py new file mode 100644 index 0000000000000000000000000000000000000000..bf096039cd76542cc4c963ab896ee8fc4b295224 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/test_load_centrality.py @@ -0,0 +1,344 @@ +import pytest + +import networkx as nx + + +class TestLoadCentrality: + @classmethod + def setup_class(cls): + G = nx.Graph() + G.add_edge(0, 1, weight=3) + G.add_edge(0, 2, weight=2) + G.add_edge(0, 3, weight=6) + G.add_edge(0, 4, weight=4) + G.add_edge(1, 3, weight=5) + G.add_edge(1, 5, weight=5) + G.add_edge(2, 4, weight=1) + G.add_edge(3, 4, weight=2) + G.add_edge(3, 5, weight=1) + G.add_edge(4, 5, weight=4) + cls.G = G + cls.exact_weighted = {0: 4.0, 1: 0.0, 2: 8.0, 3: 6.0, 4: 8.0, 5: 0.0} + cls.K = nx.krackhardt_kite_graph() + cls.P3 = nx.path_graph(3) + cls.P4 = nx.path_graph(4) + cls.K5 = nx.complete_graph(5) + cls.P2 = nx.path_graph(2) + + cls.C4 = nx.cycle_graph(4) + cls.T = nx.balanced_tree(r=2, h=2) + cls.Gb = nx.Graph() + cls.Gb.add_edges_from([(0, 1), (0, 2), (1, 3), (2, 3), (2, 4), (4, 5), (3, 5)]) + cls.F = nx.florentine_families_graph() + cls.LM = nx.les_miserables_graph() + cls.D = nx.cycle_graph(3, create_using=nx.DiGraph()) + cls.D.add_edges_from([(3, 0), (4, 3)]) + + def test_not_strongly_connected(self): + b = nx.load_centrality(self.D) + result = {0: 5.0 / 12, 1: 1.0 / 4, 2: 1.0 / 12, 3: 1.0 / 4, 4: 0.000} + for n in sorted(self.D): + assert result[n] == pytest.approx(b[n], abs=1e-3) + assert result[n] == pytest.approx(nx.load_centrality(self.D, n), abs=1e-3) + + def test_P2_normalized_load(self): + G = self.P2 + c = nx.load_centrality(G, normalized=True) + d = {0: 0.000, 1: 0.000} + for n in sorted(G): + assert c[n] == pytest.approx(d[n], abs=1e-3) + + def test_weighted_load(self): + b = nx.load_centrality(self.G, weight="weight", normalized=False) + for n in sorted(self.G): + assert b[n] == self.exact_weighted[n] + + def test_k5_load(self): + G = self.K5 + c = nx.load_centrality(G) + d = {0: 0.000, 1: 0.000, 2: 0.000, 3: 0.000, 4: 0.000} + for n in sorted(G): + assert c[n] == pytest.approx(d[n], abs=1e-3) + + def test_p3_load(self): + G = self.P3 + c = nx.load_centrality(G) + d = {0: 0.000, 1: 1.000, 2: 0.000} + for n in sorted(G): + assert c[n] == pytest.approx(d[n], abs=1e-3) + c = nx.load_centrality(G, v=1) + assert c == pytest.approx(1.0, abs=1e-7) + c = nx.load_centrality(G, v=1, normalized=True) + assert c == pytest.approx(1.0, abs=1e-7) + + def test_p2_load(self): + G = nx.path_graph(2) + c = nx.load_centrality(G) + d = {0: 0.000, 1: 0.000} + for n in sorted(G): + assert c[n] == pytest.approx(d[n], abs=1e-3) + + def test_krackhardt_load(self): + G = self.K + c = nx.load_centrality(G) + d = { + 0: 0.023, + 1: 0.023, + 2: 0.000, + 3: 0.102, + 4: 0.000, + 5: 0.231, + 6: 0.231, + 7: 0.389, + 8: 0.222, + 9: 0.000, + } + for n in sorted(G): + assert c[n] == pytest.approx(d[n], abs=1e-3) + + def test_florentine_families_load(self): + G = self.F + c = nx.load_centrality(G) + d = { + "Acciaiuoli": 0.000, + "Albizzi": 0.211, + "Barbadori": 0.093, + "Bischeri": 0.104, + "Castellani": 0.055, + "Ginori": 0.000, + "Guadagni": 0.251, + "Lamberteschi": 0.000, + "Medici": 0.522, + "Pazzi": 0.000, + "Peruzzi": 0.022, + "Ridolfi": 0.117, + "Salviati": 0.143, + "Strozzi": 0.106, + "Tornabuoni": 0.090, + } + for n in sorted(G): + assert c[n] == pytest.approx(d[n], abs=1e-3) + + def test_les_miserables_load(self): + G = self.LM + c = nx.load_centrality(G) + d = { + "Napoleon": 0.000, + "Myriel": 0.177, + "MlleBaptistine": 0.000, + "MmeMagloire": 0.000, + "CountessDeLo": 0.000, + "Geborand": 0.000, + "Champtercier": 0.000, + "Cravatte": 0.000, + "Count": 0.000, + "OldMan": 0.000, + "Valjean": 0.567, + "Labarre": 0.000, + "Marguerite": 0.000, + "MmeDeR": 0.000, + "Isabeau": 0.000, + "Gervais": 0.000, + "Listolier": 0.000, + "Tholomyes": 0.043, + "Fameuil": 0.000, + "Blacheville": 0.000, + "Favourite": 0.000, + "Dahlia": 0.000, + "Zephine": 0.000, + "Fantine": 0.128, + "MmeThenardier": 0.029, + "Thenardier": 0.075, + "Cosette": 0.024, + "Javert": 0.054, + "Fauchelevent": 0.026, + "Bamatabois": 0.008, + "Perpetue": 0.000, + "Simplice": 0.009, + "Scaufflaire": 0.000, + "Woman1": 0.000, + "Judge": 0.000, + "Champmathieu": 0.000, + "Brevet": 0.000, + "Chenildieu": 0.000, + "Cochepaille": 0.000, + "Pontmercy": 0.007, + "Boulatruelle": 0.000, + "Eponine": 0.012, + "Anzelma": 0.000, + "Woman2": 0.000, + "MotherInnocent": 0.000, + "Gribier": 0.000, + "MmeBurgon": 0.026, + "Jondrette": 0.000, + "Gavroche": 0.164, + "Gillenormand": 0.021, + "Magnon": 0.000, + "MlleGillenormand": 0.047, + "MmePontmercy": 0.000, + "MlleVaubois": 0.000, + "LtGillenormand": 0.000, + "Marius": 0.133, + "BaronessT": 0.000, + "Mabeuf": 0.028, + "Enjolras": 0.041, + "Combeferre": 0.001, + "Prouvaire": 0.000, + "Feuilly": 0.001, + "Courfeyrac": 0.006, + "Bahorel": 0.002, + "Bossuet": 0.032, + "Joly": 0.002, + "Grantaire": 0.000, + "MotherPlutarch": 0.000, + "Gueulemer": 0.005, + "Babet": 0.005, + "Claquesous": 0.005, + "Montparnasse": 0.004, + "Toussaint": 0.000, + "Child1": 0.000, + "Child2": 0.000, + "Brujon": 0.000, + "MmeHucheloup": 0.000, + } + for n in sorted(G): + assert c[n] == pytest.approx(d[n], abs=1e-3) + + def test_unnormalized_k5_load(self): + G = self.K5 + c = nx.load_centrality(G, normalized=False) + d = {0: 0.000, 1: 0.000, 2: 0.000, 3: 0.000, 4: 0.000} + for n in sorted(G): + assert c[n] == pytest.approx(d[n], abs=1e-3) + + def test_unnormalized_p3_load(self): + G = self.P3 + c = nx.load_centrality(G, normalized=False) + d = {0: 0.000, 1: 2.000, 2: 0.000} + for n in sorted(G): + assert c[n] == pytest.approx(d[n], abs=1e-3) + + def test_unnormalized_krackhardt_load(self): + G = self.K + c = nx.load_centrality(G, normalized=False) + d = { + 0: 1.667, + 1: 1.667, + 2: 0.000, + 3: 7.333, + 4: 0.000, + 5: 16.667, + 6: 16.667, + 7: 28.000, + 8: 16.000, + 9: 0.000, + } + + for n in sorted(G): + assert c[n] == pytest.approx(d[n], abs=1e-3) + + def test_unnormalized_florentine_families_load(self): + G = self.F + c = nx.load_centrality(G, normalized=False) + + d = { + "Acciaiuoli": 0.000, + "Albizzi": 38.333, + "Barbadori": 17.000, + "Bischeri": 19.000, + "Castellani": 10.000, + "Ginori": 0.000, + "Guadagni": 45.667, + "Lamberteschi": 0.000, + "Medici": 95.000, + "Pazzi": 0.000, + "Peruzzi": 4.000, + "Ridolfi": 21.333, + "Salviati": 26.000, + "Strozzi": 19.333, + "Tornabuoni": 16.333, + } + for n in sorted(G): + assert c[n] == pytest.approx(d[n], abs=1e-3) + + def test_load_betweenness_difference(self): + # Difference Between Load and Betweenness + # --------------------------------------- The smallest graph + # that shows the difference between load and betweenness is + # G=ladder_graph(3) (Graph B below) + + # Graph A and B are from Tao Zhou, Jian-Guo Liu, Bing-Hong + # Wang: Comment on "Scientific collaboration + # networks. II. Shortest paths, weighted networks, and + # centrality". https://arxiv.org/pdf/physics/0511084 + + # Notice that unlike here, their calculation adds to 1 to the + # betweenness of every node i for every path from i to every + # other node. This is exactly what it should be, based on + # Eqn. (1) in their paper: the eqn is B(v) = \sum_{s\neq t, + # s\neq v}{\frac{\sigma_{st}(v)}{\sigma_{st}}}, therefore, + # they allow v to be the target node. + + # We follow Brandes 2001, who follows Freeman 1977 that make + # the sum for betweenness of v exclude paths where v is either + # the source or target node. To agree with their numbers, we + # must additionally, remove edge (4,8) from the graph, see AC + # example following (there is a mistake in the figure in their + # paper - personal communication). + + # A = nx.Graph() + # A.add_edges_from([(0,1), (1,2), (1,3), (2,4), + # (3,5), (4,6), (4,7), (4,8), + # (5,8), (6,9), (7,9), (8,9)]) + B = nx.Graph() # ladder_graph(3) + B.add_edges_from([(0, 1), (0, 2), (1, 3), (2, 3), (2, 4), (4, 5), (3, 5)]) + c = nx.load_centrality(B, normalized=False) + d = {0: 1.750, 1: 1.750, 2: 6.500, 3: 6.500, 4: 1.750, 5: 1.750} + for n in sorted(B): + assert c[n] == pytest.approx(d[n], abs=1e-3) + + def test_c4_edge_load(self): + G = self.C4 + c = nx.edge_load_centrality(G) + d = {(0, 1): 6.000, (0, 3): 6.000, (1, 2): 6.000, (2, 3): 6.000} + for n in G.edges(): + assert c[n] == pytest.approx(d[n], abs=1e-3) + + def test_p4_edge_load(self): + G = self.P4 + c = nx.edge_load_centrality(G) + d = {(0, 1): 6.000, (1, 2): 8.000, (2, 3): 6.000} + for n in G.edges(): + assert c[n] == pytest.approx(d[n], abs=1e-3) + + def test_k5_edge_load(self): + G = self.K5 + c = nx.edge_load_centrality(G) + d = { + (0, 1): 5.000, + (0, 2): 5.000, + (0, 3): 5.000, + (0, 4): 5.000, + (1, 2): 5.000, + (1, 3): 5.000, + (1, 4): 5.000, + (2, 3): 5.000, + (2, 4): 5.000, + (3, 4): 5.000, + } + for n in G.edges(): + assert c[n] == pytest.approx(d[n], abs=1e-3) + + def test_tree_edge_load(self): + G = self.T + c = nx.edge_load_centrality(G) + d = { + (0, 1): 24.000, + (0, 2): 24.000, + (1, 3): 12.000, + (1, 4): 12.000, + (2, 5): 12.000, + (2, 6): 12.000, + } + for n in G.edges(): + assert c[n] == pytest.approx(d[n], abs=1e-3) diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/test_percolation_centrality.py b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/test_percolation_centrality.py new file mode 100644 index 0000000000000000000000000000000000000000..0cb8f52965c975013d41be7c3de874cd86ee693a --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/test_percolation_centrality.py @@ -0,0 +1,87 @@ +import pytest + +import networkx as nx + + +def example1a_G(): + G = nx.Graph() + G.add_node(1, percolation=0.1) + G.add_node(2, percolation=0.2) + G.add_node(3, percolation=0.2) + G.add_node(4, percolation=0.2) + G.add_node(5, percolation=0.3) + G.add_node(6, percolation=0.2) + G.add_node(7, percolation=0.5) + G.add_node(8, percolation=0.5) + G.add_edges_from([(1, 4), (2, 4), (3, 4), (4, 5), (5, 6), (6, 7), (6, 8)]) + return G + + +def example1b_G(): + G = nx.Graph() + G.add_node(1, percolation=0.3) + G.add_node(2, percolation=0.5) + G.add_node(3, percolation=0.5) + G.add_node(4, percolation=0.2) + G.add_node(5, percolation=0.3) + G.add_node(6, percolation=0.2) + G.add_node(7, percolation=0.1) + G.add_node(8, percolation=0.1) + G.add_edges_from([(1, 4), (2, 4), (3, 4), (4, 5), (5, 6), (6, 7), (6, 8)]) + return G + + +def test_percolation_example1a(): + """percolation centrality: example 1a""" + G = example1a_G() + p = nx.percolation_centrality(G) + p_answer = {4: 0.625, 6: 0.667} + for n, k in p_answer.items(): + assert p[n] == pytest.approx(k, abs=1e-3) + + +def test_percolation_example1b(): + """percolation centrality: example 1a""" + G = example1b_G() + p = nx.percolation_centrality(G) + p_answer = {4: 0.825, 6: 0.4} + for n, k in p_answer.items(): + assert p[n] == pytest.approx(k, abs=1e-3) + + +def test_converge_to_betweenness(): + """percolation centrality: should converge to betweenness + centrality when all nodes are percolated the same""" + # taken from betweenness test test_florentine_families_graph + G = nx.florentine_families_graph() + b_answer = { + "Acciaiuoli": 0.000, + "Albizzi": 0.212, + "Barbadori": 0.093, + "Bischeri": 0.104, + "Castellani": 0.055, + "Ginori": 0.000, + "Guadagni": 0.255, + "Lamberteschi": 0.000, + "Medici": 0.522, + "Pazzi": 0.000, + "Peruzzi": 0.022, + "Ridolfi": 0.114, + "Salviati": 0.143, + "Strozzi": 0.103, + "Tornabuoni": 0.092, + } + + # If no initial state is provided, state for + # every node defaults to 1 + p_answer = nx.percolation_centrality(G) + assert p_answer == pytest.approx(b_answer, abs=1e-3) + + p_states = {k: 0.3 for k, v in b_answer.items()} + p_answer = nx.percolation_centrality(G, states=p_states) + assert p_answer == pytest.approx(b_answer, abs=1e-3) + + +def test_default_percolation(): + G = nx.erdos_renyi_graph(42, 0.42, seed=42) + assert nx.percolation_centrality(G) == pytest.approx(nx.betweenness_centrality(G)) diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/test_reaching.py b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/test_reaching.py new file mode 100644 index 0000000000000000000000000000000000000000..35d50e701216bc6c003517ca3ed92cddc261c286 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/test_reaching.py @@ -0,0 +1,140 @@ +"""Unit tests for the :mod:`networkx.algorithms.centrality.reaching` module.""" + +import pytest + +import networkx as nx + + +class TestGlobalReachingCentrality: + """Unit tests for the global reaching centrality function.""" + + def test_non_positive_weights(self): + with pytest.raises(nx.NetworkXError): + G = nx.DiGraph() + nx.global_reaching_centrality(G, weight="weight") + + def test_negatively_weighted(self): + with pytest.raises(nx.NetworkXError): + G = nx.Graph() + G.add_weighted_edges_from([(0, 1, -2), (1, 2, +1)]) + nx.global_reaching_centrality(G, weight="weight") + + def test_directed_star(self): + G = nx.DiGraph() + G.add_weighted_edges_from([(1, 2, 0.5), (1, 3, 0.5)]) + grc = nx.global_reaching_centrality + assert grc(G, normalized=False, weight="weight") == 0.5 + assert grc(G) == 1 + + def test_undirected_unweighted_star(self): + G = nx.star_graph(2) + grc = nx.global_reaching_centrality + assert grc(G, normalized=False, weight=None) == 0.25 + + def test_undirected_weighted_star(self): + G = nx.Graph() + G.add_weighted_edges_from([(1, 2, 1), (1, 3, 2)]) + grc = nx.global_reaching_centrality + assert grc(G, normalized=False, weight="weight") == 0.375 + + def test_cycle_directed_unweighted(self): + G = nx.DiGraph() + G.add_edge(1, 2) + G.add_edge(2, 1) + assert nx.global_reaching_centrality(G, weight=None) == 0 + + def test_cycle_undirected_unweighted(self): + G = nx.Graph() + G.add_edge(1, 2) + assert nx.global_reaching_centrality(G, weight=None) == 0 + + def test_cycle_directed_weighted(self): + G = nx.DiGraph() + G.add_weighted_edges_from([(1, 2, 1), (2, 1, 1)]) + assert nx.global_reaching_centrality(G) == 0 + + def test_cycle_undirected_weighted(self): + G = nx.Graph() + G.add_edge(1, 2, weight=1) + grc = nx.global_reaching_centrality + assert grc(G, normalized=False) == 0 + + def test_directed_weighted(self): + G = nx.DiGraph() + G.add_edge("A", "B", weight=5) + G.add_edge("B", "C", weight=1) + G.add_edge("B", "D", weight=0.25) + G.add_edge("D", "E", weight=1) + + denom = len(G) - 1 + A_local = sum([5, 3, 2.625, 2.0833333333333]) / denom + B_local = sum([1, 0.25, 0.625]) / denom + C_local = 0 + D_local = sum([1]) / denom + E_local = 0 + + local_reach_ctrs = [A_local, C_local, B_local, D_local, E_local] + max_local = max(local_reach_ctrs) + expected = sum(max_local - lrc for lrc in local_reach_ctrs) / denom + grc = nx.global_reaching_centrality + actual = grc(G, normalized=False, weight="weight") + assert expected == pytest.approx(actual, abs=1e-7) + + def test_single_node_with_cycle(self): + G = nx.DiGraph([(1, 1)]) + with pytest.raises(nx.NetworkXError, match="local_reaching_centrality"): + nx.global_reaching_centrality(G) + + def test_single_node_with_weighted_cycle(self): + G = nx.DiGraph() + G.add_weighted_edges_from([(1, 1, 2)]) + with pytest.raises(nx.NetworkXError, match="local_reaching_centrality"): + nx.global_reaching_centrality(G, weight="weight") + + +class TestLocalReachingCentrality: + """Unit tests for the local reaching centrality function.""" + + def test_non_positive_weights(self): + with pytest.raises(nx.NetworkXError): + G = nx.DiGraph() + G.add_weighted_edges_from([(0, 1, 0)]) + nx.local_reaching_centrality(G, 0, weight="weight") + + def test_negatively_weighted(self): + with pytest.raises(nx.NetworkXError): + G = nx.Graph() + G.add_weighted_edges_from([(0, 1, -2), (1, 2, +1)]) + nx.local_reaching_centrality(G, 0, weight="weight") + + def test_undirected_unweighted_star(self): + G = nx.star_graph(2) + grc = nx.local_reaching_centrality + assert grc(G, 1, weight=None, normalized=False) == 0.75 + + def test_undirected_weighted_star(self): + G = nx.Graph() + G.add_weighted_edges_from([(1, 2, 1), (1, 3, 2)]) + centrality = nx.local_reaching_centrality( + G, 1, normalized=False, weight="weight" + ) + assert centrality == 1.5 + + def test_undirected_weighted_normalized(self): + G = nx.Graph() + G.add_weighted_edges_from([(1, 2, 1), (1, 3, 2)]) + centrality = nx.local_reaching_centrality( + G, 1, normalized=True, weight="weight" + ) + assert centrality == 1.0 + + def test_single_node_with_cycle(self): + G = nx.DiGraph([(1, 1)]) + with pytest.raises(nx.NetworkXError, match="local_reaching_centrality"): + nx.local_reaching_centrality(G, 1) + + def test_single_node_with_weighted_cycle(self): + G = nx.DiGraph() + G.add_weighted_edges_from([(1, 1, 2)]) + with pytest.raises(nx.NetworkXError, match="local_reaching_centrality"): + nx.local_reaching_centrality(G, 1, weight="weight") diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/test_second_order_centrality.py b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/test_second_order_centrality.py new file mode 100644 index 0000000000000000000000000000000000000000..cc3047866079fd9fe4cf43a6793cf160a0c0cdce --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/test_second_order_centrality.py @@ -0,0 +1,82 @@ +""" +Tests for second order centrality. +""" + +import pytest + +pytest.importorskip("numpy") +pytest.importorskip("scipy") + +import networkx as nx + + +def test_empty(): + with pytest.raises(nx.NetworkXException): + G = nx.empty_graph() + nx.second_order_centrality(G) + + +def test_non_connected(): + with pytest.raises(nx.NetworkXException): + G = nx.Graph() + G.add_node(0) + G.add_node(1) + nx.second_order_centrality(G) + + +def test_non_negative_edge_weights(): + with pytest.raises(nx.NetworkXException): + G = nx.path_graph(2) + G.add_edge(0, 1, weight=-1) + nx.second_order_centrality(G) + + +def test_weight_attribute(): + G = nx.Graph() + G.add_weighted_edges_from([(0, 1, 1.0), (1, 2, 3.5)], weight="w") + expected = {0: 3.431, 1: 3.082, 2: 5.612} + b = nx.second_order_centrality(G, weight="w") + + for n in sorted(G): + assert b[n] == pytest.approx(expected[n], abs=1e-2) + + +def test_one_node_graph(): + """Second order centrality: single node""" + G = nx.Graph() + G.add_node(0) + G.add_edge(0, 0) + assert nx.second_order_centrality(G)[0] == 0 + + +def test_P3(): + """Second order centrality: line graph, as defined in paper""" + G = nx.path_graph(3) + b_answer = {0: 3.741, 1: 1.414, 2: 3.741} + + b = nx.second_order_centrality(G) + + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-2) + + +def test_K3(): + """Second order centrality: complete graph, as defined in paper""" + G = nx.complete_graph(3) + b_answer = {0: 1.414, 1: 1.414, 2: 1.414} + + b = nx.second_order_centrality(G) + + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-2) + + +def test_ring_graph(): + """Second order centrality: ring graph, as defined in paper""" + G = nx.cycle_graph(5) + b_answer = {0: 4.472, 1: 4.472, 2: 4.472, 3: 4.472, 4: 4.472} + + b = nx.second_order_centrality(G) + + for n in sorted(G): + assert b[n] == pytest.approx(b_answer[n], abs=1e-2) diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/test_subgraph.py b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/test_subgraph.py new file mode 100644 index 0000000000000000000000000000000000000000..710927515baa4786e4be15ddf25ad34e423563d2 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/test_subgraph.py @@ -0,0 +1,110 @@ +import pytest + +pytest.importorskip("numpy") +pytest.importorskip("scipy") + +import networkx as nx +from networkx.algorithms.centrality.subgraph_alg import ( + communicability_betweenness_centrality, + estrada_index, + subgraph_centrality, + subgraph_centrality_exp, +) + + +class TestSubgraph: + def test_subgraph_centrality(self): + answer = {0: 1.5430806348152433, 1: 1.5430806348152433} + result = subgraph_centrality(nx.path_graph(2)) + for k, v in result.items(): + assert answer[k] == pytest.approx(v, abs=1e-7) + + answer1 = { + "1": 1.6445956054135658, + "Albert": 2.4368257358712189, + "Aric": 2.4368257358712193, + "Dan": 3.1306328496328168, + "Franck": 2.3876142275231915, + } + G1 = nx.Graph( + [ + ("Franck", "Aric"), + ("Aric", "Dan"), + ("Dan", "Albert"), + ("Albert", "Franck"), + ("Dan", "1"), + ("Franck", "Albert"), + ] + ) + result1 = subgraph_centrality(G1) + for k, v in result1.items(): + assert answer1[k] == pytest.approx(v, abs=1e-7) + result1 = subgraph_centrality_exp(G1) + for k, v in result1.items(): + assert answer1[k] == pytest.approx(v, abs=1e-7) + + def test_subgraph_centrality_big_graph(self): + g199 = nx.complete_graph(199) + g200 = nx.complete_graph(200) + + comm199 = nx.subgraph_centrality(g199) + comm199_exp = nx.subgraph_centrality_exp(g199) + + comm200 = nx.subgraph_centrality(g200) + comm200_exp = nx.subgraph_centrality_exp(g200) + + def test_communicability_betweenness_centrality_small(self): + result = communicability_betweenness_centrality(nx.path_graph(2)) + assert result == {0: 0, 1: 0} + + result = communicability_betweenness_centrality(nx.path_graph(1)) + assert result == {0: 0} + + result = communicability_betweenness_centrality(nx.path_graph(0)) + assert result == {} + + answer = {0: 0.1411224421177313, 1: 1.0, 2: 0.1411224421177313} + result = communicability_betweenness_centrality(nx.path_graph(3)) + for k, v in result.items(): + assert answer[k] == pytest.approx(v, abs=1e-7) + + result = communicability_betweenness_centrality(nx.complete_graph(3)) + for k, v in result.items(): + assert 0.49786143366223296 == pytest.approx(v, abs=1e-7) + + def test_communicability_betweenness_centrality(self): + answer = { + 0: 0.07017447951484615, + 1: 0.71565598701107991, + 2: 0.71565598701107991, + 3: 0.07017447951484615, + } + result = communicability_betweenness_centrality(nx.path_graph(4)) + for k, v in result.items(): + assert answer[k] == pytest.approx(v, abs=1e-7) + + answer1 = { + "1": 0.060039074193949521, + "Albert": 0.315470761661372, + "Aric": 0.31547076166137211, + "Dan": 0.68297778678316201, + "Franck": 0.21977926617449497, + } + G1 = nx.Graph( + [ + ("Franck", "Aric"), + ("Aric", "Dan"), + ("Dan", "Albert"), + ("Albert", "Franck"), + ("Dan", "1"), + ("Franck", "Albert"), + ] + ) + result1 = communicability_betweenness_centrality(G1) + for k, v in result1.items(): + assert answer1[k] == pytest.approx(v, abs=1e-7) + + def test_estrada_index(self): + answer = 1041.2470334195475 + result = estrada_index(nx.karate_club_graph()) + assert answer == pytest.approx(result, abs=1e-7) diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/test_trophic.py b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/test_trophic.py new file mode 100644 index 0000000000000000000000000000000000000000..e6880d52a47013f653ff27d6d80f5e207a98a4ef --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/test_trophic.py @@ -0,0 +1,302 @@ +"""Test trophic levels, trophic differences and trophic coherence""" + +import pytest + +np = pytest.importorskip("numpy") +pytest.importorskip("scipy") + +import networkx as nx + + +def test_trophic_levels(): + """Trivial example""" + G = nx.DiGraph() + G.add_edge("a", "b") + G.add_edge("b", "c") + + d = nx.trophic_levels(G) + assert d == {"a": 1, "b": 2, "c": 3} + + +def test_trophic_levels_levine(): + """Example from Figure 5 in Stephen Levine (1980) J. theor. Biol. 83, + 195-207 + """ + S = nx.DiGraph() + S.add_edge(1, 2, weight=1.0) + S.add_edge(1, 3, weight=0.2) + S.add_edge(1, 4, weight=0.8) + S.add_edge(2, 3, weight=0.2) + S.add_edge(2, 5, weight=0.3) + S.add_edge(4, 3, weight=0.6) + S.add_edge(4, 5, weight=0.7) + S.add_edge(5, 4, weight=0.2) + + # save copy for later, test intermediate implementation details first + S2 = S.copy() + + # drop nodes of in-degree zero + z = [nid for nid, d in S.in_degree if d == 0] + for nid in z: + S.remove_node(nid) + + # find adjacency matrix + q = nx.linalg.graphmatrix.adjacency_matrix(S).T + + # fmt: off + expected_q = np.array([ + [0, 0, 0., 0], + [0.2, 0, 0.6, 0], + [0, 0, 0, 0.2], + [0.3, 0, 0.7, 0] + ]) + # fmt: on + assert np.array_equal(q.todense(), expected_q) + + # must be square, size of number of nodes + assert len(q.shape) == 2 + assert q.shape[0] == q.shape[1] + assert q.shape[0] == len(S) + + nn = q.shape[0] + + i = np.eye(nn) + n = np.linalg.inv(i - q) + y = np.asarray(n) @ np.ones(nn) + + expected_y = np.array([1, 2.07906977, 1.46511628, 2.3255814]) + assert np.allclose(y, expected_y) + + expected_d = {1: 1, 2: 2, 3: 3.07906977, 4: 2.46511628, 5: 3.3255814} + + d = nx.trophic_levels(S2) + + for nid, level in d.items(): + expected_level = expected_d[nid] + assert expected_level == pytest.approx(level, abs=1e-7) + + +def test_trophic_levels_simple(): + matrix_a = np.array([[0, 0], [1, 0]]) + G = nx.from_numpy_array(matrix_a, create_using=nx.DiGraph) + d = nx.trophic_levels(G) + assert d[0] == pytest.approx(2, abs=1e-7) + assert d[1] == pytest.approx(1, abs=1e-7) + + +def test_trophic_levels_more_complex(): + # fmt: off + matrix = np.array([ + [0, 1, 0, 0], + [0, 0, 1, 0], + [0, 0, 0, 1], + [0, 0, 0, 0] + ]) + # fmt: on + G = nx.from_numpy_array(matrix, create_using=nx.DiGraph) + d = nx.trophic_levels(G) + expected_result = [1, 2, 3, 4] + for ind in range(4): + assert d[ind] == pytest.approx(expected_result[ind], abs=1e-7) + + # fmt: off + matrix = np.array([ + [0, 1, 1, 0], + [0, 0, 1, 1], + [0, 0, 0, 1], + [0, 0, 0, 0] + ]) + # fmt: on + G = nx.from_numpy_array(matrix, create_using=nx.DiGraph) + d = nx.trophic_levels(G) + + expected_result = [1, 2, 2.5, 3.25] + print("Calculated result: ", d) + print("Expected Result: ", expected_result) + + for ind in range(4): + assert d[ind] == pytest.approx(expected_result[ind], abs=1e-7) + + +def test_trophic_levels_even_more_complex(): + # fmt: off + # Another, bigger matrix + matrix = np.array([ + [0, 0, 0, 0, 0], + [0, 1, 0, 1, 0], + [1, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 0, 1, 0] + ]) + # Generated this linear system using pen and paper: + K = np.array([ + [1, 0, -1, 0, 0], + [0, 0.5, 0, -0.5, 0], + [0, 0, 1, 0, 0], + [0, -0.5, 0, 1, -0.5], + [0, 0, 0, 0, 1], + ]) + # fmt: on + result_1 = np.ravel(np.linalg.inv(K) @ np.ones(5)) + G = nx.from_numpy_array(matrix, create_using=nx.DiGraph) + result_2 = nx.trophic_levels(G) + + for ind in range(5): + assert result_1[ind] == pytest.approx(result_2[ind], abs=1e-7) + + +def test_trophic_levels_singular_matrix(): + """Should raise an error with graphs with only non-basal nodes""" + matrix = np.identity(4) + G = nx.from_numpy_array(matrix, create_using=nx.DiGraph) + with pytest.raises(nx.NetworkXError) as e: + nx.trophic_levels(G) + msg = ( + "Trophic levels are only defined for graphs where every node " + + "has a path from a basal node (basal nodes are nodes with no " + + "incoming edges)." + ) + assert msg in str(e.value) + + +def test_trophic_levels_singular_with_basal(): + """Should fail to compute if there are any parts of the graph which are not + reachable from any basal node (with in-degree zero). + """ + G = nx.DiGraph() + # a has in-degree zero + G.add_edge("a", "b") + + # b is one level above a, c and d + G.add_edge("c", "b") + G.add_edge("d", "b") + + # c and d form a loop, neither are reachable from a + G.add_edge("c", "d") + G.add_edge("d", "c") + + with pytest.raises(nx.NetworkXError) as e: + nx.trophic_levels(G) + msg = ( + "Trophic levels are only defined for graphs where every node " + + "has a path from a basal node (basal nodes are nodes with no " + + "incoming edges)." + ) + assert msg in str(e.value) + + # if self-loops are allowed, smaller example: + G = nx.DiGraph() + G.add_edge("a", "b") # a has in-degree zero + G.add_edge("c", "b") # b is one level above a and c + G.add_edge("c", "c") # c has a self-loop + with pytest.raises(nx.NetworkXError) as e: + nx.trophic_levels(G) + msg = ( + "Trophic levels are only defined for graphs where every node " + + "has a path from a basal node (basal nodes are nodes with no " + + "incoming edges)." + ) + assert msg in str(e.value) + + +def test_trophic_differences(): + matrix_a = np.array([[0, 1], [0, 0]]) + G = nx.from_numpy_array(matrix_a, create_using=nx.DiGraph) + diffs = nx.trophic_differences(G) + assert diffs[(0, 1)] == pytest.approx(1, abs=1e-7) + + # fmt: off + matrix_b = np.array([ + [0, 1, 1, 0], + [0, 0, 1, 1], + [0, 0, 0, 1], + [0, 0, 0, 0] + ]) + # fmt: on + G = nx.from_numpy_array(matrix_b, create_using=nx.DiGraph) + diffs = nx.trophic_differences(G) + + assert diffs[(0, 1)] == pytest.approx(1, abs=1e-7) + assert diffs[(0, 2)] == pytest.approx(1.5, abs=1e-7) + assert diffs[(1, 2)] == pytest.approx(0.5, abs=1e-7) + assert diffs[(1, 3)] == pytest.approx(1.25, abs=1e-7) + assert diffs[(2, 3)] == pytest.approx(0.75, abs=1e-7) + + +def test_trophic_incoherence_parameter_no_cannibalism(): + matrix_a = np.array([[0, 1], [0, 0]]) + G = nx.from_numpy_array(matrix_a, create_using=nx.DiGraph) + q = nx.trophic_incoherence_parameter(G, cannibalism=False) + assert q == pytest.approx(0, abs=1e-7) + + # fmt: off + matrix_b = np.array([ + [0, 1, 1, 0], + [0, 0, 1, 1], + [0, 0, 0, 1], + [0, 0, 0, 0] + ]) + # fmt: on + G = nx.from_numpy_array(matrix_b, create_using=nx.DiGraph) + q = nx.trophic_incoherence_parameter(G, cannibalism=False) + assert q == pytest.approx(np.std([1, 1.5, 0.5, 0.75, 1.25]), abs=1e-7) + + # fmt: off + matrix_c = np.array([ + [0, 1, 1, 0], + [0, 1, 1, 1], + [0, 0, 0, 1], + [0, 0, 0, 1] + ]) + # fmt: on + G = nx.from_numpy_array(matrix_c, create_using=nx.DiGraph) + q = nx.trophic_incoherence_parameter(G, cannibalism=False) + # Ignore the -link + assert q == pytest.approx(np.std([1, 1.5, 0.5, 0.75, 1.25]), abs=1e-7) + + # no self-loops case + # fmt: off + matrix_d = np.array([ + [0, 1, 1, 0], + [0, 0, 1, 1], + [0, 0, 0, 1], + [0, 0, 0, 0] + ]) + # fmt: on + G = nx.from_numpy_array(matrix_d, create_using=nx.DiGraph) + q = nx.trophic_incoherence_parameter(G, cannibalism=False) + # Ignore the -link + assert q == pytest.approx(np.std([1, 1.5, 0.5, 0.75, 1.25]), abs=1e-7) + + +def test_trophic_incoherence_parameter_cannibalism(): + matrix_a = np.array([[0, 1], [0, 0]]) + G = nx.from_numpy_array(matrix_a, create_using=nx.DiGraph) + q = nx.trophic_incoherence_parameter(G, cannibalism=True) + assert q == pytest.approx(0, abs=1e-7) + + # fmt: off + matrix_b = np.array([ + [0, 0, 0, 0, 0], + [0, 1, 0, 1, 0], + [1, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 0, 1, 0] + ]) + # fmt: on + G = nx.from_numpy_array(matrix_b, create_using=nx.DiGraph) + q = nx.trophic_incoherence_parameter(G, cannibalism=True) + assert q == pytest.approx(2, abs=1e-7) + + # fmt: off + matrix_c = np.array([ + [0, 1, 1, 0], + [0, 0, 1, 1], + [0, 0, 0, 1], + [0, 0, 0, 0] + ]) + # fmt: on + G = nx.from_numpy_array(matrix_c, create_using=nx.DiGraph) + q = nx.trophic_incoherence_parameter(G, cannibalism=True) + # Ignore the -link + assert q == pytest.approx(np.std([1, 1.5, 0.5, 0.75, 1.25]), abs=1e-7) diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/test_voterank.py b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/test_voterank.py new file mode 100644 index 0000000000000000000000000000000000000000..a5cfb6107a64f1c1e0923297a81c7df7172d5afe --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/centrality/tests/test_voterank.py @@ -0,0 +1,64 @@ +""" +Unit tests for VoteRank. +""" + +import networkx as nx + + +class TestVoteRankCentrality: + # Example Graph present in reference paper + def test_voterank_centrality_1(self): + G = nx.Graph() + G.add_edges_from( + [ + (7, 8), + (7, 5), + (7, 9), + (5, 0), + (0, 1), + (0, 2), + (0, 3), + (0, 4), + (1, 6), + (2, 6), + (3, 6), + (4, 6), + ] + ) + assert [0, 7, 6] == nx.voterank(G) + + def test_voterank_emptygraph(self): + G = nx.Graph() + assert [] == nx.voterank(G) + + # Graph unit test + def test_voterank_centrality_2(self): + G = nx.florentine_families_graph() + d = nx.voterank(G, 4) + exact = ["Medici", "Strozzi", "Guadagni", "Castellani"] + assert exact == d + + # DiGraph unit test + def test_voterank_centrality_3(self): + G = nx.gnc_graph(10, seed=7) + d = nx.voterank(G, 4) + exact = [3, 6, 8] + assert exact == d + + # MultiGraph unit test + def test_voterank_centrality_4(self): + G = nx.MultiGraph() + G.add_edges_from( + [(0, 1), (0, 1), (1, 2), (2, 5), (2, 5), (5, 6), (5, 6), (2, 4), (4, 3)] + ) + exact = [2, 1, 5, 4] + assert exact == nx.voterank(G) + + # MultiDiGraph unit test + def test_voterank_centrality_5(self): + G = nx.MultiDiGraph() + G.add_edges_from( + [(0, 1), (0, 1), (1, 2), (2, 5), (2, 5), (5, 6), (5, 6), (2, 4), (4, 3)] + ) + exact = [2, 0, 5, 4] + assert exact == nx.voterank(G) diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/trophic.py b/lib/python3.10/site-packages/networkx/algorithms/centrality/trophic.py new file mode 100644 index 0000000000000000000000000000000000000000..9e461ced59be8b67f3a644f142acca77a4a114fb --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/centrality/trophic.py @@ -0,0 +1,163 @@ +"""Trophic levels""" + +import networkx as nx +from networkx.utils import not_implemented_for + +__all__ = ["trophic_levels", "trophic_differences", "trophic_incoherence_parameter"] + + +@not_implemented_for("undirected") +@nx._dispatchable(edge_attrs="weight") +def trophic_levels(G, weight="weight"): + r"""Compute the trophic levels of nodes. + + The trophic level of a node $i$ is + + .. math:: + + s_i = 1 + \frac{1}{k^{in}_i} \sum_{j} a_{ij} s_j + + where $k^{in}_i$ is the in-degree of i + + .. math:: + + k^{in}_i = \sum_{j} a_{ij} + + and nodes with $k^{in}_i = 0$ have $s_i = 1$ by convention. + + These are calculated using the method outlined in Levine [1]_. + + Parameters + ---------- + G : DiGraph + A directed networkx graph + + Returns + ------- + nodes : dict + Dictionary of nodes with trophic level as the value. + + References + ---------- + .. [1] Stephen Levine (1980) J. theor. Biol. 83, 195-207 + """ + import numpy as np + + # find adjacency matrix + a = nx.adjacency_matrix(G, weight=weight).T.toarray() + + # drop rows/columns where in-degree is zero + rowsum = np.sum(a, axis=1) + p = a[rowsum != 0][:, rowsum != 0] + # normalise so sum of in-degree weights is 1 along each row + p = p / rowsum[rowsum != 0][:, np.newaxis] + + # calculate trophic levels + nn = p.shape[0] + i = np.eye(nn) + try: + n = np.linalg.inv(i - p) + except np.linalg.LinAlgError as err: + # LinAlgError is raised when there is a non-basal node + msg = ( + "Trophic levels are only defined for graphs where every " + + "node has a path from a basal node (basal nodes are nodes " + + "with no incoming edges)." + ) + raise nx.NetworkXError(msg) from err + y = n.sum(axis=1) + 1 + + levels = {} + + # all nodes with in-degree zero have trophic level == 1 + zero_node_ids = (node_id for node_id, degree in G.in_degree if degree == 0) + for node_id in zero_node_ids: + levels[node_id] = 1 + + # all other nodes have levels as calculated + nonzero_node_ids = (node_id for node_id, degree in G.in_degree if degree != 0) + for i, node_id in enumerate(nonzero_node_ids): + levels[node_id] = y.item(i) + + return levels + + +@not_implemented_for("undirected") +@nx._dispatchable(edge_attrs="weight") +def trophic_differences(G, weight="weight"): + r"""Compute the trophic differences of the edges of a directed graph. + + The trophic difference $x_ij$ for each edge is defined in Johnson et al. + [1]_ as: + + .. math:: + x_ij = s_j - s_i + + Where $s_i$ is the trophic level of node $i$. + + Parameters + ---------- + G : DiGraph + A directed networkx graph + + Returns + ------- + diffs : dict + Dictionary of edges with trophic differences as the value. + + References + ---------- + .. [1] Samuel Johnson, Virginia Dominguez-Garcia, Luca Donetti, Miguel A. + Munoz (2014) PNAS "Trophic coherence determines food-web stability" + """ + levels = trophic_levels(G, weight=weight) + diffs = {} + for u, v in G.edges: + diffs[(u, v)] = levels[v] - levels[u] + return diffs + + +@not_implemented_for("undirected") +@nx._dispatchable(edge_attrs="weight") +def trophic_incoherence_parameter(G, weight="weight", cannibalism=False): + r"""Compute the trophic incoherence parameter of a graph. + + Trophic coherence is defined as the homogeneity of the distribution of + trophic distances: the more similar, the more coherent. This is measured by + the standard deviation of the trophic differences and referred to as the + trophic incoherence parameter $q$ by [1]. + + Parameters + ---------- + G : DiGraph + A directed networkx graph + + cannibalism: Boolean + If set to False, self edges are not considered in the calculation + + Returns + ------- + trophic_incoherence_parameter : float + The trophic coherence of a graph + + References + ---------- + .. [1] Samuel Johnson, Virginia Dominguez-Garcia, Luca Donetti, Miguel A. + Munoz (2014) PNAS "Trophic coherence determines food-web stability" + """ + import numpy as np + + if cannibalism: + diffs = trophic_differences(G, weight=weight) + else: + # If no cannibalism, remove self-edges + self_loops = list(nx.selfloop_edges(G)) + if self_loops: + # Make a copy so we do not change G's edges in memory + G_2 = G.copy() + G_2.remove_edges_from(self_loops) + else: + # Avoid copy otherwise + G_2 = G + diffs = trophic_differences(G_2, weight=weight) + return float(np.std(list(diffs.values()))) diff --git a/lib/python3.10/site-packages/networkx/algorithms/centrality/voterank_alg.py b/lib/python3.10/site-packages/networkx/algorithms/centrality/voterank_alg.py new file mode 100644 index 0000000000000000000000000000000000000000..9b510b2886a5e2eca1eb3396f55f67abc4ce9f30 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/centrality/voterank_alg.py @@ -0,0 +1,95 @@ +"""Algorithm to select influential nodes in a graph using VoteRank.""" + +import networkx as nx + +__all__ = ["voterank"] + + +@nx._dispatchable +def voterank(G, number_of_nodes=None): + """Select a list of influential nodes in a graph using VoteRank algorithm + + VoteRank [1]_ computes a ranking of the nodes in a graph G based on a + voting scheme. With VoteRank, all nodes vote for each of its in-neighbors + and the node with the highest votes is elected iteratively. The voting + ability of out-neighbors of elected nodes is decreased in subsequent turns. + + Parameters + ---------- + G : graph + A NetworkX graph. + + number_of_nodes : integer, optional + Number of ranked nodes to extract (default all nodes). + + Returns + ------- + voterank : list + Ordered list of computed seeds. + Only nodes with positive number of votes are returned. + + Examples + -------- + >>> G = nx.Graph([(0, 1), (0, 2), (0, 3), (1, 4)]) + >>> nx.voterank(G) + [0, 1] + + The algorithm can be used both for undirected and directed graphs. + However, the directed version is different in two ways: + (i) nodes only vote for their in-neighbors and + (ii) only the voting ability of elected node and its out-neighbors are updated: + + >>> G = nx.DiGraph([(0, 1), (2, 1), (2, 3), (3, 4)]) + >>> nx.voterank(G) + [2, 3] + + Notes + ----- + Each edge is treated independently in case of multigraphs. + + References + ---------- + .. [1] Zhang, J.-X. et al. (2016). + Identifying a set of influential spreaders in complex networks. + Sci. Rep. 6, 27823; doi: 10.1038/srep27823. + """ + influential_nodes = [] + vote_rank = {} + if len(G) == 0: + return influential_nodes + if number_of_nodes is None or number_of_nodes > len(G): + number_of_nodes = len(G) + if G.is_directed(): + # For directed graphs compute average out-degree + avgDegree = sum(deg for _, deg in G.out_degree()) / len(G) + else: + # For undirected graphs compute average degree + avgDegree = sum(deg for _, deg in G.degree()) / len(G) + # step 1 - initiate all nodes to (0,1) (score, voting ability) + for n in G.nodes(): + vote_rank[n] = [0, 1] + # Repeat steps 1b to 4 until num_seeds are elected. + for _ in range(number_of_nodes): + # step 1b - reset rank + for n in G.nodes(): + vote_rank[n][0] = 0 + # step 2 - vote + for n, nbr in G.edges(): + # In directed graphs nodes only vote for their in-neighbors + vote_rank[n][0] += vote_rank[nbr][1] + if not G.is_directed(): + vote_rank[nbr][0] += vote_rank[n][1] + for n in influential_nodes: + vote_rank[n][0] = 0 + # step 3 - select top node + n = max(G.nodes, key=lambda x: vote_rank[x][0]) + if vote_rank[n][0] == 0: + return influential_nodes + influential_nodes.append(n) + # weaken the selected node + vote_rank[n] = [0, 0] + # step 4 - update voterank properties + for _, nbr in G.edges(n): + vote_rank[nbr][1] -= 1 / avgDegree + vote_rank[nbr][1] = max(vote_rank[nbr][1], 0) + return influential_nodes diff --git a/lib/python3.10/site-packages/networkx/algorithms/coloring/__init__.py b/lib/python3.10/site-packages/networkx/algorithms/coloring/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..39381d9f163a5400f362b91a89215bfc915a8022 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/coloring/__init__.py @@ -0,0 +1,4 @@ +from networkx.algorithms.coloring.greedy_coloring import * +from networkx.algorithms.coloring.equitable_coloring import equitable_color + +__all__ = ["greedy_color", "equitable_color"] diff --git a/lib/python3.10/site-packages/networkx/algorithms/coloring/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/coloring/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..07876cc28975e276f083e578275916f4e0e7c2a3 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/coloring/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/coloring/__pycache__/equitable_coloring.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/coloring/__pycache__/equitable_coloring.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d96dbca90571639c4adcae02e740f4f8bd82d20b Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/coloring/__pycache__/equitable_coloring.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/coloring/__pycache__/greedy_coloring.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/coloring/__pycache__/greedy_coloring.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac6302c3663055715324c25425e33a7606539f8e Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/coloring/__pycache__/greedy_coloring.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/coloring/equitable_coloring.py b/lib/python3.10/site-packages/networkx/algorithms/coloring/equitable_coloring.py new file mode 100644 index 0000000000000000000000000000000000000000..e464a07447045fcdaa8e7ca4ea56552fb00e2826 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/coloring/equitable_coloring.py @@ -0,0 +1,505 @@ +""" +Equitable coloring of graphs with bounded degree. +""" + +from collections import defaultdict + +import networkx as nx + +__all__ = ["equitable_color"] + + +@nx._dispatchable +def is_coloring(G, coloring): + """Determine if the coloring is a valid coloring for the graph G.""" + # Verify that the coloring is valid. + return all(coloring[s] != coloring[d] for s, d in G.edges) + + +@nx._dispatchable +def is_equitable(G, coloring, num_colors=None): + """Determines if the coloring is valid and equitable for the graph G.""" + + if not is_coloring(G, coloring): + return False + + # Verify whether it is equitable. + color_set_size = defaultdict(int) + for color in coloring.values(): + color_set_size[color] += 1 + + if num_colors is not None: + for color in range(num_colors): + if color not in color_set_size: + # These colors do not have any vertices attached to them. + color_set_size[color] = 0 + + # If there are more than 2 distinct values, the coloring cannot be equitable + all_set_sizes = set(color_set_size.values()) + if len(all_set_sizes) == 0 and num_colors is None: # Was an empty graph + return True + elif len(all_set_sizes) == 1: + return True + elif len(all_set_sizes) == 2: + a, b = list(all_set_sizes) + return abs(a - b) <= 1 + else: # len(all_set_sizes) > 2: + return False + + +def make_C_from_F(F): + C = defaultdict(list) + for node, color in F.items(): + C[color].append(node) + + return C + + +def make_N_from_L_C(L, C): + nodes = L.keys() + colors = C.keys() + return { + (node, color): sum(1 for v in L[node] if v in C[color]) + for node in nodes + for color in colors + } + + +def make_H_from_C_N(C, N): + return { + (c1, c2): sum(1 for node in C[c1] if N[(node, c2)] == 0) for c1 in C for c2 in C + } + + +def change_color(u, X, Y, N, H, F, C, L): + """Change the color of 'u' from X to Y and update N, H, F, C.""" + assert F[u] == X and X != Y + + # Change the class of 'u' from X to Y + F[u] = Y + + for k in C: + # 'u' witnesses an edge from k -> Y instead of from k -> X now. + if N[u, k] == 0: + H[(X, k)] -= 1 + H[(Y, k)] += 1 + + for v in L[u]: + # 'v' has lost a neighbor in X and gained one in Y + N[(v, X)] -= 1 + N[(v, Y)] += 1 + + if N[(v, X)] == 0: + # 'v' witnesses F[v] -> X + H[(F[v], X)] += 1 + + if N[(v, Y)] == 1: + # 'v' no longer witnesses F[v] -> Y + H[(F[v], Y)] -= 1 + + C[X].remove(u) + C[Y].append(u) + + +def move_witnesses(src_color, dst_color, N, H, F, C, T_cal, L): + """Move witness along a path from src_color to dst_color.""" + X = src_color + while X != dst_color: + Y = T_cal[X] + # Move _any_ witness from X to Y = T_cal[X] + w = next(x for x in C[X] if N[(x, Y)] == 0) + change_color(w, X, Y, N=N, H=H, F=F, C=C, L=L) + X = Y + + +@nx._dispatchable(mutates_input=True) +def pad_graph(G, num_colors): + """Add a disconnected complete clique K_p such that the number of nodes in + the graph becomes a multiple of `num_colors`. + + Assumes that the graph's nodes are labelled using integers. + + Returns the number of nodes with each color. + """ + + n_ = len(G) + r = num_colors - 1 + + # Ensure that the number of nodes in G is a multiple of (r + 1) + s = n_ // (r + 1) + if n_ != s * (r + 1): + p = (r + 1) - n_ % (r + 1) + s += 1 + + # Complete graph K_p between (imaginary) nodes [n_, ... , n_ + p] + K = nx.relabel_nodes(nx.complete_graph(p), {idx: idx + n_ for idx in range(p)}) + G.add_edges_from(K.edges) + + return s + + +def procedure_P(V_minus, V_plus, N, H, F, C, L, excluded_colors=None): + """Procedure P as described in the paper.""" + + if excluded_colors is None: + excluded_colors = set() + + A_cal = set() + T_cal = {} + R_cal = [] + + # BFS to determine A_cal, i.e. colors reachable from V- + reachable = [V_minus] + marked = set(reachable) + idx = 0 + + while idx < len(reachable): + pop = reachable[idx] + idx += 1 + + A_cal.add(pop) + R_cal.append(pop) + + # TODO: Checking whether a color has been visited can be made faster by + # using a look-up table instead of testing for membership in a set by a + # logarithmic factor. + next_layer = [] + for k in C: + if ( + H[(k, pop)] > 0 + and k not in A_cal + and k not in excluded_colors + and k not in marked + ): + next_layer.append(k) + + for dst in next_layer: + # Record that `dst` can reach `pop` + T_cal[dst] = pop + + marked.update(next_layer) + reachable.extend(next_layer) + + # Variables for the algorithm + b = len(C) - len(A_cal) + + if V_plus in A_cal: + # Easy case: V+ is in A_cal + # Move one node from V+ to V- using T_cal to find the parents. + move_witnesses(V_plus, V_minus, N=N, H=H, F=F, C=C, T_cal=T_cal, L=L) + else: + # If there is a solo edge, we can resolve the situation by + # moving witnesses from B to A, making G[A] equitable and then + # recursively balancing G[B - w] with a different V_minus and + # but the same V_plus. + + A_0 = set() + A_cal_0 = set() + num_terminal_sets_found = 0 + made_equitable = False + + for W_1 in R_cal[::-1]: + for v in C[W_1]: + X = None + + for U in C: + if N[(v, U)] == 0 and U in A_cal and U != W_1: + X = U + + # v does not witness an edge in H[A_cal] + if X is None: + continue + + for U in C: + # Note: Departing from the paper here. + if N[(v, U)] >= 1 and U not in A_cal: + X_prime = U + w = v + + try: + # Finding the solo neighbor of w in X_prime + y = next( + node + for node in L[w] + if F[node] == X_prime and N[(node, W_1)] == 1 + ) + except StopIteration: + pass + else: + W = W_1 + + # Move w from W to X, now X has one extra node. + change_color(w, W, X, N=N, H=H, F=F, C=C, L=L) + + # Move witness from X to V_minus, making the coloring + # equitable. + move_witnesses( + src_color=X, + dst_color=V_minus, + N=N, + H=H, + F=F, + C=C, + T_cal=T_cal, + L=L, + ) + + # Move y from X_prime to W, making W the correct size. + change_color(y, X_prime, W, N=N, H=H, F=F, C=C, L=L) + + # Then call the procedure on G[B - y] + procedure_P( + V_minus=X_prime, + V_plus=V_plus, + N=N, + H=H, + C=C, + F=F, + L=L, + excluded_colors=excluded_colors.union(A_cal), + ) + made_equitable = True + break + + if made_equitable: + break + else: + # No node in W_1 was found such that + # it had a solo-neighbor. + A_cal_0.add(W_1) + A_0.update(C[W_1]) + num_terminal_sets_found += 1 + + if num_terminal_sets_found == b: + # Otherwise, construct the maximal independent set and find + # a pair of z_1, z_2 as in Case II. + + # BFS to determine B_cal': the set of colors reachable from V+ + B_cal_prime = set() + T_cal_prime = {} + + reachable = [V_plus] + marked = set(reachable) + idx = 0 + while idx < len(reachable): + pop = reachable[idx] + idx += 1 + + B_cal_prime.add(pop) + + # No need to check for excluded_colors here because + # they only exclude colors from A_cal + next_layer = [ + k + for k in C + if H[(pop, k)] > 0 and k not in B_cal_prime and k not in marked + ] + + for dst in next_layer: + T_cal_prime[pop] = dst + + marked.update(next_layer) + reachable.extend(next_layer) + + # Construct the independent set of G[B'] + I_set = set() + I_covered = set() + W_covering = {} + + B_prime = [node for k in B_cal_prime for node in C[k]] + + # Add the nodes in V_plus to I first. + for z in C[V_plus] + B_prime: + if z in I_covered or F[z] not in B_cal_prime: + continue + + I_set.add(z) + I_covered.add(z) + I_covered.update(list(L[z])) + + for w in L[z]: + if F[w] in A_cal_0 and N[(z, F[w])] == 1: + if w not in W_covering: + W_covering[w] = z + else: + # Found z1, z2 which have the same solo + # neighbor in some W + z_1 = W_covering[w] + # z_2 = z + + Z = F[z_1] + W = F[w] + + # shift nodes along W, V- + move_witnesses( + W, V_minus, N=N, H=H, F=F, C=C, T_cal=T_cal, L=L + ) + + # shift nodes along V+ to Z + move_witnesses( + V_plus, + Z, + N=N, + H=H, + F=F, + C=C, + T_cal=T_cal_prime, + L=L, + ) + + # change color of z_1 to W + change_color(z_1, Z, W, N=N, H=H, F=F, C=C, L=L) + + # change color of w to some color in B_cal + W_plus = next( + k for k in C if N[(w, k)] == 0 and k not in A_cal + ) + change_color(w, W, W_plus, N=N, H=H, F=F, C=C, L=L) + + # recurse with G[B \cup W*] + excluded_colors.update( + [k for k in C if k != W and k not in B_cal_prime] + ) + procedure_P( + V_minus=W, + V_plus=W_plus, + N=N, + H=H, + C=C, + F=F, + L=L, + excluded_colors=excluded_colors, + ) + + made_equitable = True + break + + if made_equitable: + break + else: + assert False, ( + "Must find a w which is the solo neighbor " + "of two vertices in B_cal_prime." + ) + + if made_equitable: + break + + +@nx._dispatchable +def equitable_color(G, num_colors): + """Provides an equitable coloring for nodes of `G`. + + Attempts to color a graph using `num_colors` colors, where no neighbors of + a node can have same color as the node itself and the number of nodes with + each color differ by at most 1. `num_colors` must be greater than the + maximum degree of `G`. The algorithm is described in [1]_ and has + complexity O(num_colors * n**2). + + Parameters + ---------- + G : networkX graph + The nodes of this graph will be colored. + + num_colors : number of colors to use + This number must be at least one more than the maximum degree of nodes + in the graph. + + Returns + ------- + A dictionary with keys representing nodes and values representing + corresponding coloring. + + Examples + -------- + >>> G = nx.cycle_graph(4) + >>> nx.coloring.equitable_color(G, num_colors=3) # doctest: +SKIP + {0: 2, 1: 1, 2: 2, 3: 0} + + Raises + ------ + NetworkXAlgorithmError + If `num_colors` is not at least the maximum degree of the graph `G` + + References + ---------- + .. [1] Kierstead, H. A., Kostochka, A. V., Mydlarz, M., & Szemerédi, E. + (2010). A fast algorithm for equitable coloring. Combinatorica, 30(2), + 217-224. + """ + + # Map nodes to integers for simplicity later. + nodes_to_int = {} + int_to_nodes = {} + + for idx, node in enumerate(G.nodes): + nodes_to_int[node] = idx + int_to_nodes[idx] = node + + G = nx.relabel_nodes(G, nodes_to_int, copy=True) + + # Basic graph statistics and sanity check. + if len(G.nodes) > 0: + r_ = max(G.degree(node) for node in G.nodes) + else: + r_ = 0 + + if r_ >= num_colors: + raise nx.NetworkXAlgorithmError( + f"Graph has maximum degree {r_}, needs " + f"{r_ + 1} (> {num_colors}) colors for guaranteed coloring." + ) + + # Ensure that the number of nodes in G is a multiple of (r + 1) + pad_graph(G, num_colors) + + # Starting the algorithm. + # L = {node: list(G.neighbors(node)) for node in G.nodes} + L_ = {node: [] for node in G.nodes} + + # Arbitrary equitable allocation of colors to nodes. + F = {node: idx % num_colors for idx, node in enumerate(G.nodes)} + + C = make_C_from_F(F) + + # The neighborhood is empty initially. + N = make_N_from_L_C(L_, C) + + # Currently all nodes witness all edges. + H = make_H_from_C_N(C, N) + + # Start of algorithm. + edges_seen = set() + + for u in sorted(G.nodes): + for v in sorted(G.neighbors(u)): + # Do not double count edges if (v, u) has already been seen. + if (v, u) in edges_seen: + continue + + edges_seen.add((u, v)) + + L_[u].append(v) + L_[v].append(u) + + N[(u, F[v])] += 1 + N[(v, F[u])] += 1 + + if F[u] != F[v]: + # Were 'u' and 'v' witnesses for F[u] -> F[v] or F[v] -> F[u]? + if N[(u, F[v])] == 1: + H[F[u], F[v]] -= 1 # u cannot witness an edge between F[u], F[v] + + if N[(v, F[u])] == 1: + H[F[v], F[u]] -= 1 # v cannot witness an edge between F[v], F[u] + + if N[(u, F[u])] != 0: + # Find the first color where 'u' does not have any neighbors. + Y = next(k for k in C if N[(u, k)] == 0) + X = F[u] + change_color(u, X, Y, N=N, H=H, F=F, C=C, L=L_) + + # Procedure P + procedure_P(V_minus=X, V_plus=Y, N=N, H=H, F=F, C=C, L=L_) + + return {int_to_nodes[x]: F[x] for x in int_to_nodes} diff --git a/lib/python3.10/site-packages/networkx/algorithms/coloring/greedy_coloring.py b/lib/python3.10/site-packages/networkx/algorithms/coloring/greedy_coloring.py new file mode 100644 index 0000000000000000000000000000000000000000..9be07803fa85823617bdf4ad6c30966f96b741e4 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/coloring/greedy_coloring.py @@ -0,0 +1,565 @@ +""" +Greedy graph coloring using various strategies. +""" + +import itertools +from collections import defaultdict, deque + +import networkx as nx +from networkx.utils import arbitrary_element, py_random_state + +__all__ = [ + "greedy_color", + "strategy_connected_sequential", + "strategy_connected_sequential_bfs", + "strategy_connected_sequential_dfs", + "strategy_independent_set", + "strategy_largest_first", + "strategy_random_sequential", + "strategy_saturation_largest_first", + "strategy_smallest_last", +] + + +def strategy_largest_first(G, colors): + """Returns a list of the nodes of ``G`` in decreasing order by + degree. + + ``G`` is a NetworkX graph. ``colors`` is ignored. + + """ + return sorted(G, key=G.degree, reverse=True) + + +@py_random_state(2) +def strategy_random_sequential(G, colors, seed=None): + """Returns a random permutation of the nodes of ``G`` as a list. + + ``G`` is a NetworkX graph. ``colors`` is ignored. + + seed : integer, random_state, or None (default) + Indicator of random number generation state. + See :ref:`Randomness`. + """ + nodes = list(G) + seed.shuffle(nodes) + return nodes + + +def strategy_smallest_last(G, colors): + """Returns a deque of the nodes of ``G``, "smallest" last. + + Specifically, the degrees of each node are tracked in a bucket queue. + From this, the node of minimum degree is repeatedly popped from the + graph, updating its neighbors' degrees. + + ``G`` is a NetworkX graph. ``colors`` is ignored. + + This implementation of the strategy runs in $O(n + m)$ time + (ignoring polylogarithmic factors), where $n$ is the number of nodes + and $m$ is the number of edges. + + This strategy is related to :func:`strategy_independent_set`: if we + interpret each node removed as an independent set of size one, then + this strategy chooses an independent set of size one instead of a + maximal independent set. + + """ + H = G.copy() + result = deque() + + # Build initial degree list (i.e. the bucket queue data structure) + degrees = defaultdict(set) # set(), for fast random-access removals + lbound = float("inf") + for node, d in H.degree(): + degrees[d].add(node) + lbound = min(lbound, d) # Lower bound on min-degree. + + def find_min_degree(): + # Save time by starting the iterator at `lbound`, not 0. + # The value that we find will be our new `lbound`, which we set later. + return next(d for d in itertools.count(lbound) if d in degrees) + + for _ in G: + # Pop a min-degree node and add it to the list. + min_degree = find_min_degree() + u = degrees[min_degree].pop() + if not degrees[min_degree]: # Clean up the degree list. + del degrees[min_degree] + result.appendleft(u) + + # Update degrees of removed node's neighbors. + for v in H[u]: + degree = H.degree(v) + degrees[degree].remove(v) + if not degrees[degree]: # Clean up the degree list. + del degrees[degree] + degrees[degree - 1].add(v) + + # Finally, remove the node. + H.remove_node(u) + lbound = min_degree - 1 # Subtract 1 in case of tied neighbors. + + return result + + +def _maximal_independent_set(G): + """Returns a maximal independent set of nodes in ``G`` by repeatedly + choosing an independent node of minimum degree (with respect to the + subgraph of unchosen nodes). + + """ + result = set() + remaining = set(G) + while remaining: + G = G.subgraph(remaining) + v = min(remaining, key=G.degree) + result.add(v) + remaining -= set(G[v]) | {v} + return result + + +def strategy_independent_set(G, colors): + """Uses a greedy independent set removal strategy to determine the + colors. + + This function updates ``colors`` **in-place** and return ``None``, + unlike the other strategy functions in this module. + + This algorithm repeatedly finds and removes a maximal independent + set, assigning each node in the set an unused color. + + ``G`` is a NetworkX graph. + + This strategy is related to :func:`strategy_smallest_last`: in that + strategy, an independent set of size one is chosen at each step + instead of a maximal independent set. + + """ + remaining_nodes = set(G) + while len(remaining_nodes) > 0: + nodes = _maximal_independent_set(G.subgraph(remaining_nodes)) + remaining_nodes -= nodes + yield from nodes + + +def strategy_connected_sequential_bfs(G, colors): + """Returns an iterable over nodes in ``G`` in the order given by a + breadth-first traversal. + + The generated sequence has the property that for each node except + the first, at least one neighbor appeared earlier in the sequence. + + ``G`` is a NetworkX graph. ``colors`` is ignored. + + """ + return strategy_connected_sequential(G, colors, "bfs") + + +def strategy_connected_sequential_dfs(G, colors): + """Returns an iterable over nodes in ``G`` in the order given by a + depth-first traversal. + + The generated sequence has the property that for each node except + the first, at least one neighbor appeared earlier in the sequence. + + ``G`` is a NetworkX graph. ``colors`` is ignored. + + """ + return strategy_connected_sequential(G, colors, "dfs") + + +def strategy_connected_sequential(G, colors, traversal="bfs"): + """Returns an iterable over nodes in ``G`` in the order given by a + breadth-first or depth-first traversal. + + ``traversal`` must be one of the strings ``'dfs'`` or ``'bfs'``, + representing depth-first traversal or breadth-first traversal, + respectively. + + The generated sequence has the property that for each node except + the first, at least one neighbor appeared earlier in the sequence. + + ``G`` is a NetworkX graph. ``colors`` is ignored. + + """ + if traversal == "bfs": + traverse = nx.bfs_edges + elif traversal == "dfs": + traverse = nx.dfs_edges + else: + raise nx.NetworkXError( + "Please specify one of the strings 'bfs' or" + " 'dfs' for connected sequential ordering" + ) + for component in nx.connected_components(G): + source = arbitrary_element(component) + # Yield the source node, then all the nodes in the specified + # traversal order. + yield source + for _, end in traverse(G.subgraph(component), source): + yield end + + +def strategy_saturation_largest_first(G, colors): + """Iterates over all the nodes of ``G`` in "saturation order" (also + known as "DSATUR"). + + ``G`` is a NetworkX graph. ``colors`` is a dictionary mapping nodes of + ``G`` to colors, for those nodes that have already been colored. + + """ + distinct_colors = {v: set() for v in G} + + # Add the node color assignments given in colors to the + # distinct colors set for each neighbor of that node + for node, color in colors.items(): + for neighbor in G[node]: + distinct_colors[neighbor].add(color) + + # Check that the color assignments in colors are valid + # i.e. no neighboring nodes have the same color + if len(colors) >= 2: + for node, color in colors.items(): + if color in distinct_colors[node]: + raise nx.NetworkXError("Neighboring nodes must have different colors") + + # If 0 nodes have been colored, simply choose the node of highest degree. + if not colors: + node = max(G, key=G.degree) + yield node + # Add the color 0 to the distinct colors set for each + # neighbor of that node. + for v in G[node]: + distinct_colors[v].add(0) + + while len(G) != len(colors): + # Update the distinct color sets for the neighbors. + for node, color in colors.items(): + for neighbor in G[node]: + distinct_colors[neighbor].add(color) + + # Compute the maximum saturation and the set of nodes that + # achieve that saturation. + saturation = {v: len(c) for v, c in distinct_colors.items() if v not in colors} + # Yield the node with the highest saturation, and break ties by + # degree. + node = max(saturation, key=lambda v: (saturation[v], G.degree(v))) + yield node + + +#: Dictionary mapping name of a strategy as a string to the strategy function. +STRATEGIES = { + "largest_first": strategy_largest_first, + "random_sequential": strategy_random_sequential, + "smallest_last": strategy_smallest_last, + "independent_set": strategy_independent_set, + "connected_sequential_bfs": strategy_connected_sequential_bfs, + "connected_sequential_dfs": strategy_connected_sequential_dfs, + "connected_sequential": strategy_connected_sequential, + "saturation_largest_first": strategy_saturation_largest_first, + "DSATUR": strategy_saturation_largest_first, +} + + +@nx._dispatchable +def greedy_color(G, strategy="largest_first", interchange=False): + """Color a graph using various strategies of greedy graph coloring. + + Attempts to color a graph using as few colors as possible, where no + neighbors of a node can have same color as the node itself. The + given strategy determines the order in which nodes are colored. + + The strategies are described in [1]_, and smallest-last is based on + [2]_. + + Parameters + ---------- + G : NetworkX graph + + strategy : string or function(G, colors) + A function (or a string representing a function) that provides + the coloring strategy, by returning nodes in the ordering they + should be colored. ``G`` is the graph, and ``colors`` is a + dictionary of the currently assigned colors, keyed by nodes. The + function must return an iterable over all the nodes in ``G``. + + If the strategy function is an iterator generator (that is, a + function with ``yield`` statements), keep in mind that the + ``colors`` dictionary will be updated after each ``yield``, since + this function chooses colors greedily. + + If ``strategy`` is a string, it must be one of the following, + each of which represents one of the built-in strategy functions. + + * ``'largest_first'`` + * ``'random_sequential'`` + * ``'smallest_last'`` + * ``'independent_set'`` + * ``'connected_sequential_bfs'`` + * ``'connected_sequential_dfs'`` + * ``'connected_sequential'`` (alias for the previous strategy) + * ``'saturation_largest_first'`` + * ``'DSATUR'`` (alias for the previous strategy) + + interchange: bool + Will use the color interchange algorithm described by [3]_ if set + to ``True``. + + Note that ``saturation_largest_first`` and ``independent_set`` + do not work with interchange. Furthermore, if you use + interchange with your own strategy function, you cannot rely + on the values in the ``colors`` argument. + + Returns + ------- + A dictionary with keys representing nodes and values representing + corresponding coloring. + + Examples + -------- + >>> G = nx.cycle_graph(4) + >>> d = nx.coloring.greedy_color(G, strategy="largest_first") + >>> d in [{0: 0, 1: 1, 2: 0, 3: 1}, {0: 1, 1: 0, 2: 1, 3: 0}] + True + + Raises + ------ + NetworkXPointlessConcept + If ``strategy`` is ``saturation_largest_first`` or + ``independent_set`` and ``interchange`` is ``True``. + + References + ---------- + .. [1] Adrian Kosowski, and Krzysztof Manuszewski, + Classical Coloring of Graphs, Graph Colorings, 2-19, 2004. + ISBN 0-8218-3458-4. + .. [2] David W. Matula, and Leland L. Beck, "Smallest-last + ordering and clustering and graph coloring algorithms." *J. ACM* 30, + 3 (July 1983), 417–427. + .. [3] Maciej M. Sysło, Narsingh Deo, Janusz S. Kowalik, + Discrete Optimization Algorithms with Pascal Programs, 415-424, 1983. + ISBN 0-486-45353-7. + + """ + if len(G) == 0: + return {} + # Determine the strategy provided by the caller. + strategy = STRATEGIES.get(strategy, strategy) + if not callable(strategy): + raise nx.NetworkXError( + f"strategy must be callable or a valid string. {strategy} not valid." + ) + # Perform some validation on the arguments before executing any + # strategy functions. + if interchange: + if strategy is strategy_independent_set: + msg = "interchange cannot be used with independent_set" + raise nx.NetworkXPointlessConcept(msg) + if strategy is strategy_saturation_largest_first: + msg = "interchange cannot be used with" " saturation_largest_first" + raise nx.NetworkXPointlessConcept(msg) + colors = {} + nodes = strategy(G, colors) + if interchange: + return _greedy_coloring_with_interchange(G, nodes) + for u in nodes: + # Set to keep track of colors of neighbors + nbr_colors = {colors[v] for v in G[u] if v in colors} + # Find the first unused color. + for color in itertools.count(): + if color not in nbr_colors: + break + # Assign the new color to the current node. + colors[u] = color + return colors + + +# Tools for coloring with interchanges +class _Node: + __slots__ = ["node_id", "color", "adj_list", "adj_color"] + + def __init__(self, node_id, n): + self.node_id = node_id + self.color = -1 + self.adj_list = None + self.adj_color = [None for _ in range(n)] + + def __repr__(self): + return ( + f"Node_id: {self.node_id}, Color: {self.color}, " + f"Adj_list: ({self.adj_list}), adj_color: ({self.adj_color})" + ) + + def assign_color(self, adj_entry, color): + adj_entry.col_prev = None + adj_entry.col_next = self.adj_color[color] + self.adj_color[color] = adj_entry + if adj_entry.col_next is not None: + adj_entry.col_next.col_prev = adj_entry + + def clear_color(self, adj_entry, color): + if adj_entry.col_prev is None: + self.adj_color[color] = adj_entry.col_next + else: + adj_entry.col_prev.col_next = adj_entry.col_next + if adj_entry.col_next is not None: + adj_entry.col_next.col_prev = adj_entry.col_prev + + def iter_neighbors(self): + adj_node = self.adj_list + while adj_node is not None: + yield adj_node + adj_node = adj_node.next + + def iter_neighbors_color(self, color): + adj_color_node = self.adj_color[color] + while adj_color_node is not None: + yield adj_color_node.node_id + adj_color_node = adj_color_node.col_next + + +class _AdjEntry: + __slots__ = ["node_id", "next", "mate", "col_next", "col_prev"] + + def __init__(self, node_id): + self.node_id = node_id + self.next = None + self.mate = None + self.col_next = None + self.col_prev = None + + def __repr__(self): + col_next = None if self.col_next is None else self.col_next.node_id + col_prev = None if self.col_prev is None else self.col_prev.node_id + return ( + f"Node_id: {self.node_id}, Next: ({self.next}), " + f"Mate: ({self.mate.node_id}), " + f"col_next: ({col_next}), col_prev: ({col_prev})" + ) + + +def _greedy_coloring_with_interchange(G, nodes): + """Return a coloring for `original_graph` using interchange approach + + This procedure is an adaption of the algorithm described by [1]_, + and is an implementation of coloring with interchange. Please be + advised, that the datastructures used are rather complex because + they are optimized to minimize the time spent identifying + subcomponents of the graph, which are possible candidates for color + interchange. + + Parameters + ---------- + G : NetworkX graph + The graph to be colored + + nodes : list + nodes ordered using the strategy of choice + + Returns + ------- + dict : + A dictionary keyed by node to a color value + + References + ---------- + .. [1] Maciej M. Syslo, Narsingh Deo, Janusz S. Kowalik, + Discrete Optimization Algorithms with Pascal Programs, 415-424, 1983. + ISBN 0-486-45353-7. + """ + n = len(G) + + graph = {node: _Node(node, n) for node in G} + + for node1, node2 in G.edges(): + adj_entry1 = _AdjEntry(node2) + adj_entry2 = _AdjEntry(node1) + adj_entry1.mate = adj_entry2 + adj_entry2.mate = adj_entry1 + node1_head = graph[node1].adj_list + adj_entry1.next = node1_head + graph[node1].adj_list = adj_entry1 + node2_head = graph[node2].adj_list + adj_entry2.next = node2_head + graph[node2].adj_list = adj_entry2 + + k = 0 + for node in nodes: + # Find the smallest possible, unused color + neighbors = graph[node].iter_neighbors() + col_used = {graph[adj_node.node_id].color for adj_node in neighbors} + col_used.discard(-1) + k1 = next(itertools.dropwhile(lambda x: x in col_used, itertools.count())) + + # k1 is now the lowest available color + if k1 > k: + connected = True + visited = set() + col1 = -1 + col2 = -1 + while connected and col1 < k: + col1 += 1 + neighbor_cols = graph[node].iter_neighbors_color(col1) + col1_adj = list(neighbor_cols) + + col2 = col1 + while connected and col2 < k: + col2 += 1 + visited = set(col1_adj) + frontier = list(col1_adj) + i = 0 + while i < len(frontier): + search_node = frontier[i] + i += 1 + col_opp = col2 if graph[search_node].color == col1 else col1 + neighbor_cols = graph[search_node].iter_neighbors_color(col_opp) + + for neighbor in neighbor_cols: + if neighbor not in visited: + visited.add(neighbor) + frontier.append(neighbor) + + # Search if node is not adj to any col2 vertex + connected = ( + len( + visited.intersection(graph[node].iter_neighbors_color(col2)) + ) + > 0 + ) + + # If connected is false then we can swap !!! + if not connected: + # Update all the nodes in the component + for search_node in visited: + graph[search_node].color = ( + col2 if graph[search_node].color == col1 else col1 + ) + col2_adj = graph[search_node].adj_color[col2] + graph[search_node].adj_color[col2] = graph[search_node].adj_color[ + col1 + ] + graph[search_node].adj_color[col1] = col2_adj + + # Update all the neighboring nodes + for search_node in visited: + col = graph[search_node].color + col_opp = col1 if col == col2 else col2 + for adj_node in graph[search_node].iter_neighbors(): + if graph[adj_node.node_id].color != col_opp: + # Direct reference to entry + adj_mate = adj_node.mate + graph[adj_node.node_id].clear_color(adj_mate, col_opp) + graph[adj_node.node_id].assign_color(adj_mate, col) + k1 = col1 + + # We can color this node color k1 + graph[node].color = k1 + k = max(k1, k) + + # Update the neighbors of this node + for adj_node in graph[node].iter_neighbors(): + adj_mate = adj_node.mate + graph[adj_node.node_id].assign_color(adj_mate, k1) + + return {node.node_id: node.color for node in graph.values()} diff --git a/lib/python3.10/site-packages/networkx/algorithms/coloring/tests/__init__.py b/lib/python3.10/site-packages/networkx/algorithms/coloring/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/python3.10/site-packages/networkx/algorithms/coloring/tests/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/coloring/tests/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f08713ee55d41df84c169e32f018d578f09ef85e Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/coloring/tests/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/coloring/tests/__pycache__/test_coloring.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/coloring/tests/__pycache__/test_coloring.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2146beaae63002d2580525327481f2a716719bbc Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/coloring/tests/__pycache__/test_coloring.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/coloring/tests/test_coloring.py b/lib/python3.10/site-packages/networkx/algorithms/coloring/tests/test_coloring.py new file mode 100644 index 0000000000000000000000000000000000000000..1e5a913c7c07bc0060274e611cef18b054d71238 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/coloring/tests/test_coloring.py @@ -0,0 +1,863 @@ +"""Greedy coloring test suite.""" + +import itertools + +import pytest + +import networkx as nx + +is_coloring = nx.algorithms.coloring.equitable_coloring.is_coloring +is_equitable = nx.algorithms.coloring.equitable_coloring.is_equitable + + +ALL_STRATEGIES = [ + "largest_first", + "random_sequential", + "smallest_last", + "independent_set", + "connected_sequential_bfs", + "connected_sequential_dfs", + "connected_sequential", + "saturation_largest_first", + "DSATUR", +] + +# List of strategies where interchange=True results in an error +INTERCHANGE_INVALID = ["independent_set", "saturation_largest_first", "DSATUR"] + + +class TestColoring: + def test_basic_cases(self): + def check_basic_case(graph_func, n_nodes, strategy, interchange): + graph = graph_func() + coloring = nx.coloring.greedy_color( + graph, strategy=strategy, interchange=interchange + ) + assert verify_length(coloring, n_nodes) + assert verify_coloring(graph, coloring) + + for graph_func, n_nodes in BASIC_TEST_CASES.items(): + for interchange in [True, False]: + for strategy in ALL_STRATEGIES: + check_basic_case(graph_func, n_nodes, strategy, False) + if strategy not in INTERCHANGE_INVALID: + check_basic_case(graph_func, n_nodes, strategy, True) + + def test_special_cases(self): + def check_special_case(strategy, graph_func, interchange, colors): + graph = graph_func() + coloring = nx.coloring.greedy_color( + graph, strategy=strategy, interchange=interchange + ) + if not hasattr(colors, "__len__"): + colors = [colors] + assert any(verify_length(coloring, n_colors) for n_colors in colors) + assert verify_coloring(graph, coloring) + + for strategy, arglist in SPECIAL_TEST_CASES.items(): + for args in arglist: + check_special_case(strategy, args[0], args[1], args[2]) + + def test_interchange_invalid(self): + graph = one_node_graph() + for strategy in INTERCHANGE_INVALID: + pytest.raises( + nx.NetworkXPointlessConcept, + nx.coloring.greedy_color, + graph, + strategy=strategy, + interchange=True, + ) + + def test_bad_inputs(self): + graph = one_node_graph() + pytest.raises( + nx.NetworkXError, + nx.coloring.greedy_color, + graph, + strategy="invalid strategy", + ) + + def test_strategy_as_function(self): + graph = lf_shc() + colors_1 = nx.coloring.greedy_color(graph, "largest_first") + colors_2 = nx.coloring.greedy_color(graph, nx.coloring.strategy_largest_first) + assert colors_1 == colors_2 + + def test_seed_argument(self): + graph = lf_shc() + rs = nx.coloring.strategy_random_sequential + c1 = nx.coloring.greedy_color(graph, lambda g, c: rs(g, c, seed=1)) + for u, v in graph.edges: + assert c1[u] != c1[v] + + def test_is_coloring(self): + G = nx.Graph() + G.add_edges_from([(0, 1), (1, 2)]) + coloring = {0: 0, 1: 1, 2: 0} + assert is_coloring(G, coloring) + + coloring[0] = 1 + assert not is_coloring(G, coloring) + assert not is_equitable(G, coloring) + + def test_is_equitable(self): + G = nx.Graph() + G.add_edges_from([(0, 1), (1, 2)]) + coloring = {0: 0, 1: 1, 2: 0} + assert is_equitable(G, coloring) + + G.add_edges_from([(2, 3), (2, 4), (2, 5)]) + coloring[3] = 1 + coloring[4] = 1 + coloring[5] = 1 + assert is_coloring(G, coloring) + assert not is_equitable(G, coloring) + + def test_num_colors(self): + G = nx.Graph() + G.add_edges_from([(0, 1), (0, 2), (0, 3)]) + pytest.raises(nx.NetworkXAlgorithmError, nx.coloring.equitable_color, G, 2) + + def test_equitable_color(self): + G = nx.fast_gnp_random_graph(n=10, p=0.2, seed=42) + coloring = nx.coloring.equitable_color(G, max_degree(G) + 1) + assert is_equitable(G, coloring) + + def test_equitable_color_empty(self): + G = nx.empty_graph() + coloring = nx.coloring.equitable_color(G, max_degree(G) + 1) + assert is_equitable(G, coloring) + + def test_equitable_color_large(self): + G = nx.fast_gnp_random_graph(100, 0.1, seed=42) + coloring = nx.coloring.equitable_color(G, max_degree(G) + 1) + assert is_equitable(G, coloring, num_colors=max_degree(G) + 1) + + def test_case_V_plus_not_in_A_cal(self): + # Hand crafted case to avoid the easy case. + L = { + 0: [2, 5], + 1: [3, 4], + 2: [0, 8], + 3: [1, 7], + 4: [1, 6], + 5: [0, 6], + 6: [4, 5], + 7: [3], + 8: [2], + } + + F = { + # Color 0 + 0: 0, + 1: 0, + # Color 1 + 2: 1, + 3: 1, + 4: 1, + 5: 1, + # Color 2 + 6: 2, + 7: 2, + 8: 2, + } + + C = nx.algorithms.coloring.equitable_coloring.make_C_from_F(F) + N = nx.algorithms.coloring.equitable_coloring.make_N_from_L_C(L, C) + H = nx.algorithms.coloring.equitable_coloring.make_H_from_C_N(C, N) + + nx.algorithms.coloring.equitable_coloring.procedure_P( + V_minus=0, V_plus=1, N=N, H=H, F=F, C=C, L=L + ) + check_state(L=L, N=N, H=H, F=F, C=C) + + def test_cast_no_solo(self): + L = { + 0: [8, 9], + 1: [10, 11], + 2: [8], + 3: [9], + 4: [10, 11], + 5: [8], + 6: [9], + 7: [10, 11], + 8: [0, 2, 5], + 9: [0, 3, 6], + 10: [1, 4, 7], + 11: [1, 4, 7], + } + + F = {0: 0, 1: 0, 2: 2, 3: 2, 4: 2, 5: 3, 6: 3, 7: 3, 8: 1, 9: 1, 10: 1, 11: 1} + + C = nx.algorithms.coloring.equitable_coloring.make_C_from_F(F) + N = nx.algorithms.coloring.equitable_coloring.make_N_from_L_C(L, C) + H = nx.algorithms.coloring.equitable_coloring.make_H_from_C_N(C, N) + + nx.algorithms.coloring.equitable_coloring.procedure_P( + V_minus=0, V_plus=1, N=N, H=H, F=F, C=C, L=L + ) + check_state(L=L, N=N, H=H, F=F, C=C) + + def test_hard_prob(self): + # Tests for two levels of recursion. + num_colors, s = 5, 5 + + G = nx.Graph() + G.add_edges_from( + [ + (0, 10), + (0, 11), + (0, 12), + (0, 23), + (10, 4), + (10, 9), + (10, 20), + (11, 4), + (11, 8), + (11, 16), + (12, 9), + (12, 22), + (12, 23), + (23, 7), + (1, 17), + (1, 18), + (1, 19), + (1, 24), + (17, 5), + (17, 13), + (17, 22), + (18, 5), + (19, 5), + (19, 6), + (19, 8), + (24, 7), + (24, 16), + (2, 4), + (2, 13), + (2, 14), + (2, 15), + (4, 6), + (13, 5), + (13, 21), + (14, 6), + (14, 15), + (15, 6), + (15, 21), + (3, 16), + (3, 20), + (3, 21), + (3, 22), + (16, 8), + (20, 8), + (21, 9), + (22, 7), + ] + ) + F = {node: node // s for node in range(num_colors * s)} + F[s - 1] = num_colors - 1 + + params = make_params_from_graph(G=G, F=F) + + nx.algorithms.coloring.equitable_coloring.procedure_P( + V_minus=0, V_plus=num_colors - 1, **params + ) + check_state(**params) + + def test_hardest_prob(self): + # Tests for two levels of recursion. + num_colors, s = 10, 4 + + G = nx.Graph() + G.add_edges_from( + [ + (0, 19), + (0, 24), + (0, 29), + (0, 30), + (0, 35), + (19, 3), + (19, 7), + (19, 9), + (19, 15), + (19, 21), + (19, 24), + (19, 30), + (19, 38), + (24, 5), + (24, 11), + (24, 13), + (24, 20), + (24, 30), + (24, 37), + (24, 38), + (29, 6), + (29, 10), + (29, 13), + (29, 15), + (29, 16), + (29, 17), + (29, 20), + (29, 26), + (30, 6), + (30, 10), + (30, 15), + (30, 22), + (30, 23), + (30, 39), + (35, 6), + (35, 9), + (35, 14), + (35, 18), + (35, 22), + (35, 23), + (35, 25), + (35, 27), + (1, 20), + (1, 26), + (1, 31), + (1, 34), + (1, 38), + (20, 4), + (20, 8), + (20, 14), + (20, 18), + (20, 28), + (20, 33), + (26, 7), + (26, 10), + (26, 14), + (26, 18), + (26, 21), + (26, 32), + (26, 39), + (31, 5), + (31, 8), + (31, 13), + (31, 16), + (31, 17), + (31, 21), + (31, 25), + (31, 27), + (34, 7), + (34, 8), + (34, 13), + (34, 18), + (34, 22), + (34, 23), + (34, 25), + (34, 27), + (38, 4), + (38, 9), + (38, 12), + (38, 14), + (38, 21), + (38, 27), + (2, 3), + (2, 18), + (2, 21), + (2, 28), + (2, 32), + (2, 33), + (2, 36), + (2, 37), + (2, 39), + (3, 5), + (3, 9), + (3, 13), + (3, 22), + (3, 23), + (3, 25), + (3, 27), + (18, 6), + (18, 11), + (18, 15), + (18, 39), + (21, 4), + (21, 10), + (21, 14), + (21, 36), + (28, 6), + (28, 10), + (28, 14), + (28, 16), + (28, 17), + (28, 25), + (28, 27), + (32, 5), + (32, 10), + (32, 12), + (32, 16), + (32, 17), + (32, 22), + (32, 23), + (33, 7), + (33, 10), + (33, 12), + (33, 16), + (33, 17), + (33, 25), + (33, 27), + (36, 5), + (36, 8), + (36, 15), + (36, 16), + (36, 17), + (36, 25), + (36, 27), + (37, 5), + (37, 11), + (37, 15), + (37, 16), + (37, 17), + (37, 22), + (37, 23), + (39, 7), + (39, 8), + (39, 15), + (39, 22), + (39, 23), + ] + ) + F = {node: node // s for node in range(num_colors * s)} + F[s - 1] = num_colors - 1 # V- = 0, V+ = num_colors - 1 + + params = make_params_from_graph(G=G, F=F) + + nx.algorithms.coloring.equitable_coloring.procedure_P( + V_minus=0, V_plus=num_colors - 1, **params + ) + check_state(**params) + + def test_strategy_saturation_largest_first(self): + def color_remaining_nodes( + G, + colored_nodes, + full_color_assignment=None, + nodes_to_add_between_calls=1, + ): + color_assignments = [] + aux_colored_nodes = colored_nodes.copy() + + node_iterator = nx.algorithms.coloring.greedy_coloring.strategy_saturation_largest_first( + G, aux_colored_nodes + ) + + for u in node_iterator: + # Set to keep track of colors of neighbors + nbr_colors = { + aux_colored_nodes[v] for v in G[u] if v in aux_colored_nodes + } + # Find the first unused color. + for color in itertools.count(): + if color not in nbr_colors: + break + aux_colored_nodes[u] = color + color_assignments.append((u, color)) + + # Color nodes between iterations + for i in range(nodes_to_add_between_calls - 1): + if not len(color_assignments) + len(colored_nodes) >= len( + full_color_assignment + ): + full_color_assignment_node, color = full_color_assignment[ + len(color_assignments) + len(colored_nodes) + ] + + # Assign the new color to the current node. + aux_colored_nodes[full_color_assignment_node] = color + color_assignments.append((full_color_assignment_node, color)) + + return color_assignments, aux_colored_nodes + + for G, _, _ in SPECIAL_TEST_CASES["saturation_largest_first"]: + G = G() + + # Check that function still works when nodes are colored between iterations + for nodes_to_add_between_calls in range(1, 5): + # Get a full color assignment, (including the order in which nodes were colored) + colored_nodes = {} + full_color_assignment, full_colored_nodes = color_remaining_nodes( + G, colored_nodes + ) + + # For each node in the color assignment, add it to colored_nodes and re-run the function + for ind, (node, color) in enumerate(full_color_assignment): + colored_nodes[node] = color + + ( + partial_color_assignment, + partial_colored_nodes, + ) = color_remaining_nodes( + G, + colored_nodes, + full_color_assignment=full_color_assignment, + nodes_to_add_between_calls=nodes_to_add_between_calls, + ) + + # Check that the color assignment and order of remaining nodes are the same + assert full_color_assignment[ind + 1 :] == partial_color_assignment + assert full_colored_nodes == partial_colored_nodes + + +# ############################ Utility functions ############################ +def verify_coloring(graph, coloring): + for node in graph.nodes(): + if node not in coloring: + return False + + color = coloring[node] + for neighbor in graph.neighbors(node): + if coloring[neighbor] == color: + return False + + return True + + +def verify_length(coloring, expected): + coloring = dict_to_sets(coloring) + return len(coloring) == expected + + +def dict_to_sets(colors): + if len(colors) == 0: + return [] + + k = max(colors.values()) + 1 + sets = [set() for _ in range(k)] + + for node, color in colors.items(): + sets[color].add(node) + + return sets + + +# ############################ Graph Generation ############################ + + +def empty_graph(): + return nx.Graph() + + +def one_node_graph(): + graph = nx.Graph() + graph.add_nodes_from([1]) + return graph + + +def two_node_graph(): + graph = nx.Graph() + graph.add_nodes_from([1, 2]) + graph.add_edges_from([(1, 2)]) + return graph + + +def three_node_clique(): + graph = nx.Graph() + graph.add_nodes_from([1, 2, 3]) + graph.add_edges_from([(1, 2), (1, 3), (2, 3)]) + return graph + + +def disconnected(): + graph = nx.Graph() + graph.add_edges_from([(1, 2), (2, 3), (4, 5), (5, 6)]) + return graph + + +def rs_shc(): + graph = nx.Graph() + graph.add_nodes_from([1, 2, 3, 4]) + graph.add_edges_from([(1, 2), (2, 3), (3, 4)]) + return graph + + +def slf_shc(): + graph = nx.Graph() + graph.add_nodes_from([1, 2, 3, 4, 5, 6, 7]) + graph.add_edges_from( + [(1, 2), (1, 5), (1, 6), (2, 3), (2, 7), (3, 4), (3, 7), (4, 5), (4, 6), (5, 6)] + ) + return graph + + +def slf_hc(): + graph = nx.Graph() + graph.add_nodes_from([1, 2, 3, 4, 5, 6, 7, 8]) + graph.add_edges_from( + [ + (1, 2), + (1, 3), + (1, 4), + (1, 5), + (2, 3), + (2, 4), + (2, 6), + (5, 7), + (5, 8), + (6, 7), + (6, 8), + (7, 8), + ] + ) + return graph + + +def lf_shc(): + graph = nx.Graph() + graph.add_nodes_from([1, 2, 3, 4, 5, 6]) + graph.add_edges_from([(6, 1), (1, 4), (4, 3), (3, 2), (2, 5)]) + return graph + + +def lf_hc(): + graph = nx.Graph() + graph.add_nodes_from([1, 2, 3, 4, 5, 6, 7]) + graph.add_edges_from( + [ + (1, 7), + (1, 6), + (1, 3), + (1, 4), + (7, 2), + (2, 6), + (2, 3), + (2, 5), + (5, 3), + (5, 4), + (4, 3), + ] + ) + return graph + + +def sl_shc(): + graph = nx.Graph() + graph.add_nodes_from([1, 2, 3, 4, 5, 6]) + graph.add_edges_from( + [(1, 2), (1, 3), (2, 3), (1, 4), (2, 5), (3, 6), (4, 5), (4, 6), (5, 6)] + ) + return graph + + +def sl_hc(): + graph = nx.Graph() + graph.add_nodes_from([1, 2, 3, 4, 5, 6, 7, 8]) + graph.add_edges_from( + [ + (1, 2), + (1, 3), + (1, 5), + (1, 7), + (2, 3), + (2, 4), + (2, 8), + (8, 4), + (8, 6), + (8, 7), + (7, 5), + (7, 6), + (3, 4), + (4, 6), + (6, 5), + (5, 3), + ] + ) + return graph + + +def gis_shc(): + graph = nx.Graph() + graph.add_nodes_from([1, 2, 3, 4]) + graph.add_edges_from([(1, 2), (2, 3), (3, 4)]) + return graph + + +def gis_hc(): + graph = nx.Graph() + graph.add_nodes_from([1, 2, 3, 4, 5, 6]) + graph.add_edges_from([(1, 5), (2, 5), (3, 6), (4, 6), (5, 6)]) + return graph + + +def cs_shc(): + graph = nx.Graph() + graph.add_nodes_from([1, 2, 3, 4, 5]) + graph.add_edges_from([(1, 2), (1, 5), (2, 3), (2, 4), (2, 5), (3, 4), (4, 5)]) + return graph + + +def rsi_shc(): + graph = nx.Graph() + graph.add_nodes_from([1, 2, 3, 4, 5, 6]) + graph.add_edges_from( + [(1, 2), (1, 5), (1, 6), (2, 3), (3, 4), (4, 5), (4, 6), (5, 6)] + ) + return graph + + +def lfi_shc(): + graph = nx.Graph() + graph.add_nodes_from([1, 2, 3, 4, 5, 6, 7]) + graph.add_edges_from( + [(1, 2), (1, 5), (1, 6), (2, 3), (2, 7), (3, 4), (3, 7), (4, 5), (4, 6), (5, 6)] + ) + return graph + + +def lfi_hc(): + graph = nx.Graph() + graph.add_nodes_from([1, 2, 3, 4, 5, 6, 7, 8, 9]) + graph.add_edges_from( + [ + (1, 2), + (1, 5), + (1, 6), + (1, 7), + (2, 3), + (2, 8), + (2, 9), + (3, 4), + (3, 8), + (3, 9), + (4, 5), + (4, 6), + (4, 7), + (5, 6), + ] + ) + return graph + + +def sli_shc(): + graph = nx.Graph() + graph.add_nodes_from([1, 2, 3, 4, 5, 6, 7]) + graph.add_edges_from( + [ + (1, 2), + (1, 3), + (1, 5), + (1, 7), + (2, 3), + (2, 6), + (3, 4), + (4, 5), + (4, 6), + (5, 7), + (6, 7), + ] + ) + return graph + + +def sli_hc(): + graph = nx.Graph() + graph.add_nodes_from([1, 2, 3, 4, 5, 6, 7, 8, 9]) + graph.add_edges_from( + [ + (1, 2), + (1, 3), + (1, 4), + (1, 5), + (2, 3), + (2, 7), + (2, 8), + (2, 9), + (3, 6), + (3, 7), + (3, 9), + (4, 5), + (4, 6), + (4, 8), + (4, 9), + (5, 6), + (5, 7), + (5, 8), + (6, 7), + (6, 9), + (7, 8), + (8, 9), + ] + ) + return graph + + +# -------------------------------------------------------------------------- +# Basic tests for all strategies +# For each basic graph function, specify the number of expected colors. +BASIC_TEST_CASES = { + empty_graph: 0, + one_node_graph: 1, + two_node_graph: 2, + disconnected: 2, + three_node_clique: 3, +} + + +# -------------------------------------------------------------------------- +# Special test cases. Each strategy has a list of tuples of the form +# (graph function, interchange, valid # of colors) +SPECIAL_TEST_CASES = { + "random_sequential": [ + (rs_shc, False, (2, 3)), + (rs_shc, True, 2), + (rsi_shc, True, (3, 4)), + ], + "saturation_largest_first": [(slf_shc, False, (3, 4)), (slf_hc, False, 4)], + "largest_first": [ + (lf_shc, False, (2, 3)), + (lf_hc, False, 4), + (lf_shc, True, 2), + (lf_hc, True, 3), + (lfi_shc, True, (3, 4)), + (lfi_hc, True, 4), + ], + "smallest_last": [ + (sl_shc, False, (3, 4)), + (sl_hc, False, 5), + (sl_shc, True, 3), + (sl_hc, True, 4), + (sli_shc, True, (3, 4)), + (sli_hc, True, 5), + ], + "independent_set": [(gis_shc, False, (2, 3)), (gis_hc, False, 3)], + "connected_sequential": [(cs_shc, False, (3, 4)), (cs_shc, True, 3)], + "connected_sequential_dfs": [(cs_shc, False, (3, 4))], +} + + +# -------------------------------------------------------------------------- +# Helper functions to test +# (graph function, interchange, valid # of colors) + + +def check_state(L, N, H, F, C): + s = len(C[0]) + num_colors = len(C.keys()) + + assert all(u in L[v] for u in L for v in L[u]) + assert all(F[u] != F[v] for u in L for v in L[u]) + assert all(len(L[u]) < num_colors for u in L) + assert all(len(C[x]) == s for x in C) + assert all(H[(c1, c2)] >= 0 for c1 in C for c2 in C) + assert all(N[(u, F[u])] == 0 for u in F) + + +def max_degree(G): + """Get the maximum degree of any node in G.""" + return max(G.degree(node) for node in G.nodes) if len(G.nodes) > 0 else 0 + + +def make_params_from_graph(G, F): + """Returns {N, L, H, C} from the given graph.""" + num_nodes = len(G) + L = {u: [] for u in range(num_nodes)} + for u, v in G.edges: + L[u].append(v) + L[v].append(u) + + C = nx.algorithms.coloring.equitable_coloring.make_C_from_F(F) + N = nx.algorithms.coloring.equitable_coloring.make_N_from_L_C(L, C) + H = nx.algorithms.coloring.equitable_coloring.make_H_from_C_N(C, N) + + return {"N": N, "F": F, "C": C, "H": H, "L": L} diff --git a/lib/python3.10/site-packages/networkx/algorithms/community/__init__.py b/lib/python3.10/site-packages/networkx/algorithms/community/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4dfa8481ac466714e74c23a4d3b564308d929b48 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/community/__init__.py @@ -0,0 +1,26 @@ +"""Functions for computing and measuring community structure. + +The ``community`` subpackage can be accessed by using :mod:`networkx.community`, then accessing the +functions as attributes of ``community``. For example:: + + >>> import networkx as nx + >>> G = nx.barbell_graph(5, 1) + >>> communities_generator = nx.community.girvan_newman(G) + >>> top_level_communities = next(communities_generator) + >>> next_level_communities = next(communities_generator) + >>> sorted(map(sorted, next_level_communities)) + [[0, 1, 2, 3, 4], [5], [6, 7, 8, 9, 10]] + +""" + +from networkx.algorithms.community.asyn_fluid import * +from networkx.algorithms.community.centrality import * +from networkx.algorithms.community.divisive import * +from networkx.algorithms.community.kclique import * +from networkx.algorithms.community.kernighan_lin import * +from networkx.algorithms.community.label_propagation import * +from networkx.algorithms.community.lukes import * +from networkx.algorithms.community.modularity_max import * +from networkx.algorithms.community.quality import * +from networkx.algorithms.community.community_utils import * +from networkx.algorithms.community.louvain import * diff --git a/lib/python3.10/site-packages/networkx/algorithms/community/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/community/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c3505e4690f39990f30d687eb667cacfd5e7c663 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/community/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/community/__pycache__/asyn_fluid.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/community/__pycache__/asyn_fluid.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5f5718bf84bcf8d495ea985bbe6a51d941c8bf1 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/community/__pycache__/asyn_fluid.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/community/__pycache__/centrality.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/community/__pycache__/centrality.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..82db30766598bd3429cc712ad349c0f49edfc3f0 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/community/__pycache__/centrality.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/community/__pycache__/community_utils.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/community/__pycache__/community_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c511bd3201d21c40b13505ec3d48898e314fc8e3 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/community/__pycache__/community_utils.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/community/__pycache__/divisive.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/community/__pycache__/divisive.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1daeabe639178e35dff9e2c76b7a3b44374c065c Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/community/__pycache__/divisive.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/community/__pycache__/kclique.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/community/__pycache__/kclique.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4f99fd885e2c8a9a63e32372da64310932c8ccd5 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/community/__pycache__/kclique.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/community/__pycache__/kernighan_lin.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/community/__pycache__/kernighan_lin.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..486260bfc666e24283411a6c624f6ba6fbe82be4 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/community/__pycache__/kernighan_lin.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/community/__pycache__/label_propagation.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/community/__pycache__/label_propagation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c553ad62ca41c961e075f4001ac242b7de3a9858 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/community/__pycache__/label_propagation.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/community/__pycache__/louvain.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/community/__pycache__/louvain.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f036be911dd527d3e79531bacf28dac536c5d3b Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/community/__pycache__/louvain.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/community/__pycache__/lukes.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/community/__pycache__/lukes.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..480f4dd2f22d9a49698ef51c2a462880ec7fca86 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/community/__pycache__/lukes.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/community/__pycache__/modularity_max.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/community/__pycache__/modularity_max.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8629cc25699733dbc2bb8979860babfa9e5cae3c Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/community/__pycache__/modularity_max.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/community/__pycache__/quality.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/community/__pycache__/quality.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49e162d514c23a93a18543b57b35f47158fc9e26 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/community/__pycache__/quality.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/community/asyn_fluid.py b/lib/python3.10/site-packages/networkx/algorithms/community/asyn_fluid.py new file mode 100644 index 0000000000000000000000000000000000000000..fea72c1bfdbdd451ef653751b56c22445d5da51d --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/community/asyn_fluid.py @@ -0,0 +1,151 @@ +"""Asynchronous Fluid Communities algorithm for community detection.""" + +from collections import Counter + +import networkx as nx +from networkx.algorithms.components import is_connected +from networkx.exception import NetworkXError +from networkx.utils import groups, not_implemented_for, py_random_state + +__all__ = ["asyn_fluidc"] + + +@not_implemented_for("directed") +@not_implemented_for("multigraph") +@py_random_state(3) +@nx._dispatchable +def asyn_fluidc(G, k, max_iter=100, seed=None): + """Returns communities in `G` as detected by Fluid Communities algorithm. + + The asynchronous fluid communities algorithm is described in + [1]_. The algorithm is based on the simple idea of fluids interacting + in an environment, expanding and pushing each other. Its initialization is + random, so found communities may vary on different executions. + + The algorithm proceeds as follows. First each of the initial k communities + is initialized in a random vertex in the graph. Then the algorithm iterates + over all vertices in a random order, updating the community of each vertex + based on its own community and the communities of its neighbors. This + process is performed several times until convergence. + At all times, each community has a total density of 1, which is equally + distributed among the vertices it contains. If a vertex changes of + community, vertex densities of affected communities are adjusted + immediately. When a complete iteration over all vertices is done, such that + no vertex changes the community it belongs to, the algorithm has converged + and returns. + + This is the original version of the algorithm described in [1]_. + Unfortunately, it does not support weighted graphs yet. + + Parameters + ---------- + G : NetworkX graph + Graph must be simple and undirected. + + k : integer + The number of communities to be found. + + max_iter : integer + The number of maximum iterations allowed. By default 100. + + seed : integer, random_state, or None (default) + Indicator of random number generation state. + See :ref:`Randomness`. + + Returns + ------- + communities : iterable + Iterable of communities given as sets of nodes. + + Notes + ----- + k variable is not an optional argument. + + References + ---------- + .. [1] Parés F., Garcia-Gasulla D. et al. "Fluid Communities: A + Competitive and Highly Scalable Community Detection Algorithm". + [https://arxiv.org/pdf/1703.09307.pdf]. + """ + # Initial checks + if not isinstance(k, int): + raise NetworkXError("k must be an integer.") + if not k > 0: + raise NetworkXError("k must be greater than 0.") + if not is_connected(G): + raise NetworkXError("Fluid Communities require connected Graphs.") + if len(G) < k: + raise NetworkXError("k cannot be bigger than the number of nodes.") + # Initialization + max_density = 1.0 + vertices = list(G) + seed.shuffle(vertices) + communities = {n: i for i, n in enumerate(vertices[:k])} + density = {} + com_to_numvertices = {} + for vertex in communities: + com_to_numvertices[communities[vertex]] = 1 + density[communities[vertex]] = max_density + # Set up control variables and start iterating + iter_count = 0 + cont = True + while cont: + cont = False + iter_count += 1 + # Loop over all vertices in graph in a random order + vertices = list(G) + seed.shuffle(vertices) + for vertex in vertices: + # Updating rule + com_counter = Counter() + # Take into account self vertex community + try: + com_counter.update({communities[vertex]: density[communities[vertex]]}) + except KeyError: + pass + # Gather neighbor vertex communities + for v in G[vertex]: + try: + com_counter.update({communities[v]: density[communities[v]]}) + except KeyError: + continue + # Check which is the community with highest density + new_com = -1 + if len(com_counter.keys()) > 0: + max_freq = max(com_counter.values()) + best_communities = [ + com + for com, freq in com_counter.items() + if (max_freq - freq) < 0.0001 + ] + # If actual vertex com in best communities, it is preserved + try: + if communities[vertex] in best_communities: + new_com = communities[vertex] + except KeyError: + pass + # If vertex community changes... + if new_com == -1: + # Set flag of non-convergence + cont = True + # Randomly chose a new community from candidates + new_com = seed.choice(best_communities) + # Update previous community status + try: + com_to_numvertices[communities[vertex]] -= 1 + density[communities[vertex]] = ( + max_density / com_to_numvertices[communities[vertex]] + ) + except KeyError: + pass + # Update new community status + communities[vertex] = new_com + com_to_numvertices[communities[vertex]] += 1 + density[communities[vertex]] = ( + max_density / com_to_numvertices[communities[vertex]] + ) + # If maximum iterations reached --> output actual results + if iter_count > max_iter: + break + # Return results by grouping communities as list of vertices + return iter(groups(communities).values()) diff --git a/lib/python3.10/site-packages/networkx/algorithms/community/centrality.py b/lib/python3.10/site-packages/networkx/algorithms/community/centrality.py new file mode 100644 index 0000000000000000000000000000000000000000..43281701d2b630710acba8f3cef6693356aa461a --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/community/centrality.py @@ -0,0 +1,171 @@ +"""Functions for computing communities based on centrality notions.""" + +import networkx as nx + +__all__ = ["girvan_newman"] + + +@nx._dispatchable(preserve_edge_attrs="most_valuable_edge") +def girvan_newman(G, most_valuable_edge=None): + """Finds communities in a graph using the Girvan–Newman method. + + Parameters + ---------- + G : NetworkX graph + + most_valuable_edge : function + Function that takes a graph as input and outputs an edge. The + edge returned by this function will be recomputed and removed at + each iteration of the algorithm. + + If not specified, the edge with the highest + :func:`networkx.edge_betweenness_centrality` will be used. + + Returns + ------- + iterator + Iterator over tuples of sets of nodes in `G`. Each set of node + is a community, each tuple is a sequence of communities at a + particular level of the algorithm. + + Examples + -------- + To get the first pair of communities:: + + >>> G = nx.path_graph(10) + >>> comp = nx.community.girvan_newman(G) + >>> tuple(sorted(c) for c in next(comp)) + ([0, 1, 2, 3, 4], [5, 6, 7, 8, 9]) + + To get only the first *k* tuples of communities, use + :func:`itertools.islice`:: + + >>> import itertools + >>> G = nx.path_graph(8) + >>> k = 2 + >>> comp = nx.community.girvan_newman(G) + >>> for communities in itertools.islice(comp, k): + ... print(tuple(sorted(c) for c in communities)) + ... + ([0, 1, 2, 3], [4, 5, 6, 7]) + ([0, 1], [2, 3], [4, 5, 6, 7]) + + To stop getting tuples of communities once the number of communities + is greater than *k*, use :func:`itertools.takewhile`:: + + >>> import itertools + >>> G = nx.path_graph(8) + >>> k = 4 + >>> comp = nx.community.girvan_newman(G) + >>> limited = itertools.takewhile(lambda c: len(c) <= k, comp) + >>> for communities in limited: + ... print(tuple(sorted(c) for c in communities)) + ... + ([0, 1, 2, 3], [4, 5, 6, 7]) + ([0, 1], [2, 3], [4, 5, 6, 7]) + ([0, 1], [2, 3], [4, 5], [6, 7]) + + To just choose an edge to remove based on the weight:: + + >>> from operator import itemgetter + >>> G = nx.path_graph(10) + >>> edges = G.edges() + >>> nx.set_edge_attributes(G, {(u, v): v for u, v in edges}, "weight") + >>> def heaviest(G): + ... u, v, w = max(G.edges(data="weight"), key=itemgetter(2)) + ... return (u, v) + ... + >>> comp = nx.community.girvan_newman(G, most_valuable_edge=heaviest) + >>> tuple(sorted(c) for c in next(comp)) + ([0, 1, 2, 3, 4, 5, 6, 7, 8], [9]) + + To utilize edge weights when choosing an edge with, for example, the + highest betweenness centrality:: + + >>> from networkx import edge_betweenness_centrality as betweenness + >>> def most_central_edge(G): + ... centrality = betweenness(G, weight="weight") + ... return max(centrality, key=centrality.get) + ... + >>> G = nx.path_graph(10) + >>> comp = nx.community.girvan_newman(G, most_valuable_edge=most_central_edge) + >>> tuple(sorted(c) for c in next(comp)) + ([0, 1, 2, 3, 4], [5, 6, 7, 8, 9]) + + To specify a different ranking algorithm for edges, use the + `most_valuable_edge` keyword argument:: + + >>> from networkx import edge_betweenness_centrality + >>> from random import random + >>> def most_central_edge(G): + ... centrality = edge_betweenness_centrality(G) + ... max_cent = max(centrality.values()) + ... # Scale the centrality values so they are between 0 and 1, + ... # and add some random noise. + ... centrality = {e: c / max_cent for e, c in centrality.items()} + ... # Add some random noise. + ... centrality = {e: c + random() for e, c in centrality.items()} + ... return max(centrality, key=centrality.get) + ... + >>> G = nx.path_graph(10) + >>> comp = nx.community.girvan_newman(G, most_valuable_edge=most_central_edge) + + Notes + ----- + The Girvan–Newman algorithm detects communities by progressively + removing edges from the original graph. The algorithm removes the + "most valuable" edge, traditionally the edge with the highest + betweenness centrality, at each step. As the graph breaks down into + pieces, the tightly knit community structure is exposed and the + result can be depicted as a dendrogram. + + """ + # If the graph is already empty, simply return its connected + # components. + if G.number_of_edges() == 0: + yield tuple(nx.connected_components(G)) + return + # If no function is provided for computing the most valuable edge, + # use the edge betweenness centrality. + if most_valuable_edge is None: + + def most_valuable_edge(G): + """Returns the edge with the highest betweenness centrality + in the graph `G`. + + """ + # We have guaranteed that the graph is non-empty, so this + # dictionary will never be empty. + betweenness = nx.edge_betweenness_centrality(G) + return max(betweenness, key=betweenness.get) + + # The copy of G here must include the edge weight data. + g = G.copy().to_undirected() + # Self-loops must be removed because their removal has no effect on + # the connected components of the graph. + g.remove_edges_from(nx.selfloop_edges(g)) + while g.number_of_edges() > 0: + yield _without_most_central_edges(g, most_valuable_edge) + + +def _without_most_central_edges(G, most_valuable_edge): + """Returns the connected components of the graph that results from + repeatedly removing the most "valuable" edge in the graph. + + `G` must be a non-empty graph. This function modifies the graph `G` + in-place; that is, it removes edges on the graph `G`. + + `most_valuable_edge` is a function that takes the graph `G` as input + (or a subgraph with one or more edges of `G` removed) and returns an + edge. That edge will be removed and this process will be repeated + until the number of connected components in the graph increases. + + """ + original_num_components = nx.number_connected_components(G) + num_new_components = original_num_components + while num_new_components <= original_num_components: + edge = most_valuable_edge(G) + G.remove_edge(*edge) + new_components = tuple(nx.connected_components(G)) + num_new_components = len(new_components) + return new_components diff --git a/lib/python3.10/site-packages/networkx/algorithms/community/community_utils.py b/lib/python3.10/site-packages/networkx/algorithms/community/community_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ba73a6b30b28410b49babd8f996927a43931124d --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/community/community_utils.py @@ -0,0 +1,30 @@ +"""Helper functions for community-finding algorithms.""" + +import networkx as nx + +__all__ = ["is_partition"] + + +@nx._dispatchable +def is_partition(G, communities): + """Returns *True* if `communities` is a partition of the nodes of `G`. + + A partition of a universe set is a family of pairwise disjoint sets + whose union is the entire universe set. + + Parameters + ---------- + G : NetworkX graph. + + communities : list or iterable of sets of nodes + If not a list, the iterable is converted internally to a list. + If it is an iterator it is exhausted. + + """ + # Alternate implementation: + # return all(sum(1 if v in c else 0 for c in communities) == 1 for v in G) + if not isinstance(communities, list): + communities = list(communities) + nodes = {n for c in communities for n in c if n in G} + + return len(G) == len(nodes) == sum(len(c) for c in communities) diff --git a/lib/python3.10/site-packages/networkx/algorithms/community/divisive.py b/lib/python3.10/site-packages/networkx/algorithms/community/divisive.py new file mode 100644 index 0000000000000000000000000000000000000000..be3c7d863e9d28f6e9c56faea4a60a640f8892bb --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/community/divisive.py @@ -0,0 +1,216 @@ +import functools + +import networkx as nx + +__all__ = [ + "edge_betweenness_partition", + "edge_current_flow_betweenness_partition", +] + + +@nx._dispatchable(edge_attrs="weight") +def edge_betweenness_partition(G, number_of_sets, *, weight=None): + """Partition created by iteratively removing the highest edge betweenness edge. + + This algorithm works by calculating the edge betweenness for all + edges and removing the edge with the highest value. It is then + determined whether the graph has been broken into at least + `number_of_sets` connected components. + If not the process is repeated. + + Parameters + ---------- + G : NetworkX Graph, DiGraph or MultiGraph + Graph to be partitioned + + number_of_sets : int + Number of sets in the desired partition of the graph + + weight : key, optional, default=None + The key to use if using weights for edge betweenness calculation + + Returns + ------- + C : list of sets + Partition of the nodes of G + + Raises + ------ + NetworkXError + If number_of_sets is <= 0 or if number_of_sets > len(G) + + Examples + -------- + >>> G = nx.karate_club_graph() + >>> part = nx.community.edge_betweenness_partition(G, 2) + >>> {0, 1, 3, 4, 5, 6, 7, 10, 11, 12, 13, 16, 17, 19, 21} in part + True + >>> { + ... 2, + ... 8, + ... 9, + ... 14, + ... 15, + ... 18, + ... 20, + ... 22, + ... 23, + ... 24, + ... 25, + ... 26, + ... 27, + ... 28, + ... 29, + ... 30, + ... 31, + ... 32, + ... 33, + ... } in part + True + + See Also + -------- + edge_current_flow_betweenness_partition + + Notes + ----- + This algorithm is fairly slow, as both the calculation of connected + components and edge betweenness relies on all pairs shortest + path algorithms. They could potentially be combined to cut down + on overall computation time. + + References + ---------- + .. [1] Santo Fortunato 'Community Detection in Graphs' Physical Reports + Volume 486, Issue 3-5 p. 75-174 + http://arxiv.org/abs/0906.0612 + """ + if number_of_sets <= 0: + raise nx.NetworkXError("number_of_sets must be >0") + if number_of_sets == 1: + return [set(G)] + if number_of_sets == len(G): + return [{n} for n in G] + if number_of_sets > len(G): + raise nx.NetworkXError("number_of_sets must be <= len(G)") + + H = G.copy() + partition = list(nx.connected_components(H)) + while len(partition) < number_of_sets: + ranking = nx.edge_betweenness_centrality(H, weight=weight) + edge = max(ranking, key=ranking.get) + H.remove_edge(*edge) + partition = list(nx.connected_components(H)) + return partition + + +@nx._dispatchable(edge_attrs="weight") +def edge_current_flow_betweenness_partition(G, number_of_sets, *, weight=None): + """Partition created by removing the highest edge current flow betweenness edge. + + This algorithm works by calculating the edge current flow + betweenness for all edges and removing the edge with the + highest value. It is then determined whether the graph has + been broken into at least `number_of_sets` connected + components. If not the process is repeated. + + Parameters + ---------- + G : NetworkX Graph, DiGraph or MultiGraph + Graph to be partitioned + + number_of_sets : int + Number of sets in the desired partition of the graph + + weight : key, optional (default=None) + The edge attribute key to use as weights for + edge current flow betweenness calculations + + Returns + ------- + C : list of sets + Partition of G + + Raises + ------ + NetworkXError + If number_of_sets is <= 0 or number_of_sets > len(G) + + Examples + -------- + >>> G = nx.karate_club_graph() + >>> part = nx.community.edge_current_flow_betweenness_partition(G, 2) + >>> {0, 1, 2, 3, 4, 5, 6, 7, 9, 10, 11, 12, 13, 16, 17, 19, 21} in part + True + >>> {8, 14, 15, 18, 20, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33} in part + True + + + See Also + -------- + edge_betweenness_partition + + Notes + ----- + This algorithm is extremely slow, as the recalculation of the edge + current flow betweenness is extremely slow. + + References + ---------- + .. [1] Santo Fortunato 'Community Detection in Graphs' Physical Reports + Volume 486, Issue 3-5 p. 75-174 + http://arxiv.org/abs/0906.0612 + """ + if number_of_sets <= 0: + raise nx.NetworkXError("number_of_sets must be >0") + elif number_of_sets == 1: + return [set(G)] + elif number_of_sets == len(G): + return [{n} for n in G] + elif number_of_sets > len(G): + raise nx.NetworkXError("number_of_sets must be <= len(G)") + + rank = functools.partial( + nx.edge_current_flow_betweenness_centrality, normalized=False, weight=weight + ) + + # current flow requires a connected network so we track the components explicitly + H = G.copy() + partition = list(nx.connected_components(H)) + if len(partition) > 1: + Hcc_subgraphs = [H.subgraph(cc).copy() for cc in partition] + else: + Hcc_subgraphs = [H] + + ranking = {} + for Hcc in Hcc_subgraphs: + ranking.update(rank(Hcc)) + + while len(partition) < number_of_sets: + edge = max(ranking, key=ranking.get) + for cc, Hcc in zip(partition, Hcc_subgraphs): + if edge[0] in cc: + Hcc.remove_edge(*edge) + del ranking[edge] + splitcc_list = list(nx.connected_components(Hcc)) + if len(splitcc_list) > 1: + # there are 2 connected components. split off smaller one + cc_new = min(splitcc_list, key=len) + Hcc_new = Hcc.subgraph(cc_new).copy() + # update edge rankings for Hcc_new + newranks = rank(Hcc_new) + for e, r in newranks.items(): + ranking[e if e in ranking else e[::-1]] = r + # append new cc and Hcc to their lists. + partition.append(cc_new) + Hcc_subgraphs.append(Hcc_new) + + # leave existing cc and Hcc in their lists, but shrink them + Hcc.remove_nodes_from(cc_new) + cc.difference_update(cc_new) + # update edge rankings for Hcc whether it was split or not + newranks = rank(Hcc) + for e, r in newranks.items(): + ranking[e if e in ranking else e[::-1]] = r + break + return partition diff --git a/lib/python3.10/site-packages/networkx/algorithms/community/kclique.py b/lib/python3.10/site-packages/networkx/algorithms/community/kclique.py new file mode 100644 index 0000000000000000000000000000000000000000..c72491042046b6f79ba5c7cb4a90ac8822491d84 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/community/kclique.py @@ -0,0 +1,79 @@ +from collections import defaultdict + +import networkx as nx + +__all__ = ["k_clique_communities"] + + +@nx._dispatchable +def k_clique_communities(G, k, cliques=None): + """Find k-clique communities in graph using the percolation method. + + A k-clique community is the union of all cliques of size k that + can be reached through adjacent (sharing k-1 nodes) k-cliques. + + Parameters + ---------- + G : NetworkX graph + + k : int + Size of smallest clique + + cliques: list or generator + Precomputed cliques (use networkx.find_cliques(G)) + + Returns + ------- + Yields sets of nodes, one for each k-clique community. + + Examples + -------- + >>> G = nx.complete_graph(5) + >>> K5 = nx.convert_node_labels_to_integers(G, first_label=2) + >>> G.add_edges_from(K5.edges()) + >>> c = list(nx.community.k_clique_communities(G, 4)) + >>> sorted(list(c[0])) + [0, 1, 2, 3, 4, 5, 6] + >>> list(nx.community.k_clique_communities(G, 6)) + [] + + References + ---------- + .. [1] Gergely Palla, Imre Derényi, Illés Farkas1, and Tamás Vicsek, + Uncovering the overlapping community structure of complex networks + in nature and society Nature 435, 814-818, 2005, + doi:10.1038/nature03607 + """ + if k < 2: + raise nx.NetworkXError(f"k={k}, k must be greater than 1.") + if cliques is None: + cliques = nx.find_cliques(G) + cliques = [frozenset(c) for c in cliques if len(c) >= k] + + # First index which nodes are in which cliques + membership_dict = defaultdict(list) + for clique in cliques: + for node in clique: + membership_dict[node].append(clique) + + # For each clique, see which adjacent cliques percolate + perc_graph = nx.Graph() + perc_graph.add_nodes_from(cliques) + for clique in cliques: + for adj_clique in _get_adjacent_cliques(clique, membership_dict): + if len(clique.intersection(adj_clique)) >= (k - 1): + perc_graph.add_edge(clique, adj_clique) + + # Connected components of clique graph with perc edges + # are the percolated cliques + for component in nx.connected_components(perc_graph): + yield (frozenset.union(*component)) + + +def _get_adjacent_cliques(clique, membership_dict): + adjacent_cliques = set() + for n in clique: + for adj_clique in membership_dict[n]: + if clique != adj_clique: + adjacent_cliques.add(adj_clique) + return adjacent_cliques diff --git a/lib/python3.10/site-packages/networkx/algorithms/community/kernighan_lin.py b/lib/python3.10/site-packages/networkx/algorithms/community/kernighan_lin.py new file mode 100644 index 0000000000000000000000000000000000000000..f6397d82be6fa94273a81a613411cc6a74d8c4cc --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/community/kernighan_lin.py @@ -0,0 +1,139 @@ +"""Functions for computing the Kernighan–Lin bipartition algorithm.""" + +from itertools import count + +import networkx as nx +from networkx.algorithms.community.community_utils import is_partition +from networkx.utils import BinaryHeap, not_implemented_for, py_random_state + +__all__ = ["kernighan_lin_bisection"] + + +def _kernighan_lin_sweep(edges, side): + """ + This is a modified form of Kernighan-Lin, which moves single nodes at a + time, alternating between sides to keep the bisection balanced. We keep + two min-heaps of swap costs to make optimal-next-move selection fast. + """ + costs0, costs1 = costs = BinaryHeap(), BinaryHeap() + for u, side_u, edges_u in zip(count(), side, edges): + cost_u = sum(w if side[v] else -w for v, w in edges_u) + costs[side_u].insert(u, cost_u if side_u else -cost_u) + + def _update_costs(costs_x, x): + for y, w in edges[x]: + costs_y = costs[side[y]] + cost_y = costs_y.get(y) + if cost_y is not None: + cost_y += 2 * (-w if costs_x is costs_y else w) + costs_y.insert(y, cost_y, True) + + i = 0 + totcost = 0 + while costs0 and costs1: + u, cost_u = costs0.pop() + _update_costs(costs0, u) + v, cost_v = costs1.pop() + _update_costs(costs1, v) + totcost += cost_u + cost_v + i += 1 + yield totcost, i, (u, v) + + +@not_implemented_for("directed") +@py_random_state(4) +@nx._dispatchable(edge_attrs="weight") +def kernighan_lin_bisection(G, partition=None, max_iter=10, weight="weight", seed=None): + """Partition a graph into two blocks using the Kernighan–Lin + algorithm. + + This algorithm partitions a network into two sets by iteratively + swapping pairs of nodes to reduce the edge cut between the two sets. The + pairs are chosen according to a modified form of Kernighan-Lin [1]_, which + moves node individually, alternating between sides to keep the bisection + balanced. + + Parameters + ---------- + G : NetworkX graph + Graph must be undirected. + + partition : tuple + Pair of iterables containing an initial partition. If not + specified, a random balanced partition is used. + + max_iter : int + Maximum number of times to attempt swaps to find an + improvement before giving up. + + weight : key + Edge data key to use as weight. If None, the weights are all + set to one. + + seed : integer, random_state, or None (default) + Indicator of random number generation state. + See :ref:`Randomness`. + Only used if partition is None + + Returns + ------- + partition : tuple + A pair of sets of nodes representing the bipartition. + + Raises + ------ + NetworkXError + If partition is not a valid partition of the nodes of the graph. + + References + ---------- + .. [1] Kernighan, B. W.; Lin, Shen (1970). + "An efficient heuristic procedure for partitioning graphs." + *Bell Systems Technical Journal* 49: 291--307. + Oxford University Press 2011. + + """ + n = len(G) + labels = list(G) + seed.shuffle(labels) + index = {v: i for i, v in enumerate(labels)} + + if partition is None: + side = [0] * (n // 2) + [1] * ((n + 1) // 2) + else: + try: + A, B = partition + except (TypeError, ValueError) as err: + raise nx.NetworkXError("partition must be two sets") from err + if not is_partition(G, (A, B)): + raise nx.NetworkXError("partition invalid") + side = [0] * n + for a in A: + side[index[a]] = 1 + + if G.is_multigraph(): + edges = [ + [ + (index[u], sum(e.get(weight, 1) for e in d.values())) + for u, d in G[v].items() + ] + for v in labels + ] + else: + edges = [ + [(index[u], e.get(weight, 1)) for u, e in G[v].items()] for v in labels + ] + + for i in range(max_iter): + costs = list(_kernighan_lin_sweep(edges, side)) + min_cost, min_i, _ = min(costs) + if min_cost >= 0: + break + + for _, _, (u, v) in costs[:min_i]: + side[u] = 1 + side[v] = 0 + + A = {u for u, s in zip(labels, side) if s == 0} + B = {u for u, s in zip(labels, side) if s == 1} + return A, B diff --git a/lib/python3.10/site-packages/networkx/algorithms/community/label_propagation.py b/lib/python3.10/site-packages/networkx/algorithms/community/label_propagation.py new file mode 100644 index 0000000000000000000000000000000000000000..7488028655af419c617bd3573b98071e012c4eda --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/community/label_propagation.py @@ -0,0 +1,338 @@ +""" +Label propagation community detection algorithms. +""" + +from collections import Counter, defaultdict, deque + +import networkx as nx +from networkx.utils import groups, not_implemented_for, py_random_state + +__all__ = [ + "label_propagation_communities", + "asyn_lpa_communities", + "fast_label_propagation_communities", +] + + +@py_random_state("seed") +@nx._dispatchable(edge_attrs="weight") +def fast_label_propagation_communities(G, *, weight=None, seed=None): + """Returns communities in `G` as detected by fast label propagation. + + The fast label propagation algorithm is described in [1]_. The algorithm is + probabilistic and the found communities may vary in different executions. + + The algorithm operates as follows. First, the community label of each node is + set to a unique label. The algorithm then repeatedly updates the labels of + the nodes to the most frequent label in their neighborhood. In case of ties, + a random label is chosen from the most frequent labels. + + The algorithm maintains a queue of nodes that still need to be processed. + Initially, all nodes are added to the queue in a random order. Then the nodes + are removed from the queue one by one and processed. If a node updates its label, + all its neighbors that have a different label are added to the queue (if not + already in the queue). The algorithm stops when the queue is empty. + + Parameters + ---------- + G : Graph, DiGraph, MultiGraph, or MultiDiGraph + Any NetworkX graph. + + weight : string, or None (default) + The edge attribute representing a non-negative weight of an edge. If None, + each edge is assumed to have weight one. The weight of an edge is used in + determining the frequency with which a label appears among the neighbors of + a node (edge with weight `w` is equivalent to `w` unweighted edges). + + seed : integer, random_state, or None (default) + Indicator of random number generation state. See :ref:`Randomness`. + + Returns + ------- + communities : iterable + Iterable of communities given as sets of nodes. + + Notes + ----- + Edge directions are ignored for directed graphs. + Edge weights must be non-negative numbers. + + References + ---------- + .. [1] Vincent A. Traag & Lovro Šubelj. "Large network community detection by + fast label propagation." Scientific Reports 13 (2023): 2701. + https://doi.org/10.1038/s41598-023-29610-z + """ + + # Queue of nodes to be processed. + nodes_queue = deque(G) + seed.shuffle(nodes_queue) + + # Set of nodes in the queue. + nodes_set = set(G) + + # Assign unique label to each node. + comms = {node: i for i, node in enumerate(G)} + + while nodes_queue: + # Remove next node from the queue to process. + node = nodes_queue.popleft() + nodes_set.remove(node) + + # Isolated nodes retain their initial label. + if G.degree(node) > 0: + # Compute frequency of labels in node's neighborhood. + label_freqs = _fast_label_count(G, comms, node, weight) + max_freq = max(label_freqs.values()) + + # Always sample new label from most frequent labels. + comm = seed.choice( + [comm for comm in label_freqs if label_freqs[comm] == max_freq] + ) + + if comms[node] != comm: + comms[node] = comm + + # Add neighbors that have different label to the queue. + for nbr in nx.all_neighbors(G, node): + if comms[nbr] != comm and nbr not in nodes_set: + nodes_queue.append(nbr) + nodes_set.add(nbr) + + yield from groups(comms).values() + + +def _fast_label_count(G, comms, node, weight=None): + """Computes the frequency of labels in the neighborhood of a node. + + Returns a dictionary keyed by label to the frequency of that label. + """ + + if weight is None: + # Unweighted (un)directed simple graph. + if not G.is_multigraph(): + label_freqs = Counter(map(comms.get, nx.all_neighbors(G, node))) + + # Unweighted (un)directed multigraph. + else: + label_freqs = defaultdict(int) + for nbr in G[node]: + label_freqs[comms[nbr]] += len(G[node][nbr]) + + if G.is_directed(): + for nbr in G.pred[node]: + label_freqs[comms[nbr]] += len(G.pred[node][nbr]) + + else: + # Weighted undirected simple/multigraph. + label_freqs = defaultdict(float) + for _, nbr, w in G.edges(node, data=weight, default=1): + label_freqs[comms[nbr]] += w + + # Weighted directed simple/multigraph. + if G.is_directed(): + for nbr, _, w in G.in_edges(node, data=weight, default=1): + label_freqs[comms[nbr]] += w + + return label_freqs + + +@py_random_state(2) +@nx._dispatchable(edge_attrs="weight") +def asyn_lpa_communities(G, weight=None, seed=None): + """Returns communities in `G` as detected by asynchronous label + propagation. + + The asynchronous label propagation algorithm is described in + [1]_. The algorithm is probabilistic and the found communities may + vary on different executions. + + The algorithm proceeds as follows. After initializing each node with + a unique label, the algorithm repeatedly sets the label of a node to + be the label that appears most frequently among that nodes + neighbors. The algorithm halts when each node has the label that + appears most frequently among its neighbors. The algorithm is + asynchronous because each node is updated without waiting for + updates on the remaining nodes. + + This generalized version of the algorithm in [1]_ accepts edge + weights. + + Parameters + ---------- + G : Graph + + weight : string + The edge attribute representing the weight of an edge. + If None, each edge is assumed to have weight one. In this + algorithm, the weight of an edge is used in determining the + frequency with which a label appears among the neighbors of a + node: a higher weight means the label appears more often. + + seed : integer, random_state, or None (default) + Indicator of random number generation state. + See :ref:`Randomness`. + + Returns + ------- + communities : iterable + Iterable of communities given as sets of nodes. + + Notes + ----- + Edge weight attributes must be numerical. + + References + ---------- + .. [1] Raghavan, Usha Nandini, Réka Albert, and Soundar Kumara. "Near + linear time algorithm to detect community structures in large-scale + networks." Physical Review E 76.3 (2007): 036106. + """ + + labels = {n: i for i, n in enumerate(G)} + cont = True + + while cont: + cont = False + nodes = list(G) + seed.shuffle(nodes) + + for node in nodes: + if not G[node]: + continue + + # Get label frequencies among adjacent nodes. + # Depending on the order they are processed in, + # some nodes will be in iteration t and others in t-1, + # making the algorithm asynchronous. + if weight is None: + # initialising a Counter from an iterator of labels is + # faster for getting unweighted label frequencies + label_freq = Counter(map(labels.get, G[node])) + else: + # updating a defaultdict is substantially faster + # for getting weighted label frequencies + label_freq = defaultdict(float) + for _, v, wt in G.edges(node, data=weight, default=1): + label_freq[labels[v]] += wt + + # Get the labels that appear with maximum frequency. + max_freq = max(label_freq.values()) + best_labels = [ + label for label, freq in label_freq.items() if freq == max_freq + ] + + # If the node does not have one of the maximum frequency labels, + # randomly choose one of them and update the node's label. + # Continue the iteration as long as at least one node + # doesn't have a maximum frequency label. + if labels[node] not in best_labels: + labels[node] = seed.choice(best_labels) + cont = True + + yield from groups(labels).values() + + +@not_implemented_for("directed") +@nx._dispatchable +def label_propagation_communities(G): + """Generates community sets determined by label propagation + + Finds communities in `G` using a semi-synchronous label propagation + method [1]_. This method combines the advantages of both the synchronous + and asynchronous models. Not implemented for directed graphs. + + Parameters + ---------- + G : graph + An undirected NetworkX graph. + + Returns + ------- + communities : iterable + A dict_values object that contains a set of nodes for each community. + + Raises + ------ + NetworkXNotImplemented + If the graph is directed + + References + ---------- + .. [1] Cordasco, G., & Gargano, L. (2010, December). Community detection + via semi-synchronous label propagation algorithms. In Business + Applications of Social Network Analysis (BASNA), 2010 IEEE International + Workshop on (pp. 1-8). IEEE. + """ + coloring = _color_network(G) + # Create a unique label for each node in the graph + labeling = {v: k for k, v in enumerate(G)} + while not _labeling_complete(labeling, G): + # Update the labels of every node with the same color. + for color, nodes in coloring.items(): + for n in nodes: + _update_label(n, labeling, G) + + clusters = defaultdict(set) + for node, label in labeling.items(): + clusters[label].add(node) + return clusters.values() + + +def _color_network(G): + """Colors the network so that neighboring nodes all have distinct colors. + + Returns a dict keyed by color to a set of nodes with that color. + """ + coloring = {} # color => set(node) + colors = nx.coloring.greedy_color(G) + for node, color in colors.items(): + if color in coloring: + coloring[color].add(node) + else: + coloring[color] = {node} + return coloring + + +def _labeling_complete(labeling, G): + """Determines whether or not LPA is done. + + Label propagation is complete when all nodes have a label that is + in the set of highest frequency labels amongst its neighbors. + + Nodes with no neighbors are considered complete. + """ + return all( + labeling[v] in _most_frequent_labels(v, labeling, G) for v in G if len(G[v]) > 0 + ) + + +def _most_frequent_labels(node, labeling, G): + """Returns a set of all labels with maximum frequency in `labeling`. + + Input `labeling` should be a dict keyed by node to labels. + """ + if not G[node]: + # Nodes with no neighbors are themselves a community and are labeled + # accordingly, hence the immediate if statement. + return {labeling[node]} + + # Compute the frequencies of all neighbors of node + freqs = Counter(labeling[q] for q in G[node]) + max_freq = max(freqs.values()) + return {label for label, freq in freqs.items() if freq == max_freq} + + +def _update_label(node, labeling, G): + """Updates the label of a node using the Prec-Max tie breaking algorithm + + The algorithm is explained in: 'Community Detection via Semi-Synchronous + Label Propagation Algorithms' Cordasco and Gargano, 2011 + """ + high_labels = _most_frequent_labels(node, labeling, G) + if len(high_labels) == 1: + labeling[node] = high_labels.pop() + elif len(high_labels) > 1: + # Prec-Max + if labeling[node] not in high_labels: + labeling[node] = max(high_labels) diff --git a/lib/python3.10/site-packages/networkx/algorithms/community/louvain.py b/lib/python3.10/site-packages/networkx/algorithms/community/louvain.py new file mode 100644 index 0000000000000000000000000000000000000000..959c93a5104b1f68ad67aef1b85bd28adbfc32ec --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/community/louvain.py @@ -0,0 +1,382 @@ +"""Function for detecting communities based on Louvain Community Detection +Algorithm""" + +import itertools +from collections import defaultdict, deque + +import networkx as nx +from networkx.algorithms.community import modularity +from networkx.utils import py_random_state + +__all__ = ["louvain_communities", "louvain_partitions"] + + +@py_random_state("seed") +@nx._dispatchable(edge_attrs="weight") +def louvain_communities( + G, weight="weight", resolution=1, threshold=0.0000001, max_level=None, seed=None +): + r"""Find the best partition of a graph using the Louvain Community Detection + Algorithm. + + Louvain Community Detection Algorithm is a simple method to extract the community + structure of a network. This is a heuristic method based on modularity optimization. [1]_ + + The algorithm works in 2 steps. On the first step it assigns every node to be + in its own community and then for each node it tries to find the maximum positive + modularity gain by moving each node to all of its neighbor communities. If no positive + gain is achieved the node remains in its original community. + + The modularity gain obtained by moving an isolated node $i$ into a community $C$ can + easily be calculated by the following formula (combining [1]_ [2]_ and some algebra): + + .. math:: + \Delta Q = \frac{k_{i,in}}{2m} - \gamma\frac{ \Sigma_{tot} \cdot k_i}{2m^2} + + where $m$ is the size of the graph, $k_{i,in}$ is the sum of the weights of the links + from $i$ to nodes in $C$, $k_i$ is the sum of the weights of the links incident to node $i$, + $\Sigma_{tot}$ is the sum of the weights of the links incident to nodes in $C$ and $\gamma$ + is the resolution parameter. + + For the directed case the modularity gain can be computed using this formula according to [3]_ + + .. math:: + \Delta Q = \frac{k_{i,in}}{m} + - \gamma\frac{k_i^{out} \cdot\Sigma_{tot}^{in} + k_i^{in} \cdot \Sigma_{tot}^{out}}{m^2} + + where $k_i^{out}$, $k_i^{in}$ are the outer and inner weighted degrees of node $i$ and + $\Sigma_{tot}^{in}$, $\Sigma_{tot}^{out}$ are the sum of in-going and out-going links incident + to nodes in $C$. + + The first phase continues until no individual move can improve the modularity. + + The second phase consists in building a new network whose nodes are now the communities + found in the first phase. To do so, the weights of the links between the new nodes are given by + the sum of the weight of the links between nodes in the corresponding two communities. Once this + phase is complete it is possible to reapply the first phase creating bigger communities with + increased modularity. + + The above two phases are executed until no modularity gain is achieved (or is less than + the `threshold`, or until `max_levels` is reached). + + Be careful with self-loops in the input graph. These are treated as + previously reduced communities -- as if the process had been started + in the middle of the algorithm. Large self-loop edge weights thus + represent strong communities and in practice may be hard to add + other nodes to. If your input graph edge weights for self-loops + do not represent already reduced communities you may want to remove + the self-loops before inputting that graph. + + Parameters + ---------- + G : NetworkX graph + weight : string or None, optional (default="weight") + The name of an edge attribute that holds the numerical value + used as a weight. If None then each edge has weight 1. + resolution : float, optional (default=1) + If resolution is less than 1, the algorithm favors larger communities. + Greater than 1 favors smaller communities + threshold : float, optional (default=0.0000001) + Modularity gain threshold for each level. If the gain of modularity + between 2 levels of the algorithm is less than the given threshold + then the algorithm stops and returns the resulting communities. + max_level : int or None, optional (default=None) + The maximum number of levels (steps of the algorithm) to compute. + Must be a positive integer or None. If None, then there is no max + level and the threshold parameter determines the stopping condition. + seed : integer, random_state, or None (default) + Indicator of random number generation state. + See :ref:`Randomness`. + + Returns + ------- + list + A list of sets (partition of `G`). Each set represents one community and contains + all the nodes that constitute it. + + Examples + -------- + >>> import networkx as nx + >>> G = nx.petersen_graph() + >>> nx.community.louvain_communities(G, seed=123) + [{0, 4, 5, 7, 9}, {1, 2, 3, 6, 8}] + + Notes + ----- + The order in which the nodes are considered can affect the final output. In the algorithm + the ordering happens using a random shuffle. + + References + ---------- + .. [1] Blondel, V.D. et al. Fast unfolding of communities in + large networks. J. Stat. Mech 10008, 1-12(2008). https://doi.org/10.1088/1742-5468/2008/10/P10008 + .. [2] Traag, V.A., Waltman, L. & van Eck, N.J. From Louvain to Leiden: guaranteeing + well-connected communities. Sci Rep 9, 5233 (2019). https://doi.org/10.1038/s41598-019-41695-z + .. [3] Nicolas Dugué, Anthony Perez. Directed Louvain : maximizing modularity in directed networks. + [Research Report] Université d’Orléans. 2015. hal-01231784. https://hal.archives-ouvertes.fr/hal-01231784 + + See Also + -------- + louvain_partitions + """ + + partitions = louvain_partitions(G, weight, resolution, threshold, seed) + if max_level is not None: + if max_level <= 0: + raise ValueError("max_level argument must be a positive integer or None") + partitions = itertools.islice(partitions, max_level) + final_partition = deque(partitions, maxlen=1) + return final_partition.pop() + + +@py_random_state("seed") +@nx._dispatchable(edge_attrs="weight") +def louvain_partitions( + G, weight="weight", resolution=1, threshold=0.0000001, seed=None +): + """Yields partitions for each level of the Louvain Community Detection Algorithm + + Louvain Community Detection Algorithm is a simple method to extract the community + structure of a network. This is a heuristic method based on modularity optimization. [1]_ + + The partitions at each level (step of the algorithm) form a dendrogram of communities. + A dendrogram is a diagram representing a tree and each level represents + a partition of the G graph. The top level contains the smallest communities + and as you traverse to the bottom of the tree the communities get bigger + and the overall modularity increases making the partition better. + + Each level is generated by executing the two phases of the Louvain Community + Detection Algorithm. + + Be careful with self-loops in the input graph. These are treated as + previously reduced communities -- as if the process had been started + in the middle of the algorithm. Large self-loop edge weights thus + represent strong communities and in practice may be hard to add + other nodes to. If your input graph edge weights for self-loops + do not represent already reduced communities you may want to remove + the self-loops before inputting that graph. + + Parameters + ---------- + G : NetworkX graph + weight : string or None, optional (default="weight") + The name of an edge attribute that holds the numerical value + used as a weight. If None then each edge has weight 1. + resolution : float, optional (default=1) + If resolution is less than 1, the algorithm favors larger communities. + Greater than 1 favors smaller communities + threshold : float, optional (default=0.0000001) + Modularity gain threshold for each level. If the gain of modularity + between 2 levels of the algorithm is less than the given threshold + then the algorithm stops and returns the resulting communities. + seed : integer, random_state, or None (default) + Indicator of random number generation state. + See :ref:`Randomness`. + + Yields + ------ + list + A list of sets (partition of `G`). Each set represents one community and contains + all the nodes that constitute it. + + References + ---------- + .. [1] Blondel, V.D. et al. Fast unfolding of communities in + large networks. J. Stat. Mech 10008, 1-12(2008) + + See Also + -------- + louvain_communities + """ + + partition = [{u} for u in G.nodes()] + if nx.is_empty(G): + yield partition + return + mod = modularity(G, partition, resolution=resolution, weight=weight) + is_directed = G.is_directed() + if G.is_multigraph(): + graph = _convert_multigraph(G, weight, is_directed) + else: + graph = G.__class__() + graph.add_nodes_from(G) + graph.add_weighted_edges_from(G.edges(data=weight, default=1)) + + m = graph.size(weight="weight") + partition, inner_partition, improvement = _one_level( + graph, m, partition, resolution, is_directed, seed + ) + improvement = True + while improvement: + # gh-5901 protect the sets in the yielded list from further manipulation here + yield [s.copy() for s in partition] + new_mod = modularity( + graph, inner_partition, resolution=resolution, weight="weight" + ) + if new_mod - mod <= threshold: + return + mod = new_mod + graph = _gen_graph(graph, inner_partition) + partition, inner_partition, improvement = _one_level( + graph, m, partition, resolution, is_directed, seed + ) + + +def _one_level(G, m, partition, resolution=1, is_directed=False, seed=None): + """Calculate one level of the Louvain partitions tree + + Parameters + ---------- + G : NetworkX Graph/DiGraph + The graph from which to detect communities + m : number + The size of the graph `G`. + partition : list of sets of nodes + A valid partition of the graph `G` + resolution : positive number + The resolution parameter for computing the modularity of a partition + is_directed : bool + True if `G` is a directed graph. + seed : integer, random_state, or None (default) + Indicator of random number generation state. + See :ref:`Randomness`. + + """ + node2com = {u: i for i, u in enumerate(G.nodes())} + inner_partition = [{u} for u in G.nodes()] + if is_directed: + in_degrees = dict(G.in_degree(weight="weight")) + out_degrees = dict(G.out_degree(weight="weight")) + Stot_in = list(in_degrees.values()) + Stot_out = list(out_degrees.values()) + # Calculate weights for both in and out neighbors without considering self-loops + nbrs = {} + for u in G: + nbrs[u] = defaultdict(float) + for _, n, wt in G.out_edges(u, data="weight"): + if u != n: + nbrs[u][n] += wt + for n, _, wt in G.in_edges(u, data="weight"): + if u != n: + nbrs[u][n] += wt + else: + degrees = dict(G.degree(weight="weight")) + Stot = list(degrees.values()) + nbrs = {u: {v: data["weight"] for v, data in G[u].items() if v != u} for u in G} + rand_nodes = list(G.nodes) + seed.shuffle(rand_nodes) + nb_moves = 1 + improvement = False + while nb_moves > 0: + nb_moves = 0 + for u in rand_nodes: + best_mod = 0 + best_com = node2com[u] + weights2com = _neighbor_weights(nbrs[u], node2com) + if is_directed: + in_degree = in_degrees[u] + out_degree = out_degrees[u] + Stot_in[best_com] -= in_degree + Stot_out[best_com] -= out_degree + remove_cost = ( + -weights2com[best_com] / m + + resolution + * (out_degree * Stot_in[best_com] + in_degree * Stot_out[best_com]) + / m**2 + ) + else: + degree = degrees[u] + Stot[best_com] -= degree + remove_cost = -weights2com[best_com] / m + resolution * ( + Stot[best_com] * degree + ) / (2 * m**2) + for nbr_com, wt in weights2com.items(): + if is_directed: + gain = ( + remove_cost + + wt / m + - resolution + * ( + out_degree * Stot_in[nbr_com] + + in_degree * Stot_out[nbr_com] + ) + / m**2 + ) + else: + gain = ( + remove_cost + + wt / m + - resolution * (Stot[nbr_com] * degree) / (2 * m**2) + ) + if gain > best_mod: + best_mod = gain + best_com = nbr_com + if is_directed: + Stot_in[best_com] += in_degree + Stot_out[best_com] += out_degree + else: + Stot[best_com] += degree + if best_com != node2com[u]: + com = G.nodes[u].get("nodes", {u}) + partition[node2com[u]].difference_update(com) + inner_partition[node2com[u]].remove(u) + partition[best_com].update(com) + inner_partition[best_com].add(u) + improvement = True + nb_moves += 1 + node2com[u] = best_com + partition = list(filter(len, partition)) + inner_partition = list(filter(len, inner_partition)) + return partition, inner_partition, improvement + + +def _neighbor_weights(nbrs, node2com): + """Calculate weights between node and its neighbor communities. + + Parameters + ---------- + nbrs : dictionary + Dictionary with nodes' neighbors as keys and their edge weight as value. + node2com : dictionary + Dictionary with all graph's nodes as keys and their community index as value. + + """ + weights = defaultdict(float) + for nbr, wt in nbrs.items(): + weights[node2com[nbr]] += wt + return weights + + +def _gen_graph(G, partition): + """Generate a new graph based on the partitions of a given graph""" + H = G.__class__() + node2com = {} + for i, part in enumerate(partition): + nodes = set() + for node in part: + node2com[node] = i + nodes.update(G.nodes[node].get("nodes", {node})) + H.add_node(i, nodes=nodes) + + for node1, node2, wt in G.edges(data=True): + wt = wt["weight"] + com1 = node2com[node1] + com2 = node2com[node2] + temp = H.get_edge_data(com1, com2, {"weight": 0})["weight"] + H.add_edge(com1, com2, weight=wt + temp) + return H + + +def _convert_multigraph(G, weight, is_directed): + """Convert a Multigraph to normal Graph""" + if is_directed: + H = nx.DiGraph() + else: + H = nx.Graph() + H.add_nodes_from(G) + for u, v, wt in G.edges(data=weight, default=1): + if H.has_edge(u, v): + H[u][v]["weight"] += wt + else: + H.add_edge(u, v, weight=wt) + return H diff --git a/lib/python3.10/site-packages/networkx/algorithms/community/lukes.py b/lib/python3.10/site-packages/networkx/algorithms/community/lukes.py new file mode 100644 index 0000000000000000000000000000000000000000..08dd7cd52ff414c1397e3effea504853f3c9caf7 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/community/lukes.py @@ -0,0 +1,227 @@ +"""Lukes Algorithm for exact optimal weighted tree partitioning.""" + +from copy import deepcopy +from functools import lru_cache +from random import choice + +import networkx as nx +from networkx.utils import not_implemented_for + +__all__ = ["lukes_partitioning"] + +D_EDGE_W = "weight" +D_EDGE_VALUE = 1.0 +D_NODE_W = "weight" +D_NODE_VALUE = 1 +PKEY = "partitions" +CLUSTER_EVAL_CACHE_SIZE = 2048 + + +def _split_n_from(n, min_size_of_first_part): + # splits j in two parts of which the first is at least + # the second argument + assert n >= min_size_of_first_part + for p1 in range(min_size_of_first_part, n + 1): + yield p1, n - p1 + + +@nx._dispatchable(node_attrs="node_weight", edge_attrs="edge_weight") +def lukes_partitioning(G, max_size, node_weight=None, edge_weight=None): + """Optimal partitioning of a weighted tree using the Lukes algorithm. + + This algorithm partitions a connected, acyclic graph featuring integer + node weights and float edge weights. The resulting clusters are such + that the total weight of the nodes in each cluster does not exceed + max_size and that the weight of the edges that are cut by the partition + is minimum. The algorithm is based on [1]_. + + Parameters + ---------- + G : NetworkX graph + + max_size : int + Maximum weight a partition can have in terms of sum of + node_weight for all nodes in the partition + + edge_weight : key + Edge data key to use as weight. If None, the weights are all + set to one. + + node_weight : key + Node data key to use as weight. If None, the weights are all + set to one. The data must be int. + + Returns + ------- + partition : list + A list of sets of nodes representing the clusters of the + partition. + + Raises + ------ + NotATree + If G is not a tree. + TypeError + If any of the values of node_weight is not int. + + References + ---------- + .. [1] Lukes, J. A. (1974). + "Efficient Algorithm for the Partitioning of Trees." + IBM Journal of Research and Development, 18(3), 217–224. + + """ + # First sanity check and tree preparation + if not nx.is_tree(G): + raise nx.NotATree("lukes_partitioning works only on trees") + else: + if nx.is_directed(G): + root = [n for n, d in G.in_degree() if d == 0] + assert len(root) == 1 + root = root[0] + t_G = deepcopy(G) + else: + root = choice(list(G.nodes)) + # this has the desirable side effect of not inheriting attributes + t_G = nx.dfs_tree(G, root) + + # Since we do not want to screw up the original graph, + # if we have a blank attribute, we make a deepcopy + if edge_weight is None or node_weight is None: + safe_G = deepcopy(G) + if edge_weight is None: + nx.set_edge_attributes(safe_G, D_EDGE_VALUE, D_EDGE_W) + edge_weight = D_EDGE_W + if node_weight is None: + nx.set_node_attributes(safe_G, D_NODE_VALUE, D_NODE_W) + node_weight = D_NODE_W + else: + safe_G = G + + # Second sanity check + # The values of node_weight MUST BE int. + # I cannot see any room for duck typing without incurring serious + # danger of subtle bugs. + all_n_attr = nx.get_node_attributes(safe_G, node_weight).values() + for x in all_n_attr: + if not isinstance(x, int): + raise TypeError( + "lukes_partitioning needs integer " + f"values for node_weight ({node_weight})" + ) + + # SUBROUTINES ----------------------- + # these functions are defined here for two reasons: + # - brevity: we can leverage global "safe_G" + # - caching: signatures are hashable + + @not_implemented_for("undirected") + # this is intended to be called only on t_G + def _leaves(gr): + for x in gr.nodes: + if not nx.descendants(gr, x): + yield x + + @not_implemented_for("undirected") + def _a_parent_of_leaves_only(gr): + tleaves = set(_leaves(gr)) + for n in set(gr.nodes) - tleaves: + if all(x in tleaves for x in nx.descendants(gr, n)): + return n + + @lru_cache(CLUSTER_EVAL_CACHE_SIZE) + def _value_of_cluster(cluster): + valid_edges = [e for e in safe_G.edges if e[0] in cluster and e[1] in cluster] + return sum(safe_G.edges[e][edge_weight] for e in valid_edges) + + def _value_of_partition(partition): + return sum(_value_of_cluster(frozenset(c)) for c in partition) + + @lru_cache(CLUSTER_EVAL_CACHE_SIZE) + def _weight_of_cluster(cluster): + return sum(safe_G.nodes[n][node_weight] for n in cluster) + + def _pivot(partition, node): + ccx = [c for c in partition if node in c] + assert len(ccx) == 1 + return ccx[0] + + def _concatenate_or_merge(partition_1, partition_2, x, i, ref_weight): + ccx = _pivot(partition_1, x) + cci = _pivot(partition_2, i) + merged_xi = ccx.union(cci) + + # We first check if we can do the merge. + # If so, we do the actual calculations, otherwise we concatenate + if _weight_of_cluster(frozenset(merged_xi)) <= ref_weight: + cp1 = list(filter(lambda x: x != ccx, partition_1)) + cp2 = list(filter(lambda x: x != cci, partition_2)) + + option_2 = [merged_xi] + cp1 + cp2 + return option_2, _value_of_partition(option_2) + else: + option_1 = partition_1 + partition_2 + return option_1, _value_of_partition(option_1) + + # INITIALIZATION ----------------------- + leaves = set(_leaves(t_G)) + for lv in leaves: + t_G.nodes[lv][PKEY] = {} + slot = safe_G.nodes[lv][node_weight] + t_G.nodes[lv][PKEY][slot] = [{lv}] + t_G.nodes[lv][PKEY][0] = [{lv}] + + for inner in [x for x in t_G.nodes if x not in leaves]: + t_G.nodes[inner][PKEY] = {} + slot = safe_G.nodes[inner][node_weight] + t_G.nodes[inner][PKEY][slot] = [{inner}] + nx._clear_cache(t_G) + + # CORE ALGORITHM ----------------------- + while True: + x_node = _a_parent_of_leaves_only(t_G) + weight_of_x = safe_G.nodes[x_node][node_weight] + best_value = 0 + best_partition = None + bp_buffer = {} + x_descendants = nx.descendants(t_G, x_node) + for i_node in x_descendants: + for j in range(weight_of_x, max_size + 1): + for a, b in _split_n_from(j, weight_of_x): + if ( + a not in t_G.nodes[x_node][PKEY] + or b not in t_G.nodes[i_node][PKEY] + ): + # it's not possible to form this particular weight sum + continue + + part1 = t_G.nodes[x_node][PKEY][a] + part2 = t_G.nodes[i_node][PKEY][b] + part, value = _concatenate_or_merge(part1, part2, x_node, i_node, j) + + if j not in bp_buffer or bp_buffer[j][1] < value: + # we annotate in the buffer the best partition for j + bp_buffer[j] = part, value + + # we also keep track of the overall best partition + if best_value <= value: + best_value = value + best_partition = part + + # as illustrated in Lukes, once we finished a child, we can + # discharge the partitions we found into the graph + # (the key phrase is make all x == x') + # so that they are used by the subsequent children + for w, (best_part_for_vl, vl) in bp_buffer.items(): + t_G.nodes[x_node][PKEY][w] = best_part_for_vl + bp_buffer.clear() + + # the absolute best partition for this node + # across all weights has to be stored at 0 + t_G.nodes[x_node][PKEY][0] = best_partition + t_G.remove_nodes_from(x_descendants) + + if x_node == root: + # the 0-labeled partition of root + # is the optimal one for the whole tree + return t_G.nodes[root][PKEY][0] diff --git a/lib/python3.10/site-packages/networkx/algorithms/community/modularity_max.py b/lib/python3.10/site-packages/networkx/algorithms/community/modularity_max.py new file mode 100644 index 0000000000000000000000000000000000000000..f465e01c6b20ec0a34c7d8402ebdfb6e3e1b4e0e --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/community/modularity_max.py @@ -0,0 +1,451 @@ +"""Functions for detecting communities based on modularity.""" + +from collections import defaultdict + +import networkx as nx +from networkx.algorithms.community.quality import modularity +from networkx.utils import not_implemented_for +from networkx.utils.mapped_queue import MappedQueue + +__all__ = [ + "greedy_modularity_communities", + "naive_greedy_modularity_communities", +] + + +def _greedy_modularity_communities_generator(G, weight=None, resolution=1): + r"""Yield community partitions of G and the modularity change at each step. + + This function performs Clauset-Newman-Moore greedy modularity maximization [2]_ + At each step of the process it yields the change in modularity that will occur in + the next step followed by yielding the new community partition after that step. + + Greedy modularity maximization begins with each node in its own community + and repeatedly joins the pair of communities that lead to the largest + modularity until one community contains all nodes (the partition has one set). + + This function maximizes the generalized modularity, where `resolution` + is the resolution parameter, often expressed as $\gamma$. + See :func:`~networkx.algorithms.community.quality.modularity`. + + Parameters + ---------- + G : NetworkX graph + + weight : string or None, optional (default=None) + The name of an edge attribute that holds the numerical value used + as a weight. If None, then each edge has weight 1. + The degree is the sum of the edge weights adjacent to the node. + + resolution : float (default=1) + If resolution is less than 1, modularity favors larger communities. + Greater than 1 favors smaller communities. + + Yields + ------ + Alternating yield statements produce the following two objects: + + communities: dict_values + A dict_values of frozensets of nodes, one for each community. + This represents a partition of the nodes of the graph into communities. + The first yield is the partition with each node in its own community. + + dq: float + The change in modularity when merging the next two communities + that leads to the largest modularity. + + See Also + -------- + modularity + + References + ---------- + .. [1] Newman, M. E. J. "Networks: An Introduction", page 224 + Oxford University Press 2011. + .. [2] Clauset, A., Newman, M. E., & Moore, C. + "Finding community structure in very large networks." + Physical Review E 70(6), 2004. + .. [3] Reichardt and Bornholdt "Statistical Mechanics of Community + Detection" Phys. Rev. E74, 2006. + .. [4] Newman, M. E. J."Analysis of weighted networks" + Physical Review E 70(5 Pt 2):056131, 2004. + """ + directed = G.is_directed() + N = G.number_of_nodes() + + # Count edges (or the sum of edge-weights for weighted graphs) + m = G.size(weight) + q0 = 1 / m + + # Calculate degrees (notation from the papers) + # a : the fraction of (weighted) out-degree for each node + # b : the fraction of (weighted) in-degree for each node + if directed: + a = {node: deg_out * q0 for node, deg_out in G.out_degree(weight=weight)} + b = {node: deg_in * q0 for node, deg_in in G.in_degree(weight=weight)} + else: + a = b = {node: deg * q0 * 0.5 for node, deg in G.degree(weight=weight)} + + # this preliminary step collects the edge weights for each node pair + # It handles multigraph and digraph and works fine for graph. + dq_dict = defaultdict(lambda: defaultdict(float)) + for u, v, wt in G.edges(data=weight, default=1): + if u == v: + continue + dq_dict[u][v] += wt + dq_dict[v][u] += wt + + # now scale and subtract the expected edge-weights term + for u, nbrdict in dq_dict.items(): + for v, wt in nbrdict.items(): + dq_dict[u][v] = q0 * wt - resolution * (a[u] * b[v] + b[u] * a[v]) + + # Use -dq to get a max_heap instead of a min_heap + # dq_heap holds a heap for each node's neighbors + dq_heap = {u: MappedQueue({(u, v): -dq for v, dq in dq_dict[u].items()}) for u in G} + # H -> all_dq_heap holds a heap with the best items for each node + H = MappedQueue([dq_heap[n].heap[0] for n in G if len(dq_heap[n]) > 0]) + + # Initialize single-node communities + communities = {n: frozenset([n]) for n in G} + yield communities.values() + + # Merge the two communities that lead to the largest modularity + while len(H) > 1: + # Find best merge + # Remove from heap of row maxes + # Ties will be broken by choosing the pair with lowest min community id + try: + negdq, u, v = H.pop() + except IndexError: + break + dq = -negdq + yield dq + # Remove best merge from row u heap + dq_heap[u].pop() + # Push new row max onto H + if len(dq_heap[u]) > 0: + H.push(dq_heap[u].heap[0]) + # If this element was also at the root of row v, we need to remove the + # duplicate entry from H + if dq_heap[v].heap[0] == (v, u): + H.remove((v, u)) + # Remove best merge from row v heap + dq_heap[v].remove((v, u)) + # Push new row max onto H + if len(dq_heap[v]) > 0: + H.push(dq_heap[v].heap[0]) + else: + # Duplicate wasn't in H, just remove from row v heap + dq_heap[v].remove((v, u)) + + # Perform merge + communities[v] = frozenset(communities[u] | communities[v]) + del communities[u] + + # Get neighbor communities connected to the merged communities + u_nbrs = set(dq_dict[u]) + v_nbrs = set(dq_dict[v]) + all_nbrs = (u_nbrs | v_nbrs) - {u, v} + both_nbrs = u_nbrs & v_nbrs + # Update dq for merge of u into v + for w in all_nbrs: + # Calculate new dq value + if w in both_nbrs: + dq_vw = dq_dict[v][w] + dq_dict[u][w] + elif w in v_nbrs: + dq_vw = dq_dict[v][w] - resolution * (a[u] * b[w] + a[w] * b[u]) + else: # w in u_nbrs + dq_vw = dq_dict[u][w] - resolution * (a[v] * b[w] + a[w] * b[v]) + # Update rows v and w + for row, col in [(v, w), (w, v)]: + dq_heap_row = dq_heap[row] + # Update dict for v,w only (u is removed below) + dq_dict[row][col] = dq_vw + # Save old max of per-row heap + if len(dq_heap_row) > 0: + d_oldmax = dq_heap_row.heap[0] + else: + d_oldmax = None + # Add/update heaps + d = (row, col) + d_negdq = -dq_vw + # Save old value for finding heap index + if w in v_nbrs: + # Update existing element in per-row heap + dq_heap_row.update(d, d, priority=d_negdq) + else: + # We're creating a new nonzero element, add to heap + dq_heap_row.push(d, priority=d_negdq) + # Update heap of row maxes if necessary + if d_oldmax is None: + # No entries previously in this row, push new max + H.push(d, priority=d_negdq) + else: + # We've updated an entry in this row, has the max changed? + row_max = dq_heap_row.heap[0] + if d_oldmax != row_max or d_oldmax.priority != row_max.priority: + H.update(d_oldmax, row_max) + + # Remove row/col u from dq_dict matrix + for w in dq_dict[u]: + # Remove from dict + dq_old = dq_dict[w][u] + del dq_dict[w][u] + # Remove from heaps if we haven't already + if w != v: + # Remove both row and column + for row, col in [(w, u), (u, w)]: + dq_heap_row = dq_heap[row] + # Check if replaced dq is row max + d_old = (row, col) + if dq_heap_row.heap[0] == d_old: + # Update per-row heap and heap of row maxes + dq_heap_row.remove(d_old) + H.remove(d_old) + # Update row max + if len(dq_heap_row) > 0: + H.push(dq_heap_row.heap[0]) + else: + # Only update per-row heap + dq_heap_row.remove(d_old) + + del dq_dict[u] + # Mark row u as deleted, but keep placeholder + dq_heap[u] = MappedQueue() + # Merge u into v and update a + a[v] += a[u] + a[u] = 0 + if directed: + b[v] += b[u] + b[u] = 0 + + yield communities.values() + + +@nx._dispatchable(edge_attrs="weight") +def greedy_modularity_communities( + G, + weight=None, + resolution=1, + cutoff=1, + best_n=None, +): + r"""Find communities in G using greedy modularity maximization. + + This function uses Clauset-Newman-Moore greedy modularity maximization [2]_ + to find the community partition with the largest modularity. + + Greedy modularity maximization begins with each node in its own community + and repeatedly joins the pair of communities that lead to the largest + modularity until no further increase in modularity is possible (a maximum). + Two keyword arguments adjust the stopping condition. `cutoff` is a lower + limit on the number of communities so you can stop the process before + reaching a maximum (used to save computation time). `best_n` is an upper + limit on the number of communities so you can make the process continue + until at most n communities remain even if the maximum modularity occurs + for more. To obtain exactly n communities, set both `cutoff` and `best_n` to n. + + This function maximizes the generalized modularity, where `resolution` + is the resolution parameter, often expressed as $\gamma$. + See :func:`~networkx.algorithms.community.quality.modularity`. + + Parameters + ---------- + G : NetworkX graph + + weight : string or None, optional (default=None) + The name of an edge attribute that holds the numerical value used + as a weight. If None, then each edge has weight 1. + The degree is the sum of the edge weights adjacent to the node. + + resolution : float, optional (default=1) + If resolution is less than 1, modularity favors larger communities. + Greater than 1 favors smaller communities. + + cutoff : int, optional (default=1) + A minimum number of communities below which the merging process stops. + The process stops at this number of communities even if modularity + is not maximized. The goal is to let the user stop the process early. + The process stops before the cutoff if it finds a maximum of modularity. + + best_n : int or None, optional (default=None) + A maximum number of communities above which the merging process will + not stop. This forces community merging to continue after modularity + starts to decrease until `best_n` communities remain. + If ``None``, don't force it to continue beyond a maximum. + + Raises + ------ + ValueError : If the `cutoff` or `best_n` value is not in the range + ``[1, G.number_of_nodes()]``, or if `best_n` < `cutoff`. + + Returns + ------- + communities: list + A list of frozensets of nodes, one for each community. + Sorted by length with largest communities first. + + Examples + -------- + >>> G = nx.karate_club_graph() + >>> c = nx.community.greedy_modularity_communities(G) + >>> sorted(c[0]) + [8, 14, 15, 18, 20, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33] + + See Also + -------- + modularity + + References + ---------- + .. [1] Newman, M. E. J. "Networks: An Introduction", page 224 + Oxford University Press 2011. + .. [2] Clauset, A., Newman, M. E., & Moore, C. + "Finding community structure in very large networks." + Physical Review E 70(6), 2004. + .. [3] Reichardt and Bornholdt "Statistical Mechanics of Community + Detection" Phys. Rev. E74, 2006. + .. [4] Newman, M. E. J."Analysis of weighted networks" + Physical Review E 70(5 Pt 2):056131, 2004. + """ + if not G.size(): + return [{n} for n in G] + + if (cutoff < 1) or (cutoff > G.number_of_nodes()): + raise ValueError(f"cutoff must be between 1 and {len(G)}. Got {cutoff}.") + if best_n is not None: + if (best_n < 1) or (best_n > G.number_of_nodes()): + raise ValueError(f"best_n must be between 1 and {len(G)}. Got {best_n}.") + if best_n < cutoff: + raise ValueError(f"Must have best_n >= cutoff. Got {best_n} < {cutoff}") + if best_n == 1: + return [set(G)] + else: + best_n = G.number_of_nodes() + + # retrieve generator object to construct output + community_gen = _greedy_modularity_communities_generator( + G, weight=weight, resolution=resolution + ) + + # construct the first best community + communities = next(community_gen) + + # continue merging communities until one of the breaking criteria is satisfied + while len(communities) > cutoff: + try: + dq = next(community_gen) + # StopIteration occurs when communities are the connected components + except StopIteration: + communities = sorted(communities, key=len, reverse=True) + # if best_n requires more merging, merge big sets for highest modularity + while len(communities) > best_n: + comm1, comm2, *rest = communities + communities = [comm1 ^ comm2] + communities.extend(rest) + return communities + + # keep going unless max_mod is reached or best_n says to merge more + if dq < 0 and len(communities) <= best_n: + break + communities = next(community_gen) + + return sorted(communities, key=len, reverse=True) + + +@not_implemented_for("directed") +@not_implemented_for("multigraph") +@nx._dispatchable(edge_attrs="weight") +def naive_greedy_modularity_communities(G, resolution=1, weight=None): + r"""Find communities in G using greedy modularity maximization. + + This implementation is O(n^4), much slower than alternatives, but it is + provided as an easy-to-understand reference implementation. + + Greedy modularity maximization begins with each node in its own community + and joins the pair of communities that most increases modularity until no + such pair exists. + + This function maximizes the generalized modularity, where `resolution` + is the resolution parameter, often expressed as $\gamma$. + See :func:`~networkx.algorithms.community.quality.modularity`. + + Parameters + ---------- + G : NetworkX graph + Graph must be simple and undirected. + + resolution : float (default=1) + If resolution is less than 1, modularity favors larger communities. + Greater than 1 favors smaller communities. + + weight : string or None, optional (default=None) + The name of an edge attribute that holds the numerical value used + as a weight. If None, then each edge has weight 1. + The degree is the sum of the edge weights adjacent to the node. + + Returns + ------- + list + A list of sets of nodes, one for each community. + Sorted by length with largest communities first. + + Examples + -------- + >>> G = nx.karate_club_graph() + >>> c = nx.community.naive_greedy_modularity_communities(G) + >>> sorted(c[0]) + [8, 14, 15, 18, 20, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33] + + See Also + -------- + greedy_modularity_communities + modularity + """ + # First create one community for each node + communities = [frozenset([u]) for u in G.nodes()] + # Track merges + merges = [] + # Greedily merge communities until no improvement is possible + old_modularity = None + new_modularity = modularity(G, communities, resolution=resolution, weight=weight) + while old_modularity is None or new_modularity > old_modularity: + # Save modularity for comparison + old_modularity = new_modularity + # Find best pair to merge + trial_communities = list(communities) + to_merge = None + for i, u in enumerate(communities): + for j, v in enumerate(communities): + # Skip i==j and empty communities + if j <= i or len(u) == 0 or len(v) == 0: + continue + # Merge communities u and v + trial_communities[j] = u | v + trial_communities[i] = frozenset([]) + trial_modularity = modularity( + G, trial_communities, resolution=resolution, weight=weight + ) + if trial_modularity >= new_modularity: + # Check if strictly better or tie + if trial_modularity > new_modularity: + # Found new best, save modularity and group indexes + new_modularity = trial_modularity + to_merge = (i, j, new_modularity - old_modularity) + elif to_merge and min(i, j) < min(to_merge[0], to_merge[1]): + # Break ties by choosing pair with lowest min id + new_modularity = trial_modularity + to_merge = (i, j, new_modularity - old_modularity) + # Un-merge + trial_communities[i] = u + trial_communities[j] = v + if to_merge is not None: + # If the best merge improves modularity, use it + merges.append(to_merge) + i, j, dq = to_merge + u, v = communities[i], communities[j] + communities[j] = u | v + communities[i] = frozenset([]) + # Remove empty communities and sort + return sorted((c for c in communities if len(c) > 0), key=len, reverse=True) diff --git a/lib/python3.10/site-packages/networkx/algorithms/community/quality.py b/lib/python3.10/site-packages/networkx/algorithms/community/quality.py new file mode 100644 index 0000000000000000000000000000000000000000..f09a6d454af24b38841de473d9e4584bb3a460e9 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/community/quality.py @@ -0,0 +1,346 @@ +"""Functions for measuring the quality of a partition (into +communities). + +""" + +from itertools import combinations + +import networkx as nx +from networkx import NetworkXError +from networkx.algorithms.community.community_utils import is_partition +from networkx.utils.decorators import argmap + +__all__ = ["modularity", "partition_quality"] + + +class NotAPartition(NetworkXError): + """Raised if a given collection is not a partition.""" + + def __init__(self, G, collection): + msg = f"{collection} is not a valid partition of the graph {G}" + super().__init__(msg) + + +def _require_partition(G, partition): + """Decorator to check that a valid partition is input to a function + + Raises :exc:`networkx.NetworkXError` if the partition is not valid. + + This decorator should be used on functions whose first two arguments + are a graph and a partition of the nodes of that graph (in that + order):: + + >>> @require_partition + ... def foo(G, partition): + ... print("partition is valid!") + ... + >>> G = nx.complete_graph(5) + >>> partition = [{0, 1}, {2, 3}, {4}] + >>> foo(G, partition) + partition is valid! + >>> partition = [{0}, {2, 3}, {4}] + >>> foo(G, partition) + Traceback (most recent call last): + ... + networkx.exception.NetworkXError: `partition` is not a valid partition of the nodes of G + >>> partition = [{0, 1}, {1, 2, 3}, {4}] + >>> foo(G, partition) + Traceback (most recent call last): + ... + networkx.exception.NetworkXError: `partition` is not a valid partition of the nodes of G + + """ + if is_partition(G, partition): + return G, partition + raise nx.NetworkXError("`partition` is not a valid partition of the nodes of G") + + +require_partition = argmap(_require_partition, (0, 1)) + + +@nx._dispatchable +def intra_community_edges(G, partition): + """Returns the number of intra-community edges for a partition of `G`. + + Parameters + ---------- + G : NetworkX graph. + + partition : iterable of sets of nodes + This must be a partition of the nodes of `G`. + + The "intra-community edges" are those edges joining a pair of nodes + in the same block of the partition. + + """ + return sum(G.subgraph(block).size() for block in partition) + + +@nx._dispatchable +def inter_community_edges(G, partition): + """Returns the number of inter-community edges for a partition of `G`. + according to the given + partition of the nodes of `G`. + + Parameters + ---------- + G : NetworkX graph. + + partition : iterable of sets of nodes + This must be a partition of the nodes of `G`. + + The *inter-community edges* are those edges joining a pair of nodes + in different blocks of the partition. + + Implementation note: this function creates an intermediate graph + that may require the same amount of memory as that of `G`. + + """ + # Alternate implementation that does not require constructing a new + # graph object (but does require constructing an affiliation + # dictionary): + # + # aff = dict(chain.from_iterable(((v, block) for v in block) + # for block in partition)) + # return sum(1 for u, v in G.edges() if aff[u] != aff[v]) + # + MG = nx.MultiDiGraph if G.is_directed() else nx.MultiGraph + return nx.quotient_graph(G, partition, create_using=MG).size() + + +@nx._dispatchable +def inter_community_non_edges(G, partition): + """Returns the number of inter-community non-edges according to the + given partition of the nodes of `G`. + + Parameters + ---------- + G : NetworkX graph. + + partition : iterable of sets of nodes + This must be a partition of the nodes of `G`. + + A *non-edge* is a pair of nodes (undirected if `G` is undirected) + that are not adjacent in `G`. The *inter-community non-edges* are + those non-edges on a pair of nodes in different blocks of the + partition. + + Implementation note: this function creates two intermediate graphs, + which may require up to twice the amount of memory as required to + store `G`. + + """ + # Alternate implementation that does not require constructing two + # new graph objects (but does require constructing an affiliation + # dictionary): + # + # aff = dict(chain.from_iterable(((v, block) for v in block) + # for block in partition)) + # return sum(1 for u, v in nx.non_edges(G) if aff[u] != aff[v]) + # + return inter_community_edges(nx.complement(G), partition) + + +@nx._dispatchable(edge_attrs="weight") +def modularity(G, communities, weight="weight", resolution=1): + r"""Returns the modularity of the given partition of the graph. + + Modularity is defined in [1]_ as + + .. math:: + Q = \frac{1}{2m} \sum_{ij} \left( A_{ij} - \gamma\frac{k_ik_j}{2m}\right) + \delta(c_i,c_j) + + where $m$ is the number of edges (or sum of all edge weights as in [5]_), + $A$ is the adjacency matrix of `G`, $k_i$ is the (weighted) degree of $i$, + $\gamma$ is the resolution parameter, and $\delta(c_i, c_j)$ is 1 if $i$ and + $j$ are in the same community else 0. + + According to [2]_ (and verified by some algebra) this can be reduced to + + .. math:: + Q = \sum_{c=1}^{n} + \left[ \frac{L_c}{m} - \gamma\left( \frac{k_c}{2m} \right) ^2 \right] + + where the sum iterates over all communities $c$, $m$ is the number of edges, + $L_c$ is the number of intra-community links for community $c$, + $k_c$ is the sum of degrees of the nodes in community $c$, + and $\gamma$ is the resolution parameter. + + The resolution parameter sets an arbitrary tradeoff between intra-group + edges and inter-group edges. More complex grouping patterns can be + discovered by analyzing the same network with multiple values of gamma + and then combining the results [3]_. That said, it is very common to + simply use gamma=1. More on the choice of gamma is in [4]_. + + The second formula is the one actually used in calculation of the modularity. + For directed graphs the second formula replaces $k_c$ with $k^{in}_c k^{out}_c$. + + Parameters + ---------- + G : NetworkX Graph + + communities : list or iterable of set of nodes + These node sets must represent a partition of G's nodes. + + weight : string or None, optional (default="weight") + The edge attribute that holds the numerical value used + as a weight. If None or an edge does not have that attribute, + then that edge has weight 1. + + resolution : float (default=1) + If resolution is less than 1, modularity favors larger communities. + Greater than 1 favors smaller communities. + + Returns + ------- + Q : float + The modularity of the partition. + + Raises + ------ + NotAPartition + If `communities` is not a partition of the nodes of `G`. + + Examples + -------- + >>> G = nx.barbell_graph(3, 0) + >>> nx.community.modularity(G, [{0, 1, 2}, {3, 4, 5}]) + 0.35714285714285715 + >>> nx.community.modularity(G, nx.community.label_propagation_communities(G)) + 0.35714285714285715 + + References + ---------- + .. [1] M. E. J. Newman "Networks: An Introduction", page 224. + Oxford University Press, 2011. + .. [2] Clauset, Aaron, Mark EJ Newman, and Cristopher Moore. + "Finding community structure in very large networks." + Phys. Rev. E 70.6 (2004). + .. [3] Reichardt and Bornholdt "Statistical Mechanics of Community Detection" + Phys. Rev. E 74, 016110, 2006. https://doi.org/10.1103/PhysRevE.74.016110 + .. [4] M. E. J. Newman, "Equivalence between modularity optimization and + maximum likelihood methods for community detection" + Phys. Rev. E 94, 052315, 2016. https://doi.org/10.1103/PhysRevE.94.052315 + .. [5] Blondel, V.D. et al. "Fast unfolding of communities in large + networks" J. Stat. Mech 10008, 1-12 (2008). + https://doi.org/10.1088/1742-5468/2008/10/P10008 + """ + if not isinstance(communities, list): + communities = list(communities) + if not is_partition(G, communities): + raise NotAPartition(G, communities) + + directed = G.is_directed() + if directed: + out_degree = dict(G.out_degree(weight=weight)) + in_degree = dict(G.in_degree(weight=weight)) + m = sum(out_degree.values()) + norm = 1 / m**2 + else: + out_degree = in_degree = dict(G.degree(weight=weight)) + deg_sum = sum(out_degree.values()) + m = deg_sum / 2 + norm = 1 / deg_sum**2 + + def community_contribution(community): + comm = set(community) + L_c = sum(wt for u, v, wt in G.edges(comm, data=weight, default=1) if v in comm) + + out_degree_sum = sum(out_degree[u] for u in comm) + in_degree_sum = sum(in_degree[u] for u in comm) if directed else out_degree_sum + + return L_c / m - resolution * out_degree_sum * in_degree_sum * norm + + return sum(map(community_contribution, communities)) + + +@require_partition +@nx._dispatchable +def partition_quality(G, partition): + """Returns the coverage and performance of a partition of G. + + The *coverage* of a partition is the ratio of the number of + intra-community edges to the total number of edges in the graph. + + The *performance* of a partition is the number of + intra-community edges plus inter-community non-edges divided by the total + number of potential edges. + + This algorithm has complexity $O(C^2 + L)$ where C is the number of communities and L is the number of links. + + Parameters + ---------- + G : NetworkX graph + + partition : sequence + Partition of the nodes of `G`, represented as a sequence of + sets of nodes (blocks). Each block of the partition represents a + community. + + Returns + ------- + (float, float) + The (coverage, performance) tuple of the partition, as defined above. + + Raises + ------ + NetworkXError + If `partition` is not a valid partition of the nodes of `G`. + + Notes + ----- + If `G` is a multigraph; + - for coverage, the multiplicity of edges is counted + - for performance, the result is -1 (total number of possible edges is not defined) + + References + ---------- + .. [1] Santo Fortunato. + "Community Detection in Graphs". + *Physical Reports*, Volume 486, Issue 3--5 pp. 75--174 + + """ + + node_community = {} + for i, community in enumerate(partition): + for node in community: + node_community[node] = i + + # `performance` is not defined for multigraphs + if not G.is_multigraph(): + # Iterate over the communities, quadratic, to calculate `possible_inter_community_edges` + possible_inter_community_edges = sum( + len(p1) * len(p2) for p1, p2 in combinations(partition, 2) + ) + + if G.is_directed(): + possible_inter_community_edges *= 2 + else: + possible_inter_community_edges = 0 + + # Compute the number of edges in the complete graph -- `n` nodes, + # directed or undirected, depending on `G` + n = len(G) + total_pairs = n * (n - 1) + if not G.is_directed(): + total_pairs //= 2 + + intra_community_edges = 0 + inter_community_non_edges = possible_inter_community_edges + + # Iterate over the links to count `intra_community_edges` and `inter_community_non_edges` + for e in G.edges(): + if node_community[e[0]] == node_community[e[1]]: + intra_community_edges += 1 + else: + inter_community_non_edges -= 1 + + coverage = intra_community_edges / len(G.edges) + + if G.is_multigraph(): + performance = -1.0 + else: + performance = (intra_community_edges + inter_community_non_edges) / total_pairs + + return coverage, performance diff --git a/lib/python3.10/site-packages/networkx/algorithms/community/tests/__init__.py b/lib/python3.10/site-packages/networkx/algorithms/community/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/python3.10/site-packages/networkx/algorithms/community/tests/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/community/tests/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..209f30fdb25b31844de753f5e7c3b4990457e2d9 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/community/tests/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/community/tests/__pycache__/test_asyn_fluid.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/community/tests/__pycache__/test_asyn_fluid.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa386661f8def5d8468787ded073caff7bd32fb7 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/community/tests/__pycache__/test_asyn_fluid.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/community/tests/__pycache__/test_centrality.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/community/tests/__pycache__/test_centrality.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20e723ac0567425be101e336be46ef419bb4d03a Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/community/tests/__pycache__/test_centrality.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/community/tests/__pycache__/test_divisive.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/community/tests/__pycache__/test_divisive.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..39dbd623bb1b30eeaee61af0d0a210389f17a91a Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/community/tests/__pycache__/test_divisive.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/community/tests/__pycache__/test_kclique.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/community/tests/__pycache__/test_kclique.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93ca61e061570fcb9cc2aac082d3e3c959636dd4 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/community/tests/__pycache__/test_kclique.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/community/tests/__pycache__/test_kernighan_lin.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/community/tests/__pycache__/test_kernighan_lin.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..36352d1bc83a68f4052b4bf2741c82a84733b87b Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/community/tests/__pycache__/test_kernighan_lin.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/community/tests/__pycache__/test_label_propagation.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/community/tests/__pycache__/test_label_propagation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..10a939d9a53e09d59d925565c091aa1f8a10ef24 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/community/tests/__pycache__/test_label_propagation.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/community/tests/__pycache__/test_louvain.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/community/tests/__pycache__/test_louvain.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d8b298f4cfcbf1f069ca9236474af2e984e08652 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/community/tests/__pycache__/test_louvain.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/community/tests/__pycache__/test_lukes.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/community/tests/__pycache__/test_lukes.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..43828dd9a9cba772e393fa6fd4a65d1466970d1e Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/community/tests/__pycache__/test_lukes.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/community/tests/__pycache__/test_modularity_max.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/community/tests/__pycache__/test_modularity_max.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a4080626e9e67f8fda477a10e1d5fa4ff115f7c3 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/community/tests/__pycache__/test_modularity_max.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/community/tests/__pycache__/test_quality.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/community/tests/__pycache__/test_quality.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fff58dc0d40114b5df3643a7a769db98918eda1c Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/community/tests/__pycache__/test_quality.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/community/tests/__pycache__/test_utils.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/community/tests/__pycache__/test_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d3f7ce6daeb25d1b0d26da2344f4213fb2e8c9e9 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/community/tests/__pycache__/test_utils.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/community/tests/test_asyn_fluid.py b/lib/python3.10/site-packages/networkx/algorithms/community/tests/test_asyn_fluid.py new file mode 100644 index 0000000000000000000000000000000000000000..6c023be773d80b5ef71254ba55547f622ff4e43d --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/community/tests/test_asyn_fluid.py @@ -0,0 +1,136 @@ +import pytest + +import networkx as nx +from networkx import Graph, NetworkXError +from networkx.algorithms.community import asyn_fluidc + + +@pytest.mark.parametrize("graph_constructor", (nx.DiGraph, nx.MultiGraph)) +def test_raises_on_directed_and_multigraphs(graph_constructor): + G = graph_constructor([(0, 1), (1, 2)]) + with pytest.raises(nx.NetworkXNotImplemented): + nx.community.asyn_fluidc(G, 1) + + +def test_exceptions(): + test = Graph() + test.add_node("a") + pytest.raises(NetworkXError, asyn_fluidc, test, "hi") + pytest.raises(NetworkXError, asyn_fluidc, test, -1) + pytest.raises(NetworkXError, asyn_fluidc, test, 3) + test.add_node("b") + pytest.raises(NetworkXError, asyn_fluidc, test, 1) + + +def test_single_node(): + test = Graph() + + test.add_node("a") + + # ground truth + ground_truth = {frozenset(["a"])} + + communities = asyn_fluidc(test, 1) + result = {frozenset(c) for c in communities} + assert result == ground_truth + + +def test_two_nodes(): + test = Graph() + + test.add_edge("a", "b") + + # ground truth + ground_truth = {frozenset(["a"]), frozenset(["b"])} + + communities = asyn_fluidc(test, 2) + result = {frozenset(c) for c in communities} + assert result == ground_truth + + +def test_two_clique_communities(): + test = Graph() + + # c1 + test.add_edge("a", "b") + test.add_edge("a", "c") + test.add_edge("b", "c") + + # connection + test.add_edge("c", "d") + + # c2 + test.add_edge("d", "e") + test.add_edge("d", "f") + test.add_edge("f", "e") + + # ground truth + ground_truth = {frozenset(["a", "c", "b"]), frozenset(["e", "d", "f"])} + + communities = asyn_fluidc(test, 2, seed=7) + result = {frozenset(c) for c in communities} + assert result == ground_truth + + +def test_five_clique_ring(): + test = Graph() + + # c1 + test.add_edge("1a", "1b") + test.add_edge("1a", "1c") + test.add_edge("1a", "1d") + test.add_edge("1b", "1c") + test.add_edge("1b", "1d") + test.add_edge("1c", "1d") + + # c2 + test.add_edge("2a", "2b") + test.add_edge("2a", "2c") + test.add_edge("2a", "2d") + test.add_edge("2b", "2c") + test.add_edge("2b", "2d") + test.add_edge("2c", "2d") + + # c3 + test.add_edge("3a", "3b") + test.add_edge("3a", "3c") + test.add_edge("3a", "3d") + test.add_edge("3b", "3c") + test.add_edge("3b", "3d") + test.add_edge("3c", "3d") + + # c4 + test.add_edge("4a", "4b") + test.add_edge("4a", "4c") + test.add_edge("4a", "4d") + test.add_edge("4b", "4c") + test.add_edge("4b", "4d") + test.add_edge("4c", "4d") + + # c5 + test.add_edge("5a", "5b") + test.add_edge("5a", "5c") + test.add_edge("5a", "5d") + test.add_edge("5b", "5c") + test.add_edge("5b", "5d") + test.add_edge("5c", "5d") + + # connections + test.add_edge("1a", "2c") + test.add_edge("2a", "3c") + test.add_edge("3a", "4c") + test.add_edge("4a", "5c") + test.add_edge("5a", "1c") + + # ground truth + ground_truth = { + frozenset(["1a", "1b", "1c", "1d"]), + frozenset(["2a", "2b", "2c", "2d"]), + frozenset(["3a", "3b", "3c", "3d"]), + frozenset(["4a", "4b", "4c", "4d"]), + frozenset(["5a", "5b", "5c", "5d"]), + } + + communities = asyn_fluidc(test, 5, seed=9) + result = {frozenset(c) for c in communities} + assert result == ground_truth diff --git a/lib/python3.10/site-packages/networkx/algorithms/community/tests/test_centrality.py b/lib/python3.10/site-packages/networkx/algorithms/community/tests/test_centrality.py new file mode 100644 index 0000000000000000000000000000000000000000..1a8982f0d7fa16f6c7114a0bafe85bf205988c93 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/community/tests/test_centrality.py @@ -0,0 +1,85 @@ +"""Unit tests for the :mod:`networkx.algorithms.community.centrality` +module. + +""" + +from operator import itemgetter + +import networkx as nx + + +def set_of_sets(iterable): + return set(map(frozenset, iterable)) + + +def validate_communities(result, expected): + assert set_of_sets(result) == set_of_sets(expected) + + +def validate_possible_communities(result, *expected): + assert any(set_of_sets(result) == set_of_sets(p) for p in expected) + + +class TestGirvanNewman: + """Unit tests for the + :func:`networkx.algorithms.community.centrality.girvan_newman` + function. + + """ + + def test_no_edges(self): + G = nx.empty_graph(3) + communities = list(nx.community.girvan_newman(G)) + assert len(communities) == 1 + validate_communities(communities[0], [{0}, {1}, {2}]) + + def test_undirected(self): + # Start with the graph .-.-.-. + G = nx.path_graph(4) + communities = list(nx.community.girvan_newman(G)) + assert len(communities) == 3 + # After one removal, we get the graph .-. .-. + validate_communities(communities[0], [{0, 1}, {2, 3}]) + # After the next, we get the graph .-. . ., but there are two + # symmetric possible versions. + validate_possible_communities( + communities[1], [{0}, {1}, {2, 3}], [{0, 1}, {2}, {3}] + ) + # After the last removal, we always get the empty graph. + validate_communities(communities[2], [{0}, {1}, {2}, {3}]) + + def test_directed(self): + G = nx.DiGraph(nx.path_graph(4)) + communities = list(nx.community.girvan_newman(G)) + assert len(communities) == 3 + validate_communities(communities[0], [{0, 1}, {2, 3}]) + validate_possible_communities( + communities[1], [{0}, {1}, {2, 3}], [{0, 1}, {2}, {3}] + ) + validate_communities(communities[2], [{0}, {1}, {2}, {3}]) + + def test_selfloops(self): + G = nx.path_graph(4) + G.add_edge(0, 0) + G.add_edge(2, 2) + communities = list(nx.community.girvan_newman(G)) + assert len(communities) == 3 + validate_communities(communities[0], [{0, 1}, {2, 3}]) + validate_possible_communities( + communities[1], [{0}, {1}, {2, 3}], [{0, 1}, {2}, {3}] + ) + validate_communities(communities[2], [{0}, {1}, {2}, {3}]) + + def test_most_valuable_edge(self): + G = nx.Graph() + G.add_weighted_edges_from([(0, 1, 3), (1, 2, 2), (2, 3, 1)]) + # Let the most valuable edge be the one with the highest weight. + + def heaviest(G): + return max(G.edges(data="weight"), key=itemgetter(2))[:2] + + communities = list(nx.community.girvan_newman(G, heaviest)) + assert len(communities) == 3 + validate_communities(communities[0], [{0}, {1, 2, 3}]) + validate_communities(communities[1], [{0}, {1}, {2, 3}]) + validate_communities(communities[2], [{0}, {1}, {2}, {3}]) diff --git a/lib/python3.10/site-packages/networkx/algorithms/community/tests/test_divisive.py b/lib/python3.10/site-packages/networkx/algorithms/community/tests/test_divisive.py new file mode 100644 index 0000000000000000000000000000000000000000..6331503f97eaabee965a7f7f302b30e88601687e --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/community/tests/test_divisive.py @@ -0,0 +1,106 @@ +import pytest + +import networkx as nx + + +def test_edge_betweenness_partition(): + G = nx.barbell_graph(3, 0) + C = nx.community.edge_betweenness_partition(G, 2) + answer = [{0, 1, 2}, {3, 4, 5}] + assert len(C) == len(answer) + for s in answer: + assert s in C + + G = nx.barbell_graph(3, 1) + C = nx.community.edge_betweenness_partition(G, 3) + answer = [{0, 1, 2}, {4, 5, 6}, {3}] + assert len(C) == len(answer) + for s in answer: + assert s in C + + C = nx.community.edge_betweenness_partition(G, 7) + answer = [{n} for n in G] + assert len(C) == len(answer) + for s in answer: + assert s in C + + C = nx.community.edge_betweenness_partition(G, 1) + assert C == [set(G)] + + C = nx.community.edge_betweenness_partition(G, 1, weight="weight") + assert C == [set(G)] + + with pytest.raises(nx.NetworkXError): + nx.community.edge_betweenness_partition(G, 0) + + with pytest.raises(nx.NetworkXError): + nx.community.edge_betweenness_partition(G, -1) + + with pytest.raises(nx.NetworkXError): + nx.community.edge_betweenness_partition(G, 10) + + +def test_edge_current_flow_betweenness_partition(): + pytest.importorskip("scipy") + + G = nx.barbell_graph(3, 0) + C = nx.community.edge_current_flow_betweenness_partition(G, 2) + answer = [{0, 1, 2}, {3, 4, 5}] + assert len(C) == len(answer) + for s in answer: + assert s in C + + G = nx.barbell_graph(3, 1) + C = nx.community.edge_current_flow_betweenness_partition(G, 2) + answers = [[{0, 1, 2, 3}, {4, 5, 6}], [{0, 1, 2}, {3, 4, 5, 6}]] + assert len(C) == len(answers[0]) + assert any(all(s in answer for s in C) for answer in answers) + + C = nx.community.edge_current_flow_betweenness_partition(G, 3) + answer = [{0, 1, 2}, {4, 5, 6}, {3}] + assert len(C) == len(answer) + for s in answer: + assert s in C + + C = nx.community.edge_current_flow_betweenness_partition(G, 4) + answers = [[{1, 2}, {4, 5, 6}, {3}, {0}], [{0, 1, 2}, {5, 6}, {3}, {4}]] + assert len(C) == len(answers[0]) + assert any(all(s in answer for s in C) for answer in answers) + + C = nx.community.edge_current_flow_betweenness_partition(G, 5) + answer = [{1, 2}, {5, 6}, {3}, {0}, {4}] + assert len(C) == len(answer) + for s in answer: + assert s in C + + C = nx.community.edge_current_flow_betweenness_partition(G, 6) + answers = [[{2}, {5, 6}, {3}, {0}, {4}, {1}], [{1, 2}, {6}, {3}, {0}, {4}, {5}]] + assert len(C) == len(answers[0]) + assert any(all(s in answer for s in C) for answer in answers) + + C = nx.community.edge_current_flow_betweenness_partition(G, 7) + answer = [{n} for n in G] + assert len(C) == len(answer) + for s in answer: + assert s in C + + C = nx.community.edge_current_flow_betweenness_partition(G, 1) + assert C == [set(G)] + + C = nx.community.edge_current_flow_betweenness_partition(G, 1, weight="weight") + assert C == [set(G)] + + with pytest.raises(nx.NetworkXError): + nx.community.edge_current_flow_betweenness_partition(G, 0) + + with pytest.raises(nx.NetworkXError): + nx.community.edge_current_flow_betweenness_partition(G, -1) + + with pytest.raises(nx.NetworkXError): + nx.community.edge_current_flow_betweenness_partition(G, 10) + + N = 10 + G = nx.empty_graph(N) + for i in range(2, N - 1): + C = nx.community.edge_current_flow_betweenness_partition(G, i) + assert C == [{n} for n in G] diff --git a/lib/python3.10/site-packages/networkx/algorithms/community/tests/test_kclique.py b/lib/python3.10/site-packages/networkx/algorithms/community/tests/test_kclique.py new file mode 100644 index 0000000000000000000000000000000000000000..aa0b7e823e2780f734180562fa3fb8ce5a671312 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/community/tests/test_kclique.py @@ -0,0 +1,91 @@ +from itertools import combinations + +import pytest + +import networkx as nx + + +def test_overlapping_K5(): + G = nx.Graph() + G.add_edges_from(combinations(range(5), 2)) # Add a five clique + G.add_edges_from(combinations(range(2, 7), 2)) # Add another five clique + c = list(nx.community.k_clique_communities(G, 4)) + assert c == [frozenset(range(7))] + c = set(nx.community.k_clique_communities(G, 5)) + assert c == {frozenset(range(5)), frozenset(range(2, 7))} + + +def test_isolated_K5(): + G = nx.Graph() + G.add_edges_from(combinations(range(5), 2)) # Add a five clique + G.add_edges_from(combinations(range(5, 10), 2)) # Add another five clique + c = set(nx.community.k_clique_communities(G, 5)) + assert c == {frozenset(range(5)), frozenset(range(5, 10))} + + +class TestZacharyKarateClub: + def setup_method(self): + self.G = nx.karate_club_graph() + + def _check_communities(self, k, expected): + communities = set(nx.community.k_clique_communities(self.G, k)) + assert communities == expected + + def test_k2(self): + # clique percolation with k=2 is just connected components + expected = {frozenset(self.G)} + self._check_communities(2, expected) + + def test_k3(self): + comm1 = [ + 0, + 1, + 2, + 3, + 7, + 8, + 12, + 13, + 14, + 15, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 26, + 27, + 28, + 29, + 30, + 31, + 32, + 33, + ] + comm2 = [0, 4, 5, 6, 10, 16] + comm3 = [24, 25, 31] + expected = {frozenset(comm1), frozenset(comm2), frozenset(comm3)} + self._check_communities(3, expected) + + def test_k4(self): + expected = { + frozenset([0, 1, 2, 3, 7, 13]), + frozenset([8, 32, 30, 33]), + frozenset([32, 33, 29, 23]), + } + self._check_communities(4, expected) + + def test_k5(self): + expected = {frozenset([0, 1, 2, 3, 7, 13])} + self._check_communities(5, expected) + + def test_k6(self): + expected = set() + self._check_communities(6, expected) + + +def test_bad_k(): + with pytest.raises(nx.NetworkXError): + list(nx.community.k_clique_communities(nx.Graph(), 1)) diff --git a/lib/python3.10/site-packages/networkx/algorithms/community/tests/test_kernighan_lin.py b/lib/python3.10/site-packages/networkx/algorithms/community/tests/test_kernighan_lin.py new file mode 100644 index 0000000000000000000000000000000000000000..25d53d5f33b78ac8ce837741a56baf62b64f3b68 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/community/tests/test_kernighan_lin.py @@ -0,0 +1,92 @@ +"""Unit tests for the :mod:`networkx.algorithms.community.kernighan_lin` +module. +""" + +from itertools import permutations + +import pytest + +import networkx as nx +from networkx.algorithms.community import kernighan_lin_bisection + + +def assert_partition_equal(x, y): + assert set(map(frozenset, x)) == set(map(frozenset, y)) + + +def test_partition(): + G = nx.barbell_graph(3, 0) + C = kernighan_lin_bisection(G) + assert_partition_equal(C, [{0, 1, 2}, {3, 4, 5}]) + + +def test_partition_argument(): + G = nx.barbell_graph(3, 0) + partition = [{0, 1, 2}, {3, 4, 5}] + C = kernighan_lin_bisection(G, partition) + assert_partition_equal(C, partition) + + +def test_partition_argument_non_integer_nodes(): + G = nx.Graph([("A", "B"), ("A", "C"), ("B", "C"), ("C", "D")]) + partition = ({"A", "B"}, {"C", "D"}) + C = kernighan_lin_bisection(G, partition) + assert_partition_equal(C, partition) + + +def test_seed_argument(): + G = nx.barbell_graph(3, 0) + C = kernighan_lin_bisection(G, seed=1) + assert_partition_equal(C, [{0, 1, 2}, {3, 4, 5}]) + + +def test_non_disjoint_partition(): + with pytest.raises(nx.NetworkXError): + G = nx.barbell_graph(3, 0) + partition = ({0, 1, 2}, {2, 3, 4, 5}) + kernighan_lin_bisection(G, partition) + + +def test_too_many_blocks(): + with pytest.raises(nx.NetworkXError): + G = nx.barbell_graph(3, 0) + partition = ({0, 1}, {2}, {3, 4, 5}) + kernighan_lin_bisection(G, partition) + + +def test_multigraph(): + G = nx.cycle_graph(4) + M = nx.MultiGraph(G.edges()) + M.add_edges_from(G.edges()) + M.remove_edge(1, 2) + for labels in permutations(range(4)): + mapping = dict(zip(M, labels)) + A, B = kernighan_lin_bisection(nx.relabel_nodes(M, mapping), seed=0) + assert_partition_equal( + [A, B], [{mapping[0], mapping[1]}, {mapping[2], mapping[3]}] + ) + + +def test_max_iter_argument(): + G = nx.Graph( + [ + ("A", "B", {"weight": 1}), + ("A", "C", {"weight": 2}), + ("A", "D", {"weight": 3}), + ("A", "E", {"weight": 2}), + ("A", "F", {"weight": 4}), + ("B", "C", {"weight": 1}), + ("B", "D", {"weight": 4}), + ("B", "E", {"weight": 2}), + ("B", "F", {"weight": 1}), + ("C", "D", {"weight": 3}), + ("C", "E", {"weight": 2}), + ("C", "F", {"weight": 1}), + ("D", "E", {"weight": 4}), + ("D", "F", {"weight": 3}), + ("E", "F", {"weight": 2}), + ] + ) + partition = ({"A", "B", "C"}, {"D", "E", "F"}) + C = kernighan_lin_bisection(G, partition, max_iter=1) + assert_partition_equal(C, ({"A", "F", "C"}, {"D", "E", "B"})) diff --git a/lib/python3.10/site-packages/networkx/algorithms/community/tests/test_label_propagation.py b/lib/python3.10/site-packages/networkx/algorithms/community/tests/test_label_propagation.py new file mode 100644 index 0000000000000000000000000000000000000000..4be72dbf27281c58d973e1d0d84d101fa369d43d --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/community/tests/test_label_propagation.py @@ -0,0 +1,241 @@ +from itertools import chain, combinations + +import pytest + +import networkx as nx + + +def test_directed_not_supported(): + with pytest.raises(nx.NetworkXNotImplemented): + # not supported for directed graphs + test = nx.DiGraph() + test.add_edge("a", "b") + test.add_edge("a", "c") + test.add_edge("b", "d") + result = nx.community.label_propagation_communities(test) + + +def test_iterator_vs_iterable(): + G = nx.empty_graph("a") + assert list(nx.community.label_propagation_communities(G)) == [{"a"}] + for community in nx.community.label_propagation_communities(G): + assert community == {"a"} + pytest.raises(TypeError, next, nx.community.label_propagation_communities(G)) + + +def test_one_node(): + test = nx.Graph() + test.add_node("a") + + # The expected communities are: + ground_truth = {frozenset(["a"])} + + communities = nx.community.label_propagation_communities(test) + result = {frozenset(c) for c in communities} + assert result == ground_truth + + +def test_unconnected_communities(): + test = nx.Graph() + # community 1 + test.add_edge("a", "c") + test.add_edge("a", "d") + test.add_edge("d", "c") + # community 2 + test.add_edge("b", "e") + test.add_edge("e", "f") + test.add_edge("f", "b") + + # The expected communities are: + ground_truth = {frozenset(["a", "c", "d"]), frozenset(["b", "e", "f"])} + + communities = nx.community.label_propagation_communities(test) + result = {frozenset(c) for c in communities} + assert result == ground_truth + + +def test_connected_communities(): + test = nx.Graph() + # community 1 + test.add_edge("a", "b") + test.add_edge("c", "a") + test.add_edge("c", "b") + test.add_edge("d", "a") + test.add_edge("d", "b") + test.add_edge("d", "c") + test.add_edge("e", "a") + test.add_edge("e", "b") + test.add_edge("e", "c") + test.add_edge("e", "d") + # community 2 + test.add_edge("1", "2") + test.add_edge("3", "1") + test.add_edge("3", "2") + test.add_edge("4", "1") + test.add_edge("4", "2") + test.add_edge("4", "3") + test.add_edge("5", "1") + test.add_edge("5", "2") + test.add_edge("5", "3") + test.add_edge("5", "4") + # edge between community 1 and 2 + test.add_edge("a", "1") + # community 3 + test.add_edge("x", "y") + # community 4 with only a single node + test.add_node("z") + + # The expected communities are: + ground_truth1 = { + frozenset(["a", "b", "c", "d", "e"]), + frozenset(["1", "2", "3", "4", "5"]), + frozenset(["x", "y"]), + frozenset(["z"]), + } + ground_truth2 = { + frozenset(["a", "b", "c", "d", "e", "1", "2", "3", "4", "5"]), + frozenset(["x", "y"]), + frozenset(["z"]), + } + ground_truth = (ground_truth1, ground_truth2) + + communities = nx.community.label_propagation_communities(test) + result = {frozenset(c) for c in communities} + assert result in ground_truth + + +def test_termination(): + # ensure termination of asyn_lpa_communities in two cases + # that led to an endless loop in a previous version + test1 = nx.karate_club_graph() + test2 = nx.caveman_graph(2, 10) + test2.add_edges_from([(0, 20), (20, 10)]) + nx.community.asyn_lpa_communities(test1) + nx.community.asyn_lpa_communities(test2) + + +class TestAsynLpaCommunities: + def _check_communities(self, G, expected): + """Checks that the communities computed from the given graph ``G`` + using the :func:`~networkx.asyn_lpa_communities` function match + the set of nodes given in ``expected``. + + ``expected`` must be a :class:`set` of :class:`frozenset` + instances, each element of which is a node in the graph. + + """ + communities = nx.community.asyn_lpa_communities(G) + result = {frozenset(c) for c in communities} + assert result == expected + + def test_null_graph(self): + G = nx.null_graph() + ground_truth = set() + self._check_communities(G, ground_truth) + + def test_single_node(self): + G = nx.empty_graph(1) + ground_truth = {frozenset([0])} + self._check_communities(G, ground_truth) + + def test_simple_communities(self): + # This graph is the disjoint union of two triangles. + G = nx.Graph(["ab", "ac", "bc", "de", "df", "fe"]) + ground_truth = {frozenset("abc"), frozenset("def")} + self._check_communities(G, ground_truth) + + def test_seed_argument(self): + G = nx.Graph(["ab", "ac", "bc", "de", "df", "fe"]) + ground_truth = {frozenset("abc"), frozenset("def")} + communities = nx.community.asyn_lpa_communities(G, seed=1) + result = {frozenset(c) for c in communities} + assert result == ground_truth + + def test_several_communities(self): + # This graph is the disjoint union of five triangles. + ground_truth = {frozenset(range(3 * i, 3 * (i + 1))) for i in range(5)} + edges = chain.from_iterable(combinations(c, 2) for c in ground_truth) + G = nx.Graph(edges) + self._check_communities(G, ground_truth) + + +class TestFastLabelPropagationCommunities: + N = 100 # number of nodes + K = 15 # average node degree + + def _check_communities(self, G, truth, weight=None, seed=42): + C = nx.community.fast_label_propagation_communities(G, weight=weight, seed=seed) + assert {frozenset(c) for c in C} == truth + + def test_null_graph(self): + G = nx.null_graph() + truth = set() + self._check_communities(G, truth) + + def test_empty_graph(self): + G = nx.empty_graph(self.N) + truth = {frozenset([i]) for i in G} + self._check_communities(G, truth) + + def test_star_graph(self): + G = nx.star_graph(self.N) + truth = {frozenset(G)} + self._check_communities(G, truth) + + def test_complete_graph(self): + G = nx.complete_graph(self.N) + truth = {frozenset(G)} + self._check_communities(G, truth) + + def test_bipartite_graph(self): + G = nx.complete_bipartite_graph(self.N // 2, self.N // 2) + truth = {frozenset(G)} + self._check_communities(G, truth) + + def test_random_graph(self): + G = nx.gnm_random_graph(self.N, self.N * self.K // 2, seed=42) + truth = {frozenset(G)} + self._check_communities(G, truth) + + def test_disjoin_cliques(self): + G = nx.Graph(["ab", "AB", "AC", "BC", "12", "13", "14", "23", "24", "34"]) + truth = {frozenset("ab"), frozenset("ABC"), frozenset("1234")} + self._check_communities(G, truth) + + def test_ring_of_cliques(self): + N, K = self.N, self.K + G = nx.ring_of_cliques(N, K) + truth = {frozenset([K * i + k for k in range(K)]) for i in range(N)} + self._check_communities(G, truth) + + def test_larger_graph(self): + G = nx.gnm_random_graph(100 * self.N, 50 * self.N * self.K, seed=42) + nx.community.fast_label_propagation_communities(G) + + def test_graph_type(self): + G1 = nx.complete_graph(self.N, nx.MultiDiGraph()) + G2 = nx.MultiGraph(G1) + G3 = nx.DiGraph(G1) + G4 = nx.Graph(G1) + truth = {frozenset(G1)} + self._check_communities(G1, truth) + self._check_communities(G2, truth) + self._check_communities(G3, truth) + self._check_communities(G4, truth) + + def test_weight_argument(self): + G = nx.MultiDiGraph() + G.add_edge(1, 2, weight=1.41) + G.add_edge(2, 1, weight=1.41) + G.add_edge(2, 3) + G.add_edge(3, 4, weight=3.14) + truth = {frozenset({1, 2}), frozenset({3, 4})} + self._check_communities(G, truth, weight="weight") + + def test_seed_argument(self): + G = nx.karate_club_graph() + C = nx.community.fast_label_propagation_communities(G, seed=2023) + truth = {frozenset(c) for c in C} + self._check_communities(G, truth, seed=2023) + # smoke test that seed=None works + C = nx.community.fast_label_propagation_communities(G, seed=None) diff --git a/lib/python3.10/site-packages/networkx/algorithms/community/tests/test_louvain.py b/lib/python3.10/site-packages/networkx/algorithms/community/tests/test_louvain.py new file mode 100644 index 0000000000000000000000000000000000000000..816e6f143fe0837b5d8617a927f49e557ccd9668 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/community/tests/test_louvain.py @@ -0,0 +1,264 @@ +import pytest + +import networkx as nx + + +def test_modularity_increase(): + G = nx.LFR_benchmark_graph( + 250, 3, 1.5, 0.009, average_degree=5, min_community=20, seed=10 + ) + partition = [{u} for u in G.nodes()] + mod = nx.community.modularity(G, partition) + partition = nx.community.louvain_communities(G) + + assert nx.community.modularity(G, partition) > mod + + +def test_valid_partition(): + G = nx.LFR_benchmark_graph( + 250, 3, 1.5, 0.009, average_degree=5, min_community=20, seed=10 + ) + H = G.to_directed() + partition = nx.community.louvain_communities(G) + partition2 = nx.community.louvain_communities(H) + + assert nx.community.is_partition(G, partition) + assert nx.community.is_partition(H, partition2) + + +def test_karate_club_partition(): + G = nx.karate_club_graph() + part = [ + {0, 1, 2, 3, 7, 9, 11, 12, 13, 17, 19, 21}, + {16, 4, 5, 6, 10}, + {23, 25, 27, 28, 24, 31}, + {32, 33, 8, 14, 15, 18, 20, 22, 26, 29, 30}, + ] + partition = nx.community.louvain_communities(G, seed=2, weight=None) + + assert part == partition + + +def test_partition_iterator(): + G = nx.path_graph(15) + parts_iter = nx.community.louvain_partitions(G, seed=42) + first_part = next(parts_iter) + first_copy = [s.copy() for s in first_part] + + # gh-5901 reports sets changing after next partition is yielded + assert first_copy[0] == first_part[0] + second_part = next(parts_iter) + assert first_copy[0] == first_part[0] + + +def test_undirected_selfloops(): + G = nx.karate_club_graph() + expected_partition = nx.community.louvain_communities(G, seed=2, weight=None) + part = [ + {0, 1, 2, 3, 7, 9, 11, 12, 13, 17, 19, 21}, + {16, 4, 5, 6, 10}, + {23, 25, 27, 28, 24, 31}, + {32, 33, 8, 14, 15, 18, 20, 22, 26, 29, 30}, + ] + assert expected_partition == part + + G.add_weighted_edges_from([(i, i, i * 1000) for i in range(9)]) + # large self-loop weight impacts partition + partition = nx.community.louvain_communities(G, seed=2, weight="weight") + assert part != partition + + # small self-loop weights aren't enough to impact partition in this graph + partition = nx.community.louvain_communities(G, seed=2, weight=None) + assert part == partition + + +def test_directed_selfloops(): + G = nx.DiGraph() + G.add_nodes_from(range(11)) + G_edges = [ + (0, 2), + (0, 1), + (1, 0), + (2, 1), + (2, 0), + (3, 4), + (4, 3), + (7, 8), + (8, 7), + (9, 10), + (10, 9), + ] + G.add_edges_from(G_edges) + G_expected_partition = nx.community.louvain_communities(G, seed=123, weight=None) + + G.add_weighted_edges_from([(i, i, i * 1000) for i in range(3)]) + # large self-loop weight impacts partition + G_partition = nx.community.louvain_communities(G, seed=123, weight="weight") + assert G_partition != G_expected_partition + + # small self-loop weights aren't enough to impact partition in this graph + G_partition = nx.community.louvain_communities(G, seed=123, weight=None) + assert G_partition == G_expected_partition + + +def test_directed_partition(): + """ + Test 2 cases that were looping infinitely + from issues #5175 and #5704 + """ + G = nx.DiGraph() + H = nx.DiGraph() + G.add_nodes_from(range(10)) + H.add_nodes_from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]) + G_edges = [ + (0, 2), + (0, 1), + (1, 0), + (2, 1), + (2, 0), + (3, 4), + (4, 3), + (7, 8), + (8, 7), + (9, 10), + (10, 9), + ] + H_edges = [ + (1, 2), + (1, 6), + (1, 9), + (2, 3), + (2, 4), + (2, 5), + (3, 4), + (4, 3), + (4, 5), + (5, 4), + (6, 7), + (6, 8), + (9, 10), + (9, 11), + (10, 11), + (11, 10), + ] + G.add_edges_from(G_edges) + H.add_edges_from(H_edges) + + G_expected_partition = [{0, 1, 2}, {3, 4}, {5}, {6}, {8, 7}, {9, 10}] + G_partition = nx.community.louvain_communities(G, seed=123, weight=None) + + H_expected_partition = [{2, 3, 4, 5}, {8, 1, 6, 7}, {9, 10, 11}] + H_partition = nx.community.louvain_communities(H, seed=123, weight=None) + + assert G_partition == G_expected_partition + assert H_partition == H_expected_partition + + +def test_none_weight_param(): + G = nx.karate_club_graph() + nx.set_edge_attributes( + G, {edge: i * i for i, edge in enumerate(G.edges)}, name="foo" + ) + + part = [ + {0, 1, 2, 3, 7, 9, 11, 12, 13, 17, 19, 21}, + {16, 4, 5, 6, 10}, + {23, 25, 27, 28, 24, 31}, + {32, 33, 8, 14, 15, 18, 20, 22, 26, 29, 30}, + ] + partition1 = nx.community.louvain_communities(G, weight=None, seed=2) + partition2 = nx.community.louvain_communities(G, weight="foo", seed=2) + partition3 = nx.community.louvain_communities(G, weight="weight", seed=2) + + assert part == partition1 + assert part != partition2 + assert part != partition3 + assert partition2 != partition3 + + +def test_quality(): + G = nx.LFR_benchmark_graph( + 250, 3, 1.5, 0.009, average_degree=5, min_community=20, seed=10 + ) + H = nx.gn_graph(200, seed=1234) + I = nx.MultiGraph(G) + J = nx.MultiDiGraph(H) + + partition = nx.community.louvain_communities(G) + partition2 = nx.community.louvain_communities(H) + partition3 = nx.community.louvain_communities(I) + partition4 = nx.community.louvain_communities(J) + + quality = nx.community.partition_quality(G, partition)[0] + quality2 = nx.community.partition_quality(H, partition2)[0] + quality3 = nx.community.partition_quality(I, partition3)[0] + quality4 = nx.community.partition_quality(J, partition4)[0] + + assert quality >= 0.65 + assert quality2 >= 0.65 + assert quality3 >= 0.65 + assert quality4 >= 0.65 + + +def test_multigraph(): + G = nx.karate_club_graph() + H = nx.MultiGraph(G) + G.add_edge(0, 1, weight=10) + H.add_edge(0, 1, weight=9) + G.add_edge(0, 9, foo=20) + H.add_edge(0, 9, foo=20) + + partition1 = nx.community.louvain_communities(G, seed=1234) + partition2 = nx.community.louvain_communities(H, seed=1234) + partition3 = nx.community.louvain_communities(H, weight="foo", seed=1234) + + assert partition1 == partition2 != partition3 + + +def test_resolution(): + G = nx.LFR_benchmark_graph( + 250, 3, 1.5, 0.009, average_degree=5, min_community=20, seed=10 + ) + + partition1 = nx.community.louvain_communities(G, resolution=0.5, seed=12) + partition2 = nx.community.louvain_communities(G, seed=12) + partition3 = nx.community.louvain_communities(G, resolution=2, seed=12) + + assert len(partition1) <= len(partition2) <= len(partition3) + + +def test_threshold(): + G = nx.LFR_benchmark_graph( + 250, 3, 1.5, 0.009, average_degree=5, min_community=20, seed=10 + ) + partition1 = nx.community.louvain_communities(G, threshold=0.3, seed=2) + partition2 = nx.community.louvain_communities(G, seed=2) + mod1 = nx.community.modularity(G, partition1) + mod2 = nx.community.modularity(G, partition2) + + assert mod1 <= mod2 + + +def test_empty_graph(): + G = nx.Graph() + G.add_nodes_from(range(5)) + expected = [{0}, {1}, {2}, {3}, {4}] + assert nx.community.louvain_communities(G) == expected + + +def test_max_level(): + G = nx.LFR_benchmark_graph( + 250, 3, 1.5, 0.009, average_degree=5, min_community=20, seed=10 + ) + parts_iter = nx.community.louvain_partitions(G, seed=42) + for max_level, expected in enumerate(parts_iter, 1): + partition = nx.community.louvain_communities(G, max_level=max_level, seed=42) + assert partition == expected + assert max_level > 1 # Ensure we are actually testing max_level + # max_level is an upper limit; it's okay if we stop before it's hit. + partition = nx.community.louvain_communities(G, max_level=max_level + 1, seed=42) + assert partition == expected + with pytest.raises( + ValueError, match="max_level argument must be a positive integer" + ): + nx.community.louvain_communities(G, max_level=0) diff --git a/lib/python3.10/site-packages/networkx/algorithms/community/tests/test_lukes.py b/lib/python3.10/site-packages/networkx/algorithms/community/tests/test_lukes.py new file mode 100644 index 0000000000000000000000000000000000000000..cfa48f0f47667ce4c4fa96c175bee4cb95a4852f --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/community/tests/test_lukes.py @@ -0,0 +1,152 @@ +from itertools import product + +import pytest + +import networkx as nx + +EWL = "e_weight" +NWL = "n_weight" + + +# first test from the Lukes original paper +def paper_1_case(float_edge_wt=False, explicit_node_wt=True, directed=False): + # problem-specific constants + limit = 3 + + # configuration + if float_edge_wt: + shift = 0.001 + else: + shift = 0 + + if directed: + example_1 = nx.DiGraph() + else: + example_1 = nx.Graph() + + # graph creation + example_1.add_edge(1, 2, **{EWL: 3 + shift}) + example_1.add_edge(1, 4, **{EWL: 2 + shift}) + example_1.add_edge(2, 3, **{EWL: 4 + shift}) + example_1.add_edge(2, 5, **{EWL: 6 + shift}) + + # node weights + if explicit_node_wt: + nx.set_node_attributes(example_1, 1, NWL) + wtu = NWL + else: + wtu = None + + # partitioning + clusters_1 = { + frozenset(x) + for x in nx.community.lukes_partitioning( + example_1, limit, node_weight=wtu, edge_weight=EWL + ) + } + + return clusters_1 + + +# second test from the Lukes original paper +def paper_2_case(explicit_edge_wt=True, directed=False): + # problem specific constants + byte_block_size = 32 + + # configuration + if directed: + example_2 = nx.DiGraph() + else: + example_2 = nx.Graph() + + if explicit_edge_wt: + edic = {EWL: 1} + wtu = EWL + else: + edic = {} + wtu = None + + # graph creation + example_2.add_edge("name", "home_address", **edic) + example_2.add_edge("name", "education", **edic) + example_2.add_edge("education", "bs", **edic) + example_2.add_edge("education", "ms", **edic) + example_2.add_edge("education", "phd", **edic) + example_2.add_edge("name", "telephone", **edic) + example_2.add_edge("telephone", "home", **edic) + example_2.add_edge("telephone", "office", **edic) + example_2.add_edge("office", "no1", **edic) + example_2.add_edge("office", "no2", **edic) + + example_2.nodes["name"][NWL] = 20 + example_2.nodes["education"][NWL] = 10 + example_2.nodes["bs"][NWL] = 1 + example_2.nodes["ms"][NWL] = 1 + example_2.nodes["phd"][NWL] = 1 + example_2.nodes["home_address"][NWL] = 8 + example_2.nodes["telephone"][NWL] = 8 + example_2.nodes["home"][NWL] = 8 + example_2.nodes["office"][NWL] = 4 + example_2.nodes["no1"][NWL] = 1 + example_2.nodes["no2"][NWL] = 1 + + # partitioning + clusters_2 = { + frozenset(x) + for x in nx.community.lukes_partitioning( + example_2, byte_block_size, node_weight=NWL, edge_weight=wtu + ) + } + + return clusters_2 + + +def test_paper_1_case(): + ground_truth = {frozenset([1, 4]), frozenset([2, 3, 5])} + + tf = (True, False) + for flt, nwt, drc in product(tf, tf, tf): + part = paper_1_case(flt, nwt, drc) + assert part == ground_truth + + +def test_paper_2_case(): + ground_truth = { + frozenset(["education", "bs", "ms", "phd"]), + frozenset(["name", "home_address"]), + frozenset(["telephone", "home", "office", "no1", "no2"]), + } + + tf = (True, False) + for ewt, drc in product(tf, tf): + part = paper_2_case(ewt, drc) + assert part == ground_truth + + +def test_mandatory_tree(): + not_a_tree = nx.complete_graph(4) + + with pytest.raises(nx.NotATree): + nx.community.lukes_partitioning(not_a_tree, 5) + + +def test_mandatory_integrality(): + byte_block_size = 32 + + ex_1_broken = nx.DiGraph() + + ex_1_broken.add_edge(1, 2, **{EWL: 3.2}) + ex_1_broken.add_edge(1, 4, **{EWL: 2.4}) + ex_1_broken.add_edge(2, 3, **{EWL: 4.0}) + ex_1_broken.add_edge(2, 5, **{EWL: 6.3}) + + ex_1_broken.nodes[1][NWL] = 1.2 # ! + ex_1_broken.nodes[2][NWL] = 1 + ex_1_broken.nodes[3][NWL] = 1 + ex_1_broken.nodes[4][NWL] = 1 + ex_1_broken.nodes[5][NWL] = 2 + + with pytest.raises(TypeError): + nx.community.lukes_partitioning( + ex_1_broken, byte_block_size, node_weight=NWL, edge_weight=EWL + ) diff --git a/lib/python3.10/site-packages/networkx/algorithms/community/tests/test_modularity_max.py b/lib/python3.10/site-packages/networkx/algorithms/community/tests/test_modularity_max.py new file mode 100644 index 0000000000000000000000000000000000000000..0121367fc4ef766e2587610c3ea32ba33b12b259 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/community/tests/test_modularity_max.py @@ -0,0 +1,340 @@ +import pytest + +import networkx as nx +from networkx.algorithms.community import ( + greedy_modularity_communities, + naive_greedy_modularity_communities, +) + + +@pytest.mark.parametrize( + "func", (greedy_modularity_communities, naive_greedy_modularity_communities) +) +def test_modularity_communities(func): + G = nx.karate_club_graph() + john_a = frozenset( + [8, 14, 15, 18, 20, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33] + ) + mr_hi = frozenset([0, 4, 5, 6, 10, 11, 16, 19]) + overlap = frozenset([1, 2, 3, 7, 9, 12, 13, 17, 21]) + expected = {john_a, overlap, mr_hi} + assert set(func(G, weight=None)) == expected + + +@pytest.mark.parametrize( + "func", (greedy_modularity_communities, naive_greedy_modularity_communities) +) +def test_modularity_communities_categorical_labels(func): + # Using other than 0-starting contiguous integers as node-labels. + G = nx.Graph( + [ + ("a", "b"), + ("a", "c"), + ("b", "c"), + ("b", "d"), # inter-community edge + ("d", "e"), + ("d", "f"), + ("d", "g"), + ("f", "g"), + ("d", "e"), + ("f", "e"), + ] + ) + expected = {frozenset({"f", "g", "e", "d"}), frozenset({"a", "b", "c"})} + assert set(func(G)) == expected + + +def test_greedy_modularity_communities_components(): + # Test for gh-5530 + G = nx.Graph([(0, 1), (2, 3), (4, 5), (5, 6)]) + # usual case with 3 components + assert greedy_modularity_communities(G) == [{4, 5, 6}, {0, 1}, {2, 3}] + # best_n can make the algorithm continue even when modularity goes down + assert greedy_modularity_communities(G, best_n=3) == [{4, 5, 6}, {0, 1}, {2, 3}] + assert greedy_modularity_communities(G, best_n=2) == [{0, 1, 4, 5, 6}, {2, 3}] + assert greedy_modularity_communities(G, best_n=1) == [{0, 1, 2, 3, 4, 5, 6}] + + +def test_greedy_modularity_communities_relabeled(): + # Test for gh-4966 + G = nx.balanced_tree(2, 2) + mapping = {0: "a", 1: "b", 2: "c", 3: "d", 4: "e", 5: "f", 6: "g", 7: "h"} + G = nx.relabel_nodes(G, mapping) + expected = [frozenset({"e", "d", "a", "b"}), frozenset({"c", "f", "g"})] + assert greedy_modularity_communities(G) == expected + + +def test_greedy_modularity_communities_directed(): + G = nx.DiGraph( + [ + ("a", "b"), + ("a", "c"), + ("b", "c"), + ("b", "d"), # inter-community edge + ("d", "e"), + ("d", "f"), + ("d", "g"), + ("f", "g"), + ("d", "e"), + ("f", "e"), + ] + ) + expected = [frozenset({"f", "g", "e", "d"}), frozenset({"a", "b", "c"})] + assert greedy_modularity_communities(G) == expected + + # with loops + G = nx.DiGraph() + G.add_edges_from( + [(1, 1), (1, 2), (1, 3), (2, 3), (1, 4), (4, 4), (5, 5), (4, 5), (4, 6), (5, 6)] + ) + expected = [frozenset({1, 2, 3}), frozenset({4, 5, 6})] + assert greedy_modularity_communities(G) == expected + + +@pytest.mark.parametrize( + "func", (greedy_modularity_communities, naive_greedy_modularity_communities) +) +def test_modularity_communities_weighted(func): + G = nx.balanced_tree(2, 3) + for a, b in G.edges: + if ((a == 1) or (a == 2)) and (b != 0): + G[a][b]["weight"] = 10.0 + else: + G[a][b]["weight"] = 1.0 + + expected = [{0, 1, 3, 4, 7, 8, 9, 10}, {2, 5, 6, 11, 12, 13, 14}] + + assert func(G, weight="weight") == expected + assert func(G, weight="weight", resolution=0.9) == expected + assert func(G, weight="weight", resolution=0.3) == expected + assert func(G, weight="weight", resolution=1.1) != expected + + +def test_modularity_communities_floating_point(): + # check for floating point error when used as key in the mapped_queue dict. + # Test for gh-4992 and gh-5000 + G = nx.Graph() + G.add_weighted_edges_from( + [(0, 1, 12), (1, 4, 71), (2, 3, 15), (2, 4, 10), (3, 6, 13)] + ) + expected = [{0, 1, 4}, {2, 3, 6}] + assert greedy_modularity_communities(G, weight="weight") == expected + assert ( + greedy_modularity_communities(G, weight="weight", resolution=0.99) == expected + ) + + +def test_modularity_communities_directed_weighted(): + G = nx.DiGraph() + G.add_weighted_edges_from( + [ + (1, 2, 5), + (1, 3, 3), + (2, 3, 6), + (2, 6, 1), + (1, 4, 1), + (4, 5, 3), + (4, 6, 7), + (5, 6, 2), + (5, 7, 5), + (5, 8, 4), + (6, 8, 3), + ] + ) + expected = [frozenset({4, 5, 6, 7, 8}), frozenset({1, 2, 3})] + assert greedy_modularity_communities(G, weight="weight") == expected + + # A large weight of the edge (2, 6) causes 6 to change group, even if it shares + # only one connection with the new group and 3 with the old one. + G[2][6]["weight"] = 20 + expected = [frozenset({1, 2, 3, 6}), frozenset({4, 5, 7, 8})] + assert greedy_modularity_communities(G, weight="weight") == expected + + +def test_greedy_modularity_communities_multigraph(): + G = nx.MultiGraph() + G.add_edges_from( + [ + (1, 2), + (1, 2), + (1, 3), + (2, 3), + (1, 4), + (2, 4), + (4, 5), + (5, 6), + (5, 7), + (5, 7), + (6, 7), + (7, 8), + (5, 8), + ] + ) + expected = [frozenset({1, 2, 3, 4}), frozenset({5, 6, 7, 8})] + assert greedy_modularity_communities(G) == expected + + # Converting (4, 5) into a multi-edge causes node 4 to change group. + G.add_edge(4, 5) + expected = [frozenset({4, 5, 6, 7, 8}), frozenset({1, 2, 3})] + assert greedy_modularity_communities(G) == expected + + +def test_greedy_modularity_communities_multigraph_weighted(): + G = nx.MultiGraph() + G.add_weighted_edges_from( + [ + (1, 2, 5), + (1, 2, 3), + (1, 3, 6), + (1, 3, 6), + (2, 3, 4), + (1, 4, 1), + (1, 4, 1), + (2, 4, 3), + (2, 4, 3), + (4, 5, 1), + (5, 6, 3), + (5, 6, 7), + (5, 6, 4), + (5, 7, 9), + (5, 7, 9), + (6, 7, 8), + (7, 8, 2), + (7, 8, 2), + (5, 8, 6), + (5, 8, 6), + ] + ) + expected = [frozenset({1, 2, 3, 4}), frozenset({5, 6, 7, 8})] + assert greedy_modularity_communities(G, weight="weight") == expected + + # Adding multi-edge (4, 5, 16) causes node 4 to change group. + G.add_edge(4, 5, weight=16) + expected = [frozenset({4, 5, 6, 7, 8}), frozenset({1, 2, 3})] + assert greedy_modularity_communities(G, weight="weight") == expected + + # Increasing the weight of edge (1, 4) causes node 4 to return to the former group. + G[1][4][1]["weight"] = 3 + expected = [frozenset({1, 2, 3, 4}), frozenset({5, 6, 7, 8})] + assert greedy_modularity_communities(G, weight="weight") == expected + + +def test_greed_modularity_communities_multidigraph(): + G = nx.MultiDiGraph() + G.add_edges_from( + [ + (1, 2), + (1, 2), + (3, 1), + (2, 3), + (2, 3), + (3, 2), + (1, 4), + (2, 4), + (4, 2), + (4, 5), + (5, 6), + (5, 6), + (6, 5), + (5, 7), + (6, 7), + (7, 8), + (5, 8), + (8, 4), + ] + ) + expected = [frozenset({1, 2, 3, 4}), frozenset({5, 6, 7, 8})] + assert greedy_modularity_communities(G, weight="weight") == expected + + +def test_greed_modularity_communities_multidigraph_weighted(): + G = nx.MultiDiGraph() + G.add_weighted_edges_from( + [ + (1, 2, 5), + (1, 2, 3), + (3, 1, 6), + (1, 3, 6), + (3, 2, 4), + (1, 4, 2), + (1, 4, 5), + (2, 4, 3), + (3, 2, 8), + (4, 2, 3), + (4, 3, 5), + (4, 5, 2), + (5, 6, 3), + (5, 6, 7), + (6, 5, 4), + (5, 7, 9), + (5, 7, 9), + (7, 6, 8), + (7, 8, 2), + (8, 7, 2), + (5, 8, 6), + (5, 8, 6), + ] + ) + expected = [frozenset({1, 2, 3, 4}), frozenset({5, 6, 7, 8})] + assert greedy_modularity_communities(G, weight="weight") == expected + + +def test_resolution_parameter_impact(): + G = nx.barbell_graph(5, 3) + + gamma = 1 + expected = [frozenset(range(5)), frozenset(range(8, 13)), frozenset(range(5, 8))] + assert greedy_modularity_communities(G, resolution=gamma) == expected + assert naive_greedy_modularity_communities(G, resolution=gamma) == expected + + gamma = 2.5 + expected = [{0, 1, 2, 3}, {9, 10, 11, 12}, {5, 6, 7}, {4}, {8}] + assert greedy_modularity_communities(G, resolution=gamma) == expected + assert naive_greedy_modularity_communities(G, resolution=gamma) == expected + + gamma = 0.3 + expected = [frozenset(range(8)), frozenset(range(8, 13))] + assert greedy_modularity_communities(G, resolution=gamma) == expected + assert naive_greedy_modularity_communities(G, resolution=gamma) == expected + + +def test_cutoff_parameter(): + G = nx.circular_ladder_graph(4) + + # No aggregation: + expected = [{k} for k in range(8)] + assert greedy_modularity_communities(G, cutoff=8) == expected + + # Aggregation to half order (number of nodes) + expected = [{k, k + 1} for k in range(0, 8, 2)] + assert greedy_modularity_communities(G, cutoff=4) == expected + + # Default aggregation case (here, 2 communities emerge) + expected = [frozenset(range(4)), frozenset(range(4, 8))] + assert greedy_modularity_communities(G, cutoff=1) == expected + + +def test_best_n(): + G = nx.barbell_graph(5, 3) + + # Same result as without enforcing cutoff: + best_n = 3 + expected = [frozenset(range(5)), frozenset(range(8, 13)), frozenset(range(5, 8))] + assert greedy_modularity_communities(G, best_n=best_n) == expected + + # One additional merging step: + best_n = 2 + expected = [frozenset(range(8)), frozenset(range(8, 13))] + assert greedy_modularity_communities(G, best_n=best_n) == expected + + # Two additional merging steps: + best_n = 1 + expected = [frozenset(range(13))] + assert greedy_modularity_communities(G, best_n=best_n) == expected + + +def test_greedy_modularity_communities_corner_cases(): + G = nx.empty_graph() + assert nx.community.greedy_modularity_communities(G) == [] + G.add_nodes_from(range(3)) + assert nx.community.greedy_modularity_communities(G) == [{0}, {1}, {2}] diff --git a/lib/python3.10/site-packages/networkx/algorithms/community/tests/test_quality.py b/lib/python3.10/site-packages/networkx/algorithms/community/tests/test_quality.py new file mode 100644 index 0000000000000000000000000000000000000000..c502c7e352114ec55c1b2bf716483014105b7068 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/community/tests/test_quality.py @@ -0,0 +1,139 @@ +"""Unit tests for the :mod:`networkx.algorithms.community.quality` +module. + +""" + +import pytest + +import networkx as nx +from networkx import barbell_graph +from networkx.algorithms.community import modularity, partition_quality +from networkx.algorithms.community.quality import inter_community_edges + + +class TestPerformance: + """Unit tests for the :func:`performance` function.""" + + def test_bad_partition(self): + """Tests that a poor partition has a low performance measure.""" + G = barbell_graph(3, 0) + partition = [{0, 1, 4}, {2, 3, 5}] + assert 8 / 15 == pytest.approx(partition_quality(G, partition)[1], abs=1e-7) + + def test_good_partition(self): + """Tests that a good partition has a high performance measure.""" + G = barbell_graph(3, 0) + partition = [{0, 1, 2}, {3, 4, 5}] + assert 14 / 15 == pytest.approx(partition_quality(G, partition)[1], abs=1e-7) + + +class TestCoverage: + """Unit tests for the :func:`coverage` function.""" + + def test_bad_partition(self): + """Tests that a poor partition has a low coverage measure.""" + G = barbell_graph(3, 0) + partition = [{0, 1, 4}, {2, 3, 5}] + assert 3 / 7 == pytest.approx(partition_quality(G, partition)[0], abs=1e-7) + + def test_good_partition(self): + """Tests that a good partition has a high coverage measure.""" + G = barbell_graph(3, 0) + partition = [{0, 1, 2}, {3, 4, 5}] + assert 6 / 7 == pytest.approx(partition_quality(G, partition)[0], abs=1e-7) + + +def test_modularity(): + G = nx.barbell_graph(3, 0) + C = [{0, 1, 4}, {2, 3, 5}] + assert (-16 / (14**2)) == pytest.approx(modularity(G, C), abs=1e-7) + C = [{0, 1, 2}, {3, 4, 5}] + assert (35 * 2) / (14**2) == pytest.approx(modularity(G, C), abs=1e-7) + + n = 1000 + G = nx.erdos_renyi_graph(n, 0.09, seed=42, directed=True) + C = [set(range(n // 2)), set(range(n // 2, n))] + assert 0.00017154251389292754 == pytest.approx(modularity(G, C), abs=1e-7) + + G = nx.margulis_gabber_galil_graph(10) + mid_value = G.number_of_nodes() // 2 + nodes = list(G.nodes) + C = [set(nodes[:mid_value]), set(nodes[mid_value:])] + assert 0.13 == pytest.approx(modularity(G, C), abs=1e-7) + + G = nx.DiGraph() + G.add_edges_from([(2, 1), (2, 3), (3, 4)]) + C = [{1, 2}, {3, 4}] + assert 2 / 9 == pytest.approx(modularity(G, C), abs=1e-7) + + +def test_modularity_resolution(): + G = nx.barbell_graph(3, 0) + C = [{0, 1, 4}, {2, 3, 5}] + assert modularity(G, C) == pytest.approx(3 / 7 - 100 / 14**2) + gamma = 2 + result = modularity(G, C, resolution=gamma) + assert result == pytest.approx(3 / 7 - gamma * 100 / 14**2) + gamma = 0.2 + result = modularity(G, C, resolution=gamma) + assert result == pytest.approx(3 / 7 - gamma * 100 / 14**2) + + C = [{0, 1, 2}, {3, 4, 5}] + assert modularity(G, C) == pytest.approx(6 / 7 - 98 / 14**2) + gamma = 2 + result = modularity(G, C, resolution=gamma) + assert result == pytest.approx(6 / 7 - gamma * 98 / 14**2) + gamma = 0.2 + result = modularity(G, C, resolution=gamma) + assert result == pytest.approx(6 / 7 - gamma * 98 / 14**2) + + G = nx.barbell_graph(5, 3) + C = [frozenset(range(5)), frozenset(range(8, 13)), frozenset(range(5, 8))] + gamma = 1 + result = modularity(G, C, resolution=gamma) + # This C is maximal for gamma=1: modularity = 0.518229 + assert result == pytest.approx((22 / 24) - gamma * (918 / (48**2))) + gamma = 2 + result = modularity(G, C, resolution=gamma) + assert result == pytest.approx((22 / 24) - gamma * (918 / (48**2))) + gamma = 0.2 + result = modularity(G, C, resolution=gamma) + assert result == pytest.approx((22 / 24) - gamma * (918 / (48**2))) + + C = [{0, 1, 2, 3}, {9, 10, 11, 12}, {5, 6, 7}, {4}, {8}] + gamma = 1 + result = modularity(G, C, resolution=gamma) + assert result == pytest.approx((14 / 24) - gamma * (598 / (48**2))) + gamma = 2.5 + result = modularity(G, C, resolution=gamma) + # This C is maximal for gamma=2.5: modularity = -0.06553819 + assert result == pytest.approx((14 / 24) - gamma * (598 / (48**2))) + gamma = 0.2 + result = modularity(G, C, resolution=gamma) + assert result == pytest.approx((14 / 24) - gamma * (598 / (48**2))) + + C = [frozenset(range(8)), frozenset(range(8, 13))] + gamma = 1 + result = modularity(G, C, resolution=gamma) + assert result == pytest.approx((23 / 24) - gamma * (1170 / (48**2))) + gamma = 2 + result = modularity(G, C, resolution=gamma) + assert result == pytest.approx((23 / 24) - gamma * (1170 / (48**2))) + gamma = 0.3 + result = modularity(G, C, resolution=gamma) + # This C is maximal for gamma=0.3: modularity = 0.805990 + assert result == pytest.approx((23 / 24) - gamma * (1170 / (48**2))) + + +def test_inter_community_edges_with_digraphs(): + G = nx.complete_graph(2, create_using=nx.DiGraph()) + partition = [{0}, {1}] + assert inter_community_edges(G, partition) == 2 + + G = nx.complete_graph(10, create_using=nx.DiGraph()) + partition = [{0}, {1, 2}, {3, 4, 5}, {6, 7, 8, 9}] + assert inter_community_edges(G, partition) == 70 + + G = nx.cycle_graph(4, create_using=nx.DiGraph()) + partition = [{0, 1}, {2, 3}] + assert inter_community_edges(G, partition) == 2 diff --git a/lib/python3.10/site-packages/networkx/algorithms/community/tests/test_utils.py b/lib/python3.10/site-packages/networkx/algorithms/community/tests/test_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ea019db9da8ae0eaacb9428845125515c36ff46a --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/community/tests/test_utils.py @@ -0,0 +1,26 @@ +"""Unit tests for the :mod:`networkx.algorithms.community.utils` module.""" + +import networkx as nx + + +def test_is_partition(): + G = nx.empty_graph(3) + assert nx.community.is_partition(G, [{0, 1}, {2}]) + assert nx.community.is_partition(G, ({0, 1}, {2})) + assert nx.community.is_partition(G, ([0, 1], [2])) + assert nx.community.is_partition(G, [[0, 1], [2]]) + + +def test_not_covering(): + G = nx.empty_graph(3) + assert not nx.community.is_partition(G, [{0}, {1}]) + + +def test_not_disjoint(): + G = nx.empty_graph(3) + assert not nx.community.is_partition(G, [{0, 1}, {1, 2}]) + + +def test_not_node(): + G = nx.empty_graph(3) + assert not nx.community.is_partition(G, [{0, 1}, {3}]) diff --git a/lib/python3.10/site-packages/networkx/algorithms/components/__init__.py b/lib/python3.10/site-packages/networkx/algorithms/components/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f9ae2caba856daba534037f4a6f967abfad49552 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/components/__init__.py @@ -0,0 +1,6 @@ +from .connected import * +from .strongly_connected import * +from .weakly_connected import * +from .attracting import * +from .biconnected import * +from .semiconnected import * diff --git a/lib/python3.10/site-packages/networkx/algorithms/components/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/components/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a3dabfc82276ff0f208f01e800e914708e692380 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/components/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/components/__pycache__/attracting.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/components/__pycache__/attracting.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8af480033a76871bd87016903de155826797e81f Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/components/__pycache__/attracting.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/components/__pycache__/biconnected.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/components/__pycache__/biconnected.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ba4b6730dc03075ec81c781a2d6c1b89d58d2517 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/components/__pycache__/biconnected.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/components/__pycache__/connected.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/components/__pycache__/connected.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a8191ffd97470fdd515923535b8be4a3923c7113 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/components/__pycache__/connected.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/components/__pycache__/semiconnected.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/components/__pycache__/semiconnected.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..33ed3327dd91eaa4685a75e01abefc6728df7e02 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/components/__pycache__/semiconnected.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/components/__pycache__/strongly_connected.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/components/__pycache__/strongly_connected.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d66751f78aec8ff0ec959e31ef6ffc2114f73d8 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/components/__pycache__/strongly_connected.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/components/__pycache__/weakly_connected.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/components/__pycache__/weakly_connected.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a0b50c1689efdc10269d5f2515fcac9284227342 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/components/__pycache__/weakly_connected.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/components/attracting.py b/lib/python3.10/site-packages/networkx/algorithms/components/attracting.py new file mode 100644 index 0000000000000000000000000000000000000000..3d77cd93d70efab5f29c77c7d135f4730e4c3a4a --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/components/attracting.py @@ -0,0 +1,115 @@ +"""Attracting components.""" + +import networkx as nx +from networkx.utils.decorators import not_implemented_for + +__all__ = [ + "number_attracting_components", + "attracting_components", + "is_attracting_component", +] + + +@not_implemented_for("undirected") +@nx._dispatchable +def attracting_components(G): + """Generates the attracting components in `G`. + + An attracting component in a directed graph `G` is a strongly connected + component with the property that a random walker on the graph will never + leave the component, once it enters the component. + + The nodes in attracting components can also be thought of as recurrent + nodes. If a random walker enters the attractor containing the node, then + the node will be visited infinitely often. + + To obtain induced subgraphs on each component use: + ``(G.subgraph(c).copy() for c in attracting_components(G))`` + + Parameters + ---------- + G : DiGraph, MultiDiGraph + The graph to be analyzed. + + Returns + ------- + attractors : generator of sets + A generator of sets of nodes, one for each attracting component of G. + + Raises + ------ + NetworkXNotImplemented + If the input graph is undirected. + + See Also + -------- + number_attracting_components + is_attracting_component + + """ + scc = list(nx.strongly_connected_components(G)) + cG = nx.condensation(G, scc) + for n in cG: + if cG.out_degree(n) == 0: + yield scc[n] + + +@not_implemented_for("undirected") +@nx._dispatchable +def number_attracting_components(G): + """Returns the number of attracting components in `G`. + + Parameters + ---------- + G : DiGraph, MultiDiGraph + The graph to be analyzed. + + Returns + ------- + n : int + The number of attracting components in G. + + Raises + ------ + NetworkXNotImplemented + If the input graph is undirected. + + See Also + -------- + attracting_components + is_attracting_component + + """ + return sum(1 for ac in attracting_components(G)) + + +@not_implemented_for("undirected") +@nx._dispatchable +def is_attracting_component(G): + """Returns True if `G` consists of a single attracting component. + + Parameters + ---------- + G : DiGraph, MultiDiGraph + The graph to be analyzed. + + Returns + ------- + attracting : bool + True if `G` has a single attracting component. Otherwise, False. + + Raises + ------ + NetworkXNotImplemented + If the input graph is undirected. + + See Also + -------- + attracting_components + number_attracting_components + + """ + ac = list(attracting_components(G)) + if len(ac) == 1: + return len(ac[0]) == len(G) + return False diff --git a/lib/python3.10/site-packages/networkx/algorithms/components/biconnected.py b/lib/python3.10/site-packages/networkx/algorithms/components/biconnected.py new file mode 100644 index 0000000000000000000000000000000000000000..fd0f3865bb18e9c9eb37d768c7fd3caceb1cde86 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/components/biconnected.py @@ -0,0 +1,394 @@ +"""Biconnected components and articulation points.""" + +from itertools import chain + +import networkx as nx +from networkx.utils.decorators import not_implemented_for + +__all__ = [ + "biconnected_components", + "biconnected_component_edges", + "is_biconnected", + "articulation_points", +] + + +@not_implemented_for("directed") +@nx._dispatchable +def is_biconnected(G): + """Returns True if the graph is biconnected, False otherwise. + + A graph is biconnected if, and only if, it cannot be disconnected by + removing only one node (and all edges incident on that node). If + removing a node increases the number of disconnected components + in the graph, that node is called an articulation point, or cut + vertex. A biconnected graph has no articulation points. + + Parameters + ---------- + G : NetworkX Graph + An undirected graph. + + Returns + ------- + biconnected : bool + True if the graph is biconnected, False otherwise. + + Raises + ------ + NetworkXNotImplemented + If the input graph is not undirected. + + Examples + -------- + >>> G = nx.path_graph(4) + >>> print(nx.is_biconnected(G)) + False + >>> G.add_edge(0, 3) + >>> print(nx.is_biconnected(G)) + True + + See Also + -------- + biconnected_components + articulation_points + biconnected_component_edges + is_strongly_connected + is_weakly_connected + is_connected + is_semiconnected + + Notes + ----- + The algorithm to find articulation points and biconnected + components is implemented using a non-recursive depth-first-search + (DFS) that keeps track of the highest level that back edges reach + in the DFS tree. A node `n` is an articulation point if, and only + if, there exists a subtree rooted at `n` such that there is no + back edge from any successor of `n` that links to a predecessor of + `n` in the DFS tree. By keeping track of all the edges traversed + by the DFS we can obtain the biconnected components because all + edges of a bicomponent will be traversed consecutively between + articulation points. + + References + ---------- + .. [1] Hopcroft, J.; Tarjan, R. (1973). + "Efficient algorithms for graph manipulation". + Communications of the ACM 16: 372–378. doi:10.1145/362248.362272 + + """ + bccs = biconnected_components(G) + try: + bcc = next(bccs) + except StopIteration: + # No bicomponents (empty graph?) + return False + try: + next(bccs) + except StopIteration: + # Only one bicomponent + return len(bcc) == len(G) + else: + # Multiple bicomponents + return False + + +@not_implemented_for("directed") +@nx._dispatchable +def biconnected_component_edges(G): + """Returns a generator of lists of edges, one list for each biconnected + component of the input graph. + + Biconnected components are maximal subgraphs such that the removal of a + node (and all edges incident on that node) will not disconnect the + subgraph. Note that nodes may be part of more than one biconnected + component. Those nodes are articulation points, or cut vertices. + However, each edge belongs to one, and only one, biconnected component. + + Notice that by convention a dyad is considered a biconnected component. + + Parameters + ---------- + G : NetworkX Graph + An undirected graph. + + Returns + ------- + edges : generator of lists + Generator of lists of edges, one list for each bicomponent. + + Raises + ------ + NetworkXNotImplemented + If the input graph is not undirected. + + Examples + -------- + >>> G = nx.barbell_graph(4, 2) + >>> print(nx.is_biconnected(G)) + False + >>> bicomponents_edges = list(nx.biconnected_component_edges(G)) + >>> len(bicomponents_edges) + 5 + >>> G.add_edge(2, 8) + >>> print(nx.is_biconnected(G)) + True + >>> bicomponents_edges = list(nx.biconnected_component_edges(G)) + >>> len(bicomponents_edges) + 1 + + See Also + -------- + is_biconnected, + biconnected_components, + articulation_points, + + Notes + ----- + The algorithm to find articulation points and biconnected + components is implemented using a non-recursive depth-first-search + (DFS) that keeps track of the highest level that back edges reach + in the DFS tree. A node `n` is an articulation point if, and only + if, there exists a subtree rooted at `n` such that there is no + back edge from any successor of `n` that links to a predecessor of + `n` in the DFS tree. By keeping track of all the edges traversed + by the DFS we can obtain the biconnected components because all + edges of a bicomponent will be traversed consecutively between + articulation points. + + References + ---------- + .. [1] Hopcroft, J.; Tarjan, R. (1973). + "Efficient algorithms for graph manipulation". + Communications of the ACM 16: 372–378. doi:10.1145/362248.362272 + + """ + yield from _biconnected_dfs(G, components=True) + + +@not_implemented_for("directed") +@nx._dispatchable +def biconnected_components(G): + """Returns a generator of sets of nodes, one set for each biconnected + component of the graph + + Biconnected components are maximal subgraphs such that the removal of a + node (and all edges incident on that node) will not disconnect the + subgraph. Note that nodes may be part of more than one biconnected + component. Those nodes are articulation points, or cut vertices. The + removal of articulation points will increase the number of connected + components of the graph. + + Notice that by convention a dyad is considered a biconnected component. + + Parameters + ---------- + G : NetworkX Graph + An undirected graph. + + Returns + ------- + nodes : generator + Generator of sets of nodes, one set for each biconnected component. + + Raises + ------ + NetworkXNotImplemented + If the input graph is not undirected. + + Examples + -------- + >>> G = nx.lollipop_graph(5, 1) + >>> print(nx.is_biconnected(G)) + False + >>> bicomponents = list(nx.biconnected_components(G)) + >>> len(bicomponents) + 2 + >>> G.add_edge(0, 5) + >>> print(nx.is_biconnected(G)) + True + >>> bicomponents = list(nx.biconnected_components(G)) + >>> len(bicomponents) + 1 + + You can generate a sorted list of biconnected components, largest + first, using sort. + + >>> G.remove_edge(0, 5) + >>> [len(c) for c in sorted(nx.biconnected_components(G), key=len, reverse=True)] + [5, 2] + + If you only want the largest connected component, it's more + efficient to use max instead of sort. + + >>> Gc = max(nx.biconnected_components(G), key=len) + + To create the components as subgraphs use: + ``(G.subgraph(c).copy() for c in biconnected_components(G))`` + + See Also + -------- + is_biconnected + articulation_points + biconnected_component_edges + k_components : this function is a special case where k=2 + bridge_components : similar to this function, but is defined using + 2-edge-connectivity instead of 2-node-connectivity. + + Notes + ----- + The algorithm to find articulation points and biconnected + components is implemented using a non-recursive depth-first-search + (DFS) that keeps track of the highest level that back edges reach + in the DFS tree. A node `n` is an articulation point if, and only + if, there exists a subtree rooted at `n` such that there is no + back edge from any successor of `n` that links to a predecessor of + `n` in the DFS tree. By keeping track of all the edges traversed + by the DFS we can obtain the biconnected components because all + edges of a bicomponent will be traversed consecutively between + articulation points. + + References + ---------- + .. [1] Hopcroft, J.; Tarjan, R. (1973). + "Efficient algorithms for graph manipulation". + Communications of the ACM 16: 372–378. doi:10.1145/362248.362272 + + """ + for comp in _biconnected_dfs(G, components=True): + yield set(chain.from_iterable(comp)) + + +@not_implemented_for("directed") +@nx._dispatchable +def articulation_points(G): + """Yield the articulation points, or cut vertices, of a graph. + + An articulation point or cut vertex is any node whose removal (along with + all its incident edges) increases the number of connected components of + a graph. An undirected connected graph without articulation points is + biconnected. Articulation points belong to more than one biconnected + component of a graph. + + Notice that by convention a dyad is considered a biconnected component. + + Parameters + ---------- + G : NetworkX Graph + An undirected graph. + + Yields + ------ + node + An articulation point in the graph. + + Raises + ------ + NetworkXNotImplemented + If the input graph is not undirected. + + Examples + -------- + + >>> G = nx.barbell_graph(4, 2) + >>> print(nx.is_biconnected(G)) + False + >>> len(list(nx.articulation_points(G))) + 4 + >>> G.add_edge(2, 8) + >>> print(nx.is_biconnected(G)) + True + >>> len(list(nx.articulation_points(G))) + 0 + + See Also + -------- + is_biconnected + biconnected_components + biconnected_component_edges + + Notes + ----- + The algorithm to find articulation points and biconnected + components is implemented using a non-recursive depth-first-search + (DFS) that keeps track of the highest level that back edges reach + in the DFS tree. A node `n` is an articulation point if, and only + if, there exists a subtree rooted at `n` such that there is no + back edge from any successor of `n` that links to a predecessor of + `n` in the DFS tree. By keeping track of all the edges traversed + by the DFS we can obtain the biconnected components because all + edges of a bicomponent will be traversed consecutively between + articulation points. + + References + ---------- + .. [1] Hopcroft, J.; Tarjan, R. (1973). + "Efficient algorithms for graph manipulation". + Communications of the ACM 16: 372–378. doi:10.1145/362248.362272 + + """ + seen = set() + for articulation in _biconnected_dfs(G, components=False): + if articulation not in seen: + seen.add(articulation) + yield articulation + + +@not_implemented_for("directed") +def _biconnected_dfs(G, components=True): + # depth-first search algorithm to generate articulation points + # and biconnected components + visited = set() + for start in G: + if start in visited: + continue + discovery = {start: 0} # time of first discovery of node during search + low = {start: 0} + root_children = 0 + visited.add(start) + edge_stack = [] + stack = [(start, start, iter(G[start]))] + edge_index = {} + while stack: + grandparent, parent, children = stack[-1] + try: + child = next(children) + if grandparent == child: + continue + if child in visited: + if discovery[child] <= discovery[parent]: # back edge + low[parent] = min(low[parent], discovery[child]) + if components: + edge_index[parent, child] = len(edge_stack) + edge_stack.append((parent, child)) + else: + low[child] = discovery[child] = len(discovery) + visited.add(child) + stack.append((parent, child, iter(G[child]))) + if components: + edge_index[parent, child] = len(edge_stack) + edge_stack.append((parent, child)) + + except StopIteration: + stack.pop() + if len(stack) > 1: + if low[parent] >= discovery[grandparent]: + if components: + ind = edge_index[grandparent, parent] + yield edge_stack[ind:] + del edge_stack[ind:] + + else: + yield grandparent + low[grandparent] = min(low[parent], low[grandparent]) + elif stack: # length 1 so grandparent is root + root_children += 1 + if components: + ind = edge_index[grandparent, parent] + yield edge_stack[ind:] + del edge_stack[ind:] + if not components: + # root node is articulation point if it has more than 1 child + if root_children > 1: + yield start diff --git a/lib/python3.10/site-packages/networkx/algorithms/components/connected.py b/lib/python3.10/site-packages/networkx/algorithms/components/connected.py new file mode 100644 index 0000000000000000000000000000000000000000..ebe0d8c157b57fe68589210f2fa5dcf1219cebc5 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/components/connected.py @@ -0,0 +1,216 @@ +"""Connected components.""" + +import networkx as nx +from networkx.utils.decorators import not_implemented_for + +from ...utils import arbitrary_element + +__all__ = [ + "number_connected_components", + "connected_components", + "is_connected", + "node_connected_component", +] + + +@not_implemented_for("directed") +@nx._dispatchable +def connected_components(G): + """Generate connected components. + + Parameters + ---------- + G : NetworkX graph + An undirected graph + + Returns + ------- + comp : generator of sets + A generator of sets of nodes, one for each component of G. + + Raises + ------ + NetworkXNotImplemented + If G is directed. + + Examples + -------- + Generate a sorted list of connected components, largest first. + + >>> G = nx.path_graph(4) + >>> nx.add_path(G, [10, 11, 12]) + >>> [len(c) for c in sorted(nx.connected_components(G), key=len, reverse=True)] + [4, 3] + + If you only want the largest connected component, it's more + efficient to use max instead of sort. + + >>> largest_cc = max(nx.connected_components(G), key=len) + + To create the induced subgraph of each component use: + + >>> S = [G.subgraph(c).copy() for c in nx.connected_components(G)] + + See Also + -------- + strongly_connected_components + weakly_connected_components + + Notes + ----- + For undirected graphs only. + + """ + seen = set() + n = len(G) + for v in G: + if v not in seen: + c = _plain_bfs(G, n, v) + seen.update(c) + yield c + + +@not_implemented_for("directed") +@nx._dispatchable +def number_connected_components(G): + """Returns the number of connected components. + + Parameters + ---------- + G : NetworkX graph + An undirected graph. + + Returns + ------- + n : integer + Number of connected components + + Raises + ------ + NetworkXNotImplemented + If G is directed. + + Examples + -------- + >>> G = nx.Graph([(0, 1), (1, 2), (5, 6), (3, 4)]) + >>> nx.number_connected_components(G) + 3 + + See Also + -------- + connected_components + number_weakly_connected_components + number_strongly_connected_components + + Notes + ----- + For undirected graphs only. + + """ + return sum(1 for cc in connected_components(G)) + + +@not_implemented_for("directed") +@nx._dispatchable +def is_connected(G): + """Returns True if the graph is connected, False otherwise. + + Parameters + ---------- + G : NetworkX Graph + An undirected graph. + + Returns + ------- + connected : bool + True if the graph is connected, false otherwise. + + Raises + ------ + NetworkXNotImplemented + If G is directed. + + Examples + -------- + >>> G = nx.path_graph(4) + >>> print(nx.is_connected(G)) + True + + See Also + -------- + is_strongly_connected + is_weakly_connected + is_semiconnected + is_biconnected + connected_components + + Notes + ----- + For undirected graphs only. + + """ + n = len(G) + if n == 0: + raise nx.NetworkXPointlessConcept( + "Connectivity is undefined for the null graph." + ) + return sum(1 for node in _plain_bfs(G, n, arbitrary_element(G))) == len(G) + + +@not_implemented_for("directed") +@nx._dispatchable +def node_connected_component(G, n): + """Returns the set of nodes in the component of graph containing node n. + + Parameters + ---------- + G : NetworkX Graph + An undirected graph. + + n : node label + A node in G + + Returns + ------- + comp : set + A set of nodes in the component of G containing node n. + + Raises + ------ + NetworkXNotImplemented + If G is directed. + + Examples + -------- + >>> G = nx.Graph([(0, 1), (1, 2), (5, 6), (3, 4)]) + >>> nx.node_connected_component(G, 0) # nodes of component that contains node 0 + {0, 1, 2} + + See Also + -------- + connected_components + + Notes + ----- + For undirected graphs only. + + """ + return _plain_bfs(G, len(G), n) + + +def _plain_bfs(G, n, source): + """A fast BFS node generator""" + adj = G._adj + seen = {source} + nextlevel = [source] + while nextlevel: + thislevel = nextlevel + nextlevel = [] + for v in thislevel: + for w in adj[v]: + if w not in seen: + seen.add(w) + nextlevel.append(w) + if len(seen) == n: + return seen + return seen diff --git a/lib/python3.10/site-packages/networkx/algorithms/components/semiconnected.py b/lib/python3.10/site-packages/networkx/algorithms/components/semiconnected.py new file mode 100644 index 0000000000000000000000000000000000000000..9ca5d762ca882524d1406f9295fa3a238fedb724 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/components/semiconnected.py @@ -0,0 +1,71 @@ +"""Semiconnectedness.""" + +import networkx as nx +from networkx.utils import not_implemented_for, pairwise + +__all__ = ["is_semiconnected"] + + +@not_implemented_for("undirected") +@nx._dispatchable +def is_semiconnected(G): + r"""Returns True if the graph is semiconnected, False otherwise. + + A graph is semiconnected if and only if for any pair of nodes, either one + is reachable from the other, or they are mutually reachable. + + This function uses a theorem that states that a DAG is semiconnected + if for any topological sort, for node $v_n$ in that sort, there is an + edge $(v_i, v_{i+1})$. That allows us to check if a non-DAG `G` is + semiconnected by condensing the graph: i.e. constructing a new graph `H` + with nodes being the strongly connected components of `G`, and edges + (scc_1, scc_2) if there is a edge $(v_1, v_2)$ in `G` for some + $v_1 \in scc_1$ and $v_2 \in scc_2$. That results in a DAG, so we compute + the topological sort of `H` and check if for every $n$ there is an edge + $(scc_n, scc_{n+1})$. + + Parameters + ---------- + G : NetworkX graph + A directed graph. + + Returns + ------- + semiconnected : bool + True if the graph is semiconnected, False otherwise. + + Raises + ------ + NetworkXNotImplemented + If the input graph is undirected. + + NetworkXPointlessConcept + If the graph is empty. + + Examples + -------- + >>> G = nx.path_graph(4, create_using=nx.DiGraph()) + >>> print(nx.is_semiconnected(G)) + True + >>> G = nx.DiGraph([(1, 2), (3, 2)]) + >>> print(nx.is_semiconnected(G)) + False + + See Also + -------- + is_strongly_connected + is_weakly_connected + is_connected + is_biconnected + """ + if len(G) == 0: + raise nx.NetworkXPointlessConcept( + "Connectivity is undefined for the null graph." + ) + + if not nx.is_weakly_connected(G): + return False + + H = nx.condensation(G) + + return all(H.has_edge(u, v) for u, v in pairwise(nx.topological_sort(H))) diff --git a/lib/python3.10/site-packages/networkx/algorithms/components/strongly_connected.py b/lib/python3.10/site-packages/networkx/algorithms/components/strongly_connected.py new file mode 100644 index 0000000000000000000000000000000000000000..393728ffe1f25a077aee6691fe913a81570ef0f1 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/components/strongly_connected.py @@ -0,0 +1,351 @@ +"""Strongly connected components.""" + +import networkx as nx +from networkx.utils.decorators import not_implemented_for + +__all__ = [ + "number_strongly_connected_components", + "strongly_connected_components", + "is_strongly_connected", + "kosaraju_strongly_connected_components", + "condensation", +] + + +@not_implemented_for("undirected") +@nx._dispatchable +def strongly_connected_components(G): + """Generate nodes in strongly connected components of graph. + + Parameters + ---------- + G : NetworkX Graph + A directed graph. + + Returns + ------- + comp : generator of sets + A generator of sets of nodes, one for each strongly connected + component of G. + + Raises + ------ + NetworkXNotImplemented + If G is undirected. + + Examples + -------- + Generate a sorted list of strongly connected components, largest first. + + >>> G = nx.cycle_graph(4, create_using=nx.DiGraph()) + >>> nx.add_cycle(G, [10, 11, 12]) + >>> [ + ... len(c) + ... for c in sorted(nx.strongly_connected_components(G), key=len, reverse=True) + ... ] + [4, 3] + + If you only want the largest component, it's more efficient to + use max instead of sort. + + >>> largest = max(nx.strongly_connected_components(G), key=len) + + See Also + -------- + connected_components + weakly_connected_components + kosaraju_strongly_connected_components + + Notes + ----- + Uses Tarjan's algorithm[1]_ with Nuutila's modifications[2]_. + Nonrecursive version of algorithm. + + References + ---------- + .. [1] Depth-first search and linear graph algorithms, R. Tarjan + SIAM Journal of Computing 1(2):146-160, (1972). + + .. [2] On finding the strongly connected components in a directed graph. + E. Nuutila and E. Soisalon-Soinen + Information Processing Letters 49(1): 9-14, (1994).. + + """ + preorder = {} + lowlink = {} + scc_found = set() + scc_queue = [] + i = 0 # Preorder counter + neighbors = {v: iter(G[v]) for v in G} + for source in G: + if source not in scc_found: + queue = [source] + while queue: + v = queue[-1] + if v not in preorder: + i = i + 1 + preorder[v] = i + done = True + for w in neighbors[v]: + if w not in preorder: + queue.append(w) + done = False + break + if done: + lowlink[v] = preorder[v] + for w in G[v]: + if w not in scc_found: + if preorder[w] > preorder[v]: + lowlink[v] = min([lowlink[v], lowlink[w]]) + else: + lowlink[v] = min([lowlink[v], preorder[w]]) + queue.pop() + if lowlink[v] == preorder[v]: + scc = {v} + while scc_queue and preorder[scc_queue[-1]] > preorder[v]: + k = scc_queue.pop() + scc.add(k) + scc_found.update(scc) + yield scc + else: + scc_queue.append(v) + + +@not_implemented_for("undirected") +@nx._dispatchable +def kosaraju_strongly_connected_components(G, source=None): + """Generate nodes in strongly connected components of graph. + + Parameters + ---------- + G : NetworkX Graph + A directed graph. + + Returns + ------- + comp : generator of sets + A generator of sets of nodes, one for each strongly connected + component of G. + + Raises + ------ + NetworkXNotImplemented + If G is undirected. + + Examples + -------- + Generate a sorted list of strongly connected components, largest first. + + >>> G = nx.cycle_graph(4, create_using=nx.DiGraph()) + >>> nx.add_cycle(G, [10, 11, 12]) + >>> [ + ... len(c) + ... for c in sorted( + ... nx.kosaraju_strongly_connected_components(G), key=len, reverse=True + ... ) + ... ] + [4, 3] + + If you only want the largest component, it's more efficient to + use max instead of sort. + + >>> largest = max(nx.kosaraju_strongly_connected_components(G), key=len) + + See Also + -------- + strongly_connected_components + + Notes + ----- + Uses Kosaraju's algorithm. + + """ + post = list(nx.dfs_postorder_nodes(G.reverse(copy=False), source=source)) + + seen = set() + while post: + r = post.pop() + if r in seen: + continue + c = nx.dfs_preorder_nodes(G, r) + new = {v for v in c if v not in seen} + seen.update(new) + yield new + + +@not_implemented_for("undirected") +@nx._dispatchable +def number_strongly_connected_components(G): + """Returns number of strongly connected components in graph. + + Parameters + ---------- + G : NetworkX graph + A directed graph. + + Returns + ------- + n : integer + Number of strongly connected components + + Raises + ------ + NetworkXNotImplemented + If G is undirected. + + Examples + -------- + >>> G = nx.DiGraph( + ... [(0, 1), (1, 2), (2, 0), (2, 3), (4, 5), (3, 4), (5, 6), (6, 3), (6, 7)] + ... ) + >>> nx.number_strongly_connected_components(G) + 3 + + See Also + -------- + strongly_connected_components + number_connected_components + number_weakly_connected_components + + Notes + ----- + For directed graphs only. + """ + return sum(1 for scc in strongly_connected_components(G)) + + +@not_implemented_for("undirected") +@nx._dispatchable +def is_strongly_connected(G): + """Test directed graph for strong connectivity. + + A directed graph is strongly connected if and only if every vertex in + the graph is reachable from every other vertex. + + Parameters + ---------- + G : NetworkX Graph + A directed graph. + + Returns + ------- + connected : bool + True if the graph is strongly connected, False otherwise. + + Examples + -------- + >>> G = nx.DiGraph([(0, 1), (1, 2), (2, 3), (3, 0), (2, 4), (4, 2)]) + >>> nx.is_strongly_connected(G) + True + >>> G.remove_edge(2, 3) + >>> nx.is_strongly_connected(G) + False + + Raises + ------ + NetworkXNotImplemented + If G is undirected. + + See Also + -------- + is_weakly_connected + is_semiconnected + is_connected + is_biconnected + strongly_connected_components + + Notes + ----- + For directed graphs only. + """ + if len(G) == 0: + raise nx.NetworkXPointlessConcept( + """Connectivity is undefined for the null graph.""" + ) + + return len(next(strongly_connected_components(G))) == len(G) + + +@not_implemented_for("undirected") +@nx._dispatchable(returns_graph=True) +def condensation(G, scc=None): + """Returns the condensation of G. + + The condensation of G is the graph with each of the strongly connected + components contracted into a single node. + + Parameters + ---------- + G : NetworkX DiGraph + A directed graph. + + scc: list or generator (optional, default=None) + Strongly connected components. If provided, the elements in + `scc` must partition the nodes in `G`. If not provided, it will be + calculated as scc=nx.strongly_connected_components(G). + + Returns + ------- + C : NetworkX DiGraph + The condensation graph C of G. The node labels are integers + corresponding to the index of the component in the list of + strongly connected components of G. C has a graph attribute named + 'mapping' with a dictionary mapping the original nodes to the + nodes in C to which they belong. Each node in C also has a node + attribute 'members' with the set of original nodes in G that + form the SCC that the node in C represents. + + Raises + ------ + NetworkXNotImplemented + If G is undirected. + + Examples + -------- + Contracting two sets of strongly connected nodes into two distinct SCC + using the barbell graph. + + >>> G = nx.barbell_graph(4, 0) + >>> G.remove_edge(3, 4) + >>> G = nx.DiGraph(G) + >>> H = nx.condensation(G) + >>> H.nodes.data() + NodeDataView({0: {'members': {0, 1, 2, 3}}, 1: {'members': {4, 5, 6, 7}}}) + >>> H.graph["mapping"] + {0: 0, 1: 0, 2: 0, 3: 0, 4: 1, 5: 1, 6: 1, 7: 1} + + Contracting a complete graph into one single SCC. + + >>> G = nx.complete_graph(7, create_using=nx.DiGraph) + >>> H = nx.condensation(G) + >>> H.nodes + NodeView((0,)) + >>> H.nodes.data() + NodeDataView({0: {'members': {0, 1, 2, 3, 4, 5, 6}}}) + + Notes + ----- + After contracting all strongly connected components to a single node, + the resulting graph is a directed acyclic graph. + + """ + if scc is None: + scc = nx.strongly_connected_components(G) + mapping = {} + members = {} + C = nx.DiGraph() + # Add mapping dict as graph attribute + C.graph["mapping"] = mapping + if len(G) == 0: + return C + for i, component in enumerate(scc): + members[i] = component + mapping.update((n, i) for n in component) + number_of_components = i + 1 + C.add_nodes_from(range(number_of_components)) + C.add_edges_from( + (mapping[u], mapping[v]) for u, v in G.edges() if mapping[u] != mapping[v] + ) + # Add a list of members (ie original nodes) to each node (ie scc) in C. + nx.set_node_attributes(C, members, "members") + return C diff --git a/lib/python3.10/site-packages/networkx/algorithms/components/tests/__init__.py b/lib/python3.10/site-packages/networkx/algorithms/components/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/python3.10/site-packages/networkx/algorithms/components/tests/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/components/tests/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d2349fdac1c560a0d0233577a99761c50dc0e0de Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/components/tests/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/components/tests/__pycache__/test_attracting.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/components/tests/__pycache__/test_attracting.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb9ff910a793a8b9852eb1aa8f8626dace0a7ebc Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/components/tests/__pycache__/test_attracting.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/components/tests/__pycache__/test_biconnected.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/components/tests/__pycache__/test_biconnected.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eff2ac68047f2f78679ed53f5b592397ed3af6c5 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/components/tests/__pycache__/test_biconnected.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/components/tests/__pycache__/test_connected.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/components/tests/__pycache__/test_connected.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fefc39fb7ecf39a8b607c78a9958d792222aa1b3 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/components/tests/__pycache__/test_connected.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/components/tests/__pycache__/test_semiconnected.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/components/tests/__pycache__/test_semiconnected.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a1d7c646006b9231f7ed80133986010d4e69ab6 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/components/tests/__pycache__/test_semiconnected.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/components/tests/__pycache__/test_strongly_connected.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/components/tests/__pycache__/test_strongly_connected.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3ea36fa47c341939c21b0c4629d4cdfdf6ae4b09 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/components/tests/__pycache__/test_strongly_connected.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/components/tests/__pycache__/test_weakly_connected.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/components/tests/__pycache__/test_weakly_connected.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f1ec77090b8018a9ceecf66626311dad8378204 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/components/tests/__pycache__/test_weakly_connected.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/components/tests/test_attracting.py b/lib/python3.10/site-packages/networkx/algorithms/components/tests/test_attracting.py new file mode 100644 index 0000000000000000000000000000000000000000..336c40ddc27162c1c2f5cc245f4fc840311506b5 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/components/tests/test_attracting.py @@ -0,0 +1,70 @@ +import pytest + +import networkx as nx +from networkx import NetworkXNotImplemented + + +class TestAttractingComponents: + @classmethod + def setup_class(cls): + cls.G1 = nx.DiGraph() + cls.G1.add_edges_from( + [ + (5, 11), + (11, 2), + (11, 9), + (11, 10), + (7, 11), + (7, 8), + (8, 9), + (3, 8), + (3, 10), + ] + ) + cls.G2 = nx.DiGraph() + cls.G2.add_edges_from([(0, 1), (0, 2), (1, 1), (1, 2), (2, 1)]) + + cls.G3 = nx.DiGraph() + cls.G3.add_edges_from([(0, 1), (1, 2), (2, 1), (0, 3), (3, 4), (4, 3)]) + + cls.G4 = nx.DiGraph() + + def test_attracting_components(self): + ac = list(nx.attracting_components(self.G1)) + assert {2} in ac + assert {9} in ac + assert {10} in ac + + ac = list(nx.attracting_components(self.G2)) + ac = [tuple(sorted(x)) for x in ac] + assert ac == [(1, 2)] + + ac = list(nx.attracting_components(self.G3)) + ac = [tuple(sorted(x)) for x in ac] + assert (1, 2) in ac + assert (3, 4) in ac + assert len(ac) == 2 + + ac = list(nx.attracting_components(self.G4)) + assert ac == [] + + def test_number_attacting_components(self): + assert nx.number_attracting_components(self.G1) == 3 + assert nx.number_attracting_components(self.G2) == 1 + assert nx.number_attracting_components(self.G3) == 2 + assert nx.number_attracting_components(self.G4) == 0 + + def test_is_attracting_component(self): + assert not nx.is_attracting_component(self.G1) + assert not nx.is_attracting_component(self.G2) + assert not nx.is_attracting_component(self.G3) + g2 = self.G3.subgraph([1, 2]) + assert nx.is_attracting_component(g2) + assert not nx.is_attracting_component(self.G4) + + def test_connected_raise(self): + G = nx.Graph() + with pytest.raises(NetworkXNotImplemented): + next(nx.attracting_components(G)) + pytest.raises(NetworkXNotImplemented, nx.number_attracting_components, G) + pytest.raises(NetworkXNotImplemented, nx.is_attracting_component, G) diff --git a/lib/python3.10/site-packages/networkx/algorithms/components/tests/test_biconnected.py b/lib/python3.10/site-packages/networkx/algorithms/components/tests/test_biconnected.py new file mode 100644 index 0000000000000000000000000000000000000000..19d2d8831ced26a516d101e735b6701f39865c1b --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/components/tests/test_biconnected.py @@ -0,0 +1,248 @@ +import pytest + +import networkx as nx +from networkx import NetworkXNotImplemented + + +def assert_components_edges_equal(x, y): + sx = {frozenset(frozenset(e) for e in c) for c in x} + sy = {frozenset(frozenset(e) for e in c) for c in y} + assert sx == sy + + +def assert_components_equal(x, y): + sx = {frozenset(c) for c in x} + sy = {frozenset(c) for c in y} + assert sx == sy + + +def test_barbell(): + G = nx.barbell_graph(8, 4) + nx.add_path(G, [7, 20, 21, 22]) + nx.add_cycle(G, [22, 23, 24, 25]) + pts = set(nx.articulation_points(G)) + assert pts == {7, 8, 9, 10, 11, 12, 20, 21, 22} + + answer = [ + {12, 13, 14, 15, 16, 17, 18, 19}, + {0, 1, 2, 3, 4, 5, 6, 7}, + {22, 23, 24, 25}, + {11, 12}, + {10, 11}, + {9, 10}, + {8, 9}, + {7, 8}, + {21, 22}, + {20, 21}, + {7, 20}, + ] + assert_components_equal(list(nx.biconnected_components(G)), answer) + + G.add_edge(2, 17) + pts = set(nx.articulation_points(G)) + assert pts == {7, 20, 21, 22} + + +def test_articulation_points_repetitions(): + G = nx.Graph() + G.add_edges_from([(0, 1), (1, 2), (1, 3)]) + assert list(nx.articulation_points(G)) == [1] + + +def test_articulation_points_cycle(): + G = nx.cycle_graph(3) + nx.add_cycle(G, [1, 3, 4]) + pts = set(nx.articulation_points(G)) + assert pts == {1} + + +def test_is_biconnected(): + G = nx.cycle_graph(3) + assert nx.is_biconnected(G) + nx.add_cycle(G, [1, 3, 4]) + assert not nx.is_biconnected(G) + + +def test_empty_is_biconnected(): + G = nx.empty_graph(5) + assert not nx.is_biconnected(G) + G.add_edge(0, 1) + assert not nx.is_biconnected(G) + + +def test_biconnected_components_cycle(): + G = nx.cycle_graph(3) + nx.add_cycle(G, [1, 3, 4]) + answer = [{0, 1, 2}, {1, 3, 4}] + assert_components_equal(list(nx.biconnected_components(G)), answer) + + +def test_biconnected_components1(): + # graph example from + # https://web.archive.org/web/20121229123447/http://www.ibluemojo.com/school/articul_algorithm.html + edges = [ + (0, 1), + (0, 5), + (0, 6), + (0, 14), + (1, 5), + (1, 6), + (1, 14), + (2, 4), + (2, 10), + (3, 4), + (3, 15), + (4, 6), + (4, 7), + (4, 10), + (5, 14), + (6, 14), + (7, 9), + (8, 9), + (8, 12), + (8, 13), + (10, 15), + (11, 12), + (11, 13), + (12, 13), + ] + G = nx.Graph(edges) + pts = set(nx.articulation_points(G)) + assert pts == {4, 6, 7, 8, 9} + comps = list(nx.biconnected_component_edges(G)) + answer = [ + [(3, 4), (15, 3), (10, 15), (10, 4), (2, 10), (4, 2)], + [(13, 12), (13, 8), (11, 13), (12, 11), (8, 12)], + [(9, 8)], + [(7, 9)], + [(4, 7)], + [(6, 4)], + [(14, 0), (5, 1), (5, 0), (14, 5), (14, 1), (6, 14), (6, 0), (1, 6), (0, 1)], + ] + assert_components_edges_equal(comps, answer) + + +def test_biconnected_components2(): + G = nx.Graph() + nx.add_cycle(G, "ABC") + nx.add_cycle(G, "CDE") + nx.add_cycle(G, "FIJHG") + nx.add_cycle(G, "GIJ") + G.add_edge("E", "G") + comps = list(nx.biconnected_component_edges(G)) + answer = [ + [ + tuple("GF"), + tuple("FI"), + tuple("IG"), + tuple("IJ"), + tuple("JG"), + tuple("JH"), + tuple("HG"), + ], + [tuple("EG")], + [tuple("CD"), tuple("DE"), tuple("CE")], + [tuple("AB"), tuple("BC"), tuple("AC")], + ] + assert_components_edges_equal(comps, answer) + + +def test_biconnected_davis(): + D = nx.davis_southern_women_graph() + bcc = list(nx.biconnected_components(D))[0] + assert set(D) == bcc # All nodes in a giant bicomponent + # So no articulation points + assert len(list(nx.articulation_points(D))) == 0 + + +def test_biconnected_karate(): + K = nx.karate_club_graph() + answer = [ + { + 0, + 1, + 2, + 3, + 7, + 8, + 9, + 12, + 13, + 14, + 15, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + 27, + 28, + 29, + 30, + 31, + 32, + 33, + }, + {0, 4, 5, 6, 10, 16}, + {0, 11}, + ] + bcc = list(nx.biconnected_components(K)) + assert_components_equal(bcc, answer) + assert set(nx.articulation_points(K)) == {0} + + +def test_biconnected_eppstein(): + # tests from http://www.ics.uci.edu/~eppstein/PADS/Biconnectivity.py + G1 = nx.Graph( + { + 0: [1, 2, 5], + 1: [0, 5], + 2: [0, 3, 4], + 3: [2, 4, 5, 6], + 4: [2, 3, 5, 6], + 5: [0, 1, 3, 4], + 6: [3, 4], + } + ) + G2 = nx.Graph( + { + 0: [2, 5], + 1: [3, 8], + 2: [0, 3, 5], + 3: [1, 2, 6, 8], + 4: [7], + 5: [0, 2], + 6: [3, 8], + 7: [4], + 8: [1, 3, 6], + } + ) + assert nx.is_biconnected(G1) + assert not nx.is_biconnected(G2) + answer_G2 = [{1, 3, 6, 8}, {0, 2, 5}, {2, 3}, {4, 7}] + bcc = list(nx.biconnected_components(G2)) + assert_components_equal(bcc, answer_G2) + + +def test_null_graph(): + G = nx.Graph() + assert not nx.is_biconnected(G) + assert list(nx.biconnected_components(G)) == [] + assert list(nx.biconnected_component_edges(G)) == [] + assert list(nx.articulation_points(G)) == [] + + +def test_connected_raise(): + DG = nx.DiGraph() + with pytest.raises(NetworkXNotImplemented): + next(nx.biconnected_components(DG)) + with pytest.raises(NetworkXNotImplemented): + next(nx.biconnected_component_edges(DG)) + with pytest.raises(NetworkXNotImplemented): + next(nx.articulation_points(DG)) + pytest.raises(NetworkXNotImplemented, nx.is_biconnected, DG) diff --git a/lib/python3.10/site-packages/networkx/algorithms/components/tests/test_connected.py b/lib/python3.10/site-packages/networkx/algorithms/components/tests/test_connected.py new file mode 100644 index 0000000000000000000000000000000000000000..207214c1262ed58ac1152a5917a270514748dc0a --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/components/tests/test_connected.py @@ -0,0 +1,138 @@ +import pytest + +import networkx as nx +from networkx import NetworkXNotImplemented +from networkx import convert_node_labels_to_integers as cnlti +from networkx.classes.tests import dispatch_interface + + +class TestConnected: + @classmethod + def setup_class(cls): + G1 = cnlti(nx.grid_2d_graph(2, 2), first_label=0, ordering="sorted") + G2 = cnlti(nx.lollipop_graph(3, 3), first_label=4, ordering="sorted") + G3 = cnlti(nx.house_graph(), first_label=10, ordering="sorted") + cls.G = nx.union(G1, G2) + cls.G = nx.union(cls.G, G3) + cls.DG = nx.DiGraph([(1, 2), (1, 3), (2, 3)]) + cls.grid = cnlti(nx.grid_2d_graph(4, 4), first_label=1) + + cls.gc = [] + G = nx.DiGraph() + G.add_edges_from( + [ + (1, 2), + (2, 3), + (2, 8), + (3, 4), + (3, 7), + (4, 5), + (5, 3), + (5, 6), + (7, 4), + (7, 6), + (8, 1), + (8, 7), + ] + ) + C = [[3, 4, 5, 7], [1, 2, 8], [6]] + cls.gc.append((G, C)) + + G = nx.DiGraph() + G.add_edges_from([(1, 2), (1, 3), (1, 4), (4, 2), (3, 4), (2, 3)]) + C = [[2, 3, 4], [1]] + cls.gc.append((G, C)) + + G = nx.DiGraph() + G.add_edges_from([(1, 2), (2, 3), (3, 2), (2, 1)]) + C = [[1, 2, 3]] + cls.gc.append((G, C)) + + # Eppstein's tests + G = nx.DiGraph({0: [1], 1: [2, 3], 2: [4, 5], 3: [4, 5], 4: [6], 5: [], 6: []}) + C = [[0], [1], [2], [3], [4], [5], [6]] + cls.gc.append((G, C)) + + G = nx.DiGraph({0: [1], 1: [2, 3, 4], 2: [0, 3], 3: [4], 4: [3]}) + C = [[0, 1, 2], [3, 4]] + cls.gc.append((G, C)) + + G = nx.DiGraph() + C = [] + cls.gc.append((G, C)) + + def test_connected_components(self): + # Test duplicated below + cc = nx.connected_components + G = self.G + C = { + frozenset([0, 1, 2, 3]), + frozenset([4, 5, 6, 7, 8, 9]), + frozenset([10, 11, 12, 13, 14]), + } + assert {frozenset(g) for g in cc(G)} == C + + def test_connected_components_nx_loopback(self): + # This tests the @nx._dispatchable mechanism, treating nx.connected_components + # as if it were a re-implementation from another package. + # Test duplicated from above + cc = nx.connected_components + G = dispatch_interface.convert(self.G) + C = { + frozenset([0, 1, 2, 3]), + frozenset([4, 5, 6, 7, 8, 9]), + frozenset([10, 11, 12, 13, 14]), + } + if "nx_loopback" in nx.config.backends or not nx.config.backends: + # If `nx.config.backends` is empty, then `_dispatchable.__call__` takes a + # "fast path" and does not check graph inputs, so using an unknown backend + # here will still work. + assert {frozenset(g) for g in cc(G)} == C + else: + # This raises, because "nx_loopback" is not registered as a backend. + with pytest.raises( + ImportError, match="'nx_loopback' backend is not installed" + ): + cc(G) + + def test_number_connected_components(self): + ncc = nx.number_connected_components + assert ncc(self.G) == 3 + + def test_number_connected_components2(self): + ncc = nx.number_connected_components + assert ncc(self.grid) == 1 + + def test_connected_components2(self): + cc = nx.connected_components + G = self.grid + C = {frozenset([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16])} + assert {frozenset(g) for g in cc(G)} == C + + def test_node_connected_components(self): + ncc = nx.node_connected_component + G = self.grid + C = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} + assert ncc(G, 1) == C + + def test_is_connected(self): + assert nx.is_connected(self.grid) + G = nx.Graph() + G.add_nodes_from([1, 2]) + assert not nx.is_connected(G) + + def test_connected_raise(self): + with pytest.raises(NetworkXNotImplemented): + next(nx.connected_components(self.DG)) + pytest.raises(NetworkXNotImplemented, nx.number_connected_components, self.DG) + pytest.raises(NetworkXNotImplemented, nx.node_connected_component, self.DG, 1) + pytest.raises(NetworkXNotImplemented, nx.is_connected, self.DG) + pytest.raises(nx.NetworkXPointlessConcept, nx.is_connected, nx.Graph()) + + def test_connected_mutability(self): + G = self.grid + seen = set() + for component in nx.connected_components(G): + assert len(seen & component) == 0 + seen.update(component) + component.clear() diff --git a/lib/python3.10/site-packages/networkx/algorithms/components/tests/test_semiconnected.py b/lib/python3.10/site-packages/networkx/algorithms/components/tests/test_semiconnected.py new file mode 100644 index 0000000000000000000000000000000000000000..6376bbfb12a061e1724b0c74d2614e116149d8bf --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/components/tests/test_semiconnected.py @@ -0,0 +1,55 @@ +from itertools import chain + +import pytest + +import networkx as nx + + +class TestIsSemiconnected: + def test_undirected(self): + pytest.raises(nx.NetworkXNotImplemented, nx.is_semiconnected, nx.Graph()) + pytest.raises(nx.NetworkXNotImplemented, nx.is_semiconnected, nx.MultiGraph()) + + def test_empty(self): + pytest.raises(nx.NetworkXPointlessConcept, nx.is_semiconnected, nx.DiGraph()) + pytest.raises( + nx.NetworkXPointlessConcept, nx.is_semiconnected, nx.MultiDiGraph() + ) + + def test_single_node_graph(self): + G = nx.DiGraph() + G.add_node(0) + assert nx.is_semiconnected(G) + + def test_path(self): + G = nx.path_graph(100, create_using=nx.DiGraph()) + assert nx.is_semiconnected(G) + G.add_edge(100, 99) + assert not nx.is_semiconnected(G) + + def test_cycle(self): + G = nx.cycle_graph(100, create_using=nx.DiGraph()) + assert nx.is_semiconnected(G) + G = nx.path_graph(100, create_using=nx.DiGraph()) + G.add_edge(0, 99) + assert nx.is_semiconnected(G) + + def test_tree(self): + G = nx.DiGraph() + G.add_edges_from( + chain.from_iterable([(i, 2 * i + 1), (i, 2 * i + 2)] for i in range(100)) + ) + assert not nx.is_semiconnected(G) + + def test_dumbbell(self): + G = nx.cycle_graph(100, create_using=nx.DiGraph()) + G.add_edges_from((i + 100, (i + 1) % 100 + 100) for i in range(100)) + assert not nx.is_semiconnected(G) # G is disconnected. + G.add_edge(100, 99) + assert nx.is_semiconnected(G) + + def test_alternating_path(self): + G = nx.DiGraph( + chain.from_iterable([(i, i - 1), (i, i + 1)] for i in range(0, 100, 2)) + ) + assert not nx.is_semiconnected(G) diff --git a/lib/python3.10/site-packages/networkx/algorithms/components/tests/test_strongly_connected.py b/lib/python3.10/site-packages/networkx/algorithms/components/tests/test_strongly_connected.py new file mode 100644 index 0000000000000000000000000000000000000000..27f40988265b61eec9edb2bde64433f7396022f0 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/components/tests/test_strongly_connected.py @@ -0,0 +1,193 @@ +import pytest + +import networkx as nx +from networkx import NetworkXNotImplemented + + +class TestStronglyConnected: + @classmethod + def setup_class(cls): + cls.gc = [] + G = nx.DiGraph() + G.add_edges_from( + [ + (1, 2), + (2, 3), + (2, 8), + (3, 4), + (3, 7), + (4, 5), + (5, 3), + (5, 6), + (7, 4), + (7, 6), + (8, 1), + (8, 7), + ] + ) + C = {frozenset([3, 4, 5, 7]), frozenset([1, 2, 8]), frozenset([6])} + cls.gc.append((G, C)) + + G = nx.DiGraph() + G.add_edges_from([(1, 2), (1, 3), (1, 4), (4, 2), (3, 4), (2, 3)]) + C = {frozenset([2, 3, 4]), frozenset([1])} + cls.gc.append((G, C)) + + G = nx.DiGraph() + G.add_edges_from([(1, 2), (2, 3), (3, 2), (2, 1)]) + C = {frozenset([1, 2, 3])} + cls.gc.append((G, C)) + + # Eppstein's tests + G = nx.DiGraph({0: [1], 1: [2, 3], 2: [4, 5], 3: [4, 5], 4: [6], 5: [], 6: []}) + C = { + frozenset([0]), + frozenset([1]), + frozenset([2]), + frozenset([3]), + frozenset([4]), + frozenset([5]), + frozenset([6]), + } + cls.gc.append((G, C)) + + G = nx.DiGraph({0: [1], 1: [2, 3, 4], 2: [0, 3], 3: [4], 4: [3]}) + C = {frozenset([0, 1, 2]), frozenset([3, 4])} + cls.gc.append((G, C)) + + def test_tarjan(self): + scc = nx.strongly_connected_components + for G, C in self.gc: + assert {frozenset(g) for g in scc(G)} == C + + def test_kosaraju(self): + scc = nx.kosaraju_strongly_connected_components + for G, C in self.gc: + assert {frozenset(g) for g in scc(G)} == C + + def test_number_strongly_connected_components(self): + ncc = nx.number_strongly_connected_components + for G, C in self.gc: + assert ncc(G) == len(C) + + def test_is_strongly_connected(self): + for G, C in self.gc: + if len(C) == 1: + assert nx.is_strongly_connected(G) + else: + assert not nx.is_strongly_connected(G) + + def test_contract_scc1(self): + G = nx.DiGraph() + G.add_edges_from( + [ + (1, 2), + (2, 3), + (2, 11), + (2, 12), + (3, 4), + (4, 3), + (4, 5), + (5, 6), + (6, 5), + (6, 7), + (7, 8), + (7, 9), + (7, 10), + (8, 9), + (9, 7), + (10, 6), + (11, 2), + (11, 4), + (11, 6), + (12, 6), + (12, 11), + ] + ) + scc = list(nx.strongly_connected_components(G)) + cG = nx.condensation(G, scc) + # DAG + assert nx.is_directed_acyclic_graph(cG) + # nodes + assert sorted(cG.nodes()) == [0, 1, 2, 3] + # edges + mapping = {} + for i, component in enumerate(scc): + for n in component: + mapping[n] = i + edge = (mapping[2], mapping[3]) + assert cG.has_edge(*edge) + edge = (mapping[2], mapping[5]) + assert cG.has_edge(*edge) + edge = (mapping[3], mapping[5]) + assert cG.has_edge(*edge) + + def test_contract_scc_isolate(self): + # Bug found and fixed in [1687]. + G = nx.DiGraph() + G.add_edge(1, 2) + G.add_edge(2, 1) + scc = list(nx.strongly_connected_components(G)) + cG = nx.condensation(G, scc) + assert list(cG.nodes()) == [0] + assert list(cG.edges()) == [] + + def test_contract_scc_edge(self): + G = nx.DiGraph() + G.add_edge(1, 2) + G.add_edge(2, 1) + G.add_edge(2, 3) + G.add_edge(3, 4) + G.add_edge(4, 3) + scc = list(nx.strongly_connected_components(G)) + cG = nx.condensation(G, scc) + assert sorted(cG.nodes()) == [0, 1] + if 1 in scc[0]: + edge = (0, 1) + else: + edge = (1, 0) + assert list(cG.edges()) == [edge] + + def test_condensation_mapping_and_members(self): + G, C = self.gc[1] + C = sorted(C, key=len, reverse=True) + cG = nx.condensation(G) + mapping = cG.graph["mapping"] + assert all(n in G for n in mapping) + assert all(0 == cN for n, cN in mapping.items() if n in C[0]) + assert all(1 == cN for n, cN in mapping.items() if n in C[1]) + for n, d in cG.nodes(data=True): + assert set(C[n]) == cG.nodes[n]["members"] + + def test_null_graph(self): + G = nx.DiGraph() + assert list(nx.strongly_connected_components(G)) == [] + assert list(nx.kosaraju_strongly_connected_components(G)) == [] + assert len(nx.condensation(G)) == 0 + pytest.raises( + nx.NetworkXPointlessConcept, nx.is_strongly_connected, nx.DiGraph() + ) + + def test_connected_raise(self): + G = nx.Graph() + with pytest.raises(NetworkXNotImplemented): + next(nx.strongly_connected_components(G)) + with pytest.raises(NetworkXNotImplemented): + next(nx.kosaraju_strongly_connected_components(G)) + pytest.raises(NetworkXNotImplemented, nx.is_strongly_connected, G) + pytest.raises(NetworkXNotImplemented, nx.condensation, G) + + strong_cc_methods = ( + nx.strongly_connected_components, + nx.kosaraju_strongly_connected_components, + ) + + @pytest.mark.parametrize("get_components", strong_cc_methods) + def test_connected_mutability(self, get_components): + DG = nx.path_graph(5, create_using=nx.DiGraph) + G = nx.disjoint_union(DG, DG) + seen = set() + for component in get_components(G): + assert len(seen & component) == 0 + seen.update(component) + component.clear() diff --git a/lib/python3.10/site-packages/networkx/algorithms/components/tests/test_weakly_connected.py b/lib/python3.10/site-packages/networkx/algorithms/components/tests/test_weakly_connected.py new file mode 100644 index 0000000000000000000000000000000000000000..f014478930f598b02e6852e3109978288d023dfc --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/components/tests/test_weakly_connected.py @@ -0,0 +1,96 @@ +import pytest + +import networkx as nx +from networkx import NetworkXNotImplemented + + +class TestWeaklyConnected: + @classmethod + def setup_class(cls): + cls.gc = [] + G = nx.DiGraph() + G.add_edges_from( + [ + (1, 2), + (2, 3), + (2, 8), + (3, 4), + (3, 7), + (4, 5), + (5, 3), + (5, 6), + (7, 4), + (7, 6), + (8, 1), + (8, 7), + ] + ) + C = [[3, 4, 5, 7], [1, 2, 8], [6]] + cls.gc.append((G, C)) + + G = nx.DiGraph() + G.add_edges_from([(1, 2), (1, 3), (1, 4), (4, 2), (3, 4), (2, 3)]) + C = [[2, 3, 4], [1]] + cls.gc.append((G, C)) + + G = nx.DiGraph() + G.add_edges_from([(1, 2), (2, 3), (3, 2), (2, 1)]) + C = [[1, 2, 3]] + cls.gc.append((G, C)) + + # Eppstein's tests + G = nx.DiGraph({0: [1], 1: [2, 3], 2: [4, 5], 3: [4, 5], 4: [6], 5: [], 6: []}) + C = [[0], [1], [2], [3], [4], [5], [6]] + cls.gc.append((G, C)) + + G = nx.DiGraph({0: [1], 1: [2, 3, 4], 2: [0, 3], 3: [4], 4: [3]}) + C = [[0, 1, 2], [3, 4]] + cls.gc.append((G, C)) + + def test_weakly_connected_components(self): + for G, C in self.gc: + U = G.to_undirected() + w = {frozenset(g) for g in nx.weakly_connected_components(G)} + c = {frozenset(g) for g in nx.connected_components(U)} + assert w == c + + def test_number_weakly_connected_components(self): + for G, C in self.gc: + U = G.to_undirected() + w = nx.number_weakly_connected_components(G) + c = nx.number_connected_components(U) + assert w == c + + def test_is_weakly_connected(self): + for G, C in self.gc: + U = G.to_undirected() + assert nx.is_weakly_connected(G) == nx.is_connected(U) + + def test_null_graph(self): + G = nx.DiGraph() + assert list(nx.weakly_connected_components(G)) == [] + assert nx.number_weakly_connected_components(G) == 0 + with pytest.raises(nx.NetworkXPointlessConcept): + next(nx.is_weakly_connected(G)) + + def test_connected_raise(self): + G = nx.Graph() + with pytest.raises(NetworkXNotImplemented): + next(nx.weakly_connected_components(G)) + pytest.raises(NetworkXNotImplemented, nx.number_weakly_connected_components, G) + pytest.raises(NetworkXNotImplemented, nx.is_weakly_connected, G) + + def test_connected_mutability(self): + DG = nx.path_graph(5, create_using=nx.DiGraph) + G = nx.disjoint_union(DG, DG) + seen = set() + for component in nx.weakly_connected_components(G): + assert len(seen & component) == 0 + seen.update(component) + component.clear() + + +def test_is_weakly_connected_empty_graph_raises(): + G = nx.DiGraph() + with pytest.raises(nx.NetworkXPointlessConcept, match="Connectivity is undefined"): + nx.is_weakly_connected(G) diff --git a/lib/python3.10/site-packages/networkx/algorithms/components/weakly_connected.py b/lib/python3.10/site-packages/networkx/algorithms/components/weakly_connected.py new file mode 100644 index 0000000000000000000000000000000000000000..ecfac50a75177a7a87e41430953276e29778d6e0 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/components/weakly_connected.py @@ -0,0 +1,197 @@ +"""Weakly connected components.""" + +import networkx as nx +from networkx.utils.decorators import not_implemented_for + +__all__ = [ + "number_weakly_connected_components", + "weakly_connected_components", + "is_weakly_connected", +] + + +@not_implemented_for("undirected") +@nx._dispatchable +def weakly_connected_components(G): + """Generate weakly connected components of G. + + Parameters + ---------- + G : NetworkX graph + A directed graph + + Returns + ------- + comp : generator of sets + A generator of sets of nodes, one for each weakly connected + component of G. + + Raises + ------ + NetworkXNotImplemented + If G is undirected. + + Examples + -------- + Generate a sorted list of weakly connected components, largest first. + + >>> G = nx.path_graph(4, create_using=nx.DiGraph()) + >>> nx.add_path(G, [10, 11, 12]) + >>> [ + ... len(c) + ... for c in sorted(nx.weakly_connected_components(G), key=len, reverse=True) + ... ] + [4, 3] + + If you only want the largest component, it's more efficient to + use max instead of sort: + + >>> largest_cc = max(nx.weakly_connected_components(G), key=len) + + See Also + -------- + connected_components + strongly_connected_components + + Notes + ----- + For directed graphs only. + + """ + seen = set() + n = len(G) # must be outside the loop to avoid performance hit with graph views + for v in G: + if v not in seen: + c = set(_plain_bfs(G, n, v)) + seen.update(c) + yield c + + +@not_implemented_for("undirected") +@nx._dispatchable +def number_weakly_connected_components(G): + """Returns the number of weakly connected components in G. + + Parameters + ---------- + G : NetworkX graph + A directed graph. + + Returns + ------- + n : integer + Number of weakly connected components + + Raises + ------ + NetworkXNotImplemented + If G is undirected. + + Examples + -------- + >>> G = nx.DiGraph([(0, 1), (2, 1), (3, 4)]) + >>> nx.number_weakly_connected_components(G) + 2 + + See Also + -------- + weakly_connected_components + number_connected_components + number_strongly_connected_components + + Notes + ----- + For directed graphs only. + + """ + return sum(1 for wcc in weakly_connected_components(G)) + + +@not_implemented_for("undirected") +@nx._dispatchable +def is_weakly_connected(G): + """Test directed graph for weak connectivity. + + A directed graph is weakly connected if and only if the graph + is connected when the direction of the edge between nodes is ignored. + + Note that if a graph is strongly connected (i.e. the graph is connected + even when we account for directionality), it is by definition weakly + connected as well. + + Parameters + ---------- + G : NetworkX Graph + A directed graph. + + Returns + ------- + connected : bool + True if the graph is weakly connected, False otherwise. + + Raises + ------ + NetworkXNotImplemented + If G is undirected. + + Examples + -------- + >>> G = nx.DiGraph([(0, 1), (2, 1)]) + >>> G.add_node(3) + >>> nx.is_weakly_connected(G) # node 3 is not connected to the graph + False + >>> G.add_edge(2, 3) + >>> nx.is_weakly_connected(G) + True + + See Also + -------- + is_strongly_connected + is_semiconnected + is_connected + is_biconnected + weakly_connected_components + + Notes + ----- + For directed graphs only. + + """ + if len(G) == 0: + raise nx.NetworkXPointlessConcept( + """Connectivity is undefined for the null graph.""" + ) + + return len(next(weakly_connected_components(G))) == len(G) + + +def _plain_bfs(G, n, source): + """A fast BFS node generator + + The direction of the edge between nodes is ignored. + + For directed graphs only. + + """ + Gsucc = G._succ + Gpred = G._pred + seen = {source} + nextlevel = [source] + + yield source + while nextlevel: + thislevel = nextlevel + nextlevel = [] + for v in thislevel: + for w in Gsucc[v]: + if w not in seen: + seen.add(w) + nextlevel.append(w) + yield w + for w in Gpred[v]: + if w not in seen: + seen.add(w) + nextlevel.append(w) + yield w + if len(seen) == n: + return diff --git a/lib/python3.10/site-packages/networkx/algorithms/connectivity/__init__.py b/lib/python3.10/site-packages/networkx/algorithms/connectivity/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d08a360628d4604bb37d350746e5c9796fe31d06 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/connectivity/__init__.py @@ -0,0 +1,11 @@ +"""Connectivity and cut algorithms""" + +from .connectivity import * +from .cuts import * +from .edge_augmentation import * +from .edge_kcomponents import * +from .disjoint_paths import * +from .kcomponents import * +from .kcutsets import * +from .stoerwagner import * +from .utils import * diff --git a/lib/python3.10/site-packages/networkx/algorithms/connectivity/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/connectivity/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a50f12801ceea27894dce48dc2023675f607c8d8 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/connectivity/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/connectivity/__pycache__/connectivity.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/connectivity/__pycache__/connectivity.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9057bd16e137e761b2a498df3f5f0de0fdabb6cd Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/connectivity/__pycache__/connectivity.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/connectivity/__pycache__/cuts.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/connectivity/__pycache__/cuts.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a38e07070ee269e6647b908c8a8782e0f90b081 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/connectivity/__pycache__/cuts.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/connectivity/__pycache__/disjoint_paths.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/connectivity/__pycache__/disjoint_paths.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..607e06fdbdfefe589f270623c30558decf8a23f5 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/connectivity/__pycache__/disjoint_paths.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/connectivity/__pycache__/edge_augmentation.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/connectivity/__pycache__/edge_augmentation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..08f78fa7a2c16c9c64e4a675eda50de0b3f11c70 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/connectivity/__pycache__/edge_augmentation.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/connectivity/__pycache__/edge_kcomponents.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/connectivity/__pycache__/edge_kcomponents.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3626fe5352165853826c4e66e026b65e29ff55a9 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/connectivity/__pycache__/edge_kcomponents.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/connectivity/__pycache__/kcomponents.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/connectivity/__pycache__/kcomponents.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee7e29746b5db651df43c35b7f52798d5fdb9d7c Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/connectivity/__pycache__/kcomponents.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/connectivity/__pycache__/kcutsets.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/connectivity/__pycache__/kcutsets.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b7658776262d4f07c060c1986ab42304e25f18a0 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/connectivity/__pycache__/kcutsets.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/connectivity/__pycache__/stoerwagner.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/connectivity/__pycache__/stoerwagner.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f6cb263e3134b9c6c0738deae1fdb05b70bc1e75 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/connectivity/__pycache__/stoerwagner.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/connectivity/__pycache__/utils.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/connectivity/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9375f960a5eeead490e0b470d410513f56718038 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/connectivity/__pycache__/utils.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/connectivity/connectivity.py b/lib/python3.10/site-packages/networkx/algorithms/connectivity/connectivity.py new file mode 100644 index 0000000000000000000000000000000000000000..2f85c865d678157d236a6e1168fabcdc7192b0ce --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/connectivity/connectivity.py @@ -0,0 +1,811 @@ +""" +Flow based connectivity algorithms +""" + +import itertools +from operator import itemgetter + +import networkx as nx + +# Define the default maximum flow function to use in all flow based +# connectivity algorithms. +from networkx.algorithms.flow import ( + boykov_kolmogorov, + build_residual_network, + dinitz, + edmonds_karp, + preflow_push, + shortest_augmenting_path, +) + +default_flow_func = edmonds_karp + +from .utils import build_auxiliary_edge_connectivity, build_auxiliary_node_connectivity + +__all__ = [ + "average_node_connectivity", + "local_node_connectivity", + "node_connectivity", + "local_edge_connectivity", + "edge_connectivity", + "all_pairs_node_connectivity", +] + + +@nx._dispatchable(graphs={"G": 0, "auxiliary?": 4}, preserve_graph_attrs={"auxiliary"}) +def local_node_connectivity( + G, s, t, flow_func=None, auxiliary=None, residual=None, cutoff=None +): + r"""Computes local node connectivity for nodes s and t. + + Local node connectivity for two non adjacent nodes s and t is the + minimum number of nodes that must be removed (along with their incident + edges) to disconnect them. + + This is a flow based implementation of node connectivity. We compute the + maximum flow on an auxiliary digraph build from the original input + graph (see below for details). + + Parameters + ---------- + G : NetworkX graph + Undirected graph + + s : node + Source node + + t : node + Target node + + flow_func : function + A function for computing the maximum flow among a pair of nodes. + The function has to accept at least three parameters: a Digraph, + a source node, and a target node. And return a residual network + that follows NetworkX conventions (see :meth:`maximum_flow` for + details). If flow_func is None, the default maximum flow function + (:meth:`edmonds_karp`) is used. See below for details. The choice + of the default function may change from version to version and + should not be relied on. Default value: None. + + auxiliary : NetworkX DiGraph + Auxiliary digraph to compute flow based node connectivity. It has + to have a graph attribute called mapping with a dictionary mapping + node names in G and in the auxiliary digraph. If provided + it will be reused instead of recreated. Default value: None. + + residual : NetworkX DiGraph + Residual network to compute maximum flow. If provided it will be + reused instead of recreated. Default value: None. + + cutoff : integer, float, or None (default: None) + If specified, the maximum flow algorithm will terminate when the + flow value reaches or exceeds the cutoff. This only works for flows + that support the cutoff parameter (most do) and is ignored otherwise. + + Returns + ------- + K : integer + local node connectivity for nodes s and t + + Examples + -------- + This function is not imported in the base NetworkX namespace, so you + have to explicitly import it from the connectivity package: + + >>> from networkx.algorithms.connectivity import local_node_connectivity + + We use in this example the platonic icosahedral graph, which has node + connectivity 5. + + >>> G = nx.icosahedral_graph() + >>> local_node_connectivity(G, 0, 6) + 5 + + If you need to compute local connectivity on several pairs of + nodes in the same graph, it is recommended that you reuse the + data structures that NetworkX uses in the computation: the + auxiliary digraph for node connectivity, and the residual + network for the underlying maximum flow computation. + + Example of how to compute local node connectivity among + all pairs of nodes of the platonic icosahedral graph reusing + the data structures. + + >>> import itertools + >>> # You also have to explicitly import the function for + >>> # building the auxiliary digraph from the connectivity package + >>> from networkx.algorithms.connectivity import build_auxiliary_node_connectivity + >>> H = build_auxiliary_node_connectivity(G) + >>> # And the function for building the residual network from the + >>> # flow package + >>> from networkx.algorithms.flow import build_residual_network + >>> # Note that the auxiliary digraph has an edge attribute named capacity + >>> R = build_residual_network(H, "capacity") + >>> result = dict.fromkeys(G, dict()) + >>> # Reuse the auxiliary digraph and the residual network by passing them + >>> # as parameters + >>> for u, v in itertools.combinations(G, 2): + ... k = local_node_connectivity(G, u, v, auxiliary=H, residual=R) + ... result[u][v] = k + >>> all(result[u][v] == 5 for u, v in itertools.combinations(G, 2)) + True + + You can also use alternative flow algorithms for computing node + connectivity. For instance, in dense networks the algorithm + :meth:`shortest_augmenting_path` will usually perform better than + the default :meth:`edmonds_karp` which is faster for sparse + networks with highly skewed degree distributions. Alternative flow + functions have to be explicitly imported from the flow package. + + >>> from networkx.algorithms.flow import shortest_augmenting_path + >>> local_node_connectivity(G, 0, 6, flow_func=shortest_augmenting_path) + 5 + + Notes + ----- + This is a flow based implementation of node connectivity. We compute the + maximum flow using, by default, the :meth:`edmonds_karp` algorithm (see: + :meth:`maximum_flow`) on an auxiliary digraph build from the original + input graph: + + For an undirected graph G having `n` nodes and `m` edges we derive a + directed graph H with `2n` nodes and `2m+n` arcs by replacing each + original node `v` with two nodes `v_A`, `v_B` linked by an (internal) + arc in H. Then for each edge (`u`, `v`) in G we add two arcs + (`u_B`, `v_A`) and (`v_B`, `u_A`) in H. Finally we set the attribute + capacity = 1 for each arc in H [1]_ . + + For a directed graph G having `n` nodes and `m` arcs we derive a + directed graph H with `2n` nodes and `m+n` arcs by replacing each + original node `v` with two nodes `v_A`, `v_B` linked by an (internal) + arc (`v_A`, `v_B`) in H. Then for each arc (`u`, `v`) in G we add one arc + (`u_B`, `v_A`) in H. Finally we set the attribute capacity = 1 for + each arc in H. + + This is equal to the local node connectivity because the value of + a maximum s-t-flow is equal to the capacity of a minimum s-t-cut. + + See also + -------- + :meth:`local_edge_connectivity` + :meth:`node_connectivity` + :meth:`minimum_node_cut` + :meth:`maximum_flow` + :meth:`edmonds_karp` + :meth:`preflow_push` + :meth:`shortest_augmenting_path` + + References + ---------- + .. [1] Kammer, Frank and Hanjo Taubig. Graph Connectivity. in Brandes and + Erlebach, 'Network Analysis: Methodological Foundations', Lecture + Notes in Computer Science, Volume 3418, Springer-Verlag, 2005. + http://www.informatik.uni-augsburg.de/thi/personen/kammer/Graph_Connectivity.pdf + + """ + if flow_func is None: + flow_func = default_flow_func + + if auxiliary is None: + H = build_auxiliary_node_connectivity(G) + else: + H = auxiliary + + mapping = H.graph.get("mapping", None) + if mapping is None: + raise nx.NetworkXError("Invalid auxiliary digraph.") + + kwargs = {"flow_func": flow_func, "residual": residual} + + if flow_func is not preflow_push: + kwargs["cutoff"] = cutoff + + if flow_func is shortest_augmenting_path: + kwargs["two_phase"] = True + + return nx.maximum_flow_value(H, f"{mapping[s]}B", f"{mapping[t]}A", **kwargs) + + +@nx._dispatchable +def node_connectivity(G, s=None, t=None, flow_func=None): + r"""Returns node connectivity for a graph or digraph G. + + Node connectivity is equal to the minimum number of nodes that + must be removed to disconnect G or render it trivial. If source + and target nodes are provided, this function returns the local node + connectivity: the minimum number of nodes that must be removed to break + all paths from source to target in G. + + Parameters + ---------- + G : NetworkX graph + Undirected graph + + s : node + Source node. Optional. Default value: None. + + t : node + Target node. Optional. Default value: None. + + flow_func : function + A function for computing the maximum flow among a pair of nodes. + The function has to accept at least three parameters: a Digraph, + a source node, and a target node. And return a residual network + that follows NetworkX conventions (see :meth:`maximum_flow` for + details). If flow_func is None, the default maximum flow function + (:meth:`edmonds_karp`) is used. See below for details. The + choice of the default function may change from version + to version and should not be relied on. Default value: None. + + Returns + ------- + K : integer + Node connectivity of G, or local node connectivity if source + and target are provided. + + Examples + -------- + >>> # Platonic icosahedral graph is 5-node-connected + >>> G = nx.icosahedral_graph() + >>> nx.node_connectivity(G) + 5 + + You can use alternative flow algorithms for the underlying maximum + flow computation. In dense networks the algorithm + :meth:`shortest_augmenting_path` will usually perform better + than the default :meth:`edmonds_karp`, which is faster for + sparse networks with highly skewed degree distributions. Alternative + flow functions have to be explicitly imported from the flow package. + + >>> from networkx.algorithms.flow import shortest_augmenting_path + >>> nx.node_connectivity(G, flow_func=shortest_augmenting_path) + 5 + + If you specify a pair of nodes (source and target) as parameters, + this function returns the value of local node connectivity. + + >>> nx.node_connectivity(G, 3, 7) + 5 + + If you need to perform several local computations among different + pairs of nodes on the same graph, it is recommended that you reuse + the data structures used in the maximum flow computations. See + :meth:`local_node_connectivity` for details. + + Notes + ----- + This is a flow based implementation of node connectivity. The + algorithm works by solving $O((n-\delta-1+\delta(\delta-1)/2))$ + maximum flow problems on an auxiliary digraph. Where $\delta$ + is the minimum degree of G. For details about the auxiliary + digraph and the computation of local node connectivity see + :meth:`local_node_connectivity`. This implementation is based + on algorithm 11 in [1]_. + + See also + -------- + :meth:`local_node_connectivity` + :meth:`edge_connectivity` + :meth:`maximum_flow` + :meth:`edmonds_karp` + :meth:`preflow_push` + :meth:`shortest_augmenting_path` + + References + ---------- + .. [1] Abdol-Hossein Esfahanian. Connectivity Algorithms. + http://www.cse.msu.edu/~cse835/Papers/Graph_connectivity_revised.pdf + + """ + if (s is not None and t is None) or (s is None and t is not None): + raise nx.NetworkXError("Both source and target must be specified.") + + # Local node connectivity + if s is not None and t is not None: + if s not in G: + raise nx.NetworkXError(f"node {s} not in graph") + if t not in G: + raise nx.NetworkXError(f"node {t} not in graph") + return local_node_connectivity(G, s, t, flow_func=flow_func) + + # Global node connectivity + if G.is_directed(): + if not nx.is_weakly_connected(G): + return 0 + iter_func = itertools.permutations + # It is necessary to consider both predecessors + # and successors for directed graphs + + def neighbors(v): + return itertools.chain.from_iterable([G.predecessors(v), G.successors(v)]) + + else: + if not nx.is_connected(G): + return 0 + iter_func = itertools.combinations + neighbors = G.neighbors + + # Reuse the auxiliary digraph and the residual network + H = build_auxiliary_node_connectivity(G) + R = build_residual_network(H, "capacity") + kwargs = {"flow_func": flow_func, "auxiliary": H, "residual": R} + + # Pick a node with minimum degree + # Node connectivity is bounded by degree. + v, K = min(G.degree(), key=itemgetter(1)) + # compute local node connectivity with all its non-neighbors nodes + for w in set(G) - set(neighbors(v)) - {v}: + kwargs["cutoff"] = K + K = min(K, local_node_connectivity(G, v, w, **kwargs)) + # Also for non adjacent pairs of neighbors of v + for x, y in iter_func(neighbors(v), 2): + if y in G[x]: + continue + kwargs["cutoff"] = K + K = min(K, local_node_connectivity(G, x, y, **kwargs)) + + return K + + +@nx._dispatchable +def average_node_connectivity(G, flow_func=None): + r"""Returns the average connectivity of a graph G. + + The average connectivity `\bar{\kappa}` of a graph G is the average + of local node connectivity over all pairs of nodes of G [1]_ . + + .. math:: + + \bar{\kappa}(G) = \frac{\sum_{u,v} \kappa_{G}(u,v)}{{n \choose 2}} + + Parameters + ---------- + + G : NetworkX graph + Undirected graph + + flow_func : function + A function for computing the maximum flow among a pair of nodes. + The function has to accept at least three parameters: a Digraph, + a source node, and a target node. And return a residual network + that follows NetworkX conventions (see :meth:`maximum_flow` for + details). If flow_func is None, the default maximum flow function + (:meth:`edmonds_karp`) is used. See :meth:`local_node_connectivity` + for details. The choice of the default function may change from + version to version and should not be relied on. Default value: None. + + Returns + ------- + K : float + Average node connectivity + + See also + -------- + :meth:`local_node_connectivity` + :meth:`node_connectivity` + :meth:`edge_connectivity` + :meth:`maximum_flow` + :meth:`edmonds_karp` + :meth:`preflow_push` + :meth:`shortest_augmenting_path` + + References + ---------- + .. [1] Beineke, L., O. Oellermann, and R. Pippert (2002). The average + connectivity of a graph. Discrete mathematics 252(1-3), 31-45. + http://www.sciencedirect.com/science/article/pii/S0012365X01001807 + + """ + if G.is_directed(): + iter_func = itertools.permutations + else: + iter_func = itertools.combinations + + # Reuse the auxiliary digraph and the residual network + H = build_auxiliary_node_connectivity(G) + R = build_residual_network(H, "capacity") + kwargs = {"flow_func": flow_func, "auxiliary": H, "residual": R} + + num, den = 0, 0 + for u, v in iter_func(G, 2): + num += local_node_connectivity(G, u, v, **kwargs) + den += 1 + + if den == 0: # Null Graph + return 0 + return num / den + + +@nx._dispatchable +def all_pairs_node_connectivity(G, nbunch=None, flow_func=None): + """Compute node connectivity between all pairs of nodes of G. + + Parameters + ---------- + G : NetworkX graph + Undirected graph + + nbunch: container + Container of nodes. If provided node connectivity will be computed + only over pairs of nodes in nbunch. + + flow_func : function + A function for computing the maximum flow among a pair of nodes. + The function has to accept at least three parameters: a Digraph, + a source node, and a target node. And return a residual network + that follows NetworkX conventions (see :meth:`maximum_flow` for + details). If flow_func is None, the default maximum flow function + (:meth:`edmonds_karp`) is used. See below for details. The + choice of the default function may change from version + to version and should not be relied on. Default value: None. + + Returns + ------- + all_pairs : dict + A dictionary with node connectivity between all pairs of nodes + in G, or in nbunch if provided. + + See also + -------- + :meth:`local_node_connectivity` + :meth:`edge_connectivity` + :meth:`local_edge_connectivity` + :meth:`maximum_flow` + :meth:`edmonds_karp` + :meth:`preflow_push` + :meth:`shortest_augmenting_path` + + """ + if nbunch is None: + nbunch = G + else: + nbunch = set(nbunch) + + directed = G.is_directed() + if directed: + iter_func = itertools.permutations + else: + iter_func = itertools.combinations + + all_pairs = {n: {} for n in nbunch} + + # Reuse auxiliary digraph and residual network + H = build_auxiliary_node_connectivity(G) + mapping = H.graph["mapping"] + R = build_residual_network(H, "capacity") + kwargs = {"flow_func": flow_func, "auxiliary": H, "residual": R} + + for u, v in iter_func(nbunch, 2): + K = local_node_connectivity(G, u, v, **kwargs) + all_pairs[u][v] = K + if not directed: + all_pairs[v][u] = K + + return all_pairs + + +@nx._dispatchable(graphs={"G": 0, "auxiliary?": 4}) +def local_edge_connectivity( + G, s, t, flow_func=None, auxiliary=None, residual=None, cutoff=None +): + r"""Returns local edge connectivity for nodes s and t in G. + + Local edge connectivity for two nodes s and t is the minimum number + of edges that must be removed to disconnect them. + + This is a flow based implementation of edge connectivity. We compute the + maximum flow on an auxiliary digraph build from the original + network (see below for details). This is equal to the local edge + connectivity because the value of a maximum s-t-flow is equal to the + capacity of a minimum s-t-cut (Ford and Fulkerson theorem) [1]_ . + + Parameters + ---------- + G : NetworkX graph + Undirected or directed graph + + s : node + Source node + + t : node + Target node + + flow_func : function + A function for computing the maximum flow among a pair of nodes. + The function has to accept at least three parameters: a Digraph, + a source node, and a target node. And return a residual network + that follows NetworkX conventions (see :meth:`maximum_flow` for + details). If flow_func is None, the default maximum flow function + (:meth:`edmonds_karp`) is used. See below for details. The + choice of the default function may change from version + to version and should not be relied on. Default value: None. + + auxiliary : NetworkX DiGraph + Auxiliary digraph for computing flow based edge connectivity. If + provided it will be reused instead of recreated. Default value: None. + + residual : NetworkX DiGraph + Residual network to compute maximum flow. If provided it will be + reused instead of recreated. Default value: None. + + cutoff : integer, float, or None (default: None) + If specified, the maximum flow algorithm will terminate when the + flow value reaches or exceeds the cutoff. This only works for flows + that support the cutoff parameter (most do) and is ignored otherwise. + + Returns + ------- + K : integer + local edge connectivity for nodes s and t. + + Examples + -------- + This function is not imported in the base NetworkX namespace, so you + have to explicitly import it from the connectivity package: + + >>> from networkx.algorithms.connectivity import local_edge_connectivity + + We use in this example the platonic icosahedral graph, which has edge + connectivity 5. + + >>> G = nx.icosahedral_graph() + >>> local_edge_connectivity(G, 0, 6) + 5 + + If you need to compute local connectivity on several pairs of + nodes in the same graph, it is recommended that you reuse the + data structures that NetworkX uses in the computation: the + auxiliary digraph for edge connectivity, and the residual + network for the underlying maximum flow computation. + + Example of how to compute local edge connectivity among + all pairs of nodes of the platonic icosahedral graph reusing + the data structures. + + >>> import itertools + >>> # You also have to explicitly import the function for + >>> # building the auxiliary digraph from the connectivity package + >>> from networkx.algorithms.connectivity import build_auxiliary_edge_connectivity + >>> H = build_auxiliary_edge_connectivity(G) + >>> # And the function for building the residual network from the + >>> # flow package + >>> from networkx.algorithms.flow import build_residual_network + >>> # Note that the auxiliary digraph has an edge attribute named capacity + >>> R = build_residual_network(H, "capacity") + >>> result = dict.fromkeys(G, dict()) + >>> # Reuse the auxiliary digraph and the residual network by passing them + >>> # as parameters + >>> for u, v in itertools.combinations(G, 2): + ... k = local_edge_connectivity(G, u, v, auxiliary=H, residual=R) + ... result[u][v] = k + >>> all(result[u][v] == 5 for u, v in itertools.combinations(G, 2)) + True + + You can also use alternative flow algorithms for computing edge + connectivity. For instance, in dense networks the algorithm + :meth:`shortest_augmenting_path` will usually perform better than + the default :meth:`edmonds_karp` which is faster for sparse + networks with highly skewed degree distributions. Alternative flow + functions have to be explicitly imported from the flow package. + + >>> from networkx.algorithms.flow import shortest_augmenting_path + >>> local_edge_connectivity(G, 0, 6, flow_func=shortest_augmenting_path) + 5 + + Notes + ----- + This is a flow based implementation of edge connectivity. We compute the + maximum flow using, by default, the :meth:`edmonds_karp` algorithm on an + auxiliary digraph build from the original input graph: + + If the input graph is undirected, we replace each edge (`u`,`v`) with + two reciprocal arcs (`u`, `v`) and (`v`, `u`) and then we set the attribute + 'capacity' for each arc to 1. If the input graph is directed we simply + add the 'capacity' attribute. This is an implementation of algorithm 1 + in [1]_. + + The maximum flow in the auxiliary network is equal to the local edge + connectivity because the value of a maximum s-t-flow is equal to the + capacity of a minimum s-t-cut (Ford and Fulkerson theorem). + + See also + -------- + :meth:`edge_connectivity` + :meth:`local_node_connectivity` + :meth:`node_connectivity` + :meth:`maximum_flow` + :meth:`edmonds_karp` + :meth:`preflow_push` + :meth:`shortest_augmenting_path` + + References + ---------- + .. [1] Abdol-Hossein Esfahanian. Connectivity Algorithms. + http://www.cse.msu.edu/~cse835/Papers/Graph_connectivity_revised.pdf + + """ + if flow_func is None: + flow_func = default_flow_func + + if auxiliary is None: + H = build_auxiliary_edge_connectivity(G) + else: + H = auxiliary + + kwargs = {"flow_func": flow_func, "residual": residual} + + if flow_func is not preflow_push: + kwargs["cutoff"] = cutoff + + if flow_func is shortest_augmenting_path: + kwargs["two_phase"] = True + + return nx.maximum_flow_value(H, s, t, **kwargs) + + +@nx._dispatchable +def edge_connectivity(G, s=None, t=None, flow_func=None, cutoff=None): + r"""Returns the edge connectivity of the graph or digraph G. + + The edge connectivity is equal to the minimum number of edges that + must be removed to disconnect G or render it trivial. If source + and target nodes are provided, this function returns the local edge + connectivity: the minimum number of edges that must be removed to + break all paths from source to target in G. + + Parameters + ---------- + G : NetworkX graph + Undirected or directed graph + + s : node + Source node. Optional. Default value: None. + + t : node + Target node. Optional. Default value: None. + + flow_func : function + A function for computing the maximum flow among a pair of nodes. + The function has to accept at least three parameters: a Digraph, + a source node, and a target node. And return a residual network + that follows NetworkX conventions (see :meth:`maximum_flow` for + details). If flow_func is None, the default maximum flow function + (:meth:`edmonds_karp`) is used. See below for details. The + choice of the default function may change from version + to version and should not be relied on. Default value: None. + + cutoff : integer, float, or None (default: None) + If specified, the maximum flow algorithm will terminate when the + flow value reaches or exceeds the cutoff. This only works for flows + that support the cutoff parameter (most do) and is ignored otherwise. + + Returns + ------- + K : integer + Edge connectivity for G, or local edge connectivity if source + and target were provided + + Examples + -------- + >>> # Platonic icosahedral graph is 5-edge-connected + >>> G = nx.icosahedral_graph() + >>> nx.edge_connectivity(G) + 5 + + You can use alternative flow algorithms for the underlying + maximum flow computation. In dense networks the algorithm + :meth:`shortest_augmenting_path` will usually perform better + than the default :meth:`edmonds_karp`, which is faster for + sparse networks with highly skewed degree distributions. + Alternative flow functions have to be explicitly imported + from the flow package. + + >>> from networkx.algorithms.flow import shortest_augmenting_path + >>> nx.edge_connectivity(G, flow_func=shortest_augmenting_path) + 5 + + If you specify a pair of nodes (source and target) as parameters, + this function returns the value of local edge connectivity. + + >>> nx.edge_connectivity(G, 3, 7) + 5 + + If you need to perform several local computations among different + pairs of nodes on the same graph, it is recommended that you reuse + the data structures used in the maximum flow computations. See + :meth:`local_edge_connectivity` for details. + + Notes + ----- + This is a flow based implementation of global edge connectivity. + For undirected graphs the algorithm works by finding a 'small' + dominating set of nodes of G (see algorithm 7 in [1]_ ) and + computing local maximum flow (see :meth:`local_edge_connectivity`) + between an arbitrary node in the dominating set and the rest of + nodes in it. This is an implementation of algorithm 6 in [1]_ . + For directed graphs, the algorithm does n calls to the maximum + flow function. This is an implementation of algorithm 8 in [1]_ . + + See also + -------- + :meth:`local_edge_connectivity` + :meth:`local_node_connectivity` + :meth:`node_connectivity` + :meth:`maximum_flow` + :meth:`edmonds_karp` + :meth:`preflow_push` + :meth:`shortest_augmenting_path` + :meth:`k_edge_components` + :meth:`k_edge_subgraphs` + + References + ---------- + .. [1] Abdol-Hossein Esfahanian. Connectivity Algorithms. + http://www.cse.msu.edu/~cse835/Papers/Graph_connectivity_revised.pdf + + """ + if (s is not None and t is None) or (s is None and t is not None): + raise nx.NetworkXError("Both source and target must be specified.") + + # Local edge connectivity + if s is not None and t is not None: + if s not in G: + raise nx.NetworkXError(f"node {s} not in graph") + if t not in G: + raise nx.NetworkXError(f"node {t} not in graph") + return local_edge_connectivity(G, s, t, flow_func=flow_func, cutoff=cutoff) + + # Global edge connectivity + # reuse auxiliary digraph and residual network + H = build_auxiliary_edge_connectivity(G) + R = build_residual_network(H, "capacity") + kwargs = {"flow_func": flow_func, "auxiliary": H, "residual": R} + + if G.is_directed(): + # Algorithm 8 in [1] + if not nx.is_weakly_connected(G): + return 0 + + # initial value for \lambda is minimum degree + L = min(d for n, d in G.degree()) + nodes = list(G) + n = len(nodes) + + if cutoff is not None: + L = min(cutoff, L) + + for i in range(n): + kwargs["cutoff"] = L + try: + L = min(L, local_edge_connectivity(G, nodes[i], nodes[i + 1], **kwargs)) + except IndexError: # last node! + L = min(L, local_edge_connectivity(G, nodes[i], nodes[0], **kwargs)) + return L + else: # undirected + # Algorithm 6 in [1] + if not nx.is_connected(G): + return 0 + + # initial value for \lambda is minimum degree + L = min(d for n, d in G.degree()) + + if cutoff is not None: + L = min(cutoff, L) + + # A dominating set is \lambda-covering + # We need a dominating set with at least two nodes + for node in G: + D = nx.dominating_set(G, start_with=node) + v = D.pop() + if D: + break + else: + # in complete graphs the dominating sets will always be of one node + # thus we return min degree + return L + + for w in D: + kwargs["cutoff"] = L + L = min(L, local_edge_connectivity(G, v, w, **kwargs)) + + return L diff --git a/lib/python3.10/site-packages/networkx/algorithms/connectivity/cuts.py b/lib/python3.10/site-packages/networkx/algorithms/connectivity/cuts.py new file mode 100644 index 0000000000000000000000000000000000000000..27124e1bdd6fece3a1b4fad13cbf6529f1cc8d71 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/connectivity/cuts.py @@ -0,0 +1,612 @@ +""" +Flow based cut algorithms +""" + +import itertools + +import networkx as nx + +# Define the default maximum flow function to use in all flow based +# cut algorithms. +from networkx.algorithms.flow import build_residual_network, edmonds_karp + +default_flow_func = edmonds_karp + +from .utils import build_auxiliary_edge_connectivity, build_auxiliary_node_connectivity + +__all__ = [ + "minimum_st_node_cut", + "minimum_node_cut", + "minimum_st_edge_cut", + "minimum_edge_cut", +] + + +@nx._dispatchable( + graphs={"G": 0, "auxiliary?": 4}, + preserve_edge_attrs={"auxiliary": {"capacity": float("inf")}}, + preserve_graph_attrs={"auxiliary"}, +) +def minimum_st_edge_cut(G, s, t, flow_func=None, auxiliary=None, residual=None): + """Returns the edges of the cut-set of a minimum (s, t)-cut. + + This function returns the set of edges of minimum cardinality that, + if removed, would destroy all paths among source and target in G. + Edge weights are not considered. See :meth:`minimum_cut` for + computing minimum cuts considering edge weights. + + Parameters + ---------- + G : NetworkX graph + + s : node + Source node for the flow. + + t : node + Sink node for the flow. + + auxiliary : NetworkX DiGraph + Auxiliary digraph to compute flow based node connectivity. It has + to have a graph attribute called mapping with a dictionary mapping + node names in G and in the auxiliary digraph. If provided + it will be reused instead of recreated. Default value: None. + + flow_func : function + A function for computing the maximum flow among a pair of nodes. + The function has to accept at least three parameters: a Digraph, + a source node, and a target node. And return a residual network + that follows NetworkX conventions (see :meth:`maximum_flow` for + details). If flow_func is None, the default maximum flow function + (:meth:`edmonds_karp`) is used. See :meth:`node_connectivity` for + details. The choice of the default function may change from version + to version and should not be relied on. Default value: None. + + residual : NetworkX DiGraph + Residual network to compute maximum flow. If provided it will be + reused instead of recreated. Default value: None. + + Returns + ------- + cutset : set + Set of edges that, if removed from the graph, will disconnect it. + + See also + -------- + :meth:`minimum_cut` + :meth:`minimum_node_cut` + :meth:`minimum_edge_cut` + :meth:`stoer_wagner` + :meth:`node_connectivity` + :meth:`edge_connectivity` + :meth:`maximum_flow` + :meth:`edmonds_karp` + :meth:`preflow_push` + :meth:`shortest_augmenting_path` + + Examples + -------- + This function is not imported in the base NetworkX namespace, so you + have to explicitly import it from the connectivity package: + + >>> from networkx.algorithms.connectivity import minimum_st_edge_cut + + We use in this example the platonic icosahedral graph, which has edge + connectivity 5. + + >>> G = nx.icosahedral_graph() + >>> len(minimum_st_edge_cut(G, 0, 6)) + 5 + + If you need to compute local edge cuts on several pairs of + nodes in the same graph, it is recommended that you reuse the + data structures that NetworkX uses in the computation: the + auxiliary digraph for edge connectivity, and the residual + network for the underlying maximum flow computation. + + Example of how to compute local edge cuts among all pairs of + nodes of the platonic icosahedral graph reusing the data + structures. + + >>> import itertools + >>> # You also have to explicitly import the function for + >>> # building the auxiliary digraph from the connectivity package + >>> from networkx.algorithms.connectivity import build_auxiliary_edge_connectivity + >>> H = build_auxiliary_edge_connectivity(G) + >>> # And the function for building the residual network from the + >>> # flow package + >>> from networkx.algorithms.flow import build_residual_network + >>> # Note that the auxiliary digraph has an edge attribute named capacity + >>> R = build_residual_network(H, "capacity") + >>> result = dict.fromkeys(G, dict()) + >>> # Reuse the auxiliary digraph and the residual network by passing them + >>> # as parameters + >>> for u, v in itertools.combinations(G, 2): + ... k = len(minimum_st_edge_cut(G, u, v, auxiliary=H, residual=R)) + ... result[u][v] = k + >>> all(result[u][v] == 5 for u, v in itertools.combinations(G, 2)) + True + + You can also use alternative flow algorithms for computing edge + cuts. For instance, in dense networks the algorithm + :meth:`shortest_augmenting_path` will usually perform better than + the default :meth:`edmonds_karp` which is faster for sparse + networks with highly skewed degree distributions. Alternative flow + functions have to be explicitly imported from the flow package. + + >>> from networkx.algorithms.flow import shortest_augmenting_path + >>> len(minimum_st_edge_cut(G, 0, 6, flow_func=shortest_augmenting_path)) + 5 + + """ + if flow_func is None: + flow_func = default_flow_func + + if auxiliary is None: + H = build_auxiliary_edge_connectivity(G) + else: + H = auxiliary + + kwargs = {"capacity": "capacity", "flow_func": flow_func, "residual": residual} + + cut_value, partition = nx.minimum_cut(H, s, t, **kwargs) + reachable, non_reachable = partition + # Any edge in the original graph linking the two sets in the + # partition is part of the edge cutset + cutset = set() + for u, nbrs in ((n, G[n]) for n in reachable): + cutset.update((u, v) for v in nbrs if v in non_reachable) + + return cutset + + +@nx._dispatchable( + graphs={"G": 0, "auxiliary?": 4}, + preserve_node_attrs={"auxiliary": {"id": None}}, + preserve_graph_attrs={"auxiliary"}, +) +def minimum_st_node_cut(G, s, t, flow_func=None, auxiliary=None, residual=None): + r"""Returns a set of nodes of minimum cardinality that disconnect source + from target in G. + + This function returns the set of nodes of minimum cardinality that, + if removed, would destroy all paths among source and target in G. + + Parameters + ---------- + G : NetworkX graph + + s : node + Source node. + + t : node + Target node. + + flow_func : function + A function for computing the maximum flow among a pair of nodes. + The function has to accept at least three parameters: a Digraph, + a source node, and a target node. And return a residual network + that follows NetworkX conventions (see :meth:`maximum_flow` for + details). If flow_func is None, the default maximum flow function + (:meth:`edmonds_karp`) is used. See below for details. The choice + of the default function may change from version to version and + should not be relied on. Default value: None. + + auxiliary : NetworkX DiGraph + Auxiliary digraph to compute flow based node connectivity. It has + to have a graph attribute called mapping with a dictionary mapping + node names in G and in the auxiliary digraph. If provided + it will be reused instead of recreated. Default value: None. + + residual : NetworkX DiGraph + Residual network to compute maximum flow. If provided it will be + reused instead of recreated. Default value: None. + + Returns + ------- + cutset : set + Set of nodes that, if removed, would destroy all paths between + source and target in G. + + Examples + -------- + This function is not imported in the base NetworkX namespace, so you + have to explicitly import it from the connectivity package: + + >>> from networkx.algorithms.connectivity import minimum_st_node_cut + + We use in this example the platonic icosahedral graph, which has node + connectivity 5. + + >>> G = nx.icosahedral_graph() + >>> len(minimum_st_node_cut(G, 0, 6)) + 5 + + If you need to compute local st cuts between several pairs of + nodes in the same graph, it is recommended that you reuse the + data structures that NetworkX uses in the computation: the + auxiliary digraph for node connectivity and node cuts, and the + residual network for the underlying maximum flow computation. + + Example of how to compute local st node cuts reusing the data + structures: + + >>> # You also have to explicitly import the function for + >>> # building the auxiliary digraph from the connectivity package + >>> from networkx.algorithms.connectivity import build_auxiliary_node_connectivity + >>> H = build_auxiliary_node_connectivity(G) + >>> # And the function for building the residual network from the + >>> # flow package + >>> from networkx.algorithms.flow import build_residual_network + >>> # Note that the auxiliary digraph has an edge attribute named capacity + >>> R = build_residual_network(H, "capacity") + >>> # Reuse the auxiliary digraph and the residual network by passing them + >>> # as parameters + >>> len(minimum_st_node_cut(G, 0, 6, auxiliary=H, residual=R)) + 5 + + You can also use alternative flow algorithms for computing minimum st + node cuts. For instance, in dense networks the algorithm + :meth:`shortest_augmenting_path` will usually perform better than + the default :meth:`edmonds_karp` which is faster for sparse + networks with highly skewed degree distributions. Alternative flow + functions have to be explicitly imported from the flow package. + + >>> from networkx.algorithms.flow import shortest_augmenting_path + >>> len(minimum_st_node_cut(G, 0, 6, flow_func=shortest_augmenting_path)) + 5 + + Notes + ----- + This is a flow based implementation of minimum node cut. The algorithm + is based in solving a number of maximum flow computations to determine + the capacity of the minimum cut on an auxiliary directed network that + corresponds to the minimum node cut of G. It handles both directed + and undirected graphs. This implementation is based on algorithm 11 + in [1]_. + + See also + -------- + :meth:`minimum_node_cut` + :meth:`minimum_edge_cut` + :meth:`stoer_wagner` + :meth:`node_connectivity` + :meth:`edge_connectivity` + :meth:`maximum_flow` + :meth:`edmonds_karp` + :meth:`preflow_push` + :meth:`shortest_augmenting_path` + + References + ---------- + .. [1] Abdol-Hossein Esfahanian. Connectivity Algorithms. + http://www.cse.msu.edu/~cse835/Papers/Graph_connectivity_revised.pdf + + """ + if auxiliary is None: + H = build_auxiliary_node_connectivity(G) + else: + H = auxiliary + + mapping = H.graph.get("mapping", None) + if mapping is None: + raise nx.NetworkXError("Invalid auxiliary digraph.") + if G.has_edge(s, t) or G.has_edge(t, s): + return {} + kwargs = {"flow_func": flow_func, "residual": residual, "auxiliary": H} + + # The edge cut in the auxiliary digraph corresponds to the node cut in the + # original graph. + edge_cut = minimum_st_edge_cut(H, f"{mapping[s]}B", f"{mapping[t]}A", **kwargs) + # Each node in the original graph maps to two nodes of the auxiliary graph + node_cut = {H.nodes[node]["id"] for edge in edge_cut for node in edge} + return node_cut - {s, t} + + +@nx._dispatchable +def minimum_node_cut(G, s=None, t=None, flow_func=None): + r"""Returns a set of nodes of minimum cardinality that disconnects G. + + If source and target nodes are provided, this function returns the + set of nodes of minimum cardinality that, if removed, would destroy + all paths among source and target in G. If not, it returns a set + of nodes of minimum cardinality that disconnects G. + + Parameters + ---------- + G : NetworkX graph + + s : node + Source node. Optional. Default value: None. + + t : node + Target node. Optional. Default value: None. + + flow_func : function + A function for computing the maximum flow among a pair of nodes. + The function has to accept at least three parameters: a Digraph, + a source node, and a target node. And return a residual network + that follows NetworkX conventions (see :meth:`maximum_flow` for + details). If flow_func is None, the default maximum flow function + (:meth:`edmonds_karp`) is used. See below for details. The + choice of the default function may change from version + to version and should not be relied on. Default value: None. + + Returns + ------- + cutset : set + Set of nodes that, if removed, would disconnect G. If source + and target nodes are provided, the set contains the nodes that + if removed, would destroy all paths between source and target. + + Examples + -------- + >>> # Platonic icosahedral graph has node connectivity 5 + >>> G = nx.icosahedral_graph() + >>> node_cut = nx.minimum_node_cut(G) + >>> len(node_cut) + 5 + + You can use alternative flow algorithms for the underlying maximum + flow computation. In dense networks the algorithm + :meth:`shortest_augmenting_path` will usually perform better + than the default :meth:`edmonds_karp`, which is faster for + sparse networks with highly skewed degree distributions. Alternative + flow functions have to be explicitly imported from the flow package. + + >>> from networkx.algorithms.flow import shortest_augmenting_path + >>> node_cut == nx.minimum_node_cut(G, flow_func=shortest_augmenting_path) + True + + If you specify a pair of nodes (source and target) as parameters, + this function returns a local st node cut. + + >>> len(nx.minimum_node_cut(G, 3, 7)) + 5 + + If you need to perform several local st cuts among different + pairs of nodes on the same graph, it is recommended that you reuse + the data structures used in the maximum flow computations. See + :meth:`minimum_st_node_cut` for details. + + Notes + ----- + This is a flow based implementation of minimum node cut. The algorithm + is based in solving a number of maximum flow computations to determine + the capacity of the minimum cut on an auxiliary directed network that + corresponds to the minimum node cut of G. It handles both directed + and undirected graphs. This implementation is based on algorithm 11 + in [1]_. + + See also + -------- + :meth:`minimum_st_node_cut` + :meth:`minimum_cut` + :meth:`minimum_edge_cut` + :meth:`stoer_wagner` + :meth:`node_connectivity` + :meth:`edge_connectivity` + :meth:`maximum_flow` + :meth:`edmonds_karp` + :meth:`preflow_push` + :meth:`shortest_augmenting_path` + + References + ---------- + .. [1] Abdol-Hossein Esfahanian. Connectivity Algorithms. + http://www.cse.msu.edu/~cse835/Papers/Graph_connectivity_revised.pdf + + """ + if (s is not None and t is None) or (s is None and t is not None): + raise nx.NetworkXError("Both source and target must be specified.") + + # Local minimum node cut. + if s is not None and t is not None: + if s not in G: + raise nx.NetworkXError(f"node {s} not in graph") + if t not in G: + raise nx.NetworkXError(f"node {t} not in graph") + return minimum_st_node_cut(G, s, t, flow_func=flow_func) + + # Global minimum node cut. + # Analog to the algorithm 11 for global node connectivity in [1]. + if G.is_directed(): + if not nx.is_weakly_connected(G): + raise nx.NetworkXError("Input graph is not connected") + iter_func = itertools.permutations + + def neighbors(v): + return itertools.chain.from_iterable([G.predecessors(v), G.successors(v)]) + + else: + if not nx.is_connected(G): + raise nx.NetworkXError("Input graph is not connected") + iter_func = itertools.combinations + neighbors = G.neighbors + + # Reuse the auxiliary digraph and the residual network. + H = build_auxiliary_node_connectivity(G) + R = build_residual_network(H, "capacity") + kwargs = {"flow_func": flow_func, "auxiliary": H, "residual": R} + + # Choose a node with minimum degree. + v = min(G, key=G.degree) + # Initial node cutset is all neighbors of the node with minimum degree. + min_cut = set(G[v]) + # Compute st node cuts between v and all its non-neighbors nodes in G. + for w in set(G) - set(neighbors(v)) - {v}: + this_cut = minimum_st_node_cut(G, v, w, **kwargs) + if len(min_cut) >= len(this_cut): + min_cut = this_cut + # Also for non adjacent pairs of neighbors of v. + for x, y in iter_func(neighbors(v), 2): + if y in G[x]: + continue + this_cut = minimum_st_node_cut(G, x, y, **kwargs) + if len(min_cut) >= len(this_cut): + min_cut = this_cut + + return min_cut + + +@nx._dispatchable +def minimum_edge_cut(G, s=None, t=None, flow_func=None): + r"""Returns a set of edges of minimum cardinality that disconnects G. + + If source and target nodes are provided, this function returns the + set of edges of minimum cardinality that, if removed, would break + all paths among source and target in G. If not, it returns a set of + edges of minimum cardinality that disconnects G. + + Parameters + ---------- + G : NetworkX graph + + s : node + Source node. Optional. Default value: None. + + t : node + Target node. Optional. Default value: None. + + flow_func : function + A function for computing the maximum flow among a pair of nodes. + The function has to accept at least three parameters: a Digraph, + a source node, and a target node. And return a residual network + that follows NetworkX conventions (see :meth:`maximum_flow` for + details). If flow_func is None, the default maximum flow function + (:meth:`edmonds_karp`) is used. See below for details. The + choice of the default function may change from version + to version and should not be relied on. Default value: None. + + Returns + ------- + cutset : set + Set of edges that, if removed, would disconnect G. If source + and target nodes are provided, the set contains the edges that + if removed, would destroy all paths between source and target. + + Examples + -------- + >>> # Platonic icosahedral graph has edge connectivity 5 + >>> G = nx.icosahedral_graph() + >>> len(nx.minimum_edge_cut(G)) + 5 + + You can use alternative flow algorithms for the underlying + maximum flow computation. In dense networks the algorithm + :meth:`shortest_augmenting_path` will usually perform better + than the default :meth:`edmonds_karp`, which is faster for + sparse networks with highly skewed degree distributions. + Alternative flow functions have to be explicitly imported + from the flow package. + + >>> from networkx.algorithms.flow import shortest_augmenting_path + >>> len(nx.minimum_edge_cut(G, flow_func=shortest_augmenting_path)) + 5 + + If you specify a pair of nodes (source and target) as parameters, + this function returns the value of local edge connectivity. + + >>> nx.edge_connectivity(G, 3, 7) + 5 + + If you need to perform several local computations among different + pairs of nodes on the same graph, it is recommended that you reuse + the data structures used in the maximum flow computations. See + :meth:`local_edge_connectivity` for details. + + Notes + ----- + This is a flow based implementation of minimum edge cut. For + undirected graphs the algorithm works by finding a 'small' dominating + set of nodes of G (see algorithm 7 in [1]_) and computing the maximum + flow between an arbitrary node in the dominating set and the rest of + nodes in it. This is an implementation of algorithm 6 in [1]_. For + directed graphs, the algorithm does n calls to the max flow function. + The function raises an error if the directed graph is not weakly + connected and returns an empty set if it is weakly connected. + It is an implementation of algorithm 8 in [1]_. + + See also + -------- + :meth:`minimum_st_edge_cut` + :meth:`minimum_node_cut` + :meth:`stoer_wagner` + :meth:`node_connectivity` + :meth:`edge_connectivity` + :meth:`maximum_flow` + :meth:`edmonds_karp` + :meth:`preflow_push` + :meth:`shortest_augmenting_path` + + References + ---------- + .. [1] Abdol-Hossein Esfahanian. Connectivity Algorithms. + http://www.cse.msu.edu/~cse835/Papers/Graph_connectivity_revised.pdf + + """ + if (s is not None and t is None) or (s is None and t is not None): + raise nx.NetworkXError("Both source and target must be specified.") + + # reuse auxiliary digraph and residual network + H = build_auxiliary_edge_connectivity(G) + R = build_residual_network(H, "capacity") + kwargs = {"flow_func": flow_func, "residual": R, "auxiliary": H} + + # Local minimum edge cut if s and t are not None + if s is not None and t is not None: + if s not in G: + raise nx.NetworkXError(f"node {s} not in graph") + if t not in G: + raise nx.NetworkXError(f"node {t} not in graph") + return minimum_st_edge_cut(H, s, t, **kwargs) + + # Global minimum edge cut + # Analog to the algorithm for global edge connectivity + if G.is_directed(): + # Based on algorithm 8 in [1] + if not nx.is_weakly_connected(G): + raise nx.NetworkXError("Input graph is not connected") + + # Initial cutset is all edges of a node with minimum degree + node = min(G, key=G.degree) + min_cut = set(G.edges(node)) + nodes = list(G) + n = len(nodes) + for i in range(n): + try: + this_cut = minimum_st_edge_cut(H, nodes[i], nodes[i + 1], **kwargs) + if len(this_cut) <= len(min_cut): + min_cut = this_cut + except IndexError: # Last node! + this_cut = minimum_st_edge_cut(H, nodes[i], nodes[0], **kwargs) + if len(this_cut) <= len(min_cut): + min_cut = this_cut + + return min_cut + + else: # undirected + # Based on algorithm 6 in [1] + if not nx.is_connected(G): + raise nx.NetworkXError("Input graph is not connected") + + # Initial cutset is all edges of a node with minimum degree + node = min(G, key=G.degree) + min_cut = set(G.edges(node)) + # A dominating set is \lambda-covering + # We need a dominating set with at least two nodes + for node in G: + D = nx.dominating_set(G, start_with=node) + v = D.pop() + if D: + break + else: + # in complete graphs the dominating set will always be of one node + # thus we return min_cut, which now contains the edges of a node + # with minimum degree + return min_cut + for w in D: + this_cut = minimum_st_edge_cut(H, v, w, **kwargs) + if len(this_cut) <= len(min_cut): + min_cut = this_cut + + return min_cut diff --git a/lib/python3.10/site-packages/networkx/algorithms/connectivity/disjoint_paths.py b/lib/python3.10/site-packages/networkx/algorithms/connectivity/disjoint_paths.py new file mode 100644 index 0000000000000000000000000000000000000000..00616492819023493af697b59ba2fbe6284f79cf --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/connectivity/disjoint_paths.py @@ -0,0 +1,408 @@ +"""Flow based node and edge disjoint paths.""" + +import networkx as nx + +# Define the default maximum flow function to use for the underlying +# maximum flow computations +from networkx.algorithms.flow import ( + edmonds_karp, + preflow_push, + shortest_augmenting_path, +) +from networkx.exception import NetworkXNoPath + +default_flow_func = edmonds_karp +from itertools import filterfalse as _filterfalse + +# Functions to build auxiliary data structures. +from .utils import build_auxiliary_edge_connectivity, build_auxiliary_node_connectivity + +__all__ = ["edge_disjoint_paths", "node_disjoint_paths"] + + +@nx._dispatchable( + graphs={"G": 0, "auxiliary?": 5}, + preserve_edge_attrs={"auxiliary": {"capacity": float("inf")}}, +) +def edge_disjoint_paths( + G, s, t, flow_func=None, cutoff=None, auxiliary=None, residual=None +): + """Returns the edges disjoint paths between source and target. + + Edge disjoint paths are paths that do not share any edge. The + number of edge disjoint paths between source and target is equal + to their edge connectivity. + + Parameters + ---------- + G : NetworkX graph + + s : node + Source node for the flow. + + t : node + Sink node for the flow. + + flow_func : function + A function for computing the maximum flow among a pair of nodes. + The function has to accept at least three parameters: a Digraph, + a source node, and a target node. And return a residual network + that follows NetworkX conventions (see :meth:`maximum_flow` for + details). If flow_func is None, the default maximum flow function + (:meth:`edmonds_karp`) is used. The choice of the default function + may change from version to version and should not be relied on. + Default value: None. + + cutoff : integer or None (default: None) + Maximum number of paths to yield. If specified, the maximum flow + algorithm will terminate when the flow value reaches or exceeds the + cutoff. This only works for flows that support the cutoff parameter + (most do) and is ignored otherwise. + + auxiliary : NetworkX DiGraph + Auxiliary digraph to compute flow based edge connectivity. It has + to have a graph attribute called mapping with a dictionary mapping + node names in G and in the auxiliary digraph. If provided + it will be reused instead of recreated. Default value: None. + + residual : NetworkX DiGraph + Residual network to compute maximum flow. If provided it will be + reused instead of recreated. Default value: None. + + Returns + ------- + paths : generator + A generator of edge independent paths. + + Raises + ------ + NetworkXNoPath + If there is no path between source and target. + + NetworkXError + If source or target are not in the graph G. + + See also + -------- + :meth:`node_disjoint_paths` + :meth:`edge_connectivity` + :meth:`maximum_flow` + :meth:`edmonds_karp` + :meth:`preflow_push` + :meth:`shortest_augmenting_path` + + Examples + -------- + We use in this example the platonic icosahedral graph, which has node + edge connectivity 5, thus there are 5 edge disjoint paths between any + pair of nodes. + + >>> G = nx.icosahedral_graph() + >>> len(list(nx.edge_disjoint_paths(G, 0, 6))) + 5 + + + If you need to compute edge disjoint paths on several pairs of + nodes in the same graph, it is recommended that you reuse the + data structures that NetworkX uses in the computation: the + auxiliary digraph for edge connectivity, and the residual + network for the underlying maximum flow computation. + + Example of how to compute edge disjoint paths among all pairs of + nodes of the platonic icosahedral graph reusing the data + structures. + + >>> import itertools + >>> # You also have to explicitly import the function for + >>> # building the auxiliary digraph from the connectivity package + >>> from networkx.algorithms.connectivity import build_auxiliary_edge_connectivity + >>> H = build_auxiliary_edge_connectivity(G) + >>> # And the function for building the residual network from the + >>> # flow package + >>> from networkx.algorithms.flow import build_residual_network + >>> # Note that the auxiliary digraph has an edge attribute named capacity + >>> R = build_residual_network(H, "capacity") + >>> result = {n: {} for n in G} + >>> # Reuse the auxiliary digraph and the residual network by passing them + >>> # as arguments + >>> for u, v in itertools.combinations(G, 2): + ... k = len(list(nx.edge_disjoint_paths(G, u, v, auxiliary=H, residual=R))) + ... result[u][v] = k + >>> all(result[u][v] == 5 for u, v in itertools.combinations(G, 2)) + True + + You can also use alternative flow algorithms for computing edge disjoint + paths. For instance, in dense networks the algorithm + :meth:`shortest_augmenting_path` will usually perform better than + the default :meth:`edmonds_karp` which is faster for sparse + networks with highly skewed degree distributions. Alternative flow + functions have to be explicitly imported from the flow package. + + >>> from networkx.algorithms.flow import shortest_augmenting_path + >>> len(list(nx.edge_disjoint_paths(G, 0, 6, flow_func=shortest_augmenting_path))) + 5 + + Notes + ----- + This is a flow based implementation of edge disjoint paths. We compute + the maximum flow between source and target on an auxiliary directed + network. The saturated edges in the residual network after running the + maximum flow algorithm correspond to edge disjoint paths between source + and target in the original network. This function handles both directed + and undirected graphs, and can use all flow algorithms from NetworkX flow + package. + + """ + if s not in G: + raise nx.NetworkXError(f"node {s} not in graph") + if t not in G: + raise nx.NetworkXError(f"node {t} not in graph") + + if flow_func is None: + flow_func = default_flow_func + + if auxiliary is None: + H = build_auxiliary_edge_connectivity(G) + else: + H = auxiliary + + # Maximum possible edge disjoint paths + possible = min(H.out_degree(s), H.in_degree(t)) + if not possible: + raise NetworkXNoPath + + if cutoff is None: + cutoff = possible + else: + cutoff = min(cutoff, possible) + + # Compute maximum flow between source and target. Flow functions in + # NetworkX return a residual network. + kwargs = { + "capacity": "capacity", + "residual": residual, + "cutoff": cutoff, + "value_only": True, + } + if flow_func is preflow_push: + del kwargs["cutoff"] + if flow_func is shortest_augmenting_path: + kwargs["two_phase"] = True + R = flow_func(H, s, t, **kwargs) + + if R.graph["flow_value"] == 0: + raise NetworkXNoPath + + # Saturated edges in the residual network form the edge disjoint paths + # between source and target + cutset = [ + (u, v) + for u, v, d in R.edges(data=True) + if d["capacity"] == d["flow"] and d["flow"] > 0 + ] + # This is equivalent of what flow.utils.build_flow_dict returns, but + # only for the nodes with saturated edges and without reporting 0 flows. + flow_dict = {n: {} for edge in cutset for n in edge} + for u, v in cutset: + flow_dict[u][v] = 1 + + # Rebuild the edge disjoint paths from the flow dictionary. + paths_found = 0 + for v in list(flow_dict[s]): + if paths_found >= cutoff: + # preflow_push does not support cutoff: we have to + # keep track of the paths founds and stop at cutoff. + break + path = [s] + if v == t: + path.append(v) + yield path + continue + u = v + while u != t: + path.append(u) + try: + u, _ = flow_dict[u].popitem() + except KeyError: + break + else: + path.append(t) + yield path + paths_found += 1 + + +@nx._dispatchable( + graphs={"G": 0, "auxiliary?": 5}, + preserve_node_attrs={"auxiliary": {"id": None}}, + preserve_graph_attrs={"auxiliary"}, +) +def node_disjoint_paths( + G, s, t, flow_func=None, cutoff=None, auxiliary=None, residual=None +): + r"""Computes node disjoint paths between source and target. + + Node disjoint paths are paths that only share their first and last + nodes. The number of node independent paths between two nodes is + equal to their local node connectivity. + + Parameters + ---------- + G : NetworkX graph + + s : node + Source node. + + t : node + Target node. + + flow_func : function + A function for computing the maximum flow among a pair of nodes. + The function has to accept at least three parameters: a Digraph, + a source node, and a target node. And return a residual network + that follows NetworkX conventions (see :meth:`maximum_flow` for + details). If flow_func is None, the default maximum flow function + (:meth:`edmonds_karp`) is used. See below for details. The choice + of the default function may change from version to version and + should not be relied on. Default value: None. + + cutoff : integer or None (default: None) + Maximum number of paths to yield. If specified, the maximum flow + algorithm will terminate when the flow value reaches or exceeds the + cutoff. This only works for flows that support the cutoff parameter + (most do) and is ignored otherwise. + + auxiliary : NetworkX DiGraph + Auxiliary digraph to compute flow based node connectivity. It has + to have a graph attribute called mapping with a dictionary mapping + node names in G and in the auxiliary digraph. If provided + it will be reused instead of recreated. Default value: None. + + residual : NetworkX DiGraph + Residual network to compute maximum flow. If provided it will be + reused instead of recreated. Default value: None. + + Returns + ------- + paths : generator + Generator of node disjoint paths. + + Raises + ------ + NetworkXNoPath + If there is no path between source and target. + + NetworkXError + If source or target are not in the graph G. + + Examples + -------- + We use in this example the platonic icosahedral graph, which has node + connectivity 5, thus there are 5 node disjoint paths between any pair + of non neighbor nodes. + + >>> G = nx.icosahedral_graph() + >>> len(list(nx.node_disjoint_paths(G, 0, 6))) + 5 + + If you need to compute node disjoint paths between several pairs of + nodes in the same graph, it is recommended that you reuse the + data structures that NetworkX uses in the computation: the + auxiliary digraph for node connectivity and node cuts, and the + residual network for the underlying maximum flow computation. + + Example of how to compute node disjoint paths reusing the data + structures: + + >>> # You also have to explicitly import the function for + >>> # building the auxiliary digraph from the connectivity package + >>> from networkx.algorithms.connectivity import build_auxiliary_node_connectivity + >>> H = build_auxiliary_node_connectivity(G) + >>> # And the function for building the residual network from the + >>> # flow package + >>> from networkx.algorithms.flow import build_residual_network + >>> # Note that the auxiliary digraph has an edge attribute named capacity + >>> R = build_residual_network(H, "capacity") + >>> # Reuse the auxiliary digraph and the residual network by passing them + >>> # as arguments + >>> len(list(nx.node_disjoint_paths(G, 0, 6, auxiliary=H, residual=R))) + 5 + + You can also use alternative flow algorithms for computing node disjoint + paths. For instance, in dense networks the algorithm + :meth:`shortest_augmenting_path` will usually perform better than + the default :meth:`edmonds_karp` which is faster for sparse + networks with highly skewed degree distributions. Alternative flow + functions have to be explicitly imported from the flow package. + + >>> from networkx.algorithms.flow import shortest_augmenting_path + >>> len(list(nx.node_disjoint_paths(G, 0, 6, flow_func=shortest_augmenting_path))) + 5 + + Notes + ----- + This is a flow based implementation of node disjoint paths. We compute + the maximum flow between source and target on an auxiliary directed + network. The saturated edges in the residual network after running the + maximum flow algorithm correspond to node disjoint paths between source + and target in the original network. This function handles both directed + and undirected graphs, and can use all flow algorithms from NetworkX flow + package. + + See also + -------- + :meth:`edge_disjoint_paths` + :meth:`node_connectivity` + :meth:`maximum_flow` + :meth:`edmonds_karp` + :meth:`preflow_push` + :meth:`shortest_augmenting_path` + + """ + if s not in G: + raise nx.NetworkXError(f"node {s} not in graph") + if t not in G: + raise nx.NetworkXError(f"node {t} not in graph") + + if auxiliary is None: + H = build_auxiliary_node_connectivity(G) + else: + H = auxiliary + + mapping = H.graph.get("mapping", None) + if mapping is None: + raise nx.NetworkXError("Invalid auxiliary digraph.") + + # Maximum possible edge disjoint paths + possible = min(H.out_degree(f"{mapping[s]}B"), H.in_degree(f"{mapping[t]}A")) + if not possible: + raise NetworkXNoPath + + if cutoff is None: + cutoff = possible + else: + cutoff = min(cutoff, possible) + + kwargs = { + "flow_func": flow_func, + "residual": residual, + "auxiliary": H, + "cutoff": cutoff, + } + + # The edge disjoint paths in the auxiliary digraph correspond to the node + # disjoint paths in the original graph. + paths_edges = edge_disjoint_paths(H, f"{mapping[s]}B", f"{mapping[t]}A", **kwargs) + for path in paths_edges: + # Each node in the original graph maps to two nodes in auxiliary graph + yield list(_unique_everseen(H.nodes[node]["id"] for node in path)) + + +def _unique_everseen(iterable): + # Adapted from https://docs.python.org/3/library/itertools.html examples + "List unique elements, preserving order. Remember all elements ever seen." + # unique_everseen('AAAABBBCCDAABBB') --> A B C D + seen = set() + seen_add = seen.add + for element in _filterfalse(seen.__contains__, iterable): + seen_add(element) + yield element diff --git a/lib/python3.10/site-packages/networkx/algorithms/connectivity/edge_augmentation.py b/lib/python3.10/site-packages/networkx/algorithms/connectivity/edge_augmentation.py new file mode 100644 index 0000000000000000000000000000000000000000..278a8e36717d7b32cfb3634dea163571de82e58c --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/connectivity/edge_augmentation.py @@ -0,0 +1,1270 @@ +""" +Algorithms for finding k-edge-augmentations + +A k-edge-augmentation is a set of edges, that once added to a graph, ensures +that the graph is k-edge-connected; i.e. the graph cannot be disconnected +unless k or more edges are removed. Typically, the goal is to find the +augmentation with minimum weight. In general, it is not guaranteed that a +k-edge-augmentation exists. + +See Also +-------- +:mod:`edge_kcomponents` : algorithms for finding k-edge-connected components +:mod:`connectivity` : algorithms for determining edge connectivity. +""" + +import itertools as it +import math +from collections import defaultdict, namedtuple + +import networkx as nx +from networkx.utils import not_implemented_for, py_random_state + +__all__ = ["k_edge_augmentation", "is_k_edge_connected", "is_locally_k_edge_connected"] + + +@not_implemented_for("directed") +@not_implemented_for("multigraph") +@nx._dispatchable +def is_k_edge_connected(G, k): + """Tests to see if a graph is k-edge-connected. + + Is it impossible to disconnect the graph by removing fewer than k edges? + If so, then G is k-edge-connected. + + Parameters + ---------- + G : NetworkX graph + An undirected graph. + + k : integer + edge connectivity to test for + + Returns + ------- + boolean + True if G is k-edge-connected. + + See Also + -------- + :func:`is_locally_k_edge_connected` + + Examples + -------- + >>> G = nx.barbell_graph(10, 0) + >>> nx.is_k_edge_connected(G, k=1) + True + >>> nx.is_k_edge_connected(G, k=2) + False + """ + if k < 1: + raise ValueError(f"k must be positive, not {k}") + # First try to quickly determine if G is not k-edge-connected + if G.number_of_nodes() < k + 1: + return False + elif any(d < k for n, d in G.degree()): + return False + else: + # Otherwise perform the full check + if k == 1: + return nx.is_connected(G) + elif k == 2: + return nx.is_connected(G) and not nx.has_bridges(G) + else: + return nx.edge_connectivity(G, cutoff=k) >= k + + +@not_implemented_for("directed") +@not_implemented_for("multigraph") +@nx._dispatchable +def is_locally_k_edge_connected(G, s, t, k): + """Tests to see if an edge in a graph is locally k-edge-connected. + + Is it impossible to disconnect s and t by removing fewer than k edges? + If so, then s and t are locally k-edge-connected in G. + + Parameters + ---------- + G : NetworkX graph + An undirected graph. + + s : node + Source node + + t : node + Target node + + k : integer + local edge connectivity for nodes s and t + + Returns + ------- + boolean + True if s and t are locally k-edge-connected in G. + + See Also + -------- + :func:`is_k_edge_connected` + + Examples + -------- + >>> from networkx.algorithms.connectivity import is_locally_k_edge_connected + >>> G = nx.barbell_graph(10, 0) + >>> is_locally_k_edge_connected(G, 5, 15, k=1) + True + >>> is_locally_k_edge_connected(G, 5, 15, k=2) + False + >>> is_locally_k_edge_connected(G, 1, 5, k=2) + True + """ + if k < 1: + raise ValueError(f"k must be positive, not {k}") + + # First try to quickly determine s, t is not k-locally-edge-connected in G + if G.degree(s) < k or G.degree(t) < k: + return False + else: + # Otherwise perform the full check + if k == 1: + return nx.has_path(G, s, t) + else: + localk = nx.connectivity.local_edge_connectivity(G, s, t, cutoff=k) + return localk >= k + + +@not_implemented_for("directed") +@not_implemented_for("multigraph") +@nx._dispatchable +def k_edge_augmentation(G, k, avail=None, weight=None, partial=False): + """Finds set of edges to k-edge-connect G. + + Adding edges from the augmentation to G make it impossible to disconnect G + unless k or more edges are removed. This function uses the most efficient + function available (depending on the value of k and if the problem is + weighted or unweighted) to search for a minimum weight subset of available + edges that k-edge-connects G. In general, finding a k-edge-augmentation is + NP-hard, so solutions are not guaranteed to be minimal. Furthermore, a + k-edge-augmentation may not exist. + + Parameters + ---------- + G : NetworkX graph + An undirected graph. + + k : integer + Desired edge connectivity + + avail : dict or a set of 2 or 3 tuples + The available edges that can be used in the augmentation. + + If unspecified, then all edges in the complement of G are available. + Otherwise, each item is an available edge (with an optional weight). + + In the unweighted case, each item is an edge ``(u, v)``. + + In the weighted case, each item is a 3-tuple ``(u, v, d)`` or a dict + with items ``(u, v): d``. The third item, ``d``, can be a dictionary + or a real number. If ``d`` is a dictionary ``d[weight]`` + correspondings to the weight. + + weight : string + key to use to find weights if ``avail`` is a set of 3-tuples where the + third item in each tuple is a dictionary. + + partial : boolean + If partial is True and no feasible k-edge-augmentation exists, then all + a partial k-edge-augmentation is generated. Adding the edges in a + partial augmentation to G, minimizes the number of k-edge-connected + components and maximizes the edge connectivity between those + components. For details, see :func:`partial_k_edge_augmentation`. + + Yields + ------ + edge : tuple + Edges that, once added to G, would cause G to become k-edge-connected. + If partial is False, an error is raised if this is not possible. + Otherwise, generated edges form a partial augmentation, which + k-edge-connects any part of G where it is possible, and maximally + connects the remaining parts. + + Raises + ------ + NetworkXUnfeasible + If partial is False and no k-edge-augmentation exists. + + NetworkXNotImplemented + If the input graph is directed or a multigraph. + + ValueError: + If k is less than 1 + + Notes + ----- + When k=1 this returns an optimal solution. + + When k=2 and ``avail`` is None, this returns an optimal solution. + Otherwise when k=2, this returns a 2-approximation of the optimal solution. + + For k>3, this problem is NP-hard and this uses a randomized algorithm that + produces a feasible solution, but provides no guarantees on the + solution weight. + + Examples + -------- + >>> # Unweighted cases + >>> G = nx.path_graph((1, 2, 3, 4)) + >>> G.add_node(5) + >>> sorted(nx.k_edge_augmentation(G, k=1)) + [(1, 5)] + >>> sorted(nx.k_edge_augmentation(G, k=2)) + [(1, 5), (5, 4)] + >>> sorted(nx.k_edge_augmentation(G, k=3)) + [(1, 4), (1, 5), (2, 5), (3, 5), (4, 5)] + >>> complement = list(nx.k_edge_augmentation(G, k=5, partial=True)) + >>> G.add_edges_from(complement) + >>> nx.edge_connectivity(G) + 4 + + >>> # Weighted cases + >>> G = nx.path_graph((1, 2, 3, 4)) + >>> G.add_node(5) + >>> # avail can be a tuple with a dict + >>> avail = [(1, 5, {"weight": 11}), (2, 5, {"weight": 10})] + >>> sorted(nx.k_edge_augmentation(G, k=1, avail=avail, weight="weight")) + [(2, 5)] + >>> # or avail can be a 3-tuple with a real number + >>> avail = [(1, 5, 11), (2, 5, 10), (4, 3, 1), (4, 5, 51)] + >>> sorted(nx.k_edge_augmentation(G, k=2, avail=avail)) + [(1, 5), (2, 5), (4, 5)] + >>> # or avail can be a dict + >>> avail = {(1, 5): 11, (2, 5): 10, (4, 3): 1, (4, 5): 51} + >>> sorted(nx.k_edge_augmentation(G, k=2, avail=avail)) + [(1, 5), (2, 5), (4, 5)] + >>> # If augmentation is infeasible, then a partial solution can be found + >>> avail = {(1, 5): 11} + >>> sorted(nx.k_edge_augmentation(G, k=2, avail=avail, partial=True)) + [(1, 5)] + """ + try: + if k <= 0: + raise ValueError(f"k must be a positive integer, not {k}") + elif G.number_of_nodes() < k + 1: + msg = f"impossible to {k} connect in graph with less than {k + 1} nodes" + raise nx.NetworkXUnfeasible(msg) + elif avail is not None and len(avail) == 0: + if not nx.is_k_edge_connected(G, k): + raise nx.NetworkXUnfeasible("no available edges") + aug_edges = [] + elif k == 1: + aug_edges = one_edge_augmentation( + G, avail=avail, weight=weight, partial=partial + ) + elif k == 2: + aug_edges = bridge_augmentation(G, avail=avail, weight=weight) + else: + # raise NotImplementedError(f'not implemented for k>2. k={k}') + aug_edges = greedy_k_edge_augmentation( + G, k=k, avail=avail, weight=weight, seed=0 + ) + # Do eager evaluation so we can catch any exceptions + # Before executing partial code. + yield from list(aug_edges) + except nx.NetworkXUnfeasible: + if partial: + # Return all available edges + if avail is None: + aug_edges = complement_edges(G) + else: + # If we can't k-edge-connect the entire graph, try to + # k-edge-connect as much as possible + aug_edges = partial_k_edge_augmentation( + G, k=k, avail=avail, weight=weight + ) + yield from aug_edges + else: + raise + + +@nx._dispatchable +def partial_k_edge_augmentation(G, k, avail, weight=None): + """Finds augmentation that k-edge-connects as much of the graph as possible. + + When a k-edge-augmentation is not possible, we can still try to find a + small set of edges that partially k-edge-connects as much of the graph as + possible. All possible edges are generated between remaining parts. + This minimizes the number of k-edge-connected subgraphs in the resulting + graph and maximizes the edge connectivity between those subgraphs. + + Parameters + ---------- + G : NetworkX graph + An undirected graph. + + k : integer + Desired edge connectivity + + avail : dict or a set of 2 or 3 tuples + For more details, see :func:`k_edge_augmentation`. + + weight : string + key to use to find weights if ``avail`` is a set of 3-tuples. + For more details, see :func:`k_edge_augmentation`. + + Yields + ------ + edge : tuple + Edges in the partial augmentation of G. These edges k-edge-connect any + part of G where it is possible, and maximally connects the remaining + parts. In other words, all edges from avail are generated except for + those within subgraphs that have already become k-edge-connected. + + Notes + ----- + Construct H that augments G with all edges in avail. + Find the k-edge-subgraphs of H. + For each k-edge-subgraph, if the number of nodes is more than k, then find + the k-edge-augmentation of that graph and add it to the solution. Then add + all edges in avail between k-edge subgraphs to the solution. + + See Also + -------- + :func:`k_edge_augmentation` + + Examples + -------- + >>> G = nx.path_graph((1, 2, 3, 4, 5, 6, 7)) + >>> G.add_node(8) + >>> avail = [(1, 3), (1, 4), (1, 5), (2, 4), (2, 5), (3, 5), (1, 8)] + >>> sorted(partial_k_edge_augmentation(G, k=2, avail=avail)) + [(1, 5), (1, 8)] + """ + + def _edges_between_disjoint(H, only1, only2): + """finds edges between disjoint nodes""" + only1_adj = {u: set(H.adj[u]) for u in only1} + for u, neighbs in only1_adj.items(): + # Find the neighbors of u in only1 that are also in only2 + neighbs12 = neighbs.intersection(only2) + for v in neighbs12: + yield (u, v) + + avail_uv, avail_w = _unpack_available_edges(avail, weight=weight, G=G) + + # Find which parts of the graph can be k-edge-connected + H = G.copy() + H.add_edges_from( + ( + (u, v, {"weight": w, "generator": (u, v)}) + for (u, v), w in zip(avail, avail_w) + ) + ) + k_edge_subgraphs = list(nx.k_edge_subgraphs(H, k=k)) + + # Generate edges to k-edge-connect internal subgraphs + for nodes in k_edge_subgraphs: + if len(nodes) > 1: + # Get the k-edge-connected subgraph + C = H.subgraph(nodes).copy() + # Find the internal edges that were available + sub_avail = { + d["generator"]: d["weight"] + for (u, v, d) in C.edges(data=True) + if "generator" in d + } + # Remove potential augmenting edges + C.remove_edges_from(sub_avail.keys()) + # Find a subset of these edges that makes the component + # k-edge-connected and ignore the rest + yield from nx.k_edge_augmentation(C, k=k, avail=sub_avail) + + # Generate all edges between CCs that could not be k-edge-connected + for cc1, cc2 in it.combinations(k_edge_subgraphs, 2): + for u, v in _edges_between_disjoint(H, cc1, cc2): + d = H.get_edge_data(u, v) + edge = d.get("generator", None) + if edge is not None: + yield edge + + +@not_implemented_for("multigraph") +@not_implemented_for("directed") +@nx._dispatchable +def one_edge_augmentation(G, avail=None, weight=None, partial=False): + """Finds minimum weight set of edges to connect G. + + Equivalent to :func:`k_edge_augmentation` when k=1. Adding the resulting + edges to G will make it 1-edge-connected. The solution is optimal for both + weighted and non-weighted variants. + + Parameters + ---------- + G : NetworkX graph + An undirected graph. + + avail : dict or a set of 2 or 3 tuples + For more details, see :func:`k_edge_augmentation`. + + weight : string + key to use to find weights if ``avail`` is a set of 3-tuples. + For more details, see :func:`k_edge_augmentation`. + + partial : boolean + If partial is True and no feasible k-edge-augmentation exists, then the + augmenting edges minimize the number of connected components. + + Yields + ------ + edge : tuple + Edges in the one-augmentation of G + + Raises + ------ + NetworkXUnfeasible + If partial is False and no one-edge-augmentation exists. + + Notes + ----- + Uses either :func:`unconstrained_one_edge_augmentation` or + :func:`weighted_one_edge_augmentation` depending on whether ``avail`` is + specified. Both algorithms are based on finding a minimum spanning tree. + As such both algorithms find optimal solutions and run in linear time. + + See Also + -------- + :func:`k_edge_augmentation` + """ + if avail is None: + return unconstrained_one_edge_augmentation(G) + else: + return weighted_one_edge_augmentation( + G, avail=avail, weight=weight, partial=partial + ) + + +@not_implemented_for("multigraph") +@not_implemented_for("directed") +@nx._dispatchable +def bridge_augmentation(G, avail=None, weight=None): + """Finds the a set of edges that bridge connects G. + + Equivalent to :func:`k_edge_augmentation` when k=2, and partial=False. + Adding the resulting edges to G will make it 2-edge-connected. If no + constraints are specified the returned set of edges is minimum an optimal, + otherwise the solution is approximated. + + Parameters + ---------- + G : NetworkX graph + An undirected graph. + + avail : dict or a set of 2 or 3 tuples + For more details, see :func:`k_edge_augmentation`. + + weight : string + key to use to find weights if ``avail`` is a set of 3-tuples. + For more details, see :func:`k_edge_augmentation`. + + Yields + ------ + edge : tuple + Edges in the bridge-augmentation of G + + Raises + ------ + NetworkXUnfeasible + If no bridge-augmentation exists. + + Notes + ----- + If there are no constraints the solution can be computed in linear time + using :func:`unconstrained_bridge_augmentation`. Otherwise, the problem + becomes NP-hard and is the solution is approximated by + :func:`weighted_bridge_augmentation`. + + See Also + -------- + :func:`k_edge_augmentation` + """ + if G.number_of_nodes() < 3: + raise nx.NetworkXUnfeasible("impossible to bridge connect less than 3 nodes") + if avail is None: + return unconstrained_bridge_augmentation(G) + else: + return weighted_bridge_augmentation(G, avail, weight=weight) + + +# --- Algorithms and Helpers --- + + +def _ordered(u, v): + """Returns the nodes in an undirected edge in lower-triangular order""" + return (u, v) if u < v else (v, u) + + +def _unpack_available_edges(avail, weight=None, G=None): + """Helper to separate avail into edges and corresponding weights""" + if weight is None: + weight = "weight" + if isinstance(avail, dict): + avail_uv = list(avail.keys()) + avail_w = list(avail.values()) + else: + + def _try_getitem(d): + try: + return d[weight] + except TypeError: + return d + + avail_uv = [tup[0:2] for tup in avail] + avail_w = [1 if len(tup) == 2 else _try_getitem(tup[-1]) for tup in avail] + + if G is not None: + # Edges already in the graph are filtered + flags = [not G.has_edge(u, v) for u, v in avail_uv] + avail_uv = list(it.compress(avail_uv, flags)) + avail_w = list(it.compress(avail_w, flags)) + return avail_uv, avail_w + + +MetaEdge = namedtuple("MetaEdge", ("meta_uv", "uv", "w")) + + +def _lightest_meta_edges(mapping, avail_uv, avail_w): + """Maps available edges in the original graph to edges in the metagraph. + + Parameters + ---------- + mapping : dict + mapping produced by :func:`collapse`, that maps each node in the + original graph to a node in the meta graph + + avail_uv : list + list of edges + + avail_w : list + list of edge weights + + Notes + ----- + Each node in the metagraph is a k-edge-connected component in the original + graph. We don't care about any edge within the same k-edge-connected + component, so we ignore self edges. We also are only interested in the + minimum weight edge bridging each k-edge-connected component so, we group + the edges by meta-edge and take the lightest in each group. + + Examples + -------- + >>> # Each group represents a meta-node + >>> groups = ([1, 2, 3], [4, 5], [6]) + >>> mapping = {n: meta_n for meta_n, ns in enumerate(groups) for n in ns} + >>> avail_uv = [(1, 2), (3, 6), (1, 4), (5, 2), (6, 1), (2, 6), (3, 1)] + >>> avail_w = [20, 99, 20, 15, 50, 99, 20] + >>> sorted(_lightest_meta_edges(mapping, avail_uv, avail_w)) + [MetaEdge(meta_uv=(0, 1), uv=(5, 2), w=15), MetaEdge(meta_uv=(0, 2), uv=(6, 1), w=50)] + """ + grouped_wuv = defaultdict(list) + for w, (u, v) in zip(avail_w, avail_uv): + # Order the meta-edge so it can be used as a dict key + meta_uv = _ordered(mapping[u], mapping[v]) + # Group each available edge using the meta-edge as a key + grouped_wuv[meta_uv].append((w, u, v)) + + # Now that all available edges are grouped, choose one per group + for (mu, mv), choices_wuv in grouped_wuv.items(): + # Ignore available edges within the same meta-node + if mu != mv: + # Choose the lightest available edge belonging to each meta-edge + w, u, v = min(choices_wuv) + yield MetaEdge((mu, mv), (u, v), w) + + +@nx._dispatchable +def unconstrained_one_edge_augmentation(G): + """Finds the smallest set of edges to connect G. + + This is a variant of the unweighted MST problem. + If G is not empty, a feasible solution always exists. + + Parameters + ---------- + G : NetworkX graph + An undirected graph. + + Yields + ------ + edge : tuple + Edges in the one-edge-augmentation of G + + See Also + -------- + :func:`one_edge_augmentation` + :func:`k_edge_augmentation` + + Examples + -------- + >>> G = nx.Graph([(1, 2), (2, 3), (4, 5)]) + >>> G.add_nodes_from([6, 7, 8]) + >>> sorted(unconstrained_one_edge_augmentation(G)) + [(1, 4), (4, 6), (6, 7), (7, 8)] + """ + ccs1 = list(nx.connected_components(G)) + C = collapse(G, ccs1) + # When we are not constrained, we can just make a meta graph tree. + meta_nodes = list(C.nodes()) + # build a path in the metagraph + meta_aug = list(zip(meta_nodes, meta_nodes[1:])) + # map that path to the original graph + inverse = defaultdict(list) + for k, v in C.graph["mapping"].items(): + inverse[v].append(k) + for mu, mv in meta_aug: + yield (inverse[mu][0], inverse[mv][0]) + + +@nx._dispatchable +def weighted_one_edge_augmentation(G, avail, weight=None, partial=False): + """Finds the minimum weight set of edges to connect G if one exists. + + This is a variant of the weighted MST problem. + + Parameters + ---------- + G : NetworkX graph + An undirected graph. + + avail : dict or a set of 2 or 3 tuples + For more details, see :func:`k_edge_augmentation`. + + weight : string + key to use to find weights if ``avail`` is a set of 3-tuples. + For more details, see :func:`k_edge_augmentation`. + + partial : boolean + If partial is True and no feasible k-edge-augmentation exists, then the + augmenting edges minimize the number of connected components. + + Yields + ------ + edge : tuple + Edges in the subset of avail chosen to connect G. + + See Also + -------- + :func:`one_edge_augmentation` + :func:`k_edge_augmentation` + + Examples + -------- + >>> G = nx.Graph([(1, 2), (2, 3), (4, 5)]) + >>> G.add_nodes_from([6, 7, 8]) + >>> # any edge not in avail has an implicit weight of infinity + >>> avail = [(1, 3), (1, 5), (4, 7), (4, 8), (6, 1), (8, 1), (8, 2)] + >>> sorted(weighted_one_edge_augmentation(G, avail)) + [(1, 5), (4, 7), (6, 1), (8, 1)] + >>> # find another solution by giving large weights to edges in the + >>> # previous solution (note some of the old edges must be used) + >>> avail = [(1, 3), (1, 5, 99), (4, 7, 9), (6, 1, 99), (8, 1, 99), (8, 2)] + >>> sorted(weighted_one_edge_augmentation(G, avail)) + [(1, 5), (4, 7), (6, 1), (8, 2)] + """ + avail_uv, avail_w = _unpack_available_edges(avail, weight=weight, G=G) + # Collapse CCs in the original graph into nodes in a metagraph + # Then find an MST of the metagraph instead of the original graph + C = collapse(G, nx.connected_components(G)) + mapping = C.graph["mapping"] + # Assign each available edge to an edge in the metagraph + candidate_mapping = _lightest_meta_edges(mapping, avail_uv, avail_w) + # nx.set_edge_attributes(C, name='weight', values=0) + C.add_edges_from( + (mu, mv, {"weight": w, "generator": uv}) + for (mu, mv), uv, w in candidate_mapping + ) + # Find MST of the meta graph + meta_mst = nx.minimum_spanning_tree(C) + if not partial and not nx.is_connected(meta_mst): + raise nx.NetworkXUnfeasible("Not possible to connect G with available edges") + # Yield the edge that generated the meta-edge + for mu, mv, d in meta_mst.edges(data=True): + if "generator" in d: + edge = d["generator"] + yield edge + + +@nx._dispatchable +def unconstrained_bridge_augmentation(G): + """Finds an optimal 2-edge-augmentation of G using the fewest edges. + + This is an implementation of the algorithm detailed in [1]_. + The basic idea is to construct a meta-graph of bridge-ccs, connect leaf + nodes of the trees to connect the entire graph, and finally connect the + leafs of the tree in dfs-preorder to bridge connect the entire graph. + + Parameters + ---------- + G : NetworkX graph + An undirected graph. + + Yields + ------ + edge : tuple + Edges in the bridge augmentation of G + + Notes + ----- + Input: a graph G. + First find the bridge components of G and collapse each bridge-cc into a + node of a metagraph graph C, which is guaranteed to be a forest of trees. + + C contains p "leafs" --- nodes with exactly one incident edge. + C contains q "isolated nodes" --- nodes with no incident edges. + + Theorem: If p + q > 1, then at least :math:`ceil(p / 2) + q` edges are + needed to bridge connect C. This algorithm achieves this min number. + + The method first adds enough edges to make G into a tree and then pairs + leafs in a simple fashion. + + Let n be the number of trees in C. Let v(i) be an isolated vertex in the + i-th tree if one exists, otherwise it is a pair of distinct leafs nodes + in the i-th tree. Alternating edges from these sets (i.e. adding edges + A1 = [(v(i)[0], v(i + 1)[1]), v(i + 1)[0], v(i + 2)[1])...]) connects C + into a tree T. This tree has p' = p + 2q - 2(n -1) leafs and no isolated + vertices. A1 has n - 1 edges. The next step finds ceil(p' / 2) edges to + biconnect any tree with p' leafs. + + Convert T into an arborescence T' by picking an arbitrary root node with + degree >= 2 and directing all edges away from the root. Note the + implementation implicitly constructs T'. + + The leafs of T are the nodes with no existing edges in T'. + Order the leafs of T' by DFS preorder. Then break this list in half + and add the zipped pairs to A2. + + The set A = A1 + A2 is the minimum augmentation in the metagraph. + + To convert this to edges in the original graph + + References + ---------- + .. [1] Eswaran, Kapali P., and R. Endre Tarjan. (1975) Augmentation problems. + http://epubs.siam.org/doi/abs/10.1137/0205044 + + See Also + -------- + :func:`bridge_augmentation` + :func:`k_edge_augmentation` + + Examples + -------- + >>> G = nx.path_graph((1, 2, 3, 4, 5, 6, 7)) + >>> sorted(unconstrained_bridge_augmentation(G)) + [(1, 7)] + >>> G = nx.path_graph((1, 2, 3, 2, 4, 5, 6, 7)) + >>> sorted(unconstrained_bridge_augmentation(G)) + [(1, 3), (3, 7)] + >>> G = nx.Graph([(0, 1), (0, 2), (1, 2)]) + >>> G.add_node(4) + >>> sorted(unconstrained_bridge_augmentation(G)) + [(1, 4), (4, 0)] + """ + # ----- + # Mapping of terms from (Eswaran and Tarjan): + # G = G_0 - the input graph + # C = G_0' - the bridge condensation of G. (This is a forest of trees) + # A1 = A_1 - the edges to connect the forest into a tree + # leaf = pendant - a node with degree of 1 + + # alpha(v) = maps the node v in G to its meta-node in C + # beta(x) = maps the meta-node x in C to any node in the bridge + # component of G corresponding to x. + + # find the 2-edge-connected components of G + bridge_ccs = list(nx.connectivity.bridge_components(G)) + # condense G into an forest C + C = collapse(G, bridge_ccs) + + # Choose pairs of distinct leaf nodes in each tree. If this is not + # possible then make a pair using the single isolated node in the tree. + vset1 = [ + tuple(cc) * 2 # case1: an isolated node + if len(cc) == 1 + else sorted(cc, key=C.degree)[0:2] # case2: pair of leaf nodes + for cc in nx.connected_components(C) + ] + if len(vset1) > 1: + # Use this set to construct edges that connect C into a tree. + nodes1 = [vs[0] for vs in vset1] + nodes2 = [vs[1] for vs in vset1] + A1 = list(zip(nodes1[1:], nodes2)) + else: + A1 = [] + # Connect each tree in the forest to construct an arborescence + T = C.copy() + T.add_edges_from(A1) + + # If there are only two leaf nodes, we simply connect them. + leafs = [n for n, d in T.degree() if d == 1] + if len(leafs) == 1: + A2 = [] + if len(leafs) == 2: + A2 = [tuple(leafs)] + else: + # Choose an arbitrary non-leaf root + try: + root = next(n for n, d in T.degree() if d > 1) + except StopIteration: # no nodes found with degree > 1 + return + # order the leaves of C by (induced directed) preorder + v2 = [n for n in nx.dfs_preorder_nodes(T, root) if T.degree(n) == 1] + # connecting first half of the leafs in pre-order to the second + # half will bridge connect the tree with the fewest edges. + half = math.ceil(len(v2) / 2) + A2 = list(zip(v2[:half], v2[-half:])) + + # collect the edges used to augment the original forest + aug_tree_edges = A1 + A2 + + # Construct the mapping (beta) from meta-nodes to regular nodes + inverse = defaultdict(list) + for k, v in C.graph["mapping"].items(): + inverse[v].append(k) + # sort so we choose minimum degree nodes first + inverse = { + mu: sorted(mapped, key=lambda u: (G.degree(u), u)) + for mu, mapped in inverse.items() + } + + # For each meta-edge, map back to an arbitrary pair in the original graph + G2 = G.copy() + for mu, mv in aug_tree_edges: + # Find the first available edge that doesn't exist and return it + for u, v in it.product(inverse[mu], inverse[mv]): + if not G2.has_edge(u, v): + G2.add_edge(u, v) + yield u, v + break + + +@nx._dispatchable +def weighted_bridge_augmentation(G, avail, weight=None): + """Finds an approximate min-weight 2-edge-augmentation of G. + + This is an implementation of the approximation algorithm detailed in [1]_. + It chooses a set of edges from avail to add to G that renders it + 2-edge-connected if such a subset exists. This is done by finding a + minimum spanning arborescence of a specially constructed metagraph. + + Parameters + ---------- + G : NetworkX graph + An undirected graph. + + avail : set of 2 or 3 tuples. + candidate edges (with optional weights) to choose from + + weight : string + key to use to find weights if avail is a set of 3-tuples where the + third item in each tuple is a dictionary. + + Yields + ------ + edge : tuple + Edges in the subset of avail chosen to bridge augment G. + + Notes + ----- + Finding a weighted 2-edge-augmentation is NP-hard. + Any edge not in ``avail`` is considered to have a weight of infinity. + The approximation factor is 2 if ``G`` is connected and 3 if it is not. + Runs in :math:`O(m + n log(n))` time + + References + ---------- + .. [1] Khuller, Samir, and Ramakrishna Thurimella. (1993) Approximation + algorithms for graph augmentation. + http://www.sciencedirect.com/science/article/pii/S0196677483710102 + + See Also + -------- + :func:`bridge_augmentation` + :func:`k_edge_augmentation` + + Examples + -------- + >>> G = nx.path_graph((1, 2, 3, 4)) + >>> # When the weights are equal, (1, 4) is the best + >>> avail = [(1, 4, 1), (1, 3, 1), (2, 4, 1)] + >>> sorted(weighted_bridge_augmentation(G, avail)) + [(1, 4)] + >>> # Giving (1, 4) a high weight makes the two edge solution the best. + >>> avail = [(1, 4, 1000), (1, 3, 1), (2, 4, 1)] + >>> sorted(weighted_bridge_augmentation(G, avail)) + [(1, 3), (2, 4)] + >>> # ------ + >>> G = nx.path_graph((1, 2, 3, 4)) + >>> G.add_node(5) + >>> avail = [(1, 5, 11), (2, 5, 10), (4, 3, 1), (4, 5, 1)] + >>> sorted(weighted_bridge_augmentation(G, avail=avail)) + [(1, 5), (4, 5)] + >>> avail = [(1, 5, 11), (2, 5, 10), (4, 3, 1), (4, 5, 51)] + >>> sorted(weighted_bridge_augmentation(G, avail=avail)) + [(1, 5), (2, 5), (4, 5)] + """ + + if weight is None: + weight = "weight" + + # If input G is not connected the approximation factor increases to 3 + if not nx.is_connected(G): + H = G.copy() + connectors = list(one_edge_augmentation(H, avail=avail, weight=weight)) + H.add_edges_from(connectors) + + yield from connectors + else: + connectors = [] + H = G + + if len(avail) == 0: + if nx.has_bridges(H): + raise nx.NetworkXUnfeasible("no augmentation possible") + + avail_uv, avail_w = _unpack_available_edges(avail, weight=weight, G=H) + + # Collapse input into a metagraph. Meta nodes are bridge-ccs + bridge_ccs = nx.connectivity.bridge_components(H) + C = collapse(H, bridge_ccs) + + # Use the meta graph to shrink avail to a small feasible subset + mapping = C.graph["mapping"] + # Choose the minimum weight feasible edge in each group + meta_to_wuv = { + (mu, mv): (w, uv) + for (mu, mv), uv, w in _lightest_meta_edges(mapping, avail_uv, avail_w) + } + + # Mapping of terms from (Khuller and Thurimella): + # C : G_0 = (V, E^0) + # This is the metagraph where each node is a 2-edge-cc in G. + # The edges in C represent bridges in the original graph. + # (mu, mv) : E - E^0 # they group both avail and given edges in E + # T : \Gamma + # D : G^D = (V, E_D) + + # The paper uses ancestor because children point to parents, which is + # contrary to networkx standards. So, we actually need to run + # nx.least_common_ancestor on the reversed Tree. + + # Pick an arbitrary leaf from C as the root + try: + root = next(n for n, d in C.degree() if d == 1) + except StopIteration: # no nodes found with degree == 1 + return + # Root C into a tree TR by directing all edges away from the root + # Note in their paper T directs edges towards the root + TR = nx.dfs_tree(C, root) + + # Add to D the directed edges of T and set their weight to zero + # This indicates that it costs nothing to use edges that were given. + D = nx.reverse(TR).copy() + + nx.set_edge_attributes(D, name="weight", values=0) + + # The LCA of mu and mv in T is the shared ancestor of mu and mv that is + # located farthest from the root. + lca_gen = nx.tree_all_pairs_lowest_common_ancestor( + TR, root=root, pairs=meta_to_wuv.keys() + ) + + for (mu, mv), lca in lca_gen: + w, uv = meta_to_wuv[(mu, mv)] + if lca == mu: + # If u is an ancestor of v in TR, then add edge u->v to D + D.add_edge(lca, mv, weight=w, generator=uv) + elif lca == mv: + # If v is an ancestor of u in TR, then add edge v->u to D + D.add_edge(lca, mu, weight=w, generator=uv) + else: + # If neither u nor v is a ancestor of the other in TR + # let t = lca(TR, u, v) and add edges t->u and t->v + # Track the original edge that GENERATED these edges. + D.add_edge(lca, mu, weight=w, generator=uv) + D.add_edge(lca, mv, weight=w, generator=uv) + + # Then compute a minimum rooted branching + try: + # Note the original edges must be directed towards to root for the + # branching to give us a bridge-augmentation. + A = _minimum_rooted_branching(D, root) + except nx.NetworkXException as err: + # If there is no branching then augmentation is not possible + raise nx.NetworkXUnfeasible("no 2-edge-augmentation possible") from err + + # For each edge e, in the branching that did not belong to the directed + # tree T, add the corresponding edge that **GENERATED** it (this is not + # necessarily e itself!) + + # ensure the third case does not generate edges twice + bridge_connectors = set() + for mu, mv in A.edges(): + data = D.get_edge_data(mu, mv) + if "generator" in data: + # Add the avail edge that generated the branching edge. + edge = data["generator"] + bridge_connectors.add(edge) + + yield from bridge_connectors + + +def _minimum_rooted_branching(D, root): + """Helper function to compute a minimum rooted branching (aka rooted + arborescence) + + Before the branching can be computed, the directed graph must be rooted by + removing the predecessors of root. + + A branching / arborescence of rooted graph G is a subgraph that contains a + directed path from the root to every other vertex. It is the directed + analog of the minimum spanning tree problem. + + References + ---------- + [1] Khuller, Samir (2002) Advanced Algorithms Lecture 24 Notes. + https://web.archive.org/web/20121030033722/https://www.cs.umd.edu/class/spring2011/cmsc651/lec07.pdf + """ + rooted = D.copy() + # root the graph by removing all predecessors to `root`. + rooted.remove_edges_from([(u, root) for u in D.predecessors(root)]) + # Then compute the branching / arborescence. + A = nx.minimum_spanning_arborescence(rooted) + return A + + +@nx._dispatchable(returns_graph=True) +def collapse(G, grouped_nodes): + """Collapses each group of nodes into a single node. + + This is similar to condensation, but works on undirected graphs. + + Parameters + ---------- + G : NetworkX Graph + + grouped_nodes: list or generator + Grouping of nodes to collapse. The grouping must be disjoint. + If grouped_nodes are strongly_connected_components then this is + equivalent to :func:`condensation`. + + Returns + ------- + C : NetworkX Graph + The collapsed graph C of G with respect to the node grouping. The node + labels are integers corresponding to the index of the component in the + list of grouped_nodes. C has a graph attribute named 'mapping' with a + dictionary mapping the original nodes to the nodes in C to which they + belong. Each node in C also has a node attribute 'members' with the set + of original nodes in G that form the group that the node in C + represents. + + Examples + -------- + >>> # Collapses a graph using disjoint groups, but not necessarily connected + >>> G = nx.Graph([(1, 0), (2, 3), (3, 1), (3, 4), (4, 5), (5, 6), (5, 7)]) + >>> G.add_node("A") + >>> grouped_nodes = [{0, 1, 2, 3}, {5, 6, 7}] + >>> C = collapse(G, grouped_nodes) + >>> members = nx.get_node_attributes(C, "members") + >>> sorted(members.keys()) + [0, 1, 2, 3] + >>> member_values = set(map(frozenset, members.values())) + >>> assert {0, 1, 2, 3} in member_values + >>> assert {4} in member_values + >>> assert {5, 6, 7} in member_values + >>> assert {"A"} in member_values + """ + mapping = {} + members = {} + C = G.__class__() + i = 0 # required if G is empty + remaining = set(G.nodes()) + for i, group in enumerate(grouped_nodes): + group = set(group) + assert remaining.issuperset( + group + ), "grouped nodes must exist in G and be disjoint" + remaining.difference_update(group) + members[i] = group + mapping.update((n, i) for n in group) + # remaining nodes are in their own group + for i, node in enumerate(remaining, start=i + 1): + group = {node} + members[i] = group + mapping.update((n, i) for n in group) + number_of_groups = i + 1 + C.add_nodes_from(range(number_of_groups)) + C.add_edges_from( + (mapping[u], mapping[v]) for u, v in G.edges() if mapping[u] != mapping[v] + ) + # Add a list of members (ie original nodes) to each node (ie scc) in C. + nx.set_node_attributes(C, name="members", values=members) + # Add mapping dict as graph attribute + C.graph["mapping"] = mapping + return C + + +@nx._dispatchable +def complement_edges(G): + """Returns only the edges in the complement of G + + Parameters + ---------- + G : NetworkX Graph + + Yields + ------ + edge : tuple + Edges in the complement of G + + Examples + -------- + >>> G = nx.path_graph((1, 2, 3, 4)) + >>> sorted(complement_edges(G)) + [(1, 3), (1, 4), (2, 4)] + >>> G = nx.path_graph((1, 2, 3, 4), nx.DiGraph()) + >>> sorted(complement_edges(G)) + [(1, 3), (1, 4), (2, 1), (2, 4), (3, 1), (3, 2), (4, 1), (4, 2), (4, 3)] + >>> G = nx.complete_graph(1000) + >>> sorted(complement_edges(G)) + [] + """ + G_adj = G._adj # Store as a variable to eliminate attribute lookup + if G.is_directed(): + for u, v in it.combinations(G.nodes(), 2): + if v not in G_adj[u]: + yield (u, v) + if u not in G_adj[v]: + yield (v, u) + else: + for u, v in it.combinations(G.nodes(), 2): + if v not in G_adj[u]: + yield (u, v) + + +def _compat_shuffle(rng, input): + """wrapper around rng.shuffle for python 2 compatibility reasons""" + rng.shuffle(input) + + +@not_implemented_for("multigraph") +@not_implemented_for("directed") +@py_random_state(4) +@nx._dispatchable +def greedy_k_edge_augmentation(G, k, avail=None, weight=None, seed=None): + """Greedy algorithm for finding a k-edge-augmentation + + Parameters + ---------- + G : NetworkX graph + An undirected graph. + + k : integer + Desired edge connectivity + + avail : dict or a set of 2 or 3 tuples + For more details, see :func:`k_edge_augmentation`. + + weight : string + key to use to find weights if ``avail`` is a set of 3-tuples. + For more details, see :func:`k_edge_augmentation`. + + seed : integer, random_state, or None (default) + Indicator of random number generation state. + See :ref:`Randomness`. + + Yields + ------ + edge : tuple + Edges in the greedy augmentation of G + + Notes + ----- + The algorithm is simple. Edges are incrementally added between parts of the + graph that are not yet locally k-edge-connected. Then edges are from the + augmenting set are pruned as long as local-edge-connectivity is not broken. + + This algorithm is greedy and does not provide optimality guarantees. It + exists only to provide :func:`k_edge_augmentation` with the ability to + generate a feasible solution for arbitrary k. + + See Also + -------- + :func:`k_edge_augmentation` + + Examples + -------- + >>> G = nx.path_graph((1, 2, 3, 4, 5, 6, 7)) + >>> sorted(greedy_k_edge_augmentation(G, k=2)) + [(1, 7)] + >>> sorted(greedy_k_edge_augmentation(G, k=1, avail=[])) + [] + >>> G = nx.path_graph((1, 2, 3, 4, 5, 6, 7)) + >>> avail = {(u, v): 1 for (u, v) in complement_edges(G)} + >>> # randomized pruning process can produce different solutions + >>> sorted(greedy_k_edge_augmentation(G, k=4, avail=avail, seed=2)) + [(1, 3), (1, 4), (1, 5), (1, 6), (1, 7), (2, 4), (2, 6), (3, 7), (5, 7)] + >>> sorted(greedy_k_edge_augmentation(G, k=4, avail=avail, seed=3)) + [(1, 3), (1, 5), (1, 6), (2, 4), (2, 6), (3, 7), (4, 7), (5, 7)] + """ + # Result set + aug_edges = [] + + done = is_k_edge_connected(G, k) + if done: + return + if avail is None: + # all edges are available + avail_uv = list(complement_edges(G)) + avail_w = [1] * len(avail_uv) + else: + # Get the unique set of unweighted edges + avail_uv, avail_w = _unpack_available_edges(avail, weight=weight, G=G) + + # Greedy: order lightest edges. Use degree sum to tie-break + tiebreaker = [sum(map(G.degree, uv)) for uv in avail_uv] + avail_wduv = sorted(zip(avail_w, tiebreaker, avail_uv)) + avail_uv = [uv for w, d, uv in avail_wduv] + + # Incrementally add edges in until we are k-connected + H = G.copy() + for u, v in avail_uv: + done = False + if not is_locally_k_edge_connected(H, u, v, k=k): + # Only add edges in parts that are not yet locally k-edge-connected + aug_edges.append((u, v)) + H.add_edge(u, v) + # Did adding this edge help? + if H.degree(u) >= k and H.degree(v) >= k: + done = is_k_edge_connected(H, k) + if done: + break + + # Check for feasibility + if not done: + raise nx.NetworkXUnfeasible("not able to k-edge-connect with available edges") + + # Randomized attempt to reduce the size of the solution + _compat_shuffle(seed, aug_edges) + for u, v in list(aug_edges): + # Don't remove if we know it would break connectivity + if H.degree(u) <= k or H.degree(v) <= k: + continue + H.remove_edge(u, v) + aug_edges.remove((u, v)) + if not is_k_edge_connected(H, k=k): + # If removing this edge breaks feasibility, undo + H.add_edge(u, v) + aug_edges.append((u, v)) + + # Generate results + yield from aug_edges diff --git a/lib/python3.10/site-packages/networkx/algorithms/connectivity/edge_kcomponents.py b/lib/python3.10/site-packages/networkx/algorithms/connectivity/edge_kcomponents.py new file mode 100644 index 0000000000000000000000000000000000000000..96886f2ba39db1bb39812440e5d69b6f073b2af5 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/connectivity/edge_kcomponents.py @@ -0,0 +1,592 @@ +""" +Algorithms for finding k-edge-connected components and subgraphs. + +A k-edge-connected component (k-edge-cc) is a maximal set of nodes in G, such +that all pairs of node have an edge-connectivity of at least k. + +A k-edge-connected subgraph (k-edge-subgraph) is a maximal set of nodes in G, +such that the subgraph of G defined by the nodes has an edge-connectivity at +least k. +""" + +import itertools as it +from functools import partial + +import networkx as nx +from networkx.utils import arbitrary_element, not_implemented_for + +__all__ = [ + "k_edge_components", + "k_edge_subgraphs", + "bridge_components", + "EdgeComponentAuxGraph", +] + + +@not_implemented_for("multigraph") +@nx._dispatchable +def k_edge_components(G, k): + """Generates nodes in each maximal k-edge-connected component in G. + + Parameters + ---------- + G : NetworkX graph + + k : Integer + Desired edge connectivity + + Returns + ------- + k_edge_components : a generator of k-edge-ccs. Each set of returned nodes + will have k-edge-connectivity in the graph G. + + See Also + -------- + :func:`local_edge_connectivity` + :func:`k_edge_subgraphs` : similar to this function, but the subgraph + defined by the nodes must also have k-edge-connectivity. + :func:`k_components` : similar to this function, but uses node-connectivity + instead of edge-connectivity + + Raises + ------ + NetworkXNotImplemented + If the input graph is a multigraph. + + ValueError: + If k is less than 1 + + Notes + ----- + Attempts to use the most efficient implementation available based on k. + If k=1, this is simply connected components for directed graphs and + connected components for undirected graphs. + If k=2 on an efficient bridge connected component algorithm from _[1] is + run based on the chain decomposition. + Otherwise, the algorithm from _[2] is used. + + Examples + -------- + >>> import itertools as it + >>> from networkx.utils import pairwise + >>> paths = [ + ... (1, 2, 4, 3, 1, 4), + ... (5, 6, 7, 8, 5, 7, 8, 6), + ... ] + >>> G = nx.Graph() + >>> G.add_nodes_from(it.chain(*paths)) + >>> G.add_edges_from(it.chain(*[pairwise(path) for path in paths])) + >>> # note this returns {1, 4} unlike k_edge_subgraphs + >>> sorted(map(sorted, nx.k_edge_components(G, k=3))) + [[1, 4], [2], [3], [5, 6, 7, 8]] + + References + ---------- + .. [1] https://en.wikipedia.org/wiki/Bridge_%28graph_theory%29 + .. [2] Wang, Tianhao, et al. (2015) A simple algorithm for finding all + k-edge-connected components. + http://journals.plos.org/plosone/article?id=10.1371/journal.pone.0136264 + """ + # Compute k-edge-ccs using the most efficient algorithms available. + if k < 1: + raise ValueError("k cannot be less than 1") + if G.is_directed(): + if k == 1: + return nx.strongly_connected_components(G) + else: + # TODO: investigate https://arxiv.org/abs/1412.6466 for k=2 + aux_graph = EdgeComponentAuxGraph.construct(G) + return aux_graph.k_edge_components(k) + else: + if k == 1: + return nx.connected_components(G) + elif k == 2: + return bridge_components(G) + else: + aux_graph = EdgeComponentAuxGraph.construct(G) + return aux_graph.k_edge_components(k) + + +@not_implemented_for("multigraph") +@nx._dispatchable +def k_edge_subgraphs(G, k): + """Generates nodes in each maximal k-edge-connected subgraph in G. + + Parameters + ---------- + G : NetworkX graph + + k : Integer + Desired edge connectivity + + Returns + ------- + k_edge_subgraphs : a generator of k-edge-subgraphs + Each k-edge-subgraph is a maximal set of nodes that defines a subgraph + of G that is k-edge-connected. + + See Also + -------- + :func:`edge_connectivity` + :func:`k_edge_components` : similar to this function, but nodes only + need to have k-edge-connectivity within the graph G and the subgraphs + might not be k-edge-connected. + + Raises + ------ + NetworkXNotImplemented + If the input graph is a multigraph. + + ValueError: + If k is less than 1 + + Notes + ----- + Attempts to use the most efficient implementation available based on k. + If k=1, or k=2 and the graph is undirected, then this simply calls + `k_edge_components`. Otherwise the algorithm from _[1] is used. + + Examples + -------- + >>> import itertools as it + >>> from networkx.utils import pairwise + >>> paths = [ + ... (1, 2, 4, 3, 1, 4), + ... (5, 6, 7, 8, 5, 7, 8, 6), + ... ] + >>> G = nx.Graph() + >>> G.add_nodes_from(it.chain(*paths)) + >>> G.add_edges_from(it.chain(*[pairwise(path) for path in paths])) + >>> # note this does not return {1, 4} unlike k_edge_components + >>> sorted(map(sorted, nx.k_edge_subgraphs(G, k=3))) + [[1], [2], [3], [4], [5, 6, 7, 8]] + + References + ---------- + .. [1] Zhou, Liu, et al. (2012) Finding maximal k-edge-connected subgraphs + from a large graph. ACM International Conference on Extending Database + Technology 2012 480-–491. + https://openproceedings.org/2012/conf/edbt/ZhouLYLCL12.pdf + """ + if k < 1: + raise ValueError("k cannot be less than 1") + if G.is_directed(): + if k <= 1: + # For directed graphs , + # When k == 1, k-edge-ccs and k-edge-subgraphs are the same + return k_edge_components(G, k) + else: + return _k_edge_subgraphs_nodes(G, k) + else: + if k <= 2: + # For undirected graphs, + # when k <= 2, k-edge-ccs and k-edge-subgraphs are the same + return k_edge_components(G, k) + else: + return _k_edge_subgraphs_nodes(G, k) + + +def _k_edge_subgraphs_nodes(G, k): + """Helper to get the nodes from the subgraphs. + + This allows k_edge_subgraphs to return a generator. + """ + for C in general_k_edge_subgraphs(G, k): + yield set(C.nodes()) + + +@not_implemented_for("directed") +@not_implemented_for("multigraph") +@nx._dispatchable +def bridge_components(G): + """Finds all bridge-connected components G. + + Parameters + ---------- + G : NetworkX undirected graph + + Returns + ------- + bridge_components : a generator of 2-edge-connected components + + + See Also + -------- + :func:`k_edge_subgraphs` : this function is a special case for an + undirected graph where k=2. + :func:`biconnected_components` : similar to this function, but is defined + using 2-node-connectivity instead of 2-edge-connectivity. + + Raises + ------ + NetworkXNotImplemented + If the input graph is directed or a multigraph. + + Notes + ----- + Bridge-connected components are also known as 2-edge-connected components. + + Examples + -------- + >>> # The barbell graph with parameter zero has a single bridge + >>> G = nx.barbell_graph(5, 0) + >>> from networkx.algorithms.connectivity.edge_kcomponents import bridge_components + >>> sorted(map(sorted, bridge_components(G))) + [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]] + """ + H = G.copy() + H.remove_edges_from(nx.bridges(G)) + yield from nx.connected_components(H) + + +class EdgeComponentAuxGraph: + r"""A simple algorithm to find all k-edge-connected components in a graph. + + Constructing the auxiliary graph (which may take some time) allows for the + k-edge-ccs to be found in linear time for arbitrary k. + + Notes + ----- + This implementation is based on [1]_. The idea is to construct an auxiliary + graph from which the k-edge-ccs can be extracted in linear time. The + auxiliary graph is constructed in $O(|V|\cdot F)$ operations, where F is the + complexity of max flow. Querying the components takes an additional $O(|V|)$ + operations. This algorithm can be slow for large graphs, but it handles an + arbitrary k and works for both directed and undirected inputs. + + The undirected case for k=1 is exactly connected components. + The undirected case for k=2 is exactly bridge connected components. + The directed case for k=1 is exactly strongly connected components. + + References + ---------- + .. [1] Wang, Tianhao, et al. (2015) A simple algorithm for finding all + k-edge-connected components. + http://journals.plos.org/plosone/article?id=10.1371/journal.pone.0136264 + + Examples + -------- + >>> import itertools as it + >>> from networkx.utils import pairwise + >>> from networkx.algorithms.connectivity import EdgeComponentAuxGraph + >>> # Build an interesting graph with multiple levels of k-edge-ccs + >>> paths = [ + ... (1, 2, 3, 4, 1, 3, 4, 2), # a 3-edge-cc (a 4 clique) + ... (5, 6, 7, 5), # a 2-edge-cc (a 3 clique) + ... (1, 5), # combine first two ccs into a 1-edge-cc + ... (0,), # add an additional disconnected 1-edge-cc + ... ] + >>> G = nx.Graph() + >>> G.add_nodes_from(it.chain(*paths)) + >>> G.add_edges_from(it.chain(*[pairwise(path) for path in paths])) + >>> # Constructing the AuxGraph takes about O(n ** 4) + >>> aux_graph = EdgeComponentAuxGraph.construct(G) + >>> # Once constructed, querying takes O(n) + >>> sorted(map(sorted, aux_graph.k_edge_components(k=1))) + [[0], [1, 2, 3, 4, 5, 6, 7]] + >>> sorted(map(sorted, aux_graph.k_edge_components(k=2))) + [[0], [1, 2, 3, 4], [5, 6, 7]] + >>> sorted(map(sorted, aux_graph.k_edge_components(k=3))) + [[0], [1, 2, 3, 4], [5], [6], [7]] + >>> sorted(map(sorted, aux_graph.k_edge_components(k=4))) + [[0], [1], [2], [3], [4], [5], [6], [7]] + + The auxiliary graph is primarily used for k-edge-ccs but it + can also speed up the queries of k-edge-subgraphs by refining the + search space. + + >>> import itertools as it + >>> from networkx.utils import pairwise + >>> from networkx.algorithms.connectivity import EdgeComponentAuxGraph + >>> paths = [ + ... (1, 2, 4, 3, 1, 4), + ... ] + >>> G = nx.Graph() + >>> G.add_nodes_from(it.chain(*paths)) + >>> G.add_edges_from(it.chain(*[pairwise(path) for path in paths])) + >>> aux_graph = EdgeComponentAuxGraph.construct(G) + >>> sorted(map(sorted, aux_graph.k_edge_subgraphs(k=3))) + [[1], [2], [3], [4]] + >>> sorted(map(sorted, aux_graph.k_edge_components(k=3))) + [[1, 4], [2], [3]] + """ + + # @not_implemented_for('multigraph') # TODO: fix decor for classmethods + @classmethod + def construct(EdgeComponentAuxGraph, G): + """Builds an auxiliary graph encoding edge-connectivity between nodes. + + Notes + ----- + Given G=(V, E), initialize an empty auxiliary graph A. + Choose an arbitrary source node s. Initialize a set N of available + nodes (that can be used as the sink). The algorithm picks an + arbitrary node t from N - {s}, and then computes the minimum st-cut + (S, T) with value w. If G is directed the minimum of the st-cut or + the ts-cut is used instead. Then, the edge (s, t) is added to the + auxiliary graph with weight w. The algorithm is called recursively + first using S as the available nodes and s as the source, and then + using T and t. Recursion stops when the source is the only available + node. + + Parameters + ---------- + G : NetworkX graph + """ + # workaround for classmethod decorator + not_implemented_for("multigraph")(lambda G: G)(G) + + def _recursive_build(H, A, source, avail): + # Terminate once the flow has been compute to every node. + if {source} == avail: + return + # pick an arbitrary node as the sink + sink = arbitrary_element(avail - {source}) + # find the minimum cut and its weight + value, (S, T) = nx.minimum_cut(H, source, sink) + if H.is_directed(): + # check if the reverse direction has a smaller cut + value_, (T_, S_) = nx.minimum_cut(H, sink, source) + if value_ < value: + value, S, T = value_, S_, T_ + # add edge with weight of cut to the aux graph + A.add_edge(source, sink, weight=value) + # recursively call until all but one node is used + _recursive_build(H, A, source, avail.intersection(S)) + _recursive_build(H, A, sink, avail.intersection(T)) + + # Copy input to ensure all edges have unit capacity + H = G.__class__() + H.add_nodes_from(G.nodes()) + H.add_edges_from(G.edges(), capacity=1) + + # A is the auxiliary graph to be constructed + # It is a weighted undirected tree + A = nx.Graph() + + # Pick an arbitrary node as the source + if H.number_of_nodes() > 0: + source = arbitrary_element(H.nodes()) + # Initialize a set of elements that can be chosen as the sink + avail = set(H.nodes()) + + # This constructs A + _recursive_build(H, A, source, avail) + + # This class is a container the holds the auxiliary graph A and + # provides access the k_edge_components function. + self = EdgeComponentAuxGraph() + self.A = A + self.H = H + return self + + def k_edge_components(self, k): + """Queries the auxiliary graph for k-edge-connected components. + + Parameters + ---------- + k : Integer + Desired edge connectivity + + Returns + ------- + k_edge_components : a generator of k-edge-ccs + + Notes + ----- + Given the auxiliary graph, the k-edge-connected components can be + determined in linear time by removing all edges with weights less than + k from the auxiliary graph. The resulting connected components are the + k-edge-ccs in the original graph. + """ + if k < 1: + raise ValueError("k cannot be less than 1") + A = self.A + # "traverse the auxiliary graph A and delete all edges with weights less + # than k" + aux_weights = nx.get_edge_attributes(A, "weight") + # Create a relevant graph with the auxiliary edges with weights >= k + R = nx.Graph() + R.add_nodes_from(A.nodes()) + R.add_edges_from(e for e, w in aux_weights.items() if w >= k) + + # Return the nodes that are k-edge-connected in the original graph + yield from nx.connected_components(R) + + def k_edge_subgraphs(self, k): + """Queries the auxiliary graph for k-edge-connected subgraphs. + + Parameters + ---------- + k : Integer + Desired edge connectivity + + Returns + ------- + k_edge_subgraphs : a generator of k-edge-subgraphs + + Notes + ----- + Refines the k-edge-ccs into k-edge-subgraphs. The running time is more + than $O(|V|)$. + + For single values of k it is faster to use `nx.k_edge_subgraphs`. + But for multiple values of k, it can be faster to build AuxGraph and + then use this method. + """ + if k < 1: + raise ValueError("k cannot be less than 1") + H = self.H + A = self.A + # "traverse the auxiliary graph A and delete all edges with weights less + # than k" + aux_weights = nx.get_edge_attributes(A, "weight") + # Create a relevant graph with the auxiliary edges with weights >= k + R = nx.Graph() + R.add_nodes_from(A.nodes()) + R.add_edges_from(e for e, w in aux_weights.items() if w >= k) + + # Return the components whose subgraphs are k-edge-connected + for cc in nx.connected_components(R): + if len(cc) < k: + # Early return optimization + for node in cc: + yield {node} + else: + # Call subgraph solution to refine the results + C = H.subgraph(cc) + yield from k_edge_subgraphs(C, k) + + +def _low_degree_nodes(G, k, nbunch=None): + """Helper for finding nodes with degree less than k.""" + # Nodes with degree less than k cannot be k-edge-connected. + if G.is_directed(): + # Consider both in and out degree in the directed case + seen = set() + for node, degree in G.out_degree(nbunch): + if degree < k: + seen.add(node) + yield node + for node, degree in G.in_degree(nbunch): + if node not in seen and degree < k: + seen.add(node) + yield node + else: + # Only the degree matters in the undirected case + for node, degree in G.degree(nbunch): + if degree < k: + yield node + + +def _high_degree_components(G, k): + """Helper for filtering components that can't be k-edge-connected. + + Removes and generates each node with degree less than k. Then generates + remaining components where all nodes have degree at least k. + """ + # Iteratively remove parts of the graph that are not k-edge-connected + H = G.copy() + singletons = set(_low_degree_nodes(H, k)) + while singletons: + # Only search neighbors of removed nodes + nbunch = set(it.chain.from_iterable(map(H.neighbors, singletons))) + nbunch.difference_update(singletons) + H.remove_nodes_from(singletons) + for node in singletons: + yield {node} + singletons = set(_low_degree_nodes(H, k, nbunch)) + + # Note: remaining connected components may not be k-edge-connected + if G.is_directed(): + yield from nx.strongly_connected_components(H) + else: + yield from nx.connected_components(H) + + +@nx._dispatchable(returns_graph=True) +def general_k_edge_subgraphs(G, k): + """General algorithm to find all maximal k-edge-connected subgraphs in `G`. + + Parameters + ---------- + G : nx.Graph + Graph in which all maximal k-edge-connected subgraphs will be found. + + k : int + + Yields + ------ + k_edge_subgraphs : Graph instances that are k-edge-subgraphs + Each k-edge-subgraph contains a maximal set of nodes that defines a + subgraph of `G` that is k-edge-connected. + + Notes + ----- + Implementation of the basic algorithm from [1]_. The basic idea is to find + a global minimum cut of the graph. If the cut value is at least k, then the + graph is a k-edge-connected subgraph and can be added to the results. + Otherwise, the cut is used to split the graph in two and the procedure is + applied recursively. If the graph is just a single node, then it is also + added to the results. At the end, each result is either guaranteed to be + a single node or a subgraph of G that is k-edge-connected. + + This implementation contains optimizations for reducing the number of calls + to max-flow, but there are other optimizations in [1]_ that could be + implemented. + + References + ---------- + .. [1] Zhou, Liu, et al. (2012) Finding maximal k-edge-connected subgraphs + from a large graph. ACM International Conference on Extending Database + Technology 2012 480-–491. + https://openproceedings.org/2012/conf/edbt/ZhouLYLCL12.pdf + + Examples + -------- + >>> from networkx.utils import pairwise + >>> paths = [ + ... (11, 12, 13, 14, 11, 13, 14, 12), # a 4-clique + ... (21, 22, 23, 24, 21, 23, 24, 22), # another 4-clique + ... # connect the cliques with high degree but low connectivity + ... (50, 13), + ... (12, 50, 22), + ... (13, 102, 23), + ... (14, 101, 24), + ... ] + >>> G = nx.Graph(it.chain(*[pairwise(path) for path in paths])) + >>> sorted(len(k_sg) for k_sg in k_edge_subgraphs(G, k=3)) + [1, 1, 1, 4, 4] + """ + if k < 1: + raise ValueError("k cannot be less than 1") + + # Node pruning optimization (incorporates early return) + # find_ccs is either connected_components/strongly_connected_components + find_ccs = partial(_high_degree_components, k=k) + + # Quick return optimization + if G.number_of_nodes() < k: + for node in G.nodes(): + yield G.subgraph([node]).copy() + return + + # Intermediate results + R0 = {G.subgraph(cc).copy() for cc in find_ccs(G)} + # Subdivide CCs in the intermediate results until they are k-conn + while R0: + G1 = R0.pop() + if G1.number_of_nodes() == 1: + yield G1 + else: + # Find a global minimum cut + cut_edges = nx.minimum_edge_cut(G1) + cut_value = len(cut_edges) + if cut_value < k: + # G1 is not k-edge-connected, so subdivide it + G1.remove_edges_from(cut_edges) + for cc in find_ccs(G1): + R0.add(G1.subgraph(cc).copy()) + else: + # Otherwise we found a k-edge-connected subgraph + yield G1 diff --git a/lib/python3.10/site-packages/networkx/algorithms/connectivity/kcomponents.py b/lib/python3.10/site-packages/networkx/algorithms/connectivity/kcomponents.py new file mode 100644 index 0000000000000000000000000000000000000000..e2f1ba289fb14a705ab6b80883d36f9cdcbed7f1 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/connectivity/kcomponents.py @@ -0,0 +1,223 @@ +""" +Moody and White algorithm for k-components +""" + +from collections import defaultdict +from itertools import combinations +from operator import itemgetter + +import networkx as nx + +# Define the default maximum flow function. +from networkx.algorithms.flow import edmonds_karp +from networkx.utils import not_implemented_for + +default_flow_func = edmonds_karp + +__all__ = ["k_components"] + + +@not_implemented_for("directed") +@nx._dispatchable +def k_components(G, flow_func=None): + r"""Returns the k-component structure of a graph G. + + A `k`-component is a maximal subgraph of a graph G that has, at least, + node connectivity `k`: we need to remove at least `k` nodes to break it + into more components. `k`-components have an inherent hierarchical + structure because they are nested in terms of connectivity: a connected + graph can contain several 2-components, each of which can contain + one or more 3-components, and so forth. + + Parameters + ---------- + G : NetworkX graph + + flow_func : function + Function to perform the underlying flow computations. Default value + :meth:`edmonds_karp`. This function performs better in sparse graphs with + right tailed degree distributions. :meth:`shortest_augmenting_path` will + perform better in denser graphs. + + Returns + ------- + k_components : dict + Dictionary with all connectivity levels `k` in the input Graph as keys + and a list of sets of nodes that form a k-component of level `k` as + values. + + Raises + ------ + NetworkXNotImplemented + If the input graph is directed. + + Examples + -------- + >>> # Petersen graph has 10 nodes and it is triconnected, thus all + >>> # nodes are in a single component on all three connectivity levels + >>> G = nx.petersen_graph() + >>> k_components = nx.k_components(G) + + Notes + ----- + Moody and White [1]_ (appendix A) provide an algorithm for identifying + k-components in a graph, which is based on Kanevsky's algorithm [2]_ + for finding all minimum-size node cut-sets of a graph (implemented in + :meth:`all_node_cuts` function): + + 1. Compute node connectivity, k, of the input graph G. + + 2. Identify all k-cutsets at the current level of connectivity using + Kanevsky's algorithm. + + 3. Generate new graph components based on the removal of + these cutsets. Nodes in a cutset belong to both sides + of the induced cut. + + 4. If the graph is neither complete nor trivial, return to 1; + else end. + + This implementation also uses some heuristics (see [3]_ for details) + to speed up the computation. + + See also + -------- + node_connectivity + all_node_cuts + biconnected_components : special case of this function when k=2 + k_edge_components : similar to this function, but uses edge-connectivity + instead of node-connectivity + + References + ---------- + .. [1] Moody, J. and D. White (2003). Social cohesion and embeddedness: + A hierarchical conception of social groups. + American Sociological Review 68(1), 103--28. + http://www2.asanet.org/journals/ASRFeb03MoodyWhite.pdf + + .. [2] Kanevsky, A. (1993). Finding all minimum-size separating vertex + sets in a graph. Networks 23(6), 533--541. + http://onlinelibrary.wiley.com/doi/10.1002/net.3230230604/abstract + + .. [3] Torrents, J. and F. Ferraro (2015). Structural Cohesion: + Visualization and Heuristics for Fast Computation. + https://arxiv.org/pdf/1503.04476v1 + + """ + # Dictionary with connectivity level (k) as keys and a list of + # sets of nodes that form a k-component as values. Note that + # k-components can overlap (but only k - 1 nodes). + k_components = defaultdict(list) + # Define default flow function + if flow_func is None: + flow_func = default_flow_func + # Bicomponents as a base to check for higher order k-components + for component in nx.connected_components(G): + # isolated nodes have connectivity 0 + comp = set(component) + if len(comp) > 1: + k_components[1].append(comp) + bicomponents = [G.subgraph(c) for c in nx.biconnected_components(G)] + for bicomponent in bicomponents: + bicomp = set(bicomponent) + # avoid considering dyads as bicomponents + if len(bicomp) > 2: + k_components[2].append(bicomp) + for B in bicomponents: + if len(B) <= 2: + continue + k = nx.node_connectivity(B, flow_func=flow_func) + if k > 2: + k_components[k].append(set(B)) + # Perform cuts in a DFS like order. + cuts = list(nx.all_node_cuts(B, k=k, flow_func=flow_func)) + stack = [(k, _generate_partition(B, cuts, k))] + while stack: + (parent_k, partition) = stack[-1] + try: + nodes = next(partition) + C = B.subgraph(nodes) + this_k = nx.node_connectivity(C, flow_func=flow_func) + if this_k > parent_k and this_k > 2: + k_components[this_k].append(set(C)) + cuts = list(nx.all_node_cuts(C, k=this_k, flow_func=flow_func)) + if cuts: + stack.append((this_k, _generate_partition(C, cuts, this_k))) + except StopIteration: + stack.pop() + + # This is necessary because k-components may only be reported at their + # maximum k level. But we want to return a dictionary in which keys are + # connectivity levels and values list of sets of components, without + # skipping any connectivity level. Also, it's possible that subsets of + # an already detected k-component appear at a level k. Checking for this + # in the while loop above penalizes the common case. Thus we also have to + # _consolidate all connectivity levels in _reconstruct_k_components. + return _reconstruct_k_components(k_components) + + +def _consolidate(sets, k): + """Merge sets that share k or more elements. + + See: http://rosettacode.org/wiki/Set_consolidation + + The iterative python implementation posted there is + faster than this because of the overhead of building a + Graph and calling nx.connected_components, but it's not + clear for us if we can use it in NetworkX because there + is no licence for the code. + + """ + G = nx.Graph() + nodes = dict(enumerate(sets)) + G.add_nodes_from(nodes) + G.add_edges_from( + (u, v) for u, v in combinations(nodes, 2) if len(nodes[u] & nodes[v]) >= k + ) + for component in nx.connected_components(G): + yield set.union(*[nodes[n] for n in component]) + + +def _generate_partition(G, cuts, k): + def has_nbrs_in_partition(G, node, partition): + return any(n in partition for n in G[node]) + + components = [] + nodes = {n for n, d in G.degree() if d > k} - {n for cut in cuts for n in cut} + H = G.subgraph(nodes) + for cc in nx.connected_components(H): + component = set(cc) + for cut in cuts: + for node in cut: + if has_nbrs_in_partition(G, node, cc): + component.add(node) + if len(component) < G.order(): + components.append(component) + yield from _consolidate(components, k + 1) + + +def _reconstruct_k_components(k_comps): + result = {} + max_k = max(k_comps) + for k in reversed(range(1, max_k + 1)): + if k == max_k: + result[k] = list(_consolidate(k_comps[k], k)) + elif k not in k_comps: + result[k] = list(_consolidate(result[k + 1], k)) + else: + nodes_at_k = set.union(*k_comps[k]) + to_add = [c for c in result[k + 1] if any(n not in nodes_at_k for n in c)] + if to_add: + result[k] = list(_consolidate(k_comps[k] + to_add, k)) + else: + result[k] = list(_consolidate(k_comps[k], k)) + return result + + +def build_k_number_dict(kcomps): + result = {} + for k, comps in sorted(kcomps.items(), key=itemgetter(0)): + for comp in comps: + for node in comp: + result[node] = k + return result diff --git a/lib/python3.10/site-packages/networkx/algorithms/connectivity/kcutsets.py b/lib/python3.10/site-packages/networkx/algorithms/connectivity/kcutsets.py new file mode 100644 index 0000000000000000000000000000000000000000..de26f4c5d85f42312a811509b7a9b92cd5db952c --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/connectivity/kcutsets.py @@ -0,0 +1,235 @@ +""" +Kanevsky all minimum node k cutsets algorithm. +""" + +import copy +from collections import defaultdict +from itertools import combinations +from operator import itemgetter + +import networkx as nx +from networkx.algorithms.flow import ( + build_residual_network, + edmonds_karp, + shortest_augmenting_path, +) + +from .utils import build_auxiliary_node_connectivity + +default_flow_func = edmonds_karp + + +__all__ = ["all_node_cuts"] + + +@nx._dispatchable +def all_node_cuts(G, k=None, flow_func=None): + r"""Returns all minimum k cutsets of an undirected graph G. + + This implementation is based on Kanevsky's algorithm [1]_ for finding all + minimum-size node cut-sets of an undirected graph G; ie the set (or sets) + of nodes of cardinality equal to the node connectivity of G. Thus if + removed, would break G into two or more connected components. + + Parameters + ---------- + G : NetworkX graph + Undirected graph + + k : Integer + Node connectivity of the input graph. If k is None, then it is + computed. Default value: None. + + flow_func : function + Function to perform the underlying flow computations. Default value is + :func:`~networkx.algorithms.flow.edmonds_karp`. This function performs + better in sparse graphs with right tailed degree distributions. + :func:`~networkx.algorithms.flow.shortest_augmenting_path` will + perform better in denser graphs. + + + Returns + ------- + cuts : a generator of node cutsets + Each node cutset has cardinality equal to the node connectivity of + the input graph. + + Examples + -------- + >>> # A two-dimensional grid graph has 4 cutsets of cardinality 2 + >>> G = nx.grid_2d_graph(5, 5) + >>> cutsets = list(nx.all_node_cuts(G)) + >>> len(cutsets) + 4 + >>> all(2 == len(cutset) for cutset in cutsets) + True + >>> nx.node_connectivity(G) + 2 + + Notes + ----- + This implementation is based on the sequential algorithm for finding all + minimum-size separating vertex sets in a graph [1]_. The main idea is to + compute minimum cuts using local maximum flow computations among a set + of nodes of highest degree and all other non-adjacent nodes in the Graph. + Once we find a minimum cut, we add an edge between the high degree + node and the target node of the local maximum flow computation to make + sure that we will not find that minimum cut again. + + See also + -------- + node_connectivity + edmonds_karp + shortest_augmenting_path + + References + ---------- + .. [1] Kanevsky, A. (1993). Finding all minimum-size separating vertex + sets in a graph. Networks 23(6), 533--541. + http://onlinelibrary.wiley.com/doi/10.1002/net.3230230604/abstract + + """ + if not nx.is_connected(G): + raise nx.NetworkXError("Input graph is disconnected.") + + # Address some corner cases first. + # For complete Graphs + + if nx.density(G) == 1: + yield from () + return + + # Initialize data structures. + # Keep track of the cuts already computed so we do not repeat them. + seen = [] + # Even-Tarjan reduction is what we call auxiliary digraph + # for node connectivity. + H = build_auxiliary_node_connectivity(G) + H_nodes = H.nodes # for speed + mapping = H.graph["mapping"] + # Keep a copy of original predecessors, H will be modified later. + # Shallow copy is enough. + original_H_pred = copy.copy(H._pred) + R = build_residual_network(H, "capacity") + kwargs = {"capacity": "capacity", "residual": R} + # Define default flow function + if flow_func is None: + flow_func = default_flow_func + if flow_func is shortest_augmenting_path: + kwargs["two_phase"] = True + # Begin the actual algorithm + # step 1: Find node connectivity k of G + if k is None: + k = nx.node_connectivity(G, flow_func=flow_func) + # step 2: + # Find k nodes with top degree, call it X: + X = {n for n, d in sorted(G.degree(), key=itemgetter(1), reverse=True)[:k]} + # Check if X is a k-node-cutset + if _is_separating_set(G, X): + seen.append(X) + yield X + + for x in X: + # step 3: Compute local connectivity flow of x with all other + # non adjacent nodes in G + non_adjacent = set(G) - {x} - set(G[x]) + for v in non_adjacent: + # step 4: compute maximum flow in an Even-Tarjan reduction H of G + # and step 5: build the associated residual network R + R = flow_func(H, f"{mapping[x]}B", f"{mapping[v]}A", **kwargs) + flow_value = R.graph["flow_value"] + + if flow_value == k: + # Find the nodes incident to the flow. + E1 = flowed_edges = [ + (u, w) for (u, w, d) in R.edges(data=True) if d["flow"] != 0 + ] + VE1 = incident_nodes = {n for edge in E1 for n in edge} + # Remove saturated edges form the residual network. + # Note that reversed edges are introduced with capacity 0 + # in the residual graph and they need to be removed too. + saturated_edges = [ + (u, w, d) + for (u, w, d) in R.edges(data=True) + if d["capacity"] == d["flow"] or d["capacity"] == 0 + ] + R.remove_edges_from(saturated_edges) + R_closure = nx.transitive_closure(R) + # step 6: shrink the strongly connected components of + # residual flow network R and call it L. + L = nx.condensation(R) + cmap = L.graph["mapping"] + inv_cmap = defaultdict(list) + for n, scc in cmap.items(): + inv_cmap[scc].append(n) + # Find the incident nodes in the condensed graph. + VE1 = {cmap[n] for n in VE1} + # step 7: Compute all antichains of L; + # they map to closed sets in H. + # Any edge in H that links a closed set is part of a cutset. + for antichain in nx.antichains(L): + # Only antichains that are subsets of incident nodes counts. + # Lemma 8 in reference. + if not set(antichain).issubset(VE1): + continue + # Nodes in an antichain of the condensation graph of + # the residual network map to a closed set of nodes that + # define a node partition of the auxiliary digraph H + # through taking all of antichain's predecessors in the + # transitive closure. + S = set() + for scc in antichain: + S.update(inv_cmap[scc]) + S_ancestors = set() + for n in S: + S_ancestors.update(R_closure._pred[n]) + S.update(S_ancestors) + if f"{mapping[x]}B" not in S or f"{mapping[v]}A" in S: + continue + # Find the cutset that links the node partition (S,~S) in H + cutset = set() + for u in S: + cutset.update((u, w) for w in original_H_pred[u] if w not in S) + # The edges in H that form the cutset are internal edges + # (ie edges that represent a node of the original graph G) + if any(H_nodes[u]["id"] != H_nodes[w]["id"] for u, w in cutset): + continue + node_cut = {H_nodes[u]["id"] for u, _ in cutset} + + if len(node_cut) == k: + # The cut is invalid if it includes internal edges of + # end nodes. The other half of Lemma 8 in ref. + if x in node_cut or v in node_cut: + continue + if node_cut not in seen: + yield node_cut + seen.append(node_cut) + + # Add an edge (x, v) to make sure that we do not + # find this cutset again. This is equivalent + # of adding the edge in the input graph + # G.add_edge(x, v) and then regenerate H and R: + # Add edges to the auxiliary digraph. + # See build_residual_network for convention we used + # in residual graphs. + H.add_edge(f"{mapping[x]}B", f"{mapping[v]}A", capacity=1) + H.add_edge(f"{mapping[v]}B", f"{mapping[x]}A", capacity=1) + # Add edges to the residual network. + R.add_edge(f"{mapping[x]}B", f"{mapping[v]}A", capacity=1) + R.add_edge(f"{mapping[v]}A", f"{mapping[x]}B", capacity=0) + R.add_edge(f"{mapping[v]}B", f"{mapping[x]}A", capacity=1) + R.add_edge(f"{mapping[x]}A", f"{mapping[v]}B", capacity=0) + + # Add again the saturated edges to reuse the residual network + R.add_edges_from(saturated_edges) + + +def _is_separating_set(G, cut): + """Assumes that the input graph is connected""" + if len(cut) == len(G) - 1: + return True + + H = nx.restricted_view(G, cut, []) + if nx.is_connected(H): + return False + return True diff --git a/lib/python3.10/site-packages/networkx/algorithms/connectivity/stoerwagner.py b/lib/python3.10/site-packages/networkx/algorithms/connectivity/stoerwagner.py new file mode 100644 index 0000000000000000000000000000000000000000..29604b148303703c73ad37baffec043abd4333e9 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/connectivity/stoerwagner.py @@ -0,0 +1,152 @@ +""" +Stoer-Wagner minimum cut algorithm. +""" + +from itertools import islice + +import networkx as nx + +from ...utils import BinaryHeap, arbitrary_element, not_implemented_for + +__all__ = ["stoer_wagner"] + + +@not_implemented_for("directed") +@not_implemented_for("multigraph") +@nx._dispatchable(edge_attrs="weight") +def stoer_wagner(G, weight="weight", heap=BinaryHeap): + r"""Returns the weighted minimum edge cut using the Stoer-Wagner algorithm. + + Determine the minimum edge cut of a connected graph using the + Stoer-Wagner algorithm. In weighted cases, all weights must be + nonnegative. + + The running time of the algorithm depends on the type of heaps used: + + ============== ============================================= + Type of heap Running time + ============== ============================================= + Binary heap $O(n (m + n) \log n)$ + Fibonacci heap $O(nm + n^2 \log n)$ + Pairing heap $O(2^{2 \sqrt{\log \log n}} nm + n^2 \log n)$ + ============== ============================================= + + Parameters + ---------- + G : NetworkX graph + Edges of the graph are expected to have an attribute named by the + weight parameter below. If this attribute is not present, the edge is + considered to have unit weight. + + weight : string + Name of the weight attribute of the edges. If the attribute is not + present, unit weight is assumed. Default value: 'weight'. + + heap : class + Type of heap to be used in the algorithm. It should be a subclass of + :class:`MinHeap` or implement a compatible interface. + + If a stock heap implementation is to be used, :class:`BinaryHeap` is + recommended over :class:`PairingHeap` for Python implementations without + optimized attribute accesses (e.g., CPython) despite a slower + asymptotic running time. For Python implementations with optimized + attribute accesses (e.g., PyPy), :class:`PairingHeap` provides better + performance. Default value: :class:`BinaryHeap`. + + Returns + ------- + cut_value : integer or float + The sum of weights of edges in a minimum cut. + + partition : pair of node lists + A partitioning of the nodes that defines a minimum cut. + + Raises + ------ + NetworkXNotImplemented + If the graph is directed or a multigraph. + + NetworkXError + If the graph has less than two nodes, is not connected or has a + negative-weighted edge. + + Examples + -------- + >>> G = nx.Graph() + >>> G.add_edge("x", "a", weight=3) + >>> G.add_edge("x", "b", weight=1) + >>> G.add_edge("a", "c", weight=3) + >>> G.add_edge("b", "c", weight=5) + >>> G.add_edge("b", "d", weight=4) + >>> G.add_edge("d", "e", weight=2) + >>> G.add_edge("c", "y", weight=2) + >>> G.add_edge("e", "y", weight=3) + >>> cut_value, partition = nx.stoer_wagner(G) + >>> cut_value + 4 + """ + n = len(G) + if n < 2: + raise nx.NetworkXError("graph has less than two nodes.") + if not nx.is_connected(G): + raise nx.NetworkXError("graph is not connected.") + + # Make a copy of the graph for internal use. + G = nx.Graph( + (u, v, {"weight": e.get(weight, 1)}) for u, v, e in G.edges(data=True) if u != v + ) + G.__networkx_cache__ = None # Disable caching + + for u, v, e in G.edges(data=True): + if e["weight"] < 0: + raise nx.NetworkXError("graph has a negative-weighted edge.") + + cut_value = float("inf") + nodes = set(G) + contractions = [] # contracted node pairs + + # Repeatedly pick a pair of nodes to contract until only one node is left. + for i in range(n - 1): + # Pick an arbitrary node u and create a set A = {u}. + u = arbitrary_element(G) + A = {u} + # Repeatedly pick the node "most tightly connected" to A and add it to + # A. The tightness of connectivity of a node not in A is defined by the + # of edges connecting it to nodes in A. + h = heap() # min-heap emulating a max-heap + for v, e in G[u].items(): + h.insert(v, -e["weight"]) + # Repeat until all but one node has been added to A. + for j in range(n - i - 2): + u = h.pop()[0] + A.add(u) + for v, e in G[u].items(): + if v not in A: + h.insert(v, h.get(v, 0) - e["weight"]) + # A and the remaining node v define a "cut of the phase". There is a + # minimum cut of the original graph that is also a cut of the phase. + # Due to contractions in earlier phases, v may in fact represent + # multiple nodes in the original graph. + v, w = h.min() + w = -w + if w < cut_value: + cut_value = w + best_phase = i + # Contract v and the last node added to A. + contractions.append((u, v)) + for w, e in G[v].items(): + if w != u: + if w not in G[u]: + G.add_edge(u, w, weight=e["weight"]) + else: + G[u][w]["weight"] += e["weight"] + G.remove_node(v) + + # Recover the optimal partitioning from the contractions. + G = nx.Graph(islice(contractions, best_phase)) + v = contractions[best_phase][1] + G.add_node(v) + reachable = set(nx.single_source_shortest_path_length(G, v)) + partition = (list(reachable), list(nodes - reachable)) + + return cut_value, partition diff --git a/lib/python3.10/site-packages/networkx/algorithms/connectivity/tests/__init__.py b/lib/python3.10/site-packages/networkx/algorithms/connectivity/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/python3.10/site-packages/networkx/algorithms/connectivity/tests/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/connectivity/tests/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..124da1895a9bf57545a323352f3dffcd8e29f173 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/connectivity/tests/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/connectivity/tests/__pycache__/test_connectivity.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/connectivity/tests/__pycache__/test_connectivity.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..91dc70f4246b7f120e2bda9abeec60b21783b1c9 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/connectivity/tests/__pycache__/test_connectivity.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/connectivity/tests/__pycache__/test_cuts.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/connectivity/tests/__pycache__/test_cuts.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c10ccc9ad2b50b719faf5b908a0d968d7da371a7 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/connectivity/tests/__pycache__/test_cuts.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/connectivity/tests/__pycache__/test_disjoint_paths.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/connectivity/tests/__pycache__/test_disjoint_paths.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..978f860c2b0f4ddbbed590d4bb106af453b1882b Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/connectivity/tests/__pycache__/test_disjoint_paths.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/connectivity/tests/__pycache__/test_edge_augmentation.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/connectivity/tests/__pycache__/test_edge_augmentation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c30ff3e22b595564e59d5cbf144721e299a37d4c Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/connectivity/tests/__pycache__/test_edge_augmentation.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/connectivity/tests/__pycache__/test_edge_kcomponents.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/connectivity/tests/__pycache__/test_edge_kcomponents.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f96c943861e886cc57c8a223586c187335340f05 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/connectivity/tests/__pycache__/test_edge_kcomponents.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/connectivity/tests/__pycache__/test_kcomponents.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/connectivity/tests/__pycache__/test_kcomponents.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb96ee1959f245e381170167fd378e54431081d6 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/connectivity/tests/__pycache__/test_kcomponents.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/connectivity/tests/__pycache__/test_kcutsets.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/connectivity/tests/__pycache__/test_kcutsets.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..04f7941a9e7c4f8e43fe93723d0a7819b0a063b5 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/connectivity/tests/__pycache__/test_kcutsets.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/connectivity/tests/__pycache__/test_stoer_wagner.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/connectivity/tests/__pycache__/test_stoer_wagner.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..094666238d6ceddc8a405fd9ecb9fd43b3e41b26 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/connectivity/tests/__pycache__/test_stoer_wagner.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/connectivity/tests/test_connectivity.py b/lib/python3.10/site-packages/networkx/algorithms/connectivity/tests/test_connectivity.py new file mode 100644 index 0000000000000000000000000000000000000000..7aef2477d1331bcefc7e5dfdacd415b27ffcd3c8 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/connectivity/tests/test_connectivity.py @@ -0,0 +1,421 @@ +import itertools + +import pytest + +import networkx as nx +from networkx.algorithms import flow +from networkx.algorithms.connectivity import ( + local_edge_connectivity, + local_node_connectivity, +) + +flow_funcs = [ + flow.boykov_kolmogorov, + flow.dinitz, + flow.edmonds_karp, + flow.preflow_push, + flow.shortest_augmenting_path, +] + + +# helper functions for tests + + +def _generate_no_biconnected(max_attempts=50): + attempts = 0 + while True: + G = nx.fast_gnp_random_graph(100, 0.0575, seed=42) + if nx.is_connected(G) and not nx.is_biconnected(G): + attempts = 0 + yield G + else: + if attempts >= max_attempts: + msg = f"Tried {max_attempts} times: no suitable Graph." + raise Exception(msg) + else: + attempts += 1 + + +def test_average_connectivity(): + # figure 1 from: + # Beineke, L., O. Oellermann, and R. Pippert (2002). The average + # connectivity of a graph. Discrete mathematics 252(1-3), 31-45 + # http://www.sciencedirect.com/science/article/pii/S0012365X01001807 + G1 = nx.path_graph(3) + G1.add_edges_from([(1, 3), (1, 4)]) + G2 = nx.path_graph(3) + G2.add_edges_from([(1, 3), (1, 4), (0, 3), (0, 4), (3, 4)]) + G3 = nx.Graph() + for flow_func in flow_funcs: + kwargs = {"flow_func": flow_func} + errmsg = f"Assertion failed in function: {flow_func.__name__}" + assert nx.average_node_connectivity(G1, **kwargs) == 1, errmsg + assert nx.average_node_connectivity(G2, **kwargs) == 2.2, errmsg + assert nx.average_node_connectivity(G3, **kwargs) == 0, errmsg + + +def test_average_connectivity_directed(): + G = nx.DiGraph([(1, 3), (1, 4), (1, 5)]) + for flow_func in flow_funcs: + errmsg = f"Assertion failed in function: {flow_func.__name__}" + assert nx.average_node_connectivity(G) == 0.25, errmsg + + +def test_articulation_points(): + Ggen = _generate_no_biconnected() + for flow_func in flow_funcs: + for i in range(3): + G = next(Ggen) + errmsg = f"Assertion failed in function: {flow_func.__name__}" + assert nx.node_connectivity(G, flow_func=flow_func) == 1, errmsg + + +def test_brandes_erlebach(): + # Figure 1 chapter 7: Connectivity + # http://www.informatik.uni-augsburg.de/thi/personen/kammer/Graph_Connectivity.pdf + G = nx.Graph() + G.add_edges_from( + [ + (1, 2), + (1, 3), + (1, 4), + (1, 5), + (2, 3), + (2, 6), + (3, 4), + (3, 6), + (4, 6), + (4, 7), + (5, 7), + (6, 8), + (6, 9), + (7, 8), + (7, 10), + (8, 11), + (9, 10), + (9, 11), + (10, 11), + ] + ) + for flow_func in flow_funcs: + kwargs = {"flow_func": flow_func} + errmsg = f"Assertion failed in function: {flow_func.__name__}" + assert 3 == local_edge_connectivity(G, 1, 11, **kwargs), errmsg + assert 3 == nx.edge_connectivity(G, 1, 11, **kwargs), errmsg + assert 2 == local_node_connectivity(G, 1, 11, **kwargs), errmsg + assert 2 == nx.node_connectivity(G, 1, 11, **kwargs), errmsg + assert 2 == nx.edge_connectivity(G, **kwargs), errmsg + assert 2 == nx.node_connectivity(G, **kwargs), errmsg + if flow_func is flow.preflow_push: + assert 3 == nx.edge_connectivity(G, 1, 11, cutoff=2, **kwargs), errmsg + else: + assert 2 == nx.edge_connectivity(G, 1, 11, cutoff=2, **kwargs), errmsg + + +def test_white_harary_1(): + # Figure 1b white and harary (2001) + # https://doi.org/10.1111/0081-1750.00098 + # A graph with high adhesion (edge connectivity) and low cohesion + # (vertex connectivity) + G = nx.disjoint_union(nx.complete_graph(4), nx.complete_graph(4)) + G.remove_node(7) + for i in range(4, 7): + G.add_edge(0, i) + G = nx.disjoint_union(G, nx.complete_graph(4)) + G.remove_node(G.order() - 1) + for i in range(7, 10): + G.add_edge(0, i) + for flow_func in flow_funcs: + errmsg = f"Assertion failed in function: {flow_func.__name__}" + assert 1 == nx.node_connectivity(G, flow_func=flow_func), errmsg + assert 3 == nx.edge_connectivity(G, flow_func=flow_func), errmsg + + +def test_white_harary_2(): + # Figure 8 white and harary (2001) + # https://doi.org/10.1111/0081-1750.00098 + G = nx.disjoint_union(nx.complete_graph(4), nx.complete_graph(4)) + G.add_edge(0, 4) + # kappa <= lambda <= delta + assert 3 == min(nx.core_number(G).values()) + for flow_func in flow_funcs: + errmsg = f"Assertion failed in function: {flow_func.__name__}" + assert 1 == nx.node_connectivity(G, flow_func=flow_func), errmsg + assert 1 == nx.edge_connectivity(G, flow_func=flow_func), errmsg + + +def test_complete_graphs(): + for n in range(5, 20, 5): + for flow_func in flow_funcs: + G = nx.complete_graph(n) + errmsg = f"Assertion failed in function: {flow_func.__name__}" + assert n - 1 == nx.node_connectivity(G, flow_func=flow_func), errmsg + assert n - 1 == nx.node_connectivity( + G.to_directed(), flow_func=flow_func + ), errmsg + assert n - 1 == nx.edge_connectivity(G, flow_func=flow_func), errmsg + assert n - 1 == nx.edge_connectivity( + G.to_directed(), flow_func=flow_func + ), errmsg + + +def test_empty_graphs(): + for k in range(5, 25, 5): + G = nx.empty_graph(k) + for flow_func in flow_funcs: + errmsg = f"Assertion failed in function: {flow_func.__name__}" + assert 0 == nx.node_connectivity(G, flow_func=flow_func), errmsg + assert 0 == nx.edge_connectivity(G, flow_func=flow_func), errmsg + + +def test_petersen(): + G = nx.petersen_graph() + for flow_func in flow_funcs: + errmsg = f"Assertion failed in function: {flow_func.__name__}" + assert 3 == nx.node_connectivity(G, flow_func=flow_func), errmsg + assert 3 == nx.edge_connectivity(G, flow_func=flow_func), errmsg + + +def test_tutte(): + G = nx.tutte_graph() + for flow_func in flow_funcs: + errmsg = f"Assertion failed in function: {flow_func.__name__}" + assert 3 == nx.node_connectivity(G, flow_func=flow_func), errmsg + assert 3 == nx.edge_connectivity(G, flow_func=flow_func), errmsg + + +def test_dodecahedral(): + G = nx.dodecahedral_graph() + for flow_func in flow_funcs: + errmsg = f"Assertion failed in function: {flow_func.__name__}" + assert 3 == nx.node_connectivity(G, flow_func=flow_func), errmsg + assert 3 == nx.edge_connectivity(G, flow_func=flow_func), errmsg + + +def test_octahedral(): + G = nx.octahedral_graph() + for flow_func in flow_funcs: + errmsg = f"Assertion failed in function: {flow_func.__name__}" + assert 4 == nx.node_connectivity(G, flow_func=flow_func), errmsg + assert 4 == nx.edge_connectivity(G, flow_func=flow_func), errmsg + + +def test_icosahedral(): + G = nx.icosahedral_graph() + for flow_func in flow_funcs: + errmsg = f"Assertion failed in function: {flow_func.__name__}" + assert 5 == nx.node_connectivity(G, flow_func=flow_func), errmsg + assert 5 == nx.edge_connectivity(G, flow_func=flow_func), errmsg + + +def test_missing_source(): + G = nx.path_graph(4) + for flow_func in flow_funcs: + pytest.raises( + nx.NetworkXError, nx.node_connectivity, G, 10, 1, flow_func=flow_func + ) + + +def test_missing_target(): + G = nx.path_graph(4) + for flow_func in flow_funcs: + pytest.raises( + nx.NetworkXError, nx.node_connectivity, G, 1, 10, flow_func=flow_func + ) + + +def test_edge_missing_source(): + G = nx.path_graph(4) + for flow_func in flow_funcs: + pytest.raises( + nx.NetworkXError, nx.edge_connectivity, G, 10, 1, flow_func=flow_func + ) + + +def test_edge_missing_target(): + G = nx.path_graph(4) + for flow_func in flow_funcs: + pytest.raises( + nx.NetworkXError, nx.edge_connectivity, G, 1, 10, flow_func=flow_func + ) + + +def test_not_weakly_connected(): + G = nx.DiGraph() + nx.add_path(G, [1, 2, 3]) + nx.add_path(G, [4, 5]) + for flow_func in flow_funcs: + errmsg = f"Assertion failed in function: {flow_func.__name__}" + assert nx.node_connectivity(G) == 0, errmsg + assert nx.edge_connectivity(G) == 0, errmsg + + +def test_not_connected(): + G = nx.Graph() + nx.add_path(G, [1, 2, 3]) + nx.add_path(G, [4, 5]) + for flow_func in flow_funcs: + errmsg = f"Assertion failed in function: {flow_func.__name__}" + assert nx.node_connectivity(G) == 0, errmsg + assert nx.edge_connectivity(G) == 0, errmsg + + +def test_directed_edge_connectivity(): + G = nx.cycle_graph(10, create_using=nx.DiGraph()) # only one direction + D = nx.cycle_graph(10).to_directed() # 2 reciprocal edges + for flow_func in flow_funcs: + errmsg = f"Assertion failed in function: {flow_func.__name__}" + assert 1 == nx.edge_connectivity(G, flow_func=flow_func), errmsg + assert 1 == local_edge_connectivity(G, 1, 4, flow_func=flow_func), errmsg + assert 1 == nx.edge_connectivity(G, 1, 4, flow_func=flow_func), errmsg + assert 2 == nx.edge_connectivity(D, flow_func=flow_func), errmsg + assert 2 == local_edge_connectivity(D, 1, 4, flow_func=flow_func), errmsg + assert 2 == nx.edge_connectivity(D, 1, 4, flow_func=flow_func), errmsg + + +def test_cutoff(): + G = nx.complete_graph(5) + for local_func in [local_edge_connectivity, local_node_connectivity]: + for flow_func in flow_funcs: + if flow_func is flow.preflow_push: + # cutoff is not supported by preflow_push + continue + for cutoff in [3, 2, 1]: + result = local_func(G, 0, 4, flow_func=flow_func, cutoff=cutoff) + assert cutoff == result, f"cutoff error in {flow_func.__name__}" + + +def test_invalid_auxiliary(): + G = nx.complete_graph(5) + pytest.raises(nx.NetworkXError, local_node_connectivity, G, 0, 3, auxiliary=G) + + +def test_interface_only_source(): + G = nx.complete_graph(5) + for interface_func in [nx.node_connectivity, nx.edge_connectivity]: + pytest.raises(nx.NetworkXError, interface_func, G, s=0) + + +def test_interface_only_target(): + G = nx.complete_graph(5) + for interface_func in [nx.node_connectivity, nx.edge_connectivity]: + pytest.raises(nx.NetworkXError, interface_func, G, t=3) + + +def test_edge_connectivity_flow_vs_stoer_wagner(): + graph_funcs = [nx.icosahedral_graph, nx.octahedral_graph, nx.dodecahedral_graph] + for graph_func in graph_funcs: + G = graph_func() + assert nx.stoer_wagner(G)[0] == nx.edge_connectivity(G) + + +class TestAllPairsNodeConnectivity: + @classmethod + def setup_class(cls): + cls.path = nx.path_graph(7) + cls.directed_path = nx.path_graph(7, create_using=nx.DiGraph()) + cls.cycle = nx.cycle_graph(7) + cls.directed_cycle = nx.cycle_graph(7, create_using=nx.DiGraph()) + cls.gnp = nx.gnp_random_graph(30, 0.1, seed=42) + cls.directed_gnp = nx.gnp_random_graph(30, 0.1, directed=True, seed=42) + cls.K20 = nx.complete_graph(20) + cls.K10 = nx.complete_graph(10) + cls.K5 = nx.complete_graph(5) + cls.G_list = [ + cls.path, + cls.directed_path, + cls.cycle, + cls.directed_cycle, + cls.gnp, + cls.directed_gnp, + cls.K10, + cls.K5, + cls.K20, + ] + + def test_cycles(self): + K_undir = nx.all_pairs_node_connectivity(self.cycle) + for source in K_undir: + for target, k in K_undir[source].items(): + assert k == 2 + K_dir = nx.all_pairs_node_connectivity(self.directed_cycle) + for source in K_dir: + for target, k in K_dir[source].items(): + assert k == 1 + + def test_complete(self): + for G in [self.K10, self.K5, self.K20]: + K = nx.all_pairs_node_connectivity(G) + for source in K: + for target, k in K[source].items(): + assert k == len(G) - 1 + + def test_paths(self): + K_undir = nx.all_pairs_node_connectivity(self.path) + for source in K_undir: + for target, k in K_undir[source].items(): + assert k == 1 + K_dir = nx.all_pairs_node_connectivity(self.directed_path) + for source in K_dir: + for target, k in K_dir[source].items(): + if source < target: + assert k == 1 + else: + assert k == 0 + + def test_all_pairs_connectivity_nbunch(self): + G = nx.complete_graph(5) + nbunch = [0, 2, 3] + C = nx.all_pairs_node_connectivity(G, nbunch=nbunch) + assert len(C) == len(nbunch) + + def test_all_pairs_connectivity_icosahedral(self): + G = nx.icosahedral_graph() + C = nx.all_pairs_node_connectivity(G) + assert all(5 == C[u][v] for u, v in itertools.combinations(G, 2)) + + def test_all_pairs_connectivity(self): + G = nx.Graph() + nodes = [0, 1, 2, 3] + nx.add_path(G, nodes) + A = {n: {} for n in G} + for u, v in itertools.combinations(nodes, 2): + A[u][v] = A[v][u] = nx.node_connectivity(G, u, v) + C = nx.all_pairs_node_connectivity(G) + assert sorted((k, sorted(v)) for k, v in A.items()) == sorted( + (k, sorted(v)) for k, v in C.items() + ) + + def test_all_pairs_connectivity_directed(self): + G = nx.DiGraph() + nodes = [0, 1, 2, 3] + nx.add_path(G, nodes) + A = {n: {} for n in G} + for u, v in itertools.permutations(nodes, 2): + A[u][v] = nx.node_connectivity(G, u, v) + C = nx.all_pairs_node_connectivity(G) + assert sorted((k, sorted(v)) for k, v in A.items()) == sorted( + (k, sorted(v)) for k, v in C.items() + ) + + def test_all_pairs_connectivity_nbunch_combinations(self): + G = nx.complete_graph(5) + nbunch = [0, 2, 3] + A = {n: {} for n in nbunch} + for u, v in itertools.combinations(nbunch, 2): + A[u][v] = A[v][u] = nx.node_connectivity(G, u, v) + C = nx.all_pairs_node_connectivity(G, nbunch=nbunch) + assert sorted((k, sorted(v)) for k, v in A.items()) == sorted( + (k, sorted(v)) for k, v in C.items() + ) + + def test_all_pairs_connectivity_nbunch_iter(self): + G = nx.complete_graph(5) + nbunch = [0, 2, 3] + A = {n: {} for n in nbunch} + for u, v in itertools.combinations(nbunch, 2): + A[u][v] = A[v][u] = nx.node_connectivity(G, u, v) + C = nx.all_pairs_node_connectivity(G, nbunch=iter(nbunch)) + assert sorted((k, sorted(v)) for k, v in A.items()) == sorted( + (k, sorted(v)) for k, v in C.items() + ) diff --git a/lib/python3.10/site-packages/networkx/algorithms/connectivity/tests/test_cuts.py b/lib/python3.10/site-packages/networkx/algorithms/connectivity/tests/test_cuts.py new file mode 100644 index 0000000000000000000000000000000000000000..7a485be399d87db147f7e4567f903fb5271ad63b --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/connectivity/tests/test_cuts.py @@ -0,0 +1,309 @@ +import pytest + +import networkx as nx +from networkx.algorithms import flow +from networkx.algorithms.connectivity import minimum_st_edge_cut, minimum_st_node_cut +from networkx.utils import arbitrary_element + +flow_funcs = [ + flow.boykov_kolmogorov, + flow.dinitz, + flow.edmonds_karp, + flow.preflow_push, + flow.shortest_augmenting_path, +] + +# Tests for node and edge cutsets + + +def _generate_no_biconnected(max_attempts=50): + attempts = 0 + while True: + G = nx.fast_gnp_random_graph(100, 0.0575, seed=42) + if nx.is_connected(G) and not nx.is_biconnected(G): + attempts = 0 + yield G + else: + if attempts >= max_attempts: + msg = f"Tried {attempts} times: no suitable Graph." + raise Exception(msg) + else: + attempts += 1 + + +def test_articulation_points(): + Ggen = _generate_no_biconnected() + for flow_func in flow_funcs: + errmsg = f"Assertion failed in function: {flow_func.__name__}" + for i in range(1): # change 1 to 3 or more for more realizations. + G = next(Ggen) + cut = nx.minimum_node_cut(G, flow_func=flow_func) + assert len(cut) == 1, errmsg + assert cut.pop() in set(nx.articulation_points(G)), errmsg + + +def test_brandes_erlebach_book(): + # Figure 1 chapter 7: Connectivity + # http://www.informatik.uni-augsburg.de/thi/personen/kammer/Graph_Connectivity.pdf + G = nx.Graph() + G.add_edges_from( + [ + (1, 2), + (1, 3), + (1, 4), + (1, 5), + (2, 3), + (2, 6), + (3, 4), + (3, 6), + (4, 6), + (4, 7), + (5, 7), + (6, 8), + (6, 9), + (7, 8), + (7, 10), + (8, 11), + (9, 10), + (9, 11), + (10, 11), + ] + ) + for flow_func in flow_funcs: + kwargs = {"flow_func": flow_func} + errmsg = f"Assertion failed in function: {flow_func.__name__}" + # edge cutsets + assert 3 == len(nx.minimum_edge_cut(G, 1, 11, **kwargs)), errmsg + edge_cut = nx.minimum_edge_cut(G, **kwargs) + # Node 5 has only two edges + assert 2 == len(edge_cut), errmsg + H = G.copy() + H.remove_edges_from(edge_cut) + assert not nx.is_connected(H), errmsg + # node cuts + assert {6, 7} == minimum_st_node_cut(G, 1, 11, **kwargs), errmsg + assert {6, 7} == nx.minimum_node_cut(G, 1, 11, **kwargs), errmsg + node_cut = nx.minimum_node_cut(G, **kwargs) + assert 2 == len(node_cut), errmsg + H = G.copy() + H.remove_nodes_from(node_cut) + assert not nx.is_connected(H), errmsg + + +def test_white_harary_paper(): + # Figure 1b white and harary (2001) + # https://doi.org/10.1111/0081-1750.00098 + # A graph with high adhesion (edge connectivity) and low cohesion + # (node connectivity) + G = nx.disjoint_union(nx.complete_graph(4), nx.complete_graph(4)) + G.remove_node(7) + for i in range(4, 7): + G.add_edge(0, i) + G = nx.disjoint_union(G, nx.complete_graph(4)) + G.remove_node(G.order() - 1) + for i in range(7, 10): + G.add_edge(0, i) + for flow_func in flow_funcs: + kwargs = {"flow_func": flow_func} + errmsg = f"Assertion failed in function: {flow_func.__name__}" + # edge cuts + edge_cut = nx.minimum_edge_cut(G, **kwargs) + assert 3 == len(edge_cut), errmsg + H = G.copy() + H.remove_edges_from(edge_cut) + assert not nx.is_connected(H), errmsg + # node cuts + node_cut = nx.minimum_node_cut(G, **kwargs) + assert {0} == node_cut, errmsg + H = G.copy() + H.remove_nodes_from(node_cut) + assert not nx.is_connected(H), errmsg + + +def test_petersen_cutset(): + G = nx.petersen_graph() + for flow_func in flow_funcs: + kwargs = {"flow_func": flow_func} + errmsg = f"Assertion failed in function: {flow_func.__name__}" + # edge cuts + edge_cut = nx.minimum_edge_cut(G, **kwargs) + assert 3 == len(edge_cut), errmsg + H = G.copy() + H.remove_edges_from(edge_cut) + assert not nx.is_connected(H), errmsg + # node cuts + node_cut = nx.minimum_node_cut(G, **kwargs) + assert 3 == len(node_cut), errmsg + H = G.copy() + H.remove_nodes_from(node_cut) + assert not nx.is_connected(H), errmsg + + +def test_octahedral_cutset(): + G = nx.octahedral_graph() + for flow_func in flow_funcs: + kwargs = {"flow_func": flow_func} + errmsg = f"Assertion failed in function: {flow_func.__name__}" + # edge cuts + edge_cut = nx.minimum_edge_cut(G, **kwargs) + assert 4 == len(edge_cut), errmsg + H = G.copy() + H.remove_edges_from(edge_cut) + assert not nx.is_connected(H), errmsg + # node cuts + node_cut = nx.minimum_node_cut(G, **kwargs) + assert 4 == len(node_cut), errmsg + H = G.copy() + H.remove_nodes_from(node_cut) + assert not nx.is_connected(H), errmsg + + +def test_icosahedral_cutset(): + G = nx.icosahedral_graph() + for flow_func in flow_funcs: + kwargs = {"flow_func": flow_func} + errmsg = f"Assertion failed in function: {flow_func.__name__}" + # edge cuts + edge_cut = nx.minimum_edge_cut(G, **kwargs) + assert 5 == len(edge_cut), errmsg + H = G.copy() + H.remove_edges_from(edge_cut) + assert not nx.is_connected(H), errmsg + # node cuts + node_cut = nx.minimum_node_cut(G, **kwargs) + assert 5 == len(node_cut), errmsg + H = G.copy() + H.remove_nodes_from(node_cut) + assert not nx.is_connected(H), errmsg + + +def test_node_cutset_exception(): + G = nx.Graph() + G.add_edges_from([(1, 2), (3, 4)]) + for flow_func in flow_funcs: + pytest.raises(nx.NetworkXError, nx.minimum_node_cut, G, flow_func=flow_func) + + +def test_node_cutset_random_graphs(): + for flow_func in flow_funcs: + errmsg = f"Assertion failed in function: {flow_func.__name__}" + for i in range(3): + G = nx.fast_gnp_random_graph(50, 0.25, seed=42) + if not nx.is_connected(G): + ccs = iter(nx.connected_components(G)) + start = arbitrary_element(next(ccs)) + G.add_edges_from((start, arbitrary_element(c)) for c in ccs) + cutset = nx.minimum_node_cut(G, flow_func=flow_func) + assert nx.node_connectivity(G) == len(cutset), errmsg + G.remove_nodes_from(cutset) + assert not nx.is_connected(G), errmsg + + +def test_edge_cutset_random_graphs(): + for flow_func in flow_funcs: + errmsg = f"Assertion failed in function: {flow_func.__name__}" + for i in range(3): + G = nx.fast_gnp_random_graph(50, 0.25, seed=42) + if not nx.is_connected(G): + ccs = iter(nx.connected_components(G)) + start = arbitrary_element(next(ccs)) + G.add_edges_from((start, arbitrary_element(c)) for c in ccs) + cutset = nx.minimum_edge_cut(G, flow_func=flow_func) + assert nx.edge_connectivity(G) == len(cutset), errmsg + G.remove_edges_from(cutset) + assert not nx.is_connected(G), errmsg + + +def test_empty_graphs(): + G = nx.Graph() + D = nx.DiGraph() + for interface_func in [nx.minimum_node_cut, nx.minimum_edge_cut]: + for flow_func in flow_funcs: + pytest.raises( + nx.NetworkXPointlessConcept, interface_func, G, flow_func=flow_func + ) + pytest.raises( + nx.NetworkXPointlessConcept, interface_func, D, flow_func=flow_func + ) + + +def test_unbounded(): + G = nx.complete_graph(5) + for flow_func in flow_funcs: + assert 4 == len(minimum_st_edge_cut(G, 1, 4, flow_func=flow_func)) + + +def test_missing_source(): + G = nx.path_graph(4) + for interface_func in [nx.minimum_edge_cut, nx.minimum_node_cut]: + for flow_func in flow_funcs: + pytest.raises( + nx.NetworkXError, interface_func, G, 10, 1, flow_func=flow_func + ) + + +def test_missing_target(): + G = nx.path_graph(4) + for interface_func in [nx.minimum_edge_cut, nx.minimum_node_cut]: + for flow_func in flow_funcs: + pytest.raises( + nx.NetworkXError, interface_func, G, 1, 10, flow_func=flow_func + ) + + +def test_not_weakly_connected(): + G = nx.DiGraph() + nx.add_path(G, [1, 2, 3]) + nx.add_path(G, [4, 5]) + for interface_func in [nx.minimum_edge_cut, nx.minimum_node_cut]: + for flow_func in flow_funcs: + pytest.raises(nx.NetworkXError, interface_func, G, flow_func=flow_func) + + +def test_not_connected(): + G = nx.Graph() + nx.add_path(G, [1, 2, 3]) + nx.add_path(G, [4, 5]) + for interface_func in [nx.minimum_edge_cut, nx.minimum_node_cut]: + for flow_func in flow_funcs: + pytest.raises(nx.NetworkXError, interface_func, G, flow_func=flow_func) + + +def tests_min_cut_complete(): + G = nx.complete_graph(5) + for interface_func in [nx.minimum_edge_cut, nx.minimum_node_cut]: + for flow_func in flow_funcs: + assert 4 == len(interface_func(G, flow_func=flow_func)) + + +def tests_min_cut_complete_directed(): + G = nx.complete_graph(5) + G = G.to_directed() + for interface_func in [nx.minimum_edge_cut, nx.minimum_node_cut]: + for flow_func in flow_funcs: + assert 4 == len(interface_func(G, flow_func=flow_func)) + + +def tests_minimum_st_node_cut(): + G = nx.Graph() + G.add_nodes_from([0, 1, 2, 3, 7, 8, 11, 12]) + G.add_edges_from([(7, 11), (1, 11), (1, 12), (12, 8), (0, 1)]) + nodelist = minimum_st_node_cut(G, 7, 11) + assert nodelist == {} + + +def test_invalid_auxiliary(): + G = nx.complete_graph(5) + pytest.raises(nx.NetworkXError, minimum_st_node_cut, G, 0, 3, auxiliary=G) + + +def test_interface_only_source(): + G = nx.complete_graph(5) + for interface_func in [nx.minimum_node_cut, nx.minimum_edge_cut]: + pytest.raises(nx.NetworkXError, interface_func, G, s=0) + + +def test_interface_only_target(): + G = nx.complete_graph(5) + for interface_func in [nx.minimum_node_cut, nx.minimum_edge_cut]: + pytest.raises(nx.NetworkXError, interface_func, G, t=3) diff --git a/lib/python3.10/site-packages/networkx/algorithms/connectivity/tests/test_disjoint_paths.py b/lib/python3.10/site-packages/networkx/algorithms/connectivity/tests/test_disjoint_paths.py new file mode 100644 index 0000000000000000000000000000000000000000..0c0fad9f5ca474a6b547a399f8f284f7ff6e33a4 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/connectivity/tests/test_disjoint_paths.py @@ -0,0 +1,249 @@ +import pytest + +import networkx as nx +from networkx.algorithms import flow +from networkx.utils import pairwise + +flow_funcs = [ + flow.boykov_kolmogorov, + flow.edmonds_karp, + flow.dinitz, + flow.preflow_push, + flow.shortest_augmenting_path, +] + + +def is_path(G, path): + return all(v in G[u] for u, v in pairwise(path)) + + +def are_edge_disjoint_paths(G, paths): + if not paths: + return False + for path in paths: + assert is_path(G, path) + paths_edges = [list(pairwise(p)) for p in paths] + num_of_edges = sum(len(e) for e in paths_edges) + num_unique_edges = len(set.union(*[set(es) for es in paths_edges])) + if num_of_edges == num_unique_edges: + return True + return False + + +def are_node_disjoint_paths(G, paths): + if not paths: + return False + for path in paths: + assert is_path(G, path) + # first and last nodes are source and target + st = {paths[0][0], paths[0][-1]} + num_of_nodes = len([n for path in paths for n in path if n not in st]) + num_unique_nodes = len({n for path in paths for n in path if n not in st}) + if num_of_nodes == num_unique_nodes: + return True + return False + + +def test_graph_from_pr_2053(): + G = nx.Graph() + G.add_edges_from( + [ + ("A", "B"), + ("A", "D"), + ("A", "F"), + ("A", "G"), + ("B", "C"), + ("B", "D"), + ("B", "G"), + ("C", "D"), + ("C", "E"), + ("C", "Z"), + ("D", "E"), + ("D", "F"), + ("E", "F"), + ("E", "Z"), + ("F", "Z"), + ("G", "Z"), + ] + ) + for flow_func in flow_funcs: + kwargs = {"flow_func": flow_func} + errmsg = f"Assertion failed in function: {flow_func.__name__}" + # edge disjoint paths + edge_paths = list(nx.edge_disjoint_paths(G, "A", "Z", **kwargs)) + assert are_edge_disjoint_paths(G, edge_paths), errmsg + assert nx.edge_connectivity(G, "A", "Z") == len(edge_paths), errmsg + # node disjoint paths + node_paths = list(nx.node_disjoint_paths(G, "A", "Z", **kwargs)) + assert are_node_disjoint_paths(G, node_paths), errmsg + assert nx.node_connectivity(G, "A", "Z") == len(node_paths), errmsg + + +def test_florentine_families(): + G = nx.florentine_families_graph() + for flow_func in flow_funcs: + kwargs = {"flow_func": flow_func} + errmsg = f"Assertion failed in function: {flow_func.__name__}" + # edge disjoint paths + edge_dpaths = list(nx.edge_disjoint_paths(G, "Medici", "Strozzi", **kwargs)) + assert are_edge_disjoint_paths(G, edge_dpaths), errmsg + assert nx.edge_connectivity(G, "Medici", "Strozzi") == len(edge_dpaths), errmsg + # node disjoint paths + node_dpaths = list(nx.node_disjoint_paths(G, "Medici", "Strozzi", **kwargs)) + assert are_node_disjoint_paths(G, node_dpaths), errmsg + assert nx.node_connectivity(G, "Medici", "Strozzi") == len(node_dpaths), errmsg + + +def test_karate(): + G = nx.karate_club_graph() + for flow_func in flow_funcs: + kwargs = {"flow_func": flow_func} + errmsg = f"Assertion failed in function: {flow_func.__name__}" + # edge disjoint paths + edge_dpaths = list(nx.edge_disjoint_paths(G, 0, 33, **kwargs)) + assert are_edge_disjoint_paths(G, edge_dpaths), errmsg + assert nx.edge_connectivity(G, 0, 33) == len(edge_dpaths), errmsg + # node disjoint paths + node_dpaths = list(nx.node_disjoint_paths(G, 0, 33, **kwargs)) + assert are_node_disjoint_paths(G, node_dpaths), errmsg + assert nx.node_connectivity(G, 0, 33) == len(node_dpaths), errmsg + + +def test_petersen_disjoint_paths(): + G = nx.petersen_graph() + for flow_func in flow_funcs: + kwargs = {"flow_func": flow_func} + errmsg = f"Assertion failed in function: {flow_func.__name__}" + # edge disjoint paths + edge_dpaths = list(nx.edge_disjoint_paths(G, 0, 6, **kwargs)) + assert are_edge_disjoint_paths(G, edge_dpaths), errmsg + assert 3 == len(edge_dpaths), errmsg + # node disjoint paths + node_dpaths = list(nx.node_disjoint_paths(G, 0, 6, **kwargs)) + assert are_node_disjoint_paths(G, node_dpaths), errmsg + assert 3 == len(node_dpaths), errmsg + + +def test_octahedral_disjoint_paths(): + G = nx.octahedral_graph() + for flow_func in flow_funcs: + kwargs = {"flow_func": flow_func} + errmsg = f"Assertion failed in function: {flow_func.__name__}" + # edge disjoint paths + edge_dpaths = list(nx.edge_disjoint_paths(G, 0, 5, **kwargs)) + assert are_edge_disjoint_paths(G, edge_dpaths), errmsg + assert 4 == len(edge_dpaths), errmsg + # node disjoint paths + node_dpaths = list(nx.node_disjoint_paths(G, 0, 5, **kwargs)) + assert are_node_disjoint_paths(G, node_dpaths), errmsg + assert 4 == len(node_dpaths), errmsg + + +def test_icosahedral_disjoint_paths(): + G = nx.icosahedral_graph() + for flow_func in flow_funcs: + kwargs = {"flow_func": flow_func} + errmsg = f"Assertion failed in function: {flow_func.__name__}" + # edge disjoint paths + edge_dpaths = list(nx.edge_disjoint_paths(G, 0, 6, **kwargs)) + assert are_edge_disjoint_paths(G, edge_dpaths), errmsg + assert 5 == len(edge_dpaths), errmsg + # node disjoint paths + node_dpaths = list(nx.node_disjoint_paths(G, 0, 6, **kwargs)) + assert are_node_disjoint_paths(G, node_dpaths), errmsg + assert 5 == len(node_dpaths), errmsg + + +def test_cutoff_disjoint_paths(): + G = nx.icosahedral_graph() + for flow_func in flow_funcs: + kwargs = {"flow_func": flow_func} + errmsg = f"Assertion failed in function: {flow_func.__name__}" + for cutoff in [2, 4]: + kwargs["cutoff"] = cutoff + # edge disjoint paths + edge_dpaths = list(nx.edge_disjoint_paths(G, 0, 6, **kwargs)) + assert are_edge_disjoint_paths(G, edge_dpaths), errmsg + assert cutoff == len(edge_dpaths), errmsg + # node disjoint paths + node_dpaths = list(nx.node_disjoint_paths(G, 0, 6, **kwargs)) + assert are_node_disjoint_paths(G, node_dpaths), errmsg + assert cutoff == len(node_dpaths), errmsg + + +def test_missing_source_edge_paths(): + with pytest.raises(nx.NetworkXError): + G = nx.path_graph(4) + list(nx.edge_disjoint_paths(G, 10, 1)) + + +def test_missing_source_node_paths(): + with pytest.raises(nx.NetworkXError): + G = nx.path_graph(4) + list(nx.node_disjoint_paths(G, 10, 1)) + + +def test_missing_target_edge_paths(): + with pytest.raises(nx.NetworkXError): + G = nx.path_graph(4) + list(nx.edge_disjoint_paths(G, 1, 10)) + + +def test_missing_target_node_paths(): + with pytest.raises(nx.NetworkXError): + G = nx.path_graph(4) + list(nx.node_disjoint_paths(G, 1, 10)) + + +def test_not_weakly_connected_edges(): + with pytest.raises(nx.NetworkXNoPath): + G = nx.DiGraph() + nx.add_path(G, [1, 2, 3]) + nx.add_path(G, [4, 5]) + list(nx.edge_disjoint_paths(G, 1, 5)) + + +def test_not_weakly_connected_nodes(): + with pytest.raises(nx.NetworkXNoPath): + G = nx.DiGraph() + nx.add_path(G, [1, 2, 3]) + nx.add_path(G, [4, 5]) + list(nx.node_disjoint_paths(G, 1, 5)) + + +def test_not_connected_edges(): + with pytest.raises(nx.NetworkXNoPath): + G = nx.Graph() + nx.add_path(G, [1, 2, 3]) + nx.add_path(G, [4, 5]) + list(nx.edge_disjoint_paths(G, 1, 5)) + + +def test_not_connected_nodes(): + with pytest.raises(nx.NetworkXNoPath): + G = nx.Graph() + nx.add_path(G, [1, 2, 3]) + nx.add_path(G, [4, 5]) + list(nx.node_disjoint_paths(G, 1, 5)) + + +def test_isolated_edges(): + with pytest.raises(nx.NetworkXNoPath): + G = nx.Graph() + G.add_node(1) + nx.add_path(G, [4, 5]) + list(nx.edge_disjoint_paths(G, 1, 5)) + + +def test_isolated_nodes(): + with pytest.raises(nx.NetworkXNoPath): + G = nx.Graph() + G.add_node(1) + nx.add_path(G, [4, 5]) + list(nx.node_disjoint_paths(G, 1, 5)) + + +def test_invalid_auxiliary(): + with pytest.raises(nx.NetworkXError): + G = nx.complete_graph(5) + list(nx.node_disjoint_paths(G, 0, 3, auxiliary=G)) diff --git a/lib/python3.10/site-packages/networkx/algorithms/connectivity/tests/test_edge_augmentation.py b/lib/python3.10/site-packages/networkx/algorithms/connectivity/tests/test_edge_augmentation.py new file mode 100644 index 0000000000000000000000000000000000000000..e1d92d99616ac593d3d0ed358a804732d629f62e --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/connectivity/tests/test_edge_augmentation.py @@ -0,0 +1,502 @@ +import itertools as it +import random + +import pytest + +import networkx as nx +from networkx.algorithms.connectivity import k_edge_augmentation +from networkx.algorithms.connectivity.edge_augmentation import ( + _unpack_available_edges, + collapse, + complement_edges, + is_k_edge_connected, + is_locally_k_edge_connected, +) +from networkx.utils import pairwise + +# This should be set to the largest k for which an efficient algorithm is +# explicitly defined. +MAX_EFFICIENT_K = 2 + + +def tarjan_bridge_graph(): + # graph from tarjan paper + # RE Tarjan - "A note on finding the bridges of a graph" + # Information Processing Letters, 1974 - Elsevier + # doi:10.1016/0020-0190(74)90003-9. + # define 2-connected components and bridges + ccs = [ + (1, 2, 4, 3, 1, 4), + (5, 6, 7, 5), + (8, 9, 10, 8), + (17, 18, 16, 15, 17), + (11, 12, 14, 13, 11, 14), + ] + bridges = [(4, 8), (3, 5), (3, 17)] + G = nx.Graph(it.chain(*(pairwise(path) for path in ccs + bridges))) + return G + + +def test_weight_key(): + G = nx.Graph() + G.add_nodes_from([1, 2, 3, 4, 5, 6, 7, 8, 9]) + G.add_edges_from([(3, 8), (1, 2), (2, 3)]) + impossible = {(3, 6), (3, 9)} + rng = random.Random(0) + avail_uv = list(set(complement_edges(G)) - impossible) + avail = [(u, v, {"cost": rng.random()}) for u, v in avail_uv] + + _augment_and_check(G, k=1) + _augment_and_check(G, k=1, avail=avail_uv) + _augment_and_check(G, k=1, avail=avail, weight="cost") + + _check_augmentations(G, avail, weight="cost") + + +def test_is_locally_k_edge_connected_exceptions(): + pytest.raises(nx.NetworkXNotImplemented, is_k_edge_connected, nx.DiGraph(), k=0) + pytest.raises(nx.NetworkXNotImplemented, is_k_edge_connected, nx.MultiGraph(), k=0) + pytest.raises(ValueError, is_k_edge_connected, nx.Graph(), k=0) + + +def test_is_k_edge_connected(): + G = nx.barbell_graph(10, 0) + assert is_k_edge_connected(G, k=1) + assert not is_k_edge_connected(G, k=2) + + G = nx.Graph() + G.add_nodes_from([5, 15]) + assert not is_k_edge_connected(G, k=1) + assert not is_k_edge_connected(G, k=2) + + G = nx.complete_graph(5) + assert is_k_edge_connected(G, k=1) + assert is_k_edge_connected(G, k=2) + assert is_k_edge_connected(G, k=3) + assert is_k_edge_connected(G, k=4) + + G = nx.compose(nx.complete_graph([0, 1, 2]), nx.complete_graph([3, 4, 5])) + assert not is_k_edge_connected(G, k=1) + assert not is_k_edge_connected(G, k=2) + assert not is_k_edge_connected(G, k=3) + + +def test_is_k_edge_connected_exceptions(): + pytest.raises( + nx.NetworkXNotImplemented, is_locally_k_edge_connected, nx.DiGraph(), 1, 2, k=0 + ) + pytest.raises( + nx.NetworkXNotImplemented, + is_locally_k_edge_connected, + nx.MultiGraph(), + 1, + 2, + k=0, + ) + pytest.raises(ValueError, is_locally_k_edge_connected, nx.Graph(), 1, 2, k=0) + + +def test_is_locally_k_edge_connected(): + G = nx.barbell_graph(10, 0) + assert is_locally_k_edge_connected(G, 5, 15, k=1) + assert not is_locally_k_edge_connected(G, 5, 15, k=2) + + G = nx.Graph() + G.add_nodes_from([5, 15]) + assert not is_locally_k_edge_connected(G, 5, 15, k=2) + + +def test_null_graph(): + G = nx.Graph() + _check_augmentations(G, max_k=MAX_EFFICIENT_K + 2) + + +def test_cliques(): + for n in range(1, 10): + G = nx.complete_graph(n) + _check_augmentations(G, max_k=MAX_EFFICIENT_K + 2) + + +def test_clique_and_node(): + for n in range(1, 10): + G = nx.complete_graph(n) + G.add_node(n + 1) + _check_augmentations(G, max_k=MAX_EFFICIENT_K + 2) + + +def test_point_graph(): + G = nx.Graph() + G.add_node(1) + _check_augmentations(G, max_k=MAX_EFFICIENT_K + 2) + + +def test_edgeless_graph(): + G = nx.Graph() + G.add_nodes_from([1, 2, 3, 4]) + _check_augmentations(G) + + +def test_invalid_k(): + G = nx.Graph() + pytest.raises(ValueError, list, k_edge_augmentation(G, k=-1)) + pytest.raises(ValueError, list, k_edge_augmentation(G, k=0)) + + +def test_unfeasible(): + G = tarjan_bridge_graph() + pytest.raises(nx.NetworkXUnfeasible, list, k_edge_augmentation(G, k=1, avail=[])) + + pytest.raises(nx.NetworkXUnfeasible, list, k_edge_augmentation(G, k=2, avail=[])) + + pytest.raises( + nx.NetworkXUnfeasible, list, k_edge_augmentation(G, k=2, avail=[(7, 9)]) + ) + + # partial solutions should not error if real solutions are infeasible + aug_edges = list(k_edge_augmentation(G, k=2, avail=[(7, 9)], partial=True)) + assert aug_edges == [(7, 9)] + + _check_augmentations(G, avail=[], max_k=MAX_EFFICIENT_K + 2) + + _check_augmentations(G, avail=[(7, 9)], max_k=MAX_EFFICIENT_K + 2) + + +def test_tarjan(): + G = tarjan_bridge_graph() + + aug_edges = set(_augment_and_check(G, k=2)[0]) + print(f"aug_edges = {aug_edges!r}") + # can't assert edge exactly equality due to non-determinant edge order + # but we do know the size of the solution must be 3 + assert len(aug_edges) == 3 + + avail = [ + (9, 7), + (8, 5), + (2, 10), + (6, 13), + (11, 18), + (1, 17), + (2, 3), + (16, 17), + (18, 14), + (15, 14), + ] + aug_edges = set(_augment_and_check(G, avail=avail, k=2)[0]) + + # Can't assert exact length since approximation depends on the order of a + # dict traversal. + assert len(aug_edges) <= 3 * 2 + + _check_augmentations(G, avail) + + +def test_configuration(): + # seeds = [2718183590, 2470619828, 1694705158, 3001036531, 2401251497] + seeds = [1001, 1002, 1003, 1004] + for seed in seeds: + deg_seq = nx.random_powerlaw_tree_sequence(20, seed=seed, tries=5000) + G = nx.Graph(nx.configuration_model(deg_seq, seed=seed)) + G.remove_edges_from(nx.selfloop_edges(G)) + _check_augmentations(G) + + +def test_shell(): + # seeds = [2057382236, 3331169846, 1840105863, 476020778, 2247498425] + seeds = [18] + for seed in seeds: + constructor = [(12, 70, 0.8), (15, 40, 0.6)] + G = nx.random_shell_graph(constructor, seed=seed) + _check_augmentations(G) + + +def test_karate(): + G = nx.karate_club_graph() + _check_augmentations(G) + + +def test_star(): + G = nx.star_graph(3) + _check_augmentations(G) + + G = nx.star_graph(5) + _check_augmentations(G) + + G = nx.star_graph(10) + _check_augmentations(G) + + +def test_barbell(): + G = nx.barbell_graph(5, 0) + _check_augmentations(G) + + G = nx.barbell_graph(5, 2) + _check_augmentations(G) + + G = nx.barbell_graph(5, 3) + _check_augmentations(G) + + G = nx.barbell_graph(5, 4) + _check_augmentations(G) + + +def test_bridge(): + G = nx.Graph([(2393, 2257), (2393, 2685), (2685, 2257), (1758, 2257)]) + _check_augmentations(G) + + +def test_gnp_augmentation(): + rng = random.Random(0) + G = nx.gnp_random_graph(30, 0.005, seed=0) + # Randomly make edges available + avail = { + (u, v): 1 + rng.random() for u, v in complement_edges(G) if rng.random() < 0.25 + } + _check_augmentations(G, avail) + + +def _assert_solution_properties(G, aug_edges, avail_dict=None): + """Checks that aug_edges are consistently formatted""" + if avail_dict is not None: + assert all( + e in avail_dict for e in aug_edges + ), "when avail is specified aug-edges should be in avail" + + unique_aug = set(map(tuple, map(sorted, aug_edges))) + unique_aug = list(map(tuple, map(sorted, aug_edges))) + assert len(aug_edges) == len(unique_aug), "edges should be unique" + + assert not any(u == v for u, v in unique_aug), "should be no self-edges" + + assert not any( + G.has_edge(u, v) for u, v in unique_aug + ), "aug edges and G.edges should be disjoint" + + +def _augment_and_check( + G, k, avail=None, weight=None, verbose=False, orig_k=None, max_aug_k=None +): + """ + Does one specific augmentation and checks for properties of the result + """ + if orig_k is None: + try: + orig_k = nx.edge_connectivity(G) + except nx.NetworkXPointlessConcept: + orig_k = 0 + info = {} + try: + if avail is not None: + # ensure avail is in dict form + avail_dict = dict(zip(*_unpack_available_edges(avail, weight=weight))) + else: + avail_dict = None + try: + # Find the augmentation if possible + generator = nx.k_edge_augmentation(G, k=k, weight=weight, avail=avail) + assert not isinstance(generator, list), "should always return an iter" + aug_edges = [] + for edge in generator: + aug_edges.append(edge) + except nx.NetworkXUnfeasible: + infeasible = True + info["infeasible"] = True + assert len(aug_edges) == 0, "should not generate anything if unfeasible" + + if avail is None: + n_nodes = G.number_of_nodes() + assert n_nodes <= k, ( + "unconstrained cases are only unfeasible if |V| <= k. " + f"Got |V|={n_nodes} and k={k}" + ) + else: + if max_aug_k is None: + G_aug_all = G.copy() + G_aug_all.add_edges_from(avail_dict.keys()) + try: + max_aug_k = nx.edge_connectivity(G_aug_all) + except nx.NetworkXPointlessConcept: + max_aug_k = 0 + + assert max_aug_k < k, ( + "avail should only be unfeasible if using all edges " + "does not achieve k-edge-connectivity" + ) + + # Test for a partial solution + partial_edges = list( + nx.k_edge_augmentation(G, k=k, weight=weight, partial=True, avail=avail) + ) + + info["n_partial_edges"] = len(partial_edges) + + if avail_dict is None: + assert set(partial_edges) == set( + complement_edges(G) + ), "unweighted partial solutions should be the complement" + elif len(avail_dict) > 0: + H = G.copy() + + # Find the partial / full augmented connectivity + H.add_edges_from(partial_edges) + partial_conn = nx.edge_connectivity(H) + + H.add_edges_from(set(avail_dict.keys())) + full_conn = nx.edge_connectivity(H) + + # Full connectivity should be no better than our partial + # solution. + assert ( + partial_conn == full_conn + ), "adding more edges should not increase k-conn" + + # Find the new edge-connectivity after adding the augmenting edges + aug_edges = partial_edges + else: + infeasible = False + + # Find the weight of the augmentation + num_edges = len(aug_edges) + if avail is not None: + total_weight = sum(avail_dict[e] for e in aug_edges) + else: + total_weight = num_edges + + info["total_weight"] = total_weight + info["num_edges"] = num_edges + + # Find the new edge-connectivity after adding the augmenting edges + G_aug = G.copy() + G_aug.add_edges_from(aug_edges) + try: + aug_k = nx.edge_connectivity(G_aug) + except nx.NetworkXPointlessConcept: + aug_k = 0 + info["aug_k"] = aug_k + + # Do checks + if not infeasible and orig_k < k: + assert info["aug_k"] >= k, f"connectivity should increase to k={k} or more" + + assert info["aug_k"] >= orig_k, "augmenting should never reduce connectivity" + + _assert_solution_properties(G, aug_edges, avail_dict) + + except Exception: + info["failed"] = True + print(f"edges = {list(G.edges())}") + print(f"nodes = {list(G.nodes())}") + print(f"aug_edges = {list(aug_edges)}") + print(f"info = {info}") + raise + else: + if verbose: + print(f"info = {info}") + + if infeasible: + aug_edges = None + return aug_edges, info + + +def _check_augmentations(G, avail=None, max_k=None, weight=None, verbose=False): + """Helper to check weighted/unweighted cases with multiple values of k""" + # Using all available edges, find the maximum edge-connectivity + try: + orig_k = nx.edge_connectivity(G) + except nx.NetworkXPointlessConcept: + orig_k = 0 + + if avail is not None: + all_aug_edges = _unpack_available_edges(avail, weight=weight)[0] + G_aug_all = G.copy() + G_aug_all.add_edges_from(all_aug_edges) + try: + max_aug_k = nx.edge_connectivity(G_aug_all) + except nx.NetworkXPointlessConcept: + max_aug_k = 0 + else: + max_aug_k = G.number_of_nodes() - 1 + + if max_k is None: + max_k = min(4, max_aug_k) + + avail_uniform = {e: 1 for e in complement_edges(G)} + + if verbose: + print("\n=== CHECK_AUGMENTATION ===") + print(f"G.number_of_nodes = {G.number_of_nodes()!r}") + print(f"G.number_of_edges = {G.number_of_edges()!r}") + print(f"max_k = {max_k!r}") + print(f"max_aug_k = {max_aug_k!r}") + print(f"orig_k = {orig_k!r}") + + # check augmentation for multiple values of k + for k in range(1, max_k + 1): + if verbose: + print("---------------") + print(f"Checking k = {k}") + + # Check the unweighted version + if verbose: + print("unweighted case") + aug_edges1, info1 = _augment_and_check(G, k=k, verbose=verbose, orig_k=orig_k) + + # Check that the weighted version with all available edges and uniform + # weights gives a similar solution to the unweighted case. + if verbose: + print("weighted uniform case") + aug_edges2, info2 = _augment_and_check( + G, + k=k, + avail=avail_uniform, + verbose=verbose, + orig_k=orig_k, + max_aug_k=G.number_of_nodes() - 1, + ) + + # Check the weighted version + if avail is not None: + if verbose: + print("weighted case") + aug_edges3, info3 = _augment_and_check( + G, + k=k, + avail=avail, + weight=weight, + verbose=verbose, + max_aug_k=max_aug_k, + orig_k=orig_k, + ) + + if aug_edges1 is not None: + # Check approximation ratios + if k == 1: + # when k=1, both solutions should be optimal + assert info2["total_weight"] == info1["total_weight"] + if k == 2: + # when k=2, the weighted version is an approximation + if orig_k == 0: + # the approximation ratio is 3 if G is not connected + assert info2["total_weight"] <= info1["total_weight"] * 3 + else: + # the approximation ratio is 2 if G is was connected + assert info2["total_weight"] <= info1["total_weight"] * 2 + _check_unconstrained_bridge_property(G, info1) + + +def _check_unconstrained_bridge_property(G, info1): + # Check Theorem 5 from Eswaran and Tarjan. (1975) Augmentation problems + import math + + bridge_ccs = list(nx.connectivity.bridge_components(G)) + # condense G into an forest C + C = collapse(G, bridge_ccs) + + p = len([n for n, d in C.degree() if d == 1]) # leafs + q = len([n for n, d in C.degree() if d == 0]) # isolated + if p + q > 1: + size_target = math.ceil(p / 2) + q + size_aug = info1["num_edges"] + assert ( + size_aug == size_target + ), "augmentation size is different from what theory predicts" diff --git a/lib/python3.10/site-packages/networkx/algorithms/connectivity/tests/test_edge_kcomponents.py b/lib/python3.10/site-packages/networkx/algorithms/connectivity/tests/test_edge_kcomponents.py new file mode 100644 index 0000000000000000000000000000000000000000..4a1f681ab3da3f1f965ecbbf8dcf84eb49a512b9 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/connectivity/tests/test_edge_kcomponents.py @@ -0,0 +1,488 @@ +import itertools as it + +import pytest + +import networkx as nx +from networkx.algorithms.connectivity import EdgeComponentAuxGraph, bridge_components +from networkx.algorithms.connectivity.edge_kcomponents import general_k_edge_subgraphs +from networkx.utils import pairwise + +# ---------------- +# Helper functions +# ---------------- + + +def fset(list_of_sets): + """allows == to be used for list of sets""" + return set(map(frozenset, list_of_sets)) + + +def _assert_subgraph_edge_connectivity(G, ccs_subgraph, k): + """ + tests properties of k-edge-connected subgraphs + + the actual edge connectivity should be no less than k unless the cc is a + single node. + """ + for cc in ccs_subgraph: + C = G.subgraph(cc) + if len(cc) > 1: + connectivity = nx.edge_connectivity(C) + assert connectivity >= k + + +def _memo_connectivity(G, u, v, memo): + edge = (u, v) + if edge in memo: + return memo[edge] + if not G.is_directed(): + redge = (v, u) + if redge in memo: + return memo[redge] + memo[edge] = nx.edge_connectivity(G, *edge) + return memo[edge] + + +def _all_pairs_connectivity(G, cc, k, memo): + # Brute force check + for u, v in it.combinations(cc, 2): + # Use a memoization dict to save on computation + connectivity = _memo_connectivity(G, u, v, memo) + if G.is_directed(): + connectivity = min(connectivity, _memo_connectivity(G, v, u, memo)) + assert connectivity >= k + + +def _assert_local_cc_edge_connectivity(G, ccs_local, k, memo): + """ + tests properties of k-edge-connected components + + the local edge connectivity between each pair of nodes in the original + graph should be no less than k unless the cc is a single node. + """ + for cc in ccs_local: + if len(cc) > 1: + # Strategy for testing a bit faster: If the subgraph has high edge + # connectivity then it must have local connectivity + C = G.subgraph(cc) + connectivity = nx.edge_connectivity(C) + if connectivity < k: + # Otherwise do the brute force (with memoization) check + _all_pairs_connectivity(G, cc, k, memo) + + +# Helper function +def _check_edge_connectivity(G): + """ + Helper - generates all k-edge-components using the aux graph. Checks the + both local and subgraph edge connectivity of each cc. Also checks that + alternate methods of computing the k-edge-ccs generate the same result. + """ + # Construct the auxiliary graph that can be used to make each k-cc or k-sub + aux_graph = EdgeComponentAuxGraph.construct(G) + + # memoize the local connectivity in this graph + memo = {} + + for k in it.count(1): + # Test "local" k-edge-components and k-edge-subgraphs + ccs_local = fset(aux_graph.k_edge_components(k)) + ccs_subgraph = fset(aux_graph.k_edge_subgraphs(k)) + + # Check connectivity properties that should be guaranteed by the + # algorithms. + _assert_local_cc_edge_connectivity(G, ccs_local, k, memo) + _assert_subgraph_edge_connectivity(G, ccs_subgraph, k) + + if k == 1 or k == 2 and not G.is_directed(): + assert ( + ccs_local == ccs_subgraph + ), "Subgraphs and components should be the same when k == 1 or (k == 2 and not G.directed())" + + if G.is_directed(): + # Test special case methods are the same as the aux graph + if k == 1: + alt_sccs = fset(nx.strongly_connected_components(G)) + assert alt_sccs == ccs_local, "k=1 failed alt" + assert alt_sccs == ccs_subgraph, "k=1 failed alt" + else: + # Test special case methods are the same as the aux graph + if k == 1: + alt_ccs = fset(nx.connected_components(G)) + assert alt_ccs == ccs_local, "k=1 failed alt" + assert alt_ccs == ccs_subgraph, "k=1 failed alt" + elif k == 2: + alt_bridge_ccs = fset(bridge_components(G)) + assert alt_bridge_ccs == ccs_local, "k=2 failed alt" + assert alt_bridge_ccs == ccs_subgraph, "k=2 failed alt" + # if new methods for k == 3 or k == 4 are implemented add them here + + # Check the general subgraph method works by itself + alt_subgraph_ccs = fset( + [set(C.nodes()) for C in general_k_edge_subgraphs(G, k=k)] + ) + assert alt_subgraph_ccs == ccs_subgraph, "alt subgraph method failed" + + # Stop once k is larger than all special case methods + # and we cannot break down ccs any further. + if k > 2 and all(len(cc) == 1 for cc in ccs_local): + break + + +# ---------------- +# Misc tests +# ---------------- + + +def test_zero_k_exception(): + G = nx.Graph() + # functions that return generators error immediately + pytest.raises(ValueError, nx.k_edge_components, G, k=0) + pytest.raises(ValueError, nx.k_edge_subgraphs, G, k=0) + + # actual generators only error when you get the first item + aux_graph = EdgeComponentAuxGraph.construct(G) + pytest.raises(ValueError, list, aux_graph.k_edge_components(k=0)) + pytest.raises(ValueError, list, aux_graph.k_edge_subgraphs(k=0)) + + pytest.raises(ValueError, list, general_k_edge_subgraphs(G, k=0)) + + +def test_empty_input(): + G = nx.Graph() + assert [] == list(nx.k_edge_components(G, k=5)) + assert [] == list(nx.k_edge_subgraphs(G, k=5)) + + G = nx.DiGraph() + assert [] == list(nx.k_edge_components(G, k=5)) + assert [] == list(nx.k_edge_subgraphs(G, k=5)) + + +def test_not_implemented(): + G = nx.MultiGraph() + pytest.raises(nx.NetworkXNotImplemented, EdgeComponentAuxGraph.construct, G) + pytest.raises(nx.NetworkXNotImplemented, nx.k_edge_components, G, k=2) + pytest.raises(nx.NetworkXNotImplemented, nx.k_edge_subgraphs, G, k=2) + with pytest.raises(nx.NetworkXNotImplemented): + next(bridge_components(G)) + with pytest.raises(nx.NetworkXNotImplemented): + next(bridge_components(nx.DiGraph())) + + +def test_general_k_edge_subgraph_quick_return(): + # tests quick return optimization + G = nx.Graph() + G.add_node(0) + subgraphs = list(general_k_edge_subgraphs(G, k=1)) + assert len(subgraphs) == 1 + for subgraph in subgraphs: + assert subgraph.number_of_nodes() == 1 + + G.add_node(1) + subgraphs = list(general_k_edge_subgraphs(G, k=1)) + assert len(subgraphs) == 2 + for subgraph in subgraphs: + assert subgraph.number_of_nodes() == 1 + + +# ---------------- +# Undirected tests +# ---------------- + + +def test_random_gnp(): + # seeds = [1550709854, 1309423156, 4208992358, 2785630813, 1915069929] + seeds = [12, 13] + + for seed in seeds: + G = nx.gnp_random_graph(20, 0.2, seed=seed) + _check_edge_connectivity(G) + + +def test_configuration(): + # seeds = [2718183590, 2470619828, 1694705158, 3001036531, 2401251497] + seeds = [14, 15] + for seed in seeds: + deg_seq = nx.random_powerlaw_tree_sequence(20, seed=seed, tries=5000) + G = nx.Graph(nx.configuration_model(deg_seq, seed=seed)) + G.remove_edges_from(nx.selfloop_edges(G)) + _check_edge_connectivity(G) + + +def test_shell(): + # seeds = [2057382236, 3331169846, 1840105863, 476020778, 2247498425] + seeds = [20] + for seed in seeds: + constructor = [(12, 70, 0.8), (15, 40, 0.6)] + G = nx.random_shell_graph(constructor, seed=seed) + _check_edge_connectivity(G) + + +def test_karate(): + G = nx.karate_club_graph() + _check_edge_connectivity(G) + + +def test_tarjan_bridge(): + # graph from tarjan paper + # RE Tarjan - "A note on finding the bridges of a graph" + # Information Processing Letters, 1974 - Elsevier + # doi:10.1016/0020-0190(74)90003-9. + # define 2-connected components and bridges + ccs = [ + (1, 2, 4, 3, 1, 4), + (5, 6, 7, 5), + (8, 9, 10, 8), + (17, 18, 16, 15, 17), + (11, 12, 14, 13, 11, 14), + ] + bridges = [(4, 8), (3, 5), (3, 17)] + G = nx.Graph(it.chain(*(pairwise(path) for path in ccs + bridges))) + _check_edge_connectivity(G) + + +def test_bridge_cc(): + # define 2-connected components and bridges + cc2 = [(1, 2, 4, 3, 1, 4), (8, 9, 10, 8), (11, 12, 13, 11)] + bridges = [(4, 8), (3, 5), (20, 21), (22, 23, 24)] + G = nx.Graph(it.chain(*(pairwise(path) for path in cc2 + bridges))) + bridge_ccs = fset(bridge_components(G)) + target_ccs = fset( + [{1, 2, 3, 4}, {5}, {8, 9, 10}, {11, 12, 13}, {20}, {21}, {22}, {23}, {24}] + ) + assert bridge_ccs == target_ccs + _check_edge_connectivity(G) + + +def test_undirected_aux_graph(): + # Graph similar to the one in + # http://journals.plos.org/plosone/article?id=10.1371/journal.pone.0136264 + a, b, c, d, e, f, g, h, i = "abcdefghi" + paths = [ + (a, d, b, f, c), + (a, e, b), + (a, e, b, c, g, b, a), + (c, b), + (f, g, f), + (h, i), + ] + G = nx.Graph(it.chain(*[pairwise(path) for path in paths])) + aux_graph = EdgeComponentAuxGraph.construct(G) + + components_1 = fset(aux_graph.k_edge_subgraphs(k=1)) + target_1 = fset([{a, b, c, d, e, f, g}, {h, i}]) + assert target_1 == components_1 + + # Check that the undirected case for k=1 agrees with CCs + alt_1 = fset(nx.k_edge_subgraphs(G, k=1)) + assert alt_1 == components_1 + + components_2 = fset(aux_graph.k_edge_subgraphs(k=2)) + target_2 = fset([{a, b, c, d, e, f, g}, {h}, {i}]) + assert target_2 == components_2 + + # Check that the undirected case for k=2 agrees with bridge components + alt_2 = fset(nx.k_edge_subgraphs(G, k=2)) + assert alt_2 == components_2 + + components_3 = fset(aux_graph.k_edge_subgraphs(k=3)) + target_3 = fset([{a}, {b, c, f, g}, {d}, {e}, {h}, {i}]) + assert target_3 == components_3 + + components_4 = fset(aux_graph.k_edge_subgraphs(k=4)) + target_4 = fset([{a}, {b}, {c}, {d}, {e}, {f}, {g}, {h}, {i}]) + assert target_4 == components_4 + + _check_edge_connectivity(G) + + +def test_local_subgraph_difference(): + paths = [ + (11, 12, 13, 14, 11, 13, 14, 12), # first 4-clique + (21, 22, 23, 24, 21, 23, 24, 22), # second 4-clique + # paths connecting each node of the 4 cliques + (11, 101, 21), + (12, 102, 22), + (13, 103, 23), + (14, 104, 24), + ] + G = nx.Graph(it.chain(*[pairwise(path) for path in paths])) + aux_graph = EdgeComponentAuxGraph.construct(G) + + # Each clique is returned separately in k-edge-subgraphs + subgraph_ccs = fset(aux_graph.k_edge_subgraphs(3)) + subgraph_target = fset( + [{101}, {102}, {103}, {104}, {21, 22, 23, 24}, {11, 12, 13, 14}] + ) + assert subgraph_ccs == subgraph_target + + # But in k-edge-ccs they are returned together + # because they are locally 3-edge-connected + local_ccs = fset(aux_graph.k_edge_components(3)) + local_target = fset([{101}, {102}, {103}, {104}, {11, 12, 13, 14, 21, 22, 23, 24}]) + assert local_ccs == local_target + + +def test_local_subgraph_difference_directed(): + dipaths = [(1, 2, 3, 4, 1), (1, 3, 1)] + G = nx.DiGraph(it.chain(*[pairwise(path) for path in dipaths])) + + assert fset(nx.k_edge_components(G, k=1)) == fset(nx.k_edge_subgraphs(G, k=1)) + + # Unlike undirected graphs, when k=2, for directed graphs there is a case + # where the k-edge-ccs are not the same as the k-edge-subgraphs. + # (in directed graphs ccs and subgraphs are the same when k=2) + assert fset(nx.k_edge_components(G, k=2)) != fset(nx.k_edge_subgraphs(G, k=2)) + + assert fset(nx.k_edge_components(G, k=3)) == fset(nx.k_edge_subgraphs(G, k=3)) + + _check_edge_connectivity(G) + + +def test_triangles(): + paths = [ + (11, 12, 13, 11), # first 3-clique + (21, 22, 23, 21), # second 3-clique + (11, 21), # connected by an edge + ] + G = nx.Graph(it.chain(*[pairwise(path) for path in paths])) + + # subgraph and ccs are the same in all cases here + assert fset(nx.k_edge_components(G, k=1)) == fset(nx.k_edge_subgraphs(G, k=1)) + + assert fset(nx.k_edge_components(G, k=2)) == fset(nx.k_edge_subgraphs(G, k=2)) + + assert fset(nx.k_edge_components(G, k=3)) == fset(nx.k_edge_subgraphs(G, k=3)) + + _check_edge_connectivity(G) + + +def test_four_clique(): + paths = [ + (11, 12, 13, 14, 11, 13, 14, 12), # first 4-clique + (21, 22, 23, 24, 21, 23, 24, 22), # second 4-clique + # paths connecting the 4 cliques such that they are + # 3-connected in G, but not in the subgraph. + # Case where the nodes bridging them do not have degree less than 3. + (100, 13), + (12, 100, 22), + (13, 200, 23), + (14, 300, 24), + ] + G = nx.Graph(it.chain(*[pairwise(path) for path in paths])) + + # The subgraphs and ccs are different for k=3 + local_ccs = fset(nx.k_edge_components(G, k=3)) + subgraphs = fset(nx.k_edge_subgraphs(G, k=3)) + assert local_ccs != subgraphs + + # The cliques ares in the same cc + clique1 = frozenset(paths[0]) + clique2 = frozenset(paths[1]) + assert clique1.union(clique2).union({100}) in local_ccs + + # but different subgraphs + assert clique1 in subgraphs + assert clique2 in subgraphs + + assert G.degree(100) == 3 + + _check_edge_connectivity(G) + + +def test_five_clique(): + # Make a graph that can be disconnected less than 4 edges, but no node has + # degree less than 4. + G = nx.disjoint_union(nx.complete_graph(5), nx.complete_graph(5)) + paths = [ + # add aux-connections + (1, 100, 6), + (2, 100, 7), + (3, 200, 8), + (4, 200, 100), + ] + G.add_edges_from(it.chain(*[pairwise(path) for path in paths])) + assert min(dict(nx.degree(G)).values()) == 4 + + # For k=3 they are the same + assert fset(nx.k_edge_components(G, k=3)) == fset(nx.k_edge_subgraphs(G, k=3)) + + # For k=4 they are the different + # the aux nodes are in the same CC as clique 1 but no the same subgraph + assert fset(nx.k_edge_components(G, k=4)) != fset(nx.k_edge_subgraphs(G, k=4)) + + # For k=5 they are not the same + assert fset(nx.k_edge_components(G, k=5)) != fset(nx.k_edge_subgraphs(G, k=5)) + + # For k=6 they are the same + assert fset(nx.k_edge_components(G, k=6)) == fset(nx.k_edge_subgraphs(G, k=6)) + _check_edge_connectivity(G) + + +# ---------------- +# Undirected tests +# ---------------- + + +def test_directed_aux_graph(): + # Graph similar to the one in + # http://journals.plos.org/plosone/article?id=10.1371/journal.pone.0136264 + a, b, c, d, e, f, g, h, i = "abcdefghi" + dipaths = [ + (a, d, b, f, c), + (a, e, b), + (a, e, b, c, g, b, a), + (c, b), + (f, g, f), + (h, i), + ] + G = nx.DiGraph(it.chain(*[pairwise(path) for path in dipaths])) + aux_graph = EdgeComponentAuxGraph.construct(G) + + components_1 = fset(aux_graph.k_edge_subgraphs(k=1)) + target_1 = fset([{a, b, c, d, e, f, g}, {h}, {i}]) + assert target_1 == components_1 + + # Check that the directed case for k=1 agrees with SCCs + alt_1 = fset(nx.strongly_connected_components(G)) + assert alt_1 == components_1 + + components_2 = fset(aux_graph.k_edge_subgraphs(k=2)) + target_2 = fset([{i}, {e}, {d}, {b, c, f, g}, {h}, {a}]) + assert target_2 == components_2 + + components_3 = fset(aux_graph.k_edge_subgraphs(k=3)) + target_3 = fset([{a}, {b}, {c}, {d}, {e}, {f}, {g}, {h}, {i}]) + assert target_3 == components_3 + + +def test_random_gnp_directed(): + # seeds = [3894723670, 500186844, 267231174, 2181982262, 1116750056] + seeds = [21] + for seed in seeds: + G = nx.gnp_random_graph(20, 0.2, directed=True, seed=seed) + _check_edge_connectivity(G) + + +def test_configuration_directed(): + # seeds = [671221681, 2403749451, 124433910, 672335939, 1193127215] + seeds = [67] + for seed in seeds: + deg_seq = nx.random_powerlaw_tree_sequence(20, seed=seed, tries=5000) + G = nx.DiGraph(nx.configuration_model(deg_seq, seed=seed)) + G.remove_edges_from(nx.selfloop_edges(G)) + _check_edge_connectivity(G) + + +def test_shell_directed(): + # seeds = [3134027055, 4079264063, 1350769518, 1405643020, 530038094] + seeds = [31] + for seed in seeds: + constructor = [(12, 70, 0.8), (15, 40, 0.6)] + G = nx.random_shell_graph(constructor, seed=seed).to_directed() + _check_edge_connectivity(G) + + +def test_karate_directed(): + G = nx.karate_club_graph().to_directed() + _check_edge_connectivity(G) diff --git a/lib/python3.10/site-packages/networkx/algorithms/connectivity/tests/test_kcomponents.py b/lib/python3.10/site-packages/networkx/algorithms/connectivity/tests/test_kcomponents.py new file mode 100644 index 0000000000000000000000000000000000000000..f4436acd07fe57cb510fee138b36f10923a9688a --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/connectivity/tests/test_kcomponents.py @@ -0,0 +1,296 @@ +# Test for Moody and White k-components algorithm +import pytest + +import networkx as nx +from networkx.algorithms.connectivity.kcomponents import ( + _consolidate, + build_k_number_dict, +) + +## +# A nice synthetic graph +## + + +def torrents_and_ferraro_graph(): + # Graph from https://arxiv.org/pdf/1503.04476v1 p.26 + G = nx.convert_node_labels_to_integers( + nx.grid_graph([5, 5]), label_attribute="labels" + ) + rlabels = nx.get_node_attributes(G, "labels") + labels = {v: k for k, v in rlabels.items()} + + for nodes in [(labels[(0, 4)], labels[(1, 4)]), (labels[(3, 4)], labels[(4, 4)])]: + new_node = G.order() + 1 + # Petersen graph is triconnected + P = nx.petersen_graph() + G = nx.disjoint_union(G, P) + # Add two edges between the grid and P + G.add_edge(new_node + 1, nodes[0]) + G.add_edge(new_node, nodes[1]) + # K5 is 4-connected + K = nx.complete_graph(5) + G = nx.disjoint_union(G, K) + # Add three edges between P and K5 + G.add_edge(new_node + 2, new_node + 11) + G.add_edge(new_node + 3, new_node + 12) + G.add_edge(new_node + 4, new_node + 13) + # Add another K5 sharing a node + G = nx.disjoint_union(G, K) + nbrs = G[new_node + 10] + G.remove_node(new_node + 10) + for nbr in nbrs: + G.add_edge(new_node + 17, nbr) + # This edge makes the graph biconnected; it's + # needed because K5s share only one node. + G.add_edge(new_node + 16, new_node + 8) + + for nodes in [(labels[(0, 0)], labels[(1, 0)]), (labels[(3, 0)], labels[(4, 0)])]: + new_node = G.order() + 1 + # Petersen graph is triconnected + P = nx.petersen_graph() + G = nx.disjoint_union(G, P) + # Add two edges between the grid and P + G.add_edge(new_node + 1, nodes[0]) + G.add_edge(new_node, nodes[1]) + # K5 is 4-connected + K = nx.complete_graph(5) + G = nx.disjoint_union(G, K) + # Add three edges between P and K5 + G.add_edge(new_node + 2, new_node + 11) + G.add_edge(new_node + 3, new_node + 12) + G.add_edge(new_node + 4, new_node + 13) + # Add another K5 sharing two nodes + G = nx.disjoint_union(G, K) + nbrs = G[new_node + 10] + G.remove_node(new_node + 10) + for nbr in nbrs: + G.add_edge(new_node + 17, nbr) + nbrs2 = G[new_node + 9] + G.remove_node(new_node + 9) + for nbr in nbrs2: + G.add_edge(new_node + 18, nbr) + return G + + +def test_directed(): + with pytest.raises(nx.NetworkXNotImplemented): + G = nx.gnp_random_graph(10, 0.2, directed=True, seed=42) + nx.k_components(G) + + +# Helper function +def _check_connectivity(G, k_components): + for k, components in k_components.items(): + if k < 3: + continue + # check that k-components have node connectivity >= k. + for component in components: + C = G.subgraph(component) + K = nx.node_connectivity(C) + assert K >= k + + +@pytest.mark.slow +def test_torrents_and_ferraro_graph(): + G = torrents_and_ferraro_graph() + result = nx.k_components(G) + _check_connectivity(G, result) + + # In this example graph there are 8 3-components, 4 with 15 nodes + # and 4 with 5 nodes. + assert len(result[3]) == 8 + assert len([c for c in result[3] if len(c) == 15]) == 4 + assert len([c for c in result[3] if len(c) == 5]) == 4 + # There are also 8 4-components all with 5 nodes. + assert len(result[4]) == 8 + assert all(len(c) == 5 for c in result[4]) + + +@pytest.mark.slow +def test_random_gnp(): + G = nx.gnp_random_graph(50, 0.2, seed=42) + result = nx.k_components(G) + _check_connectivity(G, result) + + +@pytest.mark.slow +def test_shell(): + constructor = [(20, 80, 0.8), (80, 180, 0.6)] + G = nx.random_shell_graph(constructor, seed=42) + result = nx.k_components(G) + _check_connectivity(G, result) + + +def test_configuration(): + deg_seq = nx.random_powerlaw_tree_sequence(100, tries=5, seed=72) + G = nx.Graph(nx.configuration_model(deg_seq)) + G.remove_edges_from(nx.selfloop_edges(G)) + result = nx.k_components(G) + _check_connectivity(G, result) + + +def test_karate(): + G = nx.karate_club_graph() + result = nx.k_components(G) + _check_connectivity(G, result) + + +def test_karate_component_number(): + karate_k_num = { + 0: 4, + 1: 4, + 2: 4, + 3: 4, + 4: 3, + 5: 3, + 6: 3, + 7: 4, + 8: 4, + 9: 2, + 10: 3, + 11: 1, + 12: 2, + 13: 4, + 14: 2, + 15: 2, + 16: 2, + 17: 2, + 18: 2, + 19: 3, + 20: 2, + 21: 2, + 22: 2, + 23: 3, + 24: 3, + 25: 3, + 26: 2, + 27: 3, + 28: 3, + 29: 3, + 30: 4, + 31: 3, + 32: 4, + 33: 4, + } + G = nx.karate_club_graph() + k_components = nx.k_components(G) + k_num = build_k_number_dict(k_components) + assert karate_k_num == k_num + + +def test_davis_southern_women(): + G = nx.davis_southern_women_graph() + result = nx.k_components(G) + _check_connectivity(G, result) + + +def test_davis_southern_women_detail_3_and_4(): + solution = { + 3: [ + { + "Nora Fayette", + "E10", + "Myra Liddel", + "E12", + "E14", + "Frances Anderson", + "Evelyn Jefferson", + "Ruth DeSand", + "Helen Lloyd", + "Eleanor Nye", + "E9", + "E8", + "E5", + "E4", + "E7", + "E6", + "E1", + "Verne Sanderson", + "E3", + "E2", + "Theresa Anderson", + "Pearl Oglethorpe", + "Katherina Rogers", + "Brenda Rogers", + "E13", + "Charlotte McDowd", + "Sylvia Avondale", + "Laura Mandeville", + } + ], + 4: [ + { + "Nora Fayette", + "E10", + "Verne Sanderson", + "E12", + "Frances Anderson", + "Evelyn Jefferson", + "Ruth DeSand", + "Helen Lloyd", + "Eleanor Nye", + "E9", + "E8", + "E5", + "E4", + "E7", + "E6", + "Myra Liddel", + "E3", + "Theresa Anderson", + "Katherina Rogers", + "Brenda Rogers", + "Charlotte McDowd", + "Sylvia Avondale", + "Laura Mandeville", + } + ], + } + G = nx.davis_southern_women_graph() + result = nx.k_components(G) + for k, components in result.items(): + if k < 3: + continue + assert len(components) == len(solution[k]) + for component in components: + assert component in solution[k] + + +def test_set_consolidation_rosettacode(): + # Tests from http://rosettacode.org/wiki/Set_consolidation + def list_of_sets_equal(result, solution): + assert {frozenset(s) for s in result} == {frozenset(s) for s in solution} + + question = [{"A", "B"}, {"C", "D"}] + solution = [{"A", "B"}, {"C", "D"}] + list_of_sets_equal(_consolidate(question, 1), solution) + question = [{"A", "B"}, {"B", "C"}] + solution = [{"A", "B", "C"}] + list_of_sets_equal(_consolidate(question, 1), solution) + question = [{"A", "B"}, {"C", "D"}, {"D", "B"}] + solution = [{"A", "C", "B", "D"}] + list_of_sets_equal(_consolidate(question, 1), solution) + question = [{"H", "I", "K"}, {"A", "B"}, {"C", "D"}, {"D", "B"}, {"F", "G", "H"}] + solution = [{"A", "C", "B", "D"}, {"G", "F", "I", "H", "K"}] + list_of_sets_equal(_consolidate(question, 1), solution) + question = [ + {"A", "H"}, + {"H", "I", "K"}, + {"A", "B"}, + {"C", "D"}, + {"D", "B"}, + {"F", "G", "H"}, + ] + solution = [{"A", "C", "B", "D", "G", "F", "I", "H", "K"}] + list_of_sets_equal(_consolidate(question, 1), solution) + question = [ + {"H", "I", "K"}, + {"A", "B"}, + {"C", "D"}, + {"D", "B"}, + {"F", "G", "H"}, + {"A", "H"}, + ] + solution = [{"A", "C", "B", "D", "G", "F", "I", "H", "K"}] + list_of_sets_equal(_consolidate(question, 1), solution) diff --git a/lib/python3.10/site-packages/networkx/algorithms/connectivity/tests/test_kcutsets.py b/lib/python3.10/site-packages/networkx/algorithms/connectivity/tests/test_kcutsets.py new file mode 100644 index 0000000000000000000000000000000000000000..4b4b5494a87c83a5455e98cbe6fef267f1a2e91a --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/connectivity/tests/test_kcutsets.py @@ -0,0 +1,273 @@ +# Jordi Torrents +# Test for k-cutsets +import itertools + +import pytest + +import networkx as nx +from networkx.algorithms import flow +from networkx.algorithms.connectivity.kcutsets import _is_separating_set + +MAX_CUTSETS_TO_TEST = 4 # originally 100. cut to decrease testing time + +flow_funcs = [ + flow.boykov_kolmogorov, + flow.dinitz, + flow.edmonds_karp, + flow.preflow_push, + flow.shortest_augmenting_path, +] + + +## +# Some nice synthetic graphs +## +def graph_example_1(): + G = nx.convert_node_labels_to_integers( + nx.grid_graph([5, 5]), label_attribute="labels" + ) + rlabels = nx.get_node_attributes(G, "labels") + labels = {v: k for k, v in rlabels.items()} + + for nodes in [ + (labels[(0, 0)], labels[(1, 0)]), + (labels[(0, 4)], labels[(1, 4)]), + (labels[(3, 0)], labels[(4, 0)]), + (labels[(3, 4)], labels[(4, 4)]), + ]: + new_node = G.order() + 1 + # Petersen graph is triconnected + P = nx.petersen_graph() + G = nx.disjoint_union(G, P) + # Add two edges between the grid and P + G.add_edge(new_node + 1, nodes[0]) + G.add_edge(new_node, nodes[1]) + # K5 is 4-connected + K = nx.complete_graph(5) + G = nx.disjoint_union(G, K) + # Add three edges between P and K5 + G.add_edge(new_node + 2, new_node + 11) + G.add_edge(new_node + 3, new_node + 12) + G.add_edge(new_node + 4, new_node + 13) + # Add another K5 sharing a node + G = nx.disjoint_union(G, K) + nbrs = G[new_node + 10] + G.remove_node(new_node + 10) + for nbr in nbrs: + G.add_edge(new_node + 17, nbr) + G.add_edge(new_node + 16, new_node + 5) + return G + + +def torrents_and_ferraro_graph(): + G = nx.convert_node_labels_to_integers( + nx.grid_graph([5, 5]), label_attribute="labels" + ) + rlabels = nx.get_node_attributes(G, "labels") + labels = {v: k for k, v in rlabels.items()} + + for nodes in [(labels[(0, 4)], labels[(1, 4)]), (labels[(3, 4)], labels[(4, 4)])]: + new_node = G.order() + 1 + # Petersen graph is triconnected + P = nx.petersen_graph() + G = nx.disjoint_union(G, P) + # Add two edges between the grid and P + G.add_edge(new_node + 1, nodes[0]) + G.add_edge(new_node, nodes[1]) + # K5 is 4-connected + K = nx.complete_graph(5) + G = nx.disjoint_union(G, K) + # Add three edges between P and K5 + G.add_edge(new_node + 2, new_node + 11) + G.add_edge(new_node + 3, new_node + 12) + G.add_edge(new_node + 4, new_node + 13) + # Add another K5 sharing a node + G = nx.disjoint_union(G, K) + nbrs = G[new_node + 10] + G.remove_node(new_node + 10) + for nbr in nbrs: + G.add_edge(new_node + 17, nbr) + # Commenting this makes the graph not biconnected !! + # This stupid mistake make one reviewer very angry :P + G.add_edge(new_node + 16, new_node + 8) + + for nodes in [(labels[(0, 0)], labels[(1, 0)]), (labels[(3, 0)], labels[(4, 0)])]: + new_node = G.order() + 1 + # Petersen graph is triconnected + P = nx.petersen_graph() + G = nx.disjoint_union(G, P) + # Add two edges between the grid and P + G.add_edge(new_node + 1, nodes[0]) + G.add_edge(new_node, nodes[1]) + # K5 is 4-connected + K = nx.complete_graph(5) + G = nx.disjoint_union(G, K) + # Add three edges between P and K5 + G.add_edge(new_node + 2, new_node + 11) + G.add_edge(new_node + 3, new_node + 12) + G.add_edge(new_node + 4, new_node + 13) + # Add another K5 sharing two nodes + G = nx.disjoint_union(G, K) + nbrs = G[new_node + 10] + G.remove_node(new_node + 10) + for nbr in nbrs: + G.add_edge(new_node + 17, nbr) + nbrs2 = G[new_node + 9] + G.remove_node(new_node + 9) + for nbr in nbrs2: + G.add_edge(new_node + 18, nbr) + return G + + +# Helper function +def _check_separating_sets(G): + for cc in nx.connected_components(G): + if len(cc) < 3: + continue + Gc = G.subgraph(cc) + node_conn = nx.node_connectivity(Gc) + all_cuts = nx.all_node_cuts(Gc) + # Only test a limited number of cut sets to reduce test time. + for cut in itertools.islice(all_cuts, MAX_CUTSETS_TO_TEST): + assert node_conn == len(cut) + assert not nx.is_connected(nx.restricted_view(G, cut, [])) + + +@pytest.mark.slow +def test_torrents_and_ferraro_graph(): + G = torrents_and_ferraro_graph() + _check_separating_sets(G) + + +def test_example_1(): + G = graph_example_1() + _check_separating_sets(G) + + +def test_random_gnp(): + G = nx.gnp_random_graph(100, 0.1, seed=42) + _check_separating_sets(G) + + +def test_shell(): + constructor = [(20, 80, 0.8), (80, 180, 0.6)] + G = nx.random_shell_graph(constructor, seed=42) + _check_separating_sets(G) + + +def test_configuration(): + deg_seq = nx.random_powerlaw_tree_sequence(100, tries=5, seed=72) + G = nx.Graph(nx.configuration_model(deg_seq)) + G.remove_edges_from(nx.selfloop_edges(G)) + _check_separating_sets(G) + + +def test_karate(): + G = nx.karate_club_graph() + _check_separating_sets(G) + + +def _generate_no_biconnected(max_attempts=50): + attempts = 0 + while True: + G = nx.fast_gnp_random_graph(100, 0.0575, seed=42) + if nx.is_connected(G) and not nx.is_biconnected(G): + attempts = 0 + yield G + else: + if attempts >= max_attempts: + msg = f"Tried {attempts} times: no suitable Graph." + raise Exception(msg) + else: + attempts += 1 + + +def test_articulation_points(): + Ggen = _generate_no_biconnected() + for i in range(1): # change 1 to 3 or more for more realizations. + G = next(Ggen) + articulation_points = [{a} for a in nx.articulation_points(G)] + for cut in nx.all_node_cuts(G): + assert cut in articulation_points + + +def test_grid_2d_graph(): + # All minimum node cuts of a 2d grid + # are the four pairs of nodes that are + # neighbors of the four corner nodes. + G = nx.grid_2d_graph(5, 5) + solution = [{(0, 1), (1, 0)}, {(3, 0), (4, 1)}, {(3, 4), (4, 3)}, {(0, 3), (1, 4)}] + for cut in nx.all_node_cuts(G): + assert cut in solution + + +def test_disconnected_graph(): + G = nx.fast_gnp_random_graph(100, 0.01, seed=42) + cuts = nx.all_node_cuts(G) + pytest.raises(nx.NetworkXError, next, cuts) + + +@pytest.mark.slow +def test_alternative_flow_functions(): + graphs = [nx.grid_2d_graph(4, 4), nx.cycle_graph(5)] + for G in graphs: + node_conn = nx.node_connectivity(G) + for flow_func in flow_funcs: + all_cuts = nx.all_node_cuts(G, flow_func=flow_func) + # Only test a limited number of cut sets to reduce test time. + for cut in itertools.islice(all_cuts, MAX_CUTSETS_TO_TEST): + assert node_conn == len(cut) + assert not nx.is_connected(nx.restricted_view(G, cut, [])) + + +def test_is_separating_set_complete_graph(): + G = nx.complete_graph(5) + assert _is_separating_set(G, {0, 1, 2, 3}) + + +def test_is_separating_set(): + for i in [5, 10, 15]: + G = nx.star_graph(i) + max_degree_node = max(G, key=G.degree) + assert _is_separating_set(G, {max_degree_node}) + + +def test_non_repeated_cuts(): + # The algorithm was repeating the cut {0, 1} for the giant biconnected + # component of the Karate club graph. + K = nx.karate_club_graph() + bcc = max(list(nx.biconnected_components(K)), key=len) + G = K.subgraph(bcc) + solution = [{32, 33}, {2, 33}, {0, 3}, {0, 1}, {29, 33}] + cuts = list(nx.all_node_cuts(G)) + if len(solution) != len(cuts): + print(f"Solution: {solution}") + print(f"Result: {cuts}") + assert len(solution) == len(cuts) + for cut in cuts: + assert cut in solution + + +def test_cycle_graph(): + G = nx.cycle_graph(5) + solution = [{0, 2}, {0, 3}, {1, 3}, {1, 4}, {2, 4}] + cuts = list(nx.all_node_cuts(G)) + assert len(solution) == len(cuts) + for cut in cuts: + assert cut in solution + + +def test_complete_graph(): + G = nx.complete_graph(5) + assert nx.node_connectivity(G) == 4 + assert list(nx.all_node_cuts(G)) == [] + + +def test_all_node_cuts_simple_case(): + G = nx.complete_graph(5) + G.remove_edges_from([(0, 1), (3, 4)]) + expected = [{0, 1, 2}, {2, 3, 4}] + actual = list(nx.all_node_cuts(G)) + assert len(actual) == len(expected) + for cut in actual: + assert cut in expected diff --git a/lib/python3.10/site-packages/networkx/algorithms/connectivity/tests/test_stoer_wagner.py b/lib/python3.10/site-packages/networkx/algorithms/connectivity/tests/test_stoer_wagner.py new file mode 100644 index 0000000000000000000000000000000000000000..2b9e2bab41eb29067166b6faa331e022d4074ce3 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/connectivity/tests/test_stoer_wagner.py @@ -0,0 +1,102 @@ +from itertools import chain + +import pytest + +import networkx as nx + + +def _check_partition(G, cut_value, partition, weight): + assert isinstance(partition, tuple) + assert len(partition) == 2 + assert isinstance(partition[0], list) + assert isinstance(partition[1], list) + assert len(partition[0]) > 0 + assert len(partition[1]) > 0 + assert sum(map(len, partition)) == len(G) + assert set(chain.from_iterable(partition)) == set(G) + partition = tuple(map(set, partition)) + w = 0 + for u, v, e in G.edges(data=True): + if (u in partition[0]) == (v in partition[1]): + w += e.get(weight, 1) + assert w == cut_value + + +def _test_stoer_wagner(G, answer, weight="weight"): + cut_value, partition = nx.stoer_wagner(G, weight, heap=nx.utils.PairingHeap) + assert cut_value == answer + _check_partition(G, cut_value, partition, weight) + cut_value, partition = nx.stoer_wagner(G, weight, heap=nx.utils.BinaryHeap) + assert cut_value == answer + _check_partition(G, cut_value, partition, weight) + + +def test_graph1(): + G = nx.Graph() + G.add_edge("x", "a", weight=3) + G.add_edge("x", "b", weight=1) + G.add_edge("a", "c", weight=3) + G.add_edge("b", "c", weight=5) + G.add_edge("b", "d", weight=4) + G.add_edge("d", "e", weight=2) + G.add_edge("c", "y", weight=2) + G.add_edge("e", "y", weight=3) + _test_stoer_wagner(G, 4) + + +def test_graph2(): + G = nx.Graph() + G.add_edge("x", "a") + G.add_edge("x", "b") + G.add_edge("a", "c") + G.add_edge("b", "c") + G.add_edge("b", "d") + G.add_edge("d", "e") + G.add_edge("c", "y") + G.add_edge("e", "y") + _test_stoer_wagner(G, 2) + + +def test_graph3(): + # Source: + # Stoer, M. and Wagner, F. (1997). "A simple min-cut algorithm". Journal of + # the ACM 44 (4), 585-591. + G = nx.Graph() + G.add_edge(1, 2, weight=2) + G.add_edge(1, 5, weight=3) + G.add_edge(2, 3, weight=3) + G.add_edge(2, 5, weight=2) + G.add_edge(2, 6, weight=2) + G.add_edge(3, 4, weight=4) + G.add_edge(3, 7, weight=2) + G.add_edge(4, 7, weight=2) + G.add_edge(4, 8, weight=2) + G.add_edge(5, 6, weight=3) + G.add_edge(6, 7, weight=1) + G.add_edge(7, 8, weight=3) + _test_stoer_wagner(G, 4) + + +def test_weight_name(): + G = nx.Graph() + G.add_edge(1, 2, weight=1, cost=8) + G.add_edge(1, 3, cost=2) + G.add_edge(2, 3, cost=4) + _test_stoer_wagner(G, 6, weight="cost") + + +def test_exceptions(): + G = nx.Graph() + pytest.raises(nx.NetworkXError, nx.stoer_wagner, G) + G.add_node(1) + pytest.raises(nx.NetworkXError, nx.stoer_wagner, G) + G.add_node(2) + pytest.raises(nx.NetworkXError, nx.stoer_wagner, G) + G.add_edge(1, 2, weight=-2) + pytest.raises(nx.NetworkXError, nx.stoer_wagner, G) + G = nx.DiGraph() + pytest.raises(nx.NetworkXNotImplemented, nx.stoer_wagner, G) + G = nx.MultiGraph() + pytest.raises(nx.NetworkXNotImplemented, nx.stoer_wagner, G) + G = nx.MultiDiGraph() + pytest.raises(nx.NetworkXNotImplemented, nx.stoer_wagner, G) diff --git a/lib/python3.10/site-packages/networkx/algorithms/connectivity/utils.py b/lib/python3.10/site-packages/networkx/algorithms/connectivity/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7bf9994598981e528f30e0deb15413c35f3dadbe --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/connectivity/utils.py @@ -0,0 +1,88 @@ +""" +Utilities for connectivity package +""" + +import networkx as nx + +__all__ = ["build_auxiliary_node_connectivity", "build_auxiliary_edge_connectivity"] + + +@nx._dispatchable(returns_graph=True) +def build_auxiliary_node_connectivity(G): + r"""Creates a directed graph D from an undirected graph G to compute flow + based node connectivity. + + For an undirected graph G having `n` nodes and `m` edges we derive a + directed graph D with `2n` nodes and `2m+n` arcs by replacing each + original node `v` with two nodes `vA`, `vB` linked by an (internal) + arc in D. Then for each edge (`u`, `v`) in G we add two arcs (`uB`, `vA`) + and (`vB`, `uA`) in D. Finally we set the attribute capacity = 1 for each + arc in D [1]_. + + For a directed graph having `n` nodes and `m` arcs we derive a + directed graph D with `2n` nodes and `m+n` arcs by replacing each + original node `v` with two nodes `vA`, `vB` linked by an (internal) + arc (`vA`, `vB`) in D. Then for each arc (`u`, `v`) in G we add one + arc (`uB`, `vA`) in D. Finally we set the attribute capacity = 1 for + each arc in D. + + A dictionary with a mapping between nodes in the original graph and the + auxiliary digraph is stored as a graph attribute: D.graph['mapping']. + + References + ---------- + .. [1] Kammer, Frank and Hanjo Taubig. Graph Connectivity. in Brandes and + Erlebach, 'Network Analysis: Methodological Foundations', Lecture + Notes in Computer Science, Volume 3418, Springer-Verlag, 2005. + https://doi.org/10.1007/978-3-540-31955-9_7 + + """ + directed = G.is_directed() + + mapping = {} + H = nx.DiGraph() + + for i, node in enumerate(G): + mapping[node] = i + H.add_node(f"{i}A", id=node) + H.add_node(f"{i}B", id=node) + H.add_edge(f"{i}A", f"{i}B", capacity=1) + + edges = [] + for source, target in G.edges(): + edges.append((f"{mapping[source]}B", f"{mapping[target]}A")) + if not directed: + edges.append((f"{mapping[target]}B", f"{mapping[source]}A")) + H.add_edges_from(edges, capacity=1) + + # Store mapping as graph attribute + H.graph["mapping"] = mapping + return H + + +@nx._dispatchable(returns_graph=True) +def build_auxiliary_edge_connectivity(G): + """Auxiliary digraph for computing flow based edge connectivity + + If the input graph is undirected, we replace each edge (`u`,`v`) with + two reciprocal arcs (`u`, `v`) and (`v`, `u`) and then we set the attribute + 'capacity' for each arc to 1. If the input graph is directed we simply + add the 'capacity' attribute. Part of algorithm 1 in [1]_ . + + References + ---------- + .. [1] Abdol-Hossein Esfahanian. Connectivity Algorithms. (this is a + chapter, look for the reference of the book). + http://www.cse.msu.edu/~cse835/Papers/Graph_connectivity_revised.pdf + """ + if G.is_directed(): + H = nx.DiGraph() + H.add_nodes_from(G.nodes()) + H.add_edges_from(G.edges(), capacity=1) + return H + else: + H = nx.DiGraph() + H.add_nodes_from(G.nodes()) + for source, target in G.edges(): + H.add_edges_from([(source, target), (target, source)], capacity=1) + return H diff --git a/lib/python3.10/site-packages/networkx/algorithms/flow/__init__.py b/lib/python3.10/site-packages/networkx/algorithms/flow/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c5d19abed99501086359c87670edc31a680fe36c --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/flow/__init__.py @@ -0,0 +1,11 @@ +from .maxflow import * +from .mincost import * +from .boykovkolmogorov import * +from .dinitz_alg import * +from .edmondskarp import * +from .gomory_hu import * +from .preflowpush import * +from .shortestaugmentingpath import * +from .capacityscaling import * +from .networksimplex import * +from .utils import build_flow_dict, build_residual_network diff --git a/lib/python3.10/site-packages/networkx/algorithms/flow/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/flow/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..55c4aeb07de8b3f66055a24ac52ff49db93fb238 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/flow/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/flow/__pycache__/boykovkolmogorov.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/flow/__pycache__/boykovkolmogorov.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea66e569b91521da829775ac8d7606b8687b4aa0 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/flow/__pycache__/boykovkolmogorov.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/flow/__pycache__/capacityscaling.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/flow/__pycache__/capacityscaling.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e41d10ac4be1efb3ceedfd30a330a0aca0e98a37 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/flow/__pycache__/capacityscaling.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/flow/__pycache__/dinitz_alg.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/flow/__pycache__/dinitz_alg.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1de36a44efc0b87c784d29d3f5bfb8a8ac1ad9c6 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/flow/__pycache__/dinitz_alg.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/flow/__pycache__/edmondskarp.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/flow/__pycache__/edmondskarp.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b3432ce55f68b2a410847f2aef6caaceb516a38a Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/flow/__pycache__/edmondskarp.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/flow/__pycache__/gomory_hu.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/flow/__pycache__/gomory_hu.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6085b11039a7b20890c1d511cf1042de84813bc2 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/flow/__pycache__/gomory_hu.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/flow/__pycache__/maxflow.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/flow/__pycache__/maxflow.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a3f3b7658a5b8d7cbf6261c89b97df2e7f54bdfd Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/flow/__pycache__/maxflow.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/flow/__pycache__/mincost.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/flow/__pycache__/mincost.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6b2de35cf8b37df0f3451d309e8a74b8d276559 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/flow/__pycache__/mincost.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/flow/__pycache__/networksimplex.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/flow/__pycache__/networksimplex.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5fb33a7b0fdea4844761891cbb0930e2daad330a Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/flow/__pycache__/networksimplex.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/flow/__pycache__/preflowpush.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/flow/__pycache__/preflowpush.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e965b6c9fa0410566da79cc5c453f2878c2eed99 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/flow/__pycache__/preflowpush.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/flow/__pycache__/shortestaugmentingpath.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/flow/__pycache__/shortestaugmentingpath.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c5521ed799f9bc73c74b3b10ebd85295e8ea944 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/flow/__pycache__/shortestaugmentingpath.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/flow/__pycache__/utils.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/flow/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c45e22a68c9f2d34c11489d82709ec4626eea7e0 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/flow/__pycache__/utils.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/flow/boykovkolmogorov.py b/lib/python3.10/site-packages/networkx/algorithms/flow/boykovkolmogorov.py new file mode 100644 index 0000000000000000000000000000000000000000..30899c6c33e7ff508cfb13886a13ec96fef4ba44 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/flow/boykovkolmogorov.py @@ -0,0 +1,370 @@ +""" +Boykov-Kolmogorov algorithm for maximum flow problems. +""" + +from collections import deque +from operator import itemgetter + +import networkx as nx +from networkx.algorithms.flow.utils import build_residual_network + +__all__ = ["boykov_kolmogorov"] + + +@nx._dispatchable(edge_attrs={"capacity": float("inf")}, returns_graph=True) +def boykov_kolmogorov( + G, s, t, capacity="capacity", residual=None, value_only=False, cutoff=None +): + r"""Find a maximum single-commodity flow using Boykov-Kolmogorov algorithm. + + This function returns the residual network resulting after computing + the maximum flow. See below for details about the conventions + NetworkX uses for defining residual networks. + + This algorithm has worse case complexity $O(n^2 m |C|)$ for $n$ nodes, $m$ + edges, and $|C|$ the cost of the minimum cut [1]_. This implementation + uses the marking heuristic defined in [2]_ which improves its running + time in many practical problems. + + Parameters + ---------- + G : NetworkX graph + Edges of the graph are expected to have an attribute called + 'capacity'. If this attribute is not present, the edge is + considered to have infinite capacity. + + s : node + Source node for the flow. + + t : node + Sink node for the flow. + + capacity : string + Edges of the graph G are expected to have an attribute capacity + that indicates how much flow the edge can support. If this + attribute is not present, the edge is considered to have + infinite capacity. Default value: 'capacity'. + + residual : NetworkX graph + Residual network on which the algorithm is to be executed. If None, a + new residual network is created. Default value: None. + + value_only : bool + If True compute only the value of the maximum flow. This parameter + will be ignored by this algorithm because it is not applicable. + + cutoff : integer, float + If specified, the algorithm will terminate when the flow value reaches + or exceeds the cutoff. In this case, it may be unable to immediately + determine a minimum cut. Default value: None. + + Returns + ------- + R : NetworkX DiGraph + Residual network after computing the maximum flow. + + Raises + ------ + NetworkXError + The algorithm does not support MultiGraph and MultiDiGraph. If + the input graph is an instance of one of these two classes, a + NetworkXError is raised. + + NetworkXUnbounded + If the graph has a path of infinite capacity, the value of a + feasible flow on the graph is unbounded above and the function + raises a NetworkXUnbounded. + + See also + -------- + :meth:`maximum_flow` + :meth:`minimum_cut` + :meth:`preflow_push` + :meth:`shortest_augmenting_path` + + Notes + ----- + The residual network :samp:`R` from an input graph :samp:`G` has the + same nodes as :samp:`G`. :samp:`R` is a DiGraph that contains a pair + of edges :samp:`(u, v)` and :samp:`(v, u)` iff :samp:`(u, v)` is not a + self-loop, and at least one of :samp:`(u, v)` and :samp:`(v, u)` exists + in :samp:`G`. + + For each edge :samp:`(u, v)` in :samp:`R`, :samp:`R[u][v]['capacity']` + is equal to the capacity of :samp:`(u, v)` in :samp:`G` if it exists + in :samp:`G` or zero otherwise. If the capacity is infinite, + :samp:`R[u][v]['capacity']` will have a high arbitrary finite value + that does not affect the solution of the problem. This value is stored in + :samp:`R.graph['inf']`. For each edge :samp:`(u, v)` in :samp:`R`, + :samp:`R[u][v]['flow']` represents the flow function of :samp:`(u, v)` and + satisfies :samp:`R[u][v]['flow'] == -R[v][u]['flow']`. + + The flow value, defined as the total flow into :samp:`t`, the sink, is + stored in :samp:`R.graph['flow_value']`. If :samp:`cutoff` is not + specified, reachability to :samp:`t` using only edges :samp:`(u, v)` such + that :samp:`R[u][v]['flow'] < R[u][v]['capacity']` induces a minimum + :samp:`s`-:samp:`t` cut. + + Examples + -------- + >>> from networkx.algorithms.flow import boykov_kolmogorov + + The functions that implement flow algorithms and output a residual + network, such as this one, are not imported to the base NetworkX + namespace, so you have to explicitly import them from the flow package. + + >>> G = nx.DiGraph() + >>> G.add_edge("x", "a", capacity=3.0) + >>> G.add_edge("x", "b", capacity=1.0) + >>> G.add_edge("a", "c", capacity=3.0) + >>> G.add_edge("b", "c", capacity=5.0) + >>> G.add_edge("b", "d", capacity=4.0) + >>> G.add_edge("d", "e", capacity=2.0) + >>> G.add_edge("c", "y", capacity=2.0) + >>> G.add_edge("e", "y", capacity=3.0) + >>> R = boykov_kolmogorov(G, "x", "y") + >>> flow_value = nx.maximum_flow_value(G, "x", "y") + >>> flow_value + 3.0 + >>> flow_value == R.graph["flow_value"] + True + + A nice feature of the Boykov-Kolmogorov algorithm is that a partition + of the nodes that defines a minimum cut can be easily computed based + on the search trees used during the algorithm. These trees are stored + in the graph attribute `trees` of the residual network. + + >>> source_tree, target_tree = R.graph["trees"] + >>> partition = (set(source_tree), set(G) - set(source_tree)) + + Or equivalently: + + >>> partition = (set(G) - set(target_tree), set(target_tree)) + + References + ---------- + .. [1] Boykov, Y., & Kolmogorov, V. (2004). An experimental comparison + of min-cut/max-flow algorithms for energy minimization in vision. + Pattern Analysis and Machine Intelligence, IEEE Transactions on, + 26(9), 1124-1137. + https://doi.org/10.1109/TPAMI.2004.60 + + .. [2] Vladimir Kolmogorov. Graph-based Algorithms for Multi-camera + Reconstruction Problem. PhD thesis, Cornell University, CS Department, + 2003. pp. 109-114. + https://web.archive.org/web/20170809091249/https://pub.ist.ac.at/~vnk/papers/thesis.pdf + + """ + R = boykov_kolmogorov_impl(G, s, t, capacity, residual, cutoff) + R.graph["algorithm"] = "boykov_kolmogorov" + nx._clear_cache(R) + return R + + +def boykov_kolmogorov_impl(G, s, t, capacity, residual, cutoff): + if s not in G: + raise nx.NetworkXError(f"node {str(s)} not in graph") + if t not in G: + raise nx.NetworkXError(f"node {str(t)} not in graph") + if s == t: + raise nx.NetworkXError("source and sink are the same node") + + if residual is None: + R = build_residual_network(G, capacity) + else: + R = residual + + # Initialize/reset the residual network. + # This is way too slow + # nx.set_edge_attributes(R, 0, 'flow') + for u in R: + for e in R[u].values(): + e["flow"] = 0 + + # Use an arbitrary high value as infinite. It is computed + # when building the residual network. + INF = R.graph["inf"] + + if cutoff is None: + cutoff = INF + + R_succ = R.succ + R_pred = R.pred + + def grow(): + """Bidirectional breadth-first search for the growth stage. + + Returns a connecting edge, that is and edge that connects + a node from the source search tree with a node from the + target search tree. + The first node in the connecting edge is always from the + source tree and the last node from the target tree. + """ + while active: + u = active[0] + if u in source_tree: + this_tree = source_tree + other_tree = target_tree + neighbors = R_succ + else: + this_tree = target_tree + other_tree = source_tree + neighbors = R_pred + for v, attr in neighbors[u].items(): + if attr["capacity"] - attr["flow"] > 0: + if v not in this_tree: + if v in other_tree: + return (u, v) if this_tree is source_tree else (v, u) + this_tree[v] = u + dist[v] = dist[u] + 1 + timestamp[v] = timestamp[u] + active.append(v) + elif v in this_tree and _is_closer(u, v): + this_tree[v] = u + dist[v] = dist[u] + 1 + timestamp[v] = timestamp[u] + _ = active.popleft() + return None, None + + def augment(u, v): + """Augmentation stage. + + Reconstruct path and determine its residual capacity. + We start from a connecting edge, which links a node + from the source tree to a node from the target tree. + The connecting edge is the output of the grow function + and the input of this function. + """ + attr = R_succ[u][v] + flow = min(INF, attr["capacity"] - attr["flow"]) + path = [u] + # Trace a path from u to s in source_tree. + w = u + while w != s: + n = w + w = source_tree[n] + attr = R_pred[n][w] + flow = min(flow, attr["capacity"] - attr["flow"]) + path.append(w) + path.reverse() + # Trace a path from v to t in target_tree. + path.append(v) + w = v + while w != t: + n = w + w = target_tree[n] + attr = R_succ[n][w] + flow = min(flow, attr["capacity"] - attr["flow"]) + path.append(w) + # Augment flow along the path and check for saturated edges. + it = iter(path) + u = next(it) + these_orphans = [] + for v in it: + R_succ[u][v]["flow"] += flow + R_succ[v][u]["flow"] -= flow + if R_succ[u][v]["flow"] == R_succ[u][v]["capacity"]: + if v in source_tree: + source_tree[v] = None + these_orphans.append(v) + if u in target_tree: + target_tree[u] = None + these_orphans.append(u) + u = v + orphans.extend(sorted(these_orphans, key=dist.get)) + return flow + + def adopt(): + """Adoption stage. + + Reconstruct search trees by adopting or discarding orphans. + During augmentation stage some edges got saturated and thus + the source and target search trees broke down to forests, with + orphans as roots of some of its trees. We have to reconstruct + the search trees rooted to source and target before we can grow + them again. + """ + while orphans: + u = orphans.popleft() + if u in source_tree: + tree = source_tree + neighbors = R_pred + else: + tree = target_tree + neighbors = R_succ + nbrs = ((n, attr, dist[n]) for n, attr in neighbors[u].items() if n in tree) + for v, attr, d in sorted(nbrs, key=itemgetter(2)): + if attr["capacity"] - attr["flow"] > 0: + if _has_valid_root(v, tree): + tree[u] = v + dist[u] = dist[v] + 1 + timestamp[u] = time + break + else: + nbrs = ( + (n, attr, dist[n]) for n, attr in neighbors[u].items() if n in tree + ) + for v, attr, d in sorted(nbrs, key=itemgetter(2)): + if attr["capacity"] - attr["flow"] > 0: + if v not in active: + active.append(v) + if tree[v] == u: + tree[v] = None + orphans.appendleft(v) + if u in active: + active.remove(u) + del tree[u] + + def _has_valid_root(n, tree): + path = [] + v = n + while v is not None: + path.append(v) + if v in (s, t): + base_dist = 0 + break + elif timestamp[v] == time: + base_dist = dist[v] + break + v = tree[v] + else: + return False + length = len(path) + for i, u in enumerate(path, 1): + dist[u] = base_dist + length - i + timestamp[u] = time + return True + + def _is_closer(u, v): + return timestamp[v] <= timestamp[u] and dist[v] > dist[u] + 1 + + source_tree = {s: None} + target_tree = {t: None} + active = deque([s, t]) + orphans = deque() + flow_value = 0 + # data structures for the marking heuristic + time = 1 + timestamp = {s: time, t: time} + dist = {s: 0, t: 0} + while flow_value < cutoff: + # Growth stage + u, v = grow() + if u is None: + break + time += 1 + # Augmentation stage + flow_value += augment(u, v) + # Adoption stage + adopt() + + if flow_value * 2 > INF: + raise nx.NetworkXUnbounded("Infinite capacity path, flow unbounded above.") + + # Add source and target tree in a graph attribute. + # A partition that defines a minimum cut can be directly + # computed from the search trees as explained in the docstrings. + R.graph["trees"] = (source_tree, target_tree) + # Add the standard flow_value graph attribute. + R.graph["flow_value"] = flow_value + return R diff --git a/lib/python3.10/site-packages/networkx/algorithms/flow/capacityscaling.py b/lib/python3.10/site-packages/networkx/algorithms/flow/capacityscaling.py new file mode 100644 index 0000000000000000000000000000000000000000..bf68565c5486bb7b60e7ddcf6089e448bc6ddef1 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/flow/capacityscaling.py @@ -0,0 +1,407 @@ +""" +Capacity scaling minimum cost flow algorithm. +""" + +__all__ = ["capacity_scaling"] + +from itertools import chain +from math import log + +import networkx as nx + +from ...utils import BinaryHeap, arbitrary_element, not_implemented_for + + +def _detect_unboundedness(R): + """Detect infinite-capacity negative cycles.""" + G = nx.DiGraph() + G.add_nodes_from(R) + + # Value simulating infinity. + inf = R.graph["inf"] + # True infinity. + f_inf = float("inf") + for u in R: + for v, e in R[u].items(): + # Compute the minimum weight of infinite-capacity (u, v) edges. + w = f_inf + for k, e in e.items(): + if e["capacity"] == inf: + w = min(w, e["weight"]) + if w != f_inf: + G.add_edge(u, v, weight=w) + + if nx.negative_edge_cycle(G): + raise nx.NetworkXUnbounded( + "Negative cost cycle of infinite capacity found. " + "Min cost flow may be unbounded below." + ) + + +@not_implemented_for("undirected") +def _build_residual_network(G, demand, capacity, weight): + """Build a residual network and initialize a zero flow.""" + if sum(G.nodes[u].get(demand, 0) for u in G) != 0: + raise nx.NetworkXUnfeasible("Sum of the demands should be 0.") + + R = nx.MultiDiGraph() + R.add_nodes_from( + (u, {"excess": -G.nodes[u].get(demand, 0), "potential": 0}) for u in G + ) + + inf = float("inf") + # Detect selfloops with infinite capacities and negative weights. + for u, v, e in nx.selfloop_edges(G, data=True): + if e.get(weight, 0) < 0 and e.get(capacity, inf) == inf: + raise nx.NetworkXUnbounded( + "Negative cost cycle of infinite capacity found. " + "Min cost flow may be unbounded below." + ) + + # Extract edges with positive capacities. Self loops excluded. + if G.is_multigraph(): + edge_list = [ + (u, v, k, e) + for u, v, k, e in G.edges(data=True, keys=True) + if u != v and e.get(capacity, inf) > 0 + ] + else: + edge_list = [ + (u, v, 0, e) + for u, v, e in G.edges(data=True) + if u != v and e.get(capacity, inf) > 0 + ] + # Simulate infinity with the larger of the sum of absolute node imbalances + # the sum of finite edge capacities or any positive value if both sums are + # zero. This allows the infinite-capacity edges to be distinguished for + # unboundedness detection and directly participate in residual capacity + # calculation. + inf = ( + max( + sum(abs(R.nodes[u]["excess"]) for u in R), + 2 + * sum( + e[capacity] + for u, v, k, e in edge_list + if capacity in e and e[capacity] != inf + ), + ) + or 1 + ) + for u, v, k, e in edge_list: + r = min(e.get(capacity, inf), inf) + w = e.get(weight, 0) + # Add both (u, v) and (v, u) into the residual network marked with the + # original key. (key[1] == True) indicates the (u, v) is in the + # original network. + R.add_edge(u, v, key=(k, True), capacity=r, weight=w, flow=0) + R.add_edge(v, u, key=(k, False), capacity=0, weight=-w, flow=0) + + # Record the value simulating infinity. + R.graph["inf"] = inf + + _detect_unboundedness(R) + + return R + + +def _build_flow_dict(G, R, capacity, weight): + """Build a flow dictionary from a residual network.""" + inf = float("inf") + flow_dict = {} + if G.is_multigraph(): + for u in G: + flow_dict[u] = {} + for v, es in G[u].items(): + flow_dict[u][v] = { + # Always saturate negative selfloops. + k: ( + 0 + if ( + u != v or e.get(capacity, inf) <= 0 or e.get(weight, 0) >= 0 + ) + else e[capacity] + ) + for k, e in es.items() + } + for v, es in R[u].items(): + if v in flow_dict[u]: + flow_dict[u][v].update( + (k[0], e["flow"]) for k, e in es.items() if e["flow"] > 0 + ) + else: + for u in G: + flow_dict[u] = { + # Always saturate negative selfloops. + v: ( + 0 + if (u != v or e.get(capacity, inf) <= 0 or e.get(weight, 0) >= 0) + else e[capacity] + ) + for v, e in G[u].items() + } + flow_dict[u].update( + (v, e["flow"]) + for v, es in R[u].items() + for e in es.values() + if e["flow"] > 0 + ) + return flow_dict + + +@nx._dispatchable( + node_attrs="demand", edge_attrs={"capacity": float("inf"), "weight": 0} +) +def capacity_scaling( + G, demand="demand", capacity="capacity", weight="weight", heap=BinaryHeap +): + r"""Find a minimum cost flow satisfying all demands in digraph G. + + This is a capacity scaling successive shortest augmenting path algorithm. + + G is a digraph with edge costs and capacities and in which nodes + have demand, i.e., they want to send or receive some amount of + flow. A negative demand means that the node wants to send flow, a + positive demand means that the node want to receive flow. A flow on + the digraph G satisfies all demand if the net flow into each node + is equal to the demand of that node. + + Parameters + ---------- + G : NetworkX graph + DiGraph or MultiDiGraph on which a minimum cost flow satisfying all + demands is to be found. + + demand : string + Nodes of the graph G are expected to have an attribute demand + that indicates how much flow a node wants to send (negative + demand) or receive (positive demand). Note that the sum of the + demands should be 0 otherwise the problem in not feasible. If + this attribute is not present, a node is considered to have 0 + demand. Default value: 'demand'. + + capacity : string + Edges of the graph G are expected to have an attribute capacity + that indicates how much flow the edge can support. If this + attribute is not present, the edge is considered to have + infinite capacity. Default value: 'capacity'. + + weight : string + Edges of the graph G are expected to have an attribute weight + that indicates the cost incurred by sending one unit of flow on + that edge. If not present, the weight is considered to be 0. + Default value: 'weight'. + + heap : class + Type of heap to be used in the algorithm. It should be a subclass of + :class:`MinHeap` or implement a compatible interface. + + If a stock heap implementation is to be used, :class:`BinaryHeap` is + recommended over :class:`PairingHeap` for Python implementations without + optimized attribute accesses (e.g., CPython) despite a slower + asymptotic running time. For Python implementations with optimized + attribute accesses (e.g., PyPy), :class:`PairingHeap` provides better + performance. Default value: :class:`BinaryHeap`. + + Returns + ------- + flowCost : integer + Cost of a minimum cost flow satisfying all demands. + + flowDict : dictionary + If G is a digraph, a dict-of-dicts keyed by nodes such that + flowDict[u][v] is the flow on edge (u, v). + If G is a MultiDiGraph, a dict-of-dicts-of-dicts keyed by nodes + so that flowDict[u][v][key] is the flow on edge (u, v, key). + + Raises + ------ + NetworkXError + This exception is raised if the input graph is not directed, + not connected. + + NetworkXUnfeasible + This exception is raised in the following situations: + + * The sum of the demands is not zero. Then, there is no + flow satisfying all demands. + * There is no flow satisfying all demand. + + NetworkXUnbounded + This exception is raised if the digraph G has a cycle of + negative cost and infinite capacity. Then, the cost of a flow + satisfying all demands is unbounded below. + + Notes + ----- + This algorithm does not work if edge weights are floating-point numbers. + + See also + -------- + :meth:`network_simplex` + + Examples + -------- + A simple example of a min cost flow problem. + + >>> G = nx.DiGraph() + >>> G.add_node("a", demand=-5) + >>> G.add_node("d", demand=5) + >>> G.add_edge("a", "b", weight=3, capacity=4) + >>> G.add_edge("a", "c", weight=6, capacity=10) + >>> G.add_edge("b", "d", weight=1, capacity=9) + >>> G.add_edge("c", "d", weight=2, capacity=5) + >>> flowCost, flowDict = nx.capacity_scaling(G) + >>> flowCost + 24 + >>> flowDict + {'a': {'b': 4, 'c': 1}, 'd': {}, 'b': {'d': 4}, 'c': {'d': 1}} + + It is possible to change the name of the attributes used for the + algorithm. + + >>> G = nx.DiGraph() + >>> G.add_node("p", spam=-4) + >>> G.add_node("q", spam=2) + >>> G.add_node("a", spam=-2) + >>> G.add_node("d", spam=-1) + >>> G.add_node("t", spam=2) + >>> G.add_node("w", spam=3) + >>> G.add_edge("p", "q", cost=7, vacancies=5) + >>> G.add_edge("p", "a", cost=1, vacancies=4) + >>> G.add_edge("q", "d", cost=2, vacancies=3) + >>> G.add_edge("t", "q", cost=1, vacancies=2) + >>> G.add_edge("a", "t", cost=2, vacancies=4) + >>> G.add_edge("d", "w", cost=3, vacancies=4) + >>> G.add_edge("t", "w", cost=4, vacancies=1) + >>> flowCost, flowDict = nx.capacity_scaling( + ... G, demand="spam", capacity="vacancies", weight="cost" + ... ) + >>> flowCost + 37 + >>> flowDict + {'p': {'q': 2, 'a': 2}, 'q': {'d': 1}, 'a': {'t': 4}, 'd': {'w': 2}, 't': {'q': 1, 'w': 1}, 'w': {}} + """ + R = _build_residual_network(G, demand, capacity, weight) + + inf = float("inf") + # Account cost of negative selfloops. + flow_cost = sum( + 0 + if e.get(capacity, inf) <= 0 or e.get(weight, 0) >= 0 + else e[capacity] * e[weight] + for u, v, e in nx.selfloop_edges(G, data=True) + ) + + # Determine the maximum edge capacity. + wmax = max(chain([-inf], (e["capacity"] for u, v, e in R.edges(data=True)))) + if wmax == -inf: + # Residual network has no edges. + return flow_cost, _build_flow_dict(G, R, capacity, weight) + + R_nodes = R.nodes + R_succ = R.succ + + delta = 2 ** int(log(wmax, 2)) + while delta >= 1: + # Saturate Δ-residual edges with negative reduced costs to achieve + # Δ-optimality. + for u in R: + p_u = R_nodes[u]["potential"] + for v, es in R_succ[u].items(): + for k, e in es.items(): + flow = e["capacity"] - e["flow"] + if e["weight"] - p_u + R_nodes[v]["potential"] < 0: + flow = e["capacity"] - e["flow"] + if flow >= delta: + e["flow"] += flow + R_succ[v][u][(k[0], not k[1])]["flow"] -= flow + R_nodes[u]["excess"] -= flow + R_nodes[v]["excess"] += flow + # Determine the Δ-active nodes. + S = set() + T = set() + S_add = S.add + S_remove = S.remove + T_add = T.add + T_remove = T.remove + for u in R: + excess = R_nodes[u]["excess"] + if excess >= delta: + S_add(u) + elif excess <= -delta: + T_add(u) + # Repeatedly augment flow from S to T along shortest paths until + # Δ-feasibility is achieved. + while S and T: + s = arbitrary_element(S) + t = None + # Search for a shortest path in terms of reduce costs from s to + # any t in T in the Δ-residual network. + d = {} + pred = {s: None} + h = heap() + h_insert = h.insert + h_get = h.get + h_insert(s, 0) + while h: + u, d_u = h.pop() + d[u] = d_u + if u in T: + # Path found. + t = u + break + p_u = R_nodes[u]["potential"] + for v, es in R_succ[u].items(): + if v in d: + continue + wmin = inf + # Find the minimum-weighted (u, v) Δ-residual edge. + for k, e in es.items(): + if e["capacity"] - e["flow"] >= delta: + w = e["weight"] + if w < wmin: + wmin = w + kmin = k + emin = e + if wmin == inf: + continue + # Update the distance label of v. + d_v = d_u + wmin - p_u + R_nodes[v]["potential"] + if h_insert(v, d_v): + pred[v] = (u, kmin, emin) + if t is not None: + # Augment Δ units of flow from s to t. + while u != s: + v = u + u, k, e = pred[v] + e["flow"] += delta + R_succ[v][u][(k[0], not k[1])]["flow"] -= delta + # Account node excess and deficit. + R_nodes[s]["excess"] -= delta + R_nodes[t]["excess"] += delta + if R_nodes[s]["excess"] < delta: + S_remove(s) + if R_nodes[t]["excess"] > -delta: + T_remove(t) + # Update node potentials. + d_t = d[t] + for u, d_u in d.items(): + R_nodes[u]["potential"] -= d_u - d_t + else: + # Path not found. + S_remove(s) + delta //= 2 + + if any(R.nodes[u]["excess"] != 0 for u in R): + raise nx.NetworkXUnfeasible("No flow satisfying all demands.") + + # Calculate the flow cost. + for u in R: + for v, es in R_succ[u].items(): + for e in es.values(): + flow = e["flow"] + if flow > 0: + flow_cost += flow * e["weight"] + + return flow_cost, _build_flow_dict(G, R, capacity, weight) diff --git a/lib/python3.10/site-packages/networkx/algorithms/flow/dinitz_alg.py b/lib/python3.10/site-packages/networkx/algorithms/flow/dinitz_alg.py new file mode 100644 index 0000000000000000000000000000000000000000..f369642af2968094184741132a843f5dde81e428 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/flow/dinitz_alg.py @@ -0,0 +1,238 @@ +""" +Dinitz' algorithm for maximum flow problems. +""" + +from collections import deque + +import networkx as nx +from networkx.algorithms.flow.utils import build_residual_network +from networkx.utils import pairwise + +__all__ = ["dinitz"] + + +@nx._dispatchable(edge_attrs={"capacity": float("inf")}, returns_graph=True) +def dinitz(G, s, t, capacity="capacity", residual=None, value_only=False, cutoff=None): + """Find a maximum single-commodity flow using Dinitz' algorithm. + + This function returns the residual network resulting after computing + the maximum flow. See below for details about the conventions + NetworkX uses for defining residual networks. + + This algorithm has a running time of $O(n^2 m)$ for $n$ nodes and $m$ + edges [1]_. + + + Parameters + ---------- + G : NetworkX graph + Edges of the graph are expected to have an attribute called + 'capacity'. If this attribute is not present, the edge is + considered to have infinite capacity. + + s : node + Source node for the flow. + + t : node + Sink node for the flow. + + capacity : string + Edges of the graph G are expected to have an attribute capacity + that indicates how much flow the edge can support. If this + attribute is not present, the edge is considered to have + infinite capacity. Default value: 'capacity'. + + residual : NetworkX graph + Residual network on which the algorithm is to be executed. If None, a + new residual network is created. Default value: None. + + value_only : bool + If True compute only the value of the maximum flow. This parameter + will be ignored by this algorithm because it is not applicable. + + cutoff : integer, float + If specified, the algorithm will terminate when the flow value reaches + or exceeds the cutoff. In this case, it may be unable to immediately + determine a minimum cut. Default value: None. + + Returns + ------- + R : NetworkX DiGraph + Residual network after computing the maximum flow. + + Raises + ------ + NetworkXError + The algorithm does not support MultiGraph and MultiDiGraph. If + the input graph is an instance of one of these two classes, a + NetworkXError is raised. + + NetworkXUnbounded + If the graph has a path of infinite capacity, the value of a + feasible flow on the graph is unbounded above and the function + raises a NetworkXUnbounded. + + See also + -------- + :meth:`maximum_flow` + :meth:`minimum_cut` + :meth:`preflow_push` + :meth:`shortest_augmenting_path` + + Notes + ----- + The residual network :samp:`R` from an input graph :samp:`G` has the + same nodes as :samp:`G`. :samp:`R` is a DiGraph that contains a pair + of edges :samp:`(u, v)` and :samp:`(v, u)` iff :samp:`(u, v)` is not a + self-loop, and at least one of :samp:`(u, v)` and :samp:`(v, u)` exists + in :samp:`G`. + + For each edge :samp:`(u, v)` in :samp:`R`, :samp:`R[u][v]['capacity']` + is equal to the capacity of :samp:`(u, v)` in :samp:`G` if it exists + in :samp:`G` or zero otherwise. If the capacity is infinite, + :samp:`R[u][v]['capacity']` will have a high arbitrary finite value + that does not affect the solution of the problem. This value is stored in + :samp:`R.graph['inf']`. For each edge :samp:`(u, v)` in :samp:`R`, + :samp:`R[u][v]['flow']` represents the flow function of :samp:`(u, v)` and + satisfies :samp:`R[u][v]['flow'] == -R[v][u]['flow']`. + + The flow value, defined as the total flow into :samp:`t`, the sink, is + stored in :samp:`R.graph['flow_value']`. If :samp:`cutoff` is not + specified, reachability to :samp:`t` using only edges :samp:`(u, v)` such + that :samp:`R[u][v]['flow'] < R[u][v]['capacity']` induces a minimum + :samp:`s`-:samp:`t` cut. + + Examples + -------- + >>> from networkx.algorithms.flow import dinitz + + The functions that implement flow algorithms and output a residual + network, such as this one, are not imported to the base NetworkX + namespace, so you have to explicitly import them from the flow package. + + >>> G = nx.DiGraph() + >>> G.add_edge("x", "a", capacity=3.0) + >>> G.add_edge("x", "b", capacity=1.0) + >>> G.add_edge("a", "c", capacity=3.0) + >>> G.add_edge("b", "c", capacity=5.0) + >>> G.add_edge("b", "d", capacity=4.0) + >>> G.add_edge("d", "e", capacity=2.0) + >>> G.add_edge("c", "y", capacity=2.0) + >>> G.add_edge("e", "y", capacity=3.0) + >>> R = dinitz(G, "x", "y") + >>> flow_value = nx.maximum_flow_value(G, "x", "y") + >>> flow_value + 3.0 + >>> flow_value == R.graph["flow_value"] + True + + References + ---------- + .. [1] Dinitz' Algorithm: The Original Version and Even's Version. + 2006. Yefim Dinitz. In Theoretical Computer Science. Lecture + Notes in Computer Science. Volume 3895. pp 218-240. + https://doi.org/10.1007/11685654_10 + + """ + R = dinitz_impl(G, s, t, capacity, residual, cutoff) + R.graph["algorithm"] = "dinitz" + nx._clear_cache(R) + return R + + +def dinitz_impl(G, s, t, capacity, residual, cutoff): + if s not in G: + raise nx.NetworkXError(f"node {str(s)} not in graph") + if t not in G: + raise nx.NetworkXError(f"node {str(t)} not in graph") + if s == t: + raise nx.NetworkXError("source and sink are the same node") + + if residual is None: + R = build_residual_network(G, capacity) + else: + R = residual + + # Initialize/reset the residual network. + for u in R: + for e in R[u].values(): + e["flow"] = 0 + + # Use an arbitrary high value as infinite. It is computed + # when building the residual network. + INF = R.graph["inf"] + + if cutoff is None: + cutoff = INF + + R_succ = R.succ + R_pred = R.pred + + def breath_first_search(): + parents = {} + vertex_dist = {s: 0} + queue = deque([(s, 0)]) + # Record all the potential edges of shortest augmenting paths + while queue: + if t in parents: + break + u, dist = queue.popleft() + for v, attr in R_succ[u].items(): + if attr["capacity"] - attr["flow"] > 0: + if v in parents: + if vertex_dist[v] == dist + 1: + parents[v].append(u) + else: + parents[v] = deque([u]) + vertex_dist[v] = dist + 1 + queue.append((v, dist + 1)) + return parents + + def depth_first_search(parents): + # DFS to find all the shortest augmenting paths + """Build a path using DFS starting from the sink""" + total_flow = 0 + u = t + # path also functions as a stack + path = [u] + # The loop ends with no augmenting path left in the layered graph + while True: + if len(parents[u]) > 0: + v = parents[u][0] + path.append(v) + else: + path.pop() + if len(path) == 0: + break + v = path[-1] + parents[v].popleft() + # Augment the flow along the path found + if v == s: + flow = INF + for u, v in pairwise(path): + flow = min(flow, R_pred[u][v]["capacity"] - R_pred[u][v]["flow"]) + for u, v in pairwise(reversed(path)): + R_pred[v][u]["flow"] += flow + R_pred[u][v]["flow"] -= flow + # Find the proper node to continue the search + if R_pred[v][u]["capacity"] - R_pred[v][u]["flow"] == 0: + parents[v].popleft() + while path[-1] != v: + path.pop() + total_flow += flow + v = path[-1] + u = v + return total_flow + + flow_value = 0 + while flow_value < cutoff: + parents = breath_first_search() + if t not in parents: + break + this_flow = depth_first_search(parents) + if this_flow * 2 > INF: + raise nx.NetworkXUnbounded("Infinite capacity path, flow unbounded above.") + flow_value += this_flow + + R.graph["flow_value"] = flow_value + return R diff --git a/lib/python3.10/site-packages/networkx/algorithms/flow/edmondskarp.py b/lib/python3.10/site-packages/networkx/algorithms/flow/edmondskarp.py new file mode 100644 index 0000000000000000000000000000000000000000..50063268355ccc2e2ecbdf7f1a6704e7404475ec --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/flow/edmondskarp.py @@ -0,0 +1,241 @@ +""" +Edmonds-Karp algorithm for maximum flow problems. +""" + +import networkx as nx +from networkx.algorithms.flow.utils import build_residual_network + +__all__ = ["edmonds_karp"] + + +def edmonds_karp_core(R, s, t, cutoff): + """Implementation of the Edmonds-Karp algorithm.""" + R_nodes = R.nodes + R_pred = R.pred + R_succ = R.succ + + inf = R.graph["inf"] + + def augment(path): + """Augment flow along a path from s to t.""" + # Determine the path residual capacity. + flow = inf + it = iter(path) + u = next(it) + for v in it: + attr = R_succ[u][v] + flow = min(flow, attr["capacity"] - attr["flow"]) + u = v + if flow * 2 > inf: + raise nx.NetworkXUnbounded("Infinite capacity path, flow unbounded above.") + # Augment flow along the path. + it = iter(path) + u = next(it) + for v in it: + R_succ[u][v]["flow"] += flow + R_succ[v][u]["flow"] -= flow + u = v + return flow + + def bidirectional_bfs(): + """Bidirectional breadth-first search for an augmenting path.""" + pred = {s: None} + q_s = [s] + succ = {t: None} + q_t = [t] + while True: + q = [] + if len(q_s) <= len(q_t): + for u in q_s: + for v, attr in R_succ[u].items(): + if v not in pred and attr["flow"] < attr["capacity"]: + pred[v] = u + if v in succ: + return v, pred, succ + q.append(v) + if not q: + return None, None, None + q_s = q + else: + for u in q_t: + for v, attr in R_pred[u].items(): + if v not in succ and attr["flow"] < attr["capacity"]: + succ[v] = u + if v in pred: + return v, pred, succ + q.append(v) + if not q: + return None, None, None + q_t = q + + # Look for shortest augmenting paths using breadth-first search. + flow_value = 0 + while flow_value < cutoff: + v, pred, succ = bidirectional_bfs() + if pred is None: + break + path = [v] + # Trace a path from s to v. + u = v + while u != s: + u = pred[u] + path.append(u) + path.reverse() + # Trace a path from v to t. + u = v + while u != t: + u = succ[u] + path.append(u) + flow_value += augment(path) + + return flow_value + + +def edmonds_karp_impl(G, s, t, capacity, residual, cutoff): + """Implementation of the Edmonds-Karp algorithm.""" + if s not in G: + raise nx.NetworkXError(f"node {str(s)} not in graph") + if t not in G: + raise nx.NetworkXError(f"node {str(t)} not in graph") + if s == t: + raise nx.NetworkXError("source and sink are the same node") + + if residual is None: + R = build_residual_network(G, capacity) + else: + R = residual + + # Initialize/reset the residual network. + for u in R: + for e in R[u].values(): + e["flow"] = 0 + + if cutoff is None: + cutoff = float("inf") + R.graph["flow_value"] = edmonds_karp_core(R, s, t, cutoff) + + return R + + +@nx._dispatchable(edge_attrs={"capacity": float("inf")}, returns_graph=True) +def edmonds_karp( + G, s, t, capacity="capacity", residual=None, value_only=False, cutoff=None +): + """Find a maximum single-commodity flow using the Edmonds-Karp algorithm. + + This function returns the residual network resulting after computing + the maximum flow. See below for details about the conventions + NetworkX uses for defining residual networks. + + This algorithm has a running time of $O(n m^2)$ for $n$ nodes and $m$ + edges. + + + Parameters + ---------- + G : NetworkX graph + Edges of the graph are expected to have an attribute called + 'capacity'. If this attribute is not present, the edge is + considered to have infinite capacity. + + s : node + Source node for the flow. + + t : node + Sink node for the flow. + + capacity : string + Edges of the graph G are expected to have an attribute capacity + that indicates how much flow the edge can support. If this + attribute is not present, the edge is considered to have + infinite capacity. Default value: 'capacity'. + + residual : NetworkX graph + Residual network on which the algorithm is to be executed. If None, a + new residual network is created. Default value: None. + + value_only : bool + If True compute only the value of the maximum flow. This parameter + will be ignored by this algorithm because it is not applicable. + + cutoff : integer, float + If specified, the algorithm will terminate when the flow value reaches + or exceeds the cutoff. In this case, it may be unable to immediately + determine a minimum cut. Default value: None. + + Returns + ------- + R : NetworkX DiGraph + Residual network after computing the maximum flow. + + Raises + ------ + NetworkXError + The algorithm does not support MultiGraph and MultiDiGraph. If + the input graph is an instance of one of these two classes, a + NetworkXError is raised. + + NetworkXUnbounded + If the graph has a path of infinite capacity, the value of a + feasible flow on the graph is unbounded above and the function + raises a NetworkXUnbounded. + + See also + -------- + :meth:`maximum_flow` + :meth:`minimum_cut` + :meth:`preflow_push` + :meth:`shortest_augmenting_path` + + Notes + ----- + The residual network :samp:`R` from an input graph :samp:`G` has the + same nodes as :samp:`G`. :samp:`R` is a DiGraph that contains a pair + of edges :samp:`(u, v)` and :samp:`(v, u)` iff :samp:`(u, v)` is not a + self-loop, and at least one of :samp:`(u, v)` and :samp:`(v, u)` exists + in :samp:`G`. + + For each edge :samp:`(u, v)` in :samp:`R`, :samp:`R[u][v]['capacity']` + is equal to the capacity of :samp:`(u, v)` in :samp:`G` if it exists + in :samp:`G` or zero otherwise. If the capacity is infinite, + :samp:`R[u][v]['capacity']` will have a high arbitrary finite value + that does not affect the solution of the problem. This value is stored in + :samp:`R.graph['inf']`. For each edge :samp:`(u, v)` in :samp:`R`, + :samp:`R[u][v]['flow']` represents the flow function of :samp:`(u, v)` and + satisfies :samp:`R[u][v]['flow'] == -R[v][u]['flow']`. + + The flow value, defined as the total flow into :samp:`t`, the sink, is + stored in :samp:`R.graph['flow_value']`. If :samp:`cutoff` is not + specified, reachability to :samp:`t` using only edges :samp:`(u, v)` such + that :samp:`R[u][v]['flow'] < R[u][v]['capacity']` induces a minimum + :samp:`s`-:samp:`t` cut. + + Examples + -------- + >>> from networkx.algorithms.flow import edmonds_karp + + The functions that implement flow algorithms and output a residual + network, such as this one, are not imported to the base NetworkX + namespace, so you have to explicitly import them from the flow package. + + >>> G = nx.DiGraph() + >>> G.add_edge("x", "a", capacity=3.0) + >>> G.add_edge("x", "b", capacity=1.0) + >>> G.add_edge("a", "c", capacity=3.0) + >>> G.add_edge("b", "c", capacity=5.0) + >>> G.add_edge("b", "d", capacity=4.0) + >>> G.add_edge("d", "e", capacity=2.0) + >>> G.add_edge("c", "y", capacity=2.0) + >>> G.add_edge("e", "y", capacity=3.0) + >>> R = edmonds_karp(G, "x", "y") + >>> flow_value = nx.maximum_flow_value(G, "x", "y") + >>> flow_value + 3.0 + >>> flow_value == R.graph["flow_value"] + True + + """ + R = edmonds_karp_impl(G, s, t, capacity, residual, cutoff) + R.graph["algorithm"] = "edmonds_karp" + nx._clear_cache(R) + return R diff --git a/lib/python3.10/site-packages/networkx/algorithms/flow/gomory_hu.py b/lib/python3.10/site-packages/networkx/algorithms/flow/gomory_hu.py new file mode 100644 index 0000000000000000000000000000000000000000..69913da904547b3a9fe682467b69e696e9c8e0dc --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/flow/gomory_hu.py @@ -0,0 +1,178 @@ +""" +Gomory-Hu tree of undirected Graphs. +""" + +import networkx as nx +from networkx.utils import not_implemented_for + +from .edmondskarp import edmonds_karp +from .utils import build_residual_network + +default_flow_func = edmonds_karp + +__all__ = ["gomory_hu_tree"] + + +@not_implemented_for("directed") +@nx._dispatchable(edge_attrs={"capacity": float("inf")}, returns_graph=True) +def gomory_hu_tree(G, capacity="capacity", flow_func=None): + r"""Returns the Gomory-Hu tree of an undirected graph G. + + A Gomory-Hu tree of an undirected graph with capacities is a + weighted tree that represents the minimum s-t cuts for all s-t + pairs in the graph. + + It only requires `n-1` minimum cut computations instead of the + obvious `n(n-1)/2`. The tree represents all s-t cuts as the + minimum cut value among any pair of nodes is the minimum edge + weight in the shortest path between the two nodes in the + Gomory-Hu tree. + + The Gomory-Hu tree also has the property that removing the + edge with the minimum weight in the shortest path between + any two nodes leaves two connected components that form + a partition of the nodes in G that defines the minimum s-t + cut. + + See Examples section below for details. + + Parameters + ---------- + G : NetworkX graph + Undirected graph + + capacity : string + Edges of the graph G are expected to have an attribute capacity + that indicates how much flow the edge can support. If this + attribute is not present, the edge is considered to have + infinite capacity. Default value: 'capacity'. + + flow_func : function + Function to perform the underlying flow computations. Default value + :func:`edmonds_karp`. This function performs better in sparse graphs + with right tailed degree distributions. + :func:`shortest_augmenting_path` will perform better in denser + graphs. + + Returns + ------- + Tree : NetworkX graph + A NetworkX graph representing the Gomory-Hu tree of the input graph. + + Raises + ------ + NetworkXNotImplemented + Raised if the input graph is directed. + + NetworkXError + Raised if the input graph is an empty Graph. + + Examples + -------- + >>> G = nx.karate_club_graph() + >>> nx.set_edge_attributes(G, 1, "capacity") + >>> T = nx.gomory_hu_tree(G) + >>> # The value of the minimum cut between any pair + ... # of nodes in G is the minimum edge weight in the + ... # shortest path between the two nodes in the + ... # Gomory-Hu tree. + ... def minimum_edge_weight_in_shortest_path(T, u, v): + ... path = nx.shortest_path(T, u, v, weight="weight") + ... return min((T[u][v]["weight"], (u, v)) for (u, v) in zip(path, path[1:])) + >>> u, v = 0, 33 + >>> cut_value, edge = minimum_edge_weight_in_shortest_path(T, u, v) + >>> cut_value + 10 + >>> nx.minimum_cut_value(G, u, v) + 10 + >>> # The Gomory-Hu tree also has the property that removing the + ... # edge with the minimum weight in the shortest path between + ... # any two nodes leaves two connected components that form + ... # a partition of the nodes in G that defines the minimum s-t + ... # cut. + ... cut_value, edge = minimum_edge_weight_in_shortest_path(T, u, v) + >>> T.remove_edge(*edge) + >>> U, V = list(nx.connected_components(T)) + >>> # Thus U and V form a partition that defines a minimum cut + ... # between u and v in G. You can compute the edge cut set, + ... # that is, the set of edges that if removed from G will + ... # disconnect u from v in G, with this information: + ... cutset = set() + >>> for x, nbrs in ((n, G[n]) for n in U): + ... cutset.update((x, y) for y in nbrs if y in V) + >>> # Because we have set the capacities of all edges to 1 + ... # the cutset contains ten edges + ... len(cutset) + 10 + >>> # You can use any maximum flow algorithm for the underlying + ... # flow computations using the argument flow_func + ... from networkx.algorithms import flow + >>> T = nx.gomory_hu_tree(G, flow_func=flow.boykov_kolmogorov) + >>> cut_value, edge = minimum_edge_weight_in_shortest_path(T, u, v) + >>> cut_value + 10 + >>> nx.minimum_cut_value(G, u, v, flow_func=flow.boykov_kolmogorov) + 10 + + Notes + ----- + This implementation is based on Gusfield approach [1]_ to compute + Gomory-Hu trees, which does not require node contractions and has + the same computational complexity than the original method. + + See also + -------- + :func:`minimum_cut` + :func:`maximum_flow` + + References + ---------- + .. [1] Gusfield D: Very simple methods for all pairs network flow analysis. + SIAM J Comput 19(1):143-155, 1990. + + """ + if flow_func is None: + flow_func = default_flow_func + + if len(G) == 0: # empty graph + msg = "Empty Graph does not have a Gomory-Hu tree representation" + raise nx.NetworkXError(msg) + + # Start the tree as a star graph with an arbitrary node at the center + tree = {} + labels = {} + iter_nodes = iter(G) + root = next(iter_nodes) + for n in iter_nodes: + tree[n] = root + + # Reuse residual network + R = build_residual_network(G, capacity) + + # For all the leaves in the star graph tree (that is n-1 nodes). + for source in tree: + # Find neighbor in the tree + target = tree[source] + # compute minimum cut + cut_value, partition = nx.minimum_cut( + G, source, target, capacity=capacity, flow_func=flow_func, residual=R + ) + labels[(source, target)] = cut_value + # Update the tree + # Source will always be in partition[0] and target in partition[1] + for node in partition[0]: + if node != source and node in tree and tree[node] == target: + tree[node] = source + labels[node, source] = labels.get((node, target), cut_value) + # + if target != root and tree[target] in partition[0]: + labels[source, tree[target]] = labels[target, tree[target]] + labels[target, source] = cut_value + tree[source] = tree[target] + tree[target] = source + + # Build the tree + T = nx.Graph() + T.add_nodes_from(G) + T.add_weighted_edges_from(((u, v, labels[u, v]) for u, v in tree.items())) + return T diff --git a/lib/python3.10/site-packages/networkx/algorithms/flow/maxflow.py b/lib/python3.10/site-packages/networkx/algorithms/flow/maxflow.py new file mode 100644 index 0000000000000000000000000000000000000000..7993d87ba9ad8c3f3aa0639f82590f4c16f5f4b7 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/flow/maxflow.py @@ -0,0 +1,607 @@ +""" +Maximum flow (and minimum cut) algorithms on capacitated graphs. +""" + +import networkx as nx + +from .boykovkolmogorov import boykov_kolmogorov +from .dinitz_alg import dinitz +from .edmondskarp import edmonds_karp +from .preflowpush import preflow_push +from .shortestaugmentingpath import shortest_augmenting_path +from .utils import build_flow_dict + +# Define the default flow function for computing maximum flow. +default_flow_func = preflow_push + +__all__ = ["maximum_flow", "maximum_flow_value", "minimum_cut", "minimum_cut_value"] + + +@nx._dispatchable(graphs="flowG", edge_attrs={"capacity": float("inf")}) +def maximum_flow(flowG, _s, _t, capacity="capacity", flow_func=None, **kwargs): + """Find a maximum single-commodity flow. + + Parameters + ---------- + flowG : NetworkX graph + Edges of the graph are expected to have an attribute called + 'capacity'. If this attribute is not present, the edge is + considered to have infinite capacity. + + _s : node + Source node for the flow. + + _t : node + Sink node for the flow. + + capacity : string + Edges of the graph G are expected to have an attribute capacity + that indicates how much flow the edge can support. If this + attribute is not present, the edge is considered to have + infinite capacity. Default value: 'capacity'. + + flow_func : function + A function for computing the maximum flow among a pair of nodes + in a capacitated graph. The function has to accept at least three + parameters: a Graph or Digraph, a source node, and a target node. + And return a residual network that follows NetworkX conventions + (see Notes). If flow_func is None, the default maximum + flow function (:meth:`preflow_push`) is used. See below for + alternative algorithms. The choice of the default function may change + from version to version and should not be relied on. Default value: + None. + + kwargs : Any other keyword parameter is passed to the function that + computes the maximum flow. + + Returns + ------- + flow_value : integer, float + Value of the maximum flow, i.e., net outflow from the source. + + flow_dict : dict + A dictionary containing the value of the flow that went through + each edge. + + Raises + ------ + NetworkXError + The algorithm does not support MultiGraph and MultiDiGraph. If + the input graph is an instance of one of these two classes, a + NetworkXError is raised. + + NetworkXUnbounded + If the graph has a path of infinite capacity, the value of a + feasible flow on the graph is unbounded above and the function + raises a NetworkXUnbounded. + + See also + -------- + :meth:`maximum_flow_value` + :meth:`minimum_cut` + :meth:`minimum_cut_value` + :meth:`edmonds_karp` + :meth:`preflow_push` + :meth:`shortest_augmenting_path` + + Notes + ----- + The function used in the flow_func parameter has to return a residual + network that follows NetworkX conventions: + + The residual network :samp:`R` from an input graph :samp:`G` has the + same nodes as :samp:`G`. :samp:`R` is a DiGraph that contains a pair + of edges :samp:`(u, v)` and :samp:`(v, u)` iff :samp:`(u, v)` is not a + self-loop, and at least one of :samp:`(u, v)` and :samp:`(v, u)` exists + in :samp:`G`. + + For each edge :samp:`(u, v)` in :samp:`R`, :samp:`R[u][v]['capacity']` + is equal to the capacity of :samp:`(u, v)` in :samp:`G` if it exists + in :samp:`G` or zero otherwise. If the capacity is infinite, + :samp:`R[u][v]['capacity']` will have a high arbitrary finite value + that does not affect the solution of the problem. This value is stored in + :samp:`R.graph['inf']`. For each edge :samp:`(u, v)` in :samp:`R`, + :samp:`R[u][v]['flow']` represents the flow function of :samp:`(u, v)` and + satisfies :samp:`R[u][v]['flow'] == -R[v][u]['flow']`. + + The flow value, defined as the total flow into :samp:`t`, the sink, is + stored in :samp:`R.graph['flow_value']`. Reachability to :samp:`t` using + only edges :samp:`(u, v)` such that + :samp:`R[u][v]['flow'] < R[u][v]['capacity']` induces a minimum + :samp:`s`-:samp:`t` cut. + + Specific algorithms may store extra data in :samp:`R`. + + The function should supports an optional boolean parameter value_only. When + True, it can optionally terminate the algorithm as soon as the maximum flow + value and the minimum cut can be determined. + + Examples + -------- + >>> G = nx.DiGraph() + >>> G.add_edge("x", "a", capacity=3.0) + >>> G.add_edge("x", "b", capacity=1.0) + >>> G.add_edge("a", "c", capacity=3.0) + >>> G.add_edge("b", "c", capacity=5.0) + >>> G.add_edge("b", "d", capacity=4.0) + >>> G.add_edge("d", "e", capacity=2.0) + >>> G.add_edge("c", "y", capacity=2.0) + >>> G.add_edge("e", "y", capacity=3.0) + + maximum_flow returns both the value of the maximum flow and a + dictionary with all flows. + + >>> flow_value, flow_dict = nx.maximum_flow(G, "x", "y") + >>> flow_value + 3.0 + >>> print(flow_dict["x"]["b"]) + 1.0 + + You can also use alternative algorithms for computing the + maximum flow by using the flow_func parameter. + + >>> from networkx.algorithms.flow import shortest_augmenting_path + >>> flow_value == nx.maximum_flow(G, "x", "y", flow_func=shortest_augmenting_path)[ + ... 0 + ... ] + True + + """ + if flow_func is None: + if kwargs: + raise nx.NetworkXError( + "You have to explicitly set a flow_func if" + " you need to pass parameters via kwargs." + ) + flow_func = default_flow_func + + if not callable(flow_func): + raise nx.NetworkXError("flow_func has to be callable.") + + R = flow_func(flowG, _s, _t, capacity=capacity, value_only=False, **kwargs) + flow_dict = build_flow_dict(flowG, R) + + return (R.graph["flow_value"], flow_dict) + + +@nx._dispatchable(graphs="flowG", edge_attrs={"capacity": float("inf")}) +def maximum_flow_value(flowG, _s, _t, capacity="capacity", flow_func=None, **kwargs): + """Find the value of maximum single-commodity flow. + + Parameters + ---------- + flowG : NetworkX graph + Edges of the graph are expected to have an attribute called + 'capacity'. If this attribute is not present, the edge is + considered to have infinite capacity. + + _s : node + Source node for the flow. + + _t : node + Sink node for the flow. + + capacity : string + Edges of the graph G are expected to have an attribute capacity + that indicates how much flow the edge can support. If this + attribute is not present, the edge is considered to have + infinite capacity. Default value: 'capacity'. + + flow_func : function + A function for computing the maximum flow among a pair of nodes + in a capacitated graph. The function has to accept at least three + parameters: a Graph or Digraph, a source node, and a target node. + And return a residual network that follows NetworkX conventions + (see Notes). If flow_func is None, the default maximum + flow function (:meth:`preflow_push`) is used. See below for + alternative algorithms. The choice of the default function may change + from version to version and should not be relied on. Default value: + None. + + kwargs : Any other keyword parameter is passed to the function that + computes the maximum flow. + + Returns + ------- + flow_value : integer, float + Value of the maximum flow, i.e., net outflow from the source. + + Raises + ------ + NetworkXError + The algorithm does not support MultiGraph and MultiDiGraph. If + the input graph is an instance of one of these two classes, a + NetworkXError is raised. + + NetworkXUnbounded + If the graph has a path of infinite capacity, the value of a + feasible flow on the graph is unbounded above and the function + raises a NetworkXUnbounded. + + See also + -------- + :meth:`maximum_flow` + :meth:`minimum_cut` + :meth:`minimum_cut_value` + :meth:`edmonds_karp` + :meth:`preflow_push` + :meth:`shortest_augmenting_path` + + Notes + ----- + The function used in the flow_func parameter has to return a residual + network that follows NetworkX conventions: + + The residual network :samp:`R` from an input graph :samp:`G` has the + same nodes as :samp:`G`. :samp:`R` is a DiGraph that contains a pair + of edges :samp:`(u, v)` and :samp:`(v, u)` iff :samp:`(u, v)` is not a + self-loop, and at least one of :samp:`(u, v)` and :samp:`(v, u)` exists + in :samp:`G`. + + For each edge :samp:`(u, v)` in :samp:`R`, :samp:`R[u][v]['capacity']` + is equal to the capacity of :samp:`(u, v)` in :samp:`G` if it exists + in :samp:`G` or zero otherwise. If the capacity is infinite, + :samp:`R[u][v]['capacity']` will have a high arbitrary finite value + that does not affect the solution of the problem. This value is stored in + :samp:`R.graph['inf']`. For each edge :samp:`(u, v)` in :samp:`R`, + :samp:`R[u][v]['flow']` represents the flow function of :samp:`(u, v)` and + satisfies :samp:`R[u][v]['flow'] == -R[v][u]['flow']`. + + The flow value, defined as the total flow into :samp:`t`, the sink, is + stored in :samp:`R.graph['flow_value']`. Reachability to :samp:`t` using + only edges :samp:`(u, v)` such that + :samp:`R[u][v]['flow'] < R[u][v]['capacity']` induces a minimum + :samp:`s`-:samp:`t` cut. + + Specific algorithms may store extra data in :samp:`R`. + + The function should supports an optional boolean parameter value_only. When + True, it can optionally terminate the algorithm as soon as the maximum flow + value and the minimum cut can be determined. + + Examples + -------- + >>> G = nx.DiGraph() + >>> G.add_edge("x", "a", capacity=3.0) + >>> G.add_edge("x", "b", capacity=1.0) + >>> G.add_edge("a", "c", capacity=3.0) + >>> G.add_edge("b", "c", capacity=5.0) + >>> G.add_edge("b", "d", capacity=4.0) + >>> G.add_edge("d", "e", capacity=2.0) + >>> G.add_edge("c", "y", capacity=2.0) + >>> G.add_edge("e", "y", capacity=3.0) + + maximum_flow_value computes only the value of the + maximum flow: + + >>> flow_value = nx.maximum_flow_value(G, "x", "y") + >>> flow_value + 3.0 + + You can also use alternative algorithms for computing the + maximum flow by using the flow_func parameter. + + >>> from networkx.algorithms.flow import shortest_augmenting_path + >>> flow_value == nx.maximum_flow_value( + ... G, "x", "y", flow_func=shortest_augmenting_path + ... ) + True + + """ + if flow_func is None: + if kwargs: + raise nx.NetworkXError( + "You have to explicitly set a flow_func if" + " you need to pass parameters via kwargs." + ) + flow_func = default_flow_func + + if not callable(flow_func): + raise nx.NetworkXError("flow_func has to be callable.") + + R = flow_func(flowG, _s, _t, capacity=capacity, value_only=True, **kwargs) + + return R.graph["flow_value"] + + +@nx._dispatchable(graphs="flowG", edge_attrs={"capacity": float("inf")}) +def minimum_cut(flowG, _s, _t, capacity="capacity", flow_func=None, **kwargs): + """Compute the value and the node partition of a minimum (s, t)-cut. + + Use the max-flow min-cut theorem, i.e., the capacity of a minimum + capacity cut is equal to the flow value of a maximum flow. + + Parameters + ---------- + flowG : NetworkX graph + Edges of the graph are expected to have an attribute called + 'capacity'. If this attribute is not present, the edge is + considered to have infinite capacity. + + _s : node + Source node for the flow. + + _t : node + Sink node for the flow. + + capacity : string + Edges of the graph G are expected to have an attribute capacity + that indicates how much flow the edge can support. If this + attribute is not present, the edge is considered to have + infinite capacity. Default value: 'capacity'. + + flow_func : function + A function for computing the maximum flow among a pair of nodes + in a capacitated graph. The function has to accept at least three + parameters: a Graph or Digraph, a source node, and a target node. + And return a residual network that follows NetworkX conventions + (see Notes). If flow_func is None, the default maximum + flow function (:meth:`preflow_push`) is used. See below for + alternative algorithms. The choice of the default function may change + from version to version and should not be relied on. Default value: + None. + + kwargs : Any other keyword parameter is passed to the function that + computes the maximum flow. + + Returns + ------- + cut_value : integer, float + Value of the minimum cut. + + partition : pair of node sets + A partitioning of the nodes that defines a minimum cut. + + Raises + ------ + NetworkXUnbounded + If the graph has a path of infinite capacity, all cuts have + infinite capacity and the function raises a NetworkXError. + + See also + -------- + :meth:`maximum_flow` + :meth:`maximum_flow_value` + :meth:`minimum_cut_value` + :meth:`edmonds_karp` + :meth:`preflow_push` + :meth:`shortest_augmenting_path` + + Notes + ----- + The function used in the flow_func parameter has to return a residual + network that follows NetworkX conventions: + + The residual network :samp:`R` from an input graph :samp:`G` has the + same nodes as :samp:`G`. :samp:`R` is a DiGraph that contains a pair + of edges :samp:`(u, v)` and :samp:`(v, u)` iff :samp:`(u, v)` is not a + self-loop, and at least one of :samp:`(u, v)` and :samp:`(v, u)` exists + in :samp:`G`. + + For each edge :samp:`(u, v)` in :samp:`R`, :samp:`R[u][v]['capacity']` + is equal to the capacity of :samp:`(u, v)` in :samp:`G` if it exists + in :samp:`G` or zero otherwise. If the capacity is infinite, + :samp:`R[u][v]['capacity']` will have a high arbitrary finite value + that does not affect the solution of the problem. This value is stored in + :samp:`R.graph['inf']`. For each edge :samp:`(u, v)` in :samp:`R`, + :samp:`R[u][v]['flow']` represents the flow function of :samp:`(u, v)` and + satisfies :samp:`R[u][v]['flow'] == -R[v][u]['flow']`. + + The flow value, defined as the total flow into :samp:`t`, the sink, is + stored in :samp:`R.graph['flow_value']`. Reachability to :samp:`t` using + only edges :samp:`(u, v)` such that + :samp:`R[u][v]['flow'] < R[u][v]['capacity']` induces a minimum + :samp:`s`-:samp:`t` cut. + + Specific algorithms may store extra data in :samp:`R`. + + The function should supports an optional boolean parameter value_only. When + True, it can optionally terminate the algorithm as soon as the maximum flow + value and the minimum cut can be determined. + + Examples + -------- + >>> G = nx.DiGraph() + >>> G.add_edge("x", "a", capacity=3.0) + >>> G.add_edge("x", "b", capacity=1.0) + >>> G.add_edge("a", "c", capacity=3.0) + >>> G.add_edge("b", "c", capacity=5.0) + >>> G.add_edge("b", "d", capacity=4.0) + >>> G.add_edge("d", "e", capacity=2.0) + >>> G.add_edge("c", "y", capacity=2.0) + >>> G.add_edge("e", "y", capacity=3.0) + + minimum_cut computes both the value of the + minimum cut and the node partition: + + >>> cut_value, partition = nx.minimum_cut(G, "x", "y") + >>> reachable, non_reachable = partition + + 'partition' here is a tuple with the two sets of nodes that define + the minimum cut. You can compute the cut set of edges that induce + the minimum cut as follows: + + >>> cutset = set() + >>> for u, nbrs in ((n, G[n]) for n in reachable): + ... cutset.update((u, v) for v in nbrs if v in non_reachable) + >>> print(sorted(cutset)) + [('c', 'y'), ('x', 'b')] + >>> cut_value == sum(G.edges[u, v]["capacity"] for (u, v) in cutset) + True + + You can also use alternative algorithms for computing the + minimum cut by using the flow_func parameter. + + >>> from networkx.algorithms.flow import shortest_augmenting_path + >>> cut_value == nx.minimum_cut(G, "x", "y", flow_func=shortest_augmenting_path)[0] + True + + """ + if flow_func is None: + if kwargs: + raise nx.NetworkXError( + "You have to explicitly set a flow_func if" + " you need to pass parameters via kwargs." + ) + flow_func = default_flow_func + + if not callable(flow_func): + raise nx.NetworkXError("flow_func has to be callable.") + + if kwargs.get("cutoff") is not None and flow_func is preflow_push: + raise nx.NetworkXError("cutoff should not be specified.") + + R = flow_func(flowG, _s, _t, capacity=capacity, value_only=True, **kwargs) + # Remove saturated edges from the residual network + cutset = [(u, v, d) for u, v, d in R.edges(data=True) if d["flow"] == d["capacity"]] + R.remove_edges_from(cutset) + + # Then, reachable and non reachable nodes from source in the + # residual network form the node partition that defines + # the minimum cut. + non_reachable = set(dict(nx.shortest_path_length(R, target=_t))) + partition = (set(flowG) - non_reachable, non_reachable) + # Finally add again cutset edges to the residual network to make + # sure that it is reusable. + R.add_edges_from(cutset) + return (R.graph["flow_value"], partition) + + +@nx._dispatchable(graphs="flowG", edge_attrs={"capacity": float("inf")}) +def minimum_cut_value(flowG, _s, _t, capacity="capacity", flow_func=None, **kwargs): + """Compute the value of a minimum (s, t)-cut. + + Use the max-flow min-cut theorem, i.e., the capacity of a minimum + capacity cut is equal to the flow value of a maximum flow. + + Parameters + ---------- + flowG : NetworkX graph + Edges of the graph are expected to have an attribute called + 'capacity'. If this attribute is not present, the edge is + considered to have infinite capacity. + + _s : node + Source node for the flow. + + _t : node + Sink node for the flow. + + capacity : string + Edges of the graph G are expected to have an attribute capacity + that indicates how much flow the edge can support. If this + attribute is not present, the edge is considered to have + infinite capacity. Default value: 'capacity'. + + flow_func : function + A function for computing the maximum flow among a pair of nodes + in a capacitated graph. The function has to accept at least three + parameters: a Graph or Digraph, a source node, and a target node. + And return a residual network that follows NetworkX conventions + (see Notes). If flow_func is None, the default maximum + flow function (:meth:`preflow_push`) is used. See below for + alternative algorithms. The choice of the default function may change + from version to version and should not be relied on. Default value: + None. + + kwargs : Any other keyword parameter is passed to the function that + computes the maximum flow. + + Returns + ------- + cut_value : integer, float + Value of the minimum cut. + + Raises + ------ + NetworkXUnbounded + If the graph has a path of infinite capacity, all cuts have + infinite capacity and the function raises a NetworkXError. + + See also + -------- + :meth:`maximum_flow` + :meth:`maximum_flow_value` + :meth:`minimum_cut` + :meth:`edmonds_karp` + :meth:`preflow_push` + :meth:`shortest_augmenting_path` + + Notes + ----- + The function used in the flow_func parameter has to return a residual + network that follows NetworkX conventions: + + The residual network :samp:`R` from an input graph :samp:`G` has the + same nodes as :samp:`G`. :samp:`R` is a DiGraph that contains a pair + of edges :samp:`(u, v)` and :samp:`(v, u)` iff :samp:`(u, v)` is not a + self-loop, and at least one of :samp:`(u, v)` and :samp:`(v, u)` exists + in :samp:`G`. + + For each edge :samp:`(u, v)` in :samp:`R`, :samp:`R[u][v]['capacity']` + is equal to the capacity of :samp:`(u, v)` in :samp:`G` if it exists + in :samp:`G` or zero otherwise. If the capacity is infinite, + :samp:`R[u][v]['capacity']` will have a high arbitrary finite value + that does not affect the solution of the problem. This value is stored in + :samp:`R.graph['inf']`. For each edge :samp:`(u, v)` in :samp:`R`, + :samp:`R[u][v]['flow']` represents the flow function of :samp:`(u, v)` and + satisfies :samp:`R[u][v]['flow'] == -R[v][u]['flow']`. + + The flow value, defined as the total flow into :samp:`t`, the sink, is + stored in :samp:`R.graph['flow_value']`. Reachability to :samp:`t` using + only edges :samp:`(u, v)` such that + :samp:`R[u][v]['flow'] < R[u][v]['capacity']` induces a minimum + :samp:`s`-:samp:`t` cut. + + Specific algorithms may store extra data in :samp:`R`. + + The function should supports an optional boolean parameter value_only. When + True, it can optionally terminate the algorithm as soon as the maximum flow + value and the minimum cut can be determined. + + Examples + -------- + >>> G = nx.DiGraph() + >>> G.add_edge("x", "a", capacity=3.0) + >>> G.add_edge("x", "b", capacity=1.0) + >>> G.add_edge("a", "c", capacity=3.0) + >>> G.add_edge("b", "c", capacity=5.0) + >>> G.add_edge("b", "d", capacity=4.0) + >>> G.add_edge("d", "e", capacity=2.0) + >>> G.add_edge("c", "y", capacity=2.0) + >>> G.add_edge("e", "y", capacity=3.0) + + minimum_cut_value computes only the value of the + minimum cut: + + >>> cut_value = nx.minimum_cut_value(G, "x", "y") + >>> cut_value + 3.0 + + You can also use alternative algorithms for computing the + minimum cut by using the flow_func parameter. + + >>> from networkx.algorithms.flow import shortest_augmenting_path + >>> cut_value == nx.minimum_cut_value( + ... G, "x", "y", flow_func=shortest_augmenting_path + ... ) + True + + """ + if flow_func is None: + if kwargs: + raise nx.NetworkXError( + "You have to explicitly set a flow_func if" + " you need to pass parameters via kwargs." + ) + flow_func = default_flow_func + + if not callable(flow_func): + raise nx.NetworkXError("flow_func has to be callable.") + + if kwargs.get("cutoff") is not None and flow_func is preflow_push: + raise nx.NetworkXError("cutoff should not be specified.") + + R = flow_func(flowG, _s, _t, capacity=capacity, value_only=True, **kwargs) + + return R.graph["flow_value"] diff --git a/lib/python3.10/site-packages/networkx/algorithms/flow/mincost.py b/lib/python3.10/site-packages/networkx/algorithms/flow/mincost.py new file mode 100644 index 0000000000000000000000000000000000000000..2f9390d7a1c1e454ed7c2f8793d591b338115107 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/flow/mincost.py @@ -0,0 +1,356 @@ +""" +Minimum cost flow algorithms on directed connected graphs. +""" + +__all__ = ["min_cost_flow_cost", "min_cost_flow", "cost_of_flow", "max_flow_min_cost"] + +import networkx as nx + + +@nx._dispatchable( + node_attrs="demand", edge_attrs={"capacity": float("inf"), "weight": 0} +) +def min_cost_flow_cost(G, demand="demand", capacity="capacity", weight="weight"): + r"""Find the cost of a minimum cost flow satisfying all demands in digraph G. + + G is a digraph with edge costs and capacities and in which nodes + have demand, i.e., they want to send or receive some amount of + flow. A negative demand means that the node wants to send flow, a + positive demand means that the node want to receive flow. A flow on + the digraph G satisfies all demand if the net flow into each node + is equal to the demand of that node. + + Parameters + ---------- + G : NetworkX graph + DiGraph on which a minimum cost flow satisfying all demands is + to be found. + + demand : string + Nodes of the graph G are expected to have an attribute demand + that indicates how much flow a node wants to send (negative + demand) or receive (positive demand). Note that the sum of the + demands should be 0 otherwise the problem in not feasible. If + this attribute is not present, a node is considered to have 0 + demand. Default value: 'demand'. + + capacity : string + Edges of the graph G are expected to have an attribute capacity + that indicates how much flow the edge can support. If this + attribute is not present, the edge is considered to have + infinite capacity. Default value: 'capacity'. + + weight : string + Edges of the graph G are expected to have an attribute weight + that indicates the cost incurred by sending one unit of flow on + that edge. If not present, the weight is considered to be 0. + Default value: 'weight'. + + Returns + ------- + flowCost : integer, float + Cost of a minimum cost flow satisfying all demands. + + Raises + ------ + NetworkXError + This exception is raised if the input graph is not directed or + not connected. + + NetworkXUnfeasible + This exception is raised in the following situations: + + * The sum of the demands is not zero. Then, there is no + flow satisfying all demands. + * There is no flow satisfying all demand. + + NetworkXUnbounded + This exception is raised if the digraph G has a cycle of + negative cost and infinite capacity. Then, the cost of a flow + satisfying all demands is unbounded below. + + See also + -------- + cost_of_flow, max_flow_min_cost, min_cost_flow, network_simplex + + Notes + ----- + This algorithm is not guaranteed to work if edge weights or demands + are floating point numbers (overflows and roundoff errors can + cause problems). As a workaround you can use integer numbers by + multiplying the relevant edge attributes by a convenient + constant factor (eg 100). + + Examples + -------- + A simple example of a min cost flow problem. + + >>> G = nx.DiGraph() + >>> G.add_node("a", demand=-5) + >>> G.add_node("d", demand=5) + >>> G.add_edge("a", "b", weight=3, capacity=4) + >>> G.add_edge("a", "c", weight=6, capacity=10) + >>> G.add_edge("b", "d", weight=1, capacity=9) + >>> G.add_edge("c", "d", weight=2, capacity=5) + >>> flowCost = nx.min_cost_flow_cost(G) + >>> flowCost + 24 + """ + return nx.network_simplex(G, demand=demand, capacity=capacity, weight=weight)[0] + + +@nx._dispatchable( + node_attrs="demand", edge_attrs={"capacity": float("inf"), "weight": 0} +) +def min_cost_flow(G, demand="demand", capacity="capacity", weight="weight"): + r"""Returns a minimum cost flow satisfying all demands in digraph G. + + G is a digraph with edge costs and capacities and in which nodes + have demand, i.e., they want to send or receive some amount of + flow. A negative demand means that the node wants to send flow, a + positive demand means that the node want to receive flow. A flow on + the digraph G satisfies all demand if the net flow into each node + is equal to the demand of that node. + + Parameters + ---------- + G : NetworkX graph + DiGraph on which a minimum cost flow satisfying all demands is + to be found. + + demand : string + Nodes of the graph G are expected to have an attribute demand + that indicates how much flow a node wants to send (negative + demand) or receive (positive demand). Note that the sum of the + demands should be 0 otherwise the problem in not feasible. If + this attribute is not present, a node is considered to have 0 + demand. Default value: 'demand'. + + capacity : string + Edges of the graph G are expected to have an attribute capacity + that indicates how much flow the edge can support. If this + attribute is not present, the edge is considered to have + infinite capacity. Default value: 'capacity'. + + weight : string + Edges of the graph G are expected to have an attribute weight + that indicates the cost incurred by sending one unit of flow on + that edge. If not present, the weight is considered to be 0. + Default value: 'weight'. + + Returns + ------- + flowDict : dictionary + Dictionary of dictionaries keyed by nodes such that + flowDict[u][v] is the flow edge (u, v). + + Raises + ------ + NetworkXError + This exception is raised if the input graph is not directed or + not connected. + + NetworkXUnfeasible + This exception is raised in the following situations: + + * The sum of the demands is not zero. Then, there is no + flow satisfying all demands. + * There is no flow satisfying all demand. + + NetworkXUnbounded + This exception is raised if the digraph G has a cycle of + negative cost and infinite capacity. Then, the cost of a flow + satisfying all demands is unbounded below. + + See also + -------- + cost_of_flow, max_flow_min_cost, min_cost_flow_cost, network_simplex + + Notes + ----- + This algorithm is not guaranteed to work if edge weights or demands + are floating point numbers (overflows and roundoff errors can + cause problems). As a workaround you can use integer numbers by + multiplying the relevant edge attributes by a convenient + constant factor (eg 100). + + Examples + -------- + A simple example of a min cost flow problem. + + >>> G = nx.DiGraph() + >>> G.add_node("a", demand=-5) + >>> G.add_node("d", demand=5) + >>> G.add_edge("a", "b", weight=3, capacity=4) + >>> G.add_edge("a", "c", weight=6, capacity=10) + >>> G.add_edge("b", "d", weight=1, capacity=9) + >>> G.add_edge("c", "d", weight=2, capacity=5) + >>> flowDict = nx.min_cost_flow(G) + >>> flowDict + {'a': {'b': 4, 'c': 1}, 'd': {}, 'b': {'d': 4}, 'c': {'d': 1}} + """ + return nx.network_simplex(G, demand=demand, capacity=capacity, weight=weight)[1] + + +@nx._dispatchable(edge_attrs={"weight": 0}) +def cost_of_flow(G, flowDict, weight="weight"): + """Compute the cost of the flow given by flowDict on graph G. + + Note that this function does not check for the validity of the + flow flowDict. This function will fail if the graph G and the + flow don't have the same edge set. + + Parameters + ---------- + G : NetworkX graph + DiGraph on which a minimum cost flow satisfying all demands is + to be found. + + weight : string + Edges of the graph G are expected to have an attribute weight + that indicates the cost incurred by sending one unit of flow on + that edge. If not present, the weight is considered to be 0. + Default value: 'weight'. + + flowDict : dictionary + Dictionary of dictionaries keyed by nodes such that + flowDict[u][v] is the flow edge (u, v). + + Returns + ------- + cost : Integer, float + The total cost of the flow. This is given by the sum over all + edges of the product of the edge's flow and the edge's weight. + + See also + -------- + max_flow_min_cost, min_cost_flow, min_cost_flow_cost, network_simplex + + Notes + ----- + This algorithm is not guaranteed to work if edge weights or demands + are floating point numbers (overflows and roundoff errors can + cause problems). As a workaround you can use integer numbers by + multiplying the relevant edge attributes by a convenient + constant factor (eg 100). + + Examples + -------- + >>> G = nx.DiGraph() + >>> G.add_node("a", demand=-5) + >>> G.add_node("d", demand=5) + >>> G.add_edge("a", "b", weight=3, capacity=4) + >>> G.add_edge("a", "c", weight=6, capacity=10) + >>> G.add_edge("b", "d", weight=1, capacity=9) + >>> G.add_edge("c", "d", weight=2, capacity=5) + >>> flowDict = nx.min_cost_flow(G) + >>> flowDict + {'a': {'b': 4, 'c': 1}, 'd': {}, 'b': {'d': 4}, 'c': {'d': 1}} + >>> nx.cost_of_flow(G, flowDict) + 24 + """ + return sum((flowDict[u][v] * d.get(weight, 0) for u, v, d in G.edges(data=True))) + + +@nx._dispatchable(edge_attrs={"capacity": float("inf"), "weight": 0}) +def max_flow_min_cost(G, s, t, capacity="capacity", weight="weight"): + """Returns a maximum (s, t)-flow of minimum cost. + + G is a digraph with edge costs and capacities. There is a source + node s and a sink node t. This function finds a maximum flow from + s to t whose total cost is minimized. + + Parameters + ---------- + G : NetworkX graph + DiGraph on which a minimum cost flow satisfying all demands is + to be found. + + s: node label + Source of the flow. + + t: node label + Destination of the flow. + + capacity: string + Edges of the graph G are expected to have an attribute capacity + that indicates how much flow the edge can support. If this + attribute is not present, the edge is considered to have + infinite capacity. Default value: 'capacity'. + + weight: string + Edges of the graph G are expected to have an attribute weight + that indicates the cost incurred by sending one unit of flow on + that edge. If not present, the weight is considered to be 0. + Default value: 'weight'. + + Returns + ------- + flowDict: dictionary + Dictionary of dictionaries keyed by nodes such that + flowDict[u][v] is the flow edge (u, v). + + Raises + ------ + NetworkXError + This exception is raised if the input graph is not directed or + not connected. + + NetworkXUnbounded + This exception is raised if there is an infinite capacity path + from s to t in G. In this case there is no maximum flow. This + exception is also raised if the digraph G has a cycle of + negative cost and infinite capacity. Then, the cost of a flow + is unbounded below. + + See also + -------- + cost_of_flow, min_cost_flow, min_cost_flow_cost, network_simplex + + Notes + ----- + This algorithm is not guaranteed to work if edge weights or demands + are floating point numbers (overflows and roundoff errors can + cause problems). As a workaround you can use integer numbers by + multiplying the relevant edge attributes by a convenient + constant factor (eg 100). + + Examples + -------- + >>> G = nx.DiGraph() + >>> G.add_edges_from( + ... [ + ... (1, 2, {"capacity": 12, "weight": 4}), + ... (1, 3, {"capacity": 20, "weight": 6}), + ... (2, 3, {"capacity": 6, "weight": -3}), + ... (2, 6, {"capacity": 14, "weight": 1}), + ... (3, 4, {"weight": 9}), + ... (3, 5, {"capacity": 10, "weight": 5}), + ... (4, 2, {"capacity": 19, "weight": 13}), + ... (4, 5, {"capacity": 4, "weight": 0}), + ... (5, 7, {"capacity": 28, "weight": 2}), + ... (6, 5, {"capacity": 11, "weight": 1}), + ... (6, 7, {"weight": 8}), + ... (7, 4, {"capacity": 6, "weight": 6}), + ... ] + ... ) + >>> mincostFlow = nx.max_flow_min_cost(G, 1, 7) + >>> mincost = nx.cost_of_flow(G, mincostFlow) + >>> mincost + 373 + >>> from networkx.algorithms.flow import maximum_flow + >>> maxFlow = maximum_flow(G, 1, 7)[1] + >>> nx.cost_of_flow(G, maxFlow) >= mincost + True + >>> mincostFlowValue = sum((mincostFlow[u][7] for u in G.predecessors(7))) - sum( + ... (mincostFlow[7][v] for v in G.successors(7)) + ... ) + >>> mincostFlowValue == nx.maximum_flow_value(G, 1, 7) + True + + """ + maxFlow = nx.maximum_flow_value(G, s, t, capacity=capacity) + H = nx.DiGraph(G) + H.add_node(s, demand=-maxFlow) + H.add_node(t, demand=maxFlow) + return min_cost_flow(H, capacity=capacity, weight=weight) diff --git a/lib/python3.10/site-packages/networkx/algorithms/flow/networksimplex.py b/lib/python3.10/site-packages/networkx/algorithms/flow/networksimplex.py new file mode 100644 index 0000000000000000000000000000000000000000..a9822d968808eb0c7bb45794e13150ad659b311a --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/flow/networksimplex.py @@ -0,0 +1,666 @@ +""" +Minimum cost flow algorithms on directed connected graphs. +""" + +__all__ = ["network_simplex"] + +from itertools import chain, islice, repeat +from math import ceil, sqrt + +import networkx as nx +from networkx.utils import not_implemented_for + + +class _DataEssentialsAndFunctions: + def __init__( + self, G, multigraph, demand="demand", capacity="capacity", weight="weight" + ): + # Number all nodes and edges and hereafter reference them using ONLY their numbers + self.node_list = list(G) # nodes + self.node_indices = {u: i for i, u in enumerate(self.node_list)} # node indices + self.node_demands = [ + G.nodes[u].get(demand, 0) for u in self.node_list + ] # node demands + + self.edge_sources = [] # edge sources + self.edge_targets = [] # edge targets + if multigraph: + self.edge_keys = [] # edge keys + self.edge_indices = {} # edge indices + self.edge_capacities = [] # edge capacities + self.edge_weights = [] # edge weights + + if not multigraph: + edges = G.edges(data=True) + else: + edges = G.edges(data=True, keys=True) + + inf = float("inf") + edges = (e for e in edges if e[0] != e[1] and e[-1].get(capacity, inf) != 0) + for i, e in enumerate(edges): + self.edge_sources.append(self.node_indices[e[0]]) + self.edge_targets.append(self.node_indices[e[1]]) + if multigraph: + self.edge_keys.append(e[2]) + self.edge_indices[e[:-1]] = i + self.edge_capacities.append(e[-1].get(capacity, inf)) + self.edge_weights.append(e[-1].get(weight, 0)) + + # spanning tree specific data to be initialized + + self.edge_count = None # number of edges + self.edge_flow = None # edge flows + self.node_potentials = None # node potentials + self.parent = None # parent nodes + self.parent_edge = None # edges to parents + self.subtree_size = None # subtree sizes + self.next_node_dft = None # next nodes in depth-first thread + self.prev_node_dft = None # previous nodes in depth-first thread + self.last_descendent_dft = None # last descendants in depth-first thread + self._spanning_tree_initialized = ( + False # False until initialize_spanning_tree() is called + ) + + def initialize_spanning_tree(self, n, faux_inf): + self.edge_count = len(self.edge_indices) # number of edges + self.edge_flow = list( + chain(repeat(0, self.edge_count), (abs(d) for d in self.node_demands)) + ) # edge flows + self.node_potentials = [ + faux_inf if d <= 0 else -faux_inf for d in self.node_demands + ] # node potentials + self.parent = list(chain(repeat(-1, n), [None])) # parent nodes + self.parent_edge = list( + range(self.edge_count, self.edge_count + n) + ) # edges to parents + self.subtree_size = list(chain(repeat(1, n), [n + 1])) # subtree sizes + self.next_node_dft = list( + chain(range(1, n), [-1, 0]) + ) # next nodes in depth-first thread + self.prev_node_dft = list(range(-1, n)) # previous nodes in depth-first thread + self.last_descendent_dft = list( + chain(range(n), [n - 1]) + ) # last descendants in depth-first thread + self._spanning_tree_initialized = True # True only if all the assignments pass + + def find_apex(self, p, q): + """ + Find the lowest common ancestor of nodes p and q in the spanning tree. + """ + size_p = self.subtree_size[p] + size_q = self.subtree_size[q] + while True: + while size_p < size_q: + p = self.parent[p] + size_p = self.subtree_size[p] + while size_p > size_q: + q = self.parent[q] + size_q = self.subtree_size[q] + if size_p == size_q: + if p != q: + p = self.parent[p] + size_p = self.subtree_size[p] + q = self.parent[q] + size_q = self.subtree_size[q] + else: + return p + + def trace_path(self, p, w): + """ + Returns the nodes and edges on the path from node p to its ancestor w. + """ + Wn = [p] + We = [] + while p != w: + We.append(self.parent_edge[p]) + p = self.parent[p] + Wn.append(p) + return Wn, We + + def find_cycle(self, i, p, q): + """ + Returns the nodes and edges on the cycle containing edge i == (p, q) + when the latter is added to the spanning tree. + + The cycle is oriented in the direction from p to q. + """ + w = self.find_apex(p, q) + Wn, We = self.trace_path(p, w) + Wn.reverse() + We.reverse() + if We != [i]: + We.append(i) + WnR, WeR = self.trace_path(q, w) + del WnR[-1] + Wn += WnR + We += WeR + return Wn, We + + def augment_flow(self, Wn, We, f): + """ + Augment f units of flow along a cycle represented by Wn and We. + """ + for i, p in zip(We, Wn): + if self.edge_sources[i] == p: + self.edge_flow[i] += f + else: + self.edge_flow[i] -= f + + def trace_subtree(self, p): + """ + Yield the nodes in the subtree rooted at a node p. + """ + yield p + l = self.last_descendent_dft[p] + while p != l: + p = self.next_node_dft[p] + yield p + + def remove_edge(self, s, t): + """ + Remove an edge (s, t) where parent[t] == s from the spanning tree. + """ + size_t = self.subtree_size[t] + prev_t = self.prev_node_dft[t] + last_t = self.last_descendent_dft[t] + next_last_t = self.next_node_dft[last_t] + # Remove (s, t). + self.parent[t] = None + self.parent_edge[t] = None + # Remove the subtree rooted at t from the depth-first thread. + self.next_node_dft[prev_t] = next_last_t + self.prev_node_dft[next_last_t] = prev_t + self.next_node_dft[last_t] = t + self.prev_node_dft[t] = last_t + # Update the subtree sizes and last descendants of the (old) ancestors + # of t. + while s is not None: + self.subtree_size[s] -= size_t + if self.last_descendent_dft[s] == last_t: + self.last_descendent_dft[s] = prev_t + s = self.parent[s] + + def make_root(self, q): + """ + Make a node q the root of its containing subtree. + """ + ancestors = [] + while q is not None: + ancestors.append(q) + q = self.parent[q] + ancestors.reverse() + for p, q in zip(ancestors, islice(ancestors, 1, None)): + size_p = self.subtree_size[p] + last_p = self.last_descendent_dft[p] + prev_q = self.prev_node_dft[q] + last_q = self.last_descendent_dft[q] + next_last_q = self.next_node_dft[last_q] + # Make p a child of q. + self.parent[p] = q + self.parent[q] = None + self.parent_edge[p] = self.parent_edge[q] + self.parent_edge[q] = None + self.subtree_size[p] = size_p - self.subtree_size[q] + self.subtree_size[q] = size_p + # Remove the subtree rooted at q from the depth-first thread. + self.next_node_dft[prev_q] = next_last_q + self.prev_node_dft[next_last_q] = prev_q + self.next_node_dft[last_q] = q + self.prev_node_dft[q] = last_q + if last_p == last_q: + self.last_descendent_dft[p] = prev_q + last_p = prev_q + # Add the remaining parts of the subtree rooted at p as a subtree + # of q in the depth-first thread. + self.prev_node_dft[p] = last_q + self.next_node_dft[last_q] = p + self.next_node_dft[last_p] = q + self.prev_node_dft[q] = last_p + self.last_descendent_dft[q] = last_p + + def add_edge(self, i, p, q): + """ + Add an edge (p, q) to the spanning tree where q is the root of a subtree. + """ + last_p = self.last_descendent_dft[p] + next_last_p = self.next_node_dft[last_p] + size_q = self.subtree_size[q] + last_q = self.last_descendent_dft[q] + # Make q a child of p. + self.parent[q] = p + self.parent_edge[q] = i + # Insert the subtree rooted at q into the depth-first thread. + self.next_node_dft[last_p] = q + self.prev_node_dft[q] = last_p + self.prev_node_dft[next_last_p] = last_q + self.next_node_dft[last_q] = next_last_p + # Update the subtree sizes and last descendants of the (new) ancestors + # of q. + while p is not None: + self.subtree_size[p] += size_q + if self.last_descendent_dft[p] == last_p: + self.last_descendent_dft[p] = last_q + p = self.parent[p] + + def update_potentials(self, i, p, q): + """ + Update the potentials of the nodes in the subtree rooted at a node + q connected to its parent p by an edge i. + """ + if q == self.edge_targets[i]: + d = self.node_potentials[p] - self.edge_weights[i] - self.node_potentials[q] + else: + d = self.node_potentials[p] + self.edge_weights[i] - self.node_potentials[q] + for q in self.trace_subtree(q): + self.node_potentials[q] += d + + def reduced_cost(self, i): + """Returns the reduced cost of an edge i.""" + c = ( + self.edge_weights[i] + - self.node_potentials[self.edge_sources[i]] + + self.node_potentials[self.edge_targets[i]] + ) + return c if self.edge_flow[i] == 0 else -c + + def find_entering_edges(self): + """Yield entering edges until none can be found.""" + if self.edge_count == 0: + return + + # Entering edges are found by combining Dantzig's rule and Bland's + # rule. The edges are cyclically grouped into blocks of size B. Within + # each block, Dantzig's rule is applied to find an entering edge. The + # blocks to search is determined following Bland's rule. + B = int(ceil(sqrt(self.edge_count))) # pivot block size + M = (self.edge_count + B - 1) // B # number of blocks needed to cover all edges + m = 0 # number of consecutive blocks without eligible + # entering edges + f = 0 # first edge in block + while m < M: + # Determine the next block of edges. + l = f + B + if l <= self.edge_count: + edges = range(f, l) + else: + l -= self.edge_count + edges = chain(range(f, self.edge_count), range(l)) + f = l + # Find the first edge with the lowest reduced cost. + i = min(edges, key=self.reduced_cost) + c = self.reduced_cost(i) + if c >= 0: + # No entering edge found in the current block. + m += 1 + else: + # Entering edge found. + if self.edge_flow[i] == 0: + p = self.edge_sources[i] + q = self.edge_targets[i] + else: + p = self.edge_targets[i] + q = self.edge_sources[i] + yield i, p, q + m = 0 + # All edges have nonnegative reduced costs. The current flow is + # optimal. + + def residual_capacity(self, i, p): + """Returns the residual capacity of an edge i in the direction away + from its endpoint p. + """ + return ( + self.edge_capacities[i] - self.edge_flow[i] + if self.edge_sources[i] == p + else self.edge_flow[i] + ) + + def find_leaving_edge(self, Wn, We): + """Returns the leaving edge in a cycle represented by Wn and We.""" + j, s = min( + zip(reversed(We), reversed(Wn)), + key=lambda i_p: self.residual_capacity(*i_p), + ) + t = self.edge_targets[j] if self.edge_sources[j] == s else self.edge_sources[j] + return j, s, t + + +@not_implemented_for("undirected") +@nx._dispatchable( + node_attrs="demand", edge_attrs={"capacity": float("inf"), "weight": 0} +) +def network_simplex(G, demand="demand", capacity="capacity", weight="weight"): + r"""Find a minimum cost flow satisfying all demands in digraph G. + + This is a primal network simplex algorithm that uses the leaving + arc rule to prevent cycling. + + G is a digraph with edge costs and capacities and in which nodes + have demand, i.e., they want to send or receive some amount of + flow. A negative demand means that the node wants to send flow, a + positive demand means that the node want to receive flow. A flow on + the digraph G satisfies all demand if the net flow into each node + is equal to the demand of that node. + + Parameters + ---------- + G : NetworkX graph + DiGraph on which a minimum cost flow satisfying all demands is + to be found. + + demand : string + Nodes of the graph G are expected to have an attribute demand + that indicates how much flow a node wants to send (negative + demand) or receive (positive demand). Note that the sum of the + demands should be 0 otherwise the problem in not feasible. If + this attribute is not present, a node is considered to have 0 + demand. Default value: 'demand'. + + capacity : string + Edges of the graph G are expected to have an attribute capacity + that indicates how much flow the edge can support. If this + attribute is not present, the edge is considered to have + infinite capacity. Default value: 'capacity'. + + weight : string + Edges of the graph G are expected to have an attribute weight + that indicates the cost incurred by sending one unit of flow on + that edge. If not present, the weight is considered to be 0. + Default value: 'weight'. + + Returns + ------- + flowCost : integer, float + Cost of a minimum cost flow satisfying all demands. + + flowDict : dictionary + Dictionary of dictionaries keyed by nodes such that + flowDict[u][v] is the flow edge (u, v). + + Raises + ------ + NetworkXError + This exception is raised if the input graph is not directed or + not connected. + + NetworkXUnfeasible + This exception is raised in the following situations: + + * The sum of the demands is not zero. Then, there is no + flow satisfying all demands. + * There is no flow satisfying all demand. + + NetworkXUnbounded + This exception is raised if the digraph G has a cycle of + negative cost and infinite capacity. Then, the cost of a flow + satisfying all demands is unbounded below. + + Notes + ----- + This algorithm is not guaranteed to work if edge weights or demands + are floating point numbers (overflows and roundoff errors can + cause problems). As a workaround you can use integer numbers by + multiplying the relevant edge attributes by a convenient + constant factor (eg 100). + + See also + -------- + cost_of_flow, max_flow_min_cost, min_cost_flow, min_cost_flow_cost + + Examples + -------- + A simple example of a min cost flow problem. + + >>> G = nx.DiGraph() + >>> G.add_node("a", demand=-5) + >>> G.add_node("d", demand=5) + >>> G.add_edge("a", "b", weight=3, capacity=4) + >>> G.add_edge("a", "c", weight=6, capacity=10) + >>> G.add_edge("b", "d", weight=1, capacity=9) + >>> G.add_edge("c", "d", weight=2, capacity=5) + >>> flowCost, flowDict = nx.network_simplex(G) + >>> flowCost + 24 + >>> flowDict + {'a': {'b': 4, 'c': 1}, 'd': {}, 'b': {'d': 4}, 'c': {'d': 1}} + + The mincost flow algorithm can also be used to solve shortest path + problems. To find the shortest path between two nodes u and v, + give all edges an infinite capacity, give node u a demand of -1 and + node v a demand a 1. Then run the network simplex. The value of a + min cost flow will be the distance between u and v and edges + carrying positive flow will indicate the path. + + >>> G = nx.DiGraph() + >>> G.add_weighted_edges_from( + ... [ + ... ("s", "u", 10), + ... ("s", "x", 5), + ... ("u", "v", 1), + ... ("u", "x", 2), + ... ("v", "y", 1), + ... ("x", "u", 3), + ... ("x", "v", 5), + ... ("x", "y", 2), + ... ("y", "s", 7), + ... ("y", "v", 6), + ... ] + ... ) + >>> G.add_node("s", demand=-1) + >>> G.add_node("v", demand=1) + >>> flowCost, flowDict = nx.network_simplex(G) + >>> flowCost == nx.shortest_path_length(G, "s", "v", weight="weight") + True + >>> sorted([(u, v) for u in flowDict for v in flowDict[u] if flowDict[u][v] > 0]) + [('s', 'x'), ('u', 'v'), ('x', 'u')] + >>> nx.shortest_path(G, "s", "v", weight="weight") + ['s', 'x', 'u', 'v'] + + It is possible to change the name of the attributes used for the + algorithm. + + >>> G = nx.DiGraph() + >>> G.add_node("p", spam=-4) + >>> G.add_node("q", spam=2) + >>> G.add_node("a", spam=-2) + >>> G.add_node("d", spam=-1) + >>> G.add_node("t", spam=2) + >>> G.add_node("w", spam=3) + >>> G.add_edge("p", "q", cost=7, vacancies=5) + >>> G.add_edge("p", "a", cost=1, vacancies=4) + >>> G.add_edge("q", "d", cost=2, vacancies=3) + >>> G.add_edge("t", "q", cost=1, vacancies=2) + >>> G.add_edge("a", "t", cost=2, vacancies=4) + >>> G.add_edge("d", "w", cost=3, vacancies=4) + >>> G.add_edge("t", "w", cost=4, vacancies=1) + >>> flowCost, flowDict = nx.network_simplex( + ... G, demand="spam", capacity="vacancies", weight="cost" + ... ) + >>> flowCost + 37 + >>> flowDict + {'p': {'q': 2, 'a': 2}, 'q': {'d': 1}, 'a': {'t': 4}, 'd': {'w': 2}, 't': {'q': 1, 'w': 1}, 'w': {}} + + References + ---------- + .. [1] Z. Kiraly, P. Kovacs. + Efficient implementation of minimum-cost flow algorithms. + Acta Universitatis Sapientiae, Informatica 4(1):67--118. 2012. + .. [2] R. Barr, F. Glover, D. Klingman. + Enhancement of spanning tree labeling procedures for network + optimization. + INFOR 17(1):16--34. 1979. + """ + ########################################################################### + # Problem essentials extraction and sanity check + ########################################################################### + + if len(G) == 0: + raise nx.NetworkXError("graph has no nodes") + + multigraph = G.is_multigraph() + + # extracting data essential to problem + DEAF = _DataEssentialsAndFunctions( + G, multigraph, demand=demand, capacity=capacity, weight=weight + ) + + ########################################################################### + # Quick Error Detection + ########################################################################### + + inf = float("inf") + for u, d in zip(DEAF.node_list, DEAF.node_demands): + if abs(d) == inf: + raise nx.NetworkXError(f"node {u!r} has infinite demand") + for e, w in zip(DEAF.edge_indices, DEAF.edge_weights): + if abs(w) == inf: + raise nx.NetworkXError(f"edge {e!r} has infinite weight") + if not multigraph: + edges = nx.selfloop_edges(G, data=True) + else: + edges = nx.selfloop_edges(G, data=True, keys=True) + for e in edges: + if abs(e[-1].get(weight, 0)) == inf: + raise nx.NetworkXError(f"edge {e[:-1]!r} has infinite weight") + + ########################################################################### + # Quick Infeasibility Detection + ########################################################################### + + if sum(DEAF.node_demands) != 0: + raise nx.NetworkXUnfeasible("total node demand is not zero") + for e, c in zip(DEAF.edge_indices, DEAF.edge_capacities): + if c < 0: + raise nx.NetworkXUnfeasible(f"edge {e!r} has negative capacity") + if not multigraph: + edges = nx.selfloop_edges(G, data=True) + else: + edges = nx.selfloop_edges(G, data=True, keys=True) + for e in edges: + if e[-1].get(capacity, inf) < 0: + raise nx.NetworkXUnfeasible(f"edge {e[:-1]!r} has negative capacity") + + ########################################################################### + # Initialization + ########################################################################### + + # Add a dummy node -1 and connect all existing nodes to it with infinite- + # capacity dummy edges. Node -1 will serve as the root of the + # spanning tree of the network simplex method. The new edges will used to + # trivially satisfy the node demands and create an initial strongly + # feasible spanning tree. + for i, d in enumerate(DEAF.node_demands): + # Must be greater-than here. Zero-demand nodes must have + # edges pointing towards the root to ensure strong feasibility. + if d > 0: + DEAF.edge_sources.append(-1) + DEAF.edge_targets.append(i) + else: + DEAF.edge_sources.append(i) + DEAF.edge_targets.append(-1) + faux_inf = ( + 3 + * max( + chain( + [ + sum(c for c in DEAF.edge_capacities if c < inf), + sum(abs(w) for w in DEAF.edge_weights), + ], + (abs(d) for d in DEAF.node_demands), + ) + ) + or 1 + ) + + n = len(DEAF.node_list) # number of nodes + DEAF.edge_weights.extend(repeat(faux_inf, n)) + DEAF.edge_capacities.extend(repeat(faux_inf, n)) + + # Construct the initial spanning tree. + DEAF.initialize_spanning_tree(n, faux_inf) + + ########################################################################### + # Pivot loop + ########################################################################### + + for i, p, q in DEAF.find_entering_edges(): + Wn, We = DEAF.find_cycle(i, p, q) + j, s, t = DEAF.find_leaving_edge(Wn, We) + DEAF.augment_flow(Wn, We, DEAF.residual_capacity(j, s)) + # Do nothing more if the entering edge is the same as the leaving edge. + if i != j: + if DEAF.parent[t] != s: + # Ensure that s is the parent of t. + s, t = t, s + if We.index(i) > We.index(j): + # Ensure that q is in the subtree rooted at t. + p, q = q, p + DEAF.remove_edge(s, t) + DEAF.make_root(q) + DEAF.add_edge(i, p, q) + DEAF.update_potentials(i, p, q) + + ########################################################################### + # Infeasibility and unboundedness detection + ########################################################################### + + if any(DEAF.edge_flow[i] != 0 for i in range(-n, 0)): + raise nx.NetworkXUnfeasible("no flow satisfies all node demands") + + if any(DEAF.edge_flow[i] * 2 >= faux_inf for i in range(DEAF.edge_count)) or any( + e[-1].get(capacity, inf) == inf and e[-1].get(weight, 0) < 0 + for e in nx.selfloop_edges(G, data=True) + ): + raise nx.NetworkXUnbounded("negative cycle with infinite capacity found") + + ########################################################################### + # Flow cost calculation and flow dict construction + ########################################################################### + + del DEAF.edge_flow[DEAF.edge_count :] + flow_cost = sum(w * x for w, x in zip(DEAF.edge_weights, DEAF.edge_flow)) + flow_dict = {n: {} for n in DEAF.node_list} + + def add_entry(e): + """Add a flow dict entry.""" + d = flow_dict[e[0]] + for k in e[1:-2]: + try: + d = d[k] + except KeyError: + t = {} + d[k] = t + d = t + d[e[-2]] = e[-1] + + DEAF.edge_sources = ( + DEAF.node_list[s] for s in DEAF.edge_sources + ) # Use original nodes. + DEAF.edge_targets = ( + DEAF.node_list[t] for t in DEAF.edge_targets + ) # Use original nodes. + if not multigraph: + for e in zip(DEAF.edge_sources, DEAF.edge_targets, DEAF.edge_flow): + add_entry(e) + edges = G.edges(data=True) + else: + for e in zip( + DEAF.edge_sources, DEAF.edge_targets, DEAF.edge_keys, DEAF.edge_flow + ): + add_entry(e) + edges = G.edges(data=True, keys=True) + for e in edges: + if e[0] != e[1]: + if e[-1].get(capacity, inf) == 0: + add_entry(e[:-1] + (0,)) + else: + w = e[-1].get(weight, 0) + if w >= 0: + add_entry(e[:-1] + (0,)) + else: + c = e[-1][capacity] + flow_cost += w * c + add_entry(e[:-1] + (c,)) + + return flow_cost, flow_dict diff --git a/lib/python3.10/site-packages/networkx/algorithms/flow/preflowpush.py b/lib/python3.10/site-packages/networkx/algorithms/flow/preflowpush.py new file mode 100644 index 0000000000000000000000000000000000000000..42cadc2e2db6ecfb5a347499c89d5ae77f6af3d8 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/flow/preflowpush.py @@ -0,0 +1,425 @@ +""" +Highest-label preflow-push algorithm for maximum flow problems. +""" + +from collections import deque +from itertools import islice + +import networkx as nx + +from ...utils import arbitrary_element +from .utils import ( + CurrentEdge, + GlobalRelabelThreshold, + Level, + build_residual_network, + detect_unboundedness, +) + +__all__ = ["preflow_push"] + + +def preflow_push_impl(G, s, t, capacity, residual, global_relabel_freq, value_only): + """Implementation of the highest-label preflow-push algorithm.""" + if s not in G: + raise nx.NetworkXError(f"node {str(s)} not in graph") + if t not in G: + raise nx.NetworkXError(f"node {str(t)} not in graph") + if s == t: + raise nx.NetworkXError("source and sink are the same node") + + if global_relabel_freq is None: + global_relabel_freq = 0 + if global_relabel_freq < 0: + raise nx.NetworkXError("global_relabel_freq must be nonnegative.") + + if residual is None: + R = build_residual_network(G, capacity) + else: + R = residual + + detect_unboundedness(R, s, t) + + R_nodes = R.nodes + R_pred = R.pred + R_succ = R.succ + + # Initialize/reset the residual network. + for u in R: + R_nodes[u]["excess"] = 0 + for e in R_succ[u].values(): + e["flow"] = 0 + + def reverse_bfs(src): + """Perform a reverse breadth-first search from src in the residual + network. + """ + heights = {src: 0} + q = deque([(src, 0)]) + while q: + u, height = q.popleft() + height += 1 + for v, attr in R_pred[u].items(): + if v not in heights and attr["flow"] < attr["capacity"]: + heights[v] = height + q.append((v, height)) + return heights + + # Initialize heights of the nodes. + heights = reverse_bfs(t) + + if s not in heights: + # t is not reachable from s in the residual network. The maximum flow + # must be zero. + R.graph["flow_value"] = 0 + return R + + n = len(R) + # max_height represents the height of the highest level below level n with + # at least one active node. + max_height = max(heights[u] for u in heights if u != s) + heights[s] = n + + grt = GlobalRelabelThreshold(n, R.size(), global_relabel_freq) + + # Initialize heights and 'current edge' data structures of the nodes. + for u in R: + R_nodes[u]["height"] = heights[u] if u in heights else n + 1 + R_nodes[u]["curr_edge"] = CurrentEdge(R_succ[u]) + + def push(u, v, flow): + """Push flow units of flow from u to v.""" + R_succ[u][v]["flow"] += flow + R_succ[v][u]["flow"] -= flow + R_nodes[u]["excess"] -= flow + R_nodes[v]["excess"] += flow + + # The maximum flow must be nonzero now. Initialize the preflow by + # saturating all edges emanating from s. + for u, attr in R_succ[s].items(): + flow = attr["capacity"] + if flow > 0: + push(s, u, flow) + + # Partition nodes into levels. + levels = [Level() for i in range(2 * n)] + for u in R: + if u != s and u != t: + level = levels[R_nodes[u]["height"]] + if R_nodes[u]["excess"] > 0: + level.active.add(u) + else: + level.inactive.add(u) + + def activate(v): + """Move a node from the inactive set to the active set of its level.""" + if v != s and v != t: + level = levels[R_nodes[v]["height"]] + if v in level.inactive: + level.inactive.remove(v) + level.active.add(v) + + def relabel(u): + """Relabel a node to create an admissible edge.""" + grt.add_work(len(R_succ[u])) + return ( + min( + R_nodes[v]["height"] + for v, attr in R_succ[u].items() + if attr["flow"] < attr["capacity"] + ) + + 1 + ) + + def discharge(u, is_phase1): + """Discharge a node until it becomes inactive or, during phase 1 (see + below), its height reaches at least n. The node is known to have the + largest height among active nodes. + """ + height = R_nodes[u]["height"] + curr_edge = R_nodes[u]["curr_edge"] + # next_height represents the next height to examine after discharging + # the current node. During phase 1, it is capped to below n. + next_height = height + levels[height].active.remove(u) + while True: + v, attr = curr_edge.get() + if height == R_nodes[v]["height"] + 1 and attr["flow"] < attr["capacity"]: + flow = min(R_nodes[u]["excess"], attr["capacity"] - attr["flow"]) + push(u, v, flow) + activate(v) + if R_nodes[u]["excess"] == 0: + # The node has become inactive. + levels[height].inactive.add(u) + break + try: + curr_edge.move_to_next() + except StopIteration: + # We have run off the end of the adjacency list, and there can + # be no more admissible edges. Relabel the node to create one. + height = relabel(u) + if is_phase1 and height >= n - 1: + # Although the node is still active, with a height at least + # n - 1, it is now known to be on the s side of the minimum + # s-t cut. Stop processing it until phase 2. + levels[height].active.add(u) + break + # The first relabel operation after global relabeling may not + # increase the height of the node since the 'current edge' data + # structure is not rewound. Use height instead of (height - 1) + # in case other active nodes at the same level are missed. + next_height = height + R_nodes[u]["height"] = height + return next_height + + def gap_heuristic(height): + """Apply the gap heuristic.""" + # Move all nodes at levels (height + 1) to max_height to level n + 1. + for level in islice(levels, height + 1, max_height + 1): + for u in level.active: + R_nodes[u]["height"] = n + 1 + for u in level.inactive: + R_nodes[u]["height"] = n + 1 + levels[n + 1].active.update(level.active) + level.active.clear() + levels[n + 1].inactive.update(level.inactive) + level.inactive.clear() + + def global_relabel(from_sink): + """Apply the global relabeling heuristic.""" + src = t if from_sink else s + heights = reverse_bfs(src) + if not from_sink: + # s must be reachable from t. Remove t explicitly. + del heights[t] + max_height = max(heights.values()) + if from_sink: + # Also mark nodes from which t is unreachable for relabeling. This + # serves the same purpose as the gap heuristic. + for u in R: + if u not in heights and R_nodes[u]["height"] < n: + heights[u] = n + 1 + else: + # Shift the computed heights because the height of s is n. + for u in heights: + heights[u] += n + max_height += n + del heights[src] + for u, new_height in heights.items(): + old_height = R_nodes[u]["height"] + if new_height != old_height: + if u in levels[old_height].active: + levels[old_height].active.remove(u) + levels[new_height].active.add(u) + else: + levels[old_height].inactive.remove(u) + levels[new_height].inactive.add(u) + R_nodes[u]["height"] = new_height + return max_height + + # Phase 1: Find the maximum preflow by pushing as much flow as possible to + # t. + + height = max_height + while height > 0: + # Discharge active nodes in the current level. + while True: + level = levels[height] + if not level.active: + # All active nodes in the current level have been discharged. + # Move to the next lower level. + height -= 1 + break + # Record the old height and level for the gap heuristic. + old_height = height + old_level = level + u = arbitrary_element(level.active) + height = discharge(u, True) + if grt.is_reached(): + # Global relabeling heuristic: Recompute the exact heights of + # all nodes. + height = global_relabel(True) + max_height = height + grt.clear_work() + elif not old_level.active and not old_level.inactive: + # Gap heuristic: If the level at old_height is empty (a 'gap'), + # a minimum cut has been identified. All nodes with heights + # above old_height can have their heights set to n + 1 and not + # be further processed before a maximum preflow is found. + gap_heuristic(old_height) + height = old_height - 1 + max_height = height + else: + # Update the height of the highest level with at least one + # active node. + max_height = max(max_height, height) + + # A maximum preflow has been found. The excess at t is the maximum flow + # value. + if value_only: + R.graph["flow_value"] = R_nodes[t]["excess"] + return R + + # Phase 2: Convert the maximum preflow into a maximum flow by returning the + # excess to s. + + # Relabel all nodes so that they have accurate heights. + height = global_relabel(False) + grt.clear_work() + + # Continue to discharge the active nodes. + while height > n: + # Discharge active nodes in the current level. + while True: + level = levels[height] + if not level.active: + # All active nodes in the current level have been discharged. + # Move to the next lower level. + height -= 1 + break + u = arbitrary_element(level.active) + height = discharge(u, False) + if grt.is_reached(): + # Global relabeling heuristic. + height = global_relabel(False) + grt.clear_work() + + R.graph["flow_value"] = R_nodes[t]["excess"] + return R + + +@nx._dispatchable(edge_attrs={"capacity": float("inf")}, returns_graph=True) +def preflow_push( + G, s, t, capacity="capacity", residual=None, global_relabel_freq=1, value_only=False +): + r"""Find a maximum single-commodity flow using the highest-label + preflow-push algorithm. + + This function returns the residual network resulting after computing + the maximum flow. See below for details about the conventions + NetworkX uses for defining residual networks. + + This algorithm has a running time of $O(n^2 \sqrt{m})$ for $n$ nodes and + $m$ edges. + + + Parameters + ---------- + G : NetworkX graph + Edges of the graph are expected to have an attribute called + 'capacity'. If this attribute is not present, the edge is + considered to have infinite capacity. + + s : node + Source node for the flow. + + t : node + Sink node for the flow. + + capacity : string + Edges of the graph G are expected to have an attribute capacity + that indicates how much flow the edge can support. If this + attribute is not present, the edge is considered to have + infinite capacity. Default value: 'capacity'. + + residual : NetworkX graph + Residual network on which the algorithm is to be executed. If None, a + new residual network is created. Default value: None. + + global_relabel_freq : integer, float + Relative frequency of applying the global relabeling heuristic to speed + up the algorithm. If it is None, the heuristic is disabled. Default + value: 1. + + value_only : bool + If False, compute a maximum flow; otherwise, compute a maximum preflow + which is enough for computing the maximum flow value. Default value: + False. + + Returns + ------- + R : NetworkX DiGraph + Residual network after computing the maximum flow. + + Raises + ------ + NetworkXError + The algorithm does not support MultiGraph and MultiDiGraph. If + the input graph is an instance of one of these two classes, a + NetworkXError is raised. + + NetworkXUnbounded + If the graph has a path of infinite capacity, the value of a + feasible flow on the graph is unbounded above and the function + raises a NetworkXUnbounded. + + See also + -------- + :meth:`maximum_flow` + :meth:`minimum_cut` + :meth:`edmonds_karp` + :meth:`shortest_augmenting_path` + + Notes + ----- + The residual network :samp:`R` from an input graph :samp:`G` has the + same nodes as :samp:`G`. :samp:`R` is a DiGraph that contains a pair + of edges :samp:`(u, v)` and :samp:`(v, u)` iff :samp:`(u, v)` is not a + self-loop, and at least one of :samp:`(u, v)` and :samp:`(v, u)` exists + in :samp:`G`. For each node :samp:`u` in :samp:`R`, + :samp:`R.nodes[u]['excess']` represents the difference between flow into + :samp:`u` and flow out of :samp:`u`. + + For each edge :samp:`(u, v)` in :samp:`R`, :samp:`R[u][v]['capacity']` + is equal to the capacity of :samp:`(u, v)` in :samp:`G` if it exists + in :samp:`G` or zero otherwise. If the capacity is infinite, + :samp:`R[u][v]['capacity']` will have a high arbitrary finite value + that does not affect the solution of the problem. This value is stored in + :samp:`R.graph['inf']`. For each edge :samp:`(u, v)` in :samp:`R`, + :samp:`R[u][v]['flow']` represents the flow function of :samp:`(u, v)` and + satisfies :samp:`R[u][v]['flow'] == -R[v][u]['flow']`. + + The flow value, defined as the total flow into :samp:`t`, the sink, is + stored in :samp:`R.graph['flow_value']`. Reachability to :samp:`t` using + only edges :samp:`(u, v)` such that + :samp:`R[u][v]['flow'] < R[u][v]['capacity']` induces a minimum + :samp:`s`-:samp:`t` cut. + + Examples + -------- + >>> from networkx.algorithms.flow import preflow_push + + The functions that implement flow algorithms and output a residual + network, such as this one, are not imported to the base NetworkX + namespace, so you have to explicitly import them from the flow package. + + >>> G = nx.DiGraph() + >>> G.add_edge("x", "a", capacity=3.0) + >>> G.add_edge("x", "b", capacity=1.0) + >>> G.add_edge("a", "c", capacity=3.0) + >>> G.add_edge("b", "c", capacity=5.0) + >>> G.add_edge("b", "d", capacity=4.0) + >>> G.add_edge("d", "e", capacity=2.0) + >>> G.add_edge("c", "y", capacity=2.0) + >>> G.add_edge("e", "y", capacity=3.0) + >>> R = preflow_push(G, "x", "y") + >>> flow_value = nx.maximum_flow_value(G, "x", "y") + >>> flow_value == R.graph["flow_value"] + True + >>> # preflow_push also stores the maximum flow value + >>> # in the excess attribute of the sink node t + >>> flow_value == R.nodes["y"]["excess"] + True + >>> # For some problems, you might only want to compute a + >>> # maximum preflow. + >>> R = preflow_push(G, "x", "y", value_only=True) + >>> flow_value == R.graph["flow_value"] + True + >>> flow_value == R.nodes["y"]["excess"] + True + + """ + R = preflow_push_impl(G, s, t, capacity, residual, global_relabel_freq, value_only) + R.graph["algorithm"] = "preflow_push" + nx._clear_cache(R) + return R diff --git a/lib/python3.10/site-packages/networkx/algorithms/flow/shortestaugmentingpath.py b/lib/python3.10/site-packages/networkx/algorithms/flow/shortestaugmentingpath.py new file mode 100644 index 0000000000000000000000000000000000000000..9f1193f1cbfbe188ebf05105a2c6f1802baca6f1 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/flow/shortestaugmentingpath.py @@ -0,0 +1,300 @@ +""" +Shortest augmenting path algorithm for maximum flow problems. +""" + +from collections import deque + +import networkx as nx + +from .edmondskarp import edmonds_karp_core +from .utils import CurrentEdge, build_residual_network + +__all__ = ["shortest_augmenting_path"] + + +def shortest_augmenting_path_impl(G, s, t, capacity, residual, two_phase, cutoff): + """Implementation of the shortest augmenting path algorithm.""" + if s not in G: + raise nx.NetworkXError(f"node {str(s)} not in graph") + if t not in G: + raise nx.NetworkXError(f"node {str(t)} not in graph") + if s == t: + raise nx.NetworkXError("source and sink are the same node") + + if residual is None: + R = build_residual_network(G, capacity) + else: + R = residual + + R_nodes = R.nodes + R_pred = R.pred + R_succ = R.succ + + # Initialize/reset the residual network. + for u in R: + for e in R_succ[u].values(): + e["flow"] = 0 + + # Initialize heights of the nodes. + heights = {t: 0} + q = deque([(t, 0)]) + while q: + u, height = q.popleft() + height += 1 + for v, attr in R_pred[u].items(): + if v not in heights and attr["flow"] < attr["capacity"]: + heights[v] = height + q.append((v, height)) + + if s not in heights: + # t is not reachable from s in the residual network. The maximum flow + # must be zero. + R.graph["flow_value"] = 0 + return R + + n = len(G) + m = R.size() / 2 + + # Initialize heights and 'current edge' data structures of the nodes. + for u in R: + R_nodes[u]["height"] = heights[u] if u in heights else n + R_nodes[u]["curr_edge"] = CurrentEdge(R_succ[u]) + + # Initialize counts of nodes in each level. + counts = [0] * (2 * n - 1) + for u in R: + counts[R_nodes[u]["height"]] += 1 + + inf = R.graph["inf"] + + def augment(path): + """Augment flow along a path from s to t.""" + # Determine the path residual capacity. + flow = inf + it = iter(path) + u = next(it) + for v in it: + attr = R_succ[u][v] + flow = min(flow, attr["capacity"] - attr["flow"]) + u = v + if flow * 2 > inf: + raise nx.NetworkXUnbounded("Infinite capacity path, flow unbounded above.") + # Augment flow along the path. + it = iter(path) + u = next(it) + for v in it: + R_succ[u][v]["flow"] += flow + R_succ[v][u]["flow"] -= flow + u = v + return flow + + def relabel(u): + """Relabel a node to create an admissible edge.""" + height = n - 1 + for v, attr in R_succ[u].items(): + if attr["flow"] < attr["capacity"]: + height = min(height, R_nodes[v]["height"]) + return height + 1 + + if cutoff is None: + cutoff = float("inf") + + # Phase 1: Look for shortest augmenting paths using depth-first search. + + flow_value = 0 + path = [s] + u = s + d = n if not two_phase else int(min(m**0.5, 2 * n ** (2.0 / 3))) + done = R_nodes[s]["height"] >= d + while not done: + height = R_nodes[u]["height"] + curr_edge = R_nodes[u]["curr_edge"] + # Depth-first search for the next node on the path to t. + while True: + v, attr = curr_edge.get() + if height == R_nodes[v]["height"] + 1 and attr["flow"] < attr["capacity"]: + # Advance to the next node following an admissible edge. + path.append(v) + u = v + break + try: + curr_edge.move_to_next() + except StopIteration: + counts[height] -= 1 + if counts[height] == 0: + # Gap heuristic: If relabeling causes a level to become + # empty, a minimum cut has been identified. The algorithm + # can now be terminated. + R.graph["flow_value"] = flow_value + return R + height = relabel(u) + if u == s and height >= d: + if not two_phase: + # t is disconnected from s in the residual network. No + # more augmenting paths exist. + R.graph["flow_value"] = flow_value + return R + else: + # t is at least d steps away from s. End of phase 1. + done = True + break + counts[height] += 1 + R_nodes[u]["height"] = height + if u != s: + # After relabeling, the last edge on the path is no longer + # admissible. Retreat one step to look for an alternative. + path.pop() + u = path[-1] + break + if u == t: + # t is reached. Augment flow along the path and reset it for a new + # depth-first search. + flow_value += augment(path) + if flow_value >= cutoff: + R.graph["flow_value"] = flow_value + return R + path = [s] + u = s + + # Phase 2: Look for shortest augmenting paths using breadth-first search. + flow_value += edmonds_karp_core(R, s, t, cutoff - flow_value) + + R.graph["flow_value"] = flow_value + return R + + +@nx._dispatchable(edge_attrs={"capacity": float("inf")}, returns_graph=True) +def shortest_augmenting_path( + G, + s, + t, + capacity="capacity", + residual=None, + value_only=False, + two_phase=False, + cutoff=None, +): + r"""Find a maximum single-commodity flow using the shortest augmenting path + algorithm. + + This function returns the residual network resulting after computing + the maximum flow. See below for details about the conventions + NetworkX uses for defining residual networks. + + This algorithm has a running time of $O(n^2 m)$ for $n$ nodes and $m$ + edges. + + + Parameters + ---------- + G : NetworkX graph + Edges of the graph are expected to have an attribute called + 'capacity'. If this attribute is not present, the edge is + considered to have infinite capacity. + + s : node + Source node for the flow. + + t : node + Sink node for the flow. + + capacity : string + Edges of the graph G are expected to have an attribute capacity + that indicates how much flow the edge can support. If this + attribute is not present, the edge is considered to have + infinite capacity. Default value: 'capacity'. + + residual : NetworkX graph + Residual network on which the algorithm is to be executed. If None, a + new residual network is created. Default value: None. + + value_only : bool + If True compute only the value of the maximum flow. This parameter + will be ignored by this algorithm because it is not applicable. + + two_phase : bool + If True, a two-phase variant is used. The two-phase variant improves + the running time on unit-capacity networks from $O(nm)$ to + $O(\min(n^{2/3}, m^{1/2}) m)$. Default value: False. + + cutoff : integer, float + If specified, the algorithm will terminate when the flow value reaches + or exceeds the cutoff. In this case, it may be unable to immediately + determine a minimum cut. Default value: None. + + Returns + ------- + R : NetworkX DiGraph + Residual network after computing the maximum flow. + + Raises + ------ + NetworkXError + The algorithm does not support MultiGraph and MultiDiGraph. If + the input graph is an instance of one of these two classes, a + NetworkXError is raised. + + NetworkXUnbounded + If the graph has a path of infinite capacity, the value of a + feasible flow on the graph is unbounded above and the function + raises a NetworkXUnbounded. + + See also + -------- + :meth:`maximum_flow` + :meth:`minimum_cut` + :meth:`edmonds_karp` + :meth:`preflow_push` + + Notes + ----- + The residual network :samp:`R` from an input graph :samp:`G` has the + same nodes as :samp:`G`. :samp:`R` is a DiGraph that contains a pair + of edges :samp:`(u, v)` and :samp:`(v, u)` iff :samp:`(u, v)` is not a + self-loop, and at least one of :samp:`(u, v)` and :samp:`(v, u)` exists + in :samp:`G`. + + For each edge :samp:`(u, v)` in :samp:`R`, :samp:`R[u][v]['capacity']` + is equal to the capacity of :samp:`(u, v)` in :samp:`G` if it exists + in :samp:`G` or zero otherwise. If the capacity is infinite, + :samp:`R[u][v]['capacity']` will have a high arbitrary finite value + that does not affect the solution of the problem. This value is stored in + :samp:`R.graph['inf']`. For each edge :samp:`(u, v)` in :samp:`R`, + :samp:`R[u][v]['flow']` represents the flow function of :samp:`(u, v)` and + satisfies :samp:`R[u][v]['flow'] == -R[v][u]['flow']`. + + The flow value, defined as the total flow into :samp:`t`, the sink, is + stored in :samp:`R.graph['flow_value']`. If :samp:`cutoff` is not + specified, reachability to :samp:`t` using only edges :samp:`(u, v)` such + that :samp:`R[u][v]['flow'] < R[u][v]['capacity']` induces a minimum + :samp:`s`-:samp:`t` cut. + + Examples + -------- + >>> from networkx.algorithms.flow import shortest_augmenting_path + + The functions that implement flow algorithms and output a residual + network, such as this one, are not imported to the base NetworkX + namespace, so you have to explicitly import them from the flow package. + + >>> G = nx.DiGraph() + >>> G.add_edge("x", "a", capacity=3.0) + >>> G.add_edge("x", "b", capacity=1.0) + >>> G.add_edge("a", "c", capacity=3.0) + >>> G.add_edge("b", "c", capacity=5.0) + >>> G.add_edge("b", "d", capacity=4.0) + >>> G.add_edge("d", "e", capacity=2.0) + >>> G.add_edge("c", "y", capacity=2.0) + >>> G.add_edge("e", "y", capacity=3.0) + >>> R = shortest_augmenting_path(G, "x", "y") + >>> flow_value = nx.maximum_flow_value(G, "x", "y") + >>> flow_value + 3.0 + >>> flow_value == R.graph["flow_value"] + True + + """ + R = shortest_augmenting_path_impl(G, s, t, capacity, residual, two_phase, cutoff) + R.graph["algorithm"] = "shortest_augmenting_path" + nx._clear_cache(R) + return R diff --git a/lib/python3.10/site-packages/networkx/algorithms/flow/tests/__init__.py b/lib/python3.10/site-packages/networkx/algorithms/flow/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/python3.10/site-packages/networkx/algorithms/flow/tests/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/flow/tests/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..31b0a0c786ab3009ec494d8aa17997b10d304359 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/flow/tests/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/flow/tests/__pycache__/test_gomory_hu.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/flow/tests/__pycache__/test_gomory_hu.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..641c6464008da30141e2df393f1b589601625d8e Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/flow/tests/__pycache__/test_gomory_hu.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/flow/tests/__pycache__/test_maxflow.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/flow/tests/__pycache__/test_maxflow.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df532180c491722481f1ebafa7b904ea48fb4c08 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/flow/tests/__pycache__/test_maxflow.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/flow/tests/__pycache__/test_maxflow_large_graph.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/flow/tests/__pycache__/test_maxflow_large_graph.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..194c72060d464a696271972fc37434bb91d986b2 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/flow/tests/__pycache__/test_maxflow_large_graph.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/flow/tests/__pycache__/test_mincost.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/flow/tests/__pycache__/test_mincost.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..76b3502603d6117b4cf51de9b79843b0e21e8d9a Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/flow/tests/__pycache__/test_mincost.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/flow/tests/__pycache__/test_networksimplex.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/flow/tests/__pycache__/test_networksimplex.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9c125701f427e94e67da79a9e4757e76a61e21ee Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/flow/tests/__pycache__/test_networksimplex.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/flow/tests/test_gomory_hu.py b/lib/python3.10/site-packages/networkx/algorithms/flow/tests/test_gomory_hu.py new file mode 100644 index 0000000000000000000000000000000000000000..1649ec82c719226e9caa68268d8953f7cae6ef74 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/flow/tests/test_gomory_hu.py @@ -0,0 +1,128 @@ +from itertools import combinations + +import pytest + +import networkx as nx +from networkx.algorithms.flow import ( + boykov_kolmogorov, + dinitz, + edmonds_karp, + preflow_push, + shortest_augmenting_path, +) + +flow_funcs = [ + boykov_kolmogorov, + dinitz, + edmonds_karp, + preflow_push, + shortest_augmenting_path, +] + + +class TestGomoryHuTree: + def minimum_edge_weight(self, T, u, v): + path = nx.shortest_path(T, u, v, weight="weight") + return min((T[u][v]["weight"], (u, v)) for (u, v) in zip(path, path[1:])) + + def compute_cutset(self, G, T_orig, edge): + T = T_orig.copy() + T.remove_edge(*edge) + U, V = list(nx.connected_components(T)) + cutset = set() + for x, nbrs in ((n, G[n]) for n in U): + cutset.update((x, y) for y in nbrs if y in V) + return cutset + + def test_default_flow_function_karate_club_graph(self): + G = nx.karate_club_graph() + nx.set_edge_attributes(G, 1, "capacity") + T = nx.gomory_hu_tree(G) + assert nx.is_tree(T) + for u, v in combinations(G, 2): + cut_value, edge = self.minimum_edge_weight(T, u, v) + assert nx.minimum_cut_value(G, u, v) == cut_value + + def test_karate_club_graph(self): + G = nx.karate_club_graph() + nx.set_edge_attributes(G, 1, "capacity") + for flow_func in flow_funcs: + T = nx.gomory_hu_tree(G, flow_func=flow_func) + assert nx.is_tree(T) + for u, v in combinations(G, 2): + cut_value, edge = self.minimum_edge_weight(T, u, v) + assert nx.minimum_cut_value(G, u, v) == cut_value + + def test_davis_southern_women_graph(self): + G = nx.davis_southern_women_graph() + nx.set_edge_attributes(G, 1, "capacity") + for flow_func in flow_funcs: + T = nx.gomory_hu_tree(G, flow_func=flow_func) + assert nx.is_tree(T) + for u, v in combinations(G, 2): + cut_value, edge = self.minimum_edge_weight(T, u, v) + assert nx.minimum_cut_value(G, u, v) == cut_value + + def test_florentine_families_graph(self): + G = nx.florentine_families_graph() + nx.set_edge_attributes(G, 1, "capacity") + for flow_func in flow_funcs: + T = nx.gomory_hu_tree(G, flow_func=flow_func) + assert nx.is_tree(T) + for u, v in combinations(G, 2): + cut_value, edge = self.minimum_edge_weight(T, u, v) + assert nx.minimum_cut_value(G, u, v) == cut_value + + @pytest.mark.slow + def test_les_miserables_graph_cutset(self): + G = nx.les_miserables_graph() + nx.set_edge_attributes(G, 1, "capacity") + for flow_func in flow_funcs: + T = nx.gomory_hu_tree(G, flow_func=flow_func) + assert nx.is_tree(T) + for u, v in combinations(G, 2): + cut_value, edge = self.minimum_edge_weight(T, u, v) + assert nx.minimum_cut_value(G, u, v) == cut_value + + def test_karate_club_graph_cutset(self): + G = nx.karate_club_graph() + nx.set_edge_attributes(G, 1, "capacity") + T = nx.gomory_hu_tree(G) + assert nx.is_tree(T) + u, v = 0, 33 + cut_value, edge = self.minimum_edge_weight(T, u, v) + cutset = self.compute_cutset(G, T, edge) + assert cut_value == len(cutset) + + def test_wikipedia_example(self): + # Example from https://en.wikipedia.org/wiki/Gomory%E2%80%93Hu_tree + G = nx.Graph() + G.add_weighted_edges_from( + ( + (0, 1, 1), + (0, 2, 7), + (1, 2, 1), + (1, 3, 3), + (1, 4, 2), + (2, 4, 4), + (3, 4, 1), + (3, 5, 6), + (4, 5, 2), + ) + ) + for flow_func in flow_funcs: + T = nx.gomory_hu_tree(G, capacity="weight", flow_func=flow_func) + assert nx.is_tree(T) + for u, v in combinations(G, 2): + cut_value, edge = self.minimum_edge_weight(T, u, v) + assert nx.minimum_cut_value(G, u, v, capacity="weight") == cut_value + + def test_directed_raises(self): + with pytest.raises(nx.NetworkXNotImplemented): + G = nx.DiGraph() + T = nx.gomory_hu_tree(G) + + def test_empty_raises(self): + with pytest.raises(nx.NetworkXError): + G = nx.empty_graph() + T = nx.gomory_hu_tree(G) diff --git a/lib/python3.10/site-packages/networkx/algorithms/flow/tests/test_maxflow.py b/lib/python3.10/site-packages/networkx/algorithms/flow/tests/test_maxflow.py new file mode 100644 index 0000000000000000000000000000000000000000..d7305a7b6320ef1b55c682386cd7320f33a78994 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/flow/tests/test_maxflow.py @@ -0,0 +1,573 @@ +"""Maximum flow algorithms test suite.""" + +import pytest + +import networkx as nx +from networkx.algorithms.flow import ( + boykov_kolmogorov, + build_flow_dict, + build_residual_network, + dinitz, + edmonds_karp, + preflow_push, + shortest_augmenting_path, +) + +flow_funcs = { + boykov_kolmogorov, + dinitz, + edmonds_karp, + preflow_push, + shortest_augmenting_path, +} + +max_min_funcs = {nx.maximum_flow, nx.minimum_cut} +flow_value_funcs = {nx.maximum_flow_value, nx.minimum_cut_value} +interface_funcs = max_min_funcs | flow_value_funcs +all_funcs = flow_funcs | interface_funcs + + +def compute_cutset(G, partition): + reachable, non_reachable = partition + cutset = set() + for u, nbrs in ((n, G[n]) for n in reachable): + cutset.update((u, v) for v in nbrs if v in non_reachable) + return cutset + + +def validate_flows(G, s, t, flowDict, solnValue, capacity, flow_func): + errmsg = f"Assertion failed in function: {flow_func.__name__}" + assert set(G) == set(flowDict), errmsg + for u in G: + assert set(G[u]) == set(flowDict[u]), errmsg + excess = {u: 0 for u in flowDict} + for u in flowDict: + for v, flow in flowDict[u].items(): + if capacity in G[u][v]: + assert flow <= G[u][v][capacity] + assert flow >= 0, errmsg + excess[u] -= flow + excess[v] += flow + for u, exc in excess.items(): + if u == s: + assert exc == -solnValue, errmsg + elif u == t: + assert exc == solnValue, errmsg + else: + assert exc == 0, errmsg + + +def validate_cuts(G, s, t, solnValue, partition, capacity, flow_func): + errmsg = f"Assertion failed in function: {flow_func.__name__}" + assert all(n in G for n in partition[0]), errmsg + assert all(n in G for n in partition[1]), errmsg + cutset = compute_cutset(G, partition) + assert all(G.has_edge(u, v) for (u, v) in cutset), errmsg + assert solnValue == sum(G[u][v][capacity] for (u, v) in cutset), errmsg + H = G.copy() + H.remove_edges_from(cutset) + if not G.is_directed(): + assert not nx.is_connected(H), errmsg + else: + assert not nx.is_strongly_connected(H), errmsg + + +def compare_flows_and_cuts(G, s, t, solnValue, capacity="capacity"): + for flow_func in flow_funcs: + errmsg = f"Assertion failed in function: {flow_func.__name__}" + R = flow_func(G, s, t, capacity) + # Test both legacy and new implementations. + flow_value = R.graph["flow_value"] + flow_dict = build_flow_dict(G, R) + assert flow_value == solnValue, errmsg + validate_flows(G, s, t, flow_dict, solnValue, capacity, flow_func) + # Minimum cut + cut_value, partition = nx.minimum_cut( + G, s, t, capacity=capacity, flow_func=flow_func + ) + validate_cuts(G, s, t, solnValue, partition, capacity, flow_func) + + +class TestMaxflowMinCutCommon: + def test_graph1(self): + # Trivial undirected graph + G = nx.Graph() + G.add_edge(1, 2, capacity=1.0) + + # solution flows + # {1: {2: 1.0}, 2: {1: 1.0}} + + compare_flows_and_cuts(G, 1, 2, 1.0) + + def test_graph2(self): + # A more complex undirected graph + # adapted from https://web.archive.org/web/20220815055650/https://www.topcoder.com/thrive/articles/Maximum%20Flow:%20Part%20One + G = nx.Graph() + G.add_edge("x", "a", capacity=3.0) + G.add_edge("x", "b", capacity=1.0) + G.add_edge("a", "c", capacity=3.0) + G.add_edge("b", "c", capacity=5.0) + G.add_edge("b", "d", capacity=4.0) + G.add_edge("d", "e", capacity=2.0) + G.add_edge("c", "y", capacity=2.0) + G.add_edge("e", "y", capacity=3.0) + + # H + # { + # "x": {"a": 3, "b": 1}, + # "a": {"c": 3, "x": 3}, + # "b": {"c": 1, "d": 2, "x": 1}, + # "c": {"a": 3, "b": 1, "y": 2}, + # "d": {"b": 2, "e": 2}, + # "e": {"d": 2, "y": 2}, + # "y": {"c": 2, "e": 2}, + # } + + compare_flows_and_cuts(G, "x", "y", 4.0) + + def test_digraph1(self): + # The classic directed graph example + G = nx.DiGraph() + G.add_edge("a", "b", capacity=1000.0) + G.add_edge("a", "c", capacity=1000.0) + G.add_edge("b", "c", capacity=1.0) + G.add_edge("b", "d", capacity=1000.0) + G.add_edge("c", "d", capacity=1000.0) + + # H + # { + # "a": {"b": 1000.0, "c": 1000.0}, + # "b": {"c": 0, "d": 1000.0}, + # "c": {"d": 1000.0}, + # "d": {}, + # } + + compare_flows_and_cuts(G, "a", "d", 2000.0) + + def test_digraph2(self): + # An example in which some edges end up with zero flow. + G = nx.DiGraph() + G.add_edge("s", "b", capacity=2) + G.add_edge("s", "c", capacity=1) + G.add_edge("c", "d", capacity=1) + G.add_edge("d", "a", capacity=1) + G.add_edge("b", "a", capacity=2) + G.add_edge("a", "t", capacity=2) + + # H + # { + # "s": {"b": 2, "c": 0}, + # "c": {"d": 0}, + # "d": {"a": 0}, + # "b": {"a": 2}, + # "a": {"t": 2}, + # "t": {}, + # } + + compare_flows_and_cuts(G, "s", "t", 2) + + def test_digraph3(self): + # A directed graph example from Cormen et al. + G = nx.DiGraph() + G.add_edge("s", "v1", capacity=16.0) + G.add_edge("s", "v2", capacity=13.0) + G.add_edge("v1", "v2", capacity=10.0) + G.add_edge("v2", "v1", capacity=4.0) + G.add_edge("v1", "v3", capacity=12.0) + G.add_edge("v3", "v2", capacity=9.0) + G.add_edge("v2", "v4", capacity=14.0) + G.add_edge("v4", "v3", capacity=7.0) + G.add_edge("v3", "t", capacity=20.0) + G.add_edge("v4", "t", capacity=4.0) + + # H + # { + # "s": {"v1": 12.0, "v2": 11.0}, + # "v2": {"v1": 0, "v4": 11.0}, + # "v1": {"v2": 0, "v3": 12.0}, + # "v3": {"v2": 0, "t": 19.0}, + # "v4": {"v3": 7.0, "t": 4.0}, + # "t": {}, + # } + + compare_flows_and_cuts(G, "s", "t", 23.0) + + def test_digraph4(self): + # A more complex directed graph + # from https://web.archive.org/web/20220815055650/https://www.topcoder.com/thrive/articles/Maximum%20Flow:%20Part%20One + G = nx.DiGraph() + G.add_edge("x", "a", capacity=3.0) + G.add_edge("x", "b", capacity=1.0) + G.add_edge("a", "c", capacity=3.0) + G.add_edge("b", "c", capacity=5.0) + G.add_edge("b", "d", capacity=4.0) + G.add_edge("d", "e", capacity=2.0) + G.add_edge("c", "y", capacity=2.0) + G.add_edge("e", "y", capacity=3.0) + + # H + # { + # "x": {"a": 2.0, "b": 1.0}, + # "a": {"c": 2.0}, + # "b": {"c": 0, "d": 1.0}, + # "c": {"y": 2.0}, + # "d": {"e": 1.0}, + # "e": {"y": 1.0}, + # "y": {}, + # } + + compare_flows_and_cuts(G, "x", "y", 3.0) + + def test_wikipedia_dinitz_example(self): + # Nice example from https://en.wikipedia.org/wiki/Dinic's_algorithm + G = nx.DiGraph() + G.add_edge("s", 1, capacity=10) + G.add_edge("s", 2, capacity=10) + G.add_edge(1, 3, capacity=4) + G.add_edge(1, 4, capacity=8) + G.add_edge(1, 2, capacity=2) + G.add_edge(2, 4, capacity=9) + G.add_edge(3, "t", capacity=10) + G.add_edge(4, 3, capacity=6) + G.add_edge(4, "t", capacity=10) + + # solution flows + # { + # 1: {2: 0, 3: 4, 4: 6}, + # 2: {4: 9}, + # 3: {"t": 9}, + # 4: {3: 5, "t": 10}, + # "s": {1: 10, 2: 9}, + # "t": {}, + # } + + compare_flows_and_cuts(G, "s", "t", 19) + + def test_optional_capacity(self): + # Test optional capacity parameter. + G = nx.DiGraph() + G.add_edge("x", "a", spam=3.0) + G.add_edge("x", "b", spam=1.0) + G.add_edge("a", "c", spam=3.0) + G.add_edge("b", "c", spam=5.0) + G.add_edge("b", "d", spam=4.0) + G.add_edge("d", "e", spam=2.0) + G.add_edge("c", "y", spam=2.0) + G.add_edge("e", "y", spam=3.0) + + # solution flows + # { + # "x": {"a": 2.0, "b": 1.0}, + # "a": {"c": 2.0}, + # "b": {"c": 0, "d": 1.0}, + # "c": {"y": 2.0}, + # "d": {"e": 1.0}, + # "e": {"y": 1.0}, + # "y": {}, + # } + solnValue = 3.0 + s = "x" + t = "y" + + compare_flows_and_cuts(G, s, t, solnValue, capacity="spam") + + def test_digraph_infcap_edges(self): + # DiGraph with infinite capacity edges + G = nx.DiGraph() + G.add_edge("s", "a") + G.add_edge("s", "b", capacity=30) + G.add_edge("a", "c", capacity=25) + G.add_edge("b", "c", capacity=12) + G.add_edge("a", "t", capacity=60) + G.add_edge("c", "t") + + # H + # { + # "s": {"a": 85, "b": 12}, + # "a": {"c": 25, "t": 60}, + # "b": {"c": 12}, + # "c": {"t": 37}, + # "t": {}, + # } + + compare_flows_and_cuts(G, "s", "t", 97) + + # DiGraph with infinite capacity digon + G = nx.DiGraph() + G.add_edge("s", "a", capacity=85) + G.add_edge("s", "b", capacity=30) + G.add_edge("a", "c") + G.add_edge("c", "a") + G.add_edge("b", "c", capacity=12) + G.add_edge("a", "t", capacity=60) + G.add_edge("c", "t", capacity=37) + + # H + # { + # "s": {"a": 85, "b": 12}, + # "a": {"c": 25, "t": 60}, + # "c": {"a": 0, "t": 37}, + # "b": {"c": 12}, + # "t": {}, + # } + + compare_flows_and_cuts(G, "s", "t", 97) + + def test_digraph_infcap_path(self): + # Graph with infinite capacity (s, t)-path + G = nx.DiGraph() + G.add_edge("s", "a") + G.add_edge("s", "b", capacity=30) + G.add_edge("a", "c") + G.add_edge("b", "c", capacity=12) + G.add_edge("a", "t", capacity=60) + G.add_edge("c", "t") + + for flow_func in all_funcs: + pytest.raises(nx.NetworkXUnbounded, flow_func, G, "s", "t") + + def test_graph_infcap_edges(self): + # Undirected graph with infinite capacity edges + G = nx.Graph() + G.add_edge("s", "a") + G.add_edge("s", "b", capacity=30) + G.add_edge("a", "c", capacity=25) + G.add_edge("b", "c", capacity=12) + G.add_edge("a", "t", capacity=60) + G.add_edge("c", "t") + + # H + # { + # "s": {"a": 85, "b": 12}, + # "a": {"c": 25, "s": 85, "t": 60}, + # "b": {"c": 12, "s": 12}, + # "c": {"a": 25, "b": 12, "t": 37}, + # "t": {"a": 60, "c": 37}, + # } + + compare_flows_and_cuts(G, "s", "t", 97) + + def test_digraph5(self): + # From ticket #429 by mfrasca. + G = nx.DiGraph() + G.add_edge("s", "a", capacity=2) + G.add_edge("s", "b", capacity=2) + G.add_edge("a", "b", capacity=5) + G.add_edge("a", "t", capacity=1) + G.add_edge("b", "a", capacity=1) + G.add_edge("b", "t", capacity=3) + # flow solution + # { + # "a": {"b": 1, "t": 1}, + # "b": {"a": 0, "t": 3}, + # "s": {"a": 2, "b": 2}, + # "t": {}, + # } + compare_flows_and_cuts(G, "s", "t", 4) + + def test_disconnected(self): + G = nx.Graph() + G.add_weighted_edges_from([(0, 1, 1), (1, 2, 1), (2, 3, 1)], weight="capacity") + G.remove_node(1) + assert nx.maximum_flow_value(G, 0, 3) == 0 + # flow solution + # {0: {}, 2: {3: 0}, 3: {2: 0}} + compare_flows_and_cuts(G, 0, 3, 0) + + def test_source_target_not_in_graph(self): + G = nx.Graph() + G.add_weighted_edges_from([(0, 1, 1), (1, 2, 1), (2, 3, 1)], weight="capacity") + G.remove_node(0) + for flow_func in all_funcs: + pytest.raises(nx.NetworkXError, flow_func, G, 0, 3) + G.add_weighted_edges_from([(0, 1, 1), (1, 2, 1), (2, 3, 1)], weight="capacity") + G.remove_node(3) + for flow_func in all_funcs: + pytest.raises(nx.NetworkXError, flow_func, G, 0, 3) + + def test_source_target_coincide(self): + G = nx.Graph() + G.add_node(0) + for flow_func in all_funcs: + pytest.raises(nx.NetworkXError, flow_func, G, 0, 0) + + def test_multigraphs_raise(self): + G = nx.MultiGraph() + M = nx.MultiDiGraph() + G.add_edges_from([(0, 1), (1, 0)], capacity=True) + for flow_func in all_funcs: + pytest.raises(nx.NetworkXError, flow_func, G, 0, 0) + + +class TestMaxFlowMinCutInterface: + def setup_method(self): + G = nx.DiGraph() + G.add_edge("x", "a", capacity=3.0) + G.add_edge("x", "b", capacity=1.0) + G.add_edge("a", "c", capacity=3.0) + G.add_edge("b", "c", capacity=5.0) + G.add_edge("b", "d", capacity=4.0) + G.add_edge("d", "e", capacity=2.0) + G.add_edge("c", "y", capacity=2.0) + G.add_edge("e", "y", capacity=3.0) + self.G = G + H = nx.DiGraph() + H.add_edge(0, 1, capacity=1.0) + H.add_edge(1, 2, capacity=1.0) + self.H = H + + def test_flow_func_not_callable(self): + elements = ["this_should_be_callable", 10, {1, 2, 3}] + G = nx.Graph() + G.add_weighted_edges_from([(0, 1, 1), (1, 2, 1), (2, 3, 1)], weight="capacity") + for flow_func in interface_funcs: + for element in elements: + pytest.raises(nx.NetworkXError, flow_func, G, 0, 1, flow_func=element) + pytest.raises(nx.NetworkXError, flow_func, G, 0, 1, flow_func=element) + + def test_flow_func_parameters(self): + G = self.G + fv = 3.0 + for interface_func in interface_funcs: + for flow_func in flow_funcs: + errmsg = ( + f"Assertion failed in function: {flow_func.__name__} " + f"in interface {interface_func.__name__}" + ) + result = interface_func(G, "x", "y", flow_func=flow_func) + if interface_func in max_min_funcs: + result = result[0] + assert fv == result, errmsg + + def test_minimum_cut_no_cutoff(self): + G = self.G + pytest.raises( + nx.NetworkXError, + nx.minimum_cut, + G, + "x", + "y", + flow_func=preflow_push, + cutoff=1.0, + ) + pytest.raises( + nx.NetworkXError, + nx.minimum_cut_value, + G, + "x", + "y", + flow_func=preflow_push, + cutoff=1.0, + ) + + def test_kwargs(self): + G = self.H + fv = 1.0 + to_test = ( + (shortest_augmenting_path, {"two_phase": True}), + (preflow_push, {"global_relabel_freq": 5}), + ) + for interface_func in interface_funcs: + for flow_func, kwargs in to_test: + errmsg = ( + f"Assertion failed in function: {flow_func.__name__} " + f"in interface {interface_func.__name__}" + ) + result = interface_func(G, 0, 2, flow_func=flow_func, **kwargs) + if interface_func in max_min_funcs: + result = result[0] + assert fv == result, errmsg + + def test_kwargs_default_flow_func(self): + G = self.H + for interface_func in interface_funcs: + pytest.raises( + nx.NetworkXError, interface_func, G, 0, 1, global_relabel_freq=2 + ) + + def test_reusing_residual(self): + G = self.G + fv = 3.0 + s, t = "x", "y" + R = build_residual_network(G, "capacity") + for interface_func in interface_funcs: + for flow_func in flow_funcs: + errmsg = ( + f"Assertion failed in function: {flow_func.__name__} " + f"in interface {interface_func.__name__}" + ) + for i in range(3): + result = interface_func( + G, "x", "y", flow_func=flow_func, residual=R + ) + if interface_func in max_min_funcs: + result = result[0] + assert fv == result, errmsg + + +# Tests specific to one algorithm +def test_preflow_push_global_relabel_freq(): + G = nx.DiGraph() + G.add_edge(1, 2, capacity=1) + R = preflow_push(G, 1, 2, global_relabel_freq=None) + assert R.graph["flow_value"] == 1 + pytest.raises(nx.NetworkXError, preflow_push, G, 1, 2, global_relabel_freq=-1) + + +def test_preflow_push_makes_enough_space(): + # From ticket #1542 + G = nx.DiGraph() + nx.add_path(G, [0, 1, 3], capacity=1) + nx.add_path(G, [1, 2, 3], capacity=1) + R = preflow_push(G, 0, 3, value_only=False) + assert R.graph["flow_value"] == 1 + + +def test_shortest_augmenting_path_two_phase(): + k = 5 + p = 1000 + G = nx.DiGraph() + for i in range(k): + G.add_edge("s", (i, 0), capacity=1) + nx.add_path(G, ((i, j) for j in range(p)), capacity=1) + G.add_edge((i, p - 1), "t", capacity=1) + R = shortest_augmenting_path(G, "s", "t", two_phase=True) + assert R.graph["flow_value"] == k + R = shortest_augmenting_path(G, "s", "t", two_phase=False) + assert R.graph["flow_value"] == k + + +class TestCutoff: + def test_cutoff(self): + k = 5 + p = 1000 + G = nx.DiGraph() + for i in range(k): + G.add_edge("s", (i, 0), capacity=2) + nx.add_path(G, ((i, j) for j in range(p)), capacity=2) + G.add_edge((i, p - 1), "t", capacity=2) + R = shortest_augmenting_path(G, "s", "t", two_phase=True, cutoff=k) + assert k <= R.graph["flow_value"] <= (2 * k) + R = shortest_augmenting_path(G, "s", "t", two_phase=False, cutoff=k) + assert k <= R.graph["flow_value"] <= (2 * k) + R = edmonds_karp(G, "s", "t", cutoff=k) + assert k <= R.graph["flow_value"] <= (2 * k) + R = dinitz(G, "s", "t", cutoff=k) + assert k <= R.graph["flow_value"] <= (2 * k) + R = boykov_kolmogorov(G, "s", "t", cutoff=k) + assert k <= R.graph["flow_value"] <= (2 * k) + + def test_complete_graph_cutoff(self): + G = nx.complete_graph(5) + nx.set_edge_attributes(G, {(u, v): 1 for u, v in G.edges()}, "capacity") + for flow_func in [ + shortest_augmenting_path, + edmonds_karp, + dinitz, + boykov_kolmogorov, + ]: + for cutoff in [3, 2, 1]: + result = nx.maximum_flow_value( + G, 0, 4, flow_func=flow_func, cutoff=cutoff + ) + assert cutoff == result, f"cutoff error in {flow_func.__name__}" diff --git a/lib/python3.10/site-packages/networkx/algorithms/flow/tests/test_maxflow_large_graph.py b/lib/python3.10/site-packages/networkx/algorithms/flow/tests/test_maxflow_large_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..b395cbc8f1eacf09f7d9e143e796a441d8e7a0f5 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/flow/tests/test_maxflow_large_graph.py @@ -0,0 +1,156 @@ +"""Maximum flow algorithms test suite on large graphs.""" + +import bz2 +import importlib.resources +import os +import pickle + +import pytest + +import networkx as nx +from networkx.algorithms.flow import ( + boykov_kolmogorov, + build_flow_dict, + build_residual_network, + dinitz, + edmonds_karp, + preflow_push, + shortest_augmenting_path, +) + +flow_funcs = [ + boykov_kolmogorov, + dinitz, + edmonds_karp, + preflow_push, + shortest_augmenting_path, +] + + +def gen_pyramid(N): + # This graph admits a flow of value 1 for which every arc is at + # capacity (except the arcs incident to the sink which have + # infinite capacity). + G = nx.DiGraph() + + for i in range(N - 1): + cap = 1.0 / (i + 2) + for j in range(i + 1): + G.add_edge((i, j), (i + 1, j), capacity=cap) + cap = 1.0 / (i + 1) - cap + G.add_edge((i, j), (i + 1, j + 1), capacity=cap) + cap = 1.0 / (i + 2) - cap + + for j in range(N): + G.add_edge((N - 1, j), "t") + + return G + + +def read_graph(name): + fname = ( + importlib.resources.files("networkx.algorithms.flow.tests") + / f"{name}.gpickle.bz2" + ) + + with bz2.BZ2File(fname, "rb") as f: + G = pickle.load(f) + return G + + +def validate_flows(G, s, t, soln_value, R, flow_func): + flow_value = R.graph["flow_value"] + flow_dict = build_flow_dict(G, R) + errmsg = f"Assertion failed in function: {flow_func.__name__}" + assert soln_value == flow_value, errmsg + assert set(G) == set(flow_dict), errmsg + for u in G: + assert set(G[u]) == set(flow_dict[u]), errmsg + excess = {u: 0 for u in flow_dict} + for u in flow_dict: + for v, flow in flow_dict[u].items(): + assert flow <= G[u][v].get("capacity", float("inf")), errmsg + assert flow >= 0, errmsg + excess[u] -= flow + excess[v] += flow + for u, exc in excess.items(): + if u == s: + assert exc == -soln_value, errmsg + elif u == t: + assert exc == soln_value, errmsg + else: + assert exc == 0, errmsg + + +class TestMaxflowLargeGraph: + def test_complete_graph(self): + N = 50 + G = nx.complete_graph(N) + nx.set_edge_attributes(G, 5, "capacity") + R = build_residual_network(G, "capacity") + kwargs = {"residual": R} + + for flow_func in flow_funcs: + kwargs["flow_func"] = flow_func + errmsg = f"Assertion failed in function: {flow_func.__name__}" + flow_value = nx.maximum_flow_value(G, 1, 2, **kwargs) + assert flow_value == 5 * (N - 1), errmsg + + def test_pyramid(self): + N = 10 + # N = 100 # this gives a graph with 5051 nodes + G = gen_pyramid(N) + R = build_residual_network(G, "capacity") + kwargs = {"residual": R} + + for flow_func in flow_funcs: + kwargs["flow_func"] = flow_func + errmsg = f"Assertion failed in function: {flow_func.__name__}" + flow_value = nx.maximum_flow_value(G, (0, 0), "t", **kwargs) + assert flow_value == pytest.approx(1.0, abs=1e-7) + + def test_gl1(self): + G = read_graph("gl1") + s = 1 + t = len(G) + R = build_residual_network(G, "capacity") + kwargs = {"residual": R} + + # do one flow_func to save time + flow_func = flow_funcs[0] + validate_flows(G, s, t, 156545, flow_func(G, s, t, **kwargs), flow_func) + + # for flow_func in flow_funcs: + # validate_flows(G, s, t, 156545, flow_func(G, s, t, **kwargs), + # flow_func) + + @pytest.mark.slow + def test_gw1(self): + G = read_graph("gw1") + s = 1 + t = len(G) + R = build_residual_network(G, "capacity") + kwargs = {"residual": R} + + for flow_func in flow_funcs: + validate_flows(G, s, t, 1202018, flow_func(G, s, t, **kwargs), flow_func) + + def test_wlm3(self): + G = read_graph("wlm3") + s = 1 + t = len(G) + R = build_residual_network(G, "capacity") + kwargs = {"residual": R} + + # do one flow_func to save time + flow_func = flow_funcs[0] + validate_flows(G, s, t, 11875108, flow_func(G, s, t, **kwargs), flow_func) + + # for flow_func in flow_funcs: + # validate_flows(G, s, t, 11875108, flow_func(G, s, t, **kwargs), + # flow_func) + + def test_preflow_push_global_relabel(self): + G = read_graph("gw1") + R = preflow_push(G, 1, len(G), global_relabel_freq=50) + assert R.graph["flow_value"] == 1202018 diff --git a/lib/python3.10/site-packages/networkx/algorithms/flow/tests/test_mincost.py b/lib/python3.10/site-packages/networkx/algorithms/flow/tests/test_mincost.py new file mode 100644 index 0000000000000000000000000000000000000000..5b1794b152f8d5523f660d3d7bd15f841e5470f9 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/flow/tests/test_mincost.py @@ -0,0 +1,476 @@ +import bz2 +import importlib.resources +import os +import pickle + +import pytest + +import networkx as nx + + +class TestMinCostFlow: + def test_simple_digraph(self): + G = nx.DiGraph() + G.add_node("a", demand=-5) + G.add_node("d", demand=5) + G.add_edge("a", "b", weight=3, capacity=4) + G.add_edge("a", "c", weight=6, capacity=10) + G.add_edge("b", "d", weight=1, capacity=9) + G.add_edge("c", "d", weight=2, capacity=5) + flowCost, H = nx.network_simplex(G) + soln = {"a": {"b": 4, "c": 1}, "b": {"d": 4}, "c": {"d": 1}, "d": {}} + assert flowCost == 24 + assert nx.min_cost_flow_cost(G) == 24 + assert H == soln + assert nx.min_cost_flow(G) == soln + assert nx.cost_of_flow(G, H) == 24 + + flowCost, H = nx.capacity_scaling(G) + assert flowCost == 24 + assert nx.cost_of_flow(G, H) == 24 + assert H == soln + + def test_negcycle_infcap(self): + G = nx.DiGraph() + G.add_node("s", demand=-5) + G.add_node("t", demand=5) + G.add_edge("s", "a", weight=1, capacity=3) + G.add_edge("a", "b", weight=3) + G.add_edge("c", "a", weight=-6) + G.add_edge("b", "d", weight=1) + G.add_edge("d", "c", weight=-2) + G.add_edge("d", "t", weight=1, capacity=3) + pytest.raises(nx.NetworkXUnfeasible, nx.network_simplex, G) + pytest.raises(nx.NetworkXUnbounded, nx.capacity_scaling, G) + + def test_sum_demands_not_zero(self): + G = nx.DiGraph() + G.add_node("s", demand=-5) + G.add_node("t", demand=4) + G.add_edge("s", "a", weight=1, capacity=3) + G.add_edge("a", "b", weight=3) + G.add_edge("a", "c", weight=-6) + G.add_edge("b", "d", weight=1) + G.add_edge("c", "d", weight=-2) + G.add_edge("d", "t", weight=1, capacity=3) + pytest.raises(nx.NetworkXUnfeasible, nx.network_simplex, G) + pytest.raises(nx.NetworkXUnfeasible, nx.capacity_scaling, G) + + def test_no_flow_satisfying_demands(self): + G = nx.DiGraph() + G.add_node("s", demand=-5) + G.add_node("t", demand=5) + G.add_edge("s", "a", weight=1, capacity=3) + G.add_edge("a", "b", weight=3) + G.add_edge("a", "c", weight=-6) + G.add_edge("b", "d", weight=1) + G.add_edge("c", "d", weight=-2) + G.add_edge("d", "t", weight=1, capacity=3) + pytest.raises(nx.NetworkXUnfeasible, nx.network_simplex, G) + pytest.raises(nx.NetworkXUnfeasible, nx.capacity_scaling, G) + + def test_transshipment(self): + G = nx.DiGraph() + G.add_node("a", demand=1) + G.add_node("b", demand=-2) + G.add_node("c", demand=-2) + G.add_node("d", demand=3) + G.add_node("e", demand=-4) + G.add_node("f", demand=-4) + G.add_node("g", demand=3) + G.add_node("h", demand=2) + G.add_node("r", demand=3) + G.add_edge("a", "c", weight=3) + G.add_edge("r", "a", weight=2) + G.add_edge("b", "a", weight=9) + G.add_edge("r", "c", weight=0) + G.add_edge("b", "r", weight=-6) + G.add_edge("c", "d", weight=5) + G.add_edge("e", "r", weight=4) + G.add_edge("e", "f", weight=3) + G.add_edge("h", "b", weight=4) + G.add_edge("f", "d", weight=7) + G.add_edge("f", "h", weight=12) + G.add_edge("g", "d", weight=12) + G.add_edge("f", "g", weight=-1) + G.add_edge("h", "g", weight=-10) + flowCost, H = nx.network_simplex(G) + soln = { + "a": {"c": 0}, + "b": {"a": 0, "r": 2}, + "c": {"d": 3}, + "d": {}, + "e": {"r": 3, "f": 1}, + "f": {"d": 0, "g": 3, "h": 2}, + "g": {"d": 0}, + "h": {"b": 0, "g": 0}, + "r": {"a": 1, "c": 1}, + } + assert flowCost == 41 + assert nx.min_cost_flow_cost(G) == 41 + assert H == soln + assert nx.min_cost_flow(G) == soln + assert nx.cost_of_flow(G, H) == 41 + + flowCost, H = nx.capacity_scaling(G) + assert flowCost == 41 + assert nx.cost_of_flow(G, H) == 41 + assert H == soln + + def test_max_flow_min_cost(self): + G = nx.DiGraph() + G.add_edge("s", "a", bandwidth=6) + G.add_edge("s", "c", bandwidth=10, cost=10) + G.add_edge("a", "b", cost=6) + G.add_edge("b", "d", bandwidth=8, cost=7) + G.add_edge("c", "d", cost=10) + G.add_edge("d", "t", bandwidth=5, cost=5) + soln = { + "s": {"a": 5, "c": 0}, + "a": {"b": 5}, + "b": {"d": 5}, + "c": {"d": 0}, + "d": {"t": 5}, + "t": {}, + } + flow = nx.max_flow_min_cost(G, "s", "t", capacity="bandwidth", weight="cost") + assert flow == soln + assert nx.cost_of_flow(G, flow, weight="cost") == 90 + + G.add_edge("t", "s", cost=-100) + flowCost, flow = nx.capacity_scaling(G, capacity="bandwidth", weight="cost") + G.remove_edge("t", "s") + assert flowCost == -410 + assert flow["t"]["s"] == 5 + del flow["t"]["s"] + assert flow == soln + assert nx.cost_of_flow(G, flow, weight="cost") == 90 + + def test_digraph1(self): + # From Bradley, S. P., Hax, A. C. and Magnanti, T. L. Applied + # Mathematical Programming. Addison-Wesley, 1977. + G = nx.DiGraph() + G.add_node(1, demand=-20) + G.add_node(4, demand=5) + G.add_node(5, demand=15) + G.add_edges_from( + [ + (1, 2, {"capacity": 15, "weight": 4}), + (1, 3, {"capacity": 8, "weight": 4}), + (2, 3, {"weight": 2}), + (2, 4, {"capacity": 4, "weight": 2}), + (2, 5, {"capacity": 10, "weight": 6}), + (3, 4, {"capacity": 15, "weight": 1}), + (3, 5, {"capacity": 5, "weight": 3}), + (4, 5, {"weight": 2}), + (5, 3, {"capacity": 4, "weight": 1}), + ] + ) + flowCost, H = nx.network_simplex(G) + soln = { + 1: {2: 12, 3: 8}, + 2: {3: 8, 4: 4, 5: 0}, + 3: {4: 11, 5: 5}, + 4: {5: 10}, + 5: {3: 0}, + } + assert flowCost == 150 + assert nx.min_cost_flow_cost(G) == 150 + assert H == soln + assert nx.min_cost_flow(G) == soln + assert nx.cost_of_flow(G, H) == 150 + + flowCost, H = nx.capacity_scaling(G) + assert flowCost == 150 + assert H == soln + assert nx.cost_of_flow(G, H) == 150 + + def test_digraph2(self): + # Example from ticket #430 from mfrasca. Original source: + # http://www.cs.princeton.edu/courses/archive/spr03/cs226/lectures/mincost.4up.pdf, slide 11. + G = nx.DiGraph() + G.add_edge("s", 1, capacity=12) + G.add_edge("s", 2, capacity=6) + G.add_edge("s", 3, capacity=14) + G.add_edge(1, 2, capacity=11, weight=4) + G.add_edge(2, 3, capacity=9, weight=6) + G.add_edge(1, 4, capacity=5, weight=5) + G.add_edge(1, 5, capacity=2, weight=12) + G.add_edge(2, 5, capacity=4, weight=4) + G.add_edge(2, 6, capacity=2, weight=6) + G.add_edge(3, 6, capacity=31, weight=3) + G.add_edge(4, 5, capacity=18, weight=4) + G.add_edge(5, 6, capacity=9, weight=5) + G.add_edge(4, "t", capacity=3) + G.add_edge(5, "t", capacity=7) + G.add_edge(6, "t", capacity=22) + flow = nx.max_flow_min_cost(G, "s", "t") + soln = { + 1: {2: 6, 4: 5, 5: 1}, + 2: {3: 6, 5: 4, 6: 2}, + 3: {6: 20}, + 4: {5: 2, "t": 3}, + 5: {6: 0, "t": 7}, + 6: {"t": 22}, + "s": {1: 12, 2: 6, 3: 14}, + "t": {}, + } + assert flow == soln + + G.add_edge("t", "s", weight=-100) + flowCost, flow = nx.capacity_scaling(G) + G.remove_edge("t", "s") + assert flow["t"]["s"] == 32 + assert flowCost == -3007 + del flow["t"]["s"] + assert flow == soln + assert nx.cost_of_flow(G, flow) == 193 + + def test_digraph3(self): + """Combinatorial Optimization: Algorithms and Complexity, + Papadimitriou Steiglitz at page 140 has an example, 7.1, but that + admits multiple solutions, so I alter it a bit. From ticket #430 + by mfrasca.""" + + G = nx.DiGraph() + G.add_edge("s", "a") + G["s"]["a"].update({0: 2, 1: 4}) + G.add_edge("s", "b") + G["s"]["b"].update({0: 2, 1: 1}) + G.add_edge("a", "b") + G["a"]["b"].update({0: 5, 1: 2}) + G.add_edge("a", "t") + G["a"]["t"].update({0: 1, 1: 5}) + G.add_edge("b", "a") + G["b"]["a"].update({0: 1, 1: 3}) + G.add_edge("b", "t") + G["b"]["t"].update({0: 3, 1: 2}) + + "PS.ex.7.1: testing main function" + sol = nx.max_flow_min_cost(G, "s", "t", capacity=0, weight=1) + flow = sum(v for v in sol["s"].values()) + assert 4 == flow + assert 23 == nx.cost_of_flow(G, sol, weight=1) + assert sol["s"] == {"a": 2, "b": 2} + assert sol["a"] == {"b": 1, "t": 1} + assert sol["b"] == {"a": 0, "t": 3} + assert sol["t"] == {} + + G.add_edge("t", "s") + G["t"]["s"].update({1: -100}) + flowCost, sol = nx.capacity_scaling(G, capacity=0, weight=1) + G.remove_edge("t", "s") + flow = sum(v for v in sol["s"].values()) + assert 4 == flow + assert sol["t"]["s"] == 4 + assert flowCost == -377 + del sol["t"]["s"] + assert sol["s"] == {"a": 2, "b": 2} + assert sol["a"] == {"b": 1, "t": 1} + assert sol["b"] == {"a": 0, "t": 3} + assert sol["t"] == {} + assert nx.cost_of_flow(G, sol, weight=1) == 23 + + def test_zero_capacity_edges(self): + """Address issue raised in ticket #617 by arv.""" + G = nx.DiGraph() + G.add_edges_from( + [ + (1, 2, {"capacity": 1, "weight": 1}), + (1, 5, {"capacity": 1, "weight": 1}), + (2, 3, {"capacity": 0, "weight": 1}), + (2, 5, {"capacity": 1, "weight": 1}), + (5, 3, {"capacity": 2, "weight": 1}), + (5, 4, {"capacity": 0, "weight": 1}), + (3, 4, {"capacity": 2, "weight": 1}), + ] + ) + G.nodes[1]["demand"] = -1 + G.nodes[2]["demand"] = -1 + G.nodes[4]["demand"] = 2 + + flowCost, H = nx.network_simplex(G) + soln = {1: {2: 0, 5: 1}, 2: {3: 0, 5: 1}, 3: {4: 2}, 4: {}, 5: {3: 2, 4: 0}} + assert flowCost == 6 + assert nx.min_cost_flow_cost(G) == 6 + assert H == soln + assert nx.min_cost_flow(G) == soln + assert nx.cost_of_flow(G, H) == 6 + + flowCost, H = nx.capacity_scaling(G) + assert flowCost == 6 + assert H == soln + assert nx.cost_of_flow(G, H) == 6 + + def test_digon(self): + """Check if digons are handled properly. Taken from ticket + #618 by arv.""" + nodes = [(1, {}), (2, {"demand": -4}), (3, {"demand": 4})] + edges = [ + (1, 2, {"capacity": 3, "weight": 600000}), + (2, 1, {"capacity": 2, "weight": 0}), + (2, 3, {"capacity": 5, "weight": 714285}), + (3, 2, {"capacity": 2, "weight": 0}), + ] + G = nx.DiGraph(edges) + G.add_nodes_from(nodes) + flowCost, H = nx.network_simplex(G) + soln = {1: {2: 0}, 2: {1: 0, 3: 4}, 3: {2: 0}} + assert flowCost == 2857140 + assert nx.min_cost_flow_cost(G) == 2857140 + assert H == soln + assert nx.min_cost_flow(G) == soln + assert nx.cost_of_flow(G, H) == 2857140 + + flowCost, H = nx.capacity_scaling(G) + assert flowCost == 2857140 + assert H == soln + assert nx.cost_of_flow(G, H) == 2857140 + + def test_deadend(self): + """Check if one-node cycles are handled properly. Taken from ticket + #2906 from @sshraven.""" + G = nx.DiGraph() + + G.add_nodes_from(range(5), demand=0) + G.nodes[4]["demand"] = -13 + G.nodes[3]["demand"] = 13 + + G.add_edges_from([(0, 2), (0, 3), (2, 1)], capacity=20, weight=0.1) + pytest.raises(nx.NetworkXUnfeasible, nx.min_cost_flow, G) + + def test_infinite_capacity_neg_digon(self): + """An infinite capacity negative cost digon results in an unbounded + instance.""" + nodes = [(1, {}), (2, {"demand": -4}), (3, {"demand": 4})] + edges = [ + (1, 2, {"weight": -600}), + (2, 1, {"weight": 0}), + (2, 3, {"capacity": 5, "weight": 714285}), + (3, 2, {"capacity": 2, "weight": 0}), + ] + G = nx.DiGraph(edges) + G.add_nodes_from(nodes) + pytest.raises(nx.NetworkXUnbounded, nx.network_simplex, G) + pytest.raises(nx.NetworkXUnbounded, nx.capacity_scaling, G) + + def test_finite_capacity_neg_digon(self): + """The digon should receive the maximum amount of flow it can handle. + Taken from ticket #749 by @chuongdo.""" + G = nx.DiGraph() + G.add_edge("a", "b", capacity=1, weight=-1) + G.add_edge("b", "a", capacity=1, weight=-1) + min_cost = -2 + assert nx.min_cost_flow_cost(G) == min_cost + + flowCost, H = nx.capacity_scaling(G) + assert flowCost == -2 + assert H == {"a": {"b": 1}, "b": {"a": 1}} + assert nx.cost_of_flow(G, H) == -2 + + def test_multidigraph(self): + """Multidigraphs are acceptable.""" + G = nx.MultiDiGraph() + G.add_weighted_edges_from([(1, 2, 1), (2, 3, 2)], weight="capacity") + flowCost, H = nx.network_simplex(G) + assert flowCost == 0 + assert H == {1: {2: {0: 0}}, 2: {3: {0: 0}}, 3: {}} + + flowCost, H = nx.capacity_scaling(G) + assert flowCost == 0 + assert H == {1: {2: {0: 0}}, 2: {3: {0: 0}}, 3: {}} + + def test_negative_selfloops(self): + """Negative selfloops should cause an exception if uncapacitated and + always be saturated otherwise. + """ + G = nx.DiGraph() + G.add_edge(1, 1, weight=-1) + pytest.raises(nx.NetworkXUnbounded, nx.network_simplex, G) + pytest.raises(nx.NetworkXUnbounded, nx.capacity_scaling, G) + G[1][1]["capacity"] = 2 + flowCost, H = nx.network_simplex(G) + assert flowCost == -2 + assert H == {1: {1: 2}} + flowCost, H = nx.capacity_scaling(G) + assert flowCost == -2 + assert H == {1: {1: 2}} + + G = nx.MultiDiGraph() + G.add_edge(1, 1, "x", weight=-1) + G.add_edge(1, 1, "y", weight=1) + pytest.raises(nx.NetworkXUnbounded, nx.network_simplex, G) + pytest.raises(nx.NetworkXUnbounded, nx.capacity_scaling, G) + G[1][1]["x"]["capacity"] = 2 + flowCost, H = nx.network_simplex(G) + assert flowCost == -2 + assert H == {1: {1: {"x": 2, "y": 0}}} + flowCost, H = nx.capacity_scaling(G) + assert flowCost == -2 + assert H == {1: {1: {"x": 2, "y": 0}}} + + def test_bone_shaped(self): + # From #1283 + G = nx.DiGraph() + G.add_node(0, demand=-4) + G.add_node(1, demand=2) + G.add_node(2, demand=2) + G.add_node(3, demand=4) + G.add_node(4, demand=-2) + G.add_node(5, demand=-2) + G.add_edge(0, 1, capacity=4) + G.add_edge(0, 2, capacity=4) + G.add_edge(4, 3, capacity=4) + G.add_edge(5, 3, capacity=4) + G.add_edge(0, 3, capacity=0) + flowCost, H = nx.network_simplex(G) + assert flowCost == 0 + assert H == {0: {1: 2, 2: 2, 3: 0}, 1: {}, 2: {}, 3: {}, 4: {3: 2}, 5: {3: 2}} + flowCost, H = nx.capacity_scaling(G) + assert flowCost == 0 + assert H == {0: {1: 2, 2: 2, 3: 0}, 1: {}, 2: {}, 3: {}, 4: {3: 2}, 5: {3: 2}} + + def test_exceptions(self): + G = nx.Graph() + pytest.raises(nx.NetworkXNotImplemented, nx.network_simplex, G) + pytest.raises(nx.NetworkXNotImplemented, nx.capacity_scaling, G) + G = nx.MultiGraph() + pytest.raises(nx.NetworkXNotImplemented, nx.network_simplex, G) + pytest.raises(nx.NetworkXNotImplemented, nx.capacity_scaling, G) + G = nx.DiGraph() + pytest.raises(nx.NetworkXError, nx.network_simplex, G) + # pytest.raises(nx.NetworkXError, nx.capacity_scaling, G) + G.add_node(0, demand=float("inf")) + pytest.raises(nx.NetworkXError, nx.network_simplex, G) + pytest.raises(nx.NetworkXUnfeasible, nx.capacity_scaling, G) + G.nodes[0]["demand"] = 0 + G.add_node(1, demand=0) + G.add_edge(0, 1, weight=-float("inf")) + pytest.raises(nx.NetworkXError, nx.network_simplex, G) + pytest.raises(nx.NetworkXUnfeasible, nx.capacity_scaling, G) + G[0][1]["weight"] = 0 + G.add_edge(0, 0, weight=float("inf")) + pytest.raises(nx.NetworkXError, nx.network_simplex, G) + # pytest.raises(nx.NetworkXError, nx.capacity_scaling, G) + G[0][0]["weight"] = 0 + G[0][1]["capacity"] = -1 + pytest.raises(nx.NetworkXUnfeasible, nx.network_simplex, G) + # pytest.raises(nx.NetworkXUnfeasible, nx.capacity_scaling, G) + G[0][1]["capacity"] = 0 + G[0][0]["capacity"] = -1 + pytest.raises(nx.NetworkXUnfeasible, nx.network_simplex, G) + # pytest.raises(nx.NetworkXUnfeasible, nx.capacity_scaling, G) + + def test_large(self): + fname = ( + importlib.resources.files("networkx.algorithms.flow.tests") + / "netgen-2.gpickle.bz2" + ) + with bz2.BZ2File(fname, "rb") as f: + G = pickle.load(f) + flowCost, flowDict = nx.network_simplex(G) + assert 6749969302 == flowCost + assert 6749969302 == nx.cost_of_flow(G, flowDict) + flowCost, flowDict = nx.capacity_scaling(G) + assert 6749969302 == flowCost + assert 6749969302 == nx.cost_of_flow(G, flowDict) diff --git a/lib/python3.10/site-packages/networkx/algorithms/flow/tests/test_networksimplex.py b/lib/python3.10/site-packages/networkx/algorithms/flow/tests/test_networksimplex.py new file mode 100644 index 0000000000000000000000000000000000000000..5b3b5f6dd069e38b8be536cf4037a46da2366cc2 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/flow/tests/test_networksimplex.py @@ -0,0 +1,387 @@ +import bz2 +import importlib.resources +import os +import pickle + +import pytest + +import networkx as nx + + +@pytest.fixture +def simple_flow_graph(): + G = nx.DiGraph() + G.add_node("a", demand=0) + G.add_node("b", demand=-5) + G.add_node("c", demand=50000000) + G.add_node("d", demand=-49999995) + G.add_edge("a", "b", weight=3, capacity=4) + G.add_edge("a", "c", weight=6, capacity=10) + G.add_edge("b", "d", weight=1, capacity=9) + G.add_edge("c", "d", weight=2, capacity=5) + return G + + +@pytest.fixture +def simple_no_flow_graph(): + G = nx.DiGraph() + G.add_node("s", demand=-5) + G.add_node("t", demand=5) + G.add_edge("s", "a", weight=1, capacity=3) + G.add_edge("a", "b", weight=3) + G.add_edge("a", "c", weight=-6) + G.add_edge("b", "d", weight=1) + G.add_edge("c", "d", weight=-2) + G.add_edge("d", "t", weight=1, capacity=3) + return G + + +def get_flowcost_from_flowdict(G, flowDict): + """Returns flow cost calculated from flow dictionary""" + flowCost = 0 + for u in flowDict: + for v in flowDict[u]: + flowCost += flowDict[u][v] * G[u][v]["weight"] + return flowCost + + +def test_infinite_demand_raise(simple_flow_graph): + G = simple_flow_graph + inf = float("inf") + nx.set_node_attributes(G, {"a": {"demand": inf}}) + pytest.raises(nx.NetworkXError, nx.network_simplex, G) + + +def test_neg_infinite_demand_raise(simple_flow_graph): + G = simple_flow_graph + inf = float("inf") + nx.set_node_attributes(G, {"a": {"demand": -inf}}) + pytest.raises(nx.NetworkXError, nx.network_simplex, G) + + +def test_infinite_weight_raise(simple_flow_graph): + G = simple_flow_graph + inf = float("inf") + nx.set_edge_attributes( + G, {("a", "b"): {"weight": inf}, ("b", "d"): {"weight": inf}} + ) + pytest.raises(nx.NetworkXError, nx.network_simplex, G) + + +def test_nonzero_net_demand_raise(simple_flow_graph): + G = simple_flow_graph + nx.set_node_attributes(G, {"b": {"demand": -4}}) + pytest.raises(nx.NetworkXUnfeasible, nx.network_simplex, G) + + +def test_negative_capacity_raise(simple_flow_graph): + G = simple_flow_graph + nx.set_edge_attributes(G, {("a", "b"): {"weight": 1}, ("b", "d"): {"capacity": -9}}) + pytest.raises(nx.NetworkXUnfeasible, nx.network_simplex, G) + + +def test_no_flow_satisfying_demands(simple_no_flow_graph): + G = simple_no_flow_graph + pytest.raises(nx.NetworkXUnfeasible, nx.network_simplex, G) + + +def test_sum_demands_not_zero(simple_no_flow_graph): + G = simple_no_flow_graph + nx.set_node_attributes(G, {"t": {"demand": 4}}) + pytest.raises(nx.NetworkXUnfeasible, nx.network_simplex, G) + + +def test_google_or_tools_example(): + """ + https://developers.google.com/optimization/flow/mincostflow + """ + G = nx.DiGraph() + start_nodes = [0, 0, 1, 1, 1, 2, 2, 3, 4] + end_nodes = [1, 2, 2, 3, 4, 3, 4, 4, 2] + capacities = [15, 8, 20, 4, 10, 15, 4, 20, 5] + unit_costs = [4, 4, 2, 2, 6, 1, 3, 2, 3] + supplies = [20, 0, 0, -5, -15] + answer = 150 + + for i in range(len(supplies)): + G.add_node(i, demand=(-1) * supplies[i]) # supplies are negative of demand + + for i in range(len(start_nodes)): + G.add_edge( + start_nodes[i], end_nodes[i], weight=unit_costs[i], capacity=capacities[i] + ) + + flowCost, flowDict = nx.network_simplex(G) + assert flowCost == answer + assert flowCost == get_flowcost_from_flowdict(G, flowDict) + + +def test_google_or_tools_example2(): + """ + https://developers.google.com/optimization/flow/mincostflow + """ + G = nx.DiGraph() + start_nodes = [0, 0, 1, 1, 1, 2, 2, 3, 4, 3] + end_nodes = [1, 2, 2, 3, 4, 3, 4, 4, 2, 5] + capacities = [15, 8, 20, 4, 10, 15, 4, 20, 5, 10] + unit_costs = [4, 4, 2, 2, 6, 1, 3, 2, 3, 4] + supplies = [23, 0, 0, -5, -15, -3] + answer = 183 + + for i in range(len(supplies)): + G.add_node(i, demand=(-1) * supplies[i]) # supplies are negative of demand + + for i in range(len(start_nodes)): + G.add_edge( + start_nodes[i], end_nodes[i], weight=unit_costs[i], capacity=capacities[i] + ) + + flowCost, flowDict = nx.network_simplex(G) + assert flowCost == answer + assert flowCost == get_flowcost_from_flowdict(G, flowDict) + + +def test_large(): + fname = ( + importlib.resources.files("networkx.algorithms.flow.tests") + / "netgen-2.gpickle.bz2" + ) + + with bz2.BZ2File(fname, "rb") as f: + G = pickle.load(f) + flowCost, flowDict = nx.network_simplex(G) + assert 6749969302 == flowCost + assert 6749969302 == nx.cost_of_flow(G, flowDict) + + +def test_simple_digraph(): + G = nx.DiGraph() + G.add_node("a", demand=-5) + G.add_node("d", demand=5) + G.add_edge("a", "b", weight=3, capacity=4) + G.add_edge("a", "c", weight=6, capacity=10) + G.add_edge("b", "d", weight=1, capacity=9) + G.add_edge("c", "d", weight=2, capacity=5) + flowCost, H = nx.network_simplex(G) + soln = {"a": {"b": 4, "c": 1}, "b": {"d": 4}, "c": {"d": 1}, "d": {}} + assert flowCost == 24 + assert nx.min_cost_flow_cost(G) == 24 + assert H == soln + + +def test_negcycle_infcap(): + G = nx.DiGraph() + G.add_node("s", demand=-5) + G.add_node("t", demand=5) + G.add_edge("s", "a", weight=1, capacity=3) + G.add_edge("a", "b", weight=3) + G.add_edge("c", "a", weight=-6) + G.add_edge("b", "d", weight=1) + G.add_edge("d", "c", weight=-2) + G.add_edge("d", "t", weight=1, capacity=3) + pytest.raises(nx.NetworkXUnfeasible, nx.network_simplex, G) + + +def test_transshipment(): + G = nx.DiGraph() + G.add_node("a", demand=1) + G.add_node("b", demand=-2) + G.add_node("c", demand=-2) + G.add_node("d", demand=3) + G.add_node("e", demand=-4) + G.add_node("f", demand=-4) + G.add_node("g", demand=3) + G.add_node("h", demand=2) + G.add_node("r", demand=3) + G.add_edge("a", "c", weight=3) + G.add_edge("r", "a", weight=2) + G.add_edge("b", "a", weight=9) + G.add_edge("r", "c", weight=0) + G.add_edge("b", "r", weight=-6) + G.add_edge("c", "d", weight=5) + G.add_edge("e", "r", weight=4) + G.add_edge("e", "f", weight=3) + G.add_edge("h", "b", weight=4) + G.add_edge("f", "d", weight=7) + G.add_edge("f", "h", weight=12) + G.add_edge("g", "d", weight=12) + G.add_edge("f", "g", weight=-1) + G.add_edge("h", "g", weight=-10) + flowCost, H = nx.network_simplex(G) + soln = { + "a": {"c": 0}, + "b": {"a": 0, "r": 2}, + "c": {"d": 3}, + "d": {}, + "e": {"r": 3, "f": 1}, + "f": {"d": 0, "g": 3, "h": 2}, + "g": {"d": 0}, + "h": {"b": 0, "g": 0}, + "r": {"a": 1, "c": 1}, + } + assert flowCost == 41 + assert H == soln + + +def test_digraph1(): + # From Bradley, S. P., Hax, A. C. and Magnanti, T. L. Applied + # Mathematical Programming. Addison-Wesley, 1977. + G = nx.DiGraph() + G.add_node(1, demand=-20) + G.add_node(4, demand=5) + G.add_node(5, demand=15) + G.add_edges_from( + [ + (1, 2, {"capacity": 15, "weight": 4}), + (1, 3, {"capacity": 8, "weight": 4}), + (2, 3, {"weight": 2}), + (2, 4, {"capacity": 4, "weight": 2}), + (2, 5, {"capacity": 10, "weight": 6}), + (3, 4, {"capacity": 15, "weight": 1}), + (3, 5, {"capacity": 5, "weight": 3}), + (4, 5, {"weight": 2}), + (5, 3, {"capacity": 4, "weight": 1}), + ] + ) + flowCost, H = nx.network_simplex(G) + soln = { + 1: {2: 12, 3: 8}, + 2: {3: 8, 4: 4, 5: 0}, + 3: {4: 11, 5: 5}, + 4: {5: 10}, + 5: {3: 0}, + } + assert flowCost == 150 + assert nx.min_cost_flow_cost(G) == 150 + assert H == soln + + +def test_zero_capacity_edges(): + """Address issue raised in ticket #617 by arv.""" + G = nx.DiGraph() + G.add_edges_from( + [ + (1, 2, {"capacity": 1, "weight": 1}), + (1, 5, {"capacity": 1, "weight": 1}), + (2, 3, {"capacity": 0, "weight": 1}), + (2, 5, {"capacity": 1, "weight": 1}), + (5, 3, {"capacity": 2, "weight": 1}), + (5, 4, {"capacity": 0, "weight": 1}), + (3, 4, {"capacity": 2, "weight": 1}), + ] + ) + G.nodes[1]["demand"] = -1 + G.nodes[2]["demand"] = -1 + G.nodes[4]["demand"] = 2 + + flowCost, H = nx.network_simplex(G) + soln = {1: {2: 0, 5: 1}, 2: {3: 0, 5: 1}, 3: {4: 2}, 4: {}, 5: {3: 2, 4: 0}} + assert flowCost == 6 + assert nx.min_cost_flow_cost(G) == 6 + assert H == soln + + +def test_digon(): + """Check if digons are handled properly. Taken from ticket + #618 by arv.""" + nodes = [(1, {}), (2, {"demand": -4}), (3, {"demand": 4})] + edges = [ + (1, 2, {"capacity": 3, "weight": 600000}), + (2, 1, {"capacity": 2, "weight": 0}), + (2, 3, {"capacity": 5, "weight": 714285}), + (3, 2, {"capacity": 2, "weight": 0}), + ] + G = nx.DiGraph(edges) + G.add_nodes_from(nodes) + flowCost, H = nx.network_simplex(G) + soln = {1: {2: 0}, 2: {1: 0, 3: 4}, 3: {2: 0}} + assert flowCost == 2857140 + + +def test_deadend(): + """Check if one-node cycles are handled properly. Taken from ticket + #2906 from @sshraven.""" + G = nx.DiGraph() + + G.add_nodes_from(range(5), demand=0) + G.nodes[4]["demand"] = -13 + G.nodes[3]["demand"] = 13 + + G.add_edges_from([(0, 2), (0, 3), (2, 1)], capacity=20, weight=0.1) + pytest.raises(nx.NetworkXUnfeasible, nx.network_simplex, G) + + +def test_infinite_capacity_neg_digon(): + """An infinite capacity negative cost digon results in an unbounded + instance.""" + nodes = [(1, {}), (2, {"demand": -4}), (3, {"demand": 4})] + edges = [ + (1, 2, {"weight": -600}), + (2, 1, {"weight": 0}), + (2, 3, {"capacity": 5, "weight": 714285}), + (3, 2, {"capacity": 2, "weight": 0}), + ] + G = nx.DiGraph(edges) + G.add_nodes_from(nodes) + pytest.raises(nx.NetworkXUnbounded, nx.network_simplex, G) + + +def test_multidigraph(): + """Multidigraphs are acceptable.""" + G = nx.MultiDiGraph() + G.add_weighted_edges_from([(1, 2, 1), (2, 3, 2)], weight="capacity") + flowCost, H = nx.network_simplex(G) + assert flowCost == 0 + assert H == {1: {2: {0: 0}}, 2: {3: {0: 0}}, 3: {}} + + +def test_negative_selfloops(): + """Negative selfloops should cause an exception if uncapacitated and + always be saturated otherwise. + """ + G = nx.DiGraph() + G.add_edge(1, 1, weight=-1) + pytest.raises(nx.NetworkXUnbounded, nx.network_simplex, G) + + G[1][1]["capacity"] = 2 + flowCost, H = nx.network_simplex(G) + assert flowCost == -2 + assert H == {1: {1: 2}} + + G = nx.MultiDiGraph() + G.add_edge(1, 1, "x", weight=-1) + G.add_edge(1, 1, "y", weight=1) + pytest.raises(nx.NetworkXUnbounded, nx.network_simplex, G) + + G[1][1]["x"]["capacity"] = 2 + flowCost, H = nx.network_simplex(G) + assert flowCost == -2 + assert H == {1: {1: {"x": 2, "y": 0}}} + + +def test_bone_shaped(): + # From #1283 + G = nx.DiGraph() + G.add_node(0, demand=-4) + G.add_node(1, demand=2) + G.add_node(2, demand=2) + G.add_node(3, demand=4) + G.add_node(4, demand=-2) + G.add_node(5, demand=-2) + G.add_edge(0, 1, capacity=4) + G.add_edge(0, 2, capacity=4) + G.add_edge(4, 3, capacity=4) + G.add_edge(5, 3, capacity=4) + G.add_edge(0, 3, capacity=0) + flowCost, H = nx.network_simplex(G) + assert flowCost == 0 + assert H == {0: {1: 2, 2: 2, 3: 0}, 1: {}, 2: {}, 3: {}, 4: {3: 2}, 5: {3: 2}} + + +def test_graphs_type_exceptions(): + G = nx.Graph() + pytest.raises(nx.NetworkXNotImplemented, nx.network_simplex, G) + G = nx.MultiGraph() + pytest.raises(nx.NetworkXNotImplemented, nx.network_simplex, G) + G = nx.DiGraph() + pytest.raises(nx.NetworkXError, nx.network_simplex, G) diff --git a/lib/python3.10/site-packages/networkx/algorithms/flow/utils.py b/lib/python3.10/site-packages/networkx/algorithms/flow/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..03f1d10f75a9f2b3d80bdea8b4c086cebe1df966 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/flow/utils.py @@ -0,0 +1,189 @@ +""" +Utility classes and functions for network flow algorithms. +""" + +from collections import deque + +import networkx as nx + +__all__ = [ + "CurrentEdge", + "Level", + "GlobalRelabelThreshold", + "build_residual_network", + "detect_unboundedness", + "build_flow_dict", +] + + +class CurrentEdge: + """Mechanism for iterating over out-edges incident to a node in a circular + manner. StopIteration exception is raised when wraparound occurs. + """ + + __slots__ = ("_edges", "_it", "_curr") + + def __init__(self, edges): + self._edges = edges + if self._edges: + self._rewind() + + def get(self): + return self._curr + + def move_to_next(self): + try: + self._curr = next(self._it) + except StopIteration: + self._rewind() + raise + + def _rewind(self): + self._it = iter(self._edges.items()) + self._curr = next(self._it) + + +class Level: + """Active and inactive nodes in a level.""" + + __slots__ = ("active", "inactive") + + def __init__(self): + self.active = set() + self.inactive = set() + + +class GlobalRelabelThreshold: + """Measurement of work before the global relabeling heuristic should be + applied. + """ + + def __init__(self, n, m, freq): + self._threshold = (n + m) / freq if freq else float("inf") + self._work = 0 + + def add_work(self, work): + self._work += work + + def is_reached(self): + return self._work >= self._threshold + + def clear_work(self): + self._work = 0 + + +@nx._dispatchable(edge_attrs={"capacity": float("inf")}, returns_graph=True) +def build_residual_network(G, capacity): + """Build a residual network and initialize a zero flow. + + The residual network :samp:`R` from an input graph :samp:`G` has the + same nodes as :samp:`G`. :samp:`R` is a DiGraph that contains a pair + of edges :samp:`(u, v)` and :samp:`(v, u)` iff :samp:`(u, v)` is not a + self-loop, and at least one of :samp:`(u, v)` and :samp:`(v, u)` exists + in :samp:`G`. + + For each edge :samp:`(u, v)` in :samp:`R`, :samp:`R[u][v]['capacity']` + is equal to the capacity of :samp:`(u, v)` in :samp:`G` if it exists + in :samp:`G` or zero otherwise. If the capacity is infinite, + :samp:`R[u][v]['capacity']` will have a high arbitrary finite value + that does not affect the solution of the problem. This value is stored in + :samp:`R.graph['inf']`. For each edge :samp:`(u, v)` in :samp:`R`, + :samp:`R[u][v]['flow']` represents the flow function of :samp:`(u, v)` and + satisfies :samp:`R[u][v]['flow'] == -R[v][u]['flow']`. + + The flow value, defined as the total flow into :samp:`t`, the sink, is + stored in :samp:`R.graph['flow_value']`. If :samp:`cutoff` is not + specified, reachability to :samp:`t` using only edges :samp:`(u, v)` such + that :samp:`R[u][v]['flow'] < R[u][v]['capacity']` induces a minimum + :samp:`s`-:samp:`t` cut. + + """ + if G.is_multigraph(): + raise nx.NetworkXError("MultiGraph and MultiDiGraph not supported (yet).") + + R = nx.DiGraph() + R.__networkx_cache__ = None # Disable caching + R.add_nodes_from(G) + + inf = float("inf") + # Extract edges with positive capacities. Self loops excluded. + edge_list = [ + (u, v, attr) + for u, v, attr in G.edges(data=True) + if u != v and attr.get(capacity, inf) > 0 + ] + # Simulate infinity with three times the sum of the finite edge capacities + # or any positive value if the sum is zero. This allows the + # infinite-capacity edges to be distinguished for unboundedness detection + # and directly participate in residual capacity calculation. If the maximum + # flow is finite, these edges cannot appear in the minimum cut and thus + # guarantee correctness. Since the residual capacity of an + # infinite-capacity edge is always at least 2/3 of inf, while that of an + # finite-capacity edge is at most 1/3 of inf, if an operation moves more + # than 1/3 of inf units of flow to t, there must be an infinite-capacity + # s-t path in G. + inf = ( + 3 + * sum( + attr[capacity] + for u, v, attr in edge_list + if capacity in attr and attr[capacity] != inf + ) + or 1 + ) + if G.is_directed(): + for u, v, attr in edge_list: + r = min(attr.get(capacity, inf), inf) + if not R.has_edge(u, v): + # Both (u, v) and (v, u) must be present in the residual + # network. + R.add_edge(u, v, capacity=r) + R.add_edge(v, u, capacity=0) + else: + # The edge (u, v) was added when (v, u) was visited. + R[u][v]["capacity"] = r + else: + for u, v, attr in edge_list: + # Add a pair of edges with equal residual capacities. + r = min(attr.get(capacity, inf), inf) + R.add_edge(u, v, capacity=r) + R.add_edge(v, u, capacity=r) + + # Record the value simulating infinity. + R.graph["inf"] = inf + + return R + + +@nx._dispatchable( + graphs="R", + preserve_edge_attrs={"R": {"capacity": float("inf")}}, + preserve_graph_attrs=True, +) +def detect_unboundedness(R, s, t): + """Detect an infinite-capacity s-t path in R.""" + q = deque([s]) + seen = {s} + inf = R.graph["inf"] + while q: + u = q.popleft() + for v, attr in R[u].items(): + if attr["capacity"] == inf and v not in seen: + if v == t: + raise nx.NetworkXUnbounded( + "Infinite capacity path, flow unbounded above." + ) + seen.add(v) + q.append(v) + + +@nx._dispatchable(graphs={"G": 0, "R": 1}, preserve_edge_attrs={"R": {"flow": None}}) +def build_flow_dict(G, R): + """Build a flow dictionary from a residual network.""" + flow_dict = {} + for u in G: + flow_dict[u] = {v: 0 for v in G[u]} + flow_dict[u].update( + (v, attr["flow"]) for v, attr in R[u].items() if attr["flow"] > 0 + ) + return flow_dict diff --git a/lib/python3.10/site-packages/networkx/algorithms/isomorphism/__init__.py b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..58c22688660073a6abb59f7639871f711d1bd6ac --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/__init__.py @@ -0,0 +1,7 @@ +from networkx.algorithms.isomorphism.isomorph import * +from networkx.algorithms.isomorphism.vf2userfunc import * +from networkx.algorithms.isomorphism.matchhelpers import * +from networkx.algorithms.isomorphism.temporalisomorphvf2 import * +from networkx.algorithms.isomorphism.ismags import * +from networkx.algorithms.isomorphism.tree_isomorphism import * +from networkx.algorithms.isomorphism.vf2pp import * diff --git a/lib/python3.10/site-packages/networkx/algorithms/isomorphism/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5c946c0ad6a52b7c5969d674926421e55c202733 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/isomorphism/__pycache__/ismags.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/__pycache__/ismags.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..126c1ba566961a28367a3283f97dc3f018c758a2 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/__pycache__/ismags.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/isomorphism/__pycache__/isomorph.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/__pycache__/isomorph.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a3f7f437d3f675f084249e5a4b5e6d56bdcc1a63 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/__pycache__/isomorph.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/isomorphism/__pycache__/isomorphvf2.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/__pycache__/isomorphvf2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20b0fb411b1c3af0df8ce961d161db0d38ce5a21 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/__pycache__/isomorphvf2.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/isomorphism/__pycache__/matchhelpers.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/__pycache__/matchhelpers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..39f26097b2a1e29836a85e69ee15c584ffb13e24 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/__pycache__/matchhelpers.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/isomorphism/__pycache__/temporalisomorphvf2.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/__pycache__/temporalisomorphvf2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..83cab377634c52a6f1f70164299f566a1187d487 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/__pycache__/temporalisomorphvf2.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/isomorphism/__pycache__/tree_isomorphism.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/__pycache__/tree_isomorphism.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..59143633d40e9a4b1464132f6ce70c4d04efd096 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/__pycache__/tree_isomorphism.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/isomorphism/__pycache__/vf2pp.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/__pycache__/vf2pp.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea96bb14c7ddbb5b8db95e6f0ba09fb555287af6 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/__pycache__/vf2pp.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/isomorphism/__pycache__/vf2userfunc.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/__pycache__/vf2userfunc.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d987c8d1e9418590c0def997fc1b980ab410e57d Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/__pycache__/vf2userfunc.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/isomorphism/ismags.py b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/ismags.py new file mode 100644 index 0000000000000000000000000000000000000000..24819faf95cf42f8264960a4a6a27754301fc661 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/ismags.py @@ -0,0 +1,1163 @@ +""" +ISMAGS Algorithm +================ + +Provides a Python implementation of the ISMAGS algorithm. [1]_ + +It is capable of finding (subgraph) isomorphisms between two graphs, taking the +symmetry of the subgraph into account. In most cases the VF2 algorithm is +faster (at least on small graphs) than this implementation, but in some cases +there is an exponential number of isomorphisms that are symmetrically +equivalent. In that case, the ISMAGS algorithm will provide only one solution +per symmetry group. + +>>> petersen = nx.petersen_graph() +>>> ismags = nx.isomorphism.ISMAGS(petersen, petersen) +>>> isomorphisms = list(ismags.isomorphisms_iter(symmetry=False)) +>>> len(isomorphisms) +120 +>>> isomorphisms = list(ismags.isomorphisms_iter(symmetry=True)) +>>> answer = [{0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7, 8: 8, 9: 9}] +>>> answer == isomorphisms +True + +In addition, this implementation also provides an interface to find the +largest common induced subgraph [2]_ between any two graphs, again taking +symmetry into account. Given `graph` and `subgraph` the algorithm will remove +nodes from the `subgraph` until `subgraph` is isomorphic to a subgraph of +`graph`. Since only the symmetry of `subgraph` is taken into account it is +worth thinking about how you provide your graphs: + +>>> graph1 = nx.path_graph(4) +>>> graph2 = nx.star_graph(3) +>>> ismags = nx.isomorphism.ISMAGS(graph1, graph2) +>>> ismags.is_isomorphic() +False +>>> largest_common_subgraph = list(ismags.largest_common_subgraph()) +>>> answer = [{1: 0, 0: 1, 2: 2}, {2: 0, 1: 1, 3: 2}] +>>> answer == largest_common_subgraph +True +>>> ismags2 = nx.isomorphism.ISMAGS(graph2, graph1) +>>> largest_common_subgraph = list(ismags2.largest_common_subgraph()) +>>> answer = [ +... {1: 0, 0: 1, 2: 2}, +... {1: 0, 0: 1, 3: 2}, +... {2: 0, 0: 1, 1: 2}, +... {2: 0, 0: 1, 3: 2}, +... {3: 0, 0: 1, 1: 2}, +... {3: 0, 0: 1, 2: 2}, +... ] +>>> answer == largest_common_subgraph +True + +However, when not taking symmetry into account, it doesn't matter: + +>>> largest_common_subgraph = list(ismags.largest_common_subgraph(symmetry=False)) +>>> answer = [ +... {1: 0, 0: 1, 2: 2}, +... {1: 0, 2: 1, 0: 2}, +... {2: 0, 1: 1, 3: 2}, +... {2: 0, 3: 1, 1: 2}, +... {1: 0, 0: 1, 2: 3}, +... {1: 0, 2: 1, 0: 3}, +... {2: 0, 1: 1, 3: 3}, +... {2: 0, 3: 1, 1: 3}, +... {1: 0, 0: 2, 2: 3}, +... {1: 0, 2: 2, 0: 3}, +... {2: 0, 1: 2, 3: 3}, +... {2: 0, 3: 2, 1: 3}, +... ] +>>> answer == largest_common_subgraph +True +>>> largest_common_subgraph = list(ismags2.largest_common_subgraph(symmetry=False)) +>>> answer = [ +... {1: 0, 0: 1, 2: 2}, +... {1: 0, 0: 1, 3: 2}, +... {2: 0, 0: 1, 1: 2}, +... {2: 0, 0: 1, 3: 2}, +... {3: 0, 0: 1, 1: 2}, +... {3: 0, 0: 1, 2: 2}, +... {1: 1, 0: 2, 2: 3}, +... {1: 1, 0: 2, 3: 3}, +... {2: 1, 0: 2, 1: 3}, +... {2: 1, 0: 2, 3: 3}, +... {3: 1, 0: 2, 1: 3}, +... {3: 1, 0: 2, 2: 3}, +... ] +>>> answer == largest_common_subgraph +True + +Notes +----- +- The current implementation works for undirected graphs only. The algorithm + in general should work for directed graphs as well though. +- Node keys for both provided graphs need to be fully orderable as well as + hashable. +- Node and edge equality is assumed to be transitive: if A is equal to B, and + B is equal to C, then A is equal to C. + +References +---------- +.. [1] M. Houbraken, S. Demeyer, T. Michoel, P. Audenaert, D. Colle, + M. Pickavet, "The Index-Based Subgraph Matching Algorithm with General + Symmetries (ISMAGS): Exploiting Symmetry for Faster Subgraph + Enumeration", PLoS One 9(5): e97896, 2014. + https://doi.org/10.1371/journal.pone.0097896 +.. [2] https://en.wikipedia.org/wiki/Maximum_common_induced_subgraph +""" + +__all__ = ["ISMAGS"] + +import itertools +from collections import Counter, defaultdict +from functools import reduce, wraps + + +def are_all_equal(iterable): + """ + Returns ``True`` if and only if all elements in `iterable` are equal; and + ``False`` otherwise. + + Parameters + ---------- + iterable: collections.abc.Iterable + The container whose elements will be checked. + + Returns + ------- + bool + ``True`` iff all elements in `iterable` compare equal, ``False`` + otherwise. + """ + try: + shape = iterable.shape + except AttributeError: + pass + else: + if len(shape) > 1: + message = "The function does not works on multidimensional arrays." + raise NotImplementedError(message) from None + + iterator = iter(iterable) + first = next(iterator, None) + return all(item == first for item in iterator) + + +def make_partitions(items, test): + """ + Partitions items into sets based on the outcome of ``test(item1, item2)``. + Pairs of items for which `test` returns `True` end up in the same set. + + Parameters + ---------- + items : collections.abc.Iterable[collections.abc.Hashable] + Items to partition + test : collections.abc.Callable[collections.abc.Hashable, collections.abc.Hashable] + A function that will be called with 2 arguments, taken from items. + Should return `True` if those 2 items need to end up in the same + partition, and `False` otherwise. + + Returns + ------- + list[set] + A list of sets, with each set containing part of the items in `items`, + such that ``all(test(*pair) for pair in itertools.combinations(set, 2)) + == True`` + + Notes + ----- + The function `test` is assumed to be transitive: if ``test(a, b)`` and + ``test(b, c)`` return ``True``, then ``test(a, c)`` must also be ``True``. + """ + partitions = [] + for item in items: + for partition in partitions: + p_item = next(iter(partition)) + if test(item, p_item): + partition.add(item) + break + else: # No break + partitions.append({item}) + return partitions + + +def partition_to_color(partitions): + """ + Creates a dictionary that maps each item in each partition to the index of + the partition to which it belongs. + + Parameters + ---------- + partitions: collections.abc.Sequence[collections.abc.Iterable] + As returned by :func:`make_partitions`. + + Returns + ------- + dict + """ + colors = {} + for color, keys in enumerate(partitions): + for key in keys: + colors[key] = color + return colors + + +def intersect(collection_of_sets): + """ + Given an collection of sets, returns the intersection of those sets. + + Parameters + ---------- + collection_of_sets: collections.abc.Collection[set] + A collection of sets. + + Returns + ------- + set + An intersection of all sets in `collection_of_sets`. Will have the same + type as the item initially taken from `collection_of_sets`. + """ + collection_of_sets = list(collection_of_sets) + first = collection_of_sets.pop() + out = reduce(set.intersection, collection_of_sets, set(first)) + return type(first)(out) + + +class ISMAGS: + """ + Implements the ISMAGS subgraph matching algorithm. [1]_ ISMAGS stands for + "Index-based Subgraph Matching Algorithm with General Symmetries". As the + name implies, it is symmetry aware and will only generate non-symmetric + isomorphisms. + + Notes + ----- + The implementation imposes additional conditions compared to the VF2 + algorithm on the graphs provided and the comparison functions + (:attr:`node_equality` and :attr:`edge_equality`): + + - Node keys in both graphs must be orderable as well as hashable. + - Equality must be transitive: if A is equal to B, and B is equal to C, + then A must be equal to C. + + Attributes + ---------- + graph: networkx.Graph + subgraph: networkx.Graph + node_equality: collections.abc.Callable + The function called to see if two nodes should be considered equal. + It's signature looks like this: + ``f(graph1: networkx.Graph, node1, graph2: networkx.Graph, node2) -> bool``. + `node1` is a node in `graph1`, and `node2` a node in `graph2`. + Constructed from the argument `node_match`. + edge_equality: collections.abc.Callable + The function called to see if two edges should be considered equal. + It's signature looks like this: + ``f(graph1: networkx.Graph, edge1, graph2: networkx.Graph, edge2) -> bool``. + `edge1` is an edge in `graph1`, and `edge2` an edge in `graph2`. + Constructed from the argument `edge_match`. + + References + ---------- + .. [1] M. Houbraken, S. Demeyer, T. Michoel, P. Audenaert, D. Colle, + M. Pickavet, "The Index-Based Subgraph Matching Algorithm with General + Symmetries (ISMAGS): Exploiting Symmetry for Faster Subgraph + Enumeration", PLoS One 9(5): e97896, 2014. + https://doi.org/10.1371/journal.pone.0097896 + """ + + def __init__(self, graph, subgraph, node_match=None, edge_match=None, cache=None): + """ + Parameters + ---------- + graph: networkx.Graph + subgraph: networkx.Graph + node_match: collections.abc.Callable or None + Function used to determine whether two nodes are equivalent. Its + signature should look like ``f(n1: dict, n2: dict) -> bool``, with + `n1` and `n2` node property dicts. See also + :func:`~networkx.algorithms.isomorphism.categorical_node_match` and + friends. + If `None`, all nodes are considered equal. + edge_match: collections.abc.Callable or None + Function used to determine whether two edges are equivalent. Its + signature should look like ``f(e1: dict, e2: dict) -> bool``, with + `e1` and `e2` edge property dicts. See also + :func:`~networkx.algorithms.isomorphism.categorical_edge_match` and + friends. + If `None`, all edges are considered equal. + cache: collections.abc.Mapping + A cache used for caching graph symmetries. + """ + # TODO: graph and subgraph setter methods that invalidate the caches. + # TODO: allow for precomputed partitions and colors + self.graph = graph + self.subgraph = subgraph + self._symmetry_cache = cache + # Naming conventions are taken from the original paper. For your + # sanity: + # sg: subgraph + # g: graph + # e: edge(s) + # n: node(s) + # So: sgn means "subgraph nodes". + self._sgn_partitions_ = None + self._sge_partitions_ = None + + self._sgn_colors_ = None + self._sge_colors_ = None + + self._gn_partitions_ = None + self._ge_partitions_ = None + + self._gn_colors_ = None + self._ge_colors_ = None + + self._node_compat_ = None + self._edge_compat_ = None + + if node_match is None: + self.node_equality = self._node_match_maker(lambda n1, n2: True) + self._sgn_partitions_ = [set(self.subgraph.nodes)] + self._gn_partitions_ = [set(self.graph.nodes)] + self._node_compat_ = {0: 0} + else: + self.node_equality = self._node_match_maker(node_match) + if edge_match is None: + self.edge_equality = self._edge_match_maker(lambda e1, e2: True) + self._sge_partitions_ = [set(self.subgraph.edges)] + self._ge_partitions_ = [set(self.graph.edges)] + self._edge_compat_ = {0: 0} + else: + self.edge_equality = self._edge_match_maker(edge_match) + + @property + def _sgn_partitions(self): + if self._sgn_partitions_ is None: + + def nodematch(node1, node2): + return self.node_equality(self.subgraph, node1, self.subgraph, node2) + + self._sgn_partitions_ = make_partitions(self.subgraph.nodes, nodematch) + return self._sgn_partitions_ + + @property + def _sge_partitions(self): + if self._sge_partitions_ is None: + + def edgematch(edge1, edge2): + return self.edge_equality(self.subgraph, edge1, self.subgraph, edge2) + + self._sge_partitions_ = make_partitions(self.subgraph.edges, edgematch) + return self._sge_partitions_ + + @property + def _gn_partitions(self): + if self._gn_partitions_ is None: + + def nodematch(node1, node2): + return self.node_equality(self.graph, node1, self.graph, node2) + + self._gn_partitions_ = make_partitions(self.graph.nodes, nodematch) + return self._gn_partitions_ + + @property + def _ge_partitions(self): + if self._ge_partitions_ is None: + + def edgematch(edge1, edge2): + return self.edge_equality(self.graph, edge1, self.graph, edge2) + + self._ge_partitions_ = make_partitions(self.graph.edges, edgematch) + return self._ge_partitions_ + + @property + def _sgn_colors(self): + if self._sgn_colors_ is None: + self._sgn_colors_ = partition_to_color(self._sgn_partitions) + return self._sgn_colors_ + + @property + def _sge_colors(self): + if self._sge_colors_ is None: + self._sge_colors_ = partition_to_color(self._sge_partitions) + return self._sge_colors_ + + @property + def _gn_colors(self): + if self._gn_colors_ is None: + self._gn_colors_ = partition_to_color(self._gn_partitions) + return self._gn_colors_ + + @property + def _ge_colors(self): + if self._ge_colors_ is None: + self._ge_colors_ = partition_to_color(self._ge_partitions) + return self._ge_colors_ + + @property + def _node_compatibility(self): + if self._node_compat_ is not None: + return self._node_compat_ + self._node_compat_ = {} + for sgn_part_color, gn_part_color in itertools.product( + range(len(self._sgn_partitions)), range(len(self._gn_partitions)) + ): + sgn = next(iter(self._sgn_partitions[sgn_part_color])) + gn = next(iter(self._gn_partitions[gn_part_color])) + if self.node_equality(self.subgraph, sgn, self.graph, gn): + self._node_compat_[sgn_part_color] = gn_part_color + return self._node_compat_ + + @property + def _edge_compatibility(self): + if self._edge_compat_ is not None: + return self._edge_compat_ + self._edge_compat_ = {} + for sge_part_color, ge_part_color in itertools.product( + range(len(self._sge_partitions)), range(len(self._ge_partitions)) + ): + sge = next(iter(self._sge_partitions[sge_part_color])) + ge = next(iter(self._ge_partitions[ge_part_color])) + if self.edge_equality(self.subgraph, sge, self.graph, ge): + self._edge_compat_[sge_part_color] = ge_part_color + return self._edge_compat_ + + @staticmethod + def _node_match_maker(cmp): + @wraps(cmp) + def comparer(graph1, node1, graph2, node2): + return cmp(graph1.nodes[node1], graph2.nodes[node2]) + + return comparer + + @staticmethod + def _edge_match_maker(cmp): + @wraps(cmp) + def comparer(graph1, edge1, graph2, edge2): + return cmp(graph1.edges[edge1], graph2.edges[edge2]) + + return comparer + + def find_isomorphisms(self, symmetry=True): + """Find all subgraph isomorphisms between subgraph and graph + + Finds isomorphisms where :attr:`subgraph` <= :attr:`graph`. + + Parameters + ---------- + symmetry: bool + Whether symmetry should be taken into account. If False, found + isomorphisms may be symmetrically equivalent. + + Yields + ------ + dict + The found isomorphism mappings of {graph_node: subgraph_node}. + """ + # The networkx VF2 algorithm is slightly funny in when it yields an + # empty dict and when not. + if not self.subgraph: + yield {} + return + elif not self.graph: + return + elif len(self.graph) < len(self.subgraph): + return + + if symmetry: + _, cosets = self.analyze_symmetry( + self.subgraph, self._sgn_partitions, self._sge_colors + ) + constraints = self._make_constraints(cosets) + else: + constraints = [] + + candidates = self._find_nodecolor_candidates() + la_candidates = self._get_lookahead_candidates() + for sgn in self.subgraph: + extra_candidates = la_candidates[sgn] + if extra_candidates: + candidates[sgn] = candidates[sgn] | {frozenset(extra_candidates)} + + if any(candidates.values()): + start_sgn = min(candidates, key=lambda n: min(candidates[n], key=len)) + candidates[start_sgn] = (intersect(candidates[start_sgn]),) + yield from self._map_nodes(start_sgn, candidates, constraints) + else: + return + + @staticmethod + def _find_neighbor_color_count(graph, node, node_color, edge_color): + """ + For `node` in `graph`, count the number of edges of a specific color + it has to nodes of a specific color. + """ + counts = Counter() + neighbors = graph[node] + for neighbor in neighbors: + n_color = node_color[neighbor] + if (node, neighbor) in edge_color: + e_color = edge_color[node, neighbor] + else: + e_color = edge_color[neighbor, node] + counts[e_color, n_color] += 1 + return counts + + def _get_lookahead_candidates(self): + """ + Returns a mapping of {subgraph node: collection of graph nodes} for + which the graph nodes are feasible candidates for the subgraph node, as + determined by looking ahead one edge. + """ + g_counts = {} + for gn in self.graph: + g_counts[gn] = self._find_neighbor_color_count( + self.graph, gn, self._gn_colors, self._ge_colors + ) + candidates = defaultdict(set) + for sgn in self.subgraph: + sg_count = self._find_neighbor_color_count( + self.subgraph, sgn, self._sgn_colors, self._sge_colors + ) + new_sg_count = Counter() + for (sge_color, sgn_color), count in sg_count.items(): + try: + ge_color = self._edge_compatibility[sge_color] + gn_color = self._node_compatibility[sgn_color] + except KeyError: + pass + else: + new_sg_count[ge_color, gn_color] = count + + for gn, g_count in g_counts.items(): + if all(new_sg_count[x] <= g_count[x] for x in new_sg_count): + # Valid candidate + candidates[sgn].add(gn) + return candidates + + def largest_common_subgraph(self, symmetry=True): + """ + Find the largest common induced subgraphs between :attr:`subgraph` and + :attr:`graph`. + + Parameters + ---------- + symmetry: bool + Whether symmetry should be taken into account. If False, found + largest common subgraphs may be symmetrically equivalent. + + Yields + ------ + dict + The found isomorphism mappings of {graph_node: subgraph_node}. + """ + # The networkx VF2 algorithm is slightly funny in when it yields an + # empty dict and when not. + if not self.subgraph: + yield {} + return + elif not self.graph: + return + + if symmetry: + _, cosets = self.analyze_symmetry( + self.subgraph, self._sgn_partitions, self._sge_colors + ) + constraints = self._make_constraints(cosets) + else: + constraints = [] + + candidates = self._find_nodecolor_candidates() + + if any(candidates.values()): + yield from self._largest_common_subgraph(candidates, constraints) + else: + return + + def analyze_symmetry(self, graph, node_partitions, edge_colors): + """ + Find a minimal set of permutations and corresponding co-sets that + describe the symmetry of `graph`, given the node and edge equalities + given by `node_partitions` and `edge_colors`, respectively. + + Parameters + ---------- + graph : networkx.Graph + The graph whose symmetry should be analyzed. + node_partitions : list of sets + A list of sets containing node keys. Node keys in the same set + are considered equivalent. Every node key in `graph` should be in + exactly one of the sets. If all nodes are equivalent, this should + be ``[set(graph.nodes)]``. + edge_colors : dict mapping edges to their colors + A dict mapping every edge in `graph` to its corresponding color. + Edges with the same color are considered equivalent. If all edges + are equivalent, this should be ``{e: 0 for e in graph.edges}``. + + + Returns + ------- + set[frozenset] + The found permutations. This is a set of frozensets of pairs of node + keys which can be exchanged without changing :attr:`subgraph`. + dict[collections.abc.Hashable, set[collections.abc.Hashable]] + The found co-sets. The co-sets is a dictionary of + ``{node key: set of node keys}``. + Every key-value pair describes which ``values`` can be interchanged + without changing nodes less than ``key``. + """ + if self._symmetry_cache is not None: + key = hash( + ( + tuple(graph.nodes), + tuple(graph.edges), + tuple(map(tuple, node_partitions)), + tuple(edge_colors.items()), + ) + ) + if key in self._symmetry_cache: + return self._symmetry_cache[key] + node_partitions = list( + self._refine_node_partitions(graph, node_partitions, edge_colors) + ) + assert len(node_partitions) == 1 + node_partitions = node_partitions[0] + permutations, cosets = self._process_ordered_pair_partitions( + graph, node_partitions, node_partitions, edge_colors + ) + if self._symmetry_cache is not None: + self._symmetry_cache[key] = permutations, cosets + return permutations, cosets + + def is_isomorphic(self, symmetry=False): + """ + Returns True if :attr:`graph` is isomorphic to :attr:`subgraph` and + False otherwise. + + Returns + ------- + bool + """ + return len(self.subgraph) == len(self.graph) and self.subgraph_is_isomorphic( + symmetry + ) + + def subgraph_is_isomorphic(self, symmetry=False): + """ + Returns True if a subgraph of :attr:`graph` is isomorphic to + :attr:`subgraph` and False otherwise. + + Returns + ------- + bool + """ + # symmetry=False, since we only need to know whether there is any + # example; figuring out all symmetry elements probably costs more time + # than it gains. + isom = next(self.subgraph_isomorphisms_iter(symmetry=symmetry), None) + return isom is not None + + def isomorphisms_iter(self, symmetry=True): + """ + Does the same as :meth:`find_isomorphisms` if :attr:`graph` and + :attr:`subgraph` have the same number of nodes. + """ + if len(self.graph) == len(self.subgraph): + yield from self.subgraph_isomorphisms_iter(symmetry=symmetry) + + def subgraph_isomorphisms_iter(self, symmetry=True): + """Alternative name for :meth:`find_isomorphisms`.""" + return self.find_isomorphisms(symmetry) + + def _find_nodecolor_candidates(self): + """ + Per node in subgraph find all nodes in graph that have the same color. + """ + candidates = defaultdict(set) + for sgn in self.subgraph.nodes: + sgn_color = self._sgn_colors[sgn] + if sgn_color in self._node_compatibility: + gn_color = self._node_compatibility[sgn_color] + candidates[sgn].add(frozenset(self._gn_partitions[gn_color])) + else: + candidates[sgn].add(frozenset()) + candidates = dict(candidates) + for sgn, options in candidates.items(): + candidates[sgn] = frozenset(options) + return candidates + + @staticmethod + def _make_constraints(cosets): + """ + Turn cosets into constraints. + """ + constraints = [] + for node_i, node_ts in cosets.items(): + for node_t in node_ts: + if node_i != node_t: + # Node i must be smaller than node t. + constraints.append((node_i, node_t)) + return constraints + + @staticmethod + def _find_node_edge_color(graph, node_colors, edge_colors): + """ + For every node in graph, come up with a color that combines 1) the + color of the node, and 2) the number of edges of a color to each type + of node. + """ + counts = defaultdict(lambda: defaultdict(int)) + for node1, node2 in graph.edges: + if (node1, node2) in edge_colors: + # FIXME directed graphs + ecolor = edge_colors[node1, node2] + else: + ecolor = edge_colors[node2, node1] + # Count per node how many edges it has of what color to nodes of + # what color + counts[node1][ecolor, node_colors[node2]] += 1 + counts[node2][ecolor, node_colors[node1]] += 1 + + node_edge_colors = {} + for node in graph.nodes: + node_edge_colors[node] = node_colors[node], set(counts[node].items()) + + return node_edge_colors + + @staticmethod + def _get_permutations_by_length(items): + """ + Get all permutations of items, but only permute items with the same + length. + + >>> found = list(ISMAGS._get_permutations_by_length([[1], [2], [3, 4], [4, 5]])) + >>> answer = [ + ... (([1], [2]), ([3, 4], [4, 5])), + ... (([1], [2]), ([4, 5], [3, 4])), + ... (([2], [1]), ([3, 4], [4, 5])), + ... (([2], [1]), ([4, 5], [3, 4])), + ... ] + >>> found == answer + True + """ + by_len = defaultdict(list) + for item in items: + by_len[len(item)].append(item) + + yield from itertools.product( + *(itertools.permutations(by_len[l]) for l in sorted(by_len)) + ) + + @classmethod + def _refine_node_partitions(cls, graph, node_partitions, edge_colors, branch=False): + """ + Given a partition of nodes in graph, make the partitions smaller such + that all nodes in a partition have 1) the same color, and 2) the same + number of edges to specific other partitions. + """ + + def equal_color(node1, node2): + return node_edge_colors[node1] == node_edge_colors[node2] + + node_partitions = list(node_partitions) + node_colors = partition_to_color(node_partitions) + node_edge_colors = cls._find_node_edge_color(graph, node_colors, edge_colors) + if all( + are_all_equal(node_edge_colors[node] for node in partition) + for partition in node_partitions + ): + yield node_partitions + return + + new_partitions = [] + output = [new_partitions] + for partition in node_partitions: + if not are_all_equal(node_edge_colors[node] for node in partition): + refined = make_partitions(partition, equal_color) + if ( + branch + and len(refined) != 1 + and len({len(r) for r in refined}) != len([len(r) for r in refined]) + ): + # This is where it breaks. There are multiple new cells + # in refined with the same length, and their order + # matters. + # So option 1) Hit it with a big hammer and simply make all + # orderings. + permutations = cls._get_permutations_by_length(refined) + new_output = [] + for n_p in output: + for permutation in permutations: + new_output.append(n_p + list(permutation[0])) + output = new_output + else: + for n_p in output: + n_p.extend(sorted(refined, key=len)) + else: + for n_p in output: + n_p.append(partition) + for n_p in output: + yield from cls._refine_node_partitions(graph, n_p, edge_colors, branch) + + def _edges_of_same_color(self, sgn1, sgn2): + """ + Returns all edges in :attr:`graph` that have the same colour as the + edge between sgn1 and sgn2 in :attr:`subgraph`. + """ + if (sgn1, sgn2) in self._sge_colors: + # FIXME directed graphs + sge_color = self._sge_colors[sgn1, sgn2] + else: + sge_color = self._sge_colors[sgn2, sgn1] + if sge_color in self._edge_compatibility: + ge_color = self._edge_compatibility[sge_color] + g_edges = self._ge_partitions[ge_color] + else: + g_edges = [] + return g_edges + + def _map_nodes(self, sgn, candidates, constraints, mapping=None, to_be_mapped=None): + """ + Find all subgraph isomorphisms honoring constraints. + """ + if mapping is None: + mapping = {} + else: + mapping = mapping.copy() + if to_be_mapped is None: + to_be_mapped = set(self.subgraph.nodes) + + # Note, we modify candidates here. Doesn't seem to affect results, but + # remember this. + # candidates = candidates.copy() + sgn_candidates = intersect(candidates[sgn]) + candidates[sgn] = frozenset([sgn_candidates]) + for gn in sgn_candidates: + # We're going to try to map sgn to gn. + if gn in mapping.values() or sgn not in to_be_mapped: + # gn is already mapped to something + continue # pragma: no cover + + # REDUCTION and COMBINATION + mapping[sgn] = gn + # BASECASE + if to_be_mapped == set(mapping.keys()): + yield {v: k for k, v in mapping.items()} + continue + left_to_map = to_be_mapped - set(mapping.keys()) + + new_candidates = candidates.copy() + sgn_nbrs = set(self.subgraph[sgn]) + not_gn_nbrs = set(self.graph.nodes) - set(self.graph[gn]) + for sgn2 in left_to_map: + if sgn2 not in sgn_nbrs: + gn2_options = not_gn_nbrs + else: + # Get all edges to gn of the right color: + g_edges = self._edges_of_same_color(sgn, sgn2) + # FIXME directed graphs + # And all nodes involved in those which are connected to gn + gn2_options = {n for e in g_edges for n in e if gn in e} + # Node color compatibility should be taken care of by the + # initial candidate lists made by find_subgraphs + + # Add gn2_options to the right collection. Since new_candidates + # is a dict of frozensets of frozensets of node indices it's + # a bit clunky. We can't do .add, and + also doesn't work. We + # could do |, but I deem union to be clearer. + new_candidates[sgn2] = new_candidates[sgn2].union( + [frozenset(gn2_options)] + ) + + if (sgn, sgn2) in constraints: + gn2_options = {gn2 for gn2 in self.graph if gn2 > gn} + elif (sgn2, sgn) in constraints: + gn2_options = {gn2 for gn2 in self.graph if gn2 < gn} + else: + continue # pragma: no cover + new_candidates[sgn2] = new_candidates[sgn2].union( + [frozenset(gn2_options)] + ) + + # The next node is the one that is unmapped and has fewest + # candidates + next_sgn = min(left_to_map, key=lambda n: min(new_candidates[n], key=len)) + yield from self._map_nodes( + next_sgn, + new_candidates, + constraints, + mapping=mapping, + to_be_mapped=to_be_mapped, + ) + # Unmap sgn-gn. Strictly not necessary since it'd get overwritten + # when making a new mapping for sgn. + # del mapping[sgn] + + def _largest_common_subgraph(self, candidates, constraints, to_be_mapped=None): + """ + Find all largest common subgraphs honoring constraints. + """ + if to_be_mapped is None: + to_be_mapped = {frozenset(self.subgraph.nodes)} + + # The LCS problem is basically a repeated subgraph isomorphism problem + # with smaller and smaller subgraphs. We store the nodes that are + # "part of" the subgraph in to_be_mapped, and we make it a little + # smaller every iteration. + + current_size = len(next(iter(to_be_mapped), [])) + + found_iso = False + if current_size <= len(self.graph): + # There's no point in trying to find isomorphisms of + # graph >= subgraph if subgraph has more nodes than graph. + + # Try the isomorphism first with the nodes with lowest ID. So sort + # them. Those are more likely to be part of the final + # correspondence. This makes finding the first answer(s) faster. In + # theory. + for nodes in sorted(to_be_mapped, key=sorted): + # Find the isomorphism between subgraph[to_be_mapped] <= graph + next_sgn = min(nodes, key=lambda n: min(candidates[n], key=len)) + isomorphs = self._map_nodes( + next_sgn, candidates, constraints, to_be_mapped=nodes + ) + + # This is effectively `yield from isomorphs`, except that we look + # whether an item was yielded. + try: + item = next(isomorphs) + except StopIteration: + pass + else: + yield item + yield from isomorphs + found_iso = True + + # BASECASE + if found_iso or current_size == 1: + # Shrinking has no point because either 1) we end up with a smaller + # common subgraph (and we want the largest), or 2) there'll be no + # more subgraph. + return + + left_to_be_mapped = set() + for nodes in to_be_mapped: + for sgn in nodes: + # We're going to remove sgn from to_be_mapped, but subject to + # symmetry constraints. We know that for every constraint we + # have those subgraph nodes are equal. So whenever we would + # remove the lower part of a constraint, remove the higher + # instead. This is all dealth with by _remove_node. And because + # left_to_be_mapped is a set, we don't do double work. + + # And finally, make the subgraph one node smaller. + # REDUCTION + new_nodes = self._remove_node(sgn, nodes, constraints) + left_to_be_mapped.add(new_nodes) + # COMBINATION + yield from self._largest_common_subgraph( + candidates, constraints, to_be_mapped=left_to_be_mapped + ) + + @staticmethod + def _remove_node(node, nodes, constraints): + """ + Returns a new set where node has been removed from nodes, subject to + symmetry constraints. We know, that for every constraint we have + those subgraph nodes are equal. So whenever we would remove the + lower part of a constraint, remove the higher instead. + """ + while True: + for low, high in constraints: + if low == node and high in nodes: + node = high + break + else: # no break, couldn't find node in constraints + break + return frozenset(nodes - {node}) + + @staticmethod + def _find_permutations(top_partitions, bottom_partitions): + """ + Return the pairs of top/bottom partitions where the partitions are + different. Ensures that all partitions in both top and bottom + partitions have size 1. + """ + # Find permutations + permutations = set() + for top, bot in zip(top_partitions, bottom_partitions): + # top and bot have only one element + if len(top) != 1 or len(bot) != 1: + raise IndexError( + "Not all nodes are coupled. This is" + f" impossible: {top_partitions}, {bottom_partitions}" + ) + if top != bot: + permutations.add(frozenset((next(iter(top)), next(iter(bot))))) + return permutations + + @staticmethod + def _update_orbits(orbits, permutations): + """ + Update orbits based on permutations. Orbits is modified in place. + For every pair of items in permutations their respective orbits are + merged. + """ + for permutation in permutations: + node, node2 = permutation + # Find the orbits that contain node and node2, and replace the + # orbit containing node with the union + first = second = None + for idx, orbit in enumerate(orbits): + if first is not None and second is not None: + break + if node in orbit: + first = idx + if node2 in orbit: + second = idx + if first != second: + orbits[first].update(orbits[second]) + del orbits[second] + + def _couple_nodes( + self, + top_partitions, + bottom_partitions, + pair_idx, + t_node, + b_node, + graph, + edge_colors, + ): + """ + Generate new partitions from top and bottom_partitions where t_node is + coupled to b_node. pair_idx is the index of the partitions where t_ and + b_node can be found. + """ + t_partition = top_partitions[pair_idx] + b_partition = bottom_partitions[pair_idx] + assert t_node in t_partition and b_node in b_partition + # Couple node to node2. This means they get their own partition + new_top_partitions = [top.copy() for top in top_partitions] + new_bottom_partitions = [bot.copy() for bot in bottom_partitions] + new_t_groups = {t_node}, t_partition - {t_node} + new_b_groups = {b_node}, b_partition - {b_node} + # Replace the old partitions with the coupled ones + del new_top_partitions[pair_idx] + del new_bottom_partitions[pair_idx] + new_top_partitions[pair_idx:pair_idx] = new_t_groups + new_bottom_partitions[pair_idx:pair_idx] = new_b_groups + + new_top_partitions = self._refine_node_partitions( + graph, new_top_partitions, edge_colors + ) + new_bottom_partitions = self._refine_node_partitions( + graph, new_bottom_partitions, edge_colors, branch=True + ) + new_top_partitions = list(new_top_partitions) + assert len(new_top_partitions) == 1 + new_top_partitions = new_top_partitions[0] + for bot in new_bottom_partitions: + yield list(new_top_partitions), bot + + def _process_ordered_pair_partitions( + self, + graph, + top_partitions, + bottom_partitions, + edge_colors, + orbits=None, + cosets=None, + ): + """ + Processes ordered pair partitions as per the reference paper. Finds and + returns all permutations and cosets that leave the graph unchanged. + """ + if orbits is None: + orbits = [{node} for node in graph.nodes] + else: + # Note that we don't copy orbits when we are given one. This means + # we leak information between the recursive branches. This is + # intentional! + orbits = orbits + if cosets is None: + cosets = {} + else: + cosets = cosets.copy() + + assert all( + len(t_p) == len(b_p) for t_p, b_p in zip(top_partitions, bottom_partitions) + ) + + # BASECASE + if all(len(top) == 1 for top in top_partitions): + # All nodes are mapped + permutations = self._find_permutations(top_partitions, bottom_partitions) + self._update_orbits(orbits, permutations) + if permutations: + return [permutations], cosets + else: + return [], cosets + + permutations = [] + unmapped_nodes = { + (node, idx) + for idx, t_partition in enumerate(top_partitions) + for node in t_partition + if len(t_partition) > 1 + } + node, pair_idx = min(unmapped_nodes) + b_partition = bottom_partitions[pair_idx] + + for node2 in sorted(b_partition): + if len(b_partition) == 1: + # Can never result in symmetry + continue + if node != node2 and any( + node in orbit and node2 in orbit for orbit in orbits + ): + # Orbit prune branch + continue + # REDUCTION + # Couple node to node2 + partitions = self._couple_nodes( + top_partitions, + bottom_partitions, + pair_idx, + node, + node2, + graph, + edge_colors, + ) + for opp in partitions: + new_top_partitions, new_bottom_partitions = opp + + new_perms, new_cosets = self._process_ordered_pair_partitions( + graph, + new_top_partitions, + new_bottom_partitions, + edge_colors, + orbits, + cosets, + ) + # COMBINATION + permutations += new_perms + cosets.update(new_cosets) + + mapped = { + k + for top, bottom in zip(top_partitions, bottom_partitions) + for k in top + if len(top) == 1 and top == bottom + } + ks = {k for k in graph.nodes if k < node} + # Have all nodes with ID < node been mapped? + find_coset = ks <= mapped and node not in cosets + if find_coset: + # Find the orbit that contains node + for orbit in orbits: + if node in orbit: + cosets[node] = orbit.copy() + return permutations, cosets diff --git a/lib/python3.10/site-packages/networkx/algorithms/isomorphism/isomorph.py b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/isomorph.py new file mode 100644 index 0000000000000000000000000000000000000000..fc3a3fc6a50bf15c49ffc7e4b2b798413cb344b3 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/isomorph.py @@ -0,0 +1,249 @@ +""" +Graph isomorphism functions. +""" + +import networkx as nx +from networkx.exception import NetworkXError + +__all__ = [ + "could_be_isomorphic", + "fast_could_be_isomorphic", + "faster_could_be_isomorphic", + "is_isomorphic", +] + + +@nx._dispatchable(graphs={"G1": 0, "G2": 1}) +def could_be_isomorphic(G1, G2): + """Returns False if graphs are definitely not isomorphic. + True does NOT guarantee isomorphism. + + Parameters + ---------- + G1, G2 : graphs + The two graphs G1 and G2 must be the same type. + + Notes + ----- + Checks for matching degree, triangle, and number of cliques sequences. + The triangle sequence contains the number of triangles each node is part of. + The clique sequence contains for each node the number of maximal cliques + involving that node. + + """ + + # Check global properties + if G1.order() != G2.order(): + return False + + # Check local properties + d1 = G1.degree() + t1 = nx.triangles(G1) + clqs_1 = list(nx.find_cliques(G1)) + c1 = {n: sum(1 for c in clqs_1 if n in c) for n in G1} # number of cliques + props1 = [[d, t1[v], c1[v]] for v, d in d1] + props1.sort() + + d2 = G2.degree() + t2 = nx.triangles(G2) + clqs_2 = list(nx.find_cliques(G2)) + c2 = {n: sum(1 for c in clqs_2 if n in c) for n in G2} # number of cliques + props2 = [[d, t2[v], c2[v]] for v, d in d2] + props2.sort() + + if props1 != props2: + return False + + # OK... + return True + + +graph_could_be_isomorphic = could_be_isomorphic + + +@nx._dispatchable(graphs={"G1": 0, "G2": 1}) +def fast_could_be_isomorphic(G1, G2): + """Returns False if graphs are definitely not isomorphic. + + True does NOT guarantee isomorphism. + + Parameters + ---------- + G1, G2 : graphs + The two graphs G1 and G2 must be the same type. + + Notes + ----- + Checks for matching degree and triangle sequences. The triangle + sequence contains the number of triangles each node is part of. + """ + # Check global properties + if G1.order() != G2.order(): + return False + + # Check local properties + d1 = G1.degree() + t1 = nx.triangles(G1) + props1 = [[d, t1[v]] for v, d in d1] + props1.sort() + + d2 = G2.degree() + t2 = nx.triangles(G2) + props2 = [[d, t2[v]] for v, d in d2] + props2.sort() + + if props1 != props2: + return False + + # OK... + return True + + +fast_graph_could_be_isomorphic = fast_could_be_isomorphic + + +@nx._dispatchable(graphs={"G1": 0, "G2": 1}) +def faster_could_be_isomorphic(G1, G2): + """Returns False if graphs are definitely not isomorphic. + + True does NOT guarantee isomorphism. + + Parameters + ---------- + G1, G2 : graphs + The two graphs G1 and G2 must be the same type. + + Notes + ----- + Checks for matching degree sequences. + """ + # Check global properties + if G1.order() != G2.order(): + return False + + # Check local properties + d1 = sorted(d for n, d in G1.degree()) + d2 = sorted(d for n, d in G2.degree()) + + if d1 != d2: + return False + + # OK... + return True + + +faster_graph_could_be_isomorphic = faster_could_be_isomorphic + + +@nx._dispatchable( + graphs={"G1": 0, "G2": 1}, + preserve_edge_attrs="edge_match", + preserve_node_attrs="node_match", +) +def is_isomorphic(G1, G2, node_match=None, edge_match=None): + """Returns True if the graphs G1 and G2 are isomorphic and False otherwise. + + Parameters + ---------- + G1, G2: graphs + The two graphs G1 and G2 must be the same type. + + node_match : callable + A function that returns True if node n1 in G1 and n2 in G2 should + be considered equal during the isomorphism test. + If node_match is not specified then node attributes are not considered. + + The function will be called like + + node_match(G1.nodes[n1], G2.nodes[n2]). + + That is, the function will receive the node attribute dictionaries + for n1 and n2 as inputs. + + edge_match : callable + A function that returns True if the edge attribute dictionary + for the pair of nodes (u1, v1) in G1 and (u2, v2) in G2 should + be considered equal during the isomorphism test. If edge_match is + not specified then edge attributes are not considered. + + The function will be called like + + edge_match(G1[u1][v1], G2[u2][v2]). + + That is, the function will receive the edge attribute dictionaries + of the edges under consideration. + + Notes + ----- + Uses the vf2 algorithm [1]_. + + Examples + -------- + >>> import networkx.algorithms.isomorphism as iso + + For digraphs G1 and G2, using 'weight' edge attribute (default: 1) + + >>> G1 = nx.DiGraph() + >>> G2 = nx.DiGraph() + >>> nx.add_path(G1, [1, 2, 3, 4], weight=1) + >>> nx.add_path(G2, [10, 20, 30, 40], weight=2) + >>> em = iso.numerical_edge_match("weight", 1) + >>> nx.is_isomorphic(G1, G2) # no weights considered + True + >>> nx.is_isomorphic(G1, G2, edge_match=em) # match weights + False + + For multidigraphs G1 and G2, using 'fill' node attribute (default: '') + + >>> G1 = nx.MultiDiGraph() + >>> G2 = nx.MultiDiGraph() + >>> G1.add_nodes_from([1, 2, 3], fill="red") + >>> G2.add_nodes_from([10, 20, 30, 40], fill="red") + >>> nx.add_path(G1, [1, 2, 3, 4], weight=3, linewidth=2.5) + >>> nx.add_path(G2, [10, 20, 30, 40], weight=3) + >>> nm = iso.categorical_node_match("fill", "red") + >>> nx.is_isomorphic(G1, G2, node_match=nm) + True + + For multidigraphs G1 and G2, using 'weight' edge attribute (default: 7) + + >>> G1.add_edge(1, 2, weight=7) + 1 + >>> G2.add_edge(10, 20) + 1 + >>> em = iso.numerical_multiedge_match("weight", 7, rtol=1e-6) + >>> nx.is_isomorphic(G1, G2, edge_match=em) + True + + For multigraphs G1 and G2, using 'weight' and 'linewidth' edge attributes + with default values 7 and 2.5. Also using 'fill' node attribute with + default value 'red'. + + >>> em = iso.numerical_multiedge_match(["weight", "linewidth"], [7, 2.5]) + >>> nm = iso.categorical_node_match("fill", "red") + >>> nx.is_isomorphic(G1, G2, edge_match=em, node_match=nm) + True + + See Also + -------- + numerical_node_match, numerical_edge_match, numerical_multiedge_match + categorical_node_match, categorical_edge_match, categorical_multiedge_match + + References + ---------- + .. [1] L. P. Cordella, P. Foggia, C. Sansone, M. Vento, + "An Improved Algorithm for Matching Large Graphs", + 3rd IAPR-TC15 Workshop on Graph-based Representations in + Pattern Recognition, Cuen, pp. 149-159, 2001. + https://www.researchgate.net/publication/200034365_An_Improved_Algorithm_for_Matching_Large_Graphs + """ + if G1.is_directed() and G2.is_directed(): + GM = nx.algorithms.isomorphism.DiGraphMatcher + elif (not G1.is_directed()) and (not G2.is_directed()): + GM = nx.algorithms.isomorphism.GraphMatcher + else: + raise NetworkXError("Graphs G1 and G2 are not of the same type.") + + gm = GM(G1, G2, node_match=node_match, edge_match=edge_match) + + return gm.is_isomorphic() diff --git a/lib/python3.10/site-packages/networkx/algorithms/isomorphism/isomorphvf2.py b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/isomorphvf2.py new file mode 100644 index 0000000000000000000000000000000000000000..cb2f1e8f3e644221cbc19f36ff44c3d3f84a007b --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/isomorphvf2.py @@ -0,0 +1,1238 @@ +""" +************* +VF2 Algorithm +************* + +An implementation of VF2 algorithm for graph isomorphism testing. + +The simplest interface to use this module is to call the +:func:`is_isomorphic ` +function. + +Introduction +------------ + +The GraphMatcher and DiGraphMatcher are responsible for matching +graphs or directed graphs in a predetermined manner. This +usually means a check for an isomorphism, though other checks +are also possible. For example, a subgraph of one graph +can be checked for isomorphism to a second graph. + +Matching is done via syntactic feasibility. It is also possible +to check for semantic feasibility. Feasibility, then, is defined +as the logical AND of the two functions. + +To include a semantic check, the (Di)GraphMatcher class should be +subclassed, and the +:meth:`semantic_feasibility ` +function should be redefined. By default, the semantic feasibility function always +returns ``True``. The effect of this is that semantics are not +considered in the matching of G1 and G2. + +Examples +-------- + +Suppose G1 and G2 are isomorphic graphs. Verification is as follows: + +>>> from networkx.algorithms import isomorphism +>>> G1 = nx.path_graph(4) +>>> G2 = nx.path_graph(4) +>>> GM = isomorphism.GraphMatcher(G1, G2) +>>> GM.is_isomorphic() +True + +GM.mapping stores the isomorphism mapping from G1 to G2. + +>>> GM.mapping +{0: 0, 1: 1, 2: 2, 3: 3} + + +Suppose G1 and G2 are isomorphic directed graphs. +Verification is as follows: + +>>> G1 = nx.path_graph(4, create_using=nx.DiGraph) +>>> G2 = nx.path_graph(4, create_using=nx.DiGraph) +>>> DiGM = isomorphism.DiGraphMatcher(G1, G2) +>>> DiGM.is_isomorphic() +True + +DiGM.mapping stores the isomorphism mapping from G1 to G2. + +>>> DiGM.mapping +{0: 0, 1: 1, 2: 2, 3: 3} + + + +Subgraph Isomorphism +-------------------- +Graph theory literature can be ambiguous about the meaning of the +above statement, and we seek to clarify it now. + +In the VF2 literature, a mapping ``M`` is said to be a graph-subgraph +isomorphism iff ``M`` is an isomorphism between ``G2`` and a subgraph of ``G1``. +Thus, to say that ``G1`` and ``G2`` are graph-subgraph isomorphic is to say +that a subgraph of ``G1`` is isomorphic to ``G2``. + +Other literature uses the phrase 'subgraph isomorphic' as in '``G1`` does +not have a subgraph isomorphic to ``G2``'. Another use is as an in adverb +for isomorphic. Thus, to say that ``G1`` and ``G2`` are subgraph isomorphic +is to say that a subgraph of ``G1`` is isomorphic to ``G2``. + +Finally, the term 'subgraph' can have multiple meanings. In this +context, 'subgraph' always means a 'node-induced subgraph'. Edge-induced +subgraph isomorphisms are not directly supported, but one should be +able to perform the check by making use of +:func:`line_graph `. For +subgraphs which are not induced, the term 'monomorphism' is preferred +over 'isomorphism'. + +Let ``G = (N, E)`` be a graph with a set of nodes ``N`` and set of edges ``E``. + +If ``G' = (N', E')`` is a subgraph, then: + ``N'`` is a subset of ``N`` and + ``E'`` is a subset of ``E``. + +If ``G' = (N', E')`` is a node-induced subgraph, then: + ``N'`` is a subset of ``N`` and + ``E'`` is the subset of edges in ``E`` relating nodes in ``N'``. + +If ``G' = (N', E')`` is an edge-induced subgraph, then: + ``N'`` is the subset of nodes in ``N`` related by edges in ``E'`` and + ``E'`` is a subset of ``E``. + +If ``G' = (N', E')`` is a monomorphism, then: + ``N'`` is a subset of ``N`` and + ``E'`` is a subset of the set of edges in ``E`` relating nodes in ``N'``. + +Note that if ``G'`` is a node-induced subgraph of ``G``, then it is always a +subgraph monomorphism of ``G``, but the opposite is not always true, as a +monomorphism can have fewer edges. + +References +---------- +[1] Luigi P. Cordella, Pasquale Foggia, Carlo Sansone, Mario Vento, + "A (Sub)Graph Isomorphism Algorithm for Matching Large Graphs", + IEEE Transactions on Pattern Analysis and Machine Intelligence, + vol. 26, no. 10, pp. 1367-1372, Oct., 2004. + http://ieeexplore.ieee.org/iel5/34/29305/01323804.pdf + +[2] L. P. Cordella, P. Foggia, C. Sansone, M. Vento, "An Improved + Algorithm for Matching Large Graphs", 3rd IAPR-TC15 Workshop + on Graph-based Representations in Pattern Recognition, Cuen, + pp. 149-159, 2001. + https://www.researchgate.net/publication/200034365_An_Improved_Algorithm_for_Matching_Large_Graphs + +See Also +-------- +:meth:`semantic_feasibility ` +:meth:`syntactic_feasibility ` + +Notes +----- + +The implementation handles both directed and undirected graphs as well +as multigraphs. + +In general, the subgraph isomorphism problem is NP-complete whereas the +graph isomorphism problem is most likely not NP-complete (although no +polynomial-time algorithm is known to exist). + +""" + +# This work was originally coded by Christopher Ellison +# as part of the Computational Mechanics Python (CMPy) project. +# James P. Crutchfield, principal investigator. +# Complexity Sciences Center and Physics Department, UC Davis. + +import sys + +__all__ = ["GraphMatcher", "DiGraphMatcher"] + + +class GraphMatcher: + """Implementation of VF2 algorithm for matching undirected graphs. + + Suitable for Graph and MultiGraph instances. + """ + + def __init__(self, G1, G2): + """Initialize GraphMatcher. + + Parameters + ---------- + G1,G2: NetworkX Graph or MultiGraph instances. + The two graphs to check for isomorphism or monomorphism. + + Examples + -------- + To create a GraphMatcher which checks for syntactic feasibility: + + >>> from networkx.algorithms import isomorphism + >>> G1 = nx.path_graph(4) + >>> G2 = nx.path_graph(4) + >>> GM = isomorphism.GraphMatcher(G1, G2) + """ + self.G1 = G1 + self.G2 = G2 + self.G1_nodes = set(G1.nodes()) + self.G2_nodes = set(G2.nodes()) + self.G2_node_order = {n: i for i, n in enumerate(G2)} + + # Set recursion limit. + self.old_recursion_limit = sys.getrecursionlimit() + expected_max_recursion_level = len(self.G2) + if self.old_recursion_limit < 1.5 * expected_max_recursion_level: + # Give some breathing room. + sys.setrecursionlimit(int(1.5 * expected_max_recursion_level)) + + # Declare that we will be searching for a graph-graph isomorphism. + self.test = "graph" + + # Initialize state + self.initialize() + + def reset_recursion_limit(self): + """Restores the recursion limit.""" + # TODO: + # Currently, we use recursion and set the recursion level higher. + # It would be nice to restore the level, but because the + # (Di)GraphMatcher classes make use of cyclic references, garbage + # collection will never happen when we define __del__() to + # restore the recursion level. The result is a memory leak. + # So for now, we do not automatically restore the recursion level, + # and instead provide a method to do this manually. Eventually, + # we should turn this into a non-recursive implementation. + sys.setrecursionlimit(self.old_recursion_limit) + + def candidate_pairs_iter(self): + """Iterator over candidate pairs of nodes in G1 and G2.""" + + # All computations are done using the current state! + + G1_nodes = self.G1_nodes + G2_nodes = self.G2_nodes + min_key = self.G2_node_order.__getitem__ + + # First we compute the inout-terminal sets. + T1_inout = [node for node in self.inout_1 if node not in self.core_1] + T2_inout = [node for node in self.inout_2 if node not in self.core_2] + + # If T1_inout and T2_inout are both nonempty. + # P(s) = T1_inout x {min T2_inout} + if T1_inout and T2_inout: + node_2 = min(T2_inout, key=min_key) + for node_1 in T1_inout: + yield node_1, node_2 + + else: + # If T1_inout and T2_inout were both empty.... + # P(s) = (N_1 - M_1) x {min (N_2 - M_2)} + # if not (T1_inout or T2_inout): # as suggested by [2], incorrect + if 1: # as inferred from [1], correct + # First we determine the candidate node for G2 + other_node = min(G2_nodes - set(self.core_2), key=min_key) + for node in self.G1: + if node not in self.core_1: + yield node, other_node + + # For all other cases, we don't have any candidate pairs. + + def initialize(self): + """Reinitializes the state of the algorithm. + + This method should be redefined if using something other than GMState. + If only subclassing GraphMatcher, a redefinition is not necessary. + + """ + + # core_1[n] contains the index of the node paired with n, which is m, + # provided n is in the mapping. + # core_2[m] contains the index of the node paired with m, which is n, + # provided m is in the mapping. + self.core_1 = {} + self.core_2 = {} + + # See the paper for definitions of M_x and T_x^{y} + + # inout_1[n] is non-zero if n is in M_1 or in T_1^{inout} + # inout_2[m] is non-zero if m is in M_2 or in T_2^{inout} + # + # The value stored is the depth of the SSR tree when the node became + # part of the corresponding set. + self.inout_1 = {} + self.inout_2 = {} + # Practically, these sets simply store the nodes in the subgraph. + + self.state = GMState(self) + + # Provide a convenient way to access the isomorphism mapping. + self.mapping = self.core_1.copy() + + def is_isomorphic(self): + """Returns True if G1 and G2 are isomorphic graphs.""" + + # Let's do two very quick checks! + # QUESTION: Should we call faster_graph_could_be_isomorphic(G1,G2)? + # For now, I just copy the code. + + # Check global properties + if self.G1.order() != self.G2.order(): + return False + + # Check local properties + d1 = sorted(d for n, d in self.G1.degree()) + d2 = sorted(d for n, d in self.G2.degree()) + if d1 != d2: + return False + + try: + x = next(self.isomorphisms_iter()) + return True + except StopIteration: + return False + + def isomorphisms_iter(self): + """Generator over isomorphisms between G1 and G2.""" + # Declare that we are looking for a graph-graph isomorphism. + self.test = "graph" + self.initialize() + yield from self.match() + + def match(self): + """Extends the isomorphism mapping. + + This function is called recursively to determine if a complete + isomorphism can be found between G1 and G2. It cleans up the class + variables after each recursive call. If an isomorphism is found, + we yield the mapping. + + """ + if len(self.core_1) == len(self.G2): + # Save the final mapping, otherwise garbage collection deletes it. + self.mapping = self.core_1.copy() + # The mapping is complete. + yield self.mapping + else: + for G1_node, G2_node in self.candidate_pairs_iter(): + if self.syntactic_feasibility(G1_node, G2_node): + if self.semantic_feasibility(G1_node, G2_node): + # Recursive call, adding the feasible state. + newstate = self.state.__class__(self, G1_node, G2_node) + yield from self.match() + + # restore data structures + newstate.restore() + + def semantic_feasibility(self, G1_node, G2_node): + """Returns True if adding (G1_node, G2_node) is semantically feasible. + + The semantic feasibility function should return True if it is + acceptable to add the candidate pair (G1_node, G2_node) to the current + partial isomorphism mapping. The logic should focus on semantic + information contained in the edge data or a formalized node class. + + By acceptable, we mean that the subsequent mapping can still become a + complete isomorphism mapping. Thus, if adding the candidate pair + definitely makes it so that the subsequent mapping cannot become a + complete isomorphism mapping, then this function must return False. + + The default semantic feasibility function always returns True. The + effect is that semantics are not considered in the matching of G1 + and G2. + + The semantic checks might differ based on the what type of test is + being performed. A keyword description of the test is stored in + self.test. Here is a quick description of the currently implemented + tests:: + + test='graph' + Indicates that the graph matcher is looking for a graph-graph + isomorphism. + + test='subgraph' + Indicates that the graph matcher is looking for a subgraph-graph + isomorphism such that a subgraph of G1 is isomorphic to G2. + + test='mono' + Indicates that the graph matcher is looking for a subgraph-graph + monomorphism such that a subgraph of G1 is monomorphic to G2. + + Any subclass which redefines semantic_feasibility() must maintain + the above form to keep the match() method functional. Implementations + should consider multigraphs. + """ + return True + + def subgraph_is_isomorphic(self): + """Returns `True` if a subgraph of ``G1`` is isomorphic to ``G2``. + + Examples + -------- + When creating the `GraphMatcher`, the order of the arguments is important + + >>> G = nx.Graph([("A", "B"), ("B", "C"), ("A", "C")]) + >>> H = nx.Graph([(0, 1), (1, 2), (0, 2), (1, 3), (0, 4)]) + + Check whether a subgraph of G is isomorphic to H: + + >>> isomatcher = nx.isomorphism.GraphMatcher(G, H) + >>> isomatcher.subgraph_is_isomorphic() + False + + Check whether a subgraph of H is isomorphic to G: + + >>> isomatcher = nx.isomorphism.GraphMatcher(H, G) + >>> isomatcher.subgraph_is_isomorphic() + True + """ + try: + x = next(self.subgraph_isomorphisms_iter()) + return True + except StopIteration: + return False + + def subgraph_is_monomorphic(self): + """Returns `True` if a subgraph of ``G1`` is monomorphic to ``G2``. + + Examples + -------- + When creating the `GraphMatcher`, the order of the arguments is important. + + >>> G = nx.Graph([("A", "B"), ("B", "C")]) + >>> H = nx.Graph([(0, 1), (1, 2), (0, 2)]) + + Check whether a subgraph of G is monomorphic to H: + + >>> isomatcher = nx.isomorphism.GraphMatcher(G, H) + >>> isomatcher.subgraph_is_monomorphic() + False + + Check whether a subgraph of H is isomorphic to G: + + >>> isomatcher = nx.isomorphism.GraphMatcher(H, G) + >>> isomatcher.subgraph_is_monomorphic() + True + """ + try: + x = next(self.subgraph_monomorphisms_iter()) + return True + except StopIteration: + return False + + def subgraph_isomorphisms_iter(self): + """Generator over isomorphisms between a subgraph of ``G1`` and ``G2``. + + Examples + -------- + When creating the `GraphMatcher`, the order of the arguments is important + + >>> G = nx.Graph([("A", "B"), ("B", "C"), ("A", "C")]) + >>> H = nx.Graph([(0, 1), (1, 2), (0, 2), (1, 3), (0, 4)]) + + Yield isomorphic mappings between ``H`` and subgraphs of ``G``: + + >>> isomatcher = nx.isomorphism.GraphMatcher(G, H) + >>> list(isomatcher.subgraph_isomorphisms_iter()) + [] + + Yield isomorphic mappings between ``G`` and subgraphs of ``H``: + + >>> isomatcher = nx.isomorphism.GraphMatcher(H, G) + >>> next(isomatcher.subgraph_isomorphisms_iter()) + {0: 'A', 1: 'B', 2: 'C'} + + """ + # Declare that we are looking for graph-subgraph isomorphism. + self.test = "subgraph" + self.initialize() + yield from self.match() + + def subgraph_monomorphisms_iter(self): + """Generator over monomorphisms between a subgraph of ``G1`` and ``G2``. + + Examples + -------- + When creating the `GraphMatcher`, the order of the arguments is important. + + >>> G = nx.Graph([("A", "B"), ("B", "C")]) + >>> H = nx.Graph([(0, 1), (1, 2), (0, 2)]) + + Yield monomorphic mappings between ``H`` and subgraphs of ``G``: + + >>> isomatcher = nx.isomorphism.GraphMatcher(G, H) + >>> list(isomatcher.subgraph_monomorphisms_iter()) + [] + + Yield monomorphic mappings between ``G`` and subgraphs of ``H``: + + >>> isomatcher = nx.isomorphism.GraphMatcher(H, G) + >>> next(isomatcher.subgraph_monomorphisms_iter()) + {0: 'A', 1: 'B', 2: 'C'} + """ + # Declare that we are looking for graph-subgraph monomorphism. + self.test = "mono" + self.initialize() + yield from self.match() + + def syntactic_feasibility(self, G1_node, G2_node): + """Returns True if adding (G1_node, G2_node) is syntactically feasible. + + This function returns True if it is adding the candidate pair + to the current partial isomorphism/monomorphism mapping is allowable. + The addition is allowable if the inclusion of the candidate pair does + not make it impossible for an isomorphism/monomorphism to be found. + """ + + # The VF2 algorithm was designed to work with graphs having, at most, + # one edge connecting any two nodes. This is not the case when + # dealing with an MultiGraphs. + # + # Basically, when we test the look-ahead rules R_neighbor, we will + # make sure that the number of edges are checked. We also add + # a R_self check to verify that the number of selfloops is acceptable. + # + # Users might be comparing Graph instances with MultiGraph instances. + # So the generic GraphMatcher class must work with MultiGraphs. + # Care must be taken since the value in the innermost dictionary is a + # singlet for Graph instances. For MultiGraphs, the value in the + # innermost dictionary is a list. + + ### + # Test at each step to get a return value as soon as possible. + ### + + # Look ahead 0 + + # R_self + + # The number of selfloops for G1_node must equal the number of + # self-loops for G2_node. Without this check, we would fail on + # R_neighbor at the next recursion level. But it is good to prune the + # search tree now. + + if self.test == "mono": + if self.G1.number_of_edges(G1_node, G1_node) < self.G2.number_of_edges( + G2_node, G2_node + ): + return False + else: + if self.G1.number_of_edges(G1_node, G1_node) != self.G2.number_of_edges( + G2_node, G2_node + ): + return False + + # R_neighbor + + # For each neighbor n' of n in the partial mapping, the corresponding + # node m' is a neighbor of m, and vice versa. Also, the number of + # edges must be equal. + if self.test != "mono": + for neighbor in self.G1[G1_node]: + if neighbor in self.core_1: + if self.core_1[neighbor] not in self.G2[G2_node]: + return False + elif self.G1.number_of_edges( + neighbor, G1_node + ) != self.G2.number_of_edges(self.core_1[neighbor], G2_node): + return False + + for neighbor in self.G2[G2_node]: + if neighbor in self.core_2: + if self.core_2[neighbor] not in self.G1[G1_node]: + return False + elif self.test == "mono": + if self.G1.number_of_edges( + self.core_2[neighbor], G1_node + ) < self.G2.number_of_edges(neighbor, G2_node): + return False + else: + if self.G1.number_of_edges( + self.core_2[neighbor], G1_node + ) != self.G2.number_of_edges(neighbor, G2_node): + return False + + if self.test != "mono": + # Look ahead 1 + + # R_terminout + # The number of neighbors of n in T_1^{inout} is equal to the + # number of neighbors of m that are in T_2^{inout}, and vice versa. + num1 = 0 + for neighbor in self.G1[G1_node]: + if (neighbor in self.inout_1) and (neighbor not in self.core_1): + num1 += 1 + num2 = 0 + for neighbor in self.G2[G2_node]: + if (neighbor in self.inout_2) and (neighbor not in self.core_2): + num2 += 1 + if self.test == "graph": + if num1 != num2: + return False + else: # self.test == 'subgraph' + if not (num1 >= num2): + return False + + # Look ahead 2 + + # R_new + + # The number of neighbors of n that are neither in the core_1 nor + # T_1^{inout} is equal to the number of neighbors of m + # that are neither in core_2 nor T_2^{inout}. + num1 = 0 + for neighbor in self.G1[G1_node]: + if neighbor not in self.inout_1: + num1 += 1 + num2 = 0 + for neighbor in self.G2[G2_node]: + if neighbor not in self.inout_2: + num2 += 1 + if self.test == "graph": + if num1 != num2: + return False + else: # self.test == 'subgraph' + if not (num1 >= num2): + return False + + # Otherwise, this node pair is syntactically feasible! + return True + + +class DiGraphMatcher(GraphMatcher): + """Implementation of VF2 algorithm for matching directed graphs. + + Suitable for DiGraph and MultiDiGraph instances. + """ + + def __init__(self, G1, G2): + """Initialize DiGraphMatcher. + + G1 and G2 should be nx.Graph or nx.MultiGraph instances. + + Examples + -------- + To create a GraphMatcher which checks for syntactic feasibility: + + >>> from networkx.algorithms import isomorphism + >>> G1 = nx.DiGraph(nx.path_graph(4, create_using=nx.DiGraph())) + >>> G2 = nx.DiGraph(nx.path_graph(4, create_using=nx.DiGraph())) + >>> DiGM = isomorphism.DiGraphMatcher(G1, G2) + """ + super().__init__(G1, G2) + + def candidate_pairs_iter(self): + """Iterator over candidate pairs of nodes in G1 and G2.""" + + # All computations are done using the current state! + + G1_nodes = self.G1_nodes + G2_nodes = self.G2_nodes + min_key = self.G2_node_order.__getitem__ + + # First we compute the out-terminal sets. + T1_out = [node for node in self.out_1 if node not in self.core_1] + T2_out = [node for node in self.out_2 if node not in self.core_2] + + # If T1_out and T2_out are both nonempty. + # P(s) = T1_out x {min T2_out} + if T1_out and T2_out: + node_2 = min(T2_out, key=min_key) + for node_1 in T1_out: + yield node_1, node_2 + + # If T1_out and T2_out were both empty.... + # We compute the in-terminal sets. + + # elif not (T1_out or T2_out): # as suggested by [2], incorrect + else: # as suggested by [1], correct + T1_in = [node for node in self.in_1 if node not in self.core_1] + T2_in = [node for node in self.in_2 if node not in self.core_2] + + # If T1_in and T2_in are both nonempty. + # P(s) = T1_out x {min T2_out} + if T1_in and T2_in: + node_2 = min(T2_in, key=min_key) + for node_1 in T1_in: + yield node_1, node_2 + + # If all terminal sets are empty... + # P(s) = (N_1 - M_1) x {min (N_2 - M_2)} + + # elif not (T1_in or T2_in): # as suggested by [2], incorrect + else: # as inferred from [1], correct + node_2 = min(G2_nodes - set(self.core_2), key=min_key) + for node_1 in G1_nodes: + if node_1 not in self.core_1: + yield node_1, node_2 + + # For all other cases, we don't have any candidate pairs. + + def initialize(self): + """Reinitializes the state of the algorithm. + + This method should be redefined if using something other than DiGMState. + If only subclassing GraphMatcher, a redefinition is not necessary. + """ + + # core_1[n] contains the index of the node paired with n, which is m, + # provided n is in the mapping. + # core_2[m] contains the index of the node paired with m, which is n, + # provided m is in the mapping. + self.core_1 = {} + self.core_2 = {} + + # See the paper for definitions of M_x and T_x^{y} + + # in_1[n] is non-zero if n is in M_1 or in T_1^{in} + # out_1[n] is non-zero if n is in M_1 or in T_1^{out} + # + # in_2[m] is non-zero if m is in M_2 or in T_2^{in} + # out_2[m] is non-zero if m is in M_2 or in T_2^{out} + # + # The value stored is the depth of the search tree when the node became + # part of the corresponding set. + self.in_1 = {} + self.in_2 = {} + self.out_1 = {} + self.out_2 = {} + + self.state = DiGMState(self) + + # Provide a convenient way to access the isomorphism mapping. + self.mapping = self.core_1.copy() + + def syntactic_feasibility(self, G1_node, G2_node): + """Returns True if adding (G1_node, G2_node) is syntactically feasible. + + This function returns True if it is adding the candidate pair + to the current partial isomorphism/monomorphism mapping is allowable. + The addition is allowable if the inclusion of the candidate pair does + not make it impossible for an isomorphism/monomorphism to be found. + """ + + # The VF2 algorithm was designed to work with graphs having, at most, + # one edge connecting any two nodes. This is not the case when + # dealing with an MultiGraphs. + # + # Basically, when we test the look-ahead rules R_pred and R_succ, we + # will make sure that the number of edges are checked. We also add + # a R_self check to verify that the number of selfloops is acceptable. + + # Users might be comparing DiGraph instances with MultiDiGraph + # instances. So the generic DiGraphMatcher class must work with + # MultiDiGraphs. Care must be taken since the value in the innermost + # dictionary is a singlet for DiGraph instances. For MultiDiGraphs, + # the value in the innermost dictionary is a list. + + ### + # Test at each step to get a return value as soon as possible. + ### + + # Look ahead 0 + + # R_self + + # The number of selfloops for G1_node must equal the number of + # self-loops for G2_node. Without this check, we would fail on R_pred + # at the next recursion level. This should prune the tree even further. + if self.test == "mono": + if self.G1.number_of_edges(G1_node, G1_node) < self.G2.number_of_edges( + G2_node, G2_node + ): + return False + else: + if self.G1.number_of_edges(G1_node, G1_node) != self.G2.number_of_edges( + G2_node, G2_node + ): + return False + + # R_pred + + # For each predecessor n' of n in the partial mapping, the + # corresponding node m' is a predecessor of m, and vice versa. Also, + # the number of edges must be equal + if self.test != "mono": + for predecessor in self.G1.pred[G1_node]: + if predecessor in self.core_1: + if self.core_1[predecessor] not in self.G2.pred[G2_node]: + return False + elif self.G1.number_of_edges( + predecessor, G1_node + ) != self.G2.number_of_edges(self.core_1[predecessor], G2_node): + return False + + for predecessor in self.G2.pred[G2_node]: + if predecessor in self.core_2: + if self.core_2[predecessor] not in self.G1.pred[G1_node]: + return False + elif self.test == "mono": + if self.G1.number_of_edges( + self.core_2[predecessor], G1_node + ) < self.G2.number_of_edges(predecessor, G2_node): + return False + else: + if self.G1.number_of_edges( + self.core_2[predecessor], G1_node + ) != self.G2.number_of_edges(predecessor, G2_node): + return False + + # R_succ + + # For each successor n' of n in the partial mapping, the corresponding + # node m' is a successor of m, and vice versa. Also, the number of + # edges must be equal. + if self.test != "mono": + for successor in self.G1[G1_node]: + if successor in self.core_1: + if self.core_1[successor] not in self.G2[G2_node]: + return False + elif self.G1.number_of_edges( + G1_node, successor + ) != self.G2.number_of_edges(G2_node, self.core_1[successor]): + return False + + for successor in self.G2[G2_node]: + if successor in self.core_2: + if self.core_2[successor] not in self.G1[G1_node]: + return False + elif self.test == "mono": + if self.G1.number_of_edges( + G1_node, self.core_2[successor] + ) < self.G2.number_of_edges(G2_node, successor): + return False + else: + if self.G1.number_of_edges( + G1_node, self.core_2[successor] + ) != self.G2.number_of_edges(G2_node, successor): + return False + + if self.test != "mono": + # Look ahead 1 + + # R_termin + # The number of predecessors of n that are in T_1^{in} is equal to the + # number of predecessors of m that are in T_2^{in}. + num1 = 0 + for predecessor in self.G1.pred[G1_node]: + if (predecessor in self.in_1) and (predecessor not in self.core_1): + num1 += 1 + num2 = 0 + for predecessor in self.G2.pred[G2_node]: + if (predecessor in self.in_2) and (predecessor not in self.core_2): + num2 += 1 + if self.test == "graph": + if num1 != num2: + return False + else: # self.test == 'subgraph' + if not (num1 >= num2): + return False + + # The number of successors of n that are in T_1^{in} is equal to the + # number of successors of m that are in T_2^{in}. + num1 = 0 + for successor in self.G1[G1_node]: + if (successor in self.in_1) and (successor not in self.core_1): + num1 += 1 + num2 = 0 + for successor in self.G2[G2_node]: + if (successor in self.in_2) and (successor not in self.core_2): + num2 += 1 + if self.test == "graph": + if num1 != num2: + return False + else: # self.test == 'subgraph' + if not (num1 >= num2): + return False + + # R_termout + + # The number of predecessors of n that are in T_1^{out} is equal to the + # number of predecessors of m that are in T_2^{out}. + num1 = 0 + for predecessor in self.G1.pred[G1_node]: + if (predecessor in self.out_1) and (predecessor not in self.core_1): + num1 += 1 + num2 = 0 + for predecessor in self.G2.pred[G2_node]: + if (predecessor in self.out_2) and (predecessor not in self.core_2): + num2 += 1 + if self.test == "graph": + if num1 != num2: + return False + else: # self.test == 'subgraph' + if not (num1 >= num2): + return False + + # The number of successors of n that are in T_1^{out} is equal to the + # number of successors of m that are in T_2^{out}. + num1 = 0 + for successor in self.G1[G1_node]: + if (successor in self.out_1) and (successor not in self.core_1): + num1 += 1 + num2 = 0 + for successor in self.G2[G2_node]: + if (successor in self.out_2) and (successor not in self.core_2): + num2 += 1 + if self.test == "graph": + if num1 != num2: + return False + else: # self.test == 'subgraph' + if not (num1 >= num2): + return False + + # Look ahead 2 + + # R_new + + # The number of predecessors of n that are neither in the core_1 nor + # T_1^{in} nor T_1^{out} is equal to the number of predecessors of m + # that are neither in core_2 nor T_2^{in} nor T_2^{out}. + num1 = 0 + for predecessor in self.G1.pred[G1_node]: + if (predecessor not in self.in_1) and (predecessor not in self.out_1): + num1 += 1 + num2 = 0 + for predecessor in self.G2.pred[G2_node]: + if (predecessor not in self.in_2) and (predecessor not in self.out_2): + num2 += 1 + if self.test == "graph": + if num1 != num2: + return False + else: # self.test == 'subgraph' + if not (num1 >= num2): + return False + + # The number of successors of n that are neither in the core_1 nor + # T_1^{in} nor T_1^{out} is equal to the number of successors of m + # that are neither in core_2 nor T_2^{in} nor T_2^{out}. + num1 = 0 + for successor in self.G1[G1_node]: + if (successor not in self.in_1) and (successor not in self.out_1): + num1 += 1 + num2 = 0 + for successor in self.G2[G2_node]: + if (successor not in self.in_2) and (successor not in self.out_2): + num2 += 1 + if self.test == "graph": + if num1 != num2: + return False + else: # self.test == 'subgraph' + if not (num1 >= num2): + return False + + # Otherwise, this node pair is syntactically feasible! + return True + + def subgraph_is_isomorphic(self): + """Returns `True` if a subgraph of ``G1`` is isomorphic to ``G2``. + + Examples + -------- + When creating the `DiGraphMatcher`, the order of the arguments is important + + >>> G = nx.DiGraph([("A", "B"), ("B", "A"), ("B", "C"), ("C", "B")]) + >>> H = nx.DiGraph(nx.path_graph(5)) + + Check whether a subgraph of G is isomorphic to H: + + >>> isomatcher = nx.isomorphism.DiGraphMatcher(G, H) + >>> isomatcher.subgraph_is_isomorphic() + False + + Check whether a subgraph of H is isomorphic to G: + + >>> isomatcher = nx.isomorphism.DiGraphMatcher(H, G) + >>> isomatcher.subgraph_is_isomorphic() + True + """ + return super().subgraph_is_isomorphic() + + def subgraph_is_monomorphic(self): + """Returns `True` if a subgraph of ``G1`` is monomorphic to ``G2``. + + Examples + -------- + When creating the `DiGraphMatcher`, the order of the arguments is important. + + >>> G = nx.DiGraph([("A", "B"), ("C", "B"), ("D", "C")]) + >>> H = nx.DiGraph([(0, 1), (1, 2), (2, 3), (3, 2)]) + + Check whether a subgraph of G is monomorphic to H: + + >>> isomatcher = nx.isomorphism.DiGraphMatcher(G, H) + >>> isomatcher.subgraph_is_monomorphic() + False + + Check whether a subgraph of H is isomorphic to G: + + >>> isomatcher = nx.isomorphism.DiGraphMatcher(H, G) + >>> isomatcher.subgraph_is_monomorphic() + True + """ + return super().subgraph_is_monomorphic() + + def subgraph_isomorphisms_iter(self): + """Generator over isomorphisms between a subgraph of ``G1`` and ``G2``. + + Examples + -------- + When creating the `DiGraphMatcher`, the order of the arguments is important + + >>> G = nx.DiGraph([("B", "C"), ("C", "B"), ("C", "D"), ("D", "C")]) + >>> H = nx.DiGraph(nx.path_graph(5)) + + Yield isomorphic mappings between ``H`` and subgraphs of ``G``: + + >>> isomatcher = nx.isomorphism.DiGraphMatcher(G, H) + >>> list(isomatcher.subgraph_isomorphisms_iter()) + [] + + Yield isomorphic mappings between ``G`` and subgraphs of ``H``: + + >>> isomatcher = nx.isomorphism.DiGraphMatcher(H, G) + >>> next(isomatcher.subgraph_isomorphisms_iter()) + {0: 'B', 1: 'C', 2: 'D'} + """ + return super().subgraph_isomorphisms_iter() + + def subgraph_monomorphisms_iter(self): + """Generator over monomorphisms between a subgraph of ``G1`` and ``G2``. + + Examples + -------- + When creating the `DiGraphMatcher`, the order of the arguments is important. + + >>> G = nx.DiGraph([("A", "B"), ("C", "B"), ("D", "C")]) + >>> H = nx.DiGraph([(0, 1), (1, 2), (2, 3), (3, 2)]) + + Yield monomorphic mappings between ``H`` and subgraphs of ``G``: + + >>> isomatcher = nx.isomorphism.DiGraphMatcher(G, H) + >>> list(isomatcher.subgraph_monomorphisms_iter()) + [] + + Yield monomorphic mappings between ``G`` and subgraphs of ``H``: + + >>> isomatcher = nx.isomorphism.DiGraphMatcher(H, G) + >>> next(isomatcher.subgraph_monomorphisms_iter()) + {3: 'A', 2: 'B', 1: 'C', 0: 'D'} + """ + return super().subgraph_monomorphisms_iter() + + +class GMState: + """Internal representation of state for the GraphMatcher class. + + This class is used internally by the GraphMatcher class. It is used + only to store state specific data. There will be at most G2.order() of + these objects in memory at a time, due to the depth-first search + strategy employed by the VF2 algorithm. + """ + + def __init__(self, GM, G1_node=None, G2_node=None): + """Initializes GMState object. + + Pass in the GraphMatcher to which this GMState belongs and the + new node pair that will be added to the GraphMatcher's current + isomorphism mapping. + """ + self.GM = GM + + # Initialize the last stored node pair. + self.G1_node = None + self.G2_node = None + self.depth = len(GM.core_1) + + if G1_node is None or G2_node is None: + # Then we reset the class variables + GM.core_1 = {} + GM.core_2 = {} + GM.inout_1 = {} + GM.inout_2 = {} + + # Watch out! G1_node == 0 should evaluate to True. + if G1_node is not None and G2_node is not None: + # Add the node pair to the isomorphism mapping. + GM.core_1[G1_node] = G2_node + GM.core_2[G2_node] = G1_node + + # Store the node that was added last. + self.G1_node = G1_node + self.G2_node = G2_node + + # Now we must update the other two vectors. + # We will add only if it is not in there already! + self.depth = len(GM.core_1) + + # First we add the new nodes... + if G1_node not in GM.inout_1: + GM.inout_1[G1_node] = self.depth + if G2_node not in GM.inout_2: + GM.inout_2[G2_node] = self.depth + + # Now we add every other node... + + # Updates for T_1^{inout} + new_nodes = set() + for node in GM.core_1: + new_nodes.update( + [neighbor for neighbor in GM.G1[node] if neighbor not in GM.core_1] + ) + for node in new_nodes: + if node not in GM.inout_1: + GM.inout_1[node] = self.depth + + # Updates for T_2^{inout} + new_nodes = set() + for node in GM.core_2: + new_nodes.update( + [neighbor for neighbor in GM.G2[node] if neighbor not in GM.core_2] + ) + for node in new_nodes: + if node not in GM.inout_2: + GM.inout_2[node] = self.depth + + def restore(self): + """Deletes the GMState object and restores the class variables.""" + # First we remove the node that was added from the core vectors. + # Watch out! G1_node == 0 should evaluate to True. + if self.G1_node is not None and self.G2_node is not None: + del self.GM.core_1[self.G1_node] + del self.GM.core_2[self.G2_node] + + # Now we revert the other two vectors. + # Thus, we delete all entries which have this depth level. + for vector in (self.GM.inout_1, self.GM.inout_2): + for node in list(vector.keys()): + if vector[node] == self.depth: + del vector[node] + + +class DiGMState: + """Internal representation of state for the DiGraphMatcher class. + + This class is used internally by the DiGraphMatcher class. It is used + only to store state specific data. There will be at most G2.order() of + these objects in memory at a time, due to the depth-first search + strategy employed by the VF2 algorithm. + + """ + + def __init__(self, GM, G1_node=None, G2_node=None): + """Initializes DiGMState object. + + Pass in the DiGraphMatcher to which this DiGMState belongs and the + new node pair that will be added to the GraphMatcher's current + isomorphism mapping. + """ + self.GM = GM + + # Initialize the last stored node pair. + self.G1_node = None + self.G2_node = None + self.depth = len(GM.core_1) + + if G1_node is None or G2_node is None: + # Then we reset the class variables + GM.core_1 = {} + GM.core_2 = {} + GM.in_1 = {} + GM.in_2 = {} + GM.out_1 = {} + GM.out_2 = {} + + # Watch out! G1_node == 0 should evaluate to True. + if G1_node is not None and G2_node is not None: + # Add the node pair to the isomorphism mapping. + GM.core_1[G1_node] = G2_node + GM.core_2[G2_node] = G1_node + + # Store the node that was added last. + self.G1_node = G1_node + self.G2_node = G2_node + + # Now we must update the other four vectors. + # We will add only if it is not in there already! + self.depth = len(GM.core_1) + + # First we add the new nodes... + for vector in (GM.in_1, GM.out_1): + if G1_node not in vector: + vector[G1_node] = self.depth + for vector in (GM.in_2, GM.out_2): + if G2_node not in vector: + vector[G2_node] = self.depth + + # Now we add every other node... + + # Updates for T_1^{in} + new_nodes = set() + for node in GM.core_1: + new_nodes.update( + [ + predecessor + for predecessor in GM.G1.predecessors(node) + if predecessor not in GM.core_1 + ] + ) + for node in new_nodes: + if node not in GM.in_1: + GM.in_1[node] = self.depth + + # Updates for T_2^{in} + new_nodes = set() + for node in GM.core_2: + new_nodes.update( + [ + predecessor + for predecessor in GM.G2.predecessors(node) + if predecessor not in GM.core_2 + ] + ) + for node in new_nodes: + if node not in GM.in_2: + GM.in_2[node] = self.depth + + # Updates for T_1^{out} + new_nodes = set() + for node in GM.core_1: + new_nodes.update( + [ + successor + for successor in GM.G1.successors(node) + if successor not in GM.core_1 + ] + ) + for node in new_nodes: + if node not in GM.out_1: + GM.out_1[node] = self.depth + + # Updates for T_2^{out} + new_nodes = set() + for node in GM.core_2: + new_nodes.update( + [ + successor + for successor in GM.G2.successors(node) + if successor not in GM.core_2 + ] + ) + for node in new_nodes: + if node not in GM.out_2: + GM.out_2[node] = self.depth + + def restore(self): + """Deletes the DiGMState object and restores the class variables.""" + + # First we remove the node that was added from the core vectors. + # Watch out! G1_node == 0 should evaluate to True. + if self.G1_node is not None and self.G2_node is not None: + del self.GM.core_1[self.G1_node] + del self.GM.core_2[self.G2_node] + + # Now we revert the other four vectors. + # Thus, we delete all entries which have this depth level. + for vector in (self.GM.in_1, self.GM.in_2, self.GM.out_1, self.GM.out_2): + for node in list(vector.keys()): + if vector[node] == self.depth: + del vector[node] diff --git a/lib/python3.10/site-packages/networkx/algorithms/isomorphism/matchhelpers.py b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/matchhelpers.py new file mode 100644 index 0000000000000000000000000000000000000000..b48820d4d1896a8be1153f3e82feb2c3a5239761 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/matchhelpers.py @@ -0,0 +1,352 @@ +"""Functions which help end users define customize node_match and +edge_match functions to use during isomorphism checks. +""" + +import math +import types +from itertools import permutations + +__all__ = [ + "categorical_node_match", + "categorical_edge_match", + "categorical_multiedge_match", + "numerical_node_match", + "numerical_edge_match", + "numerical_multiedge_match", + "generic_node_match", + "generic_edge_match", + "generic_multiedge_match", +] + + +def copyfunc(f, name=None): + """Returns a deepcopy of a function.""" + return types.FunctionType( + f.__code__, f.__globals__, name or f.__name__, f.__defaults__, f.__closure__ + ) + + +def allclose(x, y, rtol=1.0000000000000001e-05, atol=1e-08): + """Returns True if x and y are sufficiently close, elementwise. + + Parameters + ---------- + rtol : float + The relative error tolerance. + atol : float + The absolute error tolerance. + + """ + # assume finite weights, see numpy.allclose() for reference + return all(math.isclose(xi, yi, rel_tol=rtol, abs_tol=atol) for xi, yi in zip(x, y)) + + +categorical_doc = """ +Returns a comparison function for a categorical node attribute. + +The value(s) of the attr(s) must be hashable and comparable via the == +operator since they are placed into a set([]) object. If the sets from +G1 and G2 are the same, then the constructed function returns True. + +Parameters +---------- +attr : string | list + The categorical node attribute to compare, or a list of categorical + node attributes to compare. +default : value | list + The default value for the categorical node attribute, or a list of + default values for the categorical node attributes. + +Returns +------- +match : function + The customized, categorical `node_match` function. + +Examples +-------- +>>> import networkx.algorithms.isomorphism as iso +>>> nm = iso.categorical_node_match("size", 1) +>>> nm = iso.categorical_node_match(["color", "size"], ["red", 2]) + +""" + + +def categorical_node_match(attr, default): + if isinstance(attr, str): + + def match(data1, data2): + return data1.get(attr, default) == data2.get(attr, default) + + else: + attrs = list(zip(attr, default)) # Python 3 + + def match(data1, data2): + return all(data1.get(attr, d) == data2.get(attr, d) for attr, d in attrs) + + return match + + +categorical_edge_match = copyfunc(categorical_node_match, "categorical_edge_match") + + +def categorical_multiedge_match(attr, default): + if isinstance(attr, str): + + def match(datasets1, datasets2): + values1 = {data.get(attr, default) for data in datasets1.values()} + values2 = {data.get(attr, default) for data in datasets2.values()} + return values1 == values2 + + else: + attrs = list(zip(attr, default)) # Python 3 + + def match(datasets1, datasets2): + values1 = set() + for data1 in datasets1.values(): + x = tuple(data1.get(attr, d) for attr, d in attrs) + values1.add(x) + values2 = set() + for data2 in datasets2.values(): + x = tuple(data2.get(attr, d) for attr, d in attrs) + values2.add(x) + return values1 == values2 + + return match + + +# Docstrings for categorical functions. +categorical_node_match.__doc__ = categorical_doc +categorical_edge_match.__doc__ = categorical_doc.replace("node", "edge") +tmpdoc = categorical_doc.replace("node", "edge") +tmpdoc = tmpdoc.replace("categorical_edge_match", "categorical_multiedge_match") +categorical_multiedge_match.__doc__ = tmpdoc + + +numerical_doc = """ +Returns a comparison function for a numerical node attribute. + +The value(s) of the attr(s) must be numerical and sortable. If the +sorted list of values from G1 and G2 are the same within some +tolerance, then the constructed function returns True. + +Parameters +---------- +attr : string | list + The numerical node attribute to compare, or a list of numerical + node attributes to compare. +default : value | list + The default value for the numerical node attribute, or a list of + default values for the numerical node attributes. +rtol : float + The relative error tolerance. +atol : float + The absolute error tolerance. + +Returns +------- +match : function + The customized, numerical `node_match` function. + +Examples +-------- +>>> import networkx.algorithms.isomorphism as iso +>>> nm = iso.numerical_node_match("weight", 1.0) +>>> nm = iso.numerical_node_match(["weight", "linewidth"], [0.25, 0.5]) + +""" + + +def numerical_node_match(attr, default, rtol=1.0000000000000001e-05, atol=1e-08): + if isinstance(attr, str): + + def match(data1, data2): + return math.isclose( + data1.get(attr, default), + data2.get(attr, default), + rel_tol=rtol, + abs_tol=atol, + ) + + else: + attrs = list(zip(attr, default)) # Python 3 + + def match(data1, data2): + values1 = [data1.get(attr, d) for attr, d in attrs] + values2 = [data2.get(attr, d) for attr, d in attrs] + return allclose(values1, values2, rtol=rtol, atol=atol) + + return match + + +numerical_edge_match = copyfunc(numerical_node_match, "numerical_edge_match") + + +def numerical_multiedge_match(attr, default, rtol=1.0000000000000001e-05, atol=1e-08): + if isinstance(attr, str): + + def match(datasets1, datasets2): + values1 = sorted(data.get(attr, default) for data in datasets1.values()) + values2 = sorted(data.get(attr, default) for data in datasets2.values()) + return allclose(values1, values2, rtol=rtol, atol=atol) + + else: + attrs = list(zip(attr, default)) # Python 3 + + def match(datasets1, datasets2): + values1 = [] + for data1 in datasets1.values(): + x = tuple(data1.get(attr, d) for attr, d in attrs) + values1.append(x) + values2 = [] + for data2 in datasets2.values(): + x = tuple(data2.get(attr, d) for attr, d in attrs) + values2.append(x) + values1.sort() + values2.sort() + for xi, yi in zip(values1, values2): + if not allclose(xi, yi, rtol=rtol, atol=atol): + return False + else: + return True + + return match + + +# Docstrings for numerical functions. +numerical_node_match.__doc__ = numerical_doc +numerical_edge_match.__doc__ = numerical_doc.replace("node", "edge") +tmpdoc = numerical_doc.replace("node", "edge") +tmpdoc = tmpdoc.replace("numerical_edge_match", "numerical_multiedge_match") +numerical_multiedge_match.__doc__ = tmpdoc + + +generic_doc = """ +Returns a comparison function for a generic attribute. + +The value(s) of the attr(s) are compared using the specified +operators. If all the attributes are equal, then the constructed +function returns True. + +Parameters +---------- +attr : string | list + The node attribute to compare, or a list of node attributes + to compare. +default : value | list + The default value for the node attribute, or a list of + default values for the node attributes. +op : callable | list + The operator to use when comparing attribute values, or a list + of operators to use when comparing values for each attribute. + +Returns +------- +match : function + The customized, generic `node_match` function. + +Examples +-------- +>>> from operator import eq +>>> from math import isclose +>>> from networkx.algorithms.isomorphism import generic_node_match +>>> nm = generic_node_match("weight", 1.0, isclose) +>>> nm = generic_node_match("color", "red", eq) +>>> nm = generic_node_match(["weight", "color"], [1.0, "red"], [isclose, eq]) + +""" + + +def generic_node_match(attr, default, op): + if isinstance(attr, str): + + def match(data1, data2): + return op(data1.get(attr, default), data2.get(attr, default)) + + else: + attrs = list(zip(attr, default, op)) # Python 3 + + def match(data1, data2): + for attr, d, operator in attrs: + if not operator(data1.get(attr, d), data2.get(attr, d)): + return False + else: + return True + + return match + + +generic_edge_match = copyfunc(generic_node_match, "generic_edge_match") + + +def generic_multiedge_match(attr, default, op): + """Returns a comparison function for a generic attribute. + + The value(s) of the attr(s) are compared using the specified + operators. If all the attributes are equal, then the constructed + function returns True. Potentially, the constructed edge_match + function can be slow since it must verify that no isomorphism + exists between the multiedges before it returns False. + + Parameters + ---------- + attr : string | list + The edge attribute to compare, or a list of node attributes + to compare. + default : value | list + The default value for the edge attribute, or a list of + default values for the edgeattributes. + op : callable | list + The operator to use when comparing attribute values, or a list + of operators to use when comparing values for each attribute. + + Returns + ------- + match : function + The customized, generic `edge_match` function. + + Examples + -------- + >>> from operator import eq + >>> from math import isclose + >>> from networkx.algorithms.isomorphism import generic_node_match + >>> nm = generic_node_match("weight", 1.0, isclose) + >>> nm = generic_node_match("color", "red", eq) + >>> nm = generic_node_match(["weight", "color"], [1.0, "red"], [isclose, eq]) + + """ + + # This is slow, but generic. + # We must test every possible isomorphism between the edges. + if isinstance(attr, str): + attr = [attr] + default = [default] + op = [op] + attrs = list(zip(attr, default)) # Python 3 + + def match(datasets1, datasets2): + values1 = [] + for data1 in datasets1.values(): + x = tuple(data1.get(attr, d) for attr, d in attrs) + values1.append(x) + values2 = [] + for data2 in datasets2.values(): + x = tuple(data2.get(attr, d) for attr, d in attrs) + values2.append(x) + for vals2 in permutations(values2): + for xi, yi in zip(values1, vals2): + if not all(map(lambda x, y, z: z(x, y), xi, yi, op)): + # This is not an isomorphism, go to next permutation. + break + else: + # Then we found an isomorphism. + return True + else: + # Then there are no isomorphisms between the multiedges. + return False + + return match + + +# Docstrings for numerical functions. +generic_node_match.__doc__ = generic_doc +generic_edge_match.__doc__ = generic_doc.replace("node", "edge") diff --git a/lib/python3.10/site-packages/networkx/algorithms/isomorphism/temporalisomorphvf2.py b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/temporalisomorphvf2.py new file mode 100644 index 0000000000000000000000000000000000000000..62cacc77887efa99026c117687bb9ad82cebd4dd --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/temporalisomorphvf2.py @@ -0,0 +1,308 @@ +""" +***************************** +Time-respecting VF2 Algorithm +***************************** + +An extension of the VF2 algorithm for time-respecting graph isomorphism +testing in temporal graphs. + +A temporal graph is one in which edges contain a datetime attribute, +denoting when interaction occurred between the incident nodes. A +time-respecting subgraph of a temporal graph is a subgraph such that +all interactions incident to a node occurred within a time threshold, +delta, of each other. A directed time-respecting subgraph has the +added constraint that incoming interactions to a node must precede +outgoing interactions from the same node - this enforces a sense of +directed flow. + +Introduction +------------ + +The TimeRespectingGraphMatcher and TimeRespectingDiGraphMatcher +extend the GraphMatcher and DiGraphMatcher classes, respectively, +to include temporal constraints on matches. This is achieved through +a semantic check, via the semantic_feasibility() function. + +As well as including G1 (the graph in which to seek embeddings) and +G2 (the subgraph structure of interest), the name of the temporal +attribute on the edges and the time threshold, delta, must be supplied +as arguments to the matching constructors. + +A delta of zero is the strictest temporal constraint on the match - +only embeddings in which all interactions occur at the same time will +be returned. A delta of one day will allow embeddings in which +adjacent interactions occur up to a day apart. + +Examples +-------- + +Examples will be provided when the datetime type has been incorporated. + + +Temporal Subgraph Isomorphism +----------------------------- + +A brief discussion of the somewhat diverse current literature will be +included here. + +References +---------- + +[1] Redmond, U. and Cunningham, P. Temporal subgraph isomorphism. In: +The 2013 IEEE/ACM International Conference on Advances in Social +Networks Analysis and Mining (ASONAM). Niagara Falls, Canada; 2013: +pages 1451 - 1452. [65] + +For a discussion of the literature on temporal networks: + +[3] P. Holme and J. Saramaki. Temporal networks. Physics Reports, +519(3):97–125, 2012. + +Notes +----- + +Handles directed and undirected graphs and graphs with parallel edges. + +""" + +import networkx as nx + +from .isomorphvf2 import DiGraphMatcher, GraphMatcher + +__all__ = ["TimeRespectingGraphMatcher", "TimeRespectingDiGraphMatcher"] + + +class TimeRespectingGraphMatcher(GraphMatcher): + def __init__(self, G1, G2, temporal_attribute_name, delta): + """Initialize TimeRespectingGraphMatcher. + + G1 and G2 should be nx.Graph or nx.MultiGraph instances. + + Examples + -------- + To create a TimeRespectingGraphMatcher which checks for + syntactic and semantic feasibility: + + >>> from networkx.algorithms import isomorphism + >>> from datetime import timedelta + >>> G1 = nx.Graph(nx.path_graph(4, create_using=nx.Graph())) + + >>> G2 = nx.Graph(nx.path_graph(4, create_using=nx.Graph())) + + >>> GM = isomorphism.TimeRespectingGraphMatcher( + ... G1, G2, "date", timedelta(days=1) + ... ) + """ + self.temporal_attribute_name = temporal_attribute_name + self.delta = delta + super().__init__(G1, G2) + + def one_hop(self, Gx, Gx_node, neighbors): + """ + Edges one hop out from a node in the mapping should be + time-respecting with respect to each other. + """ + dates = [] + for n in neighbors: + if isinstance(Gx, nx.Graph): # Graph G[u][v] returns the data dictionary. + dates.append(Gx[Gx_node][n][self.temporal_attribute_name]) + else: # MultiGraph G[u][v] returns a dictionary of key -> data dictionary. + for edge in Gx[Gx_node][ + n + ].values(): # Iterates all edges between node pair. + dates.append(edge[self.temporal_attribute_name]) + if any(x is None for x in dates): + raise ValueError("Datetime not supplied for at least one edge.") + return not dates or max(dates) - min(dates) <= self.delta + + def two_hop(self, Gx, core_x, Gx_node, neighbors): + """ + Paths of length 2 from Gx_node should be time-respecting. + """ + return all( + self.one_hop(Gx, v, [n for n in Gx[v] if n in core_x] + [Gx_node]) + for v in neighbors + ) + + def semantic_feasibility(self, G1_node, G2_node): + """Returns True if adding (G1_node, G2_node) is semantically + feasible. + + Any subclass which redefines semantic_feasibility() must + maintain the self.tests if needed, to keep the match() method + functional. Implementations should consider multigraphs. + """ + neighbors = [n for n in self.G1[G1_node] if n in self.core_1] + if not self.one_hop(self.G1, G1_node, neighbors): # Fail fast on first node. + return False + if not self.two_hop(self.G1, self.core_1, G1_node, neighbors): + return False + # Otherwise, this node is semantically feasible! + return True + + +class TimeRespectingDiGraphMatcher(DiGraphMatcher): + def __init__(self, G1, G2, temporal_attribute_name, delta): + """Initialize TimeRespectingDiGraphMatcher. + + G1 and G2 should be nx.DiGraph or nx.MultiDiGraph instances. + + Examples + -------- + To create a TimeRespectingDiGraphMatcher which checks for + syntactic and semantic feasibility: + + >>> from networkx.algorithms import isomorphism + >>> from datetime import timedelta + >>> G1 = nx.DiGraph(nx.path_graph(4, create_using=nx.DiGraph())) + + >>> G2 = nx.DiGraph(nx.path_graph(4, create_using=nx.DiGraph())) + + >>> GM = isomorphism.TimeRespectingDiGraphMatcher( + ... G1, G2, "date", timedelta(days=1) + ... ) + """ + self.temporal_attribute_name = temporal_attribute_name + self.delta = delta + super().__init__(G1, G2) + + def get_pred_dates(self, Gx, Gx_node, core_x, pred): + """ + Get the dates of edges from predecessors. + """ + pred_dates = [] + if isinstance(Gx, nx.DiGraph): # Graph G[u][v] returns the data dictionary. + for n in pred: + pred_dates.append(Gx[n][Gx_node][self.temporal_attribute_name]) + else: # MultiGraph G[u][v] returns a dictionary of key -> data dictionary. + for n in pred: + for edge in Gx[n][ + Gx_node + ].values(): # Iterates all edge data between node pair. + pred_dates.append(edge[self.temporal_attribute_name]) + return pred_dates + + def get_succ_dates(self, Gx, Gx_node, core_x, succ): + """ + Get the dates of edges to successors. + """ + succ_dates = [] + if isinstance(Gx, nx.DiGraph): # Graph G[u][v] returns the data dictionary. + for n in succ: + succ_dates.append(Gx[Gx_node][n][self.temporal_attribute_name]) + else: # MultiGraph G[u][v] returns a dictionary of key -> data dictionary. + for n in succ: + for edge in Gx[Gx_node][ + n + ].values(): # Iterates all edge data between node pair. + succ_dates.append(edge[self.temporal_attribute_name]) + return succ_dates + + def one_hop(self, Gx, Gx_node, core_x, pred, succ): + """ + The ego node. + """ + pred_dates = self.get_pred_dates(Gx, Gx_node, core_x, pred) + succ_dates = self.get_succ_dates(Gx, Gx_node, core_x, succ) + return self.test_one(pred_dates, succ_dates) and self.test_two( + pred_dates, succ_dates + ) + + def two_hop_pred(self, Gx, Gx_node, core_x, pred): + """ + The predecessors of the ego node. + """ + return all( + self.one_hop( + Gx, + p, + core_x, + self.preds(Gx, core_x, p), + self.succs(Gx, core_x, p, Gx_node), + ) + for p in pred + ) + + def two_hop_succ(self, Gx, Gx_node, core_x, succ): + """ + The successors of the ego node. + """ + return all( + self.one_hop( + Gx, + s, + core_x, + self.preds(Gx, core_x, s, Gx_node), + self.succs(Gx, core_x, s), + ) + for s in succ + ) + + def preds(self, Gx, core_x, v, Gx_node=None): + pred = [n for n in Gx.predecessors(v) if n in core_x] + if Gx_node: + pred.append(Gx_node) + return pred + + def succs(self, Gx, core_x, v, Gx_node=None): + succ = [n for n in Gx.successors(v) if n in core_x] + if Gx_node: + succ.append(Gx_node) + return succ + + def test_one(self, pred_dates, succ_dates): + """ + Edges one hop out from Gx_node in the mapping should be + time-respecting with respect to each other, regardless of + direction. + """ + time_respecting = True + dates = pred_dates + succ_dates + + if any(x is None for x in dates): + raise ValueError("Date or datetime not supplied for at least one edge.") + + dates.sort() # Small to large. + if 0 < len(dates) and not (dates[-1] - dates[0] <= self.delta): + time_respecting = False + return time_respecting + + def test_two(self, pred_dates, succ_dates): + """ + Edges from a dual Gx_node in the mapping should be ordered in + a time-respecting manner. + """ + time_respecting = True + pred_dates.sort() + succ_dates.sort() + # First out before last in; negative of the necessary condition for time-respect. + if ( + 0 < len(succ_dates) + and 0 < len(pred_dates) + and succ_dates[0] < pred_dates[-1] + ): + time_respecting = False + return time_respecting + + def semantic_feasibility(self, G1_node, G2_node): + """Returns True if adding (G1_node, G2_node) is semantically + feasible. + + Any subclass which redefines semantic_feasibility() must + maintain the self.tests if needed, to keep the match() method + functional. Implementations should consider multigraphs. + """ + pred, succ = ( + [n for n in self.G1.predecessors(G1_node) if n in self.core_1], + [n for n in self.G1.successors(G1_node) if n in self.core_1], + ) + if not self.one_hop( + self.G1, G1_node, self.core_1, pred, succ + ): # Fail fast on first node. + return False + if not self.two_hop_pred(self.G1, G1_node, self.core_1, pred): + return False + if not self.two_hop_succ(self.G1, G1_node, self.core_1, succ): + return False + # Otherwise, this node is semantically feasible! + return True diff --git a/lib/python3.10/site-packages/networkx/algorithms/isomorphism/tests/__init__.py b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/python3.10/site-packages/networkx/algorithms/isomorphism/tests/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/tests/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4eada62dad5355d5d690ef82f4edb0046b318be3 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/tests/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/isomorphism/tests/__pycache__/test_ismags.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/tests/__pycache__/test_ismags.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6717e62987f4eb63dc92813ea7583ba13898b805 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/tests/__pycache__/test_ismags.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/isomorphism/tests/__pycache__/test_isomorphism.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/tests/__pycache__/test_isomorphism.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..104c731bc169f209b06aea006a8e044a500aef03 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/tests/__pycache__/test_isomorphism.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/isomorphism/tests/__pycache__/test_isomorphvf2.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/tests/__pycache__/test_isomorphvf2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1213620bc3982f97e12b2be7df7b9a426c05298 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/tests/__pycache__/test_isomorphvf2.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/isomorphism/tests/__pycache__/test_match_helpers.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/tests/__pycache__/test_match_helpers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..94d9e523662e7de152d461467e10e4196be30b75 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/tests/__pycache__/test_match_helpers.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/isomorphism/tests/__pycache__/test_temporalisomorphvf2.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/tests/__pycache__/test_temporalisomorphvf2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bdc5a893e7b1e754abb3c20e1ce29653745e3815 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/tests/__pycache__/test_temporalisomorphvf2.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/isomorphism/tests/__pycache__/test_tree_isomorphism.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/tests/__pycache__/test_tree_isomorphism.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9e3c8e2c0e01a222a58c9e55eb0a5fd872ed494e Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/tests/__pycache__/test_tree_isomorphism.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/isomorphism/tests/__pycache__/test_vf2pp.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/tests/__pycache__/test_vf2pp.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c863031a6e2d3c3266807a9eed5e2bfc2a8f0c4 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/tests/__pycache__/test_vf2pp.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/isomorphism/tests/__pycache__/test_vf2pp_helpers.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/tests/__pycache__/test_vf2pp_helpers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef5f90fd941c4945feb0b7f57df8bd8c5a415528 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/tests/__pycache__/test_vf2pp_helpers.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/isomorphism/tests/__pycache__/test_vf2userfunc.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/tests/__pycache__/test_vf2userfunc.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..143c044db7c8eec80c3233b6a733340d511ff9be Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/tests/__pycache__/test_vf2userfunc.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/isomorphism/tests/iso_r01_s80.A99 b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/tests/iso_r01_s80.A99 new file mode 100644 index 0000000000000000000000000000000000000000..dac54f0038c70e2d359ffa68de3c7641b46db21a Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/tests/iso_r01_s80.A99 differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/isomorphism/tests/iso_r01_s80.B99 b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/tests/iso_r01_s80.B99 new file mode 100644 index 0000000000000000000000000000000000000000..6c6af680031b4f30fc6da946d0344b5c27e5f05e Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/tests/iso_r01_s80.B99 differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/isomorphism/tests/si2_b06_m200.A99 b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/tests/si2_b06_m200.A99 new file mode 100644 index 0000000000000000000000000000000000000000..60c3a3ce1bdb54a61ead043f0adaf24b1fe24e93 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/tests/si2_b06_m200.A99 differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/isomorphism/tests/si2_b06_m200.B99 b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/tests/si2_b06_m200.B99 new file mode 100644 index 0000000000000000000000000000000000000000..0236872094d4c73b0bb132165ce3ec4d1054f5f5 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/tests/si2_b06_m200.B99 differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/isomorphism/tests/test_ismags.py b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/tests/test_ismags.py new file mode 100644 index 0000000000000000000000000000000000000000..bc4070ac5bd5bca33030e537d223373093170040 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/tests/test_ismags.py @@ -0,0 +1,327 @@ +""" +Tests for ISMAGS isomorphism algorithm. +""" + +import pytest + +import networkx as nx +from networkx.algorithms import isomorphism as iso + + +def _matches_to_sets(matches): + """ + Helper function to facilitate comparing collections of dictionaries in + which order does not matter. + """ + return {frozenset(m.items()) for m in matches} + + +class TestSelfIsomorphism: + data = [ + ( + [ + (0, {"name": "a"}), + (1, {"name": "a"}), + (2, {"name": "b"}), + (3, {"name": "b"}), + (4, {"name": "a"}), + (5, {"name": "a"}), + ], + [(0, 1), (1, 2), (2, 3), (3, 4), (4, 5)], + ), + (range(1, 5), [(1, 2), (2, 4), (4, 3), (3, 1)]), + ( + [], + [ + (0, 1), + (1, 2), + (2, 3), + (3, 4), + (4, 5), + (5, 0), + (0, 6), + (6, 7), + (2, 8), + (8, 9), + (4, 10), + (10, 11), + ], + ), + ([], [(0, 1), (1, 2), (1, 4), (2, 3), (3, 5), (3, 6)]), + ] + + def test_self_isomorphism(self): + """ + For some small, symmetric graphs, make sure that 1) they are isomorphic + to themselves, and 2) that only the identity mapping is found. + """ + for node_data, edge_data in self.data: + graph = nx.Graph() + graph.add_nodes_from(node_data) + graph.add_edges_from(edge_data) + + ismags = iso.ISMAGS( + graph, graph, node_match=iso.categorical_node_match("name", None) + ) + assert ismags.is_isomorphic() + assert ismags.subgraph_is_isomorphic() + assert list(ismags.subgraph_isomorphisms_iter(symmetry=True)) == [ + {n: n for n in graph.nodes} + ] + + def test_edgecase_self_isomorphism(self): + """ + This edgecase is one of the cases in which it is hard to find all + symmetry elements. + """ + graph = nx.Graph() + nx.add_path(graph, range(5)) + graph.add_edges_from([(2, 5), (5, 6)]) + + ismags = iso.ISMAGS(graph, graph) + ismags_answer = list(ismags.find_isomorphisms(True)) + assert ismags_answer == [{n: n for n in graph.nodes}] + + graph = nx.relabel_nodes(graph, {0: 0, 1: 1, 2: 2, 3: 3, 4: 6, 5: 4, 6: 5}) + ismags = iso.ISMAGS(graph, graph) + ismags_answer = list(ismags.find_isomorphisms(True)) + assert ismags_answer == [{n: n for n in graph.nodes}] + + def test_directed_self_isomorphism(self): + """ + For some small, directed, symmetric graphs, make sure that 1) they are + isomorphic to themselves, and 2) that only the identity mapping is + found. + """ + for node_data, edge_data in self.data: + graph = nx.Graph() + graph.add_nodes_from(node_data) + graph.add_edges_from(edge_data) + + ismags = iso.ISMAGS( + graph, graph, node_match=iso.categorical_node_match("name", None) + ) + assert ismags.is_isomorphic() + assert ismags.subgraph_is_isomorphic() + assert list(ismags.subgraph_isomorphisms_iter(symmetry=True)) == [ + {n: n for n in graph.nodes} + ] + + +class TestSubgraphIsomorphism: + def test_isomorphism(self): + g1 = nx.Graph() + nx.add_cycle(g1, range(4)) + + g2 = nx.Graph() + nx.add_cycle(g2, range(4)) + g2.add_edges_from(list(zip(g2, range(4, 8)))) + ismags = iso.ISMAGS(g2, g1) + assert list(ismags.subgraph_isomorphisms_iter(symmetry=True)) == [ + {n: n for n in g1.nodes} + ] + + def test_isomorphism2(self): + g1 = nx.Graph() + nx.add_path(g1, range(3)) + + g2 = g1.copy() + g2.add_edge(1, 3) + + ismags = iso.ISMAGS(g2, g1) + matches = ismags.subgraph_isomorphisms_iter(symmetry=True) + expected_symmetric = [ + {0: 0, 1: 1, 2: 2}, + {0: 0, 1: 1, 3: 2}, + {2: 0, 1: 1, 3: 2}, + ] + assert _matches_to_sets(matches) == _matches_to_sets(expected_symmetric) + + matches = ismags.subgraph_isomorphisms_iter(symmetry=False) + expected_asymmetric = [ + {0: 2, 1: 1, 2: 0}, + {0: 2, 1: 1, 3: 0}, + {2: 2, 1: 1, 3: 0}, + ] + assert _matches_to_sets(matches) == _matches_to_sets( + expected_symmetric + expected_asymmetric + ) + + def test_labeled_nodes(self): + g1 = nx.Graph() + nx.add_cycle(g1, range(3)) + g1.nodes[1]["attr"] = True + + g2 = g1.copy() + g2.add_edge(1, 3) + ismags = iso.ISMAGS(g2, g1, node_match=lambda x, y: x == y) + matches = ismags.subgraph_isomorphisms_iter(symmetry=True) + expected_symmetric = [{0: 0, 1: 1, 2: 2}] + assert _matches_to_sets(matches) == _matches_to_sets(expected_symmetric) + + matches = ismags.subgraph_isomorphisms_iter(symmetry=False) + expected_asymmetric = [{0: 2, 1: 1, 2: 0}] + assert _matches_to_sets(matches) == _matches_to_sets( + expected_symmetric + expected_asymmetric + ) + + def test_labeled_edges(self): + g1 = nx.Graph() + nx.add_cycle(g1, range(3)) + g1.edges[1, 2]["attr"] = True + + g2 = g1.copy() + g2.add_edge(1, 3) + ismags = iso.ISMAGS(g2, g1, edge_match=lambda x, y: x == y) + matches = ismags.subgraph_isomorphisms_iter(symmetry=True) + expected_symmetric = [{0: 0, 1: 1, 2: 2}] + assert _matches_to_sets(matches) == _matches_to_sets(expected_symmetric) + + matches = ismags.subgraph_isomorphisms_iter(symmetry=False) + expected_asymmetric = [{1: 2, 0: 0, 2: 1}] + assert _matches_to_sets(matches) == _matches_to_sets( + expected_symmetric + expected_asymmetric + ) + + +class TestWikipediaExample: + # Nodes 'a', 'b', 'c' and 'd' form a column. + # Nodes 'g', 'h', 'i' and 'j' form a column. + g1edges = [ + ["a", "g"], + ["a", "h"], + ["a", "i"], + ["b", "g"], + ["b", "h"], + ["b", "j"], + ["c", "g"], + ["c", "i"], + ["c", "j"], + ["d", "h"], + ["d", "i"], + ["d", "j"], + ] + + # Nodes 1,2,3,4 form the clockwise corners of a large square. + # Nodes 5,6,7,8 form the clockwise corners of a small square + g2edges = [ + [1, 2], + [2, 3], + [3, 4], + [4, 1], + [5, 6], + [6, 7], + [7, 8], + [8, 5], + [1, 5], + [2, 6], + [3, 7], + [4, 8], + ] + + def test_graph(self): + g1 = nx.Graph() + g2 = nx.Graph() + g1.add_edges_from(self.g1edges) + g2.add_edges_from(self.g2edges) + gm = iso.ISMAGS(g1, g2) + assert gm.is_isomorphic() + + +class TestLargestCommonSubgraph: + def test_mcis(self): + # Example graphs from DOI: 10.1002/spe.588 + graph1 = nx.Graph() + graph1.add_edges_from([(1, 2), (2, 3), (2, 4), (3, 4), (4, 5)]) + graph1.nodes[1]["color"] = 0 + + graph2 = nx.Graph() + graph2.add_edges_from( + [(1, 2), (2, 3), (2, 4), (3, 4), (3, 5), (5, 6), (5, 7), (6, 7)] + ) + graph2.nodes[1]["color"] = 1 + graph2.nodes[6]["color"] = 2 + graph2.nodes[7]["color"] = 2 + + ismags = iso.ISMAGS( + graph1, graph2, node_match=iso.categorical_node_match("color", None) + ) + assert list(ismags.subgraph_isomorphisms_iter(True)) == [] + assert list(ismags.subgraph_isomorphisms_iter(False)) == [] + found_mcis = _matches_to_sets(ismags.largest_common_subgraph()) + expected = _matches_to_sets( + [{2: 2, 3: 4, 4: 3, 5: 5}, {2: 4, 3: 2, 4: 3, 5: 5}] + ) + assert expected == found_mcis + + ismags = iso.ISMAGS( + graph2, graph1, node_match=iso.categorical_node_match("color", None) + ) + assert list(ismags.subgraph_isomorphisms_iter(True)) == [] + assert list(ismags.subgraph_isomorphisms_iter(False)) == [] + found_mcis = _matches_to_sets(ismags.largest_common_subgraph()) + # Same answer, but reversed. + expected = _matches_to_sets( + [{2: 2, 3: 4, 4: 3, 5: 5}, {4: 2, 2: 3, 3: 4, 5: 5}] + ) + assert expected == found_mcis + + def test_symmetry_mcis(self): + graph1 = nx.Graph() + nx.add_path(graph1, range(4)) + + graph2 = nx.Graph() + nx.add_path(graph2, range(3)) + graph2.add_edge(1, 3) + + # Only the symmetry of graph2 is taken into account here. + ismags1 = iso.ISMAGS( + graph1, graph2, node_match=iso.categorical_node_match("color", None) + ) + assert list(ismags1.subgraph_isomorphisms_iter(True)) == [] + found_mcis = _matches_to_sets(ismags1.largest_common_subgraph()) + expected = _matches_to_sets([{0: 0, 1: 1, 2: 2}, {1: 0, 3: 2, 2: 1}]) + assert expected == found_mcis + + # Only the symmetry of graph1 is taken into account here. + ismags2 = iso.ISMAGS( + graph2, graph1, node_match=iso.categorical_node_match("color", None) + ) + assert list(ismags2.subgraph_isomorphisms_iter(True)) == [] + found_mcis = _matches_to_sets(ismags2.largest_common_subgraph()) + expected = _matches_to_sets( + [ + {3: 2, 0: 0, 1: 1}, + {2: 0, 0: 2, 1: 1}, + {3: 0, 0: 2, 1: 1}, + {3: 0, 1: 1, 2: 2}, + {0: 0, 1: 1, 2: 2}, + {2: 0, 3: 2, 1: 1}, + ] + ) + + assert expected == found_mcis + + found_mcis1 = _matches_to_sets(ismags1.largest_common_subgraph(False)) + found_mcis2 = ismags2.largest_common_subgraph(False) + found_mcis2 = [{v: k for k, v in d.items()} for d in found_mcis2] + found_mcis2 = _matches_to_sets(found_mcis2) + + expected = _matches_to_sets( + [ + {3: 2, 1: 3, 2: 1}, + {2: 0, 0: 2, 1: 1}, + {1: 2, 3: 3, 2: 1}, + {3: 0, 1: 3, 2: 1}, + {0: 2, 2: 3, 1: 1}, + {3: 0, 1: 2, 2: 1}, + {2: 0, 0: 3, 1: 1}, + {0: 0, 2: 3, 1: 1}, + {1: 0, 3: 3, 2: 1}, + {1: 0, 3: 2, 2: 1}, + {0: 3, 1: 1, 2: 2}, + {0: 0, 1: 1, 2: 2}, + ] + ) + assert expected == found_mcis1 + assert expected == found_mcis2 diff --git a/lib/python3.10/site-packages/networkx/algorithms/isomorphism/tests/test_isomorphism.py b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/tests/test_isomorphism.py new file mode 100644 index 0000000000000000000000000000000000000000..548af808ffd93a32dfe91b7b372b6c3e45a206bf --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/tests/test_isomorphism.py @@ -0,0 +1,48 @@ +import pytest + +import networkx as nx +from networkx.algorithms import isomorphism as iso + + +class TestIsomorph: + @classmethod + def setup_class(cls): + cls.G1 = nx.Graph() + cls.G2 = nx.Graph() + cls.G3 = nx.Graph() + cls.G4 = nx.Graph() + cls.G5 = nx.Graph() + cls.G6 = nx.Graph() + cls.G1.add_edges_from([[1, 2], [1, 3], [1, 5], [2, 3]]) + cls.G2.add_edges_from([[10, 20], [20, 30], [10, 30], [10, 50]]) + cls.G3.add_edges_from([[1, 2], [1, 3], [1, 5], [2, 5]]) + cls.G4.add_edges_from([[1, 2], [1, 3], [1, 5], [2, 4]]) + cls.G5.add_edges_from([[1, 2], [1, 3]]) + cls.G6.add_edges_from([[10, 20], [20, 30], [10, 30], [10, 50], [20, 50]]) + + def test_could_be_isomorphic(self): + assert iso.could_be_isomorphic(self.G1, self.G2) + assert iso.could_be_isomorphic(self.G1, self.G3) + assert not iso.could_be_isomorphic(self.G1, self.G4) + assert iso.could_be_isomorphic(self.G3, self.G2) + assert not iso.could_be_isomorphic(self.G1, self.G6) + + def test_fast_could_be_isomorphic(self): + assert iso.fast_could_be_isomorphic(self.G3, self.G2) + assert not iso.fast_could_be_isomorphic(self.G3, self.G5) + assert not iso.fast_could_be_isomorphic(self.G1, self.G6) + + def test_faster_could_be_isomorphic(self): + assert iso.faster_could_be_isomorphic(self.G3, self.G2) + assert not iso.faster_could_be_isomorphic(self.G3, self.G5) + assert not iso.faster_could_be_isomorphic(self.G1, self.G6) + + def test_is_isomorphic(self): + assert iso.is_isomorphic(self.G1, self.G2) + assert not iso.is_isomorphic(self.G1, self.G4) + assert iso.is_isomorphic(self.G1.to_directed(), self.G2.to_directed()) + assert not iso.is_isomorphic(self.G1.to_directed(), self.G4.to_directed()) + with pytest.raises( + nx.NetworkXError, match="Graphs G1 and G2 are not of the same type." + ): + iso.is_isomorphic(self.G1.to_directed(), self.G1) diff --git a/lib/python3.10/site-packages/networkx/algorithms/isomorphism/tests/test_isomorphvf2.py b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/tests/test_isomorphvf2.py new file mode 100644 index 0000000000000000000000000000000000000000..413dfaf3d38b5a940e5532c710cad4c4f64888bf --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/tests/test_isomorphvf2.py @@ -0,0 +1,410 @@ +""" +Tests for VF2 isomorphism algorithm. +""" + +import importlib.resources +import os +import random +import struct + +import networkx as nx +from networkx.algorithms import isomorphism as iso + + +class TestWikipediaExample: + # Source: https://en.wikipedia.org/wiki/Graph_isomorphism + + # Nodes 'a', 'b', 'c' and 'd' form a column. + # Nodes 'g', 'h', 'i' and 'j' form a column. + g1edges = [ + ["a", "g"], + ["a", "h"], + ["a", "i"], + ["b", "g"], + ["b", "h"], + ["b", "j"], + ["c", "g"], + ["c", "i"], + ["c", "j"], + ["d", "h"], + ["d", "i"], + ["d", "j"], + ] + + # Nodes 1,2,3,4 form the clockwise corners of a large square. + # Nodes 5,6,7,8 form the clockwise corners of a small square + g2edges = [ + [1, 2], + [2, 3], + [3, 4], + [4, 1], + [5, 6], + [6, 7], + [7, 8], + [8, 5], + [1, 5], + [2, 6], + [3, 7], + [4, 8], + ] + + def test_graph(self): + g1 = nx.Graph() + g2 = nx.Graph() + g1.add_edges_from(self.g1edges) + g2.add_edges_from(self.g2edges) + gm = iso.GraphMatcher(g1, g2) + assert gm.is_isomorphic() + # Just testing some cases + assert gm.subgraph_is_monomorphic() + + mapping = sorted(gm.mapping.items()) + + # this mapping is only one of the possibilities + # so this test needs to be reconsidered + # isomap = [('a', 1), ('b', 6), ('c', 3), ('d', 8), + # ('g', 2), ('h', 5), ('i', 4), ('j', 7)] + # assert_equal(mapping, isomap) + + def test_subgraph(self): + g1 = nx.Graph() + g2 = nx.Graph() + g1.add_edges_from(self.g1edges) + g2.add_edges_from(self.g2edges) + g3 = g2.subgraph([1, 2, 3, 4]) + gm = iso.GraphMatcher(g1, g3) + assert gm.subgraph_is_isomorphic() + + def test_subgraph_mono(self): + g1 = nx.Graph() + g2 = nx.Graph() + g1.add_edges_from(self.g1edges) + g2.add_edges_from([[1, 2], [2, 3], [3, 4]]) + gm = iso.GraphMatcher(g1, g2) + assert gm.subgraph_is_monomorphic() + + +class TestVF2GraphDB: + # https://web.archive.org/web/20090303210205/http://amalfi.dis.unina.it/graph/db/ + + @staticmethod + def create_graph(filename): + """Creates a Graph instance from the filename.""" + + # The file is assumed to be in the format from the VF2 graph database. + # Each file is composed of 16-bit numbers (unsigned short int). + # So we will want to read 2 bytes at a time. + + # We can read the number as follows: + # number = struct.unpack(' 0 + assert check_isomorphism(t1, t2, isomorphism) + + +# run positive_single_tree over all the +# non-isomorphic trees for k from 4 to maxk +# k = 4 is the first level that has more than 1 non-isomorphic tree +# k = 13 takes about 2.86 seconds to run on my laptop +# larger values run slow down significantly +# as the number of trees grows rapidly +def test_positive(maxk=14): + print("positive test") + + for k in range(2, maxk + 1): + start_time = time.time() + trial = 0 + for t in nx.nonisomorphic_trees(k): + positive_single_tree(t) + trial += 1 + print(k, trial, time.time() - start_time) + + +# test the trivial case of a single node in each tree +# note that nonisomorphic_trees doesn't work for k = 1 +def test_trivial(): + print("trivial test") + + # back to an undirected graph + t1 = nx.Graph() + t1.add_node("a") + root1 = "a" + + t2 = nx.Graph() + t2.add_node("n") + root2 = "n" + + isomorphism = rooted_tree_isomorphism(t1, root1, t2, root2) + + assert isomorphism == [("a", "n")] + + assert check_isomorphism(t1, t2, isomorphism) + + +# test another trivial case where the two graphs have +# different numbers of nodes +def test_trivial_2(): + print("trivial test 2") + + edges_1 = [("a", "b"), ("a", "c")] + + edges_2 = [("v", "y")] + + t1 = nx.Graph() + t1.add_edges_from(edges_1) + + t2 = nx.Graph() + t2.add_edges_from(edges_2) + + isomorphism = tree_isomorphism(t1, t2) + + # they cannot be isomorphic, + # since they have different numbers of nodes + assert isomorphism == [] + + +# the function nonisomorphic_trees generates all the non-isomorphic +# trees of a given size. Take each pair of these and verify that +# they are not isomorphic +# k = 4 is the first level that has more than 1 non-isomorphic tree +# k = 11 takes about 4.76 seconds to run on my laptop +# larger values run slow down significantly +# as the number of trees grows rapidly +def test_negative(maxk=11): + print("negative test") + + for k in range(4, maxk + 1): + test_trees = list(nx.nonisomorphic_trees(k)) + start_time = time.time() + trial = 0 + for i in range(len(test_trees) - 1): + for j in range(i + 1, len(test_trees)): + trial += 1 + assert tree_isomorphism(test_trees[i], test_trees[j]) == [] + print(k, trial, time.time() - start_time) diff --git a/lib/python3.10/site-packages/networkx/algorithms/isomorphism/tests/test_vf2pp.py b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/tests/test_vf2pp.py new file mode 100644 index 0000000000000000000000000000000000000000..5f3fb901cc8afdc2a73cfb7001538db49371eaf4 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/tests/test_vf2pp.py @@ -0,0 +1,1608 @@ +import itertools as it + +import pytest + +import networkx as nx +from networkx import vf2pp_is_isomorphic, vf2pp_isomorphism + +labels_same = ["blue"] + +labels_many = [ + "white", + "red", + "blue", + "green", + "orange", + "black", + "purple", + "yellow", + "brown", + "cyan", + "solarized", + "pink", + "none", +] + + +class TestPreCheck: + def test_first_graph_empty(self): + G1 = nx.Graph() + G2 = nx.Graph([(0, 1), (1, 2)]) + assert not vf2pp_is_isomorphic(G1, G2) + + def test_second_graph_empty(self): + G1 = nx.Graph([(0, 1), (1, 2)]) + G2 = nx.Graph() + assert not vf2pp_is_isomorphic(G1, G2) + + def test_different_order1(self): + G1 = nx.path_graph(5) + G2 = nx.path_graph(6) + assert not vf2pp_is_isomorphic(G1, G2) + + def test_different_order2(self): + G1 = nx.barbell_graph(100, 20) + G2 = nx.barbell_graph(101, 20) + assert not vf2pp_is_isomorphic(G1, G2) + + def test_different_order3(self): + G1 = nx.complete_graph(7) + G2 = nx.complete_graph(8) + assert not vf2pp_is_isomorphic(G1, G2) + + def test_different_degree_sequences1(self): + G1 = nx.Graph([(0, 1), (0, 2), (1, 2), (1, 3), (0, 4)]) + G2 = nx.Graph([(0, 1), (0, 2), (1, 2), (1, 3), (0, 4), (2, 5)]) + assert not vf2pp_is_isomorphic(G1, G2) + + G2.remove_node(3) + nx.set_node_attributes(G1, dict(zip(G1, it.cycle(["a"]))), "label") + nx.set_node_attributes(G2, dict(zip(G2, it.cycle("a"))), "label") + + assert vf2pp_is_isomorphic(G1, G2) + + def test_different_degree_sequences2(self): + G1 = nx.Graph( + [ + (0, 1), + (1, 2), + (0, 2), + (2, 3), + (3, 4), + (4, 5), + (5, 6), + (6, 3), + (4, 7), + (7, 8), + (8, 3), + ] + ) + G2 = G1.copy() + G2.add_edge(8, 0) + assert not vf2pp_is_isomorphic(G1, G2) + + G1.add_edge(6, 1) + nx.set_node_attributes(G1, dict(zip(G1, it.cycle(["a"]))), "label") + nx.set_node_attributes(G2, dict(zip(G2, it.cycle("a"))), "label") + + assert vf2pp_is_isomorphic(G1, G2) + + def test_different_degree_sequences3(self): + G1 = nx.Graph([(0, 1), (0, 2), (1, 2), (2, 3), (2, 4), (3, 4), (2, 5), (2, 6)]) + G2 = nx.Graph( + [(0, 1), (0, 6), (0, 2), (1, 2), (2, 3), (2, 4), (3, 4), (2, 5), (2, 6)] + ) + assert not vf2pp_is_isomorphic(G1, G2) + + G1.add_edge(3, 5) + nx.set_node_attributes(G1, dict(zip(G1, it.cycle(["a"]))), "label") + nx.set_node_attributes(G2, dict(zip(G2, it.cycle("a"))), "label") + + assert vf2pp_is_isomorphic(G1, G2) + + def test_label_distribution(self): + G1 = nx.Graph([(0, 1), (0, 2), (1, 2), (2, 3), (2, 4), (3, 4), (2, 5), (2, 6)]) + G2 = nx.Graph([(0, 1), (0, 2), (1, 2), (2, 3), (2, 4), (3, 4), (2, 5), (2, 6)]) + + colors1 = ["blue", "blue", "blue", "yellow", "black", "purple", "purple"] + colors2 = ["blue", "blue", "yellow", "yellow", "black", "purple", "purple"] + + nx.set_node_attributes(G1, dict(zip(G1, it.cycle(colors1[::-1]))), "label") + nx.set_node_attributes(G2, dict(zip(G2, it.cycle(colors2[::-1]))), "label") + + assert not vf2pp_is_isomorphic(G1, G2, node_label="label") + G2.nodes[3]["label"] = "blue" + assert vf2pp_is_isomorphic(G1, G2, node_label="label") + + +class TestAllGraphTypesEdgeCases: + @pytest.mark.parametrize("graph_type", (nx.Graph, nx.MultiGraph, nx.DiGraph)) + def test_both_graphs_empty(self, graph_type): + G = graph_type() + H = graph_type() + assert vf2pp_isomorphism(G, H) is None + + G.add_node(0) + + assert vf2pp_isomorphism(G, H) is None + assert vf2pp_isomorphism(H, G) is None + + H.add_node(0) + assert vf2pp_isomorphism(G, H) == {0: 0} + + @pytest.mark.parametrize("graph_type", (nx.Graph, nx.MultiGraph, nx.DiGraph)) + def test_first_graph_empty(self, graph_type): + G = graph_type() + H = graph_type([(0, 1)]) + assert vf2pp_isomorphism(G, H) is None + + @pytest.mark.parametrize("graph_type", (nx.Graph, nx.MultiGraph, nx.DiGraph)) + def test_second_graph_empty(self, graph_type): + G = graph_type([(0, 1)]) + H = graph_type() + assert vf2pp_isomorphism(G, H) is None + + +class TestGraphISOVF2pp: + def test_custom_graph1_same_labels(self): + G1 = nx.Graph() + + mapped = {1: "A", 2: "B", 3: "C", 4: "D", 5: "Z", 6: "E"} + edges1 = [(1, 2), (1, 3), (1, 4), (2, 3), (2, 6), (3, 4), (5, 1), (5, 2)] + + G1.add_edges_from(edges1) + G2 = nx.relabel_nodes(G1, mapped) + nx.set_node_attributes(G1, dict(zip(G1, it.cycle(labels_same))), "label") + nx.set_node_attributes(G2, dict(zip(G2, it.cycle(labels_same))), "label") + assert vf2pp_isomorphism(G1, G2, node_label="label") + + # Add edge making G1 symmetrical + G1.add_edge(3, 7) + G1.nodes[7]["label"] = "blue" + assert vf2pp_isomorphism(G1, G2, node_label="label") is None + + # Make G2 isomorphic to G1 + G2.add_edges_from([(mapped[3], "X"), (mapped[6], mapped[5])]) + G1.add_edge(4, 7) + G2.nodes["X"]["label"] = "blue" + assert vf2pp_isomorphism(G1, G2, node_label="label") + + # Re-structure maintaining isomorphism + G1.remove_edges_from([(1, 4), (1, 3)]) + G2.remove_edges_from([(mapped[1], mapped[5]), (mapped[1], mapped[2])]) + assert vf2pp_isomorphism(G1, G2, node_label="label") + + def test_custom_graph1_different_labels(self): + G1 = nx.Graph() + + mapped = {1: "A", 2: "B", 3: "C", 4: "D", 5: "Z", 6: "E"} + edges1 = [(1, 2), (1, 3), (1, 4), (2, 3), (2, 6), (3, 4), (5, 1), (5, 2)] + + G1.add_edges_from(edges1) + G2 = nx.relabel_nodes(G1, mapped) + nx.set_node_attributes(G1, dict(zip(G1, it.cycle(labels_many))), "label") + nx.set_node_attributes( + G2, + dict(zip([mapped[n] for n in G1], it.cycle(labels_many))), + "label", + ) + assert vf2pp_isomorphism(G1, G2, node_label="label") == mapped + + def test_custom_graph2_same_labels(self): + G1 = nx.Graph() + + mapped = {1: "A", 2: "C", 3: "D", 4: "E", 5: "G", 7: "B", 6: "F"} + edges1 = [(1, 2), (1, 5), (5, 6), (2, 3), (2, 4), (3, 4), (4, 5), (2, 7)] + + G1.add_edges_from(edges1) + G2 = nx.relabel_nodes(G1, mapped) + nx.set_node_attributes(G1, dict(zip(G1, it.cycle(labels_same))), "label") + nx.set_node_attributes(G2, dict(zip(G2, it.cycle(labels_same))), "label") + + assert vf2pp_isomorphism(G1, G2, node_label="label") + + # Obtain two isomorphic subgraphs from the graph + G2.remove_edge(mapped[1], mapped[2]) + G2.add_edge(mapped[1], mapped[4]) + H1 = nx.Graph(G1.subgraph([2, 3, 4, 7])) + H2 = nx.Graph(G2.subgraph([mapped[1], mapped[4], mapped[5], mapped[6]])) + assert vf2pp_isomorphism(H1, H2, node_label="label") + + # Add edges maintaining isomorphism + H1.add_edges_from([(3, 7), (4, 7)]) + H2.add_edges_from([(mapped[1], mapped[6]), (mapped[4], mapped[6])]) + assert vf2pp_isomorphism(H1, H2, node_label="label") + + def test_custom_graph2_different_labels(self): + G1 = nx.Graph() + + mapped = {1: "A", 2: "C", 3: "D", 4: "E", 5: "G", 7: "B", 6: "F"} + edges1 = [(1, 2), (1, 5), (5, 6), (2, 3), (2, 4), (3, 4), (4, 5), (2, 7)] + + G1.add_edges_from(edges1) + G2 = nx.relabel_nodes(G1, mapped) + nx.set_node_attributes(G1, dict(zip(G1, it.cycle(labels_many))), "label") + nx.set_node_attributes( + G2, + dict(zip([mapped[n] for n in G1], it.cycle(labels_many))), + "label", + ) + + # Adding new nodes + G1.add_node(0) + G2.add_node("Z") + G1.nodes[0]["label"] = G1.nodes[1]["label"] + G2.nodes["Z"]["label"] = G1.nodes[1]["label"] + mapped.update({0: "Z"}) + + assert vf2pp_isomorphism(G1, G2, node_label="label") == mapped + + # Change the color of one of the nodes + G2.nodes["Z"]["label"] = G1.nodes[2]["label"] + assert vf2pp_isomorphism(G1, G2, node_label="label") is None + + # Add an extra edge + G1.nodes[0]["label"] = "blue" + G2.nodes["Z"]["label"] = "blue" + G1.add_edge(0, 1) + + assert vf2pp_isomorphism(G1, G2, node_label="label") is None + + # Add extra edge to both + G2.add_edge("Z", "A") + assert vf2pp_isomorphism(G1, G2, node_label="label") == mapped + + def test_custom_graph3_same_labels(self): + G1 = nx.Graph() + + mapped = {1: 9, 2: 8, 3: 7, 4: 6, 5: 3, 8: 5, 9: 4, 7: 1, 6: 2} + edges1 = [ + (1, 2), + (1, 3), + (2, 3), + (3, 4), + (4, 5), + (4, 7), + (4, 9), + (5, 8), + (8, 9), + (5, 6), + (6, 7), + (5, 2), + ] + G1.add_edges_from(edges1) + G2 = nx.relabel_nodes(G1, mapped) + nx.set_node_attributes(G1, dict(zip(G1, it.cycle(labels_same))), "label") + nx.set_node_attributes(G2, dict(zip(G2, it.cycle(labels_same))), "label") + assert vf2pp_isomorphism(G1, G2, node_label="label") + + # Connect nodes maintaining symmetry + G1.add_edges_from([(6, 9), (7, 8)]) + G2.add_edges_from([(mapped[6], mapped[8]), (mapped[7], mapped[9])]) + assert vf2pp_isomorphism(G1, G2, node_label="label") is None + + # Make isomorphic + G1.add_edges_from([(6, 8), (7, 9)]) + G2.add_edges_from([(mapped[6], mapped[9]), (mapped[7], mapped[8])]) + assert vf2pp_isomorphism(G1, G2, node_label="label") + + # Connect more nodes + G1.add_edges_from([(2, 7), (3, 6)]) + G2.add_edges_from([(mapped[2], mapped[7]), (mapped[3], mapped[6])]) + G1.add_node(10) + G2.add_node("Z") + G1.nodes[10]["label"] = "blue" + G2.nodes["Z"]["label"] = "blue" + + assert vf2pp_isomorphism(G1, G2, node_label="label") + + # Connect the newly added node, to opposite sides of the graph + G1.add_edges_from([(10, 1), (10, 5), (10, 8)]) + G2.add_edges_from([("Z", mapped[1]), ("Z", mapped[4]), ("Z", mapped[9])]) + assert vf2pp_isomorphism(G1, G2, node_label="label") + + # Get two subgraphs that are not isomorphic but are easy to make + H1 = nx.Graph(G1.subgraph([2, 3, 4, 5, 6, 7, 10])) + H2 = nx.Graph( + G2.subgraph( + [mapped[4], mapped[5], mapped[6], mapped[7], mapped[8], mapped[9], "Z"] + ) + ) + assert vf2pp_isomorphism(H1, H2, node_label="label") is None + + # Restructure both to make them isomorphic + H1.add_edges_from([(10, 2), (10, 6), (3, 6), (2, 7), (2, 6), (3, 7)]) + H2.add_edges_from( + [("Z", mapped[7]), (mapped[6], mapped[9]), (mapped[7], mapped[8])] + ) + assert vf2pp_isomorphism(H1, H2, node_label="label") + + # Add edges with opposite direction in each Graph + H1.add_edge(3, 5) + H2.add_edge(mapped[5], mapped[7]) + assert vf2pp_isomorphism(H1, H2, node_label="label") is None + + def test_custom_graph3_different_labels(self): + G1 = nx.Graph() + + mapped = {1: 9, 2: 8, 3: 7, 4: 6, 5: 3, 8: 5, 9: 4, 7: 1, 6: 2} + edges1 = [ + (1, 2), + (1, 3), + (2, 3), + (3, 4), + (4, 5), + (4, 7), + (4, 9), + (5, 8), + (8, 9), + (5, 6), + (6, 7), + (5, 2), + ] + G1.add_edges_from(edges1) + G2 = nx.relabel_nodes(G1, mapped) + nx.set_node_attributes(G1, dict(zip(G1, it.cycle(labels_many))), "label") + nx.set_node_attributes( + G2, + dict(zip([mapped[n] for n in G1], it.cycle(labels_many))), + "label", + ) + assert vf2pp_isomorphism(G1, G2, node_label="label") == mapped + + # Add extra edge to G1 + G1.add_edge(1, 7) + assert vf2pp_isomorphism(G1, G2, node_label="label") is None + + # Compensate in G2 + G2.add_edge(9, 1) + assert vf2pp_isomorphism(G1, G2, node_label="label") == mapped + + # Add extra node + G1.add_node("A") + G2.add_node("K") + G1.nodes["A"]["label"] = "green" + G2.nodes["K"]["label"] = "green" + mapped.update({"A": "K"}) + + assert vf2pp_isomorphism(G1, G2, node_label="label") == mapped + + # Connect A to one side of G1 and K to the opposite + G1.add_edge("A", 6) + G2.add_edge("K", 5) + assert vf2pp_isomorphism(G1, G2, node_label="label") is None + + # Make the graphs symmetrical + G1.add_edge(1, 5) + G1.add_edge(2, 9) + G2.add_edge(9, 3) + G2.add_edge(8, 4) + assert vf2pp_isomorphism(G1, G2, node_label="label") is None + + # Assign same colors so the two opposite sides are identical + for node in G1.nodes(): + color = "red" + G1.nodes[node]["label"] = color + G2.nodes[mapped[node]]["label"] = color + + assert vf2pp_isomorphism(G1, G2, node_label="label") + + def test_custom_graph4_different_labels(self): + G1 = nx.Graph() + edges1 = [ + (1, 2), + (2, 3), + (3, 8), + (3, 4), + (4, 5), + (4, 6), + (3, 6), + (8, 7), + (8, 9), + (5, 9), + (10, 11), + (11, 12), + (12, 13), + (11, 13), + ] + + mapped = { + 1: "n", + 2: "m", + 3: "l", + 4: "j", + 5: "k", + 6: "i", + 7: "g", + 8: "h", + 9: "f", + 10: "b", + 11: "a", + 12: "d", + 13: "e", + } + + G1.add_edges_from(edges1) + G2 = nx.relabel_nodes(G1, mapped) + nx.set_node_attributes(G1, dict(zip(G1, it.cycle(labels_many))), "label") + nx.set_node_attributes( + G2, + dict(zip([mapped[n] for n in G1], it.cycle(labels_many))), + "label", + ) + assert vf2pp_isomorphism(G1, G2, node_label="label") == mapped + + def test_custom_graph4_same_labels(self): + G1 = nx.Graph() + edges1 = [ + (1, 2), + (2, 3), + (3, 8), + (3, 4), + (4, 5), + (4, 6), + (3, 6), + (8, 7), + (8, 9), + (5, 9), + (10, 11), + (11, 12), + (12, 13), + (11, 13), + ] + + mapped = { + 1: "n", + 2: "m", + 3: "l", + 4: "j", + 5: "k", + 6: "i", + 7: "g", + 8: "h", + 9: "f", + 10: "b", + 11: "a", + 12: "d", + 13: "e", + } + + G1.add_edges_from(edges1) + G2 = nx.relabel_nodes(G1, mapped) + nx.set_node_attributes(G1, dict(zip(G1, it.cycle(labels_same))), "label") + nx.set_node_attributes(G2, dict(zip(G2, it.cycle(labels_same))), "label") + assert vf2pp_isomorphism(G1, G2, node_label="label") + + # Add nodes of different label + G1.add_node(0) + G2.add_node("z") + G1.nodes[0]["label"] = "green" + G2.nodes["z"]["label"] = "blue" + + assert vf2pp_isomorphism(G1, G2, node_label="label") is None + + # Make the labels identical + G2.nodes["z"]["label"] = "green" + assert vf2pp_isomorphism(G1, G2, node_label="label") + + # Change the structure of the graphs, keeping them isomorphic + G1.add_edge(2, 5) + G2.remove_edge("i", "l") + G2.add_edge("g", "l") + G2.add_edge("m", "f") + assert vf2pp_isomorphism(G1, G2, node_label="label") + + # Change the structure of the disconnected sub-graph, keeping it isomorphic + G1.remove_node(13) + G2.remove_node("d") + assert vf2pp_isomorphism(G1, G2, node_label="label") + + # Connect the newly added node to the disconnected graph, which now is just a path of size 3 + G1.add_edge(0, 10) + G2.add_edge("e", "z") + assert vf2pp_isomorphism(G1, G2, node_label="label") + + # Connect the two disconnected sub-graphs, forming a single graph + G1.add_edge(11, 3) + G1.add_edge(0, 8) + G2.add_edge("a", "l") + G2.add_edge("z", "j") + assert vf2pp_isomorphism(G1, G2, node_label="label") + + def test_custom_graph5_same_labels(self): + G1 = nx.Graph() + edges1 = [ + (1, 5), + (1, 2), + (1, 4), + (2, 3), + (2, 6), + (3, 4), + (3, 7), + (4, 8), + (5, 8), + (5, 6), + (6, 7), + (7, 8), + ] + mapped = {1: "a", 2: "h", 3: "d", 4: "i", 5: "g", 6: "b", 7: "j", 8: "c"} + + G1.add_edges_from(edges1) + G2 = nx.relabel_nodes(G1, mapped) + nx.set_node_attributes(G1, dict(zip(G1, it.cycle(labels_same))), "label") + nx.set_node_attributes(G2, dict(zip(G2, it.cycle(labels_same))), "label") + assert vf2pp_isomorphism(G1, G2, node_label="label") + + # Add different edges in each graph, maintaining symmetry + G1.add_edges_from([(3, 6), (2, 7), (2, 5), (1, 3), (4, 7), (6, 8)]) + G2.add_edges_from( + [ + (mapped[6], mapped[3]), + (mapped[2], mapped[7]), + (mapped[1], mapped[6]), + (mapped[5], mapped[7]), + (mapped[3], mapped[8]), + (mapped[2], mapped[4]), + ] + ) + assert vf2pp_isomorphism(G1, G2, node_label="label") + + # Obtain two different but isomorphic subgraphs from G1 and G2 + H1 = nx.Graph(G1.subgraph([1, 5, 8, 6, 7, 3])) + H2 = nx.Graph( + G2.subgraph( + [mapped[1], mapped[4], mapped[8], mapped[7], mapped[3], mapped[5]] + ) + ) + assert vf2pp_isomorphism(H1, H2, node_label="label") + + # Delete corresponding node from the two graphs + H1.remove_node(8) + H2.remove_node(mapped[7]) + assert vf2pp_isomorphism(H1, H2, node_label="label") + + # Re-orient, maintaining isomorphism + H1.add_edge(1, 6) + H1.remove_edge(3, 6) + assert vf2pp_isomorphism(H1, H2, node_label="label") + + def test_custom_graph5_different_labels(self): + G1 = nx.Graph() + edges1 = [ + (1, 5), + (1, 2), + (1, 4), + (2, 3), + (2, 6), + (3, 4), + (3, 7), + (4, 8), + (5, 8), + (5, 6), + (6, 7), + (7, 8), + ] + mapped = {1: "a", 2: "h", 3: "d", 4: "i", 5: "g", 6: "b", 7: "j", 8: "c"} + + G1.add_edges_from(edges1) + G2 = nx.relabel_nodes(G1, mapped) + + colors = ["red", "blue", "grey", "none", "brown", "solarized", "yellow", "pink"] + nx.set_node_attributes(G1, dict(zip(G1, it.cycle(labels_many))), "label") + nx.set_node_attributes( + G2, + dict(zip([mapped[n] for n in G1], it.cycle(labels_many))), + "label", + ) + assert vf2pp_isomorphism(G1, G2, node_label="label") == mapped + + # Assign different colors to matching nodes + c = 0 + for node in G1.nodes(): + color1 = colors[c] + color2 = colors[(c + 3) % len(colors)] + G1.nodes[node]["label"] = color1 + G2.nodes[mapped[node]]["label"] = color2 + c += 1 + + assert vf2pp_isomorphism(G1, G2, node_label="label") is None + + # Get symmetrical sub-graphs of G1,G2 and compare them + H1 = G1.subgraph([1, 5]) + H2 = G2.subgraph(["i", "c"]) + c = 0 + for node1, node2 in zip(H1.nodes(), H2.nodes()): + H1.nodes[node1]["label"] = "red" + H2.nodes[node2]["label"] = "red" + c += 1 + + assert vf2pp_isomorphism(H1, H2, node_label="label") + + def test_disconnected_graph_all_same_labels(self): + G1 = nx.Graph() + G1.add_nodes_from(list(range(10))) + + mapped = {0: 9, 1: 8, 2: 7, 3: 6, 4: 5, 5: 4, 6: 3, 7: 2, 8: 1, 9: 0} + G2 = nx.relabel_nodes(G1, mapped) + nx.set_node_attributes(G1, dict(zip(G1, it.cycle(labels_same))), "label") + nx.set_node_attributes(G2, dict(zip(G2, it.cycle(labels_same))), "label") + assert vf2pp_isomorphism(G1, G2, node_label="label") + + def test_disconnected_graph_all_different_labels(self): + G1 = nx.Graph() + G1.add_nodes_from(list(range(10))) + + mapped = {0: 9, 1: 8, 2: 7, 3: 6, 4: 5, 5: 4, 6: 3, 7: 2, 8: 1, 9: 0} + G2 = nx.relabel_nodes(G1, mapped) + + nx.set_node_attributes(G1, dict(zip(G1, it.cycle(labels_many))), "label") + nx.set_node_attributes( + G2, + dict(zip([mapped[n] for n in G1], it.cycle(labels_many))), + "label", + ) + assert vf2pp_isomorphism(G1, G2, node_label="label") == mapped + + def test_disconnected_graph_some_same_labels(self): + G1 = nx.Graph() + G1.add_nodes_from(list(range(10))) + + mapped = {0: 9, 1: 8, 2: 7, 3: 6, 4: 5, 5: 4, 6: 3, 7: 2, 8: 1, 9: 0} + G2 = nx.relabel_nodes(G1, mapped) + + colors = [ + "white", + "white", + "white", + "purple", + "purple", + "red", + "red", + "pink", + "pink", + "pink", + ] + + nx.set_node_attributes(G1, dict(zip(G1, it.cycle(colors))), "label") + nx.set_node_attributes( + G2, dict(zip([mapped[n] for n in G1], it.cycle(colors))), "label" + ) + + assert vf2pp_isomorphism(G1, G2, node_label="label") + + +class TestMultiGraphISOVF2pp: + def test_custom_multigraph1_same_labels(self): + G1 = nx.MultiGraph() + + mapped = {1: "A", 2: "B", 3: "C", 4: "D", 5: "Z", 6: "E"} + edges1 = [ + (1, 2), + (1, 3), + (1, 4), + (1, 4), + (1, 4), + (2, 3), + (2, 6), + (2, 6), + (3, 4), + (3, 4), + (5, 1), + (5, 1), + (5, 2), + (5, 2), + ] + + G1.add_edges_from(edges1) + G2 = nx.relabel_nodes(G1, mapped) + + nx.set_node_attributes(G1, dict(zip(G1, it.cycle(labels_same))), "label") + nx.set_node_attributes(G2, dict(zip(G2, it.cycle(labels_same))), "label") + m = vf2pp_isomorphism(G1, G2, node_label="label") + assert m + + # Transfer the 2-clique to the right side of G1 + G1.remove_edges_from([(2, 6), (2, 6)]) + G1.add_edges_from([(3, 6), (3, 6)]) + m = vf2pp_isomorphism(G1, G2, node_label="label") + assert not m + + # Delete an edges, making them symmetrical, so the position of the 2-clique doesn't matter + G2.remove_edge(mapped[1], mapped[4]) + G1.remove_edge(1, 4) + m = vf2pp_isomorphism(G1, G2, node_label="label") + assert m + + # Add self-loops + G1.add_edges_from([(5, 5), (5, 5), (1, 1)]) + m = vf2pp_isomorphism(G1, G2, node_label="label") + assert not m + + # Compensate in G2 + G2.add_edges_from( + [(mapped[1], mapped[1]), (mapped[4], mapped[4]), (mapped[4], mapped[4])] + ) + m = vf2pp_isomorphism(G1, G2, node_label="label") + assert m + + def test_custom_multigraph1_different_labels(self): + G1 = nx.MultiGraph() + + mapped = {1: "A", 2: "B", 3: "C", 4: "D", 5: "Z", 6: "E"} + edges1 = [ + (1, 2), + (1, 3), + (1, 4), + (1, 4), + (1, 4), + (2, 3), + (2, 6), + (2, 6), + (3, 4), + (3, 4), + (5, 1), + (5, 1), + (5, 2), + (5, 2), + ] + + G1.add_edges_from(edges1) + G2 = nx.relabel_nodes(G1, mapped) + + nx.set_node_attributes(G1, dict(zip(G1, it.cycle(labels_many))), "label") + nx.set_node_attributes( + G2, + dict(zip([mapped[n] for n in G1], it.cycle(labels_many))), + "label", + ) + m = vf2pp_isomorphism(G1, G2, node_label="label") + assert m + assert m == mapped + + # Re-structure G1, maintaining the degree sequence + G1.remove_edge(1, 4) + G1.add_edge(1, 5) + m = vf2pp_isomorphism(G1, G2, node_label="label") + assert not m + + # Restructure G2, making it isomorphic to G1 + G2.remove_edge("A", "D") + G2.add_edge("A", "Z") + m = vf2pp_isomorphism(G1, G2, node_label="label") + assert m + assert m == mapped + + # Add edge from node to itself + G1.add_edges_from([(6, 6), (6, 6), (6, 6)]) + m = vf2pp_isomorphism(G1, G2, node_label="label") + assert not m + + # Same for G2 + G2.add_edges_from([("E", "E"), ("E", "E"), ("E", "E")]) + m = vf2pp_isomorphism(G1, G2, node_label="label") + assert m + assert m == mapped + + def test_custom_multigraph2_same_labels(self): + G1 = nx.MultiGraph() + + mapped = {1: "A", 2: "C", 3: "D", 4: "E", 5: "G", 7: "B", 6: "F"} + edges1 = [ + (1, 2), + (1, 2), + (1, 5), + (1, 5), + (1, 5), + (5, 6), + (2, 3), + (2, 3), + (2, 4), + (3, 4), + (3, 4), + (4, 5), + (4, 5), + (4, 5), + (2, 7), + (2, 7), + (2, 7), + ] + + G1.add_edges_from(edges1) + G2 = nx.relabel_nodes(G1, mapped) + + nx.set_node_attributes(G1, dict(zip(G1, it.cycle(labels_same))), "label") + nx.set_node_attributes(G2, dict(zip(G2, it.cycle(labels_same))), "label") + m = vf2pp_isomorphism(G1, G2, node_label="label") + assert m + + # Obtain two non-isomorphic subgraphs from the graph + G2.remove_edges_from([(mapped[1], mapped[2]), (mapped[1], mapped[2])]) + G2.add_edge(mapped[1], mapped[4]) + H1 = nx.MultiGraph(G1.subgraph([2, 3, 4, 7])) + H2 = nx.MultiGraph(G2.subgraph([mapped[1], mapped[4], mapped[5], mapped[6]])) + + m = vf2pp_isomorphism(H1, H2, node_label="label") + assert not m + + # Make them isomorphic + H1.remove_edge(3, 4) + H1.add_edges_from([(2, 3), (2, 4), (2, 4)]) + H2.add_edges_from([(mapped[5], mapped[6]), (mapped[5], mapped[6])]) + m = vf2pp_isomorphism(H1, H2, node_label="label") + assert m + + # Remove triangle edge + H1.remove_edges_from([(2, 3), (2, 3), (2, 3)]) + H2.remove_edges_from([(mapped[5], mapped[4])] * 3) + m = vf2pp_isomorphism(H1, H2, node_label="label") + assert m + + # Change the edge orientation such that H1 is rotated H2 + H1.remove_edges_from([(2, 7), (2, 7)]) + H1.add_edges_from([(3, 4), (3, 4)]) + m = vf2pp_isomorphism(H1, H2, node_label="label") + assert m + + # Add extra edges maintaining degree sequence, but in a non-symmetrical manner + H2.add_edge(mapped[5], mapped[1]) + H1.add_edge(3, 4) + m = vf2pp_isomorphism(H1, H2, node_label="label") + assert not m + + def test_custom_multigraph2_different_labels(self): + G1 = nx.MultiGraph() + + mapped = {1: "A", 2: "C", 3: "D", 4: "E", 5: "G", 7: "B", 6: "F"} + edges1 = [ + (1, 2), + (1, 2), + (1, 5), + (1, 5), + (1, 5), + (5, 6), + (2, 3), + (2, 3), + (2, 4), + (3, 4), + (3, 4), + (4, 5), + (4, 5), + (4, 5), + (2, 7), + (2, 7), + (2, 7), + ] + + G1.add_edges_from(edges1) + G2 = nx.relabel_nodes(G1, mapped) + + nx.set_node_attributes(G1, dict(zip(G1, it.cycle(labels_many))), "label") + nx.set_node_attributes( + G2, + dict(zip([mapped[n] for n in G1], it.cycle(labels_many))), + "label", + ) + m = vf2pp_isomorphism(G1, G2, node_label="label") + assert m + assert m == mapped + + # Re-structure G1 + G1.remove_edge(2, 7) + G1.add_edge(5, 6) + + m = vf2pp_isomorphism(G1, G2, node_label="label") + assert not m + + # Same for G2 + G2.remove_edge("B", "C") + G2.add_edge("G", "F") + + m = vf2pp_isomorphism(G1, G2, node_label="label") + assert m + assert m == mapped + + # Delete node from G1 and G2, keeping them isomorphic + G1.remove_node(3) + G2.remove_node("D") + m = vf2pp_isomorphism(G1, G2, node_label="label") + assert m + + # Change G1 edges + G1.remove_edge(1, 2) + G1.remove_edge(2, 7) + m = vf2pp_isomorphism(G1, G2, node_label="label") + assert not m + + # Make G2 identical to G1, but with different edge orientation and different labels + G2.add_edges_from([("A", "C"), ("C", "E"), ("C", "E")]) + G2.remove_edges_from( + [("A", "G"), ("A", "G"), ("F", "G"), ("E", "G"), ("E", "G")] + ) + + m = vf2pp_isomorphism(G1, G2, node_label="label") + assert not m + + # Make all labels the same, so G1 and G2 are also isomorphic + for n1, n2 in zip(G1.nodes(), G2.nodes()): + G1.nodes[n1]["label"] = "blue" + G2.nodes[n2]["label"] = "blue" + + m = vf2pp_isomorphism(G1, G2, node_label="label") + assert m + + def test_custom_multigraph3_same_labels(self): + G1 = nx.MultiGraph() + + mapped = {1: 9, 2: 8, 3: 7, 4: 6, 5: 3, 8: 5, 9: 4, 7: 1, 6: 2} + edges1 = [ + (1, 2), + (1, 3), + (1, 3), + (2, 3), + (2, 3), + (3, 4), + (4, 5), + (4, 7), + (4, 9), + (4, 9), + (4, 9), + (5, 8), + (5, 8), + (8, 9), + (8, 9), + (5, 6), + (6, 7), + (6, 7), + (6, 7), + (5, 2), + ] + G1.add_edges_from(edges1) + G2 = nx.relabel_nodes(G1, mapped) + + nx.set_node_attributes(G1, dict(zip(G1, it.cycle(labels_same))), "label") + nx.set_node_attributes(G2, dict(zip(G2, it.cycle(labels_same))), "label") + m = vf2pp_isomorphism(G1, G2, node_label="label") + assert m + + # Connect nodes maintaining symmetry + G1.add_edges_from([(6, 9), (7, 8), (5, 8), (4, 9), (4, 9)]) + G2.add_edges_from( + [ + (mapped[6], mapped[8]), + (mapped[7], mapped[9]), + (mapped[5], mapped[8]), + (mapped[4], mapped[9]), + (mapped[4], mapped[9]), + ] + ) + m = vf2pp_isomorphism(G1, G2, node_label="label") + assert not m + + # Make isomorphic + G1.add_edges_from([(6, 8), (6, 8), (7, 9), (7, 9), (7, 9)]) + G2.add_edges_from( + [ + (mapped[6], mapped[8]), + (mapped[6], mapped[9]), + (mapped[7], mapped[8]), + (mapped[7], mapped[9]), + (mapped[7], mapped[9]), + ] + ) + m = vf2pp_isomorphism(G1, G2, node_label="label") + assert m + + # Connect more nodes + G1.add_edges_from([(2, 7), (2, 7), (3, 6), (3, 6)]) + G2.add_edges_from( + [ + (mapped[2], mapped[7]), + (mapped[2], mapped[7]), + (mapped[3], mapped[6]), + (mapped[3], mapped[6]), + ] + ) + G1.add_node(10) + G2.add_node("Z") + G1.nodes[10]["label"] = "blue" + G2.nodes["Z"]["label"] = "blue" + + m = vf2pp_isomorphism(G1, G2, node_label="label") + assert m + + # Connect the newly added node, to opposite sides of the graph + G1.add_edges_from([(10, 1), (10, 5), (10, 8), (10, 10), (10, 10)]) + G2.add_edges_from( + [ + ("Z", mapped[1]), + ("Z", mapped[4]), + ("Z", mapped[9]), + ("Z", "Z"), + ("Z", "Z"), + ] + ) + m = vf2pp_isomorphism(G1, G2, node_label="label") + assert not m + + # We connected the new node to opposite sides, so G1 must be symmetrical to G2. Re-structure them to be so + G1.remove_edges_from([(1, 3), (4, 9), (4, 9), (7, 9)]) + G2.remove_edges_from( + [ + (mapped[1], mapped[3]), + (mapped[4], mapped[9]), + (mapped[4], mapped[9]), + (mapped[7], mapped[9]), + ] + ) + m = vf2pp_isomorphism(G1, G2, node_label="label") + assert m + + # Get two subgraphs that are not isomorphic but are easy to make + H1 = nx.Graph(G1.subgraph([2, 3, 4, 5, 6, 7, 10])) + H2 = nx.Graph( + G2.subgraph( + [mapped[4], mapped[5], mapped[6], mapped[7], mapped[8], mapped[9], "Z"] + ) + ) + + m = vf2pp_isomorphism(H1, H2, node_label="label") + assert not m + + # Restructure both to make them isomorphic + H1.add_edges_from([(10, 2), (10, 6), (3, 6), (2, 7), (2, 6), (3, 7)]) + H2.add_edges_from( + [("Z", mapped[7]), (mapped[6], mapped[9]), (mapped[7], mapped[8])] + ) + m = vf2pp_isomorphism(H1, H2, node_label="label") + assert m + + # Remove one self-loop in H2 + H2.remove_edge("Z", "Z") + m = vf2pp_isomorphism(H1, H2, node_label="label") + assert not m + + # Compensate in H1 + H1.remove_edge(10, 10) + m = vf2pp_isomorphism(H1, H2, node_label="label") + assert m + + def test_custom_multigraph3_different_labels(self): + G1 = nx.MultiGraph() + + mapped = {1: 9, 2: 8, 3: 7, 4: 6, 5: 3, 8: 5, 9: 4, 7: 1, 6: 2} + edges1 = [ + (1, 2), + (1, 3), + (1, 3), + (2, 3), + (2, 3), + (3, 4), + (4, 5), + (4, 7), + (4, 9), + (4, 9), + (4, 9), + (5, 8), + (5, 8), + (8, 9), + (8, 9), + (5, 6), + (6, 7), + (6, 7), + (6, 7), + (5, 2), + ] + + G1.add_edges_from(edges1) + G2 = nx.relabel_nodes(G1, mapped) + + nx.set_node_attributes(G1, dict(zip(G1, it.cycle(labels_many))), "label") + nx.set_node_attributes( + G2, + dict(zip([mapped[n] for n in G1], it.cycle(labels_many))), + "label", + ) + m = vf2pp_isomorphism(G1, G2, node_label="label") + assert m + assert m == mapped + + # Delete edge maintaining isomorphism + G1.remove_edge(4, 9) + G2.remove_edge(4, 6) + + m = vf2pp_isomorphism(G1, G2, node_label="label") + assert m + assert m == mapped + + # Change edge orientation such that G1 mirrors G2 + G1.add_edges_from([(4, 9), (1, 2), (1, 2)]) + G1.remove_edges_from([(1, 3), (1, 3)]) + G2.add_edges_from([(3, 5), (7, 9)]) + G2.remove_edge(8, 9) + + m = vf2pp_isomorphism(G1, G2, node_label="label") + assert not m + + # Make all labels the same, so G1 and G2 are also isomorphic + for n1, n2 in zip(G1.nodes(), G2.nodes()): + G1.nodes[n1]["label"] = "blue" + G2.nodes[n2]["label"] = "blue" + + m = vf2pp_isomorphism(G1, G2, node_label="label") + assert m + + G1.add_node(10) + G2.add_node("Z") + G1.nodes[10]["label"] = "green" + G2.nodes["Z"]["label"] = "green" + + # Add different number of edges between the new nodes and themselves + G1.add_edges_from([(10, 10), (10, 10)]) + G2.add_edges_from([("Z", "Z")]) + + m = vf2pp_isomorphism(G1, G2, node_label="label") + assert not m + + # Make the number of self-edges equal + G1.remove_edge(10, 10) + m = vf2pp_isomorphism(G1, G2, node_label="label") + assert m + + # Connect the new node to the graph + G1.add_edges_from([(10, 3), (10, 4)]) + G2.add_edges_from([("Z", 8), ("Z", 3)]) + + m = vf2pp_isomorphism(G1, G2, node_label="label") + assert m + + # Remove central node + G1.remove_node(4) + G2.remove_node(3) + G1.add_edges_from([(5, 6), (5, 6), (5, 7)]) + G2.add_edges_from([(1, 6), (1, 6), (6, 2)]) + + m = vf2pp_isomorphism(G1, G2, node_label="label") + assert m + + def test_custom_multigraph4_same_labels(self): + G1 = nx.MultiGraph() + edges1 = [ + (1, 2), + (1, 2), + (2, 2), + (2, 3), + (3, 8), + (3, 8), + (3, 4), + (4, 5), + (4, 5), + (4, 5), + (4, 6), + (3, 6), + (3, 6), + (6, 6), + (8, 7), + (7, 7), + (8, 9), + (9, 9), + (8, 9), + (8, 9), + (5, 9), + (10, 11), + (11, 12), + (12, 13), + (11, 13), + (10, 10), + (10, 11), + (11, 13), + ] + + mapped = { + 1: "n", + 2: "m", + 3: "l", + 4: "j", + 5: "k", + 6: "i", + 7: "g", + 8: "h", + 9: "f", + 10: "b", + 11: "a", + 12: "d", + 13: "e", + } + + G1.add_edges_from(edges1) + G2 = nx.relabel_nodes(G1, mapped) + + nx.set_node_attributes(G1, dict(zip(G1, it.cycle(labels_same))), "label") + nx.set_node_attributes(G2, dict(zip(G2, it.cycle(labels_same))), "label") + m = vf2pp_isomorphism(G1, G2, node_label="label") + assert m + + # Add extra but corresponding edges to both graphs + G1.add_edges_from([(2, 2), (2, 3), (2, 8), (3, 4)]) + G2.add_edges_from([("m", "m"), ("m", "l"), ("m", "h"), ("l", "j")]) + m = vf2pp_isomorphism(G1, G2, node_label="label") + assert m + + # Obtain subgraphs + H1 = nx.MultiGraph(G1.subgraph([2, 3, 4, 6, 10, 11, 12, 13])) + H2 = nx.MultiGraph( + G2.subgraph( + [ + mapped[2], + mapped[3], + mapped[8], + mapped[9], + mapped[10], + mapped[11], + mapped[12], + mapped[13], + ] + ) + ) + + m = vf2pp_isomorphism(H1, H2, node_label="label") + assert not m + + # Make them isomorphic + H2.remove_edges_from( + [(mapped[3], mapped[2]), (mapped[9], mapped[8]), (mapped[2], mapped[2])] + ) + H2.add_edges_from([(mapped[9], mapped[9]), (mapped[2], mapped[8])]) + m = vf2pp_isomorphism(H1, H2, node_label="label") + assert m + + # Re-structure the disconnected sub-graph + H1.remove_node(12) + H2.remove_node(mapped[12]) + H1.add_edge(13, 13) + H2.add_edge(mapped[13], mapped[13]) + + # Connect the two disconnected components, forming a single graph + H1.add_edges_from([(3, 13), (6, 11)]) + H2.add_edges_from([(mapped[8], mapped[10]), (mapped[2], mapped[11])]) + m = vf2pp_isomorphism(H1, H2, node_label="label") + assert m + + # Change orientation of self-loops in one graph, maintaining the degree sequence + H1.remove_edges_from([(2, 2), (3, 6)]) + H1.add_edges_from([(6, 6), (2, 3)]) + m = vf2pp_isomorphism(H1, H2, node_label="label") + assert not m + + def test_custom_multigraph4_different_labels(self): + G1 = nx.MultiGraph() + edges1 = [ + (1, 2), + (1, 2), + (2, 2), + (2, 3), + (3, 8), + (3, 8), + (3, 4), + (4, 5), + (4, 5), + (4, 5), + (4, 6), + (3, 6), + (3, 6), + (6, 6), + (8, 7), + (7, 7), + (8, 9), + (9, 9), + (8, 9), + (8, 9), + (5, 9), + (10, 11), + (11, 12), + (12, 13), + (11, 13), + ] + + mapped = { + 1: "n", + 2: "m", + 3: "l", + 4: "j", + 5: "k", + 6: "i", + 7: "g", + 8: "h", + 9: "f", + 10: "b", + 11: "a", + 12: "d", + 13: "e", + } + + G1.add_edges_from(edges1) + G2 = nx.relabel_nodes(G1, mapped) + + nx.set_node_attributes(G1, dict(zip(G1, it.cycle(labels_many))), "label") + nx.set_node_attributes( + G2, + dict(zip([mapped[n] for n in G1], it.cycle(labels_many))), + "label", + ) + m = vf2pp_isomorphism(G1, G2, node_label="label") + assert m == mapped + + # Add extra but corresponding edges to both graphs + G1.add_edges_from([(2, 2), (2, 3), (2, 8), (3, 4)]) + G2.add_edges_from([("m", "m"), ("m", "l"), ("m", "h"), ("l", "j")]) + m = vf2pp_isomorphism(G1, G2, node_label="label") + assert m == mapped + + # Obtain isomorphic subgraphs + H1 = nx.MultiGraph(G1.subgraph([2, 3, 4, 6])) + H2 = nx.MultiGraph(G2.subgraph(["m", "l", "j", "i"])) + + m = vf2pp_isomorphism(H1, H2, node_label="label") + assert m + + # Delete the 3-clique, keeping only the path-graph. Also, H1 mirrors H2 + H1.remove_node(4) + H2.remove_node("j") + H1.remove_edges_from([(2, 2), (2, 3), (6, 6)]) + H2.remove_edges_from([("l", "i"), ("m", "m"), ("m", "m")]) + + m = vf2pp_isomorphism(H1, H2, node_label="label") + assert not m + + # Assign the same labels so that mirroring means isomorphic + for n1, n2 in zip(H1.nodes(), H2.nodes()): + H1.nodes[n1]["label"] = "red" + H2.nodes[n2]["label"] = "red" + + m = vf2pp_isomorphism(H1, H2, node_label="label") + assert m + + # Leave only one node with self-loop + H1.remove_nodes_from([3, 6]) + H2.remove_nodes_from(["m", "l"]) + m = vf2pp_isomorphism(H1, H2, node_label="label") + assert m + + # Remove one self-loop from H1 + H1.remove_edge(2, 2) + m = vf2pp_isomorphism(H1, H2, node_label="label") + assert not m + + # Same for H2 + H2.remove_edge("i", "i") + m = vf2pp_isomorphism(H1, H2, node_label="label") + assert m + + # Compose H1 with the disconnected sub-graph of G1. Same for H2 + S1 = nx.compose(H1, nx.MultiGraph(G1.subgraph([10, 11, 12, 13]))) + S2 = nx.compose(H2, nx.MultiGraph(G2.subgraph(["a", "b", "d", "e"]))) + + m = vf2pp_isomorphism(H1, H2, node_label="label") + assert m + + # Connect the two components + S1.add_edges_from([(13, 13), (13, 13), (2, 13)]) + S2.add_edges_from([("a", "a"), ("a", "a"), ("i", "e")]) + m = vf2pp_isomorphism(H1, H2, node_label="label") + assert m + + def test_custom_multigraph5_same_labels(self): + G1 = nx.MultiGraph() + + edges1 = [ + (1, 5), + (1, 2), + (1, 4), + (2, 3), + (2, 6), + (3, 4), + (3, 7), + (4, 8), + (5, 8), + (5, 6), + (6, 7), + (7, 8), + ] + mapped = {1: "a", 2: "h", 3: "d", 4: "i", 5: "g", 6: "b", 7: "j", 8: "c"} + + G1.add_edges_from(edges1) + G2 = nx.relabel_nodes(G1, mapped) + nx.set_node_attributes(G1, dict(zip(G1, it.cycle(labels_same))), "label") + nx.set_node_attributes(G2, dict(zip(G2, it.cycle(labels_same))), "label") + + m = vf2pp_isomorphism(G1, G2, node_label="label") + assert m + + # Add multiple edges and self-loops, maintaining isomorphism + G1.add_edges_from( + [(1, 2), (1, 2), (3, 7), (8, 8), (8, 8), (7, 8), (2, 3), (5, 6)] + ) + G2.add_edges_from( + [ + ("a", "h"), + ("a", "h"), + ("d", "j"), + ("c", "c"), + ("c", "c"), + ("j", "c"), + ("d", "h"), + ("g", "b"), + ] + ) + + m = vf2pp_isomorphism(G1, G2, node_label="label") + assert m + + # Make G2 to be the rotated G1 + G2.remove_edges_from( + [ + ("a", "h"), + ("a", "h"), + ("d", "j"), + ("c", "c"), + ("c", "c"), + ("j", "c"), + ("d", "h"), + ("g", "b"), + ] + ) + G2.add_edges_from( + [ + ("d", "i"), + ("a", "h"), + ("g", "b"), + ("g", "b"), + ("i", "i"), + ("i", "i"), + ("b", "j"), + ("d", "j"), + ] + ) + + m = vf2pp_isomorphism(G1, G2, node_label="label") + assert m + + def test_disconnected_multigraph_all_same_labels(self): + G1 = nx.MultiGraph() + G1.add_nodes_from(list(range(10))) + G1.add_edges_from([(i, i) for i in range(10)]) + + mapped = {0: 9, 1: 8, 2: 7, 3: 6, 4: 5, 5: 4, 6: 3, 7: 2, 8: 1, 9: 0} + G2 = nx.relabel_nodes(G1, mapped) + + nx.set_node_attributes(G1, dict(zip(G1, it.cycle(labels_same))), "label") + nx.set_node_attributes(G2, dict(zip(G2, it.cycle(labels_same))), "label") + + m = vf2pp_isomorphism(G1, G2, node_label="label") + assert m + + # Add self-loops to non-mapped nodes. Should be the same, as the graph is disconnected. + G1.add_edges_from([(i, i) for i in range(5, 8)] * 3) + m = vf2pp_isomorphism(G1, G2, node_label="label") + assert not m + + # Compensate in G2 + G2.add_edges_from([(i, i) for i in range(3)] * 3) + m = vf2pp_isomorphism(G1, G2, node_label="label") + assert m + + # Add one more self-loop in G2 + G2.add_edges_from([(0, 0), (1, 1), (1, 1)]) + m = vf2pp_isomorphism(G1, G2, node_label="label") + assert not m + + # Compensate in G1 + G1.add_edges_from([(5, 5), (7, 7), (7, 7)]) + m = vf2pp_isomorphism(G1, G2, node_label="label") + assert m + + def test_disconnected_multigraph_all_different_labels(self): + G1 = nx.MultiGraph() + G1.add_nodes_from(list(range(10))) + G1.add_edges_from([(i, i) for i in range(10)]) + + mapped = {0: 9, 1: 8, 2: 7, 3: 6, 4: 5, 5: 4, 6: 3, 7: 2, 8: 1, 9: 0} + G2 = nx.relabel_nodes(G1, mapped) + + nx.set_node_attributes(G1, dict(zip(G1, it.cycle(labels_many))), "label") + nx.set_node_attributes( + G2, + dict(zip([mapped[n] for n in G1], it.cycle(labels_many))), + "label", + ) + m = vf2pp_isomorphism(G1, G2, node_label="label") + assert m + assert m == mapped + + # Add self-loops to non-mapped nodes. Now it is not the same, as there are different labels + G1.add_edges_from([(i, i) for i in range(5, 8)] * 3) + m = vf2pp_isomorphism(G1, G2, node_label="label") + assert not m + + # Add self-loops to non mapped nodes in G2 as well + G2.add_edges_from([(mapped[i], mapped[i]) for i in range(3)] * 7) + m = vf2pp_isomorphism(G1, G2, node_label="label") + assert not m + + # Add self-loops to mapped nodes in G2 + G2.add_edges_from([(mapped[i], mapped[i]) for i in range(5, 8)] * 3) + m = vf2pp_isomorphism(G1, G2, node_label="label") + assert not m + + # Add self-loops to G1 so that they are even in both graphs + G1.add_edges_from([(i, i) for i in range(3)] * 7) + m = vf2pp_isomorphism(G1, G2, node_label="label") + assert m + + +class TestDiGraphISOVF2pp: + def test_wikipedia_graph(self): + edges1 = [ + (1, 5), + (1, 2), + (1, 4), + (3, 2), + (6, 2), + (3, 4), + (7, 3), + (4, 8), + (5, 8), + (6, 5), + (6, 7), + (7, 8), + ] + mapped = {1: "a", 2: "h", 3: "d", 4: "i", 5: "g", 6: "b", 7: "j", 8: "c"} + + G1 = nx.DiGraph(edges1) + G2 = nx.relabel_nodes(G1, mapped) + + assert vf2pp_isomorphism(G1, G2) == mapped + + # Change the direction of an edge + G1.remove_edge(1, 5) + G1.add_edge(5, 1) + assert vf2pp_isomorphism(G1, G2) is None + + def test_non_isomorphic_same_degree_sequence(self): + r""" + G1 G2 + x--------------x x--------------x + | \ | | \ | + | x-------x | | x-------x | + | | | | | | | | + | x-------x | | x-------x | + | / | | \ | + x--------------x x--------------x + """ + edges1 = [ + (1, 5), + (1, 2), + (4, 1), + (3, 2), + (3, 4), + (4, 8), + (5, 8), + (6, 5), + (6, 7), + (7, 8), + ] + edges2 = [ + (1, 5), + (1, 2), + (4, 1), + (3, 2), + (4, 3), + (5, 8), + (6, 5), + (6, 7), + (3, 7), + (8, 7), + ] + + G1 = nx.DiGraph(edges1) + G2 = nx.DiGraph(edges2) + assert vf2pp_isomorphism(G1, G2) is None diff --git a/lib/python3.10/site-packages/networkx/algorithms/isomorphism/tests/test_vf2pp_helpers.py b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/tests/test_vf2pp_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..0e29b1be61734603537f2431eb2de9f1ebfab459 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/tests/test_vf2pp_helpers.py @@ -0,0 +1,3106 @@ +import itertools as it + +import pytest + +import networkx as nx +from networkx import vf2pp_is_isomorphic, vf2pp_isomorphism +from networkx.algorithms.isomorphism.vf2pp import ( + _consistent_PT, + _cut_PT, + _feasibility, + _find_candidates, + _find_candidates_Di, + _GraphParameters, + _initialize_parameters, + _matching_order, + _restore_Tinout, + _restore_Tinout_Di, + _StateParameters, + _update_Tinout, +) + +labels_same = ["blue"] + +labels_many = [ + "white", + "red", + "blue", + "green", + "orange", + "black", + "purple", + "yellow", + "brown", + "cyan", + "solarized", + "pink", + "none", +] + + +class TestNodeOrdering: + def test_empty_graph(self): + G1 = nx.Graph() + G2 = nx.Graph() + gparams = _GraphParameters(G1, G2, None, None, None, None, None) + assert len(set(_matching_order(gparams))) == 0 + + def test_single_node(self): + G1 = nx.Graph() + G2 = nx.Graph() + G1.add_node(1) + G2.add_node(1) + + nx.set_node_attributes(G1, dict(zip(G1, it.cycle(labels_many))), "label") + nx.set_node_attributes( + G2, + dict(zip(G2, it.cycle(labels_many))), + "label", + ) + l1, l2 = ( + nx.get_node_attributes(G1, "label"), + nx.get_node_attributes(G2, "label"), + ) + + gparams = _GraphParameters( + G1, + G2, + l1, + l2, + nx.utils.groups(l1), + nx.utils.groups(l2), + nx.utils.groups(dict(G2.degree())), + ) + m = _matching_order(gparams) + assert m == [1] + + def test_matching_order(self): + labels = [ + "blue", + "blue", + "red", + "red", + "red", + "red", + "green", + "green", + "green", + "yellow", + "purple", + "purple", + "blue", + "blue", + ] + G1 = nx.Graph( + [ + (0, 1), + (0, 2), + (1, 2), + (2, 5), + (2, 4), + (1, 3), + (1, 4), + (3, 6), + (4, 6), + (6, 7), + (7, 8), + (9, 10), + (9, 11), + (11, 12), + (11, 13), + (12, 13), + (10, 13), + ] + ) + G2 = G1.copy() + nx.set_node_attributes(G1, dict(zip(G1, it.cycle(labels))), "label") + nx.set_node_attributes( + G2, + dict(zip(G2, it.cycle(labels))), + "label", + ) + l1, l2 = ( + nx.get_node_attributes(G1, "label"), + nx.get_node_attributes(G2, "label"), + ) + gparams = _GraphParameters( + G1, + G2, + l1, + l2, + nx.utils.groups(l1), + nx.utils.groups(l2), + nx.utils.groups(dict(G2.degree())), + ) + + expected = [9, 11, 10, 13, 12, 1, 2, 4, 0, 3, 6, 5, 7, 8] + assert _matching_order(gparams) == expected + + def test_matching_order_all_branches(self): + G1 = nx.Graph( + [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (1, 4), (2, 4), (3, 4)] + ) + G1.add_node(5) + G2 = G1.copy() + + G1.nodes[0]["label"] = "black" + G1.nodes[1]["label"] = "blue" + G1.nodes[2]["label"] = "blue" + G1.nodes[3]["label"] = "red" + G1.nodes[4]["label"] = "red" + G1.nodes[5]["label"] = "blue" + + G2.nodes[0]["label"] = "black" + G2.nodes[1]["label"] = "blue" + G2.nodes[2]["label"] = "blue" + G2.nodes[3]["label"] = "red" + G2.nodes[4]["label"] = "red" + G2.nodes[5]["label"] = "blue" + + l1, l2 = ( + nx.get_node_attributes(G1, "label"), + nx.get_node_attributes(G2, "label"), + ) + gparams = _GraphParameters( + G1, + G2, + l1, + l2, + nx.utils.groups(l1), + nx.utils.groups(l2), + nx.utils.groups(dict(G2.degree())), + ) + + expected = [0, 4, 1, 3, 2, 5] + assert _matching_order(gparams) == expected + + +class TestGraphCandidateSelection: + G1_edges = [ + (1, 2), + (1, 4), + (1, 5), + (2, 3), + (2, 4), + (3, 4), + (4, 5), + (1, 6), + (6, 7), + (6, 8), + (8, 9), + (7, 9), + ] + mapped = { + 0: "x", + 1: "a", + 2: "b", + 3: "c", + 4: "d", + 5: "e", + 6: "f", + 7: "g", + 8: "h", + 9: "i", + } + + def test_no_covered_neighbors_no_labels(self): + G1 = nx.Graph() + G1.add_edges_from(self.G1_edges) + G1.add_node(0) + G2 = nx.relabel_nodes(G1, self.mapped) + + G1_degree = dict(G1.degree) + l1 = dict(G1.nodes(data="label", default=-1)) + l2 = dict(G2.nodes(data="label", default=-1)) + gparams = _GraphParameters( + G1, + G2, + l1, + l2, + nx.utils.groups(l1), + nx.utils.groups(l2), + nx.utils.groups(dict(G2.degree())), + ) + + m = {9: self.mapped[9], 1: self.mapped[1]} + m_rev = {self.mapped[9]: 9, self.mapped[1]: 1} + + T1 = {7, 8, 2, 4, 5} + T1_tilde = {0, 3, 6} + T2 = {"g", "h", "b", "d", "e"} + T2_tilde = {"x", "c", "f"} + + sparams = _StateParameters( + m, m_rev, T1, None, T1_tilde, None, T2, None, T2_tilde, None + ) + + u = 3 + candidates = _find_candidates(u, gparams, sparams, G1_degree) + assert candidates == {self.mapped[u]} + + u = 0 + candidates = _find_candidates(u, gparams, sparams, G1_degree) + assert candidates == {self.mapped[u]} + + m.pop(9) + m_rev.pop(self.mapped[9]) + + T1 = {2, 4, 5, 6} + T1_tilde = {0, 3, 7, 8, 9} + T2 = {"g", "h", "b", "d", "e", "f"} + T2_tilde = {"x", "c", "g", "h", "i"} + + sparams = _StateParameters( + m, m_rev, T1, None, T1_tilde, None, T2, None, T2_tilde, None + ) + + u = 7 + candidates = _find_candidates(u, gparams, sparams, G1_degree) + assert candidates == { + self.mapped[u], + self.mapped[8], + self.mapped[3], + self.mapped[9], + } + + def test_no_covered_neighbors_with_labels(self): + G1 = nx.Graph() + G1.add_edges_from(self.G1_edges) + G1.add_node(0) + G2 = nx.relabel_nodes(G1, self.mapped) + + G1_degree = dict(G1.degree) + nx.set_node_attributes( + G1, + dict(zip(G1, it.cycle(labels_many))), + "label", + ) + nx.set_node_attributes( + G2, + dict( + zip( + [self.mapped[n] for n in G1], + it.cycle(labels_many), + ) + ), + "label", + ) + l1 = dict(G1.nodes(data="label", default=-1)) + l2 = dict(G2.nodes(data="label", default=-1)) + gparams = _GraphParameters( + G1, + G2, + l1, + l2, + nx.utils.groups(l1), + nx.utils.groups(l2), + nx.utils.groups(dict(G2.degree())), + ) + + m = {9: self.mapped[9], 1: self.mapped[1]} + m_rev = {self.mapped[9]: 9, self.mapped[1]: 1} + + T1 = {7, 8, 2, 4, 5, 6} + T1_tilde = {0, 3} + T2 = {"g", "h", "b", "d", "e", "f"} + T2_tilde = {"x", "c"} + + sparams = _StateParameters( + m, m_rev, T1, None, T1_tilde, None, T2, None, T2_tilde, None + ) + + u = 3 + candidates = _find_candidates(u, gparams, sparams, G1_degree) + assert candidates == {self.mapped[u]} + + u = 0 + candidates = _find_candidates(u, gparams, sparams, G1_degree) + assert candidates == {self.mapped[u]} + + # Change label of disconnected node + G1.nodes[u]["label"] = "blue" + l1 = dict(G1.nodes(data="label", default=-1)) + l2 = dict(G2.nodes(data="label", default=-1)) + gparams = _GraphParameters( + G1, + G2, + l1, + l2, + nx.utils.groups(l1), + nx.utils.groups(l2), + nx.utils.groups(dict(G2.degree())), + ) + + # No candidate + candidates = _find_candidates(u, gparams, sparams, G1_degree) + assert candidates == set() + + m.pop(9) + m_rev.pop(self.mapped[9]) + + T1 = {2, 4, 5, 6} + T1_tilde = {0, 3, 7, 8, 9} + T2 = {"b", "d", "e", "f"} + T2_tilde = {"x", "c", "g", "h", "i"} + + sparams = _StateParameters( + m, m_rev, T1, None, T1_tilde, None, T2, None, T2_tilde, None + ) + + u = 7 + candidates = _find_candidates(u, gparams, sparams, G1_degree) + assert candidates == {self.mapped[u]} + + G1.nodes[8]["label"] = G1.nodes[7]["label"] + G2.nodes[self.mapped[8]]["label"] = G1.nodes[7]["label"] + l1 = dict(G1.nodes(data="label", default=-1)) + l2 = dict(G2.nodes(data="label", default=-1)) + gparams = _GraphParameters( + G1, + G2, + l1, + l2, + nx.utils.groups(l1), + nx.utils.groups(l2), + nx.utils.groups(dict(G2.degree())), + ) + + candidates = _find_candidates(u, gparams, sparams, G1_degree) + assert candidates == {self.mapped[u], self.mapped[8]} + + def test_covered_neighbors_no_labels(self): + G1 = nx.Graph() + G1.add_edges_from(self.G1_edges) + G1.add_node(0) + G2 = nx.relabel_nodes(G1, self.mapped) + + G1_degree = dict(G1.degree) + l1 = dict(G1.nodes(data=None, default=-1)) + l2 = dict(G2.nodes(data=None, default=-1)) + gparams = _GraphParameters( + G1, + G2, + l1, + l2, + nx.utils.groups(l1), + nx.utils.groups(l2), + nx.utils.groups(dict(G2.degree())), + ) + + m = {9: self.mapped[9], 1: self.mapped[1]} + m_rev = {self.mapped[9]: 9, self.mapped[1]: 1} + + T1 = {7, 8, 2, 4, 5, 6} + T1_tilde = {0, 3} + T2 = {"g", "h", "b", "d", "e", "f"} + T2_tilde = {"x", "c"} + + sparams = _StateParameters( + m, m_rev, T1, None, T1_tilde, None, T2, None, T2_tilde, None + ) + + u = 5 + candidates = _find_candidates(u, gparams, sparams, G1_degree) + assert candidates == {self.mapped[u]} + + u = 6 + candidates = _find_candidates(u, gparams, sparams, G1_degree) + assert candidates == {self.mapped[u], self.mapped[2]} + + def test_covered_neighbors_with_labels(self): + G1 = nx.Graph() + G1.add_edges_from(self.G1_edges) + G1.add_node(0) + G2 = nx.relabel_nodes(G1, self.mapped) + + G1_degree = dict(G1.degree) + nx.set_node_attributes( + G1, + dict(zip(G1, it.cycle(labels_many))), + "label", + ) + nx.set_node_attributes( + G2, + dict( + zip( + [self.mapped[n] for n in G1], + it.cycle(labels_many), + ) + ), + "label", + ) + l1 = dict(G1.nodes(data="label", default=-1)) + l2 = dict(G2.nodes(data="label", default=-1)) + gparams = _GraphParameters( + G1, + G2, + l1, + l2, + nx.utils.groups(l1), + nx.utils.groups(l2), + nx.utils.groups(dict(G2.degree())), + ) + + m = {9: self.mapped[9], 1: self.mapped[1]} + m_rev = {self.mapped[9]: 9, self.mapped[1]: 1} + + T1 = {7, 8, 2, 4, 5, 6} + T1_tilde = {0, 3} + T2 = {"g", "h", "b", "d", "e", "f"} + T2_tilde = {"x", "c"} + + sparams = _StateParameters( + m, m_rev, T1, None, T1_tilde, None, T2, None, T2_tilde, None + ) + + u = 5 + candidates = _find_candidates(u, gparams, sparams, G1_degree) + assert candidates == {self.mapped[u]} + + u = 6 + candidates = _find_candidates(u, gparams, sparams, G1_degree) + assert candidates == {self.mapped[u]} + + # Assign to 2, the same label as 6 + G1.nodes[2]["label"] = G1.nodes[u]["label"] + G2.nodes[self.mapped[2]]["label"] = G1.nodes[u]["label"] + l1 = dict(G1.nodes(data="label", default=-1)) + l2 = dict(G2.nodes(data="label", default=-1)) + gparams = _GraphParameters( + G1, + G2, + l1, + l2, + nx.utils.groups(l1), + nx.utils.groups(l2), + nx.utils.groups(dict(G2.degree())), + ) + + candidates = _find_candidates(u, gparams, sparams, G1_degree) + assert candidates == {self.mapped[u], self.mapped[2]} + + +class TestDiGraphCandidateSelection: + G1_edges = [ + (1, 2), + (1, 4), + (5, 1), + (2, 3), + (4, 2), + (3, 4), + (4, 5), + (1, 6), + (6, 7), + (6, 8), + (8, 9), + (7, 9), + ] + mapped = { + 0: "x", + 1: "a", + 2: "b", + 3: "c", + 4: "d", + 5: "e", + 6: "f", + 7: "g", + 8: "h", + 9: "i", + } + + def test_no_covered_neighbors_no_labels(self): + G1 = nx.DiGraph() + G1.add_edges_from(self.G1_edges) + G1.add_node(0) + G2 = nx.relabel_nodes(G1, self.mapped) + + G1_degree = { + n: (in_degree, out_degree) + for (n, in_degree), (_, out_degree) in zip(G1.in_degree, G1.out_degree) + } + + l1 = dict(G1.nodes(data="label", default=-1)) + l2 = dict(G2.nodes(data="label", default=-1)) + gparams = _GraphParameters( + G1, + G2, + l1, + l2, + nx.utils.groups(l1), + nx.utils.groups(l2), + nx.utils.groups( + { + node: (in_degree, out_degree) + for (node, in_degree), (_, out_degree) in zip( + G2.in_degree(), G2.out_degree() + ) + } + ), + ) + + m = {9: self.mapped[9], 1: self.mapped[1]} + m_rev = {self.mapped[9]: 9, self.mapped[1]: 1} + + T1_out = {2, 4, 6} + T1_in = {5, 7, 8} + T1_tilde = {0, 3} + T2_out = {"b", "d", "f"} + T2_in = {"e", "g", "h"} + T2_tilde = {"x", "c"} + + sparams = _StateParameters( + m, m_rev, T1_out, T1_in, T1_tilde, None, T2_out, T2_in, T2_tilde, None + ) + + u = 3 + candidates = _find_candidates_Di(u, gparams, sparams, G1_degree) + assert candidates == {self.mapped[u]} + + u = 0 + candidates = _find_candidates_Di(u, gparams, sparams, G1_degree) + assert candidates == {self.mapped[u]} + + m.pop(9) + m_rev.pop(self.mapped[9]) + + T1_out = {2, 4, 6} + T1_in = {5} + T1_tilde = {0, 3, 7, 8, 9} + T2_out = {"b", "d", "f"} + T2_in = {"e"} + T2_tilde = {"x", "c", "g", "h", "i"} + + sparams = _StateParameters( + m, m_rev, T1_out, T1_in, T1_tilde, None, T2_out, T2_in, T2_tilde, None + ) + + u = 7 + candidates = _find_candidates_Di(u, gparams, sparams, G1_degree) + assert candidates == {self.mapped[u], self.mapped[8], self.mapped[3]} + + def test_no_covered_neighbors_with_labels(self): + G1 = nx.DiGraph() + G1.add_edges_from(self.G1_edges) + G1.add_node(0) + G2 = nx.relabel_nodes(G1, self.mapped) + + G1_degree = { + n: (in_degree, out_degree) + for (n, in_degree), (_, out_degree) in zip(G1.in_degree, G1.out_degree) + } + nx.set_node_attributes( + G1, + dict(zip(G1, it.cycle(labels_many))), + "label", + ) + nx.set_node_attributes( + G2, + dict( + zip( + [self.mapped[n] for n in G1], + it.cycle(labels_many), + ) + ), + "label", + ) + l1 = dict(G1.nodes(data="label", default=-1)) + l2 = dict(G2.nodes(data="label", default=-1)) + gparams = _GraphParameters( + G1, + G2, + l1, + l2, + nx.utils.groups(l1), + nx.utils.groups(l2), + nx.utils.groups( + { + node: (in_degree, out_degree) + for (node, in_degree), (_, out_degree) in zip( + G2.in_degree(), G2.out_degree() + ) + } + ), + ) + + m = {9: self.mapped[9], 1: self.mapped[1]} + m_rev = {self.mapped[9]: 9, self.mapped[1]: 1} + + T1_out = {2, 4, 6} + T1_in = {5, 7, 8} + T1_tilde = {0, 3} + T2_out = {"b", "d", "f"} + T2_in = {"e", "g", "h"} + T2_tilde = {"x", "c"} + + sparams = _StateParameters( + m, m_rev, T1_out, T1_in, T1_tilde, None, T2_out, T2_in, T2_tilde, None + ) + + u = 3 + candidates = _find_candidates_Di(u, gparams, sparams, G1_degree) + assert candidates == {self.mapped[u]} + + u = 0 + candidates = _find_candidates_Di(u, gparams, sparams, G1_degree) + assert candidates == {self.mapped[u]} + + # Change label of disconnected node + G1.nodes[u]["label"] = "blue" + l1 = dict(G1.nodes(data="label", default=-1)) + l2 = dict(G2.nodes(data="label", default=-1)) + gparams = _GraphParameters( + G1, + G2, + l1, + l2, + nx.utils.groups(l1), + nx.utils.groups(l2), + nx.utils.groups( + { + node: (in_degree, out_degree) + for (node, in_degree), (_, out_degree) in zip( + G2.in_degree(), G2.out_degree() + ) + } + ), + ) + + # No candidate + candidates = _find_candidates_Di(u, gparams, sparams, G1_degree) + assert candidates == set() + + m.pop(9) + m_rev.pop(self.mapped[9]) + + T1_out = {2, 4, 6} + T1_in = {5} + T1_tilde = {0, 3, 7, 8, 9} + T2_out = {"b", "d", "f"} + T2_in = {"e"} + T2_tilde = {"x", "c", "g", "h", "i"} + + sparams = _StateParameters( + m, m_rev, T1_out, T1_in, T1_tilde, None, T2_out, T2_in, T2_tilde, None + ) + + u = 7 + candidates = _find_candidates_Di(u, gparams, sparams, G1_degree) + assert candidates == {self.mapped[u]} + + G1.nodes[8]["label"] = G1.nodes[7]["label"] + G2.nodes[self.mapped[8]]["label"] = G1.nodes[7]["label"] + l1 = dict(G1.nodes(data="label", default=-1)) + l2 = dict(G2.nodes(data="label", default=-1)) + gparams = _GraphParameters( + G1, + G2, + l1, + l2, + nx.utils.groups(l1), + nx.utils.groups(l2), + nx.utils.groups( + { + node: (in_degree, out_degree) + for (node, in_degree), (_, out_degree) in zip( + G2.in_degree(), G2.out_degree() + ) + } + ), + ) + + candidates = _find_candidates_Di(u, gparams, sparams, G1_degree) + assert candidates == {self.mapped[u], self.mapped[8]} + + def test_covered_neighbors_no_labels(self): + G1 = nx.DiGraph() + G1.add_edges_from(self.G1_edges) + G1.add_node(0) + G2 = nx.relabel_nodes(G1, self.mapped) + + G1_degree = { + n: (in_degree, out_degree) + for (n, in_degree), (_, out_degree) in zip(G1.in_degree, G1.out_degree) + } + + l1 = dict(G1.nodes(data=None, default=-1)) + l2 = dict(G2.nodes(data=None, default=-1)) + gparams = _GraphParameters( + G1, + G2, + l1, + l2, + nx.utils.groups(l1), + nx.utils.groups(l2), + nx.utils.groups( + { + node: (in_degree, out_degree) + for (node, in_degree), (_, out_degree) in zip( + G2.in_degree(), G2.out_degree() + ) + } + ), + ) + + m = {9: self.mapped[9], 1: self.mapped[1]} + m_rev = {self.mapped[9]: 9, self.mapped[1]: 1} + + T1_out = {2, 4, 6} + T1_in = {5, 7, 8} + T1_tilde = {0, 3} + T2_out = {"b", "d", "f"} + T2_in = {"e", "g", "h"} + T2_tilde = {"x", "c"} + + sparams = _StateParameters( + m, m_rev, T1_out, T1_in, T1_tilde, None, T2_out, T2_in, T2_tilde, None + ) + + u = 5 + candidates = _find_candidates_Di(u, gparams, sparams, G1_degree) + assert candidates == {self.mapped[u]} + + u = 6 + candidates = _find_candidates_Di(u, gparams, sparams, G1_degree) + assert candidates == {self.mapped[u]} + + # Change the direction of an edge to make the degree orientation same as first candidate of u. + G1.remove_edge(4, 2) + G1.add_edge(2, 4) + G2.remove_edge("d", "b") + G2.add_edge("b", "d") + + gparams = _GraphParameters( + G1, + G2, + l1, + l2, + nx.utils.groups(l1), + nx.utils.groups(l2), + nx.utils.groups( + { + node: (in_degree, out_degree) + for (node, in_degree), (_, out_degree) in zip( + G2.in_degree(), G2.out_degree() + ) + } + ), + ) + + candidates = _find_candidates_Di(u, gparams, sparams, G1_degree) + assert candidates == {self.mapped[u], self.mapped[2]} + + def test_covered_neighbors_with_labels(self): + G1 = nx.DiGraph() + G1.add_edges_from(self.G1_edges) + G1.add_node(0) + G2 = nx.relabel_nodes(G1, self.mapped) + + G1.remove_edge(4, 2) + G1.add_edge(2, 4) + G2.remove_edge("d", "b") + G2.add_edge("b", "d") + + G1_degree = { + n: (in_degree, out_degree) + for (n, in_degree), (_, out_degree) in zip(G1.in_degree, G1.out_degree) + } + + nx.set_node_attributes( + G1, + dict(zip(G1, it.cycle(labels_many))), + "label", + ) + nx.set_node_attributes( + G2, + dict( + zip( + [self.mapped[n] for n in G1], + it.cycle(labels_many), + ) + ), + "label", + ) + l1 = dict(G1.nodes(data="label", default=-1)) + l2 = dict(G2.nodes(data="label", default=-1)) + gparams = _GraphParameters( + G1, + G2, + l1, + l2, + nx.utils.groups(l1), + nx.utils.groups(l2), + nx.utils.groups( + { + node: (in_degree, out_degree) + for (node, in_degree), (_, out_degree) in zip( + G2.in_degree(), G2.out_degree() + ) + } + ), + ) + + m = {9: self.mapped[9], 1: self.mapped[1]} + m_rev = {self.mapped[9]: 9, self.mapped[1]: 1} + + T1_out = {2, 4, 6} + T1_in = {5, 7, 8} + T1_tilde = {0, 3} + T2_out = {"b", "d", "f"} + T2_in = {"e", "g", "h"} + T2_tilde = {"x", "c"} + + sparams = _StateParameters( + m, m_rev, T1_out, T1_in, T1_tilde, None, T2_out, T2_in, T2_tilde, None + ) + + u = 5 + candidates = _find_candidates_Di(u, gparams, sparams, G1_degree) + assert candidates == {self.mapped[u]} + + u = 6 + candidates = _find_candidates_Di(u, gparams, sparams, G1_degree) + assert candidates == {self.mapped[u]} + + # Assign to 2, the same label as 6 + G1.nodes[2]["label"] = G1.nodes[u]["label"] + G2.nodes[self.mapped[2]]["label"] = G1.nodes[u]["label"] + l1 = dict(G1.nodes(data="label", default=-1)) + l2 = dict(G2.nodes(data="label", default=-1)) + gparams = _GraphParameters( + G1, + G2, + l1, + l2, + nx.utils.groups(l1), + nx.utils.groups(l2), + nx.utils.groups( + { + node: (in_degree, out_degree) + for (node, in_degree), (_, out_degree) in zip( + G2.in_degree(), G2.out_degree() + ) + } + ), + ) + + candidates = _find_candidates_Di(u, gparams, sparams, G1_degree) + assert candidates == {self.mapped[u], self.mapped[2]} + + # Change the direction of an edge to make the degree orientation same as first candidate of u. + G1.remove_edge(2, 4) + G1.add_edge(4, 2) + G2.remove_edge("b", "d") + G2.add_edge("d", "b") + + gparams = _GraphParameters( + G1, + G2, + l1, + l2, + nx.utils.groups(l1), + nx.utils.groups(l2), + nx.utils.groups( + { + node: (in_degree, out_degree) + for (node, in_degree), (_, out_degree) in zip( + G2.in_degree(), G2.out_degree() + ) + } + ), + ) + + candidates = _find_candidates_Di(u, gparams, sparams, G1_degree) + assert candidates == {self.mapped[u]} + + def test_same_in_out_degrees_no_candidate(self): + g1 = nx.DiGraph([(4, 1), (4, 2), (3, 4), (5, 4), (6, 4)]) + g2 = nx.DiGraph([(1, 4), (2, 4), (3, 4), (4, 5), (4, 6)]) + + l1 = dict(g1.nodes(data=None, default=-1)) + l2 = dict(g2.nodes(data=None, default=-1)) + gparams = _GraphParameters( + g1, + g2, + l1, + l2, + nx.utils.groups(l1), + nx.utils.groups(l2), + nx.utils.groups( + { + node: (in_degree, out_degree) + for (node, in_degree), (_, out_degree) in zip( + g2.in_degree(), g2.out_degree() + ) + } + ), + ) + + g1_degree = { + n: (in_degree, out_degree) + for (n, in_degree), (_, out_degree) in zip(g1.in_degree, g1.out_degree) + } + + m = {1: 1, 2: 2, 3: 3} + m_rev = m.copy() + + T1_out = {4} + T1_in = {4} + T1_tilde = {5, 6} + T2_out = {4} + T2_in = {4} + T2_tilde = {5, 6} + + sparams = _StateParameters( + m, m_rev, T1_out, T1_in, T1_tilde, None, T2_out, T2_in, T2_tilde, None + ) + + u = 4 + # despite the same in and out degree, there's no candidate for u=4 + candidates = _find_candidates_Di(u, gparams, sparams, g1_degree) + assert candidates == set() + # Notice how the regular candidate selection method returns wrong result. + assert _find_candidates(u, gparams, sparams, g1_degree) == {4} + + +class TestGraphISOFeasibility: + def test_const_covered_neighbors(self): + G1 = nx.Graph([(0, 1), (1, 2), (3, 0), (3, 2)]) + G2 = nx.Graph([("a", "b"), ("b", "c"), ("k", "a"), ("k", "c")]) + gparams = _GraphParameters(G1, G2, None, None, None, None, None) + sparams = _StateParameters( + {0: "a", 1: "b", 2: "c"}, + {"a": 0, "b": 1, "c": 2}, + None, + None, + None, + None, + None, + None, + None, + None, + ) + u, v = 3, "k" + assert _consistent_PT(u, v, gparams, sparams) + + def test_const_no_covered_neighbors(self): + G1 = nx.Graph([(0, 1), (1, 2), (3, 4), (3, 5)]) + G2 = nx.Graph([("a", "b"), ("b", "c"), ("k", "w"), ("k", "z")]) + gparams = _GraphParameters(G1, G2, None, None, None, None, None) + sparams = _StateParameters( + {0: "a", 1: "b", 2: "c"}, + {"a": 0, "b": 1, "c": 2}, + None, + None, + None, + None, + None, + None, + None, + None, + ) + u, v = 3, "k" + assert _consistent_PT(u, v, gparams, sparams) + + def test_const_mixed_covered_uncovered_neighbors(self): + G1 = nx.Graph([(0, 1), (1, 2), (3, 0), (3, 2), (3, 4), (3, 5)]) + G2 = nx.Graph( + [("a", "b"), ("b", "c"), ("k", "a"), ("k", "c"), ("k", "w"), ("k", "z")] + ) + gparams = _GraphParameters(G1, G2, None, None, None, None, None) + sparams = _StateParameters( + {0: "a", 1: "b", 2: "c"}, + {"a": 0, "b": 1, "c": 2}, + None, + None, + None, + None, + None, + None, + None, + None, + ) + u, v = 3, "k" + assert _consistent_PT(u, v, gparams, sparams) + + def test_const_fail_cases(self): + G1 = nx.Graph( + [ + (0, 1), + (1, 2), + (10, 0), + (10, 3), + (10, 4), + (10, 5), + (10, 6), + (4, 1), + (5, 3), + ] + ) + G2 = nx.Graph( + [ + ("a", "b"), + ("b", "c"), + ("k", "a"), + ("k", "d"), + ("k", "e"), + ("k", "f"), + ("k", "g"), + ("e", "b"), + ("f", "d"), + ] + ) + gparams = _GraphParameters(G1, G2, None, None, None, None, None) + sparams = _StateParameters( + {0: "a", 1: "b", 2: "c", 3: "d"}, + {"a": 0, "b": 1, "c": 2, "d": 3}, + None, + None, + None, + None, + None, + None, + None, + None, + ) + u, v = 10, "k" + assert _consistent_PT(u, v, gparams, sparams) + + # Delete one uncovered neighbor of u. Notice how it still passes the test. + # Two reasons for this: + # 1. If u, v had different degrees from the beginning, they wouldn't + # be selected as candidates in the first place. + # 2. Even if they are selected, consistency is basically 1-look-ahead, + # meaning that we take into consideration the relation of the + # candidates with their mapped neighbors. The node we deleted is + # not a covered neighbor. + # Such nodes will be checked by the cut_PT function, which is + # basically the 2-look-ahead, checking the relation of the + # candidates with T1, T2 (in which belongs the node we just deleted). + G1.remove_node(6) + assert _consistent_PT(u, v, gparams, sparams) + + # Add one more covered neighbor of u in G1 + G1.add_edge(u, 2) + assert not _consistent_PT(u, v, gparams, sparams) + + # Compensate in G2 + G2.add_edge(v, "c") + assert _consistent_PT(u, v, gparams, sparams) + + # Add one more covered neighbor of v in G2 + G2.add_edge(v, "x") + G1.add_node(7) + sparams.mapping.update({7: "x"}) + sparams.reverse_mapping.update({"x": 7}) + assert not _consistent_PT(u, v, gparams, sparams) + + # Compendate in G1 + G1.add_edge(u, 7) + assert _consistent_PT(u, v, gparams, sparams) + + @pytest.mark.parametrize("graph_type", (nx.Graph, nx.DiGraph)) + def test_cut_inconsistent_labels(self, graph_type): + G1 = graph_type( + [ + (0, 1), + (1, 2), + (10, 0), + (10, 3), + (10, 4), + (10, 5), + (10, 6), + (4, 1), + (5, 3), + ] + ) + G2 = graph_type( + [ + ("a", "b"), + ("b", "c"), + ("k", "a"), + ("k", "d"), + ("k", "e"), + ("k", "f"), + ("k", "g"), + ("e", "b"), + ("f", "d"), + ] + ) + + l1 = {n: "blue" for n in G1.nodes()} + l2 = {n: "blue" for n in G2.nodes()} + l1.update({6: "green"}) # Change the label of one neighbor of u + + gparams = _GraphParameters( + G1, G2, l1, l2, nx.utils.groups(l1), nx.utils.groups(l2), None + ) + sparams = _StateParameters( + {0: "a", 1: "b", 2: "c", 3: "d"}, + {"a": 0, "b": 1, "c": 2, "d": 3}, + None, + None, + None, + None, + None, + None, + None, + None, + ) + + u, v = 10, "k" + assert _cut_PT(u, v, gparams, sparams) + + def test_cut_consistent_labels(self): + G1 = nx.Graph( + [ + (0, 1), + (1, 2), + (10, 0), + (10, 3), + (10, 4), + (10, 5), + (10, 6), + (4, 1), + (5, 3), + ] + ) + G2 = nx.Graph( + [ + ("a", "b"), + ("b", "c"), + ("k", "a"), + ("k", "d"), + ("k", "e"), + ("k", "f"), + ("k", "g"), + ("e", "b"), + ("f", "d"), + ] + ) + + l1 = {n: "blue" for n in G1.nodes()} + l2 = {n: "blue" for n in G2.nodes()} + + gparams = _GraphParameters( + G1, G2, l1, l2, nx.utils.groups(l1), nx.utils.groups(l2), None + ) + sparams = _StateParameters( + {0: "a", 1: "b", 2: "c", 3: "d"}, + {"a": 0, "b": 1, "c": 2, "d": 3}, + {4, 5}, + None, + {6}, + None, + {"e", "f"}, + None, + {"g"}, + None, + ) + + u, v = 10, "k" + assert not _cut_PT(u, v, gparams, sparams) + + def test_cut_same_labels(self): + G1 = nx.Graph( + [ + (0, 1), + (1, 2), + (10, 0), + (10, 3), + (10, 4), + (10, 5), + (10, 6), + (4, 1), + (5, 3), + ] + ) + mapped = {0: "a", 1: "b", 2: "c", 3: "d", 4: "e", 5: "f", 6: "g", 10: "k"} + G2 = nx.relabel_nodes(G1, mapped) + l1 = {n: "blue" for n in G1.nodes()} + l2 = {n: "blue" for n in G2.nodes()} + + gparams = _GraphParameters( + G1, G2, l1, l2, nx.utils.groups(l1), nx.utils.groups(l2), None + ) + sparams = _StateParameters( + {0: "a", 1: "b", 2: "c", 3: "d"}, + {"a": 0, "b": 1, "c": 2, "d": 3}, + {4, 5}, + None, + {6}, + None, + {"e", "f"}, + None, + {"g"}, + None, + ) + + u, v = 10, "k" + assert not _cut_PT(u, v, gparams, sparams) + + # Change intersection between G1[u] and T1, so it's not the same as the one between G2[v] and T2 + G1.remove_edge(u, 4) + assert _cut_PT(u, v, gparams, sparams) + + # Compensate in G2 + G2.remove_edge(v, mapped[4]) + assert not _cut_PT(u, v, gparams, sparams) + + # Change intersection between G2[v] and T2_tilde, so it's not the same as the one between G1[u] and T1_tilde + G2.remove_edge(v, mapped[6]) + assert _cut_PT(u, v, gparams, sparams) + + # Compensate in G1 + G1.remove_edge(u, 6) + assert not _cut_PT(u, v, gparams, sparams) + + # Add disconnected nodes, which will form the new Ti_out + G1.add_nodes_from([6, 7, 8]) + G2.add_nodes_from(["g", "y", "z"]) + sparams.T1_tilde.update({6, 7, 8}) + sparams.T2_tilde.update({"g", "y", "z"}) + + l1 = {n: "blue" for n in G1.nodes()} + l2 = {n: "blue" for n in G2.nodes()} + gparams = _GraphParameters( + G1, G2, l1, l2, nx.utils.groups(l1), nx.utils.groups(l2), None + ) + + assert not _cut_PT(u, v, gparams, sparams) + + # Add some new nodes to the mapping + sparams.mapping.update({6: "g", 7: "y"}) + sparams.reverse_mapping.update({"g": 6, "y": 7}) + + # Add more nodes to T1, T2. + G1.add_edges_from([(6, 20), (7, 20), (6, 21)]) + G2.add_edges_from([("g", "i"), ("g", "j"), ("y", "j")]) + + sparams.mapping.update({20: "j", 21: "i"}) + sparams.reverse_mapping.update({"j": 20, "i": 21}) + sparams.T1.update({20, 21}) + sparams.T2.update({"i", "j"}) + sparams.T1_tilde.difference_update({6, 7}) + sparams.T2_tilde.difference_update({"g", "y"}) + + assert not _cut_PT(u, v, gparams, sparams) + + # Add nodes from the new T1 and T2, as neighbors of u and v respectively + G1.add_edges_from([(u, 20), (u, 21)]) + G2.add_edges_from([(v, "i"), (v, "j")]) + l1 = {n: "blue" for n in G1.nodes()} + l2 = {n: "blue" for n in G2.nodes()} + gparams = _GraphParameters( + G1, G2, l1, l2, nx.utils.groups(l1), nx.utils.groups(l2), None + ) + + assert not _cut_PT(u, v, gparams, sparams) + + # Change the edges, maintaining the G1[u]-T1 intersection + G1.remove_edge(u, 20) + G1.add_edge(u, 4) + assert not _cut_PT(u, v, gparams, sparams) + + # Connect u to 8 which is still in T1_tilde + G1.add_edge(u, 8) + assert _cut_PT(u, v, gparams, sparams) + + # Same for v and z, so that inters(G1[u], T1out) == inters(G2[v], T2out) + G2.add_edge(v, "z") + assert not _cut_PT(u, v, gparams, sparams) + + def test_cut_different_labels(self): + G1 = nx.Graph( + [ + (0, 1), + (1, 2), + (1, 14), + (0, 4), + (1, 5), + (2, 6), + (3, 7), + (3, 6), + (4, 10), + (4, 9), + (6, 10), + (20, 9), + (20, 15), + (20, 12), + (20, 11), + (12, 13), + (11, 13), + (20, 8), + (20, 3), + (20, 5), + (20, 0), + ] + ) + mapped = { + 0: "a", + 1: "b", + 2: "c", + 3: "d", + 4: "e", + 5: "f", + 6: "g", + 7: "h", + 8: "i", + 9: "j", + 10: "k", + 11: "l", + 12: "m", + 13: "n", + 14: "o", + 15: "p", + 20: "x", + } + G2 = nx.relabel_nodes(G1, mapped) + + l1 = {n: "none" for n in G1.nodes()} + l2 = {} + + l1.update( + { + 9: "blue", + 15: "blue", + 12: "blue", + 11: "green", + 3: "green", + 8: "red", + 0: "red", + 5: "yellow", + } + ) + l2.update({mapped[n]: l for n, l in l1.items()}) + + gparams = _GraphParameters( + G1, G2, l1, l2, nx.utils.groups(l1), nx.utils.groups(l2), None + ) + sparams = _StateParameters( + {0: "a", 1: "b", 2: "c", 3: "d"}, + {"a": 0, "b": 1, "c": 2, "d": 3}, + {4, 5, 6, 7, 14}, + None, + {9, 10, 15, 12, 11, 13, 8}, + None, + {"e", "f", "g", "h", "o"}, + None, + {"j", "k", "l", "m", "n", "i", "p"}, + None, + ) + + u, v = 20, "x" + assert not _cut_PT(u, v, gparams, sparams) + + # Change the orientation of the labels on neighbors of u compared to neighbors of v. Leave the structure intact + l1.update({9: "red"}) + assert _cut_PT(u, v, gparams, sparams) + + # compensate in G2 + l2.update({mapped[9]: "red"}) + assert not _cut_PT(u, v, gparams, sparams) + + # Change the intersection of G1[u] and T1 + G1.add_edge(u, 4) + assert _cut_PT(u, v, gparams, sparams) + + # Same for G2[v] and T2 + G2.add_edge(v, mapped[4]) + assert not _cut_PT(u, v, gparams, sparams) + + # Change the intersection of G2[v] and T2_tilde + G2.remove_edge(v, mapped[8]) + assert _cut_PT(u, v, gparams, sparams) + + # Same for G1[u] and T1_tilde + G1.remove_edge(u, 8) + assert not _cut_PT(u, v, gparams, sparams) + + # Place 8 and mapped[8] in T1 and T2 respectively, by connecting it to covered nodes + G1.add_edge(8, 3) + G2.add_edge(mapped[8], mapped[3]) + sparams.T1.add(8) + sparams.T2.add(mapped[8]) + sparams.T1_tilde.remove(8) + sparams.T2_tilde.remove(mapped[8]) + + assert not _cut_PT(u, v, gparams, sparams) + + # Remove neighbor of u from T1 + G1.remove_node(5) + l1.pop(5) + sparams.T1.remove(5) + assert _cut_PT(u, v, gparams, sparams) + + # Same in G2 + G2.remove_node(mapped[5]) + l2.pop(mapped[5]) + sparams.T2.remove(mapped[5]) + assert not _cut_PT(u, v, gparams, sparams) + + def test_feasibility_same_labels(self): + G1 = nx.Graph( + [ + (0, 1), + (1, 2), + (1, 14), + (0, 4), + (1, 5), + (2, 6), + (3, 7), + (3, 6), + (4, 10), + (4, 9), + (6, 10), + (20, 9), + (20, 15), + (20, 12), + (20, 11), + (12, 13), + (11, 13), + (20, 8), + (20, 2), + (20, 5), + (20, 0), + ] + ) + mapped = { + 0: "a", + 1: "b", + 2: "c", + 3: "d", + 4: "e", + 5: "f", + 6: "g", + 7: "h", + 8: "i", + 9: "j", + 10: "k", + 11: "l", + 12: "m", + 13: "n", + 14: "o", + 15: "p", + 20: "x", + } + G2 = nx.relabel_nodes(G1, mapped) + + l1 = {n: "blue" for n in G1.nodes()} + l2 = {mapped[n]: "blue" for n in G1.nodes()} + + gparams = _GraphParameters( + G1, G2, l1, l2, nx.utils.groups(l1), nx.utils.groups(l2), None + ) + sparams = _StateParameters( + {0: "a", 1: "b", 2: "c", 3: "d"}, + {"a": 0, "b": 1, "c": 2, "d": 3}, + {4, 5, 6, 7, 14}, + None, + {9, 10, 15, 12, 11, 13, 8}, + None, + {"e", "f", "g", "h", "o"}, + None, + {"j", "k", "l", "m", "n", "i", "p"}, + None, + ) + + u, v = 20, "x" + assert not _cut_PT(u, v, gparams, sparams) + + # Change structure in G2 such that, ONLY consistency is harmed + G2.remove_edge(mapped[20], mapped[2]) + G2.add_edge(mapped[20], mapped[3]) + + # Consistency check fails, while the cutting rules are satisfied! + assert not _cut_PT(u, v, gparams, sparams) + assert not _consistent_PT(u, v, gparams, sparams) + + # Compensate in G1 and make it consistent + G1.remove_edge(20, 2) + G1.add_edge(20, 3) + assert not _cut_PT(u, v, gparams, sparams) + assert _consistent_PT(u, v, gparams, sparams) + + # ONLY fail the cutting check + G2.add_edge(v, mapped[10]) + assert _cut_PT(u, v, gparams, sparams) + assert _consistent_PT(u, v, gparams, sparams) + + def test_feasibility_different_labels(self): + G1 = nx.Graph( + [ + (0, 1), + (1, 2), + (1, 14), + (0, 4), + (1, 5), + (2, 6), + (3, 7), + (3, 6), + (4, 10), + (4, 9), + (6, 10), + (20, 9), + (20, 15), + (20, 12), + (20, 11), + (12, 13), + (11, 13), + (20, 8), + (20, 2), + (20, 5), + (20, 0), + ] + ) + mapped = { + 0: "a", + 1: "b", + 2: "c", + 3: "d", + 4: "e", + 5: "f", + 6: "g", + 7: "h", + 8: "i", + 9: "j", + 10: "k", + 11: "l", + 12: "m", + 13: "n", + 14: "o", + 15: "p", + 20: "x", + } + G2 = nx.relabel_nodes(G1, mapped) + + l1 = {n: "none" for n in G1.nodes()} + l2 = {} + + l1.update( + { + 9: "blue", + 15: "blue", + 12: "blue", + 11: "green", + 2: "green", + 8: "red", + 0: "red", + 5: "yellow", + } + ) + l2.update({mapped[n]: l for n, l in l1.items()}) + + gparams = _GraphParameters( + G1, G2, l1, l2, nx.utils.groups(l1), nx.utils.groups(l2), None + ) + sparams = _StateParameters( + {0: "a", 1: "b", 2: "c", 3: "d"}, + {"a": 0, "b": 1, "c": 2, "d": 3}, + {4, 5, 6, 7, 14}, + None, + {9, 10, 15, 12, 11, 13, 8}, + None, + {"e", "f", "g", "h", "o"}, + None, + {"j", "k", "l", "m", "n", "i", "p"}, + None, + ) + + u, v = 20, "x" + assert not _cut_PT(u, v, gparams, sparams) + + # Change structure in G2 such that, ONLY consistency is harmed + G2.remove_edge(mapped[20], mapped[2]) + G2.add_edge(mapped[20], mapped[3]) + l2.update({mapped[3]: "green"}) + + # Consistency check fails, while the cutting rules are satisfied! + assert not _cut_PT(u, v, gparams, sparams) + assert not _consistent_PT(u, v, gparams, sparams) + + # Compensate in G1 and make it consistent + G1.remove_edge(20, 2) + G1.add_edge(20, 3) + l1.update({3: "green"}) + assert not _cut_PT(u, v, gparams, sparams) + assert _consistent_PT(u, v, gparams, sparams) + + # ONLY fail the cutting check + l1.update({5: "red"}) + assert _cut_PT(u, v, gparams, sparams) + assert _consistent_PT(u, v, gparams, sparams) + + +class TestMultiGraphISOFeasibility: + def test_const_covered_neighbors(self): + G1 = nx.MultiGraph( + [(0, 1), (0, 1), (1, 2), (3, 0), (3, 0), (3, 0), (3, 2), (3, 2)] + ) + G2 = nx.MultiGraph( + [ + ("a", "b"), + ("a", "b"), + ("b", "c"), + ("k", "a"), + ("k", "a"), + ("k", "a"), + ("k", "c"), + ("k", "c"), + ] + ) + gparams = _GraphParameters(G1, G2, None, None, None, None, None) + sparams = _StateParameters( + {0: "a", 1: "b", 2: "c"}, + {"a": 0, "b": 1, "c": 2}, + None, + None, + None, + None, + None, + None, + None, + None, + ) + u, v = 3, "k" + assert _consistent_PT(u, v, gparams, sparams) + + def test_const_no_covered_neighbors(self): + G1 = nx.MultiGraph([(0, 1), (0, 1), (1, 2), (3, 4), (3, 4), (3, 5)]) + G2 = nx.MultiGraph([("a", "b"), ("b", "c"), ("k", "w"), ("k", "w"), ("k", "z")]) + gparams = _GraphParameters(G1, G2, None, None, None, None, None) + sparams = _StateParameters( + {0: "a", 1: "b", 2: "c"}, + {"a": 0, "b": 1, "c": 2}, + None, + None, + None, + None, + None, + None, + None, + None, + ) + u, v = 3, "k" + assert _consistent_PT(u, v, gparams, sparams) + + def test_const_mixed_covered_uncovered_neighbors(self): + G1 = nx.MultiGraph( + [(0, 1), (1, 2), (3, 0), (3, 0), (3, 0), (3, 2), (3, 2), (3, 4), (3, 5)] + ) + G2 = nx.MultiGraph( + [ + ("a", "b"), + ("b", "c"), + ("k", "a"), + ("k", "a"), + ("k", "a"), + ("k", "c"), + ("k", "c"), + ("k", "w"), + ("k", "z"), + ] + ) + gparams = _GraphParameters(G1, G2, None, None, None, None, None) + sparams = _StateParameters( + {0: "a", 1: "b", 2: "c"}, + {"a": 0, "b": 1, "c": 2}, + None, + None, + None, + None, + None, + None, + None, + None, + ) + u, v = 3, "k" + assert _consistent_PT(u, v, gparams, sparams) + + def test_const_fail_cases(self): + G1 = nx.MultiGraph( + [ + (0, 1), + (1, 2), + (10, 0), + (10, 0), + (10, 0), + (10, 3), + (10, 3), + (10, 4), + (10, 5), + (10, 6), + (10, 6), + (4, 1), + (5, 3), + ] + ) + mapped = {0: "a", 1: "b", 2: "c", 3: "d", 4: "e", 5: "f", 6: "g", 10: "k"} + G2 = nx.relabel_nodes(G1, mapped) + + gparams = _GraphParameters(G1, G2, None, None, None, None, None) + sparams = _StateParameters( + {0: "a", 1: "b", 2: "c", 3: "d"}, + {"a": 0, "b": 1, "c": 2, "d": 3}, + None, + None, + None, + None, + None, + None, + None, + None, + ) + u, v = 10, "k" + assert _consistent_PT(u, v, gparams, sparams) + + # Delete one uncovered neighbor of u. Notice how it still passes the test. Two reasons for this: + # 1. If u, v had different degrees from the beginning, they wouldn't be selected as candidates in the first + # place. + # 2. Even if they are selected, consistency is basically 1-look-ahead, meaning that we take into consideration + # the relation of the candidates with their mapped neighbors. The node we deleted is not a covered neighbor. + # Such nodes will be checked by the cut_PT function, which is basically the 2-look-ahead, checking the + # relation of the candidates with T1, T2 (in which belongs the node we just deleted). + G1.remove_node(6) + assert _consistent_PT(u, v, gparams, sparams) + + # Add one more covered neighbor of u in G1 + G1.add_edge(u, 2) + assert not _consistent_PT(u, v, gparams, sparams) + + # Compensate in G2 + G2.add_edge(v, "c") + assert _consistent_PT(u, v, gparams, sparams) + + # Add one more covered neighbor of v in G2 + G2.add_edge(v, "x") + G1.add_node(7) + sparams.mapping.update({7: "x"}) + sparams.reverse_mapping.update({"x": 7}) + assert not _consistent_PT(u, v, gparams, sparams) + + # Compendate in G1 + G1.add_edge(u, 7) + assert _consistent_PT(u, v, gparams, sparams) + + # Delete an edge between u and a covered neighbor + G1.remove_edges_from([(u, 0), (u, 0)]) + assert not _consistent_PT(u, v, gparams, sparams) + + # Compensate in G2 + G2.remove_edges_from([(v, mapped[0]), (v, mapped[0])]) + assert _consistent_PT(u, v, gparams, sparams) + + # Remove an edge between v and a covered neighbor + G2.remove_edge(v, mapped[3]) + assert not _consistent_PT(u, v, gparams, sparams) + + # Compensate in G1 + G1.remove_edge(u, 3) + assert _consistent_PT(u, v, gparams, sparams) + + def test_cut_same_labels(self): + G1 = nx.MultiGraph( + [ + (0, 1), + (1, 2), + (10, 0), + (10, 0), + (10, 0), + (10, 3), + (10, 3), + (10, 4), + (10, 4), + (10, 5), + (10, 5), + (10, 5), + (10, 5), + (10, 6), + (10, 6), + (4, 1), + (5, 3), + ] + ) + mapped = {0: "a", 1: "b", 2: "c", 3: "d", 4: "e", 5: "f", 6: "g", 10: "k"} + G2 = nx.relabel_nodes(G1, mapped) + l1 = {n: "blue" for n in G1.nodes()} + l2 = {n: "blue" for n in G2.nodes()} + + gparams = _GraphParameters( + G1, G2, l1, l2, nx.utils.groups(l1), nx.utils.groups(l2), None + ) + sparams = _StateParameters( + {0: "a", 1: "b", 2: "c", 3: "d"}, + {"a": 0, "b": 1, "c": 2, "d": 3}, + {4, 5}, + None, + {6}, + None, + {"e", "f"}, + None, + {"g"}, + None, + ) + + u, v = 10, "k" + assert not _cut_PT(u, v, gparams, sparams) + + # Remove one of the multiple edges between u and a neighbor + G1.remove_edge(u, 4) + assert _cut_PT(u, v, gparams, sparams) + + # Compensate in G2 + G1.remove_edge(u, 4) + G2.remove_edges_from([(v, mapped[4]), (v, mapped[4])]) + assert not _cut_PT(u, v, gparams, sparams) + + # Change intersection between G2[v] and T2_tilde, so it's not the same as the one between G1[u] and T1_tilde + G2.remove_edge(v, mapped[6]) + assert _cut_PT(u, v, gparams, sparams) + + # Compensate in G1 + G1.remove_edge(u, 6) + assert not _cut_PT(u, v, gparams, sparams) + + # Add more edges between u and neighbor which belongs in T1_tilde + G1.add_edges_from([(u, 5), (u, 5), (u, 5)]) + assert _cut_PT(u, v, gparams, sparams) + + # Compensate in G2 + G2.add_edges_from([(v, mapped[5]), (v, mapped[5]), (v, mapped[5])]) + assert not _cut_PT(u, v, gparams, sparams) + + # Add disconnected nodes, which will form the new Ti_out + G1.add_nodes_from([6, 7, 8]) + G2.add_nodes_from(["g", "y", "z"]) + G1.add_edges_from([(u, 6), (u, 6), (u, 6), (u, 8)]) + G2.add_edges_from([(v, "g"), (v, "g"), (v, "g"), (v, "z")]) + + sparams.T1_tilde.update({6, 7, 8}) + sparams.T2_tilde.update({"g", "y", "z"}) + + l1 = {n: "blue" for n in G1.nodes()} + l2 = {n: "blue" for n in G2.nodes()} + gparams = _GraphParameters( + G1, G2, l1, l2, nx.utils.groups(l1), nx.utils.groups(l2), None + ) + + assert not _cut_PT(u, v, gparams, sparams) + + # Add some new nodes to the mapping + sparams.mapping.update({6: "g", 7: "y"}) + sparams.reverse_mapping.update({"g": 6, "y": 7}) + + # Add more nodes to T1, T2. + G1.add_edges_from([(6, 20), (7, 20), (6, 21)]) + G2.add_edges_from([("g", "i"), ("g", "j"), ("y", "j")]) + + sparams.T1.update({20, 21}) + sparams.T2.update({"i", "j"}) + sparams.T1_tilde.difference_update({6, 7}) + sparams.T2_tilde.difference_update({"g", "y"}) + + assert not _cut_PT(u, v, gparams, sparams) + + # Remove some edges + G2.remove_edge(v, "g") + assert _cut_PT(u, v, gparams, sparams) + + G1.remove_edge(u, 6) + G1.add_edge(u, 8) + G2.add_edge(v, "z") + assert not _cut_PT(u, v, gparams, sparams) + + # Add nodes from the new T1 and T2, as neighbors of u and v respectively + G1.add_edges_from([(u, 20), (u, 20), (u, 20), (u, 21)]) + G2.add_edges_from([(v, "i"), (v, "i"), (v, "i"), (v, "j")]) + l1 = {n: "blue" for n in G1.nodes()} + l2 = {n: "blue" for n in G2.nodes()} + gparams = _GraphParameters( + G1, G2, l1, l2, nx.utils.groups(l1), nx.utils.groups(l2), None + ) + + assert not _cut_PT(u, v, gparams, sparams) + + # Change the edges + G1.remove_edge(u, 20) + G1.add_edge(u, 4) + assert _cut_PT(u, v, gparams, sparams) + + G2.remove_edge(v, "i") + G2.add_edge(v, mapped[4]) + assert not _cut_PT(u, v, gparams, sparams) + + def test_cut_different_labels(self): + G1 = nx.MultiGraph( + [ + (0, 1), + (0, 1), + (1, 2), + (1, 2), + (1, 14), + (0, 4), + (1, 5), + (2, 6), + (3, 7), + (3, 6), + (4, 10), + (4, 9), + (6, 10), + (20, 9), + (20, 9), + (20, 9), + (20, 15), + (20, 15), + (20, 12), + (20, 11), + (20, 11), + (20, 11), + (12, 13), + (11, 13), + (20, 8), + (20, 8), + (20, 3), + (20, 3), + (20, 5), + (20, 5), + (20, 5), + (20, 0), + (20, 0), + (20, 0), + ] + ) + mapped = { + 0: "a", + 1: "b", + 2: "c", + 3: "d", + 4: "e", + 5: "f", + 6: "g", + 7: "h", + 8: "i", + 9: "j", + 10: "k", + 11: "l", + 12: "m", + 13: "n", + 14: "o", + 15: "p", + 20: "x", + } + G2 = nx.relabel_nodes(G1, mapped) + + l1 = {n: "none" for n in G1.nodes()} + l2 = {} + + l1.update( + { + 9: "blue", + 15: "blue", + 12: "blue", + 11: "green", + 3: "green", + 8: "red", + 0: "red", + 5: "yellow", + } + ) + l2.update({mapped[n]: l for n, l in l1.items()}) + + gparams = _GraphParameters( + G1, G2, l1, l2, nx.utils.groups(l1), nx.utils.groups(l2), None + ) + sparams = _StateParameters( + {0: "a", 1: "b", 2: "c", 3: "d"}, + {"a": 0, "b": 1, "c": 2, "d": 3}, + {4, 5, 6, 7, 14}, + None, + {9, 10, 15, 12, 11, 13, 8}, + None, + {"e", "f", "g", "h", "o"}, + None, + {"j", "k", "l", "m", "n", "i", "p"}, + None, + ) + + u, v = 20, "x" + assert not _cut_PT(u, v, gparams, sparams) + + # Change the orientation of the labels on neighbors of u compared to neighbors of v. Leave the structure intact + l1.update({9: "red"}) + assert _cut_PT(u, v, gparams, sparams) + + # compensate in G2 + l2.update({mapped[9]: "red"}) + assert not _cut_PT(u, v, gparams, sparams) + + # Change the intersection of G1[u] and T1 + G1.add_edge(u, 4) + assert _cut_PT(u, v, gparams, sparams) + + # Same for G2[v] and T2 + G2.add_edge(v, mapped[4]) + assert not _cut_PT(u, v, gparams, sparams) + + # Delete one from the multiple edges + G2.remove_edge(v, mapped[8]) + assert _cut_PT(u, v, gparams, sparams) + + # Same for G1[u] and T1_tilde + G1.remove_edge(u, 8) + assert not _cut_PT(u, v, gparams, sparams) + + # Place 8 and mapped[8] in T1 and T2 respectively, by connecting it to covered nodes + G1.add_edges_from([(8, 3), (8, 3), (8, u)]) + G2.add_edges_from([(mapped[8], mapped[3]), (mapped[8], mapped[3])]) + sparams.T1.add(8) + sparams.T2.add(mapped[8]) + sparams.T1_tilde.remove(8) + sparams.T2_tilde.remove(mapped[8]) + + assert _cut_PT(u, v, gparams, sparams) + + # Fix uneven edges + G1.remove_edge(8, u) + assert not _cut_PT(u, v, gparams, sparams) + + # Remove neighbor of u from T1 + G1.remove_node(5) + l1.pop(5) + sparams.T1.remove(5) + assert _cut_PT(u, v, gparams, sparams) + + # Same in G2 + G2.remove_node(mapped[5]) + l2.pop(mapped[5]) + sparams.T2.remove(mapped[5]) + assert not _cut_PT(u, v, gparams, sparams) + + def test_feasibility_same_labels(self): + G1 = nx.MultiGraph( + [ + (0, 1), + (0, 1), + (1, 2), + (1, 2), + (1, 14), + (0, 4), + (1, 5), + (2, 6), + (3, 7), + (3, 6), + (4, 10), + (4, 9), + (6, 10), + (20, 9), + (20, 9), + (20, 9), + (20, 15), + (20, 15), + (20, 12), + (20, 11), + (20, 11), + (20, 11), + (12, 13), + (11, 13), + (20, 8), + (20, 8), + (20, 3), + (20, 3), + (20, 5), + (20, 5), + (20, 5), + (20, 0), + (20, 0), + (20, 0), + ] + ) + mapped = { + 0: "a", + 1: "b", + 2: "c", + 3: "d", + 4: "e", + 5: "f", + 6: "g", + 7: "h", + 8: "i", + 9: "j", + 10: "k", + 11: "l", + 12: "m", + 13: "n", + 14: "o", + 15: "p", + 20: "x", + } + G2 = nx.relabel_nodes(G1, mapped) + l1 = {n: "blue" for n in G1.nodes()} + l2 = {mapped[n]: "blue" for n in G1.nodes()} + + gparams = _GraphParameters( + G1, G2, l1, l2, nx.utils.groups(l1), nx.utils.groups(l2), None + ) + sparams = _StateParameters( + {0: "a", 1: "b", 2: "c", 3: "d"}, + {"a": 0, "b": 1, "c": 2, "d": 3}, + {4, 5, 6, 7, 14}, + None, + {9, 10, 15, 12, 11, 13, 8}, + None, + {"e", "f", "g", "h", "o"}, + None, + {"j", "k", "l", "m", "n", "i", "p"}, + None, + ) + + u, v = 20, "x" + assert not _cut_PT(u, v, gparams, sparams) + + # Change structure in G2 such that, ONLY consistency is harmed + G2.remove_edges_from([(mapped[20], mapped[3]), (mapped[20], mapped[3])]) + G2.add_edges_from([(mapped[20], mapped[2]), (mapped[20], mapped[2])]) + + # Consistency check fails, while the cutting rules are satisfied! + assert not _cut_PT(u, v, gparams, sparams) + assert not _consistent_PT(u, v, gparams, sparams) + + # Compensate in G1 and make it consistent + G1.remove_edges_from([(20, 3), (20, 3)]) + G1.add_edges_from([(20, 2), (20, 2)]) + assert not _cut_PT(u, v, gparams, sparams) + assert _consistent_PT(u, v, gparams, sparams) + + # ONLY fail the cutting check + G2.add_edges_from([(v, mapped[10])] * 5) + assert _cut_PT(u, v, gparams, sparams) + assert _consistent_PT(u, v, gparams, sparams) + + # Pass all tests + G1.add_edges_from([(u, 10)] * 5) + assert not _cut_PT(u, v, gparams, sparams) + assert _consistent_PT(u, v, gparams, sparams) + + def test_feasibility_different_labels(self): + G1 = nx.MultiGraph( + [ + (0, 1), + (0, 1), + (1, 2), + (1, 2), + (1, 14), + (0, 4), + (1, 5), + (2, 6), + (3, 7), + (3, 6), + (4, 10), + (4, 9), + (6, 10), + (20, 9), + (20, 9), + (20, 9), + (20, 15), + (20, 15), + (20, 12), + (20, 11), + (20, 11), + (20, 11), + (12, 13), + (11, 13), + (20, 8), + (20, 8), + (20, 2), + (20, 2), + (20, 5), + (20, 5), + (20, 5), + (20, 0), + (20, 0), + (20, 0), + ] + ) + mapped = { + 0: "a", + 1: "b", + 2: "c", + 3: "d", + 4: "e", + 5: "f", + 6: "g", + 7: "h", + 8: "i", + 9: "j", + 10: "k", + 11: "l", + 12: "m", + 13: "n", + 14: "o", + 15: "p", + 20: "x", + } + G2 = nx.relabel_nodes(G1, mapped) + l1 = {n: "none" for n in G1.nodes()} + l2 = {} + + l1.update( + { + 9: "blue", + 15: "blue", + 12: "blue", + 11: "green", + 2: "green", + 8: "red", + 0: "red", + 5: "yellow", + } + ) + l2.update({mapped[n]: l for n, l in l1.items()}) + + gparams = _GraphParameters( + G1, G2, l1, l2, nx.utils.groups(l1), nx.utils.groups(l2), None + ) + sparams = _StateParameters( + {0: "a", 1: "b", 2: "c", 3: "d"}, + {"a": 0, "b": 1, "c": 2, "d": 3}, + {4, 5, 6, 7, 14}, + None, + {9, 10, 15, 12, 11, 13, 8}, + None, + {"e", "f", "g", "h", "o"}, + None, + {"j", "k", "l", "m", "n", "i", "p"}, + None, + ) + + u, v = 20, "x" + assert not _cut_PT(u, v, gparams, sparams) + + # Change structure in G2 such that, ONLY consistency is harmed + G2.remove_edges_from([(mapped[20], mapped[2]), (mapped[20], mapped[2])]) + G2.add_edges_from([(mapped[20], mapped[3]), (mapped[20], mapped[3])]) + l2.update({mapped[3]: "green"}) + + # Consistency check fails, while the cutting rules are satisfied! + assert not _cut_PT(u, v, gparams, sparams) + assert not _consistent_PT(u, v, gparams, sparams) + + # Compensate in G1 and make it consistent + G1.remove_edges_from([(20, 2), (20, 2)]) + G1.add_edges_from([(20, 3), (20, 3)]) + l1.update({3: "green"}) + assert not _cut_PT(u, v, gparams, sparams) + assert _consistent_PT(u, v, gparams, sparams) + + # ONLY fail the cutting check + l1.update({5: "red"}) + assert _cut_PT(u, v, gparams, sparams) + assert _consistent_PT(u, v, gparams, sparams) + + +class TestDiGraphISOFeasibility: + def test_const_covered_neighbors(self): + G1 = nx.DiGraph([(0, 1), (1, 2), (0, 3), (2, 3)]) + G2 = nx.DiGraph([("a", "b"), ("b", "c"), ("a", "k"), ("c", "k")]) + gparams = _GraphParameters(G1, G2, None, None, None, None, None) + sparams = _StateParameters( + {0: "a", 1: "b", 2: "c"}, + {"a": 0, "b": 1, "c": 2}, + None, + None, + None, + None, + None, + None, + None, + None, + ) + u, v = 3, "k" + assert _consistent_PT(u, v, gparams, sparams) + + def test_const_no_covered_neighbors(self): + G1 = nx.DiGraph([(0, 1), (1, 2), (3, 4), (3, 5)]) + G2 = nx.DiGraph([("a", "b"), ("b", "c"), ("k", "w"), ("k", "z")]) + gparams = _GraphParameters(G1, G2, None, None, None, None, None) + sparams = _StateParameters( + {0: "a", 1: "b", 2: "c"}, + {"a": 0, "b": 1, "c": 2}, + None, + None, + None, + None, + None, + None, + None, + None, + ) + u, v = 3, "k" + assert _consistent_PT(u, v, gparams, sparams) + + def test_const_mixed_covered_uncovered_neighbors(self): + G1 = nx.DiGraph([(0, 1), (1, 2), (3, 0), (3, 2), (3, 4), (3, 5)]) + G2 = nx.DiGraph( + [("a", "b"), ("b", "c"), ("k", "a"), ("k", "c"), ("k", "w"), ("k", "z")] + ) + gparams = _GraphParameters(G1, G2, None, None, None, None, None) + sparams = _StateParameters( + {0: "a", 1: "b", 2: "c"}, + {"a": 0, "b": 1, "c": 2}, + None, + None, + None, + None, + None, + None, + None, + None, + ) + u, v = 3, "k" + assert _consistent_PT(u, v, gparams, sparams) + + def test_const_fail_cases(self): + G1 = nx.DiGraph( + [ + (0, 1), + (2, 1), + (10, 0), + (10, 3), + (10, 4), + (5, 10), + (10, 6), + (1, 4), + (5, 3), + ] + ) + G2 = nx.DiGraph( + [ + ("a", "b"), + ("c", "b"), + ("k", "a"), + ("k", "d"), + ("k", "e"), + ("f", "k"), + ("k", "g"), + ("b", "e"), + ("f", "d"), + ] + ) + gparams = _GraphParameters(G1, G2, None, None, None, None, None) + sparams = _StateParameters( + {0: "a", 1: "b", 2: "c", 3: "d"}, + {"a": 0, "b": 1, "c": 2, "d": 3}, + None, + None, + None, + None, + None, + None, + None, + None, + ) + u, v = 10, "k" + assert _consistent_PT(u, v, gparams, sparams) + + # Delete one uncovered neighbor of u. Notice how it still passes the + # test. Two reasons for this: + # 1. If u, v had different degrees from the beginning, they wouldn't + # be selected as candidates in the first place. + # 2. Even if they are selected, consistency is basically + # 1-look-ahead, meaning that we take into consideration the + # relation of the candidates with their mapped neighbors. + # The node we deleted is not a covered neighbor. + # Such nodes will be checked by the cut_PT function, which is + # basically the 2-look-ahead, checking the relation of the + # candidates with T1, T2 (in which belongs the node we just deleted). + G1.remove_node(6) + assert _consistent_PT(u, v, gparams, sparams) + + # Add one more covered neighbor of u in G1 + G1.add_edge(u, 2) + assert not _consistent_PT(u, v, gparams, sparams) + + # Compensate in G2 + G2.add_edge(v, "c") + assert _consistent_PT(u, v, gparams, sparams) + + # Add one more covered neighbor of v in G2 + G2.add_edge(v, "x") + G1.add_node(7) + sparams.mapping.update({7: "x"}) + sparams.reverse_mapping.update({"x": 7}) + assert not _consistent_PT(u, v, gparams, sparams) + + # Compensate in G1 + G1.add_edge(u, 7) + assert _consistent_PT(u, v, gparams, sparams) + + def test_cut_inconsistent_labels(self): + G1 = nx.DiGraph( + [ + (0, 1), + (2, 1), + (10, 0), + (10, 3), + (10, 4), + (5, 10), + (10, 6), + (1, 4), + (5, 3), + ] + ) + G2 = nx.DiGraph( + [ + ("a", "b"), + ("c", "b"), + ("k", "a"), + ("k", "d"), + ("k", "e"), + ("f", "k"), + ("k", "g"), + ("b", "e"), + ("f", "d"), + ] + ) + + l1 = {n: "blue" for n in G1.nodes()} + l2 = {n: "blue" for n in G2.nodes()} + l1.update({5: "green"}) # Change the label of one neighbor of u + + gparams = _GraphParameters( + G1, G2, l1, l2, nx.utils.groups(l1), nx.utils.groups(l2), None + ) + sparams = _StateParameters( + {0: "a", 1: "b", 2: "c", 3: "d"}, + {"a": 0, "b": 1, "c": 2, "d": 3}, + None, + None, + None, + None, + None, + None, + None, + None, + ) + + u, v = 10, "k" + assert _cut_PT(u, v, gparams, sparams) + + def test_cut_consistent_labels(self): + G1 = nx.DiGraph( + [ + (0, 1), + (2, 1), + (10, 0), + (10, 3), + (10, 4), + (5, 10), + (10, 6), + (1, 4), + (5, 3), + ] + ) + G2 = nx.DiGraph( + [ + ("a", "b"), + ("c", "b"), + ("k", "a"), + ("k", "d"), + ("k", "e"), + ("f", "k"), + ("k", "g"), + ("b", "e"), + ("f", "d"), + ] + ) + + l1 = {n: "blue" for n in G1.nodes()} + l2 = {n: "blue" for n in G2.nodes()} + + gparams = _GraphParameters( + G1, G2, l1, l2, nx.utils.groups(l1), nx.utils.groups(l2), None + ) + sparams = _StateParameters( + {0: "a", 1: "b", 2: "c", 3: "d"}, + {"a": 0, "b": 1, "c": 2, "d": 3}, + {4}, + {5, 10}, + {6}, + None, + {"e"}, + {"f", "k"}, + {"g"}, + None, + ) + + u, v = 10, "k" + assert not _cut_PT(u, v, gparams, sparams) + + def test_cut_same_labels(self): + G1 = nx.DiGraph( + [ + (0, 1), + (2, 1), + (10, 0), + (10, 3), + (10, 4), + (5, 10), + (10, 6), + (1, 4), + (5, 3), + ] + ) + mapped = {0: "a", 1: "b", 2: "c", 3: "d", 4: "e", 5: "f", 6: "g", 10: "k"} + G2 = nx.relabel_nodes(G1, mapped) + l1 = {n: "blue" for n in G1.nodes()} + l2 = {n: "blue" for n in G2.nodes()} + + gparams = _GraphParameters( + G1, G2, l1, l2, nx.utils.groups(l1), nx.utils.groups(l2), None + ) + sparams = _StateParameters( + {0: "a", 1: "b", 2: "c", 3: "d"}, + {"a": 0, "b": 1, "c": 2, "d": 3}, + {4}, + {5, 10}, + {6}, + None, + {"e"}, + {"f", "k"}, + {"g"}, + None, + ) + + u, v = 10, "k" + assert not _cut_PT(u, v, gparams, sparams) + + # Change intersection between G1[u] and T1_out, so it's not the same as the one between G2[v] and T2_out + G1.remove_edge(u, 4) + assert _cut_PT(u, v, gparams, sparams) + + # Compensate in G2 + G2.remove_edge(v, mapped[4]) + assert not _cut_PT(u, v, gparams, sparams) + + # Change intersection between G1[u] and T1_in, so it's not the same as the one between G2[v] and T2_in + G1.remove_edge(5, u) + assert _cut_PT(u, v, gparams, sparams) + + # Compensate in G2 + G2.remove_edge(mapped[5], v) + assert not _cut_PT(u, v, gparams, sparams) + + # Change intersection between G2[v] and T2_tilde, so it's not the same as the one between G1[u] and T1_tilde + G2.remove_edge(v, mapped[6]) + assert _cut_PT(u, v, gparams, sparams) + + # Compensate in G1 + G1.remove_edge(u, 6) + assert not _cut_PT(u, v, gparams, sparams) + + # Add disconnected nodes, which will form the new Ti_tilde + G1.add_nodes_from([6, 7, 8]) + G2.add_nodes_from(["g", "y", "z"]) + sparams.T1_tilde.update({6, 7, 8}) + sparams.T2_tilde.update({"g", "y", "z"}) + + l1 = {n: "blue" for n in G1.nodes()} + l2 = {n: "blue" for n in G2.nodes()} + gparams = _GraphParameters( + G1, G2, l1, l2, nx.utils.groups(l1), nx.utils.groups(l2), None + ) + + assert not _cut_PT(u, v, gparams, sparams) + + def test_cut_different_labels(self): + G1 = nx.DiGraph( + [ + (0, 1), + (1, 2), + (14, 1), + (0, 4), + (1, 5), + (2, 6), + (3, 7), + (3, 6), + (10, 4), + (4, 9), + (6, 10), + (20, 9), + (20, 15), + (20, 12), + (20, 11), + (12, 13), + (11, 13), + (20, 8), + (20, 3), + (20, 5), + (0, 20), + ] + ) + mapped = { + 0: "a", + 1: "b", + 2: "c", + 3: "d", + 4: "e", + 5: "f", + 6: "g", + 7: "h", + 8: "i", + 9: "j", + 10: "k", + 11: "l", + 12: "m", + 13: "n", + 14: "o", + 15: "p", + 20: "x", + } + G2 = nx.relabel_nodes(G1, mapped) + + l1 = {n: "none" for n in G1.nodes()} + l2 = {} + + l1.update( + { + 9: "blue", + 15: "blue", + 12: "blue", + 11: "green", + 3: "green", + 8: "red", + 0: "red", + 5: "yellow", + } + ) + l2.update({mapped[n]: l for n, l in l1.items()}) + + gparams = _GraphParameters( + G1, G2, l1, l2, nx.utils.groups(l1), nx.utils.groups(l2), None + ) + sparams = _StateParameters( + {0: "a", 1: "b", 2: "c", 3: "d"}, + {"a": 0, "b": 1, "c": 2, "d": 3}, + {4, 5, 6, 7, 20}, + {14, 20}, + {9, 10, 15, 12, 11, 13, 8}, + None, + {"e", "f", "g", "x"}, + {"o", "x"}, + {"j", "k", "l", "m", "n", "i", "p"}, + None, + ) + + u, v = 20, "x" + assert not _cut_PT(u, v, gparams, sparams) + + # Change the orientation of the labels on neighbors of u compared to neighbors of v. Leave the structure intact + l1.update({9: "red"}) + assert _cut_PT(u, v, gparams, sparams) + + # compensate in G2 + l2.update({mapped[9]: "red"}) + assert not _cut_PT(u, v, gparams, sparams) + + # Change the intersection of G1[u] and T1_out + G1.add_edge(u, 4) + assert _cut_PT(u, v, gparams, sparams) + + # Same for G2[v] and T2_out + G2.add_edge(v, mapped[4]) + assert not _cut_PT(u, v, gparams, sparams) + + # Change the intersection of G1[u] and T1_in + G1.add_edge(u, 14) + assert _cut_PT(u, v, gparams, sparams) + + # Same for G2[v] and T2_in + G2.add_edge(v, mapped[14]) + assert not _cut_PT(u, v, gparams, sparams) + + # Change the intersection of G2[v] and T2_tilde + G2.remove_edge(v, mapped[8]) + assert _cut_PT(u, v, gparams, sparams) + + # Same for G1[u] and T1_tilde + G1.remove_edge(u, 8) + assert not _cut_PT(u, v, gparams, sparams) + + # Place 8 and mapped[8] in T1 and T2 respectively, by connecting it to covered nodes + G1.add_edge(8, 3) + G2.add_edge(mapped[8], mapped[3]) + sparams.T1.add(8) + sparams.T2.add(mapped[8]) + sparams.T1_tilde.remove(8) + sparams.T2_tilde.remove(mapped[8]) + + assert not _cut_PT(u, v, gparams, sparams) + + # Remove neighbor of u from T1 + G1.remove_node(5) + l1.pop(5) + sparams.T1.remove(5) + assert _cut_PT(u, v, gparams, sparams) + + # Same in G2 + G2.remove_node(mapped[5]) + l2.pop(mapped[5]) + sparams.T2.remove(mapped[5]) + assert not _cut_PT(u, v, gparams, sparams) + + def test_predecessor_T1_in_fail(self): + G1 = nx.DiGraph( + [(0, 1), (0, 3), (4, 0), (1, 5), (5, 2), (3, 6), (4, 6), (6, 5)] + ) + mapped = {0: "a", 1: "b", 2: "c", 3: "d", 4: "e", 5: "f", 6: "g"} + G2 = nx.relabel_nodes(G1, mapped) + l1 = {n: "blue" for n in G1.nodes()} + l2 = {n: "blue" for n in G2.nodes()} + + gparams = _GraphParameters( + G1, G2, l1, l2, nx.utils.groups(l1), nx.utils.groups(l2), None + ) + sparams = _StateParameters( + {0: "a", 1: "b", 2: "c"}, + {"a": 0, "b": 1, "c": 2}, + {3, 5}, + {4, 5}, + {6}, + None, + {"d", "f"}, + {"f"}, # mapped[4] is missing from T2_in + {"g"}, + None, + ) + + u, v = 6, "g" + assert _cut_PT(u, v, gparams, sparams) + + sparams.T2_in.add("e") + assert not _cut_PT(u, v, gparams, sparams) + + +class TestGraphTinoutUpdating: + edges = [ + (1, 3), + (2, 3), + (3, 4), + (4, 9), + (4, 5), + (3, 9), + (5, 8), + (5, 7), + (8, 7), + (6, 7), + ] + mapped = { + 0: "x", + 1: "a", + 2: "b", + 3: "c", + 4: "d", + 5: "e", + 6: "f", + 7: "g", + 8: "h", + 9: "i", + } + G1 = nx.Graph() + G1.add_edges_from(edges) + G1.add_node(0) + G2 = nx.relabel_nodes(G1, mapping=mapped) + + def test_updating(self): + G2_degree = dict(self.G2.degree) + gparams, sparams = _initialize_parameters(self.G1, self.G2, G2_degree) + m, m_rev, T1, _, T1_tilde, _, T2, _, T2_tilde, _ = sparams + + # Add node to the mapping + m[4] = self.mapped[4] + m_rev[self.mapped[4]] = 4 + _update_Tinout(4, self.mapped[4], gparams, sparams) + + assert T1 == {3, 5, 9} + assert T2 == {"c", "i", "e"} + assert T1_tilde == {0, 1, 2, 6, 7, 8} + assert T2_tilde == {"x", "a", "b", "f", "g", "h"} + + # Add node to the mapping + m[5] = self.mapped[5] + m_rev.update({self.mapped[5]: 5}) + _update_Tinout(5, self.mapped[5], gparams, sparams) + + assert T1 == {3, 9, 8, 7} + assert T2 == {"c", "i", "h", "g"} + assert T1_tilde == {0, 1, 2, 6} + assert T2_tilde == {"x", "a", "b", "f"} + + # Add node to the mapping + m[6] = self.mapped[6] + m_rev.update({self.mapped[6]: 6}) + _update_Tinout(6, self.mapped[6], gparams, sparams) + + assert T1 == {3, 9, 8, 7} + assert T2 == {"c", "i", "h", "g"} + assert T1_tilde == {0, 1, 2} + assert T2_tilde == {"x", "a", "b"} + + # Add node to the mapping + m[3] = self.mapped[3] + m_rev.update({self.mapped[3]: 3}) + _update_Tinout(3, self.mapped[3], gparams, sparams) + + assert T1 == {1, 2, 9, 8, 7} + assert T2 == {"a", "b", "i", "h", "g"} + assert T1_tilde == {0} + assert T2_tilde == {"x"} + + # Add node to the mapping + m[0] = self.mapped[0] + m_rev.update({self.mapped[0]: 0}) + _update_Tinout(0, self.mapped[0], gparams, sparams) + + assert T1 == {1, 2, 9, 8, 7} + assert T2 == {"a", "b", "i", "h", "g"} + assert T1_tilde == set() + assert T2_tilde == set() + + def test_restoring(self): + m = {0: "x", 3: "c", 4: "d", 5: "e", 6: "f"} + m_rev = {"x": 0, "c": 3, "d": 4, "e": 5, "f": 6} + + T1 = {1, 2, 7, 9, 8} + T2 = {"a", "b", "g", "i", "h"} + T1_tilde = set() + T2_tilde = set() + + gparams = _GraphParameters(self.G1, self.G2, {}, {}, {}, {}, {}) + sparams = _StateParameters( + m, m_rev, T1, None, T1_tilde, None, T2, None, T2_tilde, None + ) + + # Remove a node from the mapping + m.pop(0) + m_rev.pop("x") + _restore_Tinout(0, self.mapped[0], gparams, sparams) + + assert T1 == {1, 2, 7, 9, 8} + assert T2 == {"a", "b", "g", "i", "h"} + assert T1_tilde == {0} + assert T2_tilde == {"x"} + + # Remove a node from the mapping + m.pop(6) + m_rev.pop("f") + _restore_Tinout(6, self.mapped[6], gparams, sparams) + + assert T1 == {1, 2, 7, 9, 8} + assert T2 == {"a", "b", "g", "i", "h"} + assert T1_tilde == {0, 6} + assert T2_tilde == {"x", "f"} + + # Remove a node from the mapping + m.pop(3) + m_rev.pop("c") + _restore_Tinout(3, self.mapped[3], gparams, sparams) + + assert T1 == {7, 9, 8, 3} + assert T2 == {"g", "i", "h", "c"} + assert T1_tilde == {0, 6, 1, 2} + assert T2_tilde == {"x", "f", "a", "b"} + + # Remove a node from the mapping + m.pop(5) + m_rev.pop("e") + _restore_Tinout(5, self.mapped[5], gparams, sparams) + + assert T1 == {9, 3, 5} + assert T2 == {"i", "c", "e"} + assert T1_tilde == {0, 6, 1, 2, 7, 8} + assert T2_tilde == {"x", "f", "a", "b", "g", "h"} + + # Remove a node from the mapping + m.pop(4) + m_rev.pop("d") + _restore_Tinout(4, self.mapped[4], gparams, sparams) + + assert T1 == set() + assert T2 == set() + assert T1_tilde == set(self.G1.nodes()) + assert T2_tilde == set(self.G2.nodes()) + + +class TestDiGraphTinoutUpdating: + edges = [ + (1, 3), + (3, 2), + (3, 4), + (4, 9), + (4, 5), + (3, 9), + (5, 8), + (5, 7), + (8, 7), + (7, 6), + ] + mapped = { + 0: "x", + 1: "a", + 2: "b", + 3: "c", + 4: "d", + 5: "e", + 6: "f", + 7: "g", + 8: "h", + 9: "i", + } + G1 = nx.DiGraph(edges) + G1.add_node(0) + G2 = nx.relabel_nodes(G1, mapping=mapped) + + def test_updating(self): + G2_degree = { + n: (in_degree, out_degree) + for (n, in_degree), (_, out_degree) in zip( + self.G2.in_degree, self.G2.out_degree + ) + } + gparams, sparams = _initialize_parameters(self.G1, self.G2, G2_degree) + m, m_rev, T1_out, T1_in, T1_tilde, _, T2_out, T2_in, T2_tilde, _ = sparams + + # Add node to the mapping + m[4] = self.mapped[4] + m_rev[self.mapped[4]] = 4 + _update_Tinout(4, self.mapped[4], gparams, sparams) + + assert T1_out == {5, 9} + assert T1_in == {3} + assert T2_out == {"i", "e"} + assert T2_in == {"c"} + assert T1_tilde == {0, 1, 2, 6, 7, 8} + assert T2_tilde == {"x", "a", "b", "f", "g", "h"} + + # Add node to the mapping + m[5] = self.mapped[5] + m_rev[self.mapped[5]] = 5 + _update_Tinout(5, self.mapped[5], gparams, sparams) + + assert T1_out == {9, 8, 7} + assert T1_in == {3} + assert T2_out == {"i", "g", "h"} + assert T2_in == {"c"} + assert T1_tilde == {0, 1, 2, 6} + assert T2_tilde == {"x", "a", "b", "f"} + + # Add node to the mapping + m[6] = self.mapped[6] + m_rev[self.mapped[6]] = 6 + _update_Tinout(6, self.mapped[6], gparams, sparams) + + assert T1_out == {9, 8, 7} + assert T1_in == {3, 7} + assert T2_out == {"i", "g", "h"} + assert T2_in == {"c", "g"} + assert T1_tilde == {0, 1, 2} + assert T2_tilde == {"x", "a", "b"} + + # Add node to the mapping + m[3] = self.mapped[3] + m_rev[self.mapped[3]] = 3 + _update_Tinout(3, self.mapped[3], gparams, sparams) + + assert T1_out == {9, 8, 7, 2} + assert T1_in == {7, 1} + assert T2_out == {"i", "g", "h", "b"} + assert T2_in == {"g", "a"} + assert T1_tilde == {0} + assert T2_tilde == {"x"} + + # Add node to the mapping + m[0] = self.mapped[0] + m_rev[self.mapped[0]] = 0 + _update_Tinout(0, self.mapped[0], gparams, sparams) + + assert T1_out == {9, 8, 7, 2} + assert T1_in == {7, 1} + assert T2_out == {"i", "g", "h", "b"} + assert T2_in == {"g", "a"} + assert T1_tilde == set() + assert T2_tilde == set() + + def test_restoring(self): + m = {0: "x", 3: "c", 4: "d", 5: "e", 6: "f"} + m_rev = {"x": 0, "c": 3, "d": 4, "e": 5, "f": 6} + + T1_out = {2, 7, 9, 8} + T1_in = {1, 7} + T2_out = {"b", "g", "i", "h"} + T2_in = {"a", "g"} + T1_tilde = set() + T2_tilde = set() + + gparams = _GraphParameters(self.G1, self.G2, {}, {}, {}, {}, {}) + sparams = _StateParameters( + m, m_rev, T1_out, T1_in, T1_tilde, None, T2_out, T2_in, T2_tilde, None + ) + + # Remove a node from the mapping + m.pop(0) + m_rev.pop("x") + _restore_Tinout_Di(0, self.mapped[0], gparams, sparams) + + assert T1_out == {2, 7, 9, 8} + assert T1_in == {1, 7} + assert T2_out == {"b", "g", "i", "h"} + assert T2_in == {"a", "g"} + assert T1_tilde == {0} + assert T2_tilde == {"x"} + + # Remove a node from the mapping + m.pop(6) + m_rev.pop("f") + _restore_Tinout_Di(6, self.mapped[6], gparams, sparams) + + assert T1_out == {2, 9, 8, 7} + assert T1_in == {1} + assert T2_out == {"b", "i", "h", "g"} + assert T2_in == {"a"} + assert T1_tilde == {0, 6} + assert T2_tilde == {"x", "f"} + + # Remove a node from the mapping + m.pop(3) + m_rev.pop("c") + _restore_Tinout_Di(3, self.mapped[3], gparams, sparams) + + assert T1_out == {9, 8, 7} + assert T1_in == {3} + assert T2_out == {"i", "h", "g"} + assert T2_in == {"c"} + assert T1_tilde == {0, 6, 1, 2} + assert T2_tilde == {"x", "f", "a", "b"} + + # Remove a node from the mapping + m.pop(5) + m_rev.pop("e") + _restore_Tinout_Di(5, self.mapped[5], gparams, sparams) + + assert T1_out == {9, 5} + assert T1_in == {3} + assert T2_out == {"i", "e"} + assert T2_in == {"c"} + assert T1_tilde == {0, 6, 1, 2, 8, 7} + assert T2_tilde == {"x", "f", "a", "b", "h", "g"} + + # Remove a node from the mapping + m.pop(4) + m_rev.pop("d") + _restore_Tinout_Di(4, self.mapped[4], gparams, sparams) + + assert T1_out == set() + assert T1_in == set() + assert T2_out == set() + assert T2_in == set() + assert T1_tilde == set(self.G1.nodes()) + assert T2_tilde == set(self.G2.nodes()) diff --git a/lib/python3.10/site-packages/networkx/algorithms/isomorphism/tests/test_vf2userfunc.py b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/tests/test_vf2userfunc.py new file mode 100644 index 0000000000000000000000000000000000000000..b44f45888519df1fb9d897ae7572fdbce6a8cb9a --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/tests/test_vf2userfunc.py @@ -0,0 +1,200 @@ +""" +Tests for VF2 isomorphism algorithm for weighted graphs. +""" + +import math +from operator import eq + +import networkx as nx +import networkx.algorithms.isomorphism as iso + + +def test_simple(): + # 16 simple tests + w = "weight" + edges = [(0, 0, 1), (0, 0, 1.5), (0, 1, 2), (1, 0, 3)] + for g1 in [nx.Graph(), nx.DiGraph(), nx.MultiGraph(), nx.MultiDiGraph()]: + g1.add_weighted_edges_from(edges) + g2 = g1.subgraph(g1.nodes()) + if g1.is_multigraph(): + em = iso.numerical_multiedge_match("weight", 1) + else: + em = iso.numerical_edge_match("weight", 1) + assert nx.is_isomorphic(g1, g2, edge_match=em) + + for mod1, mod2 in [(False, True), (True, False), (True, True)]: + # mod1 tests a regular edge + # mod2 tests a selfloop + if g2.is_multigraph(): + if mod1: + data1 = {0: {"weight": 10}} + if mod2: + data2 = {0: {"weight": 1}, 1: {"weight": 2.5}} + else: + if mod1: + data1 = {"weight": 10} + if mod2: + data2 = {"weight": 2.5} + + g2 = g1.subgraph(g1.nodes()).copy() + if mod1: + if not g1.is_directed(): + g2._adj[1][0] = data1 + g2._adj[0][1] = data1 + else: + g2._succ[1][0] = data1 + g2._pred[0][1] = data1 + if mod2: + if not g1.is_directed(): + g2._adj[0][0] = data2 + else: + g2._succ[0][0] = data2 + g2._pred[0][0] = data2 + + assert not nx.is_isomorphic(g1, g2, edge_match=em) + + +def test_weightkey(): + g1 = nx.DiGraph() + g2 = nx.DiGraph() + + g1.add_edge("A", "B", weight=1) + g2.add_edge("C", "D", weight=0) + + assert nx.is_isomorphic(g1, g2) + em = iso.numerical_edge_match("nonexistent attribute", 1) + assert nx.is_isomorphic(g1, g2, edge_match=em) + em = iso.numerical_edge_match("weight", 1) + assert not nx.is_isomorphic(g1, g2, edge_match=em) + + g2 = nx.DiGraph() + g2.add_edge("C", "D") + assert nx.is_isomorphic(g1, g2, edge_match=em) + + +class TestNodeMatch_Graph: + def setup_method(self): + self.g1 = nx.Graph() + self.g2 = nx.Graph() + self.build() + + def build(self): + self.nm = iso.categorical_node_match("color", "") + self.em = iso.numerical_edge_match("weight", 1) + + self.g1.add_node("A", color="red") + self.g2.add_node("C", color="blue") + + self.g1.add_edge("A", "B", weight=1) + self.g2.add_edge("C", "D", weight=1) + + def test_noweight_nocolor(self): + assert nx.is_isomorphic(self.g1, self.g2) + + def test_color1(self): + assert not nx.is_isomorphic(self.g1, self.g2, node_match=self.nm) + + def test_color2(self): + self.g1.nodes["A"]["color"] = "blue" + assert nx.is_isomorphic(self.g1, self.g2, node_match=self.nm) + + def test_weight1(self): + assert nx.is_isomorphic(self.g1, self.g2, edge_match=self.em) + + def test_weight2(self): + self.g1.add_edge("A", "B", weight=2) + assert not nx.is_isomorphic(self.g1, self.g2, edge_match=self.em) + + def test_colorsandweights1(self): + iso = nx.is_isomorphic(self.g1, self.g2, node_match=self.nm, edge_match=self.em) + assert not iso + + def test_colorsandweights2(self): + self.g1.nodes["A"]["color"] = "blue" + iso = nx.is_isomorphic(self.g1, self.g2, node_match=self.nm, edge_match=self.em) + assert iso + + def test_colorsandweights3(self): + # make the weights disagree + self.g1.add_edge("A", "B", weight=2) + assert not nx.is_isomorphic( + self.g1, self.g2, node_match=self.nm, edge_match=self.em + ) + + +class TestEdgeMatch_MultiGraph: + def setup_method(self): + self.g1 = nx.MultiGraph() + self.g2 = nx.MultiGraph() + self.GM = iso.MultiGraphMatcher + self.build() + + def build(self): + g1 = self.g1 + g2 = self.g2 + + # We will assume integer weights only. + g1.add_edge("A", "B", color="green", weight=0, size=0.5) + g1.add_edge("A", "B", color="red", weight=1, size=0.35) + g1.add_edge("A", "B", color="red", weight=2, size=0.65) + + g2.add_edge("C", "D", color="green", weight=1, size=0.5) + g2.add_edge("C", "D", color="red", weight=0, size=0.45) + g2.add_edge("C", "D", color="red", weight=2, size=0.65) + + if g1.is_multigraph(): + self.em = iso.numerical_multiedge_match("weight", 1) + self.emc = iso.categorical_multiedge_match("color", "") + self.emcm = iso.categorical_multiedge_match(["color", "weight"], ["", 1]) + self.emg1 = iso.generic_multiedge_match("color", "red", eq) + self.emg2 = iso.generic_multiedge_match( + ["color", "weight", "size"], + ["red", 1, 0.5], + [eq, eq, math.isclose], + ) + else: + self.em = iso.numerical_edge_match("weight", 1) + self.emc = iso.categorical_edge_match("color", "") + self.emcm = iso.categorical_edge_match(["color", "weight"], ["", 1]) + self.emg1 = iso.generic_multiedge_match("color", "red", eq) + self.emg2 = iso.generic_edge_match( + ["color", "weight", "size"], + ["red", 1, 0.5], + [eq, eq, math.isclose], + ) + + def test_weights_only(self): + assert nx.is_isomorphic(self.g1, self.g2, edge_match=self.em) + + def test_colors_only(self): + gm = self.GM(self.g1, self.g2, edge_match=self.emc) + assert gm.is_isomorphic() + + def test_colorsandweights(self): + gm = self.GM(self.g1, self.g2, edge_match=self.emcm) + assert not gm.is_isomorphic() + + def test_generic1(self): + gm = self.GM(self.g1, self.g2, edge_match=self.emg1) + assert gm.is_isomorphic() + + def test_generic2(self): + gm = self.GM(self.g1, self.g2, edge_match=self.emg2) + assert not gm.is_isomorphic() + + +class TestEdgeMatch_DiGraph(TestNodeMatch_Graph): + def setup_method(self): + TestNodeMatch_Graph.setup_method(self) + self.g1 = nx.DiGraph() + self.g2 = nx.DiGraph() + self.build() + + +class TestEdgeMatch_MultiDiGraph(TestEdgeMatch_MultiGraph): + def setup_method(self): + TestEdgeMatch_MultiGraph.setup_method(self) + self.g1 = nx.MultiDiGraph() + self.g2 = nx.MultiDiGraph() + self.GM = iso.MultiDiGraphMatcher + self.build() diff --git a/lib/python3.10/site-packages/networkx/algorithms/isomorphism/tree_isomorphism.py b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/tree_isomorphism.py new file mode 100644 index 0000000000000000000000000000000000000000..e409d515f1ce3c3058efa4edb26ea0c02cc6ae5d --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/tree_isomorphism.py @@ -0,0 +1,284 @@ +""" +An algorithm for finding if two undirected trees are isomorphic, +and if so returns an isomorphism between the two sets of nodes. + +This algorithm uses a routine to tell if two rooted trees (trees with a +specified root node) are isomorphic, which may be independently useful. + +This implements an algorithm from: +The Design and Analysis of Computer Algorithms +by Aho, Hopcroft, and Ullman +Addison-Wesley Publishing 1974 +Example 3.2 pp. 84-86. + +A more understandable version of this algorithm is described in: +Homework Assignment 5 +McGill University SOCS 308-250B, Winter 2002 +by Matthew Suderman +http://crypto.cs.mcgill.ca/~crepeau/CS250/2004/HW5+.pdf +""" + +import networkx as nx +from networkx.utils.decorators import not_implemented_for + +__all__ = ["rooted_tree_isomorphism", "tree_isomorphism"] + + +@nx._dispatchable(graphs={"t1": 0, "t2": 2}, returns_graph=True) +def root_trees(t1, root1, t2, root2): + """Create a single digraph dT of free trees t1 and t2 + # with roots root1 and root2 respectively + # rename the nodes with consecutive integers + # so that all nodes get a unique name between both trees + + # our new "fake" root node is 0 + # t1 is numbers from 1 ... n + # t2 is numbered from n+1 to 2n + """ + + dT = nx.DiGraph() + + newroot1 = 1 # left root will be 1 + newroot2 = nx.number_of_nodes(t1) + 1 # right will be n+1 + + # may be overlap in node names here so need separate maps + # given the old name, what is the new + namemap1 = {root1: newroot1} + namemap2 = {root2: newroot2} + + # add an edge from our new root to root1 and root2 + dT.add_edge(0, namemap1[root1]) + dT.add_edge(0, namemap2[root2]) + + for i, (v1, v2) in enumerate(nx.bfs_edges(t1, root1)): + namemap1[v2] = i + namemap1[root1] + 1 + dT.add_edge(namemap1[v1], namemap1[v2]) + + for i, (v1, v2) in enumerate(nx.bfs_edges(t2, root2)): + namemap2[v2] = i + namemap2[root2] + 1 + dT.add_edge(namemap2[v1], namemap2[v2]) + + # now we really want the inverse of namemap1 and namemap2 + # giving the old name given the new + # since the values of namemap1 and namemap2 are unique + # there won't be collisions + namemap = {} + for old, new in namemap1.items(): + namemap[new] = old + for old, new in namemap2.items(): + namemap[new] = old + + return (dT, namemap, newroot1, newroot2) + + +# figure out the level of each node, with 0 at root +@nx._dispatchable +def assign_levels(G, root): + level = {} + level[root] = 0 + for v1, v2 in nx.bfs_edges(G, root): + level[v2] = level[v1] + 1 + + return level + + +# now group the nodes at each level +def group_by_levels(levels): + L = {} + for n, lev in levels.items(): + if lev not in L: + L[lev] = [] + L[lev].append(n) + + return L + + +# now lets get the isomorphism by walking the ordered_children +def generate_isomorphism(v, w, M, ordered_children): + # make sure tree1 comes first + assert v < w + M.append((v, w)) + for i, (x, y) in enumerate(zip(ordered_children[v], ordered_children[w])): + generate_isomorphism(x, y, M, ordered_children) + + +@nx._dispatchable(graphs={"t1": 0, "t2": 2}) +def rooted_tree_isomorphism(t1, root1, t2, root2): + """ + Given two rooted trees `t1` and `t2`, + with roots `root1` and `root2` respectively + this routine will determine if they are isomorphic. + + These trees may be either directed or undirected, + but if they are directed, all edges should flow from the root. + + It returns the isomorphism, a mapping of the nodes of `t1` onto the nodes + of `t2`, such that two trees are then identical. + + Note that two trees may have more than one isomorphism, and this + routine just returns one valid mapping. + + Parameters + ---------- + `t1` : NetworkX graph + One of the trees being compared + + `root1` : a node of `t1` which is the root of the tree + + `t2` : undirected NetworkX graph + The other tree being compared + + `root2` : a node of `t2` which is the root of the tree + + This is a subroutine used to implement `tree_isomorphism`, but will + be somewhat faster if you already have rooted trees. + + Returns + ------- + isomorphism : list + A list of pairs in which the left element is a node in `t1` + and the right element is a node in `t2`. The pairs are in + arbitrary order. If the nodes in one tree is mapped to the names in + the other, then trees will be identical. Note that an isomorphism + will not necessarily be unique. + + If `t1` and `t2` are not isomorphic, then it returns the empty list. + """ + + assert nx.is_tree(t1) + assert nx.is_tree(t2) + + # get the rooted tree formed by combining them + # with unique names + (dT, namemap, newroot1, newroot2) = root_trees(t1, root1, t2, root2) + + # compute the distance from the root, with 0 for our + levels = assign_levels(dT, 0) + + # height + h = max(levels.values()) + + # collect nodes into a dict by level + L = group_by_levels(levels) + + # each node has a label, initially set to 0 + label = {v: 0 for v in dT} + # and also ordered_labels and ordered_children + # which will store ordered tuples + ordered_labels = {v: () for v in dT} + ordered_children = {v: () for v in dT} + + # nothing to do on last level so start on h-1 + # also nothing to do for our fake level 0, so skip that + for i in range(h - 1, 0, -1): + # update the ordered_labels and ordered_children + # for any children + for v in L[i]: + # nothing to do if no children + if dT.out_degree(v) > 0: + # get all the pairs of labels and nodes of children + # and sort by labels + s = sorted((label[u], u) for u in dT.successors(v)) + + # invert to give a list of two tuples + # the sorted labels, and the corresponding children + ordered_labels[v], ordered_children[v] = list(zip(*s)) + + # now collect and sort the sorted ordered_labels + # for all nodes in L[i], carrying along the node + forlabel = sorted((ordered_labels[v], v) for v in L[i]) + + # now assign labels to these nodes, according to the sorted order + # starting from 0, where identical ordered_labels get the same label + current = 0 + for i, (ol, v) in enumerate(forlabel): + # advance to next label if not 0, and different from previous + if (i != 0) and (ol != forlabel[i - 1][0]): + current += 1 + label[v] = current + + # they are isomorphic if the labels of newroot1 and newroot2 are 0 + isomorphism = [] + if label[newroot1] == 0 and label[newroot2] == 0: + generate_isomorphism(newroot1, newroot2, isomorphism, ordered_children) + + # get the mapping back in terms of the old names + # return in sorted order for neatness + isomorphism = [(namemap[u], namemap[v]) for (u, v) in isomorphism] + + return isomorphism + + +@not_implemented_for("directed") +@not_implemented_for("multigraph") +@nx._dispatchable(graphs={"t1": 0, "t2": 1}) +def tree_isomorphism(t1, t2): + """ + Given two undirected (or free) trees `t1` and `t2`, + this routine will determine if they are isomorphic. + It returns the isomorphism, a mapping of the nodes of `t1` onto the nodes + of `t2`, such that two trees are then identical. + + Note that two trees may have more than one isomorphism, and this + routine just returns one valid mapping. + + Parameters + ---------- + t1 : undirected NetworkX graph + One of the trees being compared + + t2 : undirected NetworkX graph + The other tree being compared + + Returns + ------- + isomorphism : list + A list of pairs in which the left element is a node in `t1` + and the right element is a node in `t2`. The pairs are in + arbitrary order. If the nodes in one tree is mapped to the names in + the other, then trees will be identical. Note that an isomorphism + will not necessarily be unique. + + If `t1` and `t2` are not isomorphic, then it returns the empty list. + + Notes + ----- + This runs in O(n*log(n)) time for trees with n nodes. + """ + + assert nx.is_tree(t1) + assert nx.is_tree(t2) + + # To be isomorphic, t1 and t2 must have the same number of nodes. + if nx.number_of_nodes(t1) != nx.number_of_nodes(t2): + return [] + + # Another shortcut is that the sorted degree sequences need to be the same. + degree_sequence1 = sorted(d for (n, d) in t1.degree()) + degree_sequence2 = sorted(d for (n, d) in t2.degree()) + + if degree_sequence1 != degree_sequence2: + return [] + + # A tree can have either 1 or 2 centers. + # If the number doesn't match then t1 and t2 are not isomorphic. + center1 = nx.center(t1) + center2 = nx.center(t2) + + if len(center1) != len(center2): + return [] + + # If there is only 1 center in each, then use it. + if len(center1) == 1: + return rooted_tree_isomorphism(t1, center1[0], t2, center2[0]) + + # If there both have 2 centers, then try the first for t1 + # with the first for t2. + attempts = rooted_tree_isomorphism(t1, center1[0], t2, center2[0]) + + # If that worked we're done. + if len(attempts) > 0: + return attempts + + # Otherwise, try center1[0] with the center2[1], and see if that works + return rooted_tree_isomorphism(t1, center1[0], t2, center2[1]) diff --git a/lib/python3.10/site-packages/networkx/algorithms/isomorphism/vf2pp.py b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/vf2pp.py new file mode 100644 index 0000000000000000000000000000000000000000..3093d9c977eed4458458783424fad10284176e2f --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/vf2pp.py @@ -0,0 +1,1075 @@ +""" +*************** +VF2++ Algorithm +*************** + +An implementation of the VF2++ algorithm [1]_ for Graph Isomorphism testing. + +The simplest interface to use this module is to call: + +`vf2pp_is_isomorphic`: to check whether two graphs are isomorphic. +`vf2pp_isomorphism`: to obtain the node mapping between two graphs, +in case they are isomorphic. +`vf2pp_all_isomorphisms`: to generate all possible mappings between two graphs, +if isomorphic. + +Introduction +------------ +The VF2++ algorithm, follows a similar logic to that of VF2, while also +introducing new easy-to-check cutting rules and determining the optimal access +order of nodes. It is also implemented in a non-recursive manner, which saves +both time and space, when compared to its previous counterpart. + +The optimal node ordering is obtained after taking into consideration both the +degree but also the label rarity of each node. +This way we place the nodes that are more likely to match, first in the order, +thus examining the most promising branches in the beginning. +The rules also consider node labels, making it easier to prune unfruitful +branches early in the process. + +Examples +-------- + +Suppose G1 and G2 are Isomorphic Graphs. Verification is as follows: + +Without node labels: + +>>> import networkx as nx +>>> G1 = nx.path_graph(4) +>>> G2 = nx.path_graph(4) +>>> nx.vf2pp_is_isomorphic(G1, G2, node_label=None) +True +>>> nx.vf2pp_isomorphism(G1, G2, node_label=None) +{1: 1, 2: 2, 0: 0, 3: 3} + +With node labels: + +>>> G1 = nx.path_graph(4) +>>> G2 = nx.path_graph(4) +>>> mapped = {1: 1, 2: 2, 3: 3, 0: 0} +>>> nx.set_node_attributes( +... G1, dict(zip(G1, ["blue", "red", "green", "yellow"])), "label" +... ) +>>> nx.set_node_attributes( +... G2, +... dict(zip([mapped[u] for u in G1], ["blue", "red", "green", "yellow"])), +... "label", +... ) +>>> nx.vf2pp_is_isomorphic(G1, G2, node_label="label") +True +>>> nx.vf2pp_isomorphism(G1, G2, node_label="label") +{1: 1, 2: 2, 0: 0, 3: 3} + +References +---------- +.. [1] Jüttner, Alpár & Madarasi, Péter. (2018). "VF2++—An improved subgraph + isomorphism algorithm". Discrete Applied Mathematics. 242. + https://doi.org/10.1016/j.dam.2018.02.018 + +""" + +import collections + +import networkx as nx + +__all__ = ["vf2pp_isomorphism", "vf2pp_is_isomorphic", "vf2pp_all_isomorphisms"] + +_GraphParameters = collections.namedtuple( + "_GraphParameters", + [ + "G1", + "G2", + "G1_labels", + "G2_labels", + "nodes_of_G1Labels", + "nodes_of_G2Labels", + "G2_nodes_of_degree", + ], +) + +_StateParameters = collections.namedtuple( + "_StateParameters", + [ + "mapping", + "reverse_mapping", + "T1", + "T1_in", + "T1_tilde", + "T1_tilde_in", + "T2", + "T2_in", + "T2_tilde", + "T2_tilde_in", + ], +) + + +@nx._dispatchable(graphs={"G1": 0, "G2": 1}, node_attrs={"node_label": "default_label"}) +def vf2pp_isomorphism(G1, G2, node_label=None, default_label=None): + """Return an isomorphic mapping between `G1` and `G2` if it exists. + + Parameters + ---------- + G1, G2 : NetworkX Graph or MultiGraph instances. + The two graphs to check for isomorphism. + + node_label : str, optional + The name of the node attribute to be used when comparing nodes. + The default is `None`, meaning node attributes are not considered + in the comparison. Any node that doesn't have the `node_label` + attribute uses `default_label` instead. + + default_label : scalar + Default value to use when a node doesn't have an attribute + named `node_label`. Default is `None`. + + Returns + ------- + dict or None + Node mapping if the two graphs are isomorphic. None otherwise. + """ + try: + mapping = next(vf2pp_all_isomorphisms(G1, G2, node_label, default_label)) + return mapping + except StopIteration: + return None + + +@nx._dispatchable(graphs={"G1": 0, "G2": 1}, node_attrs={"node_label": "default_label"}) +def vf2pp_is_isomorphic(G1, G2, node_label=None, default_label=None): + """Examines whether G1 and G2 are isomorphic. + + Parameters + ---------- + G1, G2 : NetworkX Graph or MultiGraph instances. + The two graphs to check for isomorphism. + + node_label : str, optional + The name of the node attribute to be used when comparing nodes. + The default is `None`, meaning node attributes are not considered + in the comparison. Any node that doesn't have the `node_label` + attribute uses `default_label` instead. + + default_label : scalar + Default value to use when a node doesn't have an attribute + named `node_label`. Default is `None`. + + Returns + ------- + bool + True if the two graphs are isomorphic, False otherwise. + """ + if vf2pp_isomorphism(G1, G2, node_label, default_label) is not None: + return True + return False + + +@nx._dispatchable(graphs={"G1": 0, "G2": 1}, node_attrs={"node_label": "default_label"}) +def vf2pp_all_isomorphisms(G1, G2, node_label=None, default_label=None): + """Yields all the possible mappings between G1 and G2. + + Parameters + ---------- + G1, G2 : NetworkX Graph or MultiGraph instances. + The two graphs to check for isomorphism. + + node_label : str, optional + The name of the node attribute to be used when comparing nodes. + The default is `None`, meaning node attributes are not considered + in the comparison. Any node that doesn't have the `node_label` + attribute uses `default_label` instead. + + default_label : scalar + Default value to use when a node doesn't have an attribute + named `node_label`. Default is `None`. + + Yields + ------ + dict + Isomorphic mapping between the nodes in `G1` and `G2`. + """ + if G1.number_of_nodes() == 0 or G2.number_of_nodes() == 0: + return False + + # Create the degree dicts based on graph type + if G1.is_directed(): + G1_degree = { + n: (in_degree, out_degree) + for (n, in_degree), (_, out_degree) in zip(G1.in_degree, G1.out_degree) + } + G2_degree = { + n: (in_degree, out_degree) + for (n, in_degree), (_, out_degree) in zip(G2.in_degree, G2.out_degree) + } + else: + G1_degree = dict(G1.degree) + G2_degree = dict(G2.degree) + + if not G1.is_directed(): + find_candidates = _find_candidates + restore_Tinout = _restore_Tinout + else: + find_candidates = _find_candidates_Di + restore_Tinout = _restore_Tinout_Di + + # Check that both graphs have the same number of nodes and degree sequence + if G1.order() != G2.order(): + return False + if sorted(G1_degree.values()) != sorted(G2_degree.values()): + return False + + # Initialize parameters and cache necessary information about degree and labels + graph_params, state_params = _initialize_parameters( + G1, G2, G2_degree, node_label, default_label + ) + + # Check if G1 and G2 have the same labels, and that number of nodes per label is equal between the two graphs + if not _precheck_label_properties(graph_params): + return False + + # Calculate the optimal node ordering + node_order = _matching_order(graph_params) + + # Initialize the stack + stack = [] + candidates = iter( + find_candidates(node_order[0], graph_params, state_params, G1_degree) + ) + stack.append((node_order[0], candidates)) + + mapping = state_params.mapping + reverse_mapping = state_params.reverse_mapping + + # Index of the node from the order, currently being examined + matching_node = 1 + + while stack: + current_node, candidate_nodes = stack[-1] + + try: + candidate = next(candidate_nodes) + except StopIteration: + # If no remaining candidates, return to a previous state, and follow another branch + stack.pop() + matching_node -= 1 + if stack: + # Pop the previously added u-v pair, and look for a different candidate _v for u + popped_node1, _ = stack[-1] + popped_node2 = mapping[popped_node1] + mapping.pop(popped_node1) + reverse_mapping.pop(popped_node2) + restore_Tinout(popped_node1, popped_node2, graph_params, state_params) + continue + + if _feasibility(current_node, candidate, graph_params, state_params): + # Terminate if mapping is extended to its full + if len(mapping) == G2.number_of_nodes() - 1: + cp_mapping = mapping.copy() + cp_mapping[current_node] = candidate + yield cp_mapping + continue + + # Feasibility rules pass, so extend the mapping and update the parameters + mapping[current_node] = candidate + reverse_mapping[candidate] = current_node + _update_Tinout(current_node, candidate, graph_params, state_params) + # Append the next node and its candidates to the stack + candidates = iter( + find_candidates( + node_order[matching_node], graph_params, state_params, G1_degree + ) + ) + stack.append((node_order[matching_node], candidates)) + matching_node += 1 + + +def _precheck_label_properties(graph_params): + G1, G2, G1_labels, G2_labels, nodes_of_G1Labels, nodes_of_G2Labels, _ = graph_params + if any( + label not in nodes_of_G1Labels or len(nodes_of_G1Labels[label]) != len(nodes) + for label, nodes in nodes_of_G2Labels.items() + ): + return False + return True + + +def _initialize_parameters(G1, G2, G2_degree, node_label=None, default_label=-1): + """Initializes all the necessary parameters for VF2++ + + Parameters + ---------- + G1,G2: NetworkX Graph or MultiGraph instances. + The two graphs to check for isomorphism or monomorphism + + G1_labels,G2_labels: dict + The label of every node in G1 and G2 respectively + + Returns + ------- + graph_params: namedtuple + Contains all the Graph-related parameters: + + G1,G2 + G1_labels,G2_labels: dict + + state_params: namedtuple + Contains all the State-related parameters: + + mapping: dict + The mapping as extended so far. Maps nodes of G1 to nodes of G2 + + reverse_mapping: dict + The reverse mapping as extended so far. Maps nodes from G2 to nodes of G1. It's basically "mapping" reversed + + T1, T2: set + Ti contains uncovered neighbors of covered nodes from Gi, i.e. nodes that are not in the mapping, but are + neighbors of nodes that are. + + T1_out, T2_out: set + Ti_out contains all the nodes from Gi, that are neither in the mapping nor in Ti + """ + G1_labels = dict(G1.nodes(data=node_label, default=default_label)) + G2_labels = dict(G2.nodes(data=node_label, default=default_label)) + + graph_params = _GraphParameters( + G1, + G2, + G1_labels, + G2_labels, + nx.utils.groups(G1_labels), + nx.utils.groups(G2_labels), + nx.utils.groups(G2_degree), + ) + + T1, T1_in = set(), set() + T2, T2_in = set(), set() + if G1.is_directed(): + T1_tilde, T1_tilde_in = ( + set(G1.nodes()), + set(), + ) # todo: do we need Ti_tilde_in? What nodes does it have? + T2_tilde, T2_tilde_in = set(G2.nodes()), set() + else: + T1_tilde, T1_tilde_in = set(G1.nodes()), set() + T2_tilde, T2_tilde_in = set(G2.nodes()), set() + + state_params = _StateParameters( + {}, + {}, + T1, + T1_in, + T1_tilde, + T1_tilde_in, + T2, + T2_in, + T2_tilde, + T2_tilde_in, + ) + + return graph_params, state_params + + +def _matching_order(graph_params): + """The node ordering as introduced in VF2++. + + Notes + ----- + Taking into account the structure of the Graph and the node labeling, the nodes are placed in an order such that, + most of the unfruitful/infeasible branches of the search space can be pruned on high levels, significantly + decreasing the number of visited states. The premise is that, the algorithm will be able to recognize + inconsistencies early, proceeding to go deep into the search tree only if it's needed. + + Parameters + ---------- + graph_params: namedtuple + Contains: + + G1,G2: NetworkX Graph or MultiGraph instances. + The two graphs to check for isomorphism or monomorphism. + + G1_labels,G2_labels: dict + The label of every node in G1 and G2 respectively. + + Returns + ------- + node_order: list + The ordering of the nodes. + """ + G1, G2, G1_labels, _, _, nodes_of_G2Labels, _ = graph_params + if not G1 and not G2: + return {} + + if G1.is_directed(): + G1 = G1.to_undirected(as_view=True) + + V1_unordered = set(G1.nodes()) + label_rarity = {label: len(nodes) for label, nodes in nodes_of_G2Labels.items()} + used_degrees = {node: 0 for node in G1} + node_order = [] + + while V1_unordered: + max_rarity = min(label_rarity[G1_labels[x]] for x in V1_unordered) + rarest_nodes = [ + n for n in V1_unordered if label_rarity[G1_labels[n]] == max_rarity + ] + max_node = max(rarest_nodes, key=G1.degree) + + for dlevel_nodes in nx.bfs_layers(G1, max_node): + nodes_to_add = dlevel_nodes.copy() + while nodes_to_add: + max_used_degree = max(used_degrees[n] for n in nodes_to_add) + max_used_degree_nodes = [ + n for n in nodes_to_add if used_degrees[n] == max_used_degree + ] + max_degree = max(G1.degree[n] for n in max_used_degree_nodes) + max_degree_nodes = [ + n for n in max_used_degree_nodes if G1.degree[n] == max_degree + ] + next_node = min( + max_degree_nodes, key=lambda x: label_rarity[G1_labels[x]] + ) + + node_order.append(next_node) + for node in G1.neighbors(next_node): + used_degrees[node] += 1 + + nodes_to_add.remove(next_node) + label_rarity[G1_labels[next_node]] -= 1 + V1_unordered.discard(next_node) + + return node_order + + +def _find_candidates( + u, graph_params, state_params, G1_degree +): # todo: make the 4th argument the degree of u + """Given node u of G1, finds the candidates of u from G2. + + Parameters + ---------- + u: Graph node + The node from G1 for which to find the candidates from G2. + + graph_params: namedtuple + Contains all the Graph-related parameters: + + G1,G2: NetworkX Graph or MultiGraph instances. + The two graphs to check for isomorphism or monomorphism + + G1_labels,G2_labels: dict + The label of every node in G1 and G2 respectively + + state_params: namedtuple + Contains all the State-related parameters: + + mapping: dict + The mapping as extended so far. Maps nodes of G1 to nodes of G2 + + reverse_mapping: dict + The reverse mapping as extended so far. Maps nodes from G2 to nodes of G1. It's basically "mapping" reversed + + T1, T2: set + Ti contains uncovered neighbors of covered nodes from Gi, i.e. nodes that are not in the mapping, but are + neighbors of nodes that are. + + T1_tilde, T2_tilde: set + Ti_tilde contains all the nodes from Gi, that are neither in the mapping nor in Ti + + Returns + ------- + candidates: set + The nodes from G2 which are candidates for u. + """ + G1, G2, G1_labels, _, _, nodes_of_G2Labels, G2_nodes_of_degree = graph_params + mapping, reverse_mapping, _, _, _, _, _, _, T2_tilde, _ = state_params + + covered_nbrs = [nbr for nbr in G1[u] if nbr in mapping] + if not covered_nbrs: + candidates = set(nodes_of_G2Labels[G1_labels[u]]) + candidates.intersection_update(G2_nodes_of_degree[G1_degree[u]]) + candidates.intersection_update(T2_tilde) + candidates.difference_update(reverse_mapping) + if G1.is_multigraph(): + candidates.difference_update( + { + node + for node in candidates + if G1.number_of_edges(u, u) != G2.number_of_edges(node, node) + } + ) + return candidates + + nbr1 = covered_nbrs[0] + common_nodes = set(G2[mapping[nbr1]]) + + for nbr1 in covered_nbrs[1:]: + common_nodes.intersection_update(G2[mapping[nbr1]]) + + common_nodes.difference_update(reverse_mapping) + common_nodes.intersection_update(G2_nodes_of_degree[G1_degree[u]]) + common_nodes.intersection_update(nodes_of_G2Labels[G1_labels[u]]) + if G1.is_multigraph(): + common_nodes.difference_update( + { + node + for node in common_nodes + if G1.number_of_edges(u, u) != G2.number_of_edges(node, node) + } + ) + return common_nodes + + +def _find_candidates_Di(u, graph_params, state_params, G1_degree): + G1, G2, G1_labels, _, _, nodes_of_G2Labels, G2_nodes_of_degree = graph_params + mapping, reverse_mapping, _, _, _, _, _, _, T2_tilde, _ = state_params + + covered_successors = [succ for succ in G1[u] if succ in mapping] + covered_predecessors = [pred for pred in G1.pred[u] if pred in mapping] + + if not (covered_successors or covered_predecessors): + candidates = set(nodes_of_G2Labels[G1_labels[u]]) + candidates.intersection_update(G2_nodes_of_degree[G1_degree[u]]) + candidates.intersection_update(T2_tilde) + candidates.difference_update(reverse_mapping) + if G1.is_multigraph(): + candidates.difference_update( + { + node + for node in candidates + if G1.number_of_edges(u, u) != G2.number_of_edges(node, node) + } + ) + return candidates + + if covered_successors: + succ1 = covered_successors[0] + common_nodes = set(G2.pred[mapping[succ1]]) + + for succ1 in covered_successors[1:]: + common_nodes.intersection_update(G2.pred[mapping[succ1]]) + else: + pred1 = covered_predecessors.pop() + common_nodes = set(G2[mapping[pred1]]) + + for pred1 in covered_predecessors: + common_nodes.intersection_update(G2[mapping[pred1]]) + + common_nodes.difference_update(reverse_mapping) + common_nodes.intersection_update(G2_nodes_of_degree[G1_degree[u]]) + common_nodes.intersection_update(nodes_of_G2Labels[G1_labels[u]]) + if G1.is_multigraph(): + common_nodes.difference_update( + { + node + for node in common_nodes + if G1.number_of_edges(u, u) != G2.number_of_edges(node, node) + } + ) + return common_nodes + + +def _feasibility(node1, node2, graph_params, state_params): + """Given a candidate pair of nodes u and v from G1 and G2 respectively, checks if it's feasible to extend the + mapping, i.e. if u and v can be matched. + + Notes + ----- + This function performs all the necessary checking by applying both consistency and cutting rules. + + Parameters + ---------- + node1, node2: Graph node + The candidate pair of nodes being checked for matching + + graph_params: namedtuple + Contains all the Graph-related parameters: + + G1,G2: NetworkX Graph or MultiGraph instances. + The two graphs to check for isomorphism or monomorphism + + G1_labels,G2_labels: dict + The label of every node in G1 and G2 respectively + + state_params: namedtuple + Contains all the State-related parameters: + + mapping: dict + The mapping as extended so far. Maps nodes of G1 to nodes of G2 + + reverse_mapping: dict + The reverse mapping as extended so far. Maps nodes from G2 to nodes of G1. It's basically "mapping" reversed + + T1, T2: set + Ti contains uncovered neighbors of covered nodes from Gi, i.e. nodes that are not in the mapping, but are + neighbors of nodes that are. + + T1_out, T2_out: set + Ti_out contains all the nodes from Gi, that are neither in the mapping nor in Ti + + Returns + ------- + True if all checks are successful, False otherwise. + """ + G1 = graph_params.G1 + + if _cut_PT(node1, node2, graph_params, state_params): + return False + + if G1.is_multigraph(): + if not _consistent_PT(node1, node2, graph_params, state_params): + return False + + return True + + +def _cut_PT(u, v, graph_params, state_params): + """Implements the cutting rules for the ISO problem. + + Parameters + ---------- + u, v: Graph node + The two candidate nodes being examined. + + graph_params: namedtuple + Contains all the Graph-related parameters: + + G1,G2: NetworkX Graph or MultiGraph instances. + The two graphs to check for isomorphism or monomorphism + + G1_labels,G2_labels: dict + The label of every node in G1 and G2 respectively + + state_params: namedtuple + Contains all the State-related parameters: + + mapping: dict + The mapping as extended so far. Maps nodes of G1 to nodes of G2 + + reverse_mapping: dict + The reverse mapping as extended so far. Maps nodes from G2 to nodes of G1. It's basically "mapping" reversed + + T1, T2: set + Ti contains uncovered neighbors of covered nodes from Gi, i.e. nodes that are not in the mapping, but are + neighbors of nodes that are. + + T1_tilde, T2_tilde: set + Ti_out contains all the nodes from Gi, that are neither in the mapping nor in Ti + + Returns + ------- + True if we should prune this branch, i.e. the node pair failed the cutting checks. False otherwise. + """ + G1, G2, G1_labels, G2_labels, _, _, _ = graph_params + ( + _, + _, + T1, + T1_in, + T1_tilde, + _, + T2, + T2_in, + T2_tilde, + _, + ) = state_params + + u_labels_predecessors, v_labels_predecessors = {}, {} + if G1.is_directed(): + u_labels_predecessors = nx.utils.groups( + {n1: G1_labels[n1] for n1 in G1.pred[u]} + ) + v_labels_predecessors = nx.utils.groups( + {n2: G2_labels[n2] for n2 in G2.pred[v]} + ) + + if set(u_labels_predecessors.keys()) != set(v_labels_predecessors.keys()): + return True + + u_labels_successors = nx.utils.groups({n1: G1_labels[n1] for n1 in G1[u]}) + v_labels_successors = nx.utils.groups({n2: G2_labels[n2] for n2 in G2[v]}) + + # if the neighbors of u, do not have the same labels as those of v, NOT feasible. + if set(u_labels_successors.keys()) != set(v_labels_successors.keys()): + return True + + for label, G1_nbh in u_labels_successors.items(): + G2_nbh = v_labels_successors[label] + + if G1.is_multigraph(): + # Check for every neighbor in the neighborhood, if u-nbr1 has same edges as v-nbr2 + u_nbrs_edges = sorted(G1.number_of_edges(u, x) for x in G1_nbh) + v_nbrs_edges = sorted(G2.number_of_edges(v, x) for x in G2_nbh) + if any( + u_nbr_edges != v_nbr_edges + for u_nbr_edges, v_nbr_edges in zip(u_nbrs_edges, v_nbrs_edges) + ): + return True + + if len(T1.intersection(G1_nbh)) != len(T2.intersection(G2_nbh)): + return True + if len(T1_tilde.intersection(G1_nbh)) != len(T2_tilde.intersection(G2_nbh)): + return True + if G1.is_directed() and len(T1_in.intersection(G1_nbh)) != len( + T2_in.intersection(G2_nbh) + ): + return True + + if not G1.is_directed(): + return False + + for label, G1_pred in u_labels_predecessors.items(): + G2_pred = v_labels_predecessors[label] + + if G1.is_multigraph(): + # Check for every neighbor in the neighborhood, if u-nbr1 has same edges as v-nbr2 + u_pred_edges = sorted(G1.number_of_edges(u, x) for x in G1_pred) + v_pred_edges = sorted(G2.number_of_edges(v, x) for x in G2_pred) + if any( + u_nbr_edges != v_nbr_edges + for u_nbr_edges, v_nbr_edges in zip(u_pred_edges, v_pred_edges) + ): + return True + + if len(T1.intersection(G1_pred)) != len(T2.intersection(G2_pred)): + return True + if len(T1_tilde.intersection(G1_pred)) != len(T2_tilde.intersection(G2_pred)): + return True + if len(T1_in.intersection(G1_pred)) != len(T2_in.intersection(G2_pred)): + return True + + return False + + +def _consistent_PT(u, v, graph_params, state_params): + """Checks the consistency of extending the mapping using the current node pair. + + Parameters + ---------- + u, v: Graph node + The two candidate nodes being examined. + + graph_params: namedtuple + Contains all the Graph-related parameters: + + G1,G2: NetworkX Graph or MultiGraph instances. + The two graphs to check for isomorphism or monomorphism + + G1_labels,G2_labels: dict + The label of every node in G1 and G2 respectively + + state_params: namedtuple + Contains all the State-related parameters: + + mapping: dict + The mapping as extended so far. Maps nodes of G1 to nodes of G2 + + reverse_mapping: dict + The reverse mapping as extended so far. Maps nodes from G2 to nodes of G1. It's basically "mapping" reversed + + T1, T2: set + Ti contains uncovered neighbors of covered nodes from Gi, i.e. nodes that are not in the mapping, but are + neighbors of nodes that are. + + T1_out, T2_out: set + Ti_out contains all the nodes from Gi, that are neither in the mapping nor in Ti + + Returns + ------- + True if the pair passes all the consistency checks successfully. False otherwise. + """ + G1, G2 = graph_params.G1, graph_params.G2 + mapping, reverse_mapping = state_params.mapping, state_params.reverse_mapping + + for neighbor in G1[u]: + if neighbor in mapping: + if G1.number_of_edges(u, neighbor) != G2.number_of_edges( + v, mapping[neighbor] + ): + return False + + for neighbor in G2[v]: + if neighbor in reverse_mapping: + if G1.number_of_edges(u, reverse_mapping[neighbor]) != G2.number_of_edges( + v, neighbor + ): + return False + + if not G1.is_directed(): + return True + + for predecessor in G1.pred[u]: + if predecessor in mapping: + if G1.number_of_edges(predecessor, u) != G2.number_of_edges( + mapping[predecessor], v + ): + return False + + for predecessor in G2.pred[v]: + if predecessor in reverse_mapping: + if G1.number_of_edges( + reverse_mapping[predecessor], u + ) != G2.number_of_edges(predecessor, v): + return False + + return True + + +def _update_Tinout(new_node1, new_node2, graph_params, state_params): + """Updates the Ti/Ti_out (i=1,2) when a new node pair u-v is added to the mapping. + + Notes + ----- + This function should be called right after the feasibility checks are passed, and node1 is mapped to node2. The + purpose of this function is to avoid brute force computing of Ti/Ti_out by iterating over all nodes of the graph + and checking which nodes satisfy the necessary conditions. Instead, in every step of the algorithm we focus + exclusively on the two nodes that are being added to the mapping, incrementally updating Ti/Ti_out. + + Parameters + ---------- + new_node1, new_node2: Graph node + The two new nodes, added to the mapping. + + graph_params: namedtuple + Contains all the Graph-related parameters: + + G1,G2: NetworkX Graph or MultiGraph instances. + The two graphs to check for isomorphism or monomorphism + + G1_labels,G2_labels: dict + The label of every node in G1 and G2 respectively + + state_params: namedtuple + Contains all the State-related parameters: + + mapping: dict + The mapping as extended so far. Maps nodes of G1 to nodes of G2 + + reverse_mapping: dict + The reverse mapping as extended so far. Maps nodes from G2 to nodes of G1. It's basically "mapping" reversed + + T1, T2: set + Ti contains uncovered neighbors of covered nodes from Gi, i.e. nodes that are not in the mapping, but are + neighbors of nodes that are. + + T1_tilde, T2_tilde: set + Ti_out contains all the nodes from Gi, that are neither in the mapping nor in Ti + """ + G1, G2, _, _, _, _, _ = graph_params + ( + mapping, + reverse_mapping, + T1, + T1_in, + T1_tilde, + T1_tilde_in, + T2, + T2_in, + T2_tilde, + T2_tilde_in, + ) = state_params + + uncovered_successors_G1 = {succ for succ in G1[new_node1] if succ not in mapping} + uncovered_successors_G2 = { + succ for succ in G2[new_node2] if succ not in reverse_mapping + } + + # Add the uncovered neighbors of node1 and node2 in T1 and T2 respectively + T1.update(uncovered_successors_G1) + T2.update(uncovered_successors_G2) + T1.discard(new_node1) + T2.discard(new_node2) + + T1_tilde.difference_update(uncovered_successors_G1) + T2_tilde.difference_update(uncovered_successors_G2) + T1_tilde.discard(new_node1) + T2_tilde.discard(new_node2) + + if not G1.is_directed(): + return + + uncovered_predecessors_G1 = { + pred for pred in G1.pred[new_node1] if pred not in mapping + } + uncovered_predecessors_G2 = { + pred for pred in G2.pred[new_node2] if pred not in reverse_mapping + } + + T1_in.update(uncovered_predecessors_G1) + T2_in.update(uncovered_predecessors_G2) + T1_in.discard(new_node1) + T2_in.discard(new_node2) + + T1_tilde.difference_update(uncovered_predecessors_G1) + T2_tilde.difference_update(uncovered_predecessors_G2) + T1_tilde.discard(new_node1) + T2_tilde.discard(new_node2) + + +def _restore_Tinout(popped_node1, popped_node2, graph_params, state_params): + """Restores the previous version of Ti/Ti_out when a node pair is deleted from the mapping. + + Parameters + ---------- + popped_node1, popped_node2: Graph node + The two nodes deleted from the mapping. + + graph_params: namedtuple + Contains all the Graph-related parameters: + + G1,G2: NetworkX Graph or MultiGraph instances. + The two graphs to check for isomorphism or monomorphism + + G1_labels,G2_labels: dict + The label of every node in G1 and G2 respectively + + state_params: namedtuple + Contains all the State-related parameters: + + mapping: dict + The mapping as extended so far. Maps nodes of G1 to nodes of G2 + + reverse_mapping: dict + The reverse mapping as extended so far. Maps nodes from G2 to nodes of G1. It's basically "mapping" reversed + + T1, T2: set + Ti contains uncovered neighbors of covered nodes from Gi, i.e. nodes that are not in the mapping, but are + neighbors of nodes that are. + + T1_tilde, T2_tilde: set + Ti_out contains all the nodes from Gi, that are neither in the mapping nor in Ti + """ + # If the node we want to remove from the mapping, has at least one covered neighbor, add it to T1. + G1, G2, _, _, _, _, _ = graph_params + ( + mapping, + reverse_mapping, + T1, + T1_in, + T1_tilde, + T1_tilde_in, + T2, + T2_in, + T2_tilde, + T2_tilde_in, + ) = state_params + + is_added = False + for neighbor in G1[popped_node1]: + if neighbor in mapping: + # if a neighbor of the excluded node1 is in the mapping, keep node1 in T1 + is_added = True + T1.add(popped_node1) + else: + # check if its neighbor has another connection with a covered node. If not, only then exclude it from T1 + if any(nbr in mapping for nbr in G1[neighbor]): + continue + T1.discard(neighbor) + T1_tilde.add(neighbor) + + # Case where the node is not present in neither the mapping nor T1. By definition, it should belong to T1_tilde + if not is_added: + T1_tilde.add(popped_node1) + + is_added = False + for neighbor in G2[popped_node2]: + if neighbor in reverse_mapping: + is_added = True + T2.add(popped_node2) + else: + if any(nbr in reverse_mapping for nbr in G2[neighbor]): + continue + T2.discard(neighbor) + T2_tilde.add(neighbor) + + if not is_added: + T2_tilde.add(popped_node2) + + +def _restore_Tinout_Di(popped_node1, popped_node2, graph_params, state_params): + # If the node we want to remove from the mapping, has at least one covered neighbor, add it to T1. + G1, G2, _, _, _, _, _ = graph_params + ( + mapping, + reverse_mapping, + T1, + T1_in, + T1_tilde, + T1_tilde_in, + T2, + T2_in, + T2_tilde, + T2_tilde_in, + ) = state_params + + is_added = False + for successor in G1[popped_node1]: + if successor in mapping: + # if a neighbor of the excluded node1 is in the mapping, keep node1 in T1 + is_added = True + T1_in.add(popped_node1) + else: + # check if its neighbor has another connection with a covered node. If not, only then exclude it from T1 + if not any(pred in mapping for pred in G1.pred[successor]): + T1.discard(successor) + + if not any(succ in mapping for succ in G1[successor]): + T1_in.discard(successor) + + if successor not in T1: + if successor not in T1_in: + T1_tilde.add(successor) + + for predecessor in G1.pred[popped_node1]: + if predecessor in mapping: + # if a neighbor of the excluded node1 is in the mapping, keep node1 in T1 + is_added = True + T1.add(popped_node1) + else: + # check if its neighbor has another connection with a covered node. If not, only then exclude it from T1 + if not any(pred in mapping for pred in G1.pred[predecessor]): + T1.discard(predecessor) + + if not any(succ in mapping for succ in G1[predecessor]): + T1_in.discard(predecessor) + + if not (predecessor in T1 or predecessor in T1_in): + T1_tilde.add(predecessor) + + # Case where the node is not present in neither the mapping nor T1. By definition it should belong to T1_tilde + if not is_added: + T1_tilde.add(popped_node1) + + is_added = False + for successor in G2[popped_node2]: + if successor in reverse_mapping: + is_added = True + T2_in.add(popped_node2) + else: + if not any(pred in reverse_mapping for pred in G2.pred[successor]): + T2.discard(successor) + + if not any(succ in reverse_mapping for succ in G2[successor]): + T2_in.discard(successor) + + if successor not in T2: + if successor not in T2_in: + T2_tilde.add(successor) + + for predecessor in G2.pred[popped_node2]: + if predecessor in reverse_mapping: + # if a neighbor of the excluded node1 is in the mapping, keep node1 in T1 + is_added = True + T2.add(popped_node2) + else: + # check if its neighbor has another connection with a covered node. If not, only then exclude it from T1 + if not any(pred in reverse_mapping for pred in G2.pred[predecessor]): + T2.discard(predecessor) + + if not any(succ in reverse_mapping for succ in G2[predecessor]): + T2_in.discard(predecessor) + + if not (predecessor in T2 or predecessor in T2_in): + T2_tilde.add(predecessor) + + if not is_added: + T2_tilde.add(popped_node2) diff --git a/lib/python3.10/site-packages/networkx/algorithms/isomorphism/vf2userfunc.py b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/vf2userfunc.py new file mode 100644 index 0000000000000000000000000000000000000000..6fcf8a15f6ec0ef517d225a9d0095cfe5dc26ab2 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/isomorphism/vf2userfunc.py @@ -0,0 +1,192 @@ +""" +Module to simplify the specification of user-defined equality functions for +node and edge attributes during isomorphism checks. + +During the construction of an isomorphism, the algorithm considers two +candidate nodes n1 in G1 and n2 in G2. The graphs G1 and G2 are then +compared with respect to properties involving n1 and n2, and if the outcome +is good, then the candidate nodes are considered isomorphic. NetworkX +provides a simple mechanism for users to extend the comparisons to include +node and edge attributes. + +Node attributes are handled by the node_match keyword. When considering +n1 and n2, the algorithm passes their node attribute dictionaries to +node_match, and if it returns False, then n1 and n2 cannot be +considered to be isomorphic. + +Edge attributes are handled by the edge_match keyword. When considering +n1 and n2, the algorithm must verify that outgoing edges from n1 are +commensurate with the outgoing edges for n2. If the graph is directed, +then a similar check is also performed for incoming edges. + +Focusing only on outgoing edges, we consider pairs of nodes (n1, v1) from +G1 and (n2, v2) from G2. For graphs and digraphs, there is only one edge +between (n1, v1) and only one edge between (n2, v2). Those edge attribute +dictionaries are passed to edge_match, and if it returns False, then +n1 and n2 cannot be considered isomorphic. For multigraphs and +multidigraphs, there can be multiple edges between (n1, v1) and also +multiple edges between (n2, v2). Now, there must exist an isomorphism +from "all the edges between (n1, v1)" to "all the edges between (n2, v2)". +So, all of the edge attribute dictionaries are passed to edge_match, and +it must determine if there is an isomorphism between the two sets of edges. +""" + +from . import isomorphvf2 as vf2 + +__all__ = ["GraphMatcher", "DiGraphMatcher", "MultiGraphMatcher", "MultiDiGraphMatcher"] + + +def _semantic_feasibility(self, G1_node, G2_node): + """Returns True if mapping G1_node to G2_node is semantically feasible.""" + # Make sure the nodes match + if self.node_match is not None: + nm = self.node_match(self.G1.nodes[G1_node], self.G2.nodes[G2_node]) + if not nm: + return False + + # Make sure the edges match + if self.edge_match is not None: + # Cached lookups + G1nbrs = self.G1_adj[G1_node] + G2nbrs = self.G2_adj[G2_node] + core_1 = self.core_1 + edge_match = self.edge_match + + for neighbor in G1nbrs: + # G1_node is not in core_1, so we must handle R_self separately + if neighbor == G1_node: + if G2_node in G2nbrs and not edge_match( + G1nbrs[G1_node], G2nbrs[G2_node] + ): + return False + elif neighbor in core_1: + G2_nbr = core_1[neighbor] + if G2_nbr in G2nbrs and not edge_match( + G1nbrs[neighbor], G2nbrs[G2_nbr] + ): + return False + # syntactic check has already verified that neighbors are symmetric + + return True + + +class GraphMatcher(vf2.GraphMatcher): + """VF2 isomorphism checker for undirected graphs.""" + + def __init__(self, G1, G2, node_match=None, edge_match=None): + """Initialize graph matcher. + + Parameters + ---------- + G1, G2: graph + The graphs to be tested. + + node_match: callable + A function that returns True iff node n1 in G1 and n2 in G2 + should be considered equal during the isomorphism test. The + function will be called like:: + + node_match(G1.nodes[n1], G2.nodes[n2]) + + That is, the function will receive the node attribute dictionaries + of the nodes under consideration. If None, then no attributes are + considered when testing for an isomorphism. + + edge_match: callable + A function that returns True iff the edge attribute dictionary for + the pair of nodes (u1, v1) in G1 and (u2, v2) in G2 should be + considered equal during the isomorphism test. The function will be + called like:: + + edge_match(G1[u1][v1], G2[u2][v2]) + + That is, the function will receive the edge attribute dictionaries + of the edges under consideration. If None, then no attributes are + considered when testing for an isomorphism. + + """ + vf2.GraphMatcher.__init__(self, G1, G2) + + self.node_match = node_match + self.edge_match = edge_match + + # These will be modified during checks to minimize code repeat. + self.G1_adj = self.G1.adj + self.G2_adj = self.G2.adj + + semantic_feasibility = _semantic_feasibility + + +class DiGraphMatcher(vf2.DiGraphMatcher): + """VF2 isomorphism checker for directed graphs.""" + + def __init__(self, G1, G2, node_match=None, edge_match=None): + """Initialize graph matcher. + + Parameters + ---------- + G1, G2 : graph + The graphs to be tested. + + node_match : callable + A function that returns True iff node n1 in G1 and n2 in G2 + should be considered equal during the isomorphism test. The + function will be called like:: + + node_match(G1.nodes[n1], G2.nodes[n2]) + + That is, the function will receive the node attribute dictionaries + of the nodes under consideration. If None, then no attributes are + considered when testing for an isomorphism. + + edge_match : callable + A function that returns True iff the edge attribute dictionary for + the pair of nodes (u1, v1) in G1 and (u2, v2) in G2 should be + considered equal during the isomorphism test. The function will be + called like:: + + edge_match(G1[u1][v1], G2[u2][v2]) + + That is, the function will receive the edge attribute dictionaries + of the edges under consideration. If None, then no attributes are + considered when testing for an isomorphism. + + """ + vf2.DiGraphMatcher.__init__(self, G1, G2) + + self.node_match = node_match + self.edge_match = edge_match + + # These will be modified during checks to minimize code repeat. + self.G1_adj = self.G1.adj + self.G2_adj = self.G2.adj + + def semantic_feasibility(self, G1_node, G2_node): + """Returns True if mapping G1_node to G2_node is semantically feasible.""" + + # Test node_match and also test edge_match on successors + feasible = _semantic_feasibility(self, G1_node, G2_node) + if not feasible: + return False + + # Test edge_match on predecessors + self.G1_adj = self.G1.pred + self.G2_adj = self.G2.pred + feasible = _semantic_feasibility(self, G1_node, G2_node) + self.G1_adj = self.G1.adj + self.G2_adj = self.G2.adj + + return feasible + + +# The "semantics" of edge_match are different for multi(di)graphs, but +# the implementation is the same. So, technically we do not need to +# provide "multi" versions, but we do so to match NetworkX's base classes. + + +class MultiGraphMatcher(GraphMatcher): + """VF2 isomorphism checker for undirected multigraphs.""" + + +class MultiDiGraphMatcher(DiGraphMatcher): + """VF2 isomorphism checker for directed multigraphs.""" diff --git a/lib/python3.10/site-packages/networkx/algorithms/link_analysis/__init__.py b/lib/python3.10/site-packages/networkx/algorithms/link_analysis/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6009f000814753ab436278e2d2cc38e961e80f3f --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/link_analysis/__init__.py @@ -0,0 +1,2 @@ +from networkx.algorithms.link_analysis.hits_alg import * +from networkx.algorithms.link_analysis.pagerank_alg import * diff --git a/lib/python3.10/site-packages/networkx/algorithms/link_analysis/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/link_analysis/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a385a1f7e3d908ab99ead918db9321ab552285ee Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/link_analysis/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/link_analysis/__pycache__/hits_alg.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/link_analysis/__pycache__/hits_alg.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b492c59f85204e10e2836be577b12b0c0f90e2de Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/link_analysis/__pycache__/hits_alg.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/link_analysis/__pycache__/pagerank_alg.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/link_analysis/__pycache__/pagerank_alg.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e076b05feb009d09b4b4e6607f3db93dee7e23c Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/link_analysis/__pycache__/pagerank_alg.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/link_analysis/hits_alg.py b/lib/python3.10/site-packages/networkx/algorithms/link_analysis/hits_alg.py new file mode 100644 index 0000000000000000000000000000000000000000..d9e3069d86dcb0ad0149ce66c72161b5340e01e0 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/link_analysis/hits_alg.py @@ -0,0 +1,337 @@ +"""Hubs and authorities analysis of graph structure.""" + +import networkx as nx + +__all__ = ["hits"] + + +@nx._dispatchable(preserve_edge_attrs={"G": {"weight": 1}}) +def hits(G, max_iter=100, tol=1.0e-8, nstart=None, normalized=True): + """Returns HITS hubs and authorities values for nodes. + + The HITS algorithm computes two numbers for a node. + Authorities estimates the node value based on the incoming links. + Hubs estimates the node value based on outgoing links. + + Parameters + ---------- + G : graph + A NetworkX graph + + max_iter : integer, optional + Maximum number of iterations in power method. + + tol : float, optional + Error tolerance used to check convergence in power method iteration. + + nstart : dictionary, optional + Starting value of each node for power method iteration. + + normalized : bool (default=True) + Normalize results by the sum of all of the values. + + Returns + ------- + (hubs,authorities) : two-tuple of dictionaries + Two dictionaries keyed by node containing the hub and authority + values. + + Raises + ------ + PowerIterationFailedConvergence + If the algorithm fails to converge to the specified tolerance + within the specified number of iterations of the power iteration + method. + + Examples + -------- + >>> G = nx.path_graph(4) + >>> h, a = nx.hits(G) + + Notes + ----- + The eigenvector calculation is done by the power iteration method + and has no guarantee of convergence. The iteration will stop + after max_iter iterations or an error tolerance of + number_of_nodes(G)*tol has been reached. + + The HITS algorithm was designed for directed graphs but this + algorithm does not check if the input graph is directed and will + execute on undirected graphs. + + References + ---------- + .. [1] A. Langville and C. Meyer, + "A survey of eigenvector methods of web information retrieval." + http://citeseer.ist.psu.edu/713792.html + .. [2] Jon Kleinberg, + Authoritative sources in a hyperlinked environment + Journal of the ACM 46 (5): 604-32, 1999. + doi:10.1145/324133.324140. + http://www.cs.cornell.edu/home/kleinber/auth.pdf. + """ + import numpy as np + import scipy as sp + + if len(G) == 0: + return {}, {} + A = nx.adjacency_matrix(G, nodelist=list(G), dtype=float) + + if nstart is not None: + nstart = np.array(list(nstart.values())) + if max_iter <= 0: + raise nx.PowerIterationFailedConvergence(max_iter) + try: + _, _, vt = sp.sparse.linalg.svds(A, k=1, v0=nstart, maxiter=max_iter, tol=tol) + except sp.sparse.linalg.ArpackNoConvergence as exc: + raise nx.PowerIterationFailedConvergence(max_iter) from exc + + a = vt.flatten().real + h = A @ a + if normalized: + h /= h.sum() + a /= a.sum() + hubs = dict(zip(G, map(float, h))) + authorities = dict(zip(G, map(float, a))) + return hubs, authorities + + +def _hits_python(G, max_iter=100, tol=1.0e-8, nstart=None, normalized=True): + if isinstance(G, nx.MultiGraph | nx.MultiDiGraph): + raise Exception("hits() not defined for graphs with multiedges.") + if len(G) == 0: + return {}, {} + # choose fixed starting vector if not given + if nstart is None: + h = dict.fromkeys(G, 1.0 / G.number_of_nodes()) + else: + h = nstart + # normalize starting vector + s = 1.0 / sum(h.values()) + for k in h: + h[k] *= s + for _ in range(max_iter): # power iteration: make up to max_iter iterations + hlast = h + h = dict.fromkeys(hlast.keys(), 0) + a = dict.fromkeys(hlast.keys(), 0) + # this "matrix multiply" looks odd because it is + # doing a left multiply a^T=hlast^T*G + for n in h: + for nbr in G[n]: + a[nbr] += hlast[n] * G[n][nbr].get("weight", 1) + # now multiply h=Ga + for n in h: + for nbr in G[n]: + h[n] += a[nbr] * G[n][nbr].get("weight", 1) + # normalize vector + s = 1.0 / max(h.values()) + for n in h: + h[n] *= s + # normalize vector + s = 1.0 / max(a.values()) + for n in a: + a[n] *= s + # check convergence, l1 norm + err = sum(abs(h[n] - hlast[n]) for n in h) + if err < tol: + break + else: + raise nx.PowerIterationFailedConvergence(max_iter) + if normalized: + s = 1.0 / sum(a.values()) + for n in a: + a[n] *= s + s = 1.0 / sum(h.values()) + for n in h: + h[n] *= s + return h, a + + +def _hits_numpy(G, normalized=True): + """Returns HITS hubs and authorities values for nodes. + + The HITS algorithm computes two numbers for a node. + Authorities estimates the node value based on the incoming links. + Hubs estimates the node value based on outgoing links. + + Parameters + ---------- + G : graph + A NetworkX graph + + normalized : bool (default=True) + Normalize results by the sum of all of the values. + + Returns + ------- + (hubs,authorities) : two-tuple of dictionaries + Two dictionaries keyed by node containing the hub and authority + values. + + Examples + -------- + >>> G = nx.path_graph(4) + + The `hubs` and `authorities` are given by the eigenvectors corresponding to the + maximum eigenvalues of the hubs_matrix and the authority_matrix, respectively. + + The ``hubs`` and ``authority`` matrices are computed from the adjacency + matrix: + + >>> adj_ary = nx.to_numpy_array(G) + >>> hubs_matrix = adj_ary @ adj_ary.T + >>> authority_matrix = adj_ary.T @ adj_ary + + `_hits_numpy` maps the eigenvector corresponding to the maximum eigenvalue + of the respective matrices to the nodes in `G`: + + >>> from networkx.algorithms.link_analysis.hits_alg import _hits_numpy + >>> hubs, authority = _hits_numpy(G) + + Notes + ----- + The eigenvector calculation uses NumPy's interface to LAPACK. + + The HITS algorithm was designed for directed graphs but this + algorithm does not check if the input graph is directed and will + execute on undirected graphs. + + References + ---------- + .. [1] A. Langville and C. Meyer, + "A survey of eigenvector methods of web information retrieval." + http://citeseer.ist.psu.edu/713792.html + .. [2] Jon Kleinberg, + Authoritative sources in a hyperlinked environment + Journal of the ACM 46 (5): 604-32, 1999. + doi:10.1145/324133.324140. + http://www.cs.cornell.edu/home/kleinber/auth.pdf. + """ + import numpy as np + + if len(G) == 0: + return {}, {} + adj_ary = nx.to_numpy_array(G) + # Hub matrix + H = adj_ary @ adj_ary.T + e, ev = np.linalg.eig(H) + h = ev[:, np.argmax(e)] # eigenvector corresponding to the maximum eigenvalue + # Authority matrix + A = adj_ary.T @ adj_ary + e, ev = np.linalg.eig(A) + a = ev[:, np.argmax(e)] # eigenvector corresponding to the maximum eigenvalue + if normalized: + h /= h.sum() + a /= a.sum() + else: + h /= h.max() + a /= a.max() + hubs = dict(zip(G, map(float, h))) + authorities = dict(zip(G, map(float, a))) + return hubs, authorities + + +def _hits_scipy(G, max_iter=100, tol=1.0e-6, nstart=None, normalized=True): + """Returns HITS hubs and authorities values for nodes. + + + The HITS algorithm computes two numbers for a node. + Authorities estimates the node value based on the incoming links. + Hubs estimates the node value based on outgoing links. + + Parameters + ---------- + G : graph + A NetworkX graph + + max_iter : integer, optional + Maximum number of iterations in power method. + + tol : float, optional + Error tolerance used to check convergence in power method iteration. + + nstart : dictionary, optional + Starting value of each node for power method iteration. + + normalized : bool (default=True) + Normalize results by the sum of all of the values. + + Returns + ------- + (hubs,authorities) : two-tuple of dictionaries + Two dictionaries keyed by node containing the hub and authority + values. + + Examples + -------- + >>> from networkx.algorithms.link_analysis.hits_alg import _hits_scipy + >>> G = nx.path_graph(4) + >>> h, a = _hits_scipy(G) + + Notes + ----- + This implementation uses SciPy sparse matrices. + + The eigenvector calculation is done by the power iteration method + and has no guarantee of convergence. The iteration will stop + after max_iter iterations or an error tolerance of + number_of_nodes(G)*tol has been reached. + + The HITS algorithm was designed for directed graphs but this + algorithm does not check if the input graph is directed and will + execute on undirected graphs. + + Raises + ------ + PowerIterationFailedConvergence + If the algorithm fails to converge to the specified tolerance + within the specified number of iterations of the power iteration + method. + + References + ---------- + .. [1] A. Langville and C. Meyer, + "A survey of eigenvector methods of web information retrieval." + http://citeseer.ist.psu.edu/713792.html + .. [2] Jon Kleinberg, + Authoritative sources in a hyperlinked environment + Journal of the ACM 46 (5): 604-632, 1999. + doi:10.1145/324133.324140. + http://www.cs.cornell.edu/home/kleinber/auth.pdf. + """ + import numpy as np + + if len(G) == 0: + return {}, {} + A = nx.to_scipy_sparse_array(G, nodelist=list(G)) + (n, _) = A.shape # should be square + ATA = A.T @ A # authority matrix + # choose fixed starting vector if not given + if nstart is None: + x = np.ones((n, 1)) / n + else: + x = np.array([nstart.get(n, 0) for n in list(G)], dtype=float) + x /= x.sum() + + # power iteration on authority matrix + i = 0 + while True: + xlast = x + x = ATA @ x + x /= x.max() + # check convergence, l1 norm + err = np.absolute(x - xlast).sum() + if err < tol: + break + if i > max_iter: + raise nx.PowerIterationFailedConvergence(max_iter) + i += 1 + + a = x.flatten() + h = A @ a + if normalized: + h /= h.sum() + a /= a.sum() + hubs = dict(zip(G, map(float, h))) + authorities = dict(zip(G, map(float, a))) + return hubs, authorities diff --git a/lib/python3.10/site-packages/networkx/algorithms/link_analysis/pagerank_alg.py b/lib/python3.10/site-packages/networkx/algorithms/link_analysis/pagerank_alg.py new file mode 100644 index 0000000000000000000000000000000000000000..de9f95bad50ca30ce7aec650df9110cbaaf9e283 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/link_analysis/pagerank_alg.py @@ -0,0 +1,500 @@ +"""PageRank analysis of graph structure.""" + +from warnings import warn + +import networkx as nx + +__all__ = ["pagerank", "google_matrix"] + + +@nx._dispatchable(edge_attrs="weight") +def pagerank( + G, + alpha=0.85, + personalization=None, + max_iter=100, + tol=1.0e-6, + nstart=None, + weight="weight", + dangling=None, +): + """Returns the PageRank of the nodes in the graph. + + PageRank computes a ranking of the nodes in the graph G based on + the structure of the incoming links. It was originally designed as + an algorithm to rank web pages. + + Parameters + ---------- + G : graph + A NetworkX graph. Undirected graphs will be converted to a directed + graph with two directed edges for each undirected edge. + + alpha : float, optional + Damping parameter for PageRank, default=0.85. + + personalization: dict, optional + The "personalization vector" consisting of a dictionary with a + key some subset of graph nodes and personalization value each of those. + At least one personalization value must be non-zero. + If not specified, a nodes personalization value will be zero. + By default, a uniform distribution is used. + + max_iter : integer, optional + Maximum number of iterations in power method eigenvalue solver. + + tol : float, optional + Error tolerance used to check convergence in power method solver. + The iteration will stop after a tolerance of ``len(G) * tol`` is reached. + + nstart : dictionary, optional + Starting value of PageRank iteration for each node. + + weight : key, optional + Edge data key to use as weight. If None weights are set to 1. + + dangling: dict, optional + The outedges to be assigned to any "dangling" nodes, i.e., nodes without + any outedges. The dict key is the node the outedge points to and the dict + value is the weight of that outedge. By default, dangling nodes are given + outedges according to the personalization vector (uniform if not + specified). This must be selected to result in an irreducible transition + matrix (see notes under google_matrix). It may be common to have the + dangling dict to be the same as the personalization dict. + + + Returns + ------- + pagerank : dictionary + Dictionary of nodes with PageRank as value + + Examples + -------- + >>> G = nx.DiGraph(nx.path_graph(4)) + >>> pr = nx.pagerank(G, alpha=0.9) + + Notes + ----- + The eigenvector calculation is done by the power iteration method + and has no guarantee of convergence. The iteration will stop after + an error tolerance of ``len(G) * tol`` has been reached. If the + number of iterations exceed `max_iter`, a + :exc:`networkx.exception.PowerIterationFailedConvergence` exception + is raised. + + The PageRank algorithm was designed for directed graphs but this + algorithm does not check if the input graph is directed and will + execute on undirected graphs by converting each edge in the + directed graph to two edges. + + See Also + -------- + google_matrix + + Raises + ------ + PowerIterationFailedConvergence + If the algorithm fails to converge to the specified tolerance + within the specified number of iterations of the power iteration + method. + + References + ---------- + .. [1] A. Langville and C. Meyer, + "A survey of eigenvector methods of web information retrieval." + http://citeseer.ist.psu.edu/713792.html + .. [2] Page, Lawrence; Brin, Sergey; Motwani, Rajeev and Winograd, Terry, + The PageRank citation ranking: Bringing order to the Web. 1999 + http://dbpubs.stanford.edu:8090/pub/showDoc.Fulltext?lang=en&doc=1999-66&format=pdf + + """ + return _pagerank_scipy( + G, alpha, personalization, max_iter, tol, nstart, weight, dangling + ) + + +def _pagerank_python( + G, + alpha=0.85, + personalization=None, + max_iter=100, + tol=1.0e-6, + nstart=None, + weight="weight", + dangling=None, +): + if len(G) == 0: + return {} + + D = G.to_directed() + + # Create a copy in (right) stochastic form + W = nx.stochastic_graph(D, weight=weight) + N = W.number_of_nodes() + + # Choose fixed starting vector if not given + if nstart is None: + x = dict.fromkeys(W, 1.0 / N) + else: + # Normalized nstart vector + s = sum(nstart.values()) + x = {k: v / s for k, v in nstart.items()} + + if personalization is None: + # Assign uniform personalization vector if not given + p = dict.fromkeys(W, 1.0 / N) + else: + s = sum(personalization.values()) + p = {k: v / s for k, v in personalization.items()} + + if dangling is None: + # Use personalization vector if dangling vector not specified + dangling_weights = p + else: + s = sum(dangling.values()) + dangling_weights = {k: v / s for k, v in dangling.items()} + dangling_nodes = [n for n in W if W.out_degree(n, weight=weight) == 0.0] + + # power iteration: make up to max_iter iterations + for _ in range(max_iter): + xlast = x + x = dict.fromkeys(xlast.keys(), 0) + danglesum = alpha * sum(xlast[n] for n in dangling_nodes) + for n in x: + # this matrix multiply looks odd because it is + # doing a left multiply x^T=xlast^T*W + for _, nbr, wt in W.edges(n, data=weight): + x[nbr] += alpha * xlast[n] * wt + x[n] += danglesum * dangling_weights.get(n, 0) + (1.0 - alpha) * p.get(n, 0) + # check convergence, l1 norm + err = sum(abs(x[n] - xlast[n]) for n in x) + if err < N * tol: + return x + raise nx.PowerIterationFailedConvergence(max_iter) + + +@nx._dispatchable(edge_attrs="weight") +def google_matrix( + G, alpha=0.85, personalization=None, nodelist=None, weight="weight", dangling=None +): + """Returns the Google matrix of the graph. + + Parameters + ---------- + G : graph + A NetworkX graph. Undirected graphs will be converted to a directed + graph with two directed edges for each undirected edge. + + alpha : float + The damping factor. + + personalization: dict, optional + The "personalization vector" consisting of a dictionary with a + key some subset of graph nodes and personalization value each of those. + At least one personalization value must be non-zero. + If not specified, a nodes personalization value will be zero. + By default, a uniform distribution is used. + + nodelist : list, optional + The rows and columns are ordered according to the nodes in nodelist. + If nodelist is None, then the ordering is produced by G.nodes(). + + weight : key, optional + Edge data key to use as weight. If None weights are set to 1. + + dangling: dict, optional + The outedges to be assigned to any "dangling" nodes, i.e., nodes without + any outedges. The dict key is the node the outedge points to and the dict + value is the weight of that outedge. By default, dangling nodes are given + outedges according to the personalization vector (uniform if not + specified) This must be selected to result in an irreducible transition + matrix (see notes below). It may be common to have the dangling dict to + be the same as the personalization dict. + + Returns + ------- + A : 2D NumPy ndarray + Google matrix of the graph + + Notes + ----- + The array returned represents the transition matrix that describes the + Markov chain used in PageRank. For PageRank to converge to a unique + solution (i.e., a unique stationary distribution in a Markov chain), the + transition matrix must be irreducible. In other words, it must be that + there exists a path between every pair of nodes in the graph, or else there + is the potential of "rank sinks." + + This implementation works with Multi(Di)Graphs. For multigraphs the + weight between two nodes is set to be the sum of all edge weights + between those nodes. + + See Also + -------- + pagerank + """ + import numpy as np + + if nodelist is None: + nodelist = list(G) + + A = nx.to_numpy_array(G, nodelist=nodelist, weight=weight) + N = len(G) + if N == 0: + return A + + # Personalization vector + if personalization is None: + p = np.repeat(1.0 / N, N) + else: + p = np.array([personalization.get(n, 0) for n in nodelist], dtype=float) + if p.sum() == 0: + raise ZeroDivisionError + p /= p.sum() + + # Dangling nodes + if dangling is None: + dangling_weights = p + else: + # Convert the dangling dictionary into an array in nodelist order + dangling_weights = np.array([dangling.get(n, 0) for n in nodelist], dtype=float) + dangling_weights /= dangling_weights.sum() + dangling_nodes = np.where(A.sum(axis=1) == 0)[0] + + # Assign dangling_weights to any dangling nodes (nodes with no out links) + A[dangling_nodes] = dangling_weights + + A /= A.sum(axis=1)[:, np.newaxis] # Normalize rows to sum to 1 + + return alpha * A + (1 - alpha) * p + + +def _pagerank_numpy( + G, alpha=0.85, personalization=None, weight="weight", dangling=None +): + """Returns the PageRank of the nodes in the graph. + + PageRank computes a ranking of the nodes in the graph G based on + the structure of the incoming links. It was originally designed as + an algorithm to rank web pages. + + Parameters + ---------- + G : graph + A NetworkX graph. Undirected graphs will be converted to a directed + graph with two directed edges for each undirected edge. + + alpha : float, optional + Damping parameter for PageRank, default=0.85. + + personalization: dict, optional + The "personalization vector" consisting of a dictionary with a + key some subset of graph nodes and personalization value each of those. + At least one personalization value must be non-zero. + If not specified, a nodes personalization value will be zero. + By default, a uniform distribution is used. + + weight : key, optional + Edge data key to use as weight. If None weights are set to 1. + + dangling: dict, optional + The outedges to be assigned to any "dangling" nodes, i.e., nodes without + any outedges. The dict key is the node the outedge points to and the dict + value is the weight of that outedge. By default, dangling nodes are given + outedges according to the personalization vector (uniform if not + specified) This must be selected to result in an irreducible transition + matrix (see notes under google_matrix). It may be common to have the + dangling dict to be the same as the personalization dict. + + Returns + ------- + pagerank : dictionary + Dictionary of nodes with PageRank as value. + + Examples + -------- + >>> from networkx.algorithms.link_analysis.pagerank_alg import _pagerank_numpy + >>> G = nx.DiGraph(nx.path_graph(4)) + >>> pr = _pagerank_numpy(G, alpha=0.9) + + Notes + ----- + The eigenvector calculation uses NumPy's interface to the LAPACK + eigenvalue solvers. This will be the fastest and most accurate + for small graphs. + + This implementation works with Multi(Di)Graphs. For multigraphs the + weight between two nodes is set to be the sum of all edge weights + between those nodes. + + See Also + -------- + pagerank, google_matrix + + References + ---------- + .. [1] A. Langville and C. Meyer, + "A survey of eigenvector methods of web information retrieval." + http://citeseer.ist.psu.edu/713792.html + .. [2] Page, Lawrence; Brin, Sergey; Motwani, Rajeev and Winograd, Terry, + The PageRank citation ranking: Bringing order to the Web. 1999 + http://dbpubs.stanford.edu:8090/pub/showDoc.Fulltext?lang=en&doc=1999-66&format=pdf + """ + import numpy as np + + if len(G) == 0: + return {} + M = google_matrix( + G, alpha, personalization=personalization, weight=weight, dangling=dangling + ) + # use numpy LAPACK solver + eigenvalues, eigenvectors = np.linalg.eig(M.T) + ind = np.argmax(eigenvalues) + # eigenvector of largest eigenvalue is at ind, normalized + largest = np.array(eigenvectors[:, ind]).flatten().real + norm = largest.sum() + return dict(zip(G, map(float, largest / norm))) + + +def _pagerank_scipy( + G, + alpha=0.85, + personalization=None, + max_iter=100, + tol=1.0e-6, + nstart=None, + weight="weight", + dangling=None, +): + """Returns the PageRank of the nodes in the graph. + + PageRank computes a ranking of the nodes in the graph G based on + the structure of the incoming links. It was originally designed as + an algorithm to rank web pages. + + Parameters + ---------- + G : graph + A NetworkX graph. Undirected graphs will be converted to a directed + graph with two directed edges for each undirected edge. + + alpha : float, optional + Damping parameter for PageRank, default=0.85. + + personalization: dict, optional + The "personalization vector" consisting of a dictionary with a + key some subset of graph nodes and personalization value each of those. + At least one personalization value must be non-zero. + If not specified, a nodes personalization value will be zero. + By default, a uniform distribution is used. + + max_iter : integer, optional + Maximum number of iterations in power method eigenvalue solver. + + tol : float, optional + Error tolerance used to check convergence in power method solver. + The iteration will stop after a tolerance of ``len(G) * tol`` is reached. + + nstart : dictionary, optional + Starting value of PageRank iteration for each node. + + weight : key, optional + Edge data key to use as weight. If None weights are set to 1. + + dangling: dict, optional + The outedges to be assigned to any "dangling" nodes, i.e., nodes without + any outedges. The dict key is the node the outedge points to and the dict + value is the weight of that outedge. By default, dangling nodes are given + outedges according to the personalization vector (uniform if not + specified) This must be selected to result in an irreducible transition + matrix (see notes under google_matrix). It may be common to have the + dangling dict to be the same as the personalization dict. + + Returns + ------- + pagerank : dictionary + Dictionary of nodes with PageRank as value + + Examples + -------- + >>> from networkx.algorithms.link_analysis.pagerank_alg import _pagerank_scipy + >>> G = nx.DiGraph(nx.path_graph(4)) + >>> pr = _pagerank_scipy(G, alpha=0.9) + + Notes + ----- + The eigenvector calculation uses power iteration with a SciPy + sparse matrix representation. + + This implementation works with Multi(Di)Graphs. For multigraphs the + weight between two nodes is set to be the sum of all edge weights + between those nodes. + + See Also + -------- + pagerank + + Raises + ------ + PowerIterationFailedConvergence + If the algorithm fails to converge to the specified tolerance + within the specified number of iterations of the power iteration + method. + + References + ---------- + .. [1] A. Langville and C. Meyer, + "A survey of eigenvector methods of web information retrieval." + http://citeseer.ist.psu.edu/713792.html + .. [2] Page, Lawrence; Brin, Sergey; Motwani, Rajeev and Winograd, Terry, + The PageRank citation ranking: Bringing order to the Web. 1999 + http://dbpubs.stanford.edu:8090/pub/showDoc.Fulltext?lang=en&doc=1999-66&format=pdf + """ + import numpy as np + import scipy as sp + + N = len(G) + if N == 0: + return {} + + nodelist = list(G) + A = nx.to_scipy_sparse_array(G, nodelist=nodelist, weight=weight, dtype=float) + S = A.sum(axis=1) + S[S != 0] = 1.0 / S[S != 0] + # TODO: csr_array + Q = sp.sparse.csr_array(sp.sparse.spdiags(S.T, 0, *A.shape)) + A = Q @ A + + # initial vector + if nstart is None: + x = np.repeat(1.0 / N, N) + else: + x = np.array([nstart.get(n, 0) for n in nodelist], dtype=float) + x /= x.sum() + + # Personalization vector + if personalization is None: + p = np.repeat(1.0 / N, N) + else: + p = np.array([personalization.get(n, 0) for n in nodelist], dtype=float) + if p.sum() == 0: + raise ZeroDivisionError + p /= p.sum() + # Dangling nodes + if dangling is None: + dangling_weights = p + else: + # Convert the dangling dictionary into an array in nodelist order + dangling_weights = np.array([dangling.get(n, 0) for n in nodelist], dtype=float) + dangling_weights /= dangling_weights.sum() + is_dangling = np.where(S == 0)[0] + + # power iteration: make up to max_iter iterations + for _ in range(max_iter): + xlast = x + x = alpha * (x @ A + sum(x[is_dangling]) * dangling_weights) + (1 - alpha) * p + # check convergence, l1 norm + err = np.absolute(x - xlast).sum() + if err < N * tol: + return dict(zip(nodelist, map(float, x))) + raise nx.PowerIterationFailedConvergence(max_iter) diff --git a/lib/python3.10/site-packages/networkx/algorithms/link_analysis/tests/__init__.py b/lib/python3.10/site-packages/networkx/algorithms/link_analysis/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/python3.10/site-packages/networkx/algorithms/link_analysis/tests/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/link_analysis/tests/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..be6b6748075a3027a6ca516a99d729304229166a Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/link_analysis/tests/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/link_analysis/tests/__pycache__/test_hits.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/link_analysis/tests/__pycache__/test_hits.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d351863eab9ea75fbe66bb941e90d0ef14ee6e9 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/link_analysis/tests/__pycache__/test_hits.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/link_analysis/tests/__pycache__/test_pagerank.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/link_analysis/tests/__pycache__/test_pagerank.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..082c348707ca2391ef825bfc32b46929be8ce81e Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/link_analysis/tests/__pycache__/test_pagerank.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/link_analysis/tests/test_hits.py b/lib/python3.10/site-packages/networkx/algorithms/link_analysis/tests/test_hits.py new file mode 100644 index 0000000000000000000000000000000000000000..54713eb4e2e160e13372f462ef37aff2b4c0f179 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/link_analysis/tests/test_hits.py @@ -0,0 +1,78 @@ +import pytest + +import networkx as nx + +np = pytest.importorskip("numpy") +sp = pytest.importorskip("scipy") + +from networkx.algorithms.link_analysis.hits_alg import ( + _hits_numpy, + _hits_python, + _hits_scipy, +) + +# Example from +# A. Langville and C. Meyer, "A survey of eigenvector methods of web +# information retrieval." http://citeseer.ist.psu.edu/713792.html + + +class TestHITS: + @classmethod + def setup_class(cls): + G = nx.DiGraph() + + edges = [(1, 3), (1, 5), (2, 1), (3, 5), (5, 4), (5, 3), (6, 5)] + + G.add_edges_from(edges, weight=1) + cls.G = G + cls.G.a = dict( + zip(sorted(G), [0.000000, 0.000000, 0.366025, 0.133975, 0.500000, 0.000000]) + ) + cls.G.h = dict( + zip(sorted(G), [0.366025, 0.000000, 0.211325, 0.000000, 0.211325, 0.211325]) + ) + + def test_hits_numpy(self): + G = self.G + h, a = _hits_numpy(G) + for n in G: + assert h[n] == pytest.approx(G.h[n], abs=1e-4) + for n in G: + assert a[n] == pytest.approx(G.a[n], abs=1e-4) + + @pytest.mark.parametrize("hits_alg", (nx.hits, _hits_python, _hits_scipy)) + def test_hits(self, hits_alg): + G = self.G + h, a = hits_alg(G, tol=1.0e-08) + for n in G: + assert h[n] == pytest.approx(G.h[n], abs=1e-4) + for n in G: + assert a[n] == pytest.approx(G.a[n], abs=1e-4) + nstart = {i: 1.0 / 2 for i in G} + h, a = hits_alg(G, nstart=nstart) + for n in G: + assert h[n] == pytest.approx(G.h[n], abs=1e-4) + for n in G: + assert a[n] == pytest.approx(G.a[n], abs=1e-4) + + def test_empty(self): + G = nx.Graph() + assert nx.hits(G) == ({}, {}) + assert _hits_numpy(G) == ({}, {}) + assert _hits_python(G) == ({}, {}) + assert _hits_scipy(G) == ({}, {}) + + def test_hits_not_convergent(self): + G = nx.path_graph(50) + with pytest.raises(nx.PowerIterationFailedConvergence): + _hits_scipy(G, max_iter=1) + with pytest.raises(nx.PowerIterationFailedConvergence): + _hits_python(G, max_iter=1) + with pytest.raises(nx.PowerIterationFailedConvergence): + _hits_scipy(G, max_iter=0) + with pytest.raises(nx.PowerIterationFailedConvergence): + _hits_python(G, max_iter=0) + with pytest.raises(nx.PowerIterationFailedConvergence): + nx.hits(G, max_iter=0) + with pytest.raises(nx.PowerIterationFailedConvergence): + nx.hits(G, max_iter=1) diff --git a/lib/python3.10/site-packages/networkx/algorithms/link_analysis/tests/test_pagerank.py b/lib/python3.10/site-packages/networkx/algorithms/link_analysis/tests/test_pagerank.py new file mode 100644 index 0000000000000000000000000000000000000000..44038fd47f4408c6fcf68c78b4ef21fe1b548149 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/link_analysis/tests/test_pagerank.py @@ -0,0 +1,214 @@ +import random + +import pytest + +import networkx as nx +from networkx.classes.tests import dispatch_interface + +np = pytest.importorskip("numpy") +pytest.importorskip("scipy") + +from networkx.algorithms.link_analysis.pagerank_alg import ( + _pagerank_numpy, + _pagerank_python, + _pagerank_scipy, +) + +# Example from +# A. Langville and C. Meyer, "A survey of eigenvector methods of web +# information retrieval." http://citeseer.ist.psu.edu/713792.html + + +class TestPageRank: + @classmethod + def setup_class(cls): + G = nx.DiGraph() + edges = [ + (1, 2), + (1, 3), + # 2 is a dangling node + (3, 1), + (3, 2), + (3, 5), + (4, 5), + (4, 6), + (5, 4), + (5, 6), + (6, 4), + ] + G.add_edges_from(edges) + cls.G = G + cls.G.pagerank = dict( + zip( + sorted(G), + [ + 0.03721197, + 0.05395735, + 0.04150565, + 0.37508082, + 0.20599833, + 0.28624589, + ], + ) + ) + cls.dangling_node_index = 1 + cls.dangling_edges = {1: 2, 2: 3, 3: 0, 4: 0, 5: 0, 6: 0} + cls.G.dangling_pagerank = dict( + zip( + sorted(G), + [0.10844518, 0.18618601, 0.0710892, 0.2683668, 0.15919783, 0.20671497], + ) + ) + + @pytest.mark.parametrize("alg", (nx.pagerank, _pagerank_python)) + def test_pagerank(self, alg): + G = self.G + p = alg(G, alpha=0.9, tol=1.0e-08) + for n in G: + assert p[n] == pytest.approx(G.pagerank[n], abs=1e-4) + + nstart = {n: random.random() for n in G} + p = alg(G, alpha=0.9, tol=1.0e-08, nstart=nstart) + for n in G: + assert p[n] == pytest.approx(G.pagerank[n], abs=1e-4) + + @pytest.mark.parametrize("alg", (nx.pagerank, _pagerank_python)) + def test_pagerank_max_iter(self, alg): + with pytest.raises(nx.PowerIterationFailedConvergence): + alg(self.G, max_iter=0) + + def test_numpy_pagerank(self): + G = self.G + p = _pagerank_numpy(G, alpha=0.9) + for n in G: + assert p[n] == pytest.approx(G.pagerank[n], abs=1e-4) + + def test_google_matrix(self): + G = self.G + M = nx.google_matrix(G, alpha=0.9, nodelist=sorted(G)) + _, ev = np.linalg.eig(M.T) + p = ev[:, 0] / ev[:, 0].sum() + for a, b in zip(p, self.G.pagerank.values()): + assert a == pytest.approx(b, abs=1e-7) + + @pytest.mark.parametrize("alg", (nx.pagerank, _pagerank_python, _pagerank_numpy)) + def test_personalization(self, alg): + G = nx.complete_graph(4) + personalize = {0: 1, 1: 1, 2: 4, 3: 4} + answer = { + 0: 0.23246732615667579, + 1: 0.23246732615667579, + 2: 0.267532673843324, + 3: 0.2675326738433241, + } + p = alg(G, alpha=0.85, personalization=personalize) + for n in G: + assert p[n] == pytest.approx(answer[n], abs=1e-4) + + @pytest.mark.parametrize("alg", (nx.pagerank, _pagerank_python, nx.google_matrix)) + def test_zero_personalization_vector(self, alg): + G = nx.complete_graph(4) + personalize = {0: 0, 1: 0, 2: 0, 3: 0} + pytest.raises(ZeroDivisionError, alg, G, personalization=personalize) + + @pytest.mark.parametrize("alg", (nx.pagerank, _pagerank_python)) + def test_one_nonzero_personalization_value(self, alg): + G = nx.complete_graph(4) + personalize = {0: 0, 1: 0, 2: 0, 3: 1} + answer = { + 0: 0.22077931820379187, + 1: 0.22077931820379187, + 2: 0.22077931820379187, + 3: 0.3376620453886241, + } + p = alg(G, alpha=0.85, personalization=personalize) + for n in G: + assert p[n] == pytest.approx(answer[n], abs=1e-4) + + @pytest.mark.parametrize("alg", (nx.pagerank, _pagerank_python)) + def test_incomplete_personalization(self, alg): + G = nx.complete_graph(4) + personalize = {3: 1} + answer = { + 0: 0.22077931820379187, + 1: 0.22077931820379187, + 2: 0.22077931820379187, + 3: 0.3376620453886241, + } + p = alg(G, alpha=0.85, personalization=personalize) + for n in G: + assert p[n] == pytest.approx(answer[n], abs=1e-4) + + def test_dangling_matrix(self): + """ + Tests that the google_matrix doesn't change except for the dangling + nodes. + """ + G = self.G + dangling = self.dangling_edges + dangling_sum = sum(dangling.values()) + M1 = nx.google_matrix(G, personalization=dangling) + M2 = nx.google_matrix(G, personalization=dangling, dangling=dangling) + for i in range(len(G)): + for j in range(len(G)): + if i == self.dangling_node_index and (j + 1) in dangling: + assert M2[i, j] == pytest.approx( + dangling[j + 1] / dangling_sum, abs=1e-4 + ) + else: + assert M2[i, j] == pytest.approx(M1[i, j], abs=1e-4) + + @pytest.mark.parametrize("alg", (nx.pagerank, _pagerank_python, _pagerank_numpy)) + def test_dangling_pagerank(self, alg): + pr = alg(self.G, dangling=self.dangling_edges) + for n in self.G: + assert pr[n] == pytest.approx(self.G.dangling_pagerank[n], abs=1e-4) + + def test_empty(self): + G = nx.Graph() + assert nx.pagerank(G) == {} + assert _pagerank_python(G) == {} + assert _pagerank_numpy(G) == {} + assert nx.google_matrix(G).shape == (0, 0) + + @pytest.mark.parametrize("alg", (nx.pagerank, _pagerank_python)) + def test_multigraph(self, alg): + G = nx.MultiGraph() + G.add_edges_from([(1, 2), (1, 2), (1, 2), (2, 3), (2, 3), ("3", 3), ("3", 3)]) + answer = { + 1: 0.21066048614468322, + 2: 0.3395308825985378, + 3: 0.28933951385531687, + "3": 0.16046911740146227, + } + p = alg(G) + for n in G: + assert p[n] == pytest.approx(answer[n], abs=1e-4) + + +class TestPageRankScipy(TestPageRank): + def test_scipy_pagerank(self): + G = self.G + p = _pagerank_scipy(G, alpha=0.9, tol=1.0e-08) + for n in G: + assert p[n] == pytest.approx(G.pagerank[n], abs=1e-4) + personalize = {n: random.random() for n in G} + p = _pagerank_scipy(G, alpha=0.9, tol=1.0e-08, personalization=personalize) + + nstart = {n: random.random() for n in G} + p = _pagerank_scipy(G, alpha=0.9, tol=1.0e-08, nstart=nstart) + for n in G: + assert p[n] == pytest.approx(G.pagerank[n], abs=1e-4) + + def test_scipy_pagerank_max_iter(self): + with pytest.raises(nx.PowerIterationFailedConvergence): + _pagerank_scipy(self.G, max_iter=0) + + def test_dangling_scipy_pagerank(self): + pr = _pagerank_scipy(self.G, dangling=self.dangling_edges) + for n in self.G: + assert pr[n] == pytest.approx(self.G.dangling_pagerank[n], abs=1e-4) + + def test_empty_scipy(self): + G = nx.Graph() + assert _pagerank_scipy(G) == {} diff --git a/lib/python3.10/site-packages/networkx/algorithms/minors/__init__.py b/lib/python3.10/site-packages/networkx/algorithms/minors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cf15ddb592541a959149842a4d581cf9f0a3e5e1 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/minors/__init__.py @@ -0,0 +1,27 @@ +""" +Subpackages related to graph-minor problems. + +In graph theory, an undirected graph H is called a minor of the graph G if H +can be formed from G by deleting edges and vertices and by contracting edges +[1]_. + +References +---------- +.. [1] https://en.wikipedia.org/wiki/Graph_minor +""" + +from networkx.algorithms.minors.contraction import ( + contracted_edge, + contracted_nodes, + equivalence_classes, + identified_nodes, + quotient_graph, +) + +__all__ = [ + "contracted_edge", + "contracted_nodes", + "equivalence_classes", + "identified_nodes", + "quotient_graph", +] diff --git a/lib/python3.10/site-packages/networkx/algorithms/minors/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/minors/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b95e746d4078785cc062a55a92be1ad64b9003f1 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/minors/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/minors/__pycache__/contraction.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/minors/__pycache__/contraction.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f7afaf8530d9cc90eca5ec85073a1270bb589918 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/minors/__pycache__/contraction.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/minors/contraction.py b/lib/python3.10/site-packages/networkx/algorithms/minors/contraction.py new file mode 100644 index 0000000000000000000000000000000000000000..e85b577843650df3f7f02fd1e60c6fc0800409c6 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/minors/contraction.py @@ -0,0 +1,634 @@ +"""Provides functions for computing minors of a graph.""" + +from itertools import chain, combinations, permutations, product + +import networkx as nx +from networkx import density +from networkx.exception import NetworkXException +from networkx.utils import arbitrary_element + +__all__ = [ + "contracted_edge", + "contracted_nodes", + "equivalence_classes", + "identified_nodes", + "quotient_graph", +] + +chaini = chain.from_iterable + + +def equivalence_classes(iterable, relation): + """Returns equivalence classes of `relation` when applied to `iterable`. + + The equivalence classes, or blocks, consist of objects from `iterable` + which are all equivalent. They are defined to be equivalent if the + `relation` function returns `True` when passed any two objects from that + class, and `False` otherwise. To define an equivalence relation the + function must be reflexive, symmetric and transitive. + + Parameters + ---------- + iterable : list, tuple, or set + An iterable of elements/nodes. + + relation : function + A Boolean-valued function that implements an equivalence relation + (reflexive, symmetric, transitive binary relation) on the elements + of `iterable` - it must take two elements and return `True` if + they are related, or `False` if not. + + Returns + ------- + set of frozensets + A set of frozensets representing the partition induced by the equivalence + relation function `relation` on the elements of `iterable`. Each + member set in the return set represents an equivalence class, or + block, of the partition. + + Duplicate elements will be ignored so it makes the most sense for + `iterable` to be a :class:`set`. + + Notes + ----- + This function does not check that `relation` represents an equivalence + relation. You can check that your equivalence classes provide a partition + using `is_partition`. + + Examples + -------- + Let `X` be the set of integers from `0` to `9`, and consider an equivalence + relation `R` on `X` of congruence modulo `3`: this means that two integers + `x` and `y` in `X` are equivalent under `R` if they leave the same + remainder when divided by `3`, i.e. `(x - y) mod 3 = 0`. + + The equivalence classes of this relation are `{0, 3, 6, 9}`, `{1, 4, 7}`, + `{2, 5, 8}`: `0`, `3`, `6`, `9` are all divisible by `3` and leave zero + remainder; `1`, `4`, `7` leave remainder `1`; while `2`, `5` and `8` leave + remainder `2`. We can see this by calling `equivalence_classes` with + `X` and a function implementation of `R`. + + >>> X = set(range(10)) + >>> def mod3(x, y): + ... return (x - y) % 3 == 0 + >>> equivalence_classes(X, mod3) # doctest: +SKIP + {frozenset({1, 4, 7}), frozenset({8, 2, 5}), frozenset({0, 9, 3, 6})} + """ + # For simplicity of implementation, we initialize the return value as a + # list of lists, then convert it to a set of sets at the end of the + # function. + blocks = [] + # Determine the equivalence class for each element of the iterable. + for y in iterable: + # Each element y must be in *exactly one* equivalence class. + # + # Each block is guaranteed to be non-empty + for block in blocks: + x = arbitrary_element(block) + if relation(x, y): + block.append(y) + break + else: + # If the element y is not part of any known equivalence class, it + # must be in its own, so we create a new singleton equivalence + # class for it. + blocks.append([y]) + return {frozenset(block) for block in blocks} + + +@nx._dispatchable(edge_attrs="weight", returns_graph=True) +def quotient_graph( + G, + partition, + edge_relation=None, + node_data=None, + edge_data=None, + weight="weight", + relabel=False, + create_using=None, +): + """Returns the quotient graph of `G` under the specified equivalence + relation on nodes. + + Parameters + ---------- + G : NetworkX graph + The graph for which to return the quotient graph with the + specified node relation. + + partition : function, or dict or list of lists, tuples or sets + If a function, this function must represent an equivalence + relation on the nodes of `G`. It must take two arguments *u* + and *v* and return True exactly when *u* and *v* are in the + same equivalence class. The equivalence classes form the nodes + in the returned graph. + + If a dict of lists/tuples/sets, the keys can be any meaningful + block labels, but the values must be the block lists/tuples/sets + (one list/tuple/set per block), and the blocks must form a valid + partition of the nodes of the graph. That is, each node must be + in exactly one block of the partition. + + If a list of sets, the list must form a valid partition of + the nodes of the graph. That is, each node must be in exactly + one block of the partition. + + edge_relation : Boolean function with two arguments + This function must represent an edge relation on the *blocks* of + the `partition` of `G`. It must take two arguments, *B* and *C*, + each one a set of nodes, and return True exactly when there should be + an edge joining block *B* to block *C* in the returned graph. + + If `edge_relation` is not specified, it is assumed to be the + following relation. Block *B* is related to block *C* if and + only if some node in *B* is adjacent to some node in *C*, + according to the edge set of `G`. + + node_data : function + This function takes one argument, *B*, a set of nodes in `G`, + and must return a dictionary representing the node data + attributes to set on the node representing *B* in the quotient graph. + If None, the following node attributes will be set: + + * 'graph', the subgraph of the graph `G` that this block + represents, + * 'nnodes', the number of nodes in this block, + * 'nedges', the number of edges within this block, + * 'density', the density of the subgraph of `G` that this + block represents. + + edge_data : function + This function takes two arguments, *B* and *C*, each one a set + of nodes, and must return a dictionary representing the edge + data attributes to set on the edge joining *B* and *C*, should + there be an edge joining *B* and *C* in the quotient graph (if + no such edge occurs in the quotient graph as determined by + `edge_relation`, then the output of this function is ignored). + + If the quotient graph would be a multigraph, this function is + not applied, since the edge data from each edge in the graph + `G` appears in the edges of the quotient graph. + + weight : string or None, optional (default="weight") + The name of an edge attribute that holds the numerical value + used as a weight. If None then each edge has weight 1. + + relabel : bool + If True, relabel the nodes of the quotient graph to be + nonnegative integers. Otherwise, the nodes are identified with + :class:`frozenset` instances representing the blocks given in + `partition`. + + create_using : NetworkX graph constructor, optional (default=nx.Graph) + Graph type to create. If graph instance, then cleared before populated. + + Returns + ------- + NetworkX graph + The quotient graph of `G` under the equivalence relation + specified by `partition`. If the partition were given as a + list of :class:`set` instances and `relabel` is False, + each node will be a :class:`frozenset` corresponding to the same + :class:`set`. + + Raises + ------ + NetworkXException + If the given partition is not a valid partition of the nodes of + `G`. + + Examples + -------- + The quotient graph of the complete bipartite graph under the "same + neighbors" equivalence relation is `K_2`. Under this relation, two nodes + are equivalent if they are not adjacent but have the same neighbor set. + + >>> G = nx.complete_bipartite_graph(2, 3) + >>> same_neighbors = lambda u, v: (u not in G[v] and v not in G[u] and G[u] == G[v]) + >>> Q = nx.quotient_graph(G, same_neighbors) + >>> K2 = nx.complete_graph(2) + >>> nx.is_isomorphic(Q, K2) + True + + The quotient graph of a directed graph under the "same strongly connected + component" equivalence relation is the condensation of the graph (see + :func:`condensation`). This example comes from the Wikipedia article + *`Strongly connected component`_*. + + >>> G = nx.DiGraph() + >>> edges = [ + ... "ab", + ... "be", + ... "bf", + ... "bc", + ... "cg", + ... "cd", + ... "dc", + ... "dh", + ... "ea", + ... "ef", + ... "fg", + ... "gf", + ... "hd", + ... "hf", + ... ] + >>> G.add_edges_from(tuple(x) for x in edges) + >>> components = list(nx.strongly_connected_components(G)) + >>> sorted(sorted(component) for component in components) + [['a', 'b', 'e'], ['c', 'd', 'h'], ['f', 'g']] + >>> + >>> C = nx.condensation(G, components) + >>> component_of = C.graph["mapping"] + >>> same_component = lambda u, v: component_of[u] == component_of[v] + >>> Q = nx.quotient_graph(G, same_component) + >>> nx.is_isomorphic(C, Q) + True + + Node identification can be represented as the quotient of a graph under the + equivalence relation that places the two nodes in one block and each other + node in its own singleton block. + + >>> K24 = nx.complete_bipartite_graph(2, 4) + >>> K34 = nx.complete_bipartite_graph(3, 4) + >>> C = nx.contracted_nodes(K34, 1, 2) + >>> nodes = {1, 2} + >>> is_contracted = lambda u, v: u in nodes and v in nodes + >>> Q = nx.quotient_graph(K34, is_contracted) + >>> nx.is_isomorphic(Q, C) + True + >>> nx.is_isomorphic(Q, K24) + True + + The blockmodeling technique described in [1]_ can be implemented as a + quotient graph. + + >>> G = nx.path_graph(6) + >>> partition = [{0, 1}, {2, 3}, {4, 5}] + >>> M = nx.quotient_graph(G, partition, relabel=True) + >>> list(M.edges()) + [(0, 1), (1, 2)] + + Here is the sample example but using partition as a dict of block sets. + + >>> G = nx.path_graph(6) + >>> partition = {0: {0, 1}, 2: {2, 3}, 4: {4, 5}} + >>> M = nx.quotient_graph(G, partition, relabel=True) + >>> list(M.edges()) + [(0, 1), (1, 2)] + + Partitions can be represented in various ways: + + 0. a list/tuple/set of block lists/tuples/sets + 1. a dict with block labels as keys and blocks lists/tuples/sets as values + 2. a dict with block lists/tuples/sets as keys and block labels as values + 3. a function from nodes in the original iterable to block labels + 4. an equivalence relation function on the target iterable + + As `quotient_graph` is designed to accept partitions represented as (0), (1) or + (4) only, the `equivalence_classes` function can be used to get the partitions + in the right form, in order to call `quotient_graph`. + + .. _Strongly connected component: https://en.wikipedia.org/wiki/Strongly_connected_component + + References + ---------- + .. [1] Patrick Doreian, Vladimir Batagelj, and Anuska Ferligoj. + *Generalized Blockmodeling*. + Cambridge University Press, 2004. + + """ + # If the user provided an equivalence relation as a function to compute + # the blocks of the partition on the nodes of G induced by the + # equivalence relation. + if callable(partition): + # equivalence_classes always return partition of whole G. + partition = equivalence_classes(G, partition) + if not nx.community.is_partition(G, partition): + raise nx.NetworkXException( + "Input `partition` is not an equivalence relation for nodes of G" + ) + return _quotient_graph( + G, + partition, + edge_relation, + node_data, + edge_data, + weight, + relabel, + create_using, + ) + + # If the partition is a dict, it is assumed to be one where the keys are + # user-defined block labels, and values are block lists, tuples or sets. + if isinstance(partition, dict): + partition = list(partition.values()) + + # If the user provided partition as a collection of sets. Then we + # need to check if partition covers all of G nodes. If the answer + # is 'No' then we need to prepare suitable subgraph view. + partition_nodes = set().union(*partition) + if len(partition_nodes) != len(G): + G = G.subgraph(partition_nodes) + # Each node in the graph/subgraph must be in exactly one block. + if not nx.community.is_partition(G, partition): + raise NetworkXException("each node must be in exactly one part of `partition`") + return _quotient_graph( + G, + partition, + edge_relation, + node_data, + edge_data, + weight, + relabel, + create_using, + ) + + +def _quotient_graph( + G, partition, edge_relation, node_data, edge_data, weight, relabel, create_using +): + """Construct the quotient graph assuming input has been checked""" + if create_using is None: + H = G.__class__() + else: + H = nx.empty_graph(0, create_using) + # By default set some basic information about the subgraph that each block + # represents on the nodes in the quotient graph. + if node_data is None: + + def node_data(b): + S = G.subgraph(b) + return { + "graph": S, + "nnodes": len(S), + "nedges": S.number_of_edges(), + "density": density(S), + } + + # Each block of the partition becomes a node in the quotient graph. + partition = [frozenset(b) for b in partition] + H.add_nodes_from((b, node_data(b)) for b in partition) + # By default, the edge relation is the relation defined as follows. B is + # adjacent to C if a node in B is adjacent to a node in C, according to the + # edge set of G. + # + # This is not a particularly efficient implementation of this relation: + # there are O(n^2) pairs to check and each check may require O(log n) time + # (to check set membership). This can certainly be parallelized. + if edge_relation is None: + + def edge_relation(b, c): + return any(v in G[u] for u, v in product(b, c)) + + # By default, sum the weights of the edges joining pairs of nodes across + # blocks to get the weight of the edge joining those two blocks. + if edge_data is None: + + def edge_data(b, c): + edgedata = ( + d + for u, v, d in G.edges(b | c, data=True) + if (u in b and v in c) or (u in c and v in b) + ) + return {"weight": sum(d.get(weight, 1) for d in edgedata)} + + block_pairs = permutations(H, 2) if H.is_directed() else combinations(H, 2) + # In a multigraph, add one edge in the quotient graph for each edge + # in the original graph. + if H.is_multigraph(): + edges = chaini( + ( + (b, c, G.get_edge_data(u, v, default={})) + for u, v in product(b, c) + if v in G[u] + ) + for b, c in block_pairs + if edge_relation(b, c) + ) + # In a simple graph, apply the edge data function to each pair of + # blocks to determine the edge data attributes to apply to each edge + # in the quotient graph. + else: + edges = ( + (b, c, edge_data(b, c)) for (b, c) in block_pairs if edge_relation(b, c) + ) + H.add_edges_from(edges) + # If requested by the user, relabel the nodes to be integers, + # numbered in increasing order from zero in the same order as the + # iteration order of `partition`. + if relabel: + # Can't use nx.convert_node_labels_to_integers() here since we + # want the order of iteration to be the same for backward + # compatibility with the nx.blockmodel() function. + labels = {b: i for i, b in enumerate(partition)} + H = nx.relabel_nodes(H, labels) + return H + + +@nx._dispatchable( + preserve_all_attrs=True, mutates_input={"not copy": 4}, returns_graph=True +) +def contracted_nodes(G, u, v, self_loops=True, copy=True): + """Returns the graph that results from contracting `u` and `v`. + + Node contraction identifies the two nodes as a single node incident to any + edge that was incident to the original two nodes. + + Parameters + ---------- + G : NetworkX graph + The graph whose nodes will be contracted. + + u, v : nodes + Must be nodes in `G`. + + self_loops : Boolean + If this is True, any edges joining `u` and `v` in `G` become + self-loops on the new node in the returned graph. + + copy : Boolean + If this is True (default True), make a copy of + `G` and return that instead of directly changing `G`. + + + Returns + ------- + Networkx graph + If Copy is True, + A new graph object of the same type as `G` (leaving `G` unmodified) + with `u` and `v` identified in a single node. The right node `v` + will be merged into the node `u`, so only `u` will appear in the + returned graph. + If copy is False, + Modifies `G` with `u` and `v` identified in a single node. + The right node `v` will be merged into the node `u`, so + only `u` will appear in the returned graph. + + Notes + ----- + For multigraphs, the edge keys for the realigned edges may + not be the same as the edge keys for the old edges. This is + natural because edge keys are unique only within each pair of nodes. + + For non-multigraphs where `u` and `v` are adjacent to a third node + `w`, the edge (`v`, `w`) will be contracted into the edge (`u`, + `w`) with its attributes stored into a "contraction" attribute. + + This function is also available as `identified_nodes`. + + Examples + -------- + Contracting two nonadjacent nodes of the cycle graph on four nodes `C_4` + yields the path graph (ignoring parallel edges): + + >>> G = nx.cycle_graph(4) + >>> M = nx.contracted_nodes(G, 1, 3) + >>> P3 = nx.path_graph(3) + >>> nx.is_isomorphic(M, P3) + True + + >>> G = nx.MultiGraph(P3) + >>> M = nx.contracted_nodes(G, 0, 2) + >>> M.edges + MultiEdgeView([(0, 1, 0), (0, 1, 1)]) + + >>> G = nx.Graph([(1, 2), (2, 2)]) + >>> H = nx.contracted_nodes(G, 1, 2, self_loops=False) + >>> list(H.nodes()) + [1] + >>> list(H.edges()) + [(1, 1)] + + In a ``MultiDiGraph`` with a self loop, the in and out edges will + be treated separately as edges, so while contracting a node which + has a self loop the contraction will add multiple edges: + + >>> G = nx.MultiDiGraph([(1, 2), (2, 2)]) + >>> H = nx.contracted_nodes(G, 1, 2) + >>> list(H.edges()) # edge 1->2, 2->2, 2<-2 from the original Graph G + [(1, 1), (1, 1), (1, 1)] + >>> H = nx.contracted_nodes(G, 1, 2, self_loops=False) + >>> list(H.edges()) # edge 2->2, 2<-2 from the original Graph G + [(1, 1), (1, 1)] + + See Also + -------- + contracted_edge + quotient_graph + + """ + # Copying has significant overhead and can be disabled if needed + if copy: + H = G.copy() + else: + H = G + + # edge code uses G.edges(v) instead of G.adj[v] to handle multiedges + if H.is_directed(): + edges_to_remap = chain(G.in_edges(v, data=True), G.out_edges(v, data=True)) + else: + edges_to_remap = G.edges(v, data=True) + + # If the H=G, the generators change as H changes + # This makes the edges_to_remap independent of H + if not copy: + edges_to_remap = list(edges_to_remap) + + v_data = H.nodes[v] + H.remove_node(v) + + for prev_w, prev_x, d in edges_to_remap: + w = prev_w if prev_w != v else u + x = prev_x if prev_x != v else u + + if ({prev_w, prev_x} == {u, v}) and not self_loops: + continue + + if not H.has_edge(w, x) or G.is_multigraph(): + H.add_edge(w, x, **d) + else: + if "contraction" in H.edges[(w, x)]: + H.edges[(w, x)]["contraction"][(prev_w, prev_x)] = d + else: + H.edges[(w, x)]["contraction"] = {(prev_w, prev_x): d} + + if "contraction" in H.nodes[u]: + H.nodes[u]["contraction"][v] = v_data + else: + H.nodes[u]["contraction"] = {v: v_data} + return H + + +identified_nodes = contracted_nodes + + +@nx._dispatchable( + preserve_edge_attrs=True, mutates_input={"not copy": 3}, returns_graph=True +) +def contracted_edge(G, edge, self_loops=True, copy=True): + """Returns the graph that results from contracting the specified edge. + + Edge contraction identifies the two endpoints of the edge as a single node + incident to any edge that was incident to the original two nodes. A graph + that results from edge contraction is called a *minor* of the original + graph. + + Parameters + ---------- + G : NetworkX graph + The graph whose edge will be contracted. + + edge : tuple + Must be a pair of nodes in `G`. + + self_loops : Boolean + If this is True, any edges (including `edge`) joining the + endpoints of `edge` in `G` become self-loops on the new node in the + returned graph. + + copy : Boolean (default True) + If this is True, a the contraction will be performed on a copy of `G`, + otherwise the contraction will happen in place. + + Returns + ------- + Networkx graph + A new graph object of the same type as `G` (leaving `G` unmodified) + with endpoints of `edge` identified in a single node. The right node + of `edge` will be merged into the left one, so only the left one will + appear in the returned graph. + + Raises + ------ + ValueError + If `edge` is not an edge in `G`. + + Examples + -------- + Attempting to contract two nonadjacent nodes yields an error: + + >>> G = nx.cycle_graph(4) + >>> nx.contracted_edge(G, (1, 3)) + Traceback (most recent call last): + ... + ValueError: Edge (1, 3) does not exist in graph G; cannot contract it + + Contracting two adjacent nodes in the cycle graph on *n* nodes yields the + cycle graph on *n - 1* nodes: + + >>> C5 = nx.cycle_graph(5) + >>> C4 = nx.cycle_graph(4) + >>> M = nx.contracted_edge(C5, (0, 1), self_loops=False) + >>> nx.is_isomorphic(M, C4) + True + + See also + -------- + contracted_nodes + quotient_graph + + """ + u, v = edge[:2] + if not G.has_edge(u, v): + raise ValueError(f"Edge {edge} does not exist in graph G; cannot contract it") + return contracted_nodes(G, u, v, self_loops=self_loops, copy=copy) diff --git a/lib/python3.10/site-packages/networkx/algorithms/minors/tests/__pycache__/test_contraction.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/minors/tests/__pycache__/test_contraction.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..08682101ea5bdd40bd0c7cfedd7e3ed74bc95797 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/minors/tests/__pycache__/test_contraction.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/minors/tests/test_contraction.py b/lib/python3.10/site-packages/networkx/algorithms/minors/tests/test_contraction.py new file mode 100644 index 0000000000000000000000000000000000000000..2246886767f80c0a6132d61acc931e51b3034a54 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/minors/tests/test_contraction.py @@ -0,0 +1,446 @@ +"""Unit tests for the :mod:`networkx.algorithms.minors.contraction` module.""" + +import pytest + +import networkx as nx +from networkx.utils import arbitrary_element, edges_equal, nodes_equal + + +def test_quotient_graph_complete_multipartite(): + """Tests that the quotient graph of the complete *n*-partite graph + under the "same neighbors" node relation is the complete graph on *n* + nodes. + + """ + G = nx.complete_multipartite_graph(2, 3, 4) + # Two nodes are equivalent if they are not adjacent but have the same + # neighbor set. + + def same_neighbors(u, v): + return u not in G[v] and v not in G[u] and G[u] == G[v] + + expected = nx.complete_graph(3) + actual = nx.quotient_graph(G, same_neighbors) + # It won't take too long to run a graph isomorphism algorithm on such + # small graphs. + assert nx.is_isomorphic(expected, actual) + + +def test_quotient_graph_complete_bipartite(): + """Tests that the quotient graph of the complete bipartite graph under + the "same neighbors" node relation is `K_2`. + + """ + G = nx.complete_bipartite_graph(2, 3) + # Two nodes are equivalent if they are not adjacent but have the same + # neighbor set. + + def same_neighbors(u, v): + return u not in G[v] and v not in G[u] and G[u] == G[v] + + expected = nx.complete_graph(2) + actual = nx.quotient_graph(G, same_neighbors) + # It won't take too long to run a graph isomorphism algorithm on such + # small graphs. + assert nx.is_isomorphic(expected, actual) + + +def test_quotient_graph_edge_relation(): + """Tests for specifying an alternate edge relation for the quotient + graph. + + """ + G = nx.path_graph(5) + + def identity(u, v): + return u == v + + def same_parity(b, c): + return arbitrary_element(b) % 2 == arbitrary_element(c) % 2 + + actual = nx.quotient_graph(G, identity, same_parity) + expected = nx.Graph() + expected.add_edges_from([(0, 2), (0, 4), (2, 4)]) + expected.add_edge(1, 3) + assert nx.is_isomorphic(actual, expected) + + +def test_condensation_as_quotient(): + """This tests that the condensation of a graph can be viewed as the + quotient graph under the "in the same connected component" equivalence + relation. + + """ + # This example graph comes from the file `test_strongly_connected.py`. + G = nx.DiGraph() + G.add_edges_from( + [ + (1, 2), + (2, 3), + (2, 11), + (2, 12), + (3, 4), + (4, 3), + (4, 5), + (5, 6), + (6, 5), + (6, 7), + (7, 8), + (7, 9), + (7, 10), + (8, 9), + (9, 7), + (10, 6), + (11, 2), + (11, 4), + (11, 6), + (12, 6), + (12, 11), + ] + ) + scc = list(nx.strongly_connected_components(G)) + C = nx.condensation(G, scc) + component_of = C.graph["mapping"] + # Two nodes are equivalent if they are in the same connected component. + + def same_component(u, v): + return component_of[u] == component_of[v] + + Q = nx.quotient_graph(G, same_component) + assert nx.is_isomorphic(C, Q) + + +def test_path(): + G = nx.path_graph(6) + partition = [{0, 1}, {2, 3}, {4, 5}] + M = nx.quotient_graph(G, partition, relabel=True) + assert nodes_equal(M, [0, 1, 2]) + assert edges_equal(M.edges(), [(0, 1), (1, 2)]) + for n in M: + assert M.nodes[n]["nedges"] == 1 + assert M.nodes[n]["nnodes"] == 2 + assert M.nodes[n]["density"] == 1 + + +def test_path__partition_provided_as_dict_of_lists(): + G = nx.path_graph(6) + partition = {0: [0, 1], 2: [2, 3], 4: [4, 5]} + M = nx.quotient_graph(G, partition, relabel=True) + assert nodes_equal(M, [0, 1, 2]) + assert edges_equal(M.edges(), [(0, 1), (1, 2)]) + for n in M: + assert M.nodes[n]["nedges"] == 1 + assert M.nodes[n]["nnodes"] == 2 + assert M.nodes[n]["density"] == 1 + + +def test_path__partition_provided_as_dict_of_tuples(): + G = nx.path_graph(6) + partition = {0: (0, 1), 2: (2, 3), 4: (4, 5)} + M = nx.quotient_graph(G, partition, relabel=True) + assert nodes_equal(M, [0, 1, 2]) + assert edges_equal(M.edges(), [(0, 1), (1, 2)]) + for n in M: + assert M.nodes[n]["nedges"] == 1 + assert M.nodes[n]["nnodes"] == 2 + assert M.nodes[n]["density"] == 1 + + +def test_path__partition_provided_as_dict_of_sets(): + G = nx.path_graph(6) + partition = {0: {0, 1}, 2: {2, 3}, 4: {4, 5}} + M = nx.quotient_graph(G, partition, relabel=True) + assert nodes_equal(M, [0, 1, 2]) + assert edges_equal(M.edges(), [(0, 1), (1, 2)]) + for n in M: + assert M.nodes[n]["nedges"] == 1 + assert M.nodes[n]["nnodes"] == 2 + assert M.nodes[n]["density"] == 1 + + +def test_multigraph_path(): + G = nx.MultiGraph(nx.path_graph(6)) + partition = [{0, 1}, {2, 3}, {4, 5}] + M = nx.quotient_graph(G, partition, relabel=True) + assert nodes_equal(M, [0, 1, 2]) + assert edges_equal(M.edges(), [(0, 1), (1, 2)]) + for n in M: + assert M.nodes[n]["nedges"] == 1 + assert M.nodes[n]["nnodes"] == 2 + assert M.nodes[n]["density"] == 1 + + +def test_directed_path(): + G = nx.DiGraph() + nx.add_path(G, range(6)) + partition = [{0, 1}, {2, 3}, {4, 5}] + M = nx.quotient_graph(G, partition, relabel=True) + assert nodes_equal(M, [0, 1, 2]) + assert edges_equal(M.edges(), [(0, 1), (1, 2)]) + for n in M: + assert M.nodes[n]["nedges"] == 1 + assert M.nodes[n]["nnodes"] == 2 + assert M.nodes[n]["density"] == 0.5 + + +def test_directed_multigraph_path(): + G = nx.MultiDiGraph() + nx.add_path(G, range(6)) + partition = [{0, 1}, {2, 3}, {4, 5}] + M = nx.quotient_graph(G, partition, relabel=True) + assert nodes_equal(M, [0, 1, 2]) + assert edges_equal(M.edges(), [(0, 1), (1, 2)]) + for n in M: + assert M.nodes[n]["nedges"] == 1 + assert M.nodes[n]["nnodes"] == 2 + assert M.nodes[n]["density"] == 0.5 + + +def test_overlapping_blocks(): + with pytest.raises(nx.NetworkXException): + G = nx.path_graph(6) + partition = [{0, 1, 2}, {2, 3}, {4, 5}] + nx.quotient_graph(G, partition) + + +def test_weighted_path(): + G = nx.path_graph(6) + for i in range(5): + G[i][i + 1]["w"] = i + 1 + partition = [{0, 1}, {2, 3}, {4, 5}] + M = nx.quotient_graph(G, partition, weight="w", relabel=True) + assert nodes_equal(M, [0, 1, 2]) + assert edges_equal(M.edges(), [(0, 1), (1, 2)]) + assert M[0][1]["weight"] == 2 + assert M[1][2]["weight"] == 4 + for n in M: + assert M.nodes[n]["nedges"] == 1 + assert M.nodes[n]["nnodes"] == 2 + assert M.nodes[n]["density"] == 1 + + +def test_barbell(): + G = nx.barbell_graph(3, 0) + partition = [{0, 1, 2}, {3, 4, 5}] + M = nx.quotient_graph(G, partition, relabel=True) + assert nodes_equal(M, [0, 1]) + assert edges_equal(M.edges(), [(0, 1)]) + for n in M: + assert M.nodes[n]["nedges"] == 3 + assert M.nodes[n]["nnodes"] == 3 + assert M.nodes[n]["density"] == 1 + + +def test_barbell_plus(): + G = nx.barbell_graph(3, 0) + # Add an extra edge joining the bells. + G.add_edge(0, 5) + partition = [{0, 1, 2}, {3, 4, 5}] + M = nx.quotient_graph(G, partition, relabel=True) + assert nodes_equal(M, [0, 1]) + assert edges_equal(M.edges(), [(0, 1)]) + assert M[0][1]["weight"] == 2 + for n in M: + assert M.nodes[n]["nedges"] == 3 + assert M.nodes[n]["nnodes"] == 3 + assert M.nodes[n]["density"] == 1 + + +def test_blockmodel(): + G = nx.path_graph(6) + partition = [[0, 1], [2, 3], [4, 5]] + M = nx.quotient_graph(G, partition, relabel=True) + assert nodes_equal(M.nodes(), [0, 1, 2]) + assert edges_equal(M.edges(), [(0, 1), (1, 2)]) + for n in M.nodes(): + assert M.nodes[n]["nedges"] == 1 + assert M.nodes[n]["nnodes"] == 2 + assert M.nodes[n]["density"] == 1.0 + + +def test_multigraph_blockmodel(): + G = nx.MultiGraph(nx.path_graph(6)) + partition = [[0, 1], [2, 3], [4, 5]] + M = nx.quotient_graph(G, partition, create_using=nx.MultiGraph(), relabel=True) + assert nodes_equal(M.nodes(), [0, 1, 2]) + assert edges_equal(M.edges(), [(0, 1), (1, 2)]) + for n in M.nodes(): + assert M.nodes[n]["nedges"] == 1 + assert M.nodes[n]["nnodes"] == 2 + assert M.nodes[n]["density"] == 1.0 + + +def test_quotient_graph_incomplete_partition(): + G = nx.path_graph(6) + partition = [] + H = nx.quotient_graph(G, partition, relabel=True) + assert nodes_equal(H.nodes(), []) + assert edges_equal(H.edges(), []) + + partition = [[0, 1], [2, 3], [5]] + H = nx.quotient_graph(G, partition, relabel=True) + assert nodes_equal(H.nodes(), [0, 1, 2]) + assert edges_equal(H.edges(), [(0, 1)]) + + +def test_undirected_node_contraction(): + """Tests for node contraction in an undirected graph.""" + G = nx.cycle_graph(4) + actual = nx.contracted_nodes(G, 0, 1) + expected = nx.cycle_graph(3) + expected.add_edge(0, 0) + assert nx.is_isomorphic(actual, expected) + + +def test_directed_node_contraction(): + """Tests for node contraction in a directed graph.""" + G = nx.DiGraph(nx.cycle_graph(4)) + actual = nx.contracted_nodes(G, 0, 1) + expected = nx.DiGraph(nx.cycle_graph(3)) + expected.add_edge(0, 0) + expected.add_edge(0, 0) + assert nx.is_isomorphic(actual, expected) + + +def test_undirected_node_contraction_no_copy(): + """Tests for node contraction in an undirected graph + by making changes in place.""" + G = nx.cycle_graph(4) + actual = nx.contracted_nodes(G, 0, 1, copy=False) + expected = nx.cycle_graph(3) + expected.add_edge(0, 0) + assert nx.is_isomorphic(actual, G) + assert nx.is_isomorphic(actual, expected) + + +def test_directed_node_contraction_no_copy(): + """Tests for node contraction in a directed graph + by making changes in place.""" + G = nx.DiGraph(nx.cycle_graph(4)) + actual = nx.contracted_nodes(G, 0, 1, copy=False) + expected = nx.DiGraph(nx.cycle_graph(3)) + expected.add_edge(0, 0) + expected.add_edge(0, 0) + assert nx.is_isomorphic(actual, G) + assert nx.is_isomorphic(actual, expected) + + +def test_create_multigraph(): + """Tests that using a MultiGraph creates multiple edges.""" + G = nx.path_graph(3, create_using=nx.MultiGraph()) + G.add_edge(0, 1) + G.add_edge(0, 0) + G.add_edge(0, 2) + actual = nx.contracted_nodes(G, 0, 2) + expected = nx.MultiGraph() + expected.add_edge(0, 1) + expected.add_edge(0, 1) + expected.add_edge(0, 1) + expected.add_edge(0, 0) + expected.add_edge(0, 0) + assert edges_equal(actual.edges, expected.edges) + + +def test_multigraph_keys(): + """Tests that multiedge keys are reset in new graph.""" + G = nx.path_graph(3, create_using=nx.MultiGraph()) + G.add_edge(0, 1, 5) + G.add_edge(0, 0, 0) + G.add_edge(0, 2, 5) + actual = nx.contracted_nodes(G, 0, 2) + expected = nx.MultiGraph() + expected.add_edge(0, 1, 0) + expected.add_edge(0, 1, 5) + expected.add_edge(0, 1, 2) # keyed as 2 b/c 2 edges already in G + expected.add_edge(0, 0, 0) + expected.add_edge(0, 0, 1) # this comes from (0, 2, 5) + assert edges_equal(actual.edges, expected.edges) + + +def test_node_attributes(): + """Tests that node contraction preserves node attributes.""" + G = nx.cycle_graph(4) + # Add some data to the two nodes being contracted. + G.nodes[0]["foo"] = "bar" + G.nodes[1]["baz"] = "xyzzy" + actual = nx.contracted_nodes(G, 0, 1) + # We expect that contracting the nodes 0 and 1 in C_4 yields K_3, but + # with nodes labeled 0, 2, and 3, and with a -loop on 0. + expected = nx.complete_graph(3) + expected = nx.relabel_nodes(expected, {1: 2, 2: 3}) + expected.add_edge(0, 0) + cdict = {1: {"baz": "xyzzy"}} + expected.nodes[0].update({"foo": "bar", "contraction": cdict}) + assert nx.is_isomorphic(actual, expected) + assert actual.nodes == expected.nodes + + +def test_edge_attributes(): + """Tests that node contraction preserves edge attributes.""" + # Shape: src1 --> dest <-- src2 + G = nx.DiGraph([("src1", "dest"), ("src2", "dest")]) + G["src1"]["dest"]["value"] = "src1-->dest" + G["src2"]["dest"]["value"] = "src2-->dest" + H = nx.MultiDiGraph(G) + + G = nx.contracted_nodes(G, "src1", "src2") # New Shape: src1 --> dest + assert G.edges[("src1", "dest")]["value"] == "src1-->dest" + assert ( + G.edges[("src1", "dest")]["contraction"][("src2", "dest")]["value"] + == "src2-->dest" + ) + + H = nx.contracted_nodes(H, "src1", "src2") # New Shape: src1 -(x2)-> dest + assert len(H.edges(("src1", "dest"))) == 2 + + +def test_without_self_loops(): + """Tests for node contraction without preserving -loops.""" + G = nx.cycle_graph(4) + actual = nx.contracted_nodes(G, 0, 1, self_loops=False) + expected = nx.complete_graph(3) + assert nx.is_isomorphic(actual, expected) + + +def test_contract_loop_graph(): + """Tests for node contraction when nodes have loops.""" + G = nx.cycle_graph(4) + G.add_edge(0, 0) + actual = nx.contracted_nodes(G, 0, 1) + expected = nx.complete_graph([0, 2, 3]) + expected.add_edge(0, 0) + expected.add_edge(0, 0) + assert edges_equal(actual.edges, expected.edges) + actual = nx.contracted_nodes(G, 1, 0) + expected = nx.complete_graph([1, 2, 3]) + expected.add_edge(1, 1) + expected.add_edge(1, 1) + assert edges_equal(actual.edges, expected.edges) + + +def test_undirected_edge_contraction(): + """Tests for edge contraction in an undirected graph.""" + G = nx.cycle_graph(4) + actual = nx.contracted_edge(G, (0, 1)) + expected = nx.complete_graph(3) + expected.add_edge(0, 0) + assert nx.is_isomorphic(actual, expected) + + +def test_multigraph_edge_contraction(): + """Tests for edge contraction in a multigraph""" + G = nx.cycle_graph(4) + actual = nx.contracted_edge(G, (0, 1, 0)) + expected = nx.complete_graph(3) + expected.add_edge(0, 0) + assert nx.is_isomorphic(actual, expected) + + +def test_nonexistent_edge(): + """Tests that attempting to contract a nonexistent edge raises an + exception. + + """ + with pytest.raises(ValueError): + G = nx.cycle_graph(4) + nx.contracted_edge(G, (0, 2)) diff --git a/lib/python3.10/site-packages/networkx/algorithms/operators/__init__.py b/lib/python3.10/site-packages/networkx/algorithms/operators/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0ebc6ab9998db144234c2601c24861b2c48fa339 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/operators/__init__.py @@ -0,0 +1,4 @@ +from networkx.algorithms.operators.all import * +from networkx.algorithms.operators.binary import * +from networkx.algorithms.operators.product import * +from networkx.algorithms.operators.unary import * diff --git a/lib/python3.10/site-packages/networkx/algorithms/operators/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/operators/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..616660ed7c2bb1fb070f8eab1869a14034987907 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/operators/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/operators/__pycache__/all.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/operators/__pycache__/all.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..603b8975ead5426de12fb93fcf27dc007a78f7c6 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/operators/__pycache__/all.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/operators/__pycache__/binary.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/operators/__pycache__/binary.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..febac14c5ce0ddf2cc0a26a72a7d2600369969e6 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/operators/__pycache__/binary.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/operators/__pycache__/product.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/operators/__pycache__/product.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..439ad7b5e9ff989a67bda836812c963b4ff05b68 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/operators/__pycache__/product.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/operators/__pycache__/unary.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/operators/__pycache__/unary.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c61239148b31852242f6d3bfd70dcb4ebc9cb75e Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/operators/__pycache__/unary.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/operators/all.py b/lib/python3.10/site-packages/networkx/algorithms/operators/all.py new file mode 100644 index 0000000000000000000000000000000000000000..549d335d27452a878a3aaf74d688fdaec85543b2 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/operators/all.py @@ -0,0 +1,321 @@ +"""Operations on many graphs.""" + +from itertools import chain, repeat + +import networkx as nx + +__all__ = ["union_all", "compose_all", "disjoint_union_all", "intersection_all"] + + +@nx._dispatchable(graphs="[graphs]", preserve_all_attrs=True, returns_graph=True) +def union_all(graphs, rename=()): + """Returns the union of all graphs. + + The graphs must be disjoint, otherwise an exception is raised. + + Parameters + ---------- + graphs : iterable + Iterable of NetworkX graphs + + rename : iterable , optional + Node names of graphs can be changed by specifying the tuple + rename=('G-','H-') (for example). Node "u" in G is then renamed + "G-u" and "v" in H is renamed "H-v". Infinite generators (like itertools.count) + are also supported. + + Returns + ------- + U : a graph with the same type as the first graph in list + + Raises + ------ + ValueError + If `graphs` is an empty list. + + NetworkXError + In case of mixed type graphs, like MultiGraph and Graph, or directed and undirected graphs. + + Notes + ----- + For operating on mixed type graphs, they should be converted to the same type. + >>> G = nx.Graph() + >>> H = nx.DiGraph() + >>> GH = union_all([nx.DiGraph(G), H]) + + To force a disjoint union with node relabeling, use + disjoint_union_all(G,H) or convert_node_labels_to integers(). + + Graph, edge, and node attributes are propagated to the union graph. + If a graph attribute is present in multiple graphs, then the value + from the last graph in the list with that attribute is used. + + Examples + -------- + >>> G1 = nx.Graph([(1, 2), (2, 3)]) + >>> G2 = nx.Graph([(4, 5), (5, 6)]) + >>> result_graph = nx.union_all([G1, G2]) + >>> result_graph.nodes() + NodeView((1, 2, 3, 4, 5, 6)) + >>> result_graph.edges() + EdgeView([(1, 2), (2, 3), (4, 5), (5, 6)]) + + See Also + -------- + union + disjoint_union_all + """ + R = None + seen_nodes = set() + + # rename graph to obtain disjoint node labels + def add_prefix(graph, prefix): + if prefix is None: + return graph + + def label(x): + return f"{prefix}{x}" + + return nx.relabel_nodes(graph, label) + + rename = chain(rename, repeat(None)) + graphs = (add_prefix(G, name) for G, name in zip(graphs, rename)) + + for i, G in enumerate(graphs): + G_nodes_set = set(G.nodes) + if i == 0: + # Union is the same type as first graph + R = G.__class__() + elif G.is_directed() != R.is_directed(): + raise nx.NetworkXError("All graphs must be directed or undirected.") + elif G.is_multigraph() != R.is_multigraph(): + raise nx.NetworkXError("All graphs must be graphs or multigraphs.") + elif not seen_nodes.isdisjoint(G_nodes_set): + raise nx.NetworkXError( + "The node sets of the graphs are not disjoint.\n" + "Use `rename` to specify prefixes for the graphs or use\n" + "disjoint_union(G1, G2, ..., GN)." + ) + + seen_nodes |= G_nodes_set + R.graph.update(G.graph) + R.add_nodes_from(G.nodes(data=True)) + R.add_edges_from( + G.edges(keys=True, data=True) if G.is_multigraph() else G.edges(data=True) + ) + + if R is None: + raise ValueError("cannot apply union_all to an empty list") + + return R + + +@nx._dispatchable(graphs="[graphs]", preserve_all_attrs=True, returns_graph=True) +def disjoint_union_all(graphs): + """Returns the disjoint union of all graphs. + + This operation forces distinct integer node labels starting with 0 + for the first graph in the list and numbering consecutively. + + Parameters + ---------- + graphs : iterable + Iterable of NetworkX graphs + + Returns + ------- + U : A graph with the same type as the first graph in list + + Raises + ------ + ValueError + If `graphs` is an empty list. + + NetworkXError + In case of mixed type graphs, like MultiGraph and Graph, or directed and undirected graphs. + + Examples + -------- + >>> G1 = nx.Graph([(1, 2), (2, 3)]) + >>> G2 = nx.Graph([(4, 5), (5, 6)]) + >>> U = nx.disjoint_union_all([G1, G2]) + >>> list(U.nodes()) + [0, 1, 2, 3, 4, 5] + >>> list(U.edges()) + [(0, 1), (1, 2), (3, 4), (4, 5)] + + Notes + ----- + For operating on mixed type graphs, they should be converted to the same type. + + Graph, edge, and node attributes are propagated to the union graph. + If a graph attribute is present in multiple graphs, then the value + from the last graph in the list with that attribute is used. + """ + + def yield_relabeled(graphs): + first_label = 0 + for G in graphs: + yield nx.convert_node_labels_to_integers(G, first_label=first_label) + first_label += len(G) + + R = union_all(yield_relabeled(graphs)) + + return R + + +@nx._dispatchable(graphs="[graphs]", preserve_all_attrs=True, returns_graph=True) +def compose_all(graphs): + """Returns the composition of all graphs. + + Composition is the simple union of the node sets and edge sets. + The node sets of the supplied graphs need not be disjoint. + + Parameters + ---------- + graphs : iterable + Iterable of NetworkX graphs + + Returns + ------- + C : A graph with the same type as the first graph in list + + Raises + ------ + ValueError + If `graphs` is an empty list. + + NetworkXError + In case of mixed type graphs, like MultiGraph and Graph, or directed and undirected graphs. + + Examples + -------- + >>> G1 = nx.Graph([(1, 2), (2, 3)]) + >>> G2 = nx.Graph([(3, 4), (5, 6)]) + >>> C = nx.compose_all([G1, G2]) + >>> list(C.nodes()) + [1, 2, 3, 4, 5, 6] + >>> list(C.edges()) + [(1, 2), (2, 3), (3, 4), (5, 6)] + + Notes + ----- + For operating on mixed type graphs, they should be converted to the same type. + + Graph, edge, and node attributes are propagated to the union graph. + If a graph attribute is present in multiple graphs, then the value + from the last graph in the list with that attribute is used. + """ + R = None + + # add graph attributes, H attributes take precedent over G attributes + for i, G in enumerate(graphs): + if i == 0: + # create new graph + R = G.__class__() + elif G.is_directed() != R.is_directed(): + raise nx.NetworkXError("All graphs must be directed or undirected.") + elif G.is_multigraph() != R.is_multigraph(): + raise nx.NetworkXError("All graphs must be graphs or multigraphs.") + + R.graph.update(G.graph) + R.add_nodes_from(G.nodes(data=True)) + R.add_edges_from( + G.edges(keys=True, data=True) if G.is_multigraph() else G.edges(data=True) + ) + + if R is None: + raise ValueError("cannot apply compose_all to an empty list") + + return R + + +@nx._dispatchable(graphs="[graphs]", returns_graph=True) +def intersection_all(graphs): + """Returns a new graph that contains only the nodes and the edges that exist in + all graphs. + + Parameters + ---------- + graphs : iterable + Iterable of NetworkX graphs + + Returns + ------- + R : A new graph with the same type as the first graph in list + + Raises + ------ + ValueError + If `graphs` is an empty list. + + NetworkXError + In case of mixed type graphs, like MultiGraph and Graph, or directed and undirected graphs. + + Notes + ----- + For operating on mixed type graphs, they should be converted to the same type. + + Attributes from the graph, nodes, and edges are not copied to the new + graph. + + The resulting graph can be updated with attributes if desired. For example, code which adds the minimum attribute for each node across all graphs could work. + >>> g = nx.Graph() + >>> g.add_node(0, capacity=4) + >>> g.add_node(1, capacity=3) + >>> g.add_edge(0, 1) + + >>> h = g.copy() + >>> h.nodes[0]["capacity"] = 2 + + >>> gh = nx.intersection_all([g, h]) + + >>> new_node_attr = { + ... n: min(*(anyG.nodes[n].get("capacity", float("inf")) for anyG in [g, h])) + ... for n in gh + ... } + >>> nx.set_node_attributes(gh, new_node_attr, "new_capacity") + >>> gh.nodes(data=True) + NodeDataView({0: {'new_capacity': 2}, 1: {'new_capacity': 3}}) + + Examples + -------- + >>> G1 = nx.Graph([(1, 2), (2, 3)]) + >>> G2 = nx.Graph([(2, 3), (3, 4)]) + >>> R = nx.intersection_all([G1, G2]) + >>> list(R.nodes()) + [2, 3] + >>> list(R.edges()) + [(2, 3)] + + """ + R = None + + for i, G in enumerate(graphs): + G_nodes_set = set(G.nodes) + G_edges_set = set(G.edges) + if not G.is_directed(): + if G.is_multigraph(): + G_edges_set.update((v, u, k) for u, v, k in list(G_edges_set)) + else: + G_edges_set.update((v, u) for u, v in list(G_edges_set)) + if i == 0: + # create new graph + R = G.__class__() + node_intersection = G_nodes_set + edge_intersection = G_edges_set + elif G.is_directed() != R.is_directed(): + raise nx.NetworkXError("All graphs must be directed or undirected.") + elif G.is_multigraph() != R.is_multigraph(): + raise nx.NetworkXError("All graphs must be graphs or multigraphs.") + else: + node_intersection &= G_nodes_set + edge_intersection &= G_edges_set + + if R is None: + raise ValueError("cannot apply intersection_all to an empty list") + + R.add_nodes_from(node_intersection) + R.add_edges_from(edge_intersection) + + return R diff --git a/lib/python3.10/site-packages/networkx/algorithms/operators/binary.py b/lib/python3.10/site-packages/networkx/algorithms/operators/binary.py new file mode 100644 index 0000000000000000000000000000000000000000..08907bf6e7cacb676020426f8cbfa4a257f03262 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/operators/binary.py @@ -0,0 +1,450 @@ +""" +Operations on graphs including union, intersection, difference. +""" + +import networkx as nx + +__all__ = [ + "union", + "compose", + "disjoint_union", + "intersection", + "difference", + "symmetric_difference", + "full_join", +] +_G_H = {"G": 0, "H": 1} + + +@nx._dispatchable(graphs=_G_H, preserve_all_attrs=True, returns_graph=True) +def union(G, H, rename=()): + """Combine graphs G and H. The names of nodes must be unique. + + A name collision between the graphs will raise an exception. + + A renaming facility is provided to avoid name collisions. + + + Parameters + ---------- + G, H : graph + A NetworkX graph + + rename : iterable , optional + Node names of G and H can be changed by specifying the tuple + rename=('G-','H-') (for example). Node "u" in G is then renamed + "G-u" and "v" in H is renamed "H-v". + + Returns + ------- + U : A union graph with the same type as G. + + See Also + -------- + compose + :func:`~networkx.Graph.update` + disjoint_union + + Notes + ----- + To combine graphs that have common nodes, consider compose(G, H) + or the method, Graph.update(). + + disjoint_union() is similar to union() except that it avoids name clashes + by relabeling the nodes with sequential integers. + + Edge and node attributes are propagated from G and H to the union graph. + Graph attributes are also propagated, but if they are present in both G and H, + then the value from H is used. + + Examples + -------- + >>> G = nx.Graph([(0, 1), (0, 2), (1, 2)]) + >>> H = nx.Graph([(0, 1), (0, 3), (1, 3), (1, 2)]) + >>> U = nx.union(G, H, rename=("G", "H")) + >>> U.nodes + NodeView(('G0', 'G1', 'G2', 'H0', 'H1', 'H3', 'H2')) + >>> U.edges + EdgeView([('G0', 'G1'), ('G0', 'G2'), ('G1', 'G2'), ('H0', 'H1'), ('H0', 'H3'), ('H1', 'H3'), ('H1', 'H2')]) + + + """ + return nx.union_all([G, H], rename) + + +@nx._dispatchable(graphs=_G_H, preserve_all_attrs=True, returns_graph=True) +def disjoint_union(G, H): + """Combine graphs G and H. The nodes are assumed to be unique (disjoint). + + This algorithm automatically relabels nodes to avoid name collisions. + + Parameters + ---------- + G,H : graph + A NetworkX graph + + Returns + ------- + U : A union graph with the same type as G. + + See Also + -------- + union + compose + :func:`~networkx.Graph.update` + + Notes + ----- + A new graph is created, of the same class as G. It is recommended + that G and H be either both directed or both undirected. + + The nodes of G are relabeled 0 to len(G)-1, and the nodes of H are + relabeled len(G) to len(G)+len(H)-1. + + Renumbering forces G and H to be disjoint, so no exception is ever raised for a name collision. + To preserve the check for common nodes, use union(). + + Edge and node attributes are propagated from G and H to the union graph. + Graph attributes are also propagated, but if they are present in both G and H, + then the value from H is used. + + To combine graphs that have common nodes, consider compose(G, H) + or the method, Graph.update(). + + Examples + -------- + >>> G = nx.Graph([(0, 1), (0, 2), (1, 2)]) + >>> H = nx.Graph([(0, 3), (1, 2), (2, 3)]) + >>> G.nodes[0]["key1"] = 5 + >>> H.nodes[0]["key2"] = 10 + >>> U = nx.disjoint_union(G, H) + >>> U.nodes(data=True) + NodeDataView({0: {'key1': 5}, 1: {}, 2: {}, 3: {'key2': 10}, 4: {}, 5: {}, 6: {}}) + >>> U.edges + EdgeView([(0, 1), (0, 2), (1, 2), (3, 4), (4, 6), (5, 6)]) + """ + return nx.disjoint_union_all([G, H]) + + +@nx._dispatchable(graphs=_G_H, returns_graph=True) +def intersection(G, H): + """Returns a new graph that contains only the nodes and the edges that exist in + both G and H. + + Parameters + ---------- + G,H : graph + A NetworkX graph. G and H can have different node sets but must be both graphs or both multigraphs. + + Raises + ------ + NetworkXError + If one is a MultiGraph and the other one is a graph. + + Returns + ------- + GH : A new graph with the same type as G. + + Notes + ----- + Attributes from the graph, nodes, and edges are not copied to the new + graph. If you want a new graph of the intersection of G and H + with the attributes (including edge data) from G use remove_nodes_from() + as follows + + >>> G = nx.path_graph(3) + >>> H = nx.path_graph(5) + >>> R = G.copy() + >>> R.remove_nodes_from(n for n in G if n not in H) + >>> R.remove_edges_from(e for e in G.edges if e not in H.edges) + + Examples + -------- + >>> G = nx.Graph([(0, 1), (0, 2), (1, 2)]) + >>> H = nx.Graph([(0, 3), (1, 2), (2, 3)]) + >>> R = nx.intersection(G, H) + >>> R.nodes + NodeView((0, 1, 2)) + >>> R.edges + EdgeView([(1, 2)]) + """ + return nx.intersection_all([G, H]) + + +@nx._dispatchable(graphs=_G_H, returns_graph=True) +def difference(G, H): + """Returns a new graph that contains the edges that exist in G but not in H. + + The node sets of H and G must be the same. + + Parameters + ---------- + G,H : graph + A NetworkX graph. G and H must have the same node sets. + + Returns + ------- + D : A new graph with the same type as G. + + Notes + ----- + Attributes from the graph, nodes, and edges are not copied to the new + graph. If you want a new graph of the difference of G and H with + the attributes (including edge data) from G use remove_nodes_from() + as follows: + + >>> G = nx.path_graph(3) + >>> H = nx.path_graph(5) + >>> R = G.copy() + >>> R.remove_nodes_from(n for n in G if n in H) + + Examples + -------- + >>> G = nx.Graph([(0, 1), (0, 2), (1, 2), (1, 3)]) + >>> H = nx.Graph([(0, 1), (1, 2), (0, 3)]) + >>> R = nx.difference(G, H) + >>> R.nodes + NodeView((0, 1, 2, 3)) + >>> R.edges + EdgeView([(0, 2), (1, 3)]) + """ + # create new graph + if not G.is_multigraph() == H.is_multigraph(): + raise nx.NetworkXError("G and H must both be graphs or multigraphs.") + R = nx.create_empty_copy(G, with_data=False) + + if set(G) != set(H): + raise nx.NetworkXError("Node sets of graphs not equal") + + if G.is_multigraph(): + edges = G.edges(keys=True) + else: + edges = G.edges() + for e in edges: + if not H.has_edge(*e): + R.add_edge(*e) + return R + + +@nx._dispatchable(graphs=_G_H, returns_graph=True) +def symmetric_difference(G, H): + """Returns new graph with edges that exist in either G or H but not both. + + The node sets of H and G must be the same. + + Parameters + ---------- + G,H : graph + A NetworkX graph. G and H must have the same node sets. + + Returns + ------- + D : A new graph with the same type as G. + + Notes + ----- + Attributes from the graph, nodes, and edges are not copied to the new + graph. + + Examples + -------- + >>> G = nx.Graph([(0, 1), (0, 2), (1, 2), (1, 3)]) + >>> H = nx.Graph([(0, 1), (1, 2), (0, 3)]) + >>> R = nx.symmetric_difference(G, H) + >>> R.nodes + NodeView((0, 1, 2, 3)) + >>> R.edges + EdgeView([(0, 2), (0, 3), (1, 3)]) + """ + # create new graph + if not G.is_multigraph() == H.is_multigraph(): + raise nx.NetworkXError("G and H must both be graphs or multigraphs.") + R = nx.create_empty_copy(G, with_data=False) + + if set(G) != set(H): + raise nx.NetworkXError("Node sets of graphs not equal") + + gnodes = set(G) # set of nodes in G + hnodes = set(H) # set of nodes in H + nodes = gnodes.symmetric_difference(hnodes) + R.add_nodes_from(nodes) + + if G.is_multigraph(): + edges = G.edges(keys=True) + else: + edges = G.edges() + # we could copy the data here but then this function doesn't + # match intersection and difference + for e in edges: + if not H.has_edge(*e): + R.add_edge(*e) + + if H.is_multigraph(): + edges = H.edges(keys=True) + else: + edges = H.edges() + for e in edges: + if not G.has_edge(*e): + R.add_edge(*e) + return R + + +@nx._dispatchable(graphs=_G_H, preserve_all_attrs=True, returns_graph=True) +def compose(G, H): + """Compose graph G with H by combining nodes and edges into a single graph. + + The node sets and edges sets do not need to be disjoint. + + Composing preserves the attributes of nodes and edges. + Attribute values from H take precedent over attribute values from G. + + Parameters + ---------- + G, H : graph + A NetworkX graph + + Returns + ------- + C: A new graph with the same type as G + + See Also + -------- + :func:`~networkx.Graph.update` + union + disjoint_union + + Notes + ----- + It is recommended that G and H be either both directed or both undirected. + + For MultiGraphs, the edges are identified by incident nodes AND edge-key. + This can cause surprises (i.e., edge `(1, 2)` may or may not be the same + in two graphs) if you use MultiGraph without keeping track of edge keys. + + If combining the attributes of common nodes is not desired, consider union(), + which raises an exception for name collisions. + + Examples + -------- + >>> G = nx.Graph([(0, 1), (0, 2)]) + >>> H = nx.Graph([(0, 1), (1, 2)]) + >>> R = nx.compose(G, H) + >>> R.nodes + NodeView((0, 1, 2)) + >>> R.edges + EdgeView([(0, 1), (0, 2), (1, 2)]) + + By default, the attributes from `H` take precedent over attributes from `G`. + If you prefer another way of combining attributes, you can update them after the compose operation: + + >>> G = nx.Graph([(0, 1, {"weight": 2.0}), (3, 0, {"weight": 100.0})]) + >>> H = nx.Graph([(0, 1, {"weight": 10.0}), (1, 2, {"weight": -1.0})]) + >>> nx.set_node_attributes(G, {0: "dark", 1: "light", 3: "black"}, name="color") + >>> nx.set_node_attributes(H, {0: "green", 1: "orange", 2: "yellow"}, name="color") + >>> GcomposeH = nx.compose(G, H) + + Normally, color attribute values of nodes of GcomposeH come from H. We can workaround this as follows: + + >>> node_data = { + ... n: G.nodes[n]["color"] + " " + H.nodes[n]["color"] + ... for n in G.nodes & H.nodes + ... } + >>> nx.set_node_attributes(GcomposeH, node_data, "color") + >>> print(GcomposeH.nodes[0]["color"]) + dark green + + >>> print(GcomposeH.nodes[3]["color"]) + black + + Similarly, we can update edge attributes after the compose operation in a way we prefer: + + >>> edge_data = { + ... e: G.edges[e]["weight"] * H.edges[e]["weight"] for e in G.edges & H.edges + ... } + >>> nx.set_edge_attributes(GcomposeH, edge_data, "weight") + >>> print(GcomposeH.edges[(0, 1)]["weight"]) + 20.0 + + >>> print(GcomposeH.edges[(3, 0)]["weight"]) + 100.0 + """ + return nx.compose_all([G, H]) + + +@nx._dispatchable(graphs=_G_H, preserve_all_attrs=True, returns_graph=True) +def full_join(G, H, rename=(None, None)): + """Returns the full join of graphs G and H. + + Full join is the union of G and H in which all edges between + G and H are added. + The node sets of G and H must be disjoint, + otherwise an exception is raised. + + Parameters + ---------- + G, H : graph + A NetworkX graph + + rename : tuple , default=(None, None) + Node names of G and H can be changed by specifying the tuple + rename=('G-','H-') (for example). Node "u" in G is then renamed + "G-u" and "v" in H is renamed "H-v". + + Returns + ------- + U : The full join graph with the same type as G. + + Notes + ----- + It is recommended that G and H be either both directed or both undirected. + + If G is directed, then edges from G to H are added as well as from H to G. + + Note that full_join() does not produce parallel edges for MultiGraphs. + + The full join operation of graphs G and H is the same as getting + their complement, performing a disjoint union, and finally getting + the complement of the resulting graph. + + Graph, edge, and node attributes are propagated from G and H + to the union graph. If a graph attribute is present in both + G and H the value from H is used. + + Examples + -------- + >>> G = nx.Graph([(0, 1), (0, 2)]) + >>> H = nx.Graph([(3, 4)]) + >>> R = nx.full_join(G, H, rename=("G", "H")) + >>> R.nodes + NodeView(('G0', 'G1', 'G2', 'H3', 'H4')) + >>> R.edges + EdgeView([('G0', 'G1'), ('G0', 'G2'), ('G0', 'H3'), ('G0', 'H4'), ('G1', 'H3'), ('G1', 'H4'), ('G2', 'H3'), ('G2', 'H4'), ('H3', 'H4')]) + + See Also + -------- + union + disjoint_union + """ + R = union(G, H, rename) + + def add_prefix(graph, prefix): + if prefix is None: + return graph + + def label(x): + return f"{prefix}{x}" + + return nx.relabel_nodes(graph, label) + + G = add_prefix(G, rename[0]) + H = add_prefix(H, rename[1]) + + for i in G: + for j in H: + R.add_edge(i, j) + if R.is_directed(): + for i in H: + for j in G: + R.add_edge(i, j) + + return R diff --git a/lib/python3.10/site-packages/networkx/algorithms/operators/product.py b/lib/python3.10/site-packages/networkx/algorithms/operators/product.py new file mode 100644 index 0000000000000000000000000000000000000000..28ca78bf4deb45ffa422d2792b966adfa112692f --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/operators/product.py @@ -0,0 +1,633 @@ +""" +Graph products. +""" + +from itertools import product + +import networkx as nx +from networkx.utils import not_implemented_for + +__all__ = [ + "tensor_product", + "cartesian_product", + "lexicographic_product", + "strong_product", + "power", + "rooted_product", + "corona_product", + "modular_product", +] +_G_H = {"G": 0, "H": 1} + + +def _dict_product(d1, d2): + return {k: (d1.get(k), d2.get(k)) for k in set(d1) | set(d2)} + + +# Generators for producing graph products +def _node_product(G, H): + for u, v in product(G, H): + yield ((u, v), _dict_product(G.nodes[u], H.nodes[v])) + + +def _directed_edges_cross_edges(G, H): + if not G.is_multigraph() and not H.is_multigraph(): + for u, v, c in G.edges(data=True): + for x, y, d in H.edges(data=True): + yield (u, x), (v, y), _dict_product(c, d) + if not G.is_multigraph() and H.is_multigraph(): + for u, v, c in G.edges(data=True): + for x, y, k, d in H.edges(data=True, keys=True): + yield (u, x), (v, y), k, _dict_product(c, d) + if G.is_multigraph() and not H.is_multigraph(): + for u, v, k, c in G.edges(data=True, keys=True): + for x, y, d in H.edges(data=True): + yield (u, x), (v, y), k, _dict_product(c, d) + if G.is_multigraph() and H.is_multigraph(): + for u, v, j, c in G.edges(data=True, keys=True): + for x, y, k, d in H.edges(data=True, keys=True): + yield (u, x), (v, y), (j, k), _dict_product(c, d) + + +def _undirected_edges_cross_edges(G, H): + if not G.is_multigraph() and not H.is_multigraph(): + for u, v, c in G.edges(data=True): + for x, y, d in H.edges(data=True): + yield (v, x), (u, y), _dict_product(c, d) + if not G.is_multigraph() and H.is_multigraph(): + for u, v, c in G.edges(data=True): + for x, y, k, d in H.edges(data=True, keys=True): + yield (v, x), (u, y), k, _dict_product(c, d) + if G.is_multigraph() and not H.is_multigraph(): + for u, v, k, c in G.edges(data=True, keys=True): + for x, y, d in H.edges(data=True): + yield (v, x), (u, y), k, _dict_product(c, d) + if G.is_multigraph() and H.is_multigraph(): + for u, v, j, c in G.edges(data=True, keys=True): + for x, y, k, d in H.edges(data=True, keys=True): + yield (v, x), (u, y), (j, k), _dict_product(c, d) + + +def _edges_cross_nodes(G, H): + if G.is_multigraph(): + for u, v, k, d in G.edges(data=True, keys=True): + for x in H: + yield (u, x), (v, x), k, d + else: + for u, v, d in G.edges(data=True): + for x in H: + if H.is_multigraph(): + yield (u, x), (v, x), None, d + else: + yield (u, x), (v, x), d + + +def _nodes_cross_edges(G, H): + if H.is_multigraph(): + for x in G: + for u, v, k, d in H.edges(data=True, keys=True): + yield (x, u), (x, v), k, d + else: + for x in G: + for u, v, d in H.edges(data=True): + if G.is_multigraph(): + yield (x, u), (x, v), None, d + else: + yield (x, u), (x, v), d + + +def _edges_cross_nodes_and_nodes(G, H): + if G.is_multigraph(): + for u, v, k, d in G.edges(data=True, keys=True): + for x in H: + for y in H: + yield (u, x), (v, y), k, d + else: + for u, v, d in G.edges(data=True): + for x in H: + for y in H: + if H.is_multigraph(): + yield (u, x), (v, y), None, d + else: + yield (u, x), (v, y), d + + +def _init_product_graph(G, H): + if G.is_directed() != H.is_directed(): + msg = "G and H must be both directed or both undirected" + raise nx.NetworkXError(msg) + if G.is_multigraph() or H.is_multigraph(): + GH = nx.MultiGraph() + else: + GH = nx.Graph() + if G.is_directed(): + GH = GH.to_directed() + return GH + + +@nx._dispatchable(graphs=_G_H, preserve_node_attrs=True, returns_graph=True) +def tensor_product(G, H): + r"""Returns the tensor product of G and H. + + The tensor product $P$ of the graphs $G$ and $H$ has a node set that + is the Cartesian product of the node sets, $V(P)=V(G) \times V(H)$. + $P$ has an edge $((u,v), (x,y))$ if and only if $(u,x)$ is an edge in $G$ + and $(v,y)$ is an edge in $H$. + + Tensor product is sometimes also referred to as the categorical product, + direct product, cardinal product or conjunction. + + + Parameters + ---------- + G, H: graphs + Networkx graphs. + + Returns + ------- + P: NetworkX graph + The tensor product of G and H. P will be a multi-graph if either G + or H is a multi-graph, will be a directed if G and H are directed, + and undirected if G and H are undirected. + + Raises + ------ + NetworkXError + If G and H are not both directed or both undirected. + + Notes + ----- + Node attributes in P are two-tuple of the G and H node attributes. + Missing attributes are assigned None. + + Examples + -------- + >>> G = nx.Graph() + >>> H = nx.Graph() + >>> G.add_node(0, a1=True) + >>> H.add_node("a", a2="Spam") + >>> P = nx.tensor_product(G, H) + >>> list(P) + [(0, 'a')] + + Edge attributes and edge keys (for multigraphs) are also copied to the + new product graph + """ + GH = _init_product_graph(G, H) + GH.add_nodes_from(_node_product(G, H)) + GH.add_edges_from(_directed_edges_cross_edges(G, H)) + if not GH.is_directed(): + GH.add_edges_from(_undirected_edges_cross_edges(G, H)) + return GH + + +@nx._dispatchable(graphs=_G_H, preserve_node_attrs=True, returns_graph=True) +def cartesian_product(G, H): + r"""Returns the Cartesian product of G and H. + + The Cartesian product $P$ of the graphs $G$ and $H$ has a node set that + is the Cartesian product of the node sets, $V(P)=V(G) \times V(H)$. + $P$ has an edge $((u,v),(x,y))$ if and only if either $u$ is equal to $x$ + and both $v$ and $y$ are adjacent in $H$ or if $v$ is equal to $y$ and + both $u$ and $x$ are adjacent in $G$. + + Parameters + ---------- + G, H: graphs + Networkx graphs. + + Returns + ------- + P: NetworkX graph + The Cartesian product of G and H. P will be a multi-graph if either G + or H is a multi-graph. Will be a directed if G and H are directed, + and undirected if G and H are undirected. + + Raises + ------ + NetworkXError + If G and H are not both directed or both undirected. + + Notes + ----- + Node attributes in P are two-tuple of the G and H node attributes. + Missing attributes are assigned None. + + Examples + -------- + >>> G = nx.Graph() + >>> H = nx.Graph() + >>> G.add_node(0, a1=True) + >>> H.add_node("a", a2="Spam") + >>> P = nx.cartesian_product(G, H) + >>> list(P) + [(0, 'a')] + + Edge attributes and edge keys (for multigraphs) are also copied to the + new product graph + """ + GH = _init_product_graph(G, H) + GH.add_nodes_from(_node_product(G, H)) + GH.add_edges_from(_edges_cross_nodes(G, H)) + GH.add_edges_from(_nodes_cross_edges(G, H)) + return GH + + +@nx._dispatchable(graphs=_G_H, preserve_node_attrs=True, returns_graph=True) +def lexicographic_product(G, H): + r"""Returns the lexicographic product of G and H. + + The lexicographical product $P$ of the graphs $G$ and $H$ has a node set + that is the Cartesian product of the node sets, $V(P)=V(G) \times V(H)$. + $P$ has an edge $((u,v), (x,y))$ if and only if $(u,v)$ is an edge in $G$ + or $u==v$ and $(x,y)$ is an edge in $H$. + + Parameters + ---------- + G, H: graphs + Networkx graphs. + + Returns + ------- + P: NetworkX graph + The Cartesian product of G and H. P will be a multi-graph if either G + or H is a multi-graph. Will be a directed if G and H are directed, + and undirected if G and H are undirected. + + Raises + ------ + NetworkXError + If G and H are not both directed or both undirected. + + Notes + ----- + Node attributes in P are two-tuple of the G and H node attributes. + Missing attributes are assigned None. + + Examples + -------- + >>> G = nx.Graph() + >>> H = nx.Graph() + >>> G.add_node(0, a1=True) + >>> H.add_node("a", a2="Spam") + >>> P = nx.lexicographic_product(G, H) + >>> list(P) + [(0, 'a')] + + Edge attributes and edge keys (for multigraphs) are also copied to the + new product graph + """ + GH = _init_product_graph(G, H) + GH.add_nodes_from(_node_product(G, H)) + # Edges in G regardless of H designation + GH.add_edges_from(_edges_cross_nodes_and_nodes(G, H)) + # For each x in G, only if there is an edge in H + GH.add_edges_from(_nodes_cross_edges(G, H)) + return GH + + +@nx._dispatchable(graphs=_G_H, preserve_node_attrs=True, returns_graph=True) +def strong_product(G, H): + r"""Returns the strong product of G and H. + + The strong product $P$ of the graphs $G$ and $H$ has a node set that + is the Cartesian product of the node sets, $V(P)=V(G) \times V(H)$. + $P$ has an edge $((u,x), (v,y))$ if any of the following conditions + are met: + + - $u=v$ and $(x,y)$ is an edge in $H$ + - $x=y$ and $(u,v)$ is an edge in $G$ + - $(u,v)$ is an edge in $G$ and $(x,y)$ is an edge in $H$ + + Parameters + ---------- + G, H: graphs + Networkx graphs. + + Returns + ------- + P: NetworkX graph + The Cartesian product of G and H. P will be a multi-graph if either G + or H is a multi-graph. Will be a directed if G and H are directed, + and undirected if G and H are undirected. + + Raises + ------ + NetworkXError + If G and H are not both directed or both undirected. + + Notes + ----- + Node attributes in P are two-tuple of the G and H node attributes. + Missing attributes are assigned None. + + Examples + -------- + >>> G = nx.Graph() + >>> H = nx.Graph() + >>> G.add_node(0, a1=True) + >>> H.add_node("a", a2="Spam") + >>> P = nx.strong_product(G, H) + >>> list(P) + [(0, 'a')] + + Edge attributes and edge keys (for multigraphs) are also copied to the + new product graph + """ + GH = _init_product_graph(G, H) + GH.add_nodes_from(_node_product(G, H)) + GH.add_edges_from(_nodes_cross_edges(G, H)) + GH.add_edges_from(_edges_cross_nodes(G, H)) + GH.add_edges_from(_directed_edges_cross_edges(G, H)) + if not GH.is_directed(): + GH.add_edges_from(_undirected_edges_cross_edges(G, H)) + return GH + + +@not_implemented_for("directed") +@not_implemented_for("multigraph") +@nx._dispatchable(returns_graph=True) +def power(G, k): + """Returns the specified power of a graph. + + The $k$th power of a simple graph $G$, denoted $G^k$, is a + graph on the same set of nodes in which two distinct nodes $u$ and + $v$ are adjacent in $G^k$ if and only if the shortest path + distance between $u$ and $v$ in $G$ is at most $k$. + + Parameters + ---------- + G : graph + A NetworkX simple graph object. + + k : positive integer + The power to which to raise the graph `G`. + + Returns + ------- + NetworkX simple graph + `G` to the power `k`. + + Raises + ------ + ValueError + If the exponent `k` is not positive. + + NetworkXNotImplemented + If `G` is not a simple graph. + + Examples + -------- + The number of edges will never decrease when taking successive + powers: + + >>> G = nx.path_graph(4) + >>> list(nx.power(G, 2).edges) + [(0, 1), (0, 2), (1, 2), (1, 3), (2, 3)] + >>> list(nx.power(G, 3).edges) + [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)] + + The `k` th power of a cycle graph on *n* nodes is the complete graph + on *n* nodes, if `k` is at least ``n // 2``: + + >>> G = nx.cycle_graph(5) + >>> H = nx.complete_graph(5) + >>> nx.is_isomorphic(nx.power(G, 2), H) + True + >>> G = nx.cycle_graph(8) + >>> H = nx.complete_graph(8) + >>> nx.is_isomorphic(nx.power(G, 4), H) + True + + References + ---------- + .. [1] J. A. Bondy, U. S. R. Murty, *Graph Theory*. Springer, 2008. + + Notes + ----- + This definition of "power graph" comes from Exercise 3.1.6 of + *Graph Theory* by Bondy and Murty [1]_. + + """ + if k <= 0: + raise ValueError("k must be a positive integer") + H = nx.Graph() + H.add_nodes_from(G) + # update BFS code to ignore self loops. + for n in G: + seen = {} # level (number of hops) when seen in BFS + level = 1 # the current level + nextlevel = G[n] + while nextlevel: + thislevel = nextlevel # advance to next level + nextlevel = {} # and start a new list (fringe) + for v in thislevel: + if v == n: # avoid self loop + continue + if v not in seen: + seen[v] = level # set the level of vertex v + nextlevel.update(G[v]) # add neighbors of v + if k <= level: + break + level += 1 + H.add_edges_from((n, nbr) for nbr in seen) + return H + + +@not_implemented_for("multigraph") +@nx._dispatchable(graphs=_G_H, returns_graph=True) +def rooted_product(G, H, root): + """Return the rooted product of graphs G and H rooted at root in H. + + A new graph is constructed representing the rooted product of + the inputted graphs, G and H, with a root in H. + A rooted product duplicates H for each nodes in G with the root + of H corresponding to the node in G. Nodes are renamed as the direct + product of G and H. The result is a subgraph of the cartesian product. + + Parameters + ---------- + G,H : graph + A NetworkX graph + root : node + A node in H + + Returns + ------- + R : The rooted product of G and H with a specified root in H + + Notes + ----- + The nodes of R are the Cartesian Product of the nodes of G and H. + The nodes of G and H are not relabeled. + """ + if root not in H: + raise nx.NodeNotFound("root must be a vertex in H") + + R = nx.Graph() + R.add_nodes_from(product(G, H)) + + R.add_edges_from(((e[0], root), (e[1], root)) for e in G.edges()) + R.add_edges_from(((g, e[0]), (g, e[1])) for g in G for e in H.edges()) + + return R + + +@not_implemented_for("directed") +@not_implemented_for("multigraph") +@nx._dispatchable(graphs=_G_H, returns_graph=True) +def corona_product(G, H): + r"""Returns the Corona product of G and H. + + The corona product of $G$ and $H$ is the graph $C = G \circ H$ obtained by + taking one copy of $G$, called the center graph, $|V(G)|$ copies of $H$, + called the outer graph, and making the $i$-th vertex of $G$ adjacent to + every vertex of the $i$-th copy of $H$, where $1 ≤ i ≤ |V(G)|$. + + Parameters + ---------- + G, H: NetworkX graphs + The graphs to take the carona product of. + `G` is the center graph and `H` is the outer graph + + Returns + ------- + C: NetworkX graph + The Corona product of G and H. + + Raises + ------ + NetworkXError + If G and H are not both directed or both undirected. + + Examples + -------- + >>> G = nx.cycle_graph(4) + >>> H = nx.path_graph(2) + >>> C = nx.corona_product(G, H) + >>> list(C) + [0, 1, 2, 3, (0, 0), (0, 1), (1, 0), (1, 1), (2, 0), (2, 1), (3, 0), (3, 1)] + >>> print(C) + Graph with 12 nodes and 16 edges + + References + ---------- + [1] M. Tavakoli, F. Rahbarnia, and A. R. Ashrafi, + "Studying the corona product of graphs under some graph invariants," + Transactions on Combinatorics, vol. 3, no. 3, pp. 43–49, Sep. 2014, + doi: 10.22108/toc.2014.5542. + [2] A. Faraji, "Corona Product in Graph Theory," Ali Faraji, May 11, 2021. + https://blog.alifaraji.ir/math/graph-theory/corona-product.html (accessed Dec. 07, 2021). + """ + GH = _init_product_graph(G, H) + GH.add_nodes_from(G) + GH.add_edges_from(G.edges) + + for G_node in G: + # copy nodes of H in GH, call it H_i + GH.add_nodes_from((G_node, v) for v in H) + + # copy edges of H_i based on H + GH.add_edges_from( + ((G_node, e0), (G_node, e1), d) for e0, e1, d in H.edges.data() + ) + + # creating new edges between H_i and a G's node + GH.add_edges_from((G_node, (G_node, H_node)) for H_node in H) + + return GH + + +@nx._dispatchable( + graphs=_G_H, preserve_edge_attrs=True, preserve_node_attrs=True, returns_graph=True +) +def modular_product(G, H): + r"""Returns the Modular product of G and H. + + The modular product of `G` and `H` is the graph $M = G \nabla H$, + consisting of the node set $V(M) = V(G) \times V(H)$ that is the Cartesian + product of the node sets of `G` and `H`. Further, M contains an edge ((u, v), (x, y)): + + - if u is adjacent to x in `G` and v is adjacent to y in `H`, or + - if u is not adjacent to x in `G` and v is not adjacent to y in `H`. + + More formally:: + + E(M) = {((u, v), (x, y)) | ((u, x) in E(G) and (v, y) in E(H)) or + ((u, x) not in E(G) and (v, y) not in E(H))} + + Parameters + ---------- + G, H: NetworkX graphs + The graphs to take the modular product of. + + Returns + ------- + M: NetworkX graph + The Modular product of `G` and `H`. + + Raises + ------ + NetworkXNotImplemented + If `G` is not a simple graph. + + Examples + -------- + >>> G = nx.cycle_graph(4) + >>> H = nx.path_graph(2) + >>> M = nx.modular_product(G, H) + >>> list(M) + [(0, 0), (0, 1), (1, 0), (1, 1), (2, 0), (2, 1), (3, 0), (3, 1)] + >>> print(M) + Graph with 8 nodes and 8 edges + + Notes + ----- + The *modular product* is defined in [1]_ and was first + introduced as the *weak modular product*. + + The modular product reduces the problem of counting isomorphic subgraphs + in `G` and `H` to the problem of counting cliques in M. The subgraphs of + `G` and `H` that are induced by the nodes of a clique in M are + isomorphic [2]_ [3]_. + + References + ---------- + .. [1] R. Hammack, W. Imrich, and S. Klavžar, + "Handbook of Product Graphs", CRC Press, 2011. + + .. [2] H. G. Barrow and R. M. Burstall, + "Subgraph isomorphism, matching relational structures and maximal + cliques", Information Processing Letters, vol. 4, issue 4, pp. 83-84, + 1976, https://doi.org/10.1016/0020-0190(76)90049-1. + + .. [3] V. G. Vizing, "Reduction of the problem of isomorphism and isomorphic + entrance to the task of finding the nondensity of a graph." Proc. Third + All-Union Conference on Problems of Theoretical Cybernetics. 1974. + """ + if G.is_directed() or H.is_directed(): + raise nx.NetworkXNotImplemented( + "Modular product not implemented for directed graphs" + ) + if G.is_multigraph() or H.is_multigraph(): + raise nx.NetworkXNotImplemented( + "Modular product not implemented for multigraphs" + ) + + GH = _init_product_graph(G, H) + GH.add_nodes_from(_node_product(G, H)) + + for u, v, c in G.edges(data=True): + for x, y, d in H.edges(data=True): + GH.add_edge((u, x), (v, y), **_dict_product(c, d)) + GH.add_edge((v, x), (u, y), **_dict_product(c, d)) + + G = nx.complement(G) + H = nx.complement(H) + + for u, v, c in G.edges(data=True): + for x, y, d in H.edges(data=True): + GH.add_edge((u, x), (v, y), **_dict_product(c, d)) + GH.add_edge((v, x), (u, y), **_dict_product(c, d)) + + return GH diff --git a/lib/python3.10/site-packages/networkx/algorithms/operators/tests/__init__.py b/lib/python3.10/site-packages/networkx/algorithms/operators/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/python3.10/site-packages/networkx/algorithms/operators/tests/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/operators/tests/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dce85588fba62e920b2721b0f082c8a9f6e99c7b Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/operators/tests/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/operators/tests/__pycache__/test_all.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/operators/tests/__pycache__/test_all.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5ee127c0a9d9da16bd87e30a3ba0f90eb4ab287e Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/operators/tests/__pycache__/test_all.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/operators/tests/__pycache__/test_binary.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/operators/tests/__pycache__/test_binary.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a3ea76edad3f2ca05eb79e320b62901ff81fa94d Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/operators/tests/__pycache__/test_binary.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/operators/tests/__pycache__/test_product.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/operators/tests/__pycache__/test_product.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..06fa305c71b2fc92ca3f5cf041a1f2cc7c93d790 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/operators/tests/__pycache__/test_product.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/operators/tests/__pycache__/test_unary.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/operators/tests/__pycache__/test_unary.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3cf7ca53b784f5bffedde47993200e8a56912be4 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/operators/tests/__pycache__/test_unary.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/operators/tests/test_all.py b/lib/python3.10/site-packages/networkx/algorithms/operators/tests/test_all.py new file mode 100644 index 0000000000000000000000000000000000000000..8ec29c150306080d536c1b1dc785209d4a113f6d --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/operators/tests/test_all.py @@ -0,0 +1,328 @@ +import pytest + +import networkx as nx +from networkx.utils import edges_equal + + +def test_union_all_attributes(): + g = nx.Graph() + g.add_node(0, x=4) + g.add_node(1, x=5) + g.add_edge(0, 1, size=5) + g.graph["name"] = "g" + + h = g.copy() + h.graph["name"] = "h" + h.graph["attr"] = "attr" + h.nodes[0]["x"] = 7 + + j = g.copy() + j.graph["name"] = "j" + j.graph["attr"] = "attr" + j.nodes[0]["x"] = 7 + + ghj = nx.union_all([g, h, j], rename=("g", "h", "j")) + assert set(ghj.nodes()) == {"h0", "h1", "g0", "g1", "j0", "j1"} + for n in ghj: + graph, node = n + assert ghj.nodes[n] == eval(graph).nodes[int(node)] + + assert ghj.graph["attr"] == "attr" + assert ghj.graph["name"] == "j" # j graph attributes take precedent + + +def test_intersection_all(): + G = nx.Graph() + H = nx.Graph() + R = nx.Graph(awesome=True) + G.add_nodes_from([1, 2, 3, 4]) + G.add_edge(1, 2) + G.add_edge(2, 3) + H.add_nodes_from([1, 2, 3, 4]) + H.add_edge(2, 3) + H.add_edge(3, 4) + R.add_nodes_from([1, 2, 3, 4]) + R.add_edge(2, 3) + R.add_edge(4, 1) + I = nx.intersection_all([G, H, R]) + assert set(I.nodes()) == {1, 2, 3, 4} + assert sorted(I.edges()) == [(2, 3)] + assert I.graph == {} + + +def test_intersection_all_different_node_sets(): + G = nx.Graph() + H = nx.Graph() + R = nx.Graph() + G.add_nodes_from([1, 2, 3, 4, 6, 7]) + G.add_edge(1, 2) + G.add_edge(2, 3) + G.add_edge(6, 7) + H.add_nodes_from([1, 2, 3, 4]) + H.add_edge(2, 3) + H.add_edge(3, 4) + R.add_nodes_from([1, 2, 3, 4, 8, 9]) + R.add_edge(2, 3) + R.add_edge(4, 1) + R.add_edge(8, 9) + I = nx.intersection_all([G, H, R]) + assert set(I.nodes()) == {1, 2, 3, 4} + assert sorted(I.edges()) == [(2, 3)] + + +def test_intersection_all_attributes(): + g = nx.Graph() + g.add_node(0, x=4) + g.add_node(1, x=5) + g.add_edge(0, 1, size=5) + g.graph["name"] = "g" + + h = g.copy() + h.graph["name"] = "h" + h.graph["attr"] = "attr" + h.nodes[0]["x"] = 7 + + gh = nx.intersection_all([g, h]) + assert set(gh.nodes()) == set(g.nodes()) + assert set(gh.nodes()) == set(h.nodes()) + assert sorted(gh.edges()) == sorted(g.edges()) + + +def test_intersection_all_attributes_different_node_sets(): + g = nx.Graph() + g.add_node(0, x=4) + g.add_node(1, x=5) + g.add_edge(0, 1, size=5) + g.graph["name"] = "g" + + h = g.copy() + g.add_node(2) + h.graph["name"] = "h" + h.graph["attr"] = "attr" + h.nodes[0]["x"] = 7 + + gh = nx.intersection_all([g, h]) + assert set(gh.nodes()) == set(h.nodes()) + assert sorted(gh.edges()) == sorted(g.edges()) + + +def test_intersection_all_multigraph_attributes(): + g = nx.MultiGraph() + g.add_edge(0, 1, key=0) + g.add_edge(0, 1, key=1) + g.add_edge(0, 1, key=2) + h = nx.MultiGraph() + h.add_edge(0, 1, key=0) + h.add_edge(0, 1, key=3) + gh = nx.intersection_all([g, h]) + assert set(gh.nodes()) == set(g.nodes()) + assert set(gh.nodes()) == set(h.nodes()) + assert sorted(gh.edges()) == [(0, 1)] + assert sorted(gh.edges(keys=True)) == [(0, 1, 0)] + + +def test_intersection_all_multigraph_attributes_different_node_sets(): + g = nx.MultiGraph() + g.add_edge(0, 1, key=0) + g.add_edge(0, 1, key=1) + g.add_edge(0, 1, key=2) + g.add_edge(1, 2, key=1) + g.add_edge(1, 2, key=2) + h = nx.MultiGraph() + h.add_edge(0, 1, key=0) + h.add_edge(0, 1, key=2) + h.add_edge(0, 1, key=3) + gh = nx.intersection_all([g, h]) + assert set(gh.nodes()) == set(h.nodes()) + assert sorted(gh.edges()) == [(0, 1), (0, 1)] + assert sorted(gh.edges(keys=True)) == [(0, 1, 0), (0, 1, 2)] + + +def test_intersection_all_digraph(): + g = nx.DiGraph() + g.add_edges_from([(1, 2), (2, 3)]) + h = nx.DiGraph() + h.add_edges_from([(2, 1), (2, 3)]) + gh = nx.intersection_all([g, h]) + assert sorted(gh.edges()) == [(2, 3)] + + +def test_union_all_and_compose_all(): + K3 = nx.complete_graph(3) + P3 = nx.path_graph(3) + + G1 = nx.DiGraph() + G1.add_edge("A", "B") + G1.add_edge("A", "C") + G1.add_edge("A", "D") + G2 = nx.DiGraph() + G2.add_edge("1", "2") + G2.add_edge("1", "3") + G2.add_edge("1", "4") + + G = nx.union_all([G1, G2]) + H = nx.compose_all([G1, G2]) + assert edges_equal(G.edges(), H.edges()) + assert not G.has_edge("A", "1") + pytest.raises(nx.NetworkXError, nx.union, K3, P3) + H1 = nx.union_all([H, G1], rename=("H", "G1")) + assert sorted(H1.nodes()) == [ + "G1A", + "G1B", + "G1C", + "G1D", + "H1", + "H2", + "H3", + "H4", + "HA", + "HB", + "HC", + "HD", + ] + + H2 = nx.union_all([H, G2], rename=("H", "")) + assert sorted(H2.nodes()) == [ + "1", + "2", + "3", + "4", + "H1", + "H2", + "H3", + "H4", + "HA", + "HB", + "HC", + "HD", + ] + + assert not H1.has_edge("NB", "NA") + + G = nx.compose_all([G, G]) + assert edges_equal(G.edges(), H.edges()) + + G2 = nx.union_all([G2, G2], rename=("", "copy")) + assert sorted(G2.nodes()) == [ + "1", + "2", + "3", + "4", + "copy1", + "copy2", + "copy3", + "copy4", + ] + + assert sorted(G2.neighbors("copy4")) == [] + assert sorted(G2.neighbors("copy1")) == ["copy2", "copy3", "copy4"] + assert len(G) == 8 + assert nx.number_of_edges(G) == 6 + + E = nx.disjoint_union_all([G, G]) + assert len(E) == 16 + assert nx.number_of_edges(E) == 12 + + E = nx.disjoint_union_all([G1, G2]) + assert sorted(E.nodes()) == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] + + G1 = nx.DiGraph() + G1.add_edge("A", "B") + G2 = nx.DiGraph() + G2.add_edge(1, 2) + G3 = nx.DiGraph() + G3.add_edge(11, 22) + G4 = nx.union_all([G1, G2, G3], rename=("G1", "G2", "G3")) + assert sorted(G4.nodes()) == ["G1A", "G1B", "G21", "G22", "G311", "G322"] + + +def test_union_all_multigraph(): + G = nx.MultiGraph() + G.add_edge(1, 2, key=0) + G.add_edge(1, 2, key=1) + H = nx.MultiGraph() + H.add_edge(3, 4, key=0) + H.add_edge(3, 4, key=1) + GH = nx.union_all([G, H]) + assert set(GH) == set(G) | set(H) + assert set(GH.edges(keys=True)) == set(G.edges(keys=True)) | set(H.edges(keys=True)) + + +def test_input_output(): + l = [nx.Graph([(1, 2)]), nx.Graph([(3, 4)], awesome=True)] + U = nx.disjoint_union_all(l) + assert len(l) == 2 + assert U.graph["awesome"] + C = nx.compose_all(l) + assert len(l) == 2 + l = [nx.Graph([(1, 2)]), nx.Graph([(1, 2)])] + R = nx.intersection_all(l) + assert len(l) == 2 + + +def test_mixed_type_union(): + with pytest.raises(nx.NetworkXError): + G = nx.Graph() + H = nx.MultiGraph() + I = nx.Graph() + U = nx.union_all([G, H, I]) + with pytest.raises(nx.NetworkXError): + X = nx.Graph() + Y = nx.DiGraph() + XY = nx.union_all([X, Y]) + + +def test_mixed_type_disjoint_union(): + with pytest.raises(nx.NetworkXError): + G = nx.Graph() + H = nx.MultiGraph() + I = nx.Graph() + U = nx.disjoint_union_all([G, H, I]) + with pytest.raises(nx.NetworkXError): + X = nx.Graph() + Y = nx.DiGraph() + XY = nx.disjoint_union_all([X, Y]) + + +def test_mixed_type_intersection(): + with pytest.raises(nx.NetworkXError): + G = nx.Graph() + H = nx.MultiGraph() + I = nx.Graph() + U = nx.intersection_all([G, H, I]) + with pytest.raises(nx.NetworkXError): + X = nx.Graph() + Y = nx.DiGraph() + XY = nx.intersection_all([X, Y]) + + +def test_mixed_type_compose(): + with pytest.raises(nx.NetworkXError): + G = nx.Graph() + H = nx.MultiGraph() + I = nx.Graph() + U = nx.compose_all([G, H, I]) + with pytest.raises(nx.NetworkXError): + X = nx.Graph() + Y = nx.DiGraph() + XY = nx.compose_all([X, Y]) + + +def test_empty_union(): + with pytest.raises(ValueError): + nx.union_all([]) + + +def test_empty_disjoint_union(): + with pytest.raises(ValueError): + nx.disjoint_union_all([]) + + +def test_empty_compose_all(): + with pytest.raises(ValueError): + nx.compose_all([]) + + +def test_empty_intersection_all(): + with pytest.raises(ValueError): + nx.intersection_all([]) diff --git a/lib/python3.10/site-packages/networkx/algorithms/operators/tests/test_binary.py b/lib/python3.10/site-packages/networkx/algorithms/operators/tests/test_binary.py new file mode 100644 index 0000000000000000000000000000000000000000..c907cd6f05167f4eadb0f51e238a1283ad677697 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/operators/tests/test_binary.py @@ -0,0 +1,453 @@ +import os + +import pytest + +import networkx as nx +from networkx.utils import edges_equal + + +def test_union_attributes(): + g = nx.Graph() + g.add_node(0, x=4) + g.add_node(1, x=5) + g.add_edge(0, 1, size=5) + g.graph["name"] = "g" + + h = g.copy() + h.graph["name"] = "h" + h.graph["attr"] = "attr" + h.nodes[0]["x"] = 7 + + gh = nx.union(g, h, rename=("g", "h")) + assert set(gh.nodes()) == {"h0", "h1", "g0", "g1"} + for n in gh: + graph, node = n + assert gh.nodes[n] == eval(graph).nodes[int(node)] + + assert gh.graph["attr"] == "attr" + assert gh.graph["name"] == "h" # h graph attributes take precedent + + +def test_intersection(): + G = nx.Graph() + H = nx.Graph() + G.add_nodes_from([1, 2, 3, 4]) + G.add_edge(1, 2) + G.add_edge(2, 3) + H.add_nodes_from([1, 2, 3, 4]) + H.add_edge(2, 3) + H.add_edge(3, 4) + I = nx.intersection(G, H) + assert set(I.nodes()) == {1, 2, 3, 4} + assert sorted(I.edges()) == [(2, 3)] + + +def test_intersection_node_sets_different(): + G = nx.Graph() + H = nx.Graph() + G.add_nodes_from([1, 2, 3, 4, 7]) + G.add_edge(1, 2) + G.add_edge(2, 3) + H.add_nodes_from([1, 2, 3, 4, 5, 6]) + H.add_edge(2, 3) + H.add_edge(3, 4) + H.add_edge(5, 6) + I = nx.intersection(G, H) + assert set(I.nodes()) == {1, 2, 3, 4} + assert sorted(I.edges()) == [(2, 3)] + + +def test_intersection_attributes(): + g = nx.Graph() + g.add_node(0, x=4) + g.add_node(1, x=5) + g.add_edge(0, 1, size=5) + g.graph["name"] = "g" + + h = g.copy() + h.graph["name"] = "h" + h.graph["attr"] = "attr" + h.nodes[0]["x"] = 7 + gh = nx.intersection(g, h) + + assert set(gh.nodes()) == set(g.nodes()) + assert set(gh.nodes()) == set(h.nodes()) + assert sorted(gh.edges()) == sorted(g.edges()) + + +def test_intersection_attributes_node_sets_different(): + g = nx.Graph() + g.add_node(0, x=4) + g.add_node(1, x=5) + g.add_node(2, x=3) + g.add_edge(0, 1, size=5) + g.graph["name"] = "g" + + h = g.copy() + h.graph["name"] = "h" + h.graph["attr"] = "attr" + h.nodes[0]["x"] = 7 + h.remove_node(2) + + gh = nx.intersection(g, h) + assert set(gh.nodes()) == set(h.nodes()) + assert sorted(gh.edges()) == sorted(g.edges()) + + +def test_intersection_multigraph_attributes(): + g = nx.MultiGraph() + g.add_edge(0, 1, key=0) + g.add_edge(0, 1, key=1) + g.add_edge(0, 1, key=2) + h = nx.MultiGraph() + h.add_edge(0, 1, key=0) + h.add_edge(0, 1, key=3) + gh = nx.intersection(g, h) + assert set(gh.nodes()) == set(g.nodes()) + assert set(gh.nodes()) == set(h.nodes()) + assert sorted(gh.edges()) == [(0, 1)] + assert sorted(gh.edges(keys=True)) == [(0, 1, 0)] + + +def test_intersection_multigraph_attributes_node_set_different(): + g = nx.MultiGraph() + g.add_edge(0, 1, key=0) + g.add_edge(0, 1, key=1) + g.add_edge(0, 1, key=2) + g.add_edge(0, 2, key=2) + g.add_edge(0, 2, key=1) + h = nx.MultiGraph() + h.add_edge(0, 1, key=0) + h.add_edge(0, 1, key=3) + gh = nx.intersection(g, h) + assert set(gh.nodes()) == set(h.nodes()) + assert sorted(gh.edges()) == [(0, 1)] + assert sorted(gh.edges(keys=True)) == [(0, 1, 0)] + + +def test_difference(): + G = nx.Graph() + H = nx.Graph() + G.add_nodes_from([1, 2, 3, 4]) + G.add_edge(1, 2) + G.add_edge(2, 3) + H.add_nodes_from([1, 2, 3, 4]) + H.add_edge(2, 3) + H.add_edge(3, 4) + D = nx.difference(G, H) + assert set(D.nodes()) == {1, 2, 3, 4} + assert sorted(D.edges()) == [(1, 2)] + D = nx.difference(H, G) + assert set(D.nodes()) == {1, 2, 3, 4} + assert sorted(D.edges()) == [(3, 4)] + D = nx.symmetric_difference(G, H) + assert set(D.nodes()) == {1, 2, 3, 4} + assert sorted(D.edges()) == [(1, 2), (3, 4)] + + +def test_difference2(): + G = nx.Graph() + H = nx.Graph() + G.add_nodes_from([1, 2, 3, 4]) + H.add_nodes_from([1, 2, 3, 4]) + G.add_edge(1, 2) + H.add_edge(1, 2) + G.add_edge(2, 3) + D = nx.difference(G, H) + assert set(D.nodes()) == {1, 2, 3, 4} + assert sorted(D.edges()) == [(2, 3)] + D = nx.difference(H, G) + assert set(D.nodes()) == {1, 2, 3, 4} + assert sorted(D.edges()) == [] + H.add_edge(3, 4) + D = nx.difference(H, G) + assert set(D.nodes()) == {1, 2, 3, 4} + assert sorted(D.edges()) == [(3, 4)] + + +def test_difference_attributes(): + g = nx.Graph() + g.add_node(0, x=4) + g.add_node(1, x=5) + g.add_edge(0, 1, size=5) + g.graph["name"] = "g" + + h = g.copy() + h.graph["name"] = "h" + h.graph["attr"] = "attr" + h.nodes[0]["x"] = 7 + + gh = nx.difference(g, h) + assert set(gh.nodes()) == set(g.nodes()) + assert set(gh.nodes()) == set(h.nodes()) + assert sorted(gh.edges()) == [] + # node and graph data should not be copied over + assert gh.nodes.data() != g.nodes.data() + assert gh.graph != g.graph + + +def test_difference_multigraph_attributes(): + g = nx.MultiGraph() + g.add_edge(0, 1, key=0) + g.add_edge(0, 1, key=1) + g.add_edge(0, 1, key=2) + h = nx.MultiGraph() + h.add_edge(0, 1, key=0) + h.add_edge(0, 1, key=3) + gh = nx.difference(g, h) + assert set(gh.nodes()) == set(g.nodes()) + assert set(gh.nodes()) == set(h.nodes()) + assert sorted(gh.edges()) == [(0, 1), (0, 1)] + assert sorted(gh.edges(keys=True)) == [(0, 1, 1), (0, 1, 2)] + + +def test_difference_raise(): + G = nx.path_graph(4) + H = nx.path_graph(3) + pytest.raises(nx.NetworkXError, nx.difference, G, H) + pytest.raises(nx.NetworkXError, nx.symmetric_difference, G, H) + + +def test_symmetric_difference_multigraph(): + g = nx.MultiGraph() + g.add_edge(0, 1, key=0) + g.add_edge(0, 1, key=1) + g.add_edge(0, 1, key=2) + h = nx.MultiGraph() + h.add_edge(0, 1, key=0) + h.add_edge(0, 1, key=3) + gh = nx.symmetric_difference(g, h) + assert set(gh.nodes()) == set(g.nodes()) + assert set(gh.nodes()) == set(h.nodes()) + assert sorted(gh.edges()) == 3 * [(0, 1)] + assert sorted(sorted(e) for e in gh.edges(keys=True)) == [ + [0, 1, 1], + [0, 1, 2], + [0, 1, 3], + ] + + +def test_union_and_compose(): + K3 = nx.complete_graph(3) + P3 = nx.path_graph(3) + + G1 = nx.DiGraph() + G1.add_edge("A", "B") + G1.add_edge("A", "C") + G1.add_edge("A", "D") + G2 = nx.DiGraph() + G2.add_edge("1", "2") + G2.add_edge("1", "3") + G2.add_edge("1", "4") + + G = nx.union(G1, G2) + H = nx.compose(G1, G2) + assert edges_equal(G.edges(), H.edges()) + assert not G.has_edge("A", 1) + pytest.raises(nx.NetworkXError, nx.union, K3, P3) + H1 = nx.union(H, G1, rename=("H", "G1")) + assert sorted(H1.nodes()) == [ + "G1A", + "G1B", + "G1C", + "G1D", + "H1", + "H2", + "H3", + "H4", + "HA", + "HB", + "HC", + "HD", + ] + + H2 = nx.union(H, G2, rename=("H", "")) + assert sorted(H2.nodes()) == [ + "1", + "2", + "3", + "4", + "H1", + "H2", + "H3", + "H4", + "HA", + "HB", + "HC", + "HD", + ] + + assert not H1.has_edge("NB", "NA") + + G = nx.compose(G, G) + assert edges_equal(G.edges(), H.edges()) + + G2 = nx.union(G2, G2, rename=("", "copy")) + assert sorted(G2.nodes()) == [ + "1", + "2", + "3", + "4", + "copy1", + "copy2", + "copy3", + "copy4", + ] + + assert sorted(G2.neighbors("copy4")) == [] + assert sorted(G2.neighbors("copy1")) == ["copy2", "copy3", "copy4"] + assert len(G) == 8 + assert nx.number_of_edges(G) == 6 + + E = nx.disjoint_union(G, G) + assert len(E) == 16 + assert nx.number_of_edges(E) == 12 + + E = nx.disjoint_union(G1, G2) + assert sorted(E.nodes()) == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] + + G = nx.Graph() + H = nx.Graph() + G.add_nodes_from([(1, {"a1": 1})]) + H.add_nodes_from([(1, {"b1": 1})]) + R = nx.compose(G, H) + assert R.nodes == {1: {"a1": 1, "b1": 1}} + + +def test_union_multigraph(): + G = nx.MultiGraph() + G.add_edge(1, 2, key=0) + G.add_edge(1, 2, key=1) + H = nx.MultiGraph() + H.add_edge(3, 4, key=0) + H.add_edge(3, 4, key=1) + GH = nx.union(G, H) + assert set(GH) == set(G) | set(H) + assert set(GH.edges(keys=True)) == set(G.edges(keys=True)) | set(H.edges(keys=True)) + + +def test_disjoint_union_multigraph(): + G = nx.MultiGraph() + G.add_edge(0, 1, key=0) + G.add_edge(0, 1, key=1) + H = nx.MultiGraph() + H.add_edge(2, 3, key=0) + H.add_edge(2, 3, key=1) + GH = nx.disjoint_union(G, H) + assert set(GH) == set(G) | set(H) + assert set(GH.edges(keys=True)) == set(G.edges(keys=True)) | set(H.edges(keys=True)) + + +def test_compose_multigraph(): + G = nx.MultiGraph() + G.add_edge(1, 2, key=0) + G.add_edge(1, 2, key=1) + H = nx.MultiGraph() + H.add_edge(3, 4, key=0) + H.add_edge(3, 4, key=1) + GH = nx.compose(G, H) + assert set(GH) == set(G) | set(H) + assert set(GH.edges(keys=True)) == set(G.edges(keys=True)) | set(H.edges(keys=True)) + H.add_edge(1, 2, key=2) + GH = nx.compose(G, H) + assert set(GH) == set(G) | set(H) + assert set(GH.edges(keys=True)) == set(G.edges(keys=True)) | set(H.edges(keys=True)) + + +def test_full_join_graph(): + # Simple Graphs + G = nx.Graph() + G.add_node(0) + G.add_edge(1, 2) + H = nx.Graph() + H.add_edge(3, 4) + + U = nx.full_join(G, H) + assert set(U) == set(G) | set(H) + assert len(U) == len(G) + len(H) + assert len(U.edges()) == len(G.edges()) + len(H.edges()) + len(G) * len(H) + + # Rename + U = nx.full_join(G, H, rename=("g", "h")) + assert set(U) == {"g0", "g1", "g2", "h3", "h4"} + assert len(U) == len(G) + len(H) + assert len(U.edges()) == len(G.edges()) + len(H.edges()) + len(G) * len(H) + + # Rename graphs with string-like nodes + G = nx.Graph() + G.add_node("a") + G.add_edge("b", "c") + H = nx.Graph() + H.add_edge("d", "e") + + U = nx.full_join(G, H, rename=("g", "h")) + assert set(U) == {"ga", "gb", "gc", "hd", "he"} + assert len(U) == len(G) + len(H) + assert len(U.edges()) == len(G.edges()) + len(H.edges()) + len(G) * len(H) + + # DiGraphs + G = nx.DiGraph() + G.add_node(0) + G.add_edge(1, 2) + H = nx.DiGraph() + H.add_edge(3, 4) + + U = nx.full_join(G, H) + assert set(U) == set(G) | set(H) + assert len(U) == len(G) + len(H) + assert len(U.edges()) == len(G.edges()) + len(H.edges()) + len(G) * len(H) * 2 + + # DiGraphs Rename + U = nx.full_join(G, H, rename=("g", "h")) + assert set(U) == {"g0", "g1", "g2", "h3", "h4"} + assert len(U) == len(G) + len(H) + assert len(U.edges()) == len(G.edges()) + len(H.edges()) + len(G) * len(H) * 2 + + +def test_full_join_multigraph(): + # MultiGraphs + G = nx.MultiGraph() + G.add_node(0) + G.add_edge(1, 2) + H = nx.MultiGraph() + H.add_edge(3, 4) + + U = nx.full_join(G, H) + assert set(U) == set(G) | set(H) + assert len(U) == len(G) + len(H) + assert len(U.edges()) == len(G.edges()) + len(H.edges()) + len(G) * len(H) + + # MultiGraphs rename + U = nx.full_join(G, H, rename=("g", "h")) + assert set(U) == {"g0", "g1", "g2", "h3", "h4"} + assert len(U) == len(G) + len(H) + assert len(U.edges()) == len(G.edges()) + len(H.edges()) + len(G) * len(H) + + # MultiDiGraphs + G = nx.MultiDiGraph() + G.add_node(0) + G.add_edge(1, 2) + H = nx.MultiDiGraph() + H.add_edge(3, 4) + + U = nx.full_join(G, H) + assert set(U) == set(G) | set(H) + assert len(U) == len(G) + len(H) + assert len(U.edges()) == len(G.edges()) + len(H.edges()) + len(G) * len(H) * 2 + + # MultiDiGraphs rename + U = nx.full_join(G, H, rename=("g", "h")) + assert set(U) == {"g0", "g1", "g2", "h3", "h4"} + assert len(U) == len(G) + len(H) + assert len(U.edges()) == len(G.edges()) + len(H.edges()) + len(G) * len(H) * 2 + + +def test_mixed_type_union(): + G = nx.Graph() + H = nx.MultiGraph() + pytest.raises(nx.NetworkXError, nx.union, G, H) + pytest.raises(nx.NetworkXError, nx.disjoint_union, G, H) + pytest.raises(nx.NetworkXError, nx.intersection, G, H) + pytest.raises(nx.NetworkXError, nx.difference, G, H) + pytest.raises(nx.NetworkXError, nx.symmetric_difference, G, H) + pytest.raises(nx.NetworkXError, nx.compose, G, H) diff --git a/lib/python3.10/site-packages/networkx/algorithms/operators/tests/test_product.py b/lib/python3.10/site-packages/networkx/algorithms/operators/tests/test_product.py new file mode 100644 index 0000000000000000000000000000000000000000..8ee54b93012c79531f2732da282072754da82046 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/operators/tests/test_product.py @@ -0,0 +1,491 @@ +import pytest + +import networkx as nx +from networkx.utils import edges_equal + + +def test_tensor_product_raises(): + with pytest.raises(nx.NetworkXError): + P = nx.tensor_product(nx.DiGraph(), nx.Graph()) + + +def test_tensor_product_null(): + null = nx.null_graph() + empty10 = nx.empty_graph(10) + K3 = nx.complete_graph(3) + K10 = nx.complete_graph(10) + P3 = nx.path_graph(3) + P10 = nx.path_graph(10) + # null graph + G = nx.tensor_product(null, null) + assert nx.is_isomorphic(G, null) + # null_graph X anything = null_graph and v.v. + G = nx.tensor_product(null, empty10) + assert nx.is_isomorphic(G, null) + G = nx.tensor_product(null, K3) + assert nx.is_isomorphic(G, null) + G = nx.tensor_product(null, K10) + assert nx.is_isomorphic(G, null) + G = nx.tensor_product(null, P3) + assert nx.is_isomorphic(G, null) + G = nx.tensor_product(null, P10) + assert nx.is_isomorphic(G, null) + G = nx.tensor_product(empty10, null) + assert nx.is_isomorphic(G, null) + G = nx.tensor_product(K3, null) + assert nx.is_isomorphic(G, null) + G = nx.tensor_product(K10, null) + assert nx.is_isomorphic(G, null) + G = nx.tensor_product(P3, null) + assert nx.is_isomorphic(G, null) + G = nx.tensor_product(P10, null) + assert nx.is_isomorphic(G, null) + + +def test_tensor_product_size(): + P5 = nx.path_graph(5) + K3 = nx.complete_graph(3) + K5 = nx.complete_graph(5) + + G = nx.tensor_product(P5, K3) + assert nx.number_of_nodes(G) == 5 * 3 + G = nx.tensor_product(K3, K5) + assert nx.number_of_nodes(G) == 3 * 5 + + +def test_tensor_product_combinations(): + # basic smoke test, more realistic tests would be useful + P5 = nx.path_graph(5) + K3 = nx.complete_graph(3) + G = nx.tensor_product(P5, K3) + assert nx.number_of_nodes(G) == 5 * 3 + G = nx.tensor_product(P5, nx.MultiGraph(K3)) + assert nx.number_of_nodes(G) == 5 * 3 + G = nx.tensor_product(nx.MultiGraph(P5), K3) + assert nx.number_of_nodes(G) == 5 * 3 + G = nx.tensor_product(nx.MultiGraph(P5), nx.MultiGraph(K3)) + assert nx.number_of_nodes(G) == 5 * 3 + + G = nx.tensor_product(nx.DiGraph(P5), nx.DiGraph(K3)) + assert nx.number_of_nodes(G) == 5 * 3 + + +def test_tensor_product_classic_result(): + K2 = nx.complete_graph(2) + G = nx.petersen_graph() + G = nx.tensor_product(G, K2) + assert nx.is_isomorphic(G, nx.desargues_graph()) + + G = nx.cycle_graph(5) + G = nx.tensor_product(G, K2) + assert nx.is_isomorphic(G, nx.cycle_graph(10)) + + G = nx.tetrahedral_graph() + G = nx.tensor_product(G, K2) + assert nx.is_isomorphic(G, nx.cubical_graph()) + + +def test_tensor_product_random(): + G = nx.erdos_renyi_graph(10, 2 / 10.0) + H = nx.erdos_renyi_graph(10, 2 / 10.0) + GH = nx.tensor_product(G, H) + + for u_G, u_H in GH.nodes(): + for v_G, v_H in GH.nodes(): + if H.has_edge(u_H, v_H) and G.has_edge(u_G, v_G): + assert GH.has_edge((u_G, u_H), (v_G, v_H)) + else: + assert not GH.has_edge((u_G, u_H), (v_G, v_H)) + + +def test_cartesian_product_multigraph(): + G = nx.MultiGraph() + G.add_edge(1, 2, key=0) + G.add_edge(1, 2, key=1) + H = nx.MultiGraph() + H.add_edge(3, 4, key=0) + H.add_edge(3, 4, key=1) + GH = nx.cartesian_product(G, H) + assert set(GH) == {(1, 3), (2, 3), (2, 4), (1, 4)} + assert {(frozenset([u, v]), k) for u, v, k in GH.edges(keys=True)} == { + (frozenset([u, v]), k) + for u, v, k in [ + ((1, 3), (2, 3), 0), + ((1, 3), (2, 3), 1), + ((1, 3), (1, 4), 0), + ((1, 3), (1, 4), 1), + ((2, 3), (2, 4), 0), + ((2, 3), (2, 4), 1), + ((2, 4), (1, 4), 0), + ((2, 4), (1, 4), 1), + ] + } + + +def test_cartesian_product_raises(): + with pytest.raises(nx.NetworkXError): + P = nx.cartesian_product(nx.DiGraph(), nx.Graph()) + + +def test_cartesian_product_null(): + null = nx.null_graph() + empty10 = nx.empty_graph(10) + K3 = nx.complete_graph(3) + K10 = nx.complete_graph(10) + P3 = nx.path_graph(3) + P10 = nx.path_graph(10) + # null graph + G = nx.cartesian_product(null, null) + assert nx.is_isomorphic(G, null) + # null_graph X anything = null_graph and v.v. + G = nx.cartesian_product(null, empty10) + assert nx.is_isomorphic(G, null) + G = nx.cartesian_product(null, K3) + assert nx.is_isomorphic(G, null) + G = nx.cartesian_product(null, K10) + assert nx.is_isomorphic(G, null) + G = nx.cartesian_product(null, P3) + assert nx.is_isomorphic(G, null) + G = nx.cartesian_product(null, P10) + assert nx.is_isomorphic(G, null) + G = nx.cartesian_product(empty10, null) + assert nx.is_isomorphic(G, null) + G = nx.cartesian_product(K3, null) + assert nx.is_isomorphic(G, null) + G = nx.cartesian_product(K10, null) + assert nx.is_isomorphic(G, null) + G = nx.cartesian_product(P3, null) + assert nx.is_isomorphic(G, null) + G = nx.cartesian_product(P10, null) + assert nx.is_isomorphic(G, null) + + +def test_cartesian_product_size(): + # order(GXH)=order(G)*order(H) + K5 = nx.complete_graph(5) + P5 = nx.path_graph(5) + K3 = nx.complete_graph(3) + G = nx.cartesian_product(P5, K3) + assert nx.number_of_nodes(G) == 5 * 3 + assert nx.number_of_edges(G) == nx.number_of_edges(P5) * nx.number_of_nodes( + K3 + ) + nx.number_of_edges(K3) * nx.number_of_nodes(P5) + G = nx.cartesian_product(K3, K5) + assert nx.number_of_nodes(G) == 3 * 5 + assert nx.number_of_edges(G) == nx.number_of_edges(K5) * nx.number_of_nodes( + K3 + ) + nx.number_of_edges(K3) * nx.number_of_nodes(K5) + + +def test_cartesian_product_classic(): + # test some classic product graphs + P2 = nx.path_graph(2) + P3 = nx.path_graph(3) + # cube = 2-path X 2-path + G = nx.cartesian_product(P2, P2) + G = nx.cartesian_product(P2, G) + assert nx.is_isomorphic(G, nx.cubical_graph()) + + # 3x3 grid + G = nx.cartesian_product(P3, P3) + assert nx.is_isomorphic(G, nx.grid_2d_graph(3, 3)) + + +def test_cartesian_product_random(): + G = nx.erdos_renyi_graph(10, 2 / 10.0) + H = nx.erdos_renyi_graph(10, 2 / 10.0) + GH = nx.cartesian_product(G, H) + + for u_G, u_H in GH.nodes(): + for v_G, v_H in GH.nodes(): + if (u_G == v_G and H.has_edge(u_H, v_H)) or ( + u_H == v_H and G.has_edge(u_G, v_G) + ): + assert GH.has_edge((u_G, u_H), (v_G, v_H)) + else: + assert not GH.has_edge((u_G, u_H), (v_G, v_H)) + + +def test_lexicographic_product_raises(): + with pytest.raises(nx.NetworkXError): + P = nx.lexicographic_product(nx.DiGraph(), nx.Graph()) + + +def test_lexicographic_product_null(): + null = nx.null_graph() + empty10 = nx.empty_graph(10) + K3 = nx.complete_graph(3) + K10 = nx.complete_graph(10) + P3 = nx.path_graph(3) + P10 = nx.path_graph(10) + # null graph + G = nx.lexicographic_product(null, null) + assert nx.is_isomorphic(G, null) + # null_graph X anything = null_graph and v.v. + G = nx.lexicographic_product(null, empty10) + assert nx.is_isomorphic(G, null) + G = nx.lexicographic_product(null, K3) + assert nx.is_isomorphic(G, null) + G = nx.lexicographic_product(null, K10) + assert nx.is_isomorphic(G, null) + G = nx.lexicographic_product(null, P3) + assert nx.is_isomorphic(G, null) + G = nx.lexicographic_product(null, P10) + assert nx.is_isomorphic(G, null) + G = nx.lexicographic_product(empty10, null) + assert nx.is_isomorphic(G, null) + G = nx.lexicographic_product(K3, null) + assert nx.is_isomorphic(G, null) + G = nx.lexicographic_product(K10, null) + assert nx.is_isomorphic(G, null) + G = nx.lexicographic_product(P3, null) + assert nx.is_isomorphic(G, null) + G = nx.lexicographic_product(P10, null) + assert nx.is_isomorphic(G, null) + + +def test_lexicographic_product_size(): + K5 = nx.complete_graph(5) + P5 = nx.path_graph(5) + K3 = nx.complete_graph(3) + G = nx.lexicographic_product(P5, K3) + assert nx.number_of_nodes(G) == 5 * 3 + G = nx.lexicographic_product(K3, K5) + assert nx.number_of_nodes(G) == 3 * 5 + + +def test_lexicographic_product_combinations(): + P5 = nx.path_graph(5) + K3 = nx.complete_graph(3) + G = nx.lexicographic_product(P5, K3) + assert nx.number_of_nodes(G) == 5 * 3 + G = nx.lexicographic_product(nx.MultiGraph(P5), K3) + assert nx.number_of_nodes(G) == 5 * 3 + G = nx.lexicographic_product(P5, nx.MultiGraph(K3)) + assert nx.number_of_nodes(G) == 5 * 3 + G = nx.lexicographic_product(nx.MultiGraph(P5), nx.MultiGraph(K3)) + assert nx.number_of_nodes(G) == 5 * 3 + + # No classic easily found classic results for lexicographic product + + +def test_lexicographic_product_random(): + G = nx.erdos_renyi_graph(10, 2 / 10.0) + H = nx.erdos_renyi_graph(10, 2 / 10.0) + GH = nx.lexicographic_product(G, H) + + for u_G, u_H in GH.nodes(): + for v_G, v_H in GH.nodes(): + if G.has_edge(u_G, v_G) or (u_G == v_G and H.has_edge(u_H, v_H)): + assert GH.has_edge((u_G, u_H), (v_G, v_H)) + else: + assert not GH.has_edge((u_G, u_H), (v_G, v_H)) + + +def test_strong_product_raises(): + with pytest.raises(nx.NetworkXError): + P = nx.strong_product(nx.DiGraph(), nx.Graph()) + + +def test_strong_product_null(): + null = nx.null_graph() + empty10 = nx.empty_graph(10) + K3 = nx.complete_graph(3) + K10 = nx.complete_graph(10) + P3 = nx.path_graph(3) + P10 = nx.path_graph(10) + # null graph + G = nx.strong_product(null, null) + assert nx.is_isomorphic(G, null) + # null_graph X anything = null_graph and v.v. + G = nx.strong_product(null, empty10) + assert nx.is_isomorphic(G, null) + G = nx.strong_product(null, K3) + assert nx.is_isomorphic(G, null) + G = nx.strong_product(null, K10) + assert nx.is_isomorphic(G, null) + G = nx.strong_product(null, P3) + assert nx.is_isomorphic(G, null) + G = nx.strong_product(null, P10) + assert nx.is_isomorphic(G, null) + G = nx.strong_product(empty10, null) + assert nx.is_isomorphic(G, null) + G = nx.strong_product(K3, null) + assert nx.is_isomorphic(G, null) + G = nx.strong_product(K10, null) + assert nx.is_isomorphic(G, null) + G = nx.strong_product(P3, null) + assert nx.is_isomorphic(G, null) + G = nx.strong_product(P10, null) + assert nx.is_isomorphic(G, null) + + +def test_strong_product_size(): + K5 = nx.complete_graph(5) + P5 = nx.path_graph(5) + K3 = nx.complete_graph(3) + G = nx.strong_product(P5, K3) + assert nx.number_of_nodes(G) == 5 * 3 + G = nx.strong_product(K3, K5) + assert nx.number_of_nodes(G) == 3 * 5 + + +def test_strong_product_combinations(): + P5 = nx.path_graph(5) + K3 = nx.complete_graph(3) + G = nx.strong_product(P5, K3) + assert nx.number_of_nodes(G) == 5 * 3 + G = nx.strong_product(nx.MultiGraph(P5), K3) + assert nx.number_of_nodes(G) == 5 * 3 + G = nx.strong_product(P5, nx.MultiGraph(K3)) + assert nx.number_of_nodes(G) == 5 * 3 + G = nx.strong_product(nx.MultiGraph(P5), nx.MultiGraph(K3)) + assert nx.number_of_nodes(G) == 5 * 3 + + # No classic easily found classic results for strong product + + +def test_strong_product_random(): + G = nx.erdos_renyi_graph(10, 2 / 10.0) + H = nx.erdos_renyi_graph(10, 2 / 10.0) + GH = nx.strong_product(G, H) + + for u_G, u_H in GH.nodes(): + for v_G, v_H in GH.nodes(): + if ( + (u_G == v_G and H.has_edge(u_H, v_H)) + or (u_H == v_H and G.has_edge(u_G, v_G)) + or (G.has_edge(u_G, v_G) and H.has_edge(u_H, v_H)) + ): + assert GH.has_edge((u_G, u_H), (v_G, v_H)) + else: + assert not GH.has_edge((u_G, u_H), (v_G, v_H)) + + +def test_graph_power_raises(): + with pytest.raises(nx.NetworkXNotImplemented): + nx.power(nx.MultiDiGraph(), 2) + + +def test_graph_power(): + # wikipedia example for graph power + G = nx.cycle_graph(7) + G.add_edge(6, 7) + G.add_edge(7, 8) + G.add_edge(8, 9) + G.add_edge(9, 2) + H = nx.power(G, 2) + + assert edges_equal( + list(H.edges()), + [ + (0, 1), + (0, 2), + (0, 5), + (0, 6), + (0, 7), + (1, 9), + (1, 2), + (1, 3), + (1, 6), + (2, 3), + (2, 4), + (2, 8), + (2, 9), + (3, 4), + (3, 5), + (3, 9), + (4, 5), + (4, 6), + (5, 6), + (5, 7), + (6, 7), + (6, 8), + (7, 8), + (7, 9), + (8, 9), + ], + ) + + +def test_graph_power_negative(): + with pytest.raises(ValueError): + nx.power(nx.Graph(), -1) + + +def test_rooted_product_raises(): + with pytest.raises(nx.NodeNotFound): + nx.rooted_product(nx.Graph(), nx.path_graph(2), 10) + + +def test_rooted_product(): + G = nx.cycle_graph(5) + H = nx.Graph() + H.add_edges_from([("a", "b"), ("b", "c"), ("b", "d")]) + R = nx.rooted_product(G, H, "a") + assert len(R) == len(G) * len(H) + assert R.size() == G.size() + len(G) * H.size() + + +def test_corona_product(): + G = nx.cycle_graph(3) + H = nx.path_graph(2) + C = nx.corona_product(G, H) + assert len(C) == (len(G) * len(H)) + len(G) + assert C.size() == G.size() + len(G) * H.size() + len(G) * len(H) + + +def test_modular_product(): + G = nx.path_graph(3) + H = nx.path_graph(4) + M = nx.modular_product(G, H) + assert len(M) == len(G) * len(H) + + assert edges_equal( + list(M.edges()), + [ + ((0, 0), (1, 1)), + ((0, 0), (2, 2)), + ((0, 0), (2, 3)), + ((0, 1), (1, 0)), + ((0, 1), (1, 2)), + ((0, 1), (2, 3)), + ((0, 2), (1, 1)), + ((0, 2), (1, 3)), + ((0, 2), (2, 0)), + ((0, 3), (1, 2)), + ((0, 3), (2, 0)), + ((0, 3), (2, 1)), + ((1, 0), (2, 1)), + ((1, 1), (2, 0)), + ((1, 1), (2, 2)), + ((1, 2), (2, 1)), + ((1, 2), (2, 3)), + ((1, 3), (2, 2)), + ], + ) + + +def test_modular_product_raises(): + G = nx.Graph([(0, 1), (1, 2), (2, 0)]) + H = nx.Graph([(0, 1), (1, 2), (2, 0)]) + DG = nx.DiGraph([(0, 1), (1, 2), (2, 0)]) + DH = nx.DiGraph([(0, 1), (1, 2), (2, 0)]) + with pytest.raises(nx.NetworkXNotImplemented): + nx.modular_product(G, DH) + with pytest.raises(nx.NetworkXNotImplemented): + nx.modular_product(DG, H) + with pytest.raises(nx.NetworkXNotImplemented): + nx.modular_product(DG, DH) + + MG = nx.MultiGraph([(0, 1), (1, 2), (2, 0), (0, 1)]) + MH = nx.MultiGraph([(0, 1), (1, 2), (2, 0), (0, 1)]) + with pytest.raises(nx.NetworkXNotImplemented): + nx.modular_product(G, MH) + with pytest.raises(nx.NetworkXNotImplemented): + nx.modular_product(MG, H) + with pytest.raises(nx.NetworkXNotImplemented): + nx.modular_product(MG, MH) + with pytest.raises(nx.NetworkXNotImplemented): + # check multigraph with no multiedges + nx.modular_product(nx.MultiGraph(G), H) diff --git a/lib/python3.10/site-packages/networkx/algorithms/operators/tests/test_unary.py b/lib/python3.10/site-packages/networkx/algorithms/operators/tests/test_unary.py new file mode 100644 index 0000000000000000000000000000000000000000..d68e55cd9c9fa37459b497c32a7a095576c306c3 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/operators/tests/test_unary.py @@ -0,0 +1,55 @@ +import pytest + +import networkx as nx + + +def test_complement(): + null = nx.null_graph() + empty1 = nx.empty_graph(1) + empty10 = nx.empty_graph(10) + K3 = nx.complete_graph(3) + K5 = nx.complete_graph(5) + K10 = nx.complete_graph(10) + P2 = nx.path_graph(2) + P3 = nx.path_graph(3) + P5 = nx.path_graph(5) + P10 = nx.path_graph(10) + # complement of the complete graph is empty + + G = nx.complement(K3) + assert nx.is_isomorphic(G, nx.empty_graph(3)) + G = nx.complement(K5) + assert nx.is_isomorphic(G, nx.empty_graph(5)) + # for any G, G=complement(complement(G)) + P3cc = nx.complement(nx.complement(P3)) + assert nx.is_isomorphic(P3, P3cc) + nullcc = nx.complement(nx.complement(null)) + assert nx.is_isomorphic(null, nullcc) + b = nx.bull_graph() + bcc = nx.complement(nx.complement(b)) + assert nx.is_isomorphic(b, bcc) + + +def test_complement_2(): + G1 = nx.DiGraph() + G1.add_edge("A", "B") + G1.add_edge("A", "C") + G1.add_edge("A", "D") + G1C = nx.complement(G1) + assert sorted(G1C.edges()) == [ + ("B", "A"), + ("B", "C"), + ("B", "D"), + ("C", "A"), + ("C", "B"), + ("C", "D"), + ("D", "A"), + ("D", "B"), + ("D", "C"), + ] + + +def test_reverse1(): + # Other tests for reverse are done by the DiGraph and MultiDigraph. + G1 = nx.Graph() + pytest.raises(nx.NetworkXError, nx.reverse, G1) diff --git a/lib/python3.10/site-packages/networkx/algorithms/operators/unary.py b/lib/python3.10/site-packages/networkx/algorithms/operators/unary.py new file mode 100644 index 0000000000000000000000000000000000000000..79e44d1cc04cff72c5c87d1852544514a6f53246 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/operators/unary.py @@ -0,0 +1,77 @@ +"""Unary operations on graphs""" + +import networkx as nx + +__all__ = ["complement", "reverse"] + + +@nx._dispatchable(returns_graph=True) +def complement(G): + """Returns the graph complement of G. + + Parameters + ---------- + G : graph + A NetworkX graph + + Returns + ------- + GC : A new graph. + + Notes + ----- + Note that `complement` does not create self-loops and also + does not produce parallel edges for MultiGraphs. + + Graph, node, and edge data are not propagated to the new graph. + + Examples + -------- + >>> G = nx.Graph([(1, 2), (1, 3), (2, 3), (3, 4), (3, 5)]) + >>> G_complement = nx.complement(G) + >>> G_complement.edges() # This shows the edges of the complemented graph + EdgeView([(1, 4), (1, 5), (2, 4), (2, 5), (4, 5)]) + + """ + R = G.__class__() + R.add_nodes_from(G) + R.add_edges_from( + ((n, n2) for n, nbrs in G.adjacency() for n2 in G if n2 not in nbrs if n != n2) + ) + return R + + +@nx._dispatchable(returns_graph=True) +def reverse(G, copy=True): + """Returns the reverse directed graph of G. + + Parameters + ---------- + G : directed graph + A NetworkX directed graph + copy : bool + If True, then a new graph is returned. If False, then the graph is + reversed in place. + + Returns + ------- + H : directed graph + The reversed G. + + Raises + ------ + NetworkXError + If graph is undirected. + + Examples + -------- + >>> G = nx.DiGraph([(1, 2), (1, 3), (2, 3), (3, 4), (3, 5)]) + >>> G_reversed = nx.reverse(G) + >>> G_reversed.edges() + OutEdgeView([(2, 1), (3, 1), (3, 2), (4, 3), (5, 3)]) + + """ + if not G.is_directed(): + raise nx.NetworkXError("Cannot reverse an undirected graph.") + else: + return G.reverse(copy=copy) diff --git a/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/__init__.py b/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..eb0d91cecc902f6390cb8309c017cb1558f7753f --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/__init__.py @@ -0,0 +1,5 @@ +from networkx.algorithms.shortest_paths.generic import * +from networkx.algorithms.shortest_paths.unweighted import * +from networkx.algorithms.shortest_paths.weighted import * +from networkx.algorithms.shortest_paths.astar import * +from networkx.algorithms.shortest_paths.dense import * diff --git a/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dd923121f3316f88137c66bdf79c089ae8ecea4b Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/__pycache__/astar.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/__pycache__/astar.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b90784a44920aeb907e43f79717facc3837cfcf Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/__pycache__/astar.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/__pycache__/dense.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/__pycache__/dense.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c1e9cf80a80c869e4c92ef5611112557930977a Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/__pycache__/dense.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/__pycache__/generic.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/__pycache__/generic.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3145f97a10df5aa50620e9bd82b36e0aa98af91b Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/__pycache__/generic.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/__pycache__/unweighted.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/__pycache__/unweighted.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa4c05c456fba4d02b66b68864ffae4a4739932d Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/__pycache__/unweighted.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/__pycache__/weighted.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/__pycache__/weighted.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..12878ce898df9df51e4a7b9d0dce03f86dbf5dd2 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/__pycache__/weighted.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/astar.py b/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/astar.py new file mode 100644 index 0000000000000000000000000000000000000000..8d988477b2c1034cacd99cec70d932055ba96a8e --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/astar.py @@ -0,0 +1,241 @@ +"""Shortest paths and path lengths using the A* ("A star") algorithm.""" + +from heapq import heappop, heappush +from itertools import count + +import networkx as nx +from networkx.algorithms.shortest_paths.weighted import _weight_function + +__all__ = ["astar_path", "astar_path_length"] + + +@nx._dispatchable(edge_attrs="weight", preserve_node_attrs="heuristic") +def astar_path(G, source, target, heuristic=None, weight="weight", *, cutoff=None): + """Returns a list of nodes in a shortest path between source and target + using the A* ("A-star") algorithm. + + There may be more than one shortest path. This returns only one. + + Parameters + ---------- + G : NetworkX graph + + source : node + Starting node for path + + target : node + Ending node for path + + heuristic : function + A function to evaluate the estimate of the distance + from the a node to the target. The function takes + two nodes arguments and must return a number. + If the heuristic is inadmissible (if it might + overestimate the cost of reaching the goal from a node), + the result may not be a shortest path. + The algorithm does not support updating heuristic + values for the same node due to caching the first + heuristic calculation per node. + + weight : string or function + If this is a string, then edge weights will be accessed via the + edge attribute with this key (that is, the weight of the edge + joining `u` to `v` will be ``G.edges[u, v][weight]``). If no + such edge attribute exists, the weight of the edge is assumed to + be one. + If this is a function, the weight of an edge is the value + returned by the function. The function must accept exactly three + positional arguments: the two endpoints of an edge and the + dictionary of edge attributes for that edge. The function must + return a number or None to indicate a hidden edge. + + cutoff : float, optional + If this is provided, the search will be bounded to this value. I.e. if + the evaluation function surpasses this value for a node n, the node will not + be expanded further and will be ignored. More formally, let h'(n) be the + heuristic function, and g(n) be the cost of reaching n from the source node. Then, + if g(n) + h'(n) > cutoff, the node will not be explored further. + Note that if the heuristic is inadmissible, it is possible that paths + are ignored even though they satisfy the cutoff. + + Raises + ------ + NetworkXNoPath + If no path exists between source and target. + + Examples + -------- + >>> G = nx.path_graph(5) + >>> print(nx.astar_path(G, 0, 4)) + [0, 1, 2, 3, 4] + >>> G = nx.grid_graph(dim=[3, 3]) # nodes are two-tuples (x,y) + >>> nx.set_edge_attributes(G, {e: e[1][0] * 2 for e in G.edges()}, "cost") + >>> def dist(a, b): + ... (x1, y1) = a + ... (x2, y2) = b + ... return ((x1 - x2) ** 2 + (y1 - y2) ** 2) ** 0.5 + >>> print(nx.astar_path(G, (0, 0), (2, 2), heuristic=dist, weight="cost")) + [(0, 0), (0, 1), (0, 2), (1, 2), (2, 2)] + + Notes + ----- + Edge weight attributes must be numerical. + Distances are calculated as sums of weighted edges traversed. + + The weight function can be used to hide edges by returning None. + So ``weight = lambda u, v, d: 1 if d['color']=="red" else None`` + will find the shortest red path. + + See Also + -------- + shortest_path, dijkstra_path + + """ + if source not in G: + raise nx.NodeNotFound(f"Source {source} is not in G") + + if target not in G: + raise nx.NodeNotFound(f"Target {target} is not in G") + + if heuristic is None: + # The default heuristic is h=0 - same as Dijkstra's algorithm + def heuristic(u, v): + return 0 + + push = heappush + pop = heappop + weight = _weight_function(G, weight) + + G_succ = G._adj # For speed-up (and works for both directed and undirected graphs) + + # The queue stores priority, node, cost to reach, and parent. + # Uses Python heapq to keep in priority order. + # Add a counter to the queue to prevent the underlying heap from + # attempting to compare the nodes themselves. The hash breaks ties in the + # priority and is guaranteed unique for all nodes in the graph. + c = count() + queue = [(0, next(c), source, 0, None)] + + # Maps enqueued nodes to distance of discovered paths and the + # computed heuristics to target. We avoid computing the heuristics + # more than once and inserting the node into the queue too many times. + enqueued = {} + # Maps explored nodes to parent closest to the source. + explored = {} + + while queue: + # Pop the smallest item from queue. + _, __, curnode, dist, parent = pop(queue) + + if curnode == target: + path = [curnode] + node = parent + while node is not None: + path.append(node) + node = explored[node] + path.reverse() + return path + + if curnode in explored: + # Do not override the parent of starting node + if explored[curnode] is None: + continue + + # Skip bad paths that were enqueued before finding a better one + qcost, h = enqueued[curnode] + if qcost < dist: + continue + + explored[curnode] = parent + + for neighbor, w in G_succ[curnode].items(): + cost = weight(curnode, neighbor, w) + if cost is None: + continue + ncost = dist + cost + if neighbor in enqueued: + qcost, h = enqueued[neighbor] + # if qcost <= ncost, a less costly path from the + # neighbor to the source was already determined. + # Therefore, we won't attempt to push this neighbor + # to the queue + if qcost <= ncost: + continue + else: + h = heuristic(neighbor, target) + + if cutoff and ncost + h > cutoff: + continue + + enqueued[neighbor] = ncost, h + push(queue, (ncost + h, next(c), neighbor, ncost, curnode)) + + raise nx.NetworkXNoPath(f"Node {target} not reachable from {source}") + + +@nx._dispatchable(edge_attrs="weight", preserve_node_attrs="heuristic") +def astar_path_length( + G, source, target, heuristic=None, weight="weight", *, cutoff=None +): + """Returns the length of the shortest path between source and target using + the A* ("A-star") algorithm. + + Parameters + ---------- + G : NetworkX graph + + source : node + Starting node for path + + target : node + Ending node for path + + heuristic : function + A function to evaluate the estimate of the distance + from the a node to the target. The function takes + two nodes arguments and must return a number. + If the heuristic is inadmissible (if it might + overestimate the cost of reaching the goal from a node), + the result may not be a shortest path. + The algorithm does not support updating heuristic + values for the same node due to caching the first + heuristic calculation per node. + + weight : string or function + If this is a string, then edge weights will be accessed via the + edge attribute with this key (that is, the weight of the edge + joining `u` to `v` will be ``G.edges[u, v][weight]``). If no + such edge attribute exists, the weight of the edge is assumed to + be one. + If this is a function, the weight of an edge is the value + returned by the function. The function must accept exactly three + positional arguments: the two endpoints of an edge and the + dictionary of edge attributes for that edge. The function must + return a number or None to indicate a hidden edge. + + cutoff : float, optional + If this is provided, the search will be bounded to this value. I.e. if + the evaluation function surpasses this value for a node n, the node will not + be expanded further and will be ignored. More formally, let h'(n) be the + heuristic function, and g(n) be the cost of reaching n from the source node. Then, + if g(n) + h'(n) > cutoff, the node will not be explored further. + Note that if the heuristic is inadmissible, it is possible that paths + are ignored even though they satisfy the cutoff. + + Raises + ------ + NetworkXNoPath + If no path exists between source and target. + + See Also + -------- + astar_path + + """ + if source not in G or target not in G: + msg = f"Either source {source} or target {target} is not in G" + raise nx.NodeNotFound(msg) + + weight = _weight_function(G, weight) + path = astar_path(G, source, target, heuristic, weight, cutoff=cutoff) + return sum(weight(u, v, G[u][v]) for u, v in zip(path[:-1], path[1:])) diff --git a/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/dense.py b/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/dense.py new file mode 100644 index 0000000000000000000000000000000000000000..107b920884a0cde29b77257e941cd2e172f140e6 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/dense.py @@ -0,0 +1,260 @@ +"""Floyd-Warshall algorithm for shortest paths.""" + +import networkx as nx + +__all__ = [ + "floyd_warshall", + "floyd_warshall_predecessor_and_distance", + "reconstruct_path", + "floyd_warshall_numpy", +] + + +@nx._dispatchable(edge_attrs="weight") +def floyd_warshall_numpy(G, nodelist=None, weight="weight"): + """Find all-pairs shortest path lengths using Floyd's algorithm. + + This algorithm for finding shortest paths takes advantage of + matrix representations of a graph and works well for dense + graphs where all-pairs shortest path lengths are desired. + The results are returned as a NumPy array, distance[i, j], + where i and j are the indexes of two nodes in nodelist. + The entry distance[i, j] is the distance along a shortest + path from i to j. If no path exists the distance is Inf. + + Parameters + ---------- + G : NetworkX graph + + nodelist : list, optional (default=G.nodes) + The rows and columns are ordered by the nodes in nodelist. + If nodelist is None then the ordering is produced by G.nodes. + Nodelist should include all nodes in G. + + weight: string, optional (default='weight') + Edge data key corresponding to the edge weight. + + Returns + ------- + distance : 2D numpy.ndarray + A numpy array of shortest path distances between nodes. + If there is no path between two nodes the value is Inf. + + Examples + -------- + >>> G = nx.DiGraph() + >>> G.add_weighted_edges_from( + ... [(0, 1, 5), (1, 2, 2), (2, 3, -3), (1, 3, 10), (3, 2, 8)] + ... ) + >>> nx.floyd_warshall_numpy(G) + array([[ 0., 5., 7., 4.], + [inf, 0., 2., -1.], + [inf, inf, 0., -3.], + [inf, inf, 8., 0.]]) + + Notes + ----- + Floyd's algorithm is appropriate for finding shortest paths in + dense graphs or graphs with negative weights when Dijkstra's + algorithm fails. This algorithm can still fail if there are negative + cycles. It has running time $O(n^3)$ with running space of $O(n^2)$. + + Raises + ------ + NetworkXError + If nodelist is not a list of the nodes in G. + """ + import numpy as np + + if nodelist is not None: + if not (len(nodelist) == len(G) == len(set(nodelist))): + raise nx.NetworkXError( + "nodelist must contain every node in G with no repeats." + "If you wanted a subgraph of G use G.subgraph(nodelist)" + ) + + # To handle cases when an edge has weight=0, we must make sure that + # nonedges are not given the value 0 as well. + A = nx.to_numpy_array( + G, nodelist, multigraph_weight=min, weight=weight, nonedge=np.inf + ) + n, m = A.shape + np.fill_diagonal(A, 0) # diagonal elements should be zero + for i in range(n): + # The second term has the same shape as A due to broadcasting + A = np.minimum(A, A[i, :][np.newaxis, :] + A[:, i][:, np.newaxis]) + return A + + +@nx._dispatchable(edge_attrs="weight") +def floyd_warshall_predecessor_and_distance(G, weight="weight"): + """Find all-pairs shortest path lengths using Floyd's algorithm. + + Parameters + ---------- + G : NetworkX graph + + weight: string, optional (default= 'weight') + Edge data key corresponding to the edge weight. + + Returns + ------- + predecessor,distance : dictionaries + Dictionaries, keyed by source and target, of predecessors and distances + in the shortest path. + + Examples + -------- + >>> G = nx.DiGraph() + >>> G.add_weighted_edges_from( + ... [ + ... ("s", "u", 10), + ... ("s", "x", 5), + ... ("u", "v", 1), + ... ("u", "x", 2), + ... ("v", "y", 1), + ... ("x", "u", 3), + ... ("x", "v", 5), + ... ("x", "y", 2), + ... ("y", "s", 7), + ... ("y", "v", 6), + ... ] + ... ) + >>> predecessors, _ = nx.floyd_warshall_predecessor_and_distance(G) + >>> print(nx.reconstruct_path("s", "v", predecessors)) + ['s', 'x', 'u', 'v'] + + Notes + ----- + Floyd's algorithm is appropriate for finding shortest paths + in dense graphs or graphs with negative weights when Dijkstra's algorithm + fails. This algorithm can still fail if there are negative cycles. + It has running time $O(n^3)$ with running space of $O(n^2)$. + + See Also + -------- + floyd_warshall + floyd_warshall_numpy + all_pairs_shortest_path + all_pairs_shortest_path_length + """ + from collections import defaultdict + + # dictionary-of-dictionaries representation for dist and pred + # use some defaultdict magick here + # for dist the default is the floating point inf value + dist = defaultdict(lambda: defaultdict(lambda: float("inf"))) + for u in G: + dist[u][u] = 0 + pred = defaultdict(dict) + # initialize path distance dictionary to be the adjacency matrix + # also set the distance to self to 0 (zero diagonal) + undirected = not G.is_directed() + for u, v, d in G.edges(data=True): + e_weight = d.get(weight, 1.0) + dist[u][v] = min(e_weight, dist[u][v]) + pred[u][v] = u + if undirected: + dist[v][u] = min(e_weight, dist[v][u]) + pred[v][u] = v + for w in G: + dist_w = dist[w] # save recomputation + for u in G: + dist_u = dist[u] # save recomputation + for v in G: + d = dist_u[w] + dist_w[v] + if dist_u[v] > d: + dist_u[v] = d + pred[u][v] = pred[w][v] + return dict(pred), dict(dist) + + +@nx._dispatchable(graphs=None) +def reconstruct_path(source, target, predecessors): + """Reconstruct a path from source to target using the predecessors + dict as returned by floyd_warshall_predecessor_and_distance + + Parameters + ---------- + source : node + Starting node for path + + target : node + Ending node for path + + predecessors: dictionary + Dictionary, keyed by source and target, of predecessors in the + shortest path, as returned by floyd_warshall_predecessor_and_distance + + Returns + ------- + path : list + A list of nodes containing the shortest path from source to target + + If source and target are the same, an empty list is returned + + Notes + ----- + This function is meant to give more applicability to the + floyd_warshall_predecessor_and_distance function + + See Also + -------- + floyd_warshall_predecessor_and_distance + """ + if source == target: + return [] + prev = predecessors[source] + curr = prev[target] + path = [target, curr] + while curr != source: + curr = prev[curr] + path.append(curr) + return list(reversed(path)) + + +@nx._dispatchable(edge_attrs="weight") +def floyd_warshall(G, weight="weight"): + """Find all-pairs shortest path lengths using Floyd's algorithm. + + Parameters + ---------- + G : NetworkX graph + + weight: string, optional (default= 'weight') + Edge data key corresponding to the edge weight. + + + Returns + ------- + distance : dict + A dictionary, keyed by source and target, of shortest paths distances + between nodes. + + Examples + -------- + >>> G = nx.DiGraph() + >>> G.add_weighted_edges_from( + ... [(0, 1, 5), (1, 2, 2), (2, 3, -3), (1, 3, 10), (3, 2, 8)] + ... ) + >>> fw = nx.floyd_warshall(G, weight="weight") + >>> results = {a: dict(b) for a, b in fw.items()} + >>> print(results) + {0: {0: 0, 1: 5, 2: 7, 3: 4}, 1: {1: 0, 2: 2, 3: -1, 0: inf}, 2: {2: 0, 3: -3, 0: inf, 1: inf}, 3: {3: 0, 2: 8, 0: inf, 1: inf}} + + Notes + ----- + Floyd's algorithm is appropriate for finding shortest paths + in dense graphs or graphs with negative weights when Dijkstra's algorithm + fails. This algorithm can still fail if there are negative cycles. + It has running time $O(n^3)$ with running space of $O(n^2)$. + + See Also + -------- + floyd_warshall_predecessor_and_distance + floyd_warshall_numpy + all_pairs_shortest_path + all_pairs_shortest_path_length + """ + # could make this its own function to reduce memory costs + return floyd_warshall_predecessor_and_distance(G, weight=weight)[1] diff --git a/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/generic.py b/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/generic.py new file mode 100644 index 0000000000000000000000000000000000000000..9ac48c90fe015402239e878529ee33ba6091743d --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/generic.py @@ -0,0 +1,730 @@ +""" +Compute the shortest paths and path lengths between nodes in the graph. + +These algorithms work with undirected and directed graphs. + +""" + +import warnings + +import networkx as nx + +__all__ = [ + "shortest_path", + "all_shortest_paths", + "single_source_all_shortest_paths", + "all_pairs_all_shortest_paths", + "shortest_path_length", + "average_shortest_path_length", + "has_path", +] + + +@nx._dispatchable +def has_path(G, source, target): + """Returns *True* if *G* has a path from *source* to *target*. + + Parameters + ---------- + G : NetworkX graph + + source : node + Starting node for path + + target : node + Ending node for path + """ + try: + nx.shortest_path(G, source, target) + except nx.NetworkXNoPath: + return False + return True + + +@nx._dispatchable(edge_attrs="weight") +def shortest_path(G, source=None, target=None, weight=None, method="dijkstra"): + """Compute shortest paths in the graph. + + Parameters + ---------- + G : NetworkX graph + + source : node, optional + Starting node for path. If not specified, compute shortest + paths for each possible starting node. + + target : node, optional + Ending node for path. If not specified, compute shortest + paths to all possible nodes. + + weight : None, string or function, optional (default = None) + If None, every edge has weight/distance/cost 1. + If a string, use this edge attribute as the edge weight. + Any edge attribute not present defaults to 1. + If this is a function, the weight of an edge is the value + returned by the function. The function must accept exactly + three positional arguments: the two endpoints of an edge and + the dictionary of edge attributes for that edge. + The function must return a number. + + method : string, optional (default = 'dijkstra') + The algorithm to use to compute the path. + Supported options: 'dijkstra', 'bellman-ford'. + Other inputs produce a ValueError. + If `weight` is None, unweighted graph methods are used, and this + suggestion is ignored. + + Returns + ------- + path: list or dictionary + All returned paths include both the source and target in the path. + + If the source and target are both specified, return a single list + of nodes in a shortest path from the source to the target. + + If only the source is specified, return a dictionary keyed by + targets with a list of nodes in a shortest path from the source + to one of the targets. + + If only the target is specified, return a dictionary keyed by + sources with a list of nodes in a shortest path from one of the + sources to the target. + + If neither the source nor target are specified return a dictionary + of dictionaries with path[source][target]=[list of nodes in path]. + + Raises + ------ + NodeNotFound + If `source` is not in `G`. + + ValueError + If `method` is not among the supported options. + + Examples + -------- + >>> G = nx.path_graph(5) + >>> print(nx.shortest_path(G, source=0, target=4)) + [0, 1, 2, 3, 4] + >>> p = nx.shortest_path(G, source=0) # target not specified + >>> p[3] # shortest path from source=0 to target=3 + [0, 1, 2, 3] + >>> p = nx.shortest_path(G, target=4) # source not specified + >>> p[1] # shortest path from source=1 to target=4 + [1, 2, 3, 4] + >>> p = dict(nx.shortest_path(G)) # source, target not specified + >>> p[2][4] # shortest path from source=2 to target=4 + [2, 3, 4] + + Notes + ----- + There may be more than one shortest path between a source and target. + This returns only one of them. + + See Also + -------- + all_pairs_shortest_path + all_pairs_dijkstra_path + all_pairs_bellman_ford_path + single_source_shortest_path + single_source_dijkstra_path + single_source_bellman_ford_path + """ + if method not in ("dijkstra", "bellman-ford"): + # so we don't need to check in each branch later + raise ValueError(f"method not supported: {method}") + method = "unweighted" if weight is None else method + if source is None: + if target is None: + warnings.warn( + ( + "\n\nshortest_path will return an iterator that yields\n" + "(node, path) pairs instead of a dictionary when source\n" + "and target are unspecified beginning in version 3.5\n\n" + "To keep the current behavior, use:\n\n" + "\tdict(nx.shortest_path(G))" + ), + FutureWarning, + stacklevel=3, + ) + + # Find paths between all pairs. + if method == "unweighted": + paths = dict(nx.all_pairs_shortest_path(G)) + elif method == "dijkstra": + paths = dict(nx.all_pairs_dijkstra_path(G, weight=weight)) + else: # method == 'bellman-ford': + paths = dict(nx.all_pairs_bellman_ford_path(G, weight=weight)) + else: + # Find paths from all nodes co-accessible to the target. + if G.is_directed(): + G = G.reverse(copy=False) + if method == "unweighted": + paths = nx.single_source_shortest_path(G, target) + elif method == "dijkstra": + paths = nx.single_source_dijkstra_path(G, target, weight=weight) + else: # method == 'bellman-ford': + paths = nx.single_source_bellman_ford_path(G, target, weight=weight) + # Now flip the paths so they go from a source to the target. + for target in paths: + paths[target] = list(reversed(paths[target])) + else: + if target is None: + # Find paths to all nodes accessible from the source. + if method == "unweighted": + paths = nx.single_source_shortest_path(G, source) + elif method == "dijkstra": + paths = nx.single_source_dijkstra_path(G, source, weight=weight) + else: # method == 'bellman-ford': + paths = nx.single_source_bellman_ford_path(G, source, weight=weight) + else: + # Find shortest source-target path. + if method == "unweighted": + paths = nx.bidirectional_shortest_path(G, source, target) + elif method == "dijkstra": + _, paths = nx.bidirectional_dijkstra(G, source, target, weight) + else: # method == 'bellman-ford': + paths = nx.bellman_ford_path(G, source, target, weight) + return paths + + +@nx._dispatchable(edge_attrs="weight") +def shortest_path_length(G, source=None, target=None, weight=None, method="dijkstra"): + """Compute shortest path lengths in the graph. + + Parameters + ---------- + G : NetworkX graph + + source : node, optional + Starting node for path. + If not specified, compute shortest path lengths using all nodes as + source nodes. + + target : node, optional + Ending node for path. + If not specified, compute shortest path lengths using all nodes as + target nodes. + + weight : None, string or function, optional (default = None) + If None, every edge has weight/distance/cost 1. + If a string, use this edge attribute as the edge weight. + Any edge attribute not present defaults to 1. + If this is a function, the weight of an edge is the value + returned by the function. The function must accept exactly + three positional arguments: the two endpoints of an edge and + the dictionary of edge attributes for that edge. + The function must return a number. + + method : string, optional (default = 'dijkstra') + The algorithm to use to compute the path length. + Supported options: 'dijkstra', 'bellman-ford'. + Other inputs produce a ValueError. + If `weight` is None, unweighted graph methods are used, and this + suggestion is ignored. + + Returns + ------- + length: number or iterator + If the source and target are both specified, return the length of + the shortest path from the source to the target. + + If only the source is specified, return a dict keyed by target + to the shortest path length from the source to that target. + + If only the target is specified, return a dict keyed by source + to the shortest path length from that source to the target. + + If neither the source nor target are specified, return an iterator + over (source, dictionary) where dictionary is keyed by target to + shortest path length from source to that target. + + Raises + ------ + NodeNotFound + If `source` is not in `G`. + + NetworkXNoPath + If no path exists between source and target. + + ValueError + If `method` is not among the supported options. + + Examples + -------- + >>> G = nx.path_graph(5) + >>> nx.shortest_path_length(G, source=0, target=4) + 4 + >>> p = nx.shortest_path_length(G, source=0) # target not specified + >>> p[4] + 4 + >>> p = nx.shortest_path_length(G, target=4) # source not specified + >>> p[0] + 4 + >>> p = dict(nx.shortest_path_length(G)) # source,target not specified + >>> p[0][4] + 4 + + Notes + ----- + The length of the path is always 1 less than the number of nodes involved + in the path since the length measures the number of edges followed. + + For digraphs this returns the shortest directed path length. To find path + lengths in the reverse direction use G.reverse(copy=False) first to flip + the edge orientation. + + See Also + -------- + all_pairs_shortest_path_length + all_pairs_dijkstra_path_length + all_pairs_bellman_ford_path_length + single_source_shortest_path_length + single_source_dijkstra_path_length + single_source_bellman_ford_path_length + """ + if method not in ("dijkstra", "bellman-ford"): + # so we don't need to check in each branch later + raise ValueError(f"method not supported: {method}") + method = "unweighted" if weight is None else method + if source is None: + if target is None: + # Find paths between all pairs. + if method == "unweighted": + paths = nx.all_pairs_shortest_path_length(G) + elif method == "dijkstra": + paths = nx.all_pairs_dijkstra_path_length(G, weight=weight) + else: # method == 'bellman-ford': + paths = nx.all_pairs_bellman_ford_path_length(G, weight=weight) + else: + # Find paths from all nodes co-accessible to the target. + if G.is_directed(): + G = G.reverse(copy=False) + if method == "unweighted": + path_length = nx.single_source_shortest_path_length + paths = path_length(G, target) + elif method == "dijkstra": + path_length = nx.single_source_dijkstra_path_length + paths = path_length(G, target, weight=weight) + else: # method == 'bellman-ford': + path_length = nx.single_source_bellman_ford_path_length + paths = path_length(G, target, weight=weight) + else: + if target is None: + # Find paths to all nodes accessible from the source. + if method == "unweighted": + paths = nx.single_source_shortest_path_length(G, source) + elif method == "dijkstra": + path_length = nx.single_source_dijkstra_path_length + paths = path_length(G, source, weight=weight) + else: # method == 'bellman-ford': + path_length = nx.single_source_bellman_ford_path_length + paths = path_length(G, source, weight=weight) + else: + # Find shortest source-target path. + if method == "unweighted": + p = nx.bidirectional_shortest_path(G, source, target) + paths = len(p) - 1 + elif method == "dijkstra": + paths = nx.dijkstra_path_length(G, source, target, weight) + else: # method == 'bellman-ford': + paths = nx.bellman_ford_path_length(G, source, target, weight) + return paths + + +@nx._dispatchable(edge_attrs="weight") +def average_shortest_path_length(G, weight=None, method=None): + r"""Returns the average shortest path length. + + The average shortest path length is + + .. math:: + + a =\sum_{\substack{s,t \in V \\ s\neq t}} \frac{d(s, t)}{n(n-1)} + + where `V` is the set of nodes in `G`, + `d(s, t)` is the shortest path from `s` to `t`, + and `n` is the number of nodes in `G`. + + .. versionchanged:: 3.0 + An exception is raised for directed graphs that are not strongly + connected. + + Parameters + ---------- + G : NetworkX graph + + weight : None, string or function, optional (default = None) + If None, every edge has weight/distance/cost 1. + If a string, use this edge attribute as the edge weight. + Any edge attribute not present defaults to 1. + If this is a function, the weight of an edge is the value + returned by the function. The function must accept exactly + three positional arguments: the two endpoints of an edge and + the dictionary of edge attributes for that edge. + The function must return a number. + + method : string, optional (default = 'unweighted' or 'dijkstra') + The algorithm to use to compute the path lengths. + Supported options are 'unweighted', 'dijkstra', 'bellman-ford', + 'floyd-warshall' and 'floyd-warshall-numpy'. + Other method values produce a ValueError. + The default method is 'unweighted' if `weight` is None, + otherwise the default method is 'dijkstra'. + + Raises + ------ + NetworkXPointlessConcept + If `G` is the null graph (that is, the graph on zero nodes). + + NetworkXError + If `G` is not connected (or not strongly connected, in the case + of a directed graph). + + ValueError + If `method` is not among the supported options. + + Examples + -------- + >>> G = nx.path_graph(5) + >>> nx.average_shortest_path_length(G) + 2.0 + + For disconnected graphs, you can compute the average shortest path + length for each component + + >>> G = nx.Graph([(1, 2), (3, 4)]) + >>> for C in (G.subgraph(c).copy() for c in nx.connected_components(G)): + ... print(nx.average_shortest_path_length(C)) + 1.0 + 1.0 + + """ + single_source_methods = ["unweighted", "dijkstra", "bellman-ford"] + all_pairs_methods = ["floyd-warshall", "floyd-warshall-numpy"] + supported_methods = single_source_methods + all_pairs_methods + + if method is None: + method = "unweighted" if weight is None else "dijkstra" + if method not in supported_methods: + raise ValueError(f"method not supported: {method}") + + n = len(G) + # For the special case of the null graph, raise an exception, since + # there are no paths in the null graph. + if n == 0: + msg = ( + "the null graph has no paths, thus there is no average " + "shortest path length" + ) + raise nx.NetworkXPointlessConcept(msg) + # For the special case of the trivial graph, return zero immediately. + if n == 1: + return 0 + # Shortest path length is undefined if the graph is not strongly connected. + if G.is_directed() and not nx.is_strongly_connected(G): + raise nx.NetworkXError("Graph is not strongly connected.") + # Shortest path length is undefined if the graph is not connected. + if not G.is_directed() and not nx.is_connected(G): + raise nx.NetworkXError("Graph is not connected.") + + # Compute all-pairs shortest paths. + def path_length(v): + if method == "unweighted": + return nx.single_source_shortest_path_length(G, v) + elif method == "dijkstra": + return nx.single_source_dijkstra_path_length(G, v, weight=weight) + elif method == "bellman-ford": + return nx.single_source_bellman_ford_path_length(G, v, weight=weight) + + if method in single_source_methods: + # Sum the distances for each (ordered) pair of source and target node. + s = sum(l for u in G for l in path_length(u).values()) + else: + if method == "floyd-warshall": + all_pairs = nx.floyd_warshall(G, weight=weight) + s = sum(sum(t.values()) for t in all_pairs.values()) + elif method == "floyd-warshall-numpy": + s = float(nx.floyd_warshall_numpy(G, weight=weight).sum()) + return s / (n * (n - 1)) + + +@nx._dispatchable(edge_attrs="weight") +def all_shortest_paths(G, source, target, weight=None, method="dijkstra"): + """Compute all shortest simple paths in the graph. + + Parameters + ---------- + G : NetworkX graph + + source : node + Starting node for path. + + target : node + Ending node for path. + + weight : None, string or function, optional (default = None) + If None, every edge has weight/distance/cost 1. + If a string, use this edge attribute as the edge weight. + Any edge attribute not present defaults to 1. + If this is a function, the weight of an edge is the value + returned by the function. The function must accept exactly + three positional arguments: the two endpoints of an edge and + the dictionary of edge attributes for that edge. + The function must return a number. + + method : string, optional (default = 'dijkstra') + The algorithm to use to compute the path lengths. + Supported options: 'dijkstra', 'bellman-ford'. + Other inputs produce a ValueError. + If `weight` is None, unweighted graph methods are used, and this + suggestion is ignored. + + Returns + ------- + paths : generator of lists + A generator of all paths between source and target. + + Raises + ------ + ValueError + If `method` is not among the supported options. + + NetworkXNoPath + If `target` cannot be reached from `source`. + + Examples + -------- + >>> G = nx.Graph() + >>> nx.add_path(G, [0, 1, 2]) + >>> nx.add_path(G, [0, 10, 2]) + >>> print([p for p in nx.all_shortest_paths(G, source=0, target=2)]) + [[0, 1, 2], [0, 10, 2]] + + Notes + ----- + There may be many shortest paths between the source and target. If G + contains zero-weight cycles, this function will not produce all shortest + paths because doing so would produce infinitely many paths of unbounded + length -- instead, we only produce the shortest simple paths. + + See Also + -------- + shortest_path + single_source_shortest_path + all_pairs_shortest_path + """ + method = "unweighted" if weight is None else method + if method == "unweighted": + pred = nx.predecessor(G, source) + elif method == "dijkstra": + pred, dist = nx.dijkstra_predecessor_and_distance(G, source, weight=weight) + elif method == "bellman-ford": + pred, dist = nx.bellman_ford_predecessor_and_distance(G, source, weight=weight) + else: + raise ValueError(f"method not supported: {method}") + + return _build_paths_from_predecessors({source}, target, pred) + + +@nx._dispatchable(edge_attrs="weight") +def single_source_all_shortest_paths(G, source, weight=None, method="dijkstra"): + """Compute all shortest simple paths from the given source in the graph. + + Parameters + ---------- + G : NetworkX graph + + source : node + Starting node for path. + + weight : None, string or function, optional (default = None) + If None, every edge has weight/distance/cost 1. + If a string, use this edge attribute as the edge weight. + Any edge attribute not present defaults to 1. + If this is a function, the weight of an edge is the value + returned by the function. The function must accept exactly + three positional arguments: the two endpoints of an edge and + the dictionary of edge attributes for that edge. + The function must return a number. + + method : string, optional (default = 'dijkstra') + The algorithm to use to compute the path lengths. + Supported options: 'dijkstra', 'bellman-ford'. + Other inputs produce a ValueError. + If `weight` is None, unweighted graph methods are used, and this + suggestion is ignored. + + Returns + ------- + paths : generator of dictionary + A generator of all paths between source and all nodes in the graph. + + Raises + ------ + ValueError + If `method` is not among the supported options. + + Examples + -------- + >>> G = nx.Graph() + >>> nx.add_path(G, [0, 1, 2, 3, 0]) + >>> dict(nx.single_source_all_shortest_paths(G, source=0)) + {0: [[0]], 1: [[0, 1]], 2: [[0, 1, 2], [0, 3, 2]], 3: [[0, 3]]} + + Notes + ----- + There may be many shortest paths between the source and target. If G + contains zero-weight cycles, this function will not produce all shortest + paths because doing so would produce infinitely many paths of unbounded + length -- instead, we only produce the shortest simple paths. + + See Also + -------- + shortest_path + all_shortest_paths + single_source_shortest_path + all_pairs_shortest_path + all_pairs_all_shortest_paths + """ + method = "unweighted" if weight is None else method + if method == "unweighted": + pred = nx.predecessor(G, source) + elif method == "dijkstra": + pred, dist = nx.dijkstra_predecessor_and_distance(G, source, weight=weight) + elif method == "bellman-ford": + pred, dist = nx.bellman_ford_predecessor_and_distance(G, source, weight=weight) + else: + raise ValueError(f"method not supported: {method}") + for n in G: + try: + yield n, list(_build_paths_from_predecessors({source}, n, pred)) + except nx.NetworkXNoPath: + pass + + +@nx._dispatchable(edge_attrs="weight") +def all_pairs_all_shortest_paths(G, weight=None, method="dijkstra"): + """Compute all shortest paths between all nodes. + + Parameters + ---------- + G : NetworkX graph + + weight : None, string or function, optional (default = None) + If None, every edge has weight/distance/cost 1. + If a string, use this edge attribute as the edge weight. + Any edge attribute not present defaults to 1. + If this is a function, the weight of an edge is the value + returned by the function. The function must accept exactly + three positional arguments: the two endpoints of an edge and + the dictionary of edge attributes for that edge. + The function must return a number. + + method : string, optional (default = 'dijkstra') + The algorithm to use to compute the path lengths. + Supported options: 'dijkstra', 'bellman-ford'. + Other inputs produce a ValueError. + If `weight` is None, unweighted graph methods are used, and this + suggestion is ignored. + + Returns + ------- + paths : generator of dictionary + Dictionary of arrays, keyed by source and target, of all shortest paths. + + Raises + ------ + ValueError + If `method` is not among the supported options. + + Examples + -------- + >>> G = nx.cycle_graph(4) + >>> dict(nx.all_pairs_all_shortest_paths(G))[0][2] + [[0, 1, 2], [0, 3, 2]] + >>> dict(nx.all_pairs_all_shortest_paths(G))[0][3] + [[0, 3]] + + Notes + ----- + There may be multiple shortest paths with equal lengths. Unlike + all_pairs_shortest_path, this method returns all shortest paths. + + See Also + -------- + all_pairs_shortest_path + single_source_all_shortest_paths + """ + for n in G: + yield ( + n, + dict(single_source_all_shortest_paths(G, n, weight=weight, method=method)), + ) + + +def _build_paths_from_predecessors(sources, target, pred): + """Compute all simple paths to target, given the predecessors found in + pred, terminating when any source in sources is found. + + Parameters + ---------- + sources : set + Starting nodes for path. + + target : node + Ending node for path. + + pred : dict + A dictionary of predecessor lists, keyed by node + + Returns + ------- + paths : generator of lists + A generator of all paths between source and target. + + Raises + ------ + NetworkXNoPath + If `target` cannot be reached from `source`. + + Notes + ----- + There may be many paths between the sources and target. If there are + cycles among the predecessors, this function will not produce all + possible paths because doing so would produce infinitely many paths + of unbounded length -- instead, we only produce simple paths. + + See Also + -------- + shortest_path + single_source_shortest_path + all_pairs_shortest_path + all_shortest_paths + bellman_ford_path + """ + if target not in pred: + raise nx.NetworkXNoPath(f"Target {target} cannot be reached from given sources") + + seen = {target} + stack = [[target, 0]] + top = 0 + while top >= 0: + node, i = stack[top] + if node in sources: + yield [p for p, n in reversed(stack[: top + 1])] + if len(pred[node]) > i: + stack[top][1] = i + 1 + next = pred[node][i] + if next in seen: + continue + else: + seen.add(next) + top += 1 + if top == len(stack): + stack.append([next, 0]) + else: + stack[top][:] = [next, 0] + else: + seen.discard(node) + top -= 1 diff --git a/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/tests/__init__.py b/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/tests/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/tests/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f22e13e512eea13266e5fb40ba42e823bae3d769 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/tests/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/tests/__pycache__/test_astar.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/tests/__pycache__/test_astar.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..913e225f91a28e9c8e171f942117e8e22e0775e8 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/tests/__pycache__/test_astar.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/tests/__pycache__/test_dense.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/tests/__pycache__/test_dense.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ed00d8684fb4a4a336ea81d36a8e11ce36ac01c Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/tests/__pycache__/test_dense.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/tests/__pycache__/test_dense_numpy.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/tests/__pycache__/test_dense_numpy.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..29fe994ec2ece92dbcdb451f8d0666b51c0a81cc Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/tests/__pycache__/test_dense_numpy.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/tests/__pycache__/test_generic.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/tests/__pycache__/test_generic.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d90ee4a61111a8a9b8478d68b00954a44820fc0c Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/tests/__pycache__/test_generic.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/tests/__pycache__/test_unweighted.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/tests/__pycache__/test_unweighted.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..84a024901da60c8a2b695e47457cd5f764037838 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/tests/__pycache__/test_unweighted.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/tests/__pycache__/test_weighted.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/tests/__pycache__/test_weighted.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ed0537563b3a345f6cb1f50f304b755471878a5 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/tests/__pycache__/test_weighted.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/tests/test_astar.py b/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/tests/test_astar.py new file mode 100644 index 0000000000000000000000000000000000000000..40a7d4e86e9909a48ae2f3b5c7210bd63a3ae993 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/tests/test_astar.py @@ -0,0 +1,248 @@ +import pytest + +import networkx as nx +from networkx.utils import pairwise + + +class TestAStar: + @classmethod + def setup_class(cls): + edges = [ + ("s", "u", 10), + ("s", "x", 5), + ("u", "v", 1), + ("u", "x", 2), + ("v", "y", 1), + ("x", "u", 3), + ("x", "v", 5), + ("x", "y", 2), + ("y", "s", 7), + ("y", "v", 6), + ] + cls.XG = nx.DiGraph() + cls.XG.add_weighted_edges_from(edges) + + def test_multiple_optimal_paths(self): + """Tests that A* algorithm finds any of multiple optimal paths""" + heuristic_values = {"a": 1.35, "b": 1.18, "c": 0.67, "d": 0} + + def h(u, v): + return heuristic_values[u] + + graph = nx.Graph() + points = ["a", "b", "c", "d"] + edges = [("a", "b", 0.18), ("a", "c", 0.68), ("b", "c", 0.50), ("c", "d", 0.67)] + + graph.add_nodes_from(points) + graph.add_weighted_edges_from(edges) + + path1 = ["a", "c", "d"] + path2 = ["a", "b", "c", "d"] + assert nx.astar_path(graph, "a", "d", h) in (path1, path2) + + def test_astar_directed(self): + assert nx.astar_path(self.XG, "s", "v") == ["s", "x", "u", "v"] + assert nx.astar_path_length(self.XG, "s", "v") == 9 + + def test_astar_directed_weight_function(self): + w1 = lambda u, v, d: d["weight"] + assert nx.astar_path(self.XG, "x", "u", weight=w1) == ["x", "u"] + assert nx.astar_path_length(self.XG, "x", "u", weight=w1) == 3 + assert nx.astar_path(self.XG, "s", "v", weight=w1) == ["s", "x", "u", "v"] + assert nx.astar_path_length(self.XG, "s", "v", weight=w1) == 9 + + w2 = lambda u, v, d: None if (u, v) == ("x", "u") else d["weight"] + assert nx.astar_path(self.XG, "x", "u", weight=w2) == ["x", "y", "s", "u"] + assert nx.astar_path_length(self.XG, "x", "u", weight=w2) == 19 + assert nx.astar_path(self.XG, "s", "v", weight=w2) == ["s", "x", "v"] + assert nx.astar_path_length(self.XG, "s", "v", weight=w2) == 10 + + w3 = lambda u, v, d: d["weight"] + 10 + assert nx.astar_path(self.XG, "x", "u", weight=w3) == ["x", "u"] + assert nx.astar_path_length(self.XG, "x", "u", weight=w3) == 13 + assert nx.astar_path(self.XG, "s", "v", weight=w3) == ["s", "x", "v"] + assert nx.astar_path_length(self.XG, "s", "v", weight=w3) == 30 + + def test_astar_multigraph(self): + G = nx.MultiDiGraph(self.XG) + G.add_weighted_edges_from((u, v, 1000) for (u, v) in list(G.edges())) + assert nx.astar_path(G, "s", "v") == ["s", "x", "u", "v"] + assert nx.astar_path_length(G, "s", "v") == 9 + + def test_astar_undirected(self): + GG = self.XG.to_undirected() + # make sure we get lower weight + # to_undirected might choose either edge with weight 2 or weight 3 + GG["u"]["x"]["weight"] = 2 + GG["y"]["v"]["weight"] = 2 + assert nx.astar_path(GG, "s", "v") == ["s", "x", "u", "v"] + assert nx.astar_path_length(GG, "s", "v") == 8 + + def test_astar_directed2(self): + XG2 = nx.DiGraph() + edges = [ + (1, 4, 1), + (4, 5, 1), + (5, 6, 1), + (6, 3, 1), + (1, 3, 50), + (1, 2, 100), + (2, 3, 100), + ] + XG2.add_weighted_edges_from(edges) + assert nx.astar_path(XG2, 1, 3) == [1, 4, 5, 6, 3] + + def test_astar_undirected2(self): + XG3 = nx.Graph() + edges = [(0, 1, 2), (1, 2, 12), (2, 3, 1), (3, 4, 5), (4, 5, 1), (5, 0, 10)] + XG3.add_weighted_edges_from(edges) + assert nx.astar_path(XG3, 0, 3) == [0, 1, 2, 3] + assert nx.astar_path_length(XG3, 0, 3) == 15 + + def test_astar_undirected3(self): + XG4 = nx.Graph() + edges = [ + (0, 1, 2), + (1, 2, 2), + (2, 3, 1), + (3, 4, 1), + (4, 5, 1), + (5, 6, 1), + (6, 7, 1), + (7, 0, 1), + ] + XG4.add_weighted_edges_from(edges) + assert nx.astar_path(XG4, 0, 2) == [0, 1, 2] + assert nx.astar_path_length(XG4, 0, 2) == 4 + + """ Tests that A* finds correct path when multiple paths exist + and the best one is not expanded first (GH issue #3464) + """ + + def test_astar_directed3(self): + heuristic_values = {"n5": 36, "n2": 4, "n1": 0, "n0": 0} + + def h(u, v): + return heuristic_values[u] + + edges = [("n5", "n1", 11), ("n5", "n2", 9), ("n2", "n1", 1), ("n1", "n0", 32)] + graph = nx.DiGraph() + graph.add_weighted_edges_from(edges) + answer = ["n5", "n2", "n1", "n0"] + assert nx.astar_path(graph, "n5", "n0", h) == answer + + """ Tests that parent is not wrongly overridden when a node + is re-explored multiple times. + """ + + def test_astar_directed4(self): + edges = [ + ("a", "b", 1), + ("a", "c", 1), + ("b", "d", 2), + ("c", "d", 1), + ("d", "e", 1), + ] + graph = nx.DiGraph() + graph.add_weighted_edges_from(edges) + assert nx.astar_path(graph, "a", "e") == ["a", "c", "d", "e"] + + # >>> MXG4=NX.MultiGraph(XG4) + # >>> MXG4.add_edge(0,1,3) + # >>> NX.dijkstra_path(MXG4,0,2) + # [0, 1, 2] + + def test_astar_w1(self): + G = nx.DiGraph() + G.add_edges_from( + [ + ("s", "u"), + ("s", "x"), + ("u", "v"), + ("u", "x"), + ("v", "y"), + ("x", "u"), + ("x", "w"), + ("w", "v"), + ("x", "y"), + ("y", "s"), + ("y", "v"), + ] + ) + assert nx.astar_path(G, "s", "v") == ["s", "u", "v"] + assert nx.astar_path_length(G, "s", "v") == 2 + + def test_astar_nopath(self): + with pytest.raises(nx.NodeNotFound): + nx.astar_path(self.XG, "s", "moon") + + def test_astar_cutoff(self): + with pytest.raises(nx.NetworkXNoPath): + # optimal path_length in XG is 9 + nx.astar_path(self.XG, "s", "v", cutoff=8.0) + with pytest.raises(nx.NetworkXNoPath): + nx.astar_path_length(self.XG, "s", "v", cutoff=8.0) + + def test_astar_admissible_heuristic_with_cutoff(self): + heuristic_values = {"s": 36, "y": 4, "x": 0, "u": 0, "v": 0} + + def h(u, v): + return heuristic_values[u] + + assert nx.astar_path_length(self.XG, "s", "v") == 9 + assert nx.astar_path_length(self.XG, "s", "v", heuristic=h) == 9 + assert nx.astar_path_length(self.XG, "s", "v", heuristic=h, cutoff=12) == 9 + assert nx.astar_path_length(self.XG, "s", "v", heuristic=h, cutoff=9) == 9 + with pytest.raises(nx.NetworkXNoPath): + nx.astar_path_length(self.XG, "s", "v", heuristic=h, cutoff=8) + + def test_astar_inadmissible_heuristic_with_cutoff(self): + heuristic_values = {"s": 36, "y": 14, "x": 10, "u": 10, "v": 0} + + def h(u, v): + return heuristic_values[u] + + # optimal path_length in XG is 9. This heuristic gives over-estimate. + assert nx.astar_path_length(self.XG, "s", "v", heuristic=h) == 10 + assert nx.astar_path_length(self.XG, "s", "v", heuristic=h, cutoff=15) == 10 + with pytest.raises(nx.NetworkXNoPath): + nx.astar_path_length(self.XG, "s", "v", heuristic=h, cutoff=9) + with pytest.raises(nx.NetworkXNoPath): + nx.astar_path_length(self.XG, "s", "v", heuristic=h, cutoff=12) + + def test_astar_cutoff2(self): + assert nx.astar_path(self.XG, "s", "v", cutoff=10.0) == ["s", "x", "u", "v"] + assert nx.astar_path_length(self.XG, "s", "v") == 9 + + def test_cycle(self): + C = nx.cycle_graph(7) + assert nx.astar_path(C, 0, 3) == [0, 1, 2, 3] + assert nx.dijkstra_path(C, 0, 4) == [0, 6, 5, 4] + + def test_unorderable_nodes(self): + """Tests that A* accommodates nodes that are not orderable. + + For more information, see issue #554. + + """ + # Create the cycle graph on four nodes, with nodes represented + # as (unorderable) Python objects. + nodes = [object() for n in range(4)] + G = nx.Graph() + G.add_edges_from(pairwise(nodes, cyclic=True)) + path = nx.astar_path(G, nodes[0], nodes[2]) + assert len(path) == 3 + + def test_astar_NetworkXNoPath(self): + """Tests that exception is raised when there exists no + path between source and target""" + G = nx.gnp_random_graph(10, 0.2, seed=10) + with pytest.raises(nx.NetworkXNoPath): + nx.astar_path(G, 4, 9) + + def test_astar_NodeNotFound(self): + """Tests that exception is raised when either + source or target is not in graph""" + G = nx.gnp_random_graph(10, 0.2, seed=10) + with pytest.raises(nx.NodeNotFound): + nx.astar_path_length(G, 11, 9) diff --git a/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/tests/test_dense.py b/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/tests/test_dense.py new file mode 100644 index 0000000000000000000000000000000000000000..6923bfef856c83bd3e65573b97fe96ff16cdbc71 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/tests/test_dense.py @@ -0,0 +1,212 @@ +import pytest + +import networkx as nx + + +class TestFloyd: + @classmethod + def setup_class(cls): + pass + + def test_floyd_warshall_predecessor_and_distance(self): + XG = nx.DiGraph() + XG.add_weighted_edges_from( + [ + ("s", "u", 10), + ("s", "x", 5), + ("u", "v", 1), + ("u", "x", 2), + ("v", "y", 1), + ("x", "u", 3), + ("x", "v", 5), + ("x", "y", 2), + ("y", "s", 7), + ("y", "v", 6), + ] + ) + path, dist = nx.floyd_warshall_predecessor_and_distance(XG) + assert dist["s"]["v"] == 9 + assert path["s"]["v"] == "u" + assert dist == { + "y": {"y": 0, "x": 12, "s": 7, "u": 15, "v": 6}, + "x": {"y": 2, "x": 0, "s": 9, "u": 3, "v": 4}, + "s": {"y": 7, "x": 5, "s": 0, "u": 8, "v": 9}, + "u": {"y": 2, "x": 2, "s": 9, "u": 0, "v": 1}, + "v": {"y": 1, "x": 13, "s": 8, "u": 16, "v": 0}, + } + + GG = XG.to_undirected() + # make sure we get lower weight + # to_undirected might choose either edge with weight 2 or weight 3 + GG["u"]["x"]["weight"] = 2 + path, dist = nx.floyd_warshall_predecessor_and_distance(GG) + assert dist["s"]["v"] == 8 + # skip this test, could be alternate path s-u-v + # assert_equal(path['s']['v'],'y') + + G = nx.DiGraph() # no weights + G.add_edges_from( + [ + ("s", "u"), + ("s", "x"), + ("u", "v"), + ("u", "x"), + ("v", "y"), + ("x", "u"), + ("x", "v"), + ("x", "y"), + ("y", "s"), + ("y", "v"), + ] + ) + path, dist = nx.floyd_warshall_predecessor_and_distance(G) + assert dist["s"]["v"] == 2 + # skip this test, could be alternate path s-u-v + # assert_equal(path['s']['v'],'x') + + # alternate interface + dist = nx.floyd_warshall(G) + assert dist["s"]["v"] == 2 + + # floyd_warshall_predecessor_and_distance returns + # dicts-of-defautdicts + # make sure we don't get empty dictionary + XG = nx.DiGraph() + XG.add_weighted_edges_from( + [("v", "x", 5.0), ("y", "x", 5.0), ("v", "y", 6.0), ("x", "u", 2.0)] + ) + path, dist = nx.floyd_warshall_predecessor_and_distance(XG) + inf = float("inf") + assert dist == { + "v": {"v": 0, "x": 5.0, "y": 6.0, "u": 7.0}, + "x": {"x": 0, "u": 2.0, "v": inf, "y": inf}, + "y": {"y": 0, "x": 5.0, "v": inf, "u": 7.0}, + "u": {"u": 0, "v": inf, "x": inf, "y": inf}, + } + assert path == { + "v": {"x": "v", "y": "v", "u": "x"}, + "x": {"u": "x"}, + "y": {"x": "y", "u": "x"}, + } + + def test_reconstruct_path(self): + with pytest.raises(KeyError): + XG = nx.DiGraph() + XG.add_weighted_edges_from( + [ + ("s", "u", 10), + ("s", "x", 5), + ("u", "v", 1), + ("u", "x", 2), + ("v", "y", 1), + ("x", "u", 3), + ("x", "v", 5), + ("x", "y", 2), + ("y", "s", 7), + ("y", "v", 6), + ] + ) + predecessors, _ = nx.floyd_warshall_predecessor_and_distance(XG) + + path = nx.reconstruct_path("s", "v", predecessors) + assert path == ["s", "x", "u", "v"] + + path = nx.reconstruct_path("s", "s", predecessors) + assert path == [] + + # this part raises the keyError + nx.reconstruct_path("1", "2", predecessors) + + def test_cycle(self): + path, dist = nx.floyd_warshall_predecessor_and_distance(nx.cycle_graph(7)) + assert dist[0][3] == 3 + assert path[0][3] == 2 + assert dist[0][4] == 3 + + def test_weighted(self): + XG3 = nx.Graph() + XG3.add_weighted_edges_from( + [[0, 1, 2], [1, 2, 12], [2, 3, 1], [3, 4, 5], [4, 5, 1], [5, 0, 10]] + ) + path, dist = nx.floyd_warshall_predecessor_and_distance(XG3) + assert dist[0][3] == 15 + assert path[0][3] == 2 + + def test_weighted2(self): + XG4 = nx.Graph() + XG4.add_weighted_edges_from( + [ + [0, 1, 2], + [1, 2, 2], + [2, 3, 1], + [3, 4, 1], + [4, 5, 1], + [5, 6, 1], + [6, 7, 1], + [7, 0, 1], + ] + ) + path, dist = nx.floyd_warshall_predecessor_and_distance(XG4) + assert dist[0][2] == 4 + assert path[0][2] == 1 + + def test_weight_parameter(self): + XG4 = nx.Graph() + XG4.add_edges_from( + [ + (0, 1, {"heavy": 2}), + (1, 2, {"heavy": 2}), + (2, 3, {"heavy": 1}), + (3, 4, {"heavy": 1}), + (4, 5, {"heavy": 1}), + (5, 6, {"heavy": 1}), + (6, 7, {"heavy": 1}), + (7, 0, {"heavy": 1}), + ] + ) + path, dist = nx.floyd_warshall_predecessor_and_distance(XG4, weight="heavy") + assert dist[0][2] == 4 + assert path[0][2] == 1 + + def test_zero_distance(self): + XG = nx.DiGraph() + XG.add_weighted_edges_from( + [ + ("s", "u", 10), + ("s", "x", 5), + ("u", "v", 1), + ("u", "x", 2), + ("v", "y", 1), + ("x", "u", 3), + ("x", "v", 5), + ("x", "y", 2), + ("y", "s", 7), + ("y", "v", 6), + ] + ) + path, dist = nx.floyd_warshall_predecessor_and_distance(XG) + + for u in XG: + assert dist[u][u] == 0 + + GG = XG.to_undirected() + # make sure we get lower weight + # to_undirected might choose either edge with weight 2 or weight 3 + GG["u"]["x"]["weight"] = 2 + path, dist = nx.floyd_warshall_predecessor_and_distance(GG) + + for u in GG: + dist[u][u] = 0 + + def test_zero_weight(self): + G = nx.DiGraph() + edges = [(1, 2, -2), (2, 3, -4), (1, 5, 1), (5, 4, 0), (4, 3, -5), (2, 5, -7)] + G.add_weighted_edges_from(edges) + dist = nx.floyd_warshall(G) + assert dist[1][3] == -14 + + G = nx.MultiDiGraph() + edges.append((2, 5, -7)) + G.add_weighted_edges_from(edges) + dist = nx.floyd_warshall(G) + assert dist[1][3] == -14 diff --git a/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/tests/test_dense_numpy.py b/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/tests/test_dense_numpy.py new file mode 100644 index 0000000000000000000000000000000000000000..1316e23e654a775e949cdbd34a86a474597f993a --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/tests/test_dense_numpy.py @@ -0,0 +1,89 @@ +import pytest + +np = pytest.importorskip("numpy") + + +import networkx as nx + + +def test_cycle_numpy(): + dist = nx.floyd_warshall_numpy(nx.cycle_graph(7)) + assert dist[0, 3] == 3 + assert dist[0, 4] == 3 + + +def test_weighted_numpy_three_edges(): + XG3 = nx.Graph() + XG3.add_weighted_edges_from( + [[0, 1, 2], [1, 2, 12], [2, 3, 1], [3, 4, 5], [4, 5, 1], [5, 0, 10]] + ) + dist = nx.floyd_warshall_numpy(XG3) + assert dist[0, 3] == 15 + + +def test_weighted_numpy_two_edges(): + XG4 = nx.Graph() + XG4.add_weighted_edges_from( + [ + [0, 1, 2], + [1, 2, 2], + [2, 3, 1], + [3, 4, 1], + [4, 5, 1], + [5, 6, 1], + [6, 7, 1], + [7, 0, 1], + ] + ) + dist = nx.floyd_warshall_numpy(XG4) + assert dist[0, 2] == 4 + + +def test_weight_parameter_numpy(): + XG4 = nx.Graph() + XG4.add_edges_from( + [ + (0, 1, {"heavy": 2}), + (1, 2, {"heavy": 2}), + (2, 3, {"heavy": 1}), + (3, 4, {"heavy": 1}), + (4, 5, {"heavy": 1}), + (5, 6, {"heavy": 1}), + (6, 7, {"heavy": 1}), + (7, 0, {"heavy": 1}), + ] + ) + dist = nx.floyd_warshall_numpy(XG4, weight="heavy") + assert dist[0, 2] == 4 + + +def test_directed_cycle_numpy(): + G = nx.DiGraph() + nx.add_cycle(G, [0, 1, 2, 3]) + pred, dist = nx.floyd_warshall_predecessor_and_distance(G) + D = nx.utils.dict_to_numpy_array(dist) + np.testing.assert_equal(nx.floyd_warshall_numpy(G), D) + + +def test_zero_weight(): + G = nx.DiGraph() + edges = [(1, 2, -2), (2, 3, -4), (1, 5, 1), (5, 4, 0), (4, 3, -5), (2, 5, -7)] + G.add_weighted_edges_from(edges) + dist = nx.floyd_warshall_numpy(G) + assert int(np.min(dist)) == -14 + + G = nx.MultiDiGraph() + edges.append((2, 5, -7)) + G.add_weighted_edges_from(edges) + dist = nx.floyd_warshall_numpy(G) + assert int(np.min(dist)) == -14 + + +def test_nodelist(): + G = nx.path_graph(7) + dist = nx.floyd_warshall_numpy(G, nodelist=[3, 5, 4, 6, 2, 1, 0]) + assert dist[0, 3] == 3 + assert dist[0, 1] == 2 + assert dist[6, 2] == 4 + pytest.raises(nx.NetworkXError, nx.floyd_warshall_numpy, G, [1, 3]) + pytest.raises(nx.NetworkXError, nx.floyd_warshall_numpy, G, list(range(9))) diff --git a/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/tests/test_generic.py b/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/tests/test_generic.py new file mode 100644 index 0000000000000000000000000000000000000000..e30de51771eb8654b9f31c283306b207d80e8bce --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/tests/test_generic.py @@ -0,0 +1,450 @@ +import pytest + +import networkx as nx + + +def validate_grid_path(r, c, s, t, p): + assert isinstance(p, list) + assert p[0] == s + assert p[-1] == t + s = ((s - 1) // c, (s - 1) % c) + t = ((t - 1) // c, (t - 1) % c) + assert len(p) == abs(t[0] - s[0]) + abs(t[1] - s[1]) + 1 + p = [((u - 1) // c, (u - 1) % c) for u in p] + for u in p: + assert 0 <= u[0] < r + assert 0 <= u[1] < c + for u, v in zip(p[:-1], p[1:]): + assert (abs(v[0] - u[0]), abs(v[1] - u[1])) in [(0, 1), (1, 0)] + + +class TestGenericPath: + @classmethod + def setup_class(cls): + from networkx import convert_node_labels_to_integers as cnlti + + cls.grid = cnlti(nx.grid_2d_graph(4, 4), first_label=1, ordering="sorted") + cls.cycle = nx.cycle_graph(7) + cls.directed_cycle = nx.cycle_graph(7, create_using=nx.DiGraph()) + cls.neg_weights = nx.DiGraph() + cls.neg_weights.add_edge(0, 1, weight=1) + cls.neg_weights.add_edge(0, 2, weight=3) + cls.neg_weights.add_edge(1, 3, weight=1) + cls.neg_weights.add_edge(2, 3, weight=-2) + + def test_shortest_path(self): + assert nx.shortest_path(self.cycle, 0, 3) == [0, 1, 2, 3] + assert nx.shortest_path(self.cycle, 0, 4) == [0, 6, 5, 4] + validate_grid_path(4, 4, 1, 12, nx.shortest_path(self.grid, 1, 12)) + assert nx.shortest_path(self.directed_cycle, 0, 3) == [0, 1, 2, 3] + # now with weights + assert nx.shortest_path(self.cycle, 0, 3, weight="weight") == [0, 1, 2, 3] + assert nx.shortest_path(self.cycle, 0, 4, weight="weight") == [0, 6, 5, 4] + validate_grid_path( + 4, 4, 1, 12, nx.shortest_path(self.grid, 1, 12, weight="weight") + ) + assert nx.shortest_path(self.directed_cycle, 0, 3, weight="weight") == [ + 0, + 1, + 2, + 3, + ] + # weights and method specified + assert nx.shortest_path( + self.directed_cycle, 0, 3, weight="weight", method="dijkstra" + ) == [0, 1, 2, 3] + assert nx.shortest_path( + self.directed_cycle, 0, 3, weight="weight", method="bellman-ford" + ) == [0, 1, 2, 3] + # when Dijkstra's will probably (depending on precise implementation) + # incorrectly return [0, 1, 3] instead + assert nx.shortest_path( + self.neg_weights, 0, 3, weight="weight", method="bellman-ford" + ) == [0, 2, 3] + # confirm bad method rejection + pytest.raises(ValueError, nx.shortest_path, self.cycle, method="SPAM") + # confirm absent source rejection + pytest.raises(nx.NodeNotFound, nx.shortest_path, self.cycle, 8) + + def test_shortest_path_target(self): + answer = {0: [0, 1], 1: [1], 2: [2, 1]} + sp = nx.shortest_path(nx.path_graph(3), target=1) + assert sp == answer + # with weights + sp = nx.shortest_path(nx.path_graph(3), target=1, weight="weight") + assert sp == answer + # weights and method specified + sp = nx.shortest_path( + nx.path_graph(3), target=1, weight="weight", method="dijkstra" + ) + assert sp == answer + sp = nx.shortest_path( + nx.path_graph(3), target=1, weight="weight", method="bellman-ford" + ) + assert sp == answer + + def test_shortest_path_length(self): + assert nx.shortest_path_length(self.cycle, 0, 3) == 3 + assert nx.shortest_path_length(self.grid, 1, 12) == 5 + assert nx.shortest_path_length(self.directed_cycle, 0, 4) == 4 + # now with weights + assert nx.shortest_path_length(self.cycle, 0, 3, weight="weight") == 3 + assert nx.shortest_path_length(self.grid, 1, 12, weight="weight") == 5 + assert nx.shortest_path_length(self.directed_cycle, 0, 4, weight="weight") == 4 + # weights and method specified + assert ( + nx.shortest_path_length( + self.cycle, 0, 3, weight="weight", method="dijkstra" + ) + == 3 + ) + assert ( + nx.shortest_path_length( + self.cycle, 0, 3, weight="weight", method="bellman-ford" + ) + == 3 + ) + # confirm bad method rejection + pytest.raises(ValueError, nx.shortest_path_length, self.cycle, method="SPAM") + # confirm absent source rejection + pytest.raises(nx.NodeNotFound, nx.shortest_path_length, self.cycle, 8) + + def test_shortest_path_length_target(self): + answer = {0: 1, 1: 0, 2: 1} + sp = dict(nx.shortest_path_length(nx.path_graph(3), target=1)) + assert sp == answer + # with weights + sp = nx.shortest_path_length(nx.path_graph(3), target=1, weight="weight") + assert sp == answer + # weights and method specified + sp = nx.shortest_path_length( + nx.path_graph(3), target=1, weight="weight", method="dijkstra" + ) + assert sp == answer + sp = nx.shortest_path_length( + nx.path_graph(3), target=1, weight="weight", method="bellman-ford" + ) + assert sp == answer + + def test_single_source_shortest_path(self): + p = nx.shortest_path(self.cycle, 0) + assert p[3] == [0, 1, 2, 3] + assert p == nx.single_source_shortest_path(self.cycle, 0) + p = nx.shortest_path(self.grid, 1) + validate_grid_path(4, 4, 1, 12, p[12]) + # now with weights + p = nx.shortest_path(self.cycle, 0, weight="weight") + assert p[3] == [0, 1, 2, 3] + assert p == nx.single_source_dijkstra_path(self.cycle, 0) + p = nx.shortest_path(self.grid, 1, weight="weight") + validate_grid_path(4, 4, 1, 12, p[12]) + # weights and method specified + p = nx.shortest_path(self.cycle, 0, method="dijkstra", weight="weight") + assert p[3] == [0, 1, 2, 3] + assert p == nx.single_source_shortest_path(self.cycle, 0) + p = nx.shortest_path(self.cycle, 0, method="bellman-ford", weight="weight") + assert p[3] == [0, 1, 2, 3] + assert p == nx.single_source_shortest_path(self.cycle, 0) + + def test_single_source_shortest_path_length(self): + ans = dict(nx.shortest_path_length(self.cycle, 0)) + assert ans == {0: 0, 1: 1, 2: 2, 3: 3, 4: 3, 5: 2, 6: 1} + assert ans == dict(nx.single_source_shortest_path_length(self.cycle, 0)) + ans = dict(nx.shortest_path_length(self.grid, 1)) + assert ans[16] == 6 + # now with weights + ans = dict(nx.shortest_path_length(self.cycle, 0, weight="weight")) + assert ans == {0: 0, 1: 1, 2: 2, 3: 3, 4: 3, 5: 2, 6: 1} + assert ans == dict(nx.single_source_dijkstra_path_length(self.cycle, 0)) + ans = dict(nx.shortest_path_length(self.grid, 1, weight="weight")) + assert ans[16] == 6 + # weights and method specified + ans = dict( + nx.shortest_path_length(self.cycle, 0, weight="weight", method="dijkstra") + ) + assert ans == {0: 0, 1: 1, 2: 2, 3: 3, 4: 3, 5: 2, 6: 1} + assert ans == dict(nx.single_source_dijkstra_path_length(self.cycle, 0)) + ans = dict( + nx.shortest_path_length( + self.cycle, 0, weight="weight", method="bellman-ford" + ) + ) + assert ans == {0: 0, 1: 1, 2: 2, 3: 3, 4: 3, 5: 2, 6: 1} + assert ans == dict(nx.single_source_bellman_ford_path_length(self.cycle, 0)) + + def test_single_source_all_shortest_paths(self): + cycle_ans = {0: [[0]], 1: [[0, 1]], 2: [[0, 1, 2], [0, 3, 2]], 3: [[0, 3]]} + ans = dict(nx.single_source_all_shortest_paths(nx.cycle_graph(4), 0)) + assert sorted(ans[2]) == cycle_ans[2] + ans = dict(nx.single_source_all_shortest_paths(self.grid, 1)) + grid_ans = [ + [1, 2, 3, 7, 11], + [1, 2, 6, 7, 11], + [1, 2, 6, 10, 11], + [1, 5, 6, 7, 11], + [1, 5, 6, 10, 11], + [1, 5, 9, 10, 11], + ] + assert sorted(ans[11]) == grid_ans + ans = dict( + nx.single_source_all_shortest_paths(nx.cycle_graph(4), 0, weight="weight") + ) + assert sorted(ans[2]) == cycle_ans[2] + ans = dict( + nx.single_source_all_shortest_paths( + nx.cycle_graph(4), 0, method="bellman-ford", weight="weight" + ) + ) + assert sorted(ans[2]) == cycle_ans[2] + ans = dict(nx.single_source_all_shortest_paths(self.grid, 1, weight="weight")) + assert sorted(ans[11]) == grid_ans + ans = dict( + nx.single_source_all_shortest_paths( + self.grid, 1, method="bellman-ford", weight="weight" + ) + ) + assert sorted(ans[11]) == grid_ans + G = nx.cycle_graph(4) + G.add_node(4) + ans = dict(nx.single_source_all_shortest_paths(G, 0)) + assert sorted(ans[2]) == [[0, 1, 2], [0, 3, 2]] + ans = dict(nx.single_source_all_shortest_paths(G, 4)) + assert sorted(ans[4]) == [[4]] + + def test_all_pairs_shortest_path(self): + # shortest_path w/o source and target will return a generator instead of + # a dict beginning in version 3.5. Only the first call needs changed here. + p = nx.shortest_path(self.cycle) + assert p[0][3] == [0, 1, 2, 3] + assert p == dict(nx.all_pairs_shortest_path(self.cycle)) + p = dict(nx.shortest_path(self.grid)) + validate_grid_path(4, 4, 1, 12, p[1][12]) + # now with weights + p = dict(nx.shortest_path(self.cycle, weight="weight")) + assert p[0][3] == [0, 1, 2, 3] + assert p == dict(nx.all_pairs_dijkstra_path(self.cycle)) + p = dict(nx.shortest_path(self.grid, weight="weight")) + validate_grid_path(4, 4, 1, 12, p[1][12]) + # weights and method specified + p = dict(nx.shortest_path(self.cycle, weight="weight", method="dijkstra")) + assert p[0][3] == [0, 1, 2, 3] + assert p == dict(nx.all_pairs_dijkstra_path(self.cycle)) + p = dict(nx.shortest_path(self.cycle, weight="weight", method="bellman-ford")) + assert p[0][3] == [0, 1, 2, 3] + assert p == dict(nx.all_pairs_bellman_ford_path(self.cycle)) + + def test_all_pairs_shortest_path_length(self): + ans = dict(nx.shortest_path_length(self.cycle)) + assert ans[0] == {0: 0, 1: 1, 2: 2, 3: 3, 4: 3, 5: 2, 6: 1} + assert ans == dict(nx.all_pairs_shortest_path_length(self.cycle)) + ans = dict(nx.shortest_path_length(self.grid)) + assert ans[1][16] == 6 + # now with weights + ans = dict(nx.shortest_path_length(self.cycle, weight="weight")) + assert ans[0] == {0: 0, 1: 1, 2: 2, 3: 3, 4: 3, 5: 2, 6: 1} + assert ans == dict(nx.all_pairs_dijkstra_path_length(self.cycle)) + ans = dict(nx.shortest_path_length(self.grid, weight="weight")) + assert ans[1][16] == 6 + # weights and method specified + ans = dict( + nx.shortest_path_length(self.cycle, weight="weight", method="dijkstra") + ) + assert ans[0] == {0: 0, 1: 1, 2: 2, 3: 3, 4: 3, 5: 2, 6: 1} + assert ans == dict(nx.all_pairs_dijkstra_path_length(self.cycle)) + ans = dict( + nx.shortest_path_length(self.cycle, weight="weight", method="bellman-ford") + ) + assert ans[0] == {0: 0, 1: 1, 2: 2, 3: 3, 4: 3, 5: 2, 6: 1} + assert ans == dict(nx.all_pairs_bellman_ford_path_length(self.cycle)) + + def test_all_pairs_all_shortest_paths(self): + ans = dict(nx.all_pairs_all_shortest_paths(nx.cycle_graph(4))) + assert sorted(ans[1][3]) == [[1, 0, 3], [1, 2, 3]] + ans = dict(nx.all_pairs_all_shortest_paths(nx.cycle_graph(4)), weight="weight") + assert sorted(ans[1][3]) == [[1, 0, 3], [1, 2, 3]] + ans = dict( + nx.all_pairs_all_shortest_paths(nx.cycle_graph(4)), + method="bellman-ford", + weight="weight", + ) + assert sorted(ans[1][3]) == [[1, 0, 3], [1, 2, 3]] + G = nx.cycle_graph(4) + G.add_node(4) + ans = dict(nx.all_pairs_all_shortest_paths(G)) + assert sorted(ans[4][4]) == [[4]] + + def test_has_path(self): + G = nx.Graph() + nx.add_path(G, range(3)) + nx.add_path(G, range(3, 5)) + assert nx.has_path(G, 0, 2) + assert not nx.has_path(G, 0, 4) + + def test_has_path_singleton(self): + G = nx.empty_graph(1) + assert nx.has_path(G, 0, 0) + + def test_all_shortest_paths(self): + G = nx.Graph() + nx.add_path(G, [0, 1, 2, 3]) + nx.add_path(G, [0, 10, 20, 3]) + assert [[0, 1, 2, 3], [0, 10, 20, 3]] == sorted(nx.all_shortest_paths(G, 0, 3)) + # with weights + G = nx.Graph() + nx.add_path(G, [0, 1, 2, 3]) + nx.add_path(G, [0, 10, 20, 3]) + assert [[0, 1, 2, 3], [0, 10, 20, 3]] == sorted( + nx.all_shortest_paths(G, 0, 3, weight="weight") + ) + # weights and method specified + G = nx.Graph() + nx.add_path(G, [0, 1, 2, 3]) + nx.add_path(G, [0, 10, 20, 3]) + assert [[0, 1, 2, 3], [0, 10, 20, 3]] == sorted( + nx.all_shortest_paths(G, 0, 3, weight="weight", method="dijkstra") + ) + G = nx.Graph() + nx.add_path(G, [0, 1, 2, 3]) + nx.add_path(G, [0, 10, 20, 3]) + assert [[0, 1, 2, 3], [0, 10, 20, 3]] == sorted( + nx.all_shortest_paths(G, 0, 3, weight="weight", method="bellman-ford") + ) + + def test_all_shortest_paths_raise(self): + with pytest.raises(nx.NetworkXNoPath): + G = nx.path_graph(4) + G.add_node(4) + list(nx.all_shortest_paths(G, 0, 4)) + + def test_bad_method(self): + with pytest.raises(ValueError): + G = nx.path_graph(2) + list(nx.all_shortest_paths(G, 0, 1, weight="weight", method="SPAM")) + + def test_single_source_all_shortest_paths_bad_method(self): + with pytest.raises(ValueError): + G = nx.path_graph(2) + dict( + nx.single_source_all_shortest_paths( + G, 0, weight="weight", method="SPAM" + ) + ) + + def test_all_shortest_paths_zero_weight_edge(self): + g = nx.Graph() + nx.add_path(g, [0, 1, 3]) + nx.add_path(g, [0, 1, 2, 3]) + g.edges[1, 2]["weight"] = 0 + paths30d = list( + nx.all_shortest_paths(g, 3, 0, weight="weight", method="dijkstra") + ) + paths03d = list( + nx.all_shortest_paths(g, 0, 3, weight="weight", method="dijkstra") + ) + paths30b = list( + nx.all_shortest_paths(g, 3, 0, weight="weight", method="bellman-ford") + ) + paths03b = list( + nx.all_shortest_paths(g, 0, 3, weight="weight", method="bellman-ford") + ) + assert sorted(paths03d) == sorted(p[::-1] for p in paths30d) + assert sorted(paths03d) == sorted(p[::-1] for p in paths30b) + assert sorted(paths03b) == sorted(p[::-1] for p in paths30b) + + +class TestAverageShortestPathLength: + def test_cycle_graph(self): + ans = nx.average_shortest_path_length(nx.cycle_graph(7)) + assert ans == pytest.approx(2, abs=1e-7) + + def test_path_graph(self): + ans = nx.average_shortest_path_length(nx.path_graph(5)) + assert ans == pytest.approx(2, abs=1e-7) + + def test_weighted(self): + G = nx.Graph() + nx.add_cycle(G, range(7), weight=2) + ans = nx.average_shortest_path_length(G, weight="weight") + assert ans == pytest.approx(4, abs=1e-7) + G = nx.Graph() + nx.add_path(G, range(5), weight=2) + ans = nx.average_shortest_path_length(G, weight="weight") + assert ans == pytest.approx(4, abs=1e-7) + + def test_specified_methods(self): + G = nx.Graph() + nx.add_cycle(G, range(7), weight=2) + ans = nx.average_shortest_path_length(G, weight="weight", method="dijkstra") + assert ans == pytest.approx(4, abs=1e-7) + ans = nx.average_shortest_path_length(G, weight="weight", method="bellman-ford") + assert ans == pytest.approx(4, abs=1e-7) + ans = nx.average_shortest_path_length( + G, weight="weight", method="floyd-warshall" + ) + assert ans == pytest.approx(4, abs=1e-7) + + G = nx.Graph() + nx.add_path(G, range(5), weight=2) + ans = nx.average_shortest_path_length(G, weight="weight", method="dijkstra") + assert ans == pytest.approx(4, abs=1e-7) + ans = nx.average_shortest_path_length(G, weight="weight", method="bellman-ford") + assert ans == pytest.approx(4, abs=1e-7) + ans = nx.average_shortest_path_length( + G, weight="weight", method="floyd-warshall" + ) + assert ans == pytest.approx(4, abs=1e-7) + + def test_directed_not_strongly_connected(self): + G = nx.DiGraph([(0, 1)]) + with pytest.raises(nx.NetworkXError, match="Graph is not strongly connected"): + nx.average_shortest_path_length(G) + + def test_undirected_not_connected(self): + g = nx.Graph() + g.add_nodes_from(range(3)) + g.add_edge(0, 1) + pytest.raises(nx.NetworkXError, nx.average_shortest_path_length, g) + + def test_trivial_graph(self): + """Tests that the trivial graph has average path length zero, + since there is exactly one path of length zero in the trivial + graph. + + For more information, see issue #1960. + + """ + G = nx.trivial_graph() + assert nx.average_shortest_path_length(G) == 0 + + def test_null_graph(self): + with pytest.raises(nx.NetworkXPointlessConcept): + nx.average_shortest_path_length(nx.null_graph()) + + def test_bad_method(self): + with pytest.raises(ValueError): + G = nx.path_graph(2) + nx.average_shortest_path_length(G, weight="weight", method="SPAM") + + +class TestAverageShortestPathLengthNumpy: + @classmethod + def setup_class(cls): + global np + import pytest + + np = pytest.importorskip("numpy") + + def test_specified_methods_numpy(self): + G = nx.Graph() + nx.add_cycle(G, range(7), weight=2) + ans = nx.average_shortest_path_length( + G, weight="weight", method="floyd-warshall-numpy" + ) + np.testing.assert_almost_equal(ans, 4) + + G = nx.Graph() + nx.add_path(G, range(5), weight=2) + ans = nx.average_shortest_path_length( + G, weight="weight", method="floyd-warshall-numpy" + ) + np.testing.assert_almost_equal(ans, 4) diff --git a/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/tests/test_unweighted.py b/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/tests/test_unweighted.py new file mode 100644 index 0000000000000000000000000000000000000000..f09c8b104266afee86db0836a60cb30ff96e0bb1 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/tests/test_unweighted.py @@ -0,0 +1,149 @@ +import pytest + +import networkx as nx + + +def validate_grid_path(r, c, s, t, p): + assert isinstance(p, list) + assert p[0] == s + assert p[-1] == t + s = ((s - 1) // c, (s - 1) % c) + t = ((t - 1) // c, (t - 1) % c) + assert len(p) == abs(t[0] - s[0]) + abs(t[1] - s[1]) + 1 + p = [((u - 1) // c, (u - 1) % c) for u in p] + for u in p: + assert 0 <= u[0] < r + assert 0 <= u[1] < c + for u, v in zip(p[:-1], p[1:]): + assert (abs(v[0] - u[0]), abs(v[1] - u[1])) in [(0, 1), (1, 0)] + + +class TestUnweightedPath: + @classmethod + def setup_class(cls): + from networkx import convert_node_labels_to_integers as cnlti + + cls.grid = cnlti(nx.grid_2d_graph(4, 4), first_label=1, ordering="sorted") + cls.cycle = nx.cycle_graph(7) + cls.directed_cycle = nx.cycle_graph(7, create_using=nx.DiGraph()) + + def test_bidirectional_shortest_path(self): + assert nx.bidirectional_shortest_path(self.cycle, 0, 3) == [0, 1, 2, 3] + assert nx.bidirectional_shortest_path(self.cycle, 0, 4) == [0, 6, 5, 4] + validate_grid_path( + 4, 4, 1, 12, nx.bidirectional_shortest_path(self.grid, 1, 12) + ) + assert nx.bidirectional_shortest_path(self.directed_cycle, 0, 3) == [0, 1, 2, 3] + # test source = target + assert nx.bidirectional_shortest_path(self.cycle, 3, 3) == [3] + + @pytest.mark.parametrize( + ("src", "tgt"), + ( + (8, 3), # source not in graph + (3, 8), # target not in graph + (8, 10), # neither source nor target in graph + (8, 8), # src == tgt, neither in graph - tests order of input checks + ), + ) + def test_bidirectional_shortest_path_src_tgt_not_in_graph(self, src, tgt): + with pytest.raises( + nx.NodeNotFound, + match=f"(Source {src}|Target {tgt}) is not in G", + ): + nx.bidirectional_shortest_path(self.cycle, src, tgt) + + def test_shortest_path_length(self): + assert nx.shortest_path_length(self.cycle, 0, 3) == 3 + assert nx.shortest_path_length(self.grid, 1, 12) == 5 + assert nx.shortest_path_length(self.directed_cycle, 0, 4) == 4 + # now with weights + assert nx.shortest_path_length(self.cycle, 0, 3, weight=True) == 3 + assert nx.shortest_path_length(self.grid, 1, 12, weight=True) == 5 + assert nx.shortest_path_length(self.directed_cycle, 0, 4, weight=True) == 4 + + def test_single_source_shortest_path(self): + p = nx.single_source_shortest_path(self.directed_cycle, 3) + assert p[0] == [3, 4, 5, 6, 0] + p = nx.single_source_shortest_path(self.cycle, 0) + assert p[3] == [0, 1, 2, 3] + p = nx.single_source_shortest_path(self.cycle, 0, cutoff=0) + assert p == {0: [0]} + + def test_single_source_shortest_path_length(self): + pl = nx.single_source_shortest_path_length + lengths = {0: 0, 1: 1, 2: 2, 3: 3, 4: 3, 5: 2, 6: 1} + assert dict(pl(self.cycle, 0)) == lengths + lengths = {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6} + assert dict(pl(self.directed_cycle, 0)) == lengths + + def test_single_target_shortest_path(self): + p = nx.single_target_shortest_path(self.directed_cycle, 0) + assert p[3] == [3, 4, 5, 6, 0] + p = nx.single_target_shortest_path(self.cycle, 0) + assert p[3] == [3, 2, 1, 0] + p = nx.single_target_shortest_path(self.cycle, 0, cutoff=0) + assert p == {0: [0]} + # test missing targets + target = 8 + with pytest.raises(nx.NodeNotFound, match=f"Target {target} not in G"): + nx.single_target_shortest_path(self.cycle, target) + + def test_single_target_shortest_path_length(self): + pl = nx.single_target_shortest_path_length + lengths = {0: 0, 1: 1, 2: 2, 3: 3, 4: 3, 5: 2, 6: 1} + assert dict(pl(self.cycle, 0)) == lengths + lengths = {0: 0, 1: 6, 2: 5, 3: 4, 4: 3, 5: 2, 6: 1} + assert dict(pl(self.directed_cycle, 0)) == lengths + # test missing targets + target = 8 + with pytest.raises(nx.NodeNotFound, match=f"Target {target} is not in G"): + nx.single_target_shortest_path_length(self.cycle, target) + + def test_all_pairs_shortest_path(self): + p = dict(nx.all_pairs_shortest_path(self.cycle)) + assert p[0][3] == [0, 1, 2, 3] + p = dict(nx.all_pairs_shortest_path(self.grid)) + validate_grid_path(4, 4, 1, 12, p[1][12]) + + def test_all_pairs_shortest_path_length(self): + l = dict(nx.all_pairs_shortest_path_length(self.cycle)) + assert l[0] == {0: 0, 1: 1, 2: 2, 3: 3, 4: 3, 5: 2, 6: 1} + l = dict(nx.all_pairs_shortest_path_length(self.grid)) + assert l[1][16] == 6 + + def test_predecessor_path(self): + G = nx.path_graph(4) + assert nx.predecessor(G, 0) == {0: [], 1: [0], 2: [1], 3: [2]} + assert nx.predecessor(G, 0, 3) == [2] + + def test_predecessor_cycle(self): + G = nx.cycle_graph(4) + pred = nx.predecessor(G, 0) + assert pred[0] == [] + assert pred[1] == [0] + assert pred[2] in [[1, 3], [3, 1]] + assert pred[3] == [0] + + def test_predecessor_cutoff(self): + G = nx.path_graph(4) + p = nx.predecessor(G, 0, 3) + assert 4 not in p + + def test_predecessor_target(self): + G = nx.path_graph(4) + p = nx.predecessor(G, 0, 3) + assert p == [2] + p = nx.predecessor(G, 0, 3, cutoff=2) + assert p == [] + p, s = nx.predecessor(G, 0, 3, return_seen=True) + assert p == [2] + assert s == 3 + p, s = nx.predecessor(G, 0, 3, cutoff=2, return_seen=True) + assert p == [] + assert s == -1 + + def test_predecessor_missing_source(self): + source = 8 + with pytest.raises(nx.NodeNotFound, match=f"Source {source} not in G"): + nx.predecessor(self.cycle, source) diff --git a/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/tests/test_weighted.py b/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/tests/test_weighted.py new file mode 100644 index 0000000000000000000000000000000000000000..dc88572d3571f7d94015b12d9ab870aee316516f --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/tests/test_weighted.py @@ -0,0 +1,972 @@ +import pytest + +import networkx as nx +from networkx.utils import pairwise + + +def validate_path(G, s, t, soln_len, path, weight="weight"): + assert path[0] == s + assert path[-1] == t + + if callable(weight): + weight_f = weight + else: + if G.is_multigraph(): + + def weight_f(u, v, d): + return min(e.get(weight, 1) for e in d.values()) + + else: + + def weight_f(u, v, d): + return d.get(weight, 1) + + computed = sum(weight_f(u, v, G[u][v]) for u, v in pairwise(path)) + assert soln_len == computed + + +def validate_length_path(G, s, t, soln_len, length, path, weight="weight"): + assert soln_len == length + validate_path(G, s, t, length, path, weight=weight) + + +class WeightedTestBase: + """Base class for test classes that test functions for computing + shortest paths in weighted graphs. + + """ + + def setup_method(self): + """Creates some graphs for use in the unit tests.""" + cnlti = nx.convert_node_labels_to_integers + self.grid = cnlti(nx.grid_2d_graph(4, 4), first_label=1, ordering="sorted") + self.cycle = nx.cycle_graph(7) + self.directed_cycle = nx.cycle_graph(7, create_using=nx.DiGraph()) + self.XG = nx.DiGraph() + self.XG.add_weighted_edges_from( + [ + ("s", "u", 10), + ("s", "x", 5), + ("u", "v", 1), + ("u", "x", 2), + ("v", "y", 1), + ("x", "u", 3), + ("x", "v", 5), + ("x", "y", 2), + ("y", "s", 7), + ("y", "v", 6), + ] + ) + self.MXG = nx.MultiDiGraph(self.XG) + self.MXG.add_edge("s", "u", weight=15) + self.XG2 = nx.DiGraph() + self.XG2.add_weighted_edges_from( + [ + [1, 4, 1], + [4, 5, 1], + [5, 6, 1], + [6, 3, 1], + [1, 3, 50], + [1, 2, 100], + [2, 3, 100], + ] + ) + + self.XG3 = nx.Graph() + self.XG3.add_weighted_edges_from( + [[0, 1, 2], [1, 2, 12], [2, 3, 1], [3, 4, 5], [4, 5, 1], [5, 0, 10]] + ) + + self.XG4 = nx.Graph() + self.XG4.add_weighted_edges_from( + [ + [0, 1, 2], + [1, 2, 2], + [2, 3, 1], + [3, 4, 1], + [4, 5, 1], + [5, 6, 1], + [6, 7, 1], + [7, 0, 1], + ] + ) + self.MXG4 = nx.MultiGraph(self.XG4) + self.MXG4.add_edge(0, 1, weight=3) + self.G = nx.DiGraph() # no weights + self.G.add_edges_from( + [ + ("s", "u"), + ("s", "x"), + ("u", "v"), + ("u", "x"), + ("v", "y"), + ("x", "u"), + ("x", "v"), + ("x", "y"), + ("y", "s"), + ("y", "v"), + ] + ) + + +class TestWeightedPath(WeightedTestBase): + def test_dijkstra(self): + (D, P) = nx.single_source_dijkstra(self.XG, "s") + validate_path(self.XG, "s", "v", 9, P["v"]) + assert D["v"] == 9 + + validate_path( + self.XG, "s", "v", 9, nx.single_source_dijkstra_path(self.XG, "s")["v"] + ) + assert dict(nx.single_source_dijkstra_path_length(self.XG, "s"))["v"] == 9 + + validate_path( + self.XG, "s", "v", 9, nx.single_source_dijkstra(self.XG, "s")[1]["v"] + ) + validate_path( + self.MXG, "s", "v", 9, nx.single_source_dijkstra_path(self.MXG, "s")["v"] + ) + + GG = self.XG.to_undirected() + # make sure we get lower weight + # to_undirected might choose either edge with weight 2 or weight 3 + GG["u"]["x"]["weight"] = 2 + (D, P) = nx.single_source_dijkstra(GG, "s") + validate_path(GG, "s", "v", 8, P["v"]) + assert D["v"] == 8 # uses lower weight of 2 on u<->x edge + validate_path(GG, "s", "v", 8, nx.dijkstra_path(GG, "s", "v")) + assert nx.dijkstra_path_length(GG, "s", "v") == 8 + + validate_path(self.XG2, 1, 3, 4, nx.dijkstra_path(self.XG2, 1, 3)) + validate_path(self.XG3, 0, 3, 15, nx.dijkstra_path(self.XG3, 0, 3)) + assert nx.dijkstra_path_length(self.XG3, 0, 3) == 15 + validate_path(self.XG4, 0, 2, 4, nx.dijkstra_path(self.XG4, 0, 2)) + assert nx.dijkstra_path_length(self.XG4, 0, 2) == 4 + validate_path(self.MXG4, 0, 2, 4, nx.dijkstra_path(self.MXG4, 0, 2)) + validate_path( + self.G, "s", "v", 2, nx.single_source_dijkstra(self.G, "s", "v")[1] + ) + validate_path( + self.G, "s", "v", 2, nx.single_source_dijkstra(self.G, "s")[1]["v"] + ) + + validate_path(self.G, "s", "v", 2, nx.dijkstra_path(self.G, "s", "v")) + assert nx.dijkstra_path_length(self.G, "s", "v") == 2 + + # NetworkXError: node s not reachable from moon + pytest.raises(nx.NetworkXNoPath, nx.dijkstra_path, self.G, "s", "moon") + pytest.raises(nx.NetworkXNoPath, nx.dijkstra_path_length, self.G, "s", "moon") + + validate_path(self.cycle, 0, 3, 3, nx.dijkstra_path(self.cycle, 0, 3)) + validate_path(self.cycle, 0, 4, 3, nx.dijkstra_path(self.cycle, 0, 4)) + + assert nx.single_source_dijkstra(self.cycle, 0, 0) == (0, [0]) + + def test_bidirectional_dijkstra(self): + validate_length_path( + self.XG, "s", "v", 9, *nx.bidirectional_dijkstra(self.XG, "s", "v") + ) + validate_length_path( + self.G, "s", "v", 2, *nx.bidirectional_dijkstra(self.G, "s", "v") + ) + validate_length_path( + self.cycle, 0, 3, 3, *nx.bidirectional_dijkstra(self.cycle, 0, 3) + ) + validate_length_path( + self.cycle, 0, 4, 3, *nx.bidirectional_dijkstra(self.cycle, 0, 4) + ) + validate_length_path( + self.XG3, 0, 3, 15, *nx.bidirectional_dijkstra(self.XG3, 0, 3) + ) + validate_length_path( + self.XG4, 0, 2, 4, *nx.bidirectional_dijkstra(self.XG4, 0, 2) + ) + + # need more tests here + P = nx.single_source_dijkstra_path(self.XG, "s")["v"] + validate_path( + self.XG, + "s", + "v", + sum(self.XG[u][v]["weight"] for u, v in zip(P[:-1], P[1:])), + nx.dijkstra_path(self.XG, "s", "v"), + ) + + # check absent source + G = nx.path_graph(2) + pytest.raises(nx.NodeNotFound, nx.bidirectional_dijkstra, G, 3, 0) + + def test_weight_functions(self): + def heuristic(*z): + return sum(val**2 for val in z) + + def getpath(pred, v, s): + return [v] if v == s else getpath(pred, pred[v], s) + [v] + + def goldberg_radzik(g, s, t, weight="weight"): + pred, dist = nx.goldberg_radzik(g, s, weight=weight) + dist = dist[t] + return dist, getpath(pred, t, s) + + def astar(g, s, t, weight="weight"): + path = nx.astar_path(g, s, t, heuristic, weight=weight) + dist = nx.astar_path_length(g, s, t, heuristic, weight=weight) + return dist, path + + def vlp(G, s, t, l, F, w): + res = F(G, s, t, weight=w) + validate_length_path(G, s, t, l, *res, weight=w) + + G = self.cycle + s = 6 + t = 4 + path = [6] + list(range(t + 1)) + + def weight(u, v, _): + return 1 + v**2 + + length = sum(weight(u, v, None) for u, v in pairwise(path)) + vlp(G, s, t, length, nx.bidirectional_dijkstra, weight) + vlp(G, s, t, length, nx.single_source_dijkstra, weight) + vlp(G, s, t, length, nx.single_source_bellman_ford, weight) + vlp(G, s, t, length, goldberg_radzik, weight) + vlp(G, s, t, length, astar, weight) + + def weight(u, v, _): + return 2 ** (u * v) + + length = sum(weight(u, v, None) for u, v in pairwise(path)) + vlp(G, s, t, length, nx.bidirectional_dijkstra, weight) + vlp(G, s, t, length, nx.single_source_dijkstra, weight) + vlp(G, s, t, length, nx.single_source_bellman_ford, weight) + vlp(G, s, t, length, goldberg_radzik, weight) + vlp(G, s, t, length, astar, weight) + + def test_bidirectional_dijkstra_no_path(self): + with pytest.raises(nx.NetworkXNoPath): + G = nx.Graph() + nx.add_path(G, [1, 2, 3]) + nx.add_path(G, [4, 5, 6]) + path = nx.bidirectional_dijkstra(G, 1, 6) + + @pytest.mark.parametrize( + "fn", + ( + nx.dijkstra_path, + nx.dijkstra_path_length, + nx.single_source_dijkstra_path, + nx.single_source_dijkstra_path_length, + nx.single_source_dijkstra, + nx.dijkstra_predecessor_and_distance, + ), + ) + def test_absent_source(self, fn): + G = nx.path_graph(2) + with pytest.raises(nx.NodeNotFound): + fn(G, 3, 0) + # Test when source == target, which is handled specially by some functions + with pytest.raises(nx.NodeNotFound): + fn(G, 3, 3) + + def test_dijkstra_predecessor1(self): + G = nx.path_graph(4) + assert nx.dijkstra_predecessor_and_distance(G, 0) == ( + {0: [], 1: [0], 2: [1], 3: [2]}, + {0: 0, 1: 1, 2: 2, 3: 3}, + ) + + def test_dijkstra_predecessor2(self): + # 4-cycle + G = nx.Graph([(0, 1), (1, 2), (2, 3), (3, 0)]) + pred, dist = nx.dijkstra_predecessor_and_distance(G, (0)) + assert pred[0] == [] + assert pred[1] == [0] + assert pred[2] in [[1, 3], [3, 1]] + assert pred[3] == [0] + assert dist == {0: 0, 1: 1, 2: 2, 3: 1} + + def test_dijkstra_predecessor3(self): + XG = nx.DiGraph() + XG.add_weighted_edges_from( + [ + ("s", "u", 10), + ("s", "x", 5), + ("u", "v", 1), + ("u", "x", 2), + ("v", "y", 1), + ("x", "u", 3), + ("x", "v", 5), + ("x", "y", 2), + ("y", "s", 7), + ("y", "v", 6), + ] + ) + (P, D) = nx.dijkstra_predecessor_and_distance(XG, "s") + assert P["v"] == ["u"] + assert D["v"] == 9 + (P, D) = nx.dijkstra_predecessor_and_distance(XG, "s", cutoff=8) + assert "v" not in D + + def test_single_source_dijkstra_path_length(self): + pl = nx.single_source_dijkstra_path_length + assert dict(pl(self.MXG4, 0))[2] == 4 + spl = pl(self.MXG4, 0, cutoff=2) + assert 2 not in spl + + def test_bidirectional_dijkstra_multigraph(self): + G = nx.MultiGraph() + G.add_edge("a", "b", weight=10) + G.add_edge("a", "b", weight=100) + dp = nx.bidirectional_dijkstra(G, "a", "b") + assert dp == (10, ["a", "b"]) + + def test_dijkstra_pred_distance_multigraph(self): + G = nx.MultiGraph() + G.add_edge("a", "b", key="short", foo=5, weight=100) + G.add_edge("a", "b", key="long", bar=1, weight=110) + p, d = nx.dijkstra_predecessor_and_distance(G, "a") + assert p == {"a": [], "b": ["a"]} + assert d == {"a": 0, "b": 100} + + def test_negative_edge_cycle(self): + G = nx.cycle_graph(5, create_using=nx.DiGraph()) + assert not nx.negative_edge_cycle(G) + G.add_edge(8, 9, weight=-7) + G.add_edge(9, 8, weight=3) + graph_size = len(G) + assert nx.negative_edge_cycle(G) + assert graph_size == len(G) + pytest.raises(ValueError, nx.single_source_dijkstra_path_length, G, 8) + pytest.raises(ValueError, nx.single_source_dijkstra, G, 8) + pytest.raises(ValueError, nx.dijkstra_predecessor_and_distance, G, 8) + G.add_edge(9, 10) + pytest.raises(ValueError, nx.bidirectional_dijkstra, G, 8, 10) + G = nx.MultiDiGraph() + G.add_edge(2, 2, weight=-1) + assert nx.negative_edge_cycle(G) + + def test_negative_edge_cycle_empty(self): + G = nx.DiGraph() + assert not nx.negative_edge_cycle(G) + + def test_negative_edge_cycle_custom_weight_key(self): + d = nx.DiGraph() + d.add_edge("a", "b", w=-2) + d.add_edge("b", "a", w=-1) + assert nx.negative_edge_cycle(d, weight="w") + + def test_weight_function(self): + """Tests that a callable weight is interpreted as a weight + function instead of an edge attribute. + + """ + # Create a triangle in which the edge from node 0 to node 2 has + # a large weight and the other two edges have a small weight. + G = nx.complete_graph(3) + G.adj[0][2]["weight"] = 10 + G.adj[0][1]["weight"] = 1 + G.adj[1][2]["weight"] = 1 + + # The weight function will take the multiplicative inverse of + # the weights on the edges. This way, weights that were large + # before now become small and vice versa. + + def weight(u, v, d): + return 1 / d["weight"] + + # The shortest path from 0 to 2 using the actual weights on the + # edges should be [0, 1, 2]. + distance, path = nx.single_source_dijkstra(G, 0, 2) + assert distance == 2 + assert path == [0, 1, 2] + # However, with the above weight function, the shortest path + # should be [0, 2], since that has a very small weight. + distance, path = nx.single_source_dijkstra(G, 0, 2, weight=weight) + assert distance == 1 / 10 + assert path == [0, 2] + + def test_all_pairs_dijkstra_path(self): + cycle = nx.cycle_graph(7) + p = dict(nx.all_pairs_dijkstra_path(cycle)) + assert p[0][3] == [0, 1, 2, 3] + + cycle[1][2]["weight"] = 10 + p = dict(nx.all_pairs_dijkstra_path(cycle)) + assert p[0][3] == [0, 6, 5, 4, 3] + + def test_all_pairs_dijkstra_path_length(self): + cycle = nx.cycle_graph(7) + pl = dict(nx.all_pairs_dijkstra_path_length(cycle)) + assert pl[0] == {0: 0, 1: 1, 2: 2, 3: 3, 4: 3, 5: 2, 6: 1} + + cycle[1][2]["weight"] = 10 + pl = dict(nx.all_pairs_dijkstra_path_length(cycle)) + assert pl[0] == {0: 0, 1: 1, 2: 5, 3: 4, 4: 3, 5: 2, 6: 1} + + def test_all_pairs_dijkstra(self): + cycle = nx.cycle_graph(7) + out = dict(nx.all_pairs_dijkstra(cycle)) + assert out[0][0] == {0: 0, 1: 1, 2: 2, 3: 3, 4: 3, 5: 2, 6: 1} + assert out[0][1][3] == [0, 1, 2, 3] + + cycle[1][2]["weight"] = 10 + out = dict(nx.all_pairs_dijkstra(cycle)) + assert out[0][0] == {0: 0, 1: 1, 2: 5, 3: 4, 4: 3, 5: 2, 6: 1} + assert out[0][1][3] == [0, 6, 5, 4, 3] + + +class TestDijkstraPathLength: + """Unit tests for the :func:`networkx.dijkstra_path_length` + function. + + """ + + def test_weight_function(self): + """Tests for computing the length of the shortest path using + Dijkstra's algorithm with a user-defined weight function. + + """ + # Create a triangle in which the edge from node 0 to node 2 has + # a large weight and the other two edges have a small weight. + G = nx.complete_graph(3) + G.adj[0][2]["weight"] = 10 + G.adj[0][1]["weight"] = 1 + G.adj[1][2]["weight"] = 1 + + # The weight function will take the multiplicative inverse of + # the weights on the edges. This way, weights that were large + # before now become small and vice versa. + + def weight(u, v, d): + return 1 / d["weight"] + + # The shortest path from 0 to 2 using the actual weights on the + # edges should be [0, 1, 2]. However, with the above weight + # function, the shortest path should be [0, 2], since that has a + # very small weight. + length = nx.dijkstra_path_length(G, 0, 2, weight=weight) + assert length == 1 / 10 + + +class TestMultiSourceDijkstra: + """Unit tests for the multi-source dialect of Dijkstra's shortest + path algorithms. + + """ + + def test_no_sources(self): + with pytest.raises(ValueError): + nx.multi_source_dijkstra(nx.Graph(), {}) + + def test_path_no_sources(self): + with pytest.raises(ValueError): + nx.multi_source_dijkstra_path(nx.Graph(), {}) + + def test_path_length_no_sources(self): + with pytest.raises(ValueError): + nx.multi_source_dijkstra_path_length(nx.Graph(), {}) + + @pytest.mark.parametrize( + "fn", + ( + nx.multi_source_dijkstra_path, + nx.multi_source_dijkstra_path_length, + nx.multi_source_dijkstra, + ), + ) + def test_absent_source(self, fn): + G = nx.path_graph(2) + with pytest.raises(nx.NodeNotFound): + fn(G, [3], 0) + with pytest.raises(nx.NodeNotFound): + fn(G, [3], 3) + + def test_two_sources(self): + edges = [(0, 1, 1), (1, 2, 1), (2, 3, 10), (3, 4, 1)] + G = nx.Graph() + G.add_weighted_edges_from(edges) + sources = {0, 4} + distances, paths = nx.multi_source_dijkstra(G, sources) + expected_distances = {0: 0, 1: 1, 2: 2, 3: 1, 4: 0} + expected_paths = {0: [0], 1: [0, 1], 2: [0, 1, 2], 3: [4, 3], 4: [4]} + assert distances == expected_distances + assert paths == expected_paths + + def test_simple_paths(self): + G = nx.path_graph(4) + lengths = nx.multi_source_dijkstra_path_length(G, [0]) + assert lengths == {n: n for n in G} + paths = nx.multi_source_dijkstra_path(G, [0]) + assert paths == {n: list(range(n + 1)) for n in G} + + +class TestBellmanFordAndGoldbergRadzik(WeightedTestBase): + def test_single_node_graph(self): + G = nx.DiGraph() + G.add_node(0) + assert nx.single_source_bellman_ford_path(G, 0) == {0: [0]} + assert nx.single_source_bellman_ford_path_length(G, 0) == {0: 0} + assert nx.single_source_bellman_ford(G, 0) == ({0: 0}, {0: [0]}) + assert nx.bellman_ford_predecessor_and_distance(G, 0) == ({0: []}, {0: 0}) + assert nx.goldberg_radzik(G, 0) == ({0: None}, {0: 0}) + + def test_absent_source_bellman_ford(self): + # the check is in _bellman_ford; this provides regression testing + # against later changes to "client" Bellman-Ford functions + G = nx.path_graph(2) + for fn in ( + nx.bellman_ford_predecessor_and_distance, + nx.bellman_ford_path, + nx.bellman_ford_path_length, + nx.single_source_bellman_ford_path, + nx.single_source_bellman_ford_path_length, + nx.single_source_bellman_ford, + ): + pytest.raises(nx.NodeNotFound, fn, G, 3, 0) + pytest.raises(nx.NodeNotFound, fn, G, 3, 3) + + def test_absent_source_goldberg_radzik(self): + with pytest.raises(nx.NodeNotFound): + G = nx.path_graph(2) + nx.goldberg_radzik(G, 3, 0) + + def test_negative_cycle_heuristic(self): + G = nx.DiGraph() + G.add_edge(0, 1, weight=-1) + G.add_edge(1, 2, weight=-1) + G.add_edge(2, 3, weight=-1) + G.add_edge(3, 0, weight=3) + assert not nx.negative_edge_cycle(G, heuristic=True) + G.add_edge(2, 0, weight=1.999) + assert nx.negative_edge_cycle(G, heuristic=True) + G.edges[2, 0]["weight"] = 2 + assert not nx.negative_edge_cycle(G, heuristic=True) + + def test_negative_cycle_consistency(self): + import random + + unif = random.uniform + for random_seed in range(2): # range(20): + random.seed(random_seed) + for density in [0.1, 0.9]: # .3, .7, .9]: + for N in [1, 10, 20]: # range(1, 60 - int(30 * density)): + for max_cost in [1, 90]: # [1, 10, 40, 90]: + G = nx.binomial_graph(N, density, seed=4, directed=True) + edges = ((u, v, unif(-1, max_cost)) for u, v in G.edges) + G.add_weighted_edges_from(edges) + + no_heuristic = nx.negative_edge_cycle(G, heuristic=False) + with_heuristic = nx.negative_edge_cycle(G, heuristic=True) + assert no_heuristic == with_heuristic + + def test_negative_cycle(self): + G = nx.cycle_graph(5, create_using=nx.DiGraph()) + G.add_edge(1, 2, weight=-7) + for i in range(5): + pytest.raises( + nx.NetworkXUnbounded, nx.single_source_bellman_ford_path, G, i + ) + pytest.raises( + nx.NetworkXUnbounded, nx.single_source_bellman_ford_path_length, G, i + ) + pytest.raises(nx.NetworkXUnbounded, nx.single_source_bellman_ford, G, i) + pytest.raises( + nx.NetworkXUnbounded, nx.bellman_ford_predecessor_and_distance, G, i + ) + pytest.raises(nx.NetworkXUnbounded, nx.goldberg_radzik, G, i) + G = nx.cycle_graph(5) # undirected Graph + G.add_edge(1, 2, weight=-3) + for i in range(5): + pytest.raises( + nx.NetworkXUnbounded, nx.single_source_bellman_ford_path, G, i + ) + pytest.raises( + nx.NetworkXUnbounded, nx.single_source_bellman_ford_path_length, G, i + ) + pytest.raises(nx.NetworkXUnbounded, nx.single_source_bellman_ford, G, i) + pytest.raises( + nx.NetworkXUnbounded, nx.bellman_ford_predecessor_and_distance, G, i + ) + pytest.raises(nx.NetworkXUnbounded, nx.goldberg_radzik, G, i) + G = nx.DiGraph([(1, 1, {"weight": -1})]) + pytest.raises(nx.NetworkXUnbounded, nx.single_source_bellman_ford_path, G, 1) + pytest.raises( + nx.NetworkXUnbounded, nx.single_source_bellman_ford_path_length, G, 1 + ) + pytest.raises(nx.NetworkXUnbounded, nx.single_source_bellman_ford, G, 1) + pytest.raises( + nx.NetworkXUnbounded, nx.bellman_ford_predecessor_and_distance, G, 1 + ) + pytest.raises(nx.NetworkXUnbounded, nx.goldberg_radzik, G, 1) + G = nx.MultiDiGraph([(1, 1, {"weight": -1})]) + pytest.raises(nx.NetworkXUnbounded, nx.single_source_bellman_ford_path, G, 1) + pytest.raises( + nx.NetworkXUnbounded, nx.single_source_bellman_ford_path_length, G, 1 + ) + pytest.raises(nx.NetworkXUnbounded, nx.single_source_bellman_ford, G, 1) + pytest.raises( + nx.NetworkXUnbounded, nx.bellman_ford_predecessor_and_distance, G, 1 + ) + pytest.raises(nx.NetworkXUnbounded, nx.goldberg_radzik, G, 1) + + def test_zero_cycle(self): + G = nx.cycle_graph(5, create_using=nx.DiGraph()) + G.add_edge(2, 3, weight=-4) + # check that zero cycle doesn't raise + nx.goldberg_radzik(G, 1) + nx.bellman_ford_predecessor_and_distance(G, 1) + + G.add_edge(2, 3, weight=-4.0001) + # check that negative cycle does raise + pytest.raises( + nx.NetworkXUnbounded, nx.bellman_ford_predecessor_and_distance, G, 1 + ) + pytest.raises(nx.NetworkXUnbounded, nx.goldberg_radzik, G, 1) + + def test_find_negative_cycle_longer_cycle(self): + G = nx.cycle_graph(5, create_using=nx.DiGraph()) + nx.add_cycle(G, [3, 5, 6, 7, 8, 9]) + G.add_edge(1, 2, weight=-30) + assert nx.find_negative_cycle(G, 1) == [0, 1, 2, 3, 4, 0] + assert nx.find_negative_cycle(G, 7) == [2, 3, 4, 0, 1, 2] + + def test_find_negative_cycle_no_cycle(self): + G = nx.path_graph(5, create_using=nx.DiGraph()) + pytest.raises(nx.NetworkXError, nx.find_negative_cycle, G, 3) + + def test_find_negative_cycle_single_edge(self): + G = nx.Graph() + G.add_edge(0, 1, weight=-1) + assert nx.find_negative_cycle(G, 1) == [1, 0, 1] + + def test_negative_weight(self): + G = nx.cycle_graph(5, create_using=nx.DiGraph()) + G.add_edge(1, 2, weight=-3) + assert nx.single_source_bellman_ford_path(G, 0) == { + 0: [0], + 1: [0, 1], + 2: [0, 1, 2], + 3: [0, 1, 2, 3], + 4: [0, 1, 2, 3, 4], + } + assert nx.single_source_bellman_ford_path_length(G, 0) == { + 0: 0, + 1: 1, + 2: -2, + 3: -1, + 4: 0, + } + assert nx.single_source_bellman_ford(G, 0) == ( + {0: 0, 1: 1, 2: -2, 3: -1, 4: 0}, + {0: [0], 1: [0, 1], 2: [0, 1, 2], 3: [0, 1, 2, 3], 4: [0, 1, 2, 3, 4]}, + ) + assert nx.bellman_ford_predecessor_and_distance(G, 0) == ( + {0: [], 1: [0], 2: [1], 3: [2], 4: [3]}, + {0: 0, 1: 1, 2: -2, 3: -1, 4: 0}, + ) + assert nx.goldberg_radzik(G, 0) == ( + {0: None, 1: 0, 2: 1, 3: 2, 4: 3}, + {0: 0, 1: 1, 2: -2, 3: -1, 4: 0}, + ) + + def test_not_connected(self): + G = nx.complete_graph(6) + G.add_edge(10, 11) + G.add_edge(10, 12) + assert nx.single_source_bellman_ford_path(G, 0) == { + 0: [0], + 1: [0, 1], + 2: [0, 2], + 3: [0, 3], + 4: [0, 4], + 5: [0, 5], + } + assert nx.single_source_bellman_ford_path_length(G, 0) == { + 0: 0, + 1: 1, + 2: 1, + 3: 1, + 4: 1, + 5: 1, + } + assert nx.single_source_bellman_ford(G, 0) == ( + {0: 0, 1: 1, 2: 1, 3: 1, 4: 1, 5: 1}, + {0: [0], 1: [0, 1], 2: [0, 2], 3: [0, 3], 4: [0, 4], 5: [0, 5]}, + ) + assert nx.bellman_ford_predecessor_and_distance(G, 0) == ( + {0: [], 1: [0], 2: [0], 3: [0], 4: [0], 5: [0]}, + {0: 0, 1: 1, 2: 1, 3: 1, 4: 1, 5: 1}, + ) + assert nx.goldberg_radzik(G, 0) == ( + {0: None, 1: 0, 2: 0, 3: 0, 4: 0, 5: 0}, + {0: 0, 1: 1, 2: 1, 3: 1, 4: 1, 5: 1}, + ) + + # not connected, with a component not containing the source that + # contains a negative cycle. + G = nx.complete_graph(6) + G.add_edges_from( + [ + ("A", "B", {"load": 3}), + ("B", "C", {"load": -10}), + ("C", "A", {"load": 2}), + ] + ) + assert nx.single_source_bellman_ford_path(G, 0, weight="load") == { + 0: [0], + 1: [0, 1], + 2: [0, 2], + 3: [0, 3], + 4: [0, 4], + 5: [0, 5], + } + assert nx.single_source_bellman_ford_path_length(G, 0, weight="load") == { + 0: 0, + 1: 1, + 2: 1, + 3: 1, + 4: 1, + 5: 1, + } + assert nx.single_source_bellman_ford(G, 0, weight="load") == ( + {0: 0, 1: 1, 2: 1, 3: 1, 4: 1, 5: 1}, + {0: [0], 1: [0, 1], 2: [0, 2], 3: [0, 3], 4: [0, 4], 5: [0, 5]}, + ) + assert nx.bellman_ford_predecessor_and_distance(G, 0, weight="load") == ( + {0: [], 1: [0], 2: [0], 3: [0], 4: [0], 5: [0]}, + {0: 0, 1: 1, 2: 1, 3: 1, 4: 1, 5: 1}, + ) + assert nx.goldberg_radzik(G, 0, weight="load") == ( + {0: None, 1: 0, 2: 0, 3: 0, 4: 0, 5: 0}, + {0: 0, 1: 1, 2: 1, 3: 1, 4: 1, 5: 1}, + ) + + def test_multigraph(self): + assert nx.bellman_ford_path(self.MXG, "s", "v") == ["s", "x", "u", "v"] + assert nx.bellman_ford_path_length(self.MXG, "s", "v") == 9 + assert nx.single_source_bellman_ford_path(self.MXG, "s")["v"] == [ + "s", + "x", + "u", + "v", + ] + assert nx.single_source_bellman_ford_path_length(self.MXG, "s")["v"] == 9 + D, P = nx.single_source_bellman_ford(self.MXG, "s", target="v") + assert D == 9 + assert P == ["s", "x", "u", "v"] + P, D = nx.bellman_ford_predecessor_and_distance(self.MXG, "s") + assert P["v"] == ["u"] + assert D["v"] == 9 + P, D = nx.goldberg_radzik(self.MXG, "s") + assert P["v"] == "u" + assert D["v"] == 9 + assert nx.bellman_ford_path(self.MXG4, 0, 2) == [0, 1, 2] + assert nx.bellman_ford_path_length(self.MXG4, 0, 2) == 4 + assert nx.single_source_bellman_ford_path(self.MXG4, 0)[2] == [0, 1, 2] + assert nx.single_source_bellman_ford_path_length(self.MXG4, 0)[2] == 4 + D, P = nx.single_source_bellman_ford(self.MXG4, 0, target=2) + assert D == 4 + assert P == [0, 1, 2] + P, D = nx.bellman_ford_predecessor_and_distance(self.MXG4, 0) + assert P[2] == [1] + assert D[2] == 4 + P, D = nx.goldberg_radzik(self.MXG4, 0) + assert P[2] == 1 + assert D[2] == 4 + + def test_others(self): + assert nx.bellman_ford_path(self.XG, "s", "v") == ["s", "x", "u", "v"] + assert nx.bellman_ford_path_length(self.XG, "s", "v") == 9 + assert nx.single_source_bellman_ford_path(self.XG, "s")["v"] == [ + "s", + "x", + "u", + "v", + ] + assert nx.single_source_bellman_ford_path_length(self.XG, "s")["v"] == 9 + D, P = nx.single_source_bellman_ford(self.XG, "s", target="v") + assert D == 9 + assert P == ["s", "x", "u", "v"] + (P, D) = nx.bellman_ford_predecessor_and_distance(self.XG, "s") + assert P["v"] == ["u"] + assert D["v"] == 9 + (P, D) = nx.goldberg_radzik(self.XG, "s") + assert P["v"] == "u" + assert D["v"] == 9 + + def test_path_graph(self): + G = nx.path_graph(4) + assert nx.single_source_bellman_ford_path(G, 0) == { + 0: [0], + 1: [0, 1], + 2: [0, 1, 2], + 3: [0, 1, 2, 3], + } + assert nx.single_source_bellman_ford_path_length(G, 0) == { + 0: 0, + 1: 1, + 2: 2, + 3: 3, + } + assert nx.single_source_bellman_ford(G, 0) == ( + {0: 0, 1: 1, 2: 2, 3: 3}, + {0: [0], 1: [0, 1], 2: [0, 1, 2], 3: [0, 1, 2, 3]}, + ) + assert nx.bellman_ford_predecessor_and_distance(G, 0) == ( + {0: [], 1: [0], 2: [1], 3: [2]}, + {0: 0, 1: 1, 2: 2, 3: 3}, + ) + assert nx.goldberg_radzik(G, 0) == ( + {0: None, 1: 0, 2: 1, 3: 2}, + {0: 0, 1: 1, 2: 2, 3: 3}, + ) + assert nx.single_source_bellman_ford_path(G, 3) == { + 0: [3, 2, 1, 0], + 1: [3, 2, 1], + 2: [3, 2], + 3: [3], + } + assert nx.single_source_bellman_ford_path_length(G, 3) == { + 0: 3, + 1: 2, + 2: 1, + 3: 0, + } + assert nx.single_source_bellman_ford(G, 3) == ( + {0: 3, 1: 2, 2: 1, 3: 0}, + {0: [3, 2, 1, 0], 1: [3, 2, 1], 2: [3, 2], 3: [3]}, + ) + assert nx.bellman_ford_predecessor_and_distance(G, 3) == ( + {0: [1], 1: [2], 2: [3], 3: []}, + {0: 3, 1: 2, 2: 1, 3: 0}, + ) + assert nx.goldberg_radzik(G, 3) == ( + {0: 1, 1: 2, 2: 3, 3: None}, + {0: 3, 1: 2, 2: 1, 3: 0}, + ) + + def test_4_cycle(self): + # 4-cycle + G = nx.Graph([(0, 1), (1, 2), (2, 3), (3, 0)]) + dist, path = nx.single_source_bellman_ford(G, 0) + assert dist == {0: 0, 1: 1, 2: 2, 3: 1} + assert path[0] == [0] + assert path[1] == [0, 1] + assert path[2] in [[0, 1, 2], [0, 3, 2]] + assert path[3] == [0, 3] + + pred, dist = nx.bellman_ford_predecessor_and_distance(G, 0) + assert pred[0] == [] + assert pred[1] == [0] + assert pred[2] in [[1, 3], [3, 1]] + assert pred[3] == [0] + assert dist == {0: 0, 1: 1, 2: 2, 3: 1} + + pred, dist = nx.goldberg_radzik(G, 0) + assert pred[0] is None + assert pred[1] == 0 + assert pred[2] in [1, 3] + assert pred[3] == 0 + assert dist == {0: 0, 1: 1, 2: 2, 3: 1} + + def test_negative_weight_bf_path(self): + G = nx.DiGraph() + G.add_nodes_from("abcd") + G.add_edge("a", "d", weight=0) + G.add_edge("a", "b", weight=1) + G.add_edge("b", "c", weight=-3) + G.add_edge("c", "d", weight=1) + + assert nx.bellman_ford_path(G, "a", "d") == ["a", "b", "c", "d"] + assert nx.bellman_ford_path_length(G, "a", "d") == -1 + + def test_zero_cycle_smoke(self): + D = nx.DiGraph() + D.add_weighted_edges_from([(0, 1, 1), (1, 2, 1), (2, 3, 1), (3, 1, -2)]) + + nx.bellman_ford_path(D, 1, 3) + nx.dijkstra_path(D, 1, 3) + nx.bidirectional_dijkstra(D, 1, 3) + # FIXME nx.goldberg_radzik(D, 1) + + +class TestJohnsonAlgorithm(WeightedTestBase): + def test_single_node_graph(self): + G = nx.DiGraph() + G.add_node(0) + assert nx.johnson(G) == {0: {0: [0]}} + + def test_negative_cycle(self): + G = nx.DiGraph() + G.add_weighted_edges_from( + [ + ("0", "3", 3), + ("0", "1", -5), + ("1", "0", -5), + ("0", "2", 2), + ("1", "2", 4), + ("2", "3", 1), + ] + ) + pytest.raises(nx.NetworkXUnbounded, nx.johnson, G) + G = nx.Graph() + G.add_weighted_edges_from( + [ + ("0", "3", 3), + ("0", "1", -5), + ("1", "0", -5), + ("0", "2", 2), + ("1", "2", 4), + ("2", "3", 1), + ] + ) + pytest.raises(nx.NetworkXUnbounded, nx.johnson, G) + + def test_negative_weights(self): + G = nx.DiGraph() + G.add_weighted_edges_from( + [("0", "3", 3), ("0", "1", -5), ("0", "2", 2), ("1", "2", 4), ("2", "3", 1)] + ) + paths = nx.johnson(G) + assert paths == { + "1": {"1": ["1"], "3": ["1", "2", "3"], "2": ["1", "2"]}, + "0": { + "1": ["0", "1"], + "0": ["0"], + "3": ["0", "1", "2", "3"], + "2": ["0", "1", "2"], + }, + "3": {"3": ["3"]}, + "2": {"3": ["2", "3"], "2": ["2"]}, + } + + def test_unweighted_graph(self): + G = nx.Graph() + G.add_edges_from([(1, 0), (2, 1)]) + H = G.copy() + nx.set_edge_attributes(H, values=1, name="weight") + assert nx.johnson(G) == nx.johnson(H) + + def test_partially_weighted_graph_with_negative_edges(self): + G = nx.DiGraph() + G.add_edges_from([(0, 1), (1, 2), (2, 0), (1, 0)]) + G[1][0]["weight"] = -2 + G[0][1]["weight"] = 3 + G[1][2]["weight"] = -4 + + H = G.copy() + H[2][0]["weight"] = 1 + + I = G.copy() + I[2][0]["weight"] = 8 + + assert nx.johnson(G) == nx.johnson(H) + assert nx.johnson(G) != nx.johnson(I) + + def test_graphs(self): + validate_path(self.XG, "s", "v", 9, nx.johnson(self.XG)["s"]["v"]) + validate_path(self.MXG, "s", "v", 9, nx.johnson(self.MXG)["s"]["v"]) + validate_path(self.XG2, 1, 3, 4, nx.johnson(self.XG2)[1][3]) + validate_path(self.XG3, 0, 3, 15, nx.johnson(self.XG3)[0][3]) + validate_path(self.XG4, 0, 2, 4, nx.johnson(self.XG4)[0][2]) + validate_path(self.MXG4, 0, 2, 4, nx.johnson(self.MXG4)[0][2]) diff --git a/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/unweighted.py b/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/unweighted.py new file mode 100644 index 0000000000000000000000000000000000000000..3aeef85478f94d5cfff34f9b7aa2bb4595adef9d --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/unweighted.py @@ -0,0 +1,579 @@ +""" +Shortest path algorithms for unweighted graphs. +""" + +import warnings + +import networkx as nx + +__all__ = [ + "bidirectional_shortest_path", + "single_source_shortest_path", + "single_source_shortest_path_length", + "single_target_shortest_path", + "single_target_shortest_path_length", + "all_pairs_shortest_path", + "all_pairs_shortest_path_length", + "predecessor", +] + + +@nx._dispatchable +def single_source_shortest_path_length(G, source, cutoff=None): + """Compute the shortest path lengths from source to all reachable nodes. + + Parameters + ---------- + G : NetworkX graph + + source : node + Starting node for path + + cutoff : integer, optional + Depth to stop the search. Only paths of length <= cutoff are returned. + + Returns + ------- + lengths : dict + Dict keyed by node to shortest path length to source. + + Examples + -------- + >>> G = nx.path_graph(5) + >>> length = nx.single_source_shortest_path_length(G, 0) + >>> length[4] + 4 + >>> for node in length: + ... print(f"{node}: {length[node]}") + 0: 0 + 1: 1 + 2: 2 + 3: 3 + 4: 4 + + See Also + -------- + shortest_path_length + """ + if source not in G: + raise nx.NodeNotFound(f"Source {source} is not in G") + if cutoff is None: + cutoff = float("inf") + nextlevel = [source] + return dict(_single_shortest_path_length(G._adj, nextlevel, cutoff)) + + +def _single_shortest_path_length(adj, firstlevel, cutoff): + """Yields (node, level) in a breadth first search + + Shortest Path Length helper function + Parameters + ---------- + adj : dict + Adjacency dict or view + firstlevel : list + starting nodes, e.g. [source] or [target] + cutoff : int or float + level at which we stop the process + """ + seen = set(firstlevel) + nextlevel = firstlevel + level = 0 + n = len(adj) + for v in nextlevel: + yield (v, level) + while nextlevel and cutoff > level: + level += 1 + thislevel = nextlevel + nextlevel = [] + for v in thislevel: + for w in adj[v]: + if w not in seen: + seen.add(w) + nextlevel.append(w) + yield (w, level) + if len(seen) == n: + return + + +@nx._dispatchable +def single_target_shortest_path_length(G, target, cutoff=None): + """Compute the shortest path lengths to target from all reachable nodes. + + Parameters + ---------- + G : NetworkX graph + + target : node + Target node for path + + cutoff : integer, optional + Depth to stop the search. Only paths of length <= cutoff are returned. + + Returns + ------- + lengths : iterator + (source, shortest path length) iterator + + Examples + -------- + >>> G = nx.path_graph(5, create_using=nx.DiGraph()) + >>> length = dict(nx.single_target_shortest_path_length(G, 4)) + >>> length[0] + 4 + >>> for node in range(5): + ... print(f"{node}: {length[node]}") + 0: 4 + 1: 3 + 2: 2 + 3: 1 + 4: 0 + + See Also + -------- + single_source_shortest_path_length, shortest_path_length + """ + if target not in G: + raise nx.NodeNotFound(f"Target {target} is not in G") + + warnings.warn( + ( + "\n\nsingle_target_shortest_path_length will return a dict instead of" + "\nan iterator in version 3.5" + ), + FutureWarning, + stacklevel=3, + ) + + if cutoff is None: + cutoff = float("inf") + # handle either directed or undirected + adj = G._pred if G.is_directed() else G._adj + nextlevel = [target] + # for version 3.3 we will return a dict like this: + # return dict(_single_shortest_path_length(adj, nextlevel, cutoff)) + return _single_shortest_path_length(adj, nextlevel, cutoff) + + +@nx._dispatchable +def all_pairs_shortest_path_length(G, cutoff=None): + """Computes the shortest path lengths between all nodes in `G`. + + Parameters + ---------- + G : NetworkX graph + + cutoff : integer, optional + Depth at which to stop the search. Only paths of length at most + `cutoff` are returned. + + Returns + ------- + lengths : iterator + (source, dictionary) iterator with dictionary keyed by target and + shortest path length as the key value. + + Notes + ----- + The iterator returned only has reachable node pairs. + + Examples + -------- + >>> G = nx.path_graph(5) + >>> length = dict(nx.all_pairs_shortest_path_length(G)) + >>> for node in [0, 1, 2, 3, 4]: + ... print(f"1 - {node}: {length[1][node]}") + 1 - 0: 1 + 1 - 1: 0 + 1 - 2: 1 + 1 - 3: 2 + 1 - 4: 3 + >>> length[3][2] + 1 + >>> length[2][2] + 0 + + """ + length = single_source_shortest_path_length + # TODO This can be trivially parallelized. + for n in G: + yield (n, length(G, n, cutoff=cutoff)) + + +@nx._dispatchable +def bidirectional_shortest_path(G, source, target): + """Returns a list of nodes in a shortest path between source and target. + + Parameters + ---------- + G : NetworkX graph + + source : node label + starting node for path + + target : node label + ending node for path + + Returns + ------- + path: list + List of nodes in a path from source to target. + + Raises + ------ + NetworkXNoPath + If no path exists between source and target. + + Examples + -------- + >>> G = nx.Graph() + >>> nx.add_path(G, [0, 1, 2, 3, 0, 4, 5, 6, 7, 4]) + >>> nx.bidirectional_shortest_path(G, 2, 6) + [2, 1, 0, 4, 5, 6] + + See Also + -------- + shortest_path + + Notes + ----- + This algorithm is used by shortest_path(G, source, target). + """ + + if source not in G: + raise nx.NodeNotFound(f"Source {source} is not in G") + + if target not in G: + raise nx.NodeNotFound(f"Target {target} is not in G") + + # call helper to do the real work + results = _bidirectional_pred_succ(G, source, target) + pred, succ, w = results + + # build path from pred+w+succ + path = [] + # from source to w + while w is not None: + path.append(w) + w = pred[w] + path.reverse() + # from w to target + w = succ[path[-1]] + while w is not None: + path.append(w) + w = succ[w] + + return path + + +def _bidirectional_pred_succ(G, source, target): + """Bidirectional shortest path helper. + + Returns (pred, succ, w) where + pred is a dictionary of predecessors from w to the source, and + succ is a dictionary of successors from w to the target. + """ + # does BFS from both source and target and meets in the middle + if target == source: + return ({target: None}, {source: None}, source) + + # handle either directed or undirected + if G.is_directed(): + Gpred = G.pred + Gsucc = G.succ + else: + Gpred = G.adj + Gsucc = G.adj + + # predecessor and successors in search + pred = {source: None} + succ = {target: None} + + # initialize fringes, start with forward + forward_fringe = [source] + reverse_fringe = [target] + + while forward_fringe and reverse_fringe: + if len(forward_fringe) <= len(reverse_fringe): + this_level = forward_fringe + forward_fringe = [] + for v in this_level: + for w in Gsucc[v]: + if w not in pred: + forward_fringe.append(w) + pred[w] = v + if w in succ: # path found + return pred, succ, w + else: + this_level = reverse_fringe + reverse_fringe = [] + for v in this_level: + for w in Gpred[v]: + if w not in succ: + succ[w] = v + reverse_fringe.append(w) + if w in pred: # found path + return pred, succ, w + + raise nx.NetworkXNoPath(f"No path between {source} and {target}.") + + +@nx._dispatchable +def single_source_shortest_path(G, source, cutoff=None): + """Compute shortest path between source + and all other nodes reachable from source. + + Parameters + ---------- + G : NetworkX graph + + source : node label + Starting node for path + + cutoff : integer, optional + Depth to stop the search. Only paths of length <= cutoff are returned. + + Returns + ------- + paths : dictionary + Dictionary, keyed by target, of shortest paths. + + Examples + -------- + >>> G = nx.path_graph(5) + >>> path = nx.single_source_shortest_path(G, 0) + >>> path[4] + [0, 1, 2, 3, 4] + + Notes + ----- + The shortest path is not necessarily unique. So there can be multiple + paths between the source and each target node, all of which have the + same 'shortest' length. For each target node, this function returns + only one of those paths. + + See Also + -------- + shortest_path + """ + if source not in G: + raise nx.NodeNotFound(f"Source {source} not in G") + + def join(p1, p2): + return p1 + p2 + + if cutoff is None: + cutoff = float("inf") + nextlevel = {source: 1} # list of nodes to check at next level + paths = {source: [source]} # paths dictionary (paths to key from source) + return dict(_single_shortest_path(G.adj, nextlevel, paths, cutoff, join)) + + +def _single_shortest_path(adj, firstlevel, paths, cutoff, join): + """Returns shortest paths + + Shortest Path helper function + Parameters + ---------- + adj : dict + Adjacency dict or view + firstlevel : dict + starting nodes, e.g. {source: 1} or {target: 1} + paths : dict + paths for starting nodes, e.g. {source: [source]} + cutoff : int or float + level at which we stop the process + join : function + function to construct a path from two partial paths. Requires two + list inputs `p1` and `p2`, and returns a list. Usually returns + `p1 + p2` (forward from source) or `p2 + p1` (backward from target) + """ + level = 0 # the current level + nextlevel = firstlevel + while nextlevel and cutoff > level: + thislevel = nextlevel + nextlevel = {} + for v in thislevel: + for w in adj[v]: + if w not in paths: + paths[w] = join(paths[v], [w]) + nextlevel[w] = 1 + level += 1 + return paths + + +@nx._dispatchable +def single_target_shortest_path(G, target, cutoff=None): + """Compute shortest path to target from all nodes that reach target. + + Parameters + ---------- + G : NetworkX graph + + target : node label + Target node for path + + cutoff : integer, optional + Depth to stop the search. Only paths of length <= cutoff are returned. + + Returns + ------- + paths : dictionary + Dictionary, keyed by target, of shortest paths. + + Examples + -------- + >>> G = nx.path_graph(5, create_using=nx.DiGraph()) + >>> path = nx.single_target_shortest_path(G, 4) + >>> path[0] + [0, 1, 2, 3, 4] + + Notes + ----- + The shortest path is not necessarily unique. So there can be multiple + paths between the source and each target node, all of which have the + same 'shortest' length. For each target node, this function returns + only one of those paths. + + See Also + -------- + shortest_path, single_source_shortest_path + """ + if target not in G: + raise nx.NodeNotFound(f"Target {target} not in G") + + def join(p1, p2): + return p2 + p1 + + # handle undirected graphs + adj = G.pred if G.is_directed() else G.adj + if cutoff is None: + cutoff = float("inf") + nextlevel = {target: 1} # list of nodes to check at next level + paths = {target: [target]} # paths dictionary (paths to key from source) + return dict(_single_shortest_path(adj, nextlevel, paths, cutoff, join)) + + +@nx._dispatchable +def all_pairs_shortest_path(G, cutoff=None): + """Compute shortest paths between all nodes. + + Parameters + ---------- + G : NetworkX graph + + cutoff : integer, optional + Depth at which to stop the search. Only paths of length at most + `cutoff` are returned. + + Returns + ------- + paths : iterator + Dictionary, keyed by source and target, of shortest paths. + + Examples + -------- + >>> G = nx.path_graph(5) + >>> path = dict(nx.all_pairs_shortest_path(G)) + >>> print(path[0][4]) + [0, 1, 2, 3, 4] + + Notes + ----- + There may be multiple shortest paths with the same length between + two nodes. For each pair, this function returns only one of those paths. + + See Also + -------- + floyd_warshall + all_pairs_all_shortest_paths + + """ + # TODO This can be trivially parallelized. + for n in G: + yield (n, single_source_shortest_path(G, n, cutoff=cutoff)) + + +@nx._dispatchable +def predecessor(G, source, target=None, cutoff=None, return_seen=None): + """Returns dict of predecessors for the path from source to all nodes in G. + + Parameters + ---------- + G : NetworkX graph + + source : node label + Starting node for path + + target : node label, optional + Ending node for path. If provided only predecessors between + source and target are returned + + cutoff : integer, optional + Depth to stop the search. Only paths of length <= cutoff are returned. + + return_seen : bool, optional (default=None) + Whether to return a dictionary, keyed by node, of the level (number of + hops) to reach the node (as seen during breadth-first-search). + + Returns + ------- + pred : dictionary + Dictionary, keyed by node, of predecessors in the shortest path. + + + (pred, seen): tuple of dictionaries + If `return_seen` argument is set to `True`, then a tuple of dictionaries + is returned. The first element is the dictionary, keyed by node, of + predecessors in the shortest path. The second element is the dictionary, + keyed by node, of the level (number of hops) to reach the node (as seen + during breadth-first-search). + + Examples + -------- + >>> G = nx.path_graph(4) + >>> list(G) + [0, 1, 2, 3] + >>> nx.predecessor(G, 0) + {0: [], 1: [0], 2: [1], 3: [2]} + >>> nx.predecessor(G, 0, return_seen=True) + ({0: [], 1: [0], 2: [1], 3: [2]}, {0: 0, 1: 1, 2: 2, 3: 3}) + + + """ + if source not in G: + raise nx.NodeNotFound(f"Source {source} not in G") + + level = 0 # the current level + nextlevel = [source] # list of nodes to check at next level + seen = {source: level} # level (number of hops) when seen in BFS + pred = {source: []} # predecessor dictionary + while nextlevel: + level = level + 1 + thislevel = nextlevel + nextlevel = [] + for v in thislevel: + for w in G[v]: + if w not in seen: + pred[w] = [v] + seen[w] = level + nextlevel.append(w) + elif seen[w] == level: # add v to predecessor list if it + pred[w].append(v) # is at the correct level + if cutoff and cutoff <= level: + break + + if target is not None: + if return_seen: + if target not in pred: + return ([], -1) # No predecessor + return (pred[target], seen[target]) + else: + if target not in pred: + return [] # No predecessor + return pred[target] + else: + if return_seen: + return (pred, seen) + else: + return pred diff --git a/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/weighted.py b/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/weighted.py new file mode 100644 index 0000000000000000000000000000000000000000..f8421d42b5d65b23dea79bb16cc8280940f83127 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/shortest_paths/weighted.py @@ -0,0 +1,2520 @@ +""" +Shortest path algorithms for weighted graphs. +""" + +from collections import deque +from heapq import heappop, heappush +from itertools import count + +import networkx as nx +from networkx.algorithms.shortest_paths.generic import _build_paths_from_predecessors + +__all__ = [ + "dijkstra_path", + "dijkstra_path_length", + "bidirectional_dijkstra", + "single_source_dijkstra", + "single_source_dijkstra_path", + "single_source_dijkstra_path_length", + "multi_source_dijkstra", + "multi_source_dijkstra_path", + "multi_source_dijkstra_path_length", + "all_pairs_dijkstra", + "all_pairs_dijkstra_path", + "all_pairs_dijkstra_path_length", + "dijkstra_predecessor_and_distance", + "bellman_ford_path", + "bellman_ford_path_length", + "single_source_bellman_ford", + "single_source_bellman_ford_path", + "single_source_bellman_ford_path_length", + "all_pairs_bellman_ford_path", + "all_pairs_bellman_ford_path_length", + "bellman_ford_predecessor_and_distance", + "negative_edge_cycle", + "find_negative_cycle", + "goldberg_radzik", + "johnson", +] + + +def _weight_function(G, weight): + """Returns a function that returns the weight of an edge. + + The returned function is specifically suitable for input to + functions :func:`_dijkstra` and :func:`_bellman_ford_relaxation`. + + Parameters + ---------- + G : NetworkX graph. + + weight : string or function + If it is callable, `weight` itself is returned. If it is a string, + it is assumed to be the name of the edge attribute that represents + the weight of an edge. In that case, a function is returned that + gets the edge weight according to the specified edge attribute. + + Returns + ------- + function + This function returns a callable that accepts exactly three inputs: + a node, an node adjacent to the first one, and the edge attribute + dictionary for the eedge joining those nodes. That function returns + a number representing the weight of an edge. + + If `G` is a multigraph, and `weight` is not callable, the + minimum edge weight over all parallel edges is returned. If any edge + does not have an attribute with key `weight`, it is assumed to + have weight one. + + """ + if callable(weight): + return weight + # If the weight keyword argument is not callable, we assume it is a + # string representing the edge attribute containing the weight of + # the edge. + if G.is_multigraph(): + return lambda u, v, d: min(attr.get(weight, 1) for attr in d.values()) + return lambda u, v, data: data.get(weight, 1) + + +@nx._dispatchable(edge_attrs="weight") +def dijkstra_path(G, source, target, weight="weight"): + """Returns the shortest weighted path from source to target in G. + + Uses Dijkstra's Method to compute the shortest weighted path + between two nodes in a graph. + + Parameters + ---------- + G : NetworkX graph + + source : node + Starting node + + target : node + Ending node + + weight : string or function + If this is a string, then edge weights will be accessed via the + edge attribute with this key (that is, the weight of the edge + joining `u` to `v` will be ``G.edges[u, v][weight]``). If no + such edge attribute exists, the weight of the edge is assumed to + be one. + + If this is a function, the weight of an edge is the value + returned by the function. The function must accept exactly three + positional arguments: the two endpoints of an edge and the + dictionary of edge attributes for that edge. The function must + return a number or None to indicate a hidden edge. + + Returns + ------- + path : list + List of nodes in a shortest path. + + Raises + ------ + NodeNotFound + If `source` is not in `G`. + + NetworkXNoPath + If no path exists between source and target. + + Examples + -------- + >>> G = nx.path_graph(5) + >>> print(nx.dijkstra_path(G, 0, 4)) + [0, 1, 2, 3, 4] + + Find edges of shortest path in Multigraph + + >>> G = nx.MultiDiGraph() + >>> G.add_weighted_edges_from([(1, 2, 0.75), (1, 2, 0.5), (2, 3, 0.5), (1, 3, 1.5)]) + >>> nodes = nx.dijkstra_path(G, 1, 3) + >>> edges = nx.utils.pairwise(nodes) + >>> list( + ... (u, v, min(G[u][v], key=lambda k: G[u][v][k].get("weight", 1))) + ... for u, v in edges + ... ) + [(1, 2, 1), (2, 3, 0)] + + Notes + ----- + Edge weight attributes must be numerical. + Distances are calculated as sums of weighted edges traversed. + + The weight function can be used to hide edges by returning None. + So ``weight = lambda u, v, d: 1 if d['color']=="red" else None`` + will find the shortest red path. + + The weight function can be used to include node weights. + + >>> def func(u, v, d): + ... node_u_wt = G.nodes[u].get("node_weight", 1) + ... node_v_wt = G.nodes[v].get("node_weight", 1) + ... edge_wt = d.get("weight", 1) + ... return node_u_wt / 2 + node_v_wt / 2 + edge_wt + + In this example we take the average of start and end node + weights of an edge and add it to the weight of the edge. + + The function :func:`single_source_dijkstra` computes both + path and length-of-path if you need both, use that. + + See Also + -------- + bidirectional_dijkstra + bellman_ford_path + single_source_dijkstra + """ + (length, path) = single_source_dijkstra(G, source, target=target, weight=weight) + return path + + +@nx._dispatchable(edge_attrs="weight") +def dijkstra_path_length(G, source, target, weight="weight"): + """Returns the shortest weighted path length in G from source to target. + + Uses Dijkstra's Method to compute the shortest weighted path length + between two nodes in a graph. + + Parameters + ---------- + G : NetworkX graph + + source : node label + starting node for path + + target : node label + ending node for path + + weight : string or function + If this is a string, then edge weights will be accessed via the + edge attribute with this key (that is, the weight of the edge + joining `u` to `v` will be ``G.edges[u, v][weight]``). If no + such edge attribute exists, the weight of the edge is assumed to + be one. + + If this is a function, the weight of an edge is the value + returned by the function. The function must accept exactly three + positional arguments: the two endpoints of an edge and the + dictionary of edge attributes for that edge. The function must + return a number or None to indicate a hidden edge. + + Returns + ------- + length : number + Shortest path length. + + Raises + ------ + NodeNotFound + If `source` is not in `G`. + + NetworkXNoPath + If no path exists between source and target. + + Examples + -------- + >>> G = nx.path_graph(5) + >>> nx.dijkstra_path_length(G, 0, 4) + 4 + + Notes + ----- + Edge weight attributes must be numerical. + Distances are calculated as sums of weighted edges traversed. + + The weight function can be used to hide edges by returning None. + So ``weight = lambda u, v, d: 1 if d['color']=="red" else None`` + will find the shortest red path. + + The function :func:`single_source_dijkstra` computes both + path and length-of-path if you need both, use that. + + See Also + -------- + bidirectional_dijkstra + bellman_ford_path_length + single_source_dijkstra + + """ + if source not in G: + raise nx.NodeNotFound(f"Node {source} not found in graph") + if source == target: + return 0 + weight = _weight_function(G, weight) + length = _dijkstra(G, source, weight, target=target) + try: + return length[target] + except KeyError as err: + raise nx.NetworkXNoPath(f"Node {target} not reachable from {source}") from err + + +@nx._dispatchable(edge_attrs="weight") +def single_source_dijkstra_path(G, source, cutoff=None, weight="weight"): + """Find shortest weighted paths in G from a source node. + + Compute shortest path between source and all other reachable + nodes for a weighted graph. + + Parameters + ---------- + G : NetworkX graph + + source : node + Starting node for path. + + cutoff : integer or float, optional + Length (sum of edge weights) at which the search is stopped. + If cutoff is provided, only return paths with summed weight <= cutoff. + + weight : string or function + If this is a string, then edge weights will be accessed via the + edge attribute with this key (that is, the weight of the edge + joining `u` to `v` will be ``G.edges[u, v][weight]``). If no + such edge attribute exists, the weight of the edge is assumed to + be one. + + If this is a function, the weight of an edge is the value + returned by the function. The function must accept exactly three + positional arguments: the two endpoints of an edge and the + dictionary of edge attributes for that edge. The function must + return a number or None to indicate a hidden edge. + + Returns + ------- + paths : dictionary + Dictionary of shortest path lengths keyed by target. + + Raises + ------ + NodeNotFound + If `source` is not in `G`. + + Examples + -------- + >>> G = nx.path_graph(5) + >>> path = nx.single_source_dijkstra_path(G, 0) + >>> path[4] + [0, 1, 2, 3, 4] + + Notes + ----- + Edge weight attributes must be numerical. + Distances are calculated as sums of weighted edges traversed. + + The weight function can be used to hide edges by returning None. + So ``weight = lambda u, v, d: 1 if d['color']=="red" else None`` + will find the shortest red path. + + See Also + -------- + single_source_dijkstra, single_source_bellman_ford + + """ + return multi_source_dijkstra_path(G, {source}, cutoff=cutoff, weight=weight) + + +@nx._dispatchable(edge_attrs="weight") +def single_source_dijkstra_path_length(G, source, cutoff=None, weight="weight"): + """Find shortest weighted path lengths in G from a source node. + + Compute the shortest path length between source and all other + reachable nodes for a weighted graph. + + Parameters + ---------- + G : NetworkX graph + + source : node label + Starting node for path + + cutoff : integer or float, optional + Length (sum of edge weights) at which the search is stopped. + If cutoff is provided, only return paths with summed weight <= cutoff. + + weight : string or function + If this is a string, then edge weights will be accessed via the + edge attribute with this key (that is, the weight of the edge + joining `u` to `v` will be ``G.edges[u, v][weight]``). If no + such edge attribute exists, the weight of the edge is assumed to + be one. + + If this is a function, the weight of an edge is the value + returned by the function. The function must accept exactly three + positional arguments: the two endpoints of an edge and the + dictionary of edge attributes for that edge. The function must + return a number or None to indicate a hidden edge. + + Returns + ------- + length : dict + Dict keyed by node to shortest path length from source. + + Raises + ------ + NodeNotFound + If `source` is not in `G`. + + Examples + -------- + >>> G = nx.path_graph(5) + >>> length = nx.single_source_dijkstra_path_length(G, 0) + >>> length[4] + 4 + >>> for node in [0, 1, 2, 3, 4]: + ... print(f"{node}: {length[node]}") + 0: 0 + 1: 1 + 2: 2 + 3: 3 + 4: 4 + + Notes + ----- + Edge weight attributes must be numerical. + Distances are calculated as sums of weighted edges traversed. + + The weight function can be used to hide edges by returning None. + So ``weight = lambda u, v, d: 1 if d['color']=="red" else None`` + will find the shortest red path. + + See Also + -------- + single_source_dijkstra, single_source_bellman_ford_path_length + + """ + return multi_source_dijkstra_path_length(G, {source}, cutoff=cutoff, weight=weight) + + +@nx._dispatchable(edge_attrs="weight") +def single_source_dijkstra(G, source, target=None, cutoff=None, weight="weight"): + """Find shortest weighted paths and lengths from a source node. + + Compute the shortest path length between source and all other + reachable nodes for a weighted graph. + + Uses Dijkstra's algorithm to compute shortest paths and lengths + between a source and all other reachable nodes in a weighted graph. + + Parameters + ---------- + G : NetworkX graph + + source : node label + Starting node for path + + target : node label, optional + Ending node for path + + cutoff : integer or float, optional + Length (sum of edge weights) at which the search is stopped. + If cutoff is provided, only return paths with summed weight <= cutoff. + + + weight : string or function + If this is a string, then edge weights will be accessed via the + edge attribute with this key (that is, the weight of the edge + joining `u` to `v` will be ``G.edges[u, v][weight]``). If no + such edge attribute exists, the weight of the edge is assumed to + be one. + + If this is a function, the weight of an edge is the value + returned by the function. The function must accept exactly three + positional arguments: the two endpoints of an edge and the + dictionary of edge attributes for that edge. The function must + return a number or None to indicate a hidden edge. + + Returns + ------- + distance, path : pair of dictionaries, or numeric and list. + If target is None, paths and lengths to all nodes are computed. + The return value is a tuple of two dictionaries keyed by target nodes. + The first dictionary stores distance to each target node. + The second stores the path to each target node. + If target is not None, returns a tuple (distance, path), where + distance is the distance from source to target and path is a list + representing the path from source to target. + + Raises + ------ + NodeNotFound + If `source` is not in `G`. + + Examples + -------- + >>> G = nx.path_graph(5) + >>> length, path = nx.single_source_dijkstra(G, 0) + >>> length[4] + 4 + >>> for node in [0, 1, 2, 3, 4]: + ... print(f"{node}: {length[node]}") + 0: 0 + 1: 1 + 2: 2 + 3: 3 + 4: 4 + >>> path[4] + [0, 1, 2, 3, 4] + >>> length, path = nx.single_source_dijkstra(G, 0, 1) + >>> length + 1 + >>> path + [0, 1] + + Notes + ----- + Edge weight attributes must be numerical. + Distances are calculated as sums of weighted edges traversed. + + The weight function can be used to hide edges by returning None. + So ``weight = lambda u, v, d: 1 if d['color']=="red" else None`` + will find the shortest red path. + + Based on the Python cookbook recipe (119466) at + https://code.activestate.com/recipes/119466/ + + This algorithm is not guaranteed to work if edge weights + are negative or are floating point numbers + (overflows and roundoff errors can cause problems). + + See Also + -------- + single_source_dijkstra_path + single_source_dijkstra_path_length + single_source_bellman_ford + """ + return multi_source_dijkstra( + G, {source}, cutoff=cutoff, target=target, weight=weight + ) + + +@nx._dispatchable(edge_attrs="weight") +def multi_source_dijkstra_path(G, sources, cutoff=None, weight="weight"): + """Find shortest weighted paths in G from a given set of source + nodes. + + Compute shortest path between any of the source nodes and all other + reachable nodes for a weighted graph. + + Parameters + ---------- + G : NetworkX graph + + sources : non-empty set of nodes + Starting nodes for paths. If this is just a set containing a + single node, then all paths computed by this function will start + from that node. If there are two or more nodes in the set, the + computed paths may begin from any one of the start nodes. + + cutoff : integer or float, optional + Length (sum of edge weights) at which the search is stopped. + If cutoff is provided, only return paths with summed weight <= cutoff. + + weight : string or function + If this is a string, then edge weights will be accessed via the + edge attribute with this key (that is, the weight of the edge + joining `u` to `v` will be ``G.edges[u, v][weight]``). If no + such edge attribute exists, the weight of the edge is assumed to + be one. + + If this is a function, the weight of an edge is the value + returned by the function. The function must accept exactly three + positional arguments: the two endpoints of an edge and the + dictionary of edge attributes for that edge. The function must + return a number or None to indicate a hidden edge. + + Returns + ------- + paths : dictionary + Dictionary of shortest paths keyed by target. + + Examples + -------- + >>> G = nx.path_graph(5) + >>> path = nx.multi_source_dijkstra_path(G, {0, 4}) + >>> path[1] + [0, 1] + >>> path[3] + [4, 3] + + Notes + ----- + Edge weight attributes must be numerical. + Distances are calculated as sums of weighted edges traversed. + + The weight function can be used to hide edges by returning None. + So ``weight = lambda u, v, d: 1 if d['color']=="red" else None`` + will find the shortest red path. + + Raises + ------ + ValueError + If `sources` is empty. + NodeNotFound + If any of `sources` is not in `G`. + + See Also + -------- + multi_source_dijkstra, multi_source_bellman_ford + + """ + length, path = multi_source_dijkstra(G, sources, cutoff=cutoff, weight=weight) + return path + + +@nx._dispatchable(edge_attrs="weight") +def multi_source_dijkstra_path_length(G, sources, cutoff=None, weight="weight"): + """Find shortest weighted path lengths in G from a given set of + source nodes. + + Compute the shortest path length between any of the source nodes and + all other reachable nodes for a weighted graph. + + Parameters + ---------- + G : NetworkX graph + + sources : non-empty set of nodes + Starting nodes for paths. If this is just a set containing a + single node, then all paths computed by this function will start + from that node. If there are two or more nodes in the set, the + computed paths may begin from any one of the start nodes. + + cutoff : integer or float, optional + Length (sum of edge weights) at which the search is stopped. + If cutoff is provided, only return paths with summed weight <= cutoff. + + weight : string or function + If this is a string, then edge weights will be accessed via the + edge attribute with this key (that is, the weight of the edge + joining `u` to `v` will be ``G.edges[u, v][weight]``). If no + such edge attribute exists, the weight of the edge is assumed to + be one. + + If this is a function, the weight of an edge is the value + returned by the function. The function must accept exactly three + positional arguments: the two endpoints of an edge and the + dictionary of edge attributes for that edge. The function must + return a number or None to indicate a hidden edge. + + Returns + ------- + length : dict + Dict keyed by node to shortest path length to nearest source. + + Examples + -------- + >>> G = nx.path_graph(5) + >>> length = nx.multi_source_dijkstra_path_length(G, {0, 4}) + >>> for node in [0, 1, 2, 3, 4]: + ... print(f"{node}: {length[node]}") + 0: 0 + 1: 1 + 2: 2 + 3: 1 + 4: 0 + + Notes + ----- + Edge weight attributes must be numerical. + Distances are calculated as sums of weighted edges traversed. + + The weight function can be used to hide edges by returning None. + So ``weight = lambda u, v, d: 1 if d['color']=="red" else None`` + will find the shortest red path. + + Raises + ------ + ValueError + If `sources` is empty. + NodeNotFound + If any of `sources` is not in `G`. + + See Also + -------- + multi_source_dijkstra + + """ + if not sources: + raise ValueError("sources must not be empty") + for s in sources: + if s not in G: + raise nx.NodeNotFound(f"Node {s} not found in graph") + weight = _weight_function(G, weight) + return _dijkstra_multisource(G, sources, weight, cutoff=cutoff) + + +@nx._dispatchable(edge_attrs="weight") +def multi_source_dijkstra(G, sources, target=None, cutoff=None, weight="weight"): + """Find shortest weighted paths and lengths from a given set of + source nodes. + + Uses Dijkstra's algorithm to compute the shortest paths and lengths + between one of the source nodes and the given `target`, or all other + reachable nodes if not specified, for a weighted graph. + + Parameters + ---------- + G : NetworkX graph + + sources : non-empty set of nodes + Starting nodes for paths. If this is just a set containing a + single node, then all paths computed by this function will start + from that node. If there are two or more nodes in the set, the + computed paths may begin from any one of the start nodes. + + target : node label, optional + Ending node for path + + cutoff : integer or float, optional + Length (sum of edge weights) at which the search is stopped. + If cutoff is provided, only return paths with summed weight <= cutoff. + + weight : string or function + If this is a string, then edge weights will be accessed via the + edge attribute with this key (that is, the weight of the edge + joining `u` to `v` will be ``G.edges[u, v][weight]``). If no + such edge attribute exists, the weight of the edge is assumed to + be one. + + If this is a function, the weight of an edge is the value + returned by the function. The function must accept exactly three + positional arguments: the two endpoints of an edge and the + dictionary of edge attributes for that edge. The function must + return a number or None to indicate a hidden edge. + + Returns + ------- + distance, path : pair of dictionaries, or numeric and list + If target is None, returns a tuple of two dictionaries keyed by node. + The first dictionary stores distance from one of the source nodes. + The second stores the path from one of the sources to that node. + If target is not None, returns a tuple of (distance, path) where + distance is the distance from source to target and path is a list + representing the path from source to target. + + Examples + -------- + >>> G = nx.path_graph(5) + >>> length, path = nx.multi_source_dijkstra(G, {0, 4}) + >>> for node in [0, 1, 2, 3, 4]: + ... print(f"{node}: {length[node]}") + 0: 0 + 1: 1 + 2: 2 + 3: 1 + 4: 0 + >>> path[1] + [0, 1] + >>> path[3] + [4, 3] + + >>> length, path = nx.multi_source_dijkstra(G, {0, 4}, 1) + >>> length + 1 + >>> path + [0, 1] + + Notes + ----- + Edge weight attributes must be numerical. + Distances are calculated as sums of weighted edges traversed. + + The weight function can be used to hide edges by returning None. + So ``weight = lambda u, v, d: 1 if d['color']=="red" else None`` + will find the shortest red path. + + Based on the Python cookbook recipe (119466) at + https://code.activestate.com/recipes/119466/ + + This algorithm is not guaranteed to work if edge weights + are negative or are floating point numbers + (overflows and roundoff errors can cause problems). + + Raises + ------ + ValueError + If `sources` is empty. + NodeNotFound + If any of `sources` is not in `G`. + + See Also + -------- + multi_source_dijkstra_path + multi_source_dijkstra_path_length + + """ + if not sources: + raise ValueError("sources must not be empty") + for s in sources: + if s not in G: + raise nx.NodeNotFound(f"Node {s} not found in graph") + if target in sources: + return (0, [target]) + weight = _weight_function(G, weight) + paths = {source: [source] for source in sources} # dictionary of paths + dist = _dijkstra_multisource( + G, sources, weight, paths=paths, cutoff=cutoff, target=target + ) + if target is None: + return (dist, paths) + try: + return (dist[target], paths[target]) + except KeyError as err: + raise nx.NetworkXNoPath(f"No path to {target}.") from err + + +def _dijkstra(G, source, weight, pred=None, paths=None, cutoff=None, target=None): + """Uses Dijkstra's algorithm to find shortest weighted paths from a + single source. + + This is a convenience function for :func:`_dijkstra_multisource` + with all the arguments the same, except the keyword argument + `sources` set to ``[source]``. + + """ + return _dijkstra_multisource( + G, [source], weight, pred=pred, paths=paths, cutoff=cutoff, target=target + ) + + +def _dijkstra_multisource( + G, sources, weight, pred=None, paths=None, cutoff=None, target=None +): + """Uses Dijkstra's algorithm to find shortest weighted paths + + Parameters + ---------- + G : NetworkX graph + + sources : non-empty iterable of nodes + Starting nodes for paths. If this is just an iterable containing + a single node, then all paths computed by this function will + start from that node. If there are two or more nodes in this + iterable, the computed paths may begin from any one of the start + nodes. + + weight: function + Function with (u, v, data) input that returns that edge's weight + or None to indicate a hidden edge + + pred: dict of lists, optional(default=None) + dict to store a list of predecessors keyed by that node + If None, predecessors are not stored. + + paths: dict, optional (default=None) + dict to store the path list from source to each node, keyed by node. + If None, paths are not stored. + + target : node label, optional + Ending node for path. Search is halted when target is found. + + cutoff : integer or float, optional + Length (sum of edge weights) at which the search is stopped. + If cutoff is provided, only return paths with summed weight <= cutoff. + + Returns + ------- + distance : dictionary + A mapping from node to shortest distance to that node from one + of the source nodes. + + Raises + ------ + NodeNotFound + If any of `sources` is not in `G`. + + Notes + ----- + The optional predecessor and path dictionaries can be accessed by + the caller through the original pred and paths objects passed + as arguments. No need to explicitly return pred or paths. + + """ + G_succ = G._adj # For speed-up (and works for both directed and undirected graphs) + + push = heappush + pop = heappop + dist = {} # dictionary of final distances + seen = {} + # fringe is heapq with 3-tuples (distance,c,node) + # use the count c to avoid comparing nodes (may not be able to) + c = count() + fringe = [] + for source in sources: + seen[source] = 0 + push(fringe, (0, next(c), source)) + while fringe: + (d, _, v) = pop(fringe) + if v in dist: + continue # already searched this node. + dist[v] = d + if v == target: + break + for u, e in G_succ[v].items(): + cost = weight(v, u, e) + if cost is None: + continue + vu_dist = dist[v] + cost + if cutoff is not None: + if vu_dist > cutoff: + continue + if u in dist: + u_dist = dist[u] + if vu_dist < u_dist: + raise ValueError("Contradictory paths found:", "negative weights?") + elif pred is not None and vu_dist == u_dist: + pred[u].append(v) + elif u not in seen or vu_dist < seen[u]: + seen[u] = vu_dist + push(fringe, (vu_dist, next(c), u)) + if paths is not None: + paths[u] = paths[v] + [u] + if pred is not None: + pred[u] = [v] + elif vu_dist == seen[u]: + if pred is not None: + pred[u].append(v) + + # The optional predecessor and path dictionaries can be accessed + # by the caller via the pred and paths objects passed as arguments. + return dist + + +@nx._dispatchable(edge_attrs="weight") +def dijkstra_predecessor_and_distance(G, source, cutoff=None, weight="weight"): + """Compute weighted shortest path length and predecessors. + + Uses Dijkstra's Method to obtain the shortest weighted paths + and return dictionaries of predecessors for each node and + distance for each node from the `source`. + + Parameters + ---------- + G : NetworkX graph + + source : node label + Starting node for path + + cutoff : integer or float, optional + Length (sum of edge weights) at which the search is stopped. + If cutoff is provided, only return paths with summed weight <= cutoff. + + weight : string or function + If this is a string, then edge weights will be accessed via the + edge attribute with this key (that is, the weight of the edge + joining `u` to `v` will be ``G.edges[u, v][weight]``). If no + such edge attribute exists, the weight of the edge is assumed to + be one. + + If this is a function, the weight of an edge is the value + returned by the function. The function must accept exactly three + positional arguments: the two endpoints of an edge and the + dictionary of edge attributes for that edge. The function must + return a number or None to indicate a hidden edge. + + Returns + ------- + pred, distance : dictionaries + Returns two dictionaries representing a list of predecessors + of a node and the distance to each node. + + Raises + ------ + NodeNotFound + If `source` is not in `G`. + + Notes + ----- + Edge weight attributes must be numerical. + Distances are calculated as sums of weighted edges traversed. + + The list of predecessors contains more than one element only when + there are more than one shortest paths to the key node. + + Examples + -------- + >>> G = nx.path_graph(5, create_using=nx.DiGraph()) + >>> pred, dist = nx.dijkstra_predecessor_and_distance(G, 0) + >>> sorted(pred.items()) + [(0, []), (1, [0]), (2, [1]), (3, [2]), (4, [3])] + >>> sorted(dist.items()) + [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)] + + >>> pred, dist = nx.dijkstra_predecessor_and_distance(G, 0, 1) + >>> sorted(pred.items()) + [(0, []), (1, [0])] + >>> sorted(dist.items()) + [(0, 0), (1, 1)] + """ + if source not in G: + raise nx.NodeNotFound(f"Node {source} is not found in the graph") + weight = _weight_function(G, weight) + pred = {source: []} # dictionary of predecessors + return (pred, _dijkstra(G, source, weight, pred=pred, cutoff=cutoff)) + + +@nx._dispatchable(edge_attrs="weight") +def all_pairs_dijkstra(G, cutoff=None, weight="weight"): + """Find shortest weighted paths and lengths between all nodes. + + Parameters + ---------- + G : NetworkX graph + + cutoff : integer or float, optional + Length (sum of edge weights) at which the search is stopped. + If cutoff is provided, only return paths with summed weight <= cutoff. + + weight : string or function + If this is a string, then edge weights will be accessed via the + edge attribute with this key (that is, the weight of the edge + joining `u` to `v` will be ``G.edge[u][v][weight]``). If no + such edge attribute exists, the weight of the edge is assumed to + be one. + + If this is a function, the weight of an edge is the value + returned by the function. The function must accept exactly three + positional arguments: the two endpoints of an edge and the + dictionary of edge attributes for that edge. The function must + return a number or None to indicate a hidden edge. + + Yields + ------ + (node, (distance, path)) : (node obj, (dict, dict)) + Each source node has two associated dicts. The first holds distance + keyed by target and the second holds paths keyed by target. + (See single_source_dijkstra for the source/target node terminology.) + If desired you can apply `dict()` to this function to create a dict + keyed by source node to the two dicts. + + Examples + -------- + >>> G = nx.path_graph(5) + >>> len_path = dict(nx.all_pairs_dijkstra(G)) + >>> len_path[3][0][1] + 2 + >>> for node in [0, 1, 2, 3, 4]: + ... print(f"3 - {node}: {len_path[3][0][node]}") + 3 - 0: 3 + 3 - 1: 2 + 3 - 2: 1 + 3 - 3: 0 + 3 - 4: 1 + >>> len_path[3][1][1] + [3, 2, 1] + >>> for n, (dist, path) in nx.all_pairs_dijkstra(G): + ... print(path[1]) + [0, 1] + [1] + [2, 1] + [3, 2, 1] + [4, 3, 2, 1] + + Notes + ----- + Edge weight attributes must be numerical. + Distances are calculated as sums of weighted edges traversed. + + The yielded dicts only have keys for reachable nodes. + """ + for n in G: + dist, path = single_source_dijkstra(G, n, cutoff=cutoff, weight=weight) + yield (n, (dist, path)) + + +@nx._dispatchable(edge_attrs="weight") +def all_pairs_dijkstra_path_length(G, cutoff=None, weight="weight"): + """Compute shortest path lengths between all nodes in a weighted graph. + + Parameters + ---------- + G : NetworkX graph + + cutoff : integer or float, optional + Length (sum of edge weights) at which the search is stopped. + If cutoff is provided, only return paths with summed weight <= cutoff. + + weight : string or function + If this is a string, then edge weights will be accessed via the + edge attribute with this key (that is, the weight of the edge + joining `u` to `v` will be ``G.edges[u, v][weight]``). If no + such edge attribute exists, the weight of the edge is assumed to + be one. + + If this is a function, the weight of an edge is the value + returned by the function. The function must accept exactly three + positional arguments: the two endpoints of an edge and the + dictionary of edge attributes for that edge. The function must + return a number or None to indicate a hidden edge. + + Returns + ------- + distance : iterator + (source, dictionary) iterator with dictionary keyed by target and + shortest path length as the key value. + + Examples + -------- + >>> G = nx.path_graph(5) + >>> length = dict(nx.all_pairs_dijkstra_path_length(G)) + >>> for node in [0, 1, 2, 3, 4]: + ... print(f"1 - {node}: {length[1][node]}") + 1 - 0: 1 + 1 - 1: 0 + 1 - 2: 1 + 1 - 3: 2 + 1 - 4: 3 + >>> length[3][2] + 1 + >>> length[2][2] + 0 + + Notes + ----- + Edge weight attributes must be numerical. + Distances are calculated as sums of weighted edges traversed. + + The dictionary returned only has keys for reachable node pairs. + """ + length = single_source_dijkstra_path_length + for n in G: + yield (n, length(G, n, cutoff=cutoff, weight=weight)) + + +@nx._dispatchable(edge_attrs="weight") +def all_pairs_dijkstra_path(G, cutoff=None, weight="weight"): + """Compute shortest paths between all nodes in a weighted graph. + + Parameters + ---------- + G : NetworkX graph + + cutoff : integer or float, optional + Length (sum of edge weights) at which the search is stopped. + If cutoff is provided, only return paths with summed weight <= cutoff. + + weight : string or function + If this is a string, then edge weights will be accessed via the + edge attribute with this key (that is, the weight of the edge + joining `u` to `v` will be ``G.edges[u, v][weight]``). If no + such edge attribute exists, the weight of the edge is assumed to + be one. + + If this is a function, the weight of an edge is the value + returned by the function. The function must accept exactly three + positional arguments: the two endpoints of an edge and the + dictionary of edge attributes for that edge. The function must + return a number or None to indicate a hidden edge. + + Returns + ------- + paths : iterator + (source, dictionary) iterator with dictionary keyed by target and + shortest path as the key value. + + Examples + -------- + >>> G = nx.path_graph(5) + >>> path = dict(nx.all_pairs_dijkstra_path(G)) + >>> path[0][4] + [0, 1, 2, 3, 4] + + Notes + ----- + Edge weight attributes must be numerical. + Distances are calculated as sums of weighted edges traversed. + + See Also + -------- + floyd_warshall, all_pairs_bellman_ford_path + + """ + path = single_source_dijkstra_path + # TODO This can be trivially parallelized. + for n in G: + yield (n, path(G, n, cutoff=cutoff, weight=weight)) + + +@nx._dispatchable(edge_attrs="weight") +def bellman_ford_predecessor_and_distance( + G, source, target=None, weight="weight", heuristic=False +): + """Compute shortest path lengths and predecessors on shortest paths + in weighted graphs. + + The algorithm has a running time of $O(mn)$ where $n$ is the number of + nodes and $m$ is the number of edges. It is slower than Dijkstra but + can handle negative edge weights. + + If a negative cycle is detected, you can use :func:`find_negative_cycle` + to return the cycle and examine it. Shortest paths are not defined when + a negative cycle exists because once reached, the path can cycle forever + to build up arbitrarily low weights. + + Parameters + ---------- + G : NetworkX graph + The algorithm works for all types of graphs, including directed + graphs and multigraphs. + + source: node label + Starting node for path + + target : node label, optional + Ending node for path + + weight : string or function + If this is a string, then edge weights will be accessed via the + edge attribute with this key (that is, the weight of the edge + joining `u` to `v` will be ``G.edges[u, v][weight]``). If no + such edge attribute exists, the weight of the edge is assumed to + be one. + + If this is a function, the weight of an edge is the value + returned by the function. The function must accept exactly three + positional arguments: the two endpoints of an edge and the + dictionary of edge attributes for that edge. The function must + return a number. + + heuristic : bool + Determines whether to use a heuristic to early detect negative + cycles at a hopefully negligible cost. + + Returns + ------- + pred, dist : dictionaries + Returns two dictionaries keyed by node to predecessor in the + path and to the distance from the source respectively. + + Raises + ------ + NodeNotFound + If `source` is not in `G`. + + NetworkXUnbounded + If the (di)graph contains a negative (di)cycle, the + algorithm raises an exception to indicate the presence of the + negative (di)cycle. Note: any negative weight edge in an + undirected graph is a negative cycle. + + Examples + -------- + >>> G = nx.path_graph(5, create_using=nx.DiGraph()) + >>> pred, dist = nx.bellman_ford_predecessor_and_distance(G, 0) + >>> sorted(pred.items()) + [(0, []), (1, [0]), (2, [1]), (3, [2]), (4, [3])] + >>> sorted(dist.items()) + [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)] + + >>> pred, dist = nx.bellman_ford_predecessor_and_distance(G, 0, 1) + >>> sorted(pred.items()) + [(0, []), (1, [0]), (2, [1]), (3, [2]), (4, [3])] + >>> sorted(dist.items()) + [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)] + + >>> G = nx.cycle_graph(5, create_using=nx.DiGraph()) + >>> G[1][2]["weight"] = -7 + >>> nx.bellman_ford_predecessor_and_distance(G, 0) + Traceback (most recent call last): + ... + networkx.exception.NetworkXUnbounded: Negative cycle detected. + + See Also + -------- + find_negative_cycle + + Notes + ----- + Edge weight attributes must be numerical. + Distances are calculated as sums of weighted edges traversed. + + The dictionaries returned only have keys for nodes reachable from + the source. + + In the case where the (di)graph is not connected, if a component + not containing the source contains a negative (di)cycle, it + will not be detected. + + In NetworkX v2.1 and prior, the source node had predecessor `[None]`. + In NetworkX v2.2 this changed to the source node having predecessor `[]` + """ + if source not in G: + raise nx.NodeNotFound(f"Node {source} is not found in the graph") + weight = _weight_function(G, weight) + if G.is_multigraph(): + if any( + weight(u, v, {k: d}) < 0 + for u, v, k, d in nx.selfloop_edges(G, keys=True, data=True) + ): + raise nx.NetworkXUnbounded("Negative cycle detected.") + else: + if any(weight(u, v, d) < 0 for u, v, d in nx.selfloop_edges(G, data=True)): + raise nx.NetworkXUnbounded("Negative cycle detected.") + + dist = {source: 0} + pred = {source: []} + + if len(G) == 1: + return pred, dist + + weight = _weight_function(G, weight) + + dist = _bellman_ford( + G, [source], weight, pred=pred, dist=dist, target=target, heuristic=heuristic + ) + return (pred, dist) + + +def _bellman_ford( + G, + source, + weight, + pred=None, + paths=None, + dist=None, + target=None, + heuristic=True, +): + """Calls relaxation loop for Bellman–Ford algorithm and builds paths + + This is an implementation of the SPFA variant. + See https://en.wikipedia.org/wiki/Shortest_Path_Faster_Algorithm + + Parameters + ---------- + G : NetworkX graph + + source: list + List of source nodes. The shortest path from any of the source + nodes will be found if multiple sources are provided. + + weight : function + The weight of an edge is the value returned by the function. The + function must accept exactly three positional arguments: the two + endpoints of an edge and the dictionary of edge attributes for + that edge. The function must return a number. + + pred: dict of lists, optional (default=None) + dict to store a list of predecessors keyed by that node + If None, predecessors are not stored + + paths: dict, optional (default=None) + dict to store the path list from source to each node, keyed by node + If None, paths are not stored + + dist: dict, optional (default=None) + dict to store distance from source to the keyed node + If None, returned dist dict contents default to 0 for every node in the + source list + + target: node label, optional + Ending node for path. Path lengths to other destinations may (and + probably will) be incorrect. + + heuristic : bool + Determines whether to use a heuristic to early detect negative + cycles at a hopefully negligible cost. + + Returns + ------- + dist : dict + Returns a dict keyed by node to the distance from the source. + Dicts for paths and pred are in the mutated input dicts by those names. + + Raises + ------ + NodeNotFound + If any of `source` is not in `G`. + + NetworkXUnbounded + If the (di)graph contains a negative (di)cycle, the + algorithm raises an exception to indicate the presence of the + negative (di)cycle. Note: any negative weight edge in an + undirected graph is a negative cycle + """ + if pred is None: + pred = {v: [] for v in source} + + if dist is None: + dist = {v: 0 for v in source} + + negative_cycle_found = _inner_bellman_ford( + G, + source, + weight, + pred, + dist, + heuristic, + ) + if negative_cycle_found is not None: + raise nx.NetworkXUnbounded("Negative cycle detected.") + + if paths is not None: + sources = set(source) + dsts = [target] if target is not None else pred + for dst in dsts: + gen = _build_paths_from_predecessors(sources, dst, pred) + paths[dst] = next(gen) + + return dist + + +def _inner_bellman_ford( + G, + sources, + weight, + pred, + dist=None, + heuristic=True, +): + """Inner Relaxation loop for Bellman–Ford algorithm. + + This is an implementation of the SPFA variant. + See https://en.wikipedia.org/wiki/Shortest_Path_Faster_Algorithm + + Parameters + ---------- + G : NetworkX graph + + source: list + List of source nodes. The shortest path from any of the source + nodes will be found if multiple sources are provided. + + weight : function + The weight of an edge is the value returned by the function. The + function must accept exactly three positional arguments: the two + endpoints of an edge and the dictionary of edge attributes for + that edge. The function must return a number. + + pred: dict of lists + dict to store a list of predecessors keyed by that node + + dist: dict, optional (default=None) + dict to store distance from source to the keyed node + If None, returned dist dict contents default to 0 for every node in the + source list + + heuristic : bool + Determines whether to use a heuristic to early detect negative + cycles at a hopefully negligible cost. + + Returns + ------- + node or None + Return a node `v` where processing discovered a negative cycle. + If no negative cycle found, return None. + + Raises + ------ + NodeNotFound + If any of `source` is not in `G`. + """ + for s in sources: + if s not in G: + raise nx.NodeNotFound(f"Source {s} not in G") + + if pred is None: + pred = {v: [] for v in sources} + + if dist is None: + dist = {v: 0 for v in sources} + + # Heuristic Storage setup. Note: use None because nodes cannot be None + nonexistent_edge = (None, None) + pred_edge = {v: None for v in sources} + recent_update = {v: nonexistent_edge for v in sources} + + G_succ = G._adj # For speed-up (and works for both directed and undirected graphs) + inf = float("inf") + n = len(G) + + count = {} + q = deque(sources) + in_q = set(sources) + while q: + u = q.popleft() + in_q.remove(u) + + # Skip relaxations if any of the predecessors of u is in the queue. + if all(pred_u not in in_q for pred_u in pred[u]): + dist_u = dist[u] + for v, e in G_succ[u].items(): + dist_v = dist_u + weight(u, v, e) + + if dist_v < dist.get(v, inf): + # In this conditional branch we are updating the path with v. + # If it happens that some earlier update also added node v + # that implies the existence of a negative cycle since + # after the update node v would lie on the update path twice. + # The update path is stored up to one of the source nodes, + # therefore u is always in the dict recent_update + if heuristic: + if v in recent_update[u]: + # Negative cycle found! + pred[v].append(u) + return v + + # Transfer the recent update info from u to v if the + # same source node is the head of the update path. + # If the source node is responsible for the cost update, + # then clear the history and use it instead. + if v in pred_edge and pred_edge[v] == u: + recent_update[v] = recent_update[u] + else: + recent_update[v] = (u, v) + + if v not in in_q: + q.append(v) + in_q.add(v) + count_v = count.get(v, 0) + 1 + if count_v == n: + # Negative cycle found! + return v + + count[v] = count_v + dist[v] = dist_v + pred[v] = [u] + pred_edge[v] = u + + elif dist.get(v) is not None and dist_v == dist.get(v): + pred[v].append(u) + + # successfully found shortest_path. No negative cycles found. + return None + + +@nx._dispatchable(edge_attrs="weight") +def bellman_ford_path(G, source, target, weight="weight"): + """Returns the shortest path from source to target in a weighted graph G. + + Parameters + ---------- + G : NetworkX graph + + source : node + Starting node + + target : node + Ending node + + weight : string or function (default="weight") + If this is a string, then edge weights will be accessed via the + edge attribute with this key (that is, the weight of the edge + joining `u` to `v` will be ``G.edges[u, v][weight]``). If no + such edge attribute exists, the weight of the edge is assumed to + be one. + + If this is a function, the weight of an edge is the value + returned by the function. The function must accept exactly three + positional arguments: the two endpoints of an edge and the + dictionary of edge attributes for that edge. The function must + return a number. + + Returns + ------- + path : list + List of nodes in a shortest path. + + Raises + ------ + NodeNotFound + If `source` is not in `G`. + + NetworkXNoPath + If no path exists between source and target. + + Examples + -------- + >>> G = nx.path_graph(5) + >>> nx.bellman_ford_path(G, 0, 4) + [0, 1, 2, 3, 4] + + Notes + ----- + Edge weight attributes must be numerical. + Distances are calculated as sums of weighted edges traversed. + + See Also + -------- + dijkstra_path, bellman_ford_path_length + """ + length, path = single_source_bellman_ford(G, source, target=target, weight=weight) + return path + + +@nx._dispatchable(edge_attrs="weight") +def bellman_ford_path_length(G, source, target, weight="weight"): + """Returns the shortest path length from source to target + in a weighted graph. + + Parameters + ---------- + G : NetworkX graph + + source : node label + starting node for path + + target : node label + ending node for path + + weight : string or function (default="weight") + If this is a string, then edge weights will be accessed via the + edge attribute with this key (that is, the weight of the edge + joining `u` to `v` will be ``G.edges[u, v][weight]``). If no + such edge attribute exists, the weight of the edge is assumed to + be one. + + If this is a function, the weight of an edge is the value + returned by the function. The function must accept exactly three + positional arguments: the two endpoints of an edge and the + dictionary of edge attributes for that edge. The function must + return a number. + + Returns + ------- + length : number + Shortest path length. + + Raises + ------ + NodeNotFound + If `source` is not in `G`. + + NetworkXNoPath + If no path exists between source and target. + + Examples + -------- + >>> G = nx.path_graph(5) + >>> nx.bellman_ford_path_length(G, 0, 4) + 4 + + Notes + ----- + Edge weight attributes must be numerical. + Distances are calculated as sums of weighted edges traversed. + + See Also + -------- + dijkstra_path_length, bellman_ford_path + """ + if source == target: + if source not in G: + raise nx.NodeNotFound(f"Node {source} not found in graph") + return 0 + + weight = _weight_function(G, weight) + + length = _bellman_ford(G, [source], weight, target=target) + + try: + return length[target] + except KeyError as err: + raise nx.NetworkXNoPath(f"node {target} not reachable from {source}") from err + + +@nx._dispatchable(edge_attrs="weight") +def single_source_bellman_ford_path(G, source, weight="weight"): + """Compute shortest path between source and all other reachable + nodes for a weighted graph. + + Parameters + ---------- + G : NetworkX graph + + source : node + Starting node for path. + + weight : string or function (default="weight") + If this is a string, then edge weights will be accessed via the + edge attribute with this key (that is, the weight of the edge + joining `u` to `v` will be ``G.edges[u, v][weight]``). If no + such edge attribute exists, the weight of the edge is assumed to + be one. + + If this is a function, the weight of an edge is the value + returned by the function. The function must accept exactly three + positional arguments: the two endpoints of an edge and the + dictionary of edge attributes for that edge. The function must + return a number. + + Returns + ------- + paths : dictionary + Dictionary of shortest path lengths keyed by target. + + Raises + ------ + NodeNotFound + If `source` is not in `G`. + + Examples + -------- + >>> G = nx.path_graph(5) + >>> path = nx.single_source_bellman_ford_path(G, 0) + >>> path[4] + [0, 1, 2, 3, 4] + + Notes + ----- + Edge weight attributes must be numerical. + Distances are calculated as sums of weighted edges traversed. + + See Also + -------- + single_source_dijkstra, single_source_bellman_ford + + """ + (length, path) = single_source_bellman_ford(G, source, weight=weight) + return path + + +@nx._dispatchable(edge_attrs="weight") +def single_source_bellman_ford_path_length(G, source, weight="weight"): + """Compute the shortest path length between source and all other + reachable nodes for a weighted graph. + + Parameters + ---------- + G : NetworkX graph + + source : node label + Starting node for path + + weight : string or function (default="weight") + If this is a string, then edge weights will be accessed via the + edge attribute with this key (that is, the weight of the edge + joining `u` to `v` will be ``G.edges[u, v][weight]``). If no + such edge attribute exists, the weight of the edge is assumed to + be one. + + If this is a function, the weight of an edge is the value + returned by the function. The function must accept exactly three + positional arguments: the two endpoints of an edge and the + dictionary of edge attributes for that edge. The function must + return a number. + + Returns + ------- + length : dictionary + Dictionary of shortest path length keyed by target + + Raises + ------ + NodeNotFound + If `source` is not in `G`. + + Examples + -------- + >>> G = nx.path_graph(5) + >>> length = nx.single_source_bellman_ford_path_length(G, 0) + >>> length[4] + 4 + >>> for node in [0, 1, 2, 3, 4]: + ... print(f"{node}: {length[node]}") + 0: 0 + 1: 1 + 2: 2 + 3: 3 + 4: 4 + + Notes + ----- + Edge weight attributes must be numerical. + Distances are calculated as sums of weighted edges traversed. + + See Also + -------- + single_source_dijkstra, single_source_bellman_ford + + """ + weight = _weight_function(G, weight) + return _bellman_ford(G, [source], weight) + + +@nx._dispatchable(edge_attrs="weight") +def single_source_bellman_ford(G, source, target=None, weight="weight"): + """Compute shortest paths and lengths in a weighted graph G. + + Uses Bellman-Ford algorithm for shortest paths. + + Parameters + ---------- + G : NetworkX graph + + source : node label + Starting node for path + + target : node label, optional + Ending node for path + + weight : string or function + If this is a string, then edge weights will be accessed via the + edge attribute with this key (that is, the weight of the edge + joining `u` to `v` will be ``G.edges[u, v][weight]``). If no + such edge attribute exists, the weight of the edge is assumed to + be one. + + If this is a function, the weight of an edge is the value + returned by the function. The function must accept exactly three + positional arguments: the two endpoints of an edge and the + dictionary of edge attributes for that edge. The function must + return a number. + + Returns + ------- + distance, path : pair of dictionaries, or numeric and list + If target is None, returns a tuple of two dictionaries keyed by node. + The first dictionary stores distance from one of the source nodes. + The second stores the path from one of the sources to that node. + If target is not None, returns a tuple of (distance, path) where + distance is the distance from source to target and path is a list + representing the path from source to target. + + Raises + ------ + NodeNotFound + If `source` is not in `G`. + + Examples + -------- + >>> G = nx.path_graph(5) + >>> length, path = nx.single_source_bellman_ford(G, 0) + >>> length[4] + 4 + >>> for node in [0, 1, 2, 3, 4]: + ... print(f"{node}: {length[node]}") + 0: 0 + 1: 1 + 2: 2 + 3: 3 + 4: 4 + >>> path[4] + [0, 1, 2, 3, 4] + >>> length, path = nx.single_source_bellman_ford(G, 0, 1) + >>> length + 1 + >>> path + [0, 1] + + Notes + ----- + Edge weight attributes must be numerical. + Distances are calculated as sums of weighted edges traversed. + + See Also + -------- + single_source_dijkstra + single_source_bellman_ford_path + single_source_bellman_ford_path_length + """ + if source == target: + if source not in G: + raise nx.NodeNotFound(f"Node {source} is not found in the graph") + return (0, [source]) + + weight = _weight_function(G, weight) + + paths = {source: [source]} # dictionary of paths + dist = _bellman_ford(G, [source], weight, paths=paths, target=target) + if target is None: + return (dist, paths) + try: + return (dist[target], paths[target]) + except KeyError as err: + msg = f"Node {target} not reachable from {source}" + raise nx.NetworkXNoPath(msg) from err + + +@nx._dispatchable(edge_attrs="weight") +def all_pairs_bellman_ford_path_length(G, weight="weight"): + """Compute shortest path lengths between all nodes in a weighted graph. + + Parameters + ---------- + G : NetworkX graph + + weight : string or function (default="weight") + If this is a string, then edge weights will be accessed via the + edge attribute with this key (that is, the weight of the edge + joining `u` to `v` will be ``G.edges[u, v][weight]``). If no + such edge attribute exists, the weight of the edge is assumed to + be one. + + If this is a function, the weight of an edge is the value + returned by the function. The function must accept exactly three + positional arguments: the two endpoints of an edge and the + dictionary of edge attributes for that edge. The function must + return a number. + + Returns + ------- + distance : iterator + (source, dictionary) iterator with dictionary keyed by target and + shortest path length as the key value. + + Examples + -------- + >>> G = nx.path_graph(5) + >>> length = dict(nx.all_pairs_bellman_ford_path_length(G)) + >>> for node in [0, 1, 2, 3, 4]: + ... print(f"1 - {node}: {length[1][node]}") + 1 - 0: 1 + 1 - 1: 0 + 1 - 2: 1 + 1 - 3: 2 + 1 - 4: 3 + >>> length[3][2] + 1 + >>> length[2][2] + 0 + + Notes + ----- + Edge weight attributes must be numerical. + Distances are calculated as sums of weighted edges traversed. + + The dictionary returned only has keys for reachable node pairs. + """ + length = single_source_bellman_ford_path_length + for n in G: + yield (n, dict(length(G, n, weight=weight))) + + +@nx._dispatchable(edge_attrs="weight") +def all_pairs_bellman_ford_path(G, weight="weight"): + """Compute shortest paths between all nodes in a weighted graph. + + Parameters + ---------- + G : NetworkX graph + + weight : string or function (default="weight") + If this is a string, then edge weights will be accessed via the + edge attribute with this key (that is, the weight of the edge + joining `u` to `v` will be ``G.edges[u, v][weight]``). If no + such edge attribute exists, the weight of the edge is assumed to + be one. + + If this is a function, the weight of an edge is the value + returned by the function. The function must accept exactly three + positional arguments: the two endpoints of an edge and the + dictionary of edge attributes for that edge. The function must + return a number. + + Returns + ------- + paths : iterator + (source, dictionary) iterator with dictionary keyed by target and + shortest path as the key value. + + Examples + -------- + >>> G = nx.path_graph(5) + >>> path = dict(nx.all_pairs_bellman_ford_path(G)) + >>> path[0][4] + [0, 1, 2, 3, 4] + + Notes + ----- + Edge weight attributes must be numerical. + Distances are calculated as sums of weighted edges traversed. + + See Also + -------- + floyd_warshall, all_pairs_dijkstra_path + + """ + path = single_source_bellman_ford_path + for n in G: + yield (n, path(G, n, weight=weight)) + + +@nx._dispatchable(edge_attrs="weight") +def goldberg_radzik(G, source, weight="weight"): + """Compute shortest path lengths and predecessors on shortest paths + in weighted graphs. + + The algorithm has a running time of $O(mn)$ where $n$ is the number of + nodes and $m$ is the number of edges. It is slower than Dijkstra but + can handle negative edge weights. + + Parameters + ---------- + G : NetworkX graph + The algorithm works for all types of graphs, including directed + graphs and multigraphs. + + source: node label + Starting node for path + + weight : string or function + If this is a string, then edge weights will be accessed via the + edge attribute with this key (that is, the weight of the edge + joining `u` to `v` will be ``G.edges[u, v][weight]``). If no + such edge attribute exists, the weight of the edge is assumed to + be one. + + If this is a function, the weight of an edge is the value + returned by the function. The function must accept exactly three + positional arguments: the two endpoints of an edge and the + dictionary of edge attributes for that edge. The function must + return a number. + + Returns + ------- + pred, dist : dictionaries + Returns two dictionaries keyed by node to predecessor in the + path and to the distance from the source respectively. + + Raises + ------ + NodeNotFound + If `source` is not in `G`. + + NetworkXUnbounded + If the (di)graph contains a negative (di)cycle, the + algorithm raises an exception to indicate the presence of the + negative (di)cycle. Note: any negative weight edge in an + undirected graph is a negative cycle. + + As of NetworkX v3.2, a zero weight cycle is no longer + incorrectly reported as a negative weight cycle. + + + Examples + -------- + >>> G = nx.path_graph(5, create_using=nx.DiGraph()) + >>> pred, dist = nx.goldberg_radzik(G, 0) + >>> sorted(pred.items()) + [(0, None), (1, 0), (2, 1), (3, 2), (4, 3)] + >>> sorted(dist.items()) + [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)] + + >>> G = nx.cycle_graph(5, create_using=nx.DiGraph()) + >>> G[1][2]["weight"] = -7 + >>> nx.goldberg_radzik(G, 0) + Traceback (most recent call last): + ... + networkx.exception.NetworkXUnbounded: Negative cycle detected. + + Notes + ----- + Edge weight attributes must be numerical. + Distances are calculated as sums of weighted edges traversed. + + The dictionaries returned only have keys for nodes reachable from + the source. + + In the case where the (di)graph is not connected, if a component + not containing the source contains a negative (di)cycle, it + will not be detected. + + """ + if source not in G: + raise nx.NodeNotFound(f"Node {source} is not found in the graph") + weight = _weight_function(G, weight) + if G.is_multigraph(): + if any( + weight(u, v, {k: d}) < 0 + for u, v, k, d in nx.selfloop_edges(G, keys=True, data=True) + ): + raise nx.NetworkXUnbounded("Negative cycle detected.") + else: + if any(weight(u, v, d) < 0 for u, v, d in nx.selfloop_edges(G, data=True)): + raise nx.NetworkXUnbounded("Negative cycle detected.") + + if len(G) == 1: + return {source: None}, {source: 0} + + G_succ = G._adj # For speed-up (and works for both directed and undirected graphs) + + inf = float("inf") + d = {u: inf for u in G} + d[source] = 0 + pred = {source: None} + + def topo_sort(relabeled): + """Topologically sort nodes relabeled in the previous round and detect + negative cycles. + """ + # List of nodes to scan in this round. Denoted by A in Goldberg and + # Radzik's paper. + to_scan = [] + # In the DFS in the loop below, neg_count records for each node the + # number of edges of negative reduced costs on the path from a DFS root + # to the node in the DFS forest. The reduced cost of an edge (u, v) is + # defined as d[u] + weight[u][v] - d[v]. + # + # neg_count also doubles as the DFS visit marker array. + neg_count = {} + for u in relabeled: + # Skip visited nodes. + if u in neg_count: + continue + d_u = d[u] + # Skip nodes without out-edges of negative reduced costs. + if all(d_u + weight(u, v, e) >= d[v] for v, e in G_succ[u].items()): + continue + # Nonrecursive DFS that inserts nodes reachable from u via edges of + # nonpositive reduced costs into to_scan in (reverse) topological + # order. + stack = [(u, iter(G_succ[u].items()))] + in_stack = {u} + neg_count[u] = 0 + while stack: + u, it = stack[-1] + try: + v, e = next(it) + except StopIteration: + to_scan.append(u) + stack.pop() + in_stack.remove(u) + continue + t = d[u] + weight(u, v, e) + d_v = d[v] + if t < d_v: + is_neg = t < d_v + d[v] = t + pred[v] = u + if v not in neg_count: + neg_count[v] = neg_count[u] + int(is_neg) + stack.append((v, iter(G_succ[v].items()))) + in_stack.add(v) + elif v in in_stack and neg_count[u] + int(is_neg) > neg_count[v]: + # (u, v) is a back edge, and the cycle formed by the + # path v to u and (u, v) contains at least one edge of + # negative reduced cost. The cycle must be of negative + # cost. + raise nx.NetworkXUnbounded("Negative cycle detected.") + to_scan.reverse() + return to_scan + + def relax(to_scan): + """Relax out-edges of relabeled nodes.""" + relabeled = set() + # Scan nodes in to_scan in topological order and relax incident + # out-edges. Add the relabled nodes to labeled. + for u in to_scan: + d_u = d[u] + for v, e in G_succ[u].items(): + w_e = weight(u, v, e) + if d_u + w_e < d[v]: + d[v] = d_u + w_e + pred[v] = u + relabeled.add(v) + return relabeled + + # Set of nodes relabled in the last round of scan operations. Denoted by B + # in Goldberg and Radzik's paper. + relabeled = {source} + + while relabeled: + to_scan = topo_sort(relabeled) + relabeled = relax(to_scan) + + d = {u: d[u] for u in pred} + return pred, d + + +@nx._dispatchable(edge_attrs="weight") +def negative_edge_cycle(G, weight="weight", heuristic=True): + """Returns True if there exists a negative edge cycle anywhere in G. + + Parameters + ---------- + G : NetworkX graph + + weight : string or function + If this is a string, then edge weights will be accessed via the + edge attribute with this key (that is, the weight of the edge + joining `u` to `v` will be ``G.edges[u, v][weight]``). If no + such edge attribute exists, the weight of the edge is assumed to + be one. + + If this is a function, the weight of an edge is the value + returned by the function. The function must accept exactly three + positional arguments: the two endpoints of an edge and the + dictionary of edge attributes for that edge. The function must + return a number. + + heuristic : bool + Determines whether to use a heuristic to early detect negative + cycles at a negligible cost. In case of graphs with a negative cycle, + the performance of detection increases by at least an order of magnitude. + + Returns + ------- + negative_cycle : bool + True if a negative edge cycle exists, otherwise False. + + Examples + -------- + >>> G = nx.cycle_graph(5, create_using=nx.DiGraph()) + >>> print(nx.negative_edge_cycle(G)) + False + >>> G[1][2]["weight"] = -7 + >>> print(nx.negative_edge_cycle(G)) + True + + Notes + ----- + Edge weight attributes must be numerical. + Distances are calculated as sums of weighted edges traversed. + + This algorithm uses bellman_ford_predecessor_and_distance() but finds + negative cycles on any component by first adding a new node connected to + every node, and starting bellman_ford_predecessor_and_distance on that + node. It then removes that extra node. + """ + if G.size() == 0: + return False + + # find unused node to use temporarily + newnode = -1 + while newnode in G: + newnode -= 1 + # connect it to all nodes + G.add_edges_from([(newnode, n) for n in G]) + + try: + bellman_ford_predecessor_and_distance( + G, newnode, weight=weight, heuristic=heuristic + ) + except nx.NetworkXUnbounded: + return True + finally: + G.remove_node(newnode) + return False + + +@nx._dispatchable(edge_attrs="weight") +def find_negative_cycle(G, source, weight="weight"): + """Returns a cycle with negative total weight if it exists. + + Bellman-Ford is used to find shortest_paths. That algorithm + stops if there exists a negative cycle. This algorithm + picks up from there and returns the found negative cycle. + + The cycle consists of a list of nodes in the cycle order. The last + node equals the first to make it a cycle. + You can look up the edge weights in the original graph. In the case + of multigraphs the relevant edge is the minimal weight edge between + the nodes in the 2-tuple. + + If the graph has no negative cycle, a NetworkXError is raised. + + Parameters + ---------- + G : NetworkX graph + + source: node label + The search for the negative cycle will start from this node. + + weight : string or function + If this is a string, then edge weights will be accessed via the + edge attribute with this key (that is, the weight of the edge + joining `u` to `v` will be ``G.edges[u, v][weight]``). If no + such edge attribute exists, the weight of the edge is assumed to + be one. + + If this is a function, the weight of an edge is the value + returned by the function. The function must accept exactly three + positional arguments: the two endpoints of an edge and the + dictionary of edge attributes for that edge. The function must + return a number. + + Examples + -------- + >>> G = nx.DiGraph() + >>> G.add_weighted_edges_from( + ... [(0, 1, 2), (1, 2, 2), (2, 0, 1), (1, 4, 2), (4, 0, -5)] + ... ) + >>> nx.find_negative_cycle(G, 0) + [4, 0, 1, 4] + + Returns + ------- + cycle : list + A list of nodes in the order of the cycle found. The last node + equals the first to indicate a cycle. + + Raises + ------ + NetworkXError + If no negative cycle is found. + """ + weight = _weight_function(G, weight) + pred = {source: []} + + v = _inner_bellman_ford(G, [source], weight, pred=pred) + if v is None: + raise nx.NetworkXError("No negative cycles detected.") + + # negative cycle detected... find it + neg_cycle = [] + stack = [(v, list(pred[v]))] + seen = {v} + while stack: + node, preds = stack[-1] + if v in preds: + # found the cycle + neg_cycle.extend([node, v]) + neg_cycle = list(reversed(neg_cycle)) + return neg_cycle + + if preds: + nbr = preds.pop() + if nbr not in seen: + stack.append((nbr, list(pred[nbr]))) + neg_cycle.append(node) + seen.add(nbr) + else: + stack.pop() + if neg_cycle: + neg_cycle.pop() + else: + if v in G[v] and weight(G, v, v) < 0: + return [v, v] + # should not reach here + raise nx.NetworkXError("Negative cycle is detected but not found") + # should not get here... + msg = "negative cycle detected but not identified" + raise nx.NetworkXUnbounded(msg) + + +@nx._dispatchable(edge_attrs="weight") +def bidirectional_dijkstra(G, source, target, weight="weight"): + r"""Dijkstra's algorithm for shortest paths using bidirectional search. + + Parameters + ---------- + G : NetworkX graph + + source : node + Starting node. + + target : node + Ending node. + + weight : string or function + If this is a string, then edge weights will be accessed via the + edge attribute with this key (that is, the weight of the edge + joining `u` to `v` will be ``G.edges[u, v][weight]``). If no + such edge attribute exists, the weight of the edge is assumed to + be one. + + If this is a function, the weight of an edge is the value + returned by the function. The function must accept exactly three + positional arguments: the two endpoints of an edge and the + dictionary of edge attributes for that edge. The function must + return a number or None to indicate a hidden edge. + + Returns + ------- + length, path : number and list + length is the distance from source to target. + path is a list of nodes on a path from source to target. + + Raises + ------ + NodeNotFound + If `source` or `target` is not in `G`. + + NetworkXNoPath + If no path exists between source and target. + + Examples + -------- + >>> G = nx.path_graph(5) + >>> length, path = nx.bidirectional_dijkstra(G, 0, 4) + >>> print(length) + 4 + >>> print(path) + [0, 1, 2, 3, 4] + + Notes + ----- + Edge weight attributes must be numerical. + Distances are calculated as sums of weighted edges traversed. + + The weight function can be used to hide edges by returning None. + So ``weight = lambda u, v, d: 1 if d['color']=="red" else None`` + will find the shortest red path. + + In practice bidirectional Dijkstra is much more than twice as fast as + ordinary Dijkstra. + + Ordinary Dijkstra expands nodes in a sphere-like manner from the + source. The radius of this sphere will eventually be the length + of the shortest path. Bidirectional Dijkstra will expand nodes + from both the source and the target, making two spheres of half + this radius. Volume of the first sphere is `\pi*r*r` while the + others are `2*\pi*r/2*r/2`, making up half the volume. + + This algorithm is not guaranteed to work if edge weights + are negative or are floating point numbers + (overflows and roundoff errors can cause problems). + + See Also + -------- + shortest_path + shortest_path_length + """ + if source not in G: + raise nx.NodeNotFound(f"Source {source} is not in G") + + if target not in G: + raise nx.NodeNotFound(f"Target {target} is not in G") + + if source == target: + return (0, [source]) + + weight = _weight_function(G, weight) + push = heappush + pop = heappop + # Init: [Forward, Backward] + dists = [{}, {}] # dictionary of final distances + paths = [{source: [source]}, {target: [target]}] # dictionary of paths + fringe = [[], []] # heap of (distance, node) for choosing node to expand + seen = [{source: 0}, {target: 0}] # dict of distances to seen nodes + c = count() + # initialize fringe heap + push(fringe[0], (0, next(c), source)) + push(fringe[1], (0, next(c), target)) + # neighs for extracting correct neighbor information + if G.is_directed(): + neighs = [G._succ, G._pred] + else: + neighs = [G._adj, G._adj] + # variables to hold shortest discovered path + # finaldist = 1e30000 + finalpath = [] + dir = 1 + while fringe[0] and fringe[1]: + # choose direction + # dir == 0 is forward direction and dir == 1 is back + dir = 1 - dir + # extract closest to expand + (dist, _, v) = pop(fringe[dir]) + if v in dists[dir]: + # Shortest path to v has already been found + continue + # update distance + dists[dir][v] = dist # equal to seen[dir][v] + if v in dists[1 - dir]: + # if we have scanned v in both directions we are done + # we have now discovered the shortest path + return (finaldist, finalpath) + + for w, d in neighs[dir][v].items(): + # weight(v, w, d) for forward and weight(w, v, d) for back direction + cost = weight(v, w, d) if dir == 0 else weight(w, v, d) + if cost is None: + continue + vwLength = dists[dir][v] + cost + if w in dists[dir]: + if vwLength < dists[dir][w]: + raise ValueError("Contradictory paths found: negative weights?") + elif w not in seen[dir] or vwLength < seen[dir][w]: + # relaxing + seen[dir][w] = vwLength + push(fringe[dir], (vwLength, next(c), w)) + paths[dir][w] = paths[dir][v] + [w] + if w in seen[0] and w in seen[1]: + # see if this path is better than the already + # discovered shortest path + totaldist = seen[0][w] + seen[1][w] + if finalpath == [] or finaldist > totaldist: + finaldist = totaldist + revpath = paths[1][w][:] + revpath.reverse() + finalpath = paths[0][w] + revpath[1:] + raise nx.NetworkXNoPath(f"No path between {source} and {target}.") + + +@nx._dispatchable(edge_attrs="weight") +def johnson(G, weight="weight"): + r"""Uses Johnson's Algorithm to compute shortest paths. + + Johnson's Algorithm finds a shortest path between each pair of + nodes in a weighted graph even if negative weights are present. + + Parameters + ---------- + G : NetworkX graph + + weight : string or function + If this is a string, then edge weights will be accessed via the + edge attribute with this key (that is, the weight of the edge + joining `u` to `v` will be ``G.edges[u, v][weight]``). If no + such edge attribute exists, the weight of the edge is assumed to + be one. + + If this is a function, the weight of an edge is the value + returned by the function. The function must accept exactly three + positional arguments: the two endpoints of an edge and the + dictionary of edge attributes for that edge. The function must + return a number. + + Returns + ------- + distance : dictionary + Dictionary, keyed by source and target, of shortest paths. + + Examples + -------- + >>> graph = nx.DiGraph() + >>> graph.add_weighted_edges_from( + ... [("0", "3", 3), ("0", "1", -5), ("0", "2", 2), ("1", "2", 4), ("2", "3", 1)] + ... ) + >>> paths = nx.johnson(graph, weight="weight") + >>> paths["0"]["2"] + ['0', '1', '2'] + + Notes + ----- + Johnson's algorithm is suitable even for graphs with negative weights. It + works by using the Bellman–Ford algorithm to compute a transformation of + the input graph that removes all negative weights, allowing Dijkstra's + algorithm to be used on the transformed graph. + + The time complexity of this algorithm is $O(n^2 \log n + n m)$, + where $n$ is the number of nodes and $m$ the number of edges in the + graph. For dense graphs, this may be faster than the Floyd–Warshall + algorithm. + + See Also + -------- + floyd_warshall_predecessor_and_distance + floyd_warshall_numpy + all_pairs_shortest_path + all_pairs_shortest_path_length + all_pairs_dijkstra_path + bellman_ford_predecessor_and_distance + all_pairs_bellman_ford_path + all_pairs_bellman_ford_path_length + + """ + dist = {v: 0 for v in G} + pred = {v: [] for v in G} + weight = _weight_function(G, weight) + + # Calculate distance of shortest paths + dist_bellman = _bellman_ford(G, list(G), weight, pred=pred, dist=dist) + + # Update the weight function to take into account the Bellman--Ford + # relaxation distances. + def new_weight(u, v, d): + return weight(u, v, d) + dist_bellman[u] - dist_bellman[v] + + def dist_path(v): + paths = {v: [v]} + _dijkstra(G, v, new_weight, paths=paths) + return paths + + return {v: dist_path(v) for v in G} diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/__init__.py b/lib/python3.10/site-packages/networkx/algorithms/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d3f6c27bab317939c084cfa6c37e944d76dff3eb Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_asteroidal.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_asteroidal.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d212ba9353067ba7890bd6614f98bf4cc15a173 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_asteroidal.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_boundary.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_boundary.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7dd9edbcdf76b35833222a1034debc43b91804c1 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_boundary.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_bridges.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_bridges.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f6c3a67aeca3e6eb9c8dfca1716d2f6f0cf890e0 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_bridges.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_broadcasting.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_broadcasting.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..68af7162b912523ff1f23e23240dd88e03ea83d6 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_broadcasting.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_chains.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_chains.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..40bb33012fd53b2009f2938708c8be23295fb2c3 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_chains.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_chordal.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_chordal.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e44358e716d84b3a2ca05b28d8e700fcd961536 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_chordal.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_clique.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_clique.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..84ab44b95f9ef7ad30a8f70e0c0ac2fe32982e82 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_clique.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_cluster.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_cluster.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5552174745869fb9c241da0d001bc44ce204ee42 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_cluster.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_communicability.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_communicability.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..18f7dcee91c170d9b51a671400720aa1ecbd2c29 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_communicability.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_core.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_core.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b757f661a4086ea79ec0810e3ab50948faf38a22 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_core.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_covering.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_covering.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6df7dc404113a3a91e735911a8a7f12bcf44a37 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_covering.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_cuts.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_cuts.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a190c8ba02635f751ce6e4036ec058b43e645722 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_cuts.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_cycles.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_cycles.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bda86104faae821d2c21ef66898fef2da3247602 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_cycles.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_d_separation.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_d_separation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d47bf4af41fcb442b1f730355a7e6d7349c2bc88 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_d_separation.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_dag.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_dag.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ebd47ea2bc4d8e2f26192741df2329108870bac4 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_dag.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_distance_measures.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_distance_measures.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c9f43dd2c3396bc3afec9c2ad6d54b432d80ec6 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_distance_measures.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_distance_regular.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_distance_regular.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e6bb5a4c8a5de51b907f09598aea1ee29cbdce8 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_distance_regular.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_dominance.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_dominance.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a8e322c73647eba8a2a7af2d34df65c7441aa73 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_dominance.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_dominating.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_dominating.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a08d6cf6d5f8f38a581a9df4a15e497f385feacb Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_dominating.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_efficiency.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_efficiency.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6fc64aaebe0d708537ef4fe983644d6bf9c60bac Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_efficiency.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_euler.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_euler.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..85909005905c08b63eb9532376ca26070610e9c4 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_euler.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_graph_hashing.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_graph_hashing.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c79ab814b77bf7d3fb8481d815b4da0e0707a6e Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_graph_hashing.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_graphical.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_graphical.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20aea346928a6fad277efa05a8ea02b8c0b55b2a Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_graphical.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_hierarchy.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_hierarchy.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c9384ec640f7fe8755db6d5aff66dcd976ae81e6 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_hierarchy.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_hybrid.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_hybrid.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..492c3377930670b0b9dab524cabb65f59371eec0 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_hybrid.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_isolate.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_isolate.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eac21dce61103361ea5df875440633330eb578e8 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_isolate.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_link_prediction.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_link_prediction.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a251de469401c94544ef5bdab8c91fd270598f3 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_link_prediction.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_lowest_common_ancestors.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_lowest_common_ancestors.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1fd8a4b6c7edf640e29fb3c18a4d7f69c10bea3d Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_lowest_common_ancestors.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_matching.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_matching.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5c31bf2daa5572c1172820494e01361272c7b164 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_matching.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_max_weight_clique.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_max_weight_clique.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..acba97e6a7289ccf98b827cafd6224fb60e5e7a6 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_max_weight_clique.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_mis.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_mis.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f37db49daf4f19bbd5f3efc68d75494da216783 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_mis.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_moral.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_moral.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..98d0b01c9b74d882311d82e3ca82f21699fc0a0c Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_moral.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_node_classification.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_node_classification.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f3e0d8572ec543727c1fa657a34670c121382519 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_node_classification.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_non_randomness.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_non_randomness.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a74f5ca8676091e95fcade3721a812c3f3ae9fd Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_non_randomness.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_planar_drawing.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_planar_drawing.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..052f7d3b55b9787fa101a94670de1e0db29ef3f5 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_planar_drawing.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_planarity.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_planarity.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8641fc4d4e4af66ee6a2f224655b162a81fb3c03 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_planarity.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_polynomials.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_polynomials.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..22e2de949a3152590582c47f04481626c383b956 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_polynomials.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_reciprocity.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_reciprocity.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4488e338508ca7b60990d36c5d534fdbbd4ade9d Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_reciprocity.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_regular.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_regular.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aeb1b81a23d86c01fe0cc1a5c6bb04bbcc0de80c Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_regular.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_richclub.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_richclub.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9e3d80ff4ab5747c8c307473d604dbf0898833c7 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_richclub.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_similarity.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_similarity.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6b466121c5bdcdb421fc0140ad56cd37c2fb689 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_similarity.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_simple_paths.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_simple_paths.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..47f8d0c9c6576add1f1025c4317ff5e07a53c5f1 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_simple_paths.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_smallworld.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_smallworld.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5d18f2c86bed170039566419228be56679973f90 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_smallworld.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_smetric.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_smetric.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03b6ea2562b1d747a5a087758d05ed783a74e26a Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_smetric.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_sparsifiers.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_sparsifiers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..949b57758b7273a85af2fd62127f9bf382642a11 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_sparsifiers.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_structuralholes.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_structuralholes.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bf9843021c03b1f302856a34604df9ff31367ad6 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_structuralholes.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_summarization.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_summarization.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3bdec90e3ede4cda670e8b15a36cc4e433ce32aa Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_summarization.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_swap.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_swap.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..508c4cd0b9c5e22c2150e9aa6e8ec2ef3315aeea Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_swap.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_threshold.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_threshold.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ced48eb2bf393c980f753b6471a435e1b6908e6f Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_threshold.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_time_dependent.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_time_dependent.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..54ab0f0499ccd00e5852b022149d2fcc6dfd86fe Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_time_dependent.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_tournament.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_tournament.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e983b96792f116ba16f4d85dcacf88be2c6d9a00 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_tournament.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_triads.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_triads.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..010bfa21249a4fea970f4ef8e4292e94f6b022ef Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_triads.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_vitality.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_vitality.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce8387f24945984149c707b5e4593deeb4cedcfb Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_vitality.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_voronoi.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_voronoi.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..816de1170316621d5a2ffcb756167592b1d9a8a0 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_voronoi.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_walks.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_walks.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..447f09ad5059874c62bd00e5ab762ea43df0bfd9 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_walks.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_wiener.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_wiener.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5a4aaca9a6f8a4098cd47497aa78a4d6e37f782d Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tests/__pycache__/test_wiener.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/test_asteroidal.py b/lib/python3.10/site-packages/networkx/algorithms/tests/test_asteroidal.py new file mode 100644 index 0000000000000000000000000000000000000000..67131b2d05026317b496d06e6b382836c8c26367 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/tests/test_asteroidal.py @@ -0,0 +1,23 @@ +import networkx as nx + + +def test_is_at_free(): + is_at_free = nx.asteroidal.is_at_free + + cycle = nx.cycle_graph(6) + assert not is_at_free(cycle) + + path = nx.path_graph(6) + assert is_at_free(path) + + small_graph = nx.complete_graph(2) + assert is_at_free(small_graph) + + petersen = nx.petersen_graph() + assert not is_at_free(petersen) + + clique = nx.complete_graph(6) + assert is_at_free(clique) + + line_clique = nx.line_graph(clique) + assert not is_at_free(line_clique) diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/test_boundary.py b/lib/python3.10/site-packages/networkx/algorithms/tests/test_boundary.py new file mode 100644 index 0000000000000000000000000000000000000000..856be465556941fe6f2bfc2c8bab6d4b508cf999 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/tests/test_boundary.py @@ -0,0 +1,154 @@ +"""Unit tests for the :mod:`networkx.algorithms.boundary` module.""" + +from itertools import combinations + +import pytest + +import networkx as nx +from networkx import convert_node_labels_to_integers as cnlti +from networkx.utils import edges_equal + + +class TestNodeBoundary: + """Unit tests for the :func:`~networkx.node_boundary` function.""" + + def test_null_graph(self): + """Tests that the null graph has empty node boundaries.""" + null = nx.null_graph() + assert nx.node_boundary(null, []) == set() + assert nx.node_boundary(null, [], []) == set() + assert nx.node_boundary(null, [1, 2, 3]) == set() + assert nx.node_boundary(null, [1, 2, 3], [4, 5, 6]) == set() + assert nx.node_boundary(null, [1, 2, 3], [3, 4, 5]) == set() + + def test_path_graph(self): + P10 = cnlti(nx.path_graph(10), first_label=1) + assert nx.node_boundary(P10, []) == set() + assert nx.node_boundary(P10, [], []) == set() + assert nx.node_boundary(P10, [1, 2, 3]) == {4} + assert nx.node_boundary(P10, [4, 5, 6]) == {3, 7} + assert nx.node_boundary(P10, [3, 4, 5, 6, 7]) == {2, 8} + assert nx.node_boundary(P10, [8, 9, 10]) == {7} + assert nx.node_boundary(P10, [4, 5, 6], [9, 10]) == set() + + def test_complete_graph(self): + K10 = cnlti(nx.complete_graph(10), first_label=1) + assert nx.node_boundary(K10, []) == set() + assert nx.node_boundary(K10, [], []) == set() + assert nx.node_boundary(K10, [1, 2, 3]) == {4, 5, 6, 7, 8, 9, 10} + assert nx.node_boundary(K10, [4, 5, 6]) == {1, 2, 3, 7, 8, 9, 10} + assert nx.node_boundary(K10, [3, 4, 5, 6, 7]) == {1, 2, 8, 9, 10} + assert nx.node_boundary(K10, [4, 5, 6], []) == set() + assert nx.node_boundary(K10, K10) == set() + assert nx.node_boundary(K10, [1, 2, 3], [3, 4, 5]) == {4, 5} + + def test_petersen(self): + """Check boundaries in the petersen graph + + cheeger(G,k)=min(|bdy(S)|/|S| for |S|=k, 0>> list(cycles("abc")) + [('a', 'b', 'c'), ('b', 'c', 'a'), ('c', 'a', 'b')] + + """ + n = len(seq) + cycled_seq = cycle(seq) + for x in seq: + yield tuple(islice(cycled_seq, n)) + next(cycled_seq) + + +def cyclic_equals(seq1, seq2): + """Decide whether two sequences are equal up to cyclic permutations. + + For example:: + + >>> cyclic_equals("xyz", "zxy") + True + >>> cyclic_equals("xyz", "zyx") + False + + """ + # Cast seq2 to a tuple since `cycles()` yields tuples. + seq2 = tuple(seq2) + return any(x == tuple(seq2) for x in cycles(seq1)) + + +class TestChainDecomposition: + """Unit tests for the chain decomposition function.""" + + def assertContainsChain(self, chain, expected): + # A cycle could be expressed in two different orientations, one + # forward and one backward, so we need to check for cyclic + # equality in both orientations. + reversed_chain = list(reversed([tuple(reversed(e)) for e in chain])) + for candidate in expected: + if cyclic_equals(chain, candidate): + break + if cyclic_equals(reversed_chain, candidate): + break + else: + self.fail("chain not found") + + def test_decomposition(self): + edges = [ + # DFS tree edges. + (1, 2), + (2, 3), + (3, 4), + (3, 5), + (5, 6), + (6, 7), + (7, 8), + (5, 9), + (9, 10), + # Nontree edges. + (1, 3), + (1, 4), + (2, 5), + (5, 10), + (6, 8), + ] + G = nx.Graph(edges) + expected = [ + [(1, 3), (3, 2), (2, 1)], + [(1, 4), (4, 3)], + [(2, 5), (5, 3)], + [(5, 10), (10, 9), (9, 5)], + [(6, 8), (8, 7), (7, 6)], + ] + chains = list(nx.chain_decomposition(G, root=1)) + assert len(chains) == len(expected) + + # This chain decomposition isn't unique + # for chain in chains: + # print(chain) + # self.assertContainsChain(chain, expected) + + def test_barbell_graph(self): + # The (3, 0) barbell graph has two triangles joined by a single edge. + G = nx.barbell_graph(3, 0) + chains = list(nx.chain_decomposition(G, root=0)) + expected = [[(0, 1), (1, 2), (2, 0)], [(3, 4), (4, 5), (5, 3)]] + assert len(chains) == len(expected) + for chain in chains: + self.assertContainsChain(chain, expected) + + def test_disconnected_graph(self): + """Test for a graph with multiple connected components.""" + G = nx.barbell_graph(3, 0) + H = nx.barbell_graph(3, 0) + mapping = dict(zip(range(6), "abcdef")) + nx.relabel_nodes(H, mapping, copy=False) + G = nx.union(G, H) + chains = list(nx.chain_decomposition(G)) + expected = [ + [(0, 1), (1, 2), (2, 0)], + [(3, 4), (4, 5), (5, 3)], + [("a", "b"), ("b", "c"), ("c", "a")], + [("d", "e"), ("e", "f"), ("f", "d")], + ] + assert len(chains) == len(expected) + for chain in chains: + self.assertContainsChain(chain, expected) + + def test_disconnected_graph_root_node(self): + """Test for a single component of a disconnected graph.""" + G = nx.barbell_graph(3, 0) + H = nx.barbell_graph(3, 0) + mapping = dict(zip(range(6), "abcdef")) + nx.relabel_nodes(H, mapping, copy=False) + G = nx.union(G, H) + chains = list(nx.chain_decomposition(G, root="a")) + expected = [ + [("a", "b"), ("b", "c"), ("c", "a")], + [("d", "e"), ("e", "f"), ("f", "d")], + ] + assert len(chains) == len(expected) + for chain in chains: + self.assertContainsChain(chain, expected) + + def test_chain_decomposition_root_not_in_G(self): + """Test chain decomposition when root is not in graph""" + G = nx.Graph() + G.add_nodes_from([1, 2, 3]) + with pytest.raises(nx.NodeNotFound): + nx.has_bridges(G, root=6) diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/test_chordal.py b/lib/python3.10/site-packages/networkx/algorithms/tests/test_chordal.py new file mode 100644 index 0000000000000000000000000000000000000000..148b22f2632d722522483b556f11285a8e823126 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/tests/test_chordal.py @@ -0,0 +1,129 @@ +import pytest + +import networkx as nx + + +class TestMCS: + @classmethod + def setup_class(cls): + # simple graph + connected_chordal_G = nx.Graph() + connected_chordal_G.add_edges_from( + [ + (1, 2), + (1, 3), + (2, 3), + (2, 4), + (3, 4), + (3, 5), + (3, 6), + (4, 5), + (4, 6), + (5, 6), + ] + ) + cls.connected_chordal_G = connected_chordal_G + + chordal_G = nx.Graph() + chordal_G.add_edges_from( + [ + (1, 2), + (1, 3), + (2, 3), + (2, 4), + (3, 4), + (3, 5), + (3, 6), + (4, 5), + (4, 6), + (5, 6), + (7, 8), + ] + ) + chordal_G.add_node(9) + cls.chordal_G = chordal_G + + non_chordal_G = nx.Graph() + non_chordal_G.add_edges_from([(1, 2), (1, 3), (2, 4), (2, 5), (3, 4), (3, 5)]) + cls.non_chordal_G = non_chordal_G + + self_loop_G = nx.Graph() + self_loop_G.add_edges_from([(1, 1)]) + cls.self_loop_G = self_loop_G + + @pytest.mark.parametrize("G", (nx.DiGraph(), nx.MultiGraph(), nx.MultiDiGraph())) + def test_is_chordal_not_implemented(self, G): + with pytest.raises(nx.NetworkXNotImplemented): + nx.is_chordal(G) + + def test_is_chordal(self): + assert not nx.is_chordal(self.non_chordal_G) + assert nx.is_chordal(self.chordal_G) + assert nx.is_chordal(self.connected_chordal_G) + assert nx.is_chordal(nx.Graph()) + assert nx.is_chordal(nx.complete_graph(3)) + assert nx.is_chordal(nx.cycle_graph(3)) + assert not nx.is_chordal(nx.cycle_graph(5)) + assert nx.is_chordal(self.self_loop_G) + + def test_induced_nodes(self): + G = nx.generators.classic.path_graph(10) + Induced_nodes = nx.find_induced_nodes(G, 1, 9, 2) + assert Induced_nodes == {1, 2, 3, 4, 5, 6, 7, 8, 9} + pytest.raises( + nx.NetworkXTreewidthBoundExceeded, nx.find_induced_nodes, G, 1, 9, 1 + ) + Induced_nodes = nx.find_induced_nodes(self.chordal_G, 1, 6) + assert Induced_nodes == {1, 2, 4, 6} + pytest.raises(nx.NetworkXError, nx.find_induced_nodes, self.non_chordal_G, 1, 5) + + def test_graph_treewidth(self): + with pytest.raises(nx.NetworkXError, match="Input graph is not chordal"): + nx.chordal_graph_treewidth(self.non_chordal_G) + + def test_chordal_find_cliques(self): + cliques = { + frozenset([9]), + frozenset([7, 8]), + frozenset([1, 2, 3]), + frozenset([2, 3, 4]), + frozenset([3, 4, 5, 6]), + } + assert set(nx.chordal_graph_cliques(self.chordal_G)) == cliques + with pytest.raises(nx.NetworkXError, match="Input graph is not chordal"): + set(nx.chordal_graph_cliques(self.non_chordal_G)) + with pytest.raises(nx.NetworkXError, match="Input graph is not chordal"): + set(nx.chordal_graph_cliques(self.self_loop_G)) + + def test_chordal_find_cliques_path(self): + G = nx.path_graph(10) + cliqueset = nx.chordal_graph_cliques(G) + for u, v in G.edges(): + assert frozenset([u, v]) in cliqueset or frozenset([v, u]) in cliqueset + + def test_chordal_find_cliquesCC(self): + cliques = {frozenset([1, 2, 3]), frozenset([2, 3, 4]), frozenset([3, 4, 5, 6])} + cgc = nx.chordal_graph_cliques + assert set(cgc(self.connected_chordal_G)) == cliques + + def test_complete_to_chordal_graph(self): + fgrg = nx.fast_gnp_random_graph + test_graphs = [ + nx.barbell_graph(6, 2), + nx.cycle_graph(15), + nx.wheel_graph(20), + nx.grid_graph([10, 4]), + nx.ladder_graph(15), + nx.star_graph(5), + nx.bull_graph(), + fgrg(20, 0.3, seed=1), + ] + for G in test_graphs: + H, a = nx.complete_to_chordal_graph(G) + assert nx.is_chordal(H) + assert len(a) == H.number_of_nodes() + if nx.is_chordal(G): + assert G.number_of_edges() == H.number_of_edges() + assert set(a.values()) == {0} + else: + assert len(set(a.values())) == H.number_of_nodes() diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/test_clique.py b/lib/python3.10/site-packages/networkx/algorithms/tests/test_clique.py new file mode 100644 index 0000000000000000000000000000000000000000..3bee210982888a142f07a043bbde24bdad80fae9 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/tests/test_clique.py @@ -0,0 +1,291 @@ +import pytest + +import networkx as nx +from networkx import convert_node_labels_to_integers as cnlti + + +class TestCliques: + def setup_method(self): + z = [3, 4, 3, 4, 2, 4, 2, 1, 1, 1, 1] + self.G = cnlti(nx.generators.havel_hakimi_graph(z), first_label=1) + self.cl = list(nx.find_cliques(self.G)) + H = nx.complete_graph(6) + H = nx.relabel_nodes(H, {i: i + 1 for i in range(6)}) + H.remove_edges_from([(2, 6), (2, 5), (2, 4), (1, 3), (5, 3)]) + self.H = H + + def test_find_cliques1(self): + cl = list(nx.find_cliques(self.G)) + rcl = nx.find_cliques_recursive(self.G) + expected = [[2, 6, 1, 3], [2, 6, 4], [5, 4, 7], [8, 9], [10, 11]] + assert sorted(map(sorted, cl)) == sorted(map(sorted, rcl)) + assert sorted(map(sorted, cl)) == sorted(map(sorted, expected)) + + def test_selfloops(self): + self.G.add_edge(1, 1) + cl = list(nx.find_cliques(self.G)) + rcl = list(nx.find_cliques_recursive(self.G)) + assert set(map(frozenset, cl)) == set(map(frozenset, rcl)) + answer = [{2, 6, 1, 3}, {2, 6, 4}, {5, 4, 7}, {8, 9}, {10, 11}] + assert len(answer) == len(cl) + assert all(set(c) in answer for c in cl) + + def test_find_cliques2(self): + hcl = list(nx.find_cliques(self.H)) + assert sorted(map(sorted, hcl)) == [[1, 2], [1, 4, 5, 6], [2, 3], [3, 4, 6]] + + def test_find_cliques3(self): + # all cliques are [[2, 6, 1, 3], [2, 6, 4], [5, 4, 7], [8, 9], [10, 11]] + + cl = list(nx.find_cliques(self.G, [2])) + rcl = nx.find_cliques_recursive(self.G, [2]) + expected = [[2, 6, 1, 3], [2, 6, 4]] + assert sorted(map(sorted, rcl)) == sorted(map(sorted, expected)) + assert sorted(map(sorted, cl)) == sorted(map(sorted, expected)) + + cl = list(nx.find_cliques(self.G, [2, 3])) + rcl = nx.find_cliques_recursive(self.G, [2, 3]) + expected = [[2, 6, 1, 3]] + assert sorted(map(sorted, rcl)) == sorted(map(sorted, expected)) + assert sorted(map(sorted, cl)) == sorted(map(sorted, expected)) + + cl = list(nx.find_cliques(self.G, [2, 6, 4])) + rcl = nx.find_cliques_recursive(self.G, [2, 6, 4]) + expected = [[2, 6, 4]] + assert sorted(map(sorted, rcl)) == sorted(map(sorted, expected)) + assert sorted(map(sorted, cl)) == sorted(map(sorted, expected)) + + cl = list(nx.find_cliques(self.G, [2, 6, 4])) + rcl = nx.find_cliques_recursive(self.G, [2, 6, 4]) + expected = [[2, 6, 4]] + assert sorted(map(sorted, rcl)) == sorted(map(sorted, expected)) + assert sorted(map(sorted, cl)) == sorted(map(sorted, expected)) + + with pytest.raises(ValueError): + list(nx.find_cliques(self.G, [2, 6, 4, 1])) + + with pytest.raises(ValueError): + list(nx.find_cliques_recursive(self.G, [2, 6, 4, 1])) + + def test_number_of_cliques(self): + G = self.G + assert nx.number_of_cliques(G, 1) == 1 + assert list(nx.number_of_cliques(G, [1]).values()) == [1] + assert list(nx.number_of_cliques(G, [1, 2]).values()) == [1, 2] + assert nx.number_of_cliques(G, [1, 2]) == {1: 1, 2: 2} + assert nx.number_of_cliques(G, 2) == 2 + assert nx.number_of_cliques(G) == { + 1: 1, + 2: 2, + 3: 1, + 4: 2, + 5: 1, + 6: 2, + 7: 1, + 8: 1, + 9: 1, + 10: 1, + 11: 1, + } + assert nx.number_of_cliques(G, nodes=list(G)) == { + 1: 1, + 2: 2, + 3: 1, + 4: 2, + 5: 1, + 6: 2, + 7: 1, + 8: 1, + 9: 1, + 10: 1, + 11: 1, + } + assert nx.number_of_cliques(G, nodes=[2, 3, 4]) == {2: 2, 3: 1, 4: 2} + assert nx.number_of_cliques(G, cliques=self.cl) == { + 1: 1, + 2: 2, + 3: 1, + 4: 2, + 5: 1, + 6: 2, + 7: 1, + 8: 1, + 9: 1, + 10: 1, + 11: 1, + } + assert nx.number_of_cliques(G, list(G), cliques=self.cl) == { + 1: 1, + 2: 2, + 3: 1, + 4: 2, + 5: 1, + 6: 2, + 7: 1, + 8: 1, + 9: 1, + 10: 1, + 11: 1, + } + + def test_node_clique_number(self): + G = self.G + assert nx.node_clique_number(G, 1) == 4 + assert list(nx.node_clique_number(G, [1]).values()) == [4] + assert list(nx.node_clique_number(G, [1, 2]).values()) == [4, 4] + assert nx.node_clique_number(G, [1, 2]) == {1: 4, 2: 4} + assert nx.node_clique_number(G, 1) == 4 + assert nx.node_clique_number(G) == { + 1: 4, + 2: 4, + 3: 4, + 4: 3, + 5: 3, + 6: 4, + 7: 3, + 8: 2, + 9: 2, + 10: 2, + 11: 2, + } + assert nx.node_clique_number(G, cliques=self.cl) == { + 1: 4, + 2: 4, + 3: 4, + 4: 3, + 5: 3, + 6: 4, + 7: 3, + 8: 2, + 9: 2, + 10: 2, + 11: 2, + } + assert nx.node_clique_number(G, [1, 2], cliques=self.cl) == {1: 4, 2: 4} + assert nx.node_clique_number(G, 1, cliques=self.cl) == 4 + + def test_make_clique_bipartite(self): + G = self.G + B = nx.make_clique_bipartite(G) + assert sorted(B) == [-5, -4, -3, -2, -1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] + # Project onto the nodes of the original graph. + H = nx.projected_graph(B, range(1, 12)) + assert H.adj == G.adj + # Project onto the nodes representing the cliques. + H1 = nx.projected_graph(B, range(-5, 0)) + # Relabel the negative numbers as positive ones. + H1 = nx.relabel_nodes(H1, {-v: v for v in range(1, 6)}) + assert sorted(H1) == [1, 2, 3, 4, 5] + + def test_make_max_clique_graph(self): + """Tests that the maximal clique graph is the same as the bipartite + clique graph after being projected onto the nodes representing the + cliques. + + """ + G = self.G + B = nx.make_clique_bipartite(G) + # Project onto the nodes representing the cliques. + H1 = nx.projected_graph(B, range(-5, 0)) + # Relabel the negative numbers as nonnegative ones, starting at + # 0. + H1 = nx.relabel_nodes(H1, {-v: v - 1 for v in range(1, 6)}) + H2 = nx.make_max_clique_graph(G) + assert H1.adj == H2.adj + + def test_directed(self): + with pytest.raises(nx.NetworkXNotImplemented): + next(nx.find_cliques(nx.DiGraph())) + + def test_find_cliques_trivial(self): + G = nx.Graph() + assert sorted(nx.find_cliques(G)) == [] + assert sorted(nx.find_cliques_recursive(G)) == [] + + def test_make_max_clique_graph_create_using(self): + G = nx.Graph([(1, 2), (3, 1), (4, 1), (5, 6)]) + E = nx.Graph([(0, 1), (0, 2), (1, 2)]) + E.add_node(3) + assert nx.is_isomorphic(nx.make_max_clique_graph(G, create_using=nx.Graph), E) + + +class TestEnumerateAllCliques: + def test_paper_figure_4(self): + # Same graph as given in Fig. 4 of paper enumerate_all_cliques is + # based on. + # http://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=1559964&isnumber=33129 + G = nx.Graph() + edges_fig_4 = [ + ("a", "b"), + ("a", "c"), + ("a", "d"), + ("a", "e"), + ("b", "c"), + ("b", "d"), + ("b", "e"), + ("c", "d"), + ("c", "e"), + ("d", "e"), + ("f", "b"), + ("f", "c"), + ("f", "g"), + ("g", "f"), + ("g", "c"), + ("g", "d"), + ("g", "e"), + ] + G.add_edges_from(edges_fig_4) + + cliques = list(nx.enumerate_all_cliques(G)) + clique_sizes = list(map(len, cliques)) + assert sorted(clique_sizes) == clique_sizes + + expected_cliques = [ + ["a"], + ["b"], + ["c"], + ["d"], + ["e"], + ["f"], + ["g"], + ["a", "b"], + ["a", "b", "d"], + ["a", "b", "d", "e"], + ["a", "b", "e"], + ["a", "c"], + ["a", "c", "d"], + ["a", "c", "d", "e"], + ["a", "c", "e"], + ["a", "d"], + ["a", "d", "e"], + ["a", "e"], + ["b", "c"], + ["b", "c", "d"], + ["b", "c", "d", "e"], + ["b", "c", "e"], + ["b", "c", "f"], + ["b", "d"], + ["b", "d", "e"], + ["b", "e"], + ["b", "f"], + ["c", "d"], + ["c", "d", "e"], + ["c", "d", "e", "g"], + ["c", "d", "g"], + ["c", "e"], + ["c", "e", "g"], + ["c", "f"], + ["c", "f", "g"], + ["c", "g"], + ["d", "e"], + ["d", "e", "g"], + ["d", "g"], + ["e", "g"], + ["f", "g"], + ["a", "b", "c"], + ["a", "b", "c", "d"], + ["a", "b", "c", "d", "e"], + ["a", "b", "c", "e"], + ] + + assert sorted(map(sorted, cliques)) == sorted(map(sorted, expected_cliques)) diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/test_cluster.py b/lib/python3.10/site-packages/networkx/algorithms/tests/test_cluster.py new file mode 100644 index 0000000000000000000000000000000000000000..b656ba81553bc9bdca1fe828c5e9c01427c43555 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/tests/test_cluster.py @@ -0,0 +1,549 @@ +import pytest + +import networkx as nx + + +class TestTriangles: + def test_empty(self): + G = nx.Graph() + assert list(nx.triangles(G).values()) == [] + + def test_path(self): + G = nx.path_graph(10) + assert list(nx.triangles(G).values()) == [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + assert nx.triangles(G) == { + 0: 0, + 1: 0, + 2: 0, + 3: 0, + 4: 0, + 5: 0, + 6: 0, + 7: 0, + 8: 0, + 9: 0, + } + + def test_cubical(self): + G = nx.cubical_graph() + assert list(nx.triangles(G).values()) == [0, 0, 0, 0, 0, 0, 0, 0] + assert nx.triangles(G, 1) == 0 + assert list(nx.triangles(G, [1, 2]).values()) == [0, 0] + assert nx.triangles(G, 1) == 0 + assert nx.triangles(G, [1, 2]) == {1: 0, 2: 0} + + def test_k5(self): + G = nx.complete_graph(5) + assert list(nx.triangles(G).values()) == [6, 6, 6, 6, 6] + assert sum(nx.triangles(G).values()) / 3 == 10 + assert nx.triangles(G, 1) == 6 + G.remove_edge(1, 2) + assert list(nx.triangles(G).values()) == [5, 3, 3, 5, 5] + assert nx.triangles(G, 1) == 3 + G.add_edge(3, 3) # ignore self-edges + assert list(nx.triangles(G).values()) == [5, 3, 3, 5, 5] + assert nx.triangles(G, 3) == 5 + + +class TestDirectedClustering: + def test_clustering(self): + G = nx.DiGraph() + assert list(nx.clustering(G).values()) == [] + assert nx.clustering(G) == {} + + def test_path(self): + G = nx.path_graph(10, create_using=nx.DiGraph()) + assert list(nx.clustering(G).values()) == [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ] + assert nx.clustering(G) == { + 0: 0, + 1: 0, + 2: 0, + 3: 0, + 4: 0, + 5: 0, + 6: 0, + 7: 0, + 8: 0, + 9: 0, + } + assert nx.clustering(G, 0) == 0 + + def test_k5(self): + G = nx.complete_graph(5, create_using=nx.DiGraph()) + assert list(nx.clustering(G).values()) == [1, 1, 1, 1, 1] + assert nx.average_clustering(G) == 1 + G.remove_edge(1, 2) + assert list(nx.clustering(G).values()) == [ + 11 / 12, + 1, + 1, + 11 / 12, + 11 / 12, + ] + assert nx.clustering(G, [1, 4]) == {1: 1, 4: 11 / 12} + G.remove_edge(2, 1) + assert list(nx.clustering(G).values()) == [ + 5 / 6, + 1, + 1, + 5 / 6, + 5 / 6, + ] + assert nx.clustering(G, [1, 4]) == {1: 1, 4: 0.83333333333333337} + assert nx.clustering(G, 4) == 5 / 6 + + def test_triangle_and_edge(self): + G = nx.cycle_graph(3, create_using=nx.DiGraph()) + G.add_edge(0, 4) + assert nx.clustering(G)[0] == 1 / 6 + + +class TestDirectedWeightedClustering: + @classmethod + def setup_class(cls): + global np + np = pytest.importorskip("numpy") + + def test_clustering(self): + G = nx.DiGraph() + assert list(nx.clustering(G, weight="weight").values()) == [] + assert nx.clustering(G) == {} + + def test_path(self): + G = nx.path_graph(10, create_using=nx.DiGraph()) + assert list(nx.clustering(G, weight="weight").values()) == [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ] + assert nx.clustering(G, weight="weight") == { + 0: 0, + 1: 0, + 2: 0, + 3: 0, + 4: 0, + 5: 0, + 6: 0, + 7: 0, + 8: 0, + 9: 0, + } + + def test_k5(self): + G = nx.complete_graph(5, create_using=nx.DiGraph()) + assert list(nx.clustering(G, weight="weight").values()) == [1, 1, 1, 1, 1] + assert nx.average_clustering(G, weight="weight") == 1 + G.remove_edge(1, 2) + assert list(nx.clustering(G, weight="weight").values()) == [ + 11 / 12, + 1, + 1, + 11 / 12, + 11 / 12, + ] + assert nx.clustering(G, [1, 4], weight="weight") == {1: 1, 4: 11 / 12} + G.remove_edge(2, 1) + assert list(nx.clustering(G, weight="weight").values()) == [ + 5 / 6, + 1, + 1, + 5 / 6, + 5 / 6, + ] + assert nx.clustering(G, [1, 4], weight="weight") == { + 1: 1, + 4: 0.83333333333333337, + } + + def test_triangle_and_edge(self): + G = nx.cycle_graph(3, create_using=nx.DiGraph()) + G.add_edge(0, 4, weight=2) + assert nx.clustering(G)[0] == 1 / 6 + # Relaxed comparisons to allow graphblas-algorithms to pass tests + np.testing.assert_allclose(nx.clustering(G, weight="weight")[0], 1 / 12) + np.testing.assert_allclose(nx.clustering(G, 0, weight="weight"), 1 / 12) + + +class TestWeightedClustering: + @classmethod + def setup_class(cls): + global np + np = pytest.importorskip("numpy") + + def test_clustering(self): + G = nx.Graph() + assert list(nx.clustering(G, weight="weight").values()) == [] + assert nx.clustering(G) == {} + + def test_path(self): + G = nx.path_graph(10) + assert list(nx.clustering(G, weight="weight").values()) == [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ] + assert nx.clustering(G, weight="weight") == { + 0: 0, + 1: 0, + 2: 0, + 3: 0, + 4: 0, + 5: 0, + 6: 0, + 7: 0, + 8: 0, + 9: 0, + } + + def test_cubical(self): + G = nx.cubical_graph() + assert list(nx.clustering(G, weight="weight").values()) == [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ] + assert nx.clustering(G, 1) == 0 + assert list(nx.clustering(G, [1, 2], weight="weight").values()) == [0, 0] + assert nx.clustering(G, 1, weight="weight") == 0 + assert nx.clustering(G, [1, 2], weight="weight") == {1: 0, 2: 0} + + def test_k5(self): + G = nx.complete_graph(5) + assert list(nx.clustering(G, weight="weight").values()) == [1, 1, 1, 1, 1] + assert nx.average_clustering(G, weight="weight") == 1 + G.remove_edge(1, 2) + assert list(nx.clustering(G, weight="weight").values()) == [ + 5 / 6, + 1, + 1, + 5 / 6, + 5 / 6, + ] + assert nx.clustering(G, [1, 4], weight="weight") == { + 1: 1, + 4: 0.83333333333333337, + } + + def test_triangle_and_edge(self): + G = nx.cycle_graph(3) + G.add_edge(0, 4, weight=2) + assert nx.clustering(G)[0] == 1 / 3 + np.testing.assert_allclose(nx.clustering(G, weight="weight")[0], 1 / 6) + np.testing.assert_allclose(nx.clustering(G, 0, weight="weight"), 1 / 6) + + def test_triangle_and_signed_edge(self): + G = nx.cycle_graph(3) + G.add_edge(0, 1, weight=-1) + G.add_edge(3, 0, weight=0) + assert nx.clustering(G)[0] == 1 / 3 + assert nx.clustering(G, weight="weight")[0] == -1 / 3 + + +class TestClustering: + @classmethod + def setup_class(cls): + pytest.importorskip("numpy") + + def test_clustering(self): + G = nx.Graph() + assert list(nx.clustering(G).values()) == [] + assert nx.clustering(G) == {} + + def test_path(self): + G = nx.path_graph(10) + assert list(nx.clustering(G).values()) == [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ] + assert nx.clustering(G) == { + 0: 0, + 1: 0, + 2: 0, + 3: 0, + 4: 0, + 5: 0, + 6: 0, + 7: 0, + 8: 0, + 9: 0, + } + + def test_cubical(self): + G = nx.cubical_graph() + assert list(nx.clustering(G).values()) == [0, 0, 0, 0, 0, 0, 0, 0] + assert nx.clustering(G, 1) == 0 + assert list(nx.clustering(G, [1, 2]).values()) == [0, 0] + assert nx.clustering(G, 1) == 0 + assert nx.clustering(G, [1, 2]) == {1: 0, 2: 0} + + def test_k5(self): + G = nx.complete_graph(5) + assert list(nx.clustering(G).values()) == [1, 1, 1, 1, 1] + assert nx.average_clustering(G) == 1 + G.remove_edge(1, 2) + assert list(nx.clustering(G).values()) == [ + 5 / 6, + 1, + 1, + 5 / 6, + 5 / 6, + ] + assert nx.clustering(G, [1, 4]) == {1: 1, 4: 0.83333333333333337} + + def test_k5_signed(self): + G = nx.complete_graph(5) + assert list(nx.clustering(G).values()) == [1, 1, 1, 1, 1] + assert nx.average_clustering(G) == 1 + G.remove_edge(1, 2) + G.add_edge(0, 1, weight=-1) + assert list(nx.clustering(G, weight="weight").values()) == [ + 1 / 6, + -1 / 3, + 1, + 3 / 6, + 3 / 6, + ] + + +class TestTransitivity: + def test_transitivity(self): + G = nx.Graph() + assert nx.transitivity(G) == 0 + + def test_path(self): + G = nx.path_graph(10) + assert nx.transitivity(G) == 0 + + def test_cubical(self): + G = nx.cubical_graph() + assert nx.transitivity(G) == 0 + + def test_k5(self): + G = nx.complete_graph(5) + assert nx.transitivity(G) == 1 + G.remove_edge(1, 2) + assert nx.transitivity(G) == 0.875 + + +class TestSquareClustering: + def test_clustering(self): + G = nx.Graph() + assert list(nx.square_clustering(G).values()) == [] + assert nx.square_clustering(G) == {} + + def test_path(self): + G = nx.path_graph(10) + assert list(nx.square_clustering(G).values()) == [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ] + assert nx.square_clustering(G) == { + 0: 0, + 1: 0, + 2: 0, + 3: 0, + 4: 0, + 5: 0, + 6: 0, + 7: 0, + 8: 0, + 9: 0, + } + + def test_cubical(self): + G = nx.cubical_graph() + assert list(nx.square_clustering(G).values()) == [ + 1 / 3, + 1 / 3, + 1 / 3, + 1 / 3, + 1 / 3, + 1 / 3, + 1 / 3, + 1 / 3, + ] + assert list(nx.square_clustering(G, [1, 2]).values()) == [1 / 3, 1 / 3] + assert nx.square_clustering(G, [1])[1] == 1 / 3 + assert nx.square_clustering(G, 1) == 1 / 3 + assert nx.square_clustering(G, [1, 2]) == {1: 1 / 3, 2: 1 / 3} + + def test_k5(self): + G = nx.complete_graph(5) + assert list(nx.square_clustering(G).values()) == [1, 1, 1, 1, 1] + + def test_bipartite_k5(self): + G = nx.complete_bipartite_graph(5, 5) + assert list(nx.square_clustering(G).values()) == [1, 1, 1, 1, 1, 1, 1, 1, 1, 1] + + def test_lind_square_clustering(self): + """Test C4 for figure 1 Lind et al (2005)""" + G = nx.Graph( + [ + (1, 2), + (1, 3), + (1, 6), + (1, 7), + (2, 4), + (2, 5), + (3, 4), + (3, 5), + (6, 7), + (7, 8), + (6, 8), + (7, 9), + (7, 10), + (6, 11), + (6, 12), + (2, 13), + (2, 14), + (3, 15), + (3, 16), + ] + ) + G1 = G.subgraph([1, 2, 3, 4, 5, 13, 14, 15, 16]) + G2 = G.subgraph([1, 6, 7, 8, 9, 10, 11, 12]) + assert nx.square_clustering(G, [1])[1] == 3 / 43 + assert nx.square_clustering(G1, [1])[1] == 2 / 6 + assert nx.square_clustering(G2, [1])[1] == 1 / 5 + + def test_peng_square_clustering(self): + """Test eq2 for figure 1 Peng et al (2008)""" + G = nx.Graph([(1, 2), (1, 3), (2, 4), (3, 4), (3, 5), (3, 6)]) + assert nx.square_clustering(G, [1])[1] == 1 / 3 + + def test_self_loops_square_clustering(self): + G = nx.path_graph(5) + assert nx.square_clustering(G) == {0: 0, 1: 0.0, 2: 0.0, 3: 0.0, 4: 0} + G.add_edges_from([(0, 0), (1, 1), (2, 2)]) + assert nx.square_clustering(G) == {0: 1, 1: 0.5, 2: 0.2, 3: 0.0, 4: 0} + + +class TestAverageClustering: + @classmethod + def setup_class(cls): + pytest.importorskip("numpy") + + def test_empty(self): + G = nx.Graph() + with pytest.raises(ZeroDivisionError): + nx.average_clustering(G) + + def test_average_clustering(self): + G = nx.cycle_graph(3) + G.add_edge(2, 3) + assert nx.average_clustering(G) == (1 + 1 + 1 / 3) / 4 + assert nx.average_clustering(G, count_zeros=True) == (1 + 1 + 1 / 3) / 4 + assert nx.average_clustering(G, count_zeros=False) == (1 + 1 + 1 / 3) / 3 + assert nx.average_clustering(G, [1, 2, 3]) == (1 + 1 / 3) / 3 + assert nx.average_clustering(G, [1, 2, 3], count_zeros=True) == (1 + 1 / 3) / 3 + assert nx.average_clustering(G, [1, 2, 3], count_zeros=False) == (1 + 1 / 3) / 2 + + def test_average_clustering_signed(self): + G = nx.cycle_graph(3) + G.add_edge(2, 3) + G.add_edge(0, 1, weight=-1) + assert nx.average_clustering(G, weight="weight") == (-1 - 1 - 1 / 3) / 4 + assert ( + nx.average_clustering(G, weight="weight", count_zeros=True) + == (-1 - 1 - 1 / 3) / 4 + ) + assert ( + nx.average_clustering(G, weight="weight", count_zeros=False) + == (-1 - 1 - 1 / 3) / 3 + ) + + +class TestDirectedAverageClustering: + @classmethod + def setup_class(cls): + pytest.importorskip("numpy") + + def test_empty(self): + G = nx.DiGraph() + with pytest.raises(ZeroDivisionError): + nx.average_clustering(G) + + def test_average_clustering(self): + G = nx.cycle_graph(3, create_using=nx.DiGraph()) + G.add_edge(2, 3) + assert nx.average_clustering(G) == (1 + 1 + 1 / 3) / 8 + assert nx.average_clustering(G, count_zeros=True) == (1 + 1 + 1 / 3) / 8 + assert nx.average_clustering(G, count_zeros=False) == (1 + 1 + 1 / 3) / 6 + assert nx.average_clustering(G, [1, 2, 3]) == (1 + 1 / 3) / 6 + assert nx.average_clustering(G, [1, 2, 3], count_zeros=True) == (1 + 1 / 3) / 6 + assert nx.average_clustering(G, [1, 2, 3], count_zeros=False) == (1 + 1 / 3) / 4 + + +class TestGeneralizedDegree: + def test_generalized_degree(self): + G = nx.Graph() + assert nx.generalized_degree(G) == {} + + def test_path(self): + G = nx.path_graph(5) + assert nx.generalized_degree(G, 0) == {0: 1} + assert nx.generalized_degree(G, 1) == {0: 2} + + def test_cubical(self): + G = nx.cubical_graph() + assert nx.generalized_degree(G, 0) == {0: 3} + + def test_k5(self): + G = nx.complete_graph(5) + assert nx.generalized_degree(G, 0) == {3: 4} + G.remove_edge(0, 1) + assert nx.generalized_degree(G, 0) == {2: 3} + assert nx.generalized_degree(G, [1, 2]) == {1: {2: 3}, 2: {2: 2, 3: 2}} + assert nx.generalized_degree(G) == { + 0: {2: 3}, + 1: {2: 3}, + 2: {2: 2, 3: 2}, + 3: {2: 2, 3: 2}, + 4: {2: 2, 3: 2}, + } diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/test_communicability.py b/lib/python3.10/site-packages/networkx/algorithms/tests/test_communicability.py new file mode 100644 index 0000000000000000000000000000000000000000..0f447094548415c089710b9b62ac4d73a27efeb5 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/tests/test_communicability.py @@ -0,0 +1,80 @@ +from collections import defaultdict + +import pytest + +pytest.importorskip("numpy") +pytest.importorskip("scipy") + +import networkx as nx +from networkx.algorithms.communicability_alg import communicability, communicability_exp + + +class TestCommunicability: + def test_communicability(self): + answer = { + 0: {0: 1.5430806348152435, 1: 1.1752011936438012}, + 1: {0: 1.1752011936438012, 1: 1.5430806348152435}, + } + # answer={(0, 0): 1.5430806348152435, + # (0, 1): 1.1752011936438012, + # (1, 0): 1.1752011936438012, + # (1, 1): 1.5430806348152435} + + result = communicability(nx.path_graph(2)) + for k1, val in result.items(): + for k2 in val: + assert answer[k1][k2] == pytest.approx(result[k1][k2], abs=1e-7) + + def test_communicability2(self): + answer_orig = { + ("1", "1"): 1.6445956054135658, + ("1", "Albert"): 0.7430186221096251, + ("1", "Aric"): 0.7430186221096251, + ("1", "Dan"): 1.6208126320442937, + ("1", "Franck"): 0.42639707170035257, + ("Albert", "1"): 0.7430186221096251, + ("Albert", "Albert"): 2.4368257358712189, + ("Albert", "Aric"): 1.4368257358712191, + ("Albert", "Dan"): 2.0472097037446453, + ("Albert", "Franck"): 1.8340111678944691, + ("Aric", "1"): 0.7430186221096251, + ("Aric", "Albert"): 1.4368257358712191, + ("Aric", "Aric"): 2.4368257358712193, + ("Aric", "Dan"): 2.0472097037446457, + ("Aric", "Franck"): 1.8340111678944691, + ("Dan", "1"): 1.6208126320442937, + ("Dan", "Albert"): 2.0472097037446453, + ("Dan", "Aric"): 2.0472097037446457, + ("Dan", "Dan"): 3.1306328496328168, + ("Dan", "Franck"): 1.4860372442192515, + ("Franck", "1"): 0.42639707170035257, + ("Franck", "Albert"): 1.8340111678944691, + ("Franck", "Aric"): 1.8340111678944691, + ("Franck", "Dan"): 1.4860372442192515, + ("Franck", "Franck"): 2.3876142275231915, + } + + answer = defaultdict(dict) + for (k1, k2), v in answer_orig.items(): + answer[k1][k2] = v + + G1 = nx.Graph( + [ + ("Franck", "Aric"), + ("Aric", "Dan"), + ("Dan", "Albert"), + ("Albert", "Franck"), + ("Dan", "1"), + ("Franck", "Albert"), + ] + ) + + result = communicability(G1) + for k1, val in result.items(): + for k2 in val: + assert answer[k1][k2] == pytest.approx(result[k1][k2], abs=1e-7) + + result = communicability_exp(G1) + for k1, val in result.items(): + for k2 in val: + assert answer[k1][k2] == pytest.approx(result[k1][k2], abs=1e-7) diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/test_core.py b/lib/python3.10/site-packages/networkx/algorithms/tests/test_core.py new file mode 100644 index 0000000000000000000000000000000000000000..726e98a70033e6320a031889aac24a03af82b441 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/tests/test_core.py @@ -0,0 +1,266 @@ +import pytest + +import networkx as nx +from networkx.utils import nodes_equal + + +class TestCore: + @classmethod + def setup_class(cls): + # G is the example graph in Figure 1 from Batagelj and + # Zaversnik's paper titled An O(m) Algorithm for Cores + # Decomposition of Networks, 2003, + # http://arXiv.org/abs/cs/0310049. With nodes labeled as + # shown, the 3-core is given by nodes 1-8, the 2-core by nodes + # 9-16, the 1-core by nodes 17-20 and node 21 is in the + # 0-core. + t1 = nx.convert_node_labels_to_integers(nx.tetrahedral_graph(), 1) + t2 = nx.convert_node_labels_to_integers(t1, 5) + G = nx.union(t1, t2) + G.add_edges_from( + [ + (3, 7), + (2, 11), + (11, 5), + (11, 12), + (5, 12), + (12, 19), + (12, 18), + (3, 9), + (7, 9), + (7, 10), + (9, 10), + (9, 20), + (17, 13), + (13, 14), + (14, 15), + (15, 16), + (16, 13), + ] + ) + G.add_node(21) + cls.G = G + + # Create the graph H resulting from the degree sequence + # [0, 1, 2, 2, 2, 2, 3] when using the Havel-Hakimi algorithm. + + degseq = [0, 1, 2, 2, 2, 2, 3] + H = nx.havel_hakimi_graph(degseq) + mapping = {6: 0, 0: 1, 4: 3, 5: 6, 3: 4, 1: 2, 2: 5} + cls.H = nx.relabel_nodes(H, mapping) + + def test_trivial(self): + """Empty graph""" + G = nx.Graph() + assert nx.core_number(G) == {} + + def test_core_number(self): + core = nx.core_number(self.G) + nodes_by_core = [sorted(n for n in core if core[n] == val) for val in range(4)] + assert nodes_equal(nodes_by_core[0], [21]) + assert nodes_equal(nodes_by_core[1], [17, 18, 19, 20]) + assert nodes_equal(nodes_by_core[2], [9, 10, 11, 12, 13, 14, 15, 16]) + assert nodes_equal(nodes_by_core[3], [1, 2, 3, 4, 5, 6, 7, 8]) + + def test_core_number2(self): + core = nx.core_number(self.H) + nodes_by_core = [sorted(n for n in core if core[n] == val) for val in range(3)] + assert nodes_equal(nodes_by_core[0], [0]) + assert nodes_equal(nodes_by_core[1], [1, 3]) + assert nodes_equal(nodes_by_core[2], [2, 4, 5, 6]) + + def test_core_number_multigraph(self): + G = nx.complete_graph(3) + G = nx.MultiGraph(G) + G.add_edge(1, 2) + with pytest.raises( + nx.NetworkXNotImplemented, match="not implemented for multigraph type" + ): + nx.core_number(G) + + def test_core_number_self_loop(self): + G = nx.cycle_graph(3) + G.add_edge(0, 0) + with pytest.raises( + nx.NetworkXNotImplemented, match="Input graph has self loops" + ): + nx.core_number(G) + + def test_directed_core_number(self): + """core number had a bug for directed graphs found in issue #1959""" + # small example where too timid edge removal can make cn[2] = 3 + G = nx.DiGraph() + edges = [(1, 2), (2, 1), (2, 3), (2, 4), (3, 4), (4, 3)] + G.add_edges_from(edges) + assert nx.core_number(G) == {1: 2, 2: 2, 3: 2, 4: 2} + # small example where too aggressive edge removal can make cn[2] = 2 + more_edges = [(1, 5), (3, 5), (4, 5), (3, 6), (4, 6), (5, 6)] + G.add_edges_from(more_edges) + assert nx.core_number(G) == {1: 3, 2: 3, 3: 3, 4: 3, 5: 3, 6: 3} + + def test_main_core(self): + main_core_subgraph = nx.k_core(self.H) + assert sorted(main_core_subgraph.nodes()) == [2, 4, 5, 6] + + def test_k_core(self): + # k=0 + k_core_subgraph = nx.k_core(self.H, k=0) + assert sorted(k_core_subgraph.nodes()) == sorted(self.H.nodes()) + # k=1 + k_core_subgraph = nx.k_core(self.H, k=1) + assert sorted(k_core_subgraph.nodes()) == [1, 2, 3, 4, 5, 6] + # k = 2 + k_core_subgraph = nx.k_core(self.H, k=2) + assert sorted(k_core_subgraph.nodes()) == [2, 4, 5, 6] + + def test_k_core_multigraph(self): + core_number = nx.core_number(self.H) + H = nx.MultiGraph(self.H) + with pytest.deprecated_call(): + nx.k_core(H, k=0, core_number=core_number) + + def test_main_crust(self): + main_crust_subgraph = nx.k_crust(self.H) + assert sorted(main_crust_subgraph.nodes()) == [0, 1, 3] + + def test_k_crust(self): + # k = 0 + k_crust_subgraph = nx.k_crust(self.H, k=2) + assert sorted(k_crust_subgraph.nodes()) == sorted(self.H.nodes()) + # k=1 + k_crust_subgraph = nx.k_crust(self.H, k=1) + assert sorted(k_crust_subgraph.nodes()) == [0, 1, 3] + # k=2 + k_crust_subgraph = nx.k_crust(self.H, k=0) + assert sorted(k_crust_subgraph.nodes()) == [0] + + def test_k_crust_multigraph(self): + core_number = nx.core_number(self.H) + H = nx.MultiGraph(self.H) + with pytest.deprecated_call(): + nx.k_crust(H, k=0, core_number=core_number) + + def test_main_shell(self): + main_shell_subgraph = nx.k_shell(self.H) + assert sorted(main_shell_subgraph.nodes()) == [2, 4, 5, 6] + + def test_k_shell(self): + # k=0 + k_shell_subgraph = nx.k_shell(self.H, k=2) + assert sorted(k_shell_subgraph.nodes()) == [2, 4, 5, 6] + # k=1 + k_shell_subgraph = nx.k_shell(self.H, k=1) + assert sorted(k_shell_subgraph.nodes()) == [1, 3] + # k=2 + k_shell_subgraph = nx.k_shell(self.H, k=0) + assert sorted(k_shell_subgraph.nodes()) == [0] + + def test_k_shell_multigraph(self): + core_number = nx.core_number(self.H) + H = nx.MultiGraph(self.H) + with pytest.deprecated_call(): + nx.k_shell(H, k=0, core_number=core_number) + + def test_k_corona(self): + # k=0 + k_corona_subgraph = nx.k_corona(self.H, k=2) + assert sorted(k_corona_subgraph.nodes()) == [2, 4, 5, 6] + # k=1 + k_corona_subgraph = nx.k_corona(self.H, k=1) + assert sorted(k_corona_subgraph.nodes()) == [1] + # k=2 + k_corona_subgraph = nx.k_corona(self.H, k=0) + assert sorted(k_corona_subgraph.nodes()) == [0] + + def test_k_corona_multigraph(self): + core_number = nx.core_number(self.H) + H = nx.MultiGraph(self.H) + with pytest.deprecated_call(): + nx.k_corona(H, k=0, core_number=core_number) + + def test_k_truss(self): + # k=-1 + k_truss_subgraph = nx.k_truss(self.G, -1) + assert sorted(k_truss_subgraph.nodes()) == list(range(1, 21)) + # k=0 + k_truss_subgraph = nx.k_truss(self.G, 0) + assert sorted(k_truss_subgraph.nodes()) == list(range(1, 21)) + # k=1 + k_truss_subgraph = nx.k_truss(self.G, 1) + assert sorted(k_truss_subgraph.nodes()) == list(range(1, 21)) + # k=2 + k_truss_subgraph = nx.k_truss(self.G, 2) + assert sorted(k_truss_subgraph.nodes()) == list(range(1, 21)) + # k=3 + k_truss_subgraph = nx.k_truss(self.G, 3) + assert sorted(k_truss_subgraph.nodes()) == list(range(1, 13)) + + k_truss_subgraph = nx.k_truss(self.G, 4) + assert sorted(k_truss_subgraph.nodes()) == list(range(1, 9)) + + k_truss_subgraph = nx.k_truss(self.G, 5) + assert sorted(k_truss_subgraph.nodes()) == [] + + def test_k_truss_digraph(self): + G = nx.complete_graph(3) + G = nx.DiGraph(G) + G.add_edge(2, 1) + with pytest.raises( + nx.NetworkXNotImplemented, match="not implemented for directed type" + ): + nx.k_truss(G, k=1) + + def test_k_truss_multigraph(self): + G = nx.complete_graph(3) + G = nx.MultiGraph(G) + G.add_edge(1, 2) + with pytest.raises( + nx.NetworkXNotImplemented, match="not implemented for multigraph type" + ): + nx.k_truss(G, k=1) + + def test_k_truss_self_loop(self): + G = nx.cycle_graph(3) + G.add_edge(0, 0) + with pytest.raises( + nx.NetworkXNotImplemented, match="Input graph has self loops" + ): + nx.k_truss(G, k=1) + + def test_onion_layers(self): + layers = nx.onion_layers(self.G) + nodes_by_layer = [ + sorted(n for n in layers if layers[n] == val) for val in range(1, 7) + ] + assert nodes_equal(nodes_by_layer[0], [21]) + assert nodes_equal(nodes_by_layer[1], [17, 18, 19, 20]) + assert nodes_equal(nodes_by_layer[2], [10, 12, 13, 14, 15, 16]) + assert nodes_equal(nodes_by_layer[3], [9, 11]) + assert nodes_equal(nodes_by_layer[4], [1, 2, 4, 5, 6, 8]) + assert nodes_equal(nodes_by_layer[5], [3, 7]) + + def test_onion_digraph(self): + G = nx.complete_graph(3) + G = nx.DiGraph(G) + G.add_edge(2, 1) + with pytest.raises( + nx.NetworkXNotImplemented, match="not implemented for directed type" + ): + nx.onion_layers(G) + + def test_onion_multigraph(self): + G = nx.complete_graph(3) + G = nx.MultiGraph(G) + G.add_edge(1, 2) + with pytest.raises( + nx.NetworkXNotImplemented, match="not implemented for multigraph type" + ): + nx.onion_layers(G) + + def test_onion_self_loop(self): + G = nx.cycle_graph(3) + G.add_edge(0, 0) + with pytest.raises( + nx.NetworkXNotImplemented, match="Input graph contains self loops" + ): + nx.onion_layers(G) diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/test_covering.py b/lib/python3.10/site-packages/networkx/algorithms/tests/test_covering.py new file mode 100644 index 0000000000000000000000000000000000000000..b2f97a866b0e09c199c2edb9f40f20986caa8fbc --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/tests/test_covering.py @@ -0,0 +1,85 @@ +import pytest + +import networkx as nx + + +class TestMinEdgeCover: + """Tests for :func:`networkx.algorithms.min_edge_cover`""" + + def test_empty_graph(self): + G = nx.Graph() + assert nx.min_edge_cover(G) == set() + + def test_graph_with_loop(self): + G = nx.Graph() + G.add_edge(0, 0) + assert nx.min_edge_cover(G) == {(0, 0)} + + def test_graph_with_isolated_v(self): + G = nx.Graph() + G.add_node(1) + with pytest.raises( + nx.NetworkXException, + match="Graph has a node with no edge incident on it, so no edge cover exists.", + ): + nx.min_edge_cover(G) + + def test_graph_single_edge(self): + G = nx.Graph([(0, 1)]) + assert nx.min_edge_cover(G) in ({(0, 1)}, {(1, 0)}) + + def test_graph_two_edge_path(self): + G = nx.path_graph(3) + min_cover = nx.min_edge_cover(G) + assert len(min_cover) == 2 + for u, v in G.edges: + assert (u, v) in min_cover or (v, u) in min_cover + + def test_bipartite_explicit(self): + G = nx.Graph() + G.add_nodes_from([1, 2, 3, 4], bipartite=0) + G.add_nodes_from(["a", "b", "c"], bipartite=1) + G.add_edges_from([(1, "a"), (1, "b"), (2, "b"), (2, "c"), (3, "c"), (4, "a")]) + # Use bipartite method by prescribing the algorithm + min_cover = nx.min_edge_cover( + G, nx.algorithms.bipartite.matching.eppstein_matching + ) + assert nx.is_edge_cover(G, min_cover) + assert len(min_cover) == 8 + # Use the default method which is not specialized for bipartite + min_cover2 = nx.min_edge_cover(G) + assert nx.is_edge_cover(G, min_cover2) + assert len(min_cover2) == 4 + + def test_complete_graph_even(self): + G = nx.complete_graph(10) + min_cover = nx.min_edge_cover(G) + assert nx.is_edge_cover(G, min_cover) + assert len(min_cover) == 5 + + def test_complete_graph_odd(self): + G = nx.complete_graph(11) + min_cover = nx.min_edge_cover(G) + assert nx.is_edge_cover(G, min_cover) + assert len(min_cover) == 6 + + +class TestIsEdgeCover: + """Tests for :func:`networkx.algorithms.is_edge_cover`""" + + def test_empty_graph(self): + G = nx.Graph() + assert nx.is_edge_cover(G, set()) + + def test_graph_with_loop(self): + G = nx.Graph() + G.add_edge(1, 1) + assert nx.is_edge_cover(G, {(1, 1)}) + + def test_graph_single_edge(self): + G = nx.Graph() + G.add_edge(0, 1) + assert nx.is_edge_cover(G, {(0, 0), (1, 1)}) + assert nx.is_edge_cover(G, {(0, 1), (1, 0)}) + assert nx.is_edge_cover(G, {(0, 1)}) + assert not nx.is_edge_cover(G, {(0, 0)}) diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/test_cuts.py b/lib/python3.10/site-packages/networkx/algorithms/tests/test_cuts.py new file mode 100644 index 0000000000000000000000000000000000000000..923efa502acc623650f36ff41e72884e5e508bc9 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/tests/test_cuts.py @@ -0,0 +1,171 @@ +"""Unit tests for the :mod:`networkx.algorithms.cuts` module.""" + +import networkx as nx + + +class TestCutSize: + """Unit tests for the :func:`~networkx.cut_size` function.""" + + def test_symmetric(self): + """Tests that the cut size is symmetric.""" + G = nx.barbell_graph(3, 0) + S = {0, 1, 4} + T = {2, 3, 5} + assert nx.cut_size(G, S, T) == 4 + assert nx.cut_size(G, T, S) == 4 + + def test_single_edge(self): + """Tests for a cut of a single edge.""" + G = nx.barbell_graph(3, 0) + S = {0, 1, 2} + T = {3, 4, 5} + assert nx.cut_size(G, S, T) == 1 + assert nx.cut_size(G, T, S) == 1 + + def test_directed(self): + """Tests that each directed edge is counted once in the cut.""" + G = nx.barbell_graph(3, 0).to_directed() + S = {0, 1, 2} + T = {3, 4, 5} + assert nx.cut_size(G, S, T) == 2 + assert nx.cut_size(G, T, S) == 2 + + def test_directed_symmetric(self): + """Tests that a cut in a directed graph is symmetric.""" + G = nx.barbell_graph(3, 0).to_directed() + S = {0, 1, 4} + T = {2, 3, 5} + assert nx.cut_size(G, S, T) == 8 + assert nx.cut_size(G, T, S) == 8 + + def test_multigraph(self): + """Tests that parallel edges are each counted for a cut.""" + G = nx.MultiGraph(["ab", "ab"]) + assert nx.cut_size(G, {"a"}, {"b"}) == 2 + + +class TestVolume: + """Unit tests for the :func:`~networkx.volume` function.""" + + def test_graph(self): + G = nx.cycle_graph(4) + assert nx.volume(G, {0, 1}) == 4 + + def test_digraph(self): + G = nx.DiGraph([(0, 1), (1, 2), (2, 3), (3, 0)]) + assert nx.volume(G, {0, 1}) == 2 + + def test_multigraph(self): + edges = list(nx.cycle_graph(4).edges()) + G = nx.MultiGraph(edges * 2) + assert nx.volume(G, {0, 1}) == 8 + + def test_multidigraph(self): + edges = [(0, 1), (1, 2), (2, 3), (3, 0)] + G = nx.MultiDiGraph(edges * 2) + assert nx.volume(G, {0, 1}) == 4 + + def test_barbell(self): + G = nx.barbell_graph(3, 0) + assert nx.volume(G, {0, 1, 2}) == 7 + assert nx.volume(G, {3, 4, 5}) == 7 + + +class TestNormalizedCutSize: + """Unit tests for the :func:`~networkx.normalized_cut_size` function.""" + + def test_graph(self): + G = nx.path_graph(4) + S = {1, 2} + T = set(G) - S + size = nx.normalized_cut_size(G, S, T) + # The cut looks like this: o-{-o--o-}-o + expected = 2 * ((1 / 4) + (1 / 2)) + assert expected == size + # Test with no input T + assert expected == nx.normalized_cut_size(G, S) + + def test_directed(self): + G = nx.DiGraph([(0, 1), (1, 2), (2, 3)]) + S = {1, 2} + T = set(G) - S + size = nx.normalized_cut_size(G, S, T) + # The cut looks like this: o-{->o-->o-}->o + expected = 2 * ((1 / 2) + (1 / 1)) + assert expected == size + # Test with no input T + assert expected == nx.normalized_cut_size(G, S) + + +class TestConductance: + """Unit tests for the :func:`~networkx.conductance` function.""" + + def test_graph(self): + G = nx.barbell_graph(5, 0) + # Consider the singleton sets containing the "bridge" nodes. + # There is only one cut edge, and each set has volume five. + S = {4} + T = {5} + conductance = nx.conductance(G, S, T) + expected = 1 / 5 + assert expected == conductance + # Test with no input T + G2 = nx.barbell_graph(3, 0) + # There is only one cut edge, and each set has volume seven. + S2 = {0, 1, 2} + assert nx.conductance(G2, S2) == 1 / 7 + + +class TestEdgeExpansion: + """Unit tests for the :func:`~networkx.edge_expansion` function.""" + + def test_graph(self): + G = nx.barbell_graph(5, 0) + S = set(range(5)) + T = set(G) - S + expansion = nx.edge_expansion(G, S, T) + expected = 1 / 5 + assert expected == expansion + # Test with no input T + assert expected == nx.edge_expansion(G, S) + + +class TestNodeExpansion: + """Unit tests for the :func:`~networkx.node_expansion` function.""" + + def test_graph(self): + G = nx.path_graph(8) + S = {3, 4, 5} + expansion = nx.node_expansion(G, S) + # The neighborhood of S has cardinality five, and S has + # cardinality three. + expected = 5 / 3 + assert expected == expansion + + +class TestBoundaryExpansion: + """Unit tests for the :func:`~networkx.boundary_expansion` function.""" + + def test_graph(self): + G = nx.complete_graph(10) + S = set(range(4)) + expansion = nx.boundary_expansion(G, S) + # The node boundary of S has cardinality six, and S has + # cardinality three. + expected = 6 / 4 + assert expected == expansion + + +class TestMixingExpansion: + """Unit tests for the :func:`~networkx.mixing_expansion` function.""" + + def test_graph(self): + G = nx.barbell_graph(5, 0) + S = set(range(5)) + T = set(G) - S + expansion = nx.mixing_expansion(G, S, T) + # There is one cut edge, and the total number of edges in the + # graph is twice the total number of edges in a clique of size + # five, plus one more for the bridge. + expected = 1 / (2 * (5 * 4 + 1)) + assert expected == expansion diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/test_cycles.py b/lib/python3.10/site-packages/networkx/algorithms/tests/test_cycles.py new file mode 100644 index 0000000000000000000000000000000000000000..dd21405ff394d22bc463b61df3bf50cd1a504b40 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/tests/test_cycles.py @@ -0,0 +1,974 @@ +from itertools import chain, islice, tee +from math import inf +from random import shuffle + +import pytest + +import networkx as nx +from networkx.algorithms.traversal.edgedfs import FORWARD, REVERSE + + +def check_independent(basis): + if len(basis) == 0: + return + + np = pytest.importorskip("numpy") + sp = pytest.importorskip("scipy") # Required by incidence_matrix + + H = nx.Graph() + for b in basis: + nx.add_cycle(H, b) + inc = nx.incidence_matrix(H, oriented=True) + rank = np.linalg.matrix_rank(inc.toarray(), tol=None, hermitian=False) + assert inc.shape[1] - rank == len(basis) + + +class TestCycles: + @classmethod + def setup_class(cls): + G = nx.Graph() + nx.add_cycle(G, [0, 1, 2, 3]) + nx.add_cycle(G, [0, 3, 4, 5]) + nx.add_cycle(G, [0, 1, 6, 7, 8]) + G.add_edge(8, 9) + cls.G = G + + def is_cyclic_permutation(self, a, b): + n = len(a) + if len(b) != n: + return False + l = a + a + return any(l[i : i + n] == b for i in range(n)) + + def test_cycle_basis(self): + G = self.G + cy = nx.cycle_basis(G, 0) + sort_cy = sorted(sorted(c) for c in cy) + assert sort_cy == [[0, 1, 2, 3], [0, 1, 6, 7, 8], [0, 3, 4, 5]] + cy = nx.cycle_basis(G, 1) + sort_cy = sorted(sorted(c) for c in cy) + assert sort_cy == [[0, 1, 2, 3], [0, 1, 6, 7, 8], [0, 3, 4, 5]] + cy = nx.cycle_basis(G, 9) + sort_cy = sorted(sorted(c) for c in cy) + assert sort_cy == [[0, 1, 2, 3], [0, 1, 6, 7, 8], [0, 3, 4, 5]] + # test disconnected graphs + nx.add_cycle(G, "ABC") + cy = nx.cycle_basis(G, 9) + sort_cy = sorted(sorted(c) for c in cy[:-1]) + [sorted(cy[-1])] + assert sort_cy == [[0, 1, 2, 3], [0, 1, 6, 7, 8], [0, 3, 4, 5], ["A", "B", "C"]] + + def test_cycle_basis2(self): + with pytest.raises(nx.NetworkXNotImplemented): + G = nx.DiGraph() + cy = nx.cycle_basis(G, 0) + + def test_cycle_basis3(self): + with pytest.raises(nx.NetworkXNotImplemented): + G = nx.MultiGraph() + cy = nx.cycle_basis(G, 0) + + def test_cycle_basis_ordered(self): + # see gh-6654 replace sets with (ordered) dicts + G = nx.cycle_graph(5) + G.update(nx.cycle_graph(range(3, 8))) + cbG = nx.cycle_basis(G) + + perm = {1: 0, 0: 1} # switch 0 and 1 + H = nx.relabel_nodes(G, perm) + cbH = [[perm.get(n, n) for n in cyc] for cyc in nx.cycle_basis(H)] + assert cbG == cbH + + def test_cycle_basis_self_loop(self): + """Tests the function for graphs with self loops""" + G = nx.Graph() + nx.add_cycle(G, [0, 1, 2, 3]) + nx.add_cycle(G, [0, 0, 6, 2]) + cy = nx.cycle_basis(G) + sort_cy = sorted(sorted(c) for c in cy) + assert sort_cy == [[0], [0, 1, 2], [0, 2, 3], [0, 2, 6]] + + def test_simple_cycles(self): + edges = [(0, 0), (0, 1), (0, 2), (1, 2), (2, 0), (2, 1), (2, 2)] + G = nx.DiGraph(edges) + cc = sorted(nx.simple_cycles(G)) + ca = [[0], [0, 1, 2], [0, 2], [1, 2], [2]] + assert len(cc) == len(ca) + for c in cc: + assert any(self.is_cyclic_permutation(c, rc) for rc in ca) + + def test_simple_cycles_singleton(self): + G = nx.Graph([(0, 0)]) # self-loop + assert list(nx.simple_cycles(G)) == [[0]] + + def test_unsortable(self): + # this test ensures that graphs whose nodes without an intrinsic + # ordering do not cause issues + G = nx.DiGraph() + nx.add_cycle(G, ["a", 1]) + c = list(nx.simple_cycles(G)) + assert len(c) == 1 + + def test_simple_cycles_small(self): + G = nx.DiGraph() + nx.add_cycle(G, [1, 2, 3]) + c = sorted(nx.simple_cycles(G)) + assert len(c) == 1 + assert self.is_cyclic_permutation(c[0], [1, 2, 3]) + nx.add_cycle(G, [10, 20, 30]) + cc = sorted(nx.simple_cycles(G)) + assert len(cc) == 2 + ca = [[1, 2, 3], [10, 20, 30]] + for c in cc: + assert any(self.is_cyclic_permutation(c, rc) for rc in ca) + + def test_simple_cycles_empty(self): + G = nx.DiGraph() + assert list(nx.simple_cycles(G)) == [] + + def worst_case_graph(self, k): + # see figure 1 in Johnson's paper + # this graph has exactly 3k simple cycles + G = nx.DiGraph() + for n in range(2, k + 2): + G.add_edge(1, n) + G.add_edge(n, k + 2) + G.add_edge(2 * k + 1, 1) + for n in range(k + 2, 2 * k + 2): + G.add_edge(n, 2 * k + 2) + G.add_edge(n, n + 1) + G.add_edge(2 * k + 3, k + 2) + for n in range(2 * k + 3, 3 * k + 3): + G.add_edge(2 * k + 2, n) + G.add_edge(n, 3 * k + 3) + G.add_edge(3 * k + 3, 2 * k + 2) + return G + + def test_worst_case_graph(self): + # see figure 1 in Johnson's paper + for k in range(3, 10): + G = self.worst_case_graph(k) + l = len(list(nx.simple_cycles(G))) + assert l == 3 * k + + def test_recursive_simple_and_not(self): + for k in range(2, 10): + G = self.worst_case_graph(k) + cc = sorted(nx.simple_cycles(G)) + rcc = sorted(nx.recursive_simple_cycles(G)) + assert len(cc) == len(rcc) + for c in cc: + assert any(self.is_cyclic_permutation(c, r) for r in rcc) + for rc in rcc: + assert any(self.is_cyclic_permutation(rc, c) for c in cc) + + def test_simple_graph_with_reported_bug(self): + G = nx.DiGraph() + edges = [ + (0, 2), + (0, 3), + (1, 0), + (1, 3), + (2, 1), + (2, 4), + (3, 2), + (3, 4), + (4, 0), + (4, 1), + (4, 5), + (5, 0), + (5, 1), + (5, 2), + (5, 3), + ] + G.add_edges_from(edges) + cc = sorted(nx.simple_cycles(G)) + assert len(cc) == 26 + rcc = sorted(nx.recursive_simple_cycles(G)) + assert len(cc) == len(rcc) + for c in cc: + assert any(self.is_cyclic_permutation(c, rc) for rc in rcc) + for rc in rcc: + assert any(self.is_cyclic_permutation(rc, c) for c in cc) + + +def pairwise(iterable): + a, b = tee(iterable) + next(b, None) + return zip(a, b) + + +def cycle_edges(c): + return pairwise(chain(c, islice(c, 1))) + + +def directed_cycle_edgeset(c): + return frozenset(cycle_edges(c)) + + +def undirected_cycle_edgeset(c): + if len(c) == 1: + return frozenset(cycle_edges(c)) + return frozenset(map(frozenset, cycle_edges(c))) + + +def multigraph_cycle_edgeset(c): + if len(c) <= 2: + return frozenset(cycle_edges(c)) + else: + return frozenset(map(frozenset, cycle_edges(c))) + + +class TestCycleEnumeration: + @staticmethod + def K(n): + return nx.complete_graph(n) + + @staticmethod + def D(n): + return nx.complete_graph(n).to_directed() + + @staticmethod + def edgeset_function(g): + if g.is_directed(): + return directed_cycle_edgeset + elif g.is_multigraph(): + return multigraph_cycle_edgeset + else: + return undirected_cycle_edgeset + + def check_cycle(self, g, c, es, cache, source, original_c, length_bound, chordless): + if length_bound is not None and len(c) > length_bound: + raise RuntimeError( + f"computed cycle {original_c} exceeds length bound {length_bound}" + ) + if source == "computed": + if es in cache: + raise RuntimeError( + f"computed cycle {original_c} has already been found!" + ) + else: + cache[es] = tuple(original_c) + else: + if es in cache: + cache.pop(es) + else: + raise RuntimeError(f"expected cycle {original_c} was not computed") + + if not all(g.has_edge(*e) for e in es): + raise RuntimeError( + f"{source} claimed cycle {original_c} is not a cycle of g" + ) + if chordless and len(g.subgraph(c).edges) > len(c): + raise RuntimeError(f"{source} cycle {original_c} is not chordless") + + def check_cycle_algorithm( + self, + g, + expected_cycles, + length_bound=None, + chordless=False, + algorithm=None, + ): + if algorithm is None: + algorithm = nx.chordless_cycles if chordless else nx.simple_cycles + + # note: we shuffle the labels of g to rule out accidentally-correct + # behavior which occurred during the development of chordless cycle + # enumeration algorithms + + relabel = list(range(len(g))) + shuffle(relabel) + label = dict(zip(g, relabel)) + unlabel = dict(zip(relabel, g)) + h = nx.relabel_nodes(g, label, copy=True) + + edgeset = self.edgeset_function(h) + + params = {} + if length_bound is not None: + params["length_bound"] = length_bound + + cycle_cache = {} + for c in algorithm(h, **params): + original_c = [unlabel[x] for x in c] + es = edgeset(c) + self.check_cycle( + h, c, es, cycle_cache, "computed", original_c, length_bound, chordless + ) + + if isinstance(expected_cycles, int): + if len(cycle_cache) != expected_cycles: + raise RuntimeError( + f"expected {expected_cycles} cycles, got {len(cycle_cache)}" + ) + return + for original_c in expected_cycles: + c = [label[x] for x in original_c] + es = edgeset(c) + self.check_cycle( + h, c, es, cycle_cache, "expected", original_c, length_bound, chordless + ) + + if len(cycle_cache): + for c in cycle_cache.values(): + raise RuntimeError( + f"computed cycle {c} is valid but not in the expected cycle set!" + ) + + def check_cycle_enumeration_integer_sequence( + self, + g_family, + cycle_counts, + length_bound=None, + chordless=False, + algorithm=None, + ): + for g, num_cycles in zip(g_family, cycle_counts): + self.check_cycle_algorithm( + g, + num_cycles, + length_bound=length_bound, + chordless=chordless, + algorithm=algorithm, + ) + + def test_directed_chordless_cycle_digons(self): + g = nx.DiGraph() + nx.add_cycle(g, range(5)) + nx.add_cycle(g, range(5)[::-1]) + g.add_edge(0, 0) + expected_cycles = [(0,), (1, 2), (2, 3), (3, 4)] + self.check_cycle_algorithm(g, expected_cycles, chordless=True) + + self.check_cycle_algorithm(g, expected_cycles, chordless=True, length_bound=2) + + expected_cycles = [c for c in expected_cycles if len(c) < 2] + self.check_cycle_algorithm(g, expected_cycles, chordless=True, length_bound=1) + + def test_directed_chordless_cycle_undirected(self): + g = nx.DiGraph([(1, 2), (2, 3), (3, 4), (4, 5), (5, 0), (5, 1), (0, 2)]) + expected_cycles = [(0, 2, 3, 4, 5), (1, 2, 3, 4, 5)] + self.check_cycle_algorithm(g, expected_cycles, chordless=True) + + g = nx.DiGraph() + nx.add_cycle(g, range(5)) + nx.add_cycle(g, range(4, 9)) + g.add_edge(7, 3) + expected_cycles = [(0, 1, 2, 3, 4), (3, 4, 5, 6, 7), (4, 5, 6, 7, 8)] + self.check_cycle_algorithm(g, expected_cycles, chordless=True) + + g.add_edge(3, 7) + expected_cycles = [(0, 1, 2, 3, 4), (3, 7), (4, 5, 6, 7, 8)] + self.check_cycle_algorithm(g, expected_cycles, chordless=True) + + expected_cycles = [(3, 7)] + self.check_cycle_algorithm(g, expected_cycles, chordless=True, length_bound=4) + + g.remove_edge(7, 3) + expected_cycles = [(0, 1, 2, 3, 4), (4, 5, 6, 7, 8)] + self.check_cycle_algorithm(g, expected_cycles, chordless=True) + + g = nx.DiGraph((i, j) for i in range(10) for j in range(i)) + expected_cycles = [] + self.check_cycle_algorithm(g, expected_cycles, chordless=True) + + def test_chordless_cycles_directed(self): + G = nx.DiGraph() + nx.add_cycle(G, range(5)) + nx.add_cycle(G, range(4, 12)) + expected = [[*range(5)], [*range(4, 12)]] + self.check_cycle_algorithm(G, expected, chordless=True) + self.check_cycle_algorithm( + G, [c for c in expected if len(c) <= 5], length_bound=5, chordless=True + ) + + G.add_edge(7, 3) + expected.append([*range(3, 8)]) + self.check_cycle_algorithm(G, expected, chordless=True) + self.check_cycle_algorithm( + G, [c for c in expected if len(c) <= 5], length_bound=5, chordless=True + ) + + G.add_edge(3, 7) + expected[-1] = [7, 3] + self.check_cycle_algorithm(G, expected, chordless=True) + self.check_cycle_algorithm( + G, [c for c in expected if len(c) <= 5], length_bound=5, chordless=True + ) + + expected.pop() + G.remove_edge(7, 3) + self.check_cycle_algorithm(G, expected, chordless=True) + self.check_cycle_algorithm( + G, [c for c in expected if len(c) <= 5], length_bound=5, chordless=True + ) + + def test_directed_chordless_cycle_diclique(self): + g_family = [self.D(n) for n in range(10)] + expected_cycles = [(n * n - n) // 2 for n in range(10)] + self.check_cycle_enumeration_integer_sequence( + g_family, expected_cycles, chordless=True + ) + + expected_cycles = [(n * n - n) // 2 for n in range(10)] + self.check_cycle_enumeration_integer_sequence( + g_family, expected_cycles, length_bound=2 + ) + + def test_directed_chordless_loop_blockade(self): + g = nx.DiGraph((i, i) for i in range(10)) + nx.add_cycle(g, range(10)) + expected_cycles = [(i,) for i in range(10)] + self.check_cycle_algorithm(g, expected_cycles, chordless=True) + + self.check_cycle_algorithm(g, expected_cycles, length_bound=1) + + g = nx.MultiDiGraph(g) + g.add_edges_from((i, i) for i in range(0, 10, 2)) + expected_cycles = [(i,) for i in range(1, 10, 2)] + self.check_cycle_algorithm(g, expected_cycles, chordless=True) + + def test_simple_cycles_notable_clique_sequences(self): + # A000292: Number of labeled graphs on n+3 nodes that are triangles. + g_family = [self.K(n) for n in range(2, 12)] + expected = [0, 1, 4, 10, 20, 35, 56, 84, 120, 165, 220] + self.check_cycle_enumeration_integer_sequence( + g_family, expected, length_bound=3 + ) + + def triangles(g, **kwargs): + yield from (c for c in nx.simple_cycles(g, **kwargs) if len(c) == 3) + + # directed complete graphs have twice as many triangles thanks to reversal + g_family = [self.D(n) for n in range(2, 12)] + expected = [2 * e for e in expected] + self.check_cycle_enumeration_integer_sequence( + g_family, expected, length_bound=3, algorithm=triangles + ) + + def four_cycles(g, **kwargs): + yield from (c for c in nx.simple_cycles(g, **kwargs) if len(c) == 4) + + # A050534: the number of 4-cycles in the complete graph K_{n+1} + expected = [0, 0, 0, 3, 15, 45, 105, 210, 378, 630, 990] + g_family = [self.K(n) for n in range(1, 12)] + self.check_cycle_enumeration_integer_sequence( + g_family, expected, length_bound=4, algorithm=four_cycles + ) + + # directed complete graphs have twice as many 4-cycles thanks to reversal + expected = [2 * e for e in expected] + g_family = [self.D(n) for n in range(1, 15)] + self.check_cycle_enumeration_integer_sequence( + g_family, expected, length_bound=4, algorithm=four_cycles + ) + + # A006231: the number of elementary circuits in a complete directed graph with n nodes + expected = [0, 1, 5, 20, 84, 409, 2365] + g_family = [self.D(n) for n in range(1, 8)] + self.check_cycle_enumeration_integer_sequence(g_family, expected) + + # A002807: Number of cycles in the complete graph on n nodes K_{n}. + expected = [0, 0, 0, 1, 7, 37, 197, 1172] + g_family = [self.K(n) for n in range(8)] + self.check_cycle_enumeration_integer_sequence(g_family, expected) + + def test_directed_chordless_cycle_parallel_multiedges(self): + g = nx.MultiGraph() + + nx.add_cycle(g, range(5)) + expected = [[*range(5)]] + self.check_cycle_algorithm(g, expected, chordless=True) + + nx.add_cycle(g, range(5)) + expected = [*cycle_edges(range(5))] + self.check_cycle_algorithm(g, expected, chordless=True) + + nx.add_cycle(g, range(5)) + expected = [] + self.check_cycle_algorithm(g, expected, chordless=True) + + g = nx.MultiDiGraph() + + nx.add_cycle(g, range(5)) + expected = [[*range(5)]] + self.check_cycle_algorithm(g, expected, chordless=True) + + nx.add_cycle(g, range(5)) + self.check_cycle_algorithm(g, [], chordless=True) + + nx.add_cycle(g, range(5)) + self.check_cycle_algorithm(g, [], chordless=True) + + g = nx.MultiDiGraph() + + nx.add_cycle(g, range(5)) + nx.add_cycle(g, range(5)[::-1]) + expected = [*cycle_edges(range(5))] + self.check_cycle_algorithm(g, expected, chordless=True) + + nx.add_cycle(g, range(5)) + self.check_cycle_algorithm(g, [], chordless=True) + + def test_chordless_cycles_graph(self): + G = nx.Graph() + nx.add_cycle(G, range(5)) + nx.add_cycle(G, range(4, 12)) + expected = [[*range(5)], [*range(4, 12)]] + self.check_cycle_algorithm(G, expected, chordless=True) + self.check_cycle_algorithm( + G, [c for c in expected if len(c) <= 5], length_bound=5, chordless=True + ) + + G.add_edge(7, 3) + expected.append([*range(3, 8)]) + expected.append([4, 3, 7, 8, 9, 10, 11]) + self.check_cycle_algorithm(G, expected, chordless=True) + self.check_cycle_algorithm( + G, [c for c in expected if len(c) <= 5], length_bound=5, chordless=True + ) + + def test_chordless_cycles_giant_hamiltonian(self): + # ... o - e - o - e - o ... # o = odd, e = even + # ... ---/ \-----/ \--- ... # <-- "long" edges + # + # each long edge belongs to exactly one triangle, and one giant cycle + # of length n/2. The remaining edges each belong to a triangle + + n = 1000 + assert n % 2 == 0 + G = nx.Graph() + for v in range(n): + if not v % 2: + G.add_edge(v, (v + 2) % n) + G.add_edge(v, (v + 1) % n) + + expected = [[*range(0, n, 2)]] + [ + [x % n for x in range(i, i + 3)] for i in range(0, n, 2) + ] + self.check_cycle_algorithm(G, expected, chordless=True) + self.check_cycle_algorithm( + G, [c for c in expected if len(c) <= 3], length_bound=3, chordless=True + ) + + # ... o -> e -> o -> e -> o ... # o = odd, e = even + # ... <---/ \---<---/ \---< ... # <-- "long" edges + # + # this time, we orient the short and long edges in opposition + # the cycle structure of this graph is the same, but we need to reverse + # the long one in our representation. Also, we need to drop the size + # because our partitioning algorithm uses strongly connected components + # instead of separating graphs by their strong articulation points + + n = 100 + assert n % 2 == 0 + G = nx.DiGraph() + for v in range(n): + G.add_edge(v, (v + 1) % n) + if not v % 2: + G.add_edge((v + 2) % n, v) + + expected = [[*range(n - 2, -2, -2)]] + [ + [x % n for x in range(i, i + 3)] for i in range(0, n, 2) + ] + self.check_cycle_algorithm(G, expected, chordless=True) + self.check_cycle_algorithm( + G, [c for c in expected if len(c) <= 3], length_bound=3, chordless=True + ) + + def test_simple_cycles_acyclic_tournament(self): + n = 10 + G = nx.DiGraph((x, y) for x in range(n) for y in range(x)) + self.check_cycle_algorithm(G, []) + self.check_cycle_algorithm(G, [], chordless=True) + + for k in range(n + 1): + self.check_cycle_algorithm(G, [], length_bound=k) + self.check_cycle_algorithm(G, [], length_bound=k, chordless=True) + + def test_simple_cycles_graph(self): + testG = nx.cycle_graph(8) + cyc1 = tuple(range(8)) + self.check_cycle_algorithm(testG, [cyc1]) + + testG.add_edge(4, -1) + nx.add_path(testG, [3, -2, -3, -4]) + self.check_cycle_algorithm(testG, [cyc1]) + + testG.update(nx.cycle_graph(range(8, 16))) + cyc2 = tuple(range(8, 16)) + self.check_cycle_algorithm(testG, [cyc1, cyc2]) + + testG.update(nx.cycle_graph(range(4, 12))) + cyc3 = tuple(range(4, 12)) + expected = { + (0, 1, 2, 3, 4, 5, 6, 7), # cyc1 + (8, 9, 10, 11, 12, 13, 14, 15), # cyc2 + (4, 5, 6, 7, 8, 9, 10, 11), # cyc3 + (4, 5, 6, 7, 8, 15, 14, 13, 12, 11), # cyc2 + cyc3 + (0, 1, 2, 3, 4, 11, 10, 9, 8, 7), # cyc1 + cyc3 + (0, 1, 2, 3, 4, 11, 12, 13, 14, 15, 8, 7), # cyc1 + cyc2 + cyc3 + } + self.check_cycle_algorithm(testG, expected) + assert len(expected) == (2**3 - 1) - 1 # 1 disjoint comb: cyc1 + cyc2 + + # Basis size = 5 (2 loops overlapping gives 5 small loops + # E + # / \ Note: A-F = 10-15 + # 1-2-3-4-5 + # / | | \ cyc1=012DAB -- left + # 0 D F 6 cyc2=234E -- top + # \ | | / cyc3=45678F -- right + # B-A-9-8-7 cyc4=89AC -- bottom + # \ / cyc5=234F89AD -- middle + # C + # + # combinations of 5 basis elements: 2^5 - 1 (one includes no cycles) + # + # disjoint combs: (11 total) not simple cycles + # Any pair not including cyc5 => choose(4, 2) = 6 + # Any triple not including cyc5 => choose(4, 3) = 4 + # Any quad not including cyc5 => choose(4, 4) = 1 + # + # we expect 31 - 11 = 20 simple cycles + # + testG = nx.cycle_graph(12) + testG.update(nx.cycle_graph([12, 10, 13, 2, 14, 4, 15, 8]).edges) + expected = (2**5 - 1) - 11 # 11 disjoint combinations + self.check_cycle_algorithm(testG, expected) + + def test_simple_cycles_bounded(self): + # iteratively construct a cluster of nested cycles running in the same direction + # there should be one cycle of every length + d = nx.DiGraph() + expected = [] + for n in range(10): + nx.add_cycle(d, range(n)) + expected.append(n) + for k, e in enumerate(expected): + self.check_cycle_algorithm(d, e, length_bound=k) + + # iteratively construct a path of undirected cycles, connected at articulation + # points. there should be one cycle of every length except 2: no digons + g = nx.Graph() + top = 0 + expected = [] + for n in range(10): + expected.append(n if n < 2 else n - 1) + if n == 2: + # no digons in undirected graphs + continue + nx.add_cycle(g, range(top, top + n)) + top += n + for k, e in enumerate(expected): + self.check_cycle_algorithm(g, e, length_bound=k) + + def test_simple_cycles_bound_corner_cases(self): + G = nx.cycle_graph(4) + DG = nx.cycle_graph(4, create_using=nx.DiGraph) + assert list(nx.simple_cycles(G, length_bound=0)) == [] + assert list(nx.simple_cycles(DG, length_bound=0)) == [] + assert list(nx.chordless_cycles(G, length_bound=0)) == [] + assert list(nx.chordless_cycles(DG, length_bound=0)) == [] + + def test_simple_cycles_bound_error(self): + with pytest.raises(ValueError): + G = nx.DiGraph() + for c in nx.simple_cycles(G, -1): + assert False + + with pytest.raises(ValueError): + G = nx.Graph() + for c in nx.simple_cycles(G, -1): + assert False + + with pytest.raises(ValueError): + G = nx.Graph() + for c in nx.chordless_cycles(G, -1): + assert False + + with pytest.raises(ValueError): + G = nx.DiGraph() + for c in nx.chordless_cycles(G, -1): + assert False + + def test_chordless_cycles_clique(self): + g_family = [self.K(n) for n in range(2, 15)] + expected = [0, 1, 4, 10, 20, 35, 56, 84, 120, 165, 220, 286, 364] + self.check_cycle_enumeration_integer_sequence( + g_family, expected, chordless=True + ) + + # directed cliques have as many digons as undirected graphs have edges + expected = [(n * n - n) // 2 for n in range(15)] + g_family = [self.D(n) for n in range(15)] + self.check_cycle_enumeration_integer_sequence( + g_family, expected, chordless=True + ) + + +# These tests might fail with hash randomization since they depend on +# edge_dfs. For more information, see the comments in: +# networkx/algorithms/traversal/tests/test_edgedfs.py + + +class TestFindCycle: + @classmethod + def setup_class(cls): + cls.nodes = [0, 1, 2, 3] + cls.edges = [(-1, 0), (0, 1), (1, 0), (1, 0), (2, 1), (3, 1)] + + def test_graph_nocycle(self): + G = nx.Graph(self.edges) + pytest.raises(nx.exception.NetworkXNoCycle, nx.find_cycle, G, self.nodes) + + def test_graph_cycle(self): + G = nx.Graph(self.edges) + G.add_edge(2, 0) + x = list(nx.find_cycle(G, self.nodes)) + x_ = [(0, 1), (1, 2), (2, 0)] + assert x == x_ + + def test_graph_orientation_none(self): + G = nx.Graph(self.edges) + G.add_edge(2, 0) + x = list(nx.find_cycle(G, self.nodes, orientation=None)) + x_ = [(0, 1), (1, 2), (2, 0)] + assert x == x_ + + def test_graph_orientation_original(self): + G = nx.Graph(self.edges) + G.add_edge(2, 0) + x = list(nx.find_cycle(G, self.nodes, orientation="original")) + x_ = [(0, 1, FORWARD), (1, 2, FORWARD), (2, 0, FORWARD)] + assert x == x_ + + def test_digraph(self): + G = nx.DiGraph(self.edges) + x = list(nx.find_cycle(G, self.nodes)) + x_ = [(0, 1), (1, 0)] + assert x == x_ + + def test_digraph_orientation_none(self): + G = nx.DiGraph(self.edges) + x = list(nx.find_cycle(G, self.nodes, orientation=None)) + x_ = [(0, 1), (1, 0)] + assert x == x_ + + def test_digraph_orientation_original(self): + G = nx.DiGraph(self.edges) + x = list(nx.find_cycle(G, self.nodes, orientation="original")) + x_ = [(0, 1, FORWARD), (1, 0, FORWARD)] + assert x == x_ + + def test_multigraph(self): + G = nx.MultiGraph(self.edges) + x = list(nx.find_cycle(G, self.nodes)) + x_ = [(0, 1, 0), (1, 0, 1)] # or (1, 0, 2) + # Hash randomization...could be any edge. + assert x[0] == x_[0] + assert x[1][:2] == x_[1][:2] + + def test_multidigraph(self): + G = nx.MultiDiGraph(self.edges) + x = list(nx.find_cycle(G, self.nodes)) + x_ = [(0, 1, 0), (1, 0, 0)] # (1, 0, 1) + assert x[0] == x_[0] + assert x[1][:2] == x_[1][:2] + + def test_digraph_ignore(self): + G = nx.DiGraph(self.edges) + x = list(nx.find_cycle(G, self.nodes, orientation="ignore")) + x_ = [(0, 1, FORWARD), (1, 0, FORWARD)] + assert x == x_ + + def test_digraph_reverse(self): + G = nx.DiGraph(self.edges) + x = list(nx.find_cycle(G, self.nodes, orientation="reverse")) + x_ = [(1, 0, REVERSE), (0, 1, REVERSE)] + assert x == x_ + + def test_multidigraph_ignore(self): + G = nx.MultiDiGraph(self.edges) + x = list(nx.find_cycle(G, self.nodes, orientation="ignore")) + x_ = [(0, 1, 0, FORWARD), (1, 0, 0, FORWARD)] # or (1, 0, 1, 1) + assert x[0] == x_[0] + assert x[1][:2] == x_[1][:2] + assert x[1][3] == x_[1][3] + + def test_multidigraph_ignore2(self): + # Loop traversed an edge while ignoring its orientation. + G = nx.MultiDiGraph([(0, 1), (1, 2), (1, 2)]) + x = list(nx.find_cycle(G, [0, 1, 2], orientation="ignore")) + x_ = [(1, 2, 0, FORWARD), (1, 2, 1, REVERSE)] + assert x == x_ + + def test_multidigraph_original(self): + # Node 2 doesn't need to be searched again from visited from 4. + # The goal here is to cover the case when 2 to be researched from 4, + # when 4 is visited from the first time (so we must make sure that 4 + # is not visited from 2, and hence, we respect the edge orientation). + G = nx.MultiDiGraph([(0, 1), (1, 2), (2, 3), (4, 2)]) + pytest.raises( + nx.exception.NetworkXNoCycle, + nx.find_cycle, + G, + [0, 1, 2, 3, 4], + orientation="original", + ) + + def test_dag(self): + G = nx.DiGraph([(0, 1), (0, 2), (1, 2)]) + pytest.raises( + nx.exception.NetworkXNoCycle, nx.find_cycle, G, orientation="original" + ) + x = list(nx.find_cycle(G, orientation="ignore")) + assert x == [(0, 1, FORWARD), (1, 2, FORWARD), (0, 2, REVERSE)] + + def test_prev_explored(self): + # https://github.com/networkx/networkx/issues/2323 + + G = nx.DiGraph() + G.add_edges_from([(1, 0), (2, 0), (1, 2), (2, 1)]) + pytest.raises(nx.NetworkXNoCycle, nx.find_cycle, G, source=0) + x = list(nx.find_cycle(G, 1)) + x_ = [(1, 2), (2, 1)] + assert x == x_ + + x = list(nx.find_cycle(G, 2)) + x_ = [(2, 1), (1, 2)] + assert x == x_ + + x = list(nx.find_cycle(G)) + x_ = [(1, 2), (2, 1)] + assert x == x_ + + def test_no_cycle(self): + # https://github.com/networkx/networkx/issues/2439 + + G = nx.DiGraph() + G.add_edges_from([(1, 2), (2, 0), (3, 1), (3, 2)]) + pytest.raises(nx.NetworkXNoCycle, nx.find_cycle, G, source=0) + pytest.raises(nx.NetworkXNoCycle, nx.find_cycle, G) + + +def assert_basis_equal(a, b): + assert sorted(a) == sorted(b) + + +class TestMinimumCycleBasis: + @classmethod + def setup_class(cls): + T = nx.Graph() + nx.add_cycle(T, [1, 2, 3, 4], weight=1) + T.add_edge(2, 4, weight=5) + cls.diamond_graph = T + + def test_unweighted_diamond(self): + mcb = nx.minimum_cycle_basis(self.diamond_graph) + assert_basis_equal(mcb, [[2, 4, 1], [3, 4, 2]]) + + def test_weighted_diamond(self): + mcb = nx.minimum_cycle_basis(self.diamond_graph, weight="weight") + assert_basis_equal(mcb, [[2, 4, 1], [4, 3, 2, 1]]) + + def test_dimensionality(self): + # checks |MCB|=|E|-|V|+|NC| + ntrial = 10 + for seed in range(1234, 1234 + ntrial): + rg = nx.erdos_renyi_graph(10, 0.3, seed=seed) + nnodes = rg.number_of_nodes() + nedges = rg.number_of_edges() + ncomp = nx.number_connected_components(rg) + + mcb = nx.minimum_cycle_basis(rg) + assert len(mcb) == nedges - nnodes + ncomp + check_independent(mcb) + + def test_complete_graph(self): + cg = nx.complete_graph(5) + mcb = nx.minimum_cycle_basis(cg) + assert all(len(cycle) == 3 for cycle in mcb) + check_independent(mcb) + + def test_tree_graph(self): + tg = nx.balanced_tree(3, 3) + assert not nx.minimum_cycle_basis(tg) + + def test_petersen_graph(self): + G = nx.petersen_graph() + mcb = list(nx.minimum_cycle_basis(G)) + expected = [ + [4, 9, 7, 5, 0], + [1, 2, 3, 4, 0], + [1, 6, 8, 5, 0], + [4, 3, 8, 5, 0], + [1, 6, 9, 4, 0], + [1, 2, 7, 5, 0], + ] + assert len(mcb) == len(expected) + assert all(c in expected for c in mcb) + + # check that order of the nodes is a path + for c in mcb: + assert all(G.has_edge(u, v) for u, v in nx.utils.pairwise(c, cyclic=True)) + # check independence of the basis + check_independent(mcb) + + def test_gh6787_variable_weighted_complete_graph(self): + N = 8 + cg = nx.complete_graph(N) + cg.add_weighted_edges_from([(u, v, 9) for u, v in cg.edges]) + cg.add_weighted_edges_from([(u, v, 1) for u, v in nx.cycle_graph(N).edges]) + mcb = nx.minimum_cycle_basis(cg, weight="weight") + check_independent(mcb) + + def test_gh6787_and_edge_attribute_names(self): + G = nx.cycle_graph(4) + G.add_weighted_edges_from([(0, 2, 10), (1, 3, 10)], weight="dist") + expected = [[1, 3, 0], [3, 2, 1, 0], [1, 2, 0]] + mcb = list(nx.minimum_cycle_basis(G, weight="dist")) + assert len(mcb) == len(expected) + assert all(c in expected for c in mcb) + + # test not using a weight with weight attributes + expected = [[1, 3, 0], [1, 2, 0], [3, 2, 0]] + mcb = list(nx.minimum_cycle_basis(G)) + assert len(mcb) == len(expected) + assert all(c in expected for c in mcb) + + +class TestGirth: + @pytest.mark.parametrize( + ("G", "expected"), + ( + (nx.chvatal_graph(), 4), + (nx.tutte_graph(), 4), + (nx.petersen_graph(), 5), + (nx.heawood_graph(), 6), + (nx.pappus_graph(), 6), + (nx.random_labeled_tree(10, seed=42), inf), + (nx.empty_graph(10), inf), + (nx.Graph(chain(cycle_edges(range(5)), cycle_edges(range(6, 10)))), 4), + ( + nx.Graph( + [ + (0, 6), + (0, 8), + (0, 9), + (1, 8), + (2, 8), + (2, 9), + (4, 9), + (5, 9), + (6, 8), + (6, 9), + (7, 8), + ] + ), + 3, + ), + ), + ) + def test_girth(self, G, expected): + assert nx.girth(G) == expected diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/test_d_separation.py b/lib/python3.10/site-packages/networkx/algorithms/tests/test_d_separation.py new file mode 100644 index 0000000000000000000000000000000000000000..6f62971301b9b51c967bf773dec6c267b5df24a9 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/tests/test_d_separation.py @@ -0,0 +1,348 @@ +from itertools import combinations + +import pytest + +import networkx as nx + + +def path_graph(): + """Return a path graph of length three.""" + G = nx.path_graph(3, create_using=nx.DiGraph) + G.graph["name"] = "path" + nx.freeze(G) + return G + + +def fork_graph(): + """Return a three node fork graph.""" + G = nx.DiGraph(name="fork") + G.add_edges_from([(0, 1), (0, 2)]) + nx.freeze(G) + return G + + +def collider_graph(): + """Return a collider/v-structure graph with three nodes.""" + G = nx.DiGraph(name="collider") + G.add_edges_from([(0, 2), (1, 2)]) + nx.freeze(G) + return G + + +def naive_bayes_graph(): + """Return a simply Naive Bayes PGM graph.""" + G = nx.DiGraph(name="naive_bayes") + G.add_edges_from([(0, 1), (0, 2), (0, 3), (0, 4)]) + nx.freeze(G) + return G + + +def asia_graph(): + """Return the 'Asia' PGM graph.""" + G = nx.DiGraph(name="asia") + G.add_edges_from( + [ + ("asia", "tuberculosis"), + ("smoking", "cancer"), + ("smoking", "bronchitis"), + ("tuberculosis", "either"), + ("cancer", "either"), + ("either", "xray"), + ("either", "dyspnea"), + ("bronchitis", "dyspnea"), + ] + ) + nx.freeze(G) + return G + + +@pytest.fixture(name="path_graph") +def path_graph_fixture(): + return path_graph() + + +@pytest.fixture(name="fork_graph") +def fork_graph_fixture(): + return fork_graph() + + +@pytest.fixture(name="collider_graph") +def collider_graph_fixture(): + return collider_graph() + + +@pytest.fixture(name="naive_bayes_graph") +def naive_bayes_graph_fixture(): + return naive_bayes_graph() + + +@pytest.fixture(name="asia_graph") +def asia_graph_fixture(): + return asia_graph() + + +@pytest.fixture() +def large_collider_graph(): + edge_list = [("A", "B"), ("C", "B"), ("B", "D"), ("D", "E"), ("B", "F"), ("G", "E")] + G = nx.DiGraph(edge_list) + return G + + +@pytest.fixture() +def chain_and_fork_graph(): + edge_list = [("A", "B"), ("B", "C"), ("B", "D"), ("D", "C")] + G = nx.DiGraph(edge_list) + return G + + +@pytest.fixture() +def no_separating_set_graph(): + edge_list = [("A", "B")] + G = nx.DiGraph(edge_list) + return G + + +@pytest.fixture() +def large_no_separating_set_graph(): + edge_list = [("A", "B"), ("C", "A"), ("C", "B")] + G = nx.DiGraph(edge_list) + return G + + +@pytest.fixture() +def collider_trek_graph(): + edge_list = [("A", "B"), ("C", "B"), ("C", "D")] + G = nx.DiGraph(edge_list) + return G + + +@pytest.mark.parametrize( + "graph", + [path_graph(), fork_graph(), collider_graph(), naive_bayes_graph(), asia_graph()], +) +def test_markov_condition(graph): + """Test that the Markov condition holds for each PGM graph.""" + for node in graph.nodes: + parents = set(graph.predecessors(node)) + non_descendants = graph.nodes - nx.descendants(graph, node) - {node} - parents + assert nx.is_d_separator(graph, {node}, non_descendants, parents) + + +def test_path_graph_dsep(path_graph): + """Example-based test of d-separation for path_graph.""" + assert nx.is_d_separator(path_graph, {0}, {2}, {1}) + assert not nx.is_d_separator(path_graph, {0}, {2}, set()) + + +def test_fork_graph_dsep(fork_graph): + """Example-based test of d-separation for fork_graph.""" + assert nx.is_d_separator(fork_graph, {1}, {2}, {0}) + assert not nx.is_d_separator(fork_graph, {1}, {2}, set()) + + +def test_collider_graph_dsep(collider_graph): + """Example-based test of d-separation for collider_graph.""" + assert nx.is_d_separator(collider_graph, {0}, {1}, set()) + assert not nx.is_d_separator(collider_graph, {0}, {1}, {2}) + + +def test_naive_bayes_dsep(naive_bayes_graph): + """Example-based test of d-separation for naive_bayes_graph.""" + for u, v in combinations(range(1, 5), 2): + assert nx.is_d_separator(naive_bayes_graph, {u}, {v}, {0}) + assert not nx.is_d_separator(naive_bayes_graph, {u}, {v}, set()) + + +def test_asia_graph_dsep(asia_graph): + """Example-based test of d-separation for asia_graph.""" + assert nx.is_d_separator( + asia_graph, {"asia", "smoking"}, {"dyspnea", "xray"}, {"bronchitis", "either"} + ) + assert nx.is_d_separator( + asia_graph, {"tuberculosis", "cancer"}, {"bronchitis"}, {"smoking", "xray"} + ) + + +def test_undirected_graphs_are_not_supported(): + """ + Test that undirected graphs are not supported. + + d-separation and its related algorithms do not apply in + the case of undirected graphs. + """ + g = nx.path_graph(3, nx.Graph) + with pytest.raises(nx.NetworkXNotImplemented): + nx.is_d_separator(g, {0}, {1}, {2}) + with pytest.raises(nx.NetworkXNotImplemented): + nx.is_minimal_d_separator(g, {0}, {1}, {2}) + with pytest.raises(nx.NetworkXNotImplemented): + nx.find_minimal_d_separator(g, {0}, {1}) + + +def test_cyclic_graphs_raise_error(): + """ + Test that cycle graphs should cause erroring. + + This is because PGMs assume a directed acyclic graph. + """ + g = nx.cycle_graph(3, nx.DiGraph) + with pytest.raises(nx.NetworkXError): + nx.is_d_separator(g, {0}, {1}, {2}) + with pytest.raises(nx.NetworkXError): + nx.find_minimal_d_separator(g, {0}, {1}) + with pytest.raises(nx.NetworkXError): + nx.is_minimal_d_separator(g, {0}, {1}, {2}) + + +def test_invalid_nodes_raise_error(asia_graph): + """ + Test that graphs that have invalid nodes passed in raise errors. + """ + # Check both set and node arguments + with pytest.raises(nx.NodeNotFound): + nx.is_d_separator(asia_graph, {0}, {1}, {2}) + with pytest.raises(nx.NodeNotFound): + nx.is_d_separator(asia_graph, 0, 1, 2) + with pytest.raises(nx.NodeNotFound): + nx.is_minimal_d_separator(asia_graph, {0}, {1}, {2}) + with pytest.raises(nx.NodeNotFound): + nx.is_minimal_d_separator(asia_graph, 0, 1, 2) + with pytest.raises(nx.NodeNotFound): + nx.find_minimal_d_separator(asia_graph, {0}, {1}) + with pytest.raises(nx.NodeNotFound): + nx.find_minimal_d_separator(asia_graph, 0, 1) + + +def test_nondisjoint_node_sets_raise_error(collider_graph): + """ + Test that error is raised when node sets aren't disjoint. + """ + with pytest.raises(nx.NetworkXError): + nx.is_d_separator(collider_graph, 0, 1, 0) + with pytest.raises(nx.NetworkXError): + nx.is_d_separator(collider_graph, 0, 2, 0) + with pytest.raises(nx.NetworkXError): + nx.is_d_separator(collider_graph, 0, 0, 1) + with pytest.raises(nx.NetworkXError): + nx.is_d_separator(collider_graph, 1, 0, 0) + with pytest.raises(nx.NetworkXError): + nx.find_minimal_d_separator(collider_graph, 0, 0) + with pytest.raises(nx.NetworkXError): + nx.find_minimal_d_separator(collider_graph, 0, 1, included=0) + with pytest.raises(nx.NetworkXError): + nx.find_minimal_d_separator(collider_graph, 1, 0, included=0) + with pytest.raises(nx.NetworkXError): + nx.is_minimal_d_separator(collider_graph, 0, 0, set()) + with pytest.raises(nx.NetworkXError): + nx.is_minimal_d_separator(collider_graph, 0, 1, set(), included=0) + with pytest.raises(nx.NetworkXError): + nx.is_minimal_d_separator(collider_graph, 1, 0, set(), included=0) + + +def test_is_minimal_d_separator( + large_collider_graph, + chain_and_fork_graph, + no_separating_set_graph, + large_no_separating_set_graph, + collider_trek_graph, +): + # Case 1: + # create a graph A -> B <- C + # B -> D -> E; + # B -> F; + # G -> E; + assert not nx.is_d_separator(large_collider_graph, {"B"}, {"E"}, set()) + + # minimal set of the corresponding graph + # for B and E should be (D,) + Zmin = nx.find_minimal_d_separator(large_collider_graph, "B", "E") + # check that the minimal d-separator is a d-separating set + assert nx.is_d_separator(large_collider_graph, "B", "E", Zmin) + # the minimal separating set should also pass the test for minimality + assert nx.is_minimal_d_separator(large_collider_graph, "B", "E", Zmin) + # function should also work with set arguments + assert nx.is_minimal_d_separator(large_collider_graph, {"A", "B"}, {"G", "E"}, Zmin) + assert Zmin == {"D"} + + # Case 2: + # create a graph A -> B -> C + # B -> D -> C; + assert not nx.is_d_separator(chain_and_fork_graph, {"A"}, {"C"}, set()) + Zmin = nx.find_minimal_d_separator(chain_and_fork_graph, "A", "C") + + # the minimal separating set should pass the test for minimality + assert nx.is_minimal_d_separator(chain_and_fork_graph, "A", "C", Zmin) + assert Zmin == {"B"} + Znotmin = Zmin.union({"D"}) + assert not nx.is_minimal_d_separator(chain_and_fork_graph, "A", "C", Znotmin) + + # Case 3: + # create a graph A -> B + + # there is no m-separating set between A and B at all, so + # no minimal m-separating set can exist + assert not nx.is_d_separator(no_separating_set_graph, {"A"}, {"B"}, set()) + assert nx.find_minimal_d_separator(no_separating_set_graph, "A", "B") is None + + # Case 4: + # create a graph A -> B with A <- C -> B + + # there is no m-separating set between A and B at all, so + # no minimal m-separating set can exist + # however, the algorithm will initially propose C as a + # minimal (but invalid) separating set + assert not nx.is_d_separator(large_no_separating_set_graph, {"A"}, {"B"}, {"C"}) + assert nx.find_minimal_d_separator(large_no_separating_set_graph, "A", "B") is None + + # Test `included` and `excluded` args + # create graph A -> B <- C -> D + assert nx.find_minimal_d_separator(collider_trek_graph, "A", "D", included="B") == { + "B", + "C", + } + assert ( + nx.find_minimal_d_separator( + collider_trek_graph, "A", "D", included="B", restricted="B" + ) + is None + ) + + +def test_is_minimal_d_separator_checks_dsep(): + """Test that is_minimal_d_separator checks for d-separation as well.""" + g = nx.DiGraph() + g.add_edges_from( + [ + ("A", "B"), + ("A", "E"), + ("B", "C"), + ("B", "D"), + ("D", "C"), + ("D", "F"), + ("E", "D"), + ("E", "F"), + ] + ) + + assert not nx.is_d_separator(g, {"C"}, {"F"}, {"D"}) + + # since {'D'} and {} are not d-separators, we return false + assert not nx.is_minimal_d_separator(g, "C", "F", {"D"}) + assert not nx.is_minimal_d_separator(g, "C", "F", set()) + + +def test__reachable(large_collider_graph): + reachable = nx.algorithms.d_separation._reachable + g = large_collider_graph + x = {"F", "D"} + ancestors = {"A", "B", "C", "D", "F"} + assert reachable(g, x, ancestors, {"B"}) == {"B", "F", "D"} + assert reachable(g, x, ancestors, set()) == ancestors + + +def test_deprecations(): + G = nx.DiGraph([(0, 1), (1, 2)]) + with pytest.deprecated_call(): + nx.d_separated(G, 0, 2, {1}) + with pytest.deprecated_call(): + z = nx.minimal_d_separator(G, 0, 2) diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/test_dag.py b/lib/python3.10/site-packages/networkx/algorithms/tests/test_dag.py new file mode 100644 index 0000000000000000000000000000000000000000..e98a6a0ef537c71075ff2690716cffec81b22d7f --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/tests/test_dag.py @@ -0,0 +1,835 @@ +from collections import deque +from itertools import combinations, permutations + +import pytest + +import networkx as nx +from networkx.utils import edges_equal, pairwise + + +# Recipe from the itertools documentation. +def _consume(iterator): + "Consume the iterator entirely." + # Feed the entire iterator into a zero-length deque. + deque(iterator, maxlen=0) + + +class TestDagLongestPath: + """Unit tests computing the longest path in a directed acyclic graph.""" + + def test_empty(self): + G = nx.DiGraph() + assert nx.dag_longest_path(G) == [] + + def test_unweighted1(self): + edges = [(1, 2), (2, 3), (2, 4), (3, 5), (5, 6), (3, 7)] + G = nx.DiGraph(edges) + assert nx.dag_longest_path(G) == [1, 2, 3, 5, 6] + + def test_unweighted2(self): + edges = [(1, 2), (2, 3), (3, 4), (4, 5), (1, 3), (1, 5), (3, 5)] + G = nx.DiGraph(edges) + assert nx.dag_longest_path(G) == [1, 2, 3, 4, 5] + + def test_weighted(self): + G = nx.DiGraph() + edges = [(1, 2, -5), (2, 3, 1), (3, 4, 1), (4, 5, 0), (3, 5, 4), (1, 6, 2)] + G.add_weighted_edges_from(edges) + assert nx.dag_longest_path(G) == [2, 3, 5] + + def test_undirected_not_implemented(self): + G = nx.Graph() + pytest.raises(nx.NetworkXNotImplemented, nx.dag_longest_path, G) + + def test_unorderable_nodes(self): + """Tests that computing the longest path does not depend on + nodes being orderable. + + For more information, see issue #1989. + + """ + # Create the directed path graph on four nodes in a diamond shape, + # with nodes represented as (unorderable) Python objects. + nodes = [object() for n in range(4)] + G = nx.DiGraph() + G.add_edge(nodes[0], nodes[1]) + G.add_edge(nodes[0], nodes[2]) + G.add_edge(nodes[2], nodes[3]) + G.add_edge(nodes[1], nodes[3]) + + # this will raise NotImplementedError when nodes need to be ordered + nx.dag_longest_path(G) + + def test_multigraph_unweighted(self): + edges = [(1, 2), (2, 3), (2, 3), (3, 4), (4, 5), (1, 3), (1, 5), (3, 5)] + G = nx.MultiDiGraph(edges) + assert nx.dag_longest_path(G) == [1, 2, 3, 4, 5] + + def test_multigraph_weighted(self): + G = nx.MultiDiGraph() + edges = [ + (1, 2, 2), + (2, 3, 2), + (1, 3, 1), + (1, 3, 5), + (1, 3, 2), + ] + G.add_weighted_edges_from(edges) + assert nx.dag_longest_path(G) == [1, 3] + + def test_multigraph_weighted_default_weight(self): + G = nx.MultiDiGraph([(1, 2), (2, 3)]) # Unweighted edges + G.add_weighted_edges_from([(1, 3, 1), (1, 3, 5), (1, 3, 2)]) + + # Default value for default weight is 1 + assert nx.dag_longest_path(G) == [1, 3] + assert nx.dag_longest_path(G, default_weight=3) == [1, 2, 3] + + +class TestDagLongestPathLength: + """Unit tests for computing the length of a longest path in a + directed acyclic graph. + + """ + + def test_unweighted(self): + edges = [(1, 2), (2, 3), (2, 4), (3, 5), (5, 6), (5, 7)] + G = nx.DiGraph(edges) + assert nx.dag_longest_path_length(G) == 4 + + edges = [(1, 2), (2, 3), (3, 4), (4, 5), (1, 3), (1, 5), (3, 5)] + G = nx.DiGraph(edges) + assert nx.dag_longest_path_length(G) == 4 + + # test degenerate graphs + G = nx.DiGraph() + G.add_node(1) + assert nx.dag_longest_path_length(G) == 0 + + def test_undirected_not_implemented(self): + G = nx.Graph() + pytest.raises(nx.NetworkXNotImplemented, nx.dag_longest_path_length, G) + + def test_weighted(self): + edges = [(1, 2, -5), (2, 3, 1), (3, 4, 1), (4, 5, 0), (3, 5, 4), (1, 6, 2)] + G = nx.DiGraph() + G.add_weighted_edges_from(edges) + assert nx.dag_longest_path_length(G) == 5 + + def test_multigraph_unweighted(self): + edges = [(1, 2), (2, 3), (2, 3), (3, 4), (4, 5), (1, 3), (1, 5), (3, 5)] + G = nx.MultiDiGraph(edges) + assert nx.dag_longest_path_length(G) == 4 + + def test_multigraph_weighted(self): + G = nx.MultiDiGraph() + edges = [ + (1, 2, 2), + (2, 3, 2), + (1, 3, 1), + (1, 3, 5), + (1, 3, 2), + ] + G.add_weighted_edges_from(edges) + assert nx.dag_longest_path_length(G) == 5 + + +class TestDAG: + @classmethod + def setup_class(cls): + pass + + def test_topological_sort1(self): + DG = nx.DiGraph([(1, 2), (1, 3), (2, 3)]) + + for algorithm in [nx.topological_sort, nx.lexicographical_topological_sort]: + assert tuple(algorithm(DG)) == (1, 2, 3) + + DG.add_edge(3, 2) + + for algorithm in [nx.topological_sort, nx.lexicographical_topological_sort]: + pytest.raises(nx.NetworkXUnfeasible, _consume, algorithm(DG)) + + DG.remove_edge(2, 3) + + for algorithm in [nx.topological_sort, nx.lexicographical_topological_sort]: + assert tuple(algorithm(DG)) == (1, 3, 2) + + DG.remove_edge(3, 2) + + assert tuple(nx.topological_sort(DG)) in {(1, 2, 3), (1, 3, 2)} + assert tuple(nx.lexicographical_topological_sort(DG)) == (1, 2, 3) + + def test_is_directed_acyclic_graph(self): + G = nx.generators.complete_graph(2) + assert not nx.is_directed_acyclic_graph(G) + assert not nx.is_directed_acyclic_graph(G.to_directed()) + assert not nx.is_directed_acyclic_graph(nx.Graph([(3, 4), (4, 5)])) + assert nx.is_directed_acyclic_graph(nx.DiGraph([(3, 4), (4, 5)])) + + def test_topological_sort2(self): + DG = nx.DiGraph( + { + 1: [2], + 2: [3], + 3: [4], + 4: [5], + 5: [1], + 11: [12], + 12: [13], + 13: [14], + 14: [15], + } + ) + pytest.raises(nx.NetworkXUnfeasible, _consume, nx.topological_sort(DG)) + + assert not nx.is_directed_acyclic_graph(DG) + + DG.remove_edge(1, 2) + _consume(nx.topological_sort(DG)) + assert nx.is_directed_acyclic_graph(DG) + + def test_topological_sort3(self): + DG = nx.DiGraph() + DG.add_edges_from([(1, i) for i in range(2, 5)]) + DG.add_edges_from([(2, i) for i in range(5, 9)]) + DG.add_edges_from([(6, i) for i in range(9, 12)]) + DG.add_edges_from([(4, i) for i in range(12, 15)]) + + def validate(order): + assert isinstance(order, list) + assert set(order) == set(DG) + for u, v in combinations(order, 2): + assert not nx.has_path(DG, v, u) + + validate(list(nx.topological_sort(DG))) + + DG.add_edge(14, 1) + pytest.raises(nx.NetworkXUnfeasible, _consume, nx.topological_sort(DG)) + + def test_topological_sort4(self): + G = nx.Graph() + G.add_edge(1, 2) + # Only directed graphs can be topologically sorted. + pytest.raises(nx.NetworkXError, _consume, nx.topological_sort(G)) + + def test_topological_sort5(self): + G = nx.DiGraph() + G.add_edge(0, 1) + assert list(nx.topological_sort(G)) == [0, 1] + + def test_topological_sort6(self): + for algorithm in [nx.topological_sort, nx.lexicographical_topological_sort]: + + def runtime_error(): + DG = nx.DiGraph([(1, 2), (2, 3), (3, 4)]) + first = True + for x in algorithm(DG): + if first: + first = False + DG.add_edge(5 - x, 5) + + def unfeasible_error(): + DG = nx.DiGraph([(1, 2), (2, 3), (3, 4)]) + first = True + for x in algorithm(DG): + if first: + first = False + DG.remove_node(4) + + def runtime_error2(): + DG = nx.DiGraph([(1, 2), (2, 3), (3, 4)]) + first = True + for x in algorithm(DG): + if first: + first = False + DG.remove_node(2) + + pytest.raises(RuntimeError, runtime_error) + pytest.raises(RuntimeError, runtime_error2) + pytest.raises(nx.NetworkXUnfeasible, unfeasible_error) + + def test_all_topological_sorts_1(self): + DG = nx.DiGraph([(1, 2), (2, 3), (3, 4), (4, 5)]) + assert list(nx.all_topological_sorts(DG)) == [[1, 2, 3, 4, 5]] + + def test_all_topological_sorts_2(self): + DG = nx.DiGraph([(1, 3), (2, 1), (2, 4), (4, 3), (4, 5)]) + assert sorted(nx.all_topological_sorts(DG)) == [ + [2, 1, 4, 3, 5], + [2, 1, 4, 5, 3], + [2, 4, 1, 3, 5], + [2, 4, 1, 5, 3], + [2, 4, 5, 1, 3], + ] + + def test_all_topological_sorts_3(self): + def unfeasible(): + DG = nx.DiGraph([(1, 2), (2, 3), (3, 4), (4, 2), (4, 5)]) + # convert to list to execute generator + list(nx.all_topological_sorts(DG)) + + def not_implemented(): + G = nx.Graph([(1, 2), (2, 3)]) + # convert to list to execute generator + list(nx.all_topological_sorts(G)) + + def not_implemented_2(): + G = nx.MultiGraph([(1, 2), (1, 2), (2, 3)]) + list(nx.all_topological_sorts(G)) + + pytest.raises(nx.NetworkXUnfeasible, unfeasible) + pytest.raises(nx.NetworkXNotImplemented, not_implemented) + pytest.raises(nx.NetworkXNotImplemented, not_implemented_2) + + def test_all_topological_sorts_4(self): + DG = nx.DiGraph() + for i in range(7): + DG.add_node(i) + assert sorted(map(list, permutations(DG.nodes))) == sorted( + nx.all_topological_sorts(DG) + ) + + def test_all_topological_sorts_multigraph_1(self): + DG = nx.MultiDiGraph([(1, 2), (1, 2), (2, 3), (3, 4), (3, 5), (3, 5), (3, 5)]) + assert sorted(nx.all_topological_sorts(DG)) == sorted( + [[1, 2, 3, 4, 5], [1, 2, 3, 5, 4]] + ) + + def test_all_topological_sorts_multigraph_2(self): + N = 9 + edges = [] + for i in range(1, N): + edges.extend([(i, i + 1)] * i) + DG = nx.MultiDiGraph(edges) + assert list(nx.all_topological_sorts(DG)) == [list(range(1, N + 1))] + + def test_ancestors(self): + G = nx.DiGraph() + ancestors = nx.algorithms.dag.ancestors + G.add_edges_from([(1, 2), (1, 3), (4, 2), (4, 3), (4, 5), (2, 6), (5, 6)]) + assert ancestors(G, 6) == {1, 2, 4, 5} + assert ancestors(G, 3) == {1, 4} + assert ancestors(G, 1) == set() + pytest.raises(nx.NetworkXError, ancestors, G, 8) + + def test_descendants(self): + G = nx.DiGraph() + descendants = nx.algorithms.dag.descendants + G.add_edges_from([(1, 2), (1, 3), (4, 2), (4, 3), (4, 5), (2, 6), (5, 6)]) + assert descendants(G, 1) == {2, 3, 6} + assert descendants(G, 4) == {2, 3, 5, 6} + assert descendants(G, 3) == set() + pytest.raises(nx.NetworkXError, descendants, G, 8) + + def test_transitive_closure(self): + G = nx.DiGraph([(1, 2), (2, 3), (3, 4)]) + solution = [(1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)] + assert edges_equal(nx.transitive_closure(G).edges(), solution) + G = nx.DiGraph([(1, 2), (2, 3), (2, 4)]) + solution = [(1, 2), (1, 3), (1, 4), (2, 3), (2, 4)] + assert edges_equal(nx.transitive_closure(G).edges(), solution) + G = nx.DiGraph([(1, 2), (2, 3), (3, 1)]) + solution = [(1, 2), (2, 1), (2, 3), (3, 2), (1, 3), (3, 1)] + soln = sorted(solution + [(n, n) for n in G]) + assert edges_equal(sorted(nx.transitive_closure(G).edges()), soln) + + G = nx.Graph([(1, 2), (2, 3), (3, 4)]) + solution = [(1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)] + assert edges_equal(sorted(nx.transitive_closure(G).edges()), solution) + + G = nx.MultiGraph([(1, 2), (2, 3), (3, 4)]) + solution = [(1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)] + assert edges_equal(sorted(nx.transitive_closure(G).edges()), solution) + + G = nx.MultiDiGraph([(1, 2), (2, 3), (3, 4)]) + solution = [(1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)] + assert edges_equal(sorted(nx.transitive_closure(G).edges()), solution) + + # test if edge data is copied + G = nx.DiGraph([(1, 2, {"a": 3}), (2, 3, {"b": 0}), (3, 4)]) + H = nx.transitive_closure(G) + for u, v in G.edges(): + assert G.get_edge_data(u, v) == H.get_edge_data(u, v) + + k = 10 + G = nx.DiGraph((i, i + 1, {"f": "b", "weight": i}) for i in range(k)) + H = nx.transitive_closure(G) + for u, v in G.edges(): + assert G.get_edge_data(u, v) == H.get_edge_data(u, v) + + G = nx.Graph() + with pytest.raises(nx.NetworkXError): + nx.transitive_closure(G, reflexive="wrong input") + + def test_reflexive_transitive_closure(self): + G = nx.DiGraph([(1, 2), (2, 3), (3, 4)]) + solution = [(1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)] + soln = sorted(solution + [(n, n) for n in G]) + assert edges_equal(nx.transitive_closure(G).edges(), solution) + assert edges_equal(nx.transitive_closure(G, False).edges(), solution) + assert edges_equal(nx.transitive_closure(G, True).edges(), soln) + assert edges_equal(nx.transitive_closure(G, None).edges(), solution) + + G = nx.DiGraph([(1, 2), (2, 3), (2, 4)]) + solution = [(1, 2), (1, 3), (1, 4), (2, 3), (2, 4)] + soln = sorted(solution + [(n, n) for n in G]) + assert edges_equal(nx.transitive_closure(G).edges(), solution) + assert edges_equal(nx.transitive_closure(G, False).edges(), solution) + assert edges_equal(nx.transitive_closure(G, True).edges(), soln) + assert edges_equal(nx.transitive_closure(G, None).edges(), solution) + + G = nx.DiGraph([(1, 2), (2, 3), (3, 1)]) + solution = sorted([(1, 2), (2, 1), (2, 3), (3, 2), (1, 3), (3, 1)]) + soln = sorted(solution + [(n, n) for n in G]) + assert edges_equal(sorted(nx.transitive_closure(G).edges()), soln) + assert edges_equal(sorted(nx.transitive_closure(G, False).edges()), soln) + assert edges_equal(sorted(nx.transitive_closure(G, None).edges()), solution) + assert edges_equal(sorted(nx.transitive_closure(G, True).edges()), soln) + + G = nx.Graph([(1, 2), (2, 3), (3, 4)]) + solution = [(1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)] + soln = sorted(solution + [(n, n) for n in G]) + assert edges_equal(nx.transitive_closure(G).edges(), solution) + assert edges_equal(nx.transitive_closure(G, False).edges(), solution) + assert edges_equal(nx.transitive_closure(G, True).edges(), soln) + assert edges_equal(nx.transitive_closure(G, None).edges(), solution) + + G = nx.MultiGraph([(1, 2), (2, 3), (3, 4)]) + solution = [(1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)] + soln = sorted(solution + [(n, n) for n in G]) + assert edges_equal(nx.transitive_closure(G).edges(), solution) + assert edges_equal(nx.transitive_closure(G, False).edges(), solution) + assert edges_equal(nx.transitive_closure(G, True).edges(), soln) + assert edges_equal(nx.transitive_closure(G, None).edges(), solution) + + G = nx.MultiDiGraph([(1, 2), (2, 3), (3, 4)]) + solution = [(1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)] + soln = sorted(solution + [(n, n) for n in G]) + assert edges_equal(nx.transitive_closure(G).edges(), solution) + assert edges_equal(nx.transitive_closure(G, False).edges(), solution) + assert edges_equal(nx.transitive_closure(G, True).edges(), soln) + assert edges_equal(nx.transitive_closure(G, None).edges(), solution) + + def test_transitive_closure_dag(self): + G = nx.DiGraph([(1, 2), (2, 3), (3, 4)]) + transitive_closure = nx.algorithms.dag.transitive_closure_dag + solution = [(1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)] + assert edges_equal(transitive_closure(G).edges(), solution) + G = nx.DiGraph([(1, 2), (2, 3), (2, 4)]) + solution = [(1, 2), (1, 3), (1, 4), (2, 3), (2, 4)] + assert edges_equal(transitive_closure(G).edges(), solution) + G = nx.Graph([(1, 2), (2, 3), (3, 4)]) + pytest.raises(nx.NetworkXNotImplemented, transitive_closure, G) + + # test if edge data is copied + G = nx.DiGraph([(1, 2, {"a": 3}), (2, 3, {"b": 0}), (3, 4)]) + H = transitive_closure(G) + for u, v in G.edges(): + assert G.get_edge_data(u, v) == H.get_edge_data(u, v) + + k = 10 + G = nx.DiGraph((i, i + 1, {"foo": "bar", "weight": i}) for i in range(k)) + H = transitive_closure(G) + for u, v in G.edges(): + assert G.get_edge_data(u, v) == H.get_edge_data(u, v) + + def test_transitive_reduction(self): + G = nx.DiGraph([(1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)]) + transitive_reduction = nx.algorithms.dag.transitive_reduction + solution = [(1, 2), (2, 3), (3, 4)] + assert edges_equal(transitive_reduction(G).edges(), solution) + G = nx.DiGraph([(1, 2), (1, 3), (1, 4), (2, 3), (2, 4)]) + transitive_reduction = nx.algorithms.dag.transitive_reduction + solution = [(1, 2), (2, 3), (2, 4)] + assert edges_equal(transitive_reduction(G).edges(), solution) + G = nx.Graph([(1, 2), (2, 3), (3, 4)]) + pytest.raises(nx.NetworkXNotImplemented, transitive_reduction, G) + + def _check_antichains(self, solution, result): + sol = [frozenset(a) for a in solution] + res = [frozenset(a) for a in result] + assert set(sol) == set(res) + + def test_antichains(self): + antichains = nx.algorithms.dag.antichains + G = nx.DiGraph([(1, 2), (2, 3), (3, 4)]) + solution = [[], [4], [3], [2], [1]] + self._check_antichains(list(antichains(G)), solution) + G = nx.DiGraph([(1, 2), (2, 3), (2, 4), (3, 5), (5, 6), (5, 7)]) + solution = [ + [], + [4], + [7], + [7, 4], + [6], + [6, 4], + [6, 7], + [6, 7, 4], + [5], + [5, 4], + [3], + [3, 4], + [2], + [1], + ] + self._check_antichains(list(antichains(G)), solution) + G = nx.DiGraph([(1, 2), (1, 3), (3, 4), (3, 5), (5, 6)]) + solution = [ + [], + [6], + [5], + [4], + [4, 6], + [4, 5], + [3], + [2], + [2, 6], + [2, 5], + [2, 4], + [2, 4, 6], + [2, 4, 5], + [2, 3], + [1], + ] + self._check_antichains(list(antichains(G)), solution) + G = nx.DiGraph({0: [1, 2], 1: [4], 2: [3], 3: [4]}) + solution = [[], [4], [3], [2], [1], [1, 3], [1, 2], [0]] + self._check_antichains(list(antichains(G)), solution) + G = nx.DiGraph() + self._check_antichains(list(antichains(G)), [[]]) + G = nx.DiGraph() + G.add_nodes_from([0, 1, 2]) + solution = [[], [0], [1], [1, 0], [2], [2, 0], [2, 1], [2, 1, 0]] + self._check_antichains(list(antichains(G)), solution) + + def f(x): + return list(antichains(x)) + + G = nx.Graph([(1, 2), (2, 3), (3, 4)]) + pytest.raises(nx.NetworkXNotImplemented, f, G) + G = nx.DiGraph([(1, 2), (2, 3), (3, 1)]) + pytest.raises(nx.NetworkXUnfeasible, f, G) + + def test_lexicographical_topological_sort(self): + G = nx.DiGraph([(1, 2), (2, 3), (1, 4), (1, 5), (2, 6)]) + assert list(nx.lexicographical_topological_sort(G)) == [1, 2, 3, 4, 5, 6] + assert list(nx.lexicographical_topological_sort(G, key=lambda x: x)) == [ + 1, + 2, + 3, + 4, + 5, + 6, + ] + assert list(nx.lexicographical_topological_sort(G, key=lambda x: -x)) == [ + 1, + 5, + 4, + 2, + 6, + 3, + ] + + def test_lexicographical_topological_sort2(self): + """ + Check the case of two or more nodes with same key value. + Want to avoid exception raised due to comparing nodes directly. + See Issue #3493 + """ + + class Test_Node: + def __init__(self, n): + self.label = n + self.priority = 1 + + def __repr__(self): + return f"Node({self.label})" + + def sorting_key(node): + return node.priority + + test_nodes = [Test_Node(n) for n in range(4)] + G = nx.DiGraph() + edges = [(0, 1), (0, 2), (0, 3), (2, 3)] + G.add_edges_from((test_nodes[a], test_nodes[b]) for a, b in edges) + + sorting = list(nx.lexicographical_topological_sort(G, key=sorting_key)) + assert sorting == test_nodes + + +def test_topological_generations(): + G = nx.DiGraph( + {1: [2, 3], 2: [4, 5], 3: [7], 4: [], 5: [6, 7], 6: [], 7: []} + ).reverse() + # order within each generation is inconsequential + generations = [sorted(gen) for gen in nx.topological_generations(G)] + expected = [[4, 6, 7], [3, 5], [2], [1]] + assert generations == expected + + MG = nx.MultiDiGraph(G.edges) + MG.add_edge(2, 1) + generations = [sorted(gen) for gen in nx.topological_generations(MG)] + assert generations == expected + + +def test_topological_generations_empty(): + G = nx.DiGraph() + assert list(nx.topological_generations(G)) == [] + + +def test_topological_generations_cycle(): + G = nx.DiGraph([[2, 1], [3, 1], [1, 2]]) + with pytest.raises(nx.NetworkXUnfeasible): + list(nx.topological_generations(G)) + + +def test_is_aperiodic_cycle(): + G = nx.DiGraph() + nx.add_cycle(G, [1, 2, 3, 4]) + assert not nx.is_aperiodic(G) + + +def test_is_aperiodic_cycle2(): + G = nx.DiGraph() + nx.add_cycle(G, [1, 2, 3, 4]) + nx.add_cycle(G, [3, 4, 5, 6, 7]) + assert nx.is_aperiodic(G) + + +def test_is_aperiodic_cycle3(): + G = nx.DiGraph() + nx.add_cycle(G, [1, 2, 3, 4]) + nx.add_cycle(G, [3, 4, 5, 6]) + assert not nx.is_aperiodic(G) + + +def test_is_aperiodic_cycle4(): + G = nx.DiGraph() + nx.add_cycle(G, [1, 2, 3, 4]) + G.add_edge(1, 3) + assert nx.is_aperiodic(G) + + +def test_is_aperiodic_selfloop(): + G = nx.DiGraph() + nx.add_cycle(G, [1, 2, 3, 4]) + G.add_edge(1, 1) + assert nx.is_aperiodic(G) + + +def test_is_aperiodic_undirected_raises(): + G = nx.Graph() + pytest.raises(nx.NetworkXError, nx.is_aperiodic, G) + + +def test_is_aperiodic_empty_graph(): + G = nx.empty_graph(create_using=nx.DiGraph) + with pytest.raises(nx.NetworkXPointlessConcept, match="Graph has no nodes."): + nx.is_aperiodic(G) + + +def test_is_aperiodic_bipartite(): + # Bipartite graph + G = nx.DiGraph(nx.davis_southern_women_graph()) + assert not nx.is_aperiodic(G) + + +def test_is_aperiodic_rary_tree(): + G = nx.full_rary_tree(3, 27, create_using=nx.DiGraph()) + assert not nx.is_aperiodic(G) + + +def test_is_aperiodic_disconnected(): + # disconnected graph + G = nx.DiGraph() + nx.add_cycle(G, [1, 2, 3, 4]) + nx.add_cycle(G, [5, 6, 7, 8]) + assert not nx.is_aperiodic(G) + G.add_edge(1, 3) + G.add_edge(5, 7) + assert nx.is_aperiodic(G) + + +def test_is_aperiodic_disconnected2(): + G = nx.DiGraph() + nx.add_cycle(G, [0, 1, 2]) + G.add_edge(3, 3) + assert not nx.is_aperiodic(G) + + +class TestDagToBranching: + """Unit tests for the :func:`networkx.dag_to_branching` function.""" + + def test_single_root(self): + """Tests that a directed acyclic graph with a single degree + zero node produces an arborescence. + + """ + G = nx.DiGraph([(0, 1), (0, 2), (1, 3), (2, 3)]) + B = nx.dag_to_branching(G) + expected = nx.DiGraph([(0, 1), (1, 3), (0, 2), (2, 4)]) + assert nx.is_arborescence(B) + assert nx.is_isomorphic(B, expected) + + def test_multiple_roots(self): + """Tests that a directed acyclic graph with multiple degree zero + nodes creates an arborescence with multiple (weakly) connected + components. + + """ + G = nx.DiGraph([(0, 1), (0, 2), (1, 3), (2, 3), (5, 2)]) + B = nx.dag_to_branching(G) + expected = nx.DiGraph([(0, 1), (1, 3), (0, 2), (2, 4), (5, 6), (6, 7)]) + assert nx.is_branching(B) + assert not nx.is_arborescence(B) + assert nx.is_isomorphic(B, expected) + + # # Attributes are not copied by this function. If they were, this would + # # be a good test to uncomment. + # def test_copy_attributes(self): + # """Tests that node attributes are copied in the branching.""" + # G = nx.DiGraph([(0, 1), (0, 2), (1, 3), (2, 3)]) + # for v in G: + # G.node[v]['label'] = str(v) + # B = nx.dag_to_branching(G) + # # Determine the root node of the branching. + # root = next(v for v, d in B.in_degree() if d == 0) + # assert_equal(B.node[root]['label'], '0') + # children = B[root] + # # Get the left and right children, nodes 1 and 2, respectively. + # left, right = sorted(children, key=lambda v: B.node[v]['label']) + # assert_equal(B.node[left]['label'], '1') + # assert_equal(B.node[right]['label'], '2') + # # Get the left grandchild. + # children = B[left] + # assert_equal(len(children), 1) + # left_grandchild = arbitrary_element(children) + # assert_equal(B.node[left_grandchild]['label'], '3') + # # Get the right grandchild. + # children = B[right] + # assert_equal(len(children), 1) + # right_grandchild = arbitrary_element(children) + # assert_equal(B.node[right_grandchild]['label'], '3') + + def test_already_arborescence(self): + """Tests that a directed acyclic graph that is already an + arborescence produces an isomorphic arborescence as output. + + """ + A = nx.balanced_tree(2, 2, create_using=nx.DiGraph()) + B = nx.dag_to_branching(A) + assert nx.is_isomorphic(A, B) + + def test_already_branching(self): + """Tests that a directed acyclic graph that is already a + branching produces an isomorphic branching as output. + + """ + T1 = nx.balanced_tree(2, 2, create_using=nx.DiGraph()) + T2 = nx.balanced_tree(2, 2, create_using=nx.DiGraph()) + G = nx.disjoint_union(T1, T2) + B = nx.dag_to_branching(G) + assert nx.is_isomorphic(G, B) + + def test_not_acyclic(self): + """Tests that a non-acyclic graph causes an exception.""" + with pytest.raises(nx.HasACycle): + G = nx.DiGraph(pairwise("abc", cyclic=True)) + nx.dag_to_branching(G) + + def test_undirected(self): + with pytest.raises(nx.NetworkXNotImplemented): + nx.dag_to_branching(nx.Graph()) + + def test_multigraph(self): + with pytest.raises(nx.NetworkXNotImplemented): + nx.dag_to_branching(nx.MultiGraph()) + + def test_multidigraph(self): + with pytest.raises(nx.NetworkXNotImplemented): + nx.dag_to_branching(nx.MultiDiGraph()) + + +def test_ancestors_descendants_undirected(): + """Regression test to ensure ancestors and descendants work as expected on + undirected graphs.""" + G = nx.path_graph(5) + nx.ancestors(G, 2) == nx.descendants(G, 2) == {0, 1, 3, 4} + + +def test_compute_v_structures_raise(): + G = nx.Graph() + with pytest.raises(nx.NetworkXNotImplemented, match="for undirected type"): + nx.compute_v_structures(G) + + +def test_compute_v_structures(): + edges = [(0, 1), (0, 2), (3, 2)] + G = nx.DiGraph(edges) + + v_structs = set(nx.compute_v_structures(G)) + assert len(v_structs) == 1 + assert (0, 2, 3) in v_structs + + edges = [("A", "B"), ("C", "B"), ("B", "D"), ("D", "E"), ("G", "E")] + G = nx.DiGraph(edges) + v_structs = set(nx.compute_v_structures(G)) + assert len(v_structs) == 2 + + +def test_compute_v_structures_deprecated(): + G = nx.DiGraph() + with pytest.deprecated_call(): + nx.compute_v_structures(G) + + +def test_v_structures_raise(): + G = nx.Graph() + with pytest.raises(nx.NetworkXNotImplemented, match="for undirected type"): + nx.dag.v_structures(G) + + +@pytest.mark.parametrize( + ("edgelist", "expected"), + ( + ( + [(0, 1), (0, 2), (3, 2)], + {(0, 2, 3)}, + ), + ( + [("A", "B"), ("C", "B"), ("D", "G"), ("D", "E"), ("G", "E")], + {("A", "B", "C")}, + ), + ([(0, 1), (2, 1), (0, 2)], set()), # adjacent parents case: see gh-7385 + ), +) +def test_v_structures(edgelist, expected): + G = nx.DiGraph(edgelist) + v_structs = set(nx.dag.v_structures(G)) + assert v_structs == expected + + +def test_colliders_raise(): + G = nx.Graph() + with pytest.raises(nx.NetworkXNotImplemented, match="for undirected type"): + nx.dag.colliders(G) + + +@pytest.mark.parametrize( + ("edgelist", "expected"), + ( + ( + [(0, 1), (0, 2), (3, 2)], + {(0, 2, 3)}, + ), + ( + [("A", "B"), ("C", "B"), ("D", "G"), ("D", "E"), ("G", "E")], + {("A", "B", "C"), ("D", "E", "G")}, + ), + ), +) +def test_colliders(edgelist, expected): + G = nx.DiGraph(edgelist) + colliders = set(nx.dag.colliders(G)) + assert colliders == expected diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/test_distance_measures.py b/lib/python3.10/site-packages/networkx/algorithms/tests/test_distance_measures.py new file mode 100644 index 0000000000000000000000000000000000000000..0b3840fd23e0bca32d708a276d4b3e328b9dc82d --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/tests/test_distance_measures.py @@ -0,0 +1,774 @@ +import math +from random import Random + +import pytest + +import networkx as nx +from networkx import convert_node_labels_to_integers as cnlti +from networkx.algorithms.distance_measures import _extrema_bounding + + +def test__extrema_bounding_invalid_compute_kwarg(): + G = nx.path_graph(3) + with pytest.raises(ValueError, match="compute must be one of"): + _extrema_bounding(G, compute="spam") + + +class TestDistance: + def setup_method(self): + G = cnlti(nx.grid_2d_graph(4, 4), first_label=1, ordering="sorted") + self.G = G + + def test_eccentricity(self): + assert nx.eccentricity(self.G, 1) == 6 + e = nx.eccentricity(self.G) + assert e[1] == 6 + + sp = dict(nx.shortest_path_length(self.G)) + e = nx.eccentricity(self.G, sp=sp) + assert e[1] == 6 + + e = nx.eccentricity(self.G, v=1) + assert e == 6 + + # This behavior changed in version 1.8 (ticket #739) + e = nx.eccentricity(self.G, v=[1, 1]) + assert e[1] == 6 + e = nx.eccentricity(self.G, v=[1, 2]) + assert e[1] == 6 + + # test against graph with one node + G = nx.path_graph(1) + e = nx.eccentricity(G) + assert e[0] == 0 + e = nx.eccentricity(G, v=0) + assert e == 0 + pytest.raises(nx.NetworkXError, nx.eccentricity, G, 1) + + # test against empty graph + G = nx.empty_graph() + e = nx.eccentricity(G) + assert e == {} + + def test_diameter(self): + assert nx.diameter(self.G) == 6 + + def test_harmonic_diameter(self): + assert abs(nx.harmonic_diameter(self.G) - 2.0477815699658715) < 1e-12 + + def test_harmonic_diameter_empty(self): + assert math.isnan(nx.harmonic_diameter(nx.empty_graph())) + + def test_harmonic_diameter_single_node(self): + assert math.isnan(nx.harmonic_diameter(nx.empty_graph(1))) + + def test_harmonic_diameter_discrete(self): + assert math.isinf(nx.harmonic_diameter(nx.empty_graph(3))) + + def test_harmonic_diameter_not_strongly_connected(self): + DG = nx.DiGraph() + DG.add_edge(0, 1) + assert nx.harmonic_diameter(DG) == 2 + + def test_radius(self): + assert nx.radius(self.G) == 4 + + def test_periphery(self): + assert set(nx.periphery(self.G)) == {1, 4, 13, 16} + + def test_center(self): + assert set(nx.center(self.G)) == {6, 7, 10, 11} + + def test_bound_diameter(self): + assert nx.diameter(self.G, usebounds=True) == 6 + + def test_bound_radius(self): + assert nx.radius(self.G, usebounds=True) == 4 + + def test_bound_periphery(self): + result = {1, 4, 13, 16} + assert set(nx.periphery(self.G, usebounds=True)) == result + + def test_bound_center(self): + result = {6, 7, 10, 11} + assert set(nx.center(self.G, usebounds=True)) == result + + def test_radius_exception(self): + G = nx.Graph() + G.add_edge(1, 2) + G.add_edge(3, 4) + pytest.raises(nx.NetworkXError, nx.diameter, G) + + def test_eccentricity_infinite(self): + with pytest.raises(nx.NetworkXError): + G = nx.Graph([(1, 2), (3, 4)]) + e = nx.eccentricity(G) + + def test_eccentricity_undirected_not_connected(self): + with pytest.raises(nx.NetworkXError): + G = nx.Graph([(1, 2), (3, 4)]) + e = nx.eccentricity(G, sp=1) + + def test_eccentricity_directed_weakly_connected(self): + with pytest.raises(nx.NetworkXError): + DG = nx.DiGraph([(1, 2), (1, 3)]) + nx.eccentricity(DG) + + +class TestWeightedDistance: + def setup_method(self): + G = nx.Graph() + G.add_edge(0, 1, weight=0.6, cost=0.6, high_cost=6) + G.add_edge(0, 2, weight=0.2, cost=0.2, high_cost=2) + G.add_edge(2, 3, weight=0.1, cost=0.1, high_cost=1) + G.add_edge(2, 4, weight=0.7, cost=0.7, high_cost=7) + G.add_edge(2, 5, weight=0.9, cost=0.9, high_cost=9) + G.add_edge(1, 5, weight=0.3, cost=0.3, high_cost=3) + self.G = G + self.weight_fn = lambda v, u, e: 2 + + def test_eccentricity_weight_None(self): + assert nx.eccentricity(self.G, 1, weight=None) == 3 + e = nx.eccentricity(self.G, weight=None) + assert e[1] == 3 + + e = nx.eccentricity(self.G, v=1, weight=None) + assert e == 3 + + # This behavior changed in version 1.8 (ticket #739) + e = nx.eccentricity(self.G, v=[1, 1], weight=None) + assert e[1] == 3 + e = nx.eccentricity(self.G, v=[1, 2], weight=None) + assert e[1] == 3 + + def test_eccentricity_weight_attr(self): + assert nx.eccentricity(self.G, 1, weight="weight") == 1.5 + e = nx.eccentricity(self.G, weight="weight") + assert ( + e + == nx.eccentricity(self.G, weight="cost") + != nx.eccentricity(self.G, weight="high_cost") + ) + assert e[1] == 1.5 + + e = nx.eccentricity(self.G, v=1, weight="weight") + assert e == 1.5 + + # This behavior changed in version 1.8 (ticket #739) + e = nx.eccentricity(self.G, v=[1, 1], weight="weight") + assert e[1] == 1.5 + e = nx.eccentricity(self.G, v=[1, 2], weight="weight") + assert e[1] == 1.5 + + def test_eccentricity_weight_fn(self): + assert nx.eccentricity(self.G, 1, weight=self.weight_fn) == 6 + e = nx.eccentricity(self.G, weight=self.weight_fn) + assert e[1] == 6 + + e = nx.eccentricity(self.G, v=1, weight=self.weight_fn) + assert e == 6 + + # This behavior changed in version 1.8 (ticket #739) + e = nx.eccentricity(self.G, v=[1, 1], weight=self.weight_fn) + assert e[1] == 6 + e = nx.eccentricity(self.G, v=[1, 2], weight=self.weight_fn) + assert e[1] == 6 + + def test_diameter_weight_None(self): + assert nx.diameter(self.G, weight=None) == 3 + + def test_diameter_weight_attr(self): + assert ( + nx.diameter(self.G, weight="weight") + == nx.diameter(self.G, weight="cost") + == 1.6 + != nx.diameter(self.G, weight="high_cost") + ) + + def test_diameter_weight_fn(self): + assert nx.diameter(self.G, weight=self.weight_fn) == 6 + + def test_radius_weight_None(self): + assert pytest.approx(nx.radius(self.G, weight=None)) == 2 + + def test_radius_weight_attr(self): + assert ( + pytest.approx(nx.radius(self.G, weight="weight")) + == pytest.approx(nx.radius(self.G, weight="cost")) + == 0.9 + != nx.radius(self.G, weight="high_cost") + ) + + def test_radius_weight_fn(self): + assert nx.radius(self.G, weight=self.weight_fn) == 4 + + def test_periphery_weight_None(self): + for v in set(nx.periphery(self.G, weight=None)): + assert nx.eccentricity(self.G, v, weight=None) == nx.diameter( + self.G, weight=None + ) + + def test_periphery_weight_attr(self): + periphery = set(nx.periphery(self.G, weight="weight")) + assert ( + periphery + == set(nx.periphery(self.G, weight="cost")) + == set(nx.periphery(self.G, weight="high_cost")) + ) + for v in periphery: + assert ( + nx.eccentricity(self.G, v, weight="high_cost") + != nx.eccentricity(self.G, v, weight="weight") + == nx.eccentricity(self.G, v, weight="cost") + == nx.diameter(self.G, weight="weight") + == nx.diameter(self.G, weight="cost") + != nx.diameter(self.G, weight="high_cost") + ) + assert nx.eccentricity(self.G, v, weight="high_cost") == nx.diameter( + self.G, weight="high_cost" + ) + + def test_periphery_weight_fn(self): + for v in set(nx.periphery(self.G, weight=self.weight_fn)): + assert nx.eccentricity(self.G, v, weight=self.weight_fn) == nx.diameter( + self.G, weight=self.weight_fn + ) + + def test_center_weight_None(self): + for v in set(nx.center(self.G, weight=None)): + assert pytest.approx(nx.eccentricity(self.G, v, weight=None)) == nx.radius( + self.G, weight=None + ) + + def test_center_weight_attr(self): + center = set(nx.center(self.G, weight="weight")) + assert ( + center + == set(nx.center(self.G, weight="cost")) + != set(nx.center(self.G, weight="high_cost")) + ) + for v in center: + assert ( + nx.eccentricity(self.G, v, weight="high_cost") + != pytest.approx(nx.eccentricity(self.G, v, weight="weight")) + == pytest.approx(nx.eccentricity(self.G, v, weight="cost")) + == nx.radius(self.G, weight="weight") + == nx.radius(self.G, weight="cost") + != nx.radius(self.G, weight="high_cost") + ) + assert nx.eccentricity(self.G, v, weight="high_cost") == nx.radius( + self.G, weight="high_cost" + ) + + def test_center_weight_fn(self): + for v in set(nx.center(self.G, weight=self.weight_fn)): + assert nx.eccentricity(self.G, v, weight=self.weight_fn) == nx.radius( + self.G, weight=self.weight_fn + ) + + def test_bound_diameter_weight_None(self): + assert nx.diameter(self.G, usebounds=True, weight=None) == 3 + + def test_bound_diameter_weight_attr(self): + assert ( + nx.diameter(self.G, usebounds=True, weight="high_cost") + != nx.diameter(self.G, usebounds=True, weight="weight") + == nx.diameter(self.G, usebounds=True, weight="cost") + == 1.6 + != nx.diameter(self.G, usebounds=True, weight="high_cost") + ) + assert nx.diameter(self.G, usebounds=True, weight="high_cost") == nx.diameter( + self.G, usebounds=True, weight="high_cost" + ) + + def test_bound_diameter_weight_fn(self): + assert nx.diameter(self.G, usebounds=True, weight=self.weight_fn) == 6 + + def test_bound_radius_weight_None(self): + assert pytest.approx(nx.radius(self.G, usebounds=True, weight=None)) == 2 + + def test_bound_radius_weight_attr(self): + assert ( + nx.radius(self.G, usebounds=True, weight="high_cost") + != pytest.approx(nx.radius(self.G, usebounds=True, weight="weight")) + == pytest.approx(nx.radius(self.G, usebounds=True, weight="cost")) + == 0.9 + != nx.radius(self.G, usebounds=True, weight="high_cost") + ) + assert nx.radius(self.G, usebounds=True, weight="high_cost") == nx.radius( + self.G, usebounds=True, weight="high_cost" + ) + + def test_bound_radius_weight_fn(self): + assert nx.radius(self.G, usebounds=True, weight=self.weight_fn) == 4 + + def test_bound_periphery_weight_None(self): + result = {1, 3, 4} + assert set(nx.periphery(self.G, usebounds=True, weight=None)) == result + + def test_bound_periphery_weight_attr(self): + result = {4, 5} + assert ( + set(nx.periphery(self.G, usebounds=True, weight="weight")) + == set(nx.periphery(self.G, usebounds=True, weight="cost")) + == result + ) + + def test_bound_periphery_weight_fn(self): + result = {1, 3, 4} + assert ( + set(nx.periphery(self.G, usebounds=True, weight=self.weight_fn)) == result + ) + + def test_bound_center_weight_None(self): + result = {0, 2, 5} + assert set(nx.center(self.G, usebounds=True, weight=None)) == result + + def test_bound_center_weight_attr(self): + result = {0} + assert ( + set(nx.center(self.G, usebounds=True, weight="weight")) + == set(nx.center(self.G, usebounds=True, weight="cost")) + == result + ) + + def test_bound_center_weight_fn(self): + result = {0, 2, 5} + assert set(nx.center(self.G, usebounds=True, weight=self.weight_fn)) == result + + +class TestResistanceDistance: + @classmethod + def setup_class(cls): + global np + np = pytest.importorskip("numpy") + sp = pytest.importorskip("scipy") + + def setup_method(self): + G = nx.Graph() + G.add_edge(1, 2, weight=2) + G.add_edge(2, 3, weight=4) + G.add_edge(3, 4, weight=1) + G.add_edge(1, 4, weight=3) + self.G = G + + def test_resistance_distance_directed_graph(self): + G = nx.DiGraph() + with pytest.raises(nx.NetworkXNotImplemented): + nx.resistance_distance(G) + + def test_resistance_distance_empty(self): + G = nx.Graph() + with pytest.raises(nx.NetworkXError): + nx.resistance_distance(G) + + def test_resistance_distance_not_connected(self): + with pytest.raises(nx.NetworkXError): + self.G.add_node(5) + nx.resistance_distance(self.G, 1, 5) + + def test_resistance_distance_nodeA_not_in_graph(self): + with pytest.raises(nx.NetworkXError): + nx.resistance_distance(self.G, 9, 1) + + def test_resistance_distance_nodeB_not_in_graph(self): + with pytest.raises(nx.NetworkXError): + nx.resistance_distance(self.G, 1, 9) + + def test_resistance_distance(self): + rd = nx.resistance_distance(self.G, 1, 3, "weight", True) + test_data = 1 / (1 / (2 + 4) + 1 / (1 + 3)) + assert round(rd, 5) == round(test_data, 5) + + def test_resistance_distance_noinv(self): + rd = nx.resistance_distance(self.G, 1, 3, "weight", False) + test_data = 1 / (1 / (1 / 2 + 1 / 4) + 1 / (1 / 1 + 1 / 3)) + assert round(rd, 5) == round(test_data, 5) + + def test_resistance_distance_no_weight(self): + rd = nx.resistance_distance(self.G, 1, 3) + assert round(rd, 5) == 1 + + def test_resistance_distance_neg_weight(self): + self.G[2][3]["weight"] = -4 + rd = nx.resistance_distance(self.G, 1, 3, "weight", True) + test_data = 1 / (1 / (2 + -4) + 1 / (1 + 3)) + assert round(rd, 5) == round(test_data, 5) + + def test_multigraph(self): + G = nx.MultiGraph() + G.add_edge(1, 2, weight=2) + G.add_edge(2, 3, weight=4) + G.add_edge(3, 4, weight=1) + G.add_edge(1, 4, weight=3) + rd = nx.resistance_distance(G, 1, 3, "weight", True) + assert np.isclose(rd, 1 / (1 / (2 + 4) + 1 / (1 + 3))) + + def test_resistance_distance_div0(self): + with pytest.raises(ZeroDivisionError): + self.G[1][2]["weight"] = 0 + nx.resistance_distance(self.G, 1, 3, "weight") + + def test_resistance_distance_same_node(self): + assert nx.resistance_distance(self.G, 1, 1) == 0 + + def test_resistance_distance_only_nodeA(self): + rd = nx.resistance_distance(self.G, nodeA=1) + test_data = {} + test_data[1] = 0 + test_data[2] = 0.75 + test_data[3] = 1 + test_data[4] = 0.75 + assert type(rd) == dict + assert sorted(rd.keys()) == sorted(test_data.keys()) + for key in rd: + assert np.isclose(rd[key], test_data[key]) + + def test_resistance_distance_only_nodeB(self): + rd = nx.resistance_distance(self.G, nodeB=1) + test_data = {} + test_data[1] = 0 + test_data[2] = 0.75 + test_data[3] = 1 + test_data[4] = 0.75 + assert type(rd) == dict + assert sorted(rd.keys()) == sorted(test_data.keys()) + for key in rd: + assert np.isclose(rd[key], test_data[key]) + + def test_resistance_distance_all(self): + rd = nx.resistance_distance(self.G) + assert type(rd) == dict + assert round(rd[1][3], 5) == 1 + + +class TestEffectiveGraphResistance: + @classmethod + def setup_class(cls): + global np + np = pytest.importorskip("numpy") + sp = pytest.importorskip("scipy") + + def setup_method(self): + G = nx.Graph() + G.add_edge(1, 2, weight=2) + G.add_edge(1, 3, weight=1) + G.add_edge(2, 3, weight=4) + self.G = G + + def test_effective_graph_resistance_directed_graph(self): + G = nx.DiGraph() + with pytest.raises(nx.NetworkXNotImplemented): + nx.effective_graph_resistance(G) + + def test_effective_graph_resistance_empty(self): + G = nx.Graph() + with pytest.raises(nx.NetworkXError): + nx.effective_graph_resistance(G) + + def test_effective_graph_resistance_not_connected(self): + G = nx.Graph([(1, 2), (3, 4)]) + RG = nx.effective_graph_resistance(G) + assert np.isinf(RG) + + def test_effective_graph_resistance(self): + RG = nx.effective_graph_resistance(self.G, "weight", True) + rd12 = 1 / (1 / (1 + 4) + 1 / 2) + rd13 = 1 / (1 / (1 + 2) + 1 / 4) + rd23 = 1 / (1 / (2 + 4) + 1 / 1) + assert np.isclose(RG, rd12 + rd13 + rd23) + + def test_effective_graph_resistance_noinv(self): + RG = nx.effective_graph_resistance(self.G, "weight", False) + rd12 = 1 / (1 / (1 / 1 + 1 / 4) + 1 / (1 / 2)) + rd13 = 1 / (1 / (1 / 1 + 1 / 2) + 1 / (1 / 4)) + rd23 = 1 / (1 / (1 / 2 + 1 / 4) + 1 / (1 / 1)) + assert np.isclose(RG, rd12 + rd13 + rd23) + + def test_effective_graph_resistance_no_weight(self): + RG = nx.effective_graph_resistance(self.G) + assert np.isclose(RG, 2) + + def test_effective_graph_resistance_neg_weight(self): + self.G[2][3]["weight"] = -4 + RG = nx.effective_graph_resistance(self.G, "weight", True) + rd12 = 1 / (1 / (1 + -4) + 1 / 2) + rd13 = 1 / (1 / (1 + 2) + 1 / (-4)) + rd23 = 1 / (1 / (2 + -4) + 1 / 1) + assert np.isclose(RG, rd12 + rd13 + rd23) + + def test_effective_graph_resistance_multigraph(self): + G = nx.MultiGraph() + G.add_edge(1, 2, weight=2) + G.add_edge(1, 3, weight=1) + G.add_edge(2, 3, weight=1) + G.add_edge(2, 3, weight=3) + RG = nx.effective_graph_resistance(G, "weight", True) + edge23 = 1 / (1 / 1 + 1 / 3) + rd12 = 1 / (1 / (1 + edge23) + 1 / 2) + rd13 = 1 / (1 / (1 + 2) + 1 / edge23) + rd23 = 1 / (1 / (2 + edge23) + 1 / 1) + assert np.isclose(RG, rd12 + rd13 + rd23) + + def test_effective_graph_resistance_div0(self): + with pytest.raises(ZeroDivisionError): + self.G[1][2]["weight"] = 0 + nx.effective_graph_resistance(self.G, "weight") + + def test_effective_graph_resistance_complete_graph(self): + N = 10 + G = nx.complete_graph(N) + RG = nx.effective_graph_resistance(G) + assert np.isclose(RG, N - 1) + + def test_effective_graph_resistance_path_graph(self): + N = 10 + G = nx.path_graph(N) + RG = nx.effective_graph_resistance(G) + assert np.isclose(RG, (N - 1) * N * (N + 1) // 6) + + +class TestBarycenter: + """Test :func:`networkx.algorithms.distance_measures.barycenter`.""" + + def barycenter_as_subgraph(self, g, **kwargs): + """Return the subgraph induced on the barycenter of g""" + b = nx.barycenter(g, **kwargs) + assert isinstance(b, list) + assert set(b) <= set(g) + return g.subgraph(b) + + def test_must_be_connected(self): + pytest.raises(nx.NetworkXNoPath, nx.barycenter, nx.empty_graph(5)) + + def test_sp_kwarg(self): + # Complete graph K_5. Normally it works... + K_5 = nx.complete_graph(5) + sp = dict(nx.shortest_path_length(K_5)) + assert nx.barycenter(K_5, sp=sp) == list(K_5) + + # ...but not with the weight argument + for u, v, data in K_5.edges.data(): + data["weight"] = 1 + pytest.raises(ValueError, nx.barycenter, K_5, sp=sp, weight="weight") + + # ...and a corrupted sp can make it seem like K_5 is disconnected + del sp[0][1] + pytest.raises(nx.NetworkXNoPath, nx.barycenter, K_5, sp=sp) + + def test_trees(self): + """The barycenter of a tree is a single vertex or an edge. + + See [West01]_, p. 78. + """ + prng = Random(0xDEADBEEF) + for i in range(50): + RT = nx.random_labeled_tree(prng.randint(1, 75), seed=prng) + b = self.barycenter_as_subgraph(RT) + if len(b) == 2: + assert b.size() == 1 + else: + assert len(b) == 1 + assert b.size() == 0 + + def test_this_one_specific_tree(self): + """Test the tree pictured at the bottom of [West01]_, p. 78.""" + g = nx.Graph( + { + "a": ["b"], + "b": ["a", "x"], + "x": ["b", "y"], + "y": ["x", "z"], + "z": ["y", 0, 1, 2, 3, 4], + 0: ["z"], + 1: ["z"], + 2: ["z"], + 3: ["z"], + 4: ["z"], + } + ) + b = self.barycenter_as_subgraph(g, attr="barycentricity") + assert list(b) == ["z"] + assert not b.edges + expected_barycentricity = { + 0: 23, + 1: 23, + 2: 23, + 3: 23, + 4: 23, + "a": 35, + "b": 27, + "x": 21, + "y": 17, + "z": 15, + } + for node, barycentricity in expected_barycentricity.items(): + assert g.nodes[node]["barycentricity"] == barycentricity + + # Doubling weights should do nothing but double the barycentricities + for edge in g.edges: + g.edges[edge]["weight"] = 2 + b = self.barycenter_as_subgraph(g, weight="weight", attr="barycentricity2") + assert list(b) == ["z"] + assert not b.edges + for node, barycentricity in expected_barycentricity.items(): + assert g.nodes[node]["barycentricity2"] == barycentricity * 2 + + +class TestKemenyConstant: + @classmethod + def setup_class(cls): + global np + np = pytest.importorskip("numpy") + sp = pytest.importorskip("scipy") + + def setup_method(self): + G = nx.Graph() + w12 = 2 + w13 = 3 + w23 = 4 + G.add_edge(1, 2, weight=w12) + G.add_edge(1, 3, weight=w13) + G.add_edge(2, 3, weight=w23) + self.G = G + + def test_kemeny_constant_directed(self): + G = nx.DiGraph() + G.add_edge(1, 2) + G.add_edge(1, 3) + G.add_edge(2, 3) + with pytest.raises(nx.NetworkXNotImplemented): + nx.kemeny_constant(G) + + def test_kemeny_constant_not_connected(self): + self.G.add_node(5) + with pytest.raises(nx.NetworkXError): + nx.kemeny_constant(self.G) + + def test_kemeny_constant_no_nodes(self): + G = nx.Graph() + with pytest.raises(nx.NetworkXError): + nx.kemeny_constant(G) + + def test_kemeny_constant_negative_weight(self): + G = nx.Graph() + w12 = 2 + w13 = 3 + w23 = -10 + G.add_edge(1, 2, weight=w12) + G.add_edge(1, 3, weight=w13) + G.add_edge(2, 3, weight=w23) + with pytest.raises(nx.NetworkXError): + nx.kemeny_constant(G, weight="weight") + + def test_kemeny_constant(self): + K = nx.kemeny_constant(self.G, weight="weight") + w12 = 2 + w13 = 3 + w23 = 4 + test_data = ( + 3 + / 2 + * (w12 + w13) + * (w12 + w23) + * (w13 + w23) + / ( + w12**2 * (w13 + w23) + + w13**2 * (w12 + w23) + + w23**2 * (w12 + w13) + + 3 * w12 * w13 * w23 + ) + ) + assert np.isclose(K, test_data) + + def test_kemeny_constant_no_weight(self): + K = nx.kemeny_constant(self.G) + assert np.isclose(K, 4 / 3) + + def test_kemeny_constant_multigraph(self): + G = nx.MultiGraph() + w12_1 = 2 + w12_2 = 1 + w13 = 3 + w23 = 4 + G.add_edge(1, 2, weight=w12_1) + G.add_edge(1, 2, weight=w12_2) + G.add_edge(1, 3, weight=w13) + G.add_edge(2, 3, weight=w23) + K = nx.kemeny_constant(G, weight="weight") + w12 = w12_1 + w12_2 + test_data = ( + 3 + / 2 + * (w12 + w13) + * (w12 + w23) + * (w13 + w23) + / ( + w12**2 * (w13 + w23) + + w13**2 * (w12 + w23) + + w23**2 * (w12 + w13) + + 3 * w12 * w13 * w23 + ) + ) + assert np.isclose(K, test_data) + + def test_kemeny_constant_weight0(self): + G = nx.Graph() + w12 = 0 + w13 = 3 + w23 = 4 + G.add_edge(1, 2, weight=w12) + G.add_edge(1, 3, weight=w13) + G.add_edge(2, 3, weight=w23) + K = nx.kemeny_constant(G, weight="weight") + test_data = ( + 3 + / 2 + * (w12 + w13) + * (w12 + w23) + * (w13 + w23) + / ( + w12**2 * (w13 + w23) + + w13**2 * (w12 + w23) + + w23**2 * (w12 + w13) + + 3 * w12 * w13 * w23 + ) + ) + assert np.isclose(K, test_data) + + def test_kemeny_constant_selfloop(self): + G = nx.Graph() + w11 = 1 + w12 = 2 + w13 = 3 + w23 = 4 + G.add_edge(1, 1, weight=w11) + G.add_edge(1, 2, weight=w12) + G.add_edge(1, 3, weight=w13) + G.add_edge(2, 3, weight=w23) + K = nx.kemeny_constant(G, weight="weight") + test_data = ( + (2 * w11 + 3 * w12 + 3 * w13) + * (w12 + w23) + * (w13 + w23) + / ( + (w12 * w13 + w12 * w23 + w13 * w23) + * (w11 + 2 * w12 + 2 * w13 + 2 * w23) + ) + ) + assert np.isclose(K, test_data) + + def test_kemeny_constant_complete_bipartite_graph(self): + # Theorem 1 in https://www.sciencedirect.com/science/article/pii/S0166218X20302912 + n1 = 5 + n2 = 4 + G = nx.complete_bipartite_graph(n1, n2) + K = nx.kemeny_constant(G) + assert np.isclose(K, n1 + n2 - 3 / 2) + + def test_kemeny_constant_path_graph(self): + # Theorem 2 in https://www.sciencedirect.com/science/article/pii/S0166218X20302912 + n = 10 + G = nx.path_graph(n) + K = nx.kemeny_constant(G) + assert np.isclose(K, n**2 / 3 - 2 * n / 3 + 1 / 2) diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/test_distance_regular.py b/lib/python3.10/site-packages/networkx/algorithms/tests/test_distance_regular.py new file mode 100644 index 0000000000000000000000000000000000000000..545fb6dee6a915230971cf4b5a141e47adc2cc15 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/tests/test_distance_regular.py @@ -0,0 +1,85 @@ +import pytest + +import networkx as nx +from networkx import is_strongly_regular + + +@pytest.mark.parametrize( + "f", (nx.is_distance_regular, nx.intersection_array, nx.is_strongly_regular) +) +@pytest.mark.parametrize("graph_constructor", (nx.DiGraph, nx.MultiGraph)) +def test_raises_on_directed_and_multigraphs(f, graph_constructor): + G = graph_constructor([(0, 1), (1, 2)]) + with pytest.raises(nx.NetworkXNotImplemented): + f(G) + + +class TestDistanceRegular: + def test_is_distance_regular(self): + assert nx.is_distance_regular(nx.icosahedral_graph()) + assert nx.is_distance_regular(nx.petersen_graph()) + assert nx.is_distance_regular(nx.cubical_graph()) + assert nx.is_distance_regular(nx.complete_bipartite_graph(3, 3)) + assert nx.is_distance_regular(nx.tetrahedral_graph()) + assert nx.is_distance_regular(nx.dodecahedral_graph()) + assert nx.is_distance_regular(nx.pappus_graph()) + assert nx.is_distance_regular(nx.heawood_graph()) + assert nx.is_distance_regular(nx.cycle_graph(3)) + # no distance regular + assert not nx.is_distance_regular(nx.path_graph(4)) + + def test_not_connected(self): + G = nx.cycle_graph(4) + nx.add_cycle(G, [5, 6, 7]) + assert not nx.is_distance_regular(G) + + def test_global_parameters(self): + b, c = nx.intersection_array(nx.cycle_graph(5)) + g = nx.global_parameters(b, c) + assert list(g) == [(0, 0, 2), (1, 0, 1), (1, 1, 0)] + b, c = nx.intersection_array(nx.cycle_graph(3)) + g = nx.global_parameters(b, c) + assert list(g) == [(0, 0, 2), (1, 1, 0)] + + def test_intersection_array(self): + b, c = nx.intersection_array(nx.cycle_graph(5)) + assert b == [2, 1] + assert c == [1, 1] + b, c = nx.intersection_array(nx.dodecahedral_graph()) + assert b == [3, 2, 1, 1, 1] + assert c == [1, 1, 1, 2, 3] + b, c = nx.intersection_array(nx.icosahedral_graph()) + assert b == [5, 2, 1] + assert c == [1, 2, 5] + + +@pytest.mark.parametrize("f", (nx.is_distance_regular, nx.is_strongly_regular)) +def test_empty_graph_raises(f): + G = nx.Graph() + with pytest.raises(nx.NetworkXPointlessConcept, match="Graph has no nodes"): + f(G) + + +class TestStronglyRegular: + """Unit tests for the :func:`~networkx.is_strongly_regular` + function. + + """ + + def test_cycle_graph(self): + """Tests that the cycle graph on five vertices is strongly + regular. + + """ + G = nx.cycle_graph(5) + assert is_strongly_regular(G) + + def test_petersen_graph(self): + """Tests that the Petersen graph is strongly regular.""" + G = nx.petersen_graph() + assert is_strongly_regular(G) + + def test_path_graph(self): + """Tests that the path graph is not strongly regular.""" + G = nx.path_graph(4) + assert not is_strongly_regular(G) diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/test_dominance.py b/lib/python3.10/site-packages/networkx/algorithms/tests/test_dominance.py new file mode 100644 index 0000000000000000000000000000000000000000..9b804c2f6eaafd3fb9a3dba556d43729ea0db70a --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/tests/test_dominance.py @@ -0,0 +1,286 @@ +import pytest + +import networkx as nx + + +class TestImmediateDominators: + def test_exceptions(self): + G = nx.Graph() + G.add_node(0) + pytest.raises(nx.NetworkXNotImplemented, nx.immediate_dominators, G, 0) + G = nx.MultiGraph(G) + pytest.raises(nx.NetworkXNotImplemented, nx.immediate_dominators, G, 0) + G = nx.DiGraph([[0, 0]]) + pytest.raises(nx.NetworkXError, nx.immediate_dominators, G, 1) + + def test_singleton(self): + G = nx.DiGraph() + G.add_node(0) + assert nx.immediate_dominators(G, 0) == {0: 0} + G.add_edge(0, 0) + assert nx.immediate_dominators(G, 0) == {0: 0} + + def test_path(self): + n = 5 + G = nx.path_graph(n, create_using=nx.DiGraph()) + assert nx.immediate_dominators(G, 0) == {i: max(i - 1, 0) for i in range(n)} + + def test_cycle(self): + n = 5 + G = nx.cycle_graph(n, create_using=nx.DiGraph()) + assert nx.immediate_dominators(G, 0) == {i: max(i - 1, 0) for i in range(n)} + + def test_unreachable(self): + n = 5 + assert n > 1 + G = nx.path_graph(n, create_using=nx.DiGraph()) + assert nx.immediate_dominators(G, n // 2) == { + i: max(i - 1, n // 2) for i in range(n // 2, n) + } + + def test_irreducible1(self): + """ + Graph taken from figure 2 of "A simple, fast dominance algorithm." (2006). + https://hdl.handle.net/1911/96345 + """ + edges = [(1, 2), (2, 1), (3, 2), (4, 1), (5, 3), (5, 4)] + G = nx.DiGraph(edges) + assert nx.immediate_dominators(G, 5) == {i: 5 for i in range(1, 6)} + + def test_irreducible2(self): + """ + Graph taken from figure 4 of "A simple, fast dominance algorithm." (2006). + https://hdl.handle.net/1911/96345 + """ + + edges = [(1, 2), (2, 1), (2, 3), (3, 2), (4, 2), (4, 3), (5, 1), (6, 4), (6, 5)] + G = nx.DiGraph(edges) + result = nx.immediate_dominators(G, 6) + assert result == {i: 6 for i in range(1, 7)} + + def test_domrel_png(self): + # Graph taken from https://commons.wikipedia.org/wiki/File:Domrel.png + edges = [(1, 2), (2, 3), (2, 4), (2, 6), (3, 5), (4, 5), (5, 2)] + G = nx.DiGraph(edges) + result = nx.immediate_dominators(G, 1) + assert result == {1: 1, 2: 1, 3: 2, 4: 2, 5: 2, 6: 2} + # Test postdominance. + result = nx.immediate_dominators(G.reverse(copy=False), 6) + assert result == {1: 2, 2: 6, 3: 5, 4: 5, 5: 2, 6: 6} + + def test_boost_example(self): + # Graph taken from Figure 1 of + # http://www.boost.org/doc/libs/1_56_0/libs/graph/doc/lengauer_tarjan_dominator.htm + edges = [(0, 1), (1, 2), (1, 3), (2, 7), (3, 4), (4, 5), (4, 6), (5, 7), (6, 4)] + G = nx.DiGraph(edges) + result = nx.immediate_dominators(G, 0) + assert result == {0: 0, 1: 0, 2: 1, 3: 1, 4: 3, 5: 4, 6: 4, 7: 1} + # Test postdominance. + result = nx.immediate_dominators(G.reverse(copy=False), 7) + assert result == {0: 1, 1: 7, 2: 7, 3: 4, 4: 5, 5: 7, 6: 4, 7: 7} + + +class TestDominanceFrontiers: + def test_exceptions(self): + G = nx.Graph() + G.add_node(0) + pytest.raises(nx.NetworkXNotImplemented, nx.dominance_frontiers, G, 0) + G = nx.MultiGraph(G) + pytest.raises(nx.NetworkXNotImplemented, nx.dominance_frontiers, G, 0) + G = nx.DiGraph([[0, 0]]) + pytest.raises(nx.NetworkXError, nx.dominance_frontiers, G, 1) + + def test_singleton(self): + G = nx.DiGraph() + G.add_node(0) + assert nx.dominance_frontiers(G, 0) == {0: set()} + G.add_edge(0, 0) + assert nx.dominance_frontiers(G, 0) == {0: set()} + + def test_path(self): + n = 5 + G = nx.path_graph(n, create_using=nx.DiGraph()) + assert nx.dominance_frontiers(G, 0) == {i: set() for i in range(n)} + + def test_cycle(self): + n = 5 + G = nx.cycle_graph(n, create_using=nx.DiGraph()) + assert nx.dominance_frontiers(G, 0) == {i: set() for i in range(n)} + + def test_unreachable(self): + n = 5 + assert n > 1 + G = nx.path_graph(n, create_using=nx.DiGraph()) + assert nx.dominance_frontiers(G, n // 2) == {i: set() for i in range(n // 2, n)} + + def test_irreducible1(self): + """ + Graph taken from figure 2 of "A simple, fast dominance algorithm." (2006). + https://hdl.handle.net/1911/96345 + """ + edges = [(1, 2), (2, 1), (3, 2), (4, 1), (5, 3), (5, 4)] + G = nx.DiGraph(edges) + assert dict(nx.dominance_frontiers(G, 5).items()) == { + 1: {2}, + 2: {1}, + 3: {2}, + 4: {1}, + 5: set(), + } + + def test_irreducible2(self): + """ + Graph taken from figure 4 of "A simple, fast dominance algorithm." (2006). + https://hdl.handle.net/1911/96345 + """ + edges = [(1, 2), (2, 1), (2, 3), (3, 2), (4, 2), (4, 3), (5, 1), (6, 4), (6, 5)] + G = nx.DiGraph(edges) + assert nx.dominance_frontiers(G, 6) == { + 1: {2}, + 2: {1, 3}, + 3: {2}, + 4: {2, 3}, + 5: {1}, + 6: set(), + } + + def test_domrel_png(self): + # Graph taken from https://commons.wikipedia.org/wiki/File:Domrel.png + edges = [(1, 2), (2, 3), (2, 4), (2, 6), (3, 5), (4, 5), (5, 2)] + G = nx.DiGraph(edges) + assert nx.dominance_frontiers(G, 1) == { + 1: set(), + 2: {2}, + 3: {5}, + 4: {5}, + 5: {2}, + 6: set(), + } + # Test postdominance. + result = nx.dominance_frontiers(G.reverse(copy=False), 6) + assert result == {1: set(), 2: {2}, 3: {2}, 4: {2}, 5: {2}, 6: set()} + + def test_boost_example(self): + # Graph taken from Figure 1 of + # http://www.boost.org/doc/libs/1_56_0/libs/graph/doc/lengauer_tarjan_dominator.htm + edges = [(0, 1), (1, 2), (1, 3), (2, 7), (3, 4), (4, 5), (4, 6), (5, 7), (6, 4)] + G = nx.DiGraph(edges) + assert nx.dominance_frontiers(G, 0) == { + 0: set(), + 1: set(), + 2: {7}, + 3: {7}, + 4: {4, 7}, + 5: {7}, + 6: {4}, + 7: set(), + } + # Test postdominance. + result = nx.dominance_frontiers(G.reverse(copy=False), 7) + expected = { + 0: set(), + 1: set(), + 2: {1}, + 3: {1}, + 4: {1, 4}, + 5: {1}, + 6: {4}, + 7: set(), + } + assert result == expected + + def test_discard_issue(self): + # https://github.com/networkx/networkx/issues/2071 + g = nx.DiGraph() + g.add_edges_from( + [ + ("b0", "b1"), + ("b1", "b2"), + ("b2", "b3"), + ("b3", "b1"), + ("b1", "b5"), + ("b5", "b6"), + ("b5", "b8"), + ("b6", "b7"), + ("b8", "b7"), + ("b7", "b3"), + ("b3", "b4"), + ] + ) + df = nx.dominance_frontiers(g, "b0") + assert df == { + "b4": set(), + "b5": {"b3"}, + "b6": {"b7"}, + "b7": {"b3"}, + "b0": set(), + "b1": {"b1"}, + "b2": {"b3"}, + "b3": {"b1"}, + "b8": {"b7"}, + } + + def test_loop(self): + g = nx.DiGraph() + g.add_edges_from([("a", "b"), ("b", "c"), ("b", "a")]) + df = nx.dominance_frontiers(g, "a") + assert df == {"a": set(), "b": set(), "c": set()} + + def test_missing_immediate_doms(self): + # see https://github.com/networkx/networkx/issues/2070 + g = nx.DiGraph() + edges = [ + ("entry_1", "b1"), + ("b1", "b2"), + ("b2", "b3"), + ("b3", "exit"), + ("entry_2", "b3"), + ] + + # entry_1 + # | + # b1 + # | + # b2 entry_2 + # | / + # b3 + # | + # exit + + g.add_edges_from(edges) + # formerly raised KeyError on entry_2 when parsing b3 + # because entry_2 does not have immediate doms (no path) + nx.dominance_frontiers(g, "entry_1") + + def test_loops_larger(self): + # from + # http://ecee.colorado.edu/~waite/Darmstadt/motion.html + g = nx.DiGraph() + edges = [ + ("entry", "exit"), + ("entry", "1"), + ("1", "2"), + ("2", "3"), + ("3", "4"), + ("4", "5"), + ("5", "6"), + ("6", "exit"), + ("6", "2"), + ("5", "3"), + ("4", "4"), + ] + + g.add_edges_from(edges) + df = nx.dominance_frontiers(g, "entry") + answer = { + "entry": set(), + "1": {"exit"}, + "2": {"exit", "2"}, + "3": {"exit", "3", "2"}, + "4": {"exit", "4", "3", "2"}, + "5": {"exit", "3", "2"}, + "6": {"exit", "2"}, + "exit": set(), + } + for n in df: + assert set(df[n]) == set(answer[n]) diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/test_dominating.py b/lib/python3.10/site-packages/networkx/algorithms/tests/test_dominating.py new file mode 100644 index 0000000000000000000000000000000000000000..b945c7386374d7076ee08db67631cc7d845e6762 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/tests/test_dominating.py @@ -0,0 +1,46 @@ +import pytest + +import networkx as nx + + +def test_dominating_set(): + G = nx.gnp_random_graph(100, 0.1) + D = nx.dominating_set(G) + assert nx.is_dominating_set(G, D) + D = nx.dominating_set(G, start_with=0) + assert nx.is_dominating_set(G, D) + + +def test_complete(): + """In complete graphs each node is a dominating set. + Thus the dominating set has to be of cardinality 1. + """ + K4 = nx.complete_graph(4) + assert len(nx.dominating_set(K4)) == 1 + K5 = nx.complete_graph(5) + assert len(nx.dominating_set(K5)) == 1 + + +def test_raise_dominating_set(): + with pytest.raises(nx.NetworkXError): + G = nx.path_graph(4) + D = nx.dominating_set(G, start_with=10) + + +def test_is_dominating_set(): + G = nx.path_graph(4) + d = {1, 3} + assert nx.is_dominating_set(G, d) + d = {0, 2} + assert nx.is_dominating_set(G, d) + d = {1} + assert not nx.is_dominating_set(G, d) + + +def test_wikipedia_is_dominating_set(): + """Example from https://en.wikipedia.org/wiki/Dominating_set""" + G = nx.cycle_graph(4) + G.add_edges_from([(0, 4), (1, 4), (2, 5)]) + assert nx.is_dominating_set(G, {4, 3, 5}) + assert nx.is_dominating_set(G, {0, 2}) + assert nx.is_dominating_set(G, {1, 2}) diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/test_efficiency.py b/lib/python3.10/site-packages/networkx/algorithms/tests/test_efficiency.py new file mode 100644 index 0000000000000000000000000000000000000000..9a2e7d0463b3a0abeb8395df4ab870456faa64b7 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/tests/test_efficiency.py @@ -0,0 +1,58 @@ +"""Unit tests for the :mod:`networkx.algorithms.efficiency` module.""" + +import networkx as nx + + +class TestEfficiency: + def setup_method(self): + # G1 is a disconnected graph + self.G1 = nx.Graph() + self.G1.add_nodes_from([1, 2, 3]) + # G2 is a cycle graph + self.G2 = nx.cycle_graph(4) + # G3 is the triangle graph with one additional edge + self.G3 = nx.lollipop_graph(3, 1) + + def test_efficiency_disconnected_nodes(self): + """ + When nodes are disconnected, efficiency is 0 + """ + assert nx.efficiency(self.G1, 1, 2) == 0 + + def test_local_efficiency_disconnected_graph(self): + """ + In a disconnected graph the efficiency is 0 + """ + assert nx.local_efficiency(self.G1) == 0 + + def test_efficiency(self): + assert nx.efficiency(self.G2, 0, 1) == 1 + assert nx.efficiency(self.G2, 0, 2) == 1 / 2 + + def test_global_efficiency(self): + assert nx.global_efficiency(self.G2) == 5 / 6 + + def test_global_efficiency_complete_graph(self): + """ + Tests that the average global efficiency of the complete graph is one. + """ + for n in range(2, 10): + G = nx.complete_graph(n) + assert nx.global_efficiency(G) == 1 + + def test_local_efficiency_complete_graph(self): + """ + Test that the local efficiency for a complete graph with at least 3 + nodes should be one. For a graph with only 2 nodes, the induced + subgraph has no edges. + """ + for n in range(3, 10): + G = nx.complete_graph(n) + assert nx.local_efficiency(G) == 1 + + def test_using_ego_graph(self): + """ + Test that the ego graph is used when computing local efficiency. + For more information, see GitHub issue #2710. + """ + assert nx.local_efficiency(self.G3) == 7 / 12 diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/test_euler.py b/lib/python3.10/site-packages/networkx/algorithms/tests/test_euler.py new file mode 100644 index 0000000000000000000000000000000000000000..b5871f09b5a309df2bb00d9945ca9cf662e6f656 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/tests/test_euler.py @@ -0,0 +1,314 @@ +import collections + +import pytest + +import networkx as nx + + +@pytest.mark.parametrize("f", (nx.is_eulerian, nx.is_semieulerian)) +def test_empty_graph_raises(f): + G = nx.Graph() + with pytest.raises(nx.NetworkXPointlessConcept, match="Connectivity is undefined"): + f(G) + + +class TestIsEulerian: + def test_is_eulerian(self): + assert nx.is_eulerian(nx.complete_graph(5)) + assert nx.is_eulerian(nx.complete_graph(7)) + assert nx.is_eulerian(nx.hypercube_graph(4)) + assert nx.is_eulerian(nx.hypercube_graph(6)) + + assert not nx.is_eulerian(nx.complete_graph(4)) + assert not nx.is_eulerian(nx.complete_graph(6)) + assert not nx.is_eulerian(nx.hypercube_graph(3)) + assert not nx.is_eulerian(nx.hypercube_graph(5)) + + assert not nx.is_eulerian(nx.petersen_graph()) + assert not nx.is_eulerian(nx.path_graph(4)) + + def test_is_eulerian2(self): + # not connected + G = nx.Graph() + G.add_nodes_from([1, 2, 3]) + assert not nx.is_eulerian(G) + # not strongly connected + G = nx.DiGraph() + G.add_nodes_from([1, 2, 3]) + assert not nx.is_eulerian(G) + G = nx.MultiDiGraph() + G.add_edge(1, 2) + G.add_edge(2, 3) + G.add_edge(2, 3) + G.add_edge(3, 1) + assert not nx.is_eulerian(G) + + +class TestEulerianCircuit: + def test_eulerian_circuit_cycle(self): + G = nx.cycle_graph(4) + + edges = list(nx.eulerian_circuit(G, source=0)) + nodes = [u for u, v in edges] + assert nodes == [0, 3, 2, 1] + assert edges == [(0, 3), (3, 2), (2, 1), (1, 0)] + + edges = list(nx.eulerian_circuit(G, source=1)) + nodes = [u for u, v in edges] + assert nodes == [1, 2, 3, 0] + assert edges == [(1, 2), (2, 3), (3, 0), (0, 1)] + + G = nx.complete_graph(3) + + edges = list(nx.eulerian_circuit(G, source=0)) + nodes = [u for u, v in edges] + assert nodes == [0, 2, 1] + assert edges == [(0, 2), (2, 1), (1, 0)] + + edges = list(nx.eulerian_circuit(G, source=1)) + nodes = [u for u, v in edges] + assert nodes == [1, 2, 0] + assert edges == [(1, 2), (2, 0), (0, 1)] + + def test_eulerian_circuit_digraph(self): + G = nx.DiGraph() + nx.add_cycle(G, [0, 1, 2, 3]) + + edges = list(nx.eulerian_circuit(G, source=0)) + nodes = [u for u, v in edges] + assert nodes == [0, 1, 2, 3] + assert edges == [(0, 1), (1, 2), (2, 3), (3, 0)] + + edges = list(nx.eulerian_circuit(G, source=1)) + nodes = [u for u, v in edges] + assert nodes == [1, 2, 3, 0] + assert edges == [(1, 2), (2, 3), (3, 0), (0, 1)] + + def test_multigraph(self): + G = nx.MultiGraph() + nx.add_cycle(G, [0, 1, 2, 3]) + G.add_edge(1, 2) + G.add_edge(1, 2) + edges = list(nx.eulerian_circuit(G, source=0)) + nodes = [u for u, v in edges] + assert nodes == [0, 3, 2, 1, 2, 1] + assert edges == [(0, 3), (3, 2), (2, 1), (1, 2), (2, 1), (1, 0)] + + def test_multigraph_with_keys(self): + G = nx.MultiGraph() + nx.add_cycle(G, [0, 1, 2, 3]) + G.add_edge(1, 2) + G.add_edge(1, 2) + edges = list(nx.eulerian_circuit(G, source=0, keys=True)) + nodes = [u for u, v, k in edges] + assert nodes == [0, 3, 2, 1, 2, 1] + assert edges[:2] == [(0, 3, 0), (3, 2, 0)] + assert collections.Counter(edges[2:5]) == collections.Counter( + [(2, 1, 0), (1, 2, 1), (2, 1, 2)] + ) + assert edges[5:] == [(1, 0, 0)] + + def test_not_eulerian(self): + with pytest.raises(nx.NetworkXError): + f = list(nx.eulerian_circuit(nx.complete_graph(4))) + + +class TestIsSemiEulerian: + def test_is_semieulerian(self): + # Test graphs with Eulerian paths but no cycles return True. + assert nx.is_semieulerian(nx.path_graph(4)) + G = nx.path_graph(6, create_using=nx.DiGraph) + assert nx.is_semieulerian(G) + + # Test graphs with Eulerian cycles return False. + assert not nx.is_semieulerian(nx.complete_graph(5)) + assert not nx.is_semieulerian(nx.complete_graph(7)) + assert not nx.is_semieulerian(nx.hypercube_graph(4)) + assert not nx.is_semieulerian(nx.hypercube_graph(6)) + + +class TestHasEulerianPath: + def test_has_eulerian_path_cyclic(self): + # Test graphs with Eulerian cycles return True. + assert nx.has_eulerian_path(nx.complete_graph(5)) + assert nx.has_eulerian_path(nx.complete_graph(7)) + assert nx.has_eulerian_path(nx.hypercube_graph(4)) + assert nx.has_eulerian_path(nx.hypercube_graph(6)) + + def test_has_eulerian_path_non_cyclic(self): + # Test graphs with Eulerian paths but no cycles return True. + assert nx.has_eulerian_path(nx.path_graph(4)) + G = nx.path_graph(6, create_using=nx.DiGraph) + assert nx.has_eulerian_path(G) + + def test_has_eulerian_path_directed_graph(self): + # Test directed graphs and returns False + G = nx.DiGraph() + G.add_edges_from([(0, 1), (1, 2), (0, 2)]) + assert not nx.has_eulerian_path(G) + + # Test directed graphs without isolated node returns True + G = nx.DiGraph() + G.add_edges_from([(0, 1), (1, 2), (2, 0)]) + assert nx.has_eulerian_path(G) + + # Test directed graphs with isolated node returns False + G.add_node(3) + assert not nx.has_eulerian_path(G) + + @pytest.mark.parametrize("G", (nx.Graph(), nx.DiGraph())) + def test_has_eulerian_path_not_weakly_connected(self, G): + G.add_edges_from([(0, 1), (2, 3), (3, 2)]) + assert not nx.has_eulerian_path(G) + + @pytest.mark.parametrize("G", (nx.Graph(), nx.DiGraph())) + def test_has_eulerian_path_unbalancedins_more_than_one(self, G): + G.add_edges_from([(0, 1), (2, 3)]) + assert not nx.has_eulerian_path(G) + + +class TestFindPathStart: + def testfind_path_start(self): + find_path_start = nx.algorithms.euler._find_path_start + # Test digraphs return correct starting node. + G = nx.path_graph(6, create_using=nx.DiGraph) + assert find_path_start(G) == 0 + edges = [(0, 1), (1, 2), (2, 0), (4, 0)] + assert find_path_start(nx.DiGraph(edges)) == 4 + + # Test graph with no Eulerian path return None. + edges = [(0, 1), (1, 2), (2, 3), (2, 4)] + assert find_path_start(nx.DiGraph(edges)) is None + + +class TestEulerianPath: + def test_eulerian_path(self): + x = [(4, 0), (0, 1), (1, 2), (2, 0)] + for e1, e2 in zip(x, nx.eulerian_path(nx.DiGraph(x))): + assert e1 == e2 + + def test_eulerian_path_straight_link(self): + G = nx.DiGraph() + result = [(1, 2), (2, 3), (3, 4), (4, 5)] + G.add_edges_from(result) + assert result == list(nx.eulerian_path(G)) + assert result == list(nx.eulerian_path(G, source=1)) + with pytest.raises(nx.NetworkXError): + list(nx.eulerian_path(G, source=3)) + with pytest.raises(nx.NetworkXError): + list(nx.eulerian_path(G, source=4)) + with pytest.raises(nx.NetworkXError): + list(nx.eulerian_path(G, source=5)) + + def test_eulerian_path_multigraph(self): + G = nx.MultiDiGraph() + result = [(2, 1), (1, 2), (2, 1), (1, 2), (2, 3), (3, 4), (4, 3)] + G.add_edges_from(result) + assert result == list(nx.eulerian_path(G)) + assert result == list(nx.eulerian_path(G, source=2)) + with pytest.raises(nx.NetworkXError): + list(nx.eulerian_path(G, source=3)) + with pytest.raises(nx.NetworkXError): + list(nx.eulerian_path(G, source=4)) + + def test_eulerian_path_eulerian_circuit(self): + G = nx.DiGraph() + result = [(1, 2), (2, 3), (3, 4), (4, 1)] + result2 = [(2, 3), (3, 4), (4, 1), (1, 2)] + result3 = [(3, 4), (4, 1), (1, 2), (2, 3)] + G.add_edges_from(result) + assert result == list(nx.eulerian_path(G)) + assert result == list(nx.eulerian_path(G, source=1)) + assert result2 == list(nx.eulerian_path(G, source=2)) + assert result3 == list(nx.eulerian_path(G, source=3)) + + def test_eulerian_path_undirected(self): + G = nx.Graph() + result = [(1, 2), (2, 3), (3, 4), (4, 5)] + result2 = [(5, 4), (4, 3), (3, 2), (2, 1)] + G.add_edges_from(result) + assert list(nx.eulerian_path(G)) in (result, result2) + assert result == list(nx.eulerian_path(G, source=1)) + assert result2 == list(nx.eulerian_path(G, source=5)) + with pytest.raises(nx.NetworkXError): + list(nx.eulerian_path(G, source=3)) + with pytest.raises(nx.NetworkXError): + list(nx.eulerian_path(G, source=2)) + + def test_eulerian_path_multigraph_undirected(self): + G = nx.MultiGraph() + result = [(2, 1), (1, 2), (2, 1), (1, 2), (2, 3), (3, 4)] + G.add_edges_from(result) + assert result == list(nx.eulerian_path(G)) + assert result == list(nx.eulerian_path(G, source=2)) + with pytest.raises(nx.NetworkXError): + list(nx.eulerian_path(G, source=3)) + with pytest.raises(nx.NetworkXError): + list(nx.eulerian_path(G, source=1)) + + @pytest.mark.parametrize( + ("graph_type", "result"), + ( + (nx.MultiGraph, [(0, 1, 0), (1, 0, 1)]), + (nx.MultiDiGraph, [(0, 1, 0), (1, 0, 0)]), + ), + ) + def test_eulerian_with_keys(self, graph_type, result): + G = graph_type([(0, 1), (1, 0)]) + answer = nx.eulerian_path(G, keys=True) + assert list(answer) == result + + +class TestEulerize: + def test_disconnected(self): + with pytest.raises(nx.NetworkXError): + G = nx.from_edgelist([(0, 1), (2, 3)]) + nx.eulerize(G) + + def test_null_graph(self): + with pytest.raises(nx.NetworkXPointlessConcept): + nx.eulerize(nx.Graph()) + + def test_null_multigraph(self): + with pytest.raises(nx.NetworkXPointlessConcept): + nx.eulerize(nx.MultiGraph()) + + def test_on_empty_graph(self): + with pytest.raises(nx.NetworkXError): + nx.eulerize(nx.empty_graph(3)) + + def test_on_eulerian(self): + G = nx.cycle_graph(3) + H = nx.eulerize(G) + assert nx.is_isomorphic(G, H) + + def test_on_eulerian_multigraph(self): + G = nx.MultiGraph(nx.cycle_graph(3)) + G.add_edge(0, 1) + H = nx.eulerize(G) + assert nx.is_eulerian(H) + + def test_on_complete_graph(self): + G = nx.complete_graph(4) + assert nx.is_eulerian(nx.eulerize(G)) + assert nx.is_eulerian(nx.eulerize(nx.MultiGraph(G))) + + def test_on_non_eulerian_graph(self): + G = nx.cycle_graph(18) + G.add_edge(0, 18) + G.add_edge(18, 19) + G.add_edge(17, 19) + G.add_edge(4, 20) + G.add_edge(20, 21) + G.add_edge(21, 22) + G.add_edge(22, 23) + G.add_edge(23, 24) + G.add_edge(24, 25) + G.add_edge(25, 26) + G.add_edge(26, 27) + G.add_edge(27, 28) + G.add_edge(28, 13) + assert not nx.is_eulerian(G) + G = nx.eulerize(G) + assert nx.is_eulerian(G) + assert nx.number_of_edges(G) == 39 diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/test_graph_hashing.py b/lib/python3.10/site-packages/networkx/algorithms/tests/test_graph_hashing.py new file mode 100644 index 0000000000000000000000000000000000000000..0828069d1c3c821a0eaeae844fb6182470aadb25 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/tests/test_graph_hashing.py @@ -0,0 +1,686 @@ +import pytest + +import networkx as nx +from networkx.generators import directed + +# Unit tests for the :func:`~networkx.weisfeiler_lehman_graph_hash` function + + +def test_empty_graph_hash(): + """ + empty graphs should give hashes regardless of other params + """ + G1 = nx.empty_graph() + G2 = nx.empty_graph() + + h1 = nx.weisfeiler_lehman_graph_hash(G1) + h2 = nx.weisfeiler_lehman_graph_hash(G2) + h3 = nx.weisfeiler_lehman_graph_hash(G2, edge_attr="edge_attr1") + h4 = nx.weisfeiler_lehman_graph_hash(G2, node_attr="node_attr1") + h5 = nx.weisfeiler_lehman_graph_hash( + G2, edge_attr="edge_attr1", node_attr="node_attr1" + ) + h6 = nx.weisfeiler_lehman_graph_hash(G2, iterations=10) + + assert h1 == h2 + assert h1 == h3 + assert h1 == h4 + assert h1 == h5 + assert h1 == h6 + + +def test_directed(): + """ + A directed graph with no bi-directional edges should yield different a graph hash + to the same graph taken as undirected if there are no hash collisions. + """ + r = 10 + for i in range(r): + G_directed = nx.gn_graph(10 + r, seed=100 + i) + G_undirected = nx.to_undirected(G_directed) + + h_directed = nx.weisfeiler_lehman_graph_hash(G_directed) + h_undirected = nx.weisfeiler_lehman_graph_hash(G_undirected) + + assert h_directed != h_undirected + + +def test_reversed(): + """ + A directed graph with no bi-directional edges should yield different a graph hash + to the same graph taken with edge directions reversed if there are no hash collisions. + Here we test a cycle graph which is the minimal counterexample + """ + G = nx.cycle_graph(5, create_using=nx.DiGraph) + nx.set_node_attributes(G, {n: str(n) for n in G.nodes()}, name="label") + + G_reversed = G.reverse() + + h = nx.weisfeiler_lehman_graph_hash(G, node_attr="label") + h_reversed = nx.weisfeiler_lehman_graph_hash(G_reversed, node_attr="label") + + assert h != h_reversed + + +def test_isomorphic(): + """ + graph hashes should be invariant to node-relabeling (when the output is reindexed + by the same mapping) + """ + n, r = 100, 10 + p = 1.0 / r + for i in range(1, r + 1): + G1 = nx.erdos_renyi_graph(n, p * i, seed=200 + i) + G2 = nx.relabel_nodes(G1, {u: -1 * u for u in G1.nodes()}) + + g1_hash = nx.weisfeiler_lehman_graph_hash(G1) + g2_hash = nx.weisfeiler_lehman_graph_hash(G2) + + assert g1_hash == g2_hash + + +def test_isomorphic_edge_attr(): + """ + Isomorphic graphs with differing edge attributes should yield different graph + hashes if the 'edge_attr' argument is supplied and populated in the graph, + and there are no hash collisions. + The output should still be invariant to node-relabeling + """ + n, r = 100, 10 + p = 1.0 / r + for i in range(1, r + 1): + G1 = nx.erdos_renyi_graph(n, p * i, seed=300 + i) + + for a, b in G1.edges: + G1[a][b]["edge_attr1"] = f"{a}-{b}-1" + G1[a][b]["edge_attr2"] = f"{a}-{b}-2" + + g1_hash_with_edge_attr1 = nx.weisfeiler_lehman_graph_hash( + G1, edge_attr="edge_attr1" + ) + g1_hash_with_edge_attr2 = nx.weisfeiler_lehman_graph_hash( + G1, edge_attr="edge_attr2" + ) + g1_hash_no_edge_attr = nx.weisfeiler_lehman_graph_hash(G1, edge_attr=None) + + assert g1_hash_with_edge_attr1 != g1_hash_no_edge_attr + assert g1_hash_with_edge_attr2 != g1_hash_no_edge_attr + assert g1_hash_with_edge_attr1 != g1_hash_with_edge_attr2 + + G2 = nx.relabel_nodes(G1, {u: -1 * u for u in G1.nodes()}) + + g2_hash_with_edge_attr1 = nx.weisfeiler_lehman_graph_hash( + G2, edge_attr="edge_attr1" + ) + g2_hash_with_edge_attr2 = nx.weisfeiler_lehman_graph_hash( + G2, edge_attr="edge_attr2" + ) + + assert g1_hash_with_edge_attr1 == g2_hash_with_edge_attr1 + assert g1_hash_with_edge_attr2 == g2_hash_with_edge_attr2 + + +def test_missing_edge_attr(): + """ + If the 'edge_attr' argument is supplied but is missing from an edge in the graph, + we should raise a KeyError + """ + G = nx.Graph() + G.add_edges_from([(1, 2, {"edge_attr1": "a"}), (1, 3, {})]) + pytest.raises(KeyError, nx.weisfeiler_lehman_graph_hash, G, edge_attr="edge_attr1") + + +def test_isomorphic_node_attr(): + """ + Isomorphic graphs with differing node attributes should yield different graph + hashes if the 'node_attr' argument is supplied and populated in the graph, and + there are no hash collisions. + The output should still be invariant to node-relabeling + """ + n, r = 100, 10 + p = 1.0 / r + for i in range(1, r + 1): + G1 = nx.erdos_renyi_graph(n, p * i, seed=400 + i) + + for u in G1.nodes(): + G1.nodes[u]["node_attr1"] = f"{u}-1" + G1.nodes[u]["node_attr2"] = f"{u}-2" + + g1_hash_with_node_attr1 = nx.weisfeiler_lehman_graph_hash( + G1, node_attr="node_attr1" + ) + g1_hash_with_node_attr2 = nx.weisfeiler_lehman_graph_hash( + G1, node_attr="node_attr2" + ) + g1_hash_no_node_attr = nx.weisfeiler_lehman_graph_hash(G1, node_attr=None) + + assert g1_hash_with_node_attr1 != g1_hash_no_node_attr + assert g1_hash_with_node_attr2 != g1_hash_no_node_attr + assert g1_hash_with_node_attr1 != g1_hash_with_node_attr2 + + G2 = nx.relabel_nodes(G1, {u: -1 * u for u in G1.nodes()}) + + g2_hash_with_node_attr1 = nx.weisfeiler_lehman_graph_hash( + G2, node_attr="node_attr1" + ) + g2_hash_with_node_attr2 = nx.weisfeiler_lehman_graph_hash( + G2, node_attr="node_attr2" + ) + + assert g1_hash_with_node_attr1 == g2_hash_with_node_attr1 + assert g1_hash_with_node_attr2 == g2_hash_with_node_attr2 + + +def test_missing_node_attr(): + """ + If the 'node_attr' argument is supplied but is missing from a node in the graph, + we should raise a KeyError + """ + G = nx.Graph() + G.add_nodes_from([(1, {"node_attr1": "a"}), (2, {})]) + G.add_edges_from([(1, 2), (2, 3), (3, 1), (1, 4)]) + pytest.raises(KeyError, nx.weisfeiler_lehman_graph_hash, G, node_attr="node_attr1") + + +def test_isomorphic_edge_attr_and_node_attr(): + """ + Isomorphic graphs with differing node attributes should yield different graph + hashes if the 'node_attr' and 'edge_attr' argument is supplied and populated in + the graph, and there are no hash collisions. + The output should still be invariant to node-relabeling + """ + n, r = 100, 10 + p = 1.0 / r + for i in range(1, r + 1): + G1 = nx.erdos_renyi_graph(n, p * i, seed=500 + i) + + for u in G1.nodes(): + G1.nodes[u]["node_attr1"] = f"{u}-1" + G1.nodes[u]["node_attr2"] = f"{u}-2" + + for a, b in G1.edges: + G1[a][b]["edge_attr1"] = f"{a}-{b}-1" + G1[a][b]["edge_attr2"] = f"{a}-{b}-2" + + g1_hash_edge1_node1 = nx.weisfeiler_lehman_graph_hash( + G1, edge_attr="edge_attr1", node_attr="node_attr1" + ) + g1_hash_edge2_node2 = nx.weisfeiler_lehman_graph_hash( + G1, edge_attr="edge_attr2", node_attr="node_attr2" + ) + g1_hash_edge1_node2 = nx.weisfeiler_lehman_graph_hash( + G1, edge_attr="edge_attr1", node_attr="node_attr2" + ) + g1_hash_no_attr = nx.weisfeiler_lehman_graph_hash(G1) + + assert g1_hash_edge1_node1 != g1_hash_no_attr + assert g1_hash_edge2_node2 != g1_hash_no_attr + assert g1_hash_edge1_node1 != g1_hash_edge2_node2 + assert g1_hash_edge1_node2 != g1_hash_edge2_node2 + assert g1_hash_edge1_node2 != g1_hash_edge1_node1 + + G2 = nx.relabel_nodes(G1, {u: -1 * u for u in G1.nodes()}) + + g2_hash_edge1_node1 = nx.weisfeiler_lehman_graph_hash( + G2, edge_attr="edge_attr1", node_attr="node_attr1" + ) + g2_hash_edge2_node2 = nx.weisfeiler_lehman_graph_hash( + G2, edge_attr="edge_attr2", node_attr="node_attr2" + ) + + assert g1_hash_edge1_node1 == g2_hash_edge1_node1 + assert g1_hash_edge2_node2 == g2_hash_edge2_node2 + + +def test_digest_size(): + """ + The hash string lengths should be as expected for a variety of graphs and + digest sizes + """ + n, r = 100, 10 + p = 1.0 / r + for i in range(1, r + 1): + G = nx.erdos_renyi_graph(n, p * i, seed=1000 + i) + + h16 = nx.weisfeiler_lehman_graph_hash(G) + h32 = nx.weisfeiler_lehman_graph_hash(G, digest_size=32) + + assert h16 != h32 + assert len(h16) == 16 * 2 + assert len(h32) == 32 * 2 + + +# Unit tests for the :func:`~networkx.weisfeiler_lehman_hash_subgraphs` function + + +def is_subiteration(a, b): + """ + returns True if that each hash sequence in 'a' is a prefix for + the corresponding sequence indexed by the same node in 'b'. + """ + return all(b[node][: len(hashes)] == hashes for node, hashes in a.items()) + + +def hexdigest_sizes_correct(a, digest_size): + """ + returns True if all hex digest sizes are the expected length in a node:subgraph-hashes + dictionary. Hex digest string length == 2 * bytes digest length since each pair of hex + digits encodes 1 byte (https://docs.python.org/3/library/hashlib.html) + """ + hexdigest_size = digest_size * 2 + list_digest_sizes_correct = lambda l: all(len(x) == hexdigest_size for x in l) + return all(list_digest_sizes_correct(hashes) for hashes in a.values()) + + +def test_empty_graph_subgraph_hash(): + """ " + empty graphs should give empty dict subgraph hashes regardless of other params + """ + G = nx.empty_graph() + + subgraph_hashes1 = nx.weisfeiler_lehman_subgraph_hashes(G) + subgraph_hashes2 = nx.weisfeiler_lehman_subgraph_hashes(G, edge_attr="edge_attr") + subgraph_hashes3 = nx.weisfeiler_lehman_subgraph_hashes(G, node_attr="edge_attr") + subgraph_hashes4 = nx.weisfeiler_lehman_subgraph_hashes(G, iterations=2) + subgraph_hashes5 = nx.weisfeiler_lehman_subgraph_hashes(G, digest_size=64) + + assert subgraph_hashes1 == {} + assert subgraph_hashes2 == {} + assert subgraph_hashes3 == {} + assert subgraph_hashes4 == {} + assert subgraph_hashes5 == {} + + +def test_directed_subgraph_hash(): + """ + A directed graph with no bi-directional edges should yield different subgraph hashes + to the same graph taken as undirected, if all hashes don't collide. + """ + r = 10 + for i in range(r): + G_directed = nx.gn_graph(10 + r, seed=100 + i) + G_undirected = nx.to_undirected(G_directed) + + directed_subgraph_hashes = nx.weisfeiler_lehman_subgraph_hashes(G_directed) + undirected_subgraph_hashes = nx.weisfeiler_lehman_subgraph_hashes(G_undirected) + + assert directed_subgraph_hashes != undirected_subgraph_hashes + + +def test_reversed_subgraph_hash(): + """ + A directed graph with no bi-directional edges should yield different subgraph hashes + to the same graph taken with edge directions reversed if there are no hash collisions. + Here we test a cycle graph which is the minimal counterexample + """ + G = nx.cycle_graph(5, create_using=nx.DiGraph) + nx.set_node_attributes(G, {n: str(n) for n in G.nodes()}, name="label") + + G_reversed = G.reverse() + + h = nx.weisfeiler_lehman_subgraph_hashes(G, node_attr="label") + h_reversed = nx.weisfeiler_lehman_subgraph_hashes(G_reversed, node_attr="label") + + assert h != h_reversed + + +def test_isomorphic_subgraph_hash(): + """ + the subgraph hashes should be invariant to node-relabeling when the output is reindexed + by the same mapping and all hashes don't collide. + """ + n, r = 100, 10 + p = 1.0 / r + for i in range(1, r + 1): + G1 = nx.erdos_renyi_graph(n, p * i, seed=200 + i) + G2 = nx.relabel_nodes(G1, {u: -1 * u for u in G1.nodes()}) + + g1_subgraph_hashes = nx.weisfeiler_lehman_subgraph_hashes(G1) + g2_subgraph_hashes = nx.weisfeiler_lehman_subgraph_hashes(G2) + + assert g1_subgraph_hashes == {-1 * k: v for k, v in g2_subgraph_hashes.items()} + + +def test_isomorphic_edge_attr_subgraph_hash(): + """ + Isomorphic graphs with differing edge attributes should yield different subgraph + hashes if the 'edge_attr' argument is supplied and populated in the graph, and + all hashes don't collide. + The output should still be invariant to node-relabeling + """ + n, r = 100, 10 + p = 1.0 / r + for i in range(1, r + 1): + G1 = nx.erdos_renyi_graph(n, p * i, seed=300 + i) + + for a, b in G1.edges: + G1[a][b]["edge_attr1"] = f"{a}-{b}-1" + G1[a][b]["edge_attr2"] = f"{a}-{b}-2" + + g1_hash_with_edge_attr1 = nx.weisfeiler_lehman_subgraph_hashes( + G1, edge_attr="edge_attr1" + ) + g1_hash_with_edge_attr2 = nx.weisfeiler_lehman_subgraph_hashes( + G1, edge_attr="edge_attr2" + ) + g1_hash_no_edge_attr = nx.weisfeiler_lehman_subgraph_hashes(G1, edge_attr=None) + + assert g1_hash_with_edge_attr1 != g1_hash_no_edge_attr + assert g1_hash_with_edge_attr2 != g1_hash_no_edge_attr + assert g1_hash_with_edge_attr1 != g1_hash_with_edge_attr2 + + G2 = nx.relabel_nodes(G1, {u: -1 * u for u in G1.nodes()}) + + g2_hash_with_edge_attr1 = nx.weisfeiler_lehman_subgraph_hashes( + G2, edge_attr="edge_attr1" + ) + g2_hash_with_edge_attr2 = nx.weisfeiler_lehman_subgraph_hashes( + G2, edge_attr="edge_attr2" + ) + + assert g1_hash_with_edge_attr1 == { + -1 * k: v for k, v in g2_hash_with_edge_attr1.items() + } + assert g1_hash_with_edge_attr2 == { + -1 * k: v for k, v in g2_hash_with_edge_attr2.items() + } + + +def test_missing_edge_attr_subgraph_hash(): + """ + If the 'edge_attr' argument is supplied but is missing from an edge in the graph, + we should raise a KeyError + """ + G = nx.Graph() + G.add_edges_from([(1, 2, {"edge_attr1": "a"}), (1, 3, {})]) + pytest.raises( + KeyError, nx.weisfeiler_lehman_subgraph_hashes, G, edge_attr="edge_attr1" + ) + + +def test_isomorphic_node_attr_subgraph_hash(): + """ + Isomorphic graphs with differing node attributes should yield different subgraph + hashes if the 'node_attr' argument is supplied and populated in the graph, and + all hashes don't collide. + The output should still be invariant to node-relabeling + """ + n, r = 100, 10 + p = 1.0 / r + for i in range(1, r + 1): + G1 = nx.erdos_renyi_graph(n, p * i, seed=400 + i) + + for u in G1.nodes(): + G1.nodes[u]["node_attr1"] = f"{u}-1" + G1.nodes[u]["node_attr2"] = f"{u}-2" + + g1_hash_with_node_attr1 = nx.weisfeiler_lehman_subgraph_hashes( + G1, node_attr="node_attr1" + ) + g1_hash_with_node_attr2 = nx.weisfeiler_lehman_subgraph_hashes( + G1, node_attr="node_attr2" + ) + g1_hash_no_node_attr = nx.weisfeiler_lehman_subgraph_hashes(G1, node_attr=None) + + assert g1_hash_with_node_attr1 != g1_hash_no_node_attr + assert g1_hash_with_node_attr2 != g1_hash_no_node_attr + assert g1_hash_with_node_attr1 != g1_hash_with_node_attr2 + + G2 = nx.relabel_nodes(G1, {u: -1 * u for u in G1.nodes()}) + + g2_hash_with_node_attr1 = nx.weisfeiler_lehman_subgraph_hashes( + G2, node_attr="node_attr1" + ) + g2_hash_with_node_attr2 = nx.weisfeiler_lehman_subgraph_hashes( + G2, node_attr="node_attr2" + ) + + assert g1_hash_with_node_attr1 == { + -1 * k: v for k, v in g2_hash_with_node_attr1.items() + } + assert g1_hash_with_node_attr2 == { + -1 * k: v for k, v in g2_hash_with_node_attr2.items() + } + + +def test_missing_node_attr_subgraph_hash(): + """ + If the 'node_attr' argument is supplied but is missing from a node in the graph, + we should raise a KeyError + """ + G = nx.Graph() + G.add_nodes_from([(1, {"node_attr1": "a"}), (2, {})]) + G.add_edges_from([(1, 2), (2, 3), (3, 1), (1, 4)]) + pytest.raises( + KeyError, nx.weisfeiler_lehman_subgraph_hashes, G, node_attr="node_attr1" + ) + + +def test_isomorphic_edge_attr_and_node_attr_subgraph_hash(): + """ + Isomorphic graphs with differing node attributes should yield different subgraph + hashes if the 'node_attr' and 'edge_attr' argument is supplied and populated in + the graph, and all hashes don't collide + The output should still be invariant to node-relabeling + """ + n, r = 100, 10 + p = 1.0 / r + for i in range(1, r + 1): + G1 = nx.erdos_renyi_graph(n, p * i, seed=500 + i) + + for u in G1.nodes(): + G1.nodes[u]["node_attr1"] = f"{u}-1" + G1.nodes[u]["node_attr2"] = f"{u}-2" + + for a, b in G1.edges: + G1[a][b]["edge_attr1"] = f"{a}-{b}-1" + G1[a][b]["edge_attr2"] = f"{a}-{b}-2" + + g1_hash_edge1_node1 = nx.weisfeiler_lehman_subgraph_hashes( + G1, edge_attr="edge_attr1", node_attr="node_attr1" + ) + g1_hash_edge2_node2 = nx.weisfeiler_lehman_subgraph_hashes( + G1, edge_attr="edge_attr2", node_attr="node_attr2" + ) + g1_hash_edge1_node2 = nx.weisfeiler_lehman_subgraph_hashes( + G1, edge_attr="edge_attr1", node_attr="node_attr2" + ) + g1_hash_no_attr = nx.weisfeiler_lehman_subgraph_hashes(G1) + + assert g1_hash_edge1_node1 != g1_hash_no_attr + assert g1_hash_edge2_node2 != g1_hash_no_attr + assert g1_hash_edge1_node1 != g1_hash_edge2_node2 + assert g1_hash_edge1_node2 != g1_hash_edge2_node2 + assert g1_hash_edge1_node2 != g1_hash_edge1_node1 + + G2 = nx.relabel_nodes(G1, {u: -1 * u for u in G1.nodes()}) + + g2_hash_edge1_node1 = nx.weisfeiler_lehman_subgraph_hashes( + G2, edge_attr="edge_attr1", node_attr="node_attr1" + ) + g2_hash_edge2_node2 = nx.weisfeiler_lehman_subgraph_hashes( + G2, edge_attr="edge_attr2", node_attr="node_attr2" + ) + + assert g1_hash_edge1_node1 == { + -1 * k: v for k, v in g2_hash_edge1_node1.items() + } + assert g1_hash_edge2_node2 == { + -1 * k: v for k, v in g2_hash_edge2_node2.items() + } + + +def test_iteration_depth(): + """ + All nodes should have the correct number of subgraph hashes in the output when + using degree as initial node labels + Subsequent iteration depths for the same graph should be additive for each node + """ + n, r = 100, 10 + p = 1.0 / r + for i in range(1, r + 1): + G = nx.erdos_renyi_graph(n, p * i, seed=600 + i) + + depth3 = nx.weisfeiler_lehman_subgraph_hashes(G, iterations=3) + depth4 = nx.weisfeiler_lehman_subgraph_hashes(G, iterations=4) + depth5 = nx.weisfeiler_lehman_subgraph_hashes(G, iterations=5) + + assert all(len(hashes) == 3 for hashes in depth3.values()) + assert all(len(hashes) == 4 for hashes in depth4.values()) + assert all(len(hashes) == 5 for hashes in depth5.values()) + + assert is_subiteration(depth3, depth4) + assert is_subiteration(depth4, depth5) + assert is_subiteration(depth3, depth5) + + +def test_iteration_depth_edge_attr(): + """ + All nodes should have the correct number of subgraph hashes in the output when + setting initial node labels empty and using an edge attribute when aggregating + neighborhoods. + Subsequent iteration depths for the same graph should be additive for each node + """ + n, r = 100, 10 + p = 1.0 / r + for i in range(1, r + 1): + G = nx.erdos_renyi_graph(n, p * i, seed=700 + i) + + for a, b in G.edges: + G[a][b]["edge_attr1"] = f"{a}-{b}-1" + + depth3 = nx.weisfeiler_lehman_subgraph_hashes( + G, edge_attr="edge_attr1", iterations=3 + ) + depth4 = nx.weisfeiler_lehman_subgraph_hashes( + G, edge_attr="edge_attr1", iterations=4 + ) + depth5 = nx.weisfeiler_lehman_subgraph_hashes( + G, edge_attr="edge_attr1", iterations=5 + ) + + assert all(len(hashes) == 3 for hashes in depth3.values()) + assert all(len(hashes) == 4 for hashes in depth4.values()) + assert all(len(hashes) == 5 for hashes in depth5.values()) + + assert is_subiteration(depth3, depth4) + assert is_subiteration(depth4, depth5) + assert is_subiteration(depth3, depth5) + + +def test_iteration_depth_node_attr(): + """ + All nodes should have the correct number of subgraph hashes in the output when + setting initial node labels to an attribute. + Subsequent iteration depths for the same graph should be additive for each node + """ + n, r = 100, 10 + p = 1.0 / r + for i in range(1, r + 1): + G = nx.erdos_renyi_graph(n, p * i, seed=800 + i) + + for u in G.nodes(): + G.nodes[u]["node_attr1"] = f"{u}-1" + + depth3 = nx.weisfeiler_lehman_subgraph_hashes( + G, node_attr="node_attr1", iterations=3 + ) + depth4 = nx.weisfeiler_lehman_subgraph_hashes( + G, node_attr="node_attr1", iterations=4 + ) + depth5 = nx.weisfeiler_lehman_subgraph_hashes( + G, node_attr="node_attr1", iterations=5 + ) + + assert all(len(hashes) == 3 for hashes in depth3.values()) + assert all(len(hashes) == 4 for hashes in depth4.values()) + assert all(len(hashes) == 5 for hashes in depth5.values()) + + assert is_subiteration(depth3, depth4) + assert is_subiteration(depth4, depth5) + assert is_subiteration(depth3, depth5) + + +def test_iteration_depth_node_edge_attr(): + """ + All nodes should have the correct number of subgraph hashes in the output when + setting initial node labels to an attribute and also using an edge attribute when + aggregating neighborhoods. + Subsequent iteration depths for the same graph should be additive for each node + """ + n, r = 100, 10 + p = 1.0 / r + for i in range(1, r + 1): + G = nx.erdos_renyi_graph(n, p * i, seed=900 + i) + + for u in G.nodes(): + G.nodes[u]["node_attr1"] = f"{u}-1" + + for a, b in G.edges: + G[a][b]["edge_attr1"] = f"{a}-{b}-1" + + depth3 = nx.weisfeiler_lehman_subgraph_hashes( + G, edge_attr="edge_attr1", node_attr="node_attr1", iterations=3 + ) + depth4 = nx.weisfeiler_lehman_subgraph_hashes( + G, edge_attr="edge_attr1", node_attr="node_attr1", iterations=4 + ) + depth5 = nx.weisfeiler_lehman_subgraph_hashes( + G, edge_attr="edge_attr1", node_attr="node_attr1", iterations=5 + ) + + assert all(len(hashes) == 3 for hashes in depth3.values()) + assert all(len(hashes) == 4 for hashes in depth4.values()) + assert all(len(hashes) == 5 for hashes in depth5.values()) + + assert is_subiteration(depth3, depth4) + assert is_subiteration(depth4, depth5) + assert is_subiteration(depth3, depth5) + + +def test_digest_size_subgraph_hash(): + """ + The hash string lengths should be as expected for a variety of graphs and + digest sizes + """ + n, r = 100, 10 + p = 1.0 / r + for i in range(1, r + 1): + G = nx.erdos_renyi_graph(n, p * i, seed=1000 + i) + + digest_size16_hashes = nx.weisfeiler_lehman_subgraph_hashes(G) + digest_size32_hashes = nx.weisfeiler_lehman_subgraph_hashes(G, digest_size=32) + + assert digest_size16_hashes != digest_size32_hashes + + assert hexdigest_sizes_correct(digest_size16_hashes, 16) + assert hexdigest_sizes_correct(digest_size32_hashes, 32) + + +def test_initial_node_labels_subgraph_hash(): + """ + Including the hashed initial label prepends an extra hash to the lists + """ + G = nx.path_graph(5) + nx.set_node_attributes(G, {i: int(0 < i < 4) for i in G}, "label") + # initial node labels: + # 0--1--1--1--0 + + without_initial_label = nx.weisfeiler_lehman_subgraph_hashes(G, node_attr="label") + assert all(len(v) == 3 for v in without_initial_label.values()) + # 3 different 1 hop nhds + assert len({v[0] for v in without_initial_label.values()}) == 3 + + with_initial_label = nx.weisfeiler_lehman_subgraph_hashes( + G, node_attr="label", include_initial_labels=True + ) + assert all(len(v) == 4 for v in with_initial_label.values()) + # 2 different initial labels + assert len({v[0] for v in with_initial_label.values()}) == 2 + + # check hashes match otherwise + for u in G: + for a, b in zip( + with_initial_label[u][1:], without_initial_label[u], strict=True + ): + assert a == b diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/test_graphical.py b/lib/python3.10/site-packages/networkx/algorithms/tests/test_graphical.py new file mode 100644 index 0000000000000000000000000000000000000000..99f766f799d8573e80d905482f4b685a2d16bcc0 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/tests/test_graphical.py @@ -0,0 +1,163 @@ +import pytest + +import networkx as nx + + +def test_valid_degree_sequence1(): + n = 100 + p = 0.3 + for i in range(10): + G = nx.erdos_renyi_graph(n, p) + deg = (d for n, d in G.degree()) + assert nx.is_graphical(deg, method="eg") + assert nx.is_graphical(deg, method="hh") + + +def test_valid_degree_sequence2(): + n = 100 + for i in range(10): + G = nx.barabasi_albert_graph(n, 1) + deg = (d for n, d in G.degree()) + assert nx.is_graphical(deg, method="eg") + assert nx.is_graphical(deg, method="hh") + + +def test_string_input(): + pytest.raises(nx.NetworkXException, nx.is_graphical, [], "foo") + pytest.raises(nx.NetworkXException, nx.is_graphical, ["red"], "hh") + pytest.raises(nx.NetworkXException, nx.is_graphical, ["red"], "eg") + + +def test_non_integer_input(): + pytest.raises(nx.NetworkXException, nx.is_graphical, [72.5], "eg") + pytest.raises(nx.NetworkXException, nx.is_graphical, [72.5], "hh") + + +def test_negative_input(): + assert not nx.is_graphical([-1], "hh") + assert not nx.is_graphical([-1], "eg") + + +class TestAtlas: + @classmethod + def setup_class(cls): + global atlas + from networkx.generators import atlas + + cls.GAG = atlas.graph_atlas_g() + + def test_atlas(self): + for graph in self.GAG: + deg = (d for n, d in graph.degree()) + assert nx.is_graphical(deg, method="eg") + assert nx.is_graphical(deg, method="hh") + + +def test_small_graph_true(): + z = [5, 3, 3, 3, 3, 2, 2, 2, 1, 1, 1] + assert nx.is_graphical(z, method="hh") + assert nx.is_graphical(z, method="eg") + z = [10, 3, 3, 3, 3, 2, 2, 2, 2, 2, 2] + assert nx.is_graphical(z, method="hh") + assert nx.is_graphical(z, method="eg") + z = [1, 1, 1, 1, 1, 2, 2, 2, 3, 4] + assert nx.is_graphical(z, method="hh") + assert nx.is_graphical(z, method="eg") + + +def test_small_graph_false(): + z = [1000, 3, 3, 3, 3, 2, 2, 2, 1, 1, 1] + assert not nx.is_graphical(z, method="hh") + assert not nx.is_graphical(z, method="eg") + z = [6, 5, 4, 4, 2, 1, 1, 1] + assert not nx.is_graphical(z, method="hh") + assert not nx.is_graphical(z, method="eg") + z = [1, 1, 1, 1, 1, 1, 2, 2, 2, 3, 4] + assert not nx.is_graphical(z, method="hh") + assert not nx.is_graphical(z, method="eg") + + +def test_directed_degree_sequence(): + # Test a range of valid directed degree sequences + n, r = 100, 10 + p = 1.0 / r + for i in range(r): + G = nx.erdos_renyi_graph(n, p * (i + 1), None, True) + din = (d for n, d in G.in_degree()) + dout = (d for n, d in G.out_degree()) + assert nx.is_digraphical(din, dout) + + +def test_small_directed_sequences(): + dout = [5, 3, 3, 3, 3, 2, 2, 2, 1, 1, 1] + din = [3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 1] + assert nx.is_digraphical(din, dout) + # Test nongraphical directed sequence + dout = [1000, 3, 3, 3, 3, 2, 2, 2, 1, 1, 1] + din = [103, 102, 102, 102, 102, 102, 102, 102, 102, 102] + assert not nx.is_digraphical(din, dout) + # Test digraphical small sequence + dout = [1, 1, 1, 1, 1, 2, 2, 2, 3, 4] + din = [2, 2, 2, 2, 2, 2, 2, 2, 1, 1] + assert nx.is_digraphical(din, dout) + # Test nonmatching sum + din = [2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1] + assert not nx.is_digraphical(din, dout) + # Test for negative integer in sequence + din = [2, 2, 2, -2, 2, 2, 2, 2, 1, 1, 4] + assert not nx.is_digraphical(din, dout) + # Test for noninteger + din = dout = [1, 1, 1.1, 1] + assert not nx.is_digraphical(din, dout) + din = dout = [1, 1, "rer", 1] + assert not nx.is_digraphical(din, dout) + + +def test_multi_sequence(): + # Test nongraphical multi sequence + seq = [1000, 3, 3, 3, 3, 2, 2, 2, 1, 1] + assert not nx.is_multigraphical(seq) + # Test small graphical multi sequence + seq = [6, 5, 4, 4, 2, 1, 1, 1] + assert nx.is_multigraphical(seq) + # Test for negative integer in sequence + seq = [6, 5, 4, -4, 2, 1, 1, 1] + assert not nx.is_multigraphical(seq) + # Test for sequence with odd sum + seq = [1, 1, 1, 1, 1, 1, 2, 2, 2, 3, 4] + assert not nx.is_multigraphical(seq) + # Test for noninteger + seq = [1, 1, 1.1, 1] + assert not nx.is_multigraphical(seq) + seq = [1, 1, "rer", 1] + assert not nx.is_multigraphical(seq) + + +def test_pseudo_sequence(): + # Test small valid pseudo sequence + seq = [1000, 3, 3, 3, 3, 2, 2, 2, 1, 1] + assert nx.is_pseudographical(seq) + # Test for sequence with odd sum + seq = [1000, 3, 3, 3, 3, 2, 2, 2, 1, 1, 1] + assert not nx.is_pseudographical(seq) + # Test for negative integer in sequence + seq = [1000, 3, 3, 3, 3, 2, 2, -2, 1, 1] + assert not nx.is_pseudographical(seq) + # Test for noninteger + seq = [1, 1, 1.1, 1] + assert not nx.is_pseudographical(seq) + seq = [1, 1, "rer", 1] + assert not nx.is_pseudographical(seq) + + +def test_numpy_degree_sequence(): + np = pytest.importorskip("numpy") + ds = np.array([1, 2, 2, 2, 1], dtype=np.int64) + assert nx.is_graphical(ds, "eg") + assert nx.is_graphical(ds, "hh") + ds = np.array([1, 2, 2, 2, 1], dtype=np.float64) + assert nx.is_graphical(ds, "eg") + assert nx.is_graphical(ds, "hh") + ds = np.array([1.1, 2, 2, 2, 1], dtype=np.float64) + pytest.raises(nx.NetworkXException, nx.is_graphical, ds, "eg") + pytest.raises(nx.NetworkXException, nx.is_graphical, ds, "hh") diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/test_hierarchy.py b/lib/python3.10/site-packages/networkx/algorithms/tests/test_hierarchy.py new file mode 100644 index 0000000000000000000000000000000000000000..eaa6a67b8b7f048719aa189b8365ef8e4c65951c --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/tests/test_hierarchy.py @@ -0,0 +1,46 @@ +import pytest + +import networkx as nx + + +def test_hierarchy_undirected(): + G = nx.cycle_graph(5) + pytest.raises(nx.NetworkXError, nx.flow_hierarchy, G) + + +def test_hierarchy_cycle(): + G = nx.cycle_graph(5, create_using=nx.DiGraph()) + assert nx.flow_hierarchy(G) == 0.0 + + +def test_hierarchy_tree(): + G = nx.full_rary_tree(2, 16, create_using=nx.DiGraph()) + assert nx.flow_hierarchy(G) == 1.0 + + +def test_hierarchy_1(): + G = nx.DiGraph() + G.add_edges_from([(0, 1), (1, 2), (2, 3), (3, 1), (3, 4), (0, 4)]) + assert nx.flow_hierarchy(G) == 0.5 + + +def test_hierarchy_weight(): + G = nx.DiGraph() + G.add_edges_from( + [ + (0, 1, {"weight": 0.3}), + (1, 2, {"weight": 0.1}), + (2, 3, {"weight": 0.1}), + (3, 1, {"weight": 0.1}), + (3, 4, {"weight": 0.3}), + (0, 4, {"weight": 0.3}), + ] + ) + assert nx.flow_hierarchy(G, weight="weight") == 0.75 + + +@pytest.mark.parametrize("n", (0, 1, 3)) +def test_hierarchy_empty_graph(n): + G = nx.empty_graph(n, create_using=nx.DiGraph) + with pytest.raises(nx.NetworkXError, match=".*not applicable to empty graphs"): + nx.flow_hierarchy(G) diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/test_hybrid.py b/lib/python3.10/site-packages/networkx/algorithms/tests/test_hybrid.py new file mode 100644 index 0000000000000000000000000000000000000000..6af0016498549caed58772e304c93113a8b693d9 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/tests/test_hybrid.py @@ -0,0 +1,24 @@ +import networkx as nx + + +def test_2d_grid_graph(): + # FC article claims 2d grid graph of size n is (3,3)-connected + # and (5,9)-connected, but I don't think it is (5,9)-connected + G = nx.grid_2d_graph(8, 8, periodic=True) + assert nx.is_kl_connected(G, 3, 3) + assert not nx.is_kl_connected(G, 5, 9) + (H, graphOK) = nx.kl_connected_subgraph(G, 5, 9, same_as_graph=True) + assert not graphOK + + +def test_small_graph(): + G = nx.Graph() + G.add_edge(1, 2) + G.add_edge(1, 3) + G.add_edge(2, 3) + assert nx.is_kl_connected(G, 2, 2) + H = nx.kl_connected_subgraph(G, 2, 2) + (H, graphOK) = nx.kl_connected_subgraph( + G, 2, 2, low_memory=True, same_as_graph=True + ) + assert graphOK diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/test_isolate.py b/lib/python3.10/site-packages/networkx/algorithms/tests/test_isolate.py new file mode 100644 index 0000000000000000000000000000000000000000..d29b306d2b13c2457905c41218e5c60793b309ba --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/tests/test_isolate.py @@ -0,0 +1,26 @@ +"""Unit tests for the :mod:`networkx.algorithms.isolates` module.""" + +import networkx as nx + + +def test_is_isolate(): + G = nx.Graph() + G.add_edge(0, 1) + G.add_node(2) + assert not nx.is_isolate(G, 0) + assert not nx.is_isolate(G, 1) + assert nx.is_isolate(G, 2) + + +def test_isolates(): + G = nx.Graph() + G.add_edge(0, 1) + G.add_nodes_from([2, 3]) + assert sorted(nx.isolates(G)) == [2, 3] + + +def test_number_of_isolates(): + G = nx.Graph() + G.add_edge(0, 1) + G.add_nodes_from([2, 3]) + assert nx.number_of_isolates(G) == 2 diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/test_link_prediction.py b/lib/python3.10/site-packages/networkx/algorithms/tests/test_link_prediction.py new file mode 100644 index 0000000000000000000000000000000000000000..0878496bc2aa1b81b45fe36bbc5d86c1cd4d204f --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/tests/test_link_prediction.py @@ -0,0 +1,586 @@ +import math +from functools import partial + +import pytest + +import networkx as nx + + +def _test_func(G, ebunch, expected, predict_func, **kwargs): + result = predict_func(G, ebunch, **kwargs) + exp_dict = {tuple(sorted([u, v])): score for u, v, score in expected} + res_dict = {tuple(sorted([u, v])): score for u, v, score in result} + + assert len(exp_dict) == len(res_dict) + for p in exp_dict: + assert exp_dict[p] == pytest.approx(res_dict[p], abs=1e-7) + + +class TestResourceAllocationIndex: + @classmethod + def setup_class(cls): + cls.func = staticmethod(nx.resource_allocation_index) + cls.test = partial(_test_func, predict_func=cls.func) + + def test_K5(self): + G = nx.complete_graph(5) + self.test(G, [(0, 1)], [(0, 1, 0.75)]) + + def test_P3(self): + G = nx.path_graph(3) + self.test(G, [(0, 2)], [(0, 2, 0.5)]) + + def test_S4(self): + G = nx.star_graph(4) + self.test(G, [(1, 2)], [(1, 2, 0.25)]) + + @pytest.mark.parametrize("graph_type", (nx.DiGraph, nx.MultiGraph, nx.MultiDiGraph)) + def test_notimplemented(self, graph_type): + assert pytest.raises( + nx.NetworkXNotImplemented, self.func, graph_type([(0, 1), (1, 2)]), [(0, 2)] + ) + + def test_node_not_found(self): + G = nx.Graph() + G.add_edges_from([(0, 1), (0, 2), (2, 3)]) + assert pytest.raises(nx.NodeNotFound, self.func, G, [(0, 4)]) + + def test_no_common_neighbor(self): + G = nx.Graph() + G.add_nodes_from([0, 1]) + self.test(G, [(0, 1)], [(0, 1, 0)]) + + def test_equal_nodes(self): + G = nx.complete_graph(4) + self.test(G, [(0, 0)], [(0, 0, 1)]) + + def test_all_nonexistent_edges(self): + G = nx.Graph() + G.add_edges_from([(0, 1), (0, 2), (2, 3)]) + self.test(G, None, [(0, 3, 0.5), (1, 2, 0.5), (1, 3, 0)]) + + +class TestJaccardCoefficient: + @classmethod + def setup_class(cls): + cls.func = staticmethod(nx.jaccard_coefficient) + cls.test = partial(_test_func, predict_func=cls.func) + + def test_K5(self): + G = nx.complete_graph(5) + self.test(G, [(0, 1)], [(0, 1, 0.6)]) + + def test_P4(self): + G = nx.path_graph(4) + self.test(G, [(0, 2)], [(0, 2, 0.5)]) + + @pytest.mark.parametrize("graph_type", (nx.DiGraph, nx.MultiGraph, nx.MultiDiGraph)) + def test_notimplemented(self, graph_type): + assert pytest.raises( + nx.NetworkXNotImplemented, self.func, graph_type([(0, 1), (1, 2)]), [(0, 2)] + ) + + def test_node_not_found(self): + G = nx.Graph() + G.add_edges_from([(0, 1), (0, 2), (2, 3)]) + assert pytest.raises(nx.NodeNotFound, self.func, G, [(0, 4)]) + + def test_no_common_neighbor(self): + G = nx.Graph() + G.add_edges_from([(0, 1), (2, 3)]) + self.test(G, [(0, 2)], [(0, 2, 0)]) + + def test_isolated_nodes(self): + G = nx.Graph() + G.add_nodes_from([0, 1]) + self.test(G, [(0, 1)], [(0, 1, 0)]) + + def test_all_nonexistent_edges(self): + G = nx.Graph() + G.add_edges_from([(0, 1), (0, 2), (2, 3)]) + self.test(G, None, [(0, 3, 0.5), (1, 2, 0.5), (1, 3, 0)]) + + +class TestAdamicAdarIndex: + @classmethod + def setup_class(cls): + cls.func = staticmethod(nx.adamic_adar_index) + cls.test = partial(_test_func, predict_func=cls.func) + + def test_K5(self): + G = nx.complete_graph(5) + self.test(G, [(0, 1)], [(0, 1, 3 / math.log(4))]) + + def test_P3(self): + G = nx.path_graph(3) + self.test(G, [(0, 2)], [(0, 2, 1 / math.log(2))]) + + def test_S4(self): + G = nx.star_graph(4) + self.test(G, [(1, 2)], [(1, 2, 1 / math.log(4))]) + + @pytest.mark.parametrize("graph_type", (nx.DiGraph, nx.MultiGraph, nx.MultiDiGraph)) + def test_notimplemented(self, graph_type): + assert pytest.raises( + nx.NetworkXNotImplemented, self.func, graph_type([(0, 1), (1, 2)]), [(0, 2)] + ) + + def test_node_not_found(self): + G = nx.Graph() + G.add_edges_from([(0, 1), (0, 2), (2, 3)]) + assert pytest.raises(nx.NodeNotFound, self.func, G, [(0, 4)]) + + def test_no_common_neighbor(self): + G = nx.Graph() + G.add_nodes_from([0, 1]) + self.test(G, [(0, 1)], [(0, 1, 0)]) + + def test_equal_nodes(self): + G = nx.complete_graph(4) + self.test(G, [(0, 0)], [(0, 0, 3 / math.log(3))]) + + def test_all_nonexistent_edges(self): + G = nx.Graph() + G.add_edges_from([(0, 1), (0, 2), (2, 3)]) + self.test( + G, None, [(0, 3, 1 / math.log(2)), (1, 2, 1 / math.log(2)), (1, 3, 0)] + ) + + +class TestCommonNeighborCentrality: + @classmethod + def setup_class(cls): + cls.func = staticmethod(nx.common_neighbor_centrality) + cls.test = partial(_test_func, predict_func=cls.func) + + def test_K5(self): + G = nx.complete_graph(5) + self.test(G, [(0, 1)], [(0, 1, 3.0)], alpha=1) + self.test(G, [(0, 1)], [(0, 1, 5.0)], alpha=0) + + def test_P3(self): + G = nx.path_graph(3) + self.test(G, [(0, 2)], [(0, 2, 1.25)], alpha=0.5) + + def test_S4(self): + G = nx.star_graph(4) + self.test(G, [(1, 2)], [(1, 2, 1.75)], alpha=0.5) + + @pytest.mark.parametrize("graph_type", (nx.DiGraph, nx.MultiGraph, nx.MultiDiGraph)) + def test_notimplemented(self, graph_type): + assert pytest.raises( + nx.NetworkXNotImplemented, self.func, graph_type([(0, 1), (1, 2)]), [(0, 2)] + ) + + def test_node_u_not_found(self): + G = nx.Graph() + G.add_edges_from([(1, 3), (2, 3)]) + assert pytest.raises(nx.NodeNotFound, self.func, G, [(0, 1)]) + + def test_node_v_not_found(self): + G = nx.Graph() + G.add_edges_from([(0, 1), (0, 2), (2, 3)]) + assert pytest.raises(nx.NodeNotFound, self.func, G, [(0, 4)]) + + def test_no_common_neighbor(self): + G = nx.Graph() + G.add_nodes_from([0, 1]) + self.test(G, [(0, 1)], [(0, 1, 0)]) + + def test_equal_nodes(self): + G = nx.complete_graph(4) + assert pytest.raises(nx.NetworkXAlgorithmError, self.test, G, [(0, 0)], []) + + def test_equal_nodes_with_alpha_one_raises_error(self): + G = nx.complete_graph(4) + assert pytest.raises( + nx.NetworkXAlgorithmError, self.test, G, [(0, 0)], [], alpha=1.0 + ) + + def test_all_nonexistent_edges(self): + G = nx.Graph() + G.add_edges_from([(0, 1), (0, 2), (2, 3)]) + self.test(G, None, [(0, 3, 1.5), (1, 2, 1.5), (1, 3, 2 / 3)], alpha=0.5) + + +class TestPreferentialAttachment: + @classmethod + def setup_class(cls): + cls.func = staticmethod(nx.preferential_attachment) + cls.test = partial(_test_func, predict_func=cls.func) + + def test_K5(self): + G = nx.complete_graph(5) + self.test(G, [(0, 1)], [(0, 1, 16)]) + + def test_P3(self): + G = nx.path_graph(3) + self.test(G, [(0, 1)], [(0, 1, 2)]) + + def test_S4(self): + G = nx.star_graph(4) + self.test(G, [(0, 2)], [(0, 2, 4)]) + + @pytest.mark.parametrize("graph_type", (nx.DiGraph, nx.MultiGraph, nx.MultiDiGraph)) + def test_notimplemented(self, graph_type): + assert pytest.raises( + nx.NetworkXNotImplemented, self.func, graph_type([(0, 1), (1, 2)]), [(0, 2)] + ) + + def test_node_not_found(self): + G = nx.Graph() + G.add_edges_from([(0, 1), (0, 2), (2, 3)]) + assert pytest.raises(nx.NodeNotFound, self.func, G, [(0, 4)]) + + def test_zero_degrees(self): + G = nx.Graph() + G.add_nodes_from([0, 1]) + self.test(G, [(0, 1)], [(0, 1, 0)]) + + def test_all_nonexistent_edges(self): + G = nx.Graph() + G.add_edges_from([(0, 1), (0, 2), (2, 3)]) + self.test(G, None, [(0, 3, 2), (1, 2, 2), (1, 3, 1)]) + + +class TestCNSoundarajanHopcroft: + @classmethod + def setup_class(cls): + cls.func = staticmethod(nx.cn_soundarajan_hopcroft) + cls.test = partial(_test_func, predict_func=cls.func, community="community") + + def test_K5(self): + G = nx.complete_graph(5) + G.nodes[0]["community"] = 0 + G.nodes[1]["community"] = 0 + G.nodes[2]["community"] = 0 + G.nodes[3]["community"] = 0 + G.nodes[4]["community"] = 1 + self.test(G, [(0, 1)], [(0, 1, 5)]) + + def test_P3(self): + G = nx.path_graph(3) + G.nodes[0]["community"] = 0 + G.nodes[1]["community"] = 1 + G.nodes[2]["community"] = 0 + self.test(G, [(0, 2)], [(0, 2, 1)]) + + def test_S4(self): + G = nx.star_graph(4) + G.nodes[0]["community"] = 1 + G.nodes[1]["community"] = 1 + G.nodes[2]["community"] = 1 + G.nodes[3]["community"] = 0 + G.nodes[4]["community"] = 0 + self.test(G, [(1, 2)], [(1, 2, 2)]) + + @pytest.mark.parametrize("graph_type", (nx.DiGraph, nx.MultiGraph, nx.MultiDiGraph)) + def test_notimplemented(self, graph_type): + G = graph_type([(0, 1), (1, 2)]) + G.add_nodes_from([0, 1, 2], community=0) + assert pytest.raises(nx.NetworkXNotImplemented, self.func, G, [(0, 2)]) + + def test_node_not_found(self): + G = nx.Graph() + G.add_edges_from([(0, 1), (0, 2), (2, 3)]) + G.nodes[0]["community"] = 0 + G.nodes[1]["community"] = 1 + G.nodes[2]["community"] = 0 + G.nodes[3]["community"] = 0 + assert pytest.raises(nx.NodeNotFound, self.func, G, [(0, 4)]) + + def test_no_common_neighbor(self): + G = nx.Graph() + G.add_nodes_from([0, 1]) + G.nodes[0]["community"] = 0 + G.nodes[1]["community"] = 0 + self.test(G, [(0, 1)], [(0, 1, 0)]) + + def test_equal_nodes(self): + G = nx.complete_graph(3) + G.nodes[0]["community"] = 0 + G.nodes[1]["community"] = 0 + G.nodes[2]["community"] = 0 + self.test(G, [(0, 0)], [(0, 0, 4)]) + + def test_different_community(self): + G = nx.Graph() + G.add_edges_from([(0, 1), (0, 2), (1, 3), (2, 3)]) + G.nodes[0]["community"] = 0 + G.nodes[1]["community"] = 0 + G.nodes[2]["community"] = 0 + G.nodes[3]["community"] = 1 + self.test(G, [(0, 3)], [(0, 3, 2)]) + + def test_no_community_information(self): + G = nx.complete_graph(5) + assert pytest.raises(nx.NetworkXAlgorithmError, list, self.func(G, [(0, 1)])) + + def test_insufficient_community_information(self): + G = nx.Graph() + G.add_edges_from([(0, 1), (0, 2), (1, 3), (2, 3)]) + G.nodes[0]["community"] = 0 + G.nodes[1]["community"] = 0 + G.nodes[3]["community"] = 0 + assert pytest.raises(nx.NetworkXAlgorithmError, list, self.func(G, [(0, 3)])) + + def test_sufficient_community_information(self): + G = nx.Graph() + G.add_edges_from([(0, 1), (1, 2), (1, 3), (2, 4), (3, 4), (4, 5)]) + G.nodes[1]["community"] = 0 + G.nodes[2]["community"] = 0 + G.nodes[3]["community"] = 0 + G.nodes[4]["community"] = 0 + self.test(G, [(1, 4)], [(1, 4, 4)]) + + def test_custom_community_attribute_name(self): + G = nx.Graph() + G.add_edges_from([(0, 1), (0, 2), (1, 3), (2, 3)]) + G.nodes[0]["cmty"] = 0 + G.nodes[1]["cmty"] = 0 + G.nodes[2]["cmty"] = 0 + G.nodes[3]["cmty"] = 1 + self.test(G, [(0, 3)], [(0, 3, 2)], community="cmty") + + def test_all_nonexistent_edges(self): + G = nx.Graph() + G.add_edges_from([(0, 1), (0, 2), (2, 3)]) + G.nodes[0]["community"] = 0 + G.nodes[1]["community"] = 1 + G.nodes[2]["community"] = 0 + G.nodes[3]["community"] = 0 + self.test(G, None, [(0, 3, 2), (1, 2, 1), (1, 3, 0)]) + + +class TestRAIndexSoundarajanHopcroft: + @classmethod + def setup_class(cls): + cls.func = staticmethod(nx.ra_index_soundarajan_hopcroft) + cls.test = partial(_test_func, predict_func=cls.func, community="community") + + def test_K5(self): + G = nx.complete_graph(5) + G.nodes[0]["community"] = 0 + G.nodes[1]["community"] = 0 + G.nodes[2]["community"] = 0 + G.nodes[3]["community"] = 0 + G.nodes[4]["community"] = 1 + self.test(G, [(0, 1)], [(0, 1, 0.5)]) + + def test_P3(self): + G = nx.path_graph(3) + G.nodes[0]["community"] = 0 + G.nodes[1]["community"] = 1 + G.nodes[2]["community"] = 0 + self.test(G, [(0, 2)], [(0, 2, 0)]) + + def test_S4(self): + G = nx.star_graph(4) + G.nodes[0]["community"] = 1 + G.nodes[1]["community"] = 1 + G.nodes[2]["community"] = 1 + G.nodes[3]["community"] = 0 + G.nodes[4]["community"] = 0 + self.test(G, [(1, 2)], [(1, 2, 0.25)]) + + @pytest.mark.parametrize("graph_type", (nx.DiGraph, nx.MultiGraph, nx.MultiDiGraph)) + def test_notimplemented(self, graph_type): + G = graph_type([(0, 1), (1, 2)]) + G.add_nodes_from([0, 1, 2], community=0) + assert pytest.raises(nx.NetworkXNotImplemented, self.func, G, [(0, 2)]) + + def test_node_not_found(self): + G = nx.Graph() + G.add_edges_from([(0, 1), (0, 2), (2, 3)]) + G.nodes[0]["community"] = 0 + G.nodes[1]["community"] = 1 + G.nodes[2]["community"] = 0 + G.nodes[3]["community"] = 0 + assert pytest.raises(nx.NodeNotFound, self.func, G, [(0, 4)]) + + def test_no_common_neighbor(self): + G = nx.Graph() + G.add_nodes_from([0, 1]) + G.nodes[0]["community"] = 0 + G.nodes[1]["community"] = 0 + self.test(G, [(0, 1)], [(0, 1, 0)]) + + def test_equal_nodes(self): + G = nx.complete_graph(3) + G.nodes[0]["community"] = 0 + G.nodes[1]["community"] = 0 + G.nodes[2]["community"] = 0 + self.test(G, [(0, 0)], [(0, 0, 1)]) + + def test_different_community(self): + G = nx.Graph() + G.add_edges_from([(0, 1), (0, 2), (1, 3), (2, 3)]) + G.nodes[0]["community"] = 0 + G.nodes[1]["community"] = 0 + G.nodes[2]["community"] = 0 + G.nodes[3]["community"] = 1 + self.test(G, [(0, 3)], [(0, 3, 0)]) + + def test_no_community_information(self): + G = nx.complete_graph(5) + assert pytest.raises(nx.NetworkXAlgorithmError, list, self.func(G, [(0, 1)])) + + def test_insufficient_community_information(self): + G = nx.Graph() + G.add_edges_from([(0, 1), (0, 2), (1, 3), (2, 3)]) + G.nodes[0]["community"] = 0 + G.nodes[1]["community"] = 0 + G.nodes[3]["community"] = 0 + assert pytest.raises(nx.NetworkXAlgorithmError, list, self.func(G, [(0, 3)])) + + def test_sufficient_community_information(self): + G = nx.Graph() + G.add_edges_from([(0, 1), (1, 2), (1, 3), (2, 4), (3, 4), (4, 5)]) + G.nodes[1]["community"] = 0 + G.nodes[2]["community"] = 0 + G.nodes[3]["community"] = 0 + G.nodes[4]["community"] = 0 + self.test(G, [(1, 4)], [(1, 4, 1)]) + + def test_custom_community_attribute_name(self): + G = nx.Graph() + G.add_edges_from([(0, 1), (0, 2), (1, 3), (2, 3)]) + G.nodes[0]["cmty"] = 0 + G.nodes[1]["cmty"] = 0 + G.nodes[2]["cmty"] = 0 + G.nodes[3]["cmty"] = 1 + self.test(G, [(0, 3)], [(0, 3, 0)], community="cmty") + + def test_all_nonexistent_edges(self): + G = nx.Graph() + G.add_edges_from([(0, 1), (0, 2), (2, 3)]) + G.nodes[0]["community"] = 0 + G.nodes[1]["community"] = 1 + G.nodes[2]["community"] = 0 + G.nodes[3]["community"] = 0 + self.test(G, None, [(0, 3, 0.5), (1, 2, 0), (1, 3, 0)]) + + +class TestWithinInterCluster: + @classmethod + def setup_class(cls): + cls.delta = 0.001 + cls.func = staticmethod(nx.within_inter_cluster) + cls.test = partial( + _test_func, predict_func=cls.func, delta=cls.delta, community="community" + ) + + def test_K5(self): + G = nx.complete_graph(5) + G.nodes[0]["community"] = 0 + G.nodes[1]["community"] = 0 + G.nodes[2]["community"] = 0 + G.nodes[3]["community"] = 0 + G.nodes[4]["community"] = 1 + self.test(G, [(0, 1)], [(0, 1, 2 / (1 + self.delta))]) + + def test_P3(self): + G = nx.path_graph(3) + G.nodes[0]["community"] = 0 + G.nodes[1]["community"] = 1 + G.nodes[2]["community"] = 0 + self.test(G, [(0, 2)], [(0, 2, 0)]) + + def test_S4(self): + G = nx.star_graph(4) + G.nodes[0]["community"] = 1 + G.nodes[1]["community"] = 1 + G.nodes[2]["community"] = 1 + G.nodes[3]["community"] = 0 + G.nodes[4]["community"] = 0 + self.test(G, [(1, 2)], [(1, 2, 1 / self.delta)]) + + @pytest.mark.parametrize("graph_type", (nx.DiGraph, nx.MultiGraph, nx.MultiDiGraph)) + def test_notimplemented(self, graph_type): + G = graph_type([(0, 1), (1, 2)]) + G.add_nodes_from([0, 1, 2], community=0) + assert pytest.raises(nx.NetworkXNotImplemented, self.func, G, [(0, 2)]) + + def test_node_not_found(self): + G = nx.Graph() + G.add_edges_from([(0, 1), (0, 2), (2, 3)]) + G.nodes[0]["community"] = 0 + G.nodes[1]["community"] = 1 + G.nodes[2]["community"] = 0 + G.nodes[3]["community"] = 0 + assert pytest.raises(nx.NodeNotFound, self.func, G, [(0, 4)]) + + def test_no_common_neighbor(self): + G = nx.Graph() + G.add_nodes_from([0, 1]) + G.nodes[0]["community"] = 0 + G.nodes[1]["community"] = 0 + self.test(G, [(0, 1)], [(0, 1, 0)]) + + def test_equal_nodes(self): + G = nx.complete_graph(3) + G.nodes[0]["community"] = 0 + G.nodes[1]["community"] = 0 + G.nodes[2]["community"] = 0 + self.test(G, [(0, 0)], [(0, 0, 2 / self.delta)]) + + def test_different_community(self): + G = nx.Graph() + G.add_edges_from([(0, 1), (0, 2), (1, 3), (2, 3)]) + G.nodes[0]["community"] = 0 + G.nodes[1]["community"] = 0 + G.nodes[2]["community"] = 0 + G.nodes[3]["community"] = 1 + self.test(G, [(0, 3)], [(0, 3, 0)]) + + def test_no_inter_cluster_common_neighbor(self): + G = nx.complete_graph(4) + G.nodes[0]["community"] = 0 + G.nodes[1]["community"] = 0 + G.nodes[2]["community"] = 0 + G.nodes[3]["community"] = 0 + self.test(G, [(0, 3)], [(0, 3, 2 / self.delta)]) + + def test_no_community_information(self): + G = nx.complete_graph(5) + assert pytest.raises(nx.NetworkXAlgorithmError, list, self.func(G, [(0, 1)])) + + def test_insufficient_community_information(self): + G = nx.Graph() + G.add_edges_from([(0, 1), (0, 2), (1, 3), (2, 3)]) + G.nodes[0]["community"] = 0 + G.nodes[1]["community"] = 0 + G.nodes[3]["community"] = 0 + assert pytest.raises(nx.NetworkXAlgorithmError, list, self.func(G, [(0, 3)])) + + def test_sufficient_community_information(self): + G = nx.Graph() + G.add_edges_from([(0, 1), (1, 2), (1, 3), (2, 4), (3, 4), (4, 5)]) + G.nodes[1]["community"] = 0 + G.nodes[2]["community"] = 0 + G.nodes[3]["community"] = 0 + G.nodes[4]["community"] = 0 + self.test(G, [(1, 4)], [(1, 4, 2 / self.delta)]) + + def test_invalid_delta(self): + G = nx.complete_graph(3) + G.add_nodes_from([0, 1, 2], community=0) + assert pytest.raises(nx.NetworkXAlgorithmError, self.func, G, [(0, 1)], 0) + assert pytest.raises(nx.NetworkXAlgorithmError, self.func, G, [(0, 1)], -0.5) + + def test_custom_community_attribute_name(self): + G = nx.complete_graph(4) + G.nodes[0]["cmty"] = 0 + G.nodes[1]["cmty"] = 0 + G.nodes[2]["cmty"] = 0 + G.nodes[3]["cmty"] = 0 + self.test(G, [(0, 3)], [(0, 3, 2 / self.delta)], community="cmty") + + def test_all_nonexistent_edges(self): + G = nx.Graph() + G.add_edges_from([(0, 1), (0, 2), (2, 3)]) + G.nodes[0]["community"] = 0 + G.nodes[1]["community"] = 1 + G.nodes[2]["community"] = 0 + G.nodes[3]["community"] = 0 + self.test(G, None, [(0, 3, 1 / self.delta), (1, 2, 0), (1, 3, 0)]) diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/test_lowest_common_ancestors.py b/lib/python3.10/site-packages/networkx/algorithms/tests/test_lowest_common_ancestors.py new file mode 100644 index 0000000000000000000000000000000000000000..66d75220327cb27c8b378505aea2780ea96021af --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/tests/test_lowest_common_ancestors.py @@ -0,0 +1,427 @@ +from itertools import chain, combinations, product + +import pytest + +import networkx as nx + +tree_all_pairs_lca = nx.tree_all_pairs_lowest_common_ancestor +all_pairs_lca = nx.all_pairs_lowest_common_ancestor + + +def get_pair(dictionary, n1, n2): + if (n1, n2) in dictionary: + return dictionary[n1, n2] + else: + return dictionary[n2, n1] + + +class TestTreeLCA: + @classmethod + def setup_class(cls): + cls.DG = nx.DiGraph() + edges = [(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)] + cls.DG.add_edges_from(edges) + cls.ans = dict(tree_all_pairs_lca(cls.DG, 0)) + gold = {(n, n): n for n in cls.DG} + gold.update({(0, i): 0 for i in range(1, 7)}) + gold.update( + { + (1, 2): 0, + (1, 3): 1, + (1, 4): 1, + (1, 5): 0, + (1, 6): 0, + (2, 3): 0, + (2, 4): 0, + (2, 5): 2, + (2, 6): 2, + (3, 4): 1, + (3, 5): 0, + (3, 6): 0, + (4, 5): 0, + (4, 6): 0, + (5, 6): 2, + } + ) + + cls.gold = gold + + @staticmethod + def assert_has_same_pairs(d1, d2): + for a, b in ((min(pair), max(pair)) for pair in chain(d1, d2)): + assert get_pair(d1, a, b) == get_pair(d2, a, b) + + def test_tree_all_pairs_lca_default_root(self): + assert dict(tree_all_pairs_lca(self.DG)) == self.ans + + def test_tree_all_pairs_lca_return_subset(self): + test_pairs = [(0, 1), (0, 1), (1, 0)] + ans = dict(tree_all_pairs_lca(self.DG, 0, test_pairs)) + assert (0, 1) in ans and (1, 0) in ans + assert len(ans) == 2 + + def test_tree_all_pairs_lca(self): + all_pairs = chain(combinations(self.DG, 2), ((node, node) for node in self.DG)) + + ans = dict(tree_all_pairs_lca(self.DG, 0, all_pairs)) + self.assert_has_same_pairs(ans, self.ans) + + def test_tree_all_pairs_gold_example(self): + ans = dict(tree_all_pairs_lca(self.DG)) + self.assert_has_same_pairs(self.gold, ans) + + def test_tree_all_pairs_lca_invalid_input(self): + empty_digraph = tree_all_pairs_lca(nx.DiGraph()) + pytest.raises(nx.NetworkXPointlessConcept, list, empty_digraph) + + bad_pairs_digraph = tree_all_pairs_lca(self.DG, pairs=[(-1, -2)]) + pytest.raises(nx.NodeNotFound, list, bad_pairs_digraph) + + def test_tree_all_pairs_lca_subtrees(self): + ans = dict(tree_all_pairs_lca(self.DG, 1)) + gold = { + pair: lca + for (pair, lca) in self.gold.items() + if all(n in (1, 3, 4) for n in pair) + } + self.assert_has_same_pairs(gold, ans) + + def test_tree_all_pairs_lca_disconnected_nodes(self): + G = nx.DiGraph() + G.add_node(1) + assert {(1, 1): 1} == dict(tree_all_pairs_lca(G)) + + G.add_node(0) + assert {(1, 1): 1} == dict(tree_all_pairs_lca(G, 1)) + assert {(0, 0): 0} == dict(tree_all_pairs_lca(G, 0)) + + pytest.raises(nx.NetworkXError, list, tree_all_pairs_lca(G)) + + def test_tree_all_pairs_lca_error_if_input_not_tree(self): + # Cycle + G = nx.DiGraph([(1, 2), (2, 1)]) + pytest.raises(nx.NetworkXError, list, tree_all_pairs_lca(G)) + # DAG + G = nx.DiGraph([(0, 2), (1, 2)]) + pytest.raises(nx.NetworkXError, list, tree_all_pairs_lca(G)) + + def test_tree_all_pairs_lca_generator(self): + pairs = iter([(0, 1), (0, 1), (1, 0)]) + some_pairs = dict(tree_all_pairs_lca(self.DG, 0, pairs)) + assert (0, 1) in some_pairs and (1, 0) in some_pairs + assert len(some_pairs) == 2 + + def test_tree_all_pairs_lca_nonexisting_pairs_exception(self): + lca = tree_all_pairs_lca(self.DG, 0, [(-1, -1)]) + pytest.raises(nx.NodeNotFound, list, lca) + # check if node is None + lca = tree_all_pairs_lca(self.DG, None, [(-1, -1)]) + pytest.raises(nx.NodeNotFound, list, lca) + + def test_tree_all_pairs_lca_routine_bails_on_DAGs(self): + G = nx.DiGraph([(3, 4), (5, 4)]) + pytest.raises(nx.NetworkXError, list, tree_all_pairs_lca(G)) + + def test_tree_all_pairs_lca_not_implemented(self): + NNI = nx.NetworkXNotImplemented + G = nx.Graph([(0, 1)]) + with pytest.raises(NNI): + next(tree_all_pairs_lca(G)) + with pytest.raises(NNI): + next(all_pairs_lca(G)) + pytest.raises(NNI, nx.lowest_common_ancestor, G, 0, 1) + G = nx.MultiGraph([(0, 1)]) + with pytest.raises(NNI): + next(tree_all_pairs_lca(G)) + with pytest.raises(NNI): + next(all_pairs_lca(G)) + pytest.raises(NNI, nx.lowest_common_ancestor, G, 0, 1) + + def test_tree_all_pairs_lca_trees_without_LCAs(self): + G = nx.DiGraph() + G.add_node(3) + ans = list(tree_all_pairs_lca(G)) + assert ans == [((3, 3), 3)] + + +class TestMultiTreeLCA(TestTreeLCA): + @classmethod + def setup_class(cls): + cls.DG = nx.MultiDiGraph() + edges = [(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)] + cls.DG.add_edges_from(edges) + cls.ans = dict(tree_all_pairs_lca(cls.DG, 0)) + # add multiedges + cls.DG.add_edges_from(edges) + + gold = {(n, n): n for n in cls.DG} + gold.update({(0, i): 0 for i in range(1, 7)}) + gold.update( + { + (1, 2): 0, + (1, 3): 1, + (1, 4): 1, + (1, 5): 0, + (1, 6): 0, + (2, 3): 0, + (2, 4): 0, + (2, 5): 2, + (2, 6): 2, + (3, 4): 1, + (3, 5): 0, + (3, 6): 0, + (4, 5): 0, + (4, 6): 0, + (5, 6): 2, + } + ) + + cls.gold = gold + + +class TestDAGLCA: + @classmethod + def setup_class(cls): + cls.DG = nx.DiGraph() + nx.add_path(cls.DG, (0, 1, 2, 3)) + nx.add_path(cls.DG, (0, 4, 3)) + nx.add_path(cls.DG, (0, 5, 6, 8, 3)) + nx.add_path(cls.DG, (5, 7, 8)) + cls.DG.add_edge(6, 2) + cls.DG.add_edge(7, 2) + + cls.root_distance = nx.shortest_path_length(cls.DG, source=0) + + cls.gold = { + (1, 1): 1, + (1, 2): 1, + (1, 3): 1, + (1, 4): 0, + (1, 5): 0, + (1, 6): 0, + (1, 7): 0, + (1, 8): 0, + (2, 2): 2, + (2, 3): 2, + (2, 4): 0, + (2, 5): 5, + (2, 6): 6, + (2, 7): 7, + (2, 8): 7, + (3, 3): 3, + (3, 4): 4, + (3, 5): 5, + (3, 6): 6, + (3, 7): 7, + (3, 8): 8, + (4, 4): 4, + (4, 5): 0, + (4, 6): 0, + (4, 7): 0, + (4, 8): 0, + (5, 5): 5, + (5, 6): 5, + (5, 7): 5, + (5, 8): 5, + (6, 6): 6, + (6, 7): 5, + (6, 8): 6, + (7, 7): 7, + (7, 8): 7, + (8, 8): 8, + } + cls.gold.update(((0, n), 0) for n in cls.DG) + + def assert_lca_dicts_same(self, d1, d2, G=None): + """Checks if d1 and d2 contain the same pairs and + have a node at the same distance from root for each. + If G is None use self.DG.""" + if G is None: + G = self.DG + root_distance = self.root_distance + else: + roots = [n for n, deg in G.in_degree if deg == 0] + assert len(roots) == 1 + root_distance = nx.shortest_path_length(G, source=roots[0]) + + for a, b in ((min(pair), max(pair)) for pair in chain(d1, d2)): + assert ( + root_distance[get_pair(d1, a, b)] == root_distance[get_pair(d2, a, b)] + ) + + def test_all_pairs_lca_gold_example(self): + self.assert_lca_dicts_same(dict(all_pairs_lca(self.DG)), self.gold) + + def test_all_pairs_lca_all_pairs_given(self): + all_pairs = list(product(self.DG.nodes(), self.DG.nodes())) + ans = all_pairs_lca(self.DG, pairs=all_pairs) + self.assert_lca_dicts_same(dict(ans), self.gold) + + def test_all_pairs_lca_generator(self): + all_pairs = product(self.DG.nodes(), self.DG.nodes()) + ans = all_pairs_lca(self.DG, pairs=all_pairs) + self.assert_lca_dicts_same(dict(ans), self.gold) + + def test_all_pairs_lca_input_graph_with_two_roots(self): + G = self.DG.copy() + G.add_edge(9, 10) + G.add_edge(9, 4) + gold = self.gold.copy() + gold[9, 9] = 9 + gold[9, 10] = 9 + gold[9, 4] = 9 + gold[9, 3] = 9 + gold[10, 4] = 9 + gold[10, 3] = 9 + gold[10, 10] = 10 + + testing = dict(all_pairs_lca(G)) + + G.add_edge(-1, 9) + G.add_edge(-1, 0) + self.assert_lca_dicts_same(testing, gold, G) + + def test_all_pairs_lca_nonexisting_pairs_exception(self): + pytest.raises(nx.NodeNotFound, all_pairs_lca, self.DG, [(-1, -1)]) + + def test_all_pairs_lca_pairs_without_lca(self): + G = self.DG.copy() + G.add_node(-1) + gen = all_pairs_lca(G, [(-1, -1), (-1, 0)]) + assert dict(gen) == {(-1, -1): -1} + + def test_all_pairs_lca_null_graph(self): + pytest.raises(nx.NetworkXPointlessConcept, all_pairs_lca, nx.DiGraph()) + + def test_all_pairs_lca_non_dags(self): + pytest.raises(nx.NetworkXError, all_pairs_lca, nx.DiGraph([(3, 4), (4, 3)])) + + def test_all_pairs_lca_nonempty_graph_without_lca(self): + G = nx.DiGraph() + G.add_node(3) + ans = list(all_pairs_lca(G)) + assert ans == [((3, 3), 3)] + + def test_all_pairs_lca_bug_gh4942(self): + G = nx.DiGraph([(0, 2), (1, 2), (2, 3)]) + ans = list(all_pairs_lca(G)) + assert len(ans) == 9 + + def test_all_pairs_lca_default_kwarg(self): + G = nx.DiGraph([(0, 1), (2, 1)]) + sentinel = object() + assert nx.lowest_common_ancestor(G, 0, 2, default=sentinel) is sentinel + + def test_all_pairs_lca_identity(self): + G = nx.DiGraph() + G.add_node(3) + assert nx.lowest_common_ancestor(G, 3, 3) == 3 + + def test_all_pairs_lca_issue_4574(self): + G = nx.DiGraph() + G.add_nodes_from(range(17)) + G.add_edges_from( + [ + (2, 0), + (1, 2), + (3, 2), + (5, 2), + (8, 2), + (11, 2), + (4, 5), + (6, 5), + (7, 8), + (10, 8), + (13, 11), + (14, 11), + (15, 11), + (9, 10), + (12, 13), + (16, 15), + ] + ) + + assert nx.lowest_common_ancestor(G, 7, 9) == None + + def test_all_pairs_lca_one_pair_gh4942(self): + G = nx.DiGraph() + # Note: order edge addition is critical to the test + G.add_edge(0, 1) + G.add_edge(2, 0) + G.add_edge(2, 3) + G.add_edge(4, 0) + G.add_edge(5, 2) + + assert nx.lowest_common_ancestor(G, 1, 3) == 2 + + +class TestMultiDiGraph_DAGLCA(TestDAGLCA): + @classmethod + def setup_class(cls): + cls.DG = nx.MultiDiGraph() + nx.add_path(cls.DG, (0, 1, 2, 3)) + # add multiedges + nx.add_path(cls.DG, (0, 1, 2, 3)) + nx.add_path(cls.DG, (0, 4, 3)) + nx.add_path(cls.DG, (0, 5, 6, 8, 3)) + nx.add_path(cls.DG, (5, 7, 8)) + cls.DG.add_edge(6, 2) + cls.DG.add_edge(7, 2) + + cls.root_distance = nx.shortest_path_length(cls.DG, source=0) + + cls.gold = { + (1, 1): 1, + (1, 2): 1, + (1, 3): 1, + (1, 4): 0, + (1, 5): 0, + (1, 6): 0, + (1, 7): 0, + (1, 8): 0, + (2, 2): 2, + (2, 3): 2, + (2, 4): 0, + (2, 5): 5, + (2, 6): 6, + (2, 7): 7, + (2, 8): 7, + (3, 3): 3, + (3, 4): 4, + (3, 5): 5, + (3, 6): 6, + (3, 7): 7, + (3, 8): 8, + (4, 4): 4, + (4, 5): 0, + (4, 6): 0, + (4, 7): 0, + (4, 8): 0, + (5, 5): 5, + (5, 6): 5, + (5, 7): 5, + (5, 8): 5, + (6, 6): 6, + (6, 7): 5, + (6, 8): 6, + (7, 7): 7, + (7, 8): 7, + (8, 8): 8, + } + cls.gold.update(((0, n), 0) for n in cls.DG) + + +def test_all_pairs_lca_self_ancestors(): + """Self-ancestors should always be the node itself, i.e. lca of (0, 0) is 0. + See gh-4458.""" + # DAG for test - note order of node/edge addition is relevant + G = nx.DiGraph() + G.add_nodes_from(range(5)) + G.add_edges_from([(1, 0), (2, 0), (3, 2), (4, 1), (4, 3)]) + + ap_lca = nx.all_pairs_lowest_common_ancestor + assert all(u == v == a for (u, v), a in ap_lca(G) if u == v) + MG = nx.MultiDiGraph(G) + assert all(u == v == a for (u, v), a in ap_lca(MG) if u == v) + MG.add_edges_from([(1, 0), (2, 0)]) + assert all(u == v == a for (u, v), a in ap_lca(MG) if u == v) diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/test_matching.py b/lib/python3.10/site-packages/networkx/algorithms/tests/test_matching.py new file mode 100644 index 0000000000000000000000000000000000000000..37853e3896c0fd6bcac1f46524a844ae2e2fb518 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/tests/test_matching.py @@ -0,0 +1,605 @@ +import math +from itertools import permutations + +from pytest import raises + +import networkx as nx +from networkx.algorithms.matching import matching_dict_to_set +from networkx.utils import edges_equal + + +class TestMaxWeightMatching: + """Unit tests for the + :func:`~networkx.algorithms.matching.max_weight_matching` function. + + """ + + def test_trivial1(self): + """Empty graph""" + G = nx.Graph() + assert nx.max_weight_matching(G) == set() + assert nx.min_weight_matching(G) == set() + + def test_selfloop(self): + G = nx.Graph() + G.add_edge(0, 0, weight=100) + assert nx.max_weight_matching(G) == set() + assert nx.min_weight_matching(G) == set() + + def test_single_edge(self): + G = nx.Graph() + G.add_edge(0, 1) + assert edges_equal( + nx.max_weight_matching(G), matching_dict_to_set({0: 1, 1: 0}) + ) + assert edges_equal( + nx.min_weight_matching(G), matching_dict_to_set({0: 1, 1: 0}) + ) + + def test_two_path(self): + G = nx.Graph() + G.add_edge("one", "two", weight=10) + G.add_edge("two", "three", weight=11) + assert edges_equal( + nx.max_weight_matching(G), + matching_dict_to_set({"three": "two", "two": "three"}), + ) + assert edges_equal( + nx.min_weight_matching(G), + matching_dict_to_set({"one": "two", "two": "one"}), + ) + + def test_path(self): + G = nx.Graph() + G.add_edge(1, 2, weight=5) + G.add_edge(2, 3, weight=11) + G.add_edge(3, 4, weight=5) + assert edges_equal( + nx.max_weight_matching(G), matching_dict_to_set({2: 3, 3: 2}) + ) + assert edges_equal( + nx.max_weight_matching(G, 1), matching_dict_to_set({1: 2, 2: 1, 3: 4, 4: 3}) + ) + assert edges_equal( + nx.min_weight_matching(G), matching_dict_to_set({1: 2, 3: 4}) + ) + assert edges_equal( + nx.min_weight_matching(G, 1), matching_dict_to_set({1: 2, 3: 4}) + ) + + def test_square(self): + G = nx.Graph() + G.add_edge(1, 4, weight=2) + G.add_edge(2, 3, weight=2) + G.add_edge(1, 2, weight=1) + G.add_edge(3, 4, weight=4) + assert edges_equal( + nx.max_weight_matching(G), matching_dict_to_set({1: 2, 3: 4}) + ) + assert edges_equal( + nx.min_weight_matching(G), matching_dict_to_set({1: 4, 2: 3}) + ) + + def test_edge_attribute_name(self): + G = nx.Graph() + G.add_edge("one", "two", weight=10, abcd=11) + G.add_edge("two", "three", weight=11, abcd=10) + assert edges_equal( + nx.max_weight_matching(G, weight="abcd"), + matching_dict_to_set({"one": "two", "two": "one"}), + ) + assert edges_equal( + nx.min_weight_matching(G, weight="abcd"), + matching_dict_to_set({"three": "two"}), + ) + + def test_floating_point_weights(self): + G = nx.Graph() + G.add_edge(1, 2, weight=math.pi) + G.add_edge(2, 3, weight=math.exp(1)) + G.add_edge(1, 3, weight=3.0) + G.add_edge(1, 4, weight=math.sqrt(2.0)) + assert edges_equal( + nx.max_weight_matching(G), matching_dict_to_set({1: 4, 2: 3, 3: 2, 4: 1}) + ) + assert edges_equal( + nx.min_weight_matching(G), matching_dict_to_set({1: 4, 2: 3, 3: 2, 4: 1}) + ) + + def test_negative_weights(self): + G = nx.Graph() + G.add_edge(1, 2, weight=2) + G.add_edge(1, 3, weight=-2) + G.add_edge(2, 3, weight=1) + G.add_edge(2, 4, weight=-1) + G.add_edge(3, 4, weight=-6) + assert edges_equal( + nx.max_weight_matching(G), matching_dict_to_set({1: 2, 2: 1}) + ) + assert edges_equal( + nx.max_weight_matching(G, maxcardinality=True), + matching_dict_to_set({1: 3, 2: 4, 3: 1, 4: 2}), + ) + assert edges_equal( + nx.min_weight_matching(G), matching_dict_to_set({1: 2, 3: 4}) + ) + + def test_s_blossom(self): + """Create S-blossom and use it for augmentation:""" + G = nx.Graph() + G.add_weighted_edges_from([(1, 2, 8), (1, 3, 9), (2, 3, 10), (3, 4, 7)]) + answer = matching_dict_to_set({1: 2, 2: 1, 3: 4, 4: 3}) + assert edges_equal(nx.max_weight_matching(G), answer) + assert edges_equal(nx.min_weight_matching(G), answer) + + G.add_weighted_edges_from([(1, 6, 5), (4, 5, 6)]) + answer = matching_dict_to_set({1: 6, 2: 3, 3: 2, 4: 5, 5: 4, 6: 1}) + assert edges_equal(nx.max_weight_matching(G), answer) + assert edges_equal(nx.min_weight_matching(G), answer) + + def test_s_t_blossom(self): + """Create S-blossom, relabel as T-blossom, use for augmentation:""" + G = nx.Graph() + G.add_weighted_edges_from( + [(1, 2, 9), (1, 3, 8), (2, 3, 10), (1, 4, 5), (4, 5, 4), (1, 6, 3)] + ) + answer = matching_dict_to_set({1: 6, 2: 3, 3: 2, 4: 5, 5: 4, 6: 1}) + assert edges_equal(nx.max_weight_matching(G), answer) + assert edges_equal(nx.min_weight_matching(G), answer) + + G.add_edge(4, 5, weight=3) + G.add_edge(1, 6, weight=4) + assert edges_equal(nx.max_weight_matching(G), answer) + assert edges_equal(nx.min_weight_matching(G), answer) + + G.remove_edge(1, 6) + G.add_edge(3, 6, weight=4) + answer = matching_dict_to_set({1: 2, 2: 1, 3: 6, 4: 5, 5: 4, 6: 3}) + assert edges_equal(nx.max_weight_matching(G), answer) + assert edges_equal(nx.min_weight_matching(G), answer) + + def test_nested_s_blossom(self): + """Create nested S-blossom, use for augmentation:""" + + G = nx.Graph() + G.add_weighted_edges_from( + [ + (1, 2, 9), + (1, 3, 9), + (2, 3, 10), + (2, 4, 8), + (3, 5, 8), + (4, 5, 10), + (5, 6, 6), + ] + ) + dict_format = {1: 3, 2: 4, 3: 1, 4: 2, 5: 6, 6: 5} + expected = {frozenset(e) for e in matching_dict_to_set(dict_format)} + answer = {frozenset(e) for e in nx.max_weight_matching(G)} + assert answer == expected + answer = {frozenset(e) for e in nx.min_weight_matching(G)} + assert answer == expected + + def test_nested_s_blossom_relabel(self): + """Create S-blossom, relabel as S, include in nested S-blossom:""" + G = nx.Graph() + G.add_weighted_edges_from( + [ + (1, 2, 10), + (1, 7, 10), + (2, 3, 12), + (3, 4, 20), + (3, 5, 20), + (4, 5, 25), + (5, 6, 10), + (6, 7, 10), + (7, 8, 8), + ] + ) + answer = matching_dict_to_set({1: 2, 2: 1, 3: 4, 4: 3, 5: 6, 6: 5, 7: 8, 8: 7}) + assert edges_equal(nx.max_weight_matching(G), answer) + assert edges_equal(nx.min_weight_matching(G), answer) + + def test_nested_s_blossom_expand(self): + """Create nested S-blossom, augment, expand recursively:""" + G = nx.Graph() + G.add_weighted_edges_from( + [ + (1, 2, 8), + (1, 3, 8), + (2, 3, 10), + (2, 4, 12), + (3, 5, 12), + (4, 5, 14), + (4, 6, 12), + (5, 7, 12), + (6, 7, 14), + (7, 8, 12), + ] + ) + answer = matching_dict_to_set({1: 2, 2: 1, 3: 5, 4: 6, 5: 3, 6: 4, 7: 8, 8: 7}) + assert edges_equal(nx.max_weight_matching(G), answer) + assert edges_equal(nx.min_weight_matching(G), answer) + + def test_s_blossom_relabel_expand(self): + """Create S-blossom, relabel as T, expand:""" + G = nx.Graph() + G.add_weighted_edges_from( + [ + (1, 2, 23), + (1, 5, 22), + (1, 6, 15), + (2, 3, 25), + (3, 4, 22), + (4, 5, 25), + (4, 8, 14), + (5, 7, 13), + ] + ) + answer = matching_dict_to_set({1: 6, 2: 3, 3: 2, 4: 8, 5: 7, 6: 1, 7: 5, 8: 4}) + assert edges_equal(nx.max_weight_matching(G), answer) + assert edges_equal(nx.min_weight_matching(G), answer) + + def test_nested_s_blossom_relabel_expand(self): + """Create nested S-blossom, relabel as T, expand:""" + G = nx.Graph() + G.add_weighted_edges_from( + [ + (1, 2, 19), + (1, 3, 20), + (1, 8, 8), + (2, 3, 25), + (2, 4, 18), + (3, 5, 18), + (4, 5, 13), + (4, 7, 7), + (5, 6, 7), + ] + ) + answer = matching_dict_to_set({1: 8, 2: 3, 3: 2, 4: 7, 5: 6, 6: 5, 7: 4, 8: 1}) + assert edges_equal(nx.max_weight_matching(G), answer) + assert edges_equal(nx.min_weight_matching(G), answer) + + def test_nasty_blossom1(self): + """Create blossom, relabel as T in more than one way, expand, + augment: + """ + G = nx.Graph() + G.add_weighted_edges_from( + [ + (1, 2, 45), + (1, 5, 45), + (2, 3, 50), + (3, 4, 45), + (4, 5, 50), + (1, 6, 30), + (3, 9, 35), + (4, 8, 35), + (5, 7, 26), + (9, 10, 5), + ] + ) + ansdict = {1: 6, 2: 3, 3: 2, 4: 8, 5: 7, 6: 1, 7: 5, 8: 4, 9: 10, 10: 9} + answer = matching_dict_to_set(ansdict) + assert edges_equal(nx.max_weight_matching(G), answer) + assert edges_equal(nx.min_weight_matching(G), answer) + + def test_nasty_blossom2(self): + """Again but slightly different:""" + G = nx.Graph() + G.add_weighted_edges_from( + [ + (1, 2, 45), + (1, 5, 45), + (2, 3, 50), + (3, 4, 45), + (4, 5, 50), + (1, 6, 30), + (3, 9, 35), + (4, 8, 26), + (5, 7, 40), + (9, 10, 5), + ] + ) + ans = {1: 6, 2: 3, 3: 2, 4: 8, 5: 7, 6: 1, 7: 5, 8: 4, 9: 10, 10: 9} + answer = matching_dict_to_set(ans) + assert edges_equal(nx.max_weight_matching(G), answer) + assert edges_equal(nx.min_weight_matching(G), answer) + + def test_nasty_blossom_least_slack(self): + """Create blossom, relabel as T, expand such that a new + least-slack S-to-free dge is produced, augment: + """ + G = nx.Graph() + G.add_weighted_edges_from( + [ + (1, 2, 45), + (1, 5, 45), + (2, 3, 50), + (3, 4, 45), + (4, 5, 50), + (1, 6, 30), + (3, 9, 35), + (4, 8, 28), + (5, 7, 26), + (9, 10, 5), + ] + ) + ans = {1: 6, 2: 3, 3: 2, 4: 8, 5: 7, 6: 1, 7: 5, 8: 4, 9: 10, 10: 9} + answer = matching_dict_to_set(ans) + assert edges_equal(nx.max_weight_matching(G), answer) + assert edges_equal(nx.min_weight_matching(G), answer) + + def test_nasty_blossom_augmenting(self): + """Create nested blossom, relabel as T in more than one way""" + # expand outer blossom such that inner blossom ends up on an + # augmenting path: + G = nx.Graph() + G.add_weighted_edges_from( + [ + (1, 2, 45), + (1, 7, 45), + (2, 3, 50), + (3, 4, 45), + (4, 5, 95), + (4, 6, 94), + (5, 6, 94), + (6, 7, 50), + (1, 8, 30), + (3, 11, 35), + (5, 9, 36), + (7, 10, 26), + (11, 12, 5), + ] + ) + ans = { + 1: 8, + 2: 3, + 3: 2, + 4: 6, + 5: 9, + 6: 4, + 7: 10, + 8: 1, + 9: 5, + 10: 7, + 11: 12, + 12: 11, + } + answer = matching_dict_to_set(ans) + assert edges_equal(nx.max_weight_matching(G), answer) + assert edges_equal(nx.min_weight_matching(G), answer) + + def test_nasty_blossom_expand_recursively(self): + """Create nested S-blossom, relabel as S, expand recursively:""" + G = nx.Graph() + G.add_weighted_edges_from( + [ + (1, 2, 40), + (1, 3, 40), + (2, 3, 60), + (2, 4, 55), + (3, 5, 55), + (4, 5, 50), + (1, 8, 15), + (5, 7, 30), + (7, 6, 10), + (8, 10, 10), + (4, 9, 30), + ] + ) + ans = {1: 2, 2: 1, 3: 5, 4: 9, 5: 3, 6: 7, 7: 6, 8: 10, 9: 4, 10: 8} + answer = matching_dict_to_set(ans) + assert edges_equal(nx.max_weight_matching(G), answer) + assert edges_equal(nx.min_weight_matching(G), answer) + + def test_wrong_graph_type(self): + error = nx.NetworkXNotImplemented + raises(error, nx.max_weight_matching, nx.MultiGraph()) + raises(error, nx.max_weight_matching, nx.MultiDiGraph()) + raises(error, nx.max_weight_matching, nx.DiGraph()) + raises(error, nx.min_weight_matching, nx.DiGraph()) + + +class TestIsMatching: + """Unit tests for the + :func:`~networkx.algorithms.matching.is_matching` function. + + """ + + def test_dict(self): + G = nx.path_graph(4) + assert nx.is_matching(G, {0: 1, 1: 0, 2: 3, 3: 2}) + + def test_empty_matching(self): + G = nx.path_graph(4) + assert nx.is_matching(G, set()) + + def test_single_edge(self): + G = nx.path_graph(4) + assert nx.is_matching(G, {(1, 2)}) + + def test_edge_order(self): + G = nx.path_graph(4) + assert nx.is_matching(G, {(0, 1), (2, 3)}) + assert nx.is_matching(G, {(1, 0), (2, 3)}) + assert nx.is_matching(G, {(0, 1), (3, 2)}) + assert nx.is_matching(G, {(1, 0), (3, 2)}) + + def test_valid_matching(self): + G = nx.path_graph(4) + assert nx.is_matching(G, {(0, 1), (2, 3)}) + + def test_invalid_input(self): + error = nx.NetworkXError + G = nx.path_graph(4) + # edge to node not in G + raises(error, nx.is_matching, G, {(0, 5), (2, 3)}) + # edge not a 2-tuple + raises(error, nx.is_matching, G, {(0, 1, 2), (2, 3)}) + raises(error, nx.is_matching, G, {(0,), (2, 3)}) + + def test_selfloops(self): + error = nx.NetworkXError + G = nx.path_graph(4) + # selfloop for node not in G + raises(error, nx.is_matching, G, {(5, 5), (2, 3)}) + # selfloop edge not in G + assert not nx.is_matching(G, {(0, 0), (1, 2), (2, 3)}) + # selfloop edge in G + G.add_edge(0, 0) + assert not nx.is_matching(G, {(0, 0), (1, 2)}) + + def test_invalid_matching(self): + G = nx.path_graph(4) + assert not nx.is_matching(G, {(0, 1), (1, 2), (2, 3)}) + + def test_invalid_edge(self): + G = nx.path_graph(4) + assert not nx.is_matching(G, {(0, 3), (1, 2)}) + raises(nx.NetworkXError, nx.is_matching, G, {(0, 55)}) + + G = nx.DiGraph(G.edges) + assert nx.is_matching(G, {(0, 1)}) + assert not nx.is_matching(G, {(1, 0)}) + + +class TestIsMaximalMatching: + """Unit tests for the + :func:`~networkx.algorithms.matching.is_maximal_matching` function. + + """ + + def test_dict(self): + G = nx.path_graph(4) + assert nx.is_maximal_matching(G, {0: 1, 1: 0, 2: 3, 3: 2}) + + def test_invalid_input(self): + error = nx.NetworkXError + G = nx.path_graph(4) + # edge to node not in G + raises(error, nx.is_maximal_matching, G, {(0, 5)}) + raises(error, nx.is_maximal_matching, G, {(5, 0)}) + # edge not a 2-tuple + raises(error, nx.is_maximal_matching, G, {(0, 1, 2), (2, 3)}) + raises(error, nx.is_maximal_matching, G, {(0,), (2, 3)}) + + def test_valid(self): + G = nx.path_graph(4) + assert nx.is_maximal_matching(G, {(0, 1), (2, 3)}) + + def test_not_matching(self): + G = nx.path_graph(4) + assert not nx.is_maximal_matching(G, {(0, 1), (1, 2), (2, 3)}) + assert not nx.is_maximal_matching(G, {(0, 3)}) + G.add_edge(0, 0) + assert not nx.is_maximal_matching(G, {(0, 0)}) + + def test_not_maximal(self): + G = nx.path_graph(4) + assert not nx.is_maximal_matching(G, {(0, 1)}) + + +class TestIsPerfectMatching: + """Unit tests for the + :func:`~networkx.algorithms.matching.is_perfect_matching` function. + + """ + + def test_dict(self): + G = nx.path_graph(4) + assert nx.is_perfect_matching(G, {0: 1, 1: 0, 2: 3, 3: 2}) + + def test_valid(self): + G = nx.path_graph(4) + assert nx.is_perfect_matching(G, {(0, 1), (2, 3)}) + + def test_valid_not_path(self): + G = nx.cycle_graph(4) + G.add_edge(0, 4) + G.add_edge(1, 4) + G.add_edge(5, 2) + + assert nx.is_perfect_matching(G, {(1, 4), (0, 3), (5, 2)}) + + def test_invalid_input(self): + error = nx.NetworkXError + G = nx.path_graph(4) + # edge to node not in G + raises(error, nx.is_perfect_matching, G, {(0, 5)}) + raises(error, nx.is_perfect_matching, G, {(5, 0)}) + # edge not a 2-tuple + raises(error, nx.is_perfect_matching, G, {(0, 1, 2), (2, 3)}) + raises(error, nx.is_perfect_matching, G, {(0,), (2, 3)}) + + def test_selfloops(self): + error = nx.NetworkXError + G = nx.path_graph(4) + # selfloop for node not in G + raises(error, nx.is_perfect_matching, G, {(5, 5), (2, 3)}) + # selfloop edge not in G + assert not nx.is_perfect_matching(G, {(0, 0), (1, 2), (2, 3)}) + # selfloop edge in G + G.add_edge(0, 0) + assert not nx.is_perfect_matching(G, {(0, 0), (1, 2)}) + + def test_not_matching(self): + G = nx.path_graph(4) + assert not nx.is_perfect_matching(G, {(0, 3)}) + assert not nx.is_perfect_matching(G, {(0, 1), (1, 2), (2, 3)}) + + def test_maximal_but_not_perfect(self): + G = nx.cycle_graph(4) + G.add_edge(0, 4) + G.add_edge(1, 4) + + assert not nx.is_perfect_matching(G, {(1, 4), (0, 3)}) + + +class TestMaximalMatching: + """Unit tests for the + :func:`~networkx.algorithms.matching.maximal_matching`. + + """ + + def test_valid_matching(self): + edges = [(1, 2), (1, 5), (2, 3), (2, 5), (3, 4), (3, 6), (5, 6)] + G = nx.Graph(edges) + matching = nx.maximal_matching(G) + assert nx.is_maximal_matching(G, matching) + + def test_single_edge_matching(self): + # In the star graph, any maximal matching has just one edge. + G = nx.star_graph(5) + matching = nx.maximal_matching(G) + assert 1 == len(matching) + assert nx.is_maximal_matching(G, matching) + + def test_self_loops(self): + # Create the path graph with two self-loops. + G = nx.path_graph(3) + G.add_edges_from([(0, 0), (1, 1)]) + matching = nx.maximal_matching(G) + assert len(matching) == 1 + # The matching should never include self-loops. + assert not any(u == v for u, v in matching) + assert nx.is_maximal_matching(G, matching) + + def test_ordering(self): + """Tests that a maximal matching is computed correctly + regardless of the order in which nodes are added to the graph. + + """ + for nodes in permutations(range(3)): + G = nx.Graph() + G.add_nodes_from(nodes) + G.add_edges_from([(0, 1), (0, 2)]) + matching = nx.maximal_matching(G) + assert len(matching) == 1 + assert nx.is_maximal_matching(G, matching) + + def test_wrong_graph_type(self): + error = nx.NetworkXNotImplemented + raises(error, nx.maximal_matching, nx.MultiGraph()) + raises(error, nx.maximal_matching, nx.MultiDiGraph()) + raises(error, nx.maximal_matching, nx.DiGraph()) diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/test_max_weight_clique.py b/lib/python3.10/site-packages/networkx/algorithms/tests/test_max_weight_clique.py new file mode 100644 index 0000000000000000000000000000000000000000..6cd8584ebd5c2ab04741234018a976472a92ef91 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/tests/test_max_weight_clique.py @@ -0,0 +1,179 @@ +"""Maximum weight clique test suite.""" + +import pytest + +import networkx as nx + + +class TestMaximumWeightClique: + def test_basic_cases(self): + def check_basic_case(graph_func, expected_weight, weight_accessor): + graph = graph_func() + clique, weight = nx.algorithms.max_weight_clique(graph, weight_accessor) + assert verify_clique( + graph, clique, weight, expected_weight, weight_accessor + ) + + for graph_func, (expected_weight, expected_size) in TEST_CASES.items(): + check_basic_case(graph_func, expected_weight, "weight") + check_basic_case(graph_func, expected_size, None) + + def test_key_error(self): + graph = two_node_graph() + with pytest.raises(KeyError): + nx.algorithms.max_weight_clique(graph, "nonexistent-key") + + def test_error_on_non_integer_weight(self): + graph = two_node_graph() + graph.nodes[2]["weight"] = 1.5 + with pytest.raises(ValueError): + nx.algorithms.max_weight_clique(graph) + + def test_unaffected_by_self_loops(self): + graph = two_node_graph() + graph.add_edge(1, 1) + graph.add_edge(2, 2) + clique, weight = nx.algorithms.max_weight_clique(graph, "weight") + assert verify_clique(graph, clique, weight, 30, "weight") + graph = three_node_independent_set() + graph.add_edge(1, 1) + clique, weight = nx.algorithms.max_weight_clique(graph, "weight") + assert verify_clique(graph, clique, weight, 20, "weight") + + def test_30_node_prob(self): + G = nx.Graph() + G.add_nodes_from(range(1, 31)) + for i in range(1, 31): + G.nodes[i]["weight"] = i + 1 + # fmt: off + G.add_edges_from( + [ + (1, 12), (1, 13), (1, 15), (1, 16), (1, 18), (1, 19), (1, 20), + (1, 23), (1, 26), (1, 28), (1, 29), (1, 30), (2, 3), (2, 4), + (2, 5), (2, 8), (2, 9), (2, 10), (2, 14), (2, 17), (2, 18), + (2, 21), (2, 22), (2, 23), (2, 27), (3, 9), (3, 15), (3, 21), + (3, 22), (3, 23), (3, 24), (3, 27), (3, 28), (3, 29), (4, 5), + (4, 6), (4, 8), (4, 21), (4, 22), (4, 23), (4, 26), (4, 28), + (4, 30), (5, 6), (5, 8), (5, 9), (5, 13), (5, 14), (5, 15), + (5, 16), (5, 20), (5, 21), (5, 22), (5, 25), (5, 28), (5, 29), + (6, 7), (6, 8), (6, 13), (6, 17), (6, 18), (6, 19), (6, 24), + (6, 26), (6, 27), (6, 28), (6, 29), (7, 12), (7, 14), (7, 15), + (7, 16), (7, 17), (7, 20), (7, 25), (7, 27), (7, 29), (7, 30), + (8, 10), (8, 15), (8, 16), (8, 18), (8, 20), (8, 22), (8, 24), + (8, 26), (8, 27), (8, 28), (8, 30), (9, 11), (9, 12), (9, 13), + (9, 14), (9, 15), (9, 16), (9, 19), (9, 20), (9, 21), (9, 24), + (9, 30), (10, 12), (10, 15), (10, 18), (10, 19), (10, 20), + (10, 22), (10, 23), (10, 24), (10, 26), (10, 27), (10, 29), + (10, 30), (11, 13), (11, 15), (11, 16), (11, 17), (11, 18), + (11, 19), (11, 20), (11, 22), (11, 29), (11, 30), (12, 14), + (12, 17), (12, 18), (12, 19), (12, 20), (12, 21), (12, 23), + (12, 25), (12, 26), (12, 30), (13, 20), (13, 22), (13, 23), + (13, 24), (13, 30), (14, 16), (14, 20), (14, 21), (14, 22), + (14, 23), (14, 25), (14, 26), (14, 27), (14, 29), (14, 30), + (15, 17), (15, 18), (15, 20), (15, 21), (15, 26), (15, 27), + (15, 28), (16, 17), (16, 18), (16, 19), (16, 20), (16, 21), + (16, 29), (16, 30), (17, 18), (17, 21), (17, 22), (17, 25), + (17, 27), (17, 28), (17, 30), (18, 19), (18, 20), (18, 21), + (18, 22), (18, 23), (18, 24), (19, 20), (19, 22), (19, 23), + (19, 24), (19, 25), (19, 27), (19, 30), (20, 21), (20, 23), + (20, 24), (20, 26), (20, 28), (20, 29), (21, 23), (21, 26), + (21, 27), (21, 29), (22, 24), (22, 25), (22, 26), (22, 29), + (23, 25), (23, 30), (24, 25), (24, 26), (25, 27), (25, 29), + (26, 27), (26, 28), (26, 30), (28, 29), (29, 30), + ] + ) + # fmt: on + clique, weight = nx.algorithms.max_weight_clique(G) + assert verify_clique(G, clique, weight, 111, "weight") + + +# ############################ Utility functions ############################ +def verify_clique( + graph, clique, reported_clique_weight, expected_clique_weight, weight_accessor +): + for node1 in clique: + for node2 in clique: + if node1 == node2: + continue + if not graph.has_edge(node1, node2): + return False + + if weight_accessor is None: + clique_weight = len(clique) + else: + clique_weight = sum(graph.nodes[v]["weight"] for v in clique) + + if clique_weight != expected_clique_weight: + return False + if clique_weight != reported_clique_weight: + return False + + return True + + +# ############################ Graph Generation ############################ + + +def empty_graph(): + return nx.Graph() + + +def one_node_graph(): + graph = nx.Graph() + graph.add_nodes_from([1]) + graph.nodes[1]["weight"] = 10 + return graph + + +def two_node_graph(): + graph = nx.Graph() + graph.add_nodes_from([1, 2]) + graph.add_edges_from([(1, 2)]) + graph.nodes[1]["weight"] = 10 + graph.nodes[2]["weight"] = 20 + return graph + + +def three_node_clique(): + graph = nx.Graph() + graph.add_nodes_from([1, 2, 3]) + graph.add_edges_from([(1, 2), (1, 3), (2, 3)]) + graph.nodes[1]["weight"] = 10 + graph.nodes[2]["weight"] = 20 + graph.nodes[3]["weight"] = 5 + return graph + + +def three_node_independent_set(): + graph = nx.Graph() + graph.add_nodes_from([1, 2, 3]) + graph.nodes[1]["weight"] = 10 + graph.nodes[2]["weight"] = 20 + graph.nodes[3]["weight"] = 5 + return graph + + +def disconnected(): + graph = nx.Graph() + graph.add_edges_from([(1, 2), (2, 3), (4, 5), (5, 6)]) + graph.nodes[1]["weight"] = 10 + graph.nodes[2]["weight"] = 20 + graph.nodes[3]["weight"] = 5 + graph.nodes[4]["weight"] = 100 + graph.nodes[5]["weight"] = 200 + graph.nodes[6]["weight"] = 50 + return graph + + +# -------------------------------------------------------------------------- +# Basic tests for all strategies +# For each basic graph function, specify expected weight of max weight clique +# and expected size of maximum clique +TEST_CASES = { + empty_graph: (0, 0), + one_node_graph: (10, 1), + two_node_graph: (30, 2), + three_node_clique: (35, 3), + three_node_independent_set: (20, 1), + disconnected: (300, 2), +} diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/test_mis.py b/lib/python3.10/site-packages/networkx/algorithms/tests/test_mis.py new file mode 100644 index 0000000000000000000000000000000000000000..02be02d4c33f233d27d2838e5e3d361c4212c40b --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/tests/test_mis.py @@ -0,0 +1,62 @@ +""" +Tests for maximal (not maximum) independent sets. + +""" + +import random + +import pytest + +import networkx as nx + + +def test_random_seed(): + G = nx.empty_graph(5) + assert nx.maximal_independent_set(G, seed=1) == [1, 0, 3, 2, 4] + + +@pytest.mark.parametrize("graph", [nx.complete_graph(5), nx.complete_graph(55)]) +def test_K5(graph): + """Maximal independent set for complete graphs""" + assert all(nx.maximal_independent_set(graph, [n]) == [n] for n in graph) + + +def test_exceptions(): + """Bad input should raise exception.""" + G = nx.florentine_families_graph() + pytest.raises(nx.NetworkXUnfeasible, nx.maximal_independent_set, G, ["Smith"]) + pytest.raises( + nx.NetworkXUnfeasible, nx.maximal_independent_set, G, ["Salviati", "Pazzi"] + ) + # MaximalIndependentSet is not implemented for directed graphs + pytest.raises(nx.NetworkXNotImplemented, nx.maximal_independent_set, nx.DiGraph(G)) + + +def test_florentine_family(): + G = nx.florentine_families_graph() + indep = nx.maximal_independent_set(G, ["Medici", "Bischeri"]) + assert set(indep) == { + "Medici", + "Bischeri", + "Castellani", + "Pazzi", + "Ginori", + "Lamberteschi", + } + + +def test_bipartite(): + G = nx.complete_bipartite_graph(12, 34) + indep = nx.maximal_independent_set(G, [4, 5, 9, 10]) + assert sorted(indep) == list(range(12)) + + +def test_random_graphs(): + """Generate 5 random graphs of different types and sizes and + make sure that all sets are independent and maximal.""" + for i in range(0, 50, 10): + G = nx.erdos_renyi_graph(i * 10 + 1, random.random()) + IS = nx.maximal_independent_set(G) + assert G.subgraph(IS).number_of_edges() == 0 + nbrs_of_MIS = set.union(*(set(G.neighbors(v)) for v in IS)) + assert all(v in nbrs_of_MIS for v in set(G.nodes()).difference(IS)) diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/test_moral.py b/lib/python3.10/site-packages/networkx/algorithms/tests/test_moral.py new file mode 100644 index 0000000000000000000000000000000000000000..fc98c9729a95897857013ae22333e3b8c17202fb --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/tests/test_moral.py @@ -0,0 +1,15 @@ +import networkx as nx +from networkx.algorithms.moral import moral_graph + + +def test_get_moral_graph(): + graph = nx.DiGraph() + graph.add_nodes_from([1, 2, 3, 4, 5, 6, 7]) + graph.add_edges_from([(1, 2), (3, 2), (4, 1), (4, 5), (6, 5), (7, 5)]) + H = moral_graph(graph) + assert not H.is_directed() + assert H.has_edge(1, 3) + assert H.has_edge(4, 6) + assert H.has_edge(6, 7) + assert H.has_edge(4, 7) + assert not H.has_edge(1, 5) diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/test_node_classification.py b/lib/python3.10/site-packages/networkx/algorithms/tests/test_node_classification.py new file mode 100644 index 0000000000000000000000000000000000000000..2e1fc79d48ae830625c3528f52e805d2e0d183ad --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/tests/test_node_classification.py @@ -0,0 +1,140 @@ +import pytest + +pytest.importorskip("numpy") +pytest.importorskip("scipy") + +import networkx as nx +from networkx.algorithms import node_classification + + +class TestHarmonicFunction: + def test_path_graph(self): + G = nx.path_graph(4) + label_name = "label" + G.nodes[0][label_name] = "A" + G.nodes[3][label_name] = "B" + predicted = node_classification.harmonic_function(G, label_name=label_name) + assert predicted[0] == "A" + assert predicted[1] == "A" + assert predicted[2] == "B" + assert predicted[3] == "B" + + def test_no_labels(self): + with pytest.raises(nx.NetworkXError): + G = nx.path_graph(4) + node_classification.harmonic_function(G) + + def test_no_nodes(self): + with pytest.raises(nx.NetworkXError): + G = nx.Graph() + node_classification.harmonic_function(G) + + def test_no_edges(self): + with pytest.raises(nx.NetworkXError): + G = nx.Graph() + G.add_node(1) + G.add_node(2) + node_classification.harmonic_function(G) + + def test_digraph(self): + with pytest.raises(nx.NetworkXNotImplemented): + G = nx.DiGraph() + G.add_edge(0, 1) + G.add_edge(1, 2) + G.add_edge(2, 3) + label_name = "label" + G.nodes[0][label_name] = "A" + G.nodes[3][label_name] = "B" + node_classification.harmonic_function(G) + + def test_one_labeled_node(self): + G = nx.path_graph(4) + label_name = "label" + G.nodes[0][label_name] = "A" + predicted = node_classification.harmonic_function(G, label_name=label_name) + assert predicted[0] == "A" + assert predicted[1] == "A" + assert predicted[2] == "A" + assert predicted[3] == "A" + + def test_nodes_all_labeled(self): + G = nx.karate_club_graph() + label_name = "club" + predicted = node_classification.harmonic_function(G, label_name=label_name) + for i in range(len(G)): + assert predicted[i] == G.nodes[i][label_name] + + def test_labeled_nodes_are_not_changed(self): + G = nx.karate_club_graph() + label_name = "club" + label_removed = {0, 1, 2, 3, 4, 5, 6, 7} + for i in label_removed: + del G.nodes[i][label_name] + predicted = node_classification.harmonic_function(G, label_name=label_name) + label_not_removed = set(range(len(G))) - label_removed + for i in label_not_removed: + assert predicted[i] == G.nodes[i][label_name] + + +class TestLocalAndGlobalConsistency: + def test_path_graph(self): + G = nx.path_graph(4) + label_name = "label" + G.nodes[0][label_name] = "A" + G.nodes[3][label_name] = "B" + predicted = node_classification.local_and_global_consistency( + G, label_name=label_name + ) + assert predicted[0] == "A" + assert predicted[1] == "A" + assert predicted[2] == "B" + assert predicted[3] == "B" + + def test_no_labels(self): + with pytest.raises(nx.NetworkXError): + G = nx.path_graph(4) + node_classification.local_and_global_consistency(G) + + def test_no_nodes(self): + with pytest.raises(nx.NetworkXError): + G = nx.Graph() + node_classification.local_and_global_consistency(G) + + def test_no_edges(self): + with pytest.raises(nx.NetworkXError): + G = nx.Graph() + G.add_node(1) + G.add_node(2) + node_classification.local_and_global_consistency(G) + + def test_digraph(self): + with pytest.raises(nx.NetworkXNotImplemented): + G = nx.DiGraph() + G.add_edge(0, 1) + G.add_edge(1, 2) + G.add_edge(2, 3) + label_name = "label" + G.nodes[0][label_name] = "A" + G.nodes[3][label_name] = "B" + node_classification.harmonic_function(G) + + def test_one_labeled_node(self): + G = nx.path_graph(4) + label_name = "label" + G.nodes[0][label_name] = "A" + predicted = node_classification.local_and_global_consistency( + G, label_name=label_name + ) + assert predicted[0] == "A" + assert predicted[1] == "A" + assert predicted[2] == "A" + assert predicted[3] == "A" + + def test_nodes_all_labeled(self): + G = nx.karate_club_graph() + label_name = "club" + predicted = node_classification.local_and_global_consistency( + G, alpha=0, label_name=label_name + ) + for i in range(len(G)): + assert predicted[i] == G.nodes[i][label_name] diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/test_non_randomness.py b/lib/python3.10/site-packages/networkx/algorithms/tests/test_non_randomness.py new file mode 100644 index 0000000000000000000000000000000000000000..2f495be28d57db2368b9412e93e1d137f2cac86b --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/tests/test_non_randomness.py @@ -0,0 +1,42 @@ +import pytest + +import networkx as nx + +np = pytest.importorskip("numpy") + + +@pytest.mark.parametrize( + "k, weight, expected", + [ + (None, None, 7.21), # infers 3 communities + (2, None, 11.7), + (None, "weight", 25.45), + (2, "weight", 38.8), + ], +) +def test_non_randomness(k, weight, expected): + G = nx.karate_club_graph() + np.testing.assert_almost_equal( + nx.non_randomness(G, k, weight)[0], expected, decimal=2 + ) + + +def test_non_connected(): + G = nx.Graph([(1, 2)]) + G.add_node(3) + with pytest.raises(nx.NetworkXException, match="Non connected"): + nx.non_randomness(G) + + +def test_self_loops(): + G = nx.Graph() + G.add_edge(1, 2) + G.add_edge(1, 1) + with pytest.raises(nx.NetworkXError, match="Graph must not contain self-loops"): + nx.non_randomness(G) + + +def test_empty_graph(): + G = nx.empty_graph(1) + with pytest.raises(nx.NetworkXError, match=".*not applicable to empty graphs"): + nx.non_randomness(G) diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/test_planar_drawing.py b/lib/python3.10/site-packages/networkx/algorithms/tests/test_planar_drawing.py new file mode 100644 index 0000000000000000000000000000000000000000..a5de0e0324c49ab2e194a9d25ca712c2de1e4947 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/tests/test_planar_drawing.py @@ -0,0 +1,274 @@ +import math + +import pytest + +import networkx as nx +from networkx.algorithms.planar_drawing import triangulate_embedding + + +def test_graph1(): + embedding_data = {0: [1, 2, 3], 1: [2, 0], 2: [3, 0, 1], 3: [2, 0]} + check_embedding_data(embedding_data) + + +def test_graph2(): + embedding_data = { + 0: [8, 6], + 1: [2, 6, 9], + 2: [8, 1, 7, 9, 6, 4], + 3: [9], + 4: [2], + 5: [6, 8], + 6: [9, 1, 0, 5, 2], + 7: [9, 2], + 8: [0, 2, 5], + 9: [1, 6, 2, 7, 3], + } + check_embedding_data(embedding_data) + + +def test_circle_graph(): + embedding_data = { + 0: [1, 9], + 1: [0, 2], + 2: [1, 3], + 3: [2, 4], + 4: [3, 5], + 5: [4, 6], + 6: [5, 7], + 7: [6, 8], + 8: [7, 9], + 9: [8, 0], + } + check_embedding_data(embedding_data) + + +def test_grid_graph(): + embedding_data = { + (0, 1): [(0, 0), (1, 1), (0, 2)], + (1, 2): [(1, 1), (2, 2), (0, 2)], + (0, 0): [(0, 1), (1, 0)], + (2, 1): [(2, 0), (2, 2), (1, 1)], + (1, 1): [(2, 1), (1, 2), (0, 1), (1, 0)], + (2, 0): [(1, 0), (2, 1)], + (2, 2): [(1, 2), (2, 1)], + (1, 0): [(0, 0), (2, 0), (1, 1)], + (0, 2): [(1, 2), (0, 1)], + } + check_embedding_data(embedding_data) + + +def test_one_node_graph(): + embedding_data = {0: []} + check_embedding_data(embedding_data) + + +def test_two_node_graph(): + embedding_data = {0: [1], 1: [0]} + check_embedding_data(embedding_data) + + +def test_three_node_graph(): + embedding_data = {0: [1, 2], 1: [0, 2], 2: [0, 1]} + check_embedding_data(embedding_data) + + +def test_multiple_component_graph1(): + embedding_data = {0: [], 1: []} + check_embedding_data(embedding_data) + + +def test_multiple_component_graph2(): + embedding_data = {0: [1, 2], 1: [0, 2], 2: [0, 1], 3: [4, 5], 4: [3, 5], 5: [3, 4]} + check_embedding_data(embedding_data) + + +def test_invalid_half_edge(): + with pytest.raises(nx.NetworkXException): + embedding_data = {1: [2, 3, 4], 2: [1, 3, 4], 3: [1, 2, 4], 4: [1, 2, 3]} + embedding = nx.PlanarEmbedding() + embedding.set_data(embedding_data) + nx.combinatorial_embedding_to_pos(embedding) + + +def test_triangulate_embedding1(): + embedding = nx.PlanarEmbedding() + embedding.add_node(1) + expected_embedding = {1: []} + check_triangulation(embedding, expected_embedding) + + +def test_triangulate_embedding2(): + embedding = nx.PlanarEmbedding() + embedding.connect_components(1, 2) + expected_embedding = {1: [2], 2: [1]} + check_triangulation(embedding, expected_embedding) + + +def check_triangulation(embedding, expected_embedding): + res_embedding, _ = triangulate_embedding(embedding, True) + assert ( + res_embedding.get_data() == expected_embedding + ), "Expected embedding incorrect" + res_embedding, _ = triangulate_embedding(embedding, False) + assert ( + res_embedding.get_data() == expected_embedding + ), "Expected embedding incorrect" + + +def check_embedding_data(embedding_data): + """Checks that the planar embedding of the input is correct""" + embedding = nx.PlanarEmbedding() + embedding.set_data(embedding_data) + pos_fully = nx.combinatorial_embedding_to_pos(embedding, False) + msg = "Planar drawing does not conform to the embedding (fully triangulation)" + assert planar_drawing_conforms_to_embedding(embedding, pos_fully), msg + check_edge_intersections(embedding, pos_fully) + pos_internally = nx.combinatorial_embedding_to_pos(embedding, True) + msg = "Planar drawing does not conform to the embedding (internal triangulation)" + assert planar_drawing_conforms_to_embedding(embedding, pos_internally), msg + check_edge_intersections(embedding, pos_internally) + + +def is_close(a, b, rel_tol=1e-09, abs_tol=0.0): + # Check if float numbers are basically equal, for python >=3.5 there is + # function for that in the standard library + return abs(a - b) <= max(rel_tol * max(abs(a), abs(b)), abs_tol) + + +def point_in_between(a, b, p): + # checks if p is on the line between a and b + x1, y1 = a + x2, y2 = b + px, py = p + dist_1_2 = math.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2) + dist_1_p = math.sqrt((x1 - px) ** 2 + (y1 - py) ** 2) + dist_2_p = math.sqrt((x2 - px) ** 2 + (y2 - py) ** 2) + return is_close(dist_1_p + dist_2_p, dist_1_2) + + +def check_edge_intersections(G, pos): + """Check all edges in G for intersections. + + Raises an exception if an intersection is found. + + Parameters + ---------- + G : NetworkX graph + pos : dict + Maps every node to a tuple (x, y) representing its position + + """ + for a, b in G.edges(): + for c, d in G.edges(): + # Check if end points are different + if a != c and b != d and b != c and a != d: + x1, y1 = pos[a] + x2, y2 = pos[b] + x3, y3 = pos[c] + x4, y4 = pos[d] + determinant = (x1 - x2) * (y3 - y4) - (y1 - y2) * (x3 - x4) + if determinant != 0: # the lines are not parallel + # calculate intersection point, see: + # https://en.wikipedia.org/wiki/Line%E2%80%93line_intersection + px = (x1 * y2 - y1 * x2) * (x3 - x4) - (x1 - x2) * ( + x3 * y4 - y3 * x4 + ) / determinant + py = (x1 * y2 - y1 * x2) * (y3 - y4) - (y1 - y2) * ( + x3 * y4 - y3 * x4 + ) / determinant + + # Check if intersection lies between the points + if point_in_between(pos[a], pos[b], (px, py)) and point_in_between( + pos[c], pos[d], (px, py) + ): + msg = f"There is an intersection at {px},{py}" + raise nx.NetworkXException(msg) + + # Check overlap + msg = "A node lies on a edge connecting two other nodes" + if ( + point_in_between(pos[a], pos[b], pos[c]) + or point_in_between(pos[a], pos[b], pos[d]) + or point_in_between(pos[c], pos[d], pos[a]) + or point_in_between(pos[c], pos[d], pos[b]) + ): + raise nx.NetworkXException(msg) + # No edge intersection found + + +class Vector: + """Compare vectors by their angle without loss of precision + + All vectors in direction [0, 1] are the smallest. + The vectors grow in clockwise direction. + """ + + __slots__ = ["x", "y", "node", "quadrant"] + + def __init__(self, x, y, node): + self.x = x + self.y = y + self.node = node + if self.x >= 0 and self.y > 0: + self.quadrant = 1 + elif self.x > 0 and self.y <= 0: + self.quadrant = 2 + elif self.x <= 0 and self.y < 0: + self.quadrant = 3 + else: + self.quadrant = 4 + + def __eq__(self, other): + return self.quadrant == other.quadrant and self.x * other.y == self.y * other.x + + def __lt__(self, other): + if self.quadrant < other.quadrant: + return True + elif self.quadrant > other.quadrant: + return False + else: + return self.x * other.y < self.y * other.x + + def __ne__(self, other): + return self != other + + def __le__(self, other): + return not other < self + + def __gt__(self, other): + return other < self + + def __ge__(self, other): + return not self < other + + +def planar_drawing_conforms_to_embedding(embedding, pos): + """Checks if pos conforms to the planar embedding + + Returns true iff the neighbors are actually oriented in the orientation + specified of the embedding + """ + for v in embedding: + nbr_vectors = [] + v_pos = pos[v] + for nbr in embedding[v]: + new_vector = Vector(pos[nbr][0] - v_pos[0], pos[nbr][1] - v_pos[1], nbr) + nbr_vectors.append(new_vector) + # Sort neighbors according to their phi angle + nbr_vectors.sort() + for idx, nbr_vector in enumerate(nbr_vectors): + cw_vector = nbr_vectors[(idx + 1) % len(nbr_vectors)] + ccw_vector = nbr_vectors[idx - 1] + if ( + embedding[v][nbr_vector.node]["cw"] != cw_vector.node + or embedding[v][nbr_vector.node]["ccw"] != ccw_vector.node + ): + return False + if cw_vector.node != nbr_vector.node and cw_vector == nbr_vector: + # Lines overlap + return False + if ccw_vector.node != nbr_vector.node and ccw_vector == nbr_vector: + # Lines overlap + return False + return True diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/test_planarity.py b/lib/python3.10/site-packages/networkx/algorithms/tests/test_planarity.py new file mode 100644 index 0000000000000000000000000000000000000000..99bcff4184a9d28b5d68c5787a4ca6664d78a9dd --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/tests/test_planarity.py @@ -0,0 +1,535 @@ +import pytest + +import networkx as nx +from networkx.algorithms.planarity import ( + check_planarity_recursive, + get_counterexample, + get_counterexample_recursive, +) + + +class TestLRPlanarity: + """Nose Unit tests for the :mod:`networkx.algorithms.planarity` module. + + Tests three things: + 1. Check that the result is correct + (returns planar if and only if the graph is actually planar) + 2. In case a counter example is returned: Check if it is correct + 3. In case an embedding is returned: Check if its actually an embedding + """ + + @staticmethod + def check_graph(G, is_planar=None): + """Raises an exception if the lr_planarity check returns a wrong result + + Parameters + ---------- + G : NetworkX graph + is_planar : bool + The expected result of the planarity check. + If set to None only counter example or embedding are verified. + + """ + + # obtain results of planarity check + is_planar_lr, result = nx.check_planarity(G, True) + is_planar_lr_rec, result_rec = check_planarity_recursive(G, True) + + if is_planar is not None: + # set a message for the assert + if is_planar: + msg = "Wrong planarity check result. Should be planar." + else: + msg = "Wrong planarity check result. Should be non-planar." + + # check if the result is as expected + assert is_planar == is_planar_lr, msg + assert is_planar == is_planar_lr_rec, msg + + if is_planar_lr: + # check embedding + check_embedding(G, result) + check_embedding(G, result_rec) + else: + # check counter example + check_counterexample(G, result) + check_counterexample(G, result_rec) + + def test_simple_planar_graph(self): + e = [ + (1, 2), + (2, 3), + (3, 4), + (4, 6), + (6, 7), + (7, 1), + (1, 5), + (5, 2), + (2, 4), + (4, 5), + (5, 7), + ] + self.check_graph(nx.Graph(e), is_planar=True) + + def test_planar_with_selfloop(self): + e = [ + (1, 1), + (2, 2), + (3, 3), + (4, 4), + (5, 5), + (1, 2), + (1, 3), + (1, 5), + (2, 5), + (2, 4), + (3, 4), + (3, 5), + (4, 5), + ] + self.check_graph(nx.Graph(e), is_planar=True) + + def test_k3_3(self): + self.check_graph(nx.complete_bipartite_graph(3, 3), is_planar=False) + + def test_k5(self): + self.check_graph(nx.complete_graph(5), is_planar=False) + + def test_multiple_components_planar(self): + e = [(1, 2), (2, 3), (3, 1), (4, 5), (5, 6), (6, 4)] + self.check_graph(nx.Graph(e), is_planar=True) + + def test_multiple_components_non_planar(self): + G = nx.complete_graph(5) + # add another planar component to the non planar component + # G stays non planar + G.add_edges_from([(6, 7), (7, 8), (8, 6)]) + self.check_graph(G, is_planar=False) + + def test_non_planar_with_selfloop(self): + G = nx.complete_graph(5) + # add self loops + for i in range(5): + G.add_edge(i, i) + self.check_graph(G, is_planar=False) + + def test_non_planar1(self): + # tests a graph that has no subgraph directly isomorph to K5 or K3_3 + e = [ + (1, 5), + (1, 6), + (1, 7), + (2, 6), + (2, 3), + (3, 5), + (3, 7), + (4, 5), + (4, 6), + (4, 7), + ] + self.check_graph(nx.Graph(e), is_planar=False) + + def test_loop(self): + # test a graph with a selfloop + e = [(1, 2), (2, 2)] + G = nx.Graph(e) + self.check_graph(G, is_planar=True) + + def test_comp(self): + # test multiple component graph + e = [(1, 2), (3, 4)] + G = nx.Graph(e) + G.remove_edge(1, 2) + self.check_graph(G, is_planar=True) + + def test_goldner_harary(self): + # test goldner-harary graph (a maximal planar graph) + e = [ + (1, 2), + (1, 3), + (1, 4), + (1, 5), + (1, 7), + (1, 8), + (1, 10), + (1, 11), + (2, 3), + (2, 4), + (2, 6), + (2, 7), + (2, 9), + (2, 10), + (2, 11), + (3, 4), + (4, 5), + (4, 6), + (4, 7), + (5, 7), + (6, 7), + (7, 8), + (7, 9), + (7, 10), + (8, 10), + (9, 10), + (10, 11), + ] + G = nx.Graph(e) + self.check_graph(G, is_planar=True) + + def test_planar_multigraph(self): + G = nx.MultiGraph([(1, 2), (1, 2), (1, 2), (1, 2), (2, 3), (3, 1)]) + self.check_graph(G, is_planar=True) + + def test_non_planar_multigraph(self): + G = nx.MultiGraph(nx.complete_graph(5)) + G.add_edges_from([(1, 2)] * 5) + self.check_graph(G, is_planar=False) + + def test_planar_digraph(self): + G = nx.DiGraph([(1, 2), (2, 3), (2, 4), (4, 1), (4, 2), (1, 4), (3, 2)]) + self.check_graph(G, is_planar=True) + + def test_non_planar_digraph(self): + G = nx.DiGraph(nx.complete_graph(5)) + G.remove_edge(1, 2) + G.remove_edge(4, 1) + self.check_graph(G, is_planar=False) + + def test_single_component(self): + # Test a graph with only a single node + G = nx.Graph() + G.add_node(1) + self.check_graph(G, is_planar=True) + + def test_graph1(self): + G = nx.Graph( + [ + (3, 10), + (2, 13), + (1, 13), + (7, 11), + (0, 8), + (8, 13), + (0, 2), + (0, 7), + (0, 10), + (1, 7), + ] + ) + self.check_graph(G, is_planar=True) + + def test_graph2(self): + G = nx.Graph( + [ + (1, 2), + (4, 13), + (0, 13), + (4, 5), + (7, 10), + (1, 7), + (0, 3), + (2, 6), + (5, 6), + (7, 13), + (4, 8), + (0, 8), + (0, 9), + (2, 13), + (6, 7), + (3, 6), + (2, 8), + ] + ) + self.check_graph(G, is_planar=False) + + def test_graph3(self): + G = nx.Graph( + [ + (0, 7), + (3, 11), + (3, 4), + (8, 9), + (4, 11), + (1, 7), + (1, 13), + (1, 11), + (3, 5), + (5, 7), + (1, 3), + (0, 4), + (5, 11), + (5, 13), + ] + ) + self.check_graph(G, is_planar=False) + + def test_counterexample_planar(self): + with pytest.raises(nx.NetworkXException): + # Try to get a counterexample of a planar graph + G = nx.Graph() + G.add_node(1) + get_counterexample(G) + + def test_counterexample_planar_recursive(self): + with pytest.raises(nx.NetworkXException): + # Try to get a counterexample of a planar graph + G = nx.Graph() + G.add_node(1) + get_counterexample_recursive(G) + + def test_edge_removal_from_planar_embedding(self): + # PlanarEmbedding.check_structure() must succeed after edge removal + edges = ((0, 1), (1, 2), (2, 3), (3, 4), (4, 0), (0, 2), (0, 3)) + G = nx.Graph(edges) + cert, P = nx.check_planarity(G) + assert cert is True + P.remove_edge(0, 2) + self.check_graph(P, is_planar=True) + P.add_half_edge_ccw(1, 3, 2) + P.add_half_edge_cw(3, 1, 2) + self.check_graph(P, is_planar=True) + P.remove_edges_from(((0, 3), (1, 3))) + self.check_graph(P, is_planar=True) + + +def check_embedding(G, embedding): + """Raises an exception if the combinatorial embedding is not correct + + Parameters + ---------- + G : NetworkX graph + embedding : a dict mapping nodes to a list of edges + This specifies the ordering of the outgoing edges from a node for + a combinatorial embedding + + Notes + ----- + Checks the following things: + - The type of the embedding is correct + - The nodes and edges match the original graph + - Every half edge has its matching opposite half edge + - No intersections of edges (checked by Euler's formula) + """ + + if not isinstance(embedding, nx.PlanarEmbedding): + raise nx.NetworkXException("Bad embedding. Not of type nx.PlanarEmbedding") + + # Check structure + embedding.check_structure() + + # Check that graphs are equivalent + + assert set(G.nodes) == set( + embedding.nodes + ), "Bad embedding. Nodes don't match the original graph." + + # Check that the edges are equal + g_edges = set() + for edge in G.edges: + if edge[0] != edge[1]: + g_edges.add((edge[0], edge[1])) + g_edges.add((edge[1], edge[0])) + assert g_edges == set( + embedding.edges + ), "Bad embedding. Edges don't match the original graph." + + +def check_counterexample(G, sub_graph): + """Raises an exception if the counterexample is wrong. + + Parameters + ---------- + G : NetworkX graph + subdivision_nodes : set + A set of nodes inducing a subgraph as a counterexample + """ + # 1. Create the sub graph + sub_graph = nx.Graph(sub_graph) + + # 2. Remove self loops + for u in sub_graph: + if sub_graph.has_edge(u, u): + sub_graph.remove_edge(u, u) + + # keep track of nodes we might need to contract + contract = list(sub_graph) + + # 3. Contract Edges + while len(contract) > 0: + contract_node = contract.pop() + if contract_node not in sub_graph: + # Node was already contracted + continue + degree = sub_graph.degree[contract_node] + # Check if we can remove the node + if degree == 2: + # Get the two neighbors + neighbors = iter(sub_graph[contract_node]) + u = next(neighbors) + v = next(neighbors) + # Save nodes for later + contract.append(u) + contract.append(v) + # Contract edge + sub_graph.remove_node(contract_node) + sub_graph.add_edge(u, v) + + # 4. Check for isomorphism with K5 or K3_3 graphs + if len(sub_graph) == 5: + if not nx.is_isomorphic(nx.complete_graph(5), sub_graph): + raise nx.NetworkXException("Bad counter example.") + elif len(sub_graph) == 6: + if not nx.is_isomorphic(nx.complete_bipartite_graph(3, 3), sub_graph): + raise nx.NetworkXException("Bad counter example.") + else: + raise nx.NetworkXException("Bad counter example.") + + +class TestPlanarEmbeddingClass: + def test_add_half_edge(self): + embedding = nx.PlanarEmbedding() + embedding.add_half_edge(0, 1) + with pytest.raises( + nx.NetworkXException, match="Invalid clockwise reference node." + ): + embedding.add_half_edge(0, 2, cw=3) + with pytest.raises( + nx.NetworkXException, match="Invalid counterclockwise reference node." + ): + embedding.add_half_edge(0, 2, ccw=3) + with pytest.raises( + nx.NetworkXException, match="Only one of cw/ccw can be specified." + ): + embedding.add_half_edge(0, 2, cw=1, ccw=1) + with pytest.raises( + nx.NetworkXException, + match=( + r"Node already has out-half-edge\(s\), either" + " cw or ccw reference node required." + ), + ): + embedding.add_half_edge(0, 2) + # these should work + embedding.add_half_edge(0, 2, cw=1) + embedding.add_half_edge(0, 3, ccw=1) + assert sorted(embedding.edges(data=True)) == [ + (0, 1, {"ccw": 2, "cw": 3}), + (0, 2, {"cw": 1, "ccw": 3}), + (0, 3, {"cw": 2, "ccw": 1}), + ] + + def test_get_data(self): + embedding = self.get_star_embedding(4) + data = embedding.get_data() + data_cmp = {0: [3, 2, 1], 1: [0], 2: [0], 3: [0]} + assert data == data_cmp + + def test_edge_removal(self): + embedding = nx.PlanarEmbedding() + embedding.set_data( + { + 1: [2, 5, 7], + 2: [1, 3, 4, 5], + 3: [2, 4], + 4: [3, 6, 5, 2], + 5: [7, 1, 2, 4], + 6: [4, 7], + 7: [6, 1, 5], + } + ) + # remove_edges_from() calls remove_edge(), so both are tested here + embedding.remove_edges_from(((5, 4), (1, 5))) + embedding.check_structure() + embedding_expected = nx.PlanarEmbedding() + embedding_expected.set_data( + { + 1: [2, 7], + 2: [1, 3, 4, 5], + 3: [2, 4], + 4: [3, 6, 2], + 5: [7, 2], + 6: [4, 7], + 7: [6, 1, 5], + } + ) + assert nx.utils.graphs_equal(embedding, embedding_expected) + + def test_missing_edge_orientation(self): + embedding = nx.PlanarEmbedding({1: {2: {}}, 2: {1: {}}}) + with pytest.raises(nx.NetworkXException): + # Invalid structure because the orientation of the edge was not set + embedding.check_structure() + + def test_invalid_edge_orientation(self): + embedding = nx.PlanarEmbedding( + { + 1: {2: {"cw": 2, "ccw": 2}}, + 2: {1: {"cw": 1, "ccw": 1}}, + 1: {3: {}}, + 3: {1: {}}, + } + ) + with pytest.raises(nx.NetworkXException): + embedding.check_structure() + + def test_missing_half_edge(self): + embedding = nx.PlanarEmbedding() + embedding.add_half_edge(1, 2) + with pytest.raises(nx.NetworkXException): + # Invalid structure because other half edge is missing + embedding.check_structure() + + def test_not_fulfilling_euler_formula(self): + embedding = nx.PlanarEmbedding() + for i in range(5): + ref = None + for j in range(5): + if i != j: + embedding.add_half_edge(i, j, cw=ref) + ref = j + with pytest.raises(nx.NetworkXException): + embedding.check_structure() + + def test_missing_reference(self): + embedding = nx.PlanarEmbedding() + with pytest.raises(nx.NetworkXException, match="Invalid reference node."): + embedding.add_half_edge(1, 2, ccw=3) + + def test_connect_components(self): + embedding = nx.PlanarEmbedding() + embedding.connect_components(1, 2) + + def test_successful_face_traversal(self): + embedding = nx.PlanarEmbedding() + embedding.add_half_edge(1, 2) + embedding.add_half_edge(2, 1) + face = embedding.traverse_face(1, 2) + assert face == [1, 2] + + def test_unsuccessful_face_traversal(self): + embedding = nx.PlanarEmbedding( + {1: {2: {"cw": 3, "ccw": 2}}, 2: {1: {"cw": 3, "ccw": 1}}} + ) + with pytest.raises(nx.NetworkXException): + embedding.traverse_face(1, 2) + + def test_forbidden_methods(self): + embedding = nx.PlanarEmbedding() + embedding.add_node(42) # no exception + embedding.add_nodes_from([(23, 24)]) # no exception + with pytest.raises(NotImplementedError): + embedding.add_edge(1, 3) + with pytest.raises(NotImplementedError): + embedding.add_edges_from([(0, 2), (1, 4)]) + with pytest.raises(NotImplementedError): + embedding.add_weighted_edges_from([(0, 2, 350), (1, 4, 125)]) + + @staticmethod + def get_star_embedding(n): + embedding = nx.PlanarEmbedding() + ref = None + for i in range(1, n): + embedding.add_half_edge(0, i, cw=ref) + ref = i + embedding.add_half_edge(i, 0) + return embedding diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/test_polynomials.py b/lib/python3.10/site-packages/networkx/algorithms/tests/test_polynomials.py new file mode 100644 index 0000000000000000000000000000000000000000..a81d6a69551ead74d3335fda408111a0b580bf6a --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/tests/test_polynomials.py @@ -0,0 +1,57 @@ +"""Unit tests for the :mod:`networkx.algorithms.polynomials` module.""" + +import pytest + +import networkx as nx + +sympy = pytest.importorskip("sympy") + + +# Mapping of input graphs to a string representation of their tutte polynomials +_test_tutte_graphs = { + nx.complete_graph(1): "1", + nx.complete_graph(4): "x**3 + 3*x**2 + 4*x*y + 2*x + y**3 + 3*y**2 + 2*y", + nx.cycle_graph(5): "x**4 + x**3 + x**2 + x + y", + nx.diamond_graph(): "x**3 + 2*x**2 + 2*x*y + x + y**2 + y", +} + +_test_chromatic_graphs = { + nx.complete_graph(1): "x", + nx.complete_graph(4): "x**4 - 6*x**3 + 11*x**2 - 6*x", + nx.cycle_graph(5): "x**5 - 5*x**4 + 10*x**3 - 10*x**2 + 4*x", + nx.diamond_graph(): "x**4 - 5*x**3 + 8*x**2 - 4*x", + nx.path_graph(5): "x**5 - 4*x**4 + 6*x**3 - 4*x**2 + x", +} + + +@pytest.mark.parametrize(("G", "expected"), _test_tutte_graphs.items()) +def test_tutte_polynomial(G, expected): + assert nx.tutte_polynomial(G).equals(expected) + + +@pytest.mark.parametrize("G", _test_tutte_graphs.keys()) +def test_tutte_polynomial_disjoint(G): + """Tutte polynomial factors into the Tutte polynomials of its components. + Verify this property with the disjoint union of two copies of the input graph. + """ + t_g = nx.tutte_polynomial(G) + H = nx.disjoint_union(G, G) + t_h = nx.tutte_polynomial(H) + assert sympy.simplify(t_g * t_g).equals(t_h) + + +@pytest.mark.parametrize(("G", "expected"), _test_chromatic_graphs.items()) +def test_chromatic_polynomial(G, expected): + assert nx.chromatic_polynomial(G).equals(expected) + + +@pytest.mark.parametrize("G", _test_chromatic_graphs.keys()) +def test_chromatic_polynomial_disjoint(G): + """Chromatic polynomial factors into the Chromatic polynomials of its + components. Verify this property with the disjoint union of two copies of + the input graph. + """ + x_g = nx.chromatic_polynomial(G) + H = nx.disjoint_union(G, G) + x_h = nx.chromatic_polynomial(H) + assert sympy.simplify(x_g * x_g).equals(x_h) diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/test_reciprocity.py b/lib/python3.10/site-packages/networkx/algorithms/tests/test_reciprocity.py new file mode 100644 index 0000000000000000000000000000000000000000..e713bc4303f9bfea1199f01d8369c6bdab1a221f --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/tests/test_reciprocity.py @@ -0,0 +1,37 @@ +import pytest + +import networkx as nx + + +class TestReciprocity: + # test overall reciprocity by passing whole graph + def test_reciprocity_digraph(self): + DG = nx.DiGraph([(1, 2), (2, 1)]) + reciprocity = nx.reciprocity(DG) + assert reciprocity == 1.0 + + # test empty graph's overall reciprocity which will throw an error + def test_overall_reciprocity_empty_graph(self): + with pytest.raises(nx.NetworkXError): + DG = nx.DiGraph() + nx.overall_reciprocity(DG) + + # test for reciprocity for a list of nodes + def test_reciprocity_graph_nodes(self): + DG = nx.DiGraph([(1, 2), (2, 3), (3, 2)]) + reciprocity = nx.reciprocity(DG, [1, 2]) + expected_reciprocity = {1: 0.0, 2: 0.6666666666666666} + assert reciprocity == expected_reciprocity + + # test for reciprocity for a single node + def test_reciprocity_graph_node(self): + DG = nx.DiGraph([(1, 2), (2, 3), (3, 2)]) + reciprocity = nx.reciprocity(DG, 2) + assert reciprocity == 0.6666666666666666 + + # test for reciprocity for an isolated node + def test_reciprocity_graph_isolated_nodes(self): + with pytest.raises(nx.NetworkXError): + DG = nx.DiGraph([(1, 2)]) + DG.add_node(4) + nx.reciprocity(DG, 4) diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/test_regular.py b/lib/python3.10/site-packages/networkx/algorithms/tests/test_regular.py new file mode 100644 index 0000000000000000000000000000000000000000..a8b4c3a30de612f91b4739fd35bc9ba06ab292ce --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/tests/test_regular.py @@ -0,0 +1,92 @@ +import pytest + +import networkx +import networkx as nx +import networkx.algorithms.regular as reg +import networkx.generators as gen + + +class TestKFactor: + def test_k_factor_trivial(self): + g = gen.cycle_graph(4) + f = reg.k_factor(g, 2) + assert g.edges == f.edges + + def test_k_factor1(self): + g = gen.grid_2d_graph(4, 4) + g_kf = reg.k_factor(g, 2) + for edge in g_kf.edges(): + assert g.has_edge(edge[0], edge[1]) + for _, degree in g_kf.degree(): + assert degree == 2 + + def test_k_factor2(self): + g = gen.complete_graph(6) + g_kf = reg.k_factor(g, 3) + for edge in g_kf.edges(): + assert g.has_edge(edge[0], edge[1]) + for _, degree in g_kf.degree(): + assert degree == 3 + + def test_k_factor3(self): + g = gen.grid_2d_graph(4, 4) + with pytest.raises(nx.NetworkXUnfeasible): + reg.k_factor(g, 3) + + def test_k_factor4(self): + g = gen.lattice.hexagonal_lattice_graph(4, 4) + # Perfect matching doesn't exist for 4,4 hexagonal lattice graph + with pytest.raises(nx.NetworkXUnfeasible): + reg.k_factor(g, 2) + + def test_k_factor5(self): + g = gen.complete_graph(6) + # small k to exercise SmallKGadget + g_kf = reg.k_factor(g, 2) + for edge in g_kf.edges(): + assert g.has_edge(edge[0], edge[1]) + for _, degree in g_kf.degree(): + assert degree == 2 + + +class TestIsRegular: + def test_is_regular1(self): + g = gen.cycle_graph(4) + assert reg.is_regular(g) + + def test_is_regular2(self): + g = gen.complete_graph(5) + assert reg.is_regular(g) + + def test_is_regular3(self): + g = gen.lollipop_graph(5, 5) + assert not reg.is_regular(g) + + def test_is_regular4(self): + g = nx.DiGraph() + g.add_edges_from([(0, 1), (1, 2), (2, 0)]) + assert reg.is_regular(g) + + +def test_is_regular_empty_graph_raises(): + G = nx.Graph() + with pytest.raises(nx.NetworkXPointlessConcept, match="Graph has no nodes"): + nx.is_regular(G) + + +class TestIsKRegular: + def test_is_k_regular1(self): + g = gen.cycle_graph(4) + assert reg.is_k_regular(g, 2) + assert not reg.is_k_regular(g, 3) + + def test_is_k_regular2(self): + g = gen.complete_graph(5) + assert reg.is_k_regular(g, 4) + assert not reg.is_k_regular(g, 3) + assert not reg.is_k_regular(g, 6) + + def test_is_k_regular3(self): + g = gen.lollipop_graph(5, 5) + assert not reg.is_k_regular(g, 5) + assert not reg.is_k_regular(g, 6) diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/test_richclub.py b/lib/python3.10/site-packages/networkx/algorithms/tests/test_richclub.py new file mode 100644 index 0000000000000000000000000000000000000000..1bdb66847fdfe5d3e6ad398aa76279b85b2c811a --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/tests/test_richclub.py @@ -0,0 +1,149 @@ +import pytest + +import networkx as nx + + +def test_richclub(): + G = nx.Graph([(0, 1), (0, 2), (1, 2), (1, 3), (1, 4), (4, 5)]) + rc = nx.richclub.rich_club_coefficient(G, normalized=False) + assert rc == {0: 12.0 / 30, 1: 8.0 / 12} + + # test single value + rc0 = nx.richclub.rich_club_coefficient(G, normalized=False)[0] + assert rc0 == 12.0 / 30.0 + + +def test_richclub_seed(): + G = nx.Graph([(0, 1), (0, 2), (1, 2), (1, 3), (1, 4), (4, 5)]) + rcNorm = nx.richclub.rich_club_coefficient(G, Q=2, seed=1) + assert rcNorm == {0: 1.0, 1: 1.0} + + +def test_richclub_normalized(): + G = nx.Graph([(0, 1), (0, 2), (1, 2), (1, 3), (1, 4), (4, 5)]) + rcNorm = nx.richclub.rich_club_coefficient(G, Q=2, seed=42) + assert rcNorm == {0: 1.0, 1: 1.0} + + +def test_richclub2(): + T = nx.balanced_tree(2, 10) + rc = nx.richclub.rich_club_coefficient(T, normalized=False) + assert rc == { + 0: 4092 / (2047 * 2046.0), + 1: (2044.0 / (1023 * 1022)), + 2: (2040.0 / (1022 * 1021)), + } + + +def test_richclub3(): + # tests edgecase + G = nx.karate_club_graph() + rc = nx.rich_club_coefficient(G, normalized=False) + assert rc == { + 0: 156.0 / 1122, + 1: 154.0 / 1056, + 2: 110.0 / 462, + 3: 78.0 / 240, + 4: 44.0 / 90, + 5: 22.0 / 42, + 6: 10.0 / 20, + 7: 10.0 / 20, + 8: 10.0 / 20, + 9: 6.0 / 12, + 10: 2.0 / 6, + 11: 2.0 / 6, + 12: 0.0, + 13: 0.0, + 14: 0.0, + 15: 0.0, + } + + +def test_richclub4(): + G = nx.Graph() + G.add_edges_from( + [(0, 1), (0, 2), (0, 3), (0, 4), (4, 5), (5, 9), (6, 9), (7, 9), (8, 9)] + ) + rc = nx.rich_club_coefficient(G, normalized=False) + assert rc == {0: 18 / 90.0, 1: 6 / 12.0, 2: 0.0, 3: 0.0} + + +def test_richclub_exception(): + with pytest.raises(nx.NetworkXNotImplemented): + G = nx.DiGraph() + nx.rich_club_coefficient(G) + + +def test_rich_club_exception2(): + with pytest.raises(nx.NetworkXNotImplemented): + G = nx.MultiGraph() + nx.rich_club_coefficient(G) + + +def test_rich_club_selfloop(): + G = nx.Graph() # or DiGraph, MultiGraph, MultiDiGraph, etc + G.add_edge(1, 1) # self loop + G.add_edge(1, 2) + with pytest.raises( + Exception, + match="rich_club_coefficient is not implemented for " "graphs with self loops.", + ): + nx.rich_club_coefficient(G) + + +def test_rich_club_leq_3_nodes_unnormalized(): + # edgeless graphs upto 3 nodes + G = nx.Graph() + rc = nx.rich_club_coefficient(G, normalized=False) + assert rc == {} + + for i in range(3): + G.add_node(i) + rc = nx.rich_club_coefficient(G, normalized=False) + assert rc == {} + + # 2 nodes, single edge + G = nx.Graph() + G.add_edge(0, 1) + rc = nx.rich_club_coefficient(G, normalized=False) + assert rc == {0: 1} + + # 3 nodes, single edge + G = nx.Graph() + G.add_nodes_from([0, 1, 2]) + G.add_edge(0, 1) + rc = nx.rich_club_coefficient(G, normalized=False) + assert rc == {0: 1} + + # 3 nodes, 2 edges + G.add_edge(1, 2) + rc = nx.rich_club_coefficient(G, normalized=False) + assert rc == {0: 2 / 3} + + # 3 nodes, 3 edges + G.add_edge(0, 2) + rc = nx.rich_club_coefficient(G, normalized=False) + assert rc == {0: 1, 1: 1} + + +def test_rich_club_leq_3_nodes_normalized(): + G = nx.Graph() + with pytest.raises( + nx.exception.NetworkXError, + match="Graph has fewer than four nodes", + ): + rc = nx.rich_club_coefficient(G, normalized=True) + + for i in range(3): + G.add_node(i) + with pytest.raises( + nx.exception.NetworkXError, + match="Graph has fewer than four nodes", + ): + rc = nx.rich_club_coefficient(G, normalized=True) + + +# def test_richclub2_normalized(): +# T = nx.balanced_tree(2,10) +# rcNorm = nx.richclub.rich_club_coefficient(T,Q=2) +# assert_true(rcNorm[0] ==1.0 and rcNorm[1] < 0.9 and rcNorm[2] < 0.9) diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/test_similarity.py b/lib/python3.10/site-packages/networkx/algorithms/tests/test_similarity.py new file mode 100644 index 0000000000000000000000000000000000000000..3836ccfe182fd58e96c2e0212e8aca55d7668b9d --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/tests/test_similarity.py @@ -0,0 +1,946 @@ +import pytest + +import networkx as nx +from networkx.algorithms.similarity import ( + graph_edit_distance, + optimal_edit_paths, + optimize_graph_edit_distance, +) +from networkx.generators.classic import ( + circular_ladder_graph, + cycle_graph, + path_graph, + wheel_graph, +) + + +def nmatch(n1, n2): + return n1 == n2 + + +def ematch(e1, e2): + return e1 == e2 + + +def getCanonical(): + G = nx.Graph() + G.add_node("A", label="A") + G.add_node("B", label="B") + G.add_node("C", label="C") + G.add_node("D", label="D") + G.add_edge("A", "B", label="a-b") + G.add_edge("B", "C", label="b-c") + G.add_edge("B", "D", label="b-d") + return G + + +class TestSimilarity: + @classmethod + def setup_class(cls): + global np + np = pytest.importorskip("numpy") + pytest.importorskip("scipy") + + def test_graph_edit_distance_roots_and_timeout(self): + G0 = nx.star_graph(5) + G1 = G0.copy() + pytest.raises(ValueError, graph_edit_distance, G0, G1, roots=[2]) + pytest.raises(ValueError, graph_edit_distance, G0, G1, roots=[2, 3, 4]) + pytest.raises(nx.NodeNotFound, graph_edit_distance, G0, G1, roots=(9, 3)) + pytest.raises(nx.NodeNotFound, graph_edit_distance, G0, G1, roots=(3, 9)) + pytest.raises(nx.NodeNotFound, graph_edit_distance, G0, G1, roots=(9, 9)) + assert graph_edit_distance(G0, G1, roots=(1, 2)) == 0 + assert graph_edit_distance(G0, G1, roots=(0, 1)) == 8 + assert graph_edit_distance(G0, G1, roots=(1, 2), timeout=5) == 0 + assert graph_edit_distance(G0, G1, roots=(0, 1), timeout=5) == 8 + assert graph_edit_distance(G0, G1, roots=(0, 1), timeout=0.0001) is None + # test raise on 0 timeout + pytest.raises(nx.NetworkXError, graph_edit_distance, G0, G1, timeout=0) + + def test_graph_edit_distance(self): + G0 = nx.Graph() + G1 = path_graph(6) + G2 = cycle_graph(6) + G3 = wheel_graph(7) + + assert graph_edit_distance(G0, G0) == 0 + assert graph_edit_distance(G0, G1) == 11 + assert graph_edit_distance(G1, G0) == 11 + assert graph_edit_distance(G0, G2) == 12 + assert graph_edit_distance(G2, G0) == 12 + assert graph_edit_distance(G0, G3) == 19 + assert graph_edit_distance(G3, G0) == 19 + + assert graph_edit_distance(G1, G1) == 0 + assert graph_edit_distance(G1, G2) == 1 + assert graph_edit_distance(G2, G1) == 1 + assert graph_edit_distance(G1, G3) == 8 + assert graph_edit_distance(G3, G1) == 8 + + assert graph_edit_distance(G2, G2) == 0 + assert graph_edit_distance(G2, G3) == 7 + assert graph_edit_distance(G3, G2) == 7 + + assert graph_edit_distance(G3, G3) == 0 + + def test_graph_edit_distance_node_match(self): + G1 = cycle_graph(5) + G2 = cycle_graph(5) + for n, attr in G1.nodes.items(): + attr["color"] = "red" if n % 2 == 0 else "blue" + for n, attr in G2.nodes.items(): + attr["color"] = "red" if n % 2 == 1 else "blue" + assert graph_edit_distance(G1, G2) == 0 + assert ( + graph_edit_distance( + G1, G2, node_match=lambda n1, n2: n1["color"] == n2["color"] + ) + == 1 + ) + + def test_graph_edit_distance_edge_match(self): + G1 = path_graph(6) + G2 = path_graph(6) + for e, attr in G1.edges.items(): + attr["color"] = "red" if min(e) % 2 == 0 else "blue" + for e, attr in G2.edges.items(): + attr["color"] = "red" if min(e) // 3 == 0 else "blue" + assert graph_edit_distance(G1, G2) == 0 + assert ( + graph_edit_distance( + G1, G2, edge_match=lambda e1, e2: e1["color"] == e2["color"] + ) + == 2 + ) + + def test_graph_edit_distance_node_cost(self): + G1 = path_graph(6) + G2 = path_graph(6) + for n, attr in G1.nodes.items(): + attr["color"] = "red" if n % 2 == 0 else "blue" + for n, attr in G2.nodes.items(): + attr["color"] = "red" if n % 2 == 1 else "blue" + + def node_subst_cost(uattr, vattr): + if uattr["color"] == vattr["color"]: + return 1 + else: + return 10 + + def node_del_cost(attr): + if attr["color"] == "blue": + return 20 + else: + return 50 + + def node_ins_cost(attr): + if attr["color"] == "blue": + return 40 + else: + return 100 + + assert ( + graph_edit_distance( + G1, + G2, + node_subst_cost=node_subst_cost, + node_del_cost=node_del_cost, + node_ins_cost=node_ins_cost, + ) + == 6 + ) + + def test_graph_edit_distance_edge_cost(self): + G1 = path_graph(6) + G2 = path_graph(6) + for e, attr in G1.edges.items(): + attr["color"] = "red" if min(e) % 2 == 0 else "blue" + for e, attr in G2.edges.items(): + attr["color"] = "red" if min(e) // 3 == 0 else "blue" + + def edge_subst_cost(gattr, hattr): + if gattr["color"] == hattr["color"]: + return 0.01 + else: + return 0.1 + + def edge_del_cost(attr): + if attr["color"] == "blue": + return 0.2 + else: + return 0.5 + + def edge_ins_cost(attr): + if attr["color"] == "blue": + return 0.4 + else: + return 1.0 + + assert ( + graph_edit_distance( + G1, + G2, + edge_subst_cost=edge_subst_cost, + edge_del_cost=edge_del_cost, + edge_ins_cost=edge_ins_cost, + ) + == 0.23 + ) + + def test_graph_edit_distance_upper_bound(self): + G1 = circular_ladder_graph(2) + G2 = circular_ladder_graph(6) + assert graph_edit_distance(G1, G2, upper_bound=5) is None + assert graph_edit_distance(G1, G2, upper_bound=24) == 22 + assert graph_edit_distance(G1, G2) == 22 + + def test_optimal_edit_paths(self): + G1 = path_graph(3) + G2 = cycle_graph(3) + paths, cost = optimal_edit_paths(G1, G2) + assert cost == 1 + assert len(paths) == 6 + + def canonical(vertex_path, edge_path): + return ( + tuple(sorted(vertex_path)), + tuple(sorted(edge_path, key=lambda x: (None in x, x))), + ) + + expected_paths = [ + ( + [(0, 0), (1, 1), (2, 2)], + [((0, 1), (0, 1)), ((1, 2), (1, 2)), (None, (0, 2))], + ), + ( + [(0, 0), (1, 2), (2, 1)], + [((0, 1), (0, 2)), ((1, 2), (1, 2)), (None, (0, 1))], + ), + ( + [(0, 1), (1, 0), (2, 2)], + [((0, 1), (0, 1)), ((1, 2), (0, 2)), (None, (1, 2))], + ), + ( + [(0, 1), (1, 2), (2, 0)], + [((0, 1), (1, 2)), ((1, 2), (0, 2)), (None, (0, 1))], + ), + ( + [(0, 2), (1, 0), (2, 1)], + [((0, 1), (0, 2)), ((1, 2), (0, 1)), (None, (1, 2))], + ), + ( + [(0, 2), (1, 1), (2, 0)], + [((0, 1), (1, 2)), ((1, 2), (0, 1)), (None, (0, 2))], + ), + ] + assert {canonical(*p) for p in paths} == {canonical(*p) for p in expected_paths} + + def test_optimize_graph_edit_distance(self): + G1 = circular_ladder_graph(2) + G2 = circular_ladder_graph(6) + bestcost = 1000 + for cost in optimize_graph_edit_distance(G1, G2): + assert cost < bestcost + bestcost = cost + assert bestcost == 22 + + # def test_graph_edit_distance_bigger(self): + # G1 = circular_ladder_graph(12) + # G2 = circular_ladder_graph(16) + # assert_equal(graph_edit_distance(G1, G2), 22) + + def test_selfloops(self): + G0 = nx.Graph() + G1 = nx.Graph() + G1.add_edges_from((("A", "A"), ("A", "B"))) + G2 = nx.Graph() + G2.add_edges_from((("A", "B"), ("B", "B"))) + G3 = nx.Graph() + G3.add_edges_from((("A", "A"), ("A", "B"), ("B", "B"))) + + assert graph_edit_distance(G0, G0) == 0 + assert graph_edit_distance(G0, G1) == 4 + assert graph_edit_distance(G1, G0) == 4 + assert graph_edit_distance(G0, G2) == 4 + assert graph_edit_distance(G2, G0) == 4 + assert graph_edit_distance(G0, G3) == 5 + assert graph_edit_distance(G3, G0) == 5 + + assert graph_edit_distance(G1, G1) == 0 + assert graph_edit_distance(G1, G2) == 0 + assert graph_edit_distance(G2, G1) == 0 + assert graph_edit_distance(G1, G3) == 1 + assert graph_edit_distance(G3, G1) == 1 + + assert graph_edit_distance(G2, G2) == 0 + assert graph_edit_distance(G2, G3) == 1 + assert graph_edit_distance(G3, G2) == 1 + + assert graph_edit_distance(G3, G3) == 0 + + def test_digraph(self): + G0 = nx.DiGraph() + G1 = nx.DiGraph() + G1.add_edges_from((("A", "B"), ("B", "C"), ("C", "D"), ("D", "A"))) + G2 = nx.DiGraph() + G2.add_edges_from((("A", "B"), ("B", "C"), ("C", "D"), ("A", "D"))) + G3 = nx.DiGraph() + G3.add_edges_from((("A", "B"), ("A", "C"), ("B", "D"), ("C", "D"))) + + assert graph_edit_distance(G0, G0) == 0 + assert graph_edit_distance(G0, G1) == 8 + assert graph_edit_distance(G1, G0) == 8 + assert graph_edit_distance(G0, G2) == 8 + assert graph_edit_distance(G2, G0) == 8 + assert graph_edit_distance(G0, G3) == 8 + assert graph_edit_distance(G3, G0) == 8 + + assert graph_edit_distance(G1, G1) == 0 + assert graph_edit_distance(G1, G2) == 2 + assert graph_edit_distance(G2, G1) == 2 + assert graph_edit_distance(G1, G3) == 4 + assert graph_edit_distance(G3, G1) == 4 + + assert graph_edit_distance(G2, G2) == 0 + assert graph_edit_distance(G2, G3) == 2 + assert graph_edit_distance(G3, G2) == 2 + + assert graph_edit_distance(G3, G3) == 0 + + def test_multigraph(self): + G0 = nx.MultiGraph() + G1 = nx.MultiGraph() + G1.add_edges_from((("A", "B"), ("B", "C"), ("A", "C"))) + G2 = nx.MultiGraph() + G2.add_edges_from((("A", "B"), ("B", "C"), ("B", "C"), ("A", "C"))) + G3 = nx.MultiGraph() + G3.add_edges_from((("A", "B"), ("B", "C"), ("A", "C"), ("A", "C"), ("A", "C"))) + + assert graph_edit_distance(G0, G0) == 0 + assert graph_edit_distance(G0, G1) == 6 + assert graph_edit_distance(G1, G0) == 6 + assert graph_edit_distance(G0, G2) == 7 + assert graph_edit_distance(G2, G0) == 7 + assert graph_edit_distance(G0, G3) == 8 + assert graph_edit_distance(G3, G0) == 8 + + assert graph_edit_distance(G1, G1) == 0 + assert graph_edit_distance(G1, G2) == 1 + assert graph_edit_distance(G2, G1) == 1 + assert graph_edit_distance(G1, G3) == 2 + assert graph_edit_distance(G3, G1) == 2 + + assert graph_edit_distance(G2, G2) == 0 + assert graph_edit_distance(G2, G3) == 1 + assert graph_edit_distance(G3, G2) == 1 + + assert graph_edit_distance(G3, G3) == 0 + + def test_multidigraph(self): + G1 = nx.MultiDiGraph() + G1.add_edges_from( + ( + ("hardware", "kernel"), + ("kernel", "hardware"), + ("kernel", "userspace"), + ("userspace", "kernel"), + ) + ) + G2 = nx.MultiDiGraph() + G2.add_edges_from( + ( + ("winter", "spring"), + ("spring", "summer"), + ("summer", "autumn"), + ("autumn", "winter"), + ) + ) + + assert graph_edit_distance(G1, G2) == 5 + assert graph_edit_distance(G2, G1) == 5 + + # by https://github.com/jfbeaumont + def testCopy(self): + G = nx.Graph() + G.add_node("A", label="A") + G.add_node("B", label="B") + G.add_edge("A", "B", label="a-b") + assert ( + graph_edit_distance(G, G.copy(), node_match=nmatch, edge_match=ematch) == 0 + ) + + def testSame(self): + G1 = nx.Graph() + G1.add_node("A", label="A") + G1.add_node("B", label="B") + G1.add_edge("A", "B", label="a-b") + G2 = nx.Graph() + G2.add_node("A", label="A") + G2.add_node("B", label="B") + G2.add_edge("A", "B", label="a-b") + assert graph_edit_distance(G1, G2, node_match=nmatch, edge_match=ematch) == 0 + + def testOneEdgeLabelDiff(self): + G1 = nx.Graph() + G1.add_node("A", label="A") + G1.add_node("B", label="B") + G1.add_edge("A", "B", label="a-b") + G2 = nx.Graph() + G2.add_node("A", label="A") + G2.add_node("B", label="B") + G2.add_edge("A", "B", label="bad") + assert graph_edit_distance(G1, G2, node_match=nmatch, edge_match=ematch) == 1 + + def testOneNodeLabelDiff(self): + G1 = nx.Graph() + G1.add_node("A", label="A") + G1.add_node("B", label="B") + G1.add_edge("A", "B", label="a-b") + G2 = nx.Graph() + G2.add_node("A", label="Z") + G2.add_node("B", label="B") + G2.add_edge("A", "B", label="a-b") + assert graph_edit_distance(G1, G2, node_match=nmatch, edge_match=ematch) == 1 + + def testOneExtraNode(self): + G1 = nx.Graph() + G1.add_node("A", label="A") + G1.add_node("B", label="B") + G1.add_edge("A", "B", label="a-b") + G2 = nx.Graph() + G2.add_node("A", label="A") + G2.add_node("B", label="B") + G2.add_edge("A", "B", label="a-b") + G2.add_node("C", label="C") + assert graph_edit_distance(G1, G2, node_match=nmatch, edge_match=ematch) == 1 + + def testOneExtraEdge(self): + G1 = nx.Graph() + G1.add_node("A", label="A") + G1.add_node("B", label="B") + G1.add_node("C", label="C") + G1.add_node("C", label="C") + G1.add_edge("A", "B", label="a-b") + G2 = nx.Graph() + G2.add_node("A", label="A") + G2.add_node("B", label="B") + G2.add_node("C", label="C") + G2.add_edge("A", "B", label="a-b") + G2.add_edge("A", "C", label="a-c") + assert graph_edit_distance(G1, G2, node_match=nmatch, edge_match=ematch) == 1 + + def testOneExtraNodeAndEdge(self): + G1 = nx.Graph() + G1.add_node("A", label="A") + G1.add_node("B", label="B") + G1.add_edge("A", "B", label="a-b") + G2 = nx.Graph() + G2.add_node("A", label="A") + G2.add_node("B", label="B") + G2.add_node("C", label="C") + G2.add_edge("A", "B", label="a-b") + G2.add_edge("A", "C", label="a-c") + assert graph_edit_distance(G1, G2, node_match=nmatch, edge_match=ematch) == 2 + + def testGraph1(self): + G1 = getCanonical() + G2 = nx.Graph() + G2.add_node("A", label="A") + G2.add_node("B", label="B") + G2.add_node("D", label="D") + G2.add_node("E", label="E") + G2.add_edge("A", "B", label="a-b") + G2.add_edge("B", "D", label="b-d") + G2.add_edge("D", "E", label="d-e") + assert graph_edit_distance(G1, G2, node_match=nmatch, edge_match=ematch) == 3 + + def testGraph2(self): + G1 = getCanonical() + G2 = nx.Graph() + G2.add_node("A", label="A") + G2.add_node("B", label="B") + G2.add_node("C", label="C") + G2.add_node("D", label="D") + G2.add_node("E", label="E") + G2.add_edge("A", "B", label="a-b") + G2.add_edge("B", "C", label="b-c") + G2.add_edge("C", "D", label="c-d") + G2.add_edge("C", "E", label="c-e") + assert graph_edit_distance(G1, G2, node_match=nmatch, edge_match=ematch) == 4 + + def testGraph3(self): + G1 = getCanonical() + G2 = nx.Graph() + G2.add_node("A", label="A") + G2.add_node("B", label="B") + G2.add_node("C", label="C") + G2.add_node("D", label="D") + G2.add_node("E", label="E") + G2.add_node("F", label="F") + G2.add_node("G", label="G") + G2.add_edge("A", "C", label="a-c") + G2.add_edge("A", "D", label="a-d") + G2.add_edge("D", "E", label="d-e") + G2.add_edge("D", "F", label="d-f") + G2.add_edge("D", "G", label="d-g") + G2.add_edge("E", "B", label="e-b") + assert graph_edit_distance(G1, G2, node_match=nmatch, edge_match=ematch) == 12 + + def testGraph4(self): + G1 = getCanonical() + G2 = nx.Graph() + G2.add_node("A", label="A") + G2.add_node("B", label="B") + G2.add_node("C", label="C") + G2.add_node("D", label="D") + G2.add_edge("A", "B", label="a-b") + G2.add_edge("B", "C", label="b-c") + G2.add_edge("C", "D", label="c-d") + assert graph_edit_distance(G1, G2, node_match=nmatch, edge_match=ematch) == 2 + + def testGraph4_a(self): + G1 = getCanonical() + G2 = nx.Graph() + G2.add_node("A", label="A") + G2.add_node("B", label="B") + G2.add_node("C", label="C") + G2.add_node("D", label="D") + G2.add_edge("A", "B", label="a-b") + G2.add_edge("B", "C", label="b-c") + G2.add_edge("A", "D", label="a-d") + assert graph_edit_distance(G1, G2, node_match=nmatch, edge_match=ematch) == 2 + + def testGraph4_b(self): + G1 = getCanonical() + G2 = nx.Graph() + G2.add_node("A", label="A") + G2.add_node("B", label="B") + G2.add_node("C", label="C") + G2.add_node("D", label="D") + G2.add_edge("A", "B", label="a-b") + G2.add_edge("B", "C", label="b-c") + G2.add_edge("B", "D", label="bad") + assert graph_edit_distance(G1, G2, node_match=nmatch, edge_match=ematch) == 1 + + # note: nx.simrank_similarity_numpy not included because returns np.array + simrank_algs = [ + nx.simrank_similarity, + nx.algorithms.similarity._simrank_similarity_python, + ] + + @pytest.mark.parametrize("simrank_similarity", simrank_algs) + def test_simrank_no_source_no_target(self, simrank_similarity): + G = nx.cycle_graph(5) + expected = { + 0: { + 0: 1, + 1: 0.3951219505902448, + 2: 0.5707317069281646, + 3: 0.5707317069281646, + 4: 0.3951219505902449, + }, + 1: { + 0: 0.3951219505902448, + 1: 1, + 2: 0.3951219505902449, + 3: 0.5707317069281646, + 4: 0.5707317069281646, + }, + 2: { + 0: 0.5707317069281646, + 1: 0.3951219505902449, + 2: 1, + 3: 0.3951219505902449, + 4: 0.5707317069281646, + }, + 3: { + 0: 0.5707317069281646, + 1: 0.5707317069281646, + 2: 0.3951219505902449, + 3: 1, + 4: 0.3951219505902449, + }, + 4: { + 0: 0.3951219505902449, + 1: 0.5707317069281646, + 2: 0.5707317069281646, + 3: 0.3951219505902449, + 4: 1, + }, + } + actual = simrank_similarity(G) + for k, v in expected.items(): + assert v == pytest.approx(actual[k], abs=1e-2) + + # For a DiGraph test, use the first graph from the paper cited in + # the docs: https://dl.acm.org/doi/pdf/10.1145/775047.775126 + G = nx.DiGraph() + G.add_node(0, label="Univ") + G.add_node(1, label="ProfA") + G.add_node(2, label="ProfB") + G.add_node(3, label="StudentA") + G.add_node(4, label="StudentB") + G.add_edges_from([(0, 1), (0, 2), (1, 3), (2, 4), (4, 2), (3, 0)]) + + expected = { + 0: {0: 1, 1: 0.0, 2: 0.1323363991265798, 3: 0.0, 4: 0.03387811817640443}, + 1: {0: 0.0, 1: 1, 2: 0.4135512472705618, 3: 0.0, 4: 0.10586911930126384}, + 2: { + 0: 0.1323363991265798, + 1: 0.4135512472705618, + 2: 1, + 3: 0.04234764772050554, + 4: 0.08822426608438655, + }, + 3: {0: 0.0, 1: 0.0, 2: 0.04234764772050554, 3: 1, 4: 0.3308409978164495}, + 4: { + 0: 0.03387811817640443, + 1: 0.10586911930126384, + 2: 0.08822426608438655, + 3: 0.3308409978164495, + 4: 1, + }, + } + # Use the importance_factor from the paper to get the same numbers. + actual = simrank_similarity(G, importance_factor=0.8) + for k, v in expected.items(): + assert v == pytest.approx(actual[k], abs=1e-2) + + @pytest.mark.parametrize("simrank_similarity", simrank_algs) + def test_simrank_source_no_target(self, simrank_similarity): + G = nx.cycle_graph(5) + expected = { + 0: 1, + 1: 0.3951219505902448, + 2: 0.5707317069281646, + 3: 0.5707317069281646, + 4: 0.3951219505902449, + } + actual = simrank_similarity(G, source=0) + assert expected == pytest.approx(actual, abs=1e-2) + + # For a DiGraph test, use the first graph from the paper cited in + # the docs: https://dl.acm.org/doi/pdf/10.1145/775047.775126 + G = nx.DiGraph() + G.add_node(0, label="Univ") + G.add_node(1, label="ProfA") + G.add_node(2, label="ProfB") + G.add_node(3, label="StudentA") + G.add_node(4, label="StudentB") + G.add_edges_from([(0, 1), (0, 2), (1, 3), (2, 4), (4, 2), (3, 0)]) + + expected = {0: 1, 1: 0.0, 2: 0.1323363991265798, 3: 0.0, 4: 0.03387811817640443} + # Use the importance_factor from the paper to get the same numbers. + actual = simrank_similarity(G, importance_factor=0.8, source=0) + assert expected == pytest.approx(actual, abs=1e-2) + + @pytest.mark.parametrize("simrank_similarity", simrank_algs) + def test_simrank_noninteger_nodes(self, simrank_similarity): + G = nx.cycle_graph(5) + G = nx.relabel_nodes(G, dict(enumerate("abcde"))) + expected = { + "a": 1, + "b": 0.3951219505902448, + "c": 0.5707317069281646, + "d": 0.5707317069281646, + "e": 0.3951219505902449, + } + actual = simrank_similarity(G, source="a") + assert expected == pytest.approx(actual, abs=1e-2) + + # For a DiGraph test, use the first graph from the paper cited in + # the docs: https://dl.acm.org/doi/pdf/10.1145/775047.775126 + G = nx.DiGraph() + G.add_node(0, label="Univ") + G.add_node(1, label="ProfA") + G.add_node(2, label="ProfB") + G.add_node(3, label="StudentA") + G.add_node(4, label="StudentB") + G.add_edges_from([(0, 1), (0, 2), (1, 3), (2, 4), (4, 2), (3, 0)]) + node_labels = dict(enumerate(nx.get_node_attributes(G, "label").values())) + G = nx.relabel_nodes(G, node_labels) + + expected = { + "Univ": 1, + "ProfA": 0.0, + "ProfB": 0.1323363991265798, + "StudentA": 0.0, + "StudentB": 0.03387811817640443, + } + # Use the importance_factor from the paper to get the same numbers. + actual = simrank_similarity(G, importance_factor=0.8, source="Univ") + assert expected == pytest.approx(actual, abs=1e-2) + + @pytest.mark.parametrize("simrank_similarity", simrank_algs) + def test_simrank_source_and_target(self, simrank_similarity): + G = nx.cycle_graph(5) + expected = 1 + actual = simrank_similarity(G, source=0, target=0) + assert expected == pytest.approx(actual, abs=1e-2) + + # For a DiGraph test, use the first graph from the paper cited in + # the docs: https://dl.acm.org/doi/pdf/10.1145/775047.775126 + G = nx.DiGraph() + G.add_node(0, label="Univ") + G.add_node(1, label="ProfA") + G.add_node(2, label="ProfB") + G.add_node(3, label="StudentA") + G.add_node(4, label="StudentB") + G.add_edges_from([(0, 1), (0, 2), (1, 3), (2, 4), (4, 2), (3, 0)]) + + expected = 0.1323363991265798 + # Use the importance_factor from the paper to get the same numbers. + # Use the pair (0,2) because (0,0) and (0,1) have trivial results. + actual = simrank_similarity(G, importance_factor=0.8, source=0, target=2) + assert expected == pytest.approx(actual, abs=1e-5) + + @pytest.mark.parametrize("alg", simrank_algs) + def test_simrank_max_iterations(self, alg): + G = nx.cycle_graph(5) + pytest.raises(nx.ExceededMaxIterations, alg, G, max_iterations=10) + + def test_simrank_source_not_found(self): + G = nx.cycle_graph(5) + with pytest.raises(nx.NodeNotFound, match="Source node 10 not in G"): + nx.simrank_similarity(G, source=10) + + def test_simrank_target_not_found(self): + G = nx.cycle_graph(5) + with pytest.raises(nx.NodeNotFound, match="Target node 10 not in G"): + nx.simrank_similarity(G, target=10) + + def test_simrank_between_versions(self): + G = nx.cycle_graph(5) + # _python tolerance 1e-4 + expected_python_tol4 = { + 0: 1, + 1: 0.394512499239852, + 2: 0.5703550452791322, + 3: 0.5703550452791323, + 4: 0.394512499239852, + } + # _numpy tolerance 1e-4 + expected_numpy_tol4 = { + 0: 1.0, + 1: 0.3947180735764555, + 2: 0.570482097206368, + 3: 0.570482097206368, + 4: 0.3947180735764555, + } + actual = nx.simrank_similarity(G, source=0) + assert expected_numpy_tol4 == pytest.approx(actual, abs=1e-7) + # versions differ at 1e-4 level but equal at 1e-3 + assert expected_python_tol4 != pytest.approx(actual, abs=1e-4) + assert expected_python_tol4 == pytest.approx(actual, abs=1e-3) + + actual = nx.similarity._simrank_similarity_python(G, source=0) + assert expected_python_tol4 == pytest.approx(actual, abs=1e-7) + # versions differ at 1e-4 level but equal at 1e-3 + assert expected_numpy_tol4 != pytest.approx(actual, abs=1e-4) + assert expected_numpy_tol4 == pytest.approx(actual, abs=1e-3) + + def test_simrank_numpy_no_source_no_target(self): + G = nx.cycle_graph(5) + expected = np.array( + [ + [ + 1.0, + 0.3947180735764555, + 0.570482097206368, + 0.570482097206368, + 0.3947180735764555, + ], + [ + 0.3947180735764555, + 1.0, + 0.3947180735764555, + 0.570482097206368, + 0.570482097206368, + ], + [ + 0.570482097206368, + 0.3947180735764555, + 1.0, + 0.3947180735764555, + 0.570482097206368, + ], + [ + 0.570482097206368, + 0.570482097206368, + 0.3947180735764555, + 1.0, + 0.3947180735764555, + ], + [ + 0.3947180735764555, + 0.570482097206368, + 0.570482097206368, + 0.3947180735764555, + 1.0, + ], + ] + ) + actual = nx.similarity._simrank_similarity_numpy(G) + np.testing.assert_allclose(expected, actual, atol=1e-7) + + def test_simrank_numpy_source_no_target(self): + G = nx.cycle_graph(5) + expected = np.array( + [ + 1.0, + 0.3947180735764555, + 0.570482097206368, + 0.570482097206368, + 0.3947180735764555, + ] + ) + actual = nx.similarity._simrank_similarity_numpy(G, source=0) + np.testing.assert_allclose(expected, actual, atol=1e-7) + + def test_simrank_numpy_source_and_target(self): + G = nx.cycle_graph(5) + expected = 1.0 + actual = nx.similarity._simrank_similarity_numpy(G, source=0, target=0) + np.testing.assert_allclose(expected, actual, atol=1e-7) + + def test_panther_similarity_unweighted(self): + np.random.seed(42) + + G = nx.Graph() + G.add_edge(0, 1) + G.add_edge(0, 2) + G.add_edge(0, 3) + G.add_edge(1, 2) + G.add_edge(2, 4) + expected = {3: 0.5, 2: 0.5, 1: 0.5, 4: 0.125} + sim = nx.panther_similarity(G, 0, path_length=2) + assert sim == expected + + def test_panther_similarity_weighted(self): + np.random.seed(42) + + G = nx.Graph() + G.add_edge("v1", "v2", w=5) + G.add_edge("v1", "v3", w=1) + G.add_edge("v1", "v4", w=2) + G.add_edge("v2", "v3", w=0.1) + G.add_edge("v3", "v5", w=1) + expected = {"v3": 0.75, "v4": 0.5, "v2": 0.5, "v5": 0.25} + sim = nx.panther_similarity(G, "v1", path_length=2, weight="w") + assert sim == expected + + def test_panther_similarity_source_not_found(self): + G = nx.Graph() + G.add_edges_from([(0, 1), (0, 2), (0, 3), (1, 2), (2, 4)]) + with pytest.raises(nx.NodeNotFound, match="Source node 10 not in G"): + nx.panther_similarity(G, source=10) + + def test_panther_similarity_isolated(self): + G = nx.Graph() + G.add_nodes_from(range(5)) + with pytest.raises( + nx.NetworkXUnfeasible, + match="Panther similarity is not defined for the isolated source node 1.", + ): + nx.panther_similarity(G, source=1) + + def test_generate_random_paths_unweighted(self): + index_map = {} + num_paths = 10 + path_length = 2 + G = nx.Graph() + G.add_edge(0, 1) + G.add_edge(0, 2) + G.add_edge(0, 3) + G.add_edge(1, 2) + G.add_edge(2, 4) + paths = nx.generate_random_paths( + G, num_paths, path_length=path_length, index_map=index_map, seed=42 + ) + expected_paths = [ + [3, 0, 3], + [4, 2, 1], + [2, 1, 0], + [2, 0, 3], + [3, 0, 1], + [3, 0, 1], + [4, 2, 0], + [2, 1, 0], + [3, 0, 2], + [2, 1, 2], + ] + expected_map = { + 0: {0, 2, 3, 4, 5, 6, 7, 8}, + 1: {1, 2, 4, 5, 7, 9}, + 2: {1, 2, 3, 6, 7, 8, 9}, + 3: {0, 3, 4, 5, 8}, + 4: {1, 6}, + } + + assert expected_paths == list(paths) + assert expected_map == index_map + + def test_generate_random_paths_weighted(self): + np.random.seed(42) + + index_map = {} + num_paths = 10 + path_length = 6 + G = nx.Graph() + G.add_edge("a", "b", weight=0.6) + G.add_edge("a", "c", weight=0.2) + G.add_edge("c", "d", weight=0.1) + G.add_edge("c", "e", weight=0.7) + G.add_edge("c", "f", weight=0.9) + G.add_edge("a", "d", weight=0.3) + paths = nx.generate_random_paths( + G, num_paths, path_length=path_length, index_map=index_map + ) + + expected_paths = [ + ["d", "c", "f", "c", "d", "a", "b"], + ["e", "c", "f", "c", "f", "c", "e"], + ["d", "a", "b", "a", "b", "a", "c"], + ["b", "a", "d", "a", "b", "a", "b"], + ["d", "a", "b", "a", "b", "a", "d"], + ["d", "a", "b", "a", "b", "a", "c"], + ["d", "a", "b", "a", "b", "a", "b"], + ["f", "c", "f", "c", "f", "c", "e"], + ["d", "a", "d", "a", "b", "a", "b"], + ["e", "c", "f", "c", "e", "c", "d"], + ] + expected_map = { + "d": {0, 2, 3, 4, 5, 6, 8, 9}, + "c": {0, 1, 2, 5, 7, 9}, + "f": {0, 1, 9, 7}, + "a": {0, 2, 3, 4, 5, 6, 8}, + "b": {0, 2, 3, 4, 5, 6, 8}, + "e": {1, 9, 7}, + } + + assert expected_paths == list(paths) + assert expected_map == index_map + + def test_symmetry_with_custom_matching(self): + print("G2 is edge (a,b) and G3 is edge (a,a)") + print("but node order for G2 is (a,b) while for G3 it is (b,a)") + + a, b = "A", "B" + G2 = nx.Graph() + G2.add_nodes_from((a, b)) + G2.add_edges_from([(a, b)]) + G3 = nx.Graph() + G3.add_nodes_from((b, a)) + G3.add_edges_from([(a, a)]) + for G in (G2, G3): + for n in G: + G.nodes[n]["attr"] = n + for e in G.edges: + G.edges[e]["attr"] = e + match = lambda x, y: x == y + + print("Starting G2 to G3 GED calculation") + assert nx.graph_edit_distance(G2, G3, node_match=match, edge_match=match) == 1 + + print("Starting G3 to G2 GED calculation") + assert nx.graph_edit_distance(G3, G2, node_match=match, edge_match=match) == 1 diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/test_simple_paths.py b/lib/python3.10/site-packages/networkx/algorithms/tests/test_simple_paths.py new file mode 100644 index 0000000000000000000000000000000000000000..7855bbad27b896750faa932a74062aa2bc8ca143 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/tests/test_simple_paths.py @@ -0,0 +1,803 @@ +import random + +import pytest + +import networkx as nx +from networkx import convert_node_labels_to_integers as cnlti +from networkx.algorithms.simple_paths import ( + _bidirectional_dijkstra, + _bidirectional_shortest_path, +) +from networkx.utils import arbitrary_element, pairwise + + +class TestIsSimplePath: + """Unit tests for the + :func:`networkx.algorithms.simple_paths.is_simple_path` function. + + """ + + def test_empty_list(self): + """Tests that the empty list is not a valid path, since there + should be a one-to-one correspondence between paths as lists of + nodes and paths as lists of edges. + + """ + G = nx.trivial_graph() + assert not nx.is_simple_path(G, []) + + def test_trivial_path(self): + """Tests that the trivial path, a path of length one, is + considered a simple path in a graph. + + """ + G = nx.trivial_graph() + assert nx.is_simple_path(G, [0]) + + def test_trivial_nonpath(self): + """Tests that a list whose sole element is an object not in the + graph is not considered a simple path. + + """ + G = nx.trivial_graph() + assert not nx.is_simple_path(G, ["not a node"]) + + def test_simple_path(self): + G = nx.path_graph(2) + assert nx.is_simple_path(G, [0, 1]) + + def test_non_simple_path(self): + G = nx.path_graph(2) + assert not nx.is_simple_path(G, [0, 1, 0]) + + def test_cycle(self): + G = nx.cycle_graph(3) + assert not nx.is_simple_path(G, [0, 1, 2, 0]) + + def test_missing_node(self): + G = nx.path_graph(2) + assert not nx.is_simple_path(G, [0, 2]) + + def test_missing_starting_node(self): + G = nx.path_graph(2) + assert not nx.is_simple_path(G, [2, 0]) + + def test_directed_path(self): + G = nx.DiGraph([(0, 1), (1, 2)]) + assert nx.is_simple_path(G, [0, 1, 2]) + + def test_directed_non_path(self): + G = nx.DiGraph([(0, 1), (1, 2)]) + assert not nx.is_simple_path(G, [2, 1, 0]) + + def test_directed_cycle(self): + G = nx.DiGraph([(0, 1), (1, 2), (2, 0)]) + assert not nx.is_simple_path(G, [0, 1, 2, 0]) + + def test_multigraph(self): + G = nx.MultiGraph([(0, 1), (0, 1)]) + assert nx.is_simple_path(G, [0, 1]) + + def test_multidigraph(self): + G = nx.MultiDiGraph([(0, 1), (0, 1), (1, 0), (1, 0)]) + assert nx.is_simple_path(G, [0, 1]) + + +# Tests for all_simple_paths +def test_all_simple_paths(): + G = nx.path_graph(4) + paths = nx.all_simple_paths(G, 0, 3) + assert {tuple(p) for p in paths} == {(0, 1, 2, 3)} + + +def test_all_simple_paths_with_two_targets_emits_two_paths(): + G = nx.path_graph(4) + G.add_edge(2, 4) + paths = nx.all_simple_paths(G, 0, [3, 4]) + assert {tuple(p) for p in paths} == {(0, 1, 2, 3), (0, 1, 2, 4)} + + +def test_digraph_all_simple_paths_with_two_targets_emits_two_paths(): + G = nx.path_graph(4, create_using=nx.DiGraph()) + G.add_edge(2, 4) + paths = nx.all_simple_paths(G, 0, [3, 4]) + assert {tuple(p) for p in paths} == {(0, 1, 2, 3), (0, 1, 2, 4)} + + +def test_all_simple_paths_with_two_targets_cutoff(): + G = nx.path_graph(4) + G.add_edge(2, 4) + paths = nx.all_simple_paths(G, 0, [3, 4], cutoff=3) + assert {tuple(p) for p in paths} == {(0, 1, 2, 3), (0, 1, 2, 4)} + + +def test_digraph_all_simple_paths_with_two_targets_cutoff(): + G = nx.path_graph(4, create_using=nx.DiGraph()) + G.add_edge(2, 4) + paths = nx.all_simple_paths(G, 0, [3, 4], cutoff=3) + assert {tuple(p) for p in paths} == {(0, 1, 2, 3), (0, 1, 2, 4)} + + +def test_all_simple_paths_with_two_targets_in_line_emits_two_paths(): + G = nx.path_graph(4) + paths = nx.all_simple_paths(G, 0, [2, 3]) + assert {tuple(p) for p in paths} == {(0, 1, 2), (0, 1, 2, 3)} + + +def test_all_simple_paths_ignores_cycle(): + G = nx.cycle_graph(3, create_using=nx.DiGraph()) + G.add_edge(1, 3) + paths = nx.all_simple_paths(G, 0, 3) + assert {tuple(p) for p in paths} == {(0, 1, 3)} + + +def test_all_simple_paths_with_two_targets_inside_cycle_emits_two_paths(): + G = nx.cycle_graph(3, create_using=nx.DiGraph()) + G.add_edge(1, 3) + paths = nx.all_simple_paths(G, 0, [2, 3]) + assert {tuple(p) for p in paths} == {(0, 1, 2), (0, 1, 3)} + + +def test_all_simple_paths_source_target(): + G = nx.path_graph(4) + assert list(nx.all_simple_paths(G, 1, 1)) == [[1]] + + +def test_all_simple_paths_cutoff(): + G = nx.complete_graph(4) + paths = nx.all_simple_paths(G, 0, 1, cutoff=1) + assert {tuple(p) for p in paths} == {(0, 1)} + paths = nx.all_simple_paths(G, 0, 1, cutoff=2) + assert {tuple(p) for p in paths} == {(0, 1), (0, 2, 1), (0, 3, 1)} + + +def test_all_simple_paths_on_non_trivial_graph(): + """you may need to draw this graph to make sure it is reasonable""" + G = nx.path_graph(5, create_using=nx.DiGraph()) + G.add_edges_from([(0, 5), (1, 5), (1, 3), (5, 4), (4, 2), (4, 3)]) + paths = nx.all_simple_paths(G, 1, [2, 3]) + assert {tuple(p) for p in paths} == { + (1, 2), + (1, 3, 4, 2), + (1, 5, 4, 2), + (1, 3), + (1, 2, 3), + (1, 5, 4, 3), + (1, 5, 4, 2, 3), + } + paths = nx.all_simple_paths(G, 1, [2, 3], cutoff=3) + assert {tuple(p) for p in paths} == { + (1, 2), + (1, 3, 4, 2), + (1, 5, 4, 2), + (1, 3), + (1, 2, 3), + (1, 5, 4, 3), + } + paths = nx.all_simple_paths(G, 1, [2, 3], cutoff=2) + assert {tuple(p) for p in paths} == {(1, 2), (1, 3), (1, 2, 3)} + + +def test_all_simple_paths_multigraph(): + G = nx.MultiGraph([(1, 2), (1, 2)]) + assert list(nx.all_simple_paths(G, 1, 1)) == [[1]] + nx.add_path(G, [3, 1, 10, 2]) + paths = list(nx.all_simple_paths(G, 1, 2)) + assert len(paths) == 3 + assert {tuple(p) for p in paths} == {(1, 2), (1, 2), (1, 10, 2)} + + +def test_all_simple_paths_multigraph_with_cutoff(): + G = nx.MultiGraph([(1, 2), (1, 2), (1, 10), (10, 2)]) + paths = list(nx.all_simple_paths(G, 1, 2, cutoff=1)) + assert len(paths) == 2 + assert {tuple(p) for p in paths} == {(1, 2), (1, 2)} + + # See GitHub issue #6732. + G = nx.MultiGraph([(0, 1), (0, 2)]) + assert list(nx.all_simple_paths(G, 0, {1, 2}, cutoff=1)) == [[0, 1], [0, 2]] + + +def test_all_simple_paths_directed(): + G = nx.DiGraph() + nx.add_path(G, [1, 2, 3]) + nx.add_path(G, [3, 2, 1]) + paths = nx.all_simple_paths(G, 1, 3) + assert {tuple(p) for p in paths} == {(1, 2, 3)} + + +def test_all_simple_paths_empty(): + G = nx.path_graph(4) + paths = nx.all_simple_paths(G, 0, 3, cutoff=2) + assert list(paths) == [] + + +def test_all_simple_paths_corner_cases(): + assert list(nx.all_simple_paths(nx.empty_graph(2), 0, 0)) == [[0]] + assert list(nx.all_simple_paths(nx.empty_graph(2), 0, 1)) == [] + assert list(nx.all_simple_paths(nx.path_graph(9), 0, 8, 0)) == [] + + +def test_all_simple_paths_source_in_targets(): + # See GitHub issue #6690. + G = nx.path_graph(3) + assert list(nx.all_simple_paths(G, 0, {0, 1, 2})) == [[0], [0, 1], [0, 1, 2]] + + +def hamiltonian_path(G, source): + source = arbitrary_element(G) + neighbors = set(G[source]) - {source} + n = len(G) + for target in neighbors: + for path in nx.all_simple_paths(G, source, target): + if len(path) == n: + yield path + + +def test_hamiltonian_path(): + from itertools import permutations + + G = nx.complete_graph(4) + paths = [list(p) for p in hamiltonian_path(G, 0)] + exact = [[0] + list(p) for p in permutations([1, 2, 3], 3)] + assert sorted(paths) == sorted(exact) + + +def test_cutoff_zero(): + G = nx.complete_graph(4) + paths = nx.all_simple_paths(G, 0, 3, cutoff=0) + assert [list(p) for p in paths] == [] + paths = nx.all_simple_paths(nx.MultiGraph(G), 0, 3, cutoff=0) + assert [list(p) for p in paths] == [] + + +def test_source_missing(): + with pytest.raises(nx.NodeNotFound): + G = nx.Graph() + nx.add_path(G, [1, 2, 3]) + list(nx.all_simple_paths(nx.MultiGraph(G), 0, 3)) + + +def test_target_missing(): + with pytest.raises(nx.NodeNotFound): + G = nx.Graph() + nx.add_path(G, [1, 2, 3]) + list(nx.all_simple_paths(nx.MultiGraph(G), 1, 4)) + + +# Tests for all_simple_edge_paths +def test_all_simple_edge_paths(): + G = nx.path_graph(4) + paths = nx.all_simple_edge_paths(G, 0, 3) + assert {tuple(p) for p in paths} == {((0, 1), (1, 2), (2, 3))} + + +def test_all_simple_edge_paths_empty_path(): + G = nx.empty_graph(1) + assert list(nx.all_simple_edge_paths(G, 0, 0)) == [[]] + + +def test_all_simple_edge_paths_with_two_targets_emits_two_paths(): + G = nx.path_graph(4) + G.add_edge(2, 4) + paths = nx.all_simple_edge_paths(G, 0, [3, 4]) + assert {tuple(p) for p in paths} == { + ((0, 1), (1, 2), (2, 3)), + ((0, 1), (1, 2), (2, 4)), + } + + +def test_digraph_all_simple_edge_paths_with_two_targets_emits_two_paths(): + G = nx.path_graph(4, create_using=nx.DiGraph()) + G.add_edge(2, 4) + paths = nx.all_simple_edge_paths(G, 0, [3, 4]) + assert {tuple(p) for p in paths} == { + ((0, 1), (1, 2), (2, 3)), + ((0, 1), (1, 2), (2, 4)), + } + + +def test_all_simple_edge_paths_with_two_targets_cutoff(): + G = nx.path_graph(4) + G.add_edge(2, 4) + paths = nx.all_simple_edge_paths(G, 0, [3, 4], cutoff=3) + assert {tuple(p) for p in paths} == { + ((0, 1), (1, 2), (2, 3)), + ((0, 1), (1, 2), (2, 4)), + } + + +def test_digraph_all_simple_edge_paths_with_two_targets_cutoff(): + G = nx.path_graph(4, create_using=nx.DiGraph()) + G.add_edge(2, 4) + paths = nx.all_simple_edge_paths(G, 0, [3, 4], cutoff=3) + assert {tuple(p) for p in paths} == { + ((0, 1), (1, 2), (2, 3)), + ((0, 1), (1, 2), (2, 4)), + } + + +def test_all_simple_edge_paths_with_two_targets_in_line_emits_two_paths(): + G = nx.path_graph(4) + paths = nx.all_simple_edge_paths(G, 0, [2, 3]) + assert {tuple(p) for p in paths} == {((0, 1), (1, 2)), ((0, 1), (1, 2), (2, 3))} + + +def test_all_simple_edge_paths_ignores_cycle(): + G = nx.cycle_graph(3, create_using=nx.DiGraph()) + G.add_edge(1, 3) + paths = nx.all_simple_edge_paths(G, 0, 3) + assert {tuple(p) for p in paths} == {((0, 1), (1, 3))} + + +def test_all_simple_edge_paths_with_two_targets_inside_cycle_emits_two_paths(): + G = nx.cycle_graph(3, create_using=nx.DiGraph()) + G.add_edge(1, 3) + paths = nx.all_simple_edge_paths(G, 0, [2, 3]) + assert {tuple(p) for p in paths} == {((0, 1), (1, 2)), ((0, 1), (1, 3))} + + +def test_all_simple_edge_paths_source_target(): + G = nx.path_graph(4) + paths = nx.all_simple_edge_paths(G, 1, 1) + assert list(paths) == [[]] + + +def test_all_simple_edge_paths_cutoff(): + G = nx.complete_graph(4) + paths = nx.all_simple_edge_paths(G, 0, 1, cutoff=1) + assert {tuple(p) for p in paths} == {((0, 1),)} + paths = nx.all_simple_edge_paths(G, 0, 1, cutoff=2) + assert {tuple(p) for p in paths} == {((0, 1),), ((0, 2), (2, 1)), ((0, 3), (3, 1))} + + +def test_all_simple_edge_paths_on_non_trivial_graph(): + """you may need to draw this graph to make sure it is reasonable""" + G = nx.path_graph(5, create_using=nx.DiGraph()) + G.add_edges_from([(0, 5), (1, 5), (1, 3), (5, 4), (4, 2), (4, 3)]) + paths = nx.all_simple_edge_paths(G, 1, [2, 3]) + assert {tuple(p) for p in paths} == { + ((1, 2),), + ((1, 3), (3, 4), (4, 2)), + ((1, 5), (5, 4), (4, 2)), + ((1, 3),), + ((1, 2), (2, 3)), + ((1, 5), (5, 4), (4, 3)), + ((1, 5), (5, 4), (4, 2), (2, 3)), + } + paths = nx.all_simple_edge_paths(G, 1, [2, 3], cutoff=3) + assert {tuple(p) for p in paths} == { + ((1, 2),), + ((1, 3), (3, 4), (4, 2)), + ((1, 5), (5, 4), (4, 2)), + ((1, 3),), + ((1, 2), (2, 3)), + ((1, 5), (5, 4), (4, 3)), + } + paths = nx.all_simple_edge_paths(G, 1, [2, 3], cutoff=2) + assert {tuple(p) for p in paths} == {((1, 2),), ((1, 3),), ((1, 2), (2, 3))} + + +def test_all_simple_edge_paths_multigraph(): + G = nx.MultiGraph([(1, 2), (1, 2)]) + paths = nx.all_simple_edge_paths(G, 1, 1) + assert list(paths) == [[]] + nx.add_path(G, [3, 1, 10, 2]) + paths = list(nx.all_simple_edge_paths(G, 1, 2)) + assert len(paths) == 3 + assert {tuple(p) for p in paths} == { + ((1, 2, 0),), + ((1, 2, 1),), + ((1, 10, 0), (10, 2, 0)), + } + + +def test_all_simple_edge_paths_multigraph_with_cutoff(): + G = nx.MultiGraph([(1, 2), (1, 2), (1, 10), (10, 2)]) + paths = list(nx.all_simple_edge_paths(G, 1, 2, cutoff=1)) + assert len(paths) == 2 + assert {tuple(p) for p in paths} == {((1, 2, 0),), ((1, 2, 1),)} + + +def test_all_simple_edge_paths_directed(): + G = nx.DiGraph() + nx.add_path(G, [1, 2, 3]) + nx.add_path(G, [3, 2, 1]) + paths = nx.all_simple_edge_paths(G, 1, 3) + assert {tuple(p) for p in paths} == {((1, 2), (2, 3))} + + +def test_all_simple_edge_paths_empty(): + G = nx.path_graph(4) + paths = nx.all_simple_edge_paths(G, 0, 3, cutoff=2) + assert list(paths) == [] + + +def test_all_simple_edge_paths_corner_cases(): + assert list(nx.all_simple_edge_paths(nx.empty_graph(2), 0, 0)) == [[]] + assert list(nx.all_simple_edge_paths(nx.empty_graph(2), 0, 1)) == [] + assert list(nx.all_simple_edge_paths(nx.path_graph(9), 0, 8, 0)) == [] + + +def test_all_simple_edge_paths_ignores_self_loop(): + G = nx.Graph([(0, 0), (0, 1), (1, 1), (1, 2)]) + assert list(nx.all_simple_edge_paths(G, 0, 2)) == [[(0, 1), (1, 2)]] + + +def hamiltonian_edge_path(G, source): + source = arbitrary_element(G) + neighbors = set(G[source]) - {source} + n = len(G) + for target in neighbors: + for path in nx.all_simple_edge_paths(G, source, target): + if len(path) == n - 1: + yield path + + +def test_hamiltonian__edge_path(): + from itertools import permutations + + G = nx.complete_graph(4) + paths = hamiltonian_edge_path(G, 0) + exact = [list(pairwise([0] + list(p))) for p in permutations([1, 2, 3], 3)] + assert sorted(exact) == sorted(paths) + + +def test_edge_cutoff_zero(): + G = nx.complete_graph(4) + paths = nx.all_simple_edge_paths(G, 0, 3, cutoff=0) + assert [list(p) for p in paths] == [] + paths = nx.all_simple_edge_paths(nx.MultiGraph(G), 0, 3, cutoff=0) + assert [list(p) for p in paths] == [] + + +def test_edge_source_missing(): + with pytest.raises(nx.NodeNotFound): + G = nx.Graph() + nx.add_path(G, [1, 2, 3]) + list(nx.all_simple_edge_paths(nx.MultiGraph(G), 0, 3)) + + +def test_edge_target_missing(): + with pytest.raises(nx.NodeNotFound): + G = nx.Graph() + nx.add_path(G, [1, 2, 3]) + list(nx.all_simple_edge_paths(nx.MultiGraph(G), 1, 4)) + + +# Tests for shortest_simple_paths +def test_shortest_simple_paths(): + G = cnlti(nx.grid_2d_graph(4, 4), first_label=1, ordering="sorted") + paths = nx.shortest_simple_paths(G, 1, 12) + assert next(paths) == [1, 2, 3, 4, 8, 12] + assert next(paths) == [1, 5, 6, 7, 8, 12] + assert [len(path) for path in nx.shortest_simple_paths(G, 1, 12)] == sorted( + len(path) for path in nx.all_simple_paths(G, 1, 12) + ) + + +def test_shortest_simple_paths_singleton_path(): + G = nx.empty_graph(3) + assert list(nx.shortest_simple_paths(G, 0, 0)) == [[0]] + + +def test_shortest_simple_paths_directed(): + G = nx.cycle_graph(7, create_using=nx.DiGraph()) + paths = nx.shortest_simple_paths(G, 0, 3) + assert list(paths) == [[0, 1, 2, 3]] + + +def test_shortest_simple_paths_directed_with_weight_function(): + def cost(u, v, x): + return 1 + + G = cnlti(nx.grid_2d_graph(4, 4), first_label=1, ordering="sorted") + paths = nx.shortest_simple_paths(G, 1, 12) + assert next(paths) == [1, 2, 3, 4, 8, 12] + assert next(paths) == [1, 5, 6, 7, 8, 12] + assert [ + len(path) for path in nx.shortest_simple_paths(G, 1, 12, weight=cost) + ] == sorted(len(path) for path in nx.all_simple_paths(G, 1, 12)) + + +def test_shortest_simple_paths_with_weight_function(): + def cost(u, v, x): + return 1 + + G = nx.cycle_graph(7, create_using=nx.DiGraph()) + paths = nx.shortest_simple_paths(G, 0, 3, weight=cost) + assert list(paths) == [[0, 1, 2, 3]] + + +def test_shortest_simple_paths_with_none_weight_function(): + def cost(u, v, x): + delta = abs(u - v) + # ignore interior edges + return 1 if (delta == 1 or delta == 4) else None + + G = nx.complete_graph(5) + paths = nx.shortest_simple_paths(G, 0, 2, weight=cost) + assert list(paths) == [[0, 1, 2], [0, 4, 3, 2]] + + +def test_Greg_Bernstein(): + g1 = nx.Graph() + g1.add_nodes_from(["N0", "N1", "N2", "N3", "N4"]) + g1.add_edge("N4", "N1", weight=10.0, capacity=50, name="L5") + g1.add_edge("N4", "N0", weight=7.0, capacity=40, name="L4") + g1.add_edge("N0", "N1", weight=10.0, capacity=45, name="L1") + g1.add_edge("N3", "N0", weight=10.0, capacity=50, name="L0") + g1.add_edge("N2", "N3", weight=12.0, capacity=30, name="L2") + g1.add_edge("N1", "N2", weight=15.0, capacity=42, name="L3") + solution = [["N1", "N0", "N3"], ["N1", "N2", "N3"], ["N1", "N4", "N0", "N3"]] + result = list(nx.shortest_simple_paths(g1, "N1", "N3", weight="weight")) + assert result == solution + + +def test_weighted_shortest_simple_path(): + def cost_func(path): + return sum(G.adj[u][v]["weight"] for (u, v) in zip(path, path[1:])) + + G = nx.complete_graph(5) + weight = {(u, v): random.randint(1, 100) for (u, v) in G.edges()} + nx.set_edge_attributes(G, weight, "weight") + cost = 0 + for path in nx.shortest_simple_paths(G, 0, 3, weight="weight"): + this_cost = cost_func(path) + assert cost <= this_cost + cost = this_cost + + +def test_directed_weighted_shortest_simple_path(): + def cost_func(path): + return sum(G.adj[u][v]["weight"] for (u, v) in zip(path, path[1:])) + + G = nx.complete_graph(5) + G = G.to_directed() + weight = {(u, v): random.randint(1, 100) for (u, v) in G.edges()} + nx.set_edge_attributes(G, weight, "weight") + cost = 0 + for path in nx.shortest_simple_paths(G, 0, 3, weight="weight"): + this_cost = cost_func(path) + assert cost <= this_cost + cost = this_cost + + +def test_weighted_shortest_simple_path_issue2427(): + G = nx.Graph() + G.add_edge("IN", "OUT", weight=2) + G.add_edge("IN", "A", weight=1) + G.add_edge("IN", "B", weight=2) + G.add_edge("B", "OUT", weight=2) + assert list(nx.shortest_simple_paths(G, "IN", "OUT", weight="weight")) == [ + ["IN", "OUT"], + ["IN", "B", "OUT"], + ] + G = nx.Graph() + G.add_edge("IN", "OUT", weight=10) + G.add_edge("IN", "A", weight=1) + G.add_edge("IN", "B", weight=1) + G.add_edge("B", "OUT", weight=1) + assert list(nx.shortest_simple_paths(G, "IN", "OUT", weight="weight")) == [ + ["IN", "B", "OUT"], + ["IN", "OUT"], + ] + + +def test_directed_weighted_shortest_simple_path_issue2427(): + G = nx.DiGraph() + G.add_edge("IN", "OUT", weight=2) + G.add_edge("IN", "A", weight=1) + G.add_edge("IN", "B", weight=2) + G.add_edge("B", "OUT", weight=2) + assert list(nx.shortest_simple_paths(G, "IN", "OUT", weight="weight")) == [ + ["IN", "OUT"], + ["IN", "B", "OUT"], + ] + G = nx.DiGraph() + G.add_edge("IN", "OUT", weight=10) + G.add_edge("IN", "A", weight=1) + G.add_edge("IN", "B", weight=1) + G.add_edge("B", "OUT", weight=1) + assert list(nx.shortest_simple_paths(G, "IN", "OUT", weight="weight")) == [ + ["IN", "B", "OUT"], + ["IN", "OUT"], + ] + + +def test_weight_name(): + G = nx.cycle_graph(7) + nx.set_edge_attributes(G, 1, "weight") + nx.set_edge_attributes(G, 1, "foo") + G.adj[1][2]["foo"] = 7 + paths = list(nx.shortest_simple_paths(G, 0, 3, weight="foo")) + solution = [[0, 6, 5, 4, 3], [0, 1, 2, 3]] + assert paths == solution + + +def test_ssp_source_missing(): + with pytest.raises(nx.NodeNotFound): + G = nx.Graph() + nx.add_path(G, [1, 2, 3]) + list(nx.shortest_simple_paths(G, 0, 3)) + + +def test_ssp_target_missing(): + with pytest.raises(nx.NodeNotFound): + G = nx.Graph() + nx.add_path(G, [1, 2, 3]) + list(nx.shortest_simple_paths(G, 1, 4)) + + +def test_ssp_multigraph(): + with pytest.raises(nx.NetworkXNotImplemented): + G = nx.MultiGraph() + nx.add_path(G, [1, 2, 3]) + list(nx.shortest_simple_paths(G, 1, 4)) + + +def test_ssp_source_missing2(): + with pytest.raises(nx.NetworkXNoPath): + G = nx.Graph() + nx.add_path(G, [0, 1, 2]) + nx.add_path(G, [3, 4, 5]) + list(nx.shortest_simple_paths(G, 0, 3)) + + +def test_bidirectional_shortest_path_restricted_cycle(): + cycle = nx.cycle_graph(7) + length, path = _bidirectional_shortest_path(cycle, 0, 3) + assert path == [0, 1, 2, 3] + length, path = _bidirectional_shortest_path(cycle, 0, 3, ignore_nodes=[1]) + assert path == [0, 6, 5, 4, 3] + + +def test_bidirectional_shortest_path_restricted_wheel(): + wheel = nx.wheel_graph(6) + length, path = _bidirectional_shortest_path(wheel, 1, 3) + assert path in [[1, 0, 3], [1, 2, 3]] + length, path = _bidirectional_shortest_path(wheel, 1, 3, ignore_nodes=[0]) + assert path == [1, 2, 3] + length, path = _bidirectional_shortest_path(wheel, 1, 3, ignore_nodes=[0, 2]) + assert path == [1, 5, 4, 3] + length, path = _bidirectional_shortest_path( + wheel, 1, 3, ignore_edges=[(1, 0), (5, 0), (2, 3)] + ) + assert path in [[1, 2, 0, 3], [1, 5, 4, 3]] + + +def test_bidirectional_shortest_path_restricted_directed_cycle(): + directed_cycle = nx.cycle_graph(7, create_using=nx.DiGraph()) + length, path = _bidirectional_shortest_path(directed_cycle, 0, 3) + assert path == [0, 1, 2, 3] + pytest.raises( + nx.NetworkXNoPath, + _bidirectional_shortest_path, + directed_cycle, + 0, + 3, + ignore_nodes=[1], + ) + length, path = _bidirectional_shortest_path( + directed_cycle, 0, 3, ignore_edges=[(2, 1)] + ) + assert path == [0, 1, 2, 3] + pytest.raises( + nx.NetworkXNoPath, + _bidirectional_shortest_path, + directed_cycle, + 0, + 3, + ignore_edges=[(1, 2)], + ) + + +def test_bidirectional_shortest_path_ignore(): + G = nx.Graph() + nx.add_path(G, [1, 2]) + nx.add_path(G, [1, 3]) + nx.add_path(G, [1, 4]) + pytest.raises( + nx.NetworkXNoPath, _bidirectional_shortest_path, G, 1, 2, ignore_nodes=[1] + ) + pytest.raises( + nx.NetworkXNoPath, _bidirectional_shortest_path, G, 1, 2, ignore_nodes=[2] + ) + G = nx.Graph() + nx.add_path(G, [1, 3]) + nx.add_path(G, [1, 4]) + nx.add_path(G, [3, 2]) + pytest.raises( + nx.NetworkXNoPath, _bidirectional_shortest_path, G, 1, 2, ignore_nodes=[1, 2] + ) + + +def validate_path(G, s, t, soln_len, path): + assert path[0] == s + assert path[-1] == t + assert soln_len == sum( + G[u][v].get("weight", 1) for u, v in zip(path[:-1], path[1:]) + ) + + +def validate_length_path(G, s, t, soln_len, length, path): + assert soln_len == length + validate_path(G, s, t, length, path) + + +def test_bidirectional_dijkstra_restricted(): + XG = nx.DiGraph() + XG.add_weighted_edges_from( + [ + ("s", "u", 10), + ("s", "x", 5), + ("u", "v", 1), + ("u", "x", 2), + ("v", "y", 1), + ("x", "u", 3), + ("x", "v", 5), + ("x", "y", 2), + ("y", "s", 7), + ("y", "v", 6), + ] + ) + + XG3 = nx.Graph() + XG3.add_weighted_edges_from( + [[0, 1, 2], [1, 2, 12], [2, 3, 1], [3, 4, 5], [4, 5, 1], [5, 0, 10]] + ) + validate_length_path(XG, "s", "v", 9, *_bidirectional_dijkstra(XG, "s", "v")) + validate_length_path( + XG, "s", "v", 10, *_bidirectional_dijkstra(XG, "s", "v", ignore_nodes=["u"]) + ) + validate_length_path( + XG, + "s", + "v", + 11, + *_bidirectional_dijkstra(XG, "s", "v", ignore_edges=[("s", "x")]), + ) + pytest.raises( + nx.NetworkXNoPath, + _bidirectional_dijkstra, + XG, + "s", + "v", + ignore_nodes=["u"], + ignore_edges=[("s", "x")], + ) + validate_length_path(XG3, 0, 3, 15, *_bidirectional_dijkstra(XG3, 0, 3)) + validate_length_path( + XG3, 0, 3, 16, *_bidirectional_dijkstra(XG3, 0, 3, ignore_nodes=[1]) + ) + validate_length_path( + XG3, 0, 3, 16, *_bidirectional_dijkstra(XG3, 0, 3, ignore_edges=[(2, 3)]) + ) + pytest.raises( + nx.NetworkXNoPath, + _bidirectional_dijkstra, + XG3, + 0, + 3, + ignore_nodes=[1], + ignore_edges=[(5, 4)], + ) + + +def test_bidirectional_dijkstra_no_path(): + with pytest.raises(nx.NetworkXNoPath): + G = nx.Graph() + nx.add_path(G, [1, 2, 3]) + nx.add_path(G, [4, 5, 6]) + _bidirectional_dijkstra(G, 1, 6) + + +def test_bidirectional_dijkstra_ignore(): + G = nx.Graph() + nx.add_path(G, [1, 2, 10]) + nx.add_path(G, [1, 3, 10]) + pytest.raises(nx.NetworkXNoPath, _bidirectional_dijkstra, G, 1, 2, ignore_nodes=[1]) + pytest.raises(nx.NetworkXNoPath, _bidirectional_dijkstra, G, 1, 2, ignore_nodes=[2]) + pytest.raises( + nx.NetworkXNoPath, _bidirectional_dijkstra, G, 1, 2, ignore_nodes=[1, 2] + ) diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/test_smallworld.py b/lib/python3.10/site-packages/networkx/algorithms/tests/test_smallworld.py new file mode 100644 index 0000000000000000000000000000000000000000..d115dd99b796fc256341f1e8ff75fd4bc01b9b17 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/tests/test_smallworld.py @@ -0,0 +1,78 @@ +import pytest + +pytest.importorskip("numpy") + +import random + +import networkx as nx +from networkx import lattice_reference, omega, random_reference, sigma + +rng = 42 + + +def test_random_reference(): + G = nx.connected_watts_strogatz_graph(50, 6, 0.1, seed=rng) + Gr = random_reference(G, niter=1, seed=rng) + C = nx.average_clustering(G) + Cr = nx.average_clustering(Gr) + assert C > Cr + + with pytest.raises(nx.NetworkXError): + next(random_reference(nx.Graph())) + with pytest.raises(nx.NetworkXNotImplemented): + next(random_reference(nx.DiGraph())) + + H = nx.Graph(((0, 1), (2, 3))) + Hl = random_reference(H, niter=1, seed=rng) + + +def test_lattice_reference(): + G = nx.connected_watts_strogatz_graph(50, 6, 1, seed=rng) + Gl = lattice_reference(G, niter=1, seed=rng) + L = nx.average_shortest_path_length(G) + Ll = nx.average_shortest_path_length(Gl) + assert Ll > L + + pytest.raises(nx.NetworkXError, lattice_reference, nx.Graph()) + pytest.raises(nx.NetworkXNotImplemented, lattice_reference, nx.DiGraph()) + + H = nx.Graph(((0, 1), (2, 3))) + Hl = lattice_reference(H, niter=1) + + +def test_sigma(): + Gs = nx.connected_watts_strogatz_graph(50, 6, 0.1, seed=rng) + Gr = nx.connected_watts_strogatz_graph(50, 6, 1, seed=rng) + sigmas = sigma(Gs, niter=1, nrand=2, seed=rng) + sigmar = sigma(Gr, niter=1, nrand=2, seed=rng) + assert sigmar < sigmas + + +def test_omega(): + Gl = nx.connected_watts_strogatz_graph(50, 6, 0, seed=rng) + Gr = nx.connected_watts_strogatz_graph(50, 6, 1, seed=rng) + Gs = nx.connected_watts_strogatz_graph(50, 6, 0.1, seed=rng) + omegal = omega(Gl, niter=1, nrand=1, seed=rng) + omegar = omega(Gr, niter=1, nrand=1, seed=rng) + omegas = omega(Gs, niter=1, nrand=1, seed=rng) + assert omegal < omegas and omegas < omegar + + # Test that omega lies within the [-1, 1] bounds + G_barbell = nx.barbell_graph(5, 1) + G_karate = nx.karate_club_graph() + + omega_barbell = nx.omega(G_barbell) + omega_karate = nx.omega(G_karate, nrand=2) + + omegas = (omegal, omegar, omegas, omega_barbell, omega_karate) + + for o in omegas: + assert -1 <= o <= 1 + + +@pytest.mark.parametrize("f", (nx.random_reference, nx.lattice_reference)) +def test_graph_no_edges(f): + G = nx.Graph() + G.add_nodes_from([0, 1, 2, 3]) + with pytest.raises(nx.NetworkXError, match="Graph has fewer that 2 edges"): + f(G) diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/test_smetric.py b/lib/python3.10/site-packages/networkx/algorithms/tests/test_smetric.py new file mode 100644 index 0000000000000000000000000000000000000000..528dbc8d69bf5dca221c17cd16118cf3ba01a2a9 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/tests/test_smetric.py @@ -0,0 +1,8 @@ +import pytest + +import networkx as nx + + +def test_smetric(): + G = nx.Graph([(1, 2), (2, 3), (2, 4), (1, 4)]) + assert nx.s_metric(G) == 19.0 diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/test_sparsifiers.py b/lib/python3.10/site-packages/networkx/algorithms/tests/test_sparsifiers.py new file mode 100644 index 0000000000000000000000000000000000000000..e8604e61ae45aca9092226a793a02b082b126738 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/tests/test_sparsifiers.py @@ -0,0 +1,138 @@ +"""Unit tests for the sparsifier computation functions.""" + +import pytest + +import networkx as nx +from networkx.utils import py_random_state + +_seed = 2 + + +def _test_spanner(G, spanner, stretch, weight=None): + """Test whether a spanner is valid. + + This function tests whether the given spanner is a subgraph of the + given graph G with the same node set. It also tests for all shortest + paths whether they adhere to the given stretch. + + Parameters + ---------- + G : NetworkX graph + The original graph for which the spanner was constructed. + + spanner : NetworkX graph + The spanner to be tested. + + stretch : float + The proclaimed stretch of the spanner. + + weight : object + The edge attribute to use as distance. + """ + # check node set + assert set(G.nodes()) == set(spanner.nodes()) + + # check edge set and weights + for u, v in spanner.edges(): + assert G.has_edge(u, v) + if weight: + assert spanner[u][v][weight] == G[u][v][weight] + + # check connectivity and stretch + original_length = dict(nx.shortest_path_length(G, weight=weight)) + spanner_length = dict(nx.shortest_path_length(spanner, weight=weight)) + for u in G.nodes(): + for v in G.nodes(): + if u in original_length and v in original_length[u]: + assert spanner_length[u][v] <= stretch * original_length[u][v] + + +@py_random_state(1) +def _assign_random_weights(G, seed=None): + """Assigns random weights to the edges of a graph. + + Parameters + ---------- + + G : NetworkX graph + The original graph for which the spanner was constructed. + + seed : integer, random_state, or None (default) + Indicator of random number generation state. + See :ref:`Randomness`. + """ + for u, v in G.edges(): + G[u][v]["weight"] = seed.random() + + +def test_spanner_trivial(): + """Test a trivial spanner with stretch 1.""" + G = nx.complete_graph(20) + spanner = nx.spanner(G, 1, seed=_seed) + + for u, v in G.edges: + assert spanner.has_edge(u, v) + + +def test_spanner_unweighted_complete_graph(): + """Test spanner construction on a complete unweighted graph.""" + G = nx.complete_graph(20) + + spanner = nx.spanner(G, 4, seed=_seed) + _test_spanner(G, spanner, 4) + + spanner = nx.spanner(G, 10, seed=_seed) + _test_spanner(G, spanner, 10) + + +def test_spanner_weighted_complete_graph(): + """Test spanner construction on a complete weighted graph.""" + G = nx.complete_graph(20) + _assign_random_weights(G, seed=_seed) + + spanner = nx.spanner(G, 4, weight="weight", seed=_seed) + _test_spanner(G, spanner, 4, weight="weight") + + spanner = nx.spanner(G, 10, weight="weight", seed=_seed) + _test_spanner(G, spanner, 10, weight="weight") + + +def test_spanner_unweighted_gnp_graph(): + """Test spanner construction on an unweighted gnp graph.""" + G = nx.gnp_random_graph(20, 0.4, seed=_seed) + + spanner = nx.spanner(G, 4, seed=_seed) + _test_spanner(G, spanner, 4) + + spanner = nx.spanner(G, 10, seed=_seed) + _test_spanner(G, spanner, 10) + + +def test_spanner_weighted_gnp_graph(): + """Test spanner construction on an weighted gnp graph.""" + G = nx.gnp_random_graph(20, 0.4, seed=_seed) + _assign_random_weights(G, seed=_seed) + + spanner = nx.spanner(G, 4, weight="weight", seed=_seed) + _test_spanner(G, spanner, 4, weight="weight") + + spanner = nx.spanner(G, 10, weight="weight", seed=_seed) + _test_spanner(G, spanner, 10, weight="weight") + + +def test_spanner_unweighted_disconnected_graph(): + """Test spanner construction on a disconnected graph.""" + G = nx.disjoint_union(nx.complete_graph(10), nx.complete_graph(10)) + + spanner = nx.spanner(G, 4, seed=_seed) + _test_spanner(G, spanner, 4) + + spanner = nx.spanner(G, 10, seed=_seed) + _test_spanner(G, spanner, 10) + + +def test_spanner_invalid_stretch(): + """Check whether an invalid stretch is caught.""" + with pytest.raises(ValueError): + G = nx.empty_graph() + nx.spanner(G, 0) diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/test_structuralholes.py b/lib/python3.10/site-packages/networkx/algorithms/tests/test_structuralholes.py new file mode 100644 index 0000000000000000000000000000000000000000..1e5952b2347d22f7e25c055426d414ac26632c33 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/tests/test_structuralholes.py @@ -0,0 +1,137 @@ +"""Unit tests for the :mod:`networkx.algorithms.structuralholes` module.""" + +import math + +import pytest + +import networkx as nx +from networkx.classes.tests import dispatch_interface + + +class TestStructuralHoles: + """Unit tests for computing measures of structural holes. + + The expected values for these functions were originally computed using the + proprietary software `UCINET`_ and the free software `IGraph`_ , and then + computed by hand to make sure that the results are correct. + + .. _UCINET: https://sites.google.com/site/ucinetsoftware/home + .. _IGraph: http://igraph.org/ + + """ + + def setup_method(self): + self.D = nx.DiGraph() + self.D.add_edges_from([(0, 1), (0, 2), (1, 0), (2, 1)]) + self.D_weights = {(0, 1): 2, (0, 2): 2, (1, 0): 1, (2, 1): 1} + # Example from http://www.analytictech.com/connections/v20(1)/holes.htm + self.G = nx.Graph() + self.G.add_edges_from( + [ + ("A", "B"), + ("A", "F"), + ("A", "G"), + ("A", "E"), + ("E", "G"), + ("F", "G"), + ("B", "G"), + ("B", "D"), + ("D", "G"), + ("G", "C"), + ] + ) + self.G_weights = { + ("A", "B"): 2, + ("A", "F"): 3, + ("A", "G"): 5, + ("A", "E"): 2, + ("E", "G"): 8, + ("F", "G"): 3, + ("B", "G"): 4, + ("B", "D"): 1, + ("D", "G"): 3, + ("G", "C"): 10, + } + + def test_constraint_directed(self): + constraint = nx.constraint(self.D) + assert constraint[0] == pytest.approx(1.003, abs=1e-3) + assert constraint[1] == pytest.approx(1.003, abs=1e-3) + assert constraint[2] == pytest.approx(1.389, abs=1e-3) + + def test_effective_size_directed(self): + effective_size = nx.effective_size(self.D) + assert effective_size[0] == pytest.approx(1.167, abs=1e-3) + assert effective_size[1] == pytest.approx(1.167, abs=1e-3) + assert effective_size[2] == pytest.approx(1, abs=1e-3) + + def test_constraint_weighted_directed(self): + D = self.D.copy() + nx.set_edge_attributes(D, self.D_weights, "weight") + constraint = nx.constraint(D, weight="weight") + assert constraint[0] == pytest.approx(0.840, abs=1e-3) + assert constraint[1] == pytest.approx(1.143, abs=1e-3) + assert constraint[2] == pytest.approx(1.378, abs=1e-3) + + def test_effective_size_weighted_directed(self): + D = self.D.copy() + nx.set_edge_attributes(D, self.D_weights, "weight") + effective_size = nx.effective_size(D, weight="weight") + assert effective_size[0] == pytest.approx(1.567, abs=1e-3) + assert effective_size[1] == pytest.approx(1.083, abs=1e-3) + assert effective_size[2] == pytest.approx(1, abs=1e-3) + + def test_constraint_undirected(self): + constraint = nx.constraint(self.G) + assert constraint["G"] == pytest.approx(0.400, abs=1e-3) + assert constraint["A"] == pytest.approx(0.595, abs=1e-3) + assert constraint["C"] == pytest.approx(1, abs=1e-3) + + def test_effective_size_undirected_borgatti(self): + effective_size = nx.effective_size(self.G) + assert effective_size["G"] == pytest.approx(4.67, abs=1e-2) + assert effective_size["A"] == pytest.approx(2.50, abs=1e-2) + assert effective_size["C"] == pytest.approx(1, abs=1e-2) + + def test_effective_size_undirected(self): + G = self.G.copy() + nx.set_edge_attributes(G, 1, "weight") + effective_size = nx.effective_size(G, weight="weight") + assert effective_size["G"] == pytest.approx(4.67, abs=1e-2) + assert effective_size["A"] == pytest.approx(2.50, abs=1e-2) + assert effective_size["C"] == pytest.approx(1, abs=1e-2) + + def test_constraint_weighted_undirected(self): + G = self.G.copy() + nx.set_edge_attributes(G, self.G_weights, "weight") + constraint = nx.constraint(G, weight="weight") + assert constraint["G"] == pytest.approx(0.299, abs=1e-3) + assert constraint["A"] == pytest.approx(0.795, abs=1e-3) + assert constraint["C"] == pytest.approx(1, abs=1e-3) + + def test_effective_size_weighted_undirected(self): + G = self.G.copy() + nx.set_edge_attributes(G, self.G_weights, "weight") + effective_size = nx.effective_size(G, weight="weight") + assert effective_size["G"] == pytest.approx(5.47, abs=1e-2) + assert effective_size["A"] == pytest.approx(2.47, abs=1e-2) + assert effective_size["C"] == pytest.approx(1, abs=1e-2) + + def test_constraint_isolated(self): + G = self.G.copy() + G.add_node(1) + constraint = nx.constraint(G) + assert math.isnan(constraint[1]) + + def test_effective_size_isolated(self): + G = self.G.copy() + G.add_node(1) + nx.set_edge_attributes(G, self.G_weights, "weight") + effective_size = nx.effective_size(G, weight="weight") + assert math.isnan(effective_size[1]) + + def test_effective_size_borgatti_isolated(self): + G = self.G.copy() + G.add_node(1) + effective_size = nx.effective_size(G) + assert math.isnan(effective_size[1]) diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/test_summarization.py b/lib/python3.10/site-packages/networkx/algorithms/tests/test_summarization.py new file mode 100644 index 0000000000000000000000000000000000000000..c3bf82fa53b2564b13e555d994bd73b5885e1915 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/tests/test_summarization.py @@ -0,0 +1,642 @@ +""" +Unit tests for dedensification and graph summarization +""" + +import pytest + +import networkx as nx + + +class TestDirectedDedensification: + def build_original_graph(self): + original_matrix = [ + ("1", "BC"), + ("2", "ABC"), + ("3", ["A", "B", "6"]), + ("4", "ABC"), + ("5", "AB"), + ("6", ["5"]), + ("A", ["6"]), + ] + graph = nx.DiGraph() + for source, targets in original_matrix: + for target in targets: + graph.add_edge(source, target) + return graph + + def build_compressed_graph(self): + compressed_matrix = [ + ("1", "BC"), + ("2", ["ABC"]), + ("3", ["A", "B", "6"]), + ("4", ["ABC"]), + ("5", "AB"), + ("6", ["5"]), + ("A", ["6"]), + ("ABC", "ABC"), + ] + compressed_graph = nx.DiGraph() + for source, targets in compressed_matrix: + for target in targets: + compressed_graph.add_edge(source, target) + return compressed_graph + + def test_empty(self): + """ + Verify that an empty directed graph results in no compressor nodes + """ + G = nx.DiGraph() + compressed_graph, c_nodes = nx.dedensify(G, threshold=2) + assert c_nodes == set() + + @staticmethod + def densify(G, compressor_nodes, copy=True): + """ + Reconstructs the original graph from a dedensified, directed graph + + Parameters + ---------- + G: dedensified graph + A networkx graph + compressor_nodes: iterable + Iterable of compressor nodes in the dedensified graph + inplace: bool, optional (default: False) + Indicates if densification should be done inplace + + Returns + ------- + G: graph + A densified networkx graph + """ + if copy: + G = G.copy() + for compressor_node in compressor_nodes: + all_neighbors = set(nx.all_neighbors(G, compressor_node)) + out_neighbors = set(G.neighbors(compressor_node)) + for out_neighbor in out_neighbors: + G.remove_edge(compressor_node, out_neighbor) + in_neighbors = all_neighbors - out_neighbors + for in_neighbor in in_neighbors: + G.remove_edge(in_neighbor, compressor_node) + for out_neighbor in out_neighbors: + G.add_edge(in_neighbor, out_neighbor) + G.remove_node(compressor_node) + return G + + def setup_method(self): + self.c_nodes = ("ABC",) + + def test_dedensify_edges(self): + """ + Verifies that dedensify produced the correct edges to/from compressor + nodes in a directed graph + """ + G = self.build_original_graph() + compressed_G = self.build_compressed_graph() + compressed_graph, c_nodes = nx.dedensify(G, threshold=2) + for s, t in compressed_graph.edges(): + o_s = "".join(sorted(s)) + o_t = "".join(sorted(t)) + compressed_graph_exists = compressed_graph.has_edge(s, t) + verified_compressed_exists = compressed_G.has_edge(o_s, o_t) + assert compressed_graph_exists == verified_compressed_exists + assert len(c_nodes) == len(self.c_nodes) + + def test_dedensify_edge_count(self): + """ + Verifies that dedensify produced the correct number of compressor nodes + in a directed graph + """ + G = self.build_original_graph() + original_edge_count = len(G.edges()) + c_G, c_nodes = nx.dedensify(G, threshold=2) + compressed_edge_count = len(c_G.edges()) + assert compressed_edge_count <= original_edge_count + compressed_G = self.build_compressed_graph() + assert compressed_edge_count == len(compressed_G.edges()) + + def test_densify_edges(self): + """ + Verifies that densification produces the correct edges from the + original directed graph + """ + compressed_G = self.build_compressed_graph() + original_graph = self.densify(compressed_G, self.c_nodes, copy=True) + G = self.build_original_graph() + for s, t in G.edges(): + assert G.has_edge(s, t) == original_graph.has_edge(s, t) + + def test_densify_edge_count(self): + """ + Verifies that densification produces the correct number of edges in the + original directed graph + """ + compressed_G = self.build_compressed_graph() + compressed_edge_count = len(compressed_G.edges()) + original_graph = self.densify(compressed_G, self.c_nodes) + original_edge_count = len(original_graph.edges()) + assert compressed_edge_count <= original_edge_count + G = self.build_original_graph() + assert original_edge_count == len(G.edges()) + + +class TestUnDirectedDedensification: + def build_original_graph(self): + """ + Builds graph shown in the original research paper + """ + original_matrix = [ + ("1", "CB"), + ("2", "ABC"), + ("3", ["A", "B", "6"]), + ("4", "ABC"), + ("5", "AB"), + ("6", ["5"]), + ("A", ["6"]), + ] + graph = nx.Graph() + for source, targets in original_matrix: + for target in targets: + graph.add_edge(source, target) + return graph + + def test_empty(self): + """ + Verify that an empty undirected graph results in no compressor nodes + """ + G = nx.Graph() + compressed_G, c_nodes = nx.dedensify(G, threshold=2) + assert c_nodes == set() + + def setup_method(self): + self.c_nodes = ("6AB", "ABC") + + def build_compressed_graph(self): + compressed_matrix = [ + ("1", ["B", "C"]), + ("2", ["ABC"]), + ("3", ["6AB"]), + ("4", ["ABC"]), + ("5", ["6AB"]), + ("6", ["6AB", "A"]), + ("A", ["6AB", "ABC"]), + ("B", ["ABC", "6AB"]), + ("C", ["ABC"]), + ] + compressed_graph = nx.Graph() + for source, targets in compressed_matrix: + for target in targets: + compressed_graph.add_edge(source, target) + return compressed_graph + + def test_dedensify_edges(self): + """ + Verifies that dedensify produced correct compressor nodes and the + correct edges to/from the compressor nodes in an undirected graph + """ + G = self.build_original_graph() + c_G, c_nodes = nx.dedensify(G, threshold=2) + v_compressed_G = self.build_compressed_graph() + for s, t in c_G.edges(): + o_s = "".join(sorted(s)) + o_t = "".join(sorted(t)) + has_compressed_edge = c_G.has_edge(s, t) + verified_has_compressed_edge = v_compressed_G.has_edge(o_s, o_t) + assert has_compressed_edge == verified_has_compressed_edge + assert len(c_nodes) == len(self.c_nodes) + + def test_dedensify_edge_count(self): + """ + Verifies that dedensify produced the correct number of edges in an + undirected graph + """ + G = self.build_original_graph() + c_G, c_nodes = nx.dedensify(G, threshold=2, copy=True) + compressed_edge_count = len(c_G.edges()) + verified_original_edge_count = len(G.edges()) + assert compressed_edge_count <= verified_original_edge_count + verified_compressed_G = self.build_compressed_graph() + verified_compressed_edge_count = len(verified_compressed_G.edges()) + assert compressed_edge_count == verified_compressed_edge_count + + +@pytest.mark.parametrize( + "graph_type", [nx.Graph, nx.DiGraph, nx.MultiGraph, nx.MultiDiGraph] +) +def test_summarization_empty(graph_type): + G = graph_type() + summary_graph = nx.snap_aggregation(G, node_attributes=("color",)) + assert nx.is_isomorphic(summary_graph, G) + + +class AbstractSNAP: + node_attributes = ("color",) + + def build_original_graph(self): + pass + + def build_summary_graph(self): + pass + + def test_summary_graph(self): + original_graph = self.build_original_graph() + summary_graph = self.build_summary_graph() + + relationship_attributes = ("type",) + generated_summary_graph = nx.snap_aggregation( + original_graph, self.node_attributes, relationship_attributes + ) + relabeled_summary_graph = self.deterministic_labels(generated_summary_graph) + assert nx.is_isomorphic(summary_graph, relabeled_summary_graph) + + def deterministic_labels(self, G): + node_labels = list(G.nodes) + node_labels = sorted(node_labels, key=lambda n: sorted(G.nodes[n]["group"])[0]) + node_labels.sort() + + label_mapping = {} + for index, node in enumerate(node_labels): + label = f"Supernode-{index}" + label_mapping[node] = label + + return nx.relabel_nodes(G, label_mapping) + + +class TestSNAPNoEdgeTypes(AbstractSNAP): + relationship_attributes = () + + def test_summary_graph(self): + original_graph = self.build_original_graph() + summary_graph = self.build_summary_graph() + + relationship_attributes = ("type",) + generated_summary_graph = nx.snap_aggregation( + original_graph, self.node_attributes + ) + relabeled_summary_graph = self.deterministic_labels(generated_summary_graph) + assert nx.is_isomorphic(summary_graph, relabeled_summary_graph) + + def build_original_graph(self): + nodes = { + "A": {"color": "Red"}, + "B": {"color": "Red"}, + "C": {"color": "Red"}, + "D": {"color": "Red"}, + "E": {"color": "Blue"}, + "F": {"color": "Blue"}, + "G": {"color": "Blue"}, + "H": {"color": "Blue"}, + "I": {"color": "Yellow"}, + "J": {"color": "Yellow"}, + "K": {"color": "Yellow"}, + "L": {"color": "Yellow"}, + } + edges = [ + ("A", "B"), + ("A", "C"), + ("A", "E"), + ("A", "I"), + ("B", "D"), + ("B", "J"), + ("B", "F"), + ("C", "G"), + ("D", "H"), + ("I", "J"), + ("J", "K"), + ("I", "L"), + ] + G = nx.Graph() + for node in nodes: + attributes = nodes[node] + G.add_node(node, **attributes) + + for source, target in edges: + G.add_edge(source, target) + + return G + + def build_summary_graph(self): + nodes = { + "Supernode-0": {"color": "Red"}, + "Supernode-1": {"color": "Red"}, + "Supernode-2": {"color": "Blue"}, + "Supernode-3": {"color": "Blue"}, + "Supernode-4": {"color": "Yellow"}, + "Supernode-5": {"color": "Yellow"}, + } + edges = [ + ("Supernode-0", "Supernode-0"), + ("Supernode-0", "Supernode-1"), + ("Supernode-0", "Supernode-2"), + ("Supernode-0", "Supernode-4"), + ("Supernode-1", "Supernode-3"), + ("Supernode-4", "Supernode-4"), + ("Supernode-4", "Supernode-5"), + ] + G = nx.Graph() + for node in nodes: + attributes = nodes[node] + G.add_node(node, **attributes) + + for source, target in edges: + G.add_edge(source, target) + + supernodes = { + "Supernode-0": {"A", "B"}, + "Supernode-1": {"C", "D"}, + "Supernode-2": {"E", "F"}, + "Supernode-3": {"G", "H"}, + "Supernode-4": {"I", "J"}, + "Supernode-5": {"K", "L"}, + } + nx.set_node_attributes(G, supernodes, "group") + return G + + +class TestSNAPUndirected(AbstractSNAP): + def build_original_graph(self): + nodes = { + "A": {"color": "Red"}, + "B": {"color": "Red"}, + "C": {"color": "Red"}, + "D": {"color": "Red"}, + "E": {"color": "Blue"}, + "F": {"color": "Blue"}, + "G": {"color": "Blue"}, + "H": {"color": "Blue"}, + "I": {"color": "Yellow"}, + "J": {"color": "Yellow"}, + "K": {"color": "Yellow"}, + "L": {"color": "Yellow"}, + } + edges = [ + ("A", "B", "Strong"), + ("A", "C", "Weak"), + ("A", "E", "Strong"), + ("A", "I", "Weak"), + ("B", "D", "Weak"), + ("B", "J", "Weak"), + ("B", "F", "Strong"), + ("C", "G", "Weak"), + ("D", "H", "Weak"), + ("I", "J", "Strong"), + ("J", "K", "Strong"), + ("I", "L", "Strong"), + ] + G = nx.Graph() + for node in nodes: + attributes = nodes[node] + G.add_node(node, **attributes) + + for source, target, type in edges: + G.add_edge(source, target, type=type) + + return G + + def build_summary_graph(self): + nodes = { + "Supernode-0": {"color": "Red"}, + "Supernode-1": {"color": "Red"}, + "Supernode-2": {"color": "Blue"}, + "Supernode-3": {"color": "Blue"}, + "Supernode-4": {"color": "Yellow"}, + "Supernode-5": {"color": "Yellow"}, + } + edges = [ + ("Supernode-0", "Supernode-0", "Strong"), + ("Supernode-0", "Supernode-1", "Weak"), + ("Supernode-0", "Supernode-2", "Strong"), + ("Supernode-0", "Supernode-4", "Weak"), + ("Supernode-1", "Supernode-3", "Weak"), + ("Supernode-4", "Supernode-4", "Strong"), + ("Supernode-4", "Supernode-5", "Strong"), + ] + G = nx.Graph() + for node in nodes: + attributes = nodes[node] + G.add_node(node, **attributes) + + for source, target, type in edges: + G.add_edge(source, target, types=[{"type": type}]) + + supernodes = { + "Supernode-0": {"A", "B"}, + "Supernode-1": {"C", "D"}, + "Supernode-2": {"E", "F"}, + "Supernode-3": {"G", "H"}, + "Supernode-4": {"I", "J"}, + "Supernode-5": {"K", "L"}, + } + nx.set_node_attributes(G, supernodes, "group") + return G + + +class TestSNAPDirected(AbstractSNAP): + def build_original_graph(self): + nodes = { + "A": {"color": "Red"}, + "B": {"color": "Red"}, + "C": {"color": "Green"}, + "D": {"color": "Green"}, + "E": {"color": "Blue"}, + "F": {"color": "Blue"}, + "G": {"color": "Yellow"}, + "H": {"color": "Yellow"}, + } + edges = [ + ("A", "C", "Strong"), + ("A", "E", "Strong"), + ("A", "F", "Weak"), + ("B", "D", "Strong"), + ("B", "E", "Weak"), + ("B", "F", "Strong"), + ("C", "G", "Strong"), + ("C", "F", "Strong"), + ("D", "E", "Strong"), + ("D", "H", "Strong"), + ("G", "E", "Strong"), + ("H", "F", "Strong"), + ] + G = nx.DiGraph() + for node in nodes: + attributes = nodes[node] + G.add_node(node, **attributes) + + for source, target, type in edges: + G.add_edge(source, target, type=type) + + return G + + def build_summary_graph(self): + nodes = { + "Supernode-0": {"color": "Red"}, + "Supernode-1": {"color": "Green"}, + "Supernode-2": {"color": "Blue"}, + "Supernode-3": {"color": "Yellow"}, + } + edges = [ + ("Supernode-0", "Supernode-1", [{"type": "Strong"}]), + ("Supernode-0", "Supernode-2", [{"type": "Weak"}, {"type": "Strong"}]), + ("Supernode-1", "Supernode-2", [{"type": "Strong"}]), + ("Supernode-1", "Supernode-3", [{"type": "Strong"}]), + ("Supernode-3", "Supernode-2", [{"type": "Strong"}]), + ] + G = nx.DiGraph() + for node in nodes: + attributes = nodes[node] + G.add_node(node, **attributes) + + for source, target, types in edges: + G.add_edge(source, target, types=types) + + supernodes = { + "Supernode-0": {"A", "B"}, + "Supernode-1": {"C", "D"}, + "Supernode-2": {"E", "F"}, + "Supernode-3": {"G", "H"}, + "Supernode-4": {"I", "J"}, + "Supernode-5": {"K", "L"}, + } + nx.set_node_attributes(G, supernodes, "group") + return G + + +class TestSNAPUndirectedMulti(AbstractSNAP): + def build_original_graph(self): + nodes = { + "A": {"color": "Red"}, + "B": {"color": "Red"}, + "C": {"color": "Red"}, + "D": {"color": "Blue"}, + "E": {"color": "Blue"}, + "F": {"color": "Blue"}, + "G": {"color": "Yellow"}, + "H": {"color": "Yellow"}, + "I": {"color": "Yellow"}, + } + edges = [ + ("A", "D", ["Weak", "Strong"]), + ("B", "E", ["Weak", "Strong"]), + ("D", "I", ["Strong"]), + ("E", "H", ["Strong"]), + ("F", "G", ["Weak"]), + ("I", "G", ["Weak", "Strong"]), + ("I", "H", ["Weak", "Strong"]), + ("G", "H", ["Weak", "Strong"]), + ] + G = nx.MultiGraph() + for node in nodes: + attributes = nodes[node] + G.add_node(node, **attributes) + + for source, target, types in edges: + for type in types: + G.add_edge(source, target, type=type) + + return G + + def build_summary_graph(self): + nodes = { + "Supernode-0": {"color": "Red"}, + "Supernode-1": {"color": "Blue"}, + "Supernode-2": {"color": "Yellow"}, + "Supernode-3": {"color": "Blue"}, + "Supernode-4": {"color": "Yellow"}, + "Supernode-5": {"color": "Red"}, + } + edges = [ + ("Supernode-1", "Supernode-2", [{"type": "Weak"}]), + ("Supernode-2", "Supernode-4", [{"type": "Weak"}, {"type": "Strong"}]), + ("Supernode-3", "Supernode-4", [{"type": "Strong"}]), + ("Supernode-3", "Supernode-5", [{"type": "Weak"}, {"type": "Strong"}]), + ("Supernode-4", "Supernode-4", [{"type": "Weak"}, {"type": "Strong"}]), + ] + G = nx.MultiGraph() + for node in nodes: + attributes = nodes[node] + G.add_node(node, **attributes) + + for source, target, types in edges: + for type in types: + G.add_edge(source, target, type=type) + + supernodes = { + "Supernode-0": {"A", "B"}, + "Supernode-1": {"C", "D"}, + "Supernode-2": {"E", "F"}, + "Supernode-3": {"G", "H"}, + "Supernode-4": {"I", "J"}, + "Supernode-5": {"K", "L"}, + } + nx.set_node_attributes(G, supernodes, "group") + return G + + +class TestSNAPDirectedMulti(AbstractSNAP): + def build_original_graph(self): + nodes = { + "A": {"color": "Red"}, + "B": {"color": "Red"}, + "C": {"color": "Green"}, + "D": {"color": "Green"}, + "E": {"color": "Blue"}, + "F": {"color": "Blue"}, + "G": {"color": "Yellow"}, + "H": {"color": "Yellow"}, + } + edges = [ + ("A", "C", ["Weak", "Strong"]), + ("A", "E", ["Strong"]), + ("A", "F", ["Weak"]), + ("B", "D", ["Weak", "Strong"]), + ("B", "E", ["Weak"]), + ("B", "F", ["Strong"]), + ("C", "G", ["Weak", "Strong"]), + ("C", "F", ["Strong"]), + ("D", "E", ["Strong"]), + ("D", "H", ["Weak", "Strong"]), + ("G", "E", ["Strong"]), + ("H", "F", ["Strong"]), + ] + G = nx.MultiDiGraph() + for node in nodes: + attributes = nodes[node] + G.add_node(node, **attributes) + + for source, target, types in edges: + for type in types: + G.add_edge(source, target, type=type) + + return G + + def build_summary_graph(self): + nodes = { + "Supernode-0": {"color": "Red"}, + "Supernode-1": {"color": "Blue"}, + "Supernode-2": {"color": "Yellow"}, + "Supernode-3": {"color": "Blue"}, + } + edges = [ + ("Supernode-0", "Supernode-1", ["Weak", "Strong"]), + ("Supernode-0", "Supernode-2", ["Weak", "Strong"]), + ("Supernode-1", "Supernode-2", ["Strong"]), + ("Supernode-1", "Supernode-3", ["Weak", "Strong"]), + ("Supernode-3", "Supernode-2", ["Strong"]), + ] + G = nx.MultiDiGraph() + for node in nodes: + attributes = nodes[node] + G.add_node(node, **attributes) + + for source, target, types in edges: + for type in types: + G.add_edge(source, target, type=type) + + supernodes = { + "Supernode-0": {"A", "B"}, + "Supernode-1": {"C", "D"}, + "Supernode-2": {"E", "F"}, + "Supernode-3": {"G", "H"}, + } + nx.set_node_attributes(G, supernodes, "group") + return G diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/test_swap.py b/lib/python3.10/site-packages/networkx/algorithms/tests/test_swap.py new file mode 100644 index 0000000000000000000000000000000000000000..e765bd5e11496841072990aa792b90ca8772b4d3 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/tests/test_swap.py @@ -0,0 +1,179 @@ +import pytest + +import networkx as nx + +cycle = nx.cycle_graph(5, create_using=nx.DiGraph) +tree = nx.DiGraph() +tree.add_edges_from(nx.random_labeled_tree(10, seed=42).edges) +path = nx.path_graph(5, create_using=nx.DiGraph) +binomial = nx.binomial_tree(3, create_using=nx.DiGraph) +HH = nx.directed_havel_hakimi_graph([1, 2, 1, 2, 2, 2], [3, 1, 0, 1, 2, 3]) +balanced_tree = nx.balanced_tree(2, 3, create_using=nx.DiGraph) + + +@pytest.mark.parametrize("G", [path, binomial, HH, cycle, tree, balanced_tree]) +def test_directed_edge_swap(G): + in_degree = set(G.in_degree) + out_degree = set(G.out_degree) + edges = set(G.edges) + nx.directed_edge_swap(G, nswap=1, max_tries=100, seed=1) + assert in_degree == set(G.in_degree) + assert out_degree == set(G.out_degree) + assert edges != set(G.edges) + assert 3 == sum(e not in edges for e in G.edges) + + +def test_directed_edge_swap_undo_previous_swap(): + G = nx.DiGraph(nx.path_graph(4).edges) # only 1 swap possible + edges = set(G.edges) + nx.directed_edge_swap(G, nswap=2, max_tries=100) + assert edges == set(G.edges) + + nx.directed_edge_swap(G, nswap=1, max_tries=100, seed=1) + assert {(0, 2), (1, 3), (2, 1)} == set(G.edges) + nx.directed_edge_swap(G, nswap=1, max_tries=100, seed=1) + assert edges == set(G.edges) + + +def test_edge_cases_directed_edge_swap(): + # Tests cases when swaps are impossible, either too few edges exist, or self loops/cycles are unavoidable + # TODO: Rewrite function to explicitly check for impossible swaps and raise error + e = ( + "Maximum number of swap attempts \\(11\\) exceeded " + "before desired swaps achieved \\(\\d\\)." + ) + graph = nx.DiGraph([(0, 0), (0, 1), (1, 0), (2, 3), (3, 2)]) + with pytest.raises(nx.NetworkXAlgorithmError, match=e): + nx.directed_edge_swap(graph, nswap=1, max_tries=10, seed=1) + + +def test_double_edge_swap(): + graph = nx.barabasi_albert_graph(200, 1) + degrees = sorted(d for n, d in graph.degree()) + G = nx.double_edge_swap(graph, 40) + assert degrees == sorted(d for n, d in graph.degree()) + + +def test_double_edge_swap_seed(): + graph = nx.barabasi_albert_graph(200, 1) + degrees = sorted(d for n, d in graph.degree()) + G = nx.double_edge_swap(graph, 40, seed=1) + assert degrees == sorted(d for n, d in graph.degree()) + + +def test_connected_double_edge_swap(): + graph = nx.barabasi_albert_graph(200, 1) + degrees = sorted(d for n, d in graph.degree()) + G = nx.connected_double_edge_swap(graph, 40, seed=1) + assert nx.is_connected(graph) + assert degrees == sorted(d for n, d in graph.degree()) + + +def test_connected_double_edge_swap_low_window_threshold(): + graph = nx.barabasi_albert_graph(200, 1) + degrees = sorted(d for n, d in graph.degree()) + G = nx.connected_double_edge_swap(graph, 40, _window_threshold=0, seed=1) + assert nx.is_connected(graph) + assert degrees == sorted(d for n, d in graph.degree()) + + +def test_connected_double_edge_swap_star(): + # Testing ui==xi in connected_double_edge_swap + graph = nx.star_graph(40) + degrees = sorted(d for n, d in graph.degree()) + G = nx.connected_double_edge_swap(graph, 1, seed=4) + assert nx.is_connected(graph) + assert degrees == sorted(d for n, d in graph.degree()) + + +def test_connected_double_edge_swap_star_low_window_threshold(): + # Testing ui==xi in connected_double_edge_swap with low window threshold + graph = nx.star_graph(40) + degrees = sorted(d for n, d in graph.degree()) + G = nx.connected_double_edge_swap(graph, 1, _window_threshold=0, seed=4) + assert nx.is_connected(graph) + assert degrees == sorted(d for n, d in graph.degree()) + + +def test_directed_edge_swap_small(): + with pytest.raises(nx.NetworkXError): + G = nx.directed_edge_swap(nx.path_graph(3, create_using=nx.DiGraph)) + + +def test_directed_edge_swap_tries(): + with pytest.raises(nx.NetworkXError): + G = nx.directed_edge_swap( + nx.path_graph(3, create_using=nx.DiGraph), nswap=1, max_tries=0 + ) + + +def test_directed_exception_undirected(): + graph = nx.Graph([(0, 1), (2, 3)]) + with pytest.raises(nx.NetworkXNotImplemented): + G = nx.directed_edge_swap(graph) + + +def test_directed_edge_max_tries(): + with pytest.raises(nx.NetworkXAlgorithmError): + G = nx.directed_edge_swap( + nx.complete_graph(4, nx.DiGraph()), nswap=1, max_tries=5 + ) + + +def test_double_edge_swap_small(): + with pytest.raises(nx.NetworkXError): + G = nx.double_edge_swap(nx.path_graph(3)) + + +def test_double_edge_swap_tries(): + with pytest.raises(nx.NetworkXError): + G = nx.double_edge_swap(nx.path_graph(10), nswap=1, max_tries=0) + + +def test_double_edge_directed(): + graph = nx.DiGraph([(0, 1), (2, 3)]) + with pytest.raises(nx.NetworkXError, match="not defined for directed graphs."): + G = nx.double_edge_swap(graph) + + +def test_double_edge_max_tries(): + with pytest.raises(nx.NetworkXAlgorithmError): + G = nx.double_edge_swap(nx.complete_graph(4), nswap=1, max_tries=5) + + +def test_connected_double_edge_swap_small(): + with pytest.raises(nx.NetworkXError): + G = nx.connected_double_edge_swap(nx.path_graph(3)) + + +def test_connected_double_edge_swap_not_connected(): + with pytest.raises(nx.NetworkXError): + G = nx.path_graph(3) + nx.add_path(G, [10, 11, 12]) + G = nx.connected_double_edge_swap(G) + + +def test_degree_seq_c4(): + G = nx.cycle_graph(4) + degrees = sorted(d for n, d in G.degree()) + G = nx.double_edge_swap(G, 1, 100) + assert degrees == sorted(d for n, d in G.degree()) + + +def test_fewer_than_4_nodes(): + G = nx.DiGraph() + G.add_nodes_from([0, 1, 2]) + with pytest.raises(nx.NetworkXError, match=".*fewer than four nodes."): + nx.directed_edge_swap(G) + + +def test_less_than_3_edges(): + G = nx.DiGraph([(0, 1), (1, 2)]) + G.add_nodes_from([3, 4]) + with pytest.raises(nx.NetworkXError, match=".*fewer than 3 edges"): + nx.directed_edge_swap(G) + + G = nx.Graph() + G.add_nodes_from([0, 1, 2, 3]) + with pytest.raises(nx.NetworkXError, match=".*fewer than 2 edges"): + nx.double_edge_swap(G) diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/test_threshold.py b/lib/python3.10/site-packages/networkx/algorithms/tests/test_threshold.py new file mode 100644 index 0000000000000000000000000000000000000000..07aad44bb268a42944260b4217bce15b1278ebfd --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/tests/test_threshold.py @@ -0,0 +1,269 @@ +""" +Threshold Graphs +================ +""" + +import pytest + +import networkx as nx +import networkx.algorithms.threshold as nxt +from networkx.algorithms.isomorphism.isomorph import graph_could_be_isomorphic + +cnlti = nx.convert_node_labels_to_integers + + +class TestGeneratorThreshold: + def test_threshold_sequence_graph_test(self): + G = nx.star_graph(10) + assert nxt.is_threshold_graph(G) + assert nxt.is_threshold_sequence([d for n, d in G.degree()]) + + G = nx.complete_graph(10) + assert nxt.is_threshold_graph(G) + assert nxt.is_threshold_sequence([d for n, d in G.degree()]) + + deg = [3, 2, 2, 1, 1, 1] + assert not nxt.is_threshold_sequence(deg) + + deg = [3, 2, 2, 1] + assert nxt.is_threshold_sequence(deg) + + G = nx.generators.havel_hakimi_graph(deg) + assert nxt.is_threshold_graph(G) + + def test_creation_sequences(self): + deg = [3, 2, 2, 1] + G = nx.generators.havel_hakimi_graph(deg) + + with pytest.raises(ValueError): + nxt.creation_sequence(deg, with_labels=True, compact=True) + + cs0 = nxt.creation_sequence(deg) + H0 = nxt.threshold_graph(cs0) + assert "".join(cs0) == "ddid" + + cs1 = nxt.creation_sequence(deg, with_labels=True) + H1 = nxt.threshold_graph(cs1) + assert cs1 == [(1, "d"), (2, "d"), (3, "i"), (0, "d")] + + cs2 = nxt.creation_sequence(deg, compact=True) + H2 = nxt.threshold_graph(cs2) + assert cs2 == [2, 1, 1] + assert "".join(nxt.uncompact(cs2)) == "ddid" + assert graph_could_be_isomorphic(H0, G) + assert graph_could_be_isomorphic(H0, H1) + assert graph_could_be_isomorphic(H0, H2) + + def test_make_compact(self): + assert nxt.make_compact(["d", "d", "d", "i", "d", "d"]) == [3, 1, 2] + assert nxt.make_compact([3, 1, 2]) == [3, 1, 2] + assert pytest.raises(TypeError, nxt.make_compact, [3.0, 1.0, 2.0]) + + def test_uncompact(self): + assert nxt.uncompact([3, 1, 2]) == ["d", "d", "d", "i", "d", "d"] + assert nxt.uncompact(["d", "d", "i", "d"]) == ["d", "d", "i", "d"] + assert nxt.uncompact( + nxt.uncompact([(1, "d"), (2, "d"), (3, "i"), (0, "d")]) + ) == nxt.uncompact([(1, "d"), (2, "d"), (3, "i"), (0, "d")]) + assert pytest.raises(TypeError, nxt.uncompact, [3.0, 1.0, 2.0]) + + def test_creation_sequence_to_weights(self): + assert nxt.creation_sequence_to_weights([3, 1, 2]) == [ + 0.5, + 0.5, + 0.5, + 0.25, + 0.75, + 0.75, + ] + assert pytest.raises( + TypeError, nxt.creation_sequence_to_weights, [3.0, 1.0, 2.0] + ) + + def test_weights_to_creation_sequence(self): + deg = [3, 2, 2, 1] + with pytest.raises(ValueError): + nxt.weights_to_creation_sequence(deg, with_labels=True, compact=True) + assert nxt.weights_to_creation_sequence(deg, with_labels=True) == [ + (3, "d"), + (1, "d"), + (2, "d"), + (0, "d"), + ] + assert nxt.weights_to_creation_sequence(deg, compact=True) == [4] + + def test_find_alternating_4_cycle(self): + G = nx.Graph() + G.add_edge(1, 2) + assert not nxt.find_alternating_4_cycle(G) + + def test_shortest_path(self): + deg = [3, 2, 2, 1] + G = nx.generators.havel_hakimi_graph(deg) + cs1 = nxt.creation_sequence(deg, with_labels=True) + for n, m in [(3, 0), (0, 3), (0, 2), (0, 1), (1, 3), (3, 1), (1, 2), (2, 3)]: + assert nxt.shortest_path(cs1, n, m) == nx.shortest_path(G, n, m) + + spl = nxt.shortest_path_length(cs1, 3) + spl2 = nxt.shortest_path_length([t for v, t in cs1], 2) + assert spl == spl2 + + spld = {} + for j, pl in enumerate(spl): + n = cs1[j][0] + spld[n] = pl + assert spld == nx.single_source_shortest_path_length(G, 3) + + assert nxt.shortest_path(["d", "d", "d", "i", "d", "d"], 1, 2) == [1, 2] + assert nxt.shortest_path([3, 1, 2], 1, 2) == [1, 2] + assert pytest.raises(TypeError, nxt.shortest_path, [3.0, 1.0, 2.0], 1, 2) + assert pytest.raises(ValueError, nxt.shortest_path, [3, 1, 2], "a", 2) + assert pytest.raises(ValueError, nxt.shortest_path, [3, 1, 2], 1, "b") + assert nxt.shortest_path([3, 1, 2], 1, 1) == [1] + + def test_shortest_path_length(self): + assert nxt.shortest_path_length([3, 1, 2], 1) == [1, 0, 1, 2, 1, 1] + assert nxt.shortest_path_length(["d", "d", "d", "i", "d", "d"], 1) == [ + 1, + 0, + 1, + 2, + 1, + 1, + ] + assert nxt.shortest_path_length(("d", "d", "d", "i", "d", "d"), 1) == [ + 1, + 0, + 1, + 2, + 1, + 1, + ] + assert pytest.raises(TypeError, nxt.shortest_path, [3.0, 1.0, 2.0], 1) + + def test_random_threshold_sequence(self): + assert len(nxt.random_threshold_sequence(10, 0.5)) == 10 + assert nxt.random_threshold_sequence(10, 0.5, seed=42) == [ + "d", + "i", + "d", + "d", + "d", + "i", + "i", + "i", + "d", + "d", + ] + assert pytest.raises(ValueError, nxt.random_threshold_sequence, 10, 1.5) + + def test_right_d_threshold_sequence(self): + assert nxt.right_d_threshold_sequence(3, 2) == ["d", "i", "d"] + assert pytest.raises(ValueError, nxt.right_d_threshold_sequence, 2, 3) + + def test_left_d_threshold_sequence(self): + assert nxt.left_d_threshold_sequence(3, 2) == ["d", "i", "d"] + assert pytest.raises(ValueError, nxt.left_d_threshold_sequence, 2, 3) + + def test_weights_thresholds(self): + wseq = [3, 4, 3, 3, 5, 6, 5, 4, 5, 6] + cs = nxt.weights_to_creation_sequence(wseq, threshold=10) + wseq = nxt.creation_sequence_to_weights(cs) + cs2 = nxt.weights_to_creation_sequence(wseq) + assert cs == cs2 + + wseq = nxt.creation_sequence_to_weights(nxt.uncompact([3, 1, 2, 3, 3, 2, 3])) + assert wseq == [ + s * 0.125 for s in [4, 4, 4, 3, 5, 5, 2, 2, 2, 6, 6, 6, 1, 1, 7, 7, 7] + ] + + wseq = nxt.creation_sequence_to_weights([3, 1, 2, 3, 3, 2, 3]) + assert wseq == [ + s * 0.125 for s in [4, 4, 4, 3, 5, 5, 2, 2, 2, 6, 6, 6, 1, 1, 7, 7, 7] + ] + + wseq = nxt.creation_sequence_to_weights(list(enumerate("ddidiiidididi"))) + assert wseq == [s * 0.1 for s in [5, 5, 4, 6, 3, 3, 3, 7, 2, 8, 1, 9, 0]] + + wseq = nxt.creation_sequence_to_weights("ddidiiidididi") + assert wseq == [s * 0.1 for s in [5, 5, 4, 6, 3, 3, 3, 7, 2, 8, 1, 9, 0]] + + wseq = nxt.creation_sequence_to_weights("ddidiiidididid") + ws = [s / 12 for s in [6, 6, 5, 7, 4, 4, 4, 8, 3, 9, 2, 10, 1, 11]] + assert sum(abs(c - d) for c, d in zip(wseq, ws)) < 1e-14 + + def test_finding_routines(self): + G = nx.Graph({1: [2], 2: [3], 3: [4], 4: [5], 5: [6]}) + G.add_edge(2, 4) + G.add_edge(2, 5) + G.add_edge(2, 7) + G.add_edge(3, 6) + G.add_edge(4, 6) + + # Alternating 4 cycle + assert nxt.find_alternating_4_cycle(G) == [1, 2, 3, 6] + + # Threshold graph + TG = nxt.find_threshold_graph(G) + assert nxt.is_threshold_graph(TG) + assert sorted(TG.nodes()) == [1, 2, 3, 4, 5, 7] + + cs = nxt.creation_sequence(dict(TG.degree()), with_labels=True) + assert nxt.find_creation_sequence(G) == cs + + def test_fast_versions_properties_threshold_graphs(self): + cs = "ddiiddid" + G = nxt.threshold_graph(cs) + assert nxt.density("ddiiddid") == nx.density(G) + assert sorted(nxt.degree_sequence(cs)) == sorted(d for n, d in G.degree()) + + ts = nxt.triangle_sequence(cs) + assert ts == list(nx.triangles(G).values()) + assert sum(ts) // 3 == nxt.triangles(cs) + + c1 = nxt.cluster_sequence(cs) + c2 = list(nx.clustering(G).values()) + assert sum(abs(c - d) for c, d in zip(c1, c2)) == pytest.approx(0, abs=1e-7) + + b1 = nx.betweenness_centrality(G).values() + b2 = nxt.betweenness_sequence(cs) + assert sum(abs(c - d) for c, d in zip(b1, b2)) < 1e-7 + + assert nxt.eigenvalues(cs) == [0, 1, 3, 3, 5, 7, 7, 8] + + # Degree Correlation + assert abs(nxt.degree_correlation(cs) + 0.593038821954) < 1e-12 + assert nxt.degree_correlation("diiiddi") == -0.8 + assert nxt.degree_correlation("did") == -1.0 + assert nxt.degree_correlation("ddd") == 1.0 + assert nxt.eigenvalues("dddiii") == [0, 0, 0, 0, 3, 3] + assert nxt.eigenvalues("dddiiid") == [0, 1, 1, 1, 4, 4, 7] + + def test_tg_creation_routines(self): + s = nxt.left_d_threshold_sequence(5, 7) + s = nxt.right_d_threshold_sequence(5, 7) + s1 = nxt.swap_d(s, 1.0, 1.0) + s1 = nxt.swap_d(s, 1.0, 1.0, seed=1) + + def test_eigenvectors(self): + np = pytest.importorskip("numpy") + eigenval = np.linalg.eigvals + pytest.importorskip("scipy") + + cs = "ddiiddid" + G = nxt.threshold_graph(cs) + (tgeval, tgevec) = nxt.eigenvectors(cs) + np.testing.assert_allclose([np.dot(lv, lv) for lv in tgevec], 1.0, rtol=1e-9) + lapl = nx.laplacian_matrix(G) + + def test_create_using(self): + cs = "ddiiddid" + G = nxt.threshold_graph(cs) + assert pytest.raises( + nx.exception.NetworkXError, + nxt.threshold_graph, + cs, + create_using=nx.DiGraph(), + ) + MG = nxt.threshold_graph(cs, create_using=nx.MultiGraph()) + assert sorted(MG.edges()) == sorted(G.edges()) diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/test_time_dependent.py b/lib/python3.10/site-packages/networkx/algorithms/tests/test_time_dependent.py new file mode 100644 index 0000000000000000000000000000000000000000..1e256f4bc69389464cfa164f209bc2db713b79ee --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/tests/test_time_dependent.py @@ -0,0 +1,431 @@ +"""Unit testing for time dependent algorithms.""" + +from datetime import datetime, timedelta + +import pytest + +import networkx as nx + +_delta = timedelta(days=5 * 365) + + +class TestCdIndex: + """Unit testing for the cd index function.""" + + def test_common_graph(self): + G = nx.DiGraph() + G.add_nodes_from([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + G.add_edge(4, 2) + G.add_edge(4, 0) + G.add_edge(4, 1) + G.add_edge(4, 3) + G.add_edge(5, 2) + G.add_edge(6, 2) + G.add_edge(6, 4) + G.add_edge(7, 4) + G.add_edge(8, 4) + G.add_edge(9, 4) + G.add_edge(9, 1) + G.add_edge(9, 3) + G.add_edge(10, 4) + + node_attrs = { + 0: {"time": datetime(1992, 1, 1)}, + 1: {"time": datetime(1992, 1, 1)}, + 2: {"time": datetime(1993, 1, 1)}, + 3: {"time": datetime(1993, 1, 1)}, + 4: {"time": datetime(1995, 1, 1)}, + 5: {"time": datetime(1997, 1, 1)}, + 6: {"time": datetime(1998, 1, 1)}, + 7: {"time": datetime(1999, 1, 1)}, + 8: {"time": datetime(1999, 1, 1)}, + 9: {"time": datetime(1998, 1, 1)}, + 10: {"time": datetime(1997, 4, 1)}, + } + + nx.set_node_attributes(G, node_attrs) + + assert nx.cd_index(G, 4, time_delta=_delta) == 0.17 + + def test_common_graph_with_given_attributes(self): + G = nx.DiGraph() + G.add_nodes_from([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + G.add_edge(4, 2) + G.add_edge(4, 0) + G.add_edge(4, 1) + G.add_edge(4, 3) + G.add_edge(5, 2) + G.add_edge(6, 2) + G.add_edge(6, 4) + G.add_edge(7, 4) + G.add_edge(8, 4) + G.add_edge(9, 4) + G.add_edge(9, 1) + G.add_edge(9, 3) + G.add_edge(10, 4) + + node_attrs = { + 0: {"date": datetime(1992, 1, 1)}, + 1: {"date": datetime(1992, 1, 1)}, + 2: {"date": datetime(1993, 1, 1)}, + 3: {"date": datetime(1993, 1, 1)}, + 4: {"date": datetime(1995, 1, 1)}, + 5: {"date": datetime(1997, 1, 1)}, + 6: {"date": datetime(1998, 1, 1)}, + 7: {"date": datetime(1999, 1, 1)}, + 8: {"date": datetime(1999, 1, 1)}, + 9: {"date": datetime(1998, 1, 1)}, + 10: {"date": datetime(1997, 4, 1)}, + } + + nx.set_node_attributes(G, node_attrs) + + assert nx.cd_index(G, 4, time_delta=_delta, time="date") == 0.17 + + def test_common_graph_with_int_attributes(self): + G = nx.DiGraph() + G.add_nodes_from([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + G.add_edge(4, 2) + G.add_edge(4, 0) + G.add_edge(4, 1) + G.add_edge(4, 3) + G.add_edge(5, 2) + G.add_edge(6, 2) + G.add_edge(6, 4) + G.add_edge(7, 4) + G.add_edge(8, 4) + G.add_edge(9, 4) + G.add_edge(9, 1) + G.add_edge(9, 3) + G.add_edge(10, 4) + + node_attrs = { + 0: {"time": 20}, + 1: {"time": 20}, + 2: {"time": 30}, + 3: {"time": 30}, + 4: {"time": 50}, + 5: {"time": 70}, + 6: {"time": 80}, + 7: {"time": 90}, + 8: {"time": 90}, + 9: {"time": 80}, + 10: {"time": 74}, + } + + nx.set_node_attributes(G, node_attrs) + + assert nx.cd_index(G, 4, time_delta=50) == 0.17 + + def test_common_graph_with_float_attributes(self): + G = nx.DiGraph() + G.add_nodes_from([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + G.add_edge(4, 2) + G.add_edge(4, 0) + G.add_edge(4, 1) + G.add_edge(4, 3) + G.add_edge(5, 2) + G.add_edge(6, 2) + G.add_edge(6, 4) + G.add_edge(7, 4) + G.add_edge(8, 4) + G.add_edge(9, 4) + G.add_edge(9, 1) + G.add_edge(9, 3) + G.add_edge(10, 4) + + node_attrs = { + 0: {"time": 20.2}, + 1: {"time": 20.2}, + 2: {"time": 30.7}, + 3: {"time": 30.7}, + 4: {"time": 50.9}, + 5: {"time": 70.1}, + 6: {"time": 80.6}, + 7: {"time": 90.7}, + 8: {"time": 90.7}, + 9: {"time": 80.6}, + 10: {"time": 74.2}, + } + + nx.set_node_attributes(G, node_attrs) + + assert nx.cd_index(G, 4, time_delta=50) == 0.17 + + def test_common_graph_with_weights(self): + G = nx.DiGraph() + G.add_nodes_from([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + G.add_edge(4, 2) + G.add_edge(4, 0) + G.add_edge(4, 1) + G.add_edge(4, 3) + G.add_edge(5, 2) + G.add_edge(6, 2) + G.add_edge(6, 4) + G.add_edge(7, 4) + G.add_edge(8, 4) + G.add_edge(9, 4) + G.add_edge(9, 1) + G.add_edge(9, 3) + G.add_edge(10, 4) + + node_attrs = { + 0: {"time": datetime(1992, 1, 1)}, + 1: {"time": datetime(1992, 1, 1)}, + 2: {"time": datetime(1993, 1, 1)}, + 3: {"time": datetime(1993, 1, 1)}, + 4: {"time": datetime(1995, 1, 1)}, + 5: {"time": datetime(1997, 1, 1)}, + 6: {"time": datetime(1998, 1, 1), "weight": 5}, + 7: {"time": datetime(1999, 1, 1), "weight": 2}, + 8: {"time": datetime(1999, 1, 1), "weight": 6}, + 9: {"time": datetime(1998, 1, 1), "weight": 3}, + 10: {"time": datetime(1997, 4, 1), "weight": 10}, + } + + nx.set_node_attributes(G, node_attrs) + assert nx.cd_index(G, 4, time_delta=_delta, weight="weight") == 0.04 + + def test_node_with_no_predecessors(self): + G = nx.DiGraph() + G.add_nodes_from([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + G.add_edge(4, 2) + G.add_edge(4, 0) + G.add_edge(4, 3) + G.add_edge(5, 2) + G.add_edge(6, 2) + G.add_edge(6, 4) + G.add_edge(7, 4) + G.add_edge(8, 4) + G.add_edge(9, 4) + G.add_edge(9, 1) + G.add_edge(9, 3) + G.add_edge(10, 4) + + node_attrs = { + 0: {"time": datetime(1992, 1, 1)}, + 1: {"time": datetime(1992, 1, 1)}, + 2: {"time": datetime(1993, 1, 1)}, + 3: {"time": datetime(1993, 1, 1)}, + 4: {"time": datetime(1995, 1, 1)}, + 5: {"time": datetime(2005, 1, 1)}, + 6: {"time": datetime(2010, 1, 1)}, + 7: {"time": datetime(2001, 1, 1)}, + 8: {"time": datetime(2020, 1, 1)}, + 9: {"time": datetime(2017, 1, 1)}, + 10: {"time": datetime(2004, 4, 1)}, + } + + nx.set_node_attributes(G, node_attrs) + assert nx.cd_index(G, 4, time_delta=_delta) == 0.0 + + def test_node_with_no_successors(self): + G = nx.DiGraph() + G.add_nodes_from([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + G.add_edge(8, 2) + G.add_edge(6, 0) + G.add_edge(6, 3) + G.add_edge(5, 2) + G.add_edge(6, 2) + G.add_edge(6, 4) + G.add_edge(7, 4) + G.add_edge(8, 4) + G.add_edge(9, 4) + G.add_edge(9, 1) + G.add_edge(9, 3) + G.add_edge(10, 4) + + node_attrs = { + 0: {"time": datetime(1992, 1, 1)}, + 1: {"time": datetime(1992, 1, 1)}, + 2: {"time": datetime(1993, 1, 1)}, + 3: {"time": datetime(1993, 1, 1)}, + 4: {"time": datetime(1995, 1, 1)}, + 5: {"time": datetime(1997, 1, 1)}, + 6: {"time": datetime(1998, 1, 1)}, + 7: {"time": datetime(1999, 1, 1)}, + 8: {"time": datetime(1999, 1, 1)}, + 9: {"time": datetime(1998, 1, 1)}, + 10: {"time": datetime(1997, 4, 1)}, + } + + nx.set_node_attributes(G, node_attrs) + assert nx.cd_index(G, 4, time_delta=_delta) == 1.0 + + def test_n_equals_zero(self): + G = nx.DiGraph() + G.add_nodes_from([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + G.add_edge(4, 2) + G.add_edge(4, 0) + G.add_edge(4, 3) + G.add_edge(6, 4) + G.add_edge(7, 4) + G.add_edge(8, 4) + G.add_edge(9, 4) + G.add_edge(9, 1) + G.add_edge(10, 4) + + node_attrs = { + 0: {"time": datetime(1992, 1, 1)}, + 1: {"time": datetime(1992, 1, 1)}, + 2: {"time": datetime(1993, 1, 1)}, + 3: {"time": datetime(1993, 1, 1)}, + 4: {"time": datetime(1995, 1, 1)}, + 5: {"time": datetime(2005, 1, 1)}, + 6: {"time": datetime(2010, 1, 1)}, + 7: {"time": datetime(2001, 1, 1)}, + 8: {"time": datetime(2020, 1, 1)}, + 9: {"time": datetime(2017, 1, 1)}, + 10: {"time": datetime(2004, 4, 1)}, + } + + nx.set_node_attributes(G, node_attrs) + + with pytest.raises( + nx.NetworkXError, match="The cd index cannot be defined." + ) as ve: + nx.cd_index(G, 4, time_delta=_delta) + + def test_time_timedelta_compatibility(self): + G = nx.DiGraph() + G.add_nodes_from([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + G.add_edge(4, 2) + G.add_edge(4, 0) + G.add_edge(4, 3) + G.add_edge(6, 4) + G.add_edge(7, 4) + G.add_edge(8, 4) + G.add_edge(9, 4) + G.add_edge(9, 1) + G.add_edge(10, 4) + + node_attrs = { + 0: {"time": 20.2}, + 1: {"time": 20.2}, + 2: {"time": 30.7}, + 3: {"time": 30.7}, + 4: {"time": 50.9}, + 5: {"time": 70.1}, + 6: {"time": 80.6}, + 7: {"time": 90.7}, + 8: {"time": 90.7}, + 9: {"time": 80.6}, + 10: {"time": 74.2}, + } + + nx.set_node_attributes(G, node_attrs) + + with pytest.raises( + nx.NetworkXError, + match="Addition and comparison are not supported between", + ) as ve: + nx.cd_index(G, 4, time_delta=_delta) + + def test_node_with_no_time(self): + G = nx.DiGraph() + G.add_nodes_from([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + G.add_edge(8, 2) + G.add_edge(6, 0) + G.add_edge(6, 3) + G.add_edge(5, 2) + G.add_edge(6, 2) + G.add_edge(6, 4) + G.add_edge(7, 4) + G.add_edge(8, 4) + G.add_edge(9, 4) + G.add_edge(9, 1) + G.add_edge(9, 3) + G.add_edge(10, 4) + + node_attrs = { + 0: {"time": datetime(1992, 1, 1)}, + 1: {"time": datetime(1992, 1, 1)}, + 2: {"time": datetime(1993, 1, 1)}, + 3: {"time": datetime(1993, 1, 1)}, + 4: {"time": datetime(1995, 1, 1)}, + 6: {"time": datetime(1998, 1, 1)}, + 7: {"time": datetime(1999, 1, 1)}, + 8: {"time": datetime(1999, 1, 1)}, + 9: {"time": datetime(1998, 1, 1)}, + 10: {"time": datetime(1997, 4, 1)}, + } + + nx.set_node_attributes(G, node_attrs) + + with pytest.raises( + nx.NetworkXError, match="Not all nodes have a 'time' attribute." + ) as ve: + nx.cd_index(G, 4, time_delta=_delta) + + def test_maximally_consolidating(self): + G = nx.DiGraph() + G.add_nodes_from([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]) + G.add_edge(5, 1) + G.add_edge(5, 2) + G.add_edge(5, 3) + G.add_edge(5, 4) + G.add_edge(6, 1) + G.add_edge(6, 5) + G.add_edge(7, 1) + G.add_edge(7, 5) + G.add_edge(8, 2) + G.add_edge(8, 5) + G.add_edge(9, 5) + G.add_edge(9, 3) + G.add_edge(10, 5) + G.add_edge(10, 3) + G.add_edge(10, 4) + G.add_edge(11, 5) + G.add_edge(11, 4) + + node_attrs = { + 0: {"time": datetime(1992, 1, 1)}, + 1: {"time": datetime(1992, 1, 1)}, + 2: {"time": datetime(1993, 1, 1)}, + 3: {"time": datetime(1993, 1, 1)}, + 4: {"time": datetime(1995, 1, 1)}, + 5: {"time": datetime(1997, 1, 1)}, + 6: {"time": datetime(1998, 1, 1)}, + 7: {"time": datetime(1999, 1, 1)}, + 8: {"time": datetime(1999, 1, 1)}, + 9: {"time": datetime(1998, 1, 1)}, + 10: {"time": datetime(1997, 4, 1)}, + 11: {"time": datetime(1998, 5, 1)}, + } + + nx.set_node_attributes(G, node_attrs) + + assert nx.cd_index(G, 5, time_delta=_delta) == -1 + + def test_maximally_destabilizing(self): + G = nx.DiGraph() + G.add_nodes_from([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]) + G.add_edge(5, 1) + G.add_edge(5, 2) + G.add_edge(5, 3) + G.add_edge(5, 4) + G.add_edge(6, 5) + G.add_edge(7, 5) + G.add_edge(8, 5) + G.add_edge(9, 5) + G.add_edge(10, 5) + G.add_edge(11, 5) + + node_attrs = { + 0: {"time": datetime(1992, 1, 1)}, + 1: {"time": datetime(1992, 1, 1)}, + 2: {"time": datetime(1993, 1, 1)}, + 3: {"time": datetime(1993, 1, 1)}, + 4: {"time": datetime(1995, 1, 1)}, + 5: {"time": datetime(1997, 1, 1)}, + 6: {"time": datetime(1998, 1, 1)}, + 7: {"time": datetime(1999, 1, 1)}, + 8: {"time": datetime(1999, 1, 1)}, + 9: {"time": datetime(1998, 1, 1)}, + 10: {"time": datetime(1997, 4, 1)}, + 11: {"time": datetime(1998, 5, 1)}, + } + + nx.set_node_attributes(G, node_attrs) + + assert nx.cd_index(G, 5, time_delta=_delta) == 1 diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/test_tournament.py b/lib/python3.10/site-packages/networkx/algorithms/tests/test_tournament.py new file mode 100644 index 0000000000000000000000000000000000000000..e75abf8fcc6f6af0bdf9186bf3bb6fec433d5500 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/tests/test_tournament.py @@ -0,0 +1,163 @@ +"""Unit tests for the :mod:`networkx.algorithms.tournament` module.""" + +from itertools import combinations + +import pytest + +from networkx import DiGraph +from networkx.algorithms.tournament import ( + hamiltonian_path, + index_satisfying, + is_reachable, + is_strongly_connected, + is_tournament, + random_tournament, + score_sequence, + tournament_matrix, +) + + +def test_condition_not_satisfied(): + condition = lambda x: x > 0 + iter_in = [0] + assert index_satisfying(iter_in, condition) == 1 + + +def test_empty_iterable(): + condition = lambda x: x > 0 + with pytest.raises(ValueError): + index_satisfying([], condition) + + +def test_is_tournament(): + G = DiGraph() + G.add_edges_from([(0, 1), (1, 2), (2, 3), (3, 0), (1, 3), (0, 2)]) + assert is_tournament(G) + + +def test_self_loops(): + """A tournament must have no self-loops.""" + G = DiGraph() + G.add_edges_from([(0, 1), (1, 2), (2, 3), (3, 0), (1, 3), (0, 2)]) + G.add_edge(0, 0) + assert not is_tournament(G) + + +def test_missing_edges(): + """A tournament must not have any pair of nodes without at least + one edge joining the pair. + + """ + G = DiGraph() + G.add_edges_from([(0, 1), (1, 2), (2, 3), (3, 0), (1, 3)]) + assert not is_tournament(G) + + +def test_bidirectional_edges(): + """A tournament must not have any pair of nodes with greater + than one edge joining the pair. + + """ + G = DiGraph() + G.add_edges_from([(0, 1), (1, 2), (2, 3), (3, 0), (1, 3), (0, 2)]) + G.add_edge(1, 0) + assert not is_tournament(G) + + +def test_graph_is_tournament(): + for _ in range(10): + G = random_tournament(5) + assert is_tournament(G) + + +def test_graph_is_tournament_seed(): + for _ in range(10): + G = random_tournament(5, seed=1) + assert is_tournament(G) + + +def test_graph_is_tournament_one_node(): + G = random_tournament(1) + assert is_tournament(G) + + +def test_graph_is_tournament_zero_node(): + G = random_tournament(0) + assert is_tournament(G) + + +def test_hamiltonian_empty_graph(): + path = hamiltonian_path(DiGraph()) + assert len(path) == 0 + + +def test_path_is_hamiltonian(): + G = DiGraph() + G.add_edges_from([(0, 1), (1, 2), (2, 3), (3, 0), (1, 3), (0, 2)]) + path = hamiltonian_path(G) + assert len(path) == 4 + assert all(v in G[u] for u, v in zip(path, path[1:])) + + +def test_hamiltonian_cycle(): + """Tests that :func:`networkx.tournament.hamiltonian_path` + returns a Hamiltonian cycle when provided a strongly connected + tournament. + + """ + G = DiGraph() + G.add_edges_from([(0, 1), (1, 2), (2, 3), (3, 0), (1, 3), (0, 2)]) + path = hamiltonian_path(G) + assert len(path) == 4 + assert all(v in G[u] for u, v in zip(path, path[1:])) + assert path[0] in G[path[-1]] + + +def test_score_sequence_edge(): + G = DiGraph([(0, 1)]) + assert score_sequence(G) == [0, 1] + + +def test_score_sequence_triangle(): + G = DiGraph([(0, 1), (1, 2), (2, 0)]) + assert score_sequence(G) == [1, 1, 1] + + +def test_tournament_matrix(): + np = pytest.importorskip("numpy") + pytest.importorskip("scipy") + npt = np.testing + G = DiGraph([(0, 1)]) + m = tournament_matrix(G) + npt.assert_array_equal(m.todense(), np.array([[0, 1], [-1, 0]])) + + +def test_reachable_pair(): + """Tests for a reachable pair of nodes.""" + G = DiGraph([(0, 1), (1, 2), (2, 0)]) + assert is_reachable(G, 0, 2) + + +def test_same_node_is_reachable(): + """Tests that a node is always reachable from it.""" + # G is an arbitrary tournament on ten nodes. + G = DiGraph(sorted(p) for p in combinations(range(10), 2)) + assert all(is_reachable(G, v, v) for v in G) + + +def test_unreachable_pair(): + """Tests for an unreachable pair of nodes.""" + G = DiGraph([(0, 1), (0, 2), (1, 2)]) + assert not is_reachable(G, 1, 0) + + +def test_is_strongly_connected(): + """Tests for a strongly connected tournament.""" + G = DiGraph([(0, 1), (1, 2), (2, 0)]) + assert is_strongly_connected(G) + + +def test_not_strongly_connected(): + """Tests for a tournament that is not strongly connected.""" + G = DiGraph([(0, 1), (0, 2), (1, 2)]) + assert not is_strongly_connected(G) diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/test_triads.py b/lib/python3.10/site-packages/networkx/algorithms/tests/test_triads.py new file mode 100644 index 0000000000000000000000000000000000000000..62670351e84dc05347e397987726e687cbedbae7 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/tests/test_triads.py @@ -0,0 +1,289 @@ +"""Tests for the :mod:`networkx.algorithms.triads` module.""" + +import itertools +from collections import defaultdict +from random import sample + +import pytest + +import networkx as nx + + +def test_all_triplets_deprecated(): + G = nx.DiGraph([(1, 2), (2, 3), (3, 4)]) + with pytest.deprecated_call(): + nx.all_triplets(G) + + +def test_random_triad_deprecated(): + G = nx.path_graph(3, create_using=nx.DiGraph) + with pytest.deprecated_call(): + nx.random_triad(G) + + +def test_triadic_census(): + """Tests the triadic_census function.""" + G = nx.DiGraph() + G.add_edges_from(["01", "02", "03", "04", "05", "12", "16", "51", "56", "65"]) + expected = { + "030T": 2, + "120C": 1, + "210": 0, + "120U": 0, + "012": 9, + "102": 3, + "021U": 0, + "111U": 0, + "003": 8, + "030C": 0, + "021D": 9, + "201": 0, + "111D": 1, + "300": 0, + "120D": 0, + "021C": 2, + } + actual = nx.triadic_census(G) + assert expected == actual + + +def test_is_triad(): + """Tests the is_triad function""" + G = nx.karate_club_graph() + G = G.to_directed() + for i in range(100): + nodes = sample(sorted(G.nodes()), 3) + G2 = G.subgraph(nodes) + assert nx.is_triad(G2) + + +def test_all_triplets(): + """Tests the all_triplets function.""" + G = nx.DiGraph() + G.add_edges_from(["01", "02", "03", "04", "05", "12", "16", "51", "56", "65"]) + expected = [ + f"{i},{j},{k}" + for i in range(7) + for j in range(i + 1, 7) + for k in range(j + 1, 7) + ] + expected = [set(x.split(",")) for x in expected] + actual = [set(x) for x in nx.all_triplets(G)] + assert all(any(s1 == s2 for s1 in expected) for s2 in actual) + + +def test_all_triads(): + """Tests the all_triplets function.""" + G = nx.DiGraph() + G.add_edges_from(["01", "02", "03", "04", "05", "12", "16", "51", "56", "65"]) + expected = [ + f"{i},{j},{k}" + for i in range(7) + for j in range(i + 1, 7) + for k in range(j + 1, 7) + ] + expected = [G.subgraph(x.split(",")) for x in expected] + actual = list(nx.all_triads(G)) + assert all(any(nx.is_isomorphic(G1, G2) for G1 in expected) for G2 in actual) + + +def test_triad_type(): + """Tests the triad_type function.""" + # 0 edges (1 type) + G = nx.DiGraph({0: [], 1: [], 2: []}) + assert nx.triad_type(G) == "003" + # 1 edge (1 type) + G = nx.DiGraph({0: [1], 1: [], 2: []}) + assert nx.triad_type(G) == "012" + # 2 edges (4 types) + G = nx.DiGraph([(0, 1), (0, 2)]) + assert nx.triad_type(G) == "021D" + G = nx.DiGraph({0: [1], 1: [0], 2: []}) + assert nx.triad_type(G) == "102" + G = nx.DiGraph([(0, 1), (2, 1)]) + assert nx.triad_type(G) == "021U" + G = nx.DiGraph([(0, 1), (1, 2)]) + assert nx.triad_type(G) == "021C" + # 3 edges (4 types) + G = nx.DiGraph([(0, 1), (1, 0), (2, 1)]) + assert nx.triad_type(G) == "111D" + G = nx.DiGraph([(0, 1), (1, 0), (1, 2)]) + assert nx.triad_type(G) == "111U" + G = nx.DiGraph([(0, 1), (1, 2), (0, 2)]) + assert nx.triad_type(G) == "030T" + G = nx.DiGraph([(0, 1), (1, 2), (2, 0)]) + assert nx.triad_type(G) == "030C" + # 4 edges (4 types) + G = nx.DiGraph([(0, 1), (1, 0), (2, 0), (0, 2)]) + assert nx.triad_type(G) == "201" + G = nx.DiGraph([(0, 1), (1, 0), (2, 0), (2, 1)]) + assert nx.triad_type(G) == "120D" + G = nx.DiGraph([(0, 1), (1, 0), (0, 2), (1, 2)]) + assert nx.triad_type(G) == "120U" + G = nx.DiGraph([(0, 1), (1, 0), (0, 2), (2, 1)]) + assert nx.triad_type(G) == "120C" + # 5 edges (1 type) + G = nx.DiGraph([(0, 1), (1, 0), (2, 1), (1, 2), (0, 2)]) + assert nx.triad_type(G) == "210" + # 6 edges (1 type) + G = nx.DiGraph([(0, 1), (1, 0), (1, 2), (2, 1), (0, 2), (2, 0)]) + assert nx.triad_type(G) == "300" + + +def test_triads_by_type(): + """Tests the all_triplets function.""" + G = nx.DiGraph() + G.add_edges_from(["01", "02", "03", "04", "05", "12", "16", "51", "56", "65"]) + all_triads = nx.all_triads(G) + expected = defaultdict(list) + for triad in all_triads: + name = nx.triad_type(triad) + expected[name].append(triad) + actual = nx.triads_by_type(G) + assert set(actual.keys()) == set(expected.keys()) + for tri_type, actual_Gs in actual.items(): + expected_Gs = expected[tri_type] + for a in actual_Gs: + assert any(nx.is_isomorphic(a, e) for e in expected_Gs) + + +def test_random_triad(): + """Tests the random_triad function""" + G = nx.karate_club_graph() + G = G.to_directed() + for i in range(100): + assert nx.is_triad(nx.random_triad(G)) + + G = nx.DiGraph() + msg = "at least 3 nodes to form a triad" + with pytest.raises(nx.NetworkXError, match=msg): + nx.random_triad(G) + + +def test_triadic_census_short_path_nodelist(): + G = nx.path_graph("abc", create_using=nx.DiGraph) + expected = {"021C": 1} + for nl in ["a", "b", "c", "ab", "ac", "bc", "abc"]: + triad_census = nx.triadic_census(G, nodelist=nl) + assert expected == {typ: cnt for typ, cnt in triad_census.items() if cnt > 0} + + +def test_triadic_census_correct_nodelist_values(): + G = nx.path_graph(5, create_using=nx.DiGraph) + msg = r"nodelist includes duplicate nodes or nodes not in G" + with pytest.raises(ValueError, match=msg): + nx.triadic_census(G, [1, 2, 2, 3]) + with pytest.raises(ValueError, match=msg): + nx.triadic_census(G, [1, 2, "a", 3]) + + +def test_triadic_census_tiny_graphs(): + tc = nx.triadic_census(nx.empty_graph(0, create_using=nx.DiGraph)) + assert {} == {typ: cnt for typ, cnt in tc.items() if cnt > 0} + tc = nx.triadic_census(nx.empty_graph(1, create_using=nx.DiGraph)) + assert {} == {typ: cnt for typ, cnt in tc.items() if cnt > 0} + tc = nx.triadic_census(nx.empty_graph(2, create_using=nx.DiGraph)) + assert {} == {typ: cnt for typ, cnt in tc.items() if cnt > 0} + tc = nx.triadic_census(nx.DiGraph([(1, 2)])) + assert {} == {typ: cnt for typ, cnt in tc.items() if cnt > 0} + + +def test_triadic_census_selfloops(): + GG = nx.path_graph("abc", create_using=nx.DiGraph) + expected = {"021C": 1} + for n in GG: + G = GG.copy() + G.add_edge(n, n) + tc = nx.triadic_census(G) + assert expected == {typ: cnt for typ, cnt in tc.items() if cnt > 0} + + GG = nx.path_graph("abcde", create_using=nx.DiGraph) + tbt = nx.triads_by_type(GG) + for n in GG: + GG.add_edge(n, n) + tc = nx.triadic_census(GG) + assert tc == {tt: len(tbt[tt]) for tt in tc} + + +def test_triadic_census_four_path(): + G = nx.path_graph("abcd", create_using=nx.DiGraph) + expected = {"012": 2, "021C": 2} + triad_census = nx.triadic_census(G) + assert expected == {typ: cnt for typ, cnt in triad_census.items() if cnt > 0} + + +def test_triadic_census_four_path_nodelist(): + G = nx.path_graph("abcd", create_using=nx.DiGraph) + expected_end = {"012": 2, "021C": 1} + expected_mid = {"012": 1, "021C": 2} + a_triad_census = nx.triadic_census(G, nodelist=["a"]) + assert expected_end == {typ: cnt for typ, cnt in a_triad_census.items() if cnt > 0} + b_triad_census = nx.triadic_census(G, nodelist=["b"]) + assert expected_mid == {typ: cnt for typ, cnt in b_triad_census.items() if cnt > 0} + c_triad_census = nx.triadic_census(G, nodelist=["c"]) + assert expected_mid == {typ: cnt for typ, cnt in c_triad_census.items() if cnt > 0} + d_triad_census = nx.triadic_census(G, nodelist=["d"]) + assert expected_end == {typ: cnt for typ, cnt in d_triad_census.items() if cnt > 0} + + +def test_triadic_census_nodelist(): + """Tests the triadic_census function.""" + G = nx.DiGraph() + G.add_edges_from(["01", "02", "03", "04", "05", "12", "16", "51", "56", "65"]) + expected = { + "030T": 2, + "120C": 1, + "210": 0, + "120U": 0, + "012": 9, + "102": 3, + "021U": 0, + "111U": 0, + "003": 8, + "030C": 0, + "021D": 9, + "201": 0, + "111D": 1, + "300": 0, + "120D": 0, + "021C": 2, + } + actual = {k: 0 for k in expected} + for node in G.nodes(): + node_triad_census = nx.triadic_census(G, nodelist=[node]) + for triad_key in expected: + actual[triad_key] += node_triad_census[triad_key] + # Divide all counts by 3 + for k, v in actual.items(): + actual[k] //= 3 + assert expected == actual + + +@pytest.mark.parametrize("N", [5, 10]) +def test_triadic_census_on_random_graph(N): + G = nx.binomial_graph(N, 0.3, directed=True, seed=42) + tc1 = nx.triadic_census(G) + tbt = nx.triads_by_type(G) + tc2 = {tt: len(tbt[tt]) for tt in tc1} + assert tc1 == tc2 + + for n in G: + tc1 = nx.triadic_census(G, nodelist={n}) + tc2 = {tt: sum(1 for t in tbt.get(tt, []) if n in t) for tt in tc1} + assert tc1 == tc2 + + for ns in itertools.combinations(G, 2): + ns = set(ns) + tc1 = nx.triadic_census(G, nodelist=ns) + tc2 = { + tt: sum(1 for t in tbt.get(tt, []) if any(n in ns for n in t)) for tt in tc1 + } + assert tc1 == tc2 + + for ns in itertools.combinations(G, 3): + ns = set(ns) + tc1 = nx.triadic_census(G, nodelist=ns) + tc2 = { + tt: sum(1 for t in tbt.get(tt, []) if any(n in ns for n in t)) for tt in tc1 + } + assert tc1 == tc2 diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/test_vitality.py b/lib/python3.10/site-packages/networkx/algorithms/tests/test_vitality.py new file mode 100644 index 0000000000000000000000000000000000000000..248206e670fa911f62177bb6727d6a7a6df1e6b9 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/tests/test_vitality.py @@ -0,0 +1,41 @@ +import networkx as nx + + +class TestClosenessVitality: + def test_unweighted(self): + G = nx.cycle_graph(3) + vitality = nx.closeness_vitality(G) + assert vitality == {0: 2, 1: 2, 2: 2} + + def test_weighted(self): + G = nx.Graph() + nx.add_cycle(G, [0, 1, 2], weight=2) + vitality = nx.closeness_vitality(G, weight="weight") + assert vitality == {0: 4, 1: 4, 2: 4} + + def test_unweighted_digraph(self): + G = nx.DiGraph(nx.cycle_graph(3)) + vitality = nx.closeness_vitality(G) + assert vitality == {0: 4, 1: 4, 2: 4} + + def test_weighted_digraph(self): + G = nx.DiGraph() + nx.add_cycle(G, [0, 1, 2], weight=2) + nx.add_cycle(G, [2, 1, 0], weight=2) + vitality = nx.closeness_vitality(G, weight="weight") + assert vitality == {0: 8, 1: 8, 2: 8} + + def test_weighted_multidigraph(self): + G = nx.MultiDiGraph() + nx.add_cycle(G, [0, 1, 2], weight=2) + nx.add_cycle(G, [2, 1, 0], weight=2) + vitality = nx.closeness_vitality(G, weight="weight") + assert vitality == {0: 8, 1: 8, 2: 8} + + def test_disconnecting_graph(self): + """Tests that the closeness vitality of a node whose removal + disconnects the graph is negative infinity. + + """ + G = nx.path_graph(3) + assert nx.closeness_vitality(G, node=1) == -float("inf") diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/test_voronoi.py b/lib/python3.10/site-packages/networkx/algorithms/tests/test_voronoi.py new file mode 100644 index 0000000000000000000000000000000000000000..3269ae62a023ff0cf9fdc55122cb6e7c8d2ba319 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/tests/test_voronoi.py @@ -0,0 +1,103 @@ +import networkx as nx +from networkx.utils import pairwise + + +class TestVoronoiCells: + """Unit tests for the Voronoi cells function.""" + + def test_isolates(self): + """Tests that a graph with isolated nodes has all isolates in + one block of the partition. + + """ + G = nx.empty_graph(5) + cells = nx.voronoi_cells(G, {0, 2, 4}) + expected = {0: {0}, 2: {2}, 4: {4}, "unreachable": {1, 3}} + assert expected == cells + + def test_undirected_unweighted(self): + G = nx.cycle_graph(6) + cells = nx.voronoi_cells(G, {0, 3}) + expected = {0: {0, 1, 5}, 3: {2, 3, 4}} + assert expected == cells + + def test_directed_unweighted(self): + # This is the singly-linked directed cycle graph on six nodes. + G = nx.DiGraph(pairwise(range(6), cyclic=True)) + cells = nx.voronoi_cells(G, {0, 3}) + expected = {0: {0, 1, 2}, 3: {3, 4, 5}} + assert expected == cells + + def test_directed_inward(self): + """Tests that reversing the graph gives the "inward" Voronoi + partition. + + """ + # This is the singly-linked reverse directed cycle graph on six nodes. + G = nx.DiGraph(pairwise(range(6), cyclic=True)) + G = G.reverse(copy=False) + cells = nx.voronoi_cells(G, {0, 3}) + expected = {0: {0, 4, 5}, 3: {1, 2, 3}} + assert expected == cells + + def test_undirected_weighted(self): + edges = [(0, 1, 10), (1, 2, 1), (2, 3, 1)] + G = nx.Graph() + G.add_weighted_edges_from(edges) + cells = nx.voronoi_cells(G, {0, 3}) + expected = {0: {0}, 3: {1, 2, 3}} + assert expected == cells + + def test_directed_weighted(self): + edges = [(0, 1, 10), (1, 2, 1), (2, 3, 1), (3, 2, 1), (2, 1, 1)] + G = nx.DiGraph() + G.add_weighted_edges_from(edges) + cells = nx.voronoi_cells(G, {0, 3}) + expected = {0: {0}, 3: {1, 2, 3}} + assert expected == cells + + def test_multigraph_unweighted(self): + """Tests that the Voronoi cells for a multigraph are the same as + for a simple graph. + + """ + edges = [(0, 1), (1, 2), (2, 3)] + G = nx.MultiGraph(2 * edges) + H = nx.Graph(G) + G_cells = nx.voronoi_cells(G, {0, 3}) + H_cells = nx.voronoi_cells(H, {0, 3}) + assert G_cells == H_cells + + def test_multidigraph_unweighted(self): + # This is the twice-singly-linked directed cycle graph on six nodes. + edges = list(pairwise(range(6), cyclic=True)) + G = nx.MultiDiGraph(2 * edges) + H = nx.DiGraph(G) + G_cells = nx.voronoi_cells(G, {0, 3}) + H_cells = nx.voronoi_cells(H, {0, 3}) + assert G_cells == H_cells + + def test_multigraph_weighted(self): + edges = [(0, 1, 10), (0, 1, 10), (1, 2, 1), (1, 2, 100), (2, 3, 1), (2, 3, 100)] + G = nx.MultiGraph() + G.add_weighted_edges_from(edges) + cells = nx.voronoi_cells(G, {0, 3}) + expected = {0: {0}, 3: {1, 2, 3}} + assert expected == cells + + def test_multidigraph_weighted(self): + edges = [ + (0, 1, 10), + (0, 1, 10), + (1, 2, 1), + (2, 3, 1), + (3, 2, 10), + (3, 2, 1), + (2, 1, 10), + (2, 1, 1), + ] + G = nx.MultiDiGraph() + G.add_weighted_edges_from(edges) + cells = nx.voronoi_cells(G, {0, 3}) + expected = {0: {0}, 3: {1, 2, 3}} + assert expected == cells diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/test_walks.py b/lib/python3.10/site-packages/networkx/algorithms/tests/test_walks.py new file mode 100644 index 0000000000000000000000000000000000000000..7a6b323932988e1b9513118162df62e9613ee65b --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/tests/test_walks.py @@ -0,0 +1,54 @@ +"""Unit tests for the :mod:`networkx.algorithms.walks` module.""" + +import pytest + +import networkx as nx + +pytest.importorskip("numpy") +pytest.importorskip("scipy") + + +def test_directed(): + G = nx.DiGraph([(0, 1), (1, 2), (2, 0)]) + num_walks = nx.number_of_walks(G, 3) + expected = {0: {0: 1, 1: 0, 2: 0}, 1: {0: 0, 1: 1, 2: 0}, 2: {0: 0, 1: 0, 2: 1}} + assert num_walks == expected + + +def test_undirected(): + G = nx.cycle_graph(3) + num_walks = nx.number_of_walks(G, 3) + expected = {0: {0: 2, 1: 3, 2: 3}, 1: {0: 3, 1: 2, 2: 3}, 2: {0: 3, 1: 3, 2: 2}} + assert num_walks == expected + + +def test_non_integer_nodes(): + G = nx.DiGraph([("A", "B"), ("B", "C"), ("C", "A")]) + num_walks = nx.number_of_walks(G, 2) + expected = { + "A": {"A": 0, "B": 0, "C": 1}, + "B": {"A": 1, "B": 0, "C": 0}, + "C": {"A": 0, "B": 1, "C": 0}, + } + assert num_walks == expected + + +def test_zero_length(): + G = nx.cycle_graph(3) + num_walks = nx.number_of_walks(G, 0) + expected = {0: {0: 1, 1: 0, 2: 0}, 1: {0: 0, 1: 1, 2: 0}, 2: {0: 0, 1: 0, 2: 1}} + assert num_walks == expected + + +def test_negative_length_exception(): + G = nx.cycle_graph(3) + with pytest.raises(ValueError): + nx.number_of_walks(G, -1) + + +def test_hidden_weight_attr(): + G = nx.cycle_graph(3) + G.add_edge(1, 2, weight=5) + num_walks = nx.number_of_walks(G, 3) + expected = {0: {0: 2, 1: 3, 2: 3}, 1: {0: 3, 1: 2, 2: 3}, 2: {0: 3, 1: 3, 2: 2}} + assert num_walks == expected diff --git a/lib/python3.10/site-packages/networkx/algorithms/tests/test_wiener.py b/lib/python3.10/site-packages/networkx/algorithms/tests/test_wiener.py new file mode 100644 index 0000000000000000000000000000000000000000..aded95143ca53e0031189dfabeacf5df0f887120 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/tests/test_wiener.py @@ -0,0 +1,123 @@ +import networkx as nx + + +def test_wiener_index_of_disconnected_graph(): + assert nx.wiener_index(nx.empty_graph(2)) == float("inf") + + +def test_wiener_index_of_directed_graph(): + G = nx.complete_graph(3) + H = nx.DiGraph(G) + assert (2 * nx.wiener_index(G)) == nx.wiener_index(H) + + +def test_wiener_index_of_complete_graph(): + n = 10 + G = nx.complete_graph(n) + assert nx.wiener_index(G) == (n * (n - 1) / 2) + + +def test_wiener_index_of_path_graph(): + # In P_n, there are n - 1 pairs of vertices at distance one, n - + # 2 pairs at distance two, n - 3 at distance three, ..., 1 at + # distance n - 1, so the Wiener index should be + # + # 1 * (n - 1) + 2 * (n - 2) + ... + (n - 2) * 2 + (n - 1) * 1 + # + # For example, in P_5, + # + # 1 * 4 + 2 * 3 + 3 * 2 + 4 * 1 = 2 (1 * 4 + 2 * 3) + # + # and in P_6, + # + # 1 * 5 + 2 * 4 + 3 * 3 + 4 * 2 + 5 * 1 = 2 (1 * 5 + 2 * 4) + 3 * 3 + # + # assuming n is *odd*, this gives the formula + # + # 2 \sum_{i = 1}^{(n - 1) / 2} [i * (n - i)] + # + # assuming n is *even*, this gives the formula + # + # 2 \sum_{i = 1}^{n / 2} [i * (n - i)] - (n / 2) ** 2 + # + n = 9 + G = nx.path_graph(n) + expected = 2 * sum(i * (n - i) for i in range(1, (n // 2) + 1)) + actual = nx.wiener_index(G) + assert expected == actual + + +def test_schultz_and_gutman_index_of_disconnected_graph(): + n = 4 + G = nx.Graph() + G.add_nodes_from(list(range(1, n + 1))) + expected = float("inf") + + G.add_edge(1, 2) + G.add_edge(3, 4) + + actual_1 = nx.schultz_index(G) + actual_2 = nx.gutman_index(G) + + assert expected == actual_1 + assert expected == actual_2 + + +def test_schultz_and_gutman_index_of_complete_bipartite_graph_1(): + n = 3 + m = 3 + cbg = nx.complete_bipartite_graph(n, m) + + expected_1 = n * m * (n + m) + 2 * n * (n - 1) * m + 2 * m * (m - 1) * n + actual_1 = nx.schultz_index(cbg) + + expected_2 = n * m * (n * m) + n * (n - 1) * m * m + m * (m - 1) * n * n + actual_2 = nx.gutman_index(cbg) + + assert expected_1 == actual_1 + assert expected_2 == actual_2 + + +def test_schultz_and_gutman_index_of_complete_bipartite_graph_2(): + n = 2 + m = 5 + cbg = nx.complete_bipartite_graph(n, m) + + expected_1 = n * m * (n + m) + 2 * n * (n - 1) * m + 2 * m * (m - 1) * n + actual_1 = nx.schultz_index(cbg) + + expected_2 = n * m * (n * m) + n * (n - 1) * m * m + m * (m - 1) * n * n + actual_2 = nx.gutman_index(cbg) + + assert expected_1 == actual_1 + assert expected_2 == actual_2 + + +def test_schultz_and_gutman_index_of_complete_graph(): + n = 5 + cg = nx.complete_graph(n) + + expected_1 = n * (n - 1) * (n - 1) + actual_1 = nx.schultz_index(cg) + + assert expected_1 == actual_1 + + expected_2 = n * (n - 1) * (n - 1) * (n - 1) / 2 + actual_2 = nx.gutman_index(cg) + + assert expected_2 == actual_2 + + +def test_schultz_and_gutman_index_of_odd_cycle_graph(): + k = 5 + n = 2 * k + 1 + ocg = nx.cycle_graph(n) + + expected_1 = 2 * n * k * (k + 1) + actual_1 = nx.schultz_index(ocg) + + expected_2 = 2 * n * k * (k + 1) + actual_2 = nx.gutman_index(ocg) + + assert expected_1 == actual_1 + assert expected_2 == actual_2 diff --git a/lib/python3.10/site-packages/networkx/algorithms/traversal/__init__.py b/lib/python3.10/site-packages/networkx/algorithms/traversal/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..93e6cdd08ca959835cc5c5f9a3e6dc353f4b217d --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/traversal/__init__.py @@ -0,0 +1,5 @@ +from .beamsearch import * +from .breadth_first_search import * +from .depth_first_search import * +from .edgedfs import * +from .edgebfs import * diff --git a/lib/python3.10/site-packages/networkx/algorithms/traversal/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/traversal/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ce23f3167df03df381412d84b47879a4c986c6d Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/traversal/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/traversal/__pycache__/beamsearch.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/traversal/__pycache__/beamsearch.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e0f7312d698726cd06ceabc67b6aa8472e2ac87 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/traversal/__pycache__/beamsearch.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/traversal/__pycache__/breadth_first_search.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/traversal/__pycache__/breadth_first_search.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb5a5425c0cfc522ed3c2b4915fde11b3128f942 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/traversal/__pycache__/breadth_first_search.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/traversal/__pycache__/depth_first_search.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/traversal/__pycache__/depth_first_search.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..59b0da4bc03b782a54a68a3a00515fe5122a54f3 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/traversal/__pycache__/depth_first_search.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/traversal/__pycache__/edgebfs.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/traversal/__pycache__/edgebfs.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eed9c4e7755a55b51ea9d1472da220d9e176967e Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/traversal/__pycache__/edgebfs.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/traversal/__pycache__/edgedfs.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/traversal/__pycache__/edgedfs.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e0d42eaa22c6310a82e983e43984f6f33e2955a Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/traversal/__pycache__/edgedfs.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/traversal/beamsearch.py b/lib/python3.10/site-packages/networkx/algorithms/traversal/beamsearch.py new file mode 100644 index 0000000000000000000000000000000000000000..23fbe7bbb3f037eca4f7ede25b7edf6305356587 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/traversal/beamsearch.py @@ -0,0 +1,90 @@ +"""Basic algorithms for breadth-first searching the nodes of a graph.""" + +import networkx as nx + +__all__ = ["bfs_beam_edges"] + + +@nx._dispatchable +def bfs_beam_edges(G, source, value, width=None): + """Iterates over edges in a beam search. + + The beam search is a generalized breadth-first search in which only + the "best" *w* neighbors of the current node are enqueued, where *w* + is the beam width and "best" is an application-specific + heuristic. In general, a beam search with a small beam width might + not visit each node in the graph. + + .. note:: + + With the default value of ``width=None`` or `width` greater than the + maximum degree of the graph, this function equates to a slower + version of `~networkx.algorithms.traversal.breadth_first_search.bfs_edges`. + All nodes will be visited, though the order of the reported edges may + vary. In such cases, `value` has no effect - consider using `bfs_edges` + directly instead. + + Parameters + ---------- + G : NetworkX graph + + source : node + Starting node for the breadth-first search; this function + iterates over only those edges in the component reachable from + this node. + + value : function + A function that takes a node of the graph as input and returns a + real number indicating how "good" it is. A higher value means it + is more likely to be visited sooner during the search. When + visiting a new node, only the `width` neighbors with the highest + `value` are enqueued (in decreasing order of `value`). + + width : int (default = None) + The beam width for the search. This is the number of neighbors + (ordered by `value`) to enqueue when visiting each new node. + + Yields + ------ + edge + Edges in the beam search starting from `source`, given as a pair + of nodes. + + Examples + -------- + To give nodes with, for example, a higher centrality precedence + during the search, set the `value` function to return the centrality + value of the node: + + >>> G = nx.karate_club_graph() + >>> centrality = nx.eigenvector_centrality(G) + >>> list(nx.bfs_beam_edges(G, source=0, value=centrality.get, width=3)) + [(0, 2), (0, 1), (0, 8), (2, 32), (1, 13), (8, 33)] + """ + + if width is None: + width = len(G) + + def successors(v): + """Returns a list of the best neighbors of a node. + + `v` is a node in the graph `G`. + + The "best" neighbors are chosen according to the `value` + function (higher is better). Only the `width` best neighbors of + `v` are returned. + """ + # TODO The Python documentation states that for small values, it + # is better to use `heapq.nlargest`. We should determine the + # threshold at which its better to use `heapq.nlargest()` + # instead of `sorted()[:]` and apply that optimization here. + # + # If `width` is greater than the number of neighbors of `v`, all + # neighbors are returned by the semantics of slicing in + # Python. This occurs in the special case that the user did not + # specify a `width`: in this case all neighbors are always + # returned, so this is just a (slower) implementation of + # `bfs_edges(G, source)` but with a sorted enqueue step. + return iter(sorted(G.neighbors(v), key=value, reverse=True)[:width]) + + yield from nx.generic_bfs_edges(G, source, successors) diff --git a/lib/python3.10/site-packages/networkx/algorithms/traversal/breadth_first_search.py b/lib/python3.10/site-packages/networkx/algorithms/traversal/breadth_first_search.py new file mode 100644 index 0000000000000000000000000000000000000000..899dc92b723665f88eb76dc90a0c1aee87dde1f4 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/traversal/breadth_first_search.py @@ -0,0 +1,575 @@ +"""Basic algorithms for breadth-first searching the nodes of a graph.""" + +from collections import deque + +import networkx as nx + +__all__ = [ + "bfs_edges", + "bfs_tree", + "bfs_predecessors", + "bfs_successors", + "descendants_at_distance", + "bfs_layers", + "bfs_labeled_edges", + "generic_bfs_edges", +] + + +@nx._dispatchable +def generic_bfs_edges(G, source, neighbors=None, depth_limit=None): + """Iterate over edges in a breadth-first search. + + The breadth-first search begins at `source` and enqueues the + neighbors of newly visited nodes specified by the `neighbors` + function. + + Parameters + ---------- + G : NetworkX graph + + source : node + Starting node for the breadth-first search; this function + iterates over only those edges in the component reachable from + this node. + + neighbors : function + A function that takes a newly visited node of the graph as input + and returns an *iterator* (not just a list) of nodes that are + neighbors of that node with custom ordering. If not specified, this is + just the ``G.neighbors`` method, but in general it can be any function + that returns an iterator over some or all of the neighbors of a + given node, in any order. + + depth_limit : int, optional(default=len(G)) + Specify the maximum search depth. + + Yields + ------ + edge + Edges in the breadth-first search starting from `source`. + + Examples + -------- + >>> G = nx.path_graph(7) + >>> list(nx.generic_bfs_edges(G, source=0)) + [(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6)] + >>> list(nx.generic_bfs_edges(G, source=2)) + [(2, 1), (2, 3), (1, 0), (3, 4), (4, 5), (5, 6)] + >>> list(nx.generic_bfs_edges(G, source=2, depth_limit=2)) + [(2, 1), (2, 3), (1, 0), (3, 4)] + + The `neighbors` param can be used to specify the visitation order of each + node's neighbors generically. In the following example, we modify the default + neighbor to return *odd* nodes first: + + >>> def odd_first(n): + ... return sorted(G.neighbors(n), key=lambda x: x % 2, reverse=True) + + >>> G = nx.star_graph(5) + >>> list(nx.generic_bfs_edges(G, source=0)) # Default neighbor ordering + [(0, 1), (0, 2), (0, 3), (0, 4), (0, 5)] + >>> list(nx.generic_bfs_edges(G, source=0, neighbors=odd_first)) + [(0, 1), (0, 3), (0, 5), (0, 2), (0, 4)] + + Notes + ----- + This implementation is from `PADS`_, which was in the public domain + when it was first accessed in July, 2004. The modifications + to allow depth limits are based on the Wikipedia article + "`Depth-limited-search`_". + + .. _PADS: http://www.ics.uci.edu/~eppstein/PADS/BFS.py + .. _Depth-limited-search: https://en.wikipedia.org/wiki/Depth-limited_search + """ + if neighbors is None: + neighbors = G.neighbors + if depth_limit is None: + depth_limit = len(G) + + seen = {source} + n = len(G) + depth = 0 + next_parents_children = [(source, neighbors(source))] + while next_parents_children and depth < depth_limit: + this_parents_children = next_parents_children + next_parents_children = [] + for parent, children in this_parents_children: + for child in children: + if child not in seen: + seen.add(child) + next_parents_children.append((child, neighbors(child))) + yield parent, child + if len(seen) == n: + return + depth += 1 + + +@nx._dispatchable +def bfs_edges(G, source, reverse=False, depth_limit=None, sort_neighbors=None): + """Iterate over edges in a breadth-first-search starting at source. + + Parameters + ---------- + G : NetworkX graph + + source : node + Specify starting node for breadth-first search; this function + iterates over only those edges in the component reachable from + this node. + + reverse : bool, optional + If True traverse a directed graph in the reverse direction + + depth_limit : int, optional(default=len(G)) + Specify the maximum search depth + + sort_neighbors : function (default=None) + A function that takes an iterator over nodes as the input, and + returns an iterable of the same nodes with a custom ordering. + For example, `sorted` will sort the nodes in increasing order. + + Yields + ------ + edge: 2-tuple of nodes + Yields edges resulting from the breadth-first search. + + Examples + -------- + To get the edges in a breadth-first search:: + + >>> G = nx.path_graph(3) + >>> list(nx.bfs_edges(G, 0)) + [(0, 1), (1, 2)] + >>> list(nx.bfs_edges(G, source=0, depth_limit=1)) + [(0, 1)] + + To get the nodes in a breadth-first search order:: + + >>> G = nx.path_graph(3) + >>> root = 2 + >>> edges = nx.bfs_edges(G, root) + >>> nodes = [root] + [v for u, v in edges] + >>> nodes + [2, 1, 0] + + Notes + ----- + The naming of this function is very similar to + :func:`~networkx.algorithms.traversal.edgebfs.edge_bfs`. The difference + is that ``edge_bfs`` yields edges even if they extend back to an already + explored node while this generator yields the edges of the tree that results + from a breadth-first-search (BFS) so no edges are reported if they extend + to already explored nodes. That means ``edge_bfs`` reports all edges while + ``bfs_edges`` only reports those traversed by a node-based BFS. Yet another + description is that ``bfs_edges`` reports the edges traversed during BFS + while ``edge_bfs`` reports all edges in the order they are explored. + + Based on the breadth-first search implementation in PADS [1]_ + by D. Eppstein, July 2004; with modifications to allow depth limits + as described in [2]_. + + References + ---------- + .. [1] http://www.ics.uci.edu/~eppstein/PADS/BFS.py. + .. [2] https://en.wikipedia.org/wiki/Depth-limited_search + + See Also + -------- + bfs_tree + :func:`~networkx.algorithms.traversal.depth_first_search.dfs_edges` + :func:`~networkx.algorithms.traversal.edgebfs.edge_bfs` + + """ + if reverse and G.is_directed(): + successors = G.predecessors + else: + successors = G.neighbors + + if sort_neighbors is not None: + yield from generic_bfs_edges( + G, source, lambda node: iter(sort_neighbors(successors(node))), depth_limit + ) + else: + yield from generic_bfs_edges(G, source, successors, depth_limit) + + +@nx._dispatchable(returns_graph=True) +def bfs_tree(G, source, reverse=False, depth_limit=None, sort_neighbors=None): + """Returns an oriented tree constructed from of a breadth-first-search + starting at source. + + Parameters + ---------- + G : NetworkX graph + + source : node + Specify starting node for breadth-first search + + reverse : bool, optional + If True traverse a directed graph in the reverse direction + + depth_limit : int, optional(default=len(G)) + Specify the maximum search depth + + sort_neighbors : function (default=None) + A function that takes an iterator over nodes as the input, and + returns an iterable of the same nodes with a custom ordering. + For example, `sorted` will sort the nodes in increasing order. + + Returns + ------- + T: NetworkX DiGraph + An oriented tree + + Examples + -------- + >>> G = nx.path_graph(3) + >>> list(nx.bfs_tree(G, 1).edges()) + [(1, 0), (1, 2)] + >>> H = nx.Graph() + >>> nx.add_path(H, [0, 1, 2, 3, 4, 5, 6]) + >>> nx.add_path(H, [2, 7, 8, 9, 10]) + >>> sorted(list(nx.bfs_tree(H, source=3, depth_limit=3).edges())) + [(1, 0), (2, 1), (2, 7), (3, 2), (3, 4), (4, 5), (5, 6), (7, 8)] + + + Notes + ----- + Based on http://www.ics.uci.edu/~eppstein/PADS/BFS.py + by D. Eppstein, July 2004. The modifications + to allow depth limits based on the Wikipedia article + "`Depth-limited-search`_". + + .. _Depth-limited-search: https://en.wikipedia.org/wiki/Depth-limited_search + + See Also + -------- + dfs_tree + bfs_edges + edge_bfs + """ + T = nx.DiGraph() + T.add_node(source) + edges_gen = bfs_edges( + G, + source, + reverse=reverse, + depth_limit=depth_limit, + sort_neighbors=sort_neighbors, + ) + T.add_edges_from(edges_gen) + return T + + +@nx._dispatchable +def bfs_predecessors(G, source, depth_limit=None, sort_neighbors=None): + """Returns an iterator of predecessors in breadth-first-search from source. + + Parameters + ---------- + G : NetworkX graph + + source : node + Specify starting node for breadth-first search + + depth_limit : int, optional(default=len(G)) + Specify the maximum search depth + + sort_neighbors : function (default=None) + A function that takes an iterator over nodes as the input, and + returns an iterable of the same nodes with a custom ordering. + For example, `sorted` will sort the nodes in increasing order. + + Returns + ------- + pred: iterator + (node, predecessor) iterator where `predecessor` is the predecessor of + `node` in a breadth first search starting from `source`. + + Examples + -------- + >>> G = nx.path_graph(3) + >>> dict(nx.bfs_predecessors(G, 0)) + {1: 0, 2: 1} + >>> H = nx.Graph() + >>> H.add_edges_from([(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)]) + >>> dict(nx.bfs_predecessors(H, 0)) + {1: 0, 2: 0, 3: 1, 4: 1, 5: 2, 6: 2} + >>> M = nx.Graph() + >>> nx.add_path(M, [0, 1, 2, 3, 4, 5, 6]) + >>> nx.add_path(M, [2, 7, 8, 9, 10]) + >>> sorted(nx.bfs_predecessors(M, source=1, depth_limit=3)) + [(0, 1), (2, 1), (3, 2), (4, 3), (7, 2), (8, 7)] + >>> N = nx.DiGraph() + >>> nx.add_path(N, [0, 1, 2, 3, 4, 7]) + >>> nx.add_path(N, [3, 5, 6, 7]) + >>> sorted(nx.bfs_predecessors(N, source=2)) + [(3, 2), (4, 3), (5, 3), (6, 5), (7, 4)] + + Notes + ----- + Based on http://www.ics.uci.edu/~eppstein/PADS/BFS.py + by D. Eppstein, July 2004. The modifications + to allow depth limits based on the Wikipedia article + "`Depth-limited-search`_". + + .. _Depth-limited-search: https://en.wikipedia.org/wiki/Depth-limited_search + + See Also + -------- + bfs_tree + bfs_edges + edge_bfs + """ + for s, t in bfs_edges( + G, source, depth_limit=depth_limit, sort_neighbors=sort_neighbors + ): + yield (t, s) + + +@nx._dispatchable +def bfs_successors(G, source, depth_limit=None, sort_neighbors=None): + """Returns an iterator of successors in breadth-first-search from source. + + Parameters + ---------- + G : NetworkX graph + + source : node + Specify starting node for breadth-first search + + depth_limit : int, optional(default=len(G)) + Specify the maximum search depth + + sort_neighbors : function (default=None) + A function that takes an iterator over nodes as the input, and + returns an iterable of the same nodes with a custom ordering. + For example, `sorted` will sort the nodes in increasing order. + + Returns + ------- + succ: iterator + (node, successors) iterator where `successors` is the non-empty list of + successors of `node` in a breadth first search from `source`. + To appear in the iterator, `node` must have successors. + + Examples + -------- + >>> G = nx.path_graph(3) + >>> dict(nx.bfs_successors(G, 0)) + {0: [1], 1: [2]} + >>> H = nx.Graph() + >>> H.add_edges_from([(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)]) + >>> dict(nx.bfs_successors(H, 0)) + {0: [1, 2], 1: [3, 4], 2: [5, 6]} + >>> G = nx.Graph() + >>> nx.add_path(G, [0, 1, 2, 3, 4, 5, 6]) + >>> nx.add_path(G, [2, 7, 8, 9, 10]) + >>> dict(nx.bfs_successors(G, source=1, depth_limit=3)) + {1: [0, 2], 2: [3, 7], 3: [4], 7: [8]} + >>> G = nx.DiGraph() + >>> nx.add_path(G, [0, 1, 2, 3, 4, 5]) + >>> dict(nx.bfs_successors(G, source=3)) + {3: [4], 4: [5]} + + Notes + ----- + Based on http://www.ics.uci.edu/~eppstein/PADS/BFS.py + by D. Eppstein, July 2004.The modifications + to allow depth limits based on the Wikipedia article + "`Depth-limited-search`_". + + .. _Depth-limited-search: https://en.wikipedia.org/wiki/Depth-limited_search + + See Also + -------- + bfs_tree + bfs_edges + edge_bfs + """ + parent = source + children = [] + for p, c in bfs_edges( + G, source, depth_limit=depth_limit, sort_neighbors=sort_neighbors + ): + if p == parent: + children.append(c) + continue + yield (parent, children) + children = [c] + parent = p + yield (parent, children) + + +@nx._dispatchable +def bfs_layers(G, sources): + """Returns an iterator of all the layers in breadth-first search traversal. + + Parameters + ---------- + G : NetworkX graph + A graph over which to find the layers using breadth-first search. + + sources : node in `G` or list of nodes in `G` + Specify starting nodes for single source or multiple sources breadth-first search + + Yields + ------ + layer: list of nodes + Yields list of nodes at the same distance from sources + + Examples + -------- + >>> G = nx.path_graph(5) + >>> dict(enumerate(nx.bfs_layers(G, [0, 4]))) + {0: [0, 4], 1: [1, 3], 2: [2]} + >>> H = nx.Graph() + >>> H.add_edges_from([(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)]) + >>> dict(enumerate(nx.bfs_layers(H, [1]))) + {0: [1], 1: [0, 3, 4], 2: [2], 3: [5, 6]} + >>> dict(enumerate(nx.bfs_layers(H, [1, 6]))) + {0: [1, 6], 1: [0, 3, 4, 2], 2: [5]} + """ + if sources in G: + sources = [sources] + + current_layer = list(sources) + visited = set(sources) + + for source in current_layer: + if source not in G: + raise nx.NetworkXError(f"The node {source} is not in the graph.") + + # this is basically BFS, except that the current layer only stores the nodes at + # same distance from sources at each iteration + while current_layer: + yield current_layer + next_layer = [] + for node in current_layer: + for child in G[node]: + if child not in visited: + visited.add(child) + next_layer.append(child) + current_layer = next_layer + + +REVERSE_EDGE = "reverse" +TREE_EDGE = "tree" +FORWARD_EDGE = "forward" +LEVEL_EDGE = "level" + + +@nx._dispatchable +def bfs_labeled_edges(G, sources): + """Iterate over edges in a breadth-first search (BFS) labeled by type. + + We generate triple of the form (*u*, *v*, *d*), where (*u*, *v*) is the + edge being explored in the breadth-first search and *d* is one of the + strings 'tree', 'forward', 'level', or 'reverse'. A 'tree' edge is one in + which *v* is first discovered and placed into the layer below *u*. A + 'forward' edge is one in which *u* is on the layer above *v* and *v* has + already been discovered. A 'level' edge is one in which both *u* and *v* + occur on the same layer. A 'reverse' edge is one in which *u* is on a layer + below *v*. + + We emit each edge exactly once. In an undirected graph, 'reverse' edges do + not occur, because each is discovered either as a 'tree' or 'forward' edge. + + Parameters + ---------- + G : NetworkX graph + A graph over which to find the layers using breadth-first search. + + sources : node in `G` or list of nodes in `G` + Starting nodes for single source or multiple sources breadth-first search + + Yields + ------ + edges: generator + A generator of triples (*u*, *v*, *d*) where (*u*, *v*) is the edge being + explored and *d* is described above. + + Examples + -------- + >>> G = nx.cycle_graph(4, create_using=nx.DiGraph) + >>> list(nx.bfs_labeled_edges(G, 0)) + [(0, 1, 'tree'), (1, 2, 'tree'), (2, 3, 'tree'), (3, 0, 'reverse')] + >>> G = nx.complete_graph(3) + >>> list(nx.bfs_labeled_edges(G, 0)) + [(0, 1, 'tree'), (0, 2, 'tree'), (1, 2, 'level')] + >>> list(nx.bfs_labeled_edges(G, [0, 1])) + [(0, 1, 'level'), (0, 2, 'tree'), (1, 2, 'forward')] + """ + if sources in G: + sources = [sources] + + neighbors = G._adj + directed = G.is_directed() + visited = set() + visit = visited.discard if directed else visited.add + # We use visited in a negative sense, so the visited set stays empty for the + # directed case and level edges are reported on their first occurrence in + # the undirected case. Note our use of visited.discard -- this is built-in + # thus somewhat faster than a python-defined def nop(x): pass + depth = {s: 0 for s in sources} + queue = deque(depth.items()) + push = queue.append + pop = queue.popleft + while queue: + u, du = pop() + for v in neighbors[u]: + if v not in depth: + depth[v] = dv = du + 1 + push((v, dv)) + yield u, v, TREE_EDGE + else: + dv = depth[v] + if du == dv: + if v not in visited: + yield u, v, LEVEL_EDGE + elif du < dv: + yield u, v, FORWARD_EDGE + elif directed: + yield u, v, REVERSE_EDGE + visit(u) + + +@nx._dispatchable +def descendants_at_distance(G, source, distance): + """Returns all nodes at a fixed `distance` from `source` in `G`. + + Parameters + ---------- + G : NetworkX graph + A graph + source : node in `G` + distance : the distance of the wanted nodes from `source` + + Returns + ------- + set() + The descendants of `source` in `G` at the given `distance` from `source` + + Examples + -------- + >>> G = nx.path_graph(5) + >>> nx.descendants_at_distance(G, 2, 2) + {0, 4} + >>> H = nx.DiGraph() + >>> H.add_edges_from([(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)]) + >>> nx.descendants_at_distance(H, 0, 2) + {3, 4, 5, 6} + >>> nx.descendants_at_distance(H, 5, 0) + {5} + >>> nx.descendants_at_distance(H, 5, 1) + set() + """ + if source not in G: + raise nx.NetworkXError(f"The node {source} is not in the graph.") + + bfs_generator = nx.bfs_layers(G, source) + for i, layer in enumerate(bfs_generator): + if i == distance: + return set(layer) + return set() diff --git a/lib/python3.10/site-packages/networkx/algorithms/traversal/depth_first_search.py b/lib/python3.10/site-packages/networkx/algorithms/traversal/depth_first_search.py new file mode 100644 index 0000000000000000000000000000000000000000..5bac5ecfd1cbefcba5707cac2885ef32987ee98b --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/traversal/depth_first_search.py @@ -0,0 +1,529 @@ +"""Basic algorithms for depth-first searching the nodes of a graph.""" + +from collections import defaultdict + +import networkx as nx + +__all__ = [ + "dfs_edges", + "dfs_tree", + "dfs_predecessors", + "dfs_successors", + "dfs_preorder_nodes", + "dfs_postorder_nodes", + "dfs_labeled_edges", +] + + +@nx._dispatchable +def dfs_edges(G, source=None, depth_limit=None, *, sort_neighbors=None): + """Iterate over edges in a depth-first-search (DFS). + + Perform a depth-first-search over the nodes of `G` and yield + the edges in order. This may not generate all edges in `G` + (see `~networkx.algorithms.traversal.edgedfs.edge_dfs`). + + Parameters + ---------- + G : NetworkX graph + + source : node, optional + Specify starting node for depth-first search and yield edges in + the component reachable from source. + + depth_limit : int, optional (default=len(G)) + Specify the maximum search depth. + + sort_neighbors : function (default=None) + A function that takes an iterator over nodes as the input, and + returns an iterable of the same nodes with a custom ordering. + For example, `sorted` will sort the nodes in increasing order. + + Yields + ------ + edge: 2-tuple of nodes + Yields edges resulting from the depth-first-search. + + Examples + -------- + >>> G = nx.path_graph(5) + >>> list(nx.dfs_edges(G, source=0)) + [(0, 1), (1, 2), (2, 3), (3, 4)] + >>> list(nx.dfs_edges(G, source=0, depth_limit=2)) + [(0, 1), (1, 2)] + + Notes + ----- + If a source is not specified then a source is chosen arbitrarily and + repeatedly until all components in the graph are searched. + + The implementation of this function is adapted from David Eppstein's + depth-first search function in PADS [1]_, with modifications + to allow depth limits based on the Wikipedia article + "Depth-limited search" [2]_. + + See Also + -------- + dfs_preorder_nodes + dfs_postorder_nodes + dfs_labeled_edges + :func:`~networkx.algorithms.traversal.edgedfs.edge_dfs` + :func:`~networkx.algorithms.traversal.breadth_first_search.bfs_edges` + + References + ---------- + .. [1] http://www.ics.uci.edu/~eppstein/PADS + .. [2] https://en.wikipedia.org/wiki/Depth-limited_search + """ + if source is None: + # edges for all components + nodes = G + else: + # edges for components with source + nodes = [source] + if depth_limit is None: + depth_limit = len(G) + + get_children = ( + G.neighbors + if sort_neighbors is None + else lambda n: iter(sort_neighbors(G.neighbors(n))) + ) + + visited = set() + for start in nodes: + if start in visited: + continue + visited.add(start) + stack = [(start, get_children(start))] + depth_now = 1 + while stack: + parent, children = stack[-1] + for child in children: + if child not in visited: + yield parent, child + visited.add(child) + if depth_now < depth_limit: + stack.append((child, get_children(child))) + depth_now += 1 + break + else: + stack.pop() + depth_now -= 1 + + +@nx._dispatchable(returns_graph=True) +def dfs_tree(G, source=None, depth_limit=None, *, sort_neighbors=None): + """Returns oriented tree constructed from a depth-first-search from source. + + Parameters + ---------- + G : NetworkX graph + + source : node, optional + Specify starting node for depth-first search. + + depth_limit : int, optional (default=len(G)) + Specify the maximum search depth. + + sort_neighbors : function (default=None) + A function that takes an iterator over nodes as the input, and + returns an iterable of the same nodes with a custom ordering. + For example, `sorted` will sort the nodes in increasing order. + + Returns + ------- + T : NetworkX DiGraph + An oriented tree + + Examples + -------- + >>> G = nx.path_graph(5) + >>> T = nx.dfs_tree(G, source=0, depth_limit=2) + >>> list(T.edges()) + [(0, 1), (1, 2)] + >>> T = nx.dfs_tree(G, source=0) + >>> list(T.edges()) + [(0, 1), (1, 2), (2, 3), (3, 4)] + + See Also + -------- + dfs_preorder_nodes + dfs_postorder_nodes + dfs_labeled_edges + :func:`~networkx.algorithms.traversal.edgedfs.edge_dfs` + :func:`~networkx.algorithms.traversal.breadth_first_search.bfs_tree` + """ + T = nx.DiGraph() + if source is None: + T.add_nodes_from(G) + else: + T.add_node(source) + T.add_edges_from(dfs_edges(G, source, depth_limit, sort_neighbors=sort_neighbors)) + return T + + +@nx._dispatchable +def dfs_predecessors(G, source=None, depth_limit=None, *, sort_neighbors=None): + """Returns dictionary of predecessors in depth-first-search from source. + + Parameters + ---------- + G : NetworkX graph + + source : node, optional + Specify starting node for depth-first search. + Note that you will get predecessors for all nodes in the + component containing `source`. This input only specifies + where the DFS starts. + + depth_limit : int, optional (default=len(G)) + Specify the maximum search depth. + + sort_neighbors : function (default=None) + A function that takes an iterator over nodes as the input, and + returns an iterable of the same nodes with a custom ordering. + For example, `sorted` will sort the nodes in increasing order. + + Returns + ------- + pred: dict + A dictionary with nodes as keys and predecessor nodes as values. + + Examples + -------- + >>> G = nx.path_graph(4) + >>> nx.dfs_predecessors(G, source=0) + {1: 0, 2: 1, 3: 2} + >>> nx.dfs_predecessors(G, source=0, depth_limit=2) + {1: 0, 2: 1} + + Notes + ----- + If a source is not specified then a source is chosen arbitrarily and + repeatedly until all components in the graph are searched. + + The implementation of this function is adapted from David Eppstein's + depth-first search function in `PADS`_, with modifications + to allow depth limits based on the Wikipedia article + "`Depth-limited search`_". + + .. _PADS: http://www.ics.uci.edu/~eppstein/PADS + .. _Depth-limited search: https://en.wikipedia.org/wiki/Depth-limited_search + + See Also + -------- + dfs_preorder_nodes + dfs_postorder_nodes + dfs_labeled_edges + :func:`~networkx.algorithms.traversal.edgedfs.edge_dfs` + :func:`~networkx.algorithms.traversal.breadth_first_search.bfs_tree` + """ + return { + t: s + for s, t in dfs_edges(G, source, depth_limit, sort_neighbors=sort_neighbors) + } + + +@nx._dispatchable +def dfs_successors(G, source=None, depth_limit=None, *, sort_neighbors=None): + """Returns dictionary of successors in depth-first-search from source. + + Parameters + ---------- + G : NetworkX graph + + source : node, optional + Specify starting node for depth-first search. + Note that you will get successors for all nodes in the + component containing `source`. This input only specifies + where the DFS starts. + + depth_limit : int, optional (default=len(G)) + Specify the maximum search depth. + + sort_neighbors : function (default=None) + A function that takes an iterator over nodes as the input, and + returns an iterable of the same nodes with a custom ordering. + For example, `sorted` will sort the nodes in increasing order. + + Returns + ------- + succ: dict + A dictionary with nodes as keys and list of successor nodes as values. + + Examples + -------- + >>> G = nx.path_graph(5) + >>> nx.dfs_successors(G, source=0) + {0: [1], 1: [2], 2: [3], 3: [4]} + >>> nx.dfs_successors(G, source=0, depth_limit=2) + {0: [1], 1: [2]} + + Notes + ----- + If a source is not specified then a source is chosen arbitrarily and + repeatedly until all components in the graph are searched. + + The implementation of this function is adapted from David Eppstein's + depth-first search function in `PADS`_, with modifications + to allow depth limits based on the Wikipedia article + "`Depth-limited search`_". + + .. _PADS: http://www.ics.uci.edu/~eppstein/PADS + .. _Depth-limited search: https://en.wikipedia.org/wiki/Depth-limited_search + + See Also + -------- + dfs_preorder_nodes + dfs_postorder_nodes + dfs_labeled_edges + :func:`~networkx.algorithms.traversal.edgedfs.edge_dfs` + :func:`~networkx.algorithms.traversal.breadth_first_search.bfs_tree` + """ + d = defaultdict(list) + for s, t in dfs_edges( + G, + source=source, + depth_limit=depth_limit, + sort_neighbors=sort_neighbors, + ): + d[s].append(t) + return dict(d) + + +@nx._dispatchable +def dfs_postorder_nodes(G, source=None, depth_limit=None, *, sort_neighbors=None): + """Generate nodes in a depth-first-search post-ordering starting at source. + + Parameters + ---------- + G : NetworkX graph + + source : node, optional + Specify starting node for depth-first search. + + depth_limit : int, optional (default=len(G)) + Specify the maximum search depth. + + sort_neighbors : function (default=None) + A function that takes an iterator over nodes as the input, and + returns an iterable of the same nodes with a custom ordering. + For example, `sorted` will sort the nodes in increasing order. + + Returns + ------- + nodes: generator + A generator of nodes in a depth-first-search post-ordering. + + Examples + -------- + >>> G = nx.path_graph(5) + >>> list(nx.dfs_postorder_nodes(G, source=0)) + [4, 3, 2, 1, 0] + >>> list(nx.dfs_postorder_nodes(G, source=0, depth_limit=2)) + [1, 0] + + Notes + ----- + If a source is not specified then a source is chosen arbitrarily and + repeatedly until all components in the graph are searched. + + The implementation of this function is adapted from David Eppstein's + depth-first search function in `PADS`_, with modifications + to allow depth limits based on the Wikipedia article + "`Depth-limited search`_". + + .. _PADS: http://www.ics.uci.edu/~eppstein/PADS + .. _Depth-limited search: https://en.wikipedia.org/wiki/Depth-limited_search + + See Also + -------- + dfs_edges + dfs_preorder_nodes + dfs_labeled_edges + :func:`~networkx.algorithms.traversal.edgedfs.edge_dfs` + :func:`~networkx.algorithms.traversal.breadth_first_search.bfs_tree` + """ + edges = nx.dfs_labeled_edges( + G, source=source, depth_limit=depth_limit, sort_neighbors=sort_neighbors + ) + return (v for u, v, d in edges if d == "reverse") + + +@nx._dispatchable +def dfs_preorder_nodes(G, source=None, depth_limit=None, *, sort_neighbors=None): + """Generate nodes in a depth-first-search pre-ordering starting at source. + + Parameters + ---------- + G : NetworkX graph + + source : node, optional + Specify starting node for depth-first search and return nodes in + the component reachable from source. + + depth_limit : int, optional (default=len(G)) + Specify the maximum search depth. + + sort_neighbors : function (default=None) + A function that takes an iterator over nodes as the input, and + returns an iterable of the same nodes with a custom ordering. + For example, `sorted` will sort the nodes in increasing order. + + Returns + ------- + nodes: generator + A generator of nodes in a depth-first-search pre-ordering. + + Examples + -------- + >>> G = nx.path_graph(5) + >>> list(nx.dfs_preorder_nodes(G, source=0)) + [0, 1, 2, 3, 4] + >>> list(nx.dfs_preorder_nodes(G, source=0, depth_limit=2)) + [0, 1, 2] + + Notes + ----- + If a source is not specified then a source is chosen arbitrarily and + repeatedly until all components in the graph are searched. + + The implementation of this function is adapted from David Eppstein's + depth-first search function in `PADS`_, with modifications + to allow depth limits based on the Wikipedia article + "`Depth-limited search`_". + + .. _PADS: http://www.ics.uci.edu/~eppstein/PADS + .. _Depth-limited search: https://en.wikipedia.org/wiki/Depth-limited_search + + See Also + -------- + dfs_edges + dfs_postorder_nodes + dfs_labeled_edges + :func:`~networkx.algorithms.traversal.breadth_first_search.bfs_edges` + """ + edges = nx.dfs_labeled_edges( + G, source=source, depth_limit=depth_limit, sort_neighbors=sort_neighbors + ) + return (v for u, v, d in edges if d == "forward") + + +@nx._dispatchable +def dfs_labeled_edges(G, source=None, depth_limit=None, *, sort_neighbors=None): + """Iterate over edges in a depth-first-search (DFS) labeled by type. + + Parameters + ---------- + G : NetworkX graph + + source : node, optional + Specify starting node for depth-first search and return edges in + the component reachable from source. + + depth_limit : int, optional (default=len(G)) + Specify the maximum search depth. + + sort_neighbors : function (default=None) + A function that takes an iterator over nodes as the input, and + returns an iterable of the same nodes with a custom ordering. + For example, `sorted` will sort the nodes in increasing order. + + Returns + ------- + edges: generator + A generator of triples of the form (*u*, *v*, *d*), where (*u*, + *v*) is the edge being explored in the depth-first search and *d* + is one of the strings 'forward', 'nontree', 'reverse', or 'reverse-depth_limit'. + A 'forward' edge is one in which *u* has been visited but *v* has + not. A 'nontree' edge is one in which both *u* and *v* have been + visited but the edge is not in the DFS tree. A 'reverse' edge is + one in which both *u* and *v* have been visited and the edge is in + the DFS tree. When the `depth_limit` is reached via a 'forward' edge, + a 'reverse' edge is immediately generated rather than the subtree + being explored. To indicate this flavor of 'reverse' edge, the string + yielded is 'reverse-depth_limit'. + + Examples + -------- + + The labels reveal the complete transcript of the depth-first search + algorithm in more detail than, for example, :func:`dfs_edges`:: + + >>> from pprint import pprint + >>> + >>> G = nx.DiGraph([(0, 1), (1, 2), (2, 1)]) + >>> pprint(list(nx.dfs_labeled_edges(G, source=0))) + [(0, 0, 'forward'), + (0, 1, 'forward'), + (1, 2, 'forward'), + (2, 1, 'nontree'), + (1, 2, 'reverse'), + (0, 1, 'reverse'), + (0, 0, 'reverse')] + + Notes + ----- + If a source is not specified then a source is chosen arbitrarily and + repeatedly until all components in the graph are searched. + + The implementation of this function is adapted from David Eppstein's + depth-first search function in `PADS`_, with modifications + to allow depth limits based on the Wikipedia article + "`Depth-limited search`_". + + .. _PADS: http://www.ics.uci.edu/~eppstein/PADS + .. _Depth-limited search: https://en.wikipedia.org/wiki/Depth-limited_search + + See Also + -------- + dfs_edges + dfs_preorder_nodes + dfs_postorder_nodes + """ + # Based on http://www.ics.uci.edu/~eppstein/PADS/DFS.py + # by D. Eppstein, July 2004. + if source is None: + # edges for all components + nodes = G + else: + # edges for components with source + nodes = [source] + if depth_limit is None: + depth_limit = len(G) + + get_children = ( + G.neighbors + if sort_neighbors is None + else lambda n: iter(sort_neighbors(G.neighbors(n))) + ) + + visited = set() + for start in nodes: + if start in visited: + continue + yield start, start, "forward" + visited.add(start) + stack = [(start, get_children(start))] + depth_now = 1 + while stack: + parent, children = stack[-1] + for child in children: + if child in visited: + yield parent, child, "nontree" + else: + yield parent, child, "forward" + visited.add(child) + if depth_now < depth_limit: + stack.append((child, iter(get_children(child)))) + depth_now += 1 + break + else: + yield parent, child, "reverse-depth_limit" + else: + stack.pop() + depth_now -= 1 + if stack: + yield stack[-1][0], parent, "reverse" + yield start, start, "reverse" diff --git a/lib/python3.10/site-packages/networkx/algorithms/traversal/edgebfs.py b/lib/python3.10/site-packages/networkx/algorithms/traversal/edgebfs.py new file mode 100644 index 0000000000000000000000000000000000000000..6320ddc2a683187136103dd1cb18036ae3088d03 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/traversal/edgebfs.py @@ -0,0 +1,178 @@ +""" +============================= +Breadth First Search on Edges +============================= + +Algorithms for a breadth-first traversal of edges in a graph. + +""" + +from collections import deque + +import networkx as nx + +FORWARD = "forward" +REVERSE = "reverse" + +__all__ = ["edge_bfs"] + + +@nx._dispatchable +def edge_bfs(G, source=None, orientation=None): + """A directed, breadth-first-search of edges in `G`, beginning at `source`. + + Yield the edges of G in a breadth-first-search order continuing until + all edges are generated. + + Parameters + ---------- + G : graph + A directed/undirected graph/multigraph. + + source : node, list of nodes + The node from which the traversal begins. If None, then a source + is chosen arbitrarily and repeatedly until all edges from each node in + the graph are searched. + + orientation : None | 'original' | 'reverse' | 'ignore' (default: None) + For directed graphs and directed multigraphs, edge traversals need not + respect the original orientation of the edges. + When set to 'reverse' every edge is traversed in the reverse direction. + When set to 'ignore', every edge is treated as undirected. + When set to 'original', every edge is treated as directed. + In all three cases, the yielded edge tuples add a last entry to + indicate the direction in which that edge was traversed. + If orientation is None, the yielded edge has no direction indicated. + The direction is respected, but not reported. + + Yields + ------ + edge : directed edge + A directed edge indicating the path taken by the breadth-first-search. + For graphs, `edge` is of the form `(u, v)` where `u` and `v` + are the tail and head of the edge as determined by the traversal. + For multigraphs, `edge` is of the form `(u, v, key)`, where `key` is + the key of the edge. When the graph is directed, then `u` and `v` + are always in the order of the actual directed edge. + If orientation is not None then the edge tuple is extended to include + the direction of traversal ('forward' or 'reverse') on that edge. + + Examples + -------- + >>> nodes = [0, 1, 2, 3] + >>> edges = [(0, 1), (1, 0), (1, 0), (2, 0), (2, 1), (3, 1)] + + >>> list(nx.edge_bfs(nx.Graph(edges), nodes)) + [(0, 1), (0, 2), (1, 2), (1, 3)] + + >>> list(nx.edge_bfs(nx.DiGraph(edges), nodes)) + [(0, 1), (1, 0), (2, 0), (2, 1), (3, 1)] + + >>> list(nx.edge_bfs(nx.MultiGraph(edges), nodes)) + [(0, 1, 0), (0, 1, 1), (0, 1, 2), (0, 2, 0), (1, 2, 0), (1, 3, 0)] + + >>> list(nx.edge_bfs(nx.MultiDiGraph(edges), nodes)) + [(0, 1, 0), (1, 0, 0), (1, 0, 1), (2, 0, 0), (2, 1, 0), (3, 1, 0)] + + >>> list(nx.edge_bfs(nx.DiGraph(edges), nodes, orientation="ignore")) + [(0, 1, 'forward'), (1, 0, 'reverse'), (2, 0, 'reverse'), (2, 1, 'reverse'), (3, 1, 'reverse')] + + >>> list(nx.edge_bfs(nx.MultiDiGraph(edges), nodes, orientation="ignore")) + [(0, 1, 0, 'forward'), (1, 0, 0, 'reverse'), (1, 0, 1, 'reverse'), (2, 0, 0, 'reverse'), (2, 1, 0, 'reverse'), (3, 1, 0, 'reverse')] + + Notes + ----- + The goal of this function is to visit edges. It differs from the more + familiar breadth-first-search of nodes, as provided by + :func:`networkx.algorithms.traversal.breadth_first_search.bfs_edges`, in + that it does not stop once every node has been visited. In a directed graph + with edges [(0, 1), (1, 2), (2, 1)], the edge (2, 1) would not be visited + if not for the functionality provided by this function. + + The naming of this function is very similar to bfs_edges. The difference + is that 'edge_bfs' yields edges even if they extend back to an already + explored node while 'bfs_edges' yields the edges of the tree that results + from a breadth-first-search (BFS) so no edges are reported if they extend + to already explored nodes. That means 'edge_bfs' reports all edges while + 'bfs_edges' only report those traversed by a node-based BFS. Yet another + description is that 'bfs_edges' reports the edges traversed during BFS + while 'edge_bfs' reports all edges in the order they are explored. + + See Also + -------- + bfs_edges + bfs_tree + edge_dfs + + """ + nodes = list(G.nbunch_iter(source)) + if not nodes: + return + + directed = G.is_directed() + kwds = {"data": False} + if G.is_multigraph() is True: + kwds["keys"] = True + + # set up edge lookup + if orientation is None: + + def edges_from(node): + return iter(G.edges(node, **kwds)) + + elif not directed or orientation == "original": + + def edges_from(node): + for e in G.edges(node, **kwds): + yield e + (FORWARD,) + + elif orientation == "reverse": + + def edges_from(node): + for e in G.in_edges(node, **kwds): + yield e + (REVERSE,) + + elif orientation == "ignore": + + def edges_from(node): + for e in G.edges(node, **kwds): + yield e + (FORWARD,) + for e in G.in_edges(node, **kwds): + yield e + (REVERSE,) + + else: + raise nx.NetworkXError("invalid orientation argument.") + + if directed: + neighbors = G.successors + + def edge_id(edge): + # remove direction indicator + return edge[:-1] if orientation is not None else edge + + else: + neighbors = G.neighbors + + def edge_id(edge): + return (frozenset(edge[:2]),) + edge[2:] + + check_reverse = directed and orientation in ("reverse", "ignore") + + # start BFS + visited_nodes = set(nodes) + visited_edges = set() + queue = deque([(n, edges_from(n)) for n in nodes]) + while queue: + parent, children_edges = queue.popleft() + for edge in children_edges: + if check_reverse and edge[-1] == REVERSE: + child = edge[0] + else: + child = edge[1] + if child not in visited_nodes: + visited_nodes.add(child) + queue.append((child, edges_from(child))) + edgeid = edge_id(edge) + if edgeid not in visited_edges: + visited_edges.add(edgeid) + yield edge diff --git a/lib/python3.10/site-packages/networkx/algorithms/traversal/edgedfs.py b/lib/python3.10/site-packages/networkx/algorithms/traversal/edgedfs.py new file mode 100644 index 0000000000000000000000000000000000000000..8f657f39fdd8a24660772d7cc3cef0f641ed61c5 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/traversal/edgedfs.py @@ -0,0 +1,176 @@ +""" +=========================== +Depth First Search on Edges +=========================== + +Algorithms for a depth-first traversal of edges in a graph. + +""" + +import networkx as nx + +FORWARD = "forward" +REVERSE = "reverse" + +__all__ = ["edge_dfs"] + + +@nx._dispatchable +def edge_dfs(G, source=None, orientation=None): + """A directed, depth-first-search of edges in `G`, beginning at `source`. + + Yield the edges of G in a depth-first-search order continuing until + all edges are generated. + + Parameters + ---------- + G : graph + A directed/undirected graph/multigraph. + + source : node, list of nodes + The node from which the traversal begins. If None, then a source + is chosen arbitrarily and repeatedly until all edges from each node in + the graph are searched. + + orientation : None | 'original' | 'reverse' | 'ignore' (default: None) + For directed graphs and directed multigraphs, edge traversals need not + respect the original orientation of the edges. + When set to 'reverse' every edge is traversed in the reverse direction. + When set to 'ignore', every edge is treated as undirected. + When set to 'original', every edge is treated as directed. + In all three cases, the yielded edge tuples add a last entry to + indicate the direction in which that edge was traversed. + If orientation is None, the yielded edge has no direction indicated. + The direction is respected, but not reported. + + Yields + ------ + edge : directed edge + A directed edge indicating the path taken by the depth-first traversal. + For graphs, `edge` is of the form `(u, v)` where `u` and `v` + are the tail and head of the edge as determined by the traversal. + For multigraphs, `edge` is of the form `(u, v, key)`, where `key` is + the key of the edge. When the graph is directed, then `u` and `v` + are always in the order of the actual directed edge. + If orientation is not None then the edge tuple is extended to include + the direction of traversal ('forward' or 'reverse') on that edge. + + Examples + -------- + >>> nodes = [0, 1, 2, 3] + >>> edges = [(0, 1), (1, 0), (1, 0), (2, 1), (3, 1)] + + >>> list(nx.edge_dfs(nx.Graph(edges), nodes)) + [(0, 1), (1, 2), (1, 3)] + + >>> list(nx.edge_dfs(nx.DiGraph(edges), nodes)) + [(0, 1), (1, 0), (2, 1), (3, 1)] + + >>> list(nx.edge_dfs(nx.MultiGraph(edges), nodes)) + [(0, 1, 0), (1, 0, 1), (0, 1, 2), (1, 2, 0), (1, 3, 0)] + + >>> list(nx.edge_dfs(nx.MultiDiGraph(edges), nodes)) + [(0, 1, 0), (1, 0, 0), (1, 0, 1), (2, 1, 0), (3, 1, 0)] + + >>> list(nx.edge_dfs(nx.DiGraph(edges), nodes, orientation="ignore")) + [(0, 1, 'forward'), (1, 0, 'forward'), (2, 1, 'reverse'), (3, 1, 'reverse')] + + >>> list(nx.edge_dfs(nx.MultiDiGraph(edges), nodes, orientation="ignore")) + [(0, 1, 0, 'forward'), (1, 0, 0, 'forward'), (1, 0, 1, 'reverse'), (2, 1, 0, 'reverse'), (3, 1, 0, 'reverse')] + + Notes + ----- + The goal of this function is to visit edges. It differs from the more + familiar depth-first traversal of nodes, as provided by + :func:`~networkx.algorithms.traversal.depth_first_search.dfs_edges`, in + that it does not stop once every node has been visited. In a directed graph + with edges [(0, 1), (1, 2), (2, 1)], the edge (2, 1) would not be visited + if not for the functionality provided by this function. + + See Also + -------- + :func:`~networkx.algorithms.traversal.depth_first_search.dfs_edges` + + """ + nodes = list(G.nbunch_iter(source)) + if not nodes: + return + + directed = G.is_directed() + kwds = {"data": False} + if G.is_multigraph() is True: + kwds["keys"] = True + + # set up edge lookup + if orientation is None: + + def edges_from(node): + return iter(G.edges(node, **kwds)) + + elif not directed or orientation == "original": + + def edges_from(node): + for e in G.edges(node, **kwds): + yield e + (FORWARD,) + + elif orientation == "reverse": + + def edges_from(node): + for e in G.in_edges(node, **kwds): + yield e + (REVERSE,) + + elif orientation == "ignore": + + def edges_from(node): + for e in G.edges(node, **kwds): + yield e + (FORWARD,) + for e in G.in_edges(node, **kwds): + yield e + (REVERSE,) + + else: + raise nx.NetworkXError("invalid orientation argument.") + + # set up formation of edge_id to easily look up if edge already returned + if directed: + + def edge_id(edge): + # remove direction indicator + return edge[:-1] if orientation is not None else edge + + else: + + def edge_id(edge): + # single id for undirected requires frozenset on nodes + return (frozenset(edge[:2]),) + edge[2:] + + # Basic setup + check_reverse = directed and orientation in ("reverse", "ignore") + + visited_edges = set() + visited_nodes = set() + edges = {} + + # start DFS + for start_node in nodes: + stack = [start_node] + while stack: + current_node = stack[-1] + if current_node not in visited_nodes: + edges[current_node] = edges_from(current_node) + visited_nodes.add(current_node) + + try: + edge = next(edges[current_node]) + except StopIteration: + # No more edges from the current node. + stack.pop() + else: + edgeid = edge_id(edge) + if edgeid not in visited_edges: + visited_edges.add(edgeid) + # Mark the traversed "to" node as to-be-explored. + if check_reverse and edge[-1] == REVERSE: + stack.append(edge[0]) + else: + stack.append(edge[1]) + yield edge diff --git a/lib/python3.10/site-packages/networkx/algorithms/traversal/tests/__init__.py b/lib/python3.10/site-packages/networkx/algorithms/traversal/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/python3.10/site-packages/networkx/algorithms/traversal/tests/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/traversal/tests/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..280ea1ab54a2a1155c8c23625151490eb5e7e5ef Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/traversal/tests/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/traversal/tests/__pycache__/test_beamsearch.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/traversal/tests/__pycache__/test_beamsearch.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b92132a1a429fc0687eb85c172605c542902fcc7 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/traversal/tests/__pycache__/test_beamsearch.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/traversal/tests/__pycache__/test_bfs.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/traversal/tests/__pycache__/test_bfs.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..50f3a6a04b5fd9998bf0c424baaa170903ddda70 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/traversal/tests/__pycache__/test_bfs.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/traversal/tests/__pycache__/test_dfs.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/traversal/tests/__pycache__/test_dfs.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f619b1da8f94f909cb7f2463497c389fcab9d32 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/traversal/tests/__pycache__/test_dfs.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/traversal/tests/__pycache__/test_edgebfs.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/traversal/tests/__pycache__/test_edgebfs.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0631484487afe276541d5533f1b4a72c6825e829 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/traversal/tests/__pycache__/test_edgebfs.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/traversal/tests/__pycache__/test_edgedfs.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/traversal/tests/__pycache__/test_edgedfs.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..217978302b24e2f2cb367d6c3ce63da15b8080b0 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/traversal/tests/__pycache__/test_edgedfs.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/traversal/tests/test_beamsearch.py b/lib/python3.10/site-packages/networkx/algorithms/traversal/tests/test_beamsearch.py new file mode 100644 index 0000000000000000000000000000000000000000..049f116b62fb595977d5ed94d01b9c15baf17fb3 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/traversal/tests/test_beamsearch.py @@ -0,0 +1,25 @@ +"""Unit tests for the beam search functions.""" + +import pytest + +import networkx as nx + + +def test_narrow(): + """Tests that a narrow beam width may cause an incomplete search.""" + # In this search, we enqueue only the neighbor 3 at the first + # step, then only the neighbor 2 at the second step. Once at + # node 2, the search chooses node 3, since it has a higher value + # than node 1, but node 3 has already been visited, so the + # search terminates. + G = nx.cycle_graph(4) + edges = nx.bfs_beam_edges(G, source=0, value=lambda n: n, width=1) + assert list(edges) == [(0, 3), (3, 2)] + + +@pytest.mark.parametrize("width", (2, None)) +def test_wide(width): + """All nodes are searched when `width` is None or >= max degree""" + G = nx.cycle_graph(4) + edges = nx.bfs_beam_edges(G, source=0, value=lambda n: n, width=width) + assert list(edges) == [(0, 3), (0, 1), (3, 2)] diff --git a/lib/python3.10/site-packages/networkx/algorithms/traversal/tests/test_bfs.py b/lib/python3.10/site-packages/networkx/algorithms/traversal/tests/test_bfs.py new file mode 100644 index 0000000000000000000000000000000000000000..fcfbbc68dc113fe3363233a98faa3e99d44df689 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/traversal/tests/test_bfs.py @@ -0,0 +1,203 @@ +from functools import partial + +import pytest + +import networkx as nx + + +class TestBFS: + @classmethod + def setup_class(cls): + # simple graph + G = nx.Graph() + G.add_edges_from([(0, 1), (1, 2), (1, 3), (2, 4), (3, 4)]) + cls.G = G + + def test_successor(self): + assert dict(nx.bfs_successors(self.G, source=0)) == {0: [1], 1: [2, 3], 2: [4]} + + def test_predecessor(self): + assert dict(nx.bfs_predecessors(self.G, source=0)) == {1: 0, 2: 1, 3: 1, 4: 2} + + def test_bfs_tree(self): + T = nx.bfs_tree(self.G, source=0) + assert sorted(T.nodes()) == sorted(self.G.nodes()) + assert sorted(T.edges()) == [(0, 1), (1, 2), (1, 3), (2, 4)] + + def test_bfs_edges(self): + edges = nx.bfs_edges(self.G, source=0) + assert list(edges) == [(0, 1), (1, 2), (1, 3), (2, 4)] + + def test_bfs_edges_reverse(self): + D = nx.DiGraph() + D.add_edges_from([(0, 1), (1, 2), (1, 3), (2, 4), (3, 4)]) + edges = nx.bfs_edges(D, source=4, reverse=True) + assert list(edges) == [(4, 2), (4, 3), (2, 1), (1, 0)] + + def test_bfs_edges_sorting(self): + D = nx.DiGraph() + D.add_edges_from([(0, 1), (0, 2), (1, 4), (1, 3), (2, 5)]) + sort_desc = partial(sorted, reverse=True) + edges_asc = nx.bfs_edges(D, source=0, sort_neighbors=sorted) + edges_desc = nx.bfs_edges(D, source=0, sort_neighbors=sort_desc) + assert list(edges_asc) == [(0, 1), (0, 2), (1, 3), (1, 4), (2, 5)] + assert list(edges_desc) == [(0, 2), (0, 1), (2, 5), (1, 4), (1, 3)] + + def test_bfs_tree_isolates(self): + G = nx.Graph() + G.add_node(1) + G.add_node(2) + T = nx.bfs_tree(G, source=1) + assert sorted(T.nodes()) == [1] + assert sorted(T.edges()) == [] + + def test_bfs_layers(self): + expected = { + 0: [0], + 1: [1], + 2: [2, 3], + 3: [4], + } + assert dict(enumerate(nx.bfs_layers(self.G, sources=[0]))) == expected + assert dict(enumerate(nx.bfs_layers(self.G, sources=0))) == expected + + def test_bfs_layers_missing_source(self): + with pytest.raises(nx.NetworkXError): + next(nx.bfs_layers(self.G, sources="abc")) + with pytest.raises(nx.NetworkXError): + next(nx.bfs_layers(self.G, sources=["abc"])) + + def test_descendants_at_distance(self): + for distance, descendants in enumerate([{0}, {1}, {2, 3}, {4}]): + assert nx.descendants_at_distance(self.G, 0, distance) == descendants + + def test_descendants_at_distance_missing_source(self): + with pytest.raises(nx.NetworkXError): + nx.descendants_at_distance(self.G, "abc", 0) + + def test_bfs_labeled_edges_directed(self): + D = nx.cycle_graph(5, create_using=nx.DiGraph) + expected = [ + (0, 1, "tree"), + (1, 2, "tree"), + (2, 3, "tree"), + (3, 4, "tree"), + (4, 0, "reverse"), + ] + answer = list(nx.bfs_labeled_edges(D, 0)) + assert expected == answer + + D.add_edge(4, 4) + expected.append((4, 4, "level")) + answer = list(nx.bfs_labeled_edges(D, 0)) + assert expected == answer + + D.add_edge(0, 2) + D.add_edge(1, 5) + D.add_edge(2, 5) + D.remove_edge(4, 4) + expected = [ + (0, 1, "tree"), + (0, 2, "tree"), + (1, 2, "level"), + (1, 5, "tree"), + (2, 3, "tree"), + (2, 5, "forward"), + (3, 4, "tree"), + (4, 0, "reverse"), + ] + answer = list(nx.bfs_labeled_edges(D, 0)) + assert expected == answer + + G = D.to_undirected() + G.add_edge(4, 4) + expected = [ + (0, 1, "tree"), + (0, 2, "tree"), + (0, 4, "tree"), + (1, 2, "level"), + (1, 5, "tree"), + (2, 3, "tree"), + (2, 5, "forward"), + (4, 3, "forward"), + (4, 4, "level"), + ] + answer = list(nx.bfs_labeled_edges(G, 0)) + assert expected == answer + + +class TestBreadthLimitedSearch: + @classmethod + def setup_class(cls): + # a tree + G = nx.Graph() + nx.add_path(G, [0, 1, 2, 3, 4, 5, 6]) + nx.add_path(G, [2, 7, 8, 9, 10]) + cls.G = G + # a disconnected graph + D = nx.Graph() + D.add_edges_from([(0, 1), (2, 3)]) + nx.add_path(D, [2, 7, 8, 9, 10]) + cls.D = D + + def test_limited_bfs_successor(self): + assert dict(nx.bfs_successors(self.G, source=1, depth_limit=3)) == { + 1: [0, 2], + 2: [3, 7], + 3: [4], + 7: [8], + } + result = { + n: sorted(s) for n, s in nx.bfs_successors(self.D, source=7, depth_limit=2) + } + assert result == {8: [9], 2: [3], 7: [2, 8]} + + def test_limited_bfs_predecessor(self): + assert dict(nx.bfs_predecessors(self.G, source=1, depth_limit=3)) == { + 0: 1, + 2: 1, + 3: 2, + 4: 3, + 7: 2, + 8: 7, + } + assert dict(nx.bfs_predecessors(self.D, source=7, depth_limit=2)) == { + 2: 7, + 3: 2, + 8: 7, + 9: 8, + } + + def test_limited_bfs_tree(self): + T = nx.bfs_tree(self.G, source=3, depth_limit=1) + assert sorted(T.edges()) == [(3, 2), (3, 4)] + + def test_limited_bfs_edges(self): + edges = nx.bfs_edges(self.G, source=9, depth_limit=4) + assert list(edges) == [(9, 8), (9, 10), (8, 7), (7, 2), (2, 1), (2, 3)] + + def test_limited_bfs_layers(self): + assert dict(enumerate(nx.bfs_layers(self.G, sources=[0]))) == { + 0: [0], + 1: [1], + 2: [2], + 3: [3, 7], + 4: [4, 8], + 5: [5, 9], + 6: [6, 10], + } + assert dict(enumerate(nx.bfs_layers(self.D, sources=2))) == { + 0: [2], + 1: [3, 7], + 2: [8], + 3: [9], + 4: [10], + } + + def test_limited_descendants_at_distance(self): + for distance, descendants in enumerate( + [{0}, {1}, {2}, {3, 7}, {4, 8}, {5, 9}, {6, 10}] + ): + assert nx.descendants_at_distance(self.G, 0, distance) == descendants + for distance, descendants in enumerate([{2}, {3, 7}, {8}, {9}, {10}]): + assert nx.descendants_at_distance(self.D, 2, distance) == descendants diff --git a/lib/python3.10/site-packages/networkx/algorithms/traversal/tests/test_dfs.py b/lib/python3.10/site-packages/networkx/algorithms/traversal/tests/test_dfs.py new file mode 100644 index 0000000000000000000000000000000000000000..e43d7d61629838aedc01b0a1624e018c0268fe11 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/traversal/tests/test_dfs.py @@ -0,0 +1,305 @@ +import networkx as nx + + +class TestDFS: + @classmethod + def setup_class(cls): + # simple graph + G = nx.Graph() + G.add_edges_from([(0, 1), (1, 2), (1, 3), (2, 4), (3, 0), (0, 4)]) + cls.G = G + # simple graph, disconnected + D = nx.Graph() + D.add_edges_from([(0, 1), (2, 3)]) + cls.D = D + + def test_preorder_nodes(self): + assert list(nx.dfs_preorder_nodes(self.G, source=0)) == [0, 1, 2, 4, 3] + assert list(nx.dfs_preorder_nodes(self.D)) == [0, 1, 2, 3] + assert list(nx.dfs_preorder_nodes(self.D, source=2)) == [2, 3] + + def test_postorder_nodes(self): + assert list(nx.dfs_postorder_nodes(self.G, source=0)) == [4, 2, 3, 1, 0] + assert list(nx.dfs_postorder_nodes(self.D)) == [1, 0, 3, 2] + assert list(nx.dfs_postorder_nodes(self.D, source=0)) == [1, 0] + + def test_successor(self): + assert nx.dfs_successors(self.G, source=0) == {0: [1], 1: [2, 3], 2: [4]} + assert nx.dfs_successors(self.G, source=1) == {0: [3, 4], 1: [0], 4: [2]} + assert nx.dfs_successors(self.D) == {0: [1], 2: [3]} + assert nx.dfs_successors(self.D, source=1) == {1: [0]} + + def test_predecessor(self): + assert nx.dfs_predecessors(self.G, source=0) == {1: 0, 2: 1, 3: 1, 4: 2} + assert nx.dfs_predecessors(self.D) == {1: 0, 3: 2} + + def test_dfs_tree(self): + exp_nodes = sorted(self.G.nodes()) + exp_edges = [(0, 1), (1, 2), (1, 3), (2, 4)] + # Search from first node + T = nx.dfs_tree(self.G, source=0) + assert sorted(T.nodes()) == exp_nodes + assert sorted(T.edges()) == exp_edges + # Check source=None + T = nx.dfs_tree(self.G, source=None) + assert sorted(T.nodes()) == exp_nodes + assert sorted(T.edges()) == exp_edges + # Check source=None is the default + T = nx.dfs_tree(self.G) + assert sorted(T.nodes()) == exp_nodes + assert sorted(T.edges()) == exp_edges + + def test_dfs_edges(self): + edges = nx.dfs_edges(self.G, source=0) + assert list(edges) == [(0, 1), (1, 2), (2, 4), (1, 3)] + edges = nx.dfs_edges(self.D) + assert list(edges) == [(0, 1), (2, 3)] + + def test_dfs_edges_sorting(self): + G = nx.Graph([(0, 1), (1, 2), (1, 3), (2, 4), (3, 0), (0, 4)]) + edges_asc = nx.dfs_edges(G, source=0, sort_neighbors=sorted) + sorted_desc = lambda x: sorted(x, reverse=True) + edges_desc = nx.dfs_edges(G, source=0, sort_neighbors=sorted_desc) + assert list(edges_asc) == [(0, 1), (1, 2), (2, 4), (1, 3)] + assert list(edges_desc) == [(0, 4), (4, 2), (2, 1), (1, 3)] + + def test_dfs_labeled_edges(self): + edges = list(nx.dfs_labeled_edges(self.G, source=0)) + forward = [(u, v) for (u, v, d) in edges if d == "forward"] + assert forward == [(0, 0), (0, 1), (1, 2), (2, 4), (1, 3)] + assert edges == [ + (0, 0, "forward"), + (0, 1, "forward"), + (1, 0, "nontree"), + (1, 2, "forward"), + (2, 1, "nontree"), + (2, 4, "forward"), + (4, 2, "nontree"), + (4, 0, "nontree"), + (2, 4, "reverse"), + (1, 2, "reverse"), + (1, 3, "forward"), + (3, 1, "nontree"), + (3, 0, "nontree"), + (1, 3, "reverse"), + (0, 1, "reverse"), + (0, 3, "nontree"), + (0, 4, "nontree"), + (0, 0, "reverse"), + ] + + def test_dfs_labeled_edges_sorting(self): + G = nx.Graph([(0, 1), (1, 2), (1, 3), (2, 4), (3, 0), (0, 4)]) + edges_asc = nx.dfs_labeled_edges(G, source=0, sort_neighbors=sorted) + sorted_desc = lambda x: sorted(x, reverse=True) + edges_desc = nx.dfs_labeled_edges(G, source=0, sort_neighbors=sorted_desc) + assert list(edges_asc) == [ + (0, 0, "forward"), + (0, 1, "forward"), + (1, 0, "nontree"), + (1, 2, "forward"), + (2, 1, "nontree"), + (2, 4, "forward"), + (4, 0, "nontree"), + (4, 2, "nontree"), + (2, 4, "reverse"), + (1, 2, "reverse"), + (1, 3, "forward"), + (3, 0, "nontree"), + (3, 1, "nontree"), + (1, 3, "reverse"), + (0, 1, "reverse"), + (0, 3, "nontree"), + (0, 4, "nontree"), + (0, 0, "reverse"), + ] + assert list(edges_desc) == [ + (0, 0, "forward"), + (0, 4, "forward"), + (4, 2, "forward"), + (2, 4, "nontree"), + (2, 1, "forward"), + (1, 3, "forward"), + (3, 1, "nontree"), + (3, 0, "nontree"), + (1, 3, "reverse"), + (1, 2, "nontree"), + (1, 0, "nontree"), + (2, 1, "reverse"), + (4, 2, "reverse"), + (4, 0, "nontree"), + (0, 4, "reverse"), + (0, 3, "nontree"), + (0, 1, "nontree"), + (0, 0, "reverse"), + ] + + def test_dfs_labeled_disconnected_edges(self): + edges = list(nx.dfs_labeled_edges(self.D)) + forward = [(u, v) for (u, v, d) in edges if d == "forward"] + assert forward == [(0, 0), (0, 1), (2, 2), (2, 3)] + assert edges == [ + (0, 0, "forward"), + (0, 1, "forward"), + (1, 0, "nontree"), + (0, 1, "reverse"), + (0, 0, "reverse"), + (2, 2, "forward"), + (2, 3, "forward"), + (3, 2, "nontree"), + (2, 3, "reverse"), + (2, 2, "reverse"), + ] + + def test_dfs_tree_isolates(self): + G = nx.Graph() + G.add_node(1) + G.add_node(2) + T = nx.dfs_tree(G, source=1) + assert sorted(T.nodes()) == [1] + assert sorted(T.edges()) == [] + T = nx.dfs_tree(G, source=None) + assert sorted(T.nodes()) == [1, 2] + assert sorted(T.edges()) == [] + + +class TestDepthLimitedSearch: + @classmethod + def setup_class(cls): + # a tree + G = nx.Graph() + nx.add_path(G, [0, 1, 2, 3, 4, 5, 6]) + nx.add_path(G, [2, 7, 8, 9, 10]) + cls.G = G + # a disconnected graph + D = nx.Graph() + D.add_edges_from([(0, 1), (2, 3)]) + nx.add_path(D, [2, 7, 8, 9, 10]) + cls.D = D + + def test_dls_preorder_nodes(self): + assert list(nx.dfs_preorder_nodes(self.G, source=0, depth_limit=2)) == [0, 1, 2] + assert list(nx.dfs_preorder_nodes(self.D, source=1, depth_limit=2)) == ([1, 0]) + + def test_dls_postorder_nodes(self): + assert list(nx.dfs_postorder_nodes(self.G, source=3, depth_limit=3)) == [ + 1, + 7, + 2, + 5, + 4, + 3, + ] + assert list(nx.dfs_postorder_nodes(self.D, source=2, depth_limit=2)) == ( + [3, 7, 2] + ) + + def test_dls_successor(self): + result = nx.dfs_successors(self.G, source=4, depth_limit=3) + assert {n: set(v) for n, v in result.items()} == { + 2: {1, 7}, + 3: {2}, + 4: {3, 5}, + 5: {6}, + } + result = nx.dfs_successors(self.D, source=7, depth_limit=2) + assert {n: set(v) for n, v in result.items()} == {8: {9}, 2: {3}, 7: {8, 2}} + + def test_dls_predecessor(self): + assert nx.dfs_predecessors(self.G, source=0, depth_limit=3) == { + 1: 0, + 2: 1, + 3: 2, + 7: 2, + } + assert nx.dfs_predecessors(self.D, source=2, depth_limit=3) == { + 8: 7, + 9: 8, + 3: 2, + 7: 2, + } + + def test_dls_tree(self): + T = nx.dfs_tree(self.G, source=3, depth_limit=1) + assert sorted(T.edges()) == [(3, 2), (3, 4)] + + def test_dls_edges(self): + edges = nx.dfs_edges(self.G, source=9, depth_limit=4) + assert list(edges) == [(9, 8), (8, 7), (7, 2), (2, 1), (2, 3), (9, 10)] + + def test_dls_labeled_edges_depth_1(self): + edges = list(nx.dfs_labeled_edges(self.G, source=5, depth_limit=1)) + forward = [(u, v) for (u, v, d) in edges if d == "forward"] + assert forward == [(5, 5), (5, 4), (5, 6)] + # Note: reverse-depth_limit edge types were not reported before gh-6240 + assert edges == [ + (5, 5, "forward"), + (5, 4, "forward"), + (5, 4, "reverse-depth_limit"), + (5, 6, "forward"), + (5, 6, "reverse-depth_limit"), + (5, 5, "reverse"), + ] + + def test_dls_labeled_edges_depth_2(self): + edges = list(nx.dfs_labeled_edges(self.G, source=6, depth_limit=2)) + forward = [(u, v) for (u, v, d) in edges if d == "forward"] + assert forward == [(6, 6), (6, 5), (5, 4)] + assert edges == [ + (6, 6, "forward"), + (6, 5, "forward"), + (5, 4, "forward"), + (5, 4, "reverse-depth_limit"), + (5, 6, "nontree"), + (6, 5, "reverse"), + (6, 6, "reverse"), + ] + + def test_dls_labeled_disconnected_edges(self): + edges = list(nx.dfs_labeled_edges(self.D, depth_limit=1)) + assert edges == [ + (0, 0, "forward"), + (0, 1, "forward"), + (0, 1, "reverse-depth_limit"), + (0, 0, "reverse"), + (2, 2, "forward"), + (2, 3, "forward"), + (2, 3, "reverse-depth_limit"), + (2, 7, "forward"), + (2, 7, "reverse-depth_limit"), + (2, 2, "reverse"), + (8, 8, "forward"), + (8, 7, "nontree"), + (8, 9, "forward"), + (8, 9, "reverse-depth_limit"), + (8, 8, "reverse"), + (10, 10, "forward"), + (10, 9, "nontree"), + (10, 10, "reverse"), + ] + # large depth_limit has no impact + edges = list(nx.dfs_labeled_edges(self.D, depth_limit=19)) + assert edges == [ + (0, 0, "forward"), + (0, 1, "forward"), + (1, 0, "nontree"), + (0, 1, "reverse"), + (0, 0, "reverse"), + (2, 2, "forward"), + (2, 3, "forward"), + (3, 2, "nontree"), + (2, 3, "reverse"), + (2, 7, "forward"), + (7, 2, "nontree"), + (7, 8, "forward"), + (8, 7, "nontree"), + (8, 9, "forward"), + (9, 8, "nontree"), + (9, 10, "forward"), + (10, 9, "nontree"), + (9, 10, "reverse"), + (8, 9, "reverse"), + (7, 8, "reverse"), + (2, 7, "reverse"), + (2, 2, "reverse"), + ] diff --git a/lib/python3.10/site-packages/networkx/algorithms/traversal/tests/test_edgebfs.py b/lib/python3.10/site-packages/networkx/algorithms/traversal/tests/test_edgebfs.py new file mode 100644 index 0000000000000000000000000000000000000000..1bf3fae0bd067dd548281e3382a6125f6e50ee22 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/traversal/tests/test_edgebfs.py @@ -0,0 +1,147 @@ +import pytest + +import networkx as nx +from networkx.algorithms.traversal.edgedfs import FORWARD, REVERSE + + +class TestEdgeBFS: + @classmethod + def setup_class(cls): + cls.nodes = [0, 1, 2, 3] + cls.edges = [(0, 1), (1, 0), (1, 0), (2, 0), (2, 1), (3, 1)] + + def test_empty(self): + G = nx.Graph() + edges = list(nx.edge_bfs(G)) + assert edges == [] + + def test_graph_single_source(self): + G = nx.Graph(self.edges) + G.add_edge(4, 5) + x = list(nx.edge_bfs(G, [0])) + x_ = [(0, 1), (0, 2), (1, 2), (1, 3)] + assert x == x_ + + def test_graph(self): + G = nx.Graph(self.edges) + x = list(nx.edge_bfs(G, self.nodes)) + x_ = [(0, 1), (0, 2), (1, 2), (1, 3)] + assert x == x_ + + def test_digraph(self): + G = nx.DiGraph(self.edges) + x = list(nx.edge_bfs(G, self.nodes)) + x_ = [(0, 1), (1, 0), (2, 0), (2, 1), (3, 1)] + assert x == x_ + + def test_digraph_orientation_invalid(self): + G = nx.DiGraph(self.edges) + edge_iterator = nx.edge_bfs(G, self.nodes, orientation="hello") + pytest.raises(nx.NetworkXError, list, edge_iterator) + + def test_digraph_orientation_none(self): + G = nx.DiGraph(self.edges) + x = list(nx.edge_bfs(G, self.nodes, orientation=None)) + x_ = [(0, 1), (1, 0), (2, 0), (2, 1), (3, 1)] + assert x == x_ + + def test_digraph_orientation_original(self): + G = nx.DiGraph(self.edges) + x = list(nx.edge_bfs(G, self.nodes, orientation="original")) + x_ = [ + (0, 1, FORWARD), + (1, 0, FORWARD), + (2, 0, FORWARD), + (2, 1, FORWARD), + (3, 1, FORWARD), + ] + assert x == x_ + + def test_digraph2(self): + G = nx.DiGraph() + nx.add_path(G, range(4)) + x = list(nx.edge_bfs(G, [0])) + x_ = [(0, 1), (1, 2), (2, 3)] + assert x == x_ + + def test_digraph_rev(self): + G = nx.DiGraph(self.edges) + x = list(nx.edge_bfs(G, self.nodes, orientation="reverse")) + x_ = [ + (1, 0, REVERSE), + (2, 0, REVERSE), + (0, 1, REVERSE), + (2, 1, REVERSE), + (3, 1, REVERSE), + ] + assert x == x_ + + def test_digraph_rev2(self): + G = nx.DiGraph() + nx.add_path(G, range(4)) + x = list(nx.edge_bfs(G, [3], orientation="reverse")) + x_ = [(2, 3, REVERSE), (1, 2, REVERSE), (0, 1, REVERSE)] + assert x == x_ + + def test_multigraph(self): + G = nx.MultiGraph(self.edges) + x = list(nx.edge_bfs(G, self.nodes)) + x_ = [(0, 1, 0), (0, 1, 1), (0, 1, 2), (0, 2, 0), (1, 2, 0), (1, 3, 0)] + # This is an example of where hash randomization can break. + # There are 3! * 2 alternative outputs, such as: + # [(0, 1, 1), (1, 0, 0), (0, 1, 2), (1, 3, 0), (1, 2, 0)] + # But note, the edges (1,2,0) and (1,3,0) always follow the (0,1,k) + # edges. So the algorithm only guarantees a partial order. A total + # order is guaranteed only if the graph data structures are ordered. + assert x == x_ + + def test_multidigraph(self): + G = nx.MultiDiGraph(self.edges) + x = list(nx.edge_bfs(G, self.nodes)) + x_ = [(0, 1, 0), (1, 0, 0), (1, 0, 1), (2, 0, 0), (2, 1, 0), (3, 1, 0)] + assert x == x_ + + def test_multidigraph_rev(self): + G = nx.MultiDiGraph(self.edges) + x = list(nx.edge_bfs(G, self.nodes, orientation="reverse")) + x_ = [ + (1, 0, 0, REVERSE), + (1, 0, 1, REVERSE), + (2, 0, 0, REVERSE), + (0, 1, 0, REVERSE), + (2, 1, 0, REVERSE), + (3, 1, 0, REVERSE), + ] + assert x == x_ + + def test_digraph_ignore(self): + G = nx.DiGraph(self.edges) + x = list(nx.edge_bfs(G, self.nodes, orientation="ignore")) + x_ = [ + (0, 1, FORWARD), + (1, 0, REVERSE), + (2, 0, REVERSE), + (2, 1, REVERSE), + (3, 1, REVERSE), + ] + assert x == x_ + + def test_digraph_ignore2(self): + G = nx.DiGraph() + nx.add_path(G, range(4)) + x = list(nx.edge_bfs(G, [0], orientation="ignore")) + x_ = [(0, 1, FORWARD), (1, 2, FORWARD), (2, 3, FORWARD)] + assert x == x_ + + def test_multidigraph_ignore(self): + G = nx.MultiDiGraph(self.edges) + x = list(nx.edge_bfs(G, self.nodes, orientation="ignore")) + x_ = [ + (0, 1, 0, FORWARD), + (1, 0, 0, REVERSE), + (1, 0, 1, REVERSE), + (2, 0, 0, REVERSE), + (2, 1, 0, REVERSE), + (3, 1, 0, REVERSE), + ] + assert x == x_ diff --git a/lib/python3.10/site-packages/networkx/algorithms/traversal/tests/test_edgedfs.py b/lib/python3.10/site-packages/networkx/algorithms/traversal/tests/test_edgedfs.py new file mode 100644 index 0000000000000000000000000000000000000000..7c1967cce04b3a0c9db80f9af39d7b1dfd8ef4cb --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/traversal/tests/test_edgedfs.py @@ -0,0 +1,131 @@ +import pytest + +import networkx as nx +from networkx.algorithms import edge_dfs +from networkx.algorithms.traversal.edgedfs import FORWARD, REVERSE + +# These tests can fail with hash randomization. The easiest and clearest way +# to write these unit tests is for the edges to be output in an expected total +# order, but we cannot guarantee the order amongst outgoing edges from a node, +# unless each class uses an ordered data structure for neighbors. This is +# painful to do with the current API. The alternative is that the tests are +# written (IMO confusingly) so that there is not a total order over the edges, +# but only a partial order. Due to the small size of the graphs, hopefully +# failures due to hash randomization will not occur. For an example of how +# this can fail, see TestEdgeDFS.test_multigraph. + + +class TestEdgeDFS: + @classmethod + def setup_class(cls): + cls.nodes = [0, 1, 2, 3] + cls.edges = [(0, 1), (1, 0), (1, 0), (2, 1), (3, 1)] + + def test_empty(self): + G = nx.Graph() + edges = list(edge_dfs(G)) + assert edges == [] + + def test_graph(self): + G = nx.Graph(self.edges) + x = list(edge_dfs(G, self.nodes)) + x_ = [(0, 1), (1, 2), (1, 3)] + assert x == x_ + + def test_digraph(self): + G = nx.DiGraph(self.edges) + x = list(edge_dfs(G, self.nodes)) + x_ = [(0, 1), (1, 0), (2, 1), (3, 1)] + assert x == x_ + + def test_digraph_orientation_invalid(self): + G = nx.DiGraph(self.edges) + edge_iterator = edge_dfs(G, self.nodes, orientation="hello") + pytest.raises(nx.NetworkXError, list, edge_iterator) + + def test_digraph_orientation_none(self): + G = nx.DiGraph(self.edges) + x = list(edge_dfs(G, self.nodes, orientation=None)) + x_ = [(0, 1), (1, 0), (2, 1), (3, 1)] + assert x == x_ + + def test_digraph_orientation_original(self): + G = nx.DiGraph(self.edges) + x = list(edge_dfs(G, self.nodes, orientation="original")) + x_ = [(0, 1, FORWARD), (1, 0, FORWARD), (2, 1, FORWARD), (3, 1, FORWARD)] + assert x == x_ + + def test_digraph2(self): + G = nx.DiGraph() + nx.add_path(G, range(4)) + x = list(edge_dfs(G, [0])) + x_ = [(0, 1), (1, 2), (2, 3)] + assert x == x_ + + def test_digraph_rev(self): + G = nx.DiGraph(self.edges) + x = list(edge_dfs(G, self.nodes, orientation="reverse")) + x_ = [(1, 0, REVERSE), (0, 1, REVERSE), (2, 1, REVERSE), (3, 1, REVERSE)] + assert x == x_ + + def test_digraph_rev2(self): + G = nx.DiGraph() + nx.add_path(G, range(4)) + x = list(edge_dfs(G, [3], orientation="reverse")) + x_ = [(2, 3, REVERSE), (1, 2, REVERSE), (0, 1, REVERSE)] + assert x == x_ + + def test_multigraph(self): + G = nx.MultiGraph(self.edges) + x = list(edge_dfs(G, self.nodes)) + x_ = [(0, 1, 0), (1, 0, 1), (0, 1, 2), (1, 2, 0), (1, 3, 0)] + # This is an example of where hash randomization can break. + # There are 3! * 2 alternative outputs, such as: + # [(0, 1, 1), (1, 0, 0), (0, 1, 2), (1, 3, 0), (1, 2, 0)] + # But note, the edges (1,2,0) and (1,3,0) always follow the (0,1,k) + # edges. So the algorithm only guarantees a partial order. A total + # order is guaranteed only if the graph data structures are ordered. + assert x == x_ + + def test_multidigraph(self): + G = nx.MultiDiGraph(self.edges) + x = list(edge_dfs(G, self.nodes)) + x_ = [(0, 1, 0), (1, 0, 0), (1, 0, 1), (2, 1, 0), (3, 1, 0)] + assert x == x_ + + def test_multidigraph_rev(self): + G = nx.MultiDiGraph(self.edges) + x = list(edge_dfs(G, self.nodes, orientation="reverse")) + x_ = [ + (1, 0, 0, REVERSE), + (0, 1, 0, REVERSE), + (1, 0, 1, REVERSE), + (2, 1, 0, REVERSE), + (3, 1, 0, REVERSE), + ] + assert x == x_ + + def test_digraph_ignore(self): + G = nx.DiGraph(self.edges) + x = list(edge_dfs(G, self.nodes, orientation="ignore")) + x_ = [(0, 1, FORWARD), (1, 0, FORWARD), (2, 1, REVERSE), (3, 1, REVERSE)] + assert x == x_ + + def test_digraph_ignore2(self): + G = nx.DiGraph() + nx.add_path(G, range(4)) + x = list(edge_dfs(G, [0], orientation="ignore")) + x_ = [(0, 1, FORWARD), (1, 2, FORWARD), (2, 3, FORWARD)] + assert x == x_ + + def test_multidigraph_ignore(self): + G = nx.MultiDiGraph(self.edges) + x = list(edge_dfs(G, self.nodes, orientation="ignore")) + x_ = [ + (0, 1, 0, FORWARD), + (1, 0, 0, FORWARD), + (1, 0, 1, REVERSE), + (2, 1, 0, REVERSE), + (3, 1, 0, REVERSE), + ] + assert x == x_ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tree/__init__.py b/lib/python3.10/site-packages/networkx/algorithms/tree/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7120d4bc7ef25279b68eaa23690b6ff4574ed676 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/tree/__init__.py @@ -0,0 +1,6 @@ +from .branchings import * +from .coding import * +from .mst import * +from .recognition import * +from .operations import * +from .decomposition import * diff --git a/lib/python3.10/site-packages/networkx/algorithms/tree/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tree/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..192f7f79086b25d6e1ca20aa10a672b12359dc03 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tree/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tree/__pycache__/branchings.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tree/__pycache__/branchings.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..295fe8d2ab0268edf241fd8f695c3eda76812d1a Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tree/__pycache__/branchings.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tree/__pycache__/coding.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tree/__pycache__/coding.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d3e451c0ac30b864e365bc57f0469b2039b450e Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tree/__pycache__/coding.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tree/__pycache__/decomposition.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tree/__pycache__/decomposition.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9339e719d33b06c23e40b4be9fa8d2914b7acf22 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tree/__pycache__/decomposition.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tree/__pycache__/mst.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tree/__pycache__/mst.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..867f3b190f0f46253987a98a8f75ab1491a18cca Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tree/__pycache__/mst.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tree/__pycache__/operations.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tree/__pycache__/operations.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d392287e80431d6a50875d23e0002b8f0448c5f1 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tree/__pycache__/operations.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tree/__pycache__/recognition.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tree/__pycache__/recognition.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee0c299670c72be03a059adde277ab2b46491230 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tree/__pycache__/recognition.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tree/branchings.py b/lib/python3.10/site-packages/networkx/algorithms/tree/branchings.py new file mode 100644 index 0000000000000000000000000000000000000000..cc9c7cf1189d341577a81501e5ca6760ed73a58c --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/tree/branchings.py @@ -0,0 +1,1042 @@ +""" +Algorithms for finding optimum branchings and spanning arborescences. + +This implementation is based on: + + J. Edmonds, Optimum branchings, J. Res. Natl. Bur. Standards 71B (1967), + 233–240. URL: http://archive.org/details/jresv71Bn4p233 + +""" + +# TODO: Implement method from Gabow, Galil, Spence and Tarjan: +# +# @article{ +# year={1986}, +# issn={0209-9683}, +# journal={Combinatorica}, +# volume={6}, +# number={2}, +# doi={10.1007/BF02579168}, +# title={Efficient algorithms for finding minimum spanning trees in +# undirected and directed graphs}, +# url={https://doi.org/10.1007/BF02579168}, +# publisher={Springer-Verlag}, +# keywords={68 B 15; 68 C 05}, +# author={Gabow, Harold N. and Galil, Zvi and Spencer, Thomas and Tarjan, +# Robert E.}, +# pages={109-122}, +# language={English} +# } +import string +from dataclasses import dataclass, field +from operator import itemgetter +from queue import PriorityQueue + +import networkx as nx +from networkx.utils import py_random_state + +from .recognition import is_arborescence, is_branching + +__all__ = [ + "branching_weight", + "greedy_branching", + "maximum_branching", + "minimum_branching", + "minimal_branching", + "maximum_spanning_arborescence", + "minimum_spanning_arborescence", + "ArborescenceIterator", +] + +KINDS = {"max", "min"} + +STYLES = { + "branching": "branching", + "arborescence": "arborescence", + "spanning arborescence": "arborescence", +} + +INF = float("inf") + + +@py_random_state(1) +def random_string(L=15, seed=None): + return "".join([seed.choice(string.ascii_letters) for n in range(L)]) + + +def _min_weight(weight): + return -weight + + +def _max_weight(weight): + return weight + + +@nx._dispatchable(edge_attrs={"attr": "default"}) +def branching_weight(G, attr="weight", default=1): + """ + Returns the total weight of a branching. + + You must access this function through the networkx.algorithms.tree module. + + Parameters + ---------- + G : DiGraph + The directed graph. + attr : str + The attribute to use as weights. If None, then each edge will be + treated equally with a weight of 1. + default : float + When `attr` is not None, then if an edge does not have that attribute, + `default` specifies what value it should take. + + Returns + ------- + weight: int or float + The total weight of the branching. + + Examples + -------- + >>> G = nx.DiGraph() + >>> G.add_weighted_edges_from([(0, 1, 2), (1, 2, 4), (2, 3, 3), (3, 4, 2)]) + >>> nx.tree.branching_weight(G) + 11 + + """ + return sum(edge[2].get(attr, default) for edge in G.edges(data=True)) + + +@py_random_state(4) +@nx._dispatchable(edge_attrs={"attr": "default"}, returns_graph=True) +def greedy_branching(G, attr="weight", default=1, kind="max", seed=None): + """ + Returns a branching obtained through a greedy algorithm. + + This algorithm is wrong, and cannot give a proper optimal branching. + However, we include it for pedagogical reasons, as it can be helpful to + see what its outputs are. + + The output is a branching, and possibly, a spanning arborescence. However, + it is not guaranteed to be optimal in either case. + + Parameters + ---------- + G : DiGraph + The directed graph to scan. + attr : str + The attribute to use as weights. If None, then each edge will be + treated equally with a weight of 1. + default : float + When `attr` is not None, then if an edge does not have that attribute, + `default` specifies what value it should take. + kind : str + The type of optimum to search for: 'min' or 'max' greedy branching. + seed : integer, random_state, or None (default) + Indicator of random number generation state. + See :ref:`Randomness`. + + Returns + ------- + B : directed graph + The greedily obtained branching. + + """ + if kind not in KINDS: + raise nx.NetworkXException("Unknown value for `kind`.") + + if kind == "min": + reverse = False + else: + reverse = True + + if attr is None: + # Generate a random string the graph probably won't have. + attr = random_string(seed=seed) + + edges = [(u, v, data.get(attr, default)) for (u, v, data) in G.edges(data=True)] + + # We sort by weight, but also by nodes to normalize behavior across runs. + try: + edges.sort(key=itemgetter(2, 0, 1), reverse=reverse) + except TypeError: + # This will fail in Python 3.x if the nodes are of varying types. + # In that case, we use the arbitrary order. + edges.sort(key=itemgetter(2), reverse=reverse) + + # The branching begins with a forest of no edges. + B = nx.DiGraph() + B.add_nodes_from(G) + + # Now we add edges greedily so long we maintain the branching. + uf = nx.utils.UnionFind() + for i, (u, v, w) in enumerate(edges): + if uf[u] == uf[v]: + # Adding this edge would form a directed cycle. + continue + elif B.in_degree(v) == 1: + # The edge would increase the degree to be greater than one. + continue + else: + # If attr was None, then don't insert weights... + data = {} + if attr is not None: + data[attr] = w + B.add_edge(u, v, **data) + uf.union(u, v) + + return B + + +@nx._dispatchable(preserve_edge_attrs=True, returns_graph=True) +def maximum_branching( + G, + attr="weight", + default=1, + preserve_attrs=False, + partition=None, +): + ####################################### + ### Data Structure Helper Functions ### + ####################################### + + def edmonds_add_edge(G, edge_index, u, v, key, **d): + """ + Adds an edge to `G` while also updating the edge index. + + This algorithm requires the use of an external dictionary to track + the edge keys since it is possible that the source or destination + node of an edge will be changed and the default key-handling + capabilities of the MultiDiGraph class do not account for this. + + Parameters + ---------- + G : MultiDiGraph + The graph to insert an edge into. + edge_index : dict + A mapping from integers to the edges of the graph. + u : node + The source node of the new edge. + v : node + The destination node of the new edge. + key : int + The key to use from `edge_index`. + d : keyword arguments, optional + Other attributes to store on the new edge. + """ + + if key in edge_index: + uu, vv, _ = edge_index[key] + if (u != uu) or (v != vv): + raise Exception(f"Key {key!r} is already in use.") + + G.add_edge(u, v, key, **d) + edge_index[key] = (u, v, G.succ[u][v][key]) + + def edmonds_remove_node(G, edge_index, n): + """ + Remove a node from the graph, updating the edge index to match. + + Parameters + ---------- + G : MultiDiGraph + The graph to remove an edge from. + edge_index : dict + A mapping from integers to the edges of the graph. + n : node + The node to remove from `G`. + """ + keys = set() + for keydict in G.pred[n].values(): + keys.update(keydict) + for keydict in G.succ[n].values(): + keys.update(keydict) + + for key in keys: + del edge_index[key] + + G.remove_node(n) + + ####################### + ### Algorithm Setup ### + ####################### + + # Pick an attribute name that the original graph is unlikly to have + candidate_attr = "edmonds' secret candidate attribute" + new_node_base_name = "edmonds new node base name " + + G_original = G + G = nx.MultiDiGraph() + G.__networkx_cache__ = None # Disable caching + + # A dict to reliably track mutations to the edges using the key of the edge. + G_edge_index = {} + # Each edge is given an arbitrary numerical key + for key, (u, v, data) in enumerate(G_original.edges(data=True)): + d = {attr: data.get(attr, default)} + + if data.get(partition) is not None: + d[partition] = data.get(partition) + + if preserve_attrs: + for d_k, d_v in data.items(): + if d_k != attr: + d[d_k] = d_v + + edmonds_add_edge(G, G_edge_index, u, v, key, **d) + + level = 0 # Stores the number of contracted nodes + + # These are the buckets from the paper. + # + # In the paper, G^i are modified versions of the original graph. + # D^i and E^i are the nodes and edges of the maximal edges that are + # consistent with G^i. In this implementation, D^i and E^i are stored + # together as the graph B^i. We will have strictly more B^i then the + # paper will have. + # + # Note that the data in graphs and branchings are tuples with the graph as + # the first element and the edge index as the second. + B = nx.MultiDiGraph() + B_edge_index = {} + graphs = [] # G^i list + branchings = [] # B^i list + selected_nodes = set() # D^i bucket + uf = nx.utils.UnionFind() + + # A list of lists of edge indices. Each list is a circuit for graph G^i. + # Note the edge list is not required to be a circuit in G^0. + circuits = [] + + # Stores the index of the minimum edge in the circuit found in G^i and B^i. + # The ordering of the edges seems to preserver the weight ordering from + # G^0. So even if the circuit does not form a circuit in G^0, it is still + # true that the minimum edges in circuit G^0 (despite their weights being + # different) + minedge_circuit = [] + + ########################### + ### Algorithm Structure ### + ########################### + + # Each step listed in the algorithm is an inner function. Thus, the overall + # loop structure is: + # + # while True: + # step_I1() + # if cycle detected: + # step_I2() + # elif every node of G is in D and E is a branching: + # break + + ################################## + ### Algorithm Helper Functions ### + ################################## + + def edmonds_find_desired_edge(v): + """ + Find the edge directed towards v with maximal weight. + + If an edge partition exists in this graph, return the included + edge if it exists and never return any excluded edge. + + Note: There can only be one included edge for each vertex otherwise + the edge partition is empty. + + Parameters + ---------- + v : node + The node to search for the maximal weight incoming edge. + """ + edge = None + max_weight = -INF + for u, _, key, data in G.in_edges(v, data=True, keys=True): + # Skip excluded edges + if data.get(partition) == nx.EdgePartition.EXCLUDED: + continue + + new_weight = data[attr] + + # Return the included edge + if data.get(partition) == nx.EdgePartition.INCLUDED: + max_weight = new_weight + edge = (u, v, key, new_weight, data) + break + + # Find the best open edge + if new_weight > max_weight: + max_weight = new_weight + edge = (u, v, key, new_weight, data) + + return edge, max_weight + + def edmonds_step_I2(v, desired_edge, level): + """ + Perform step I2 from Edmonds' paper + + First, check if the last step I1 created a cycle. If it did not, do nothing. + If it did, store the cycle for later reference and contract it. + + Parameters + ---------- + v : node + The current node to consider + desired_edge : edge + The minimum desired edge to remove from the cycle. + level : int + The current level, i.e. the number of cycles that have already been removed. + """ + u = desired_edge[0] + + Q_nodes = nx.shortest_path(B, v, u) + Q_edges = [ + list(B[Q_nodes[i]][vv].keys())[0] for i, vv in enumerate(Q_nodes[1:]) + ] + Q_edges.append(desired_edge[2]) # Add the new edge key to complete the circuit + + # Get the edge in the circuit with the minimum weight. + # Also, save the incoming weights for each node. + minweight = INF + minedge = None + Q_incoming_weight = {} + for edge_key in Q_edges: + u, v, data = B_edge_index[edge_key] + w = data[attr] + # We cannot remove an included edge, even if it is the + # minimum edge in the circuit + Q_incoming_weight[v] = w + if data.get(partition) == nx.EdgePartition.INCLUDED: + continue + if w < minweight: + minweight = w + minedge = edge_key + + circuits.append(Q_edges) + minedge_circuit.append(minedge) + graphs.append((G.copy(), G_edge_index.copy())) + branchings.append((B.copy(), B_edge_index.copy())) + + # Mutate the graph to contract the circuit + new_node = new_node_base_name + str(level) + G.add_node(new_node) + new_edges = [] + for u, v, key, data in G.edges(data=True, keys=True): + if u in Q_incoming_weight: + if v in Q_incoming_weight: + # Circuit edge. For the moment do nothing, + # eventually it will be removed. + continue + else: + # Outgoing edge from a node in the circuit. + # Make it come from the new node instead + dd = data.copy() + new_edges.append((new_node, v, key, dd)) + else: + if v in Q_incoming_weight: + # Incoming edge to the circuit. + # Update it's weight + w = data[attr] + w += minweight - Q_incoming_weight[v] + dd = data.copy() + dd[attr] = w + new_edges.append((u, new_node, key, dd)) + else: + # Outside edge. No modification needed + continue + + for node in Q_nodes: + edmonds_remove_node(G, G_edge_index, node) + edmonds_remove_node(B, B_edge_index, node) + + selected_nodes.difference_update(set(Q_nodes)) + + for u, v, key, data in new_edges: + edmonds_add_edge(G, G_edge_index, u, v, key, **data) + if candidate_attr in data: + del data[candidate_attr] + edmonds_add_edge(B, B_edge_index, u, v, key, **data) + uf.union(u, v) + + def is_root(G, u, edgekeys): + """ + Returns True if `u` is a root node in G. + + Node `u` is a root node if its in-degree over the specified edges is zero. + + Parameters + ---------- + G : Graph + The current graph. + u : node + The node in `G` to check if it is a root. + edgekeys : iterable of edges + The edges for which to check if `u` is a root of. + """ + if u not in G: + raise Exception(f"{u!r} not in G") + + for v in G.pred[u]: + for edgekey in G.pred[u][v]: + if edgekey in edgekeys: + return False, edgekey + else: + return True, None + + nodes = iter(list(G.nodes)) + while True: + try: + v = next(nodes) + except StopIteration: + # If there are no more new nodes to consider, then we should + # meet stopping condition (b) from the paper: + # (b) every node of G^i is in D^i and E^i is a branching + assert len(G) == len(B) + if len(B): + assert is_branching(B) + + graphs.append((G.copy(), G_edge_index.copy())) + branchings.append((B.copy(), B_edge_index.copy())) + circuits.append([]) + minedge_circuit.append(None) + + break + else: + ##################### + ### BEGIN STEP I1 ### + ##################### + + # This is a very simple step, so I don't think it needs a method of it's own + if v in selected_nodes: + continue + + selected_nodes.add(v) + B.add_node(v) + desired_edge, desired_edge_weight = edmonds_find_desired_edge(v) + + # There might be no desired edge if all edges are excluded or + # v is the last node to be added to B, the ultimate root of the branching + if desired_edge is not None and desired_edge_weight > 0: + u = desired_edge[0] + # Flag adding the edge will create a circuit before merging the two + # connected components of u and v in B + circuit = uf[u] == uf[v] + dd = {attr: desired_edge_weight} + if desired_edge[4].get(partition) is not None: + dd[partition] = desired_edge[4].get(partition) + + edmonds_add_edge(B, B_edge_index, u, v, desired_edge[2], **dd) + G[u][v][desired_edge[2]][candidate_attr] = True + uf.union(u, v) + + ################### + ### END STEP I1 ### + ################### + + ##################### + ### BEGIN STEP I2 ### + ##################### + + if circuit: + edmonds_step_I2(v, desired_edge, level) + nodes = iter(list(G.nodes())) + level += 1 + + ################### + ### END STEP I2 ### + ################### + + ##################### + ### BEGIN STEP I3 ### + ##################### + + # Create a new graph of the same class as the input graph + H = G_original.__class__() + + # Start with the branching edges in the last level. + edges = set(branchings[level][1]) + while level > 0: + level -= 1 + + # The current level is i, and we start counting from 0. + # + # We need the node at level i+1 that results from merging a circuit + # at level i. basename_0 is the first merged node and this happens + # at level 1. That is basename_0 is a node at level 1 that results + # from merging a circuit at level 0. + + merged_node = new_node_base_name + str(level) + circuit = circuits[level] + isroot, edgekey = is_root(graphs[level + 1][0], merged_node, edges) + edges.update(circuit) + + if isroot: + minedge = minedge_circuit[level] + if minedge is None: + raise Exception + + # Remove the edge in the cycle with minimum weight + edges.remove(minedge) + else: + # We have identified an edge at the next higher level that + # transitions into the merged node at this level. That edge + # transitions to some corresponding node at the current level. + # + # We want to remove an edge from the cycle that transitions + # into the corresponding node, otherwise the result would not + # be a branching. + + G, G_edge_index = graphs[level] + target = G_edge_index[edgekey][1] + for edgekey in circuit: + u, v, data = G_edge_index[edgekey] + if v == target: + break + else: + raise Exception("Couldn't find edge incoming to merged node.") + + edges.remove(edgekey) + + H.add_nodes_from(G_original) + for edgekey in edges: + u, v, d = graphs[0][1][edgekey] + dd = {attr: d[attr]} + + if preserve_attrs: + for key, value in d.items(): + if key not in [attr, candidate_attr]: + dd[key] = value + + H.add_edge(u, v, **dd) + + ################### + ### END STEP I3 ### + ################### + + return H + + +@nx._dispatchable(preserve_edge_attrs=True, mutates_input=True, returns_graph=True) +def minimum_branching( + G, attr="weight", default=1, preserve_attrs=False, partition=None +): + for _, _, d in G.edges(data=True): + d[attr] = -d.get(attr, default) + nx._clear_cache(G) + + B = maximum_branching(G, attr, default, preserve_attrs, partition) + + for _, _, d in G.edges(data=True): + d[attr] = -d.get(attr, default) + nx._clear_cache(G) + + for _, _, d in B.edges(data=True): + d[attr] = -d.get(attr, default) + nx._clear_cache(B) + + return B + + +@nx._dispatchable(preserve_edge_attrs=True, mutates_input=True, returns_graph=True) +def minimal_branching( + G, /, *, attr="weight", default=1, preserve_attrs=False, partition=None +): + """ + Returns a minimal branching from `G`. + + A minimal branching is a branching similar to a minimal arborescence but + without the requirement that the result is actually a spanning arborescence. + This allows minimal branchinges to be computed over graphs which may not + have arborescence (such as multiple components). + + Parameters + ---------- + G : (multi)digraph-like + The graph to be searched. + attr : str + The edge attribute used in determining optimality. + default : float + The value of the edge attribute used if an edge does not have + the attribute `attr`. + preserve_attrs : bool + If True, preserve the other attributes of the original graph (that are not + passed to `attr`) + partition : str + The key for the edge attribute containing the partition + data on the graph. Edges can be included, excluded or open using the + `EdgePartition` enum. + + Returns + ------- + B : (multi)digraph-like + A minimal branching. + """ + max_weight = -INF + min_weight = INF + for _, _, w in G.edges(data=attr, default=default): + if w > max_weight: + max_weight = w + if w < min_weight: + min_weight = w + + for _, _, d in G.edges(data=True): + # Transform the weights so that the minimum weight is larger than + # the difference between the max and min weights. This is important + # in order to prevent the edge weights from becoming negative during + # computation + d[attr] = max_weight + 1 + (max_weight - min_weight) - d.get(attr, default) + nx._clear_cache(G) + + B = maximum_branching(G, attr, default, preserve_attrs, partition) + + # Reverse the weight transformations + for _, _, d in G.edges(data=True): + d[attr] = max_weight + 1 + (max_weight - min_weight) - d.get(attr, default) + nx._clear_cache(G) + + for _, _, d in B.edges(data=True): + d[attr] = max_weight + 1 + (max_weight - min_weight) - d.get(attr, default) + nx._clear_cache(B) + + return B + + +@nx._dispatchable(preserve_edge_attrs=True, mutates_input=True, returns_graph=True) +def maximum_spanning_arborescence( + G, attr="weight", default=1, preserve_attrs=False, partition=None +): + # In order to use the same algorithm is the maximum branching, we need to adjust + # the weights of the graph. The branching algorithm can choose to not include an + # edge if it doesn't help find a branching, mainly triggered by edges with negative + # weights. + # + # To prevent this from happening while trying to find a spanning arborescence, we + # just have to tweak the edge weights so that they are all positive and cannot + # become negative during the branching algorithm, find the maximum branching and + # then return them to their original values. + + min_weight = INF + max_weight = -INF + for _, _, w in G.edges(data=attr, default=default): + if w < min_weight: + min_weight = w + if w > max_weight: + max_weight = w + + for _, _, d in G.edges(data=True): + d[attr] = d.get(attr, default) - min_weight + 1 - (min_weight - max_weight) + nx._clear_cache(G) + + B = maximum_branching(G, attr, default, preserve_attrs, partition) + + for _, _, d in G.edges(data=True): + d[attr] = d.get(attr, default) + min_weight - 1 + (min_weight - max_weight) + nx._clear_cache(G) + + for _, _, d in B.edges(data=True): + d[attr] = d.get(attr, default) + min_weight - 1 + (min_weight - max_weight) + nx._clear_cache(B) + + if not is_arborescence(B): + raise nx.exception.NetworkXException("No maximum spanning arborescence in G.") + + return B + + +@nx._dispatchable(preserve_edge_attrs=True, mutates_input=True, returns_graph=True) +def minimum_spanning_arborescence( + G, attr="weight", default=1, preserve_attrs=False, partition=None +): + B = minimal_branching( + G, + attr=attr, + default=default, + preserve_attrs=preserve_attrs, + partition=partition, + ) + + if not is_arborescence(B): + raise nx.exception.NetworkXException("No minimum spanning arborescence in G.") + + return B + + +docstring_branching = """ +Returns a {kind} {style} from G. + +Parameters +---------- +G : (multi)digraph-like + The graph to be searched. +attr : str + The edge attribute used to in determining optimality. +default : float + The value of the edge attribute used if an edge does not have + the attribute `attr`. +preserve_attrs : bool + If True, preserve the other attributes of the original graph (that are not + passed to `attr`) +partition : str + The key for the edge attribute containing the partition + data on the graph. Edges can be included, excluded or open using the + `EdgePartition` enum. + +Returns +------- +B : (multi)digraph-like + A {kind} {style}. +""" + +docstring_arborescence = ( + docstring_branching + + """ +Raises +------ +NetworkXException + If the graph does not contain a {kind} {style}. + +""" +) + +maximum_branching.__doc__ = docstring_branching.format( + kind="maximum", style="branching" +) + +minimum_branching.__doc__ = ( + docstring_branching.format(kind="minimum", style="branching") + + """ +See Also +-------- + minimal_branching +""" +) + +maximum_spanning_arborescence.__doc__ = docstring_arborescence.format( + kind="maximum", style="spanning arborescence" +) + +minimum_spanning_arborescence.__doc__ = docstring_arborescence.format( + kind="minimum", style="spanning arborescence" +) + + +class ArborescenceIterator: + """ + Iterate over all spanning arborescences of a graph in either increasing or + decreasing cost. + + Notes + ----- + This iterator uses the partition scheme from [1]_ (included edges, + excluded edges and open edges). It generates minimum spanning + arborescences using a modified Edmonds' Algorithm which respects the + partition of edges. For arborescences with the same weight, ties are + broken arbitrarily. + + References + ---------- + .. [1] G.K. Janssens, K. Sörensen, An algorithm to generate all spanning + trees in order of increasing cost, Pesquisa Operacional, 2005-08, + Vol. 25 (2), p. 219-229, + https://www.scielo.br/j/pope/a/XHswBwRwJyrfL88dmMwYNWp/?lang=en + """ + + @dataclass(order=True) + class Partition: + """ + This dataclass represents a partition and stores a dict with the edge + data and the weight of the minimum spanning arborescence of the + partition dict. + """ + + mst_weight: float + partition_dict: dict = field(compare=False) + + def __copy__(self): + return ArborescenceIterator.Partition( + self.mst_weight, self.partition_dict.copy() + ) + + def __init__(self, G, weight="weight", minimum=True, init_partition=None): + """ + Initialize the iterator + + Parameters + ---------- + G : nx.DiGraph + The directed graph which we need to iterate trees over + + weight : String, default = "weight" + The edge attribute used to store the weight of the edge + + minimum : bool, default = True + Return the trees in increasing order while true and decreasing order + while false. + + init_partition : tuple, default = None + In the case that certain edges have to be included or excluded from + the arborescences, `init_partition` should be in the form + `(included_edges, excluded_edges)` where each edges is a + `(u, v)`-tuple inside an iterable such as a list or set. + + """ + self.G = G.copy() + self.weight = weight + self.minimum = minimum + self.method = ( + minimum_spanning_arborescence if minimum else maximum_spanning_arborescence + ) + # Randomly create a key for an edge attribute to hold the partition data + self.partition_key = ( + "ArborescenceIterators super secret partition attribute name" + ) + if init_partition is not None: + partition_dict = {} + for e in init_partition[0]: + partition_dict[e] = nx.EdgePartition.INCLUDED + for e in init_partition[1]: + partition_dict[e] = nx.EdgePartition.EXCLUDED + self.init_partition = ArborescenceIterator.Partition(0, partition_dict) + else: + self.init_partition = None + + def __iter__(self): + """ + Returns + ------- + ArborescenceIterator + The iterator object for this graph + """ + self.partition_queue = PriorityQueue() + self._clear_partition(self.G) + + # Write the initial partition if it exists. + if self.init_partition is not None: + self._write_partition(self.init_partition) + + mst_weight = self.method( + self.G, + self.weight, + partition=self.partition_key, + preserve_attrs=True, + ).size(weight=self.weight) + + self.partition_queue.put( + self.Partition( + mst_weight if self.minimum else -mst_weight, + ( + {} + if self.init_partition is None + else self.init_partition.partition_dict + ), + ) + ) + + return self + + def __next__(self): + """ + Returns + ------- + (multi)Graph + The spanning tree of next greatest weight, which ties broken + arbitrarily. + """ + if self.partition_queue.empty(): + del self.G, self.partition_queue + raise StopIteration + + partition = self.partition_queue.get() + self._write_partition(partition) + next_arborescence = self.method( + self.G, + self.weight, + partition=self.partition_key, + preserve_attrs=True, + ) + self._partition(partition, next_arborescence) + + self._clear_partition(next_arborescence) + return next_arborescence + + def _partition(self, partition, partition_arborescence): + """ + Create new partitions based of the minimum spanning tree of the + current minimum partition. + + Parameters + ---------- + partition : Partition + The Partition instance used to generate the current minimum spanning + tree. + partition_arborescence : nx.Graph + The minimum spanning arborescence of the input partition. + """ + # create two new partitions with the data from the input partition dict + p1 = self.Partition(0, partition.partition_dict.copy()) + p2 = self.Partition(0, partition.partition_dict.copy()) + for e in partition_arborescence.edges: + # determine if the edge was open or included + if e not in partition.partition_dict: + # This is an open edge + p1.partition_dict[e] = nx.EdgePartition.EXCLUDED + p2.partition_dict[e] = nx.EdgePartition.INCLUDED + + self._write_partition(p1) + try: + p1_mst = self.method( + self.G, + self.weight, + partition=self.partition_key, + preserve_attrs=True, + ) + + p1_mst_weight = p1_mst.size(weight=self.weight) + p1.mst_weight = p1_mst_weight if self.minimum else -p1_mst_weight + self.partition_queue.put(p1.__copy__()) + except nx.NetworkXException: + pass + + p1.partition_dict = p2.partition_dict.copy() + + def _write_partition(self, partition): + """ + Writes the desired partition into the graph to calculate the minimum + spanning tree. Also, if one incoming edge is included, mark all others + as excluded so that if that vertex is merged during Edmonds' algorithm + we cannot still pick another of that vertex's included edges. + + Parameters + ---------- + partition : Partition + A Partition dataclass describing a partition on the edges of the + graph. + """ + for u, v, d in self.G.edges(data=True): + if (u, v) in partition.partition_dict: + d[self.partition_key] = partition.partition_dict[(u, v)] + else: + d[self.partition_key] = nx.EdgePartition.OPEN + nx._clear_cache(self.G) + + for n in self.G: + included_count = 0 + excluded_count = 0 + for u, v, d in self.G.in_edges(nbunch=n, data=True): + if d.get(self.partition_key) == nx.EdgePartition.INCLUDED: + included_count += 1 + elif d.get(self.partition_key) == nx.EdgePartition.EXCLUDED: + excluded_count += 1 + # Check that if there is an included edges, all other incoming ones + # are excluded. If not fix it! + if included_count == 1 and excluded_count != self.G.in_degree(n) - 1: + for u, v, d in self.G.in_edges(nbunch=n, data=True): + if d.get(self.partition_key) != nx.EdgePartition.INCLUDED: + d[self.partition_key] = nx.EdgePartition.EXCLUDED + + def _clear_partition(self, G): + """ + Removes partition data from the graph + """ + for u, v, d in G.edges(data=True): + if self.partition_key in d: + del d[self.partition_key] + nx._clear_cache(self.G) diff --git a/lib/python3.10/site-packages/networkx/algorithms/tree/coding.py b/lib/python3.10/site-packages/networkx/algorithms/tree/coding.py new file mode 100644 index 0000000000000000000000000000000000000000..f33089f76554bdb32651ddd145dc36835e260d3c --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/tree/coding.py @@ -0,0 +1,413 @@ +"""Functions for encoding and decoding trees. + +Since a tree is a highly restricted form of graph, it can be represented +concisely in several ways. This module includes functions for encoding +and decoding trees in the form of nested tuples and Prüfer +sequences. The former requires a rooted tree, whereas the latter can be +applied to unrooted trees. Furthermore, there is a bijection from Prüfer +sequences to labeled trees. + +""" + +from collections import Counter +from itertools import chain + +import networkx as nx +from networkx.utils import not_implemented_for + +__all__ = [ + "from_nested_tuple", + "from_prufer_sequence", + "NotATree", + "to_nested_tuple", + "to_prufer_sequence", +] + + +class NotATree(nx.NetworkXException): + """Raised when a function expects a tree (that is, a connected + undirected graph with no cycles) but gets a non-tree graph as input + instead. + + """ + + +@not_implemented_for("directed") +@nx._dispatchable(graphs="T") +def to_nested_tuple(T, root, canonical_form=False): + """Returns a nested tuple representation of the given tree. + + The nested tuple representation of a tree is defined + recursively. The tree with one node and no edges is represented by + the empty tuple, ``()``. A tree with ``k`` subtrees is represented + by a tuple of length ``k`` in which each element is the nested tuple + representation of a subtree. + + Parameters + ---------- + T : NetworkX graph + An undirected graph object representing a tree. + + root : node + The node in ``T`` to interpret as the root of the tree. + + canonical_form : bool + If ``True``, each tuple is sorted so that the function returns + a canonical form for rooted trees. This means "lighter" subtrees + will appear as nested tuples before "heavier" subtrees. In this + way, each isomorphic rooted tree has the same nested tuple + representation. + + Returns + ------- + tuple + A nested tuple representation of the tree. + + Notes + ----- + This function is *not* the inverse of :func:`from_nested_tuple`; the + only guarantee is that the rooted trees are isomorphic. + + See also + -------- + from_nested_tuple + to_prufer_sequence + + Examples + -------- + The tree need not be a balanced binary tree:: + + >>> T = nx.Graph() + >>> T.add_edges_from([(0, 1), (0, 2), (0, 3)]) + >>> T.add_edges_from([(1, 4), (1, 5)]) + >>> T.add_edges_from([(3, 6), (3, 7)]) + >>> root = 0 + >>> nx.to_nested_tuple(T, root) + (((), ()), (), ((), ())) + + Continuing the above example, if ``canonical_form`` is ``True``, the + nested tuples will be sorted:: + + >>> nx.to_nested_tuple(T, root, canonical_form=True) + ((), ((), ()), ((), ())) + + Even the path graph can be interpreted as a tree:: + + >>> T = nx.path_graph(4) + >>> root = 0 + >>> nx.to_nested_tuple(T, root) + ((((),),),) + + """ + + def _make_tuple(T, root, _parent): + """Recursively compute the nested tuple representation of the + given rooted tree. + + ``_parent`` is the parent node of ``root`` in the supertree in + which ``T`` is a subtree, or ``None`` if ``root`` is the root of + the supertree. This argument is used to determine which + neighbors of ``root`` are children and which is the parent. + + """ + # Get the neighbors of `root` that are not the parent node. We + # are guaranteed that `root` is always in `T` by construction. + children = set(T[root]) - {_parent} + if len(children) == 0: + return () + nested = (_make_tuple(T, v, root) for v in children) + if canonical_form: + nested = sorted(nested) + return tuple(nested) + + # Do some sanity checks on the input. + if not nx.is_tree(T): + raise nx.NotATree("provided graph is not a tree") + if root not in T: + raise nx.NodeNotFound(f"Graph {T} contains no node {root}") + + return _make_tuple(T, root, None) + + +@nx._dispatchable(graphs=None, returns_graph=True) +def from_nested_tuple(sequence, sensible_relabeling=False): + """Returns the rooted tree corresponding to the given nested tuple. + + The nested tuple representation of a tree is defined + recursively. The tree with one node and no edges is represented by + the empty tuple, ``()``. A tree with ``k`` subtrees is represented + by a tuple of length ``k`` in which each element is the nested tuple + representation of a subtree. + + Parameters + ---------- + sequence : tuple + A nested tuple representing a rooted tree. + + sensible_relabeling : bool + Whether to relabel the nodes of the tree so that nodes are + labeled in increasing order according to their breadth-first + search order from the root node. + + Returns + ------- + NetworkX graph + The tree corresponding to the given nested tuple, whose root + node is node 0. If ``sensible_labeling`` is ``True``, nodes will + be labeled in breadth-first search order starting from the root + node. + + Notes + ----- + This function is *not* the inverse of :func:`to_nested_tuple`; the + only guarantee is that the rooted trees are isomorphic. + + See also + -------- + to_nested_tuple + from_prufer_sequence + + Examples + -------- + Sensible relabeling ensures that the nodes are labeled from the root + starting at 0:: + + >>> balanced = (((), ()), ((), ())) + >>> T = nx.from_nested_tuple(balanced, sensible_relabeling=True) + >>> edges = [(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)] + >>> all((u, v) in T.edges() or (v, u) in T.edges() for (u, v) in edges) + True + + """ + + def _make_tree(sequence): + """Recursively creates a tree from the given sequence of nested + tuples. + + This function employs the :func:`~networkx.tree.join` function + to recursively join subtrees into a larger tree. + + """ + # The empty sequence represents the empty tree, which is the + # (unique) graph with a single node. We mark the single node + # with an attribute that indicates that it is the root of the + # graph. + if len(sequence) == 0: + return nx.empty_graph(1) + # For a nonempty sequence, get the subtrees for each child + # sequence and join all the subtrees at their roots. After + # joining the subtrees, the root is node 0. + return nx.tree.join_trees([(_make_tree(child), 0) for child in sequence]) + + # Make the tree and remove the `is_root` node attribute added by the + # helper function. + T = _make_tree(sequence) + if sensible_relabeling: + # Relabel the nodes according to their breadth-first search + # order, starting from the root node (that is, the node 0). + bfs_nodes = chain([0], (v for u, v in nx.bfs_edges(T, 0))) + labels = {v: i for i, v in enumerate(bfs_nodes)} + # We would like to use `copy=False`, but `relabel_nodes` doesn't + # allow a relabel mapping that can't be topologically sorted. + T = nx.relabel_nodes(T, labels) + return T + + +@not_implemented_for("directed") +@nx._dispatchable(graphs="T") +def to_prufer_sequence(T): + r"""Returns the Prüfer sequence of the given tree. + + A *Prüfer sequence* is a list of *n* - 2 numbers between 0 and + *n* - 1, inclusive. The tree corresponding to a given Prüfer + sequence can be recovered by repeatedly joining a node in the + sequence with a node with the smallest potential degree according to + the sequence. + + Parameters + ---------- + T : NetworkX graph + An undirected graph object representing a tree. + + Returns + ------- + list + The Prüfer sequence of the given tree. + + Raises + ------ + NetworkXPointlessConcept + If the number of nodes in `T` is less than two. + + NotATree + If `T` is not a tree. + + KeyError + If the set of nodes in `T` is not {0, …, *n* - 1}. + + Notes + ----- + There is a bijection from labeled trees to Prüfer sequences. This + function is the inverse of the :func:`from_prufer_sequence` + function. + + Sometimes Prüfer sequences use nodes labeled from 1 to *n* instead + of from 0 to *n* - 1. This function requires nodes to be labeled in + the latter form. You can use :func:`~networkx.relabel_nodes` to + relabel the nodes of your tree to the appropriate format. + + This implementation is from [1]_ and has a running time of + $O(n)$. + + See also + -------- + to_nested_tuple + from_prufer_sequence + + References + ---------- + .. [1] Wang, Xiaodong, Lei Wang, and Yingjie Wu. + "An optimal algorithm for Prufer codes." + *Journal of Software Engineering and Applications* 2.02 (2009): 111. + + + Examples + -------- + There is a bijection between Prüfer sequences and labeled trees, so + this function is the inverse of the :func:`from_prufer_sequence` + function: + + >>> edges = [(0, 3), (1, 3), (2, 3), (3, 4), (4, 5)] + >>> tree = nx.Graph(edges) + >>> sequence = nx.to_prufer_sequence(tree) + >>> sequence + [3, 3, 3, 4] + >>> tree2 = nx.from_prufer_sequence(sequence) + >>> list(tree2.edges()) == edges + True + + """ + # Perform some sanity checks on the input. + n = len(T) + if n < 2: + msg = "Prüfer sequence undefined for trees with fewer than two nodes" + raise nx.NetworkXPointlessConcept(msg) + if not nx.is_tree(T): + raise nx.NotATree("provided graph is not a tree") + if set(T) != set(range(n)): + raise KeyError("tree must have node labels {0, ..., n - 1}") + + degree = dict(T.degree()) + + def parents(u): + return next(v for v in T[u] if degree[v] > 1) + + index = u = next(k for k in range(n) if degree[k] == 1) + result = [] + for i in range(n - 2): + v = parents(u) + result.append(v) + degree[v] -= 1 + if v < index and degree[v] == 1: + u = v + else: + index = u = next(k for k in range(index + 1, n) if degree[k] == 1) + return result + + +@nx._dispatchable(graphs=None, returns_graph=True) +def from_prufer_sequence(sequence): + r"""Returns the tree corresponding to the given Prüfer sequence. + + A *Prüfer sequence* is a list of *n* - 2 numbers between 0 and + *n* - 1, inclusive. The tree corresponding to a given Prüfer + sequence can be recovered by repeatedly joining a node in the + sequence with a node with the smallest potential degree according to + the sequence. + + Parameters + ---------- + sequence : list + A Prüfer sequence, which is a list of *n* - 2 integers between + zero and *n* - 1, inclusive. + + Returns + ------- + NetworkX graph + The tree corresponding to the given Prüfer sequence. + + Raises + ------ + NetworkXError + If the Prüfer sequence is not valid. + + Notes + ----- + There is a bijection from labeled trees to Prüfer sequences. This + function is the inverse of the :func:`from_prufer_sequence` function. + + Sometimes Prüfer sequences use nodes labeled from 1 to *n* instead + of from 0 to *n* - 1. This function requires nodes to be labeled in + the latter form. You can use :func:`networkx.relabel_nodes` to + relabel the nodes of your tree to the appropriate format. + + This implementation is from [1]_ and has a running time of + $O(n)$. + + References + ---------- + .. [1] Wang, Xiaodong, Lei Wang, and Yingjie Wu. + "An optimal algorithm for Prufer codes." + *Journal of Software Engineering and Applications* 2.02 (2009): 111. + + + See also + -------- + from_nested_tuple + to_prufer_sequence + + Examples + -------- + There is a bijection between Prüfer sequences and labeled trees, so + this function is the inverse of the :func:`to_prufer_sequence` + function: + + >>> edges = [(0, 3), (1, 3), (2, 3), (3, 4), (4, 5)] + >>> tree = nx.Graph(edges) + >>> sequence = nx.to_prufer_sequence(tree) + >>> sequence + [3, 3, 3, 4] + >>> tree2 = nx.from_prufer_sequence(sequence) + >>> list(tree2.edges()) == edges + True + + """ + n = len(sequence) + 2 + # `degree` stores the remaining degree (plus one) for each node. The + # degree of a node in the decoded tree is one more than the number + # of times it appears in the code. + degree = Counter(chain(sequence, range(n))) + T = nx.empty_graph(n) + # `not_orphaned` is the set of nodes that have a parent in the + # tree. After the loop, there should be exactly two nodes that are + # not in this set. + not_orphaned = set() + index = u = next(k for k in range(n) if degree[k] == 1) + for v in sequence: + # check the validity of the prufer sequence + if v < 0 or v > n - 1: + raise nx.NetworkXError( + f"Invalid Prufer sequence: Values must be between 0 and {n-1}, got {v}" + ) + T.add_edge(u, v) + not_orphaned.add(u) + degree[v] -= 1 + if v < index and degree[v] == 1: + u = v + else: + index = u = next(k for k in range(index + 1, n) if degree[k] == 1) + # At this point, there must be exactly two orphaned nodes; join them. + orphans = set(T) - not_orphaned + u, v = orphans + T.add_edge(u, v) + return T diff --git a/lib/python3.10/site-packages/networkx/algorithms/tree/decomposition.py b/lib/python3.10/site-packages/networkx/algorithms/tree/decomposition.py new file mode 100644 index 0000000000000000000000000000000000000000..c8b8f2477b47581cd6010aba7e3329f5044e0da4 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/tree/decomposition.py @@ -0,0 +1,88 @@ +r"""Function for computing a junction tree of a graph.""" + +from itertools import combinations + +import networkx as nx +from networkx.algorithms import chordal_graph_cliques, complete_to_chordal_graph, moral +from networkx.utils import not_implemented_for + +__all__ = ["junction_tree"] + + +@not_implemented_for("multigraph") +@nx._dispatchable(returns_graph=True) +def junction_tree(G): + r"""Returns a junction tree of a given graph. + + A junction tree (or clique tree) is constructed from a (un)directed graph G. + The tree is constructed based on a moralized and triangulated version of G. + The tree's nodes consist of maximal cliques and sepsets of the revised graph. + The sepset of two cliques is the intersection of the nodes of these cliques, + e.g. the sepset of (A,B,C) and (A,C,E,F) is (A,C). These nodes are often called + "variables" in this literature. The tree is bipartite with each sepset + connected to its two cliques. + + Junction Trees are not unique as the order of clique consideration determines + which sepsets are included. + + The junction tree algorithm consists of five steps [1]_: + + 1. Moralize the graph + 2. Triangulate the graph + 3. Find maximal cliques + 4. Build the tree from cliques, connecting cliques with shared + nodes, set edge-weight to number of shared variables + 5. Find maximum spanning tree + + + Parameters + ---------- + G : networkx.Graph + Directed or undirected graph. + + Returns + ------- + junction_tree : networkx.Graph + The corresponding junction tree of `G`. + + Raises + ------ + NetworkXNotImplemented + Raised if `G` is an instance of `MultiGraph` or `MultiDiGraph`. + + References + ---------- + .. [1] Junction tree algorithm: + https://en.wikipedia.org/wiki/Junction_tree_algorithm + + .. [2] Finn V. Jensen and Frank Jensen. 1994. Optimal + junction trees. In Proceedings of the Tenth international + conference on Uncertainty in artificial intelligence (UAI’94). + Morgan Kaufmann Publishers Inc., San Francisco, CA, USA, 360–366. + """ + + clique_graph = nx.Graph() + + if G.is_directed(): + G = moral.moral_graph(G) + chordal_graph, _ = complete_to_chordal_graph(G) + + cliques = [tuple(sorted(i)) for i in chordal_graph_cliques(chordal_graph)] + clique_graph.add_nodes_from(cliques, type="clique") + + for edge in combinations(cliques, 2): + set_edge_0 = set(edge[0]) + set_edge_1 = set(edge[1]) + if not set_edge_0.isdisjoint(set_edge_1): + sepset = tuple(sorted(set_edge_0.intersection(set_edge_1))) + clique_graph.add_edge(edge[0], edge[1], weight=len(sepset), sepset=sepset) + + junction_tree = nx.maximum_spanning_tree(clique_graph) + + for edge in list(junction_tree.edges(data=True)): + junction_tree.add_node(edge[2]["sepset"], type="sepset") + junction_tree.add_edge(edge[0], edge[2]["sepset"]) + junction_tree.add_edge(edge[1], edge[2]["sepset"]) + junction_tree.remove_edge(edge[0], edge[1]) + + return junction_tree diff --git a/lib/python3.10/site-packages/networkx/algorithms/tree/mst.py b/lib/python3.10/site-packages/networkx/algorithms/tree/mst.py new file mode 100644 index 0000000000000000000000000000000000000000..554613b8f36dae63eb1ce7f4a03a646fd2dc81c4 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/tree/mst.py @@ -0,0 +1,1284 @@ +""" +Algorithms for calculating min/max spanning trees/forests. + +""" + +from dataclasses import dataclass, field +from enum import Enum +from heapq import heappop, heappush +from itertools import count +from math import isnan +from operator import itemgetter +from queue import PriorityQueue + +import networkx as nx +from networkx.utils import UnionFind, not_implemented_for, py_random_state + +__all__ = [ + "minimum_spanning_edges", + "maximum_spanning_edges", + "minimum_spanning_tree", + "maximum_spanning_tree", + "number_of_spanning_trees", + "random_spanning_tree", + "partition_spanning_tree", + "EdgePartition", + "SpanningTreeIterator", +] + + +class EdgePartition(Enum): + """ + An enum to store the state of an edge partition. The enum is written to the + edges of a graph before being pasted to `kruskal_mst_edges`. Options are: + + - EdgePartition.OPEN + - EdgePartition.INCLUDED + - EdgePartition.EXCLUDED + """ + + OPEN = 0 + INCLUDED = 1 + EXCLUDED = 2 + + +@not_implemented_for("multigraph") +@nx._dispatchable(edge_attrs="weight", preserve_edge_attrs="data") +def boruvka_mst_edges( + G, minimum=True, weight="weight", keys=False, data=True, ignore_nan=False +): + """Iterate over edges of a Borůvka's algorithm min/max spanning tree. + + Parameters + ---------- + G : NetworkX Graph + The edges of `G` must have distinct weights, + otherwise the edges may not form a tree. + + minimum : bool (default: True) + Find the minimum (True) or maximum (False) spanning tree. + + weight : string (default: 'weight') + The name of the edge attribute holding the edge weights. + + keys : bool (default: True) + This argument is ignored since this function is not + implemented for multigraphs; it exists only for consistency + with the other minimum spanning tree functions. + + data : bool (default: True) + Flag for whether to yield edge attribute dicts. + If True, yield edges `(u, v, d)`, where `d` is the attribute dict. + If False, yield edges `(u, v)`. + + ignore_nan : bool (default: False) + If a NaN is found as an edge weight normally an exception is raised. + If `ignore_nan is True` then that edge is ignored instead. + + """ + # Initialize a forest, assuming initially that it is the discrete + # partition of the nodes of the graph. + forest = UnionFind(G) + + def best_edge(component): + """Returns the optimum (minimum or maximum) edge on the edge + boundary of the given set of nodes. + + A return value of ``None`` indicates an empty boundary. + + """ + sign = 1 if minimum else -1 + minwt = float("inf") + boundary = None + for e in nx.edge_boundary(G, component, data=True): + wt = e[-1].get(weight, 1) * sign + if isnan(wt): + if ignore_nan: + continue + msg = f"NaN found as an edge weight. Edge {e}" + raise ValueError(msg) + if wt < minwt: + minwt = wt + boundary = e + return boundary + + # Determine the optimum edge in the edge boundary of each component + # in the forest. + best_edges = (best_edge(component) for component in forest.to_sets()) + best_edges = [edge for edge in best_edges if edge is not None] + # If each entry was ``None``, that means the graph was disconnected, + # so we are done generating the forest. + while best_edges: + # Determine the optimum edge in the edge boundary of each + # component in the forest. + # + # This must be a sequence, not an iterator. In this list, the + # same edge may appear twice, in different orientations (but + # that's okay, since a union operation will be called on the + # endpoints the first time it is seen, but not the second time). + # + # Any ``None`` indicates that the edge boundary for that + # component was empty, so that part of the forest has been + # completed. + # + # TODO This can be parallelized, both in the outer loop over + # each component in the forest and in the computation of the + # minimum. (Same goes for the identical lines outside the loop.) + best_edges = (best_edge(component) for component in forest.to_sets()) + best_edges = [edge for edge in best_edges if edge is not None] + # Join trees in the forest using the best edges, and yield that + # edge, since it is part of the spanning tree. + # + # TODO This loop can be parallelized, to an extent (the union + # operation must be atomic). + for u, v, d in best_edges: + if forest[u] != forest[v]: + if data: + yield u, v, d + else: + yield u, v + forest.union(u, v) + + +@nx._dispatchable( + edge_attrs={"weight": None, "partition": None}, preserve_edge_attrs="data" +) +def kruskal_mst_edges( + G, minimum, weight="weight", keys=True, data=True, ignore_nan=False, partition=None +): + """ + Iterate over edge of a Kruskal's algorithm min/max spanning tree. + + Parameters + ---------- + G : NetworkX Graph + The graph holding the tree of interest. + + minimum : bool (default: True) + Find the minimum (True) or maximum (False) spanning tree. + + weight : string (default: 'weight') + The name of the edge attribute holding the edge weights. + + keys : bool (default: True) + If `G` is a multigraph, `keys` controls whether edge keys ar yielded. + Otherwise `keys` is ignored. + + data : bool (default: True) + Flag for whether to yield edge attribute dicts. + If True, yield edges `(u, v, d)`, where `d` is the attribute dict. + If False, yield edges `(u, v)`. + + ignore_nan : bool (default: False) + If a NaN is found as an edge weight normally an exception is raised. + If `ignore_nan is True` then that edge is ignored instead. + + partition : string (default: None) + The name of the edge attribute holding the partition data, if it exists. + Partition data is written to the edges using the `EdgePartition` enum. + If a partition exists, all included edges and none of the excluded edges + will appear in the final tree. Open edges may or may not be used. + + Yields + ------ + edge tuple + The edges as discovered by Kruskal's method. Each edge can + take the following forms: `(u, v)`, `(u, v, d)` or `(u, v, k, d)` + depending on the `key` and `data` parameters + """ + subtrees = UnionFind() + if G.is_multigraph(): + edges = G.edges(keys=True, data=True) + else: + edges = G.edges(data=True) + + """ + Sort the edges of the graph with respect to the partition data. + Edges are returned in the following order: + + * Included edges + * Open edges from smallest to largest weight + * Excluded edges + """ + included_edges = [] + open_edges = [] + for e in edges: + d = e[-1] + wt = d.get(weight, 1) + if isnan(wt): + if ignore_nan: + continue + raise ValueError(f"NaN found as an edge weight. Edge {e}") + + edge = (wt,) + e + if d.get(partition) == EdgePartition.INCLUDED: + included_edges.append(edge) + elif d.get(partition) == EdgePartition.EXCLUDED: + continue + else: + open_edges.append(edge) + + if minimum: + sorted_open_edges = sorted(open_edges, key=itemgetter(0)) + else: + sorted_open_edges = sorted(open_edges, key=itemgetter(0), reverse=True) + + # Condense the lists into one + included_edges.extend(sorted_open_edges) + sorted_edges = included_edges + del open_edges, sorted_open_edges, included_edges + + # Multigraphs need to handle edge keys in addition to edge data. + if G.is_multigraph(): + for wt, u, v, k, d in sorted_edges: + if subtrees[u] != subtrees[v]: + if keys: + if data: + yield u, v, k, d + else: + yield u, v, k + else: + if data: + yield u, v, d + else: + yield u, v + subtrees.union(u, v) + else: + for wt, u, v, d in sorted_edges: + if subtrees[u] != subtrees[v]: + if data: + yield u, v, d + else: + yield u, v + subtrees.union(u, v) + + +@nx._dispatchable(edge_attrs="weight", preserve_edge_attrs="data") +def prim_mst_edges(G, minimum, weight="weight", keys=True, data=True, ignore_nan=False): + """Iterate over edges of Prim's algorithm min/max spanning tree. + + Parameters + ---------- + G : NetworkX Graph + The graph holding the tree of interest. + + minimum : bool (default: True) + Find the minimum (True) or maximum (False) spanning tree. + + weight : string (default: 'weight') + The name of the edge attribute holding the edge weights. + + keys : bool (default: True) + If `G` is a multigraph, `keys` controls whether edge keys ar yielded. + Otherwise `keys` is ignored. + + data : bool (default: True) + Flag for whether to yield edge attribute dicts. + If True, yield edges `(u, v, d)`, where `d` is the attribute dict. + If False, yield edges `(u, v)`. + + ignore_nan : bool (default: False) + If a NaN is found as an edge weight normally an exception is raised. + If `ignore_nan is True` then that edge is ignored instead. + + """ + is_multigraph = G.is_multigraph() + push = heappush + pop = heappop + + nodes = set(G) + c = count() + + sign = 1 if minimum else -1 + + while nodes: + u = nodes.pop() + frontier = [] + visited = {u} + if is_multigraph: + for v, keydict in G.adj[u].items(): + for k, d in keydict.items(): + wt = d.get(weight, 1) * sign + if isnan(wt): + if ignore_nan: + continue + msg = f"NaN found as an edge weight. Edge {(u, v, k, d)}" + raise ValueError(msg) + push(frontier, (wt, next(c), u, v, k, d)) + else: + for v, d in G.adj[u].items(): + wt = d.get(weight, 1) * sign + if isnan(wt): + if ignore_nan: + continue + msg = f"NaN found as an edge weight. Edge {(u, v, d)}" + raise ValueError(msg) + push(frontier, (wt, next(c), u, v, d)) + while nodes and frontier: + if is_multigraph: + W, _, u, v, k, d = pop(frontier) + else: + W, _, u, v, d = pop(frontier) + if v in visited or v not in nodes: + continue + # Multigraphs need to handle edge keys in addition to edge data. + if is_multigraph and keys: + if data: + yield u, v, k, d + else: + yield u, v, k + else: + if data: + yield u, v, d + else: + yield u, v + # update frontier + visited.add(v) + nodes.discard(v) + if is_multigraph: + for w, keydict in G.adj[v].items(): + if w in visited: + continue + for k2, d2 in keydict.items(): + new_weight = d2.get(weight, 1) * sign + if isnan(new_weight): + if ignore_nan: + continue + msg = f"NaN found as an edge weight. Edge {(v, w, k2, d2)}" + raise ValueError(msg) + push(frontier, (new_weight, next(c), v, w, k2, d2)) + else: + for w, d2 in G.adj[v].items(): + if w in visited: + continue + new_weight = d2.get(weight, 1) * sign + if isnan(new_weight): + if ignore_nan: + continue + msg = f"NaN found as an edge weight. Edge {(v, w, d2)}" + raise ValueError(msg) + push(frontier, (new_weight, next(c), v, w, d2)) + + +ALGORITHMS = { + "boruvka": boruvka_mst_edges, + "borůvka": boruvka_mst_edges, + "kruskal": kruskal_mst_edges, + "prim": prim_mst_edges, +} + + +@not_implemented_for("directed") +@nx._dispatchable(edge_attrs="weight", preserve_edge_attrs="data") +def minimum_spanning_edges( + G, algorithm="kruskal", weight="weight", keys=True, data=True, ignore_nan=False +): + """Generate edges in a minimum spanning forest of an undirected + weighted graph. + + A minimum spanning tree is a subgraph of the graph (a tree) + with the minimum sum of edge weights. A spanning forest is a + union of the spanning trees for each connected component of the graph. + + Parameters + ---------- + G : undirected Graph + An undirected graph. If `G` is connected, then the algorithm finds a + spanning tree. Otherwise, a spanning forest is found. + + algorithm : string + The algorithm to use when finding a minimum spanning tree. Valid + choices are 'kruskal', 'prim', or 'boruvka'. The default is 'kruskal'. + + weight : string + Edge data key to use for weight (default 'weight'). + + keys : bool + Whether to yield edge key in multigraphs in addition to the edge. + If `G` is not a multigraph, this is ignored. + + data : bool, optional + If True yield the edge data along with the edge. + + ignore_nan : bool (default: False) + If a NaN is found as an edge weight normally an exception is raised. + If `ignore_nan is True` then that edge is ignored instead. + + Returns + ------- + edges : iterator + An iterator over edges in a maximum spanning tree of `G`. + Edges connecting nodes `u` and `v` are represented as tuples: + `(u, v, k, d)` or `(u, v, k)` or `(u, v, d)` or `(u, v)` + + If `G` is a multigraph, `keys` indicates whether the edge key `k` will + be reported in the third position in the edge tuple. `data` indicates + whether the edge datadict `d` will appear at the end of the edge tuple. + + If `G` is not a multigraph, the tuples are `(u, v, d)` if `data` is True + or `(u, v)` if `data` is False. + + Examples + -------- + >>> from networkx.algorithms import tree + + Find minimum spanning edges by Kruskal's algorithm + + >>> G = nx.cycle_graph(4) + >>> G.add_edge(0, 3, weight=2) + >>> mst = tree.minimum_spanning_edges(G, algorithm="kruskal", data=False) + >>> edgelist = list(mst) + >>> sorted(sorted(e) for e in edgelist) + [[0, 1], [1, 2], [2, 3]] + + Find minimum spanning edges by Prim's algorithm + + >>> G = nx.cycle_graph(4) + >>> G.add_edge(0, 3, weight=2) + >>> mst = tree.minimum_spanning_edges(G, algorithm="prim", data=False) + >>> edgelist = list(mst) + >>> sorted(sorted(e) for e in edgelist) + [[0, 1], [1, 2], [2, 3]] + + Notes + ----- + For Borůvka's algorithm, each edge must have a weight attribute, and + each edge weight must be distinct. + + For the other algorithms, if the graph edges do not have a weight + attribute a default weight of 1 will be used. + + Modified code from David Eppstein, April 2006 + http://www.ics.uci.edu/~eppstein/PADS/ + + """ + try: + algo = ALGORITHMS[algorithm] + except KeyError as err: + msg = f"{algorithm} is not a valid choice for an algorithm." + raise ValueError(msg) from err + + return algo( + G, minimum=True, weight=weight, keys=keys, data=data, ignore_nan=ignore_nan + ) + + +@not_implemented_for("directed") +@nx._dispatchable(edge_attrs="weight", preserve_edge_attrs="data") +def maximum_spanning_edges( + G, algorithm="kruskal", weight="weight", keys=True, data=True, ignore_nan=False +): + """Generate edges in a maximum spanning forest of an undirected + weighted graph. + + A maximum spanning tree is a subgraph of the graph (a tree) + with the maximum possible sum of edge weights. A spanning forest is a + union of the spanning trees for each connected component of the graph. + + Parameters + ---------- + G : undirected Graph + An undirected graph. If `G` is connected, then the algorithm finds a + spanning tree. Otherwise, a spanning forest is found. + + algorithm : string + The algorithm to use when finding a maximum spanning tree. Valid + choices are 'kruskal', 'prim', or 'boruvka'. The default is 'kruskal'. + + weight : string + Edge data key to use for weight (default 'weight'). + + keys : bool + Whether to yield edge key in multigraphs in addition to the edge. + If `G` is not a multigraph, this is ignored. + + data : bool, optional + If True yield the edge data along with the edge. + + ignore_nan : bool (default: False) + If a NaN is found as an edge weight normally an exception is raised. + If `ignore_nan is True` then that edge is ignored instead. + + Returns + ------- + edges : iterator + An iterator over edges in a maximum spanning tree of `G`. + Edges connecting nodes `u` and `v` are represented as tuples: + `(u, v, k, d)` or `(u, v, k)` or `(u, v, d)` or `(u, v)` + + If `G` is a multigraph, `keys` indicates whether the edge key `k` will + be reported in the third position in the edge tuple. `data` indicates + whether the edge datadict `d` will appear at the end of the edge tuple. + + If `G` is not a multigraph, the tuples are `(u, v, d)` if `data` is True + or `(u, v)` if `data` is False. + + Examples + -------- + >>> from networkx.algorithms import tree + + Find maximum spanning edges by Kruskal's algorithm + + >>> G = nx.cycle_graph(4) + >>> G.add_edge(0, 3, weight=2) + >>> mst = tree.maximum_spanning_edges(G, algorithm="kruskal", data=False) + >>> edgelist = list(mst) + >>> sorted(sorted(e) for e in edgelist) + [[0, 1], [0, 3], [1, 2]] + + Find maximum spanning edges by Prim's algorithm + + >>> G = nx.cycle_graph(4) + >>> G.add_edge(0, 3, weight=2) # assign weight 2 to edge 0-3 + >>> mst = tree.maximum_spanning_edges(G, algorithm="prim", data=False) + >>> edgelist = list(mst) + >>> sorted(sorted(e) for e in edgelist) + [[0, 1], [0, 3], [2, 3]] + + Notes + ----- + For Borůvka's algorithm, each edge must have a weight attribute, and + each edge weight must be distinct. + + For the other algorithms, if the graph edges do not have a weight + attribute a default weight of 1 will be used. + + Modified code from David Eppstein, April 2006 + http://www.ics.uci.edu/~eppstein/PADS/ + """ + try: + algo = ALGORITHMS[algorithm] + except KeyError as err: + msg = f"{algorithm} is not a valid choice for an algorithm." + raise ValueError(msg) from err + + return algo( + G, minimum=False, weight=weight, keys=keys, data=data, ignore_nan=ignore_nan + ) + + +@nx._dispatchable(preserve_all_attrs=True, returns_graph=True) +def minimum_spanning_tree(G, weight="weight", algorithm="kruskal", ignore_nan=False): + """Returns a minimum spanning tree or forest on an undirected graph `G`. + + Parameters + ---------- + G : undirected graph + An undirected graph. If `G` is connected, then the algorithm finds a + spanning tree. Otherwise, a spanning forest is found. + + weight : str + Data key to use for edge weights. + + algorithm : string + The algorithm to use when finding a minimum spanning tree. Valid + choices are 'kruskal', 'prim', or 'boruvka'. The default is + 'kruskal'. + + ignore_nan : bool (default: False) + If a NaN is found as an edge weight normally an exception is raised. + If `ignore_nan is True` then that edge is ignored instead. + + Returns + ------- + G : NetworkX Graph + A minimum spanning tree or forest. + + Examples + -------- + >>> G = nx.cycle_graph(4) + >>> G.add_edge(0, 3, weight=2) + >>> T = nx.minimum_spanning_tree(G) + >>> sorted(T.edges(data=True)) + [(0, 1, {}), (1, 2, {}), (2, 3, {})] + + + Notes + ----- + For Borůvka's algorithm, each edge must have a weight attribute, and + each edge weight must be distinct. + + For the other algorithms, if the graph edges do not have a weight + attribute a default weight of 1 will be used. + + There may be more than one tree with the same minimum or maximum weight. + See :mod:`networkx.tree.recognition` for more detailed definitions. + + Isolated nodes with self-loops are in the tree as edgeless isolated nodes. + + """ + edges = minimum_spanning_edges( + G, algorithm, weight, keys=True, data=True, ignore_nan=ignore_nan + ) + T = G.__class__() # Same graph class as G + T.graph.update(G.graph) + T.add_nodes_from(G.nodes.items()) + T.add_edges_from(edges) + return T + + +@nx._dispatchable(preserve_all_attrs=True, returns_graph=True) +def partition_spanning_tree( + G, minimum=True, weight="weight", partition="partition", ignore_nan=False +): + """ + Find a spanning tree while respecting a partition of edges. + + Edges can be flagged as either `INCLUDED` which are required to be in the + returned tree, `EXCLUDED`, which cannot be in the returned tree and `OPEN`. + + This is used in the SpanningTreeIterator to create new partitions following + the algorithm of Sörensen and Janssens [1]_. + + Parameters + ---------- + G : undirected graph + An undirected graph. + + minimum : bool (default: True) + Determines whether the returned tree is the minimum spanning tree of + the partition of the maximum one. + + weight : str + Data key to use for edge weights. + + partition : str + The key for the edge attribute containing the partition + data on the graph. Edges can be included, excluded or open using the + `EdgePartition` enum. + + ignore_nan : bool (default: False) + If a NaN is found as an edge weight normally an exception is raised. + If `ignore_nan is True` then that edge is ignored instead. + + + Returns + ------- + G : NetworkX Graph + A minimum spanning tree using all of the included edges in the graph and + none of the excluded edges. + + References + ---------- + .. [1] G.K. Janssens, K. Sörensen, An algorithm to generate all spanning + trees in order of increasing cost, Pesquisa Operacional, 2005-08, + Vol. 25 (2), p. 219-229, + https://www.scielo.br/j/pope/a/XHswBwRwJyrfL88dmMwYNWp/?lang=en + """ + edges = kruskal_mst_edges( + G, + minimum, + weight, + keys=True, + data=True, + ignore_nan=ignore_nan, + partition=partition, + ) + T = G.__class__() # Same graph class as G + T.graph.update(G.graph) + T.add_nodes_from(G.nodes.items()) + T.add_edges_from(edges) + return T + + +@nx._dispatchable(preserve_all_attrs=True, returns_graph=True) +def maximum_spanning_tree(G, weight="weight", algorithm="kruskal", ignore_nan=False): + """Returns a maximum spanning tree or forest on an undirected graph `G`. + + Parameters + ---------- + G : undirected graph + An undirected graph. If `G` is connected, then the algorithm finds a + spanning tree. Otherwise, a spanning forest is found. + + weight : str + Data key to use for edge weights. + + algorithm : string + The algorithm to use when finding a maximum spanning tree. Valid + choices are 'kruskal', 'prim', or 'boruvka'. The default is + 'kruskal'. + + ignore_nan : bool (default: False) + If a NaN is found as an edge weight normally an exception is raised. + If `ignore_nan is True` then that edge is ignored instead. + + + Returns + ------- + G : NetworkX Graph + A maximum spanning tree or forest. + + + Examples + -------- + >>> G = nx.cycle_graph(4) + >>> G.add_edge(0, 3, weight=2) + >>> T = nx.maximum_spanning_tree(G) + >>> sorted(T.edges(data=True)) + [(0, 1, {}), (0, 3, {'weight': 2}), (1, 2, {})] + + + Notes + ----- + For Borůvka's algorithm, each edge must have a weight attribute, and + each edge weight must be distinct. + + For the other algorithms, if the graph edges do not have a weight + attribute a default weight of 1 will be used. + + There may be more than one tree with the same minimum or maximum weight. + See :mod:`networkx.tree.recognition` for more detailed definitions. + + Isolated nodes with self-loops are in the tree as edgeless isolated nodes. + + """ + edges = maximum_spanning_edges( + G, algorithm, weight, keys=True, data=True, ignore_nan=ignore_nan + ) + edges = list(edges) + T = G.__class__() # Same graph class as G + T.graph.update(G.graph) + T.add_nodes_from(G.nodes.items()) + T.add_edges_from(edges) + return T + + +@py_random_state(3) +@nx._dispatchable(preserve_edge_attrs=True, returns_graph=True) +def random_spanning_tree(G, weight=None, *, multiplicative=True, seed=None): + """ + Sample a random spanning tree using the edges weights of `G`. + + This function supports two different methods for determining the + probability of the graph. If ``multiplicative=True``, the probability + is based on the product of edge weights, and if ``multiplicative=False`` + it is based on the sum of the edge weight. However, since it is + easier to determine the total weight of all spanning trees for the + multiplicative version, that is significantly faster and should be used if + possible. Additionally, setting `weight` to `None` will cause a spanning tree + to be selected with uniform probability. + + The function uses algorithm A8 in [1]_ . + + Parameters + ---------- + G : nx.Graph + An undirected version of the original graph. + + weight : string + The edge key for the edge attribute holding edge weight. + + multiplicative : bool, default=True + If `True`, the probability of each tree is the product of its edge weight + over the sum of the product of all the spanning trees in the graph. If + `False`, the probability is the sum of its edge weight over the sum of + the sum of weights for all spanning trees in the graph. + + seed : integer, random_state, or None (default) + Indicator of random number generation state. + See :ref:`Randomness`. + + Returns + ------- + nx.Graph + A spanning tree using the distribution defined by the weight of the tree. + + References + ---------- + .. [1] V. Kulkarni, Generating random combinatorial objects, Journal of + Algorithms, 11 (1990), pp. 185–207 + """ + + def find_node(merged_nodes, node): + """ + We can think of clusters of contracted nodes as having one + representative in the graph. Each node which is not in merged_nodes + is still its own representative. Since a representative can be later + contracted, we need to recursively search though the dict to find + the final representative, but once we know it we can use path + compression to speed up the access of the representative for next time. + + This cannot be replaced by the standard NetworkX union_find since that + data structure will merge nodes with less representing nodes into the + one with more representing nodes but this function requires we merge + them using the order that contract_edges contracts using. + + Parameters + ---------- + merged_nodes : dict + The dict storing the mapping from node to representative + node + The node whose representative we seek + + Returns + ------- + The representative of the `node` + """ + if node not in merged_nodes: + return node + else: + rep = find_node(merged_nodes, merged_nodes[node]) + merged_nodes[node] = rep + return rep + + def prepare_graph(): + """ + For the graph `G`, remove all edges not in the set `V` and then + contract all edges in the set `U`. + + Returns + ------- + A copy of `G` which has had all edges not in `V` removed and all edges + in `U` contracted. + """ + + # The result is a MultiGraph version of G so that parallel edges are + # allowed during edge contraction + result = nx.MultiGraph(incoming_graph_data=G) + + # Remove all edges not in V + edges_to_remove = set(result.edges()).difference(V) + result.remove_edges_from(edges_to_remove) + + # Contract all edges in U + # + # Imagine that you have two edges to contract and they share an + # endpoint like this: + # [0] ----- [1] ----- [2] + # If we contract (0, 1) first, the contraction function will always + # delete the second node it is passed so the resulting graph would be + # [0] ----- [2] + # and edge (1, 2) no longer exists but (0, 2) would need to be contracted + # in its place now. That is why I use the below dict as a merge-find + # data structure with path compression to track how the nodes are merged. + merged_nodes = {} + + for u, v in U: + u_rep = find_node(merged_nodes, u) + v_rep = find_node(merged_nodes, v) + # We cannot contract a node with itself + if u_rep == v_rep: + continue + nx.contracted_nodes(result, u_rep, v_rep, self_loops=False, copy=False) + merged_nodes[v_rep] = u_rep + + return merged_nodes, result + + def spanning_tree_total_weight(G, weight): + """ + Find the sum of weights of the spanning trees of `G` using the + appropriate `method`. + + This is easy if the chosen method is 'multiplicative', since we can + use Kirchhoff's Tree Matrix Theorem directly. However, with the + 'additive' method, this process is slightly more complex and less + computationally efficient as we have to find the number of spanning + trees which contain each possible edge in the graph. + + Parameters + ---------- + G : NetworkX Graph + The graph to find the total weight of all spanning trees on. + + weight : string + The key for the weight edge attribute of the graph. + + Returns + ------- + float + The sum of either the multiplicative or additive weight for all + spanning trees in the graph. + """ + if multiplicative: + return nx.total_spanning_tree_weight(G, weight) + else: + # There are two cases for the total spanning tree additive weight. + # 1. There is one edge in the graph. Then the only spanning tree is + # that edge itself, which will have a total weight of that edge + # itself. + if G.number_of_edges() == 1: + return G.edges(data=weight).__iter__().__next__()[2] + # 2. There are no edges or two or more edges in the graph. Then, we find the + # total weight of the spanning trees using the formula in the + # reference paper: take the weight of each edge and multiply it by + # the number of spanning trees which include that edge. This + # can be accomplished by contracting the edge and finding the + # multiplicative total spanning tree weight if the weight of each edge + # is assumed to be 1, which is conveniently built into networkx already, + # by calling total_spanning_tree_weight with weight=None. + # Note that with no edges the returned value is just zero. + else: + total = 0 + for u, v, w in G.edges(data=weight): + total += w * nx.total_spanning_tree_weight( + nx.contracted_edge(G, edge=(u, v), self_loops=False), None + ) + return total + + if G.number_of_nodes() < 2: + # no edges in the spanning tree + return nx.empty_graph(G.nodes) + + U = set() + st_cached_value = 0 + V = set(G.edges()) + shuffled_edges = list(G.edges()) + seed.shuffle(shuffled_edges) + + for u, v in shuffled_edges: + e_weight = G[u][v][weight] if weight is not None else 1 + node_map, prepared_G = prepare_graph() + G_total_tree_weight = spanning_tree_total_weight(prepared_G, weight) + # Add the edge to U so that we can compute the total tree weight + # assuming we include that edge + # Now, if (u, v) cannot exist in G because it is fully contracted out + # of existence, then it by definition cannot influence G_e's Kirchhoff + # value. But, we also cannot pick it. + rep_edge = (find_node(node_map, u), find_node(node_map, v)) + # Check to see if the 'representative edge' for the current edge is + # in prepared_G. If so, then we can pick it. + if rep_edge in prepared_G.edges: + prepared_G_e = nx.contracted_edge( + prepared_G, edge=rep_edge, self_loops=False + ) + G_e_total_tree_weight = spanning_tree_total_weight(prepared_G_e, weight) + if multiplicative: + threshold = e_weight * G_e_total_tree_weight / G_total_tree_weight + else: + numerator = ( + st_cached_value + e_weight + ) * nx.total_spanning_tree_weight(prepared_G_e) + G_e_total_tree_weight + denominator = ( + st_cached_value * nx.total_spanning_tree_weight(prepared_G) + + G_total_tree_weight + ) + threshold = numerator / denominator + else: + threshold = 0.0 + z = seed.uniform(0.0, 1.0) + if z > threshold: + # Remove the edge from V since we did not pick it. + V.remove((u, v)) + else: + # Add the edge to U since we picked it. + st_cached_value += e_weight + U.add((u, v)) + # If we decide to keep an edge, it may complete the spanning tree. + if len(U) == G.number_of_nodes() - 1: + spanning_tree = nx.Graph() + spanning_tree.add_edges_from(U) + return spanning_tree + raise Exception(f"Something went wrong! Only {len(U)} edges in the spanning tree!") + + +class SpanningTreeIterator: + """ + Iterate over all spanning trees of a graph in either increasing or + decreasing cost. + + Notes + ----- + This iterator uses the partition scheme from [1]_ (included edges, + excluded edges and open edges) as well as a modified Kruskal's Algorithm + to generate minimum spanning trees which respect the partition of edges. + For spanning trees with the same weight, ties are broken arbitrarily. + + References + ---------- + .. [1] G.K. Janssens, K. Sörensen, An algorithm to generate all spanning + trees in order of increasing cost, Pesquisa Operacional, 2005-08, + Vol. 25 (2), p. 219-229, + https://www.scielo.br/j/pope/a/XHswBwRwJyrfL88dmMwYNWp/?lang=en + """ + + @dataclass(order=True) + class Partition: + """ + This dataclass represents a partition and stores a dict with the edge + data and the weight of the minimum spanning tree of the partition dict. + """ + + mst_weight: float + partition_dict: dict = field(compare=False) + + def __copy__(self): + return SpanningTreeIterator.Partition( + self.mst_weight, self.partition_dict.copy() + ) + + def __init__(self, G, weight="weight", minimum=True, ignore_nan=False): + """ + Initialize the iterator + + Parameters + ---------- + G : nx.Graph + The directed graph which we need to iterate trees over + + weight : String, default = "weight" + The edge attribute used to store the weight of the edge + + minimum : bool, default = True + Return the trees in increasing order while true and decreasing order + while false. + + ignore_nan : bool, default = False + If a NaN is found as an edge weight normally an exception is raised. + If `ignore_nan is True` then that edge is ignored instead. + """ + self.G = G.copy() + self.G.__networkx_cache__ = None # Disable caching + self.weight = weight + self.minimum = minimum + self.ignore_nan = ignore_nan + # Randomly create a key for an edge attribute to hold the partition data + self.partition_key = ( + "SpanningTreeIterators super secret partition attribute name" + ) + + def __iter__(self): + """ + Returns + ------- + SpanningTreeIterator + The iterator object for this graph + """ + self.partition_queue = PriorityQueue() + self._clear_partition(self.G) + mst_weight = partition_spanning_tree( + self.G, self.minimum, self.weight, self.partition_key, self.ignore_nan + ).size(weight=self.weight) + + self.partition_queue.put( + self.Partition(mst_weight if self.minimum else -mst_weight, {}) + ) + + return self + + def __next__(self): + """ + Returns + ------- + (multi)Graph + The spanning tree of next greatest weight, which ties broken + arbitrarily. + """ + if self.partition_queue.empty(): + del self.G, self.partition_queue + raise StopIteration + + partition = self.partition_queue.get() + self._write_partition(partition) + next_tree = partition_spanning_tree( + self.G, self.minimum, self.weight, self.partition_key, self.ignore_nan + ) + self._partition(partition, next_tree) + + self._clear_partition(next_tree) + return next_tree + + def _partition(self, partition, partition_tree): + """ + Create new partitions based of the minimum spanning tree of the + current minimum partition. + + Parameters + ---------- + partition : Partition + The Partition instance used to generate the current minimum spanning + tree. + partition_tree : nx.Graph + The minimum spanning tree of the input partition. + """ + # create two new partitions with the data from the input partition dict + p1 = self.Partition(0, partition.partition_dict.copy()) + p2 = self.Partition(0, partition.partition_dict.copy()) + for e in partition_tree.edges: + # determine if the edge was open or included + if e not in partition.partition_dict: + # This is an open edge + p1.partition_dict[e] = EdgePartition.EXCLUDED + p2.partition_dict[e] = EdgePartition.INCLUDED + + self._write_partition(p1) + p1_mst = partition_spanning_tree( + self.G, + self.minimum, + self.weight, + self.partition_key, + self.ignore_nan, + ) + p1_mst_weight = p1_mst.size(weight=self.weight) + if nx.is_connected(p1_mst): + p1.mst_weight = p1_mst_weight if self.minimum else -p1_mst_weight + self.partition_queue.put(p1.__copy__()) + p1.partition_dict = p2.partition_dict.copy() + + def _write_partition(self, partition): + """ + Writes the desired partition into the graph to calculate the minimum + spanning tree. + + Parameters + ---------- + partition : Partition + A Partition dataclass describing a partition on the edges of the + graph. + """ + + partition_dict = partition.partition_dict + partition_key = self.partition_key + G = self.G + + edges = ( + G.edges(keys=True, data=True) if G.is_multigraph() else G.edges(data=True) + ) + for *e, d in edges: + d[partition_key] = partition_dict.get(tuple(e), EdgePartition.OPEN) + + def _clear_partition(self, G): + """ + Removes partition data from the graph + """ + partition_key = self.partition_key + edges = ( + G.edges(keys=True, data=True) if G.is_multigraph() else G.edges(data=True) + ) + for *e, d in edges: + if partition_key in d: + del d[partition_key] + + +@nx._dispatchable(edge_attrs="weight") +def number_of_spanning_trees(G, *, root=None, weight=None): + """Returns the number of spanning trees in `G`. + + A spanning tree for an undirected graph is a tree that connects + all nodes in the graph. For a directed graph, the analog of a + spanning tree is called a (spanning) arborescence. The arborescence + includes a unique directed path from the `root` node to each other node. + The graph must be weakly connected, and the root must be a node + that includes all nodes as successors [3]_. Note that to avoid + discussing sink-roots and reverse-arborescences, we have reversed + the edge orientation from [3]_ and use the in-degree laplacian. + + This function (when `weight` is `None`) returns the number of + spanning trees for an undirected graph and the number of + arborescences from a single root node for a directed graph. + When `weight` is the name of an edge attribute which holds the + weight value of each edge, the function returns the sum over + all trees of the multiplicative weight of each tree. That is, + the weight of the tree is the product of its edge weights. + + Kirchoff's Tree Matrix Theorem states that any cofactor of the + Laplacian matrix of a graph is the number of spanning trees in the + graph. (Here we use cofactors for a diagonal entry so that the + cofactor becomes the determinant of the matrix with one row + and its matching column removed.) For a weighted Laplacian matrix, + the cofactor is the sum across all spanning trees of the + multiplicative weight of each tree. That is, the weight of each + tree is the product of its edge weights. The theorem is also + known as Kirchhoff's theorem [1]_ and the Matrix-Tree theorem [2]_. + + For directed graphs, a similar theorem (Tutte's Theorem) holds with + the cofactor chosen to be the one with row and column removed that + correspond to the root. The cofactor is the number of arborescences + with the specified node as root. And the weighted version gives the + sum of the arborescence weights with root `root`. The arborescence + weight is the product of its edge weights. + + Parameters + ---------- + G : NetworkX graph + + root : node + A node in the directed graph `G` that has all nodes as descendants. + (This is ignored for undirected graphs.) + + weight : string or None, optional (default=None) + The name of the edge attribute holding the edge weight. + If `None`, then each edge is assumed to have a weight of 1. + + Returns + ------- + Number + Undirected graphs: + The number of spanning trees of the graph `G`. + Or the sum of all spanning tree weights of the graph `G` + where the weight of a tree is the product of its edge weights. + Directed graphs: + The number of arborescences of `G` rooted at node `root`. + Or the sum of all arborescence weights of the graph `G` with + specified root where the weight of an arborescence is the product + of its edge weights. + + Raises + ------ + NetworkXPointlessConcept + If `G` does not contain any nodes. + + NetworkXError + If the graph `G` is directed and the root node + is not specified or is not in G. + + Examples + -------- + >>> G = nx.complete_graph(5) + >>> round(nx.number_of_spanning_trees(G)) + 125 + + >>> G = nx.Graph() + >>> G.add_edge(1, 2, weight=2) + >>> G.add_edge(1, 3, weight=1) + >>> G.add_edge(2, 3, weight=1) + >>> round(nx.number_of_spanning_trees(G, weight="weight")) + 5 + + Notes + ----- + Self-loops are excluded. Multi-edges are contracted in one edge + equal to the sum of the weights. + + References + ---------- + .. [1] Wikipedia + "Kirchhoff's theorem." + https://en.wikipedia.org/wiki/Kirchhoff%27s_theorem + .. [2] Kirchhoff, G. R. + Über die Auflösung der Gleichungen, auf welche man + bei der Untersuchung der linearen Vertheilung + Galvanischer Ströme geführt wird + Annalen der Physik und Chemie, vol. 72, pp. 497-508, 1847. + .. [3] Margoliash, J. + "Matrix-Tree Theorem for Directed Graphs" + https://www.math.uchicago.edu/~may/VIGRE/VIGRE2010/REUPapers/Margoliash.pdf + """ + import numpy as np + + if len(G) == 0: + raise nx.NetworkXPointlessConcept("Graph G must contain at least one node.") + + # undirected G + if not nx.is_directed(G): + if not nx.is_connected(G): + return 0 + G_laplacian = nx.laplacian_matrix(G, weight=weight).toarray() + return float(np.linalg.det(G_laplacian[1:, 1:])) + + # directed G + if root is None: + raise nx.NetworkXError("Input `root` must be provided when G is directed") + if root not in G: + raise nx.NetworkXError("The node root is not in the graph G.") + if not nx.is_weakly_connected(G): + return 0 + + # Compute directed Laplacian matrix + nodelist = [root] + [n for n in G if n != root] + A = nx.adjacency_matrix(G, nodelist=nodelist, weight=weight) + D = np.diag(A.sum(axis=0)) + G_laplacian = D - A + + # Compute number of spanning trees + return float(np.linalg.det(G_laplacian[1:, 1:])) diff --git a/lib/python3.10/site-packages/networkx/algorithms/tree/operations.py b/lib/python3.10/site-packages/networkx/algorithms/tree/operations.py new file mode 100644 index 0000000000000000000000000000000000000000..6c3e839453e686c80d33c94a66defa87698a066f --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/tree/operations.py @@ -0,0 +1,105 @@ +"""Operations on trees.""" + +from functools import partial +from itertools import accumulate, chain + +import networkx as nx + +__all__ = ["join_trees"] + + +# Argument types don't match dispatching, but allow manual selection of backend +@nx._dispatchable(graphs=None, returns_graph=True) +def join_trees(rooted_trees, *, label_attribute=None, first_label=0): + """Returns a new rooted tree made by joining `rooted_trees` + + Constructs a new tree by joining each tree in `rooted_trees`. + A new root node is added and connected to each of the roots + of the input trees. While copying the nodes from the trees, + relabeling to integers occurs. If the `label_attribute` is provided, + the old node labels will be stored in the new tree under this attribute. + + Parameters + ---------- + rooted_trees : list + A list of pairs in which each left element is a NetworkX graph + object representing a tree and each right element is the root + node of that tree. The nodes of these trees will be relabeled to + integers. + + label_attribute : str + If provided, the old node labels will be stored in the new tree + under this node attribute. If not provided, the original labels + of the nodes in the input trees are not stored. + + first_label : int, optional (default=0) + Specifies the label for the new root node. If provided, the root node of the joined tree + will have this label. If not provided, the root node will default to a label of 0. + + Returns + ------- + NetworkX graph + The rooted tree resulting from joining the provided `rooted_trees`. The new tree has a root node + labeled as specified by `first_label` (defaulting to 0 if not provided). Subtrees from the input + `rooted_trees` are attached to this new root node. Each non-root node, if the `label_attribute` + is provided, has an attribute that indicates the original label of the node in the input tree. + + Notes + ----- + Trees are stored in NetworkX as NetworkX Graphs. There is no specific + enforcement of the fact that these are trees. Testing for each tree + can be done using :func:`networkx.is_tree`. + + Graph, edge, and node attributes are propagated from the given + rooted trees to the created tree. If there are any overlapping graph + attributes, those from later trees will overwrite those from earlier + trees in the tuple of positional arguments. + + Examples + -------- + Join two full balanced binary trees of height *h* to get a full + balanced binary tree of depth *h* + 1:: + + >>> h = 4 + >>> left = nx.balanced_tree(2, h) + >>> right = nx.balanced_tree(2, h) + >>> joined_tree = nx.join_trees([(left, 0), (right, 0)]) + >>> nx.is_isomorphic(joined_tree, nx.balanced_tree(2, h + 1)) + True + + """ + if not rooted_trees: + return nx.empty_graph(1) + + # Unzip the zipped list of (tree, root) pairs. + trees, roots = zip(*rooted_trees) + + # The join of the trees has the same type as the type of the first tree. + R = type(trees[0])() + + lengths = (len(tree) for tree in trees[:-1]) + first_labels = list(accumulate(lengths, initial=first_label + 1)) + + new_roots = [] + for tree, root, first_node in zip(trees, roots, first_labels): + new_root = first_node + list(tree.nodes()).index(root) + new_roots.append(new_root) + + # Relabel the nodes so that their union is the integers starting at first_label. + relabel = partial( + nx.convert_node_labels_to_integers, label_attribute=label_attribute + ) + new_trees = [ + relabel(tree, first_label=first_label) + for tree, first_label in zip(trees, first_labels) + ] + + # Add all sets of nodes and edges, attributes + for tree in new_trees: + R.update(tree) + + # Finally, join the subtrees at the root. We know first_label is unused by the way we relabeled the subtrees. + R.add_node(first_label) + R.add_edges_from((first_label, root) for root in new_roots) + + return R diff --git a/lib/python3.10/site-packages/networkx/algorithms/tree/recognition.py b/lib/python3.10/site-packages/networkx/algorithms/tree/recognition.py new file mode 100644 index 0000000000000000000000000000000000000000..a9eae98707a6889213ff8b93887c481ba59215a0 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/tree/recognition.py @@ -0,0 +1,273 @@ +""" +Recognition Tests +================= + +A *forest* is an acyclic, undirected graph, and a *tree* is a connected forest. +Depending on the subfield, there are various conventions for generalizing these +definitions to directed graphs. + +In one convention, directed variants of forest and tree are defined in an +identical manner, except that the direction of the edges is ignored. In effect, +each directed edge is treated as a single undirected edge. Then, additional +restrictions are imposed to define *branchings* and *arborescences*. + +In another convention, directed variants of forest and tree correspond to +the previous convention's branchings and arborescences, respectively. Then two +new terms, *polyforest* and *polytree*, are defined to correspond to the other +convention's forest and tree. + +Summarizing:: + + +-----------------------------+ + | Convention A | Convention B | + +=============================+ + | forest | polyforest | + | tree | polytree | + | branching | forest | + | arborescence | tree | + +-----------------------------+ + +Each convention has its reasons. The first convention emphasizes definitional +similarity in that directed forests and trees are only concerned with +acyclicity and do not have an in-degree constraint, just as their undirected +counterparts do not. The second convention emphasizes functional similarity +in the sense that the directed analog of a spanning tree is a spanning +arborescence. That is, take any spanning tree and choose one node as the root. +Then every edge is assigned a direction such there is a directed path from the +root to every other node. The result is a spanning arborescence. + +NetworkX follows convention "A". Explicitly, these are: + +undirected forest + An undirected graph with no undirected cycles. + +undirected tree + A connected, undirected forest. + +directed forest + A directed graph with no undirected cycles. Equivalently, the underlying + graph structure (which ignores edge orientations) is an undirected forest. + In convention B, this is known as a polyforest. + +directed tree + A weakly connected, directed forest. Equivalently, the underlying graph + structure (which ignores edge orientations) is an undirected tree. In + convention B, this is known as a polytree. + +branching + A directed forest with each node having, at most, one parent. So the maximum + in-degree is equal to 1. In convention B, this is known as a forest. + +arborescence + A directed tree with each node having, at most, one parent. So the maximum + in-degree is equal to 1. In convention B, this is known as a tree. + +For trees and arborescences, the adjective "spanning" may be added to designate +that the graph, when considered as a forest/branching, consists of a single +tree/arborescence that includes all nodes in the graph. It is true, by +definition, that every tree/arborescence is spanning with respect to the nodes +that define the tree/arborescence and so, it might seem redundant to introduce +the notion of "spanning". However, the nodes may represent a subset of +nodes from a larger graph, and it is in this context that the term "spanning" +becomes a useful notion. + +""" + +import networkx as nx + +__all__ = ["is_arborescence", "is_branching", "is_forest", "is_tree"] + + +@nx.utils.not_implemented_for("undirected") +@nx._dispatchable +def is_arborescence(G): + """ + Returns True if `G` is an arborescence. + + An arborescence is a directed tree with maximum in-degree equal to 1. + + Parameters + ---------- + G : graph + The graph to test. + + Returns + ------- + b : bool + A boolean that is True if `G` is an arborescence. + + Examples + -------- + >>> G = nx.DiGraph([(0, 1), (0, 2), (2, 3), (3, 4)]) + >>> nx.is_arborescence(G) + True + >>> G.remove_edge(0, 1) + >>> G.add_edge(1, 2) # maximum in-degree is 2 + >>> nx.is_arborescence(G) + False + + Notes + ----- + In another convention, an arborescence is known as a *tree*. + + See Also + -------- + is_tree + + """ + return is_tree(G) and max(d for n, d in G.in_degree()) <= 1 + + +@nx.utils.not_implemented_for("undirected") +@nx._dispatchable +def is_branching(G): + """ + Returns True if `G` is a branching. + + A branching is a directed forest with maximum in-degree equal to 1. + + Parameters + ---------- + G : directed graph + The directed graph to test. + + Returns + ------- + b : bool + A boolean that is True if `G` is a branching. + + Examples + -------- + >>> G = nx.DiGraph([(0, 1), (1, 2), (2, 3), (3, 4)]) + >>> nx.is_branching(G) + True + >>> G.remove_edge(2, 3) + >>> G.add_edge(3, 1) # maximum in-degree is 2 + >>> nx.is_branching(G) + False + + Notes + ----- + In another convention, a branching is also known as a *forest*. + + See Also + -------- + is_forest + + """ + return is_forest(G) and max(d for n, d in G.in_degree()) <= 1 + + +@nx._dispatchable +def is_forest(G): + """ + Returns True if `G` is a forest. + + A forest is a graph with no undirected cycles. + + For directed graphs, `G` is a forest if the underlying graph is a forest. + The underlying graph is obtained by treating each directed edge as a single + undirected edge in a multigraph. + + Parameters + ---------- + G : graph + The graph to test. + + Returns + ------- + b : bool + A boolean that is True if `G` is a forest. + + Raises + ------ + NetworkXPointlessConcept + If `G` is empty. + + Examples + -------- + >>> G = nx.Graph() + >>> G.add_edges_from([(1, 2), (1, 3), (2, 4), (2, 5)]) + >>> nx.is_forest(G) + True + >>> G.add_edge(4, 1) + >>> nx.is_forest(G) + False + + Notes + ----- + In another convention, a directed forest is known as a *polyforest* and + then *forest* corresponds to a *branching*. + + See Also + -------- + is_branching + + """ + if len(G) == 0: + raise nx.exception.NetworkXPointlessConcept("G has no nodes.") + + if G.is_directed(): + components = (G.subgraph(c) for c in nx.weakly_connected_components(G)) + else: + components = (G.subgraph(c) for c in nx.connected_components(G)) + + return all(len(c) - 1 == c.number_of_edges() for c in components) + + +@nx._dispatchable +def is_tree(G): + """ + Returns True if `G` is a tree. + + A tree is a connected graph with no undirected cycles. + + For directed graphs, `G` is a tree if the underlying graph is a tree. The + underlying graph is obtained by treating each directed edge as a single + undirected edge in a multigraph. + + Parameters + ---------- + G : graph + The graph to test. + + Returns + ------- + b : bool + A boolean that is True if `G` is a tree. + + Raises + ------ + NetworkXPointlessConcept + If `G` is empty. + + Examples + -------- + >>> G = nx.Graph() + >>> G.add_edges_from([(1, 2), (1, 3), (2, 4), (2, 5)]) + >>> nx.is_tree(G) # n-1 edges + True + >>> G.add_edge(3, 4) + >>> nx.is_tree(G) # n edges + False + + Notes + ----- + In another convention, a directed tree is known as a *polytree* and then + *tree* corresponds to an *arborescence*. + + See Also + -------- + is_arborescence + + """ + if len(G) == 0: + raise nx.exception.NetworkXPointlessConcept("G has no nodes.") + + if G.is_directed(): + is_connected = nx.is_weakly_connected + else: + is_connected = nx.is_connected + + # A connected graph with no cycles has n-1 edges. + return len(G) - 1 == G.number_of_edges() and is_connected(G) diff --git a/lib/python3.10/site-packages/networkx/algorithms/tree/tests/__init__.py b/lib/python3.10/site-packages/networkx/algorithms/tree/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/python3.10/site-packages/networkx/algorithms/tree/tests/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tree/tests/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6cec0c397c217f8246b11a9f95454df0aff107a7 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tree/tests/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tree/tests/__pycache__/test_branchings.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tree/tests/__pycache__/test_branchings.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..47674bc16491aecb4aed72e8107604c58fe86e20 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tree/tests/__pycache__/test_branchings.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tree/tests/__pycache__/test_coding.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tree/tests/__pycache__/test_coding.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f523f3ec89966880f0ac6d94da0c5f7dbd1d7ea6 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tree/tests/__pycache__/test_coding.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tree/tests/__pycache__/test_decomposition.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tree/tests/__pycache__/test_decomposition.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f10dcca5fa5a0054fafbfd35566754f460054cb Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tree/tests/__pycache__/test_decomposition.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tree/tests/__pycache__/test_mst.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tree/tests/__pycache__/test_mst.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5c2d4075c608eb137c2b317c1d51405ab4050235 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tree/tests/__pycache__/test_mst.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tree/tests/__pycache__/test_operations.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tree/tests/__pycache__/test_operations.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1193e610fa4238f213655be6f26f75e49b9ee159 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tree/tests/__pycache__/test_operations.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tree/tests/__pycache__/test_recognition.cpython-310.pyc b/lib/python3.10/site-packages/networkx/algorithms/tree/tests/__pycache__/test_recognition.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0f71b02493a1b2956fc0e4abc5101a58fcff22c3 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/algorithms/tree/tests/__pycache__/test_recognition.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/algorithms/tree/tests/test_branchings.py b/lib/python3.10/site-packages/networkx/algorithms/tree/tests/test_branchings.py new file mode 100644 index 0000000000000000000000000000000000000000..e19ddee332b93d8131d48fa7cfc837625b2261f4 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/tree/tests/test_branchings.py @@ -0,0 +1,624 @@ +import math +from operator import itemgetter + +import pytest + +np = pytest.importorskip("numpy") + +import networkx as nx +from networkx.algorithms.tree import branchings, recognition + +# +# Explicitly discussed examples from Edmonds paper. +# + +# Used in Figures A-F. +# +# fmt: off +G_array = np.array([ + # 0 1 2 3 4 5 6 7 8 + [0, 0, 12, 0, 12, 0, 0, 0, 0], # 0 + [4, 0, 0, 0, 0, 13, 0, 0, 0], # 1 + [0, 17, 0, 21, 0, 12, 0, 0, 0], # 2 + [5, 0, 0, 0, 17, 0, 18, 0, 0], # 3 + [0, 0, 0, 0, 0, 0, 0, 12, 0], # 4 + [0, 0, 0, 0, 0, 0, 14, 0, 12], # 5 + [0, 0, 21, 0, 0, 0, 0, 0, 15], # 6 + [0, 0, 0, 19, 0, 0, 15, 0, 0], # 7 + [0, 0, 0, 0, 0, 0, 0, 18, 0], # 8 +], dtype=int) + +# Two copies of the graph from the original paper as disconnected components +G_big_array = np.zeros(np.array(G_array.shape) * 2, dtype=int) +G_big_array[:G_array.shape[0], :G_array.shape[1]] = G_array +G_big_array[G_array.shape[0]:, G_array.shape[1]:] = G_array + +# fmt: on + + +def G1(): + G = nx.from_numpy_array(G_array, create_using=nx.MultiDiGraph) + return G + + +def G2(): + # Now we shift all the weights by -10. + # Should not affect optimal arborescence, but does affect optimal branching. + Garr = G_array.copy() + Garr[np.nonzero(Garr)] -= 10 + G = nx.from_numpy_array(Garr, create_using=nx.MultiDiGraph) + return G + + +# An optimal branching for G1 that is also a spanning arborescence. So it is +# also an optimal spanning arborescence. +# +optimal_arborescence_1 = [ + (0, 2, 12), + (2, 1, 17), + (2, 3, 21), + (1, 5, 13), + (3, 4, 17), + (3, 6, 18), + (6, 8, 15), + (8, 7, 18), +] + +# For G2, the optimal branching of G1 (with shifted weights) is no longer +# an optimal branching, but it is still an optimal spanning arborescence +# (just with shifted weights). An optimal branching for G2 is similar to what +# appears in figure G (this is greedy_subopt_branching_1a below), but with the +# edge (3, 0, 5), which is now (3, 0, -5), removed. Thus, the optimal branching +# is not a spanning arborescence. The code finds optimal_branching_2a. +# An alternative and equivalent branching is optimal_branching_2b. We would +# need to modify the code to iterate through all equivalent optimal branchings. +# +# These are maximal branchings or arborescences. +optimal_branching_2a = [ + (5, 6, 4), + (6, 2, 11), + (6, 8, 5), + (8, 7, 8), + (2, 1, 7), + (2, 3, 11), + (3, 4, 7), +] +optimal_branching_2b = [ + (8, 7, 8), + (7, 3, 9), + (3, 4, 7), + (3, 6, 8), + (6, 2, 11), + (2, 1, 7), + (1, 5, 3), +] +optimal_arborescence_2 = [ + (0, 2, 2), + (2, 1, 7), + (2, 3, 11), + (1, 5, 3), + (3, 4, 7), + (3, 6, 8), + (6, 8, 5), + (8, 7, 8), +] + +# Two suboptimal maximal branchings on G1 obtained from a greedy algorithm. +# 1a matches what is shown in Figure G in Edmonds's paper. +greedy_subopt_branching_1a = [ + (5, 6, 14), + (6, 2, 21), + (6, 8, 15), + (8, 7, 18), + (2, 1, 17), + (2, 3, 21), + (3, 0, 5), + (3, 4, 17), +] +greedy_subopt_branching_1b = [ + (8, 7, 18), + (7, 6, 15), + (6, 2, 21), + (2, 1, 17), + (2, 3, 21), + (1, 5, 13), + (3, 0, 5), + (3, 4, 17), +] + + +def build_branching(edges, double=False): + G = nx.DiGraph() + for u, v, weight in edges: + G.add_edge(u, v, weight=weight) + if double: + G.add_edge(u + 9, v + 9, weight=weight) + return G + + +def sorted_edges(G, attr="weight", default=1): + edges = [(u, v, data.get(attr, default)) for (u, v, data) in G.edges(data=True)] + edges = sorted(edges, key=lambda x: (x[2], x[1], x[0])) + return edges + + +def assert_equal_branchings(G1, G2, attr="weight", default=1): + edges1 = list(G1.edges(data=True)) + edges2 = list(G2.edges(data=True)) + assert len(edges1) == len(edges2) + + # Grab the weights only. + e1 = sorted_edges(G1, attr, default) + e2 = sorted_edges(G2, attr, default) + + for a, b in zip(e1, e2): + assert a[:2] == b[:2] + np.testing.assert_almost_equal(a[2], b[2]) + + +################ + + +def test_optimal_branching1(): + G = build_branching(optimal_arborescence_1) + assert recognition.is_arborescence(G), True + assert branchings.branching_weight(G) == 131 + + +def test_optimal_branching2a(): + G = build_branching(optimal_branching_2a) + assert recognition.is_arborescence(G), True + assert branchings.branching_weight(G) == 53 + + +def test_optimal_branching2b(): + G = build_branching(optimal_branching_2b) + assert recognition.is_arborescence(G), True + assert branchings.branching_weight(G) == 53 + + +def test_optimal_arborescence2(): + G = build_branching(optimal_arborescence_2) + assert recognition.is_arborescence(G), True + assert branchings.branching_weight(G) == 51 + + +def test_greedy_suboptimal_branching1a(): + G = build_branching(greedy_subopt_branching_1a) + assert recognition.is_arborescence(G), True + assert branchings.branching_weight(G) == 128 + + +def test_greedy_suboptimal_branching1b(): + G = build_branching(greedy_subopt_branching_1b) + assert recognition.is_arborescence(G), True + assert branchings.branching_weight(G) == 127 + + +def test_greedy_max1(): + # Standard test. + # + G = G1() + B = branchings.greedy_branching(G) + # There are only two possible greedy branchings. The sorting is such + # that it should equal the second suboptimal branching: 1b. + B_ = build_branching(greedy_subopt_branching_1b) + assert_equal_branchings(B, B_) + + +def test_greedy_branching_kwarg_kind(): + G = G1() + with pytest.raises(nx.NetworkXException, match="Unknown value for `kind`."): + B = branchings.greedy_branching(G, kind="lol") + + +def test_greedy_branching_for_unsortable_nodes(): + G = nx.DiGraph() + G.add_weighted_edges_from([((2, 3), 5, 1), (3, "a", 1), (2, 4, 5)]) + edges = [(u, v, data.get("weight", 1)) for (u, v, data) in G.edges(data=True)] + with pytest.raises(TypeError): + edges.sort(key=itemgetter(2, 0, 1), reverse=True) + B = branchings.greedy_branching(G, kind="max").edges(data=True) + assert list(B) == [ + ((2, 3), 5, {"weight": 1}), + (3, "a", {"weight": 1}), + (2, 4, {"weight": 5}), + ] + + +def test_greedy_max2(): + # Different default weight. + # + G = G1() + del G[1][0][0]["weight"] + B = branchings.greedy_branching(G, default=6) + # Chosen so that edge (3,0,5) is not selected and (1,0,6) is instead. + + edges = [ + (1, 0, 6), + (1, 5, 13), + (7, 6, 15), + (2, 1, 17), + (3, 4, 17), + (8, 7, 18), + (2, 3, 21), + (6, 2, 21), + ] + B_ = build_branching(edges) + assert_equal_branchings(B, B_) + + +def test_greedy_max3(): + # All equal weights. + # + G = G1() + B = branchings.greedy_branching(G, attr=None) + + # This is mostly arbitrary...the output was generated by running the algo. + edges = [ + (2, 1, 1), + (3, 0, 1), + (3, 4, 1), + (5, 8, 1), + (6, 2, 1), + (7, 3, 1), + (7, 6, 1), + (8, 7, 1), + ] + B_ = build_branching(edges) + assert_equal_branchings(B, B_, default=1) + + +def test_greedy_min(): + G = G1() + B = branchings.greedy_branching(G, kind="min") + + edges = [ + (1, 0, 4), + (0, 2, 12), + (0, 4, 12), + (2, 5, 12), + (4, 7, 12), + (5, 8, 12), + (5, 6, 14), + (7, 3, 19), + ] + B_ = build_branching(edges) + assert_equal_branchings(B, B_) + + +def test_edmonds1_maxbranch(): + G = G1() + x = branchings.maximum_branching(G) + x_ = build_branching(optimal_arborescence_1) + assert_equal_branchings(x, x_) + + +def test_edmonds1_maxarbor(): + G = G1() + x = branchings.maximum_spanning_arborescence(G) + x_ = build_branching(optimal_arborescence_1) + assert_equal_branchings(x, x_) + + +def test_edmonds1_minimal_branching(): + # graph will have something like a minimum arborescence but no spanning one + G = nx.from_numpy_array(G_big_array, create_using=nx.DiGraph) + B = branchings.minimal_branching(G) + edges = [ + (3, 0, 5), + (0, 2, 12), + (0, 4, 12), + (2, 5, 12), + (4, 7, 12), + (5, 8, 12), + (5, 6, 14), + (2, 1, 17), + ] + B_ = build_branching(edges, double=True) + assert_equal_branchings(B, B_) + + +def test_edmonds2_maxbranch(): + G = G2() + x = branchings.maximum_branching(G) + x_ = build_branching(optimal_branching_2a) + assert_equal_branchings(x, x_) + + +def test_edmonds2_maxarbor(): + G = G2() + x = branchings.maximum_spanning_arborescence(G) + x_ = build_branching(optimal_arborescence_2) + assert_equal_branchings(x, x_) + + +def test_edmonds2_minarbor(): + G = G1() + x = branchings.minimum_spanning_arborescence(G) + # This was obtained from algorithm. Need to verify it independently. + # Branch weight is: 96 + edges = [ + (3, 0, 5), + (0, 2, 12), + (0, 4, 12), + (2, 5, 12), + (4, 7, 12), + (5, 8, 12), + (5, 6, 14), + (2, 1, 17), + ] + x_ = build_branching(edges) + assert_equal_branchings(x, x_) + + +def test_edmonds3_minbranch1(): + G = G1() + x = branchings.minimum_branching(G) + edges = [] + x_ = build_branching(edges) + assert_equal_branchings(x, x_) + + +def test_edmonds3_minbranch2(): + G = G1() + G.add_edge(8, 9, weight=-10) + x = branchings.minimum_branching(G) + edges = [(8, 9, -10)] + x_ = build_branching(edges) + assert_equal_branchings(x, x_) + + +# Need more tests + + +def test_mst(): + # Make sure we get the same results for undirected graphs. + # Example from: https://en.wikipedia.org/wiki/Kruskal's_algorithm + G = nx.Graph() + edgelist = [ + (0, 3, [("weight", 5)]), + (0, 1, [("weight", 7)]), + (1, 3, [("weight", 9)]), + (1, 2, [("weight", 8)]), + (1, 4, [("weight", 7)]), + (3, 4, [("weight", 15)]), + (3, 5, [("weight", 6)]), + (2, 4, [("weight", 5)]), + (4, 5, [("weight", 8)]), + (4, 6, [("weight", 9)]), + (5, 6, [("weight", 11)]), + ] + G.add_edges_from(edgelist) + G = G.to_directed() + x = branchings.minimum_spanning_arborescence(G) + + edges = [ + ({0, 1}, 7), + ({0, 3}, 5), + ({3, 5}, 6), + ({1, 4}, 7), + ({4, 2}, 5), + ({4, 6}, 9), + ] + + assert x.number_of_edges() == len(edges) + for u, v, d in x.edges(data=True): + assert ({u, v}, d["weight"]) in edges + + +def test_mixed_nodetypes(): + # Smoke test to make sure no TypeError is raised for mixed node types. + G = nx.Graph() + edgelist = [(0, 3, [("weight", 5)]), (0, "1", [("weight", 5)])] + G.add_edges_from(edgelist) + G = G.to_directed() + x = branchings.minimum_spanning_arborescence(G) + + +def test_edmonds1_minbranch(): + # Using -G_array and min should give the same as optimal_arborescence_1, + # but with all edges negative. + edges = [(u, v, -w) for (u, v, w) in optimal_arborescence_1] + + G = nx.from_numpy_array(-G_array, create_using=nx.DiGraph) + + # Quickly make sure max branching is empty. + x = branchings.maximum_branching(G) + x_ = build_branching([]) + assert_equal_branchings(x, x_) + + # Now test the min branching. + x = branchings.minimum_branching(G) + x_ = build_branching(edges) + assert_equal_branchings(x, x_) + + +def test_edge_attribute_preservation_normal_graph(): + # Test that edge attributes are preserved when finding an optimum graph + # using the Edmonds class for normal graphs. + G = nx.Graph() + + edgelist = [ + (0, 1, [("weight", 5), ("otherattr", 1), ("otherattr2", 3)]), + (0, 2, [("weight", 5), ("otherattr", 2), ("otherattr2", 2)]), + (1, 2, [("weight", 6), ("otherattr", 3), ("otherattr2", 1)]), + ] + G.add_edges_from(edgelist) + + B = branchings.maximum_branching(G, preserve_attrs=True) + + assert B[0][1]["otherattr"] == 1 + assert B[0][1]["otherattr2"] == 3 + + +def test_edge_attribute_preservation_multigraph(): + # Test that edge attributes are preserved when finding an optimum graph + # using the Edmonds class for multigraphs. + G = nx.MultiGraph() + + edgelist = [ + (0, 1, [("weight", 5), ("otherattr", 1), ("otherattr2", 3)]), + (0, 2, [("weight", 5), ("otherattr", 2), ("otherattr2", 2)]), + (1, 2, [("weight", 6), ("otherattr", 3), ("otherattr2", 1)]), + ] + G.add_edges_from(edgelist * 2) # Make sure we have duplicate edge paths + + B = branchings.maximum_branching(G, preserve_attrs=True) + + assert B[0][1][0]["otherattr"] == 1 + assert B[0][1][0]["otherattr2"] == 3 + + +def test_edge_attribute_discard(): + # Test that edge attributes are discarded if we do not specify to keep them + G = nx.Graph() + + edgelist = [ + (0, 1, [("weight", 5), ("otherattr", 1), ("otherattr2", 3)]), + (0, 2, [("weight", 5), ("otherattr", 2), ("otherattr2", 2)]), + (1, 2, [("weight", 6), ("otherattr", 3), ("otherattr2", 1)]), + ] + G.add_edges_from(edgelist) + + B = branchings.maximum_branching(G, preserve_attrs=False) + + edge_dict = B[0][1] + with pytest.raises(KeyError): + _ = edge_dict["otherattr"] + + +def test_partition_spanning_arborescence(): + """ + Test that we can generate minimum spanning arborescences which respect the + given partition. + """ + G = nx.from_numpy_array(G_array, create_using=nx.DiGraph) + G[3][0]["partition"] = nx.EdgePartition.EXCLUDED + G[2][3]["partition"] = nx.EdgePartition.INCLUDED + G[7][3]["partition"] = nx.EdgePartition.EXCLUDED + G[0][2]["partition"] = nx.EdgePartition.EXCLUDED + G[6][2]["partition"] = nx.EdgePartition.INCLUDED + + actual_edges = [ + (0, 4, 12), + (1, 0, 4), + (1, 5, 13), + (2, 3, 21), + (4, 7, 12), + (5, 6, 14), + (5, 8, 12), + (6, 2, 21), + ] + + B = branchings.minimum_spanning_arborescence(G, partition="partition") + assert_equal_branchings(build_branching(actual_edges), B) + + +def test_arborescence_iterator_min(): + """ + Tests the arborescence iterator. + + A brute force method found 680 arborescences in this graph. + This test will not verify all of them individually, but will check two + things + + * The iterator returns 680 arborescences + * The weight of the arborescences is non-strictly increasing + + for more information please visit + https://mjschwenne.github.io/2021/06/10/implementing-the-iterators.html + """ + G = nx.from_numpy_array(G_array, create_using=nx.DiGraph) + + arborescence_count = 0 + arborescence_weight = -math.inf + for B in branchings.ArborescenceIterator(G): + arborescence_count += 1 + new_arborescence_weight = B.size(weight="weight") + assert new_arborescence_weight >= arborescence_weight + arborescence_weight = new_arborescence_weight + + assert arborescence_count == 680 + + +def test_arborescence_iterator_max(): + """ + Tests the arborescence iterator. + + A brute force method found 680 arborescences in this graph. + This test will not verify all of them individually, but will check two + things + + * The iterator returns 680 arborescences + * The weight of the arborescences is non-strictly decreasing + + for more information please visit + https://mjschwenne.github.io/2021/06/10/implementing-the-iterators.html + """ + G = nx.from_numpy_array(G_array, create_using=nx.DiGraph) + + arborescence_count = 0 + arborescence_weight = math.inf + for B in branchings.ArborescenceIterator(G, minimum=False): + arborescence_count += 1 + new_arborescence_weight = B.size(weight="weight") + assert new_arborescence_weight <= arborescence_weight + arborescence_weight = new_arborescence_weight + + assert arborescence_count == 680 + + +def test_arborescence_iterator_initial_partition(): + """ + Tests the arborescence iterator with three included edges and three excluded + in the initial partition. + + A brute force method similar to the one used in the above tests found that + there are 16 arborescences which contain the included edges and not the + excluded edges. + """ + G = nx.from_numpy_array(G_array, create_using=nx.DiGraph) + included_edges = [(1, 0), (5, 6), (8, 7)] + excluded_edges = [(0, 2), (3, 6), (1, 5)] + + arborescence_count = 0 + arborescence_weight = -math.inf + for B in branchings.ArborescenceIterator( + G, init_partition=(included_edges, excluded_edges) + ): + arborescence_count += 1 + new_arborescence_weight = B.size(weight="weight") + assert new_arborescence_weight >= arborescence_weight + arborescence_weight = new_arborescence_weight + for e in included_edges: + assert e in B.edges + for e in excluded_edges: + assert e not in B.edges + assert arborescence_count == 16 + + +def test_branchings_with_default_weights(): + """ + Tests that various branching algorithms work on graphs without weights. + For more information, see issue #7279. + """ + graph = nx.erdos_renyi_graph(10, p=0.2, directed=True, seed=123) + + assert all( + "weight" not in d for (u, v, d) in graph.edges(data=True) + ), "test is for graphs without a weight attribute" + + # Calling these functions will modify graph inplace to add weights + # copy the graph to avoid this. + nx.minimum_spanning_arborescence(graph.copy()) + nx.maximum_spanning_arborescence(graph.copy()) + nx.minimum_branching(graph.copy()) + nx.maximum_branching(graph.copy()) + nx.algorithms.tree.minimal_branching(graph.copy()) + nx.algorithms.tree.branching_weight(graph.copy()) + nx.algorithms.tree.greedy_branching(graph.copy()) + + assert all( + "weight" not in d for (u, v, d) in graph.edges(data=True) + ), "The above calls should not modify the initial graph in-place" diff --git a/lib/python3.10/site-packages/networkx/algorithms/tree/tests/test_coding.py b/lib/python3.10/site-packages/networkx/algorithms/tree/tests/test_coding.py new file mode 100644 index 0000000000000000000000000000000000000000..26bd4083f52a0cc90b94c6de6d47b2c44e70a079 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/tree/tests/test_coding.py @@ -0,0 +1,114 @@ +"""Unit tests for the :mod:`~networkx.algorithms.tree.coding` module.""" + +from itertools import product + +import pytest + +import networkx as nx +from networkx.utils import edges_equal, nodes_equal + + +class TestPruferSequence: + """Unit tests for the Prüfer sequence encoding and decoding + functions. + + """ + + def test_nontree(self): + with pytest.raises(nx.NotATree): + G = nx.cycle_graph(3) + nx.to_prufer_sequence(G) + + def test_null_graph(self): + with pytest.raises(nx.NetworkXPointlessConcept): + nx.to_prufer_sequence(nx.null_graph()) + + def test_trivial_graph(self): + with pytest.raises(nx.NetworkXPointlessConcept): + nx.to_prufer_sequence(nx.trivial_graph()) + + def test_bad_integer_labels(self): + with pytest.raises(KeyError): + T = nx.Graph(nx.utils.pairwise("abc")) + nx.to_prufer_sequence(T) + + def test_encoding(self): + """Tests for encoding a tree as a Prüfer sequence using the + iterative strategy. + + """ + # Example from Wikipedia. + tree = nx.Graph([(0, 3), (1, 3), (2, 3), (3, 4), (4, 5)]) + sequence = nx.to_prufer_sequence(tree) + assert sequence == [3, 3, 3, 4] + + def test_decoding(self): + """Tests for decoding a tree from a Prüfer sequence.""" + # Example from Wikipedia. + sequence = [3, 3, 3, 4] + tree = nx.from_prufer_sequence(sequence) + assert nodes_equal(list(tree), list(range(6))) + edges = [(0, 3), (1, 3), (2, 3), (3, 4), (4, 5)] + assert edges_equal(list(tree.edges()), edges) + + def test_decoding2(self): + # Example from "An Optimal Algorithm for Prufer Codes". + sequence = [2, 4, 0, 1, 3, 3] + tree = nx.from_prufer_sequence(sequence) + assert nodes_equal(list(tree), list(range(8))) + edges = [(0, 1), (0, 4), (1, 3), (2, 4), (2, 5), (3, 6), (3, 7)] + assert edges_equal(list(tree.edges()), edges) + + def test_inverse(self): + """Tests that the encoding and decoding functions are inverses.""" + for T in nx.nonisomorphic_trees(4): + T2 = nx.from_prufer_sequence(nx.to_prufer_sequence(T)) + assert nodes_equal(list(T), list(T2)) + assert edges_equal(list(T.edges()), list(T2.edges())) + + for seq in product(range(4), repeat=2): + seq2 = nx.to_prufer_sequence(nx.from_prufer_sequence(seq)) + assert list(seq) == seq2 + + +class TestNestedTuple: + """Unit tests for the nested tuple encoding and decoding functions.""" + + def test_nontree(self): + with pytest.raises(nx.NotATree): + G = nx.cycle_graph(3) + nx.to_nested_tuple(G, 0) + + def test_unknown_root(self): + with pytest.raises(nx.NodeNotFound): + G = nx.path_graph(2) + nx.to_nested_tuple(G, "bogus") + + def test_encoding(self): + T = nx.full_rary_tree(2, 2**3 - 1) + expected = (((), ()), ((), ())) + actual = nx.to_nested_tuple(T, 0) + assert nodes_equal(expected, actual) + + def test_canonical_form(self): + T = nx.Graph() + T.add_edges_from([(0, 1), (0, 2), (0, 3)]) + T.add_edges_from([(1, 4), (1, 5)]) + T.add_edges_from([(3, 6), (3, 7)]) + root = 0 + actual = nx.to_nested_tuple(T, root, canonical_form=True) + expected = ((), ((), ()), ((), ())) + assert actual == expected + + def test_decoding(self): + balanced = (((), ()), ((), ())) + expected = nx.full_rary_tree(2, 2**3 - 1) + actual = nx.from_nested_tuple(balanced) + assert nx.is_isomorphic(expected, actual) + + def test_sensible_relabeling(self): + balanced = (((), ()), ((), ())) + T = nx.from_nested_tuple(balanced, sensible_relabeling=True) + edges = [(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)] + assert nodes_equal(list(T), list(range(2**3 - 1))) + assert edges_equal(list(T.edges()), edges) diff --git a/lib/python3.10/site-packages/networkx/algorithms/tree/tests/test_decomposition.py b/lib/python3.10/site-packages/networkx/algorithms/tree/tests/test_decomposition.py new file mode 100644 index 0000000000000000000000000000000000000000..8c376053794537611f46c038ed074eb92b1ba676 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/tree/tests/test_decomposition.py @@ -0,0 +1,79 @@ +import networkx as nx +from networkx.algorithms.tree.decomposition import junction_tree + + +def test_junction_tree_directed_confounders(): + B = nx.DiGraph() + B.add_edges_from([("A", "C"), ("B", "C"), ("C", "D"), ("C", "E")]) + + G = junction_tree(B) + J = nx.Graph() + J.add_edges_from( + [ + (("C", "E"), ("C",)), + (("C",), ("A", "B", "C")), + (("A", "B", "C"), ("C",)), + (("C",), ("C", "D")), + ] + ) + + assert nx.is_isomorphic(G, J) + + +def test_junction_tree_directed_unconnected_nodes(): + B = nx.DiGraph() + B.add_nodes_from([("A", "B", "C", "D")]) + G = junction_tree(B) + + J = nx.Graph() + J.add_nodes_from([("A", "B", "C", "D")]) + + assert nx.is_isomorphic(G, J) + + +def test_junction_tree_directed_cascade(): + B = nx.DiGraph() + B.add_edges_from([("A", "B"), ("B", "C"), ("C", "D")]) + G = junction_tree(B) + + J = nx.Graph() + J.add_edges_from( + [ + (("A", "B"), ("B",)), + (("B",), ("B", "C")), + (("B", "C"), ("C",)), + (("C",), ("C", "D")), + ] + ) + assert nx.is_isomorphic(G, J) + + +def test_junction_tree_directed_unconnected_edges(): + B = nx.DiGraph() + B.add_edges_from([("A", "B"), ("C", "D"), ("E", "F")]) + G = junction_tree(B) + + J = nx.Graph() + J.add_nodes_from([("A", "B"), ("C", "D"), ("E", "F")]) + + assert nx.is_isomorphic(G, J) + + +def test_junction_tree_undirected(): + B = nx.Graph() + B.add_edges_from([("A", "C"), ("A", "D"), ("B", "C"), ("C", "E")]) + G = junction_tree(B) + + J = nx.Graph() + J.add_edges_from( + [ + (("A", "D"), ("A",)), + (("A",), ("A", "C")), + (("A", "C"), ("C",)), + (("C",), ("B", "C")), + (("B", "C"), ("C",)), + (("C",), ("C", "E")), + ] + ) + + assert nx.is_isomorphic(G, J) diff --git a/lib/python3.10/site-packages/networkx/algorithms/tree/tests/test_mst.py b/lib/python3.10/site-packages/networkx/algorithms/tree/tests/test_mst.py new file mode 100644 index 0000000000000000000000000000000000000000..f8945a71835dbfa35c0c45259c8c84b653b1f49b --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/tree/tests/test_mst.py @@ -0,0 +1,918 @@ +"""Unit tests for the :mod:`networkx.algorithms.tree.mst` module.""" + +import pytest + +import networkx as nx +from networkx.utils import edges_equal, nodes_equal + + +def test_unknown_algorithm(): + with pytest.raises(ValueError): + nx.minimum_spanning_tree(nx.Graph(), algorithm="random") + with pytest.raises( + ValueError, match="random is not a valid choice for an algorithm." + ): + nx.maximum_spanning_edges(nx.Graph(), algorithm="random") + + +class MinimumSpanningTreeTestBase: + """Base class for test classes for minimum spanning tree algorithms. + This class contains some common tests that will be inherited by + subclasses. Each subclass must have a class attribute + :data:`algorithm` that is a string representing the algorithm to + run, as described under the ``algorithm`` keyword argument for the + :func:`networkx.minimum_spanning_edges` function. Subclasses can + then implement any algorithm-specific tests. + """ + + def setup_method(self, method): + """Creates an example graph and stores the expected minimum and + maximum spanning tree edges. + """ + # This stores the class attribute `algorithm` in an instance attribute. + self.algo = self.algorithm + # This example graph comes from Wikipedia: + # https://en.wikipedia.org/wiki/Kruskal's_algorithm + edges = [ + (0, 1, 7), + (0, 3, 5), + (1, 2, 8), + (1, 3, 9), + (1, 4, 7), + (2, 4, 5), + (3, 4, 15), + (3, 5, 6), + (4, 5, 8), + (4, 6, 9), + (5, 6, 11), + ] + self.G = nx.Graph() + self.G.add_weighted_edges_from(edges) + self.minimum_spanning_edgelist = [ + (0, 1, {"weight": 7}), + (0, 3, {"weight": 5}), + (1, 4, {"weight": 7}), + (2, 4, {"weight": 5}), + (3, 5, {"weight": 6}), + (4, 6, {"weight": 9}), + ] + self.maximum_spanning_edgelist = [ + (0, 1, {"weight": 7}), + (1, 2, {"weight": 8}), + (1, 3, {"weight": 9}), + (3, 4, {"weight": 15}), + (4, 6, {"weight": 9}), + (5, 6, {"weight": 11}), + ] + + def test_minimum_edges(self): + edges = nx.minimum_spanning_edges(self.G, algorithm=self.algo) + # Edges from the spanning edges functions don't come in sorted + # orientation, so we need to sort each edge individually. + actual = sorted((min(u, v), max(u, v), d) for u, v, d in edges) + assert edges_equal(actual, self.minimum_spanning_edgelist) + + def test_maximum_edges(self): + edges = nx.maximum_spanning_edges(self.G, algorithm=self.algo) + # Edges from the spanning edges functions don't come in sorted + # orientation, so we need to sort each edge individually. + actual = sorted((min(u, v), max(u, v), d) for u, v, d in edges) + assert edges_equal(actual, self.maximum_spanning_edgelist) + + def test_without_data(self): + edges = nx.minimum_spanning_edges(self.G, algorithm=self.algo, data=False) + # Edges from the spanning edges functions don't come in sorted + # orientation, so we need to sort each edge individually. + actual = sorted((min(u, v), max(u, v)) for u, v in edges) + expected = [(u, v) for u, v, d in self.minimum_spanning_edgelist] + assert edges_equal(actual, expected) + + def test_nan_weights(self): + # Edge weights NaN never appear in the spanning tree. see #2164 + G = self.G + G.add_edge(0, 12, weight=float("nan")) + edges = nx.minimum_spanning_edges( + G, algorithm=self.algo, data=False, ignore_nan=True + ) + actual = sorted((min(u, v), max(u, v)) for u, v in edges) + expected = [(u, v) for u, v, d in self.minimum_spanning_edgelist] + assert edges_equal(actual, expected) + # Now test for raising exception + edges = nx.minimum_spanning_edges( + G, algorithm=self.algo, data=False, ignore_nan=False + ) + with pytest.raises(ValueError): + list(edges) + # test default for ignore_nan as False + edges = nx.minimum_spanning_edges(G, algorithm=self.algo, data=False) + with pytest.raises(ValueError): + list(edges) + + def test_nan_weights_MultiGraph(self): + G = nx.MultiGraph() + G.add_edge(0, 12, weight=float("nan")) + edges = nx.minimum_spanning_edges( + G, algorithm="prim", data=False, ignore_nan=False + ) + with pytest.raises(ValueError): + list(edges) + # test default for ignore_nan as False + edges = nx.minimum_spanning_edges(G, algorithm="prim", data=False) + with pytest.raises(ValueError): + list(edges) + + def test_nan_weights_order(self): + # now try again with a nan edge at the beginning of G.nodes + edges = [ + (0, 1, 7), + (0, 3, 5), + (1, 2, 8), + (1, 3, 9), + (1, 4, 7), + (2, 4, 5), + (3, 4, 15), + (3, 5, 6), + (4, 5, 8), + (4, 6, 9), + (5, 6, 11), + ] + G = nx.Graph() + G.add_weighted_edges_from([(u + 1, v + 1, wt) for u, v, wt in edges]) + G.add_edge(0, 7, weight=float("nan")) + edges = nx.minimum_spanning_edges( + G, algorithm=self.algo, data=False, ignore_nan=True + ) + actual = sorted((min(u, v), max(u, v)) for u, v in edges) + shift = [(u + 1, v + 1) for u, v, d in self.minimum_spanning_edgelist] + assert edges_equal(actual, shift) + + def test_isolated_node(self): + # now try again with an isolated node + edges = [ + (0, 1, 7), + (0, 3, 5), + (1, 2, 8), + (1, 3, 9), + (1, 4, 7), + (2, 4, 5), + (3, 4, 15), + (3, 5, 6), + (4, 5, 8), + (4, 6, 9), + (5, 6, 11), + ] + G = nx.Graph() + G.add_weighted_edges_from([(u + 1, v + 1, wt) for u, v, wt in edges]) + G.add_node(0) + edges = nx.minimum_spanning_edges( + G, algorithm=self.algo, data=False, ignore_nan=True + ) + actual = sorted((min(u, v), max(u, v)) for u, v in edges) + shift = [(u + 1, v + 1) for u, v, d in self.minimum_spanning_edgelist] + assert edges_equal(actual, shift) + + def test_minimum_tree(self): + T = nx.minimum_spanning_tree(self.G, algorithm=self.algo) + actual = sorted(T.edges(data=True)) + assert edges_equal(actual, self.minimum_spanning_edgelist) + + def test_maximum_tree(self): + T = nx.maximum_spanning_tree(self.G, algorithm=self.algo) + actual = sorted(T.edges(data=True)) + assert edges_equal(actual, self.maximum_spanning_edgelist) + + def test_disconnected(self): + G = nx.Graph([(0, 1, {"weight": 1}), (2, 3, {"weight": 2})]) + T = nx.minimum_spanning_tree(G, algorithm=self.algo) + assert nodes_equal(list(T), list(range(4))) + assert edges_equal(list(T.edges()), [(0, 1), (2, 3)]) + + def test_empty_graph(self): + G = nx.empty_graph(3) + T = nx.minimum_spanning_tree(G, algorithm=self.algo) + assert nodes_equal(sorted(T), list(range(3))) + assert T.number_of_edges() == 0 + + def test_attributes(self): + G = nx.Graph() + G.add_edge(1, 2, weight=1, color="red", distance=7) + G.add_edge(2, 3, weight=1, color="green", distance=2) + G.add_edge(1, 3, weight=10, color="blue", distance=1) + G.graph["foo"] = "bar" + T = nx.minimum_spanning_tree(G, algorithm=self.algo) + assert T.graph == G.graph + assert nodes_equal(T, G) + for u, v in T.edges(): + assert T.adj[u][v] == G.adj[u][v] + + def test_weight_attribute(self): + G = nx.Graph() + G.add_edge(0, 1, weight=1, distance=7) + G.add_edge(0, 2, weight=30, distance=1) + G.add_edge(1, 2, weight=1, distance=1) + G.add_node(3) + T = nx.minimum_spanning_tree(G, algorithm=self.algo, weight="distance") + assert nodes_equal(sorted(T), list(range(4))) + assert edges_equal(sorted(T.edges()), [(0, 2), (1, 2)]) + T = nx.maximum_spanning_tree(G, algorithm=self.algo, weight="distance") + assert nodes_equal(sorted(T), list(range(4))) + assert edges_equal(sorted(T.edges()), [(0, 1), (0, 2)]) + + +class TestBoruvka(MinimumSpanningTreeTestBase): + """Unit tests for computing a minimum (or maximum) spanning tree + using Borůvka's algorithm. + """ + + algorithm = "boruvka" + + def test_unicode_name(self): + """Tests that using a Unicode string can correctly indicate + Borůvka's algorithm. + """ + edges = nx.minimum_spanning_edges(self.G, algorithm="borůvka") + # Edges from the spanning edges functions don't come in sorted + # orientation, so we need to sort each edge individually. + actual = sorted((min(u, v), max(u, v), d) for u, v, d in edges) + assert edges_equal(actual, self.minimum_spanning_edgelist) + + +class MultigraphMSTTestBase(MinimumSpanningTreeTestBase): + # Abstract class + + def test_multigraph_keys_min(self): + """Tests that the minimum spanning edges of a multigraph + preserves edge keys. + """ + G = nx.MultiGraph() + G.add_edge(0, 1, key="a", weight=2) + G.add_edge(0, 1, key="b", weight=1) + min_edges = nx.minimum_spanning_edges + mst_edges = min_edges(G, algorithm=self.algo, data=False) + assert edges_equal([(0, 1, "b")], list(mst_edges)) + + def test_multigraph_keys_max(self): + """Tests that the maximum spanning edges of a multigraph + preserves edge keys. + """ + G = nx.MultiGraph() + G.add_edge(0, 1, key="a", weight=2) + G.add_edge(0, 1, key="b", weight=1) + max_edges = nx.maximum_spanning_edges + mst_edges = max_edges(G, algorithm=self.algo, data=False) + assert edges_equal([(0, 1, "a")], list(mst_edges)) + + +class TestKruskal(MultigraphMSTTestBase): + """Unit tests for computing a minimum (or maximum) spanning tree + using Kruskal's algorithm. + """ + + algorithm = "kruskal" + + def test_key_data_bool(self): + """Tests that the keys and data values are included in + MST edges based on whether keys and data parameters are + true or false""" + G = nx.MultiGraph() + G.add_edge(1, 2, key=1, weight=2) + G.add_edge(1, 2, key=2, weight=3) + G.add_edge(3, 2, key=1, weight=2) + G.add_edge(3, 1, key=1, weight=4) + + # keys are included and data is not included + mst_edges = nx.minimum_spanning_edges( + G, algorithm=self.algo, keys=True, data=False + ) + assert edges_equal([(1, 2, 1), (2, 3, 1)], list(mst_edges)) + + # keys are not included and data is included + mst_edges = nx.minimum_spanning_edges( + G, algorithm=self.algo, keys=False, data=True + ) + assert edges_equal( + [(1, 2, {"weight": 2}), (2, 3, {"weight": 2})], list(mst_edges) + ) + + # both keys and data are not included + mst_edges = nx.minimum_spanning_edges( + G, algorithm=self.algo, keys=False, data=False + ) + assert edges_equal([(1, 2), (2, 3)], list(mst_edges)) + + # both keys and data are included + mst_edges = nx.minimum_spanning_edges( + G, algorithm=self.algo, keys=True, data=True + ) + assert edges_equal( + [(1, 2, 1, {"weight": 2}), (2, 3, 1, {"weight": 2})], list(mst_edges) + ) + + +class TestPrim(MultigraphMSTTestBase): + """Unit tests for computing a minimum (or maximum) spanning tree + using Prim's algorithm. + """ + + algorithm = "prim" + + def test_prim_mst_edges_simple_graph(self): + H = nx.Graph() + H.add_edge(1, 2, key=2, weight=3) + H.add_edge(3, 2, key=1, weight=2) + H.add_edge(3, 1, key=1, weight=4) + + mst_edges = nx.minimum_spanning_edges(H, algorithm=self.algo, ignore_nan=True) + assert edges_equal( + [(1, 2, {"key": 2, "weight": 3}), (2, 3, {"key": 1, "weight": 2})], + list(mst_edges), + ) + + def test_ignore_nan(self): + """Tests that the edges with NaN weights are ignored or + raise an Error based on ignore_nan is true or false""" + H = nx.MultiGraph() + H.add_edge(1, 2, key=1, weight=float("nan")) + H.add_edge(1, 2, key=2, weight=3) + H.add_edge(3, 2, key=1, weight=2) + H.add_edge(3, 1, key=1, weight=4) + + # NaN weight edges are ignored when ignore_nan=True + mst_edges = nx.minimum_spanning_edges(H, algorithm=self.algo, ignore_nan=True) + assert edges_equal( + [(1, 2, 2, {"weight": 3}), (2, 3, 1, {"weight": 2})], list(mst_edges) + ) + + # NaN weight edges raise Error when ignore_nan=False + with pytest.raises(ValueError): + list(nx.minimum_spanning_edges(H, algorithm=self.algo, ignore_nan=False)) + + def test_multigraph_keys_tree(self): + G = nx.MultiGraph() + G.add_edge(0, 1, key="a", weight=2) + G.add_edge(0, 1, key="b", weight=1) + T = nx.minimum_spanning_tree(G, algorithm=self.algo) + assert edges_equal([(0, 1, 1)], list(T.edges(data="weight"))) + + def test_multigraph_keys_tree_max(self): + G = nx.MultiGraph() + G.add_edge(0, 1, key="a", weight=2) + G.add_edge(0, 1, key="b", weight=1) + T = nx.maximum_spanning_tree(G, algorithm=self.algo) + assert edges_equal([(0, 1, 2)], list(T.edges(data="weight"))) + + +class TestSpanningTreeIterator: + """ + Tests the spanning tree iterator on the example graph in the 2005 Sörensen + and Janssens paper An Algorithm to Generate all Spanning Trees of a Graph in + Order of Increasing Cost + """ + + def setup_method(self): + # Original Graph + edges = [(0, 1, 5), (1, 2, 4), (1, 4, 6), (2, 3, 5), (2, 4, 7), (3, 4, 3)] + self.G = nx.Graph() + self.G.add_weighted_edges_from(edges) + # List of lists of spanning trees in increasing order + self.spanning_trees = [ + # 1, MST, cost = 17 + [ + (0, 1, {"weight": 5}), + (1, 2, {"weight": 4}), + (2, 3, {"weight": 5}), + (3, 4, {"weight": 3}), + ], + # 2, cost = 18 + [ + (0, 1, {"weight": 5}), + (1, 2, {"weight": 4}), + (1, 4, {"weight": 6}), + (3, 4, {"weight": 3}), + ], + # 3, cost = 19 + [ + (0, 1, {"weight": 5}), + (1, 4, {"weight": 6}), + (2, 3, {"weight": 5}), + (3, 4, {"weight": 3}), + ], + # 4, cost = 19 + [ + (0, 1, {"weight": 5}), + (1, 2, {"weight": 4}), + (2, 4, {"weight": 7}), + (3, 4, {"weight": 3}), + ], + # 5, cost = 20 + [ + (0, 1, {"weight": 5}), + (1, 2, {"weight": 4}), + (1, 4, {"weight": 6}), + (2, 3, {"weight": 5}), + ], + # 6, cost = 21 + [ + (0, 1, {"weight": 5}), + (1, 4, {"weight": 6}), + (2, 4, {"weight": 7}), + (3, 4, {"weight": 3}), + ], + # 7, cost = 21 + [ + (0, 1, {"weight": 5}), + (1, 2, {"weight": 4}), + (2, 3, {"weight": 5}), + (2, 4, {"weight": 7}), + ], + # 8, cost = 23 + [ + (0, 1, {"weight": 5}), + (1, 4, {"weight": 6}), + (2, 3, {"weight": 5}), + (2, 4, {"weight": 7}), + ], + ] + + def test_minimum_spanning_tree_iterator(self): + """ + Tests that the spanning trees are correctly returned in increasing order + """ + tree_index = 0 + for tree in nx.SpanningTreeIterator(self.G): + actual = sorted(tree.edges(data=True)) + assert edges_equal(actual, self.spanning_trees[tree_index]) + tree_index += 1 + + def test_maximum_spanning_tree_iterator(self): + """ + Tests that the spanning trees are correctly returned in decreasing order + """ + tree_index = 7 + for tree in nx.SpanningTreeIterator(self.G, minimum=False): + actual = sorted(tree.edges(data=True)) + assert edges_equal(actual, self.spanning_trees[tree_index]) + tree_index -= 1 + + +class TestSpanningTreeMultiGraphIterator: + """ + Uses the same graph as the above class but with an added edge of twice the weight. + """ + + def setup_method(self): + # New graph + edges = [ + (0, 1, 5), + (0, 1, 10), + (1, 2, 4), + (1, 2, 8), + (1, 4, 6), + (1, 4, 12), + (2, 3, 5), + (2, 3, 10), + (2, 4, 7), + (2, 4, 14), + (3, 4, 3), + (3, 4, 6), + ] + self.G = nx.MultiGraph() + self.G.add_weighted_edges_from(edges) + + # There are 128 trees. I'd rather not list all 128 here, and computing them + # on such a small graph actually doesn't take that long. + from itertools import combinations + + self.spanning_trees = [] + for e in combinations(self.G.edges, 4): + tree = self.G.edge_subgraph(e) + if nx.is_tree(tree): + self.spanning_trees.append(sorted(tree.edges(keys=True, data=True))) + + def test_minimum_spanning_tree_iterator_multigraph(self): + """ + Tests that the spanning trees are correctly returned in increasing order + """ + tree_index = 0 + last_weight = 0 + for tree in nx.SpanningTreeIterator(self.G): + actual = sorted(tree.edges(keys=True, data=True)) + weight = sum([e[3]["weight"] for e in actual]) + assert actual in self.spanning_trees + assert weight >= last_weight + tree_index += 1 + + def test_maximum_spanning_tree_iterator_multigraph(self): + """ + Tests that the spanning trees are correctly returned in decreasing order + """ + tree_index = 127 + # Maximum weight tree is 46 + last_weight = 50 + for tree in nx.SpanningTreeIterator(self.G, minimum=False): + actual = sorted(tree.edges(keys=True, data=True)) + weight = sum([e[3]["weight"] for e in actual]) + assert actual in self.spanning_trees + assert weight <= last_weight + tree_index -= 1 + + +def test_random_spanning_tree_multiplicative_small(): + """ + Using a fixed seed, sample one tree for repeatability. + """ + from math import exp + + pytest.importorskip("scipy") + + gamma = { + (0, 1): -0.6383, + (0, 2): -0.6827, + (0, 5): 0, + (1, 2): -1.0781, + (1, 4): 0, + (2, 3): 0, + (5, 3): -0.2820, + (5, 4): -0.3327, + (4, 3): -0.9927, + } + + # The undirected support of gamma + G = nx.Graph() + for u, v in gamma: + G.add_edge(u, v, lambda_key=exp(gamma[(u, v)])) + + solution_edges = [(2, 3), (3, 4), (0, 5), (5, 4), (4, 1)] + solution = nx.Graph() + solution.add_edges_from(solution_edges) + + sampled_tree = nx.random_spanning_tree(G, "lambda_key", seed=42) + + assert nx.utils.edges_equal(solution.edges, sampled_tree.edges) + + +@pytest.mark.slow +def test_random_spanning_tree_multiplicative_large(): + """ + Sample many trees from the distribution created in the last test + """ + from math import exp + from random import Random + + pytest.importorskip("numpy") + stats = pytest.importorskip("scipy.stats") + + gamma = { + (0, 1): -0.6383, + (0, 2): -0.6827, + (0, 5): 0, + (1, 2): -1.0781, + (1, 4): 0, + (2, 3): 0, + (5, 3): -0.2820, + (5, 4): -0.3327, + (4, 3): -0.9927, + } + + # The undirected support of gamma + G = nx.Graph() + for u, v in gamma: + G.add_edge(u, v, lambda_key=exp(gamma[(u, v)])) + + # Find the multiplicative weight for each tree. + total_weight = 0 + tree_expected = {} + for t in nx.SpanningTreeIterator(G): + # Find the multiplicative weight of the spanning tree + weight = 1 + for u, v, d in t.edges(data="lambda_key"): + weight *= d + tree_expected[t] = weight + total_weight += weight + + # Assert that every tree has an entry in the expected distribution + assert len(tree_expected) == 75 + + # Set the sample size and then calculate the expected number of times we + # expect to see each tree. This test uses a near minimum sample size where + # the most unlikely tree has an expected frequency of 5.15. + # (Minimum required is 5) + # + # Here we also initialize the tree_actual dict so that we know the keys + # match between the two. We will later take advantage of the fact that since + # python 3.7 dict order is guaranteed so the expected and actual data will + # have the same order. + sample_size = 1200 + tree_actual = {} + for t in tree_expected: + tree_expected[t] = (tree_expected[t] / total_weight) * sample_size + tree_actual[t] = 0 + + # Sample the spanning trees + # + # Assert that they are actually trees and record which of the 75 trees we + # have sampled. + # + # For repeatability, we want to take advantage of the decorators in NetworkX + # to randomly sample the same sample each time. However, if we pass in a + # constant seed to sample_spanning_tree we will get the same tree each time. + # Instead, we can create our own random number generator with a fixed seed + # and pass those into sample_spanning_tree. + rng = Random(37) + for _ in range(sample_size): + sampled_tree = nx.random_spanning_tree(G, "lambda_key", seed=rng) + assert nx.is_tree(sampled_tree) + + for t in tree_expected: + if nx.utils.edges_equal(t.edges, sampled_tree.edges): + tree_actual[t] += 1 + break + + # Conduct a Chi squared test to see if the actual distribution matches the + # expected one at an alpha = 0.05 significance level. + # + # H_0: The distribution of trees in tree_actual matches the normalized product + # of the edge weights in the tree. + # + # H_a: The distribution of trees in tree_actual follows some other + # distribution of spanning trees. + _, p = stats.chisquare(list(tree_actual.values()), list(tree_expected.values())) + + # Assert that p is greater than the significance level so that we do not + # reject the null hypothesis + assert not p < 0.05 + + +def test_random_spanning_tree_additive_small(): + """ + Sample a single spanning tree from the additive method. + """ + pytest.importorskip("scipy") + + edges = { + (0, 1): 1, + (0, 2): 1, + (0, 5): 3, + (1, 2): 2, + (1, 4): 3, + (2, 3): 3, + (5, 3): 4, + (5, 4): 5, + (4, 3): 4, + } + + # Build the graph + G = nx.Graph() + for u, v in edges: + G.add_edge(u, v, weight=edges[(u, v)]) + + solution_edges = [(0, 2), (1, 2), (2, 3), (3, 4), (3, 5)] + solution = nx.Graph() + solution.add_edges_from(solution_edges) + + sampled_tree = nx.random_spanning_tree( + G, weight="weight", multiplicative=False, seed=37 + ) + + assert nx.utils.edges_equal(solution.edges, sampled_tree.edges) + + +@pytest.mark.slow +def test_random_spanning_tree_additive_large(): + """ + Sample many spanning trees from the additive method. + """ + from random import Random + + pytest.importorskip("numpy") + stats = pytest.importorskip("scipy.stats") + + edges = { + (0, 1): 1, + (0, 2): 1, + (0, 5): 3, + (1, 2): 2, + (1, 4): 3, + (2, 3): 3, + (5, 3): 4, + (5, 4): 5, + (4, 3): 4, + } + + # Build the graph + G = nx.Graph() + for u, v in edges: + G.add_edge(u, v, weight=edges[(u, v)]) + + # Find the additive weight for each tree. + total_weight = 0 + tree_expected = {} + for t in nx.SpanningTreeIterator(G): + # Find the multiplicative weight of the spanning tree + weight = 0 + for u, v, d in t.edges(data="weight"): + weight += d + tree_expected[t] = weight + total_weight += weight + + # Assert that every tree has an entry in the expected distribution + assert len(tree_expected) == 75 + + # Set the sample size and then calculate the expected number of times we + # expect to see each tree. This test uses a near minimum sample size where + # the most unlikely tree has an expected frequency of 5.07. + # (Minimum required is 5) + # + # Here we also initialize the tree_actual dict so that we know the keys + # match between the two. We will later take advantage of the fact that since + # python 3.7 dict order is guaranteed so the expected and actual data will + # have the same order. + sample_size = 500 + tree_actual = {} + for t in tree_expected: + tree_expected[t] = (tree_expected[t] / total_weight) * sample_size + tree_actual[t] = 0 + + # Sample the spanning trees + # + # Assert that they are actually trees and record which of the 75 trees we + # have sampled. + # + # For repeatability, we want to take advantage of the decorators in NetworkX + # to randomly sample the same sample each time. However, if we pass in a + # constant seed to sample_spanning_tree we will get the same tree each time. + # Instead, we can create our own random number generator with a fixed seed + # and pass those into sample_spanning_tree. + rng = Random(37) + for _ in range(sample_size): + sampled_tree = nx.random_spanning_tree( + G, "weight", multiplicative=False, seed=rng + ) + assert nx.is_tree(sampled_tree) + + for t in tree_expected: + if nx.utils.edges_equal(t.edges, sampled_tree.edges): + tree_actual[t] += 1 + break + + # Conduct a Chi squared test to see if the actual distribution matches the + # expected one at an alpha = 0.05 significance level. + # + # H_0: The distribution of trees in tree_actual matches the normalized product + # of the edge weights in the tree. + # + # H_a: The distribution of trees in tree_actual follows some other + # distribution of spanning trees. + _, p = stats.chisquare(list(tree_actual.values()), list(tree_expected.values())) + + # Assert that p is greater than the significance level so that we do not + # reject the null hypothesis + assert not p < 0.05 + + +def test_random_spanning_tree_empty_graph(): + G = nx.Graph() + rst = nx.tree.random_spanning_tree(G) + assert len(rst.nodes) == 0 + assert len(rst.edges) == 0 + + +def test_random_spanning_tree_single_node_graph(): + G = nx.Graph() + G.add_node(0) + rst = nx.tree.random_spanning_tree(G) + assert len(rst.nodes) == 1 + assert len(rst.edges) == 0 + + +def test_random_spanning_tree_single_node_loop(): + G = nx.Graph() + G.add_node(0) + G.add_edge(0, 0) + rst = nx.tree.random_spanning_tree(G) + assert len(rst.nodes) == 1 + assert len(rst.edges) == 0 + + +class TestNumberSpanningTrees: + @classmethod + def setup_class(cls): + global np + np = pytest.importorskip("numpy") + sp = pytest.importorskip("scipy") + + def test_nst_disconnected(self): + G = nx.empty_graph(2) + assert np.isclose(nx.number_of_spanning_trees(G), 0) + + def test_nst_no_nodes(self): + G = nx.Graph() + with pytest.raises(nx.NetworkXPointlessConcept): + nx.number_of_spanning_trees(G) + + def test_nst_weight(self): + G = nx.Graph() + G.add_edge(1, 2, weight=1) + G.add_edge(1, 3, weight=1) + G.add_edge(2, 3, weight=2) + # weights are ignored + assert np.isclose(nx.number_of_spanning_trees(G), 3) + # including weight + assert np.isclose(nx.number_of_spanning_trees(G, weight="weight"), 5) + + def test_nst_negative_weight(self): + G = nx.Graph() + G.add_edge(1, 2, weight=1) + G.add_edge(1, 3, weight=-1) + G.add_edge(2, 3, weight=-2) + # weights are ignored + assert np.isclose(nx.number_of_spanning_trees(G), 3) + # including weight + assert np.isclose(nx.number_of_spanning_trees(G, weight="weight"), -1) + + def test_nst_selfloop(self): + # self-loops are ignored + G = nx.complete_graph(3) + G.add_edge(1, 1) + assert np.isclose(nx.number_of_spanning_trees(G), 3) + + def test_nst_multigraph(self): + G = nx.MultiGraph() + G.add_edge(1, 2) + G.add_edge(1, 2) + G.add_edge(1, 3) + G.add_edge(2, 3) + assert np.isclose(nx.number_of_spanning_trees(G), 5) + + def test_nst_complete_graph(self): + # this is known as Cayley's formula + N = 5 + G = nx.complete_graph(N) + assert np.isclose(nx.number_of_spanning_trees(G), N ** (N - 2)) + + def test_nst_path_graph(self): + G = nx.path_graph(5) + assert np.isclose(nx.number_of_spanning_trees(G), 1) + + def test_nst_cycle_graph(self): + G = nx.cycle_graph(5) + assert np.isclose(nx.number_of_spanning_trees(G), 5) + + def test_nst_directed_noroot(self): + G = nx.empty_graph(3, create_using=nx.MultiDiGraph) + with pytest.raises(nx.NetworkXError): + nx.number_of_spanning_trees(G) + + def test_nst_directed_root_not_exist(self): + G = nx.empty_graph(3, create_using=nx.MultiDiGraph) + with pytest.raises(nx.NetworkXError): + nx.number_of_spanning_trees(G, root=42) + + def test_nst_directed_not_weak_connected(self): + G = nx.DiGraph() + G.add_edge(1, 2) + G.add_edge(3, 4) + assert np.isclose(nx.number_of_spanning_trees(G, root=1), 0) + + def test_nst_directed_cycle_graph(self): + G = nx.DiGraph() + G = nx.cycle_graph(7, G) + assert np.isclose(nx.number_of_spanning_trees(G, root=0), 1) + + def test_nst_directed_complete_graph(self): + G = nx.DiGraph() + G = nx.complete_graph(7, G) + assert np.isclose(nx.number_of_spanning_trees(G, root=0), 7**5) + + def test_nst_directed_multi(self): + G = nx.MultiDiGraph() + G = nx.cycle_graph(3, G) + G.add_edge(1, 2) + assert np.isclose(nx.number_of_spanning_trees(G, root=0), 2) + + def test_nst_directed_selfloop(self): + G = nx.MultiDiGraph() + G = nx.cycle_graph(3, G) + G.add_edge(1, 1) + assert np.isclose(nx.number_of_spanning_trees(G, root=0), 1) + + def test_nst_directed_weak_connected(self): + G = nx.MultiDiGraph() + G = nx.cycle_graph(3, G) + G.remove_edge(1, 2) + assert np.isclose(nx.number_of_spanning_trees(G, root=0), 0) + + def test_nst_directed_weighted(self): + # from root=1: + # arborescence 1: 1->2, 1->3, weight=2*1 + # arborescence 2: 1->2, 2->3, weight=2*3 + G = nx.DiGraph() + G.add_edge(1, 2, weight=2) + G.add_edge(1, 3, weight=1) + G.add_edge(2, 3, weight=3) + Nst = nx.number_of_spanning_trees(G, root=1, weight="weight") + assert np.isclose(Nst, 8) + Nst = nx.number_of_spanning_trees(G, root=2, weight="weight") + assert np.isclose(Nst, 0) + Nst = nx.number_of_spanning_trees(G, root=3, weight="weight") + assert np.isclose(Nst, 0) diff --git a/lib/python3.10/site-packages/networkx/algorithms/tree/tests/test_operations.py b/lib/python3.10/site-packages/networkx/algorithms/tree/tests/test_operations.py new file mode 100644 index 0000000000000000000000000000000000000000..284d94e2e5059de267b5ea47f6012a42c6ac4639 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/tree/tests/test_operations.py @@ -0,0 +1,53 @@ +from itertools import chain + +import networkx as nx +from networkx.utils import edges_equal, nodes_equal + + +def _check_custom_label_attribute(input_trees, res_tree, label_attribute): + res_attr_dict = nx.get_node_attributes(res_tree, label_attribute) + res_attr_set = set(res_attr_dict.values()) + input_label = (tree for tree, root in input_trees) + input_label_set = set(chain.from_iterable(input_label)) + return res_attr_set == input_label_set + + +def test_empty_sequence(): + """Joining the empty sequence results in the tree with one node.""" + T = nx.join_trees([]) + assert len(T) == 1 + assert T.number_of_edges() == 0 + + +def test_single(): + """Joining just one tree yields a tree with one more node.""" + T = nx.empty_graph(1) + trees = [(T, 0)] + actual_with_label = nx.join_trees(trees, label_attribute="custom_label") + expected = nx.path_graph(2) + assert nodes_equal(list(expected), list(actual_with_label)) + assert edges_equal(list(expected.edges()), list(actual_with_label.edges())) + + +def test_basic(): + """Joining multiple subtrees at a root node.""" + trees = [(nx.full_rary_tree(2, 2**2 - 1), 0) for i in range(2)] + expected = nx.full_rary_tree(2, 2**3 - 1) + actual = nx.join_trees(trees, label_attribute="old_labels") + assert nx.is_isomorphic(actual, expected) + assert _check_custom_label_attribute(trees, actual, "old_labels") + + actual_without_label = nx.join_trees(trees) + assert nx.is_isomorphic(actual_without_label, expected) + # check that no labels were stored + assert all(not data for _, data in actual_without_label.nodes(data=True)) + + +def test_first_label(): + """Test the functionality of the first_label argument.""" + T1 = nx.path_graph(3) + T2 = nx.path_graph(2) + actual = nx.join_trees([(T1, 0), (T2, 0)], first_label=10) + expected_nodes = set(range(10, 16)) + assert set(actual.nodes()) == expected_nodes + assert set(actual.neighbors(10)) == {11, 14} diff --git a/lib/python3.10/site-packages/networkx/algorithms/tree/tests/test_recognition.py b/lib/python3.10/site-packages/networkx/algorithms/tree/tests/test_recognition.py new file mode 100644 index 0000000000000000000000000000000000000000..105f5a89e9b10d37d1cc140880a66bc860d2e9f8 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/algorithms/tree/tests/test_recognition.py @@ -0,0 +1,174 @@ +import pytest + +import networkx as nx + + +class TestTreeRecognition: + graph = nx.Graph + multigraph = nx.MultiGraph + + @classmethod + def setup_class(cls): + cls.T1 = cls.graph() + + cls.T2 = cls.graph() + cls.T2.add_node(1) + + cls.T3 = cls.graph() + cls.T3.add_nodes_from(range(5)) + edges = [(i, i + 1) for i in range(4)] + cls.T3.add_edges_from(edges) + + cls.T5 = cls.multigraph() + cls.T5.add_nodes_from(range(5)) + edges = [(i, i + 1) for i in range(4)] + cls.T5.add_edges_from(edges) + + cls.T6 = cls.graph() + cls.T6.add_nodes_from([6, 7]) + cls.T6.add_edge(6, 7) + + cls.F1 = nx.compose(cls.T6, cls.T3) + + cls.N4 = cls.graph() + cls.N4.add_node(1) + cls.N4.add_edge(1, 1) + + cls.N5 = cls.graph() + cls.N5.add_nodes_from(range(5)) + + cls.N6 = cls.graph() + cls.N6.add_nodes_from(range(3)) + cls.N6.add_edges_from([(0, 1), (1, 2), (2, 0)]) + + cls.NF1 = nx.compose(cls.T6, cls.N6) + + def test_null_tree(self): + with pytest.raises(nx.NetworkXPointlessConcept): + nx.is_tree(self.graph()) + + def test_null_tree2(self): + with pytest.raises(nx.NetworkXPointlessConcept): + nx.is_tree(self.multigraph()) + + def test_null_forest(self): + with pytest.raises(nx.NetworkXPointlessConcept): + nx.is_forest(self.graph()) + + def test_null_forest2(self): + with pytest.raises(nx.NetworkXPointlessConcept): + nx.is_forest(self.multigraph()) + + def test_is_tree(self): + assert nx.is_tree(self.T2) + assert nx.is_tree(self.T3) + assert nx.is_tree(self.T5) + + def test_is_not_tree(self): + assert not nx.is_tree(self.N4) + assert not nx.is_tree(self.N5) + assert not nx.is_tree(self.N6) + + def test_is_forest(self): + assert nx.is_forest(self.T2) + assert nx.is_forest(self.T3) + assert nx.is_forest(self.T5) + assert nx.is_forest(self.F1) + assert nx.is_forest(self.N5) + + def test_is_not_forest(self): + assert not nx.is_forest(self.N4) + assert not nx.is_forest(self.N6) + assert not nx.is_forest(self.NF1) + + +class TestDirectedTreeRecognition(TestTreeRecognition): + graph = nx.DiGraph + multigraph = nx.MultiDiGraph + + +def test_disconnected_graph(): + # https://github.com/networkx/networkx/issues/1144 + G = nx.Graph() + G.add_edges_from([(0, 1), (1, 2), (2, 0), (3, 4)]) + assert not nx.is_tree(G) + + G = nx.DiGraph() + G.add_edges_from([(0, 1), (1, 2), (2, 0), (3, 4)]) + assert not nx.is_tree(G) + + +def test_dag_nontree(): + G = nx.DiGraph() + G.add_edges_from([(0, 1), (0, 2), (1, 2)]) + assert not nx.is_tree(G) + assert nx.is_directed_acyclic_graph(G) + + +def test_multicycle(): + G = nx.MultiDiGraph() + G.add_edges_from([(0, 1), (0, 1)]) + assert not nx.is_tree(G) + assert nx.is_directed_acyclic_graph(G) + + +def test_emptybranch(): + G = nx.DiGraph() + G.add_nodes_from(range(10)) + assert nx.is_branching(G) + assert not nx.is_arborescence(G) + + +def test_is_branching_empty_graph_raises(): + G = nx.DiGraph() + with pytest.raises(nx.NetworkXPointlessConcept, match="G has no nodes."): + nx.is_branching(G) + + +def test_path(): + G = nx.DiGraph() + nx.add_path(G, range(5)) + assert nx.is_branching(G) + assert nx.is_arborescence(G) + + +def test_notbranching1(): + # Acyclic violation. + G = nx.MultiDiGraph() + G.add_nodes_from(range(10)) + G.add_edges_from([(0, 1), (1, 0)]) + assert not nx.is_branching(G) + assert not nx.is_arborescence(G) + + +def test_notbranching2(): + # In-degree violation. + G = nx.MultiDiGraph() + G.add_nodes_from(range(10)) + G.add_edges_from([(0, 1), (0, 2), (3, 2)]) + assert not nx.is_branching(G) + assert not nx.is_arborescence(G) + + +def test_notarborescence1(): + # Not an arborescence due to not spanning. + G = nx.MultiDiGraph() + G.add_nodes_from(range(10)) + G.add_edges_from([(0, 1), (0, 2), (1, 3), (5, 6)]) + assert nx.is_branching(G) + assert not nx.is_arborescence(G) + + +def test_notarborescence2(): + # Not an arborescence due to in-degree violation. + G = nx.MultiDiGraph() + nx.add_path(G, range(5)) + G.add_edge(6, 4) + assert not nx.is_branching(G) + assert not nx.is_arborescence(G) + + +def test_is_arborescense_empty_graph_raises(): + G = nx.DiGraph() + with pytest.raises(nx.NetworkXPointlessConcept, match="G has no nodes."): + nx.is_arborescence(G) diff --git a/lib/python3.10/site-packages/networkx/classes/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/networkx/classes/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a0680226074badca163663ecd0bef5c132542e45 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/classes/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/classes/__pycache__/coreviews.cpython-310.pyc b/lib/python3.10/site-packages/networkx/classes/__pycache__/coreviews.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2bd1521ce6b7fe9997c9f5259fcf5f3aeec8a9d4 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/classes/__pycache__/coreviews.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/classes/__pycache__/digraph.cpython-310.pyc b/lib/python3.10/site-packages/networkx/classes/__pycache__/digraph.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f5a4203a985f841bfa80a3feba1f98d6364950dd Binary files /dev/null and b/lib/python3.10/site-packages/networkx/classes/__pycache__/digraph.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/classes/__pycache__/filters.cpython-310.pyc b/lib/python3.10/site-packages/networkx/classes/__pycache__/filters.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..84eef9d0a9d90bf2d66041c323966a37dd03343e Binary files /dev/null and b/lib/python3.10/site-packages/networkx/classes/__pycache__/filters.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/classes/__pycache__/function.cpython-310.pyc b/lib/python3.10/site-packages/networkx/classes/__pycache__/function.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ac1f5a028b608938cde49c5b9d20d8a7d79274d Binary files /dev/null and b/lib/python3.10/site-packages/networkx/classes/__pycache__/function.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/classes/__pycache__/graph.cpython-310.pyc b/lib/python3.10/site-packages/networkx/classes/__pycache__/graph.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a09b3a91d2d01219c109dee77036f139e615feda Binary files /dev/null and b/lib/python3.10/site-packages/networkx/classes/__pycache__/graph.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/classes/__pycache__/graphviews.cpython-310.pyc b/lib/python3.10/site-packages/networkx/classes/__pycache__/graphviews.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1f7c0843cba1c284a810c1c057b65fce902baaa Binary files /dev/null and b/lib/python3.10/site-packages/networkx/classes/__pycache__/graphviews.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/classes/__pycache__/multidigraph.cpython-310.pyc b/lib/python3.10/site-packages/networkx/classes/__pycache__/multidigraph.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c0a115c4db5b285c641d95d65205b1a004e8a0ba Binary files /dev/null and b/lib/python3.10/site-packages/networkx/classes/__pycache__/multidigraph.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/classes/__pycache__/multigraph.cpython-310.pyc b/lib/python3.10/site-packages/networkx/classes/__pycache__/multigraph.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..030c380a24cc71221bcfffecb07c4ee186dfc30f Binary files /dev/null and b/lib/python3.10/site-packages/networkx/classes/__pycache__/multigraph.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/classes/__pycache__/reportviews.cpython-310.pyc b/lib/python3.10/site-packages/networkx/classes/__pycache__/reportviews.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..26fa494319d263a73013f5e9bf7ba4c070af4a6b Binary files /dev/null and b/lib/python3.10/site-packages/networkx/classes/__pycache__/reportviews.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/classes/tests/__init__.py b/lib/python3.10/site-packages/networkx/classes/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/python3.10/site-packages/networkx/classes/tests/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/networkx/classes/tests/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..26d7395c2b9531003bdb17a7878728f587bc6976 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/classes/tests/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/classes/tests/__pycache__/dispatch_interface.cpython-310.pyc b/lib/python3.10/site-packages/networkx/classes/tests/__pycache__/dispatch_interface.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..781bb14aa407ed3a5dd286fbe35ae92eb2fbcc47 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/classes/tests/__pycache__/dispatch_interface.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/classes/tests/__pycache__/historical_tests.cpython-310.pyc b/lib/python3.10/site-packages/networkx/classes/tests/__pycache__/historical_tests.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f8d0f318b05a5174d65963b12b8621afb3b328ff Binary files /dev/null and b/lib/python3.10/site-packages/networkx/classes/tests/__pycache__/historical_tests.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/classes/tests/__pycache__/test_coreviews.cpython-310.pyc b/lib/python3.10/site-packages/networkx/classes/tests/__pycache__/test_coreviews.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e47cc2685a7e7659d4596f3bfe15eae4f6a261b Binary files /dev/null and b/lib/python3.10/site-packages/networkx/classes/tests/__pycache__/test_coreviews.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/classes/tests/__pycache__/test_digraph.cpython-310.pyc b/lib/python3.10/site-packages/networkx/classes/tests/__pycache__/test_digraph.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c1153722419fd5413c972e4ab1e40a2e20d15e12 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/classes/tests/__pycache__/test_digraph.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/classes/tests/__pycache__/test_digraph_historical.cpython-310.pyc b/lib/python3.10/site-packages/networkx/classes/tests/__pycache__/test_digraph_historical.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..42678c822990497d8fb57e61bfff24436a15e0a2 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/classes/tests/__pycache__/test_digraph_historical.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/classes/tests/__pycache__/test_filters.cpython-310.pyc b/lib/python3.10/site-packages/networkx/classes/tests/__pycache__/test_filters.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9052d3f696f1711c7a348a55bde0f271bf092cb4 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/classes/tests/__pycache__/test_filters.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/classes/tests/__pycache__/test_function.cpython-310.pyc b/lib/python3.10/site-packages/networkx/classes/tests/__pycache__/test_function.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f6b73631d7bfaec257a077bfaa892d57646a193f Binary files /dev/null and b/lib/python3.10/site-packages/networkx/classes/tests/__pycache__/test_function.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/classes/tests/__pycache__/test_graph.cpython-310.pyc b/lib/python3.10/site-packages/networkx/classes/tests/__pycache__/test_graph.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c62f3fbb5f79999084933c329c39c32d17db189 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/classes/tests/__pycache__/test_graph.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/classes/tests/__pycache__/test_graph_historical.cpython-310.pyc b/lib/python3.10/site-packages/networkx/classes/tests/__pycache__/test_graph_historical.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..112ac736abbc2c9a4562a5273d51a45b65046e61 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/classes/tests/__pycache__/test_graph_historical.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/classes/tests/__pycache__/test_graphviews.cpython-310.pyc b/lib/python3.10/site-packages/networkx/classes/tests/__pycache__/test_graphviews.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e451b69fd95c362b7c701d27c31d389579a411fa Binary files /dev/null and b/lib/python3.10/site-packages/networkx/classes/tests/__pycache__/test_graphviews.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/classes/tests/__pycache__/test_multidigraph.cpython-310.pyc b/lib/python3.10/site-packages/networkx/classes/tests/__pycache__/test_multidigraph.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7acac10445a85014f055386d23da05164c948c8a Binary files /dev/null and b/lib/python3.10/site-packages/networkx/classes/tests/__pycache__/test_multidigraph.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/classes/tests/__pycache__/test_multigraph.cpython-310.pyc b/lib/python3.10/site-packages/networkx/classes/tests/__pycache__/test_multigraph.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d283d358422bd59c30a3814919a4e78faebe8a0 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/classes/tests/__pycache__/test_multigraph.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/classes/tests/__pycache__/test_reportviews.cpython-310.pyc b/lib/python3.10/site-packages/networkx/classes/tests/__pycache__/test_reportviews.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5023392af69aa5659c6170d27b223961a8567e5a Binary files /dev/null and b/lib/python3.10/site-packages/networkx/classes/tests/__pycache__/test_reportviews.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/classes/tests/__pycache__/test_special.cpython-310.pyc b/lib/python3.10/site-packages/networkx/classes/tests/__pycache__/test_special.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5732d35fa5245492b1d0a8c296f523fea7ecbbde Binary files /dev/null and b/lib/python3.10/site-packages/networkx/classes/tests/__pycache__/test_special.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/classes/tests/__pycache__/test_subgraphviews.cpython-310.pyc b/lib/python3.10/site-packages/networkx/classes/tests/__pycache__/test_subgraphviews.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f1910046e088dd152953118922998730e73b8c9 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/classes/tests/__pycache__/test_subgraphviews.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/classes/tests/dispatch_interface.py b/lib/python3.10/site-packages/networkx/classes/tests/dispatch_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..5cc908d707c8efa30ce1e334313e1f946bdb5348 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/classes/tests/dispatch_interface.py @@ -0,0 +1,185 @@ +# This file contains utilities for testing the dispatching feature + +# A full test of all dispatchable algorithms is performed by +# modifying the pytest invocation and setting an environment variable +# NETWORKX_TEST_BACKEND=nx_loopback pytest +# This is comprehensive, but only tests the `test_override_dispatch` +# function in networkx.classes.backends. + +# To test the `_dispatchable` function directly, several tests scattered throughout +# NetworkX have been augmented to test normal and dispatch mode. +# Searching for `dispatch_interface` should locate the specific tests. + +import networkx as nx +from networkx import DiGraph, Graph, MultiDiGraph, MultiGraph, PlanarEmbedding +from networkx.classes.reportviews import NodeView + + +class LoopbackGraph(Graph): + __networkx_backend__ = "nx_loopback" + + +class LoopbackDiGraph(DiGraph): + __networkx_backend__ = "nx_loopback" + + +class LoopbackMultiGraph(MultiGraph): + __networkx_backend__ = "nx_loopback" + + +class LoopbackMultiDiGraph(MultiDiGraph): + __networkx_backend__ = "nx_loopback" + + +class LoopbackPlanarEmbedding(PlanarEmbedding): + __networkx_backend__ = "nx_loopback" + + +def convert(graph): + if isinstance(graph, PlanarEmbedding): + return LoopbackPlanarEmbedding(graph) + if isinstance(graph, MultiDiGraph): + return LoopbackMultiDiGraph(graph) + if isinstance(graph, MultiGraph): + return LoopbackMultiGraph(graph) + if isinstance(graph, DiGraph): + return LoopbackDiGraph(graph) + if isinstance(graph, Graph): + return LoopbackGraph(graph) + raise TypeError(f"Unsupported type of graph: {type(graph)}") + + +class LoopbackBackendInterface: + def __getattr__(self, item): + try: + return nx.utils.backends._registered_algorithms[item].orig_func + except KeyError: + raise AttributeError(item) from None + + @staticmethod + def convert_from_nx( + graph, + *, + edge_attrs=None, + node_attrs=None, + preserve_edge_attrs=None, + preserve_node_attrs=None, + preserve_graph_attrs=None, + name=None, + graph_name=None, + ): + if name in { + # Raise if input graph changes. See test_dag.py::test_topological_sort6 + "lexicographical_topological_sort", + "topological_generations", + "topological_sort", + # Would be nice to some day avoid these cutoffs of full testing + }: + return graph + if isinstance(graph, NodeView): + # Convert to a Graph with only nodes (no edges) + new_graph = Graph() + new_graph.add_nodes_from(graph.items()) + graph = new_graph + G = LoopbackGraph() + elif not isinstance(graph, Graph): + raise TypeError( + f"Bad type for graph argument {graph_name} in {name}: {type(graph)}" + ) + elif graph.__class__ in {Graph, LoopbackGraph}: + G = LoopbackGraph() + elif graph.__class__ in {DiGraph, LoopbackDiGraph}: + G = LoopbackDiGraph() + elif graph.__class__ in {MultiGraph, LoopbackMultiGraph}: + G = LoopbackMultiGraph() + elif graph.__class__ in {MultiDiGraph, LoopbackMultiDiGraph}: + G = LoopbackMultiDiGraph() + elif graph.__class__ in {PlanarEmbedding, LoopbackPlanarEmbedding}: + G = LoopbackDiGraph() # or LoopbackPlanarEmbedding + else: + # Would be nice to handle these better some day + # nx.algorithms.approximation.kcomponents._AntiGraph + # nx.classes.tests.test_multidigraph.MultiDiGraphSubClass + # nx.classes.tests.test_multigraph.MultiGraphSubClass + G = graph.__class__() + + if preserve_graph_attrs: + G.graph.update(graph.graph) + + # add nodes + G.add_nodes_from(graph) + if preserve_node_attrs: + for n, dd in G._node.items(): + dd.update(graph.nodes[n]) + elif node_attrs: + for n, dd in G._node.items(): + dd.update( + (attr, graph._node[n].get(attr, default)) + for attr, default in node_attrs.items() + if default is not None or attr in graph._node[n] + ) + + # tools to build datadict and keydict + if preserve_edge_attrs: + + def G_new_datadict(old_dd): + return G.edge_attr_dict_factory(old_dd) + elif edge_attrs: + + def G_new_datadict(old_dd): + return G.edge_attr_dict_factory( + (attr, old_dd.get(attr, default)) + for attr, default in edge_attrs.items() + if default is not None or attr in old_dd + ) + else: + + def G_new_datadict(old_dd): + return G.edge_attr_dict_factory() + + if G.is_multigraph(): + + def G_new_inner(keydict): + kd = G.adjlist_inner_dict_factory( + (k, G_new_datadict(dd)) for k, dd in keydict.items() + ) + return kd + else: + G_new_inner = G_new_datadict + + # add edges keeping the same order in _adj and _pred + G_adj = G._adj + if G.is_directed(): + for n, nbrs in graph._adj.items(): + G_adj[n].update((nbr, G_new_inner(dd)) for nbr, dd in nbrs.items()) + # ensure same datadict for pred and adj; and pred order of graph._pred + G_pred = G._pred + for n, nbrs in graph._pred.items(): + G_pred[n].update((nbr, G_adj[nbr][n]) for nbr in nbrs) + else: # undirected + for n, nbrs in graph._adj.items(): + # ensure same datadict for both ways; and adj order of graph._adj + G_adj[n].update( + (nbr, G_adj[nbr][n] if n in G_adj[nbr] else G_new_inner(dd)) + for nbr, dd in nbrs.items() + ) + + return G + + @staticmethod + def convert_to_nx(obj, *, name=None): + return obj + + @staticmethod + def on_start_tests(items): + # Verify that items can be xfailed + for item in items: + assert hasattr(item, "add_marker") + + def can_run(self, name, args, kwargs): + # It is unnecessary to define this function if algorithms are fully supported. + # We include it for illustration purposes. + return hasattr(self, name) + + +backend_interface = LoopbackBackendInterface() diff --git a/lib/python3.10/site-packages/networkx/classes/tests/historical_tests.py b/lib/python3.10/site-packages/networkx/classes/tests/historical_tests.py new file mode 100644 index 0000000000000000000000000000000000000000..9dad24e2328408fd803cce8aa909604226184b31 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/classes/tests/historical_tests.py @@ -0,0 +1,475 @@ +"""Original NetworkX graph tests""" + +import pytest + +import networkx as nx +from networkx import convert_node_labels_to_integers as cnlti +from networkx.utils import edges_equal, nodes_equal + + +class HistoricalTests: + @classmethod + def setup_class(cls): + cls.null = nx.null_graph() + cls.P1 = cnlti(nx.path_graph(1), first_label=1) + cls.P3 = cnlti(nx.path_graph(3), first_label=1) + cls.P10 = cnlti(nx.path_graph(10), first_label=1) + cls.K1 = cnlti(nx.complete_graph(1), first_label=1) + cls.K3 = cnlti(nx.complete_graph(3), first_label=1) + cls.K4 = cnlti(nx.complete_graph(4), first_label=1) + cls.K5 = cnlti(nx.complete_graph(5), first_label=1) + cls.K10 = cnlti(nx.complete_graph(10), first_label=1) + cls.G = nx.Graph + + def test_name(self): + G = self.G(name="test") + assert G.name == "test" + H = self.G() + assert H.name == "" + + # Nodes + + def test_add_remove_node(self): + G = self.G() + G.add_node("A") + assert G.has_node("A") + G.remove_node("A") + assert not G.has_node("A") + + def test_nonhashable_node(self): + # Test if a non-hashable object is in the Graph. A python dict will + # raise a TypeError, but for a Graph class a simple False should be + # returned (see Graph __contains__). If it cannot be a node then it is + # not a node. + G = self.G() + assert not G.has_node(["A"]) + assert not G.has_node({"A": 1}) + + def test_add_nodes_from(self): + G = self.G() + G.add_nodes_from(list("ABCDEFGHIJKL")) + assert G.has_node("L") + G.remove_nodes_from(["H", "I", "J", "K", "L"]) + G.add_nodes_from([1, 2, 3, 4]) + assert sorted(G.nodes(), key=str) == [ + 1, + 2, + 3, + 4, + "A", + "B", + "C", + "D", + "E", + "F", + "G", + ] + # test __iter__ + assert sorted(G, key=str) == [1, 2, 3, 4, "A", "B", "C", "D", "E", "F", "G"] + + def test_contains(self): + G = self.G() + G.add_node("A") + assert "A" in G + assert [] not in G # never raise a Key or TypeError in this test + assert {1: 1} not in G + + def test_add_remove(self): + # Test add_node and remove_node acting for various nbunch + G = self.G() + G.add_node("m") + assert G.has_node("m") + G.add_node("m") # no complaints + pytest.raises(nx.NetworkXError, G.remove_node, "j") + G.remove_node("m") + assert list(G) == [] + + def test_nbunch_is_list(self): + G = self.G() + G.add_nodes_from(list("ABCD")) + G.add_nodes_from(self.P3) # add nbunch of nodes (nbunch=Graph) + assert sorted(G.nodes(), key=str) == [1, 2, 3, "A", "B", "C", "D"] + G.remove_nodes_from(self.P3) # remove nbunch of nodes (nbunch=Graph) + assert sorted(G.nodes(), key=str) == ["A", "B", "C", "D"] + + def test_nbunch_is_set(self): + G = self.G() + nbunch = set("ABCDEFGHIJKL") + G.add_nodes_from(nbunch) + assert G.has_node("L") + + def test_nbunch_dict(self): + # nbunch is a dict with nodes as keys + G = self.G() + nbunch = set("ABCDEFGHIJKL") + G.add_nodes_from(nbunch) + nbunch = {"I": "foo", "J": 2, "K": True, "L": "spam"} + G.remove_nodes_from(nbunch) + assert sorted(G.nodes(), key=str), ["A", "B", "C", "D", "E", "F", "G", "H"] + + def test_nbunch_iterator(self): + G = self.G() + G.add_nodes_from(["A", "B", "C", "D", "E", "F", "G", "H"]) + n_iter = self.P3.nodes() + G.add_nodes_from(n_iter) + assert sorted(G.nodes(), key=str) == [ + 1, + 2, + 3, + "A", + "B", + "C", + "D", + "E", + "F", + "G", + "H", + ] + n_iter = self.P3.nodes() # rebuild same iterator + G.remove_nodes_from(n_iter) # remove nbunch of nodes (nbunch=iterator) + assert sorted(G.nodes(), key=str) == ["A", "B", "C", "D", "E", "F", "G", "H"] + + def test_nbunch_graph(self): + G = self.G() + G.add_nodes_from(["A", "B", "C", "D", "E", "F", "G", "H"]) + nbunch = self.K3 + G.add_nodes_from(nbunch) + assert sorted(G.nodes(), key=str), [ + 1, + 2, + 3, + "A", + "B", + "C", + "D", + "E", + "F", + "G", + "H", + ] + + # Edges + + def test_add_edge(self): + G = self.G() + pytest.raises(TypeError, G.add_edge, "A") + + G.add_edge("A", "B") # testing add_edge() + G.add_edge("A", "B") # should fail silently + assert G.has_edge("A", "B") + assert not G.has_edge("A", "C") + assert G.has_edge(*("A", "B")) + if G.is_directed(): + assert not G.has_edge("B", "A") + else: + # G is undirected, so B->A is an edge + assert G.has_edge("B", "A") + + G.add_edge("A", "C") # test directedness + G.add_edge("C", "A") + G.remove_edge("C", "A") + if G.is_directed(): + assert G.has_edge("A", "C") + else: + assert not G.has_edge("A", "C") + assert not G.has_edge("C", "A") + + def test_self_loop(self): + G = self.G() + G.add_edge("A", "A") # test self loops + assert G.has_edge("A", "A") + G.remove_edge("A", "A") + G.add_edge("X", "X") + assert G.has_node("X") + G.remove_node("X") + G.add_edge("A", "Z") # should add the node silently + assert G.has_node("Z") + + def test_add_edges_from(self): + G = self.G() + G.add_edges_from([("B", "C")]) # test add_edges_from() + assert G.has_edge("B", "C") + if G.is_directed(): + assert not G.has_edge("C", "B") + else: + assert G.has_edge("C", "B") # undirected + + G.add_edges_from([("D", "F"), ("B", "D")]) + assert G.has_edge("D", "F") + assert G.has_edge("B", "D") + + if G.is_directed(): + assert not G.has_edge("D", "B") + else: + assert G.has_edge("D", "B") # undirected + + def test_add_edges_from2(self): + G = self.G() + # after failing silently, should add 2nd edge + G.add_edges_from([tuple("IJ"), list("KK"), tuple("JK")]) + assert G.has_edge(*("I", "J")) + assert G.has_edge(*("K", "K")) + assert G.has_edge(*("J", "K")) + if G.is_directed(): + assert not G.has_edge(*("K", "J")) + else: + assert G.has_edge(*("K", "J")) + + def test_add_edges_from3(self): + G = self.G() + G.add_edges_from(zip(list("ACD"), list("CDE"))) + assert G.has_edge("D", "E") + assert not G.has_edge("E", "C") + + def test_remove_edge(self): + G = self.G() + G.add_nodes_from([1, 2, 3, "A", "B", "C", "D", "E", "F", "G", "H"]) + + G.add_edges_from(zip(list("MNOP"), list("NOPM"))) + assert G.has_edge("O", "P") + assert G.has_edge("P", "M") + G.remove_node("P") # tests remove_node()'s handling of edges. + assert not G.has_edge("P", "M") + pytest.raises(TypeError, G.remove_edge, "M") + + G.add_edge("N", "M") + assert G.has_edge("M", "N") + G.remove_edge("M", "N") + assert not G.has_edge("M", "N") + + # self loop fails silently + G.remove_edges_from([list("HI"), list("DF"), tuple("KK"), tuple("JK")]) + assert not G.has_edge("H", "I") + assert not G.has_edge("J", "K") + G.remove_edges_from([list("IJ"), list("KK"), list("JK")]) + assert not G.has_edge("I", "J") + G.remove_nodes_from(set("ZEFHIMNO")) + G.add_edge("J", "K") + + def test_edges_nbunch(self): + # Test G.edges(nbunch) with various forms of nbunch + G = self.G() + G.add_edges_from([("A", "B"), ("A", "C"), ("B", "D"), ("C", "B"), ("C", "D")]) + # node not in nbunch should be quietly ignored + pytest.raises(nx.NetworkXError, G.edges, 6) + assert list(G.edges("Z")) == [] # iterable non-node + # nbunch can be an empty list + assert list(G.edges([])) == [] + if G.is_directed(): + elist = [("A", "B"), ("A", "C"), ("B", "D")] + else: + elist = [("A", "B"), ("A", "C"), ("B", "C"), ("B", "D")] + # nbunch can be a list + assert edges_equal(list(G.edges(["A", "B"])), elist) + # nbunch can be a set + assert edges_equal(G.edges({"A", "B"}), elist) + # nbunch can be a graph + G1 = self.G() + G1.add_nodes_from("AB") + assert edges_equal(G.edges(G1), elist) + # nbunch can be a dict with nodes as keys + ndict = {"A": "thing1", "B": "thing2"} + assert edges_equal(G.edges(ndict), elist) + # nbunch can be a single node + assert edges_equal(list(G.edges("A")), [("A", "B"), ("A", "C")]) + assert nodes_equal(sorted(G), ["A", "B", "C", "D"]) + + # nbunch can be nothing (whole graph) + assert edges_equal( + list(G.edges()), + [("A", "B"), ("A", "C"), ("B", "D"), ("C", "B"), ("C", "D")], + ) + + def test_degree(self): + G = self.G() + G.add_edges_from([("A", "B"), ("A", "C"), ("B", "D"), ("C", "B"), ("C", "D")]) + assert G.degree("A") == 2 + + # degree of single node in iterable container must return dict + assert list(G.degree(["A"])) == [("A", 2)] + assert sorted(d for n, d in G.degree(["A", "B"])) == [2, 3] + assert sorted(d for n, d in G.degree()) == [2, 2, 3, 3] + + def test_degree2(self): + H = self.G() + H.add_edges_from([(1, 24), (1, 2)]) + assert sorted(d for n, d in H.degree([1, 24])) == [1, 2] + + def test_degree_graph(self): + P3 = nx.path_graph(3) + P5 = nx.path_graph(5) + # silently ignore nodes not in P3 + assert dict(d for n, d in P3.degree(["A", "B"])) == {} + # nbunch can be a graph + assert sorted(d for n, d in P5.degree(P3)) == [1, 2, 2] + # nbunch can be a graph that's way too big + assert sorted(d for n, d in P3.degree(P5)) == [1, 1, 2] + assert list(P5.degree([])) == [] + assert dict(P5.degree([])) == {} + + def test_null(self): + null = nx.null_graph() + assert list(null.degree()) == [] + assert dict(null.degree()) == {} + + def test_order_size(self): + G = self.G() + G.add_edges_from([("A", "B"), ("A", "C"), ("B", "D"), ("C", "B"), ("C", "D")]) + assert G.order() == 4 + assert G.size() == 5 + assert G.number_of_edges() == 5 + assert G.number_of_edges("A", "B") == 1 + assert G.number_of_edges("A", "D") == 0 + + def test_copy(self): + G = self.G() + H = G.copy() # copy + assert H.adj == G.adj + assert H.name == G.name + assert H is not G + + def test_subgraph(self): + G = self.G() + G.add_edges_from([("A", "B"), ("A", "C"), ("B", "D"), ("C", "B"), ("C", "D")]) + SG = G.subgraph(["A", "B", "D"]) + assert nodes_equal(list(SG), ["A", "B", "D"]) + assert edges_equal(list(SG.edges()), [("A", "B"), ("B", "D")]) + + def test_to_directed(self): + G = self.G() + if not G.is_directed(): + G.add_edges_from( + [("A", "B"), ("A", "C"), ("B", "D"), ("C", "B"), ("C", "D")] + ) + + DG = G.to_directed() + assert DG is not G # directed copy or copy + + assert DG.is_directed() + assert DG.name == G.name + assert DG.adj == G.adj + assert sorted(DG.out_edges(list("AB"))) == [ + ("A", "B"), + ("A", "C"), + ("B", "A"), + ("B", "C"), + ("B", "D"), + ] + DG.remove_edge("A", "B") + assert DG.has_edge("B", "A") # this removes B-A but not A-B + assert not DG.has_edge("A", "B") + + def test_to_undirected(self): + G = self.G() + if G.is_directed(): + G.add_edges_from( + [("A", "B"), ("A", "C"), ("B", "D"), ("C", "B"), ("C", "D")] + ) + UG = G.to_undirected() # to_undirected + assert UG is not G + assert not UG.is_directed() + assert G.is_directed() + assert UG.name == G.name + assert UG.adj != G.adj + assert sorted(UG.edges(list("AB"))) == [ + ("A", "B"), + ("A", "C"), + ("B", "C"), + ("B", "D"), + ] + assert sorted(UG.edges(["A", "B"])) == [ + ("A", "B"), + ("A", "C"), + ("B", "C"), + ("B", "D"), + ] + UG.remove_edge("A", "B") + assert not UG.has_edge("B", "A") + assert not UG.has_edge("A", "B") + + def test_neighbors(self): + G = self.G() + G.add_edges_from([("A", "B"), ("A", "C"), ("B", "D"), ("C", "B"), ("C", "D")]) + G.add_nodes_from("GJK") + assert sorted(G["A"]) == ["B", "C"] + assert sorted(G.neighbors("A")) == ["B", "C"] + assert sorted(G.neighbors("A")) == ["B", "C"] + assert sorted(G.neighbors("G")) == [] + pytest.raises(nx.NetworkXError, G.neighbors, "j") + + def test_iterators(self): + G = self.G() + G.add_edges_from([("A", "B"), ("A", "C"), ("B", "D"), ("C", "B"), ("C", "D")]) + G.add_nodes_from("GJK") + assert sorted(G.nodes()) == ["A", "B", "C", "D", "G", "J", "K"] + assert edges_equal( + G.edges(), [("A", "B"), ("A", "C"), ("B", "D"), ("C", "B"), ("C", "D")] + ) + + assert sorted(v for k, v in G.degree()) == [0, 0, 0, 2, 2, 3, 3] + assert sorted(G.degree(), key=str) == [ + ("A", 2), + ("B", 3), + ("C", 3), + ("D", 2), + ("G", 0), + ("J", 0), + ("K", 0), + ] + assert sorted(G.neighbors("A")) == ["B", "C"] + pytest.raises(nx.NetworkXError, G.neighbors, "X") + G.clear() + assert nx.number_of_nodes(G) == 0 + assert nx.number_of_edges(G) == 0 + + def test_null_subgraph(self): + # Subgraph of a null graph is a null graph + nullgraph = nx.null_graph() + G = nx.null_graph() + H = G.subgraph([]) + assert nx.is_isomorphic(H, nullgraph) + + def test_empty_subgraph(self): + # Subgraph of an empty graph is an empty graph. test 1 + nullgraph = nx.null_graph() + E5 = nx.empty_graph(5) + E10 = nx.empty_graph(10) + H = E10.subgraph([]) + assert nx.is_isomorphic(H, nullgraph) + H = E10.subgraph([1, 2, 3, 4, 5]) + assert nx.is_isomorphic(H, E5) + + def test_complete_subgraph(self): + # Subgraph of a complete graph is a complete graph + K1 = nx.complete_graph(1) + K3 = nx.complete_graph(3) + K5 = nx.complete_graph(5) + H = K5.subgraph([1, 2, 3]) + assert nx.is_isomorphic(H, K3) + + def test_subgraph_nbunch(self): + nullgraph = nx.null_graph() + K1 = nx.complete_graph(1) + K3 = nx.complete_graph(3) + K5 = nx.complete_graph(5) + # Test G.subgraph(nbunch), where nbunch is a single node + H = K5.subgraph(1) + assert nx.is_isomorphic(H, K1) + # Test G.subgraph(nbunch), where nbunch is a set + H = K5.subgraph({1}) + assert nx.is_isomorphic(H, K1) + # Test G.subgraph(nbunch), where nbunch is an iterator + H = K5.subgraph(iter(K3)) + assert nx.is_isomorphic(H, K3) + # Test G.subgraph(nbunch), where nbunch is another graph + H = K5.subgraph(K3) + assert nx.is_isomorphic(H, K3) + H = K5.subgraph([9]) + assert nx.is_isomorphic(H, nullgraph) + + def test_node_tuple_issue(self): + H = self.G() + # Test error handling of tuple as a node + pytest.raises(nx.NetworkXError, H.remove_node, (1, 2)) + H.remove_nodes_from([(1, 2)]) # no error + pytest.raises(nx.NetworkXError, H.neighbors, (1, 2)) diff --git a/lib/python3.10/site-packages/networkx/classes/tests/test_coreviews.py b/lib/python3.10/site-packages/networkx/classes/tests/test_coreviews.py new file mode 100644 index 0000000000000000000000000000000000000000..24de7f2f1115b864682b261daa256eff0deef696 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/classes/tests/test_coreviews.py @@ -0,0 +1,362 @@ +import pickle + +import pytest + +import networkx as nx + + +class TestAtlasView: + # node->data + def setup_method(self): + self.d = {0: {"color": "blue", "weight": 1.2}, 1: {}, 2: {"color": 1}} + self.av = nx.classes.coreviews.AtlasView(self.d) + + def test_pickle(self): + view = self.av + pview = pickle.loads(pickle.dumps(view, -1)) + assert view == pview + assert view.__slots__ == pview.__slots__ + pview = pickle.loads(pickle.dumps(view)) + assert view == pview + assert view.__slots__ == pview.__slots__ + + def test_len(self): + assert len(self.av) == len(self.d) + + def test_iter(self): + assert list(self.av) == list(self.d) + + def test_getitem(self): + assert self.av[1] is self.d[1] + assert self.av[2]["color"] == 1 + pytest.raises(KeyError, self.av.__getitem__, 3) + + def test_copy(self): + avcopy = self.av.copy() + assert avcopy[0] == self.av[0] + assert avcopy == self.av + assert avcopy[0] is not self.av[0] + assert avcopy is not self.av + avcopy[5] = {} + assert avcopy != self.av + + avcopy[0]["ht"] = 4 + assert avcopy[0] != self.av[0] + self.av[0]["ht"] = 4 + assert avcopy[0] == self.av[0] + del self.av[0]["ht"] + + assert not hasattr(self.av, "__setitem__") + + def test_items(self): + assert sorted(self.av.items()) == sorted(self.d.items()) + + def test_str(self): + out = str(self.d) + assert str(self.av) == out + + def test_repr(self): + out = "AtlasView(" + str(self.d) + ")" + assert repr(self.av) == out + + +class TestAdjacencyView: + # node->nbr->data + def setup_method(self): + dd = {"color": "blue", "weight": 1.2} + self.nd = {0: dd, 1: {}, 2: {"color": 1}} + self.adj = {3: self.nd, 0: {3: dd}, 1: {}, 2: {3: {"color": 1}}} + self.adjview = nx.classes.coreviews.AdjacencyView(self.adj) + + def test_pickle(self): + view = self.adjview + pview = pickle.loads(pickle.dumps(view, -1)) + assert view == pview + assert view.__slots__ == pview.__slots__ + + def test_len(self): + assert len(self.adjview) == len(self.adj) + + def test_iter(self): + assert list(self.adjview) == list(self.adj) + + def test_getitem(self): + assert self.adjview[1] is not self.adj[1] + assert self.adjview[3][0] is self.adjview[0][3] + assert self.adjview[2][3]["color"] == 1 + pytest.raises(KeyError, self.adjview.__getitem__, 4) + + def test_copy(self): + avcopy = self.adjview.copy() + assert avcopy[0] == self.adjview[0] + assert avcopy[0] is not self.adjview[0] + + avcopy[2][3]["ht"] = 4 + assert avcopy[2] != self.adjview[2] + self.adjview[2][3]["ht"] = 4 + assert avcopy[2] == self.adjview[2] + del self.adjview[2][3]["ht"] + + assert not hasattr(self.adjview, "__setitem__") + + def test_items(self): + view_items = sorted((n, dict(d)) for n, d in self.adjview.items()) + assert view_items == sorted(self.adj.items()) + + def test_str(self): + out = str(dict(self.adj)) + assert str(self.adjview) == out + + def test_repr(self): + out = self.adjview.__class__.__name__ + "(" + str(self.adj) + ")" + assert repr(self.adjview) == out + + +class TestMultiAdjacencyView(TestAdjacencyView): + # node->nbr->key->data + def setup_method(self): + dd = {"color": "blue", "weight": 1.2} + self.kd = {0: dd, 1: {}, 2: {"color": 1}} + self.nd = {3: self.kd, 0: {3: dd}, 1: {0: {}}, 2: {3: {"color": 1}}} + self.adj = {3: self.nd, 0: {3: {3: dd}}, 1: {}, 2: {3: {8: {}}}} + self.adjview = nx.classes.coreviews.MultiAdjacencyView(self.adj) + + def test_getitem(self): + assert self.adjview[1] is not self.adj[1] + assert self.adjview[3][0][3] is self.adjview[0][3][3] + assert self.adjview[3][2][3]["color"] == 1 + pytest.raises(KeyError, self.adjview.__getitem__, 4) + + def test_copy(self): + avcopy = self.adjview.copy() + assert avcopy[0] == self.adjview[0] + assert avcopy[0] is not self.adjview[0] + + avcopy[2][3][8]["ht"] = 4 + assert avcopy[2] != self.adjview[2] + self.adjview[2][3][8]["ht"] = 4 + assert avcopy[2] == self.adjview[2] + del self.adjview[2][3][8]["ht"] + + assert not hasattr(self.adjview, "__setitem__") + + +class TestUnionAtlas: + # node->data + def setup_method(self): + self.s = {0: {"color": "blue", "weight": 1.2}, 1: {}, 2: {"color": 1}} + self.p = {3: {"color": "blue", "weight": 1.2}, 4: {}, 2: {"watch": 2}} + self.av = nx.classes.coreviews.UnionAtlas(self.s, self.p) + + def test_pickle(self): + view = self.av + pview = pickle.loads(pickle.dumps(view, -1)) + assert view == pview + assert view.__slots__ == pview.__slots__ + + def test_len(self): + assert len(self.av) == len(self.s.keys() | self.p.keys()) == 5 + + def test_iter(self): + assert set(self.av) == set(self.s) | set(self.p) + + def test_getitem(self): + assert self.av[0] is self.s[0] + assert self.av[4] is self.p[4] + assert self.av[2]["color"] == 1 + pytest.raises(KeyError, self.av[2].__getitem__, "watch") + pytest.raises(KeyError, self.av.__getitem__, 8) + + def test_copy(self): + avcopy = self.av.copy() + assert avcopy[0] == self.av[0] + assert avcopy[0] is not self.av[0] + assert avcopy is not self.av + avcopy[5] = {} + assert avcopy != self.av + + avcopy[0]["ht"] = 4 + assert avcopy[0] != self.av[0] + self.av[0]["ht"] = 4 + assert avcopy[0] == self.av[0] + del self.av[0]["ht"] + + assert not hasattr(self.av, "__setitem__") + + def test_items(self): + expected = dict(self.p.items()) + expected.update(self.s) + assert sorted(self.av.items()) == sorted(expected.items()) + + def test_str(self): + out = str(dict(self.av)) + assert str(self.av) == out + + def test_repr(self): + out = f"{self.av.__class__.__name__}({self.s}, {self.p})" + assert repr(self.av) == out + + +class TestUnionAdjacency: + # node->nbr->data + def setup_method(self): + dd = {"color": "blue", "weight": 1.2} + self.nd = {0: dd, 1: {}, 2: {"color": 1}} + self.s = {3: self.nd, 0: {}, 1: {}, 2: {3: {"color": 1}}} + self.p = {3: {}, 0: {3: dd}, 1: {0: {}}, 2: {1: {"color": 1}}} + self.adjview = nx.classes.coreviews.UnionAdjacency(self.s, self.p) + + def test_pickle(self): + view = self.adjview + pview = pickle.loads(pickle.dumps(view, -1)) + assert view == pview + assert view.__slots__ == pview.__slots__ + + def test_len(self): + assert len(self.adjview) == len(self.s) + + def test_iter(self): + assert sorted(self.adjview) == sorted(self.s) + + def test_getitem(self): + assert self.adjview[1] is not self.s[1] + assert self.adjview[3][0] is self.adjview[0][3] + assert self.adjview[2][3]["color"] == 1 + pytest.raises(KeyError, self.adjview.__getitem__, 4) + + def test_copy(self): + avcopy = self.adjview.copy() + assert avcopy[0] == self.adjview[0] + assert avcopy[0] is not self.adjview[0] + + avcopy[2][3]["ht"] = 4 + assert avcopy[2] != self.adjview[2] + self.adjview[2][3]["ht"] = 4 + assert avcopy[2] == self.adjview[2] + del self.adjview[2][3]["ht"] + + assert not hasattr(self.adjview, "__setitem__") + + def test_str(self): + out = str(dict(self.adjview)) + assert str(self.adjview) == out + + def test_repr(self): + clsname = self.adjview.__class__.__name__ + out = f"{clsname}({self.s}, {self.p})" + assert repr(self.adjview) == out + + +class TestUnionMultiInner(TestUnionAdjacency): + # nbr->key->data + def setup_method(self): + dd = {"color": "blue", "weight": 1.2} + self.kd = {7: {}, "ekey": {}, 9: {"color": 1}} + self.s = {3: self.kd, 0: {7: dd}, 1: {}, 2: {"key": {"color": 1}}} + self.p = {3: {}, 0: {3: dd}, 1: {}, 2: {1: {"span": 2}}} + self.adjview = nx.classes.coreviews.UnionMultiInner(self.s, self.p) + + def test_len(self): + assert len(self.adjview) == len(self.s.keys() | self.p.keys()) == 4 + + def test_getitem(self): + assert self.adjview[1] is not self.s[1] + assert self.adjview[0][7] is self.adjview[0][3] + assert self.adjview[2]["key"]["color"] == 1 + assert self.adjview[2][1]["span"] == 2 + pytest.raises(KeyError, self.adjview.__getitem__, 4) + pytest.raises(KeyError, self.adjview[1].__getitem__, "key") + + def test_copy(self): + avcopy = self.adjview.copy() + assert avcopy[0] == self.adjview[0] + assert avcopy[0] is not self.adjview[0] + + avcopy[2][1]["width"] = 8 + assert avcopy[2] != self.adjview[2] + self.adjview[2][1]["width"] = 8 + assert avcopy[2] == self.adjview[2] + del self.adjview[2][1]["width"] + + assert not hasattr(self.adjview, "__setitem__") + assert hasattr(avcopy, "__setitem__") + + +class TestUnionMultiAdjacency(TestUnionAdjacency): + # node->nbr->key->data + def setup_method(self): + dd = {"color": "blue", "weight": 1.2} + self.kd = {7: {}, 8: {}, 9: {"color": 1}} + self.nd = {3: self.kd, 0: {9: dd}, 1: {8: {}}, 2: {9: {"color": 1}}} + self.s = {3: self.nd, 0: {3: {7: dd}}, 1: {}, 2: {3: {8: {}}}} + self.p = {3: {}, 0: {3: {9: dd}}, 1: {}, 2: {1: {8: {}}}} + self.adjview = nx.classes.coreviews.UnionMultiAdjacency(self.s, self.p) + + def test_getitem(self): + assert self.adjview[1] is not self.s[1] + assert self.adjview[3][0][9] is self.adjview[0][3][9] + assert self.adjview[3][2][9]["color"] == 1 + pytest.raises(KeyError, self.adjview.__getitem__, 4) + + def test_copy(self): + avcopy = self.adjview.copy() + assert avcopy[0] == self.adjview[0] + assert avcopy[0] is not self.adjview[0] + + avcopy[2][3][8]["ht"] = 4 + assert avcopy[2] != self.adjview[2] + self.adjview[2][3][8]["ht"] = 4 + assert avcopy[2] == self.adjview[2] + del self.adjview[2][3][8]["ht"] + + assert not hasattr(self.adjview, "__setitem__") + assert hasattr(avcopy, "__setitem__") + + +class TestFilteredGraphs: + def setup_method(self): + self.Graphs = [nx.Graph, nx.DiGraph, nx.MultiGraph, nx.MultiDiGraph] + + def test_hide_show_nodes(self): + SubGraph = nx.subgraph_view + for Graph in self.Graphs: + G = nx.path_graph(4, Graph) + SG = G.subgraph([2, 3]) + RG = SubGraph(G, filter_node=nx.filters.hide_nodes([0, 1])) + assert SG.nodes == RG.nodes + assert SG.edges == RG.edges + SGC = SG.copy() + RGC = RG.copy() + assert SGC.nodes == RGC.nodes + assert SGC.edges == RGC.edges + + def test_str_repr(self): + SubGraph = nx.subgraph_view + for Graph in self.Graphs: + G = nx.path_graph(4, Graph) + SG = G.subgraph([2, 3]) + RG = SubGraph(G, filter_node=nx.filters.hide_nodes([0, 1])) + str(SG.adj) + str(RG.adj) + repr(SG.adj) + repr(RG.adj) + str(SG.adj[2]) + str(RG.adj[2]) + repr(SG.adj[2]) + repr(RG.adj[2]) + + def test_copy(self): + SubGraph = nx.subgraph_view + for Graph in self.Graphs: + G = nx.path_graph(4, Graph) + SG = G.subgraph([2, 3]) + RG = SubGraph(G, filter_node=nx.filters.hide_nodes([0, 1])) + RsG = SubGraph(G, filter_node=nx.filters.show_nodes([2, 3])) + assert G.adj.copy() == G.adj + assert G.adj[2].copy() == G.adj[2] + assert SG.adj.copy() == SG.adj + assert SG.adj[2].copy() == SG.adj[2] + assert RG.adj.copy() == RG.adj + assert RG.adj[2].copy() == RG.adj[2] + assert RsG.adj.copy() == RsG.adj + assert RsG.adj[2].copy() == RsG.adj[2] diff --git a/lib/python3.10/site-packages/networkx/classes/tests/test_digraph.py b/lib/python3.10/site-packages/networkx/classes/tests/test_digraph.py new file mode 100644 index 0000000000000000000000000000000000000000..b9972f9a5f1ab101b9f6f2f9a1584ddafccd2ff3 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/classes/tests/test_digraph.py @@ -0,0 +1,331 @@ +import pytest + +import networkx as nx +from networkx.utils import nodes_equal + +from .test_graph import BaseAttrGraphTester, BaseGraphTester +from .test_graph import TestEdgeSubgraph as _TestGraphEdgeSubgraph +from .test_graph import TestGraph as _TestGraph + + +class BaseDiGraphTester(BaseGraphTester): + def test_has_successor(self): + G = self.K3 + assert G.has_successor(0, 1) + assert not G.has_successor(0, -1) + + def test_successors(self): + G = self.K3 + assert sorted(G.successors(0)) == [1, 2] + with pytest.raises(nx.NetworkXError): + G.successors(-1) + + def test_has_predecessor(self): + G = self.K3 + assert G.has_predecessor(0, 1) + assert not G.has_predecessor(0, -1) + + def test_predecessors(self): + G = self.K3 + assert sorted(G.predecessors(0)) == [1, 2] + with pytest.raises(nx.NetworkXError): + G.predecessors(-1) + + def test_edges(self): + G = self.K3 + assert sorted(G.edges()) == [(0, 1), (0, 2), (1, 0), (1, 2), (2, 0), (2, 1)] + assert sorted(G.edges(0)) == [(0, 1), (0, 2)] + assert sorted(G.edges([0, 1])) == [(0, 1), (0, 2), (1, 0), (1, 2)] + with pytest.raises(nx.NetworkXError): + G.edges(-1) + + def test_out_edges(self): + G = self.K3 + assert sorted(G.out_edges()) == [(0, 1), (0, 2), (1, 0), (1, 2), (2, 0), (2, 1)] + assert sorted(G.out_edges(0)) == [(0, 1), (0, 2)] + with pytest.raises(nx.NetworkXError): + G.out_edges(-1) + + def test_out_edges_dir(self): + G = self.P3 + assert sorted(G.out_edges()) == [(0, 1), (1, 2)] + assert sorted(G.out_edges(0)) == [(0, 1)] + assert sorted(G.out_edges(2)) == [] + + def test_out_edges_data(self): + G = nx.DiGraph([(0, 1, {"data": 0}), (1, 0, {})]) + assert sorted(G.out_edges(data=True)) == [(0, 1, {"data": 0}), (1, 0, {})] + assert sorted(G.out_edges(0, data=True)) == [(0, 1, {"data": 0})] + assert sorted(G.out_edges(data="data")) == [(0, 1, 0), (1, 0, None)] + assert sorted(G.out_edges(0, data="data")) == [(0, 1, 0)] + + def test_in_edges_dir(self): + G = self.P3 + assert sorted(G.in_edges()) == [(0, 1), (1, 2)] + assert sorted(G.in_edges(0)) == [] + assert sorted(G.in_edges(2)) == [(1, 2)] + + def test_in_edges_data(self): + G = nx.DiGraph([(0, 1, {"data": 0}), (1, 0, {})]) + assert sorted(G.in_edges(data=True)) == [(0, 1, {"data": 0}), (1, 0, {})] + assert sorted(G.in_edges(1, data=True)) == [(0, 1, {"data": 0})] + assert sorted(G.in_edges(data="data")) == [(0, 1, 0), (1, 0, None)] + assert sorted(G.in_edges(1, data="data")) == [(0, 1, 0)] + + def test_degree(self): + G = self.K3 + assert sorted(G.degree()) == [(0, 4), (1, 4), (2, 4)] + assert dict(G.degree()) == {0: 4, 1: 4, 2: 4} + assert G.degree(0) == 4 + assert list(G.degree(iter([0]))) == [(0, 4)] # run through iterator + + def test_in_degree(self): + G = self.K3 + assert sorted(G.in_degree()) == [(0, 2), (1, 2), (2, 2)] + assert dict(G.in_degree()) == {0: 2, 1: 2, 2: 2} + assert G.in_degree(0) == 2 + assert list(G.in_degree(iter([0]))) == [(0, 2)] # run through iterator + + def test_out_degree(self): + G = self.K3 + assert sorted(G.out_degree()) == [(0, 2), (1, 2), (2, 2)] + assert dict(G.out_degree()) == {0: 2, 1: 2, 2: 2} + assert G.out_degree(0) == 2 + assert list(G.out_degree(iter([0]))) == [(0, 2)] + + def test_size(self): + G = self.K3 + assert G.size() == 6 + assert G.number_of_edges() == 6 + + def test_to_undirected_reciprocal(self): + G = self.Graph() + G.add_edge(1, 2) + assert G.to_undirected().has_edge(1, 2) + assert not G.to_undirected(reciprocal=True).has_edge(1, 2) + G.add_edge(2, 1) + assert G.to_undirected(reciprocal=True).has_edge(1, 2) + + def test_reverse_copy(self): + G = nx.DiGraph([(0, 1), (1, 2)]) + R = G.reverse() + assert sorted(R.edges()) == [(1, 0), (2, 1)] + R.remove_edge(1, 0) + assert sorted(R.edges()) == [(2, 1)] + assert sorted(G.edges()) == [(0, 1), (1, 2)] + + def test_reverse_nocopy(self): + G = nx.DiGraph([(0, 1), (1, 2)]) + R = G.reverse(copy=False) + assert sorted(R.edges()) == [(1, 0), (2, 1)] + with pytest.raises(nx.NetworkXError): + R.remove_edge(1, 0) + + def test_reverse_hashable(self): + class Foo: + pass + + x = Foo() + y = Foo() + G = nx.DiGraph() + G.add_edge(x, y) + assert nodes_equal(G.nodes(), G.reverse().nodes()) + assert [(y, x)] == list(G.reverse().edges()) + + def test_di_cache_reset(self): + G = self.K3.copy() + old_succ = G.succ + assert id(G.succ) == id(old_succ) + old_adj = G.adj + assert id(G.adj) == id(old_adj) + + G._succ = {} + assert id(G.succ) != id(old_succ) + assert id(G.adj) != id(old_adj) + + old_pred = G.pred + assert id(G.pred) == id(old_pred) + G._pred = {} + assert id(G.pred) != id(old_pred) + + def test_di_attributes_cached(self): + G = self.K3.copy() + assert id(G.in_edges) == id(G.in_edges) + assert id(G.out_edges) == id(G.out_edges) + assert id(G.in_degree) == id(G.in_degree) + assert id(G.out_degree) == id(G.out_degree) + assert id(G.succ) == id(G.succ) + assert id(G.pred) == id(G.pred) + + +class BaseAttrDiGraphTester(BaseDiGraphTester, BaseAttrGraphTester): + def test_edges_data(self): + G = self.K3 + all_edges = [ + (0, 1, {}), + (0, 2, {}), + (1, 0, {}), + (1, 2, {}), + (2, 0, {}), + (2, 1, {}), + ] + assert sorted(G.edges(data=True)) == all_edges + assert sorted(G.edges(0, data=True)) == all_edges[:2] + assert sorted(G.edges([0, 1], data=True)) == all_edges[:4] + with pytest.raises(nx.NetworkXError): + G.edges(-1, True) + + def test_in_degree_weighted(self): + G = self.K3.copy() + G.add_edge(0, 1, weight=0.3, other=1.2) + assert sorted(G.in_degree(weight="weight")) == [(0, 2), (1, 1.3), (2, 2)] + assert dict(G.in_degree(weight="weight")) == {0: 2, 1: 1.3, 2: 2} + assert G.in_degree(1, weight="weight") == 1.3 + assert sorted(G.in_degree(weight="other")) == [(0, 2), (1, 2.2), (2, 2)] + assert dict(G.in_degree(weight="other")) == {0: 2, 1: 2.2, 2: 2} + assert G.in_degree(1, weight="other") == 2.2 + assert list(G.in_degree(iter([1]), weight="other")) == [(1, 2.2)] + + def test_out_degree_weighted(self): + G = self.K3.copy() + G.add_edge(0, 1, weight=0.3, other=1.2) + assert sorted(G.out_degree(weight="weight")) == [(0, 1.3), (1, 2), (2, 2)] + assert dict(G.out_degree(weight="weight")) == {0: 1.3, 1: 2, 2: 2} + assert G.out_degree(0, weight="weight") == 1.3 + assert sorted(G.out_degree(weight="other")) == [(0, 2.2), (1, 2), (2, 2)] + assert dict(G.out_degree(weight="other")) == {0: 2.2, 1: 2, 2: 2} + assert G.out_degree(0, weight="other") == 2.2 + assert list(G.out_degree(iter([0]), weight="other")) == [(0, 2.2)] + + +class TestDiGraph(BaseAttrDiGraphTester, _TestGraph): + """Tests specific to dict-of-dict-of-dict digraph data structure""" + + def setup_method(self): + self.Graph = nx.DiGraph + # build dict-of-dict-of-dict K3 + ed1, ed2, ed3, ed4, ed5, ed6 = ({}, {}, {}, {}, {}, {}) + self.k3adj = {0: {1: ed1, 2: ed2}, 1: {0: ed3, 2: ed4}, 2: {0: ed5, 1: ed6}} + self.k3edges = [(0, 1), (0, 2), (1, 2)] + self.k3nodes = [0, 1, 2] + self.K3 = self.Graph() + self.K3._succ = self.k3adj # K3._adj is synced with K3._succ + self.K3._pred = {0: {1: ed3, 2: ed5}, 1: {0: ed1, 2: ed6}, 2: {0: ed2, 1: ed4}} + self.K3._node = {} + self.K3._node[0] = {} + self.K3._node[1] = {} + self.K3._node[2] = {} + + ed1, ed2 = ({}, {}) + self.P3 = self.Graph() + self.P3._succ = {0: {1: ed1}, 1: {2: ed2}, 2: {}} + self.P3._pred = {0: {}, 1: {0: ed1}, 2: {1: ed2}} + # P3._adj is synced with P3._succ + self.P3._node = {} + self.P3._node[0] = {} + self.P3._node[1] = {} + self.P3._node[2] = {} + + def test_data_input(self): + G = self.Graph({1: [2], 2: [1]}, name="test") + assert G.name == "test" + assert sorted(G.adj.items()) == [(1, {2: {}}), (2, {1: {}})] + assert sorted(G.succ.items()) == [(1, {2: {}}), (2, {1: {}})] + assert sorted(G.pred.items()) == [(1, {2: {}}), (2, {1: {}})] + + def test_add_edge(self): + G = self.Graph() + G.add_edge(0, 1) + assert G.adj == {0: {1: {}}, 1: {}} + assert G.succ == {0: {1: {}}, 1: {}} + assert G.pred == {0: {}, 1: {0: {}}} + G = self.Graph() + G.add_edge(*(0, 1)) + assert G.adj == {0: {1: {}}, 1: {}} + assert G.succ == {0: {1: {}}, 1: {}} + assert G.pred == {0: {}, 1: {0: {}}} + with pytest.raises(ValueError, match="None cannot be a node"): + G.add_edge(None, 3) + + def test_add_edges_from(self): + G = self.Graph() + G.add_edges_from([(0, 1), (0, 2, {"data": 3})], data=2) + assert G.adj == {0: {1: {"data": 2}, 2: {"data": 3}}, 1: {}, 2: {}} + assert G.succ == {0: {1: {"data": 2}, 2: {"data": 3}}, 1: {}, 2: {}} + assert G.pred == {0: {}, 1: {0: {"data": 2}}, 2: {0: {"data": 3}}} + + with pytest.raises(nx.NetworkXError): + G.add_edges_from([(0,)]) # too few in tuple + with pytest.raises(nx.NetworkXError): + G.add_edges_from([(0, 1, 2, 3)]) # too many in tuple + with pytest.raises(TypeError): + G.add_edges_from([0]) # not a tuple + with pytest.raises(ValueError, match="None cannot be a node"): + G.add_edges_from([(None, 3), (3, 2)]) + + def test_remove_edge(self): + G = self.K3.copy() + G.remove_edge(0, 1) + assert G.succ == {0: {2: {}}, 1: {0: {}, 2: {}}, 2: {0: {}, 1: {}}} + assert G.pred == {0: {1: {}, 2: {}}, 1: {2: {}}, 2: {0: {}, 1: {}}} + with pytest.raises(nx.NetworkXError): + G.remove_edge(-1, 0) + + def test_remove_edges_from(self): + G = self.K3.copy() + G.remove_edges_from([(0, 1)]) + assert G.succ == {0: {2: {}}, 1: {0: {}, 2: {}}, 2: {0: {}, 1: {}}} + assert G.pred == {0: {1: {}, 2: {}}, 1: {2: {}}, 2: {0: {}, 1: {}}} + G.remove_edges_from([(0, 0)]) # silent fail + + def test_clear(self): + G = self.K3 + G.graph["name"] = "K3" + G.clear() + assert list(G.nodes) == [] + assert G.succ == {} + assert G.pred == {} + assert G.graph == {} + + def test_clear_edges(self): + G = self.K3 + G.graph["name"] = "K3" + nodes = list(G.nodes) + G.clear_edges() + assert list(G.nodes) == nodes + expected = {0: {}, 1: {}, 2: {}} + assert G.succ == expected + assert G.pred == expected + assert list(G.edges) == [] + assert G.graph["name"] == "K3" + + +class TestEdgeSubgraph(_TestGraphEdgeSubgraph): + """Unit tests for the :meth:`DiGraph.edge_subgraph` method.""" + + def setup_method(self): + # Create a doubly-linked path graph on five nodes. + G = nx.DiGraph(nx.path_graph(5)) + # Add some node, edge, and graph attributes. + for i in range(5): + G.nodes[i]["name"] = f"node{i}" + G.edges[0, 1]["name"] = "edge01" + G.edges[3, 4]["name"] = "edge34" + G.graph["name"] = "graph" + # Get the subgraph induced by the first and last edges. + self.G = G + self.H = G.edge_subgraph([(0, 1), (3, 4)]) + + def test_pred_succ(self): + """Test that nodes are added to predecessors and successors. + + For more information, see GitHub issue #2370. + + """ + G = nx.DiGraph() + G.add_edge(0, 1) + H = G.edge_subgraph([(0, 1)]) + assert list(H.predecessors(0)) == [] + assert list(H.successors(0)) == [1] + assert list(H.predecessors(1)) == [0] + assert list(H.successors(1)) == [] diff --git a/lib/python3.10/site-packages/networkx/classes/tests/test_digraph_historical.py b/lib/python3.10/site-packages/networkx/classes/tests/test_digraph_historical.py new file mode 100644 index 0000000000000000000000000000000000000000..4f2b1da90f977962e9610bd64d10e721915a0595 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/classes/tests/test_digraph_historical.py @@ -0,0 +1,111 @@ +"""Original NetworkX graph tests""" + +import pytest + +import networkx +import networkx as nx + +from .historical_tests import HistoricalTests + + +class TestDiGraphHistorical(HistoricalTests): + @classmethod + def setup_class(cls): + HistoricalTests.setup_class() + cls.G = nx.DiGraph + + def test_in_degree(self): + G = self.G() + G.add_nodes_from("GJK") + G.add_edges_from([("A", "B"), ("A", "C"), ("B", "D"), ("B", "C"), ("C", "D")]) + + assert sorted(d for n, d in G.in_degree()) == [0, 0, 0, 0, 1, 2, 2] + assert dict(G.in_degree()) == { + "A": 0, + "C": 2, + "B": 1, + "D": 2, + "G": 0, + "K": 0, + "J": 0, + } + + def test_out_degree(self): + G = self.G() + G.add_nodes_from("GJK") + G.add_edges_from([("A", "B"), ("A", "C"), ("B", "D"), ("B", "C"), ("C", "D")]) + assert sorted(v for k, v in G.in_degree()) == [0, 0, 0, 0, 1, 2, 2] + assert dict(G.out_degree()) == { + "A": 2, + "C": 1, + "B": 2, + "D": 0, + "G": 0, + "K": 0, + "J": 0, + } + + def test_degree_digraph(self): + H = nx.DiGraph() + H.add_edges_from([(1, 24), (1, 2)]) + assert sorted(d for n, d in H.in_degree([1, 24])) == [0, 1] + assert sorted(d for n, d in H.out_degree([1, 24])) == [0, 2] + assert sorted(d for n, d in H.degree([1, 24])) == [1, 2] + + def test_neighbors(self): + G = self.G() + G.add_nodes_from("GJK") + G.add_edges_from([("A", "B"), ("A", "C"), ("B", "D"), ("B", "C"), ("C", "D")]) + + assert sorted(G.neighbors("C")) == ["D"] + assert sorted(G["C"]) == ["D"] + assert sorted(G.neighbors("A")) == ["B", "C"] + pytest.raises(nx.NetworkXError, G.neighbors, "j") + pytest.raises(nx.NetworkXError, G.neighbors, "j") + + def test_successors(self): + G = self.G() + G.add_nodes_from("GJK") + G.add_edges_from([("A", "B"), ("A", "C"), ("B", "D"), ("B", "C"), ("C", "D")]) + assert sorted(G.successors("A")) == ["B", "C"] + assert sorted(G.successors("A")) == ["B", "C"] + assert sorted(G.successors("G")) == [] + assert sorted(G.successors("D")) == [] + assert sorted(G.successors("G")) == [] + pytest.raises(nx.NetworkXError, G.successors, "j") + pytest.raises(nx.NetworkXError, G.successors, "j") + + def test_predecessors(self): + G = self.G() + G.add_nodes_from("GJK") + G.add_edges_from([("A", "B"), ("A", "C"), ("B", "D"), ("B", "C"), ("C", "D")]) + assert sorted(G.predecessors("C")) == ["A", "B"] + assert sorted(G.predecessors("C")) == ["A", "B"] + assert sorted(G.predecessors("G")) == [] + assert sorted(G.predecessors("A")) == [] + assert sorted(G.predecessors("G")) == [] + assert sorted(G.predecessors("A")) == [] + assert sorted(G.successors("D")) == [] + + pytest.raises(nx.NetworkXError, G.predecessors, "j") + pytest.raises(nx.NetworkXError, G.predecessors, "j") + + def test_reverse(self): + G = nx.complete_graph(10) + H = G.to_directed() + HR = H.reverse() + assert nx.is_isomorphic(H, HR) + assert sorted(H.edges()) == sorted(HR.edges()) + + def test_reverse2(self): + H = nx.DiGraph() + foo = [H.add_edge(u, u + 1) for u in range(5)] + HR = H.reverse() + for u in range(5): + assert HR.has_edge(u + 1, u) + + def test_reverse3(self): + H = nx.DiGraph() + H.add_nodes_from([1, 2, 3, 4]) + HR = H.reverse() + assert sorted(HR.nodes()) == [1, 2, 3, 4] diff --git a/lib/python3.10/site-packages/networkx/classes/tests/test_filters.py b/lib/python3.10/site-packages/networkx/classes/tests/test_filters.py new file mode 100644 index 0000000000000000000000000000000000000000..2da59117cad0d72d5830b53c8d19c6e0ca988d54 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/classes/tests/test_filters.py @@ -0,0 +1,177 @@ +import pytest + +import networkx as nx + + +class TestFilterFactory: + def test_no_filter(self): + nf = nx.filters.no_filter + assert nf() + assert nf(1) + assert nf(2, 1) + + def test_hide_nodes(self): + f = nx.classes.filters.hide_nodes([1, 2, 3]) + assert not f(1) + assert not f(2) + assert not f(3) + assert f(4) + assert f(0) + assert f("a") + pytest.raises(TypeError, f, 1, 2) + pytest.raises(TypeError, f) + + def test_show_nodes(self): + f = nx.classes.filters.show_nodes([1, 2, 3]) + assert f(1) + assert f(2) + assert f(3) + assert not f(4) + assert not f(0) + assert not f("a") + pytest.raises(TypeError, f, 1, 2) + pytest.raises(TypeError, f) + + def test_hide_edges(self): + factory = nx.classes.filters.hide_edges + f = factory([(1, 2), (3, 4)]) + assert not f(1, 2) + assert not f(3, 4) + assert not f(4, 3) + assert f(2, 3) + assert f(0, -1) + assert f("a", "b") + pytest.raises(TypeError, f, 1, 2, 3) + pytest.raises(TypeError, f, 1) + pytest.raises(TypeError, f) + pytest.raises(TypeError, factory, [1, 2, 3]) + pytest.raises(ValueError, factory, [(1, 2, 3)]) + + def test_show_edges(self): + factory = nx.classes.filters.show_edges + f = factory([(1, 2), (3, 4)]) + assert f(1, 2) + assert f(3, 4) + assert f(4, 3) + assert not f(2, 3) + assert not f(0, -1) + assert not f("a", "b") + pytest.raises(TypeError, f, 1, 2, 3) + pytest.raises(TypeError, f, 1) + pytest.raises(TypeError, f) + pytest.raises(TypeError, factory, [1, 2, 3]) + pytest.raises(ValueError, factory, [(1, 2, 3)]) + + def test_hide_diedges(self): + factory = nx.classes.filters.hide_diedges + f = factory([(1, 2), (3, 4)]) + assert not f(1, 2) + assert not f(3, 4) + assert f(4, 3) + assert f(2, 3) + assert f(0, -1) + assert f("a", "b") + pytest.raises(TypeError, f, 1, 2, 3) + pytest.raises(TypeError, f, 1) + pytest.raises(TypeError, f) + pytest.raises(TypeError, factory, [1, 2, 3]) + pytest.raises(ValueError, factory, [(1, 2, 3)]) + + def test_show_diedges(self): + factory = nx.classes.filters.show_diedges + f = factory([(1, 2), (3, 4)]) + assert f(1, 2) + assert f(3, 4) + assert not f(4, 3) + assert not f(2, 3) + assert not f(0, -1) + assert not f("a", "b") + pytest.raises(TypeError, f, 1, 2, 3) + pytest.raises(TypeError, f, 1) + pytest.raises(TypeError, f) + pytest.raises(TypeError, factory, [1, 2, 3]) + pytest.raises(ValueError, factory, [(1, 2, 3)]) + + def test_hide_multiedges(self): + factory = nx.classes.filters.hide_multiedges + f = factory([(1, 2, 0), (3, 4, 1), (1, 2, 1)]) + assert not f(1, 2, 0) + assert not f(1, 2, 1) + assert f(1, 2, 2) + assert f(3, 4, 0) + assert not f(3, 4, 1) + assert not f(4, 3, 1) + assert f(4, 3, 0) + assert f(2, 3, 0) + assert f(0, -1, 0) + assert f("a", "b", 0) + pytest.raises(TypeError, f, 1, 2, 3, 4) + pytest.raises(TypeError, f, 1, 2) + pytest.raises(TypeError, f, 1) + pytest.raises(TypeError, f) + pytest.raises(TypeError, factory, [1, 2, 3]) + pytest.raises(ValueError, factory, [(1, 2)]) + pytest.raises(ValueError, factory, [(1, 2, 3, 4)]) + + def test_show_multiedges(self): + factory = nx.classes.filters.show_multiedges + f = factory([(1, 2, 0), (3, 4, 1), (1, 2, 1)]) + assert f(1, 2, 0) + assert f(1, 2, 1) + assert not f(1, 2, 2) + assert not f(3, 4, 0) + assert f(3, 4, 1) + assert f(4, 3, 1) + assert not f(4, 3, 0) + assert not f(2, 3, 0) + assert not f(0, -1, 0) + assert not f("a", "b", 0) + pytest.raises(TypeError, f, 1, 2, 3, 4) + pytest.raises(TypeError, f, 1, 2) + pytest.raises(TypeError, f, 1) + pytest.raises(TypeError, f) + pytest.raises(TypeError, factory, [1, 2, 3]) + pytest.raises(ValueError, factory, [(1, 2)]) + pytest.raises(ValueError, factory, [(1, 2, 3, 4)]) + + def test_hide_multidiedges(self): + factory = nx.classes.filters.hide_multidiedges + f = factory([(1, 2, 0), (3, 4, 1), (1, 2, 1)]) + assert not f(1, 2, 0) + assert not f(1, 2, 1) + assert f(1, 2, 2) + assert f(3, 4, 0) + assert not f(3, 4, 1) + assert f(4, 3, 1) + assert f(4, 3, 0) + assert f(2, 3, 0) + assert f(0, -1, 0) + assert f("a", "b", 0) + pytest.raises(TypeError, f, 1, 2, 3, 4) + pytest.raises(TypeError, f, 1, 2) + pytest.raises(TypeError, f, 1) + pytest.raises(TypeError, f) + pytest.raises(TypeError, factory, [1, 2, 3]) + pytest.raises(ValueError, factory, [(1, 2)]) + pytest.raises(ValueError, factory, [(1, 2, 3, 4)]) + + def test_show_multidiedges(self): + factory = nx.classes.filters.show_multidiedges + f = factory([(1, 2, 0), (3, 4, 1), (1, 2, 1)]) + assert f(1, 2, 0) + assert f(1, 2, 1) + assert not f(1, 2, 2) + assert not f(3, 4, 0) + assert f(3, 4, 1) + assert not f(4, 3, 1) + assert not f(4, 3, 0) + assert not f(2, 3, 0) + assert not f(0, -1, 0) + assert not f("a", "b", 0) + pytest.raises(TypeError, f, 1, 2, 3, 4) + pytest.raises(TypeError, f, 1, 2) + pytest.raises(TypeError, f, 1) + pytest.raises(TypeError, f) + pytest.raises(TypeError, factory, [1, 2, 3]) + pytest.raises(ValueError, factory, [(1, 2)]) + pytest.raises(ValueError, factory, [(1, 2, 3, 4)]) diff --git a/lib/python3.10/site-packages/networkx/classes/tests/test_function.py b/lib/python3.10/site-packages/networkx/classes/tests/test_function.py new file mode 100644 index 0000000000000000000000000000000000000000..f86890dd25268472db87fb18448d38ab11ec9946 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/classes/tests/test_function.py @@ -0,0 +1,1035 @@ +import random + +import pytest + +import networkx as nx +from networkx.utils import edges_equal, nodes_equal + + +def test_degree_histogram_empty(): + G = nx.Graph() + assert nx.degree_histogram(G) == [] + + +class TestFunction: + def setup_method(self): + self.G = nx.Graph({0: [1, 2, 3], 1: [1, 2, 0], 4: []}, name="Test") + self.Gdegree = {0: 3, 1: 2, 2: 2, 3: 1, 4: 0} + self.Gnodes = list(range(5)) + self.Gedges = [(0, 1), (0, 2), (0, 3), (1, 0), (1, 1), (1, 2)] + self.DG = nx.DiGraph({0: [1, 2, 3], 1: [1, 2, 0], 4: []}) + self.DGin_degree = {0: 1, 1: 2, 2: 2, 3: 1, 4: 0} + self.DGout_degree = {0: 3, 1: 3, 2: 0, 3: 0, 4: 0} + self.DGnodes = list(range(5)) + self.DGedges = [(0, 1), (0, 2), (0, 3), (1, 0), (1, 1), (1, 2)] + + def test_nodes(self): + assert nodes_equal(self.G.nodes(), list(nx.nodes(self.G))) + assert nodes_equal(self.DG.nodes(), list(nx.nodes(self.DG))) + + def test_edges(self): + assert edges_equal(self.G.edges(), list(nx.edges(self.G))) + assert sorted(self.DG.edges()) == sorted(nx.edges(self.DG)) + assert edges_equal( + self.G.edges(nbunch=[0, 1, 3]), list(nx.edges(self.G, nbunch=[0, 1, 3])) + ) + assert sorted(self.DG.edges(nbunch=[0, 1, 3])) == sorted( + nx.edges(self.DG, nbunch=[0, 1, 3]) + ) + + def test_degree(self): + assert edges_equal(self.G.degree(), list(nx.degree(self.G))) + assert sorted(self.DG.degree()) == sorted(nx.degree(self.DG)) + assert edges_equal( + self.G.degree(nbunch=[0, 1]), list(nx.degree(self.G, nbunch=[0, 1])) + ) + assert sorted(self.DG.degree(nbunch=[0, 1])) == sorted( + nx.degree(self.DG, nbunch=[0, 1]) + ) + assert edges_equal( + self.G.degree(weight="weight"), list(nx.degree(self.G, weight="weight")) + ) + assert sorted(self.DG.degree(weight="weight")) == sorted( + nx.degree(self.DG, weight="weight") + ) + + def test_neighbors(self): + assert list(self.G.neighbors(1)) == list(nx.neighbors(self.G, 1)) + assert list(self.DG.neighbors(1)) == list(nx.neighbors(self.DG, 1)) + + def test_number_of_nodes(self): + assert self.G.number_of_nodes() == nx.number_of_nodes(self.G) + assert self.DG.number_of_nodes() == nx.number_of_nodes(self.DG) + + def test_number_of_edges(self): + assert self.G.number_of_edges() == nx.number_of_edges(self.G) + assert self.DG.number_of_edges() == nx.number_of_edges(self.DG) + + def test_is_directed(self): + assert self.G.is_directed() == nx.is_directed(self.G) + assert self.DG.is_directed() == nx.is_directed(self.DG) + + def test_add_star(self): + G = self.G.copy() + nlist = [12, 13, 14, 15] + nx.add_star(G, nlist) + assert edges_equal(G.edges(nlist), [(12, 13), (12, 14), (12, 15)]) + + G = self.G.copy() + nx.add_star(G, nlist, weight=2.0) + assert edges_equal( + G.edges(nlist, data=True), + [ + (12, 13, {"weight": 2.0}), + (12, 14, {"weight": 2.0}), + (12, 15, {"weight": 2.0}), + ], + ) + + G = self.G.copy() + nlist = [12] + nx.add_star(G, nlist) + assert nodes_equal(G, list(self.G) + nlist) + + G = self.G.copy() + nlist = [] + nx.add_star(G, nlist) + assert nodes_equal(G.nodes, self.Gnodes) + assert edges_equal(G.edges, self.G.edges) + + def test_add_path(self): + G = self.G.copy() + nlist = [12, 13, 14, 15] + nx.add_path(G, nlist) + assert edges_equal(G.edges(nlist), [(12, 13), (13, 14), (14, 15)]) + G = self.G.copy() + nx.add_path(G, nlist, weight=2.0) + assert edges_equal( + G.edges(nlist, data=True), + [ + (12, 13, {"weight": 2.0}), + (13, 14, {"weight": 2.0}), + (14, 15, {"weight": 2.0}), + ], + ) + + G = self.G.copy() + nlist = ["node"] + nx.add_path(G, nlist) + assert edges_equal(G.edges(nlist), []) + assert nodes_equal(G, list(self.G) + ["node"]) + + G = self.G.copy() + nlist = iter(["node"]) + nx.add_path(G, nlist) + assert edges_equal(G.edges(["node"]), []) + assert nodes_equal(G, list(self.G) + ["node"]) + + G = self.G.copy() + nlist = [12] + nx.add_path(G, nlist) + assert edges_equal(G.edges(nlist), []) + assert nodes_equal(G, list(self.G) + [12]) + + G = self.G.copy() + nlist = iter([12]) + nx.add_path(G, nlist) + assert edges_equal(G.edges([12]), []) + assert nodes_equal(G, list(self.G) + [12]) + + G = self.G.copy() + nlist = [] + nx.add_path(G, nlist) + assert edges_equal(G.edges, self.G.edges) + assert nodes_equal(G, list(self.G)) + + G = self.G.copy() + nlist = iter([]) + nx.add_path(G, nlist) + assert edges_equal(G.edges, self.G.edges) + assert nodes_equal(G, list(self.G)) + + def test_add_cycle(self): + G = self.G.copy() + nlist = [12, 13, 14, 15] + oklists = [ + [(12, 13), (12, 15), (13, 14), (14, 15)], + [(12, 13), (13, 14), (14, 15), (15, 12)], + ] + nx.add_cycle(G, nlist) + assert sorted(G.edges(nlist)) in oklists + G = self.G.copy() + oklists = [ + [ + (12, 13, {"weight": 1.0}), + (12, 15, {"weight": 1.0}), + (13, 14, {"weight": 1.0}), + (14, 15, {"weight": 1.0}), + ], + [ + (12, 13, {"weight": 1.0}), + (13, 14, {"weight": 1.0}), + (14, 15, {"weight": 1.0}), + (15, 12, {"weight": 1.0}), + ], + ] + nx.add_cycle(G, nlist, weight=1.0) + assert sorted(G.edges(nlist, data=True)) in oklists + + G = self.G.copy() + nlist = [12] + nx.add_cycle(G, nlist) + assert nodes_equal(G, list(self.G) + nlist) + + G = self.G.copy() + nlist = [] + nx.add_cycle(G, nlist) + assert nodes_equal(G.nodes, self.Gnodes) + assert edges_equal(G.edges, self.G.edges) + + def test_subgraph(self): + assert ( + self.G.subgraph([0, 1, 2, 4]).adj == nx.subgraph(self.G, [0, 1, 2, 4]).adj + ) + assert ( + self.DG.subgraph([0, 1, 2, 4]).adj == nx.subgraph(self.DG, [0, 1, 2, 4]).adj + ) + assert ( + self.G.subgraph([0, 1, 2, 4]).adj + == nx.induced_subgraph(self.G, [0, 1, 2, 4]).adj + ) + assert ( + self.DG.subgraph([0, 1, 2, 4]).adj + == nx.induced_subgraph(self.DG, [0, 1, 2, 4]).adj + ) + # subgraph-subgraph chain is allowed in function interface + H = nx.induced_subgraph(self.G.subgraph([0, 1, 2, 4]), [0, 1, 4]) + assert H._graph is not self.G + assert H.adj == self.G.subgraph([0, 1, 4]).adj + + def test_edge_subgraph(self): + assert ( + self.G.edge_subgraph([(1, 2), (0, 3)]).adj + == nx.edge_subgraph(self.G, [(1, 2), (0, 3)]).adj + ) + assert ( + self.DG.edge_subgraph([(1, 2), (0, 3)]).adj + == nx.edge_subgraph(self.DG, [(1, 2), (0, 3)]).adj + ) + + def test_create_empty_copy(self): + G = nx.create_empty_copy(self.G, with_data=False) + assert nodes_equal(G, list(self.G)) + assert G.graph == {} + assert G._node == {}.fromkeys(self.G.nodes(), {}) + assert G._adj == {}.fromkeys(self.G.nodes(), {}) + G = nx.create_empty_copy(self.G) + assert nodes_equal(G, list(self.G)) + assert G.graph == self.G.graph + assert G._node == self.G._node + assert G._adj == {}.fromkeys(self.G.nodes(), {}) + + def test_degree_histogram(self): + assert nx.degree_histogram(self.G) == [1, 1, 1, 1, 1] + + def test_density(self): + assert nx.density(self.G) == 0.5 + assert nx.density(self.DG) == 0.3 + G = nx.Graph() + G.add_node(1) + assert nx.density(G) == 0.0 + + def test_density_selfloop(self): + G = nx.Graph() + G.add_edge(1, 1) + assert nx.density(G) == 0.0 + G.add_edge(1, 2) + assert nx.density(G) == 2.0 + + def test_freeze(self): + G = nx.freeze(self.G) + assert G.frozen + pytest.raises(nx.NetworkXError, G.add_node, 1) + pytest.raises(nx.NetworkXError, G.add_nodes_from, [1]) + pytest.raises(nx.NetworkXError, G.remove_node, 1) + pytest.raises(nx.NetworkXError, G.remove_nodes_from, [1]) + pytest.raises(nx.NetworkXError, G.add_edge, 1, 2) + pytest.raises(nx.NetworkXError, G.add_edges_from, [(1, 2)]) + pytest.raises(nx.NetworkXError, G.remove_edge, 1, 2) + pytest.raises(nx.NetworkXError, G.remove_edges_from, [(1, 2)]) + pytest.raises(nx.NetworkXError, G.clear_edges) + pytest.raises(nx.NetworkXError, G.clear) + + def test_is_frozen(self): + assert not nx.is_frozen(self.G) + G = nx.freeze(self.G) + assert G.frozen == nx.is_frozen(self.G) + assert G.frozen + + def test_node_attributes_are_still_mutable_on_frozen_graph(self): + G = nx.freeze(nx.path_graph(3)) + node = G.nodes[0] + node["node_attribute"] = True + assert node["node_attribute"] == True + + def test_edge_attributes_are_still_mutable_on_frozen_graph(self): + G = nx.freeze(nx.path_graph(3)) + edge = G.edges[(0, 1)] + edge["edge_attribute"] = True + assert edge["edge_attribute"] == True + + def test_neighbors_complete_graph(self): + graph = nx.complete_graph(100) + pop = random.sample(list(graph), 1) + nbors = list(nx.neighbors(graph, pop[0])) + # should be all the other vertices in the graph + assert len(nbors) == len(graph) - 1 + + graph = nx.path_graph(100) + node = random.sample(list(graph), 1)[0] + nbors = list(nx.neighbors(graph, node)) + # should be all the other vertices in the graph + if node != 0 and node != 99: + assert len(nbors) == 2 + else: + assert len(nbors) == 1 + + # create a star graph with 99 outer nodes + graph = nx.star_graph(99) + nbors = list(nx.neighbors(graph, 0)) + assert len(nbors) == 99 + + def test_non_neighbors(self): + graph = nx.complete_graph(100) + pop = random.sample(list(graph), 1) + nbors = nx.non_neighbors(graph, pop[0]) + # should be all the other vertices in the graph + assert len(nbors) == 0 + + graph = nx.path_graph(100) + node = random.sample(list(graph), 1)[0] + nbors = nx.non_neighbors(graph, node) + # should be all the other vertices in the graph + if node != 0 and node != 99: + assert len(nbors) == 97 + else: + assert len(nbors) == 98 + + # create a star graph with 99 outer nodes + graph = nx.star_graph(99) + nbors = nx.non_neighbors(graph, 0) + assert len(nbors) == 0 + + # disconnected graph + graph = nx.Graph() + graph.add_nodes_from(range(10)) + nbors = nx.non_neighbors(graph, 0) + assert len(nbors) == 9 + + def test_non_edges(self): + # All possible edges exist + graph = nx.complete_graph(5) + nedges = list(nx.non_edges(graph)) + assert len(nedges) == 0 + + graph = nx.path_graph(4) + expected = [(0, 2), (0, 3), (1, 3)] + nedges = list(nx.non_edges(graph)) + for u, v in expected: + assert (u, v) in nedges or (v, u) in nedges + + graph = nx.star_graph(4) + expected = [(1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)] + nedges = list(nx.non_edges(graph)) + for u, v in expected: + assert (u, v) in nedges or (v, u) in nedges + + # Directed graphs + graph = nx.DiGraph() + graph.add_edges_from([(0, 2), (2, 0), (2, 1)]) + expected = [(0, 1), (1, 0), (1, 2)] + nedges = list(nx.non_edges(graph)) + for e in expected: + assert e in nedges + + def test_is_weighted(self): + G = nx.Graph() + assert not nx.is_weighted(G) + + G = nx.path_graph(4) + assert not nx.is_weighted(G) + assert not nx.is_weighted(G, (2, 3)) + + G.add_node(4) + G.add_edge(3, 4, weight=4) + assert not nx.is_weighted(G) + assert nx.is_weighted(G, (3, 4)) + + G = nx.DiGraph() + G.add_weighted_edges_from( + [ + ("0", "3", 3), + ("0", "1", -5), + ("1", "0", -5), + ("0", "2", 2), + ("1", "2", 4), + ("2", "3", 1), + ] + ) + assert nx.is_weighted(G) + assert nx.is_weighted(G, ("1", "0")) + + G = G.to_undirected() + assert nx.is_weighted(G) + assert nx.is_weighted(G, ("1", "0")) + + pytest.raises(nx.NetworkXError, nx.is_weighted, G, (1, 2)) + + def test_is_negatively_weighted(self): + G = nx.Graph() + assert not nx.is_negatively_weighted(G) + + G.add_node(1) + G.add_nodes_from([2, 3, 4, 5]) + assert not nx.is_negatively_weighted(G) + + G.add_edge(1, 2, weight=4) + assert not nx.is_negatively_weighted(G, (1, 2)) + + G.add_edges_from([(1, 3), (2, 4), (2, 6)]) + G[1][3]["color"] = "blue" + assert not nx.is_negatively_weighted(G) + assert not nx.is_negatively_weighted(G, (1, 3)) + + G[2][4]["weight"] = -2 + assert nx.is_negatively_weighted(G, (2, 4)) + assert nx.is_negatively_weighted(G) + + G = nx.DiGraph() + G.add_weighted_edges_from( + [ + ("0", "3", 3), + ("0", "1", -5), + ("1", "0", -2), + ("0", "2", 2), + ("1", "2", -3), + ("2", "3", 1), + ] + ) + assert nx.is_negatively_weighted(G) + assert not nx.is_negatively_weighted(G, ("0", "3")) + assert nx.is_negatively_weighted(G, ("1", "0")) + + pytest.raises(nx.NetworkXError, nx.is_negatively_weighted, G, (1, 4)) + + +class TestCommonNeighbors: + @classmethod + def setup_class(cls): + cls.func = staticmethod(nx.common_neighbors) + + def test_func(G, u, v, expected): + result = sorted(cls.func(G, u, v)) + assert result == expected + + cls.test = staticmethod(test_func) + + def test_K5(self): + G = nx.complete_graph(5) + self.test(G, 0, 1, [2, 3, 4]) + + def test_P3(self): + G = nx.path_graph(3) + self.test(G, 0, 2, [1]) + + def test_S4(self): + G = nx.star_graph(4) + self.test(G, 1, 2, [0]) + + def test_digraph(self): + with pytest.raises(nx.NetworkXNotImplemented): + G = nx.DiGraph() + G.add_edges_from([(0, 1), (1, 2)]) + self.func(G, 0, 2) + + def test_nonexistent_nodes(self): + G = nx.complete_graph(5) + pytest.raises(nx.NetworkXError, nx.common_neighbors, G, 5, 4) + pytest.raises(nx.NetworkXError, nx.common_neighbors, G, 4, 5) + pytest.raises(nx.NetworkXError, nx.common_neighbors, G, 5, 6) + + def test_custom1(self): + """Case of no common neighbors.""" + G = nx.Graph() + G.add_nodes_from([0, 1]) + self.test(G, 0, 1, []) + + def test_custom2(self): + """Case of equal nodes.""" + G = nx.complete_graph(4) + self.test(G, 0, 0, [1, 2, 3]) + + +@pytest.mark.parametrize( + "graph_type", (nx.Graph, nx.DiGraph, nx.MultiGraph, nx.MultiDiGraph) +) +def test_set_node_attributes(graph_type): + # Test single value + G = nx.path_graph(3, create_using=graph_type) + vals = 100 + attr = "hello" + nx.set_node_attributes(G, vals, attr) + assert G.nodes[0][attr] == vals + assert G.nodes[1][attr] == vals + assert G.nodes[2][attr] == vals + + # Test dictionary + G = nx.path_graph(3, create_using=graph_type) + vals = dict(zip(sorted(G.nodes()), range(len(G)))) + attr = "hi" + nx.set_node_attributes(G, vals, attr) + assert G.nodes[0][attr] == 0 + assert G.nodes[1][attr] == 1 + assert G.nodes[2][attr] == 2 + + # Test dictionary of dictionaries + G = nx.path_graph(3, create_using=graph_type) + d = {"hi": 0, "hello": 200} + vals = dict.fromkeys(G.nodes(), d) + vals.pop(0) + nx.set_node_attributes(G, vals) + assert G.nodes[0] == {} + assert G.nodes[1]["hi"] == 0 + assert G.nodes[2]["hello"] == 200 + + +@pytest.mark.parametrize( + ("values", "name"), + ( + ({0: "red", 1: "blue"}, "color"), # values dictionary + ({0: {"color": "red"}, 1: {"color": "blue"}}, None), # dict-of-dict + ), +) +def test_set_node_attributes_ignores_extra_nodes(values, name): + """ + When `values` is a dict or dict-of-dict keyed by nodes, ensure that keys + that correspond to nodes not in G are ignored. + """ + G = nx.Graph() + G.add_node(0) + nx.set_node_attributes(G, values, name) + assert G.nodes[0]["color"] == "red" + assert 1 not in G.nodes + + +@pytest.mark.parametrize("graph_type", (nx.Graph, nx.DiGraph)) +def test_set_edge_attributes(graph_type): + # Test single value + G = nx.path_graph(3, create_using=graph_type) + attr = "hello" + vals = 3 + nx.set_edge_attributes(G, vals, attr) + assert G[0][1][attr] == vals + assert G[1][2][attr] == vals + + # Test multiple values + G = nx.path_graph(3, create_using=graph_type) + attr = "hi" + edges = [(0, 1), (1, 2)] + vals = dict(zip(edges, range(len(edges)))) + nx.set_edge_attributes(G, vals, attr) + assert G[0][1][attr] == 0 + assert G[1][2][attr] == 1 + + # Test dictionary of dictionaries + G = nx.path_graph(3, create_using=graph_type) + d = {"hi": 0, "hello": 200} + edges = [(0, 1)] + vals = dict.fromkeys(edges, d) + nx.set_edge_attributes(G, vals) + assert G[0][1]["hi"] == 0 + assert G[0][1]["hello"] == 200 + assert G[1][2] == {} + + +@pytest.mark.parametrize( + ("values", "name"), + ( + ({(0, 1): 1.0, (0, 2): 2.0}, "weight"), # values dict + ({(0, 1): {"weight": 1.0}, (0, 2): {"weight": 2.0}}, None), # values dod + ), +) +def test_set_edge_attributes_ignores_extra_edges(values, name): + """If `values` is a dict or dict-of-dicts containing edges that are not in + G, data associate with these edges should be ignored. + """ + G = nx.Graph([(0, 1)]) + nx.set_edge_attributes(G, values, name) + assert G[0][1]["weight"] == 1.0 + assert (0, 2) not in G.edges + + +@pytest.mark.parametrize("graph_type", (nx.MultiGraph, nx.MultiDiGraph)) +def test_set_edge_attributes_multi(graph_type): + # Test single value + G = nx.path_graph(3, create_using=graph_type) + attr = "hello" + vals = 3 + nx.set_edge_attributes(G, vals, attr) + assert G[0][1][0][attr] == vals + assert G[1][2][0][attr] == vals + + # Test multiple values + G = nx.path_graph(3, create_using=graph_type) + attr = "hi" + edges = [(0, 1, 0), (1, 2, 0)] + vals = dict(zip(edges, range(len(edges)))) + nx.set_edge_attributes(G, vals, attr) + assert G[0][1][0][attr] == 0 + assert G[1][2][0][attr] == 1 + + # Test dictionary of dictionaries + G = nx.path_graph(3, create_using=graph_type) + d = {"hi": 0, "hello": 200} + edges = [(0, 1, 0)] + vals = dict.fromkeys(edges, d) + nx.set_edge_attributes(G, vals) + assert G[0][1][0]["hi"] == 0 + assert G[0][1][0]["hello"] == 200 + assert G[1][2][0] == {} + + +@pytest.mark.parametrize( + ("values", "name"), + ( + ({(0, 1, 0): 1.0, (0, 2, 0): 2.0}, "weight"), # values dict + ({(0, 1, 0): {"weight": 1.0}, (0, 2, 0): {"weight": 2.0}}, None), # values dod + ), +) +def test_set_edge_attributes_multi_ignores_extra_edges(values, name): + """If `values` is a dict or dict-of-dicts containing edges that are not in + G, data associate with these edges should be ignored. + """ + G = nx.MultiGraph([(0, 1, 0), (0, 1, 1)]) + nx.set_edge_attributes(G, values, name) + assert G[0][1][0]["weight"] == 1.0 + assert G[0][1][1] == {} + assert (0, 2) not in G.edges() + + +def test_get_node_attributes(): + graphs = [nx.Graph(), nx.DiGraph(), nx.MultiGraph(), nx.MultiDiGraph()] + for G in graphs: + G = nx.path_graph(3, create_using=G) + attr = "hello" + vals = 100 + nx.set_node_attributes(G, vals, attr) + attrs = nx.get_node_attributes(G, attr) + assert attrs[0] == vals + assert attrs[1] == vals + assert attrs[2] == vals + default_val = 1 + G.add_node(4) + attrs = nx.get_node_attributes(G, attr, default=default_val) + assert attrs[4] == default_val + + +def test_get_edge_attributes(): + graphs = [nx.Graph(), nx.DiGraph(), nx.MultiGraph(), nx.MultiDiGraph()] + for G in graphs: + G = nx.path_graph(3, create_using=G) + attr = "hello" + vals = 100 + nx.set_edge_attributes(G, vals, attr) + attrs = nx.get_edge_attributes(G, attr) + assert len(attrs) == 2 + + for edge in G.edges: + assert attrs[edge] == vals + + default_val = vals + G.add_edge(4, 5) + deafult_attrs = nx.get_edge_attributes(G, attr, default=default_val) + assert len(deafult_attrs) == 3 + + for edge in G.edges: + assert deafult_attrs[edge] == vals + + +@pytest.mark.parametrize( + "graph_type", (nx.Graph, nx.DiGraph, nx.MultiGraph, nx.MultiDiGraph) +) +def test_remove_node_attributes(graph_type): + # Test removing single attribute + G = nx.path_graph(3, create_using=graph_type) + vals = 100 + attr = "hello" + nx.set_node_attributes(G, vals, attr) + nx.remove_node_attributes(G, attr) + assert attr not in G.nodes[0] + assert attr not in G.nodes[1] + assert attr not in G.nodes[2] + + # Test removing single attribute when multiple present + G = nx.path_graph(3, create_using=graph_type) + other_vals = 200 + other_attr = "other" + nx.set_node_attributes(G, vals, attr) + nx.set_node_attributes(G, other_vals, other_attr) + nx.remove_node_attributes(G, attr) + assert attr not in G.nodes[0] + assert G.nodes[0][other_attr] == other_vals + assert attr not in G.nodes[1] + assert G.nodes[1][other_attr] == other_vals + assert attr not in G.nodes[2] + assert G.nodes[2][other_attr] == other_vals + + # Test removing multiple attributes + G = nx.path_graph(3, create_using=graph_type) + nx.set_node_attributes(G, vals, attr) + nx.set_node_attributes(G, other_vals, other_attr) + nx.remove_node_attributes(G, attr, other_attr) + assert attr not in G.nodes[0] and other_attr not in G.nodes[0] + assert attr not in G.nodes[1] and other_attr not in G.nodes[1] + assert attr not in G.nodes[2] and other_attr not in G.nodes[2] + + # Test removing multiple (but not all) attributes + G = nx.path_graph(3, create_using=graph_type) + third_vals = 300 + third_attr = "three" + nx.set_node_attributes( + G, + { + n: {attr: vals, other_attr: other_vals, third_attr: third_vals} + for n in G.nodes() + }, + ) + nx.remove_node_attributes(G, other_attr, third_attr) + assert other_attr not in G.nodes[0] and third_attr not in G.nodes[0] + assert other_attr not in G.nodes[1] and third_attr not in G.nodes[1] + assert other_attr not in G.nodes[2] and third_attr not in G.nodes[2] + assert G.nodes[0][attr] == vals + assert G.nodes[1][attr] == vals + assert G.nodes[2][attr] == vals + + # Test incomplete node attributes + G = nx.path_graph(3, create_using=graph_type) + nx.set_node_attributes( + G, + { + 1: {attr: vals, other_attr: other_vals}, + 2: {attr: vals, other_attr: other_vals}, + }, + ) + nx.remove_node_attributes(G, attr) + assert attr not in G.nodes[0] + assert attr not in G.nodes[1] + assert attr not in G.nodes[2] + assert G.nodes[1][other_attr] == other_vals + assert G.nodes[2][other_attr] == other_vals + + # Test removing on a subset of nodes + G = nx.path_graph(3, create_using=graph_type) + nx.set_node_attributes( + G, + { + n: {attr: vals, other_attr: other_vals, third_attr: third_vals} + for n in G.nodes() + }, + ) + nx.remove_node_attributes(G, attr, other_attr, nbunch=[0, 1]) + assert attr not in G.nodes[0] and other_attr not in G.nodes[0] + assert attr not in G.nodes[1] and other_attr not in G.nodes[1] + assert attr in G.nodes[2] and other_attr in G.nodes[2] + assert third_attr in G.nodes[0] and G.nodes[0][third_attr] == third_vals + assert third_attr in G.nodes[1] and G.nodes[1][third_attr] == third_vals + + +@pytest.mark.parametrize("graph_type", (nx.Graph, nx.DiGraph)) +def test_remove_edge_attributes(graph_type): + # Test removing single attribute + G = nx.path_graph(3, create_using=graph_type) + attr = "hello" + vals = 100 + nx.set_edge_attributes(G, vals, attr) + nx.remove_edge_attributes(G, attr) + assert len(nx.get_edge_attributes(G, attr)) == 0 + + # Test removing only some attributes + G = nx.path_graph(3, create_using=graph_type) + other_attr = "other" + other_vals = 200 + nx.set_edge_attributes(G, vals, attr) + nx.set_edge_attributes(G, other_vals, other_attr) + nx.remove_edge_attributes(G, attr) + + assert attr not in G[0][1] + assert attr not in G[1][2] + assert G[0][1][other_attr] == 200 + assert G[1][2][other_attr] == 200 + + # Test removing multiple attributes + G = nx.path_graph(3, create_using=graph_type) + nx.set_edge_attributes(G, vals, attr) + nx.set_edge_attributes(G, other_vals, other_attr) + nx.remove_edge_attributes(G, attr, other_attr) + assert attr not in G[0][1] and other_attr not in G[0][1] + assert attr not in G[1][2] and other_attr not in G[1][2] + + # Test removing multiple (not all) attributes + G = nx.path_graph(3, create_using=graph_type) + third_attr = "third" + third_vals = 300 + nx.set_edge_attributes( + G, + { + (u, v): {attr: vals, other_attr: other_vals, third_attr: third_vals} + for u, v in G.edges() + }, + ) + nx.remove_edge_attributes(G, other_attr, third_attr) + assert other_attr not in G[0][1] and third_attr not in G[0][1] + assert other_attr not in G[1][2] and third_attr not in G[1][2] + assert G[0][1][attr] == vals + assert G[1][2][attr] == vals + + # Test removing incomplete edge attributes + G = nx.path_graph(3, create_using=graph_type) + nx.set_edge_attributes(G, {(0, 1): {attr: vals, other_attr: other_vals}}) + nx.remove_edge_attributes(G, other_attr) + assert other_attr not in G[0][1] and G[0][1][attr] == vals + assert other_attr not in G[1][2] + + # Test removing subset of edge attributes + G = nx.path_graph(3, create_using=graph_type) + nx.set_edge_attributes( + G, + { + (u, v): {attr: vals, other_attr: other_vals, third_attr: third_vals} + for u, v in G.edges() + }, + ) + nx.remove_edge_attributes(G, other_attr, third_attr, ebunch=[(0, 1)]) + assert other_attr not in G[0][1] and third_attr not in G[0][1] + assert other_attr in G[1][2] and third_attr in G[1][2] + + +@pytest.mark.parametrize("graph_type", (nx.MultiGraph, nx.MultiDiGraph)) +def test_remove_multi_edge_attributes(graph_type): + # Test removing single attribute + G = nx.path_graph(3, create_using=graph_type) + G.add_edge(1, 2) + attr = "hello" + vals = 100 + nx.set_edge_attributes(G, vals, attr) + nx.remove_edge_attributes(G, attr) + assert attr not in G[0][1][0] + assert attr not in G[1][2][0] + assert attr not in G[1][2][1] + + # Test removing only some attributes + G = nx.path_graph(3, create_using=graph_type) + G.add_edge(1, 2) + other_attr = "other" + other_vals = 200 + nx.set_edge_attributes(G, vals, attr) + nx.set_edge_attributes(G, other_vals, other_attr) + nx.remove_edge_attributes(G, attr) + assert attr not in G[0][1][0] + assert attr not in G[1][2][0] + assert attr not in G[1][2][1] + assert G[0][1][0][other_attr] == other_vals + assert G[1][2][0][other_attr] == other_vals + assert G[1][2][1][other_attr] == other_vals + + # Test removing multiple attributes + G = nx.path_graph(3, create_using=graph_type) + G.add_edge(1, 2) + nx.set_edge_attributes(G, vals, attr) + nx.set_edge_attributes(G, other_vals, other_attr) + nx.remove_edge_attributes(G, attr, other_attr) + assert attr not in G[0][1][0] and other_attr not in G[0][1][0] + assert attr not in G[1][2][0] and other_attr not in G[1][2][0] + assert attr not in G[1][2][1] and other_attr not in G[1][2][1] + + # Test removing multiple (not all) attributes + G = nx.path_graph(3, create_using=graph_type) + G.add_edge(1, 2) + third_attr = "third" + third_vals = 300 + nx.set_edge_attributes( + G, + { + (u, v, k): {attr: vals, other_attr: other_vals, third_attr: third_vals} + for u, v, k in G.edges(keys=True) + }, + ) + nx.remove_edge_attributes(G, other_attr, third_attr) + assert other_attr not in G[0][1][0] and third_attr not in G[0][1][0] + assert other_attr not in G[1][2][0] and other_attr not in G[1][2][0] + assert other_attr not in G[1][2][1] and other_attr not in G[1][2][1] + assert G[0][1][0][attr] == vals + assert G[1][2][0][attr] == vals + assert G[1][2][1][attr] == vals + + # Test removing incomplete edge attributes + G = nx.path_graph(3, create_using=graph_type) + G.add_edge(1, 2) + nx.set_edge_attributes( + G, + { + (0, 1, 0): {attr: vals, other_attr: other_vals}, + (1, 2, 1): {attr: vals, other_attr: other_vals}, + }, + ) + nx.remove_edge_attributes(G, other_attr) + assert other_attr not in G[0][1][0] and G[0][1][0][attr] == vals + assert other_attr not in G[1][2][0] + assert other_attr not in G[1][2][1] + + # Test removing subset of edge attributes + G = nx.path_graph(3, create_using=graph_type) + G.add_edge(1, 2) + nx.set_edge_attributes( + G, + { + (0, 1, 0): {attr: vals, other_attr: other_vals}, + (1, 2, 0): {attr: vals, other_attr: other_vals}, + (1, 2, 1): {attr: vals, other_attr: other_vals}, + }, + ) + nx.remove_edge_attributes(G, attr, ebunch=[(0, 1, 0), (1, 2, 0)]) + assert attr not in G[0][1][0] and other_attr in G[0][1][0] + assert attr not in G[1][2][0] and other_attr in G[1][2][0] + assert attr in G[1][2][1] and other_attr in G[1][2][1] + + +def test_is_empty(): + graphs = [nx.Graph(), nx.DiGraph(), nx.MultiGraph(), nx.MultiDiGraph()] + for G in graphs: + assert nx.is_empty(G) + G.add_nodes_from(range(5)) + assert nx.is_empty(G) + G.add_edges_from([(1, 2), (3, 4)]) + assert not nx.is_empty(G) + + +@pytest.mark.parametrize( + "graph_type", [nx.Graph, nx.DiGraph, nx.MultiGraph, nx.MultiDiGraph] +) +def test_selfloops(graph_type): + G = nx.complete_graph(3, create_using=graph_type) + G.add_edge(0, 0) + assert nodes_equal(nx.nodes_with_selfloops(G), [0]) + assert edges_equal(nx.selfloop_edges(G), [(0, 0)]) + assert edges_equal(nx.selfloop_edges(G, data=True), [(0, 0, {})]) + assert nx.number_of_selfloops(G) == 1 + + +@pytest.mark.parametrize( + "graph_type", [nx.Graph, nx.DiGraph, nx.MultiGraph, nx.MultiDiGraph] +) +def test_selfloop_edges_attr(graph_type): + G = nx.complete_graph(3, create_using=graph_type) + G.add_edge(0, 0) + G.add_edge(1, 1, weight=2) + assert edges_equal( + nx.selfloop_edges(G, data=True), [(0, 0, {}), (1, 1, {"weight": 2})] + ) + assert edges_equal(nx.selfloop_edges(G, data="weight"), [(0, 0, None), (1, 1, 2)]) + + +def test_selfloop_edges_multi_with_data_and_keys(): + G = nx.complete_graph(3, create_using=nx.MultiGraph) + G.add_edge(0, 0, weight=10) + G.add_edge(0, 0, weight=100) + assert edges_equal( + nx.selfloop_edges(G, data="weight", keys=True), [(0, 0, 0, 10), (0, 0, 1, 100)] + ) + + +@pytest.mark.parametrize("graph_type", [nx.Graph, nx.DiGraph]) +def test_selfloops_removal(graph_type): + G = nx.complete_graph(3, create_using=graph_type) + G.add_edge(0, 0) + G.remove_edges_from(nx.selfloop_edges(G, keys=True)) + G.add_edge(0, 0) + G.remove_edges_from(nx.selfloop_edges(G, data=True)) + G.add_edge(0, 0) + G.remove_edges_from(nx.selfloop_edges(G, keys=True, data=True)) + + +@pytest.mark.parametrize("graph_type", [nx.MultiGraph, nx.MultiDiGraph]) +def test_selfloops_removal_multi(graph_type): + """test removing selfloops behavior vis-a-vis altering a dict while iterating. + cf. gh-4068""" + G = nx.complete_graph(3, create_using=graph_type) + # Defaults - see gh-4080 + G.add_edge(0, 0) + G.add_edge(0, 0) + G.remove_edges_from(nx.selfloop_edges(G)) + assert (0, 0) not in G.edges() + # With keys + G.add_edge(0, 0) + G.add_edge(0, 0) + with pytest.raises(RuntimeError): + G.remove_edges_from(nx.selfloop_edges(G, keys=True)) + # With data + G.add_edge(0, 0) + G.add_edge(0, 0) + with pytest.raises(TypeError): + G.remove_edges_from(nx.selfloop_edges(G, data=True)) + # With keys and data + G.add_edge(0, 0) + G.add_edge(0, 0) + with pytest.raises(RuntimeError): + G.remove_edges_from(nx.selfloop_edges(G, data=True, keys=True)) + + +def test_pathweight(): + valid_path = [1, 2, 3] + invalid_path = [1, 3, 2] + graphs = [nx.Graph(), nx.DiGraph(), nx.MultiGraph(), nx.MultiDiGraph()] + edges = [ + (1, 2, {"cost": 5, "dist": 6}), + (2, 3, {"cost": 3, "dist": 4}), + (1, 2, {"cost": 1, "dist": 2}), + ] + for graph in graphs: + graph.add_edges_from(edges) + assert nx.path_weight(graph, valid_path, "cost") == 4 + assert nx.path_weight(graph, valid_path, "dist") == 6 + pytest.raises(nx.NetworkXNoPath, nx.path_weight, graph, invalid_path, "cost") + + +@pytest.mark.parametrize( + "G", (nx.Graph(), nx.DiGraph(), nx.MultiGraph(), nx.MultiDiGraph()) +) +def test_ispath(G): + G.add_edges_from([(1, 2), (2, 3), (1, 2), (3, 4)]) + valid_path = [1, 2, 3, 4] + invalid_path = [1, 2, 4, 3] # wrong node order + another_invalid_path = [1, 2, 3, 4, 5] # contains node not in G + assert nx.is_path(G, valid_path) + assert not nx.is_path(G, invalid_path) + assert not nx.is_path(G, another_invalid_path) + + +@pytest.mark.parametrize("G", (nx.Graph(), nx.DiGraph())) +def test_restricted_view(G): + G.add_edges_from([(0, 1), (0, 2), (0, 3), (1, 0), (1, 1), (1, 2)]) + G.add_node(4) + H = nx.restricted_view(G, [0, 2, 5], [(1, 2), (3, 4)]) + assert set(H.nodes()) == {1, 3, 4} + assert set(H.edges()) == {(1, 1)} + + +@pytest.mark.parametrize("G", (nx.MultiGraph(), nx.MultiDiGraph())) +def test_restricted_view_multi(G): + G.add_edges_from( + [(0, 1, 0), (0, 2, 0), (0, 3, 0), (0, 1, 1), (1, 0, 0), (1, 1, 0), (1, 2, 0)] + ) + G.add_node(4) + H = nx.restricted_view(G, [0, 2, 5], [(1, 2, 0), (3, 4, 0)]) + assert set(H.nodes()) == {1, 3, 4} + assert set(H.edges()) == {(1, 1)} diff --git a/lib/python3.10/site-packages/networkx/classes/tests/test_graph.py b/lib/python3.10/site-packages/networkx/classes/tests/test_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..b0048a31f04eb583f4bcfc32385c2335b94467ad --- /dev/null +++ b/lib/python3.10/site-packages/networkx/classes/tests/test_graph.py @@ -0,0 +1,920 @@ +import gc +import pickle +import platform +import weakref + +import pytest + +import networkx as nx +from networkx.utils import edges_equal, graphs_equal, nodes_equal + + +class BaseGraphTester: + """Tests for data-structure independent graph class features.""" + + def test_contains(self): + G = self.K3 + assert 1 in G + assert 4 not in G + assert "b" not in G + assert [] not in G # no exception for nonhashable + assert {1: 1} not in G # no exception for nonhashable + + def test_order(self): + G = self.K3 + assert len(G) == 3 + assert G.order() == 3 + assert G.number_of_nodes() == 3 + + def test_nodes(self): + G = self.K3 + assert isinstance(G._node, G.node_dict_factory) + assert isinstance(G._adj, G.adjlist_outer_dict_factory) + assert all( + isinstance(adj, G.adjlist_inner_dict_factory) for adj in G._adj.values() + ) + assert sorted(G.nodes()) == self.k3nodes + assert sorted(G.nodes(data=True)) == [(0, {}), (1, {}), (2, {})] + + def test_none_node(self): + G = self.Graph() + with pytest.raises(ValueError): + G.add_node(None) + with pytest.raises(ValueError): + G.add_nodes_from([None]) + with pytest.raises(ValueError): + G.add_edge(0, None) + with pytest.raises(ValueError): + G.add_edges_from([(0, None)]) + + def test_has_node(self): + G = self.K3 + assert G.has_node(1) + assert not G.has_node(4) + assert not G.has_node([]) # no exception for nonhashable + assert not G.has_node({1: 1}) # no exception for nonhashable + + def test_has_edge(self): + G = self.K3 + assert G.has_edge(0, 1) + assert not G.has_edge(0, -1) + + def test_neighbors(self): + G = self.K3 + assert sorted(G.neighbors(0)) == [1, 2] + with pytest.raises(nx.NetworkXError): + G.neighbors(-1) + + @pytest.mark.skipif( + platform.python_implementation() == "PyPy", reason="PyPy gc is different" + ) + def test_memory_leak(self): + G = self.Graph() + + def count_objects_of_type(_type): + # Iterating over all objects tracked by gc can include weak references + # whose weakly-referenced objects may no longer exist. Calling `isinstance` + # on such a weak reference will raise ReferenceError. There are at least + # three workarounds for this: one is to compare type names instead of using + # `isinstance` such as `type(obj).__name__ == typename`, another is to use + # `type(obj) == _type`, and the last is to ignore ProxyTypes as we do below. + # NOTE: even if this safeguard is deemed unnecessary to pass NetworkX tests, + # we should still keep it for maximum safety for other NetworkX backends. + return sum( + 1 + for obj in gc.get_objects() + if not isinstance(obj, weakref.ProxyTypes) and isinstance(obj, _type) + ) + + gc.collect() + before = count_objects_of_type(self.Graph) + G.copy() + gc.collect() + after = count_objects_of_type(self.Graph) + assert before == after + + # test a subgraph of the base class + class MyGraph(self.Graph): + pass + + gc.collect() + G = MyGraph() + before = count_objects_of_type(MyGraph) + G.copy() + gc.collect() + after = count_objects_of_type(MyGraph) + assert before == after + + def test_edges(self): + G = self.K3 + assert isinstance(G._adj, G.adjlist_outer_dict_factory) + assert edges_equal(G.edges(), [(0, 1), (0, 2), (1, 2)]) + assert edges_equal(G.edges(0), [(0, 1), (0, 2)]) + assert edges_equal(G.edges([0, 1]), [(0, 1), (0, 2), (1, 2)]) + with pytest.raises(nx.NetworkXError): + G.edges(-1) + + def test_degree(self): + G = self.K3 + assert sorted(G.degree()) == [(0, 2), (1, 2), (2, 2)] + assert dict(G.degree()) == {0: 2, 1: 2, 2: 2} + assert G.degree(0) == 2 + with pytest.raises(nx.NetworkXError): + G.degree(-1) # node not in graph + + def test_size(self): + G = self.K3 + assert G.size() == 3 + assert G.number_of_edges() == 3 + + def test_nbunch_iter(self): + G = self.K3 + assert nodes_equal(G.nbunch_iter(), self.k3nodes) # all nodes + assert nodes_equal(G.nbunch_iter(0), [0]) # single node + assert nodes_equal(G.nbunch_iter([0, 1]), [0, 1]) # sequence + # sequence with none in graph + assert nodes_equal(G.nbunch_iter([-1]), []) + # string sequence with none in graph + assert nodes_equal(G.nbunch_iter("foo"), []) + # node not in graph doesn't get caught upon creation of iterator + bunch = G.nbunch_iter(-1) + # but gets caught when iterator used + with pytest.raises(nx.NetworkXError, match="is not a node or a sequence"): + list(bunch) + # unhashable doesn't get caught upon creation of iterator + bunch = G.nbunch_iter([0, 1, 2, {}]) + # but gets caught when iterator hits the unhashable + with pytest.raises( + nx.NetworkXError, match="in sequence nbunch is not a valid node" + ): + list(bunch) + + def test_nbunch_iter_node_format_raise(self): + # Tests that a node that would have failed string formatting + # doesn't cause an error when attempting to raise a + # :exc:`nx.NetworkXError`. + + # For more information, see pull request #1813. + G = self.Graph() + nbunch = [("x", set())] + with pytest.raises(nx.NetworkXError): + list(G.nbunch_iter(nbunch)) + + def test_selfloop_degree(self): + G = self.Graph() + G.add_edge(1, 1) + assert sorted(G.degree()) == [(1, 2)] + assert dict(G.degree()) == {1: 2} + assert G.degree(1) == 2 + assert sorted(G.degree([1])) == [(1, 2)] + assert G.degree(1, weight="weight") == 2 + + def test_selfloops(self): + G = self.K3.copy() + G.add_edge(0, 0) + assert nodes_equal(nx.nodes_with_selfloops(G), [0]) + assert edges_equal(nx.selfloop_edges(G), [(0, 0)]) + assert nx.number_of_selfloops(G) == 1 + G.remove_edge(0, 0) + G.add_edge(0, 0) + G.remove_edges_from([(0, 0)]) + G.add_edge(1, 1) + G.remove_node(1) + G.add_edge(0, 0) + G.add_edge(1, 1) + G.remove_nodes_from([0, 1]) + + def test_cache_reset(self): + G = self.K3.copy() + old_adj = G.adj + assert id(G.adj) == id(old_adj) + G._adj = {} + assert id(G.adj) != id(old_adj) + + old_nodes = G.nodes + assert id(G.nodes) == id(old_nodes) + G._node = {} + assert id(G.nodes) != id(old_nodes) + + def test_attributes_cached(self): + G = self.K3.copy() + assert id(G.nodes) == id(G.nodes) + assert id(G.edges) == id(G.edges) + assert id(G.degree) == id(G.degree) + assert id(G.adj) == id(G.adj) + + +class BaseAttrGraphTester(BaseGraphTester): + """Tests of graph class attribute features.""" + + def test_weighted_degree(self): + G = self.Graph() + G.add_edge(1, 2, weight=2, other=3) + G.add_edge(2, 3, weight=3, other=4) + assert sorted(d for n, d in G.degree(weight="weight")) == [2, 3, 5] + assert dict(G.degree(weight="weight")) == {1: 2, 2: 5, 3: 3} + assert G.degree(1, weight="weight") == 2 + assert nodes_equal((G.degree([1], weight="weight")), [(1, 2)]) + + assert nodes_equal((d for n, d in G.degree(weight="other")), [3, 7, 4]) + assert dict(G.degree(weight="other")) == {1: 3, 2: 7, 3: 4} + assert G.degree(1, weight="other") == 3 + assert edges_equal((G.degree([1], weight="other")), [(1, 3)]) + + def add_attributes(self, G): + G.graph["foo"] = [] + G.nodes[0]["foo"] = [] + G.remove_edge(1, 2) + ll = [] + G.add_edge(1, 2, foo=ll) + G.add_edge(2, 1, foo=ll) + + def test_name(self): + G = self.Graph(name="") + assert G.name == "" + G = self.Graph(name="test") + assert G.name == "test" + + def test_str_unnamed(self): + G = self.Graph() + G.add_edges_from([(1, 2), (2, 3)]) + assert str(G) == f"{type(G).__name__} with 3 nodes and 2 edges" + + def test_str_named(self): + G = self.Graph(name="foo") + G.add_edges_from([(1, 2), (2, 3)]) + assert str(G) == f"{type(G).__name__} named 'foo' with 3 nodes and 2 edges" + + def test_graph_chain(self): + G = self.Graph([(0, 1), (1, 2)]) + DG = G.to_directed(as_view=True) + SDG = DG.subgraph([0, 1]) + RSDG = SDG.reverse(copy=False) + assert G is DG._graph + assert DG is SDG._graph + assert SDG is RSDG._graph + + def test_copy(self): + G = self.Graph() + G.add_node(0) + G.add_edge(1, 2) + self.add_attributes(G) + # copy edge datadict but any container attr are same + H = G.copy() + self.graphs_equal(H, G) + self.different_attrdict(H, G) + self.shallow_copy_attrdict(H, G) + + def test_class_copy(self): + G = self.Graph() + G.add_node(0) + G.add_edge(1, 2) + self.add_attributes(G) + # copy edge datadict but any container attr are same + H = G.__class__(G) + self.graphs_equal(H, G) + self.different_attrdict(H, G) + self.shallow_copy_attrdict(H, G) + + def test_fresh_copy(self): + G = self.Graph() + G.add_node(0) + G.add_edge(1, 2) + self.add_attributes(G) + # copy graph structure but use fresh datadict + H = G.__class__() + H.add_nodes_from(G) + H.add_edges_from(G.edges()) + assert len(G.nodes[0]) == 1 + ddict = G.adj[1][2][0] if G.is_multigraph() else G.adj[1][2] + assert len(ddict) == 1 + assert len(H.nodes[0]) == 0 + ddict = H.adj[1][2][0] if H.is_multigraph() else H.adj[1][2] + assert len(ddict) == 0 + + def is_deepcopy(self, H, G): + self.graphs_equal(H, G) + self.different_attrdict(H, G) + self.deep_copy_attrdict(H, G) + + def deep_copy_attrdict(self, H, G): + self.deepcopy_graph_attr(H, G) + self.deepcopy_node_attr(H, G) + self.deepcopy_edge_attr(H, G) + + def deepcopy_graph_attr(self, H, G): + assert G.graph["foo"] == H.graph["foo"] + G.graph["foo"].append(1) + assert G.graph["foo"] != H.graph["foo"] + + def deepcopy_node_attr(self, H, G): + assert G.nodes[0]["foo"] == H.nodes[0]["foo"] + G.nodes[0]["foo"].append(1) + assert G.nodes[0]["foo"] != H.nodes[0]["foo"] + + def deepcopy_edge_attr(self, H, G): + assert G[1][2]["foo"] == H[1][2]["foo"] + G[1][2]["foo"].append(1) + assert G[1][2]["foo"] != H[1][2]["foo"] + + def is_shallow_copy(self, H, G): + self.graphs_equal(H, G) + self.shallow_copy_attrdict(H, G) + + def shallow_copy_attrdict(self, H, G): + self.shallow_copy_graph_attr(H, G) + self.shallow_copy_node_attr(H, G) + self.shallow_copy_edge_attr(H, G) + + def shallow_copy_graph_attr(self, H, G): + assert G.graph["foo"] == H.graph["foo"] + G.graph["foo"].append(1) + assert G.graph["foo"] == H.graph["foo"] + + def shallow_copy_node_attr(self, H, G): + assert G.nodes[0]["foo"] == H.nodes[0]["foo"] + G.nodes[0]["foo"].append(1) + assert G.nodes[0]["foo"] == H.nodes[0]["foo"] + + def shallow_copy_edge_attr(self, H, G): + assert G[1][2]["foo"] == H[1][2]["foo"] + G[1][2]["foo"].append(1) + assert G[1][2]["foo"] == H[1][2]["foo"] + + def same_attrdict(self, H, G): + old_foo = H[1][2]["foo"] + H.adj[1][2]["foo"] = "baz" + assert G.edges == H.edges + H.adj[1][2]["foo"] = old_foo + assert G.edges == H.edges + + old_foo = H.nodes[0]["foo"] + H.nodes[0]["foo"] = "baz" + assert G.nodes == H.nodes + H.nodes[0]["foo"] = old_foo + assert G.nodes == H.nodes + + def different_attrdict(self, H, G): + old_foo = H[1][2]["foo"] + H.adj[1][2]["foo"] = "baz" + assert G._adj != H._adj + H.adj[1][2]["foo"] = old_foo + assert G._adj == H._adj + + old_foo = H.nodes[0]["foo"] + H.nodes[0]["foo"] = "baz" + assert G._node != H._node + H.nodes[0]["foo"] = old_foo + assert G._node == H._node + + def graphs_equal(self, H, G): + assert G._adj == H._adj + assert G._node == H._node + assert G.graph == H.graph + assert G.name == H.name + if not G.is_directed() and not H.is_directed(): + assert H._adj[1][2] is H._adj[2][1] + assert G._adj[1][2] is G._adj[2][1] + else: # at least one is directed + if not G.is_directed(): + G._pred = G._adj + G._succ = G._adj + if not H.is_directed(): + H._pred = H._adj + H._succ = H._adj + assert G._pred == H._pred + assert G._succ == H._succ + assert H._succ[1][2] is H._pred[2][1] + assert G._succ[1][2] is G._pred[2][1] + + def test_graph_attr(self): + G = self.K3.copy() + G.graph["foo"] = "bar" + assert isinstance(G.graph, G.graph_attr_dict_factory) + assert G.graph["foo"] == "bar" + del G.graph["foo"] + assert G.graph == {} + H = self.Graph(foo="bar") + assert H.graph["foo"] == "bar" + + def test_node_attr(self): + G = self.K3.copy() + G.add_node(1, foo="bar") + assert all( + isinstance(d, G.node_attr_dict_factory) for u, d in G.nodes(data=True) + ) + assert nodes_equal(G.nodes(), [0, 1, 2]) + assert nodes_equal(G.nodes(data=True), [(0, {}), (1, {"foo": "bar"}), (2, {})]) + G.nodes[1]["foo"] = "baz" + assert nodes_equal(G.nodes(data=True), [(0, {}), (1, {"foo": "baz"}), (2, {})]) + assert nodes_equal(G.nodes(data="foo"), [(0, None), (1, "baz"), (2, None)]) + assert nodes_equal( + G.nodes(data="foo", default="bar"), [(0, "bar"), (1, "baz"), (2, "bar")] + ) + + def test_node_attr2(self): + G = self.K3.copy() + a = {"foo": "bar"} + G.add_node(3, **a) + assert nodes_equal(G.nodes(), [0, 1, 2, 3]) + assert nodes_equal( + G.nodes(data=True), [(0, {}), (1, {}), (2, {}), (3, {"foo": "bar"})] + ) + + def test_edge_lookup(self): + G = self.Graph() + G.add_edge(1, 2, foo="bar") + assert edges_equal(G.edges[1, 2], {"foo": "bar"}) + + def test_edge_attr(self): + G = self.Graph() + G.add_edge(1, 2, foo="bar") + assert all( + isinstance(d, G.edge_attr_dict_factory) for u, v, d in G.edges(data=True) + ) + assert edges_equal(G.edges(data=True), [(1, 2, {"foo": "bar"})]) + assert edges_equal(G.edges(data="foo"), [(1, 2, "bar")]) + + def test_edge_attr2(self): + G = self.Graph() + G.add_edges_from([(1, 2), (3, 4)], foo="foo") + assert edges_equal( + G.edges(data=True), [(1, 2, {"foo": "foo"}), (3, 4, {"foo": "foo"})] + ) + assert edges_equal(G.edges(data="foo"), [(1, 2, "foo"), (3, 4, "foo")]) + + def test_edge_attr3(self): + G = self.Graph() + G.add_edges_from([(1, 2, {"weight": 32}), (3, 4, {"weight": 64})], foo="foo") + assert edges_equal( + G.edges(data=True), + [ + (1, 2, {"foo": "foo", "weight": 32}), + (3, 4, {"foo": "foo", "weight": 64}), + ], + ) + + G.remove_edges_from([(1, 2), (3, 4)]) + G.add_edge(1, 2, data=7, spam="bar", bar="foo") + assert edges_equal( + G.edges(data=True), [(1, 2, {"data": 7, "spam": "bar", "bar": "foo"})] + ) + + def test_edge_attr4(self): + G = self.Graph() + G.add_edge(1, 2, data=7, spam="bar", bar="foo") + assert edges_equal( + G.edges(data=True), [(1, 2, {"data": 7, "spam": "bar", "bar": "foo"})] + ) + G[1][2]["data"] = 10 # OK to set data like this + assert edges_equal( + G.edges(data=True), [(1, 2, {"data": 10, "spam": "bar", "bar": "foo"})] + ) + + G.adj[1][2]["data"] = 20 + assert edges_equal( + G.edges(data=True), [(1, 2, {"data": 20, "spam": "bar", "bar": "foo"})] + ) + G.edges[1, 2]["data"] = 21 # another spelling, "edge" + assert edges_equal( + G.edges(data=True), [(1, 2, {"data": 21, "spam": "bar", "bar": "foo"})] + ) + G.adj[1][2]["listdata"] = [20, 200] + G.adj[1][2]["weight"] = 20 + dd = { + "data": 21, + "spam": "bar", + "bar": "foo", + "listdata": [20, 200], + "weight": 20, + } + assert edges_equal(G.edges(data=True), [(1, 2, dd)]) + + def test_to_undirected(self): + G = self.K3 + self.add_attributes(G) + H = nx.Graph(G) + self.is_shallow_copy(H, G) + self.different_attrdict(H, G) + H = G.to_undirected() + self.is_deepcopy(H, G) + + def test_to_directed_as_view(self): + H = nx.path_graph(2, create_using=self.Graph) + H2 = H.to_directed(as_view=True) + assert H is H2._graph + assert H2.has_edge(0, 1) + assert H2.has_edge(1, 0) or H.is_directed() + pytest.raises(nx.NetworkXError, H2.add_node, -1) + pytest.raises(nx.NetworkXError, H2.add_edge, 1, 2) + H.add_edge(1, 2) + assert H2.has_edge(1, 2) + assert H2.has_edge(2, 1) or H.is_directed() + + def test_to_undirected_as_view(self): + H = nx.path_graph(2, create_using=self.Graph) + H2 = H.to_undirected(as_view=True) + assert H is H2._graph + assert H2.has_edge(0, 1) + assert H2.has_edge(1, 0) + pytest.raises(nx.NetworkXError, H2.add_node, -1) + pytest.raises(nx.NetworkXError, H2.add_edge, 1, 2) + H.add_edge(1, 2) + assert H2.has_edge(1, 2) + assert H2.has_edge(2, 1) + + def test_directed_class(self): + G = self.Graph() + + class newGraph(G.to_undirected_class()): + def to_directed_class(self): + return newDiGraph + + def to_undirected_class(self): + return newGraph + + class newDiGraph(G.to_directed_class()): + def to_directed_class(self): + return newDiGraph + + def to_undirected_class(self): + return newGraph + + G = newDiGraph() if G.is_directed() else newGraph() + H = G.to_directed() + assert isinstance(H, newDiGraph) + H = G.to_undirected() + assert isinstance(H, newGraph) + + def test_to_directed(self): + G = self.K3 + self.add_attributes(G) + H = nx.DiGraph(G) + self.is_shallow_copy(H, G) + self.different_attrdict(H, G) + H = G.to_directed() + self.is_deepcopy(H, G) + + def test_subgraph(self): + G = self.K3 + self.add_attributes(G) + H = G.subgraph([0, 1, 2, 5]) + self.graphs_equal(H, G) + self.same_attrdict(H, G) + self.shallow_copy_attrdict(H, G) + + H = G.subgraph(0) + assert H.adj == {0: {}} + H = G.subgraph([]) + assert H.adj == {} + assert G.adj != {} + + def test_selfloops_attr(self): + G = self.K3.copy() + G.add_edge(0, 0) + G.add_edge(1, 1, weight=2) + assert edges_equal( + nx.selfloop_edges(G, data=True), [(0, 0, {}), (1, 1, {"weight": 2})] + ) + assert edges_equal( + nx.selfloop_edges(G, data="weight"), [(0, 0, None), (1, 1, 2)] + ) + + +class TestGraph(BaseAttrGraphTester): + """Tests specific to dict-of-dict-of-dict graph data structure""" + + def setup_method(self): + self.Graph = nx.Graph + # build dict-of-dict-of-dict K3 + ed1, ed2, ed3 = ({}, {}, {}) + self.k3adj = {0: {1: ed1, 2: ed2}, 1: {0: ed1, 2: ed3}, 2: {0: ed2, 1: ed3}} + self.k3edges = [(0, 1), (0, 2), (1, 2)] + self.k3nodes = [0, 1, 2] + self.K3 = self.Graph() + self.K3._adj = self.k3adj + self.K3._node = {} + self.K3._node[0] = {} + self.K3._node[1] = {} + self.K3._node[2] = {} + + def test_pickle(self): + G = self.K3 + pg = pickle.loads(pickle.dumps(G, -1)) + self.graphs_equal(pg, G) + pg = pickle.loads(pickle.dumps(G)) + self.graphs_equal(pg, G) + + def test_data_input(self): + G = self.Graph({1: [2], 2: [1]}, name="test") + assert G.name == "test" + assert sorted(G.adj.items()) == [(1, {2: {}}), (2, {1: {}})] + + def test_adjacency(self): + G = self.K3 + assert dict(G.adjacency()) == { + 0: {1: {}, 2: {}}, + 1: {0: {}, 2: {}}, + 2: {0: {}, 1: {}}, + } + + def test_getitem(self): + G = self.K3 + assert G.adj[0] == {1: {}, 2: {}} + assert G[0] == {1: {}, 2: {}} + with pytest.raises(KeyError): + G.__getitem__("j") + with pytest.raises(TypeError): + G.__getitem__(["A"]) + + def test_add_node(self): + G = self.Graph() + G.add_node(0) + assert G.adj == {0: {}} + # test add attributes + G.add_node(1, c="red") + G.add_node(2, c="blue") + G.add_node(3, c="red") + assert G.nodes[1]["c"] == "red" + assert G.nodes[2]["c"] == "blue" + assert G.nodes[3]["c"] == "red" + # test updating attributes + G.add_node(1, c="blue") + G.add_node(2, c="red") + G.add_node(3, c="blue") + assert G.nodes[1]["c"] == "blue" + assert G.nodes[2]["c"] == "red" + assert G.nodes[3]["c"] == "blue" + + def test_add_nodes_from(self): + G = self.Graph() + G.add_nodes_from([0, 1, 2]) + assert G.adj == {0: {}, 1: {}, 2: {}} + # test add attributes + G.add_nodes_from([0, 1, 2], c="red") + assert G.nodes[0]["c"] == "red" + assert G.nodes[2]["c"] == "red" + # test that attribute dicts are not the same + assert G.nodes[0] is not G.nodes[1] + # test updating attributes + G.add_nodes_from([0, 1, 2], c="blue") + assert G.nodes[0]["c"] == "blue" + assert G.nodes[2]["c"] == "blue" + assert G.nodes[0] is not G.nodes[1] + # test tuple input + H = self.Graph() + H.add_nodes_from(G.nodes(data=True)) + assert H.nodes[0]["c"] == "blue" + assert H.nodes[2]["c"] == "blue" + assert H.nodes[0] is not H.nodes[1] + # specific overrides general + H.add_nodes_from([0, (1, {"c": "green"}), (3, {"c": "cyan"})], c="red") + assert H.nodes[0]["c"] == "red" + assert H.nodes[1]["c"] == "green" + assert H.nodes[2]["c"] == "blue" + assert H.nodes[3]["c"] == "cyan" + + def test_remove_node(self): + G = self.K3.copy() + G.remove_node(0) + assert G.adj == {1: {2: {}}, 2: {1: {}}} + with pytest.raises(nx.NetworkXError): + G.remove_node(-1) + + # generator here to implement list,set,string... + + def test_remove_nodes_from(self): + G = self.K3.copy() + G.remove_nodes_from([0, 1]) + assert G.adj == {2: {}} + G.remove_nodes_from([-1]) # silent fail + + def test_add_edge(self): + G = self.Graph() + G.add_edge(0, 1) + assert G.adj == {0: {1: {}}, 1: {0: {}}} + G = self.Graph() + G.add_edge(*(0, 1)) + assert G.adj == {0: {1: {}}, 1: {0: {}}} + G = self.Graph() + with pytest.raises(ValueError): + G.add_edge(None, "anything") + + def test_add_edges_from(self): + G = self.Graph() + G.add_edges_from([(0, 1), (0, 2, {"weight": 3})]) + assert G.adj == { + 0: {1: {}, 2: {"weight": 3}}, + 1: {0: {}}, + 2: {0: {"weight": 3}}, + } + G = self.Graph() + G.add_edges_from([(0, 1), (0, 2, {"weight": 3}), (1, 2, {"data": 4})], data=2) + assert G.adj == { + 0: {1: {"data": 2}, 2: {"weight": 3, "data": 2}}, + 1: {0: {"data": 2}, 2: {"data": 4}}, + 2: {0: {"weight": 3, "data": 2}, 1: {"data": 4}}, + } + + with pytest.raises(nx.NetworkXError): + G.add_edges_from([(0,)]) # too few in tuple + with pytest.raises(nx.NetworkXError): + G.add_edges_from([(0, 1, 2, 3)]) # too many in tuple + with pytest.raises(TypeError): + G.add_edges_from([0]) # not a tuple + with pytest.raises(ValueError): + G.add_edges_from([(None, 3), (3, 2)]) # None cannot be a node + + def test_remove_edge(self): + G = self.K3.copy() + G.remove_edge(0, 1) + assert G.adj == {0: {2: {}}, 1: {2: {}}, 2: {0: {}, 1: {}}} + with pytest.raises(nx.NetworkXError): + G.remove_edge(-1, 0) + + def test_remove_edges_from(self): + G = self.K3.copy() + G.remove_edges_from([(0, 1)]) + assert G.adj == {0: {2: {}}, 1: {2: {}}, 2: {0: {}, 1: {}}} + G.remove_edges_from([(0, 0)]) # silent fail + + def test_clear(self): + G = self.K3.copy() + G.graph["name"] = "K3" + G.clear() + assert list(G.nodes) == [] + assert G.adj == {} + assert G.graph == {} + + def test_clear_edges(self): + G = self.K3.copy() + G.graph["name"] = "K3" + nodes = list(G.nodes) + G.clear_edges() + assert list(G.nodes) == nodes + assert G.adj == {0: {}, 1: {}, 2: {}} + assert list(G.edges) == [] + assert G.graph["name"] == "K3" + + def test_edges_data(self): + G = self.K3 + all_edges = [(0, 1, {}), (0, 2, {}), (1, 2, {})] + assert edges_equal(G.edges(data=True), all_edges) + assert edges_equal(G.edges(0, data=True), [(0, 1, {}), (0, 2, {})]) + assert edges_equal(G.edges([0, 1], data=True), all_edges) + with pytest.raises(nx.NetworkXError): + G.edges(-1, True) + + def test_get_edge_data(self): + G = self.K3.copy() + assert G.get_edge_data(0, 1) == {} + assert G[0][1] == {} + assert G.get_edge_data(10, 20) is None + assert G.get_edge_data(-1, 0) is None + assert G.get_edge_data(-1, 0, default=1) == 1 + + def test_update(self): + # specify both edges and nodes + G = self.K3.copy() + G.update(nodes=[3, (4, {"size": 2})], edges=[(4, 5), (6, 7, {"weight": 2})]) + nlist = [ + (0, {}), + (1, {}), + (2, {}), + (3, {}), + (4, {"size": 2}), + (5, {}), + (6, {}), + (7, {}), + ] + assert sorted(G.nodes.data()) == nlist + if G.is_directed(): + elist = [ + (0, 1, {}), + (0, 2, {}), + (1, 0, {}), + (1, 2, {}), + (2, 0, {}), + (2, 1, {}), + (4, 5, {}), + (6, 7, {"weight": 2}), + ] + else: + elist = [ + (0, 1, {}), + (0, 2, {}), + (1, 2, {}), + (4, 5, {}), + (6, 7, {"weight": 2}), + ] + assert sorted(G.edges.data()) == elist + assert G.graph == {} + + # no keywords -- order is edges, nodes + G = self.K3.copy() + G.update([(4, 5), (6, 7, {"weight": 2})], [3, (4, {"size": 2})]) + assert sorted(G.nodes.data()) == nlist + assert sorted(G.edges.data()) == elist + assert G.graph == {} + + # update using only a graph + G = self.Graph() + G.graph["foo"] = "bar" + G.add_node(2, data=4) + G.add_edge(0, 1, weight=0.5) + GG = G.copy() + H = self.Graph() + GG.update(H) + assert graphs_equal(G, GG) + H.update(G) + assert graphs_equal(H, G) + + # update nodes only + H = self.Graph() + H.update(nodes=[3, 4]) + assert H.nodes ^ {3, 4} == set() + assert H.size() == 0 + + # update edges only + H = self.Graph() + H.update(edges=[(3, 4)]) + assert sorted(H.edges.data()) == [(3, 4, {})] + assert H.size() == 1 + + # No inputs -> exception + with pytest.raises(nx.NetworkXError): + nx.Graph().update() + + +class TestEdgeSubgraph: + """Unit tests for the :meth:`Graph.edge_subgraph` method.""" + + def setup_method(self): + # Create a path graph on five nodes. + G = nx.path_graph(5) + # Add some node, edge, and graph attributes. + for i in range(5): + G.nodes[i]["name"] = f"node{i}" + G.edges[0, 1]["name"] = "edge01" + G.edges[3, 4]["name"] = "edge34" + G.graph["name"] = "graph" + # Get the subgraph induced by the first and last edges. + self.G = G + self.H = G.edge_subgraph([(0, 1), (3, 4)]) + + def test_correct_nodes(self): + """Tests that the subgraph has the correct nodes.""" + assert [0, 1, 3, 4] == sorted(self.H.nodes()) + + def test_correct_edges(self): + """Tests that the subgraph has the correct edges.""" + assert [(0, 1, "edge01"), (3, 4, "edge34")] == sorted(self.H.edges(data="name")) + + def test_add_node(self): + """Tests that adding a node to the original graph does not + affect the nodes of the subgraph. + + """ + self.G.add_node(5) + assert [0, 1, 3, 4] == sorted(self.H.nodes()) + + def test_remove_node(self): + """Tests that removing a node in the original graph does + affect the nodes of the subgraph. + + """ + self.G.remove_node(0) + assert [1, 3, 4] == sorted(self.H.nodes()) + + def test_node_attr_dict(self): + """Tests that the node attribute dictionary of the two graphs is + the same object. + + """ + for v in self.H: + assert self.G.nodes[v] == self.H.nodes[v] + # Making a change to G should make a change in H and vice versa. + self.G.nodes[0]["name"] = "foo" + assert self.G.nodes[0] == self.H.nodes[0] + self.H.nodes[1]["name"] = "bar" + assert self.G.nodes[1] == self.H.nodes[1] + + def test_edge_attr_dict(self): + """Tests that the edge attribute dictionary of the two graphs is + the same object. + + """ + for u, v in self.H.edges(): + assert self.G.edges[u, v] == self.H.edges[u, v] + # Making a change to G should make a change in H and vice versa. + self.G.edges[0, 1]["name"] = "foo" + assert self.G.edges[0, 1]["name"] == self.H.edges[0, 1]["name"] + self.H.edges[3, 4]["name"] = "bar" + assert self.G.edges[3, 4]["name"] == self.H.edges[3, 4]["name"] + + def test_graph_attr_dict(self): + """Tests that the graph attribute dictionary of the two graphs + is the same object. + + """ + assert self.G.graph is self.H.graph diff --git a/lib/python3.10/site-packages/networkx/classes/tests/test_graph_historical.py b/lib/python3.10/site-packages/networkx/classes/tests/test_graph_historical.py new file mode 100644 index 0000000000000000000000000000000000000000..36aba7100e7758f6357d66470f45b0fcd0f10145 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/classes/tests/test_graph_historical.py @@ -0,0 +1,13 @@ +"""Original NetworkX graph tests""" + +import networkx +import networkx as nx + +from .historical_tests import HistoricalTests + + +class TestGraphHistorical(HistoricalTests): + @classmethod + def setup_class(cls): + HistoricalTests.setup_class() + cls.G = nx.Graph diff --git a/lib/python3.10/site-packages/networkx/classes/tests/test_graphviews.py b/lib/python3.10/site-packages/networkx/classes/tests/test_graphviews.py new file mode 100644 index 0000000000000000000000000000000000000000..591c760c9bdfa61984dd7b87e938b1c15c04be0e --- /dev/null +++ b/lib/python3.10/site-packages/networkx/classes/tests/test_graphviews.py @@ -0,0 +1,350 @@ +import pytest + +import networkx as nx +from networkx.utils import edges_equal, nodes_equal + +# Note: SubGraph views are not tested here. They have their own testing file + + +class TestReverseView: + def setup_method(self): + self.G = nx.path_graph(9, create_using=nx.DiGraph()) + self.rv = nx.reverse_view(self.G) + + def test_pickle(self): + import pickle + + rv = self.rv + prv = pickle.loads(pickle.dumps(rv, -1)) + assert rv._node == prv._node + assert rv._adj == prv._adj + assert rv.graph == prv.graph + + def test_contains(self): + assert (2, 3) in self.G.edges + assert (3, 2) not in self.G.edges + assert (2, 3) not in self.rv.edges + assert (3, 2) in self.rv.edges + + def test_iter(self): + expected = sorted(tuple(reversed(e)) for e in self.G.edges) + assert sorted(self.rv.edges) == expected + + def test_exceptions(self): + G = nx.Graph() + pytest.raises(nx.NetworkXNotImplemented, nx.reverse_view, G) + + def test_subclass(self): + class MyGraph(nx.DiGraph): + def my_method(self): + return "me" + + def to_directed_class(self): + return MyGraph() + + M = MyGraph() + M.add_edge(1, 2) + RM = nx.reverse_view(M) + print("RM class", RM.__class__) + RMC = RM.copy() + print("RMC class", RMC.__class__) + print(RMC.edges) + assert RMC.has_edge(2, 1) + assert RMC.my_method() == "me" + + +class TestMultiReverseView: + def setup_method(self): + self.G = nx.path_graph(9, create_using=nx.MultiDiGraph()) + self.G.add_edge(4, 5) + self.rv = nx.reverse_view(self.G) + + def test_pickle(self): + import pickle + + rv = self.rv + prv = pickle.loads(pickle.dumps(rv, -1)) + assert rv._node == prv._node + assert rv._adj == prv._adj + assert rv.graph == prv.graph + + def test_contains(self): + assert (2, 3, 0) in self.G.edges + assert (3, 2, 0) not in self.G.edges + assert (2, 3, 0) not in self.rv.edges + assert (3, 2, 0) in self.rv.edges + assert (5, 4, 1) in self.rv.edges + assert (4, 5, 1) not in self.rv.edges + + def test_iter(self): + expected = sorted((v, u, k) for u, v, k in self.G.edges) + assert sorted(self.rv.edges) == expected + + def test_exceptions(self): + MG = nx.MultiGraph(self.G) + pytest.raises(nx.NetworkXNotImplemented, nx.reverse_view, MG) + + +def test_generic_multitype(): + nxg = nx.graphviews + G = nx.DiGraph([(1, 2)]) + with pytest.raises(nx.NetworkXError): + nxg.generic_graph_view(G, create_using=nx.MultiGraph) + G = nx.MultiDiGraph([(1, 2)]) + with pytest.raises(nx.NetworkXError): + nxg.generic_graph_view(G, create_using=nx.DiGraph) + + +class TestToDirected: + def setup_method(self): + self.G = nx.path_graph(9) + self.dv = nx.to_directed(self.G) + self.MG = nx.path_graph(9, create_using=nx.MultiGraph()) + self.Mdv = nx.to_directed(self.MG) + + def test_directed(self): + assert not self.G.is_directed() + assert self.dv.is_directed() + + def test_already_directed(self): + dd = nx.to_directed(self.dv) + Mdd = nx.to_directed(self.Mdv) + assert edges_equal(dd.edges, self.dv.edges) + assert edges_equal(Mdd.edges, self.Mdv.edges) + + def test_pickle(self): + import pickle + + dv = self.dv + pdv = pickle.loads(pickle.dumps(dv, -1)) + assert dv._node == pdv._node + assert dv._succ == pdv._succ + assert dv._pred == pdv._pred + assert dv.graph == pdv.graph + + def test_contains(self): + assert (2, 3) in self.G.edges + assert (3, 2) in self.G.edges + assert (2, 3) in self.dv.edges + assert (3, 2) in self.dv.edges + + def test_iter(self): + revd = [tuple(reversed(e)) for e in self.G.edges] + expected = sorted(list(self.G.edges) + revd) + assert sorted(self.dv.edges) == expected + + +class TestToUndirected: + def setup_method(self): + self.DG = nx.path_graph(9, create_using=nx.DiGraph()) + self.uv = nx.to_undirected(self.DG) + self.MDG = nx.path_graph(9, create_using=nx.MultiDiGraph()) + self.Muv = nx.to_undirected(self.MDG) + + def test_directed(self): + assert self.DG.is_directed() + assert not self.uv.is_directed() + + def test_already_directed(self): + uu = nx.to_undirected(self.uv) + Muu = nx.to_undirected(self.Muv) + assert edges_equal(uu.edges, self.uv.edges) + assert edges_equal(Muu.edges, self.Muv.edges) + + def test_pickle(self): + import pickle + + uv = self.uv + puv = pickle.loads(pickle.dumps(uv, -1)) + assert uv._node == puv._node + assert uv._adj == puv._adj + assert uv.graph == puv.graph + assert hasattr(uv, "_graph") + + def test_contains(self): + assert (2, 3) in self.DG.edges + assert (3, 2) not in self.DG.edges + assert (2, 3) in self.uv.edges + assert (3, 2) in self.uv.edges + + def test_iter(self): + expected = sorted(self.DG.edges) + assert sorted(self.uv.edges) == expected + + +class TestChainsOfViews: + @classmethod + def setup_class(cls): + cls.G = nx.path_graph(9) + cls.DG = nx.path_graph(9, create_using=nx.DiGraph()) + cls.MG = nx.path_graph(9, create_using=nx.MultiGraph()) + cls.MDG = nx.path_graph(9, create_using=nx.MultiDiGraph()) + cls.Gv = nx.to_undirected(cls.DG) + cls.DGv = nx.to_directed(cls.G) + cls.MGv = nx.to_undirected(cls.MDG) + cls.MDGv = nx.to_directed(cls.MG) + cls.Rv = cls.DG.reverse() + cls.MRv = cls.MDG.reverse() + cls.graphs = [ + cls.G, + cls.DG, + cls.MG, + cls.MDG, + cls.Gv, + cls.DGv, + cls.MGv, + cls.MDGv, + cls.Rv, + cls.MRv, + ] + for G in cls.graphs: + G.edges, G.nodes, G.degree + + def test_pickle(self): + import pickle + + for G in self.graphs: + H = pickle.loads(pickle.dumps(G, -1)) + assert edges_equal(H.edges, G.edges) + assert nodes_equal(H.nodes, G.nodes) + + def test_subgraph_of_subgraph(self): + SGv = nx.subgraph(self.G, range(3, 7)) + SDGv = nx.subgraph(self.DG, range(3, 7)) + SMGv = nx.subgraph(self.MG, range(3, 7)) + SMDGv = nx.subgraph(self.MDG, range(3, 7)) + for G in self.graphs + [SGv, SDGv, SMGv, SMDGv]: + SG = nx.induced_subgraph(G, [4, 5, 6]) + assert list(SG) == [4, 5, 6] + SSG = SG.subgraph([6, 7]) + assert list(SSG) == [6] + # subgraph-subgraph chain is short-cut in base class method + assert SSG._graph is G + + def test_restricted_induced_subgraph_chains(self): + """Test subgraph chains that both restrict and show nodes/edges. + + A restricted_view subgraph should allow induced subgraphs using + G.subgraph that automagically without a chain (meaning the result + is a subgraph view of the original graph not a subgraph-of-subgraph. + """ + hide_nodes = [3, 4, 5] + hide_edges = [(6, 7)] + RG = nx.restricted_view(self.G, hide_nodes, hide_edges) + nodes = [4, 5, 6, 7, 8] + SG = nx.induced_subgraph(RG, nodes) + SSG = RG.subgraph(nodes) + assert RG._graph is self.G + assert SSG._graph is self.G + assert SG._graph is RG + assert edges_equal(SG.edges, SSG.edges) + # should be same as morphing the graph + CG = self.G.copy() + CG.remove_nodes_from(hide_nodes) + CG.remove_edges_from(hide_edges) + assert edges_equal(CG.edges(nodes), SSG.edges) + CG.remove_nodes_from([0, 1, 2, 3]) + assert edges_equal(CG.edges, SSG.edges) + # switch order: subgraph first, then restricted view + SSSG = self.G.subgraph(nodes) + RSG = nx.restricted_view(SSSG, hide_nodes, hide_edges) + assert RSG._graph is not self.G + assert edges_equal(RSG.edges, CG.edges) + + def test_subgraph_copy(self): + for origG in self.graphs: + G = nx.Graph(origG) + SG = G.subgraph([4, 5, 6]) + H = SG.copy() + assert type(G) == type(H) + + def test_subgraph_todirected(self): + SG = nx.induced_subgraph(self.G, [4, 5, 6]) + SSG = SG.to_directed() + assert sorted(SSG) == [4, 5, 6] + assert sorted(SSG.edges) == [(4, 5), (5, 4), (5, 6), (6, 5)] + + def test_subgraph_toundirected(self): + SG = nx.induced_subgraph(self.G, [4, 5, 6]) + SSG = SG.to_undirected() + assert list(SSG) == [4, 5, 6] + assert sorted(SSG.edges) == [(4, 5), (5, 6)] + + def test_reverse_subgraph_toundirected(self): + G = self.DG.reverse(copy=False) + SG = G.subgraph([4, 5, 6]) + SSG = SG.to_undirected() + assert list(SSG) == [4, 5, 6] + assert sorted(SSG.edges) == [(4, 5), (5, 6)] + + def test_reverse_reverse_copy(self): + G = self.DG.reverse(copy=False) + H = G.reverse(copy=True) + assert H.nodes == self.DG.nodes + assert H.edges == self.DG.edges + G = self.MDG.reverse(copy=False) + H = G.reverse(copy=True) + assert H.nodes == self.MDG.nodes + assert H.edges == self.MDG.edges + + def test_subgraph_edgesubgraph_toundirected(self): + G = self.G.copy() + SG = G.subgraph([4, 5, 6]) + SSG = SG.edge_subgraph([(4, 5), (5, 4)]) + USSG = SSG.to_undirected() + assert list(USSG) == [4, 5] + assert sorted(USSG.edges) == [(4, 5)] + + def test_copy_subgraph(self): + G = self.G.copy() + SG = G.subgraph([4, 5, 6]) + CSG = SG.copy(as_view=True) + DCSG = SG.copy(as_view=False) + assert hasattr(CSG, "_graph") # is a view + assert not hasattr(DCSG, "_graph") # not a view + + def test_copy_disubgraph(self): + G = self.DG.copy() + SG = G.subgraph([4, 5, 6]) + CSG = SG.copy(as_view=True) + DCSG = SG.copy(as_view=False) + assert hasattr(CSG, "_graph") # is a view + assert not hasattr(DCSG, "_graph") # not a view + + def test_copy_multidisubgraph(self): + G = self.MDG.copy() + SG = G.subgraph([4, 5, 6]) + CSG = SG.copy(as_view=True) + DCSG = SG.copy(as_view=False) + assert hasattr(CSG, "_graph") # is a view + assert not hasattr(DCSG, "_graph") # not a view + + def test_copy_multisubgraph(self): + G = self.MG.copy() + SG = G.subgraph([4, 5, 6]) + CSG = SG.copy(as_view=True) + DCSG = SG.copy(as_view=False) + assert hasattr(CSG, "_graph") # is a view + assert not hasattr(DCSG, "_graph") # not a view + + def test_copy_of_view(self): + G = nx.MultiGraph(self.MGv) + assert G.__class__.__name__ == "MultiGraph" + G = G.copy(as_view=True) + assert G.__class__.__name__ == "MultiGraph" + + def test_subclass(self): + class MyGraph(nx.DiGraph): + def my_method(self): + return "me" + + def to_directed_class(self): + return MyGraph() + + for origG in self.graphs: + G = MyGraph(origG) + SG = G.subgraph([4, 5, 6]) + H = SG.copy() + assert SG.my_method() == "me" + assert H.my_method() == "me" + assert 3 not in H or 3 in SG diff --git a/lib/python3.10/site-packages/networkx/classes/tests/test_multidigraph.py b/lib/python3.10/site-packages/networkx/classes/tests/test_multidigraph.py new file mode 100644 index 0000000000000000000000000000000000000000..fc0bd5467d0a62dc8f533af7a6c5bbc0a57fc010 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/classes/tests/test_multidigraph.py @@ -0,0 +1,459 @@ +from collections import UserDict + +import pytest + +import networkx as nx +from networkx.utils import edges_equal + +from .test_multigraph import BaseMultiGraphTester +from .test_multigraph import TestEdgeSubgraph as _TestMultiGraphEdgeSubgraph +from .test_multigraph import TestMultiGraph as _TestMultiGraph + + +class BaseMultiDiGraphTester(BaseMultiGraphTester): + def test_edges(self): + G = self.K3 + edges = [(0, 1), (0, 2), (1, 0), (1, 2), (2, 0), (2, 1)] + assert sorted(G.edges()) == edges + assert sorted(G.edges(0)) == [(0, 1), (0, 2)] + pytest.raises((KeyError, nx.NetworkXError), G.edges, -1) + + def test_edges_data(self): + G = self.K3 + edges = [(0, 1, {}), (0, 2, {}), (1, 0, {}), (1, 2, {}), (2, 0, {}), (2, 1, {})] + assert sorted(G.edges(data=True)) == edges + assert sorted(G.edges(0, data=True)) == [(0, 1, {}), (0, 2, {})] + pytest.raises((KeyError, nx.NetworkXError), G.neighbors, -1) + + def test_edges_multi(self): + G = self.K3 + assert sorted(G.edges()) == [(0, 1), (0, 2), (1, 0), (1, 2), (2, 0), (2, 1)] + assert sorted(G.edges(0)) == [(0, 1), (0, 2)] + G.add_edge(0, 1) + assert sorted(G.edges()) == [ + (0, 1), + (0, 1), + (0, 2), + (1, 0), + (1, 2), + (2, 0), + (2, 1), + ] + + def test_out_edges(self): + G = self.K3 + assert sorted(G.out_edges()) == [(0, 1), (0, 2), (1, 0), (1, 2), (2, 0), (2, 1)] + assert sorted(G.out_edges(0)) == [(0, 1), (0, 2)] + pytest.raises((KeyError, nx.NetworkXError), G.out_edges, -1) + assert sorted(G.out_edges(0, keys=True)) == [(0, 1, 0), (0, 2, 0)] + + def test_out_edges_multi(self): + G = self.K3 + assert sorted(G.out_edges()) == [(0, 1), (0, 2), (1, 0), (1, 2), (2, 0), (2, 1)] + assert sorted(G.out_edges(0)) == [(0, 1), (0, 2)] + G.add_edge(0, 1, 2) + assert sorted(G.out_edges()) == [ + (0, 1), + (0, 1), + (0, 2), + (1, 0), + (1, 2), + (2, 0), + (2, 1), + ] + + def test_out_edges_data(self): + G = self.K3 + assert sorted(G.edges(0, data=True)) == [(0, 1, {}), (0, 2, {})] + G.remove_edge(0, 1) + G.add_edge(0, 1, data=1) + assert sorted(G.edges(0, data=True)) == [(0, 1, {"data": 1}), (0, 2, {})] + assert sorted(G.edges(0, data="data")) == [(0, 1, 1), (0, 2, None)] + assert sorted(G.edges(0, data="data", default=-1)) == [(0, 1, 1), (0, 2, -1)] + + def test_in_edges(self): + G = self.K3 + assert sorted(G.in_edges()) == [(0, 1), (0, 2), (1, 0), (1, 2), (2, 0), (2, 1)] + assert sorted(G.in_edges(0)) == [(1, 0), (2, 0)] + pytest.raises((KeyError, nx.NetworkXError), G.in_edges, -1) + G.add_edge(0, 1, 2) + assert sorted(G.in_edges()) == [ + (0, 1), + (0, 1), + (0, 2), + (1, 0), + (1, 2), + (2, 0), + (2, 1), + ] + assert sorted(G.in_edges(0, keys=True)) == [(1, 0, 0), (2, 0, 0)] + + def test_in_edges_no_keys(self): + G = self.K3 + assert sorted(G.in_edges()) == [(0, 1), (0, 2), (1, 0), (1, 2), (2, 0), (2, 1)] + assert sorted(G.in_edges(0)) == [(1, 0), (2, 0)] + G.add_edge(0, 1, 2) + assert sorted(G.in_edges()) == [ + (0, 1), + (0, 1), + (0, 2), + (1, 0), + (1, 2), + (2, 0), + (2, 1), + ] + + assert sorted(G.in_edges(data=True, keys=False)) == [ + (0, 1, {}), + (0, 1, {}), + (0, 2, {}), + (1, 0, {}), + (1, 2, {}), + (2, 0, {}), + (2, 1, {}), + ] + + def test_in_edges_data(self): + G = self.K3 + assert sorted(G.in_edges(0, data=True)) == [(1, 0, {}), (2, 0, {})] + G.remove_edge(1, 0) + G.add_edge(1, 0, data=1) + assert sorted(G.in_edges(0, data=True)) == [(1, 0, {"data": 1}), (2, 0, {})] + assert sorted(G.in_edges(0, data="data")) == [(1, 0, 1), (2, 0, None)] + assert sorted(G.in_edges(0, data="data", default=-1)) == [(1, 0, 1), (2, 0, -1)] + + def is_shallow(self, H, G): + # graph + assert G.graph["foo"] == H.graph["foo"] + G.graph["foo"].append(1) + assert G.graph["foo"] == H.graph["foo"] + # node + assert G.nodes[0]["foo"] == H.nodes[0]["foo"] + G.nodes[0]["foo"].append(1) + assert G.nodes[0]["foo"] == H.nodes[0]["foo"] + # edge + assert G[1][2][0]["foo"] == H[1][2][0]["foo"] + G[1][2][0]["foo"].append(1) + assert G[1][2][0]["foo"] == H[1][2][0]["foo"] + + def is_deep(self, H, G): + # graph + assert G.graph["foo"] == H.graph["foo"] + G.graph["foo"].append(1) + assert G.graph["foo"] != H.graph["foo"] + # node + assert G.nodes[0]["foo"] == H.nodes[0]["foo"] + G.nodes[0]["foo"].append(1) + assert G.nodes[0]["foo"] != H.nodes[0]["foo"] + # edge + assert G[1][2][0]["foo"] == H[1][2][0]["foo"] + G[1][2][0]["foo"].append(1) + assert G[1][2][0]["foo"] != H[1][2][0]["foo"] + + def test_to_undirected(self): + # MultiDiGraph -> MultiGraph changes number of edges so it is + # not a copy operation... use is_shallow, not is_shallow_copy + G = self.K3 + self.add_attributes(G) + H = nx.MultiGraph(G) + # self.is_shallow(H,G) + # the result is traversal order dependent so we + # can't use the is_shallow() test here. + try: + assert edges_equal(H.edges(), [(0, 1), (1, 2), (2, 0)]) + except AssertionError: + assert edges_equal(H.edges(), [(0, 1), (1, 2), (1, 2), (2, 0)]) + H = G.to_undirected() + self.is_deep(H, G) + + def test_has_successor(self): + G = self.K3 + assert G.has_successor(0, 1) + assert not G.has_successor(0, -1) + + def test_successors(self): + G = self.K3 + assert sorted(G.successors(0)) == [1, 2] + pytest.raises((KeyError, nx.NetworkXError), G.successors, -1) + + def test_has_predecessor(self): + G = self.K3 + assert G.has_predecessor(0, 1) + assert not G.has_predecessor(0, -1) + + def test_predecessors(self): + G = self.K3 + assert sorted(G.predecessors(0)) == [1, 2] + pytest.raises((KeyError, nx.NetworkXError), G.predecessors, -1) + + def test_degree(self): + G = self.K3 + assert sorted(G.degree()) == [(0, 4), (1, 4), (2, 4)] + assert dict(G.degree()) == {0: 4, 1: 4, 2: 4} + assert G.degree(0) == 4 + assert list(G.degree(iter([0]))) == [(0, 4)] + G.add_edge(0, 1, weight=0.3, other=1.2) + assert sorted(G.degree(weight="weight")) == [(0, 4.3), (1, 4.3), (2, 4)] + assert sorted(G.degree(weight="other")) == [(0, 5.2), (1, 5.2), (2, 4)] + + def test_in_degree(self): + G = self.K3 + assert sorted(G.in_degree()) == [(0, 2), (1, 2), (2, 2)] + assert dict(G.in_degree()) == {0: 2, 1: 2, 2: 2} + assert G.in_degree(0) == 2 + assert list(G.in_degree(iter([0]))) == [(0, 2)] + assert G.in_degree(0, weight="weight") == 2 + + def test_out_degree(self): + G = self.K3 + assert sorted(G.out_degree()) == [(0, 2), (1, 2), (2, 2)] + assert dict(G.out_degree()) == {0: 2, 1: 2, 2: 2} + assert G.out_degree(0) == 2 + assert list(G.out_degree(iter([0]))) == [(0, 2)] + assert G.out_degree(0, weight="weight") == 2 + + def test_size(self): + G = self.K3 + assert G.size() == 6 + assert G.number_of_edges() == 6 + G.add_edge(0, 1, weight=0.3, other=1.2) + assert round(G.size(weight="weight"), 2) == 6.3 + assert round(G.size(weight="other"), 2) == 7.2 + + def test_to_undirected_reciprocal(self): + G = self.Graph() + G.add_edge(1, 2) + assert G.to_undirected().has_edge(1, 2) + assert not G.to_undirected(reciprocal=True).has_edge(1, 2) + G.add_edge(2, 1) + assert G.to_undirected(reciprocal=True).has_edge(1, 2) + + def test_reverse_copy(self): + G = nx.MultiDiGraph([(0, 1), (0, 1)]) + R = G.reverse() + assert sorted(R.edges()) == [(1, 0), (1, 0)] + R.remove_edge(1, 0) + assert sorted(R.edges()) == [(1, 0)] + assert sorted(G.edges()) == [(0, 1), (0, 1)] + + def test_reverse_nocopy(self): + G = nx.MultiDiGraph([(0, 1), (0, 1)]) + R = G.reverse(copy=False) + assert sorted(R.edges()) == [(1, 0), (1, 0)] + pytest.raises(nx.NetworkXError, R.remove_edge, 1, 0) + + def test_di_attributes_cached(self): + G = self.K3.copy() + assert id(G.in_edges) == id(G.in_edges) + assert id(G.out_edges) == id(G.out_edges) + assert id(G.in_degree) == id(G.in_degree) + assert id(G.out_degree) == id(G.out_degree) + assert id(G.succ) == id(G.succ) + assert id(G.pred) == id(G.pred) + + +class TestMultiDiGraph(BaseMultiDiGraphTester, _TestMultiGraph): + def setup_method(self): + self.Graph = nx.MultiDiGraph + # build K3 + self.k3edges = [(0, 1), (0, 2), (1, 2)] + self.k3nodes = [0, 1, 2] + self.K3 = self.Graph() + self.K3._succ = {0: {}, 1: {}, 2: {}} + # K3._adj is synced with K3._succ + self.K3._pred = {0: {}, 1: {}, 2: {}} + for u in self.k3nodes: + for v in self.k3nodes: + if u == v: + continue + d = {0: {}} + self.K3._succ[u][v] = d + self.K3._pred[v][u] = d + self.K3._node = {} + self.K3._node[0] = {} + self.K3._node[1] = {} + self.K3._node[2] = {} + + def test_add_edge(self): + G = self.Graph() + G.add_edge(0, 1) + assert G._adj == {0: {1: {0: {}}}, 1: {}} + assert G._succ == {0: {1: {0: {}}}, 1: {}} + assert G._pred == {0: {}, 1: {0: {0: {}}}} + G = self.Graph() + G.add_edge(*(0, 1)) + assert G._adj == {0: {1: {0: {}}}, 1: {}} + assert G._succ == {0: {1: {0: {}}}, 1: {}} + assert G._pred == {0: {}, 1: {0: {0: {}}}} + with pytest.raises(ValueError, match="None cannot be a node"): + G.add_edge(None, 3) + + def test_add_edges_from(self): + G = self.Graph() + G.add_edges_from([(0, 1), (0, 1, {"weight": 3})]) + assert G._adj == {0: {1: {0: {}, 1: {"weight": 3}}}, 1: {}} + assert G._succ == {0: {1: {0: {}, 1: {"weight": 3}}}, 1: {}} + assert G._pred == {0: {}, 1: {0: {0: {}, 1: {"weight": 3}}}} + + G.add_edges_from([(0, 1), (0, 1, {"weight": 3})], weight=2) + assert G._succ == { + 0: {1: {0: {}, 1: {"weight": 3}, 2: {"weight": 2}, 3: {"weight": 3}}}, + 1: {}, + } + assert G._pred == { + 0: {}, + 1: {0: {0: {}, 1: {"weight": 3}, 2: {"weight": 2}, 3: {"weight": 3}}}, + } + + G = self.Graph() + edges = [ + (0, 1, {"weight": 3}), + (0, 1, (("weight", 2),)), + (0, 1, 5), + (0, 1, "s"), + ] + G.add_edges_from(edges) + keydict = {0: {"weight": 3}, 1: {"weight": 2}, 5: {}, "s": {}} + assert G._succ == {0: {1: keydict}, 1: {}} + assert G._pred == {1: {0: keydict}, 0: {}} + + # too few in tuple + pytest.raises(nx.NetworkXError, G.add_edges_from, [(0,)]) + # too many in tuple + pytest.raises(nx.NetworkXError, G.add_edges_from, [(0, 1, 2, 3, 4)]) + # not a tuple + pytest.raises(TypeError, G.add_edges_from, [0]) + with pytest.raises(ValueError, match="None cannot be a node"): + G.add_edges_from([(None, 3), (3, 2)]) + + def test_remove_edge(self): + G = self.K3 + G.remove_edge(0, 1) + assert G._succ == { + 0: {2: {0: {}}}, + 1: {0: {0: {}}, 2: {0: {}}}, + 2: {0: {0: {}}, 1: {0: {}}}, + } + assert G._pred == { + 0: {1: {0: {}}, 2: {0: {}}}, + 1: {2: {0: {}}}, + 2: {0: {0: {}}, 1: {0: {}}}, + } + pytest.raises((KeyError, nx.NetworkXError), G.remove_edge, -1, 0) + pytest.raises((KeyError, nx.NetworkXError), G.remove_edge, 0, 2, key=1) + + def test_remove_multiedge(self): + G = self.K3 + G.add_edge(0, 1, key="parallel edge") + G.remove_edge(0, 1, key="parallel edge") + assert G._adj == { + 0: {1: {0: {}}, 2: {0: {}}}, + 1: {0: {0: {}}, 2: {0: {}}}, + 2: {0: {0: {}}, 1: {0: {}}}, + } + + assert G._succ == { + 0: {1: {0: {}}, 2: {0: {}}}, + 1: {0: {0: {}}, 2: {0: {}}}, + 2: {0: {0: {}}, 1: {0: {}}}, + } + + assert G._pred == { + 0: {1: {0: {}}, 2: {0: {}}}, + 1: {0: {0: {}}, 2: {0: {}}}, + 2: {0: {0: {}}, 1: {0: {}}}, + } + G.remove_edge(0, 1) + assert G._succ == { + 0: {2: {0: {}}}, + 1: {0: {0: {}}, 2: {0: {}}}, + 2: {0: {0: {}}, 1: {0: {}}}, + } + assert G._pred == { + 0: {1: {0: {}}, 2: {0: {}}}, + 1: {2: {0: {}}}, + 2: {0: {0: {}}, 1: {0: {}}}, + } + pytest.raises((KeyError, nx.NetworkXError), G.remove_edge, -1, 0) + + def test_remove_edges_from(self): + G = self.K3 + G.remove_edges_from([(0, 1)]) + assert G._succ == { + 0: {2: {0: {}}}, + 1: {0: {0: {}}, 2: {0: {}}}, + 2: {0: {0: {}}, 1: {0: {}}}, + } + assert G._pred == { + 0: {1: {0: {}}, 2: {0: {}}}, + 1: {2: {0: {}}}, + 2: {0: {0: {}}, 1: {0: {}}}, + } + G.remove_edges_from([(0, 0)]) # silent fail + + +class TestEdgeSubgraph(_TestMultiGraphEdgeSubgraph): + """Unit tests for the :meth:`MultiDiGraph.edge_subgraph` method.""" + + def setup_method(self): + # Create a quadruply-linked path graph on five nodes. + G = nx.MultiDiGraph() + nx.add_path(G, range(5)) + nx.add_path(G, range(5)) + nx.add_path(G, reversed(range(5))) + nx.add_path(G, reversed(range(5))) + # Add some node, edge, and graph attributes. + for i in range(5): + G.nodes[i]["name"] = f"node{i}" + G.adj[0][1][0]["name"] = "edge010" + G.adj[0][1][1]["name"] = "edge011" + G.adj[3][4][0]["name"] = "edge340" + G.adj[3][4][1]["name"] = "edge341" + G.graph["name"] = "graph" + # Get the subgraph induced by one of the first edges and one of + # the last edges. + self.G = G + self.H = G.edge_subgraph([(0, 1, 0), (3, 4, 1)]) + + +class CustomDictClass(UserDict): + pass + + +class MultiDiGraphSubClass(nx.MultiDiGraph): + node_dict_factory = CustomDictClass # type: ignore[assignment] + node_attr_dict_factory = CustomDictClass # type: ignore[assignment] + adjlist_outer_dict_factory = CustomDictClass # type: ignore[assignment] + adjlist_inner_dict_factory = CustomDictClass # type: ignore[assignment] + edge_key_dict_factory = CustomDictClass # type: ignore[assignment] + edge_attr_dict_factory = CustomDictClass # type: ignore[assignment] + graph_attr_dict_factory = CustomDictClass # type: ignore[assignment] + + +class TestMultiDiGraphSubclass(TestMultiDiGraph): + def setup_method(self): + self.Graph = MultiDiGraphSubClass + # build K3 + self.k3edges = [(0, 1), (0, 2), (1, 2)] + self.k3nodes = [0, 1, 2] + self.K3 = self.Graph() + self.K3._succ = self.K3.adjlist_outer_dict_factory( + { + 0: self.K3.adjlist_inner_dict_factory(), + 1: self.K3.adjlist_inner_dict_factory(), + 2: self.K3.adjlist_inner_dict_factory(), + } + ) + # K3._adj is synced with K3._succ + self.K3._pred = {0: {}, 1: {}, 2: {}} + for u in self.k3nodes: + for v in self.k3nodes: + if u == v: + continue + d = {0: {}} + self.K3._succ[u][v] = d + self.K3._pred[v][u] = d + self.K3._node = self.K3.node_dict_factory() + self.K3._node[0] = self.K3.node_attr_dict_factory() + self.K3._node[1] = self.K3.node_attr_dict_factory() + self.K3._node[2] = self.K3.node_attr_dict_factory() diff --git a/lib/python3.10/site-packages/networkx/classes/tests/test_multigraph.py b/lib/python3.10/site-packages/networkx/classes/tests/test_multigraph.py new file mode 100644 index 0000000000000000000000000000000000000000..cd912d1d7c33c056b3c9808221bf7b72cd10fcac --- /dev/null +++ b/lib/python3.10/site-packages/networkx/classes/tests/test_multigraph.py @@ -0,0 +1,528 @@ +from collections import UserDict + +import pytest + +import networkx as nx +from networkx.utils import edges_equal + +from .test_graph import BaseAttrGraphTester +from .test_graph import TestGraph as _TestGraph + + +class BaseMultiGraphTester(BaseAttrGraphTester): + def test_has_edge(self): + G = self.K3 + assert G.has_edge(0, 1) + assert not G.has_edge(0, -1) + assert G.has_edge(0, 1, 0) + assert not G.has_edge(0, 1, 1) + + def test_get_edge_data(self): + G = self.K3 + assert G.get_edge_data(0, 1) == {0: {}} + assert G[0][1] == {0: {}} + assert G[0][1][0] == {} + assert G.get_edge_data(10, 20) is None + assert G.get_edge_data(0, 1, 0) == {} + + def test_adjacency(self): + G = self.K3 + assert dict(G.adjacency()) == { + 0: {1: {0: {}}, 2: {0: {}}}, + 1: {0: {0: {}}, 2: {0: {}}}, + 2: {0: {0: {}}, 1: {0: {}}}, + } + + def deepcopy_edge_attr(self, H, G): + assert G[1][2][0]["foo"] == H[1][2][0]["foo"] + G[1][2][0]["foo"].append(1) + assert G[1][2][0]["foo"] != H[1][2][0]["foo"] + + def shallow_copy_edge_attr(self, H, G): + assert G[1][2][0]["foo"] == H[1][2][0]["foo"] + G[1][2][0]["foo"].append(1) + assert G[1][2][0]["foo"] == H[1][2][0]["foo"] + + def graphs_equal(self, H, G): + assert G._adj == H._adj + assert G._node == H._node + assert G.graph == H.graph + assert G.name == H.name + if not G.is_directed() and not H.is_directed(): + assert H._adj[1][2][0] is H._adj[2][1][0] + assert G._adj[1][2][0] is G._adj[2][1][0] + else: # at least one is directed + if not G.is_directed(): + G._pred = G._adj + G._succ = G._adj + if not H.is_directed(): + H._pred = H._adj + H._succ = H._adj + assert G._pred == H._pred + assert G._succ == H._succ + assert H._succ[1][2][0] is H._pred[2][1][0] + assert G._succ[1][2][0] is G._pred[2][1][0] + + def same_attrdict(self, H, G): + # same attrdict in the edgedata + old_foo = H[1][2][0]["foo"] + H.adj[1][2][0]["foo"] = "baz" + assert G._adj == H._adj + H.adj[1][2][0]["foo"] = old_foo + assert G._adj == H._adj + + old_foo = H.nodes[0]["foo"] + H.nodes[0]["foo"] = "baz" + assert G._node == H._node + H.nodes[0]["foo"] = old_foo + assert G._node == H._node + + def different_attrdict(self, H, G): + # used by graph_equal_but_different + old_foo = H[1][2][0]["foo"] + H.adj[1][2][0]["foo"] = "baz" + assert G._adj != H._adj + H.adj[1][2][0]["foo"] = old_foo + assert G._adj == H._adj + + old_foo = H.nodes[0]["foo"] + H.nodes[0]["foo"] = "baz" + assert G._node != H._node + H.nodes[0]["foo"] = old_foo + assert G._node == H._node + + def test_to_undirected(self): + G = self.K3 + self.add_attributes(G) + H = nx.MultiGraph(G) + self.is_shallow_copy(H, G) + H = G.to_undirected() + self.is_deepcopy(H, G) + + def test_to_directed(self): + G = self.K3 + self.add_attributes(G) + H = nx.MultiDiGraph(G) + self.is_shallow_copy(H, G) + H = G.to_directed() + self.is_deepcopy(H, G) + + def test_number_of_edges_selfloops(self): + G = self.K3 + G.add_edge(0, 0) + G.add_edge(0, 0) + G.add_edge(0, 0, key="parallel edge") + G.remove_edge(0, 0, key="parallel edge") + assert G.number_of_edges(0, 0) == 2 + G.remove_edge(0, 0) + assert G.number_of_edges(0, 0) == 1 + + def test_edge_lookup(self): + G = self.Graph() + G.add_edge(1, 2, foo="bar") + G.add_edge(1, 2, "key", foo="biz") + assert edges_equal(G.edges[1, 2, 0], {"foo": "bar"}) + assert edges_equal(G.edges[1, 2, "key"], {"foo": "biz"}) + + def test_edge_attr(self): + G = self.Graph() + G.add_edge(1, 2, key="k1", foo="bar") + G.add_edge(1, 2, key="k2", foo="baz") + assert isinstance(G.get_edge_data(1, 2), G.edge_key_dict_factory) + assert all( + isinstance(d, G.edge_attr_dict_factory) for u, v, d in G.edges(data=True) + ) + assert edges_equal( + G.edges(keys=True, data=True), + [(1, 2, "k1", {"foo": "bar"}), (1, 2, "k2", {"foo": "baz"})], + ) + assert edges_equal( + G.edges(keys=True, data="foo"), [(1, 2, "k1", "bar"), (1, 2, "k2", "baz")] + ) + + def test_edge_attr4(self): + G = self.Graph() + G.add_edge(1, 2, key=0, data=7, spam="bar", bar="foo") + assert edges_equal( + G.edges(data=True), [(1, 2, {"data": 7, "spam": "bar", "bar": "foo"})] + ) + G[1][2][0]["data"] = 10 # OK to set data like this + assert edges_equal( + G.edges(data=True), [(1, 2, {"data": 10, "spam": "bar", "bar": "foo"})] + ) + + G.adj[1][2][0]["data"] = 20 + assert edges_equal( + G.edges(data=True), [(1, 2, {"data": 20, "spam": "bar", "bar": "foo"})] + ) + G.edges[1, 2, 0]["data"] = 21 # another spelling, "edge" + assert edges_equal( + G.edges(data=True), [(1, 2, {"data": 21, "spam": "bar", "bar": "foo"})] + ) + G.adj[1][2][0]["listdata"] = [20, 200] + G.adj[1][2][0]["weight"] = 20 + assert edges_equal( + G.edges(data=True), + [ + ( + 1, + 2, + { + "data": 21, + "spam": "bar", + "bar": "foo", + "listdata": [20, 200], + "weight": 20, + }, + ) + ], + ) + + +class TestMultiGraph(BaseMultiGraphTester, _TestGraph): + def setup_method(self): + self.Graph = nx.MultiGraph + # build K3 + ed1, ed2, ed3 = ({0: {}}, {0: {}}, {0: {}}) + self.k3adj = {0: {1: ed1, 2: ed2}, 1: {0: ed1, 2: ed3}, 2: {0: ed2, 1: ed3}} + self.k3edges = [(0, 1), (0, 2), (1, 2)] + self.k3nodes = [0, 1, 2] + self.K3 = self.Graph() + self.K3._adj = self.k3adj + self.K3._node = {} + self.K3._node[0] = {} + self.K3._node[1] = {} + self.K3._node[2] = {} + + def test_data_input(self): + G = self.Graph({1: [2], 2: [1]}, name="test") + assert G.name == "test" + expected = [(1, {2: {0: {}}}), (2, {1: {0: {}}})] + assert sorted(G.adj.items()) == expected + + def test_data_multigraph_input(self): + # standard case with edge keys and edge data + edata0 = {"w": 200, "s": "foo"} + edata1 = {"w": 201, "s": "bar"} + keydict = {0: edata0, 1: edata1} + dododod = {"a": {"b": keydict}} + + multiple_edge = [("a", "b", 0, edata0), ("a", "b", 1, edata1)] + single_edge = [("a", "b", 0, keydict)] + + G = self.Graph(dododod, multigraph_input=True) + assert list(G.edges(keys=True, data=True)) == multiple_edge + G = self.Graph(dododod, multigraph_input=None) + assert list(G.edges(keys=True, data=True)) == multiple_edge + G = self.Graph(dododod, multigraph_input=False) + assert list(G.edges(keys=True, data=True)) == single_edge + + # test round-trip to_dict_of_dict and MultiGraph constructor + G = self.Graph(dododod, multigraph_input=True) + H = self.Graph(nx.to_dict_of_dicts(G)) + assert nx.is_isomorphic(G, H) is True # test that default is True + for mgi in [True, False]: + H = self.Graph(nx.to_dict_of_dicts(G), multigraph_input=mgi) + assert nx.is_isomorphic(G, H) == mgi + + # Set up cases for when incoming_graph_data is not multigraph_input + etraits = {"w": 200, "s": "foo"} + egraphics = {"color": "blue", "shape": "box"} + edata = {"traits": etraits, "graphics": egraphics} + dodod1 = {"a": {"b": edata}} + dodod2 = {"a": {"b": etraits}} + dodod3 = {"a": {"b": {"traits": etraits, "s": "foo"}}} + dol = {"a": ["b"]} + + multiple_edge = [("a", "b", "traits", etraits), ("a", "b", "graphics", egraphics)] + single_edge = [("a", "b", 0, {})] # type: ignore[var-annotated] + single_edge1 = [("a", "b", 0, edata)] + single_edge2 = [("a", "b", 0, etraits)] + single_edge3 = [("a", "b", 0, {"traits": etraits, "s": "foo"})] + + cases = [ # (dod, mgi, edges) + (dodod1, True, multiple_edge), + (dodod1, False, single_edge1), + (dodod2, False, single_edge2), + (dodod3, False, single_edge3), + (dol, False, single_edge), + ] + + @pytest.mark.parametrize("dod, mgi, edges", cases) + def test_non_multigraph_input(self, dod, mgi, edges): + G = self.Graph(dod, multigraph_input=mgi) + assert list(G.edges(keys=True, data=True)) == edges + G = nx.to_networkx_graph(dod, create_using=self.Graph, multigraph_input=mgi) + assert list(G.edges(keys=True, data=True)) == edges + + mgi_none_cases = [ + (dodod1, multiple_edge), + (dodod2, single_edge2), + (dodod3, single_edge3), + ] + + @pytest.mark.parametrize("dod, edges", mgi_none_cases) + def test_non_multigraph_input_mgi_none(self, dod, edges): + # test constructor without to_networkx_graph for mgi=None + G = self.Graph(dod) + assert list(G.edges(keys=True, data=True)) == edges + + raise_cases = [dodod2, dodod3, dol] + + @pytest.mark.parametrize("dod", raise_cases) + def test_non_multigraph_input_raise(self, dod): + # cases where NetworkXError is raised + pytest.raises(nx.NetworkXError, self.Graph, dod, multigraph_input=True) + pytest.raises( + nx.NetworkXError, + nx.to_networkx_graph, + dod, + create_using=self.Graph, + multigraph_input=True, + ) + + def test_getitem(self): + G = self.K3 + assert G[0] == {1: {0: {}}, 2: {0: {}}} + with pytest.raises(KeyError): + G.__getitem__("j") + with pytest.raises(TypeError): + G.__getitem__(["A"]) + + def test_remove_node(self): + G = self.K3 + G.remove_node(0) + assert G.adj == {1: {2: {0: {}}}, 2: {1: {0: {}}}} + with pytest.raises(nx.NetworkXError): + G.remove_node(-1) + + def test_add_edge(self): + G = self.Graph() + G.add_edge(0, 1) + assert G.adj == {0: {1: {0: {}}}, 1: {0: {0: {}}}} + G = self.Graph() + G.add_edge(*(0, 1)) + assert G.adj == {0: {1: {0: {}}}, 1: {0: {0: {}}}} + G = self.Graph() + with pytest.raises(ValueError): + G.add_edge(None, "anything") + + def test_add_edge_conflicting_key(self): + G = self.Graph() + G.add_edge(0, 1, key=1) + G.add_edge(0, 1) + assert G.number_of_edges() == 2 + G = self.Graph() + G.add_edges_from([(0, 1, 1, {})]) + G.add_edges_from([(0, 1)]) + assert G.number_of_edges() == 2 + + def test_add_edges_from(self): + G = self.Graph() + G.add_edges_from([(0, 1), (0, 1, {"weight": 3})]) + assert G.adj == { + 0: {1: {0: {}, 1: {"weight": 3}}}, + 1: {0: {0: {}, 1: {"weight": 3}}}, + } + G.add_edges_from([(0, 1), (0, 1, {"weight": 3})], weight=2) + assert G.adj == { + 0: {1: {0: {}, 1: {"weight": 3}, 2: {"weight": 2}, 3: {"weight": 3}}}, + 1: {0: {0: {}, 1: {"weight": 3}, 2: {"weight": 2}, 3: {"weight": 3}}}, + } + G = self.Graph() + edges = [ + (0, 1, {"weight": 3}), + (0, 1, (("weight", 2),)), + (0, 1, 5), + (0, 1, "s"), + ] + G.add_edges_from(edges) + keydict = {0: {"weight": 3}, 1: {"weight": 2}, 5: {}, "s": {}} + assert G._adj == {0: {1: keydict}, 1: {0: keydict}} + + # too few in tuple + with pytest.raises(nx.NetworkXError): + G.add_edges_from([(0,)]) + # too many in tuple + with pytest.raises(nx.NetworkXError): + G.add_edges_from([(0, 1, 2, 3, 4)]) + # not a tuple + with pytest.raises(TypeError): + G.add_edges_from([0]) + + def test_multigraph_add_edges_from_four_tuple_misordered(self): + """add_edges_from expects 4-tuples of the format (u, v, key, data_dict). + + Ensure 4-tuples of form (u, v, data_dict, key) raise exception. + """ + G = nx.MultiGraph() + with pytest.raises(TypeError): + # key/data values flipped in 4-tuple + G.add_edges_from([(0, 1, {"color": "red"}, 0)]) + + def test_remove_edge(self): + G = self.K3 + G.remove_edge(0, 1) + assert G.adj == {0: {2: {0: {}}}, 1: {2: {0: {}}}, 2: {0: {0: {}}, 1: {0: {}}}} + + with pytest.raises(nx.NetworkXError): + G.remove_edge(-1, 0) + with pytest.raises(nx.NetworkXError): + G.remove_edge(0, 2, key=1) + + def test_remove_edges_from(self): + G = self.K3.copy() + G.remove_edges_from([(0, 1)]) + kd = {0: {}} + assert G.adj == {0: {2: kd}, 1: {2: kd}, 2: {0: kd, 1: kd}} + G.remove_edges_from([(0, 0)]) # silent fail + self.K3.add_edge(0, 1) + G = self.K3.copy() + G.remove_edges_from(list(G.edges(data=True, keys=True))) + assert G.adj == {0: {}, 1: {}, 2: {}} + G = self.K3.copy() + G.remove_edges_from(list(G.edges(data=False, keys=True))) + assert G.adj == {0: {}, 1: {}, 2: {}} + G = self.K3.copy() + G.remove_edges_from(list(G.edges(data=False, keys=False))) + assert G.adj == {0: {}, 1: {}, 2: {}} + G = self.K3.copy() + G.remove_edges_from([(0, 1, 0), (0, 2, 0, {}), (1, 2)]) + assert G.adj == {0: {1: {1: {}}}, 1: {0: {1: {}}}, 2: {}} + + def test_remove_multiedge(self): + G = self.K3 + G.add_edge(0, 1, key="parallel edge") + G.remove_edge(0, 1, key="parallel edge") + assert G.adj == { + 0: {1: {0: {}}, 2: {0: {}}}, + 1: {0: {0: {}}, 2: {0: {}}}, + 2: {0: {0: {}}, 1: {0: {}}}, + } + G.remove_edge(0, 1) + kd = {0: {}} + assert G.adj == {0: {2: kd}, 1: {2: kd}, 2: {0: kd, 1: kd}} + with pytest.raises(nx.NetworkXError): + G.remove_edge(-1, 0) + + +class TestEdgeSubgraph: + """Unit tests for the :meth:`MultiGraph.edge_subgraph` method.""" + + def setup_method(self): + # Create a doubly-linked path graph on five nodes. + G = nx.MultiGraph() + nx.add_path(G, range(5)) + nx.add_path(G, range(5)) + # Add some node, edge, and graph attributes. + for i in range(5): + G.nodes[i]["name"] = f"node{i}" + G.adj[0][1][0]["name"] = "edge010" + G.adj[0][1][1]["name"] = "edge011" + G.adj[3][4][0]["name"] = "edge340" + G.adj[3][4][1]["name"] = "edge341" + G.graph["name"] = "graph" + # Get the subgraph induced by one of the first edges and one of + # the last edges. + self.G = G + self.H = G.edge_subgraph([(0, 1, 0), (3, 4, 1)]) + + def test_correct_nodes(self): + """Tests that the subgraph has the correct nodes.""" + assert [0, 1, 3, 4] == sorted(self.H.nodes()) + + def test_correct_edges(self): + """Tests that the subgraph has the correct edges.""" + assert [(0, 1, 0, "edge010"), (3, 4, 1, "edge341")] == sorted( + self.H.edges(keys=True, data="name") + ) + + def test_add_node(self): + """Tests that adding a node to the original graph does not + affect the nodes of the subgraph. + + """ + self.G.add_node(5) + assert [0, 1, 3, 4] == sorted(self.H.nodes()) + + def test_remove_node(self): + """Tests that removing a node in the original graph does + affect the nodes of the subgraph. + + """ + self.G.remove_node(0) + assert [1, 3, 4] == sorted(self.H.nodes()) + + def test_node_attr_dict(self): + """Tests that the node attribute dictionary of the two graphs is + the same object. + + """ + for v in self.H: + assert self.G.nodes[v] == self.H.nodes[v] + # Making a change to G should make a change in H and vice versa. + self.G.nodes[0]["name"] = "foo" + assert self.G.nodes[0] == self.H.nodes[0] + self.H.nodes[1]["name"] = "bar" + assert self.G.nodes[1] == self.H.nodes[1] + + def test_edge_attr_dict(self): + """Tests that the edge attribute dictionary of the two graphs is + the same object. + + """ + for u, v, k in self.H.edges(keys=True): + assert self.G._adj[u][v][k] == self.H._adj[u][v][k] + # Making a change to G should make a change in H and vice versa. + self.G._adj[0][1][0]["name"] = "foo" + assert self.G._adj[0][1][0]["name"] == self.H._adj[0][1][0]["name"] + self.H._adj[3][4][1]["name"] = "bar" + assert self.G._adj[3][4][1]["name"] == self.H._adj[3][4][1]["name"] + + def test_graph_attr_dict(self): + """Tests that the graph attribute dictionary of the two graphs + is the same object. + + """ + assert self.G.graph is self.H.graph + + +class CustomDictClass(UserDict): + pass + + +class MultiGraphSubClass(nx.MultiGraph): + node_dict_factory = CustomDictClass # type: ignore[assignment] + node_attr_dict_factory = CustomDictClass # type: ignore[assignment] + adjlist_outer_dict_factory = CustomDictClass # type: ignore[assignment] + adjlist_inner_dict_factory = CustomDictClass # type: ignore[assignment] + edge_key_dict_factory = CustomDictClass # type: ignore[assignment] + edge_attr_dict_factory = CustomDictClass # type: ignore[assignment] + graph_attr_dict_factory = CustomDictClass # type: ignore[assignment] + + +class TestMultiGraphSubclass(TestMultiGraph): + def setup_method(self): + self.Graph = MultiGraphSubClass + # build K3 + self.k3edges = [(0, 1), (0, 2), (1, 2)] + self.k3nodes = [0, 1, 2] + self.K3 = self.Graph() + self.K3._adj = self.K3.adjlist_outer_dict_factory( + { + 0: self.K3.adjlist_inner_dict_factory(), + 1: self.K3.adjlist_inner_dict_factory(), + 2: self.K3.adjlist_inner_dict_factory(), + } + ) + self.K3._pred = {0: {}, 1: {}, 2: {}} + for u in self.k3nodes: + for v in self.k3nodes: + if u != v: + d = {0: {}} + self.K3._adj[u][v] = d + self.K3._adj[v][u] = d + self.K3._node = self.K3.node_dict_factory() + self.K3._node[0] = self.K3.node_attr_dict_factory() + self.K3._node[1] = self.K3.node_attr_dict_factory() + self.K3._node[2] = self.K3.node_attr_dict_factory() diff --git a/lib/python3.10/site-packages/networkx/classes/tests/test_reportviews.py b/lib/python3.10/site-packages/networkx/classes/tests/test_reportviews.py new file mode 100644 index 0000000000000000000000000000000000000000..789c829f4b05fa71eff384a05ad071bc7fdebd9f --- /dev/null +++ b/lib/python3.10/site-packages/networkx/classes/tests/test_reportviews.py @@ -0,0 +1,1435 @@ +import pickle +from copy import deepcopy + +import pytest + +import networkx as nx +from networkx.classes import reportviews as rv +from networkx.classes.reportviews import NodeDataView + + +# Nodes +class TestNodeView: + @classmethod + def setup_class(cls): + cls.G = nx.path_graph(9) + cls.nv = cls.G.nodes # NodeView(G) + + def test_pickle(self): + import pickle + + nv = self.nv + pnv = pickle.loads(pickle.dumps(nv, -1)) + assert nv == pnv + assert nv.__slots__ == pnv.__slots__ + + def test_str(self): + assert str(self.nv) == "[0, 1, 2, 3, 4, 5, 6, 7, 8]" + + def test_repr(self): + assert repr(self.nv) == "NodeView((0, 1, 2, 3, 4, 5, 6, 7, 8))" + + def test_contains(self): + G = self.G.copy() + nv = G.nodes + assert 7 in nv + assert 9 not in nv + G.remove_node(7) + G.add_node(9) + assert 7 not in nv + assert 9 in nv + + def test_getitem(self): + G = self.G.copy() + nv = G.nodes + G.nodes[3]["foo"] = "bar" + assert nv[7] == {} + assert nv[3] == {"foo": "bar"} + # slicing + with pytest.raises(nx.NetworkXError): + G.nodes[0:5] + + def test_iter(self): + nv = self.nv + for i, n in enumerate(nv): + assert i == n + inv = iter(nv) + assert next(inv) == 0 + assert iter(nv) != nv + assert iter(inv) == inv + inv2 = iter(nv) + next(inv2) + assert list(inv) == list(inv2) + # odd case where NodeView calls NodeDataView with data=False + nnv = nv(data=False) + for i, n in enumerate(nnv): + assert i == n + + def test_call(self): + nodes = self.nv + assert nodes is nodes() + assert nodes is not nodes(data=True) + assert nodes is not nodes(data="weight") + + +class TestNodeDataView: + @classmethod + def setup_class(cls): + cls.G = nx.path_graph(9) + cls.nv = NodeDataView(cls.G) + cls.ndv = cls.G.nodes.data(True) + cls.nwv = cls.G.nodes.data("foo") + + def test_viewtype(self): + nv = self.G.nodes + ndvfalse = nv.data(False) + assert nv is ndvfalse + assert nv is not self.ndv + + def test_pickle(self): + import pickle + + nv = self.nv + pnv = pickle.loads(pickle.dumps(nv, -1)) + assert nv == pnv + assert nv.__slots__ == pnv.__slots__ + + def test_str(self): + msg = str([(n, {}) for n in range(9)]) + assert str(self.ndv) == msg + + def test_repr(self): + expected = "NodeDataView((0, 1, 2, 3, 4, 5, 6, 7, 8))" + assert repr(self.nv) == expected + expected = ( + "NodeDataView({0: {}, 1: {}, 2: {}, 3: {}, " + + "4: {}, 5: {}, 6: {}, 7: {}, 8: {}})" + ) + assert repr(self.ndv) == expected + expected = ( + "NodeDataView({0: None, 1: None, 2: None, 3: None, 4: None, " + + "5: None, 6: None, 7: None, 8: None}, data='foo')" + ) + assert repr(self.nwv) == expected + + def test_contains(self): + G = self.G.copy() + nv = G.nodes.data() + nwv = G.nodes.data("foo") + G.nodes[3]["foo"] = "bar" + assert (7, {}) in nv + assert (3, {"foo": "bar"}) in nv + assert (3, "bar") in nwv + assert (7, None) in nwv + # default + nwv_def = G.nodes(data="foo", default="biz") + assert (7, "biz") in nwv_def + assert (3, "bar") in nwv_def + + def test_getitem(self): + G = self.G.copy() + nv = G.nodes + G.nodes[3]["foo"] = "bar" + assert nv[3] == {"foo": "bar"} + # default + nwv_def = G.nodes(data="foo", default="biz") + assert nwv_def[7], "biz" + assert nwv_def[3] == "bar" + # slicing + with pytest.raises(nx.NetworkXError): + G.nodes.data()[0:5] + + def test_iter(self): + G = self.G.copy() + nv = G.nodes.data() + ndv = G.nodes.data(True) + nwv = G.nodes.data("foo") + for i, (n, d) in enumerate(nv): + assert i == n + assert d == {} + inv = iter(nv) + assert next(inv) == (0, {}) + G.nodes[3]["foo"] = "bar" + # default + for n, d in nv: + if n == 3: + assert d == {"foo": "bar"} + else: + assert d == {} + # data=True + for n, d in ndv: + if n == 3: + assert d == {"foo": "bar"} + else: + assert d == {} + # data='foo' + for n, d in nwv: + if n == 3: + assert d == "bar" + else: + assert d is None + # data='foo', default=1 + for n, d in G.nodes.data("foo", default=1): + if n == 3: + assert d == "bar" + else: + assert d == 1 + + +def test_nodedataview_unhashable(): + G = nx.path_graph(9) + G.nodes[3]["foo"] = "bar" + nvs = [G.nodes.data()] + nvs.append(G.nodes.data(True)) + H = G.copy() + H.nodes[4]["foo"] = {1, 2, 3} + nvs.append(H.nodes.data(True)) + # raise unhashable + for nv in nvs: + pytest.raises(TypeError, set, nv) + pytest.raises(TypeError, eval, "nv | nv", locals()) + # no raise... hashable + Gn = G.nodes.data(False) + set(Gn) + Gn | Gn + Gn = G.nodes.data("foo") + set(Gn) + Gn | Gn + + +class TestNodeViewSetOps: + @classmethod + def setup_class(cls): + cls.G = nx.path_graph(9) + cls.G.nodes[3]["foo"] = "bar" + cls.nv = cls.G.nodes + + def n_its(self, nodes): + return set(nodes) + + def test_len(self): + G = self.G.copy() + nv = G.nodes + assert len(nv) == 9 + G.remove_node(7) + assert len(nv) == 8 + G.add_node(9) + assert len(nv) == 9 + + def test_and(self): + # print("G & H nodes:", gnv & hnv) + nv = self.nv + some_nodes = self.n_its(range(5, 12)) + assert nv & some_nodes == self.n_its(range(5, 9)) + assert some_nodes & nv == self.n_its(range(5, 9)) + + def test_or(self): + # print("G | H nodes:", gnv | hnv) + nv = self.nv + some_nodes = self.n_its(range(5, 12)) + assert nv | some_nodes == self.n_its(range(12)) + assert some_nodes | nv == self.n_its(range(12)) + + def test_xor(self): + # print("G ^ H nodes:", gnv ^ hnv) + nv = self.nv + some_nodes = self.n_its(range(5, 12)) + nodes = {0, 1, 2, 3, 4, 9, 10, 11} + assert nv ^ some_nodes == self.n_its(nodes) + assert some_nodes ^ nv == self.n_its(nodes) + + def test_sub(self): + # print("G - H nodes:", gnv - hnv) + nv = self.nv + some_nodes = self.n_its(range(5, 12)) + assert nv - some_nodes == self.n_its(range(5)) + assert some_nodes - nv == self.n_its(range(9, 12)) + + +class TestNodeDataViewSetOps(TestNodeViewSetOps): + @classmethod + def setup_class(cls): + cls.G = nx.path_graph(9) + cls.G.nodes[3]["foo"] = "bar" + cls.nv = cls.G.nodes.data("foo") + + def n_its(self, nodes): + return {(node, "bar" if node == 3 else None) for node in nodes} + + +class TestNodeDataViewDefaultSetOps(TestNodeDataViewSetOps): + @classmethod + def setup_class(cls): + cls.G = nx.path_graph(9) + cls.G.nodes[3]["foo"] = "bar" + cls.nv = cls.G.nodes.data("foo", default=1) + + def n_its(self, nodes): + return {(node, "bar" if node == 3 else 1) for node in nodes} + + +# Edges Data View +class TestEdgeDataView: + @classmethod + def setup_class(cls): + cls.G = nx.path_graph(9) + cls.eview = nx.reportviews.EdgeView + + def test_pickle(self): + import pickle + + ev = self.eview(self.G)(data=True) + pev = pickle.loads(pickle.dumps(ev, -1)) + assert list(ev) == list(pev) + assert ev.__slots__ == pev.__slots__ + + def modify_edge(self, G, e, **kwds): + G._adj[e[0]][e[1]].update(kwds) + + def test_str(self): + ev = self.eview(self.G)(data=True) + rep = str([(n, n + 1, {}) for n in range(8)]) + assert str(ev) == rep + + def test_repr(self): + ev = self.eview(self.G)(data=True) + rep = ( + "EdgeDataView([(0, 1, {}), (1, 2, {}), " + + "(2, 3, {}), (3, 4, {}), " + + "(4, 5, {}), (5, 6, {}), " + + "(6, 7, {}), (7, 8, {})])" + ) + assert repr(ev) == rep + + def test_iterdata(self): + G = self.G.copy() + evr = self.eview(G) + ev = evr(data=True) + ev_def = evr(data="foo", default=1) + + for u, v, d in ev: + pass + assert d == {} + + for u, v, wt in ev_def: + pass + assert wt == 1 + + self.modify_edge(G, (2, 3), foo="bar") + for e in ev: + assert len(e) == 3 + if set(e[:2]) == {2, 3}: + assert e[2] == {"foo": "bar"} + checked = True + else: + assert e[2] == {} + assert checked + + for e in ev_def: + assert len(e) == 3 + if set(e[:2]) == {2, 3}: + assert e[2] == "bar" + checked_wt = True + else: + assert e[2] == 1 + assert checked_wt + + def test_iter(self): + evr = self.eview(self.G) + ev = evr() + for u, v in ev: + pass + iev = iter(ev) + assert next(iev) == (0, 1) + assert iter(ev) != ev + assert iter(iev) == iev + + def test_contains(self): + evr = self.eview(self.G) + ev = evr() + if self.G.is_directed(): + assert (1, 2) in ev and (2, 1) not in ev + else: + assert (1, 2) in ev and (2, 1) in ev + assert (1, 4) not in ev + assert (1, 90) not in ev + assert (90, 1) not in ev + + def test_contains_with_nbunch(self): + evr = self.eview(self.G) + ev = evr(nbunch=[0, 2]) + if self.G.is_directed(): + assert (0, 1) in ev + assert (1, 2) not in ev + assert (2, 3) in ev + else: + assert (0, 1) in ev + assert (1, 2) in ev + assert (2, 3) in ev + assert (3, 4) not in ev + assert (4, 5) not in ev + assert (5, 6) not in ev + assert (7, 8) not in ev + assert (8, 9) not in ev + + def test_len(self): + evr = self.eview(self.G) + ev = evr(data="foo") + assert len(ev) == 8 + assert len(evr(1)) == 2 + assert len(evr([1, 2, 3])) == 4 + + assert len(self.G.edges(1)) == 2 + assert len(self.G.edges()) == 8 + assert len(self.G.edges) == 8 + + H = self.G.copy() + H.add_edge(1, 1) + assert len(H.edges(1)) == 3 + assert len(H.edges()) == 9 + assert len(H.edges) == 9 + + +class TestOutEdgeDataView(TestEdgeDataView): + @classmethod + def setup_class(cls): + cls.G = nx.path_graph(9, create_using=nx.DiGraph()) + cls.eview = nx.reportviews.OutEdgeView + + def test_repr(self): + ev = self.eview(self.G)(data=True) + rep = ( + "OutEdgeDataView([(0, 1, {}), (1, 2, {}), " + + "(2, 3, {}), (3, 4, {}), " + + "(4, 5, {}), (5, 6, {}), " + + "(6, 7, {}), (7, 8, {})])" + ) + assert repr(ev) == rep + + def test_len(self): + evr = self.eview(self.G) + ev = evr(data="foo") + assert len(ev) == 8 + assert len(evr(1)) == 1 + assert len(evr([1, 2, 3])) == 3 + + assert len(self.G.edges(1)) == 1 + assert len(self.G.edges()) == 8 + assert len(self.G.edges) == 8 + + H = self.G.copy() + H.add_edge(1, 1) + assert len(H.edges(1)) == 2 + assert len(H.edges()) == 9 + assert len(H.edges) == 9 + + def test_contains_with_nbunch(self): + evr = self.eview(self.G) + ev = evr(nbunch=[0, 2]) + assert (0, 1) in ev + assert (1, 2) not in ev + assert (2, 3) in ev + assert (3, 4) not in ev + assert (4, 5) not in ev + assert (5, 6) not in ev + assert (7, 8) not in ev + assert (8, 9) not in ev + + +class TestInEdgeDataView(TestOutEdgeDataView): + @classmethod + def setup_class(cls): + cls.G = nx.path_graph(9, create_using=nx.DiGraph()) + cls.eview = nx.reportviews.InEdgeView + + def test_repr(self): + ev = self.eview(self.G)(data=True) + rep = ( + "InEdgeDataView([(0, 1, {}), (1, 2, {}), " + + "(2, 3, {}), (3, 4, {}), " + + "(4, 5, {}), (5, 6, {}), " + + "(6, 7, {}), (7, 8, {})])" + ) + assert repr(ev) == rep + + def test_contains_with_nbunch(self): + evr = self.eview(self.G) + ev = evr(nbunch=[0, 2]) + assert (0, 1) not in ev + assert (1, 2) in ev + assert (2, 3) not in ev + assert (3, 4) not in ev + assert (4, 5) not in ev + assert (5, 6) not in ev + assert (7, 8) not in ev + assert (8, 9) not in ev + + +class TestMultiEdgeDataView(TestEdgeDataView): + @classmethod + def setup_class(cls): + cls.G = nx.path_graph(9, create_using=nx.MultiGraph()) + cls.eview = nx.reportviews.MultiEdgeView + + def modify_edge(self, G, e, **kwds): + G._adj[e[0]][e[1]][0].update(kwds) + + def test_repr(self): + ev = self.eview(self.G)(data=True) + rep = ( + "MultiEdgeDataView([(0, 1, {}), (1, 2, {}), " + + "(2, 3, {}), (3, 4, {}), " + + "(4, 5, {}), (5, 6, {}), " + + "(6, 7, {}), (7, 8, {})])" + ) + assert repr(ev) == rep + + def test_contains_with_nbunch(self): + evr = self.eview(self.G) + ev = evr(nbunch=[0, 2]) + assert (0, 1) in ev + assert (1, 2) in ev + assert (2, 3) in ev + assert (3, 4) not in ev + assert (4, 5) not in ev + assert (5, 6) not in ev + assert (7, 8) not in ev + assert (8, 9) not in ev + + +class TestOutMultiEdgeDataView(TestOutEdgeDataView): + @classmethod + def setup_class(cls): + cls.G = nx.path_graph(9, create_using=nx.MultiDiGraph()) + cls.eview = nx.reportviews.OutMultiEdgeView + + def modify_edge(self, G, e, **kwds): + G._adj[e[0]][e[1]][0].update(kwds) + + def test_repr(self): + ev = self.eview(self.G)(data=True) + rep = ( + "OutMultiEdgeDataView([(0, 1, {}), (1, 2, {}), " + + "(2, 3, {}), (3, 4, {}), " + + "(4, 5, {}), (5, 6, {}), " + + "(6, 7, {}), (7, 8, {})])" + ) + assert repr(ev) == rep + + def test_contains_with_nbunch(self): + evr = self.eview(self.G) + ev = evr(nbunch=[0, 2]) + assert (0, 1) in ev + assert (1, 2) not in ev + assert (2, 3) in ev + assert (3, 4) not in ev + assert (4, 5) not in ev + assert (5, 6) not in ev + assert (7, 8) not in ev + assert (8, 9) not in ev + + +class TestInMultiEdgeDataView(TestOutMultiEdgeDataView): + @classmethod + def setup_class(cls): + cls.G = nx.path_graph(9, create_using=nx.MultiDiGraph()) + cls.eview = nx.reportviews.InMultiEdgeView + + def test_repr(self): + ev = self.eview(self.G)(data=True) + rep = ( + "InMultiEdgeDataView([(0, 1, {}), (1, 2, {}), " + + "(2, 3, {}), (3, 4, {}), " + + "(4, 5, {}), (5, 6, {}), " + + "(6, 7, {}), (7, 8, {})])" + ) + assert repr(ev) == rep + + def test_contains_with_nbunch(self): + evr = self.eview(self.G) + ev = evr(nbunch=[0, 2]) + assert (0, 1) not in ev + assert (1, 2) in ev + assert (2, 3) not in ev + assert (3, 4) not in ev + assert (4, 5) not in ev + assert (5, 6) not in ev + assert (7, 8) not in ev + assert (8, 9) not in ev + + +# Edge Views +class TestEdgeView: + @classmethod + def setup_class(cls): + cls.G = nx.path_graph(9) + cls.eview = nx.reportviews.EdgeView + + def test_pickle(self): + import pickle + + ev = self.eview(self.G) + pev = pickle.loads(pickle.dumps(ev, -1)) + assert ev == pev + assert ev.__slots__ == pev.__slots__ + + def modify_edge(self, G, e, **kwds): + G._adj[e[0]][e[1]].update(kwds) + + def test_str(self): + ev = self.eview(self.G) + rep = str([(n, n + 1) for n in range(8)]) + assert str(ev) == rep + + def test_repr(self): + ev = self.eview(self.G) + rep = ( + "EdgeView([(0, 1), (1, 2), (2, 3), (3, 4), " + + "(4, 5), (5, 6), (6, 7), (7, 8)])" + ) + assert repr(ev) == rep + + def test_getitem(self): + G = self.G.copy() + ev = G.edges + G.edges[0, 1]["foo"] = "bar" + assert ev[0, 1] == {"foo": "bar"} + + # slicing + with pytest.raises(nx.NetworkXError, match=".*does not support slicing"): + G.edges[0:5] + + # Invalid edge + with pytest.raises(KeyError, match=r".*edge.*is not in the graph."): + G.edges[0, 9] + + def test_call(self): + ev = self.eview(self.G) + assert id(ev) == id(ev()) + assert id(ev) == id(ev(data=False)) + assert id(ev) != id(ev(data=True)) + assert id(ev) != id(ev(nbunch=1)) + + def test_data(self): + ev = self.eview(self.G) + assert id(ev) != id(ev.data()) + assert id(ev) == id(ev.data(data=False)) + assert id(ev) != id(ev.data(data=True)) + assert id(ev) != id(ev.data(nbunch=1)) + + def test_iter(self): + ev = self.eview(self.G) + for u, v in ev: + pass + iev = iter(ev) + assert next(iev) == (0, 1) + assert iter(ev) != ev + assert iter(iev) == iev + + def test_contains(self): + ev = self.eview(self.G) + edv = ev() + if self.G.is_directed(): + assert (1, 2) in ev and (2, 1) not in ev + assert (1, 2) in edv and (2, 1) not in edv + else: + assert (1, 2) in ev and (2, 1) in ev + assert (1, 2) in edv and (2, 1) in edv + assert (1, 4) not in ev + assert (1, 4) not in edv + # edge not in graph + assert (1, 90) not in ev + assert (90, 1) not in ev + assert (1, 90) not in edv + assert (90, 1) not in edv + + def test_contains_with_nbunch(self): + ev = self.eview(self.G) + evn = ev(nbunch=[0, 2]) + assert (0, 1) in evn + assert (1, 2) in evn + assert (2, 3) in evn + assert (3, 4) not in evn + assert (4, 5) not in evn + assert (5, 6) not in evn + assert (7, 8) not in evn + assert (8, 9) not in evn + + def test_len(self): + ev = self.eview(self.G) + num_ed = 9 if self.G.is_multigraph() else 8 + assert len(ev) == num_ed + + H = self.G.copy() + H.add_edge(1, 1) + assert len(H.edges(1)) == 3 + H.is_multigraph() - H.is_directed() + assert len(H.edges()) == num_ed + 1 + assert len(H.edges) == num_ed + 1 + + def test_and(self): + # print("G & H edges:", gnv & hnv) + ev = self.eview(self.G) + some_edges = {(0, 1), (1, 0), (0, 2)} + if self.G.is_directed(): + assert some_edges & ev, {(0, 1)} + assert ev & some_edges, {(0, 1)} + else: + assert ev & some_edges == {(0, 1), (1, 0)} + assert some_edges & ev == {(0, 1), (1, 0)} + return + + def test_or(self): + # print("G | H edges:", gnv | hnv) + ev = self.eview(self.G) + some_edges = {(0, 1), (1, 0), (0, 2)} + result1 = {(n, n + 1) for n in range(8)} + result1.update(some_edges) + result2 = {(n + 1, n) for n in range(8)} + result2.update(some_edges) + assert (ev | some_edges) in (result1, result2) + assert (some_edges | ev) in (result1, result2) + + def test_xor(self): + # print("G ^ H edges:", gnv ^ hnv) + ev = self.eview(self.G) + some_edges = {(0, 1), (1, 0), (0, 2)} + if self.G.is_directed(): + result = {(n, n + 1) for n in range(1, 8)} + result.update({(1, 0), (0, 2)}) + assert ev ^ some_edges == result + else: + result = {(n, n + 1) for n in range(1, 8)} + result.update({(0, 2)}) + assert ev ^ some_edges == result + return + + def test_sub(self): + # print("G - H edges:", gnv - hnv) + ev = self.eview(self.G) + some_edges = {(0, 1), (1, 0), (0, 2)} + result = {(n, n + 1) for n in range(8)} + result.remove((0, 1)) + assert ev - some_edges, result + + +class TestOutEdgeView(TestEdgeView): + @classmethod + def setup_class(cls): + cls.G = nx.path_graph(9, nx.DiGraph()) + cls.eview = nx.reportviews.OutEdgeView + + def test_repr(self): + ev = self.eview(self.G) + rep = ( + "OutEdgeView([(0, 1), (1, 2), (2, 3), (3, 4), " + + "(4, 5), (5, 6), (6, 7), (7, 8)])" + ) + assert repr(ev) == rep + + def test_contains_with_nbunch(self): + ev = self.eview(self.G) + evn = ev(nbunch=[0, 2]) + assert (0, 1) in evn + assert (1, 2) not in evn + assert (2, 3) in evn + assert (3, 4) not in evn + assert (4, 5) not in evn + assert (5, 6) not in evn + assert (7, 8) not in evn + assert (8, 9) not in evn + + +class TestInEdgeView(TestEdgeView): + @classmethod + def setup_class(cls): + cls.G = nx.path_graph(9, nx.DiGraph()) + cls.eview = nx.reportviews.InEdgeView + + def test_repr(self): + ev = self.eview(self.G) + rep = ( + "InEdgeView([(0, 1), (1, 2), (2, 3), (3, 4), " + + "(4, 5), (5, 6), (6, 7), (7, 8)])" + ) + assert repr(ev) == rep + + def test_contains_with_nbunch(self): + ev = self.eview(self.G) + evn = ev(nbunch=[0, 2]) + assert (0, 1) not in evn + assert (1, 2) in evn + assert (2, 3) not in evn + assert (3, 4) not in evn + assert (4, 5) not in evn + assert (5, 6) not in evn + assert (7, 8) not in evn + assert (8, 9) not in evn + + +class TestMultiEdgeView(TestEdgeView): + @classmethod + def setup_class(cls): + cls.G = nx.path_graph(9, nx.MultiGraph()) + cls.G.add_edge(1, 2, key=3, foo="bar") + cls.eview = nx.reportviews.MultiEdgeView + + def modify_edge(self, G, e, **kwds): + if len(e) == 2: + e = e + (0,) + G._adj[e[0]][e[1]][e[2]].update(kwds) + + def test_str(self): + ev = self.eview(self.G) + replist = [(n, n + 1, 0) for n in range(8)] + replist.insert(2, (1, 2, 3)) + rep = str(replist) + assert str(ev) == rep + + def test_getitem(self): + G = self.G.copy() + ev = G.edges + G.edges[0, 1, 0]["foo"] = "bar" + assert ev[0, 1, 0] == {"foo": "bar"} + + # slicing + with pytest.raises(nx.NetworkXError): + G.edges[0:5] + + def test_repr(self): + ev = self.eview(self.G) + rep = ( + "MultiEdgeView([(0, 1, 0), (1, 2, 0), (1, 2, 3), (2, 3, 0), " + + "(3, 4, 0), (4, 5, 0), (5, 6, 0), (6, 7, 0), (7, 8, 0)])" + ) + assert repr(ev) == rep + + def test_call(self): + ev = self.eview(self.G) + assert id(ev) == id(ev(keys=True)) + assert id(ev) == id(ev(data=False, keys=True)) + assert id(ev) != id(ev(keys=False)) + assert id(ev) != id(ev(data=True)) + assert id(ev) != id(ev(nbunch=1)) + + def test_data(self): + ev = self.eview(self.G) + assert id(ev) != id(ev.data()) + assert id(ev) == id(ev.data(data=False, keys=True)) + assert id(ev) != id(ev.data(keys=False)) + assert id(ev) != id(ev.data(data=True)) + assert id(ev) != id(ev.data(nbunch=1)) + + def test_iter(self): + ev = self.eview(self.G) + for u, v, k in ev: + pass + iev = iter(ev) + assert next(iev) == (0, 1, 0) + assert iter(ev) != ev + assert iter(iev) == iev + + def test_iterkeys(self): + G = self.G + evr = self.eview(G) + ev = evr(keys=True) + for u, v, k in ev: + pass + assert k == 0 + ev = evr(keys=True, data="foo", default=1) + for u, v, k, wt in ev: + pass + assert wt == 1 + + self.modify_edge(G, (2, 3, 0), foo="bar") + ev = evr(keys=True, data=True) + for e in ev: + assert len(e) == 4 + print("edge:", e) + if set(e[:2]) == {2, 3}: + print(self.G._adj[2][3]) + assert e[2] == 0 + assert e[3] == {"foo": "bar"} + checked = True + elif set(e[:3]) == {1, 2, 3}: + assert e[2] == 3 + assert e[3] == {"foo": "bar"} + checked_multi = True + else: + assert e[2] == 0 + assert e[3] == {} + assert checked + assert checked_multi + ev = evr(keys=True, data="foo", default=1) + for e in ev: + if set(e[:2]) == {1, 2} and e[2] == 3: + assert e[3] == "bar" + if set(e[:2]) == {1, 2} and e[2] == 0: + assert e[3] == 1 + if set(e[:2]) == {2, 3}: + assert e[2] == 0 + assert e[3] == "bar" + assert len(e) == 4 + checked_wt = True + assert checked_wt + ev = evr(keys=True) + for e in ev: + assert len(e) == 3 + elist = sorted([(i, i + 1, 0) for i in range(8)] + [(1, 2, 3)]) + assert sorted(ev) == elist + # test that the keyword arguments are passed correctly + ev = evr((1, 2), "foo", keys=True, default=1) + with pytest.raises(TypeError): + evr((1, 2), "foo", True, 1) + with pytest.raises(TypeError): + evr((1, 2), "foo", True, default=1) + for e in ev: + if set(e[:2]) == {1, 2}: + assert e[2] in {0, 3} + if e[2] == 3: + assert e[3] == "bar" + else: # e[2] == 0 + assert e[3] == 1 + if G.is_directed(): + assert len(list(ev)) == 3 + else: + assert len(list(ev)) == 4 + + def test_or(self): + # print("G | H edges:", gnv | hnv) + ev = self.eview(self.G) + some_edges = {(0, 1, 0), (1, 0, 0), (0, 2, 0)} + result = {(n, n + 1, 0) for n in range(8)} + result.update(some_edges) + result.update({(1, 2, 3)}) + assert ev | some_edges == result + assert some_edges | ev == result + + def test_sub(self): + # print("G - H edges:", gnv - hnv) + ev = self.eview(self.G) + some_edges = {(0, 1, 0), (1, 0, 0), (0, 2, 0)} + result = {(n, n + 1, 0) for n in range(8)} + result.remove((0, 1, 0)) + result.update({(1, 2, 3)}) + assert ev - some_edges, result + assert some_edges - ev, result + + def test_xor(self): + # print("G ^ H edges:", gnv ^ hnv) + ev = self.eview(self.G) + some_edges = {(0, 1, 0), (1, 0, 0), (0, 2, 0)} + if self.G.is_directed(): + result = {(n, n + 1, 0) for n in range(1, 8)} + result.update({(1, 0, 0), (0, 2, 0), (1, 2, 3)}) + assert ev ^ some_edges == result + assert some_edges ^ ev == result + else: + result = {(n, n + 1, 0) for n in range(1, 8)} + result.update({(0, 2, 0), (1, 2, 3)}) + assert ev ^ some_edges == result + assert some_edges ^ ev == result + + def test_and(self): + # print("G & H edges:", gnv & hnv) + ev = self.eview(self.G) + some_edges = {(0, 1, 0), (1, 0, 0), (0, 2, 0)} + if self.G.is_directed(): + assert ev & some_edges == {(0, 1, 0)} + assert some_edges & ev == {(0, 1, 0)} + else: + assert ev & some_edges == {(0, 1, 0), (1, 0, 0)} + assert some_edges & ev == {(0, 1, 0), (1, 0, 0)} + + def test_contains_with_nbunch(self): + ev = self.eview(self.G) + evn = ev(nbunch=[0, 2]) + assert (0, 1) in evn + assert (1, 2) in evn + assert (2, 3) in evn + assert (3, 4) not in evn + assert (4, 5) not in evn + assert (5, 6) not in evn + assert (7, 8) not in evn + assert (8, 9) not in evn + + +class TestOutMultiEdgeView(TestMultiEdgeView): + @classmethod + def setup_class(cls): + cls.G = nx.path_graph(9, nx.MultiDiGraph()) + cls.G.add_edge(1, 2, key=3, foo="bar") + cls.eview = nx.reportviews.OutMultiEdgeView + + def modify_edge(self, G, e, **kwds): + if len(e) == 2: + e = e + (0,) + G._adj[e[0]][e[1]][e[2]].update(kwds) + + def test_repr(self): + ev = self.eview(self.G) + rep = ( + "OutMultiEdgeView([(0, 1, 0), (1, 2, 0), (1, 2, 3), (2, 3, 0)," + + " (3, 4, 0), (4, 5, 0), (5, 6, 0), (6, 7, 0), (7, 8, 0)])" + ) + assert repr(ev) == rep + + def test_contains_with_nbunch(self): + ev = self.eview(self.G) + evn = ev(nbunch=[0, 2]) + assert (0, 1) in evn + assert (1, 2) not in evn + assert (2, 3) in evn + assert (3, 4) not in evn + assert (4, 5) not in evn + assert (5, 6) not in evn + assert (7, 8) not in evn + assert (8, 9) not in evn + + +class TestInMultiEdgeView(TestMultiEdgeView): + @classmethod + def setup_class(cls): + cls.G = nx.path_graph(9, nx.MultiDiGraph()) + cls.G.add_edge(1, 2, key=3, foo="bar") + cls.eview = nx.reportviews.InMultiEdgeView + + def modify_edge(self, G, e, **kwds): + if len(e) == 2: + e = e + (0,) + G._adj[e[0]][e[1]][e[2]].update(kwds) + + def test_repr(self): + ev = self.eview(self.G) + rep = ( + "InMultiEdgeView([(0, 1, 0), (1, 2, 0), (1, 2, 3), (2, 3, 0), " + + "(3, 4, 0), (4, 5, 0), (5, 6, 0), (6, 7, 0), (7, 8, 0)])" + ) + assert repr(ev) == rep + + def test_contains_with_nbunch(self): + ev = self.eview(self.G) + evn = ev(nbunch=[0, 2]) + assert (0, 1) not in evn + assert (1, 2) in evn + assert (2, 3) not in evn + assert (3, 4) not in evn + assert (4, 5) not in evn + assert (5, 6) not in evn + assert (7, 8) not in evn + assert (8, 9) not in evn + + +# Degrees +class TestDegreeView: + GRAPH = nx.Graph + dview = nx.reportviews.DegreeView + + @classmethod + def setup_class(cls): + cls.G = nx.path_graph(6, cls.GRAPH()) + cls.G.add_edge(1, 3, foo=2) + cls.G.add_edge(1, 3, foo=3) + + def test_pickle(self): + import pickle + + deg = self.G.degree + pdeg = pickle.loads(pickle.dumps(deg, -1)) + assert dict(deg) == dict(pdeg) + + def test_str(self): + dv = self.dview(self.G) + rep = str([(0, 1), (1, 3), (2, 2), (3, 3), (4, 2), (5, 1)]) + assert str(dv) == rep + dv = self.G.degree() + assert str(dv) == rep + + def test_repr(self): + dv = self.dview(self.G) + rep = "DegreeView({0: 1, 1: 3, 2: 2, 3: 3, 4: 2, 5: 1})" + assert repr(dv) == rep + + def test_iter(self): + dv = self.dview(self.G) + for n, d in dv: + pass + idv = iter(dv) + assert iter(dv) != dv + assert iter(idv) == idv + assert next(idv) == (0, dv[0]) + assert next(idv) == (1, dv[1]) + # weighted + dv = self.dview(self.G, weight="foo") + for n, d in dv: + pass + idv = iter(dv) + assert iter(dv) != dv + assert iter(idv) == idv + assert next(idv) == (0, dv[0]) + assert next(idv) == (1, dv[1]) + + def test_nbunch(self): + dv = self.dview(self.G) + dvn = dv(0) + assert dvn == 1 + dvn = dv([2, 3]) + assert sorted(dvn) == [(2, 2), (3, 3)] + + def test_getitem(self): + dv = self.dview(self.G) + assert dv[0] == 1 + assert dv[1] == 3 + assert dv[2] == 2 + assert dv[3] == 3 + dv = self.dview(self.G, weight="foo") + assert dv[0] == 1 + assert dv[1] == 5 + assert dv[2] == 2 + assert dv[3] == 5 + + def test_weight(self): + dv = self.dview(self.G) + dvw = dv(0, weight="foo") + assert dvw == 1 + dvw = dv(1, weight="foo") + assert dvw == 5 + dvw = dv([2, 3], weight="foo") + assert sorted(dvw) == [(2, 2), (3, 5)] + dvd = dict(dv(weight="foo")) + assert dvd[0] == 1 + assert dvd[1] == 5 + assert dvd[2] == 2 + assert dvd[3] == 5 + + def test_len(self): + dv = self.dview(self.G) + assert len(dv) == 6 + + +class TestDiDegreeView(TestDegreeView): + GRAPH = nx.DiGraph + dview = nx.reportviews.DiDegreeView + + def test_repr(self): + dv = self.G.degree() + rep = "DiDegreeView({0: 1, 1: 3, 2: 2, 3: 3, 4: 2, 5: 1})" + assert repr(dv) == rep + + +class TestOutDegreeView(TestDegreeView): + GRAPH = nx.DiGraph + dview = nx.reportviews.OutDegreeView + + def test_str(self): + dv = self.dview(self.G) + rep = str([(0, 1), (1, 2), (2, 1), (3, 1), (4, 1), (5, 0)]) + assert str(dv) == rep + dv = self.G.out_degree() + assert str(dv) == rep + + def test_repr(self): + dv = self.G.out_degree() + rep = "OutDegreeView({0: 1, 1: 2, 2: 1, 3: 1, 4: 1, 5: 0})" + assert repr(dv) == rep + + def test_nbunch(self): + dv = self.dview(self.G) + dvn = dv(0) + assert dvn == 1 + dvn = dv([2, 3]) + assert sorted(dvn) == [(2, 1), (3, 1)] + + def test_getitem(self): + dv = self.dview(self.G) + assert dv[0] == 1 + assert dv[1] == 2 + assert dv[2] == 1 + assert dv[3] == 1 + dv = self.dview(self.G, weight="foo") + assert dv[0] == 1 + assert dv[1] == 4 + assert dv[2] == 1 + assert dv[3] == 1 + + def test_weight(self): + dv = self.dview(self.G) + dvw = dv(0, weight="foo") + assert dvw == 1 + dvw = dv(1, weight="foo") + assert dvw == 4 + dvw = dv([2, 3], weight="foo") + assert sorted(dvw) == [(2, 1), (3, 1)] + dvd = dict(dv(weight="foo")) + assert dvd[0] == 1 + assert dvd[1] == 4 + assert dvd[2] == 1 + assert dvd[3] == 1 + + +class TestInDegreeView(TestDegreeView): + GRAPH = nx.DiGraph + dview = nx.reportviews.InDegreeView + + def test_str(self): + dv = self.dview(self.G) + rep = str([(0, 0), (1, 1), (2, 1), (3, 2), (4, 1), (5, 1)]) + assert str(dv) == rep + dv = self.G.in_degree() + assert str(dv) == rep + + def test_repr(self): + dv = self.G.in_degree() + rep = "InDegreeView({0: 0, 1: 1, 2: 1, 3: 2, 4: 1, 5: 1})" + assert repr(dv) == rep + + def test_nbunch(self): + dv = self.dview(self.G) + dvn = dv(0) + assert dvn == 0 + dvn = dv([2, 3]) + assert sorted(dvn) == [(2, 1), (3, 2)] + + def test_getitem(self): + dv = self.dview(self.G) + assert dv[0] == 0 + assert dv[1] == 1 + assert dv[2] == 1 + assert dv[3] == 2 + dv = self.dview(self.G, weight="foo") + assert dv[0] == 0 + assert dv[1] == 1 + assert dv[2] == 1 + assert dv[3] == 4 + + def test_weight(self): + dv = self.dview(self.G) + dvw = dv(0, weight="foo") + assert dvw == 0 + dvw = dv(1, weight="foo") + assert dvw == 1 + dvw = dv([2, 3], weight="foo") + assert sorted(dvw) == [(2, 1), (3, 4)] + dvd = dict(dv(weight="foo")) + assert dvd[0] == 0 + assert dvd[1] == 1 + assert dvd[2] == 1 + assert dvd[3] == 4 + + +class TestMultiDegreeView(TestDegreeView): + GRAPH = nx.MultiGraph + dview = nx.reportviews.MultiDegreeView + + def test_str(self): + dv = self.dview(self.G) + rep = str([(0, 1), (1, 4), (2, 2), (3, 4), (4, 2), (5, 1)]) + assert str(dv) == rep + dv = self.G.degree() + assert str(dv) == rep + + def test_repr(self): + dv = self.G.degree() + rep = "MultiDegreeView({0: 1, 1: 4, 2: 2, 3: 4, 4: 2, 5: 1})" + assert repr(dv) == rep + + def test_nbunch(self): + dv = self.dview(self.G) + dvn = dv(0) + assert dvn == 1 + dvn = dv([2, 3]) + assert sorted(dvn) == [(2, 2), (3, 4)] + + def test_getitem(self): + dv = self.dview(self.G) + assert dv[0] == 1 + assert dv[1] == 4 + assert dv[2] == 2 + assert dv[3] == 4 + dv = self.dview(self.G, weight="foo") + assert dv[0] == 1 + assert dv[1] == 7 + assert dv[2] == 2 + assert dv[3] == 7 + + def test_weight(self): + dv = self.dview(self.G) + dvw = dv(0, weight="foo") + assert dvw == 1 + dvw = dv(1, weight="foo") + assert dvw == 7 + dvw = dv([2, 3], weight="foo") + assert sorted(dvw) == [(2, 2), (3, 7)] + dvd = dict(dv(weight="foo")) + assert dvd[0] == 1 + assert dvd[1] == 7 + assert dvd[2] == 2 + assert dvd[3] == 7 + + +class TestDiMultiDegreeView(TestMultiDegreeView): + GRAPH = nx.MultiDiGraph + dview = nx.reportviews.DiMultiDegreeView + + def test_repr(self): + dv = self.G.degree() + rep = "DiMultiDegreeView({0: 1, 1: 4, 2: 2, 3: 4, 4: 2, 5: 1})" + assert repr(dv) == rep + + +class TestOutMultiDegreeView(TestDegreeView): + GRAPH = nx.MultiDiGraph + dview = nx.reportviews.OutMultiDegreeView + + def test_str(self): + dv = self.dview(self.G) + rep = str([(0, 1), (1, 3), (2, 1), (3, 1), (4, 1), (5, 0)]) + assert str(dv) == rep + dv = self.G.out_degree() + assert str(dv) == rep + + def test_repr(self): + dv = self.G.out_degree() + rep = "OutMultiDegreeView({0: 1, 1: 3, 2: 1, 3: 1, 4: 1, 5: 0})" + assert repr(dv) == rep + + def test_nbunch(self): + dv = self.dview(self.G) + dvn = dv(0) + assert dvn == 1 + dvn = dv([2, 3]) + assert sorted(dvn) == [(2, 1), (3, 1)] + + def test_getitem(self): + dv = self.dview(self.G) + assert dv[0] == 1 + assert dv[1] == 3 + assert dv[2] == 1 + assert dv[3] == 1 + dv = self.dview(self.G, weight="foo") + assert dv[0] == 1 + assert dv[1] == 6 + assert dv[2] == 1 + assert dv[3] == 1 + + def test_weight(self): + dv = self.dview(self.G) + dvw = dv(0, weight="foo") + assert dvw == 1 + dvw = dv(1, weight="foo") + assert dvw == 6 + dvw = dv([2, 3], weight="foo") + assert sorted(dvw) == [(2, 1), (3, 1)] + dvd = dict(dv(weight="foo")) + assert dvd[0] == 1 + assert dvd[1] == 6 + assert dvd[2] == 1 + assert dvd[3] == 1 + + +class TestInMultiDegreeView(TestDegreeView): + GRAPH = nx.MultiDiGraph + dview = nx.reportviews.InMultiDegreeView + + def test_str(self): + dv = self.dview(self.G) + rep = str([(0, 0), (1, 1), (2, 1), (3, 3), (4, 1), (5, 1)]) + assert str(dv) == rep + dv = self.G.in_degree() + assert str(dv) == rep + + def test_repr(self): + dv = self.G.in_degree() + rep = "InMultiDegreeView({0: 0, 1: 1, 2: 1, 3: 3, 4: 1, 5: 1})" + assert repr(dv) == rep + + def test_nbunch(self): + dv = self.dview(self.G) + dvn = dv(0) + assert dvn == 0 + dvn = dv([2, 3]) + assert sorted(dvn) == [(2, 1), (3, 3)] + + def test_getitem(self): + dv = self.dview(self.G) + assert dv[0] == 0 + assert dv[1] == 1 + assert dv[2] == 1 + assert dv[3] == 3 + dv = self.dview(self.G, weight="foo") + assert dv[0] == 0 + assert dv[1] == 1 + assert dv[2] == 1 + assert dv[3] == 6 + + def test_weight(self): + dv = self.dview(self.G) + dvw = dv(0, weight="foo") + assert dvw == 0 + dvw = dv(1, weight="foo") + assert dvw == 1 + dvw = dv([2, 3], weight="foo") + assert sorted(dvw) == [(2, 1), (3, 6)] + dvd = dict(dv(weight="foo")) + assert dvd[0] == 0 + assert dvd[1] == 1 + assert dvd[2] == 1 + assert dvd[3] == 6 + + +@pytest.mark.parametrize( + ("reportview", "err_msg_terms"), + ( + (rv.NodeView, "list(G.nodes"), + (rv.NodeDataView, "list(G.nodes.data"), + (rv.EdgeView, "list(G.edges"), + # Directed EdgeViews + (rv.InEdgeView, "list(G.in_edges"), + (rv.OutEdgeView, "list(G.edges"), + # Multi EdgeViews + (rv.MultiEdgeView, "list(G.edges"), + (rv.InMultiEdgeView, "list(G.in_edges"), + (rv.OutMultiEdgeView, "list(G.edges"), + ), +) +def test_slicing_reportviews(reportview, err_msg_terms): + G = nx.complete_graph(3) + view = reportview(G) + with pytest.raises(nx.NetworkXError) as exc: + view[0:2] + errmsg = str(exc.value) + assert type(view).__name__ in errmsg + assert err_msg_terms in errmsg + + +@pytest.mark.parametrize( + "graph", [nx.Graph, nx.DiGraph, nx.MultiGraph, nx.MultiDiGraph] +) +def test_cache_dict_get_set_state(graph): + G = nx.path_graph(5, graph()) + G.nodes, G.edges, G.adj, G.degree + if G.is_directed(): + G.pred, G.succ, G.in_edges, G.out_edges, G.in_degree, G.out_degree + cached_dict = G.__dict__ + assert "nodes" in cached_dict + assert "edges" in cached_dict + assert "adj" in cached_dict + assert "degree" in cached_dict + if G.is_directed(): + assert "pred" in cached_dict + assert "succ" in cached_dict + assert "in_edges" in cached_dict + assert "out_edges" in cached_dict + assert "in_degree" in cached_dict + assert "out_degree" in cached_dict + + # Raises error if the cached properties and views do not work + pickle.loads(pickle.dumps(G, -1)) + deepcopy(G) + + +def test_edge_views_inherit_from_EdgeViewABC(): + all_edge_view_classes = (v for v in dir(nx.reportviews) if "Edge" in v) + for eview_class in all_edge_view_classes: + assert issubclass( + getattr(nx.reportviews, eview_class), nx.reportviews.EdgeViewABC + ) diff --git a/lib/python3.10/site-packages/networkx/classes/tests/test_special.py b/lib/python3.10/site-packages/networkx/classes/tests/test_special.py new file mode 100644 index 0000000000000000000000000000000000000000..1fa79605b484f57ed6cbd17762b21a23406ee006 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/classes/tests/test_special.py @@ -0,0 +1,131 @@ +import networkx as nx + +from .test_digraph import BaseDiGraphTester +from .test_digraph import TestDiGraph as _TestDiGraph +from .test_graph import BaseGraphTester +from .test_graph import TestGraph as _TestGraph +from .test_multidigraph import TestMultiDiGraph as _TestMultiDiGraph +from .test_multigraph import TestMultiGraph as _TestMultiGraph + + +def test_factories(): + class mydict1(dict): + pass + + class mydict2(dict): + pass + + class mydict3(dict): + pass + + class mydict4(dict): + pass + + class mydict5(dict): + pass + + for Graph in (nx.Graph, nx.DiGraph, nx.MultiGraph, nx.MultiDiGraph): + # print("testing class: ", Graph.__name__) + class MyGraph(Graph): + node_dict_factory = mydict1 + adjlist_outer_dict_factory = mydict2 + adjlist_inner_dict_factory = mydict3 + edge_key_dict_factory = mydict4 + edge_attr_dict_factory = mydict5 + + G = MyGraph() + assert isinstance(G._node, mydict1) + assert isinstance(G._adj, mydict2) + G.add_node(1) + assert isinstance(G._adj[1], mydict3) + if G.is_directed(): + assert isinstance(G._pred, mydict2) + assert isinstance(G._succ, mydict2) + assert isinstance(G._pred[1], mydict3) + G.add_edge(1, 2) + if G.is_multigraph(): + assert isinstance(G._adj[1][2], mydict4) + assert isinstance(G._adj[1][2][0], mydict5) + else: + assert isinstance(G._adj[1][2], mydict5) + + +class TestSpecialGraph(_TestGraph): + def setup_method(self): + _TestGraph.setup_method(self) + self.Graph = nx.Graph + + +class TestThinGraph(BaseGraphTester): + def setup_method(self): + all_edge_dict = {"weight": 1} + + class MyGraph(nx.Graph): + def edge_attr_dict_factory(self): + return all_edge_dict + + self.Graph = MyGraph + # build dict-of-dict-of-dict K3 + ed1, ed2, ed3 = (all_edge_dict, all_edge_dict, all_edge_dict) + self.k3adj = {0: {1: ed1, 2: ed2}, 1: {0: ed1, 2: ed3}, 2: {0: ed2, 1: ed3}} + self.k3edges = [(0, 1), (0, 2), (1, 2)] + self.k3nodes = [0, 1, 2] + self.K3 = self.Graph() + self.K3._adj = self.k3adj + self.K3._node = {} + self.K3._node[0] = {} + self.K3._node[1] = {} + self.K3._node[2] = {} + + +class TestSpecialDiGraph(_TestDiGraph): + def setup_method(self): + _TestDiGraph.setup_method(self) + self.Graph = nx.DiGraph + + +class TestThinDiGraph(BaseDiGraphTester): + def setup_method(self): + all_edge_dict = {"weight": 1} + + class MyGraph(nx.DiGraph): + def edge_attr_dict_factory(self): + return all_edge_dict + + self.Graph = MyGraph + # build dict-of-dict-of-dict K3 + ed1, ed2, ed3 = (all_edge_dict, all_edge_dict, all_edge_dict) + ed4, ed5, ed6 = (all_edge_dict, all_edge_dict, all_edge_dict) + self.k3adj = {0: {1: ed1, 2: ed2}, 1: {0: ed3, 2: ed4}, 2: {0: ed5, 1: ed6}} + self.k3edges = [(0, 1), (0, 2), (1, 2)] + self.k3nodes = [0, 1, 2] + self.K3 = self.Graph() + self.K3._succ = self.k3adj + # K3._adj is synced with K3._succ + self.K3._pred = {0: {1: ed3, 2: ed5}, 1: {0: ed1, 2: ed6}, 2: {0: ed2, 1: ed4}} + self.K3._node = {} + self.K3._node[0] = {} + self.K3._node[1] = {} + self.K3._node[2] = {} + + ed1, ed2 = (all_edge_dict, all_edge_dict) + self.P3 = self.Graph() + self.P3._succ = {0: {1: ed1}, 1: {2: ed2}, 2: {}} + # P3._adj is synced with P3._succ + self.P3._pred = {0: {}, 1: {0: ed1}, 2: {1: ed2}} + self.P3._node = {} + self.P3._node[0] = {} + self.P3._node[1] = {} + self.P3._node[2] = {} + + +class TestSpecialMultiGraph(_TestMultiGraph): + def setup_method(self): + _TestMultiGraph.setup_method(self) + self.Graph = nx.MultiGraph + + +class TestSpecialMultiDiGraph(_TestMultiDiGraph): + def setup_method(self): + _TestMultiDiGraph.setup_method(self) + self.Graph = nx.MultiDiGraph diff --git a/lib/python3.10/site-packages/networkx/classes/tests/test_subgraphviews.py b/lib/python3.10/site-packages/networkx/classes/tests/test_subgraphviews.py new file mode 100644 index 0000000000000000000000000000000000000000..73e0fdd2d52bcb7623dbd4e4f502e8bed0a4e3d3 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/classes/tests/test_subgraphviews.py @@ -0,0 +1,362 @@ +import pytest + +import networkx as nx +from networkx.utils import edges_equal + + +class TestSubGraphView: + gview = staticmethod(nx.subgraph_view) + graph = nx.Graph + hide_edges_filter = staticmethod(nx.filters.hide_edges) + show_edges_filter = staticmethod(nx.filters.show_edges) + + @classmethod + def setup_class(cls): + cls.G = nx.path_graph(9, create_using=cls.graph()) + cls.hide_edges_w_hide_nodes = {(3, 4), (4, 5), (5, 6)} + + def test_hidden_nodes(self): + hide_nodes = [4, 5, 111] + nodes_gone = nx.filters.hide_nodes(hide_nodes) + gview = self.gview + G = gview(self.G, filter_node=nodes_gone) + assert self.G.nodes - G.nodes == {4, 5} + assert self.G.edges - G.edges == self.hide_edges_w_hide_nodes + if G.is_directed(): + assert list(G[3]) == [] + assert list(G[2]) == [3] + else: + assert list(G[3]) == [2] + assert set(G[2]) == {1, 3} + pytest.raises(KeyError, G.__getitem__, 4) + pytest.raises(KeyError, G.__getitem__, 112) + pytest.raises(KeyError, G.__getitem__, 111) + assert G.degree(3) == (3 if G.is_multigraph() else 1) + assert G.size() == (7 if G.is_multigraph() else 5) + + def test_hidden_edges(self): + hide_edges = [(2, 3), (8, 7), (222, 223)] + edges_gone = self.hide_edges_filter(hide_edges) + gview = self.gview + G = gview(self.G, filter_edge=edges_gone) + assert self.G.nodes == G.nodes + if G.is_directed(): + assert self.G.edges - G.edges == {(2, 3)} + assert list(G[2]) == [] + assert list(G.pred[3]) == [] + assert list(G.pred[2]) == [1] + assert G.size() == 7 + else: + assert self.G.edges - G.edges == {(2, 3), (7, 8)} + assert list(G[2]) == [1] + assert G.size() == 6 + assert list(G[3]) == [4] + pytest.raises(KeyError, G.__getitem__, 221) + pytest.raises(KeyError, G.__getitem__, 222) + assert G.degree(3) == 1 + + def test_shown_node(self): + induced_subgraph = nx.filters.show_nodes([2, 3, 111]) + gview = self.gview + G = gview(self.G, filter_node=induced_subgraph) + assert set(G.nodes) == {2, 3} + if G.is_directed(): + assert list(G[3]) == [] + else: + assert list(G[3]) == [2] + assert list(G[2]) == [3] + pytest.raises(KeyError, G.__getitem__, 4) + pytest.raises(KeyError, G.__getitem__, 112) + pytest.raises(KeyError, G.__getitem__, 111) + assert G.degree(3) == (3 if G.is_multigraph() else 1) + assert G.size() == (3 if G.is_multigraph() else 1) + + def test_shown_edges(self): + show_edges = [(2, 3), (8, 7), (222, 223)] + edge_subgraph = self.show_edges_filter(show_edges) + G = self.gview(self.G, filter_edge=edge_subgraph) + assert self.G.nodes == G.nodes + if G.is_directed(): + assert G.edges == {(2, 3)} + assert list(G[3]) == [] + assert list(G[2]) == [3] + assert list(G.pred[3]) == [2] + assert list(G.pred[2]) == [] + assert G.size() == 1 + else: + assert G.edges == {(2, 3), (7, 8)} + assert list(G[3]) == [2] + assert list(G[2]) == [3] + assert G.size() == 2 + pytest.raises(KeyError, G.__getitem__, 221) + pytest.raises(KeyError, G.__getitem__, 222) + assert G.degree(3) == 1 + + +class TestSubDiGraphView(TestSubGraphView): + gview = staticmethod(nx.subgraph_view) + graph = nx.DiGraph + hide_edges_filter = staticmethod(nx.filters.hide_diedges) + show_edges_filter = staticmethod(nx.filters.show_diedges) + hide_edges = [(2, 3), (8, 7), (222, 223)] + excluded = {(2, 3), (3, 4), (4, 5), (5, 6)} + + def test_inoutedges(self): + edges_gone = self.hide_edges_filter(self.hide_edges) + hide_nodes = [4, 5, 111] + nodes_gone = nx.filters.hide_nodes(hide_nodes) + G = self.gview(self.G, filter_node=nodes_gone, filter_edge=edges_gone) + + assert self.G.in_edges - G.in_edges == self.excluded + assert self.G.out_edges - G.out_edges == self.excluded + + def test_pred(self): + edges_gone = self.hide_edges_filter(self.hide_edges) + hide_nodes = [4, 5, 111] + nodes_gone = nx.filters.hide_nodes(hide_nodes) + G = self.gview(self.G, filter_node=nodes_gone, filter_edge=edges_gone) + + assert list(G.pred[2]) == [1] + assert list(G.pred[6]) == [] + + def test_inout_degree(self): + edges_gone = self.hide_edges_filter(self.hide_edges) + hide_nodes = [4, 5, 111] + nodes_gone = nx.filters.hide_nodes(hide_nodes) + G = self.gview(self.G, filter_node=nodes_gone, filter_edge=edges_gone) + + assert G.degree(2) == 1 + assert G.out_degree(2) == 0 + assert G.in_degree(2) == 1 + assert G.size() == 4 + + +# multigraph +class TestMultiGraphView(TestSubGraphView): + gview = staticmethod(nx.subgraph_view) + graph = nx.MultiGraph + hide_edges_filter = staticmethod(nx.filters.hide_multiedges) + show_edges_filter = staticmethod(nx.filters.show_multiedges) + + @classmethod + def setup_class(cls): + cls.G = nx.path_graph(9, create_using=cls.graph()) + multiedges = {(2, 3, 4), (2, 3, 5)} + cls.G.add_edges_from(multiedges) + cls.hide_edges_w_hide_nodes = {(3, 4, 0), (4, 5, 0), (5, 6, 0)} + + def test_hidden_edges(self): + hide_edges = [(2, 3, 4), (2, 3, 3), (8, 7, 0), (222, 223, 0)] + edges_gone = self.hide_edges_filter(hide_edges) + G = self.gview(self.G, filter_edge=edges_gone) + assert self.G.nodes == G.nodes + if G.is_directed(): + assert self.G.edges - G.edges == {(2, 3, 4)} + assert list(G[3]) == [4] + assert list(G[2]) == [3] + assert list(G.pred[3]) == [2] # only one 2 but two edges + assert list(G.pred[2]) == [1] + assert G.size() == 9 + else: + assert self.G.edges - G.edges == {(2, 3, 4), (7, 8, 0)} + assert list(G[3]) == [2, 4] + assert list(G[2]) == [1, 3] + assert G.size() == 8 + assert G.degree(3) == 3 + pytest.raises(KeyError, G.__getitem__, 221) + pytest.raises(KeyError, G.__getitem__, 222) + + def test_shown_edges(self): + show_edges = [(2, 3, 4), (2, 3, 3), (8, 7, 0), (222, 223, 0)] + edge_subgraph = self.show_edges_filter(show_edges) + G = self.gview(self.G, filter_edge=edge_subgraph) + assert self.G.nodes == G.nodes + if G.is_directed(): + assert G.edges == {(2, 3, 4)} + assert list(G[3]) == [] + assert list(G.pred[3]) == [2] + assert list(G.pred[2]) == [] + assert G.size() == 1 + else: + assert G.edges == {(2, 3, 4), (7, 8, 0)} + assert G.size() == 2 + assert list(G[3]) == [2] + assert G.degree(3) == 1 + assert list(G[2]) == [3] + pytest.raises(KeyError, G.__getitem__, 221) + pytest.raises(KeyError, G.__getitem__, 222) + + +# multidigraph +class TestMultiDiGraphView(TestMultiGraphView, TestSubDiGraphView): + gview = staticmethod(nx.subgraph_view) + graph = nx.MultiDiGraph + hide_edges_filter = staticmethod(nx.filters.hide_multidiedges) + show_edges_filter = staticmethod(nx.filters.show_multidiedges) + hide_edges = [(2, 3, 0), (8, 7, 0), (222, 223, 0)] + excluded = {(2, 3, 0), (3, 4, 0), (4, 5, 0), (5, 6, 0)} + + def test_inout_degree(self): + edges_gone = self.hide_edges_filter(self.hide_edges) + hide_nodes = [4, 5, 111] + nodes_gone = nx.filters.hide_nodes(hide_nodes) + G = self.gview(self.G, filter_node=nodes_gone, filter_edge=edges_gone) + + assert G.degree(2) == 3 + assert G.out_degree(2) == 2 + assert G.in_degree(2) == 1 + assert G.size() == 6 + + +# induced_subgraph +class TestInducedSubGraph: + @classmethod + def setup_class(cls): + cls.K3 = G = nx.complete_graph(3) + G.graph["foo"] = [] + G.nodes[0]["foo"] = [] + G.remove_edge(1, 2) + ll = [] + G.add_edge(1, 2, foo=ll) + G.add_edge(2, 1, foo=ll) + + def test_full_graph(self): + G = self.K3 + H = nx.induced_subgraph(G, [0, 1, 2, 5]) + assert H.name == G.name + self.graphs_equal(H, G) + self.same_attrdict(H, G) + + def test_partial_subgraph(self): + G = self.K3 + H = nx.induced_subgraph(G, 0) + assert dict(H.adj) == {0: {}} + assert dict(G.adj) != {0: {}} + + H = nx.induced_subgraph(G, [0, 1]) + assert dict(H.adj) == {0: {1: {}}, 1: {0: {}}} + + def same_attrdict(self, H, G): + old_foo = H[1][2]["foo"] + H.edges[1, 2]["foo"] = "baz" + assert G.edges == H.edges + H.edges[1, 2]["foo"] = old_foo + assert G.edges == H.edges + old_foo = H.nodes[0]["foo"] + H.nodes[0]["foo"] = "baz" + assert G.nodes == H.nodes + H.nodes[0]["foo"] = old_foo + assert G.nodes == H.nodes + + def graphs_equal(self, H, G): + assert G._adj == H._adj + assert G._node == H._node + assert G.graph == H.graph + assert G.name == H.name + if not G.is_directed() and not H.is_directed(): + assert H._adj[1][2] is H._adj[2][1] + assert G._adj[1][2] is G._adj[2][1] + else: # at least one is directed + if not G.is_directed(): + G._pred = G._adj + G._succ = G._adj + if not H.is_directed(): + H._pred = H._adj + H._succ = H._adj + assert G._pred == H._pred + assert G._succ == H._succ + assert H._succ[1][2] is H._pred[2][1] + assert G._succ[1][2] is G._pred[2][1] + + +# edge_subgraph +class TestEdgeSubGraph: + @classmethod + def setup_class(cls): + # Create a path graph on five nodes. + cls.G = G = nx.path_graph(5) + # Add some node, edge, and graph attributes. + for i in range(5): + G.nodes[i]["name"] = f"node{i}" + G.edges[0, 1]["name"] = "edge01" + G.edges[3, 4]["name"] = "edge34" + G.graph["name"] = "graph" + # Get the subgraph induced by the first and last edges. + cls.H = nx.edge_subgraph(G, [(0, 1), (3, 4)]) + + def test_correct_nodes(self): + """Tests that the subgraph has the correct nodes.""" + assert [(0, "node0"), (1, "node1"), (3, "node3"), (4, "node4")] == sorted( + self.H.nodes.data("name") + ) + + def test_correct_edges(self): + """Tests that the subgraph has the correct edges.""" + assert edges_equal( + [(0, 1, "edge01"), (3, 4, "edge34")], self.H.edges.data("name") + ) + + def test_add_node(self): + """Tests that adding a node to the original graph does not + affect the nodes of the subgraph. + + """ + self.G.add_node(5) + assert [0, 1, 3, 4] == sorted(self.H.nodes) + self.G.remove_node(5) + + def test_remove_node(self): + """Tests that removing a node in the original graph + removes the nodes of the subgraph. + + """ + self.G.remove_node(0) + assert [1, 3, 4] == sorted(self.H.nodes) + self.G.add_node(0, name="node0") + self.G.add_edge(0, 1, name="edge01") + + def test_node_attr_dict(self): + """Tests that the node attribute dictionary of the two graphs is + the same object. + + """ + for v in self.H: + assert self.G.nodes[v] == self.H.nodes[v] + # Making a change to G should make a change in H and vice versa. + self.G.nodes[0]["name"] = "foo" + assert self.G.nodes[0] == self.H.nodes[0] + self.H.nodes[1]["name"] = "bar" + assert self.G.nodes[1] == self.H.nodes[1] + # Revert the change, so tests pass with pytest-randomly + self.G.nodes[0]["name"] = "node0" + self.H.nodes[1]["name"] = "node1" + + def test_edge_attr_dict(self): + """Tests that the edge attribute dictionary of the two graphs is + the same object. + + """ + for u, v in self.H.edges(): + assert self.G.edges[u, v] == self.H.edges[u, v] + # Making a change to G should make a change in H and vice versa. + self.G.edges[0, 1]["name"] = "foo" + assert self.G.edges[0, 1]["name"] == self.H.edges[0, 1]["name"] + self.H.edges[3, 4]["name"] = "bar" + assert self.G.edges[3, 4]["name"] == self.H.edges[3, 4]["name"] + # Revert the change, so tests pass with pytest-randomly + self.G.edges[0, 1]["name"] = "edge01" + self.H.edges[3, 4]["name"] = "edge34" + + def test_graph_attr_dict(self): + """Tests that the graph attribute dictionary of the two graphs + is the same object. + + """ + assert self.G.graph is self.H.graph + + def test_readonly(self): + """Tests that the subgraph cannot change the graph structure""" + pytest.raises(nx.NetworkXError, self.H.add_node, 5) + pytest.raises(nx.NetworkXError, self.H.remove_node, 0) + pytest.raises(nx.NetworkXError, self.H.add_edge, 5, 6) + pytest.raises(nx.NetworkXError, self.H.remove_edge, 0, 1) diff --git a/lib/python3.10/site-packages/networkx/generators/tests/test_atlas.py b/lib/python3.10/site-packages/networkx/generators/tests/test_atlas.py new file mode 100644 index 0000000000000000000000000000000000000000..add4741c00e8d8aefe4fcf3a2a86815a15aab29c --- /dev/null +++ b/lib/python3.10/site-packages/networkx/generators/tests/test_atlas.py @@ -0,0 +1,75 @@ +from itertools import groupby + +import pytest + +import networkx as nx +from networkx import graph_atlas, graph_atlas_g +from networkx.generators.atlas import NUM_GRAPHS +from networkx.utils import edges_equal, nodes_equal, pairwise + + +class TestAtlasGraph: + """Unit tests for the :func:`~networkx.graph_atlas` function.""" + + def test_index_too_small(self): + with pytest.raises(ValueError): + graph_atlas(-1) + + def test_index_too_large(self): + with pytest.raises(ValueError): + graph_atlas(NUM_GRAPHS) + + def test_graph(self): + G = graph_atlas(6) + assert nodes_equal(G.nodes(), range(3)) + assert edges_equal(G.edges(), [(0, 1), (0, 2)]) + + +class TestAtlasGraphG: + """Unit tests for the :func:`~networkx.graph_atlas_g` function.""" + + @classmethod + def setup_class(cls): + cls.GAG = graph_atlas_g() + + def test_sizes(self): + G = self.GAG[0] + assert G.number_of_nodes() == 0 + assert G.number_of_edges() == 0 + + G = self.GAG[7] + assert G.number_of_nodes() == 3 + assert G.number_of_edges() == 3 + + def test_names(self): + for i, G in enumerate(self.GAG): + assert int(G.name[1:]) == i + + def test_nondecreasing_nodes(self): + # check for nondecreasing number of nodes + for n1, n2 in pairwise(map(len, self.GAG)): + assert n2 <= n1 + 1 + + def test_nondecreasing_edges(self): + # check for nondecreasing number of edges (for fixed number of + # nodes) + for n, group in groupby(self.GAG, key=nx.number_of_nodes): + for m1, m2 in pairwise(map(nx.number_of_edges, group)): + assert m2 <= m1 + 1 + + def test_nondecreasing_degree_sequence(self): + # Check for lexicographically nondecreasing degree sequences + # (for fixed number of nodes and edges). + # + # There are three exceptions to this rule in the order given in + # the "Atlas of Graphs" book, so we need to manually exclude + # those. + exceptions = [("G55", "G56"), ("G1007", "G1008"), ("G1012", "G1013")] + for n, group in groupby(self.GAG, key=nx.number_of_nodes): + for m, group in groupby(group, key=nx.number_of_edges): + for G1, G2 in pairwise(group): + if (G1.name, G2.name) in exceptions: + continue + d1 = sorted(d for v, d in G1.degree()) + d2 = sorted(d for v, d in G2.degree()) + assert d1 <= d2 diff --git a/lib/python3.10/site-packages/networkx/generators/tests/test_classic.py b/lib/python3.10/site-packages/networkx/generators/tests/test_classic.py new file mode 100644 index 0000000000000000000000000000000000000000..9353c7f505cb2ab515614bc2920b60cbde0f0992 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/generators/tests/test_classic.py @@ -0,0 +1,640 @@ +""" +==================== +Generators - Classic +==================== + +Unit tests for various classic graph generators in generators/classic.py +""" + +import itertools +import typing + +import pytest + +import networkx as nx +from networkx.algorithms.isomorphism.isomorph import graph_could_be_isomorphic +from networkx.utils import edges_equal, nodes_equal + +is_isomorphic = graph_could_be_isomorphic + + +class TestGeneratorClassic: + def test_balanced_tree(self): + # balanced_tree(r,h) is a tree with (r**(h+1)-1)/(r-1) edges + for r, h in [(2, 2), (3, 3), (6, 2)]: + t = nx.balanced_tree(r, h) + order = t.order() + assert order == (r ** (h + 1) - 1) / (r - 1) + assert nx.is_connected(t) + assert t.size() == order - 1 + dh = nx.degree_histogram(t) + assert dh[0] == 0 # no nodes of 0 + assert dh[1] == r**h # nodes of degree 1 are leaves + assert dh[r] == 1 # root is degree r + assert dh[r + 1] == order - r**h - 1 # everyone else is degree r+1 + assert len(dh) == r + 2 + + def test_balanced_tree_star(self): + # balanced_tree(r,1) is the r-star + t = nx.balanced_tree(r=2, h=1) + assert is_isomorphic(t, nx.star_graph(2)) + t = nx.balanced_tree(r=5, h=1) + assert is_isomorphic(t, nx.star_graph(5)) + t = nx.balanced_tree(r=10, h=1) + assert is_isomorphic(t, nx.star_graph(10)) + + def test_balanced_tree_path(self): + """Tests that the balanced tree with branching factor one is the + path graph. + + """ + # A tree of height four has five levels. + T = nx.balanced_tree(1, 4) + P = nx.path_graph(5) + assert is_isomorphic(T, P) + + def test_full_rary_tree(self): + r = 2 + n = 9 + t = nx.full_rary_tree(r, n) + assert t.order() == n + assert nx.is_connected(t) + dh = nx.degree_histogram(t) + assert dh[0] == 0 # no nodes of 0 + assert dh[1] == 5 # nodes of degree 1 are leaves + assert dh[r] == 1 # root is degree r + assert dh[r + 1] == 9 - 5 - 1 # everyone else is degree r+1 + assert len(dh) == r + 2 + + def test_full_rary_tree_balanced(self): + t = nx.full_rary_tree(2, 15) + th = nx.balanced_tree(2, 3) + assert is_isomorphic(t, th) + + def test_full_rary_tree_path(self): + t = nx.full_rary_tree(1, 10) + assert is_isomorphic(t, nx.path_graph(10)) + + def test_full_rary_tree_empty(self): + t = nx.full_rary_tree(0, 10) + assert is_isomorphic(t, nx.empty_graph(10)) + t = nx.full_rary_tree(3, 0) + assert is_isomorphic(t, nx.empty_graph(0)) + + def test_full_rary_tree_3_20(self): + t = nx.full_rary_tree(3, 20) + assert t.order() == 20 + + def test_barbell_graph(self): + # number of nodes = 2*m1 + m2 (2 m1-complete graphs + m2-path + 2 edges) + # number of edges = 2*(nx.number_of_edges(m1-complete graph) + m2 + 1 + m1 = 3 + m2 = 5 + b = nx.barbell_graph(m1, m2) + assert nx.number_of_nodes(b) == 2 * m1 + m2 + assert nx.number_of_edges(b) == m1 * (m1 - 1) + m2 + 1 + + m1 = 4 + m2 = 10 + b = nx.barbell_graph(m1, m2) + assert nx.number_of_nodes(b) == 2 * m1 + m2 + assert nx.number_of_edges(b) == m1 * (m1 - 1) + m2 + 1 + + m1 = 3 + m2 = 20 + b = nx.barbell_graph(m1, m2) + assert nx.number_of_nodes(b) == 2 * m1 + m2 + assert nx.number_of_edges(b) == m1 * (m1 - 1) + m2 + 1 + + # Raise NetworkXError if m1<2 + m1 = 1 + m2 = 20 + pytest.raises(nx.NetworkXError, nx.barbell_graph, m1, m2) + + # Raise NetworkXError if m2<0 + m1 = 5 + m2 = -2 + pytest.raises(nx.NetworkXError, nx.barbell_graph, m1, m2) + + # nx.barbell_graph(2,m) = nx.path_graph(m+4) + m1 = 2 + m2 = 5 + b = nx.barbell_graph(m1, m2) + assert is_isomorphic(b, nx.path_graph(m2 + 4)) + + m1 = 2 + m2 = 10 + b = nx.barbell_graph(m1, m2) + assert is_isomorphic(b, nx.path_graph(m2 + 4)) + + m1 = 2 + m2 = 20 + b = nx.barbell_graph(m1, m2) + assert is_isomorphic(b, nx.path_graph(m2 + 4)) + + pytest.raises( + nx.NetworkXError, nx.barbell_graph, m1, m2, create_using=nx.DiGraph() + ) + + mb = nx.barbell_graph(m1, m2, create_using=nx.MultiGraph()) + assert edges_equal(mb.edges(), b.edges()) + + def test_binomial_tree(self): + graphs = (None, nx.Graph, nx.DiGraph, nx.MultiGraph, nx.MultiDiGraph) + for create_using in graphs: + for n in range(4): + b = nx.binomial_tree(n, create_using) + assert nx.number_of_nodes(b) == 2**n + assert nx.number_of_edges(b) == (2**n - 1) + + def test_complete_graph(self): + # complete_graph(m) is a connected graph with + # m nodes and m*(m+1)/2 edges + for m in [0, 1, 3, 5]: + g = nx.complete_graph(m) + assert nx.number_of_nodes(g) == m + assert nx.number_of_edges(g) == m * (m - 1) // 2 + + mg = nx.complete_graph(m, create_using=nx.MultiGraph) + assert edges_equal(mg.edges(), g.edges()) + + g = nx.complete_graph("abc") + assert nodes_equal(g.nodes(), ["a", "b", "c"]) + assert g.size() == 3 + + # creates a self-loop... should it? + g = nx.complete_graph("abcb") + assert nodes_equal(g.nodes(), ["a", "b", "c"]) + assert g.size() == 4 + + g = nx.complete_graph("abcb", create_using=nx.MultiGraph) + assert nodes_equal(g.nodes(), ["a", "b", "c"]) + assert g.size() == 6 + + def test_complete_digraph(self): + # complete_graph(m) is a connected graph with + # m nodes and m*(m+1)/2 edges + for m in [0, 1, 3, 5]: + g = nx.complete_graph(m, create_using=nx.DiGraph) + assert nx.number_of_nodes(g) == m + assert nx.number_of_edges(g) == m * (m - 1) + + g = nx.complete_graph("abc", create_using=nx.DiGraph) + assert len(g) == 3 + assert g.size() == 6 + assert g.is_directed() + + def test_circular_ladder_graph(self): + G = nx.circular_ladder_graph(5) + pytest.raises( + nx.NetworkXError, nx.circular_ladder_graph, 5, create_using=nx.DiGraph + ) + mG = nx.circular_ladder_graph(5, create_using=nx.MultiGraph) + assert edges_equal(mG.edges(), G.edges()) + + def test_circulant_graph(self): + # Ci_n(1) is the cycle graph for all n + Ci6_1 = nx.circulant_graph(6, [1]) + C6 = nx.cycle_graph(6) + assert edges_equal(Ci6_1.edges(), C6.edges()) + + # Ci_n(1, 2, ..., n div 2) is the complete graph for all n + Ci7 = nx.circulant_graph(7, [1, 2, 3]) + K7 = nx.complete_graph(7) + assert edges_equal(Ci7.edges(), K7.edges()) + + # Ci_6(1, 3) is K_3,3 i.e. the utility graph + Ci6_1_3 = nx.circulant_graph(6, [1, 3]) + K3_3 = nx.complete_bipartite_graph(3, 3) + assert is_isomorphic(Ci6_1_3, K3_3) + + def test_cycle_graph(self): + G = nx.cycle_graph(4) + assert edges_equal(G.edges(), [(0, 1), (0, 3), (1, 2), (2, 3)]) + mG = nx.cycle_graph(4, create_using=nx.MultiGraph) + assert edges_equal(mG.edges(), [(0, 1), (0, 3), (1, 2), (2, 3)]) + G = nx.cycle_graph(4, create_using=nx.DiGraph) + assert not G.has_edge(2, 1) + assert G.has_edge(1, 2) + assert G.is_directed() + + G = nx.cycle_graph("abc") + assert len(G) == 3 + assert G.size() == 3 + G = nx.cycle_graph("abcb") + assert len(G) == 3 + assert G.size() == 2 + g = nx.cycle_graph("abc", nx.DiGraph) + assert len(g) == 3 + assert g.size() == 3 + assert g.is_directed() + g = nx.cycle_graph("abcb", nx.DiGraph) + assert len(g) == 3 + assert g.size() == 4 + + def test_dorogovtsev_goltsev_mendes_graph(self): + G = nx.dorogovtsev_goltsev_mendes_graph(0) + assert edges_equal(G.edges(), [(0, 1)]) + assert nodes_equal(list(G), [0, 1]) + G = nx.dorogovtsev_goltsev_mendes_graph(1) + assert edges_equal(G.edges(), [(0, 1), (0, 2), (1, 2)]) + assert nx.average_clustering(G) == 1.0 + assert nx.average_shortest_path_length(G) == 1.0 + assert sorted(nx.triangles(G).values()) == [1, 1, 1] + assert nx.is_planar(G) + G = nx.dorogovtsev_goltsev_mendes_graph(2) + assert nx.number_of_nodes(G) == 6 + assert nx.number_of_edges(G) == 9 + assert nx.average_clustering(G) == 0.75 + assert nx.average_shortest_path_length(G) == 1.4 + assert nx.is_planar(G) + G = nx.dorogovtsev_goltsev_mendes_graph(10) + assert nx.number_of_nodes(G) == 29526 + assert nx.number_of_edges(G) == 59049 + assert G.degree(0) == 1024 + assert G.degree(1) == 1024 + assert G.degree(2) == 1024 + + with pytest.raises(nx.NetworkXError, match=r"n must be greater than"): + nx.dorogovtsev_goltsev_mendes_graph(-1) + with pytest.raises(nx.NetworkXError, match=r"directed graph not supported"): + nx.dorogovtsev_goltsev_mendes_graph(7, create_using=nx.DiGraph) + with pytest.raises(nx.NetworkXError, match=r"multigraph not supported"): + nx.dorogovtsev_goltsev_mendes_graph(7, create_using=nx.MultiGraph) + with pytest.raises(nx.NetworkXError): + nx.dorogovtsev_goltsev_mendes_graph(7, create_using=nx.MultiDiGraph) + + def test_create_using(self): + G = nx.empty_graph() + assert isinstance(G, nx.Graph) + pytest.raises(TypeError, nx.empty_graph, create_using=0.0) + pytest.raises(TypeError, nx.empty_graph, create_using="Graph") + + G = nx.empty_graph(create_using=nx.MultiGraph) + assert isinstance(G, nx.MultiGraph) + G = nx.empty_graph(create_using=nx.DiGraph) + assert isinstance(G, nx.DiGraph) + + G = nx.empty_graph(create_using=nx.DiGraph, default=nx.MultiGraph) + assert isinstance(G, nx.DiGraph) + G = nx.empty_graph(create_using=None, default=nx.MultiGraph) + assert isinstance(G, nx.MultiGraph) + G = nx.empty_graph(default=nx.MultiGraph) + assert isinstance(G, nx.MultiGraph) + + G = nx.path_graph(5) + H = nx.empty_graph(create_using=G) + assert not H.is_multigraph() + assert not H.is_directed() + assert len(H) == 0 + assert G is H + + H = nx.empty_graph(create_using=nx.MultiGraph()) + assert H.is_multigraph() + assert not H.is_directed() + assert G is not H + + # test for subclasses that also use typing.Protocol. See gh-6243 + class Mixin(typing.Protocol): + pass + + class MyGraph(Mixin, nx.DiGraph): + pass + + G = nx.empty_graph(create_using=MyGraph) + + def test_empty_graph(self): + G = nx.empty_graph() + assert nx.number_of_nodes(G) == 0 + G = nx.empty_graph(42) + assert nx.number_of_nodes(G) == 42 + assert nx.number_of_edges(G) == 0 + + G = nx.empty_graph("abc") + assert len(G) == 3 + assert G.size() == 0 + + # create empty digraph + G = nx.empty_graph(42, create_using=nx.DiGraph(name="duh")) + assert nx.number_of_nodes(G) == 42 + assert nx.number_of_edges(G) == 0 + assert isinstance(G, nx.DiGraph) + + # create empty multigraph + G = nx.empty_graph(42, create_using=nx.MultiGraph(name="duh")) + assert nx.number_of_nodes(G) == 42 + assert nx.number_of_edges(G) == 0 + assert isinstance(G, nx.MultiGraph) + + # create empty graph from another + pete = nx.petersen_graph() + G = nx.empty_graph(42, create_using=pete) + assert nx.number_of_nodes(G) == 42 + assert nx.number_of_edges(G) == 0 + assert isinstance(G, nx.Graph) + + def test_ladder_graph(self): + for i, G in [ + (0, nx.empty_graph(0)), + (1, nx.path_graph(2)), + (2, nx.hypercube_graph(2)), + (10, nx.grid_graph([2, 10])), + ]: + assert is_isomorphic(nx.ladder_graph(i), G) + + pytest.raises(nx.NetworkXError, nx.ladder_graph, 2, create_using=nx.DiGraph) + + g = nx.ladder_graph(2) + mg = nx.ladder_graph(2, create_using=nx.MultiGraph) + assert edges_equal(mg.edges(), g.edges()) + + @pytest.mark.parametrize(("m", "n"), [(3, 5), (4, 10), (3, 20)]) + def test_lollipop_graph_right_sizes(self, m, n): + G = nx.lollipop_graph(m, n) + assert nx.number_of_nodes(G) == m + n + assert nx.number_of_edges(G) == m * (m - 1) / 2 + n + + @pytest.mark.parametrize(("m", "n"), [("ab", ""), ("abc", "defg")]) + def test_lollipop_graph_size_node_sequence(self, m, n): + G = nx.lollipop_graph(m, n) + assert nx.number_of_nodes(G) == len(m) + len(n) + assert nx.number_of_edges(G) == len(m) * (len(m) - 1) / 2 + len(n) + + def test_lollipop_graph_exceptions(self): + # Raise NetworkXError if m<2 + pytest.raises(nx.NetworkXError, nx.lollipop_graph, -1, 2) + pytest.raises(nx.NetworkXError, nx.lollipop_graph, 1, 20) + pytest.raises(nx.NetworkXError, nx.lollipop_graph, "", 20) + pytest.raises(nx.NetworkXError, nx.lollipop_graph, "a", 20) + + # Raise NetworkXError if n<0 + pytest.raises(nx.NetworkXError, nx.lollipop_graph, 5, -2) + + # raise NetworkXError if create_using is directed + with pytest.raises(nx.NetworkXError): + nx.lollipop_graph(2, 20, create_using=nx.DiGraph) + with pytest.raises(nx.NetworkXError): + nx.lollipop_graph(2, 20, create_using=nx.MultiDiGraph) + + @pytest.mark.parametrize(("m", "n"), [(2, 0), (2, 5), (2, 10), ("ab", 20)]) + def test_lollipop_graph_same_as_path_when_m1_is_2(self, m, n): + G = nx.lollipop_graph(m, n) + assert is_isomorphic(G, nx.path_graph(n + 2)) + + def test_lollipop_graph_for_multigraph(self): + G = nx.lollipop_graph(5, 20) + MG = nx.lollipop_graph(5, 20, create_using=nx.MultiGraph) + assert edges_equal(MG.edges(), G.edges()) + + @pytest.mark.parametrize( + ("m", "n"), + [(4, "abc"), ("abcd", 3), ([1, 2, 3, 4], "abc"), ("abcd", [1, 2, 3])], + ) + def test_lollipop_graph_mixing_input_types(self, m, n): + expected = nx.compose(nx.complete_graph(4), nx.path_graph(range(100, 103))) + expected.add_edge(0, 100) # Connect complete graph and path graph + assert is_isomorphic(nx.lollipop_graph(m, n), expected) + + def test_lollipop_graph_non_builtin_ints(self): + np = pytest.importorskip("numpy") + G = nx.lollipop_graph(np.int32(4), np.int64(3)) + expected = nx.compose(nx.complete_graph(4), nx.path_graph(range(100, 103))) + expected.add_edge(0, 100) # Connect complete graph and path graph + assert is_isomorphic(G, expected) + + def test_null_graph(self): + assert nx.number_of_nodes(nx.null_graph()) == 0 + + def test_path_graph(self): + p = nx.path_graph(0) + assert is_isomorphic(p, nx.null_graph()) + + p = nx.path_graph(1) + assert is_isomorphic(p, nx.empty_graph(1)) + + p = nx.path_graph(10) + assert nx.is_connected(p) + assert sorted(d for n, d in p.degree()) == [1, 1, 2, 2, 2, 2, 2, 2, 2, 2] + assert p.order() - 1 == p.size() + + dp = nx.path_graph(3, create_using=nx.DiGraph) + assert dp.has_edge(0, 1) + assert not dp.has_edge(1, 0) + + mp = nx.path_graph(10, create_using=nx.MultiGraph) + assert edges_equal(mp.edges(), p.edges()) + + G = nx.path_graph("abc") + assert len(G) == 3 + assert G.size() == 2 + G = nx.path_graph("abcb") + assert len(G) == 3 + assert G.size() == 2 + g = nx.path_graph("abc", nx.DiGraph) + assert len(g) == 3 + assert g.size() == 2 + assert g.is_directed() + g = nx.path_graph("abcb", nx.DiGraph) + assert len(g) == 3 + assert g.size() == 3 + + G = nx.path_graph((1, 2, 3, 2, 4)) + assert G.has_edge(2, 4) + + def test_star_graph(self): + assert is_isomorphic(nx.star_graph(""), nx.empty_graph(0)) + assert is_isomorphic(nx.star_graph([]), nx.empty_graph(0)) + assert is_isomorphic(nx.star_graph(0), nx.empty_graph(1)) + assert is_isomorphic(nx.star_graph(1), nx.path_graph(2)) + assert is_isomorphic(nx.star_graph(2), nx.path_graph(3)) + assert is_isomorphic(nx.star_graph(5), nx.complete_bipartite_graph(1, 5)) + + s = nx.star_graph(10) + assert sorted(d for n, d in s.degree()) == [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 10] + + pytest.raises(nx.NetworkXError, nx.star_graph, 10, create_using=nx.DiGraph) + + ms = nx.star_graph(10, create_using=nx.MultiGraph) + assert edges_equal(ms.edges(), s.edges()) + + G = nx.star_graph("abc") + assert len(G) == 3 + assert G.size() == 2 + + G = nx.star_graph("abcb") + assert len(G) == 3 + assert G.size() == 2 + G = nx.star_graph("abcb", create_using=nx.MultiGraph) + assert len(G) == 3 + assert G.size() == 3 + + G = nx.star_graph("abcdefg") + assert len(G) == 7 + assert G.size() == 6 + + def test_non_int_integers_for_star_graph(self): + np = pytest.importorskip("numpy") + G = nx.star_graph(np.int32(3)) + assert len(G) == 4 + assert G.size() == 3 + + @pytest.mark.parametrize(("m", "n"), [(3, 0), (3, 5), (4, 10), (3, 20)]) + def test_tadpole_graph_right_sizes(self, m, n): + G = nx.tadpole_graph(m, n) + assert nx.number_of_nodes(G) == m + n + assert nx.number_of_edges(G) == m + n - (m == 2) + + @pytest.mark.parametrize(("m", "n"), [("ab", ""), ("ab", "c"), ("abc", "defg")]) + def test_tadpole_graph_size_node_sequences(self, m, n): + G = nx.tadpole_graph(m, n) + assert nx.number_of_nodes(G) == len(m) + len(n) + assert nx.number_of_edges(G) == len(m) + len(n) - (len(m) == 2) + + def test_tadpole_graph_exceptions(self): + # Raise NetworkXError if m<2 + pytest.raises(nx.NetworkXError, nx.tadpole_graph, -1, 3) + pytest.raises(nx.NetworkXError, nx.tadpole_graph, 0, 3) + pytest.raises(nx.NetworkXError, nx.tadpole_graph, 1, 3) + + # Raise NetworkXError if n<0 + pytest.raises(nx.NetworkXError, nx.tadpole_graph, 5, -2) + + # Raise NetworkXError for digraphs + with pytest.raises(nx.NetworkXError): + nx.tadpole_graph(2, 20, create_using=nx.DiGraph) + with pytest.raises(nx.NetworkXError): + nx.tadpole_graph(2, 20, create_using=nx.MultiDiGraph) + + @pytest.mark.parametrize(("m", "n"), [(2, 0), (2, 5), (2, 10), ("ab", 20)]) + def test_tadpole_graph_same_as_path_when_m_is_2(self, m, n): + G = nx.tadpole_graph(m, n) + assert is_isomorphic(G, nx.path_graph(n + 2)) + + @pytest.mark.parametrize("m", [4, 7]) + def test_tadpole_graph_same_as_cycle_when_m2_is_0(self, m): + G = nx.tadpole_graph(m, 0) + assert is_isomorphic(G, nx.cycle_graph(m)) + + def test_tadpole_graph_for_multigraph(self): + G = nx.tadpole_graph(5, 20) + MG = nx.tadpole_graph(5, 20, create_using=nx.MultiGraph) + assert edges_equal(MG.edges(), G.edges()) + + @pytest.mark.parametrize( + ("m", "n"), + [(4, "abc"), ("abcd", 3), ([1, 2, 3, 4], "abc"), ("abcd", [1, 2, 3])], + ) + def test_tadpole_graph_mixing_input_types(self, m, n): + expected = nx.compose(nx.cycle_graph(4), nx.path_graph(range(100, 103))) + expected.add_edge(0, 100) # Connect cycle and path + assert is_isomorphic(nx.tadpole_graph(m, n), expected) + + def test_tadpole_graph_non_builtin_integers(self): + np = pytest.importorskip("numpy") + G = nx.tadpole_graph(np.int32(4), np.int64(3)) + expected = nx.compose(nx.cycle_graph(4), nx.path_graph(range(100, 103))) + expected.add_edge(0, 100) # Connect cycle and path + assert is_isomorphic(G, expected) + + def test_trivial_graph(self): + assert nx.number_of_nodes(nx.trivial_graph()) == 1 + + def test_turan_graph(self): + assert nx.number_of_edges(nx.turan_graph(13, 4)) == 63 + assert is_isomorphic( + nx.turan_graph(13, 4), nx.complete_multipartite_graph(3, 4, 3, 3) + ) + + def test_wheel_graph(self): + for n, G in [ + ("", nx.null_graph()), + (0, nx.null_graph()), + (1, nx.empty_graph(1)), + (2, nx.path_graph(2)), + (3, nx.complete_graph(3)), + (4, nx.complete_graph(4)), + ]: + g = nx.wheel_graph(n) + assert is_isomorphic(g, G) + + g = nx.wheel_graph(10) + assert sorted(d for n, d in g.degree()) == [3, 3, 3, 3, 3, 3, 3, 3, 3, 9] + + pytest.raises(nx.NetworkXError, nx.wheel_graph, 10, create_using=nx.DiGraph) + + mg = nx.wheel_graph(10, create_using=nx.MultiGraph()) + assert edges_equal(mg.edges(), g.edges()) + + G = nx.wheel_graph("abc") + assert len(G) == 3 + assert G.size() == 3 + + G = nx.wheel_graph("abcb") + assert len(G) == 3 + assert G.size() == 4 + G = nx.wheel_graph("abcb", nx.MultiGraph) + assert len(G) == 3 + assert G.size() == 6 + + def test_non_int_integers_for_wheel_graph(self): + np = pytest.importorskip("numpy") + G = nx.wheel_graph(np.int32(3)) + assert len(G) == 3 + assert G.size() == 3 + + def test_complete_0_partite_graph(self): + """Tests that the complete 0-partite graph is the null graph.""" + G = nx.complete_multipartite_graph() + H = nx.null_graph() + assert nodes_equal(G, H) + assert edges_equal(G.edges(), H.edges()) + + def test_complete_1_partite_graph(self): + """Tests that the complete 1-partite graph is the empty graph.""" + G = nx.complete_multipartite_graph(3) + H = nx.empty_graph(3) + assert nodes_equal(G, H) + assert edges_equal(G.edges(), H.edges()) + + def test_complete_2_partite_graph(self): + """Tests that the complete 2-partite graph is the complete bipartite + graph. + + """ + G = nx.complete_multipartite_graph(2, 3) + H = nx.complete_bipartite_graph(2, 3) + assert nodes_equal(G, H) + assert edges_equal(G.edges(), H.edges()) + + def test_complete_multipartite_graph(self): + """Tests for generating the complete multipartite graph.""" + G = nx.complete_multipartite_graph(2, 3, 4) + blocks = [(0, 1), (2, 3, 4), (5, 6, 7, 8)] + # Within each block, no two vertices should be adjacent. + for block in blocks: + for u, v in itertools.combinations_with_replacement(block, 2): + assert v not in G[u] + assert G.nodes[u] == G.nodes[v] + # Across blocks, all vertices should be adjacent. + for block1, block2 in itertools.combinations(blocks, 2): + for u, v in itertools.product(block1, block2): + assert v in G[u] + assert G.nodes[u] != G.nodes[v] + with pytest.raises(nx.NetworkXError, match="Negative number of nodes"): + nx.complete_multipartite_graph(2, -3, 4) + + def test_kneser_graph(self): + # the petersen graph is a special case of the kneser graph when n=5 and k=2 + assert is_isomorphic(nx.kneser_graph(5, 2), nx.petersen_graph()) + + # when k is 1, the kneser graph returns a complete graph with n vertices + for i in range(1, 7): + assert is_isomorphic(nx.kneser_graph(i, 1), nx.complete_graph(i)) + + # the kneser graph of n and n-1 is the empty graph with n vertices + for j in range(3, 7): + assert is_isomorphic(nx.kneser_graph(j, j - 1), nx.empty_graph(j)) + + # in general the number of edges of the kneser graph is equal to + # (n choose k) times (n-k choose k) divided by 2 + assert nx.number_of_edges(nx.kneser_graph(8, 3)) == 280 diff --git a/lib/python3.10/site-packages/networkx/generators/tests/test_cographs.py b/lib/python3.10/site-packages/networkx/generators/tests/test_cographs.py new file mode 100644 index 0000000000000000000000000000000000000000..a71849b019e7fc3f198a240fc137de0cfddaed0d --- /dev/null +++ b/lib/python3.10/site-packages/networkx/generators/tests/test_cographs.py @@ -0,0 +1,18 @@ +"""Unit tests for the :mod:`networkx.generators.cographs` module.""" + +import networkx as nx + + +def test_random_cograph(): + n = 3 + G = nx.random_cograph(n) + + assert len(G) == 2**n + + # Every connected subgraph of G has diameter <= 2 + if nx.is_connected(G): + assert nx.diameter(G) <= 2 + else: + components = nx.connected_components(G) + for component in components: + assert nx.diameter(G.subgraph(component)) <= 2 diff --git a/lib/python3.10/site-packages/networkx/generators/tests/test_community.py b/lib/python3.10/site-packages/networkx/generators/tests/test_community.py new file mode 100644 index 0000000000000000000000000000000000000000..2fa107f6dde9f280123796f81b919c99f92ee20c --- /dev/null +++ b/lib/python3.10/site-packages/networkx/generators/tests/test_community.py @@ -0,0 +1,362 @@ +import pytest + +import networkx as nx + + +def test_random_partition_graph(): + G = nx.random_partition_graph([3, 3, 3], 1, 0, seed=42) + C = G.graph["partition"] + assert C == [{0, 1, 2}, {3, 4, 5}, {6, 7, 8}] + assert len(G) == 9 + assert len(list(G.edges())) == 9 + + G = nx.random_partition_graph([3, 3, 3], 0, 1) + C = G.graph["partition"] + assert C == [{0, 1, 2}, {3, 4, 5}, {6, 7, 8}] + assert len(G) == 9 + assert len(list(G.edges())) == 27 + + G = nx.random_partition_graph([3, 3, 3], 1, 0, directed=True) + C = G.graph["partition"] + assert C == [{0, 1, 2}, {3, 4, 5}, {6, 7, 8}] + assert len(G) == 9 + assert len(list(G.edges())) == 18 + + G = nx.random_partition_graph([3, 3, 3], 0, 1, directed=True) + C = G.graph["partition"] + assert C == [{0, 1, 2}, {3, 4, 5}, {6, 7, 8}] + assert len(G) == 9 + assert len(list(G.edges())) == 54 + + G = nx.random_partition_graph([1, 2, 3, 4, 5], 0.5, 0.1) + C = G.graph["partition"] + assert C == [{0}, {1, 2}, {3, 4, 5}, {6, 7, 8, 9}, {10, 11, 12, 13, 14}] + assert len(G) == 15 + + rpg = nx.random_partition_graph + pytest.raises(nx.NetworkXError, rpg, [1, 2, 3], 1.1, 0.1) + pytest.raises(nx.NetworkXError, rpg, [1, 2, 3], -0.1, 0.1) + pytest.raises(nx.NetworkXError, rpg, [1, 2, 3], 0.1, 1.1) + pytest.raises(nx.NetworkXError, rpg, [1, 2, 3], 0.1, -0.1) + + +def test_planted_partition_graph(): + G = nx.planted_partition_graph(4, 3, 1, 0, seed=42) + C = G.graph["partition"] + assert len(C) == 4 + assert len(G) == 12 + assert len(list(G.edges())) == 12 + + G = nx.planted_partition_graph(4, 3, 0, 1) + C = G.graph["partition"] + assert len(C) == 4 + assert len(G) == 12 + assert len(list(G.edges())) == 54 + + G = nx.planted_partition_graph(10, 4, 0.5, 0.1, seed=42) + C = G.graph["partition"] + assert len(C) == 10 + assert len(G) == 40 + + G = nx.planted_partition_graph(4, 3, 1, 0, directed=True) + C = G.graph["partition"] + assert len(C) == 4 + assert len(G) == 12 + assert len(list(G.edges())) == 24 + + G = nx.planted_partition_graph(4, 3, 0, 1, directed=True) + C = G.graph["partition"] + assert len(C) == 4 + assert len(G) == 12 + assert len(list(G.edges())) == 108 + + G = nx.planted_partition_graph(10, 4, 0.5, 0.1, seed=42, directed=True) + C = G.graph["partition"] + assert len(C) == 10 + assert len(G) == 40 + + ppg = nx.planted_partition_graph + pytest.raises(nx.NetworkXError, ppg, 3, 3, 1.1, 0.1) + pytest.raises(nx.NetworkXError, ppg, 3, 3, -0.1, 0.1) + pytest.raises(nx.NetworkXError, ppg, 3, 3, 0.1, 1.1) + pytest.raises(nx.NetworkXError, ppg, 3, 3, 0.1, -0.1) + + +def test_relaxed_caveman_graph(): + G = nx.relaxed_caveman_graph(4, 3, 0) + assert len(G) == 12 + G = nx.relaxed_caveman_graph(4, 3, 1) + assert len(G) == 12 + G = nx.relaxed_caveman_graph(4, 3, 0.5) + assert len(G) == 12 + G = nx.relaxed_caveman_graph(4, 3, 0.5, seed=42) + assert len(G) == 12 + + +def test_connected_caveman_graph(): + G = nx.connected_caveman_graph(4, 3) + assert len(G) == 12 + + G = nx.connected_caveman_graph(1, 5) + K5 = nx.complete_graph(5) + K5.remove_edge(3, 4) + assert nx.is_isomorphic(G, K5) + + # need at least 2 nodes in each clique + pytest.raises(nx.NetworkXError, nx.connected_caveman_graph, 4, 1) + + +def test_caveman_graph(): + G = nx.caveman_graph(4, 3) + assert len(G) == 12 + + G = nx.caveman_graph(5, 1) + E5 = nx.empty_graph(5) + assert nx.is_isomorphic(G, E5) + + G = nx.caveman_graph(1, 5) + K5 = nx.complete_graph(5) + assert nx.is_isomorphic(G, K5) + + +def test_gaussian_random_partition_graph(): + G = nx.gaussian_random_partition_graph(100, 10, 10, 0.3, 0.01) + assert len(G) == 100 + G = nx.gaussian_random_partition_graph(100, 10, 10, 0.3, 0.01, directed=True) + assert len(G) == 100 + G = nx.gaussian_random_partition_graph( + 100, 10, 10, 0.3, 0.01, directed=False, seed=42 + ) + assert len(G) == 100 + assert not isinstance(G, nx.DiGraph) + G = nx.gaussian_random_partition_graph( + 100, 10, 10, 0.3, 0.01, directed=True, seed=42 + ) + assert len(G) == 100 + assert isinstance(G, nx.DiGraph) + pytest.raises( + nx.NetworkXError, nx.gaussian_random_partition_graph, 100, 101, 10, 1, 0 + ) + # Test when clusters are likely less than 1 + G = nx.gaussian_random_partition_graph(10, 0.5, 0.5, 0.5, 0.5, seed=1) + assert len(G) == 10 + + +def test_ring_of_cliques(): + for i in range(2, 20, 3): + for j in range(2, 20, 3): + G = nx.ring_of_cliques(i, j) + assert G.number_of_nodes() == i * j + if i != 2 or j != 1: + expected_num_edges = i * (((j * (j - 1)) // 2) + 1) + else: + # the edge that already exists cannot be duplicated + expected_num_edges = i * (((j * (j - 1)) // 2) + 1) - 1 + assert G.number_of_edges() == expected_num_edges + with pytest.raises( + nx.NetworkXError, match="A ring of cliques must have at least two cliques" + ): + nx.ring_of_cliques(1, 5) + with pytest.raises( + nx.NetworkXError, match="The cliques must have at least two nodes" + ): + nx.ring_of_cliques(3, 0) + + +def test_windmill_graph(): + for n in range(2, 20, 3): + for k in range(2, 20, 3): + G = nx.windmill_graph(n, k) + assert G.number_of_nodes() == (k - 1) * n + 1 + assert G.number_of_edges() == n * k * (k - 1) / 2 + assert G.degree(0) == G.number_of_nodes() - 1 + for i in range(1, G.number_of_nodes()): + assert G.degree(i) == k - 1 + with pytest.raises( + nx.NetworkXError, match="A windmill graph must have at least two cliques" + ): + nx.windmill_graph(1, 3) + with pytest.raises( + nx.NetworkXError, match="The cliques must have at least two nodes" + ): + nx.windmill_graph(3, 0) + + +def test_stochastic_block_model(): + sizes = [75, 75, 300] + probs = [[0.25, 0.05, 0.02], [0.05, 0.35, 0.07], [0.02, 0.07, 0.40]] + G = nx.stochastic_block_model(sizes, probs, seed=0) + C = G.graph["partition"] + assert len(C) == 3 + assert len(G) == 450 + assert G.size() == 22160 + + GG = nx.stochastic_block_model(sizes, probs, range(450), seed=0) + assert G.nodes == GG.nodes + + # Test Exceptions + sbm = nx.stochastic_block_model + badnodelist = list(range(400)) # not enough nodes to match sizes + badprobs1 = [[0.25, 0.05, 1.02], [0.05, 0.35, 0.07], [0.02, 0.07, 0.40]] + badprobs2 = [[0.25, 0.05, 0.02], [0.05, -0.35, 0.07], [0.02, 0.07, 0.40]] + probs_rect1 = [[0.25, 0.05, 0.02], [0.05, -0.35, 0.07]] + probs_rect2 = [[0.25, 0.05], [0.05, -0.35], [0.02, 0.07]] + asymprobs = [[0.25, 0.05, 0.01], [0.05, -0.35, 0.07], [0.02, 0.07, 0.40]] + pytest.raises(nx.NetworkXException, sbm, sizes, badprobs1) + pytest.raises(nx.NetworkXException, sbm, sizes, badprobs2) + pytest.raises(nx.NetworkXException, sbm, sizes, probs_rect1, directed=True) + pytest.raises(nx.NetworkXException, sbm, sizes, probs_rect2, directed=True) + pytest.raises(nx.NetworkXException, sbm, sizes, asymprobs, directed=False) + pytest.raises(nx.NetworkXException, sbm, sizes, probs, badnodelist) + nodelist = [0] + list(range(449)) # repeated node name in nodelist + pytest.raises(nx.NetworkXException, sbm, sizes, probs, nodelist) + + # Extra keyword arguments test + GG = nx.stochastic_block_model(sizes, probs, seed=0, selfloops=True) + assert G.nodes == GG.nodes + GG = nx.stochastic_block_model(sizes, probs, selfloops=True, directed=True) + assert G.nodes == GG.nodes + GG = nx.stochastic_block_model(sizes, probs, seed=0, sparse=False) + assert G.nodes == GG.nodes + + +def test_generator(): + n = 250 + tau1 = 3 + tau2 = 1.5 + mu = 0.1 + G = nx.LFR_benchmark_graph( + n, tau1, tau2, mu, average_degree=5, min_community=20, seed=10 + ) + assert len(G) == 250 + C = {frozenset(G.nodes[v]["community"]) for v in G} + assert nx.community.is_partition(G.nodes(), C) + + +def test_invalid_tau1(): + with pytest.raises(nx.NetworkXError, match="tau2 must be greater than one"): + n = 100 + tau1 = 2 + tau2 = 1 + mu = 0.1 + nx.LFR_benchmark_graph(n, tau1, tau2, mu, min_degree=2) + + +def test_invalid_tau2(): + with pytest.raises(nx.NetworkXError, match="tau1 must be greater than one"): + n = 100 + tau1 = 1 + tau2 = 2 + mu = 0.1 + nx.LFR_benchmark_graph(n, tau1, tau2, mu, min_degree=2) + + +def test_mu_too_large(): + with pytest.raises(nx.NetworkXError, match="mu must be in the interval \\[0, 1\\]"): + n = 100 + tau1 = 2 + tau2 = 2 + mu = 1.1 + nx.LFR_benchmark_graph(n, tau1, tau2, mu, min_degree=2) + + +def test_mu_too_small(): + with pytest.raises(nx.NetworkXError, match="mu must be in the interval \\[0, 1\\]"): + n = 100 + tau1 = 2 + tau2 = 2 + mu = -1 + nx.LFR_benchmark_graph(n, tau1, tau2, mu, min_degree=2) + + +def test_both_degrees_none(): + with pytest.raises( + nx.NetworkXError, + match="Must assign exactly one of min_degree and average_degree", + ): + n = 100 + tau1 = 2 + tau2 = 2 + mu = 1 + nx.LFR_benchmark_graph(n, tau1, tau2, mu) + + +def test_neither_degrees_none(): + with pytest.raises( + nx.NetworkXError, + match="Must assign exactly one of min_degree and average_degree", + ): + n = 100 + tau1 = 2 + tau2 = 2 + mu = 1 + nx.LFR_benchmark_graph(n, tau1, tau2, mu, min_degree=2, average_degree=5) + + +def test_max_iters_exceeded(): + with pytest.raises( + nx.ExceededMaxIterations, + match="Could not assign communities; try increasing min_community", + ): + n = 10 + tau1 = 2 + tau2 = 2 + mu = 0.1 + nx.LFR_benchmark_graph(n, tau1, tau2, mu, min_degree=2, max_iters=10, seed=1) + + +def test_max_deg_out_of_range(): + with pytest.raises( + nx.NetworkXError, match="max_degree must be in the interval \\(0, n\\]" + ): + n = 10 + tau1 = 2 + tau2 = 2 + mu = 0.1 + nx.LFR_benchmark_graph( + n, tau1, tau2, mu, max_degree=n + 1, max_iters=10, seed=1 + ) + + +def test_max_community(): + n = 250 + tau1 = 3 + tau2 = 1.5 + mu = 0.1 + G = nx.LFR_benchmark_graph( + n, + tau1, + tau2, + mu, + average_degree=5, + max_degree=100, + min_community=50, + max_community=200, + seed=10, + ) + assert len(G) == 250 + C = {frozenset(G.nodes[v]["community"]) for v in G} + assert nx.community.is_partition(G.nodes(), C) + + +def test_powerlaw_iterations_exceeded(): + with pytest.raises( + nx.ExceededMaxIterations, match="Could not create power law sequence" + ): + n = 100 + tau1 = 2 + tau2 = 2 + mu = 1 + nx.LFR_benchmark_graph(n, tau1, tau2, mu, min_degree=2, max_iters=0) + + +def test_no_scipy_zeta(): + zeta2 = 1.6449340668482264 + assert abs(zeta2 - nx.generators.community._hurwitz_zeta(2, 1, 0.0001)) < 0.01 + + +def test_generate_min_degree_itr(): + with pytest.raises( + nx.ExceededMaxIterations, match="Could not match average_degree" + ): + nx.generators.community._generate_min_degree(2, 2, 1, 0.01, 0) diff --git a/lib/python3.10/site-packages/networkx/generators/tests/test_degree_seq.py b/lib/python3.10/site-packages/networkx/generators/tests/test_degree_seq.py new file mode 100644 index 0000000000000000000000000000000000000000..39ed59a5f32270242b8d069c57229d3e10ba7f43 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/generators/tests/test_degree_seq.py @@ -0,0 +1,230 @@ +import pytest + +import networkx as nx + + +class TestConfigurationModel: + """Unit tests for the :func:`~networkx.configuration_model` + function. + + """ + + def test_empty_degree_sequence(self): + """Tests that an empty degree sequence yields the null graph.""" + G = nx.configuration_model([]) + assert len(G) == 0 + + def test_degree_zero(self): + """Tests that a degree sequence of all zeros yields the empty + graph. + + """ + G = nx.configuration_model([0, 0, 0]) + assert len(G) == 3 + assert G.number_of_edges() == 0 + + def test_degree_sequence(self): + """Tests that the degree sequence of the generated graph matches + the input degree sequence. + + """ + deg_seq = [5, 3, 3, 3, 3, 2, 2, 2, 1, 1, 1] + G = nx.configuration_model(deg_seq, seed=12345678) + assert sorted((d for n, d in G.degree()), reverse=True) == [ + 5, + 3, + 3, + 3, + 3, + 2, + 2, + 2, + 1, + 1, + 1, + ] + assert sorted((d for n, d in G.degree(range(len(deg_seq)))), reverse=True) == [ + 5, + 3, + 3, + 3, + 3, + 2, + 2, + 2, + 1, + 1, + 1, + ] + + def test_random_seed(self): + """Tests that each call with the same random seed generates the + same graph. + + """ + deg_seq = [3] * 12 + G1 = nx.configuration_model(deg_seq, seed=1000) + G2 = nx.configuration_model(deg_seq, seed=1000) + assert nx.is_isomorphic(G1, G2) + G1 = nx.configuration_model(deg_seq, seed=10) + G2 = nx.configuration_model(deg_seq, seed=10) + assert nx.is_isomorphic(G1, G2) + + def test_directed_disallowed(self): + """Tests that attempting to create a configuration model graph + using a directed graph yields an exception. + + """ + with pytest.raises(nx.NetworkXNotImplemented): + nx.configuration_model([], create_using=nx.DiGraph()) + + def test_odd_degree_sum(self): + """Tests that a degree sequence whose sum is odd yields an + exception. + + """ + with pytest.raises(nx.NetworkXError): + nx.configuration_model([1, 2]) + + +def test_directed_configuration_raise_unequal(): + with pytest.raises(nx.NetworkXError): + zin = [5, 3, 3, 3, 3, 2, 2, 2, 1, 1] + zout = [5, 3, 3, 3, 3, 2, 2, 2, 1, 2] + nx.directed_configuration_model(zin, zout) + + +def test_directed_configuration_model(): + G = nx.directed_configuration_model([], [], seed=0) + assert len(G) == 0 + + +def test_simple_directed_configuration_model(): + G = nx.directed_configuration_model([1, 1], [1, 1], seed=0) + assert len(G) == 2 + + +def test_expected_degree_graph_empty(): + # empty graph has empty degree sequence + deg_seq = [] + G = nx.expected_degree_graph(deg_seq) + assert dict(G.degree()) == {} + + +def test_expected_degree_graph(): + # test that fixed seed delivers the same graph + deg_seq = [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3] + G1 = nx.expected_degree_graph(deg_seq, seed=1000) + assert len(G1) == 12 + + G2 = nx.expected_degree_graph(deg_seq, seed=1000) + assert nx.is_isomorphic(G1, G2) + + G1 = nx.expected_degree_graph(deg_seq, seed=10) + G2 = nx.expected_degree_graph(deg_seq, seed=10) + assert nx.is_isomorphic(G1, G2) + + +def test_expected_degree_graph_selfloops(): + deg_seq = [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3] + G1 = nx.expected_degree_graph(deg_seq, seed=1000, selfloops=False) + G2 = nx.expected_degree_graph(deg_seq, seed=1000, selfloops=False) + assert nx.is_isomorphic(G1, G2) + assert len(G1) == 12 + + +def test_expected_degree_graph_skew(): + deg_seq = [10, 2, 2, 2, 2] + G1 = nx.expected_degree_graph(deg_seq, seed=1000) + G2 = nx.expected_degree_graph(deg_seq, seed=1000) + assert nx.is_isomorphic(G1, G2) + assert len(G1) == 5 + + +def test_havel_hakimi_construction(): + G = nx.havel_hakimi_graph([]) + assert len(G) == 0 + + z = [1000, 3, 3, 3, 3, 2, 2, 2, 1, 1, 1] + pytest.raises(nx.NetworkXError, nx.havel_hakimi_graph, z) + z = ["A", 3, 3, 3, 3, 2, 2, 2, 1, 1, 1] + pytest.raises(nx.NetworkXError, nx.havel_hakimi_graph, z) + + z = [5, 4, 3, 3, 3, 2, 2, 2] + G = nx.havel_hakimi_graph(z) + G = nx.configuration_model(z) + z = [6, 5, 4, 4, 2, 1, 1, 1] + pytest.raises(nx.NetworkXError, nx.havel_hakimi_graph, z) + + z = [10, 3, 3, 3, 3, 2, 2, 2, 2, 2, 2] + + G = nx.havel_hakimi_graph(z) + + pytest.raises(nx.NetworkXError, nx.havel_hakimi_graph, z, create_using=nx.DiGraph()) + + +def test_directed_havel_hakimi(): + # Test range of valid directed degree sequences + n, r = 100, 10 + p = 1.0 / r + for i in range(r): + G1 = nx.erdos_renyi_graph(n, p * (i + 1), None, True) + din1 = [d for n, d in G1.in_degree()] + dout1 = [d for n, d in G1.out_degree()] + G2 = nx.directed_havel_hakimi_graph(din1, dout1) + din2 = [d for n, d in G2.in_degree()] + dout2 = [d for n, d in G2.out_degree()] + assert sorted(din1) == sorted(din2) + assert sorted(dout1) == sorted(dout2) + + # Test non-graphical sequence + dout = [1000, 3, 3, 3, 3, 2, 2, 2, 1, 1, 1] + din = [103, 102, 102, 102, 102, 102, 102, 102, 102, 102] + pytest.raises(nx.exception.NetworkXError, nx.directed_havel_hakimi_graph, din, dout) + # Test valid sequences + dout = [1, 1, 1, 1, 1, 2, 2, 2, 3, 4] + din = [2, 2, 2, 2, 2, 2, 2, 2, 0, 2] + G2 = nx.directed_havel_hakimi_graph(din, dout) + dout2 = (d for n, d in G2.out_degree()) + din2 = (d for n, d in G2.in_degree()) + assert sorted(dout) == sorted(dout2) + assert sorted(din) == sorted(din2) + # Test unequal sums + din = [2, 2, 2, 2, 2, 2, 2, 2, 2, 2] + pytest.raises(nx.exception.NetworkXError, nx.directed_havel_hakimi_graph, din, dout) + # Test for negative values + din = [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, -2] + pytest.raises(nx.exception.NetworkXError, nx.directed_havel_hakimi_graph, din, dout) + + +def test_degree_sequence_tree(): + z = [1, 1, 1, 1, 1, 2, 2, 2, 3, 4] + G = nx.degree_sequence_tree(z) + assert len(G) == len(z) + assert len(list(G.edges())) == sum(z) / 2 + + pytest.raises( + nx.NetworkXError, nx.degree_sequence_tree, z, create_using=nx.DiGraph() + ) + + z = [1, 1, 1, 1, 1, 1, 2, 2, 2, 3, 4] + pytest.raises(nx.NetworkXError, nx.degree_sequence_tree, z) + + +def test_random_degree_sequence_graph(): + d = [1, 2, 2, 3] + G = nx.random_degree_sequence_graph(d, seed=42) + assert d == sorted(d for n, d in G.degree()) + + +def test_random_degree_sequence_graph_raise(): + z = [1, 1, 1, 1, 1, 1, 2, 2, 2, 3, 4] + pytest.raises(nx.NetworkXUnfeasible, nx.random_degree_sequence_graph, z) + + +def test_random_degree_sequence_large(): + G1 = nx.fast_gnp_random_graph(100, 0.1, seed=42) + d1 = (d for n, d in G1.degree()) + G2 = nx.random_degree_sequence_graph(d1, seed=42) + d2 = (d for n, d in G2.degree()) + assert sorted(d1) == sorted(d2) diff --git a/lib/python3.10/site-packages/networkx/generators/tests/test_directed.py b/lib/python3.10/site-packages/networkx/generators/tests/test_directed.py new file mode 100644 index 0000000000000000000000000000000000000000..8078d9f7a8b04465bfc556d8f101e343657a078a --- /dev/null +++ b/lib/python3.10/site-packages/networkx/generators/tests/test_directed.py @@ -0,0 +1,163 @@ +"""Generators - Directed Graphs +---------------------------- +""" + +import pytest + +import networkx as nx +from networkx.classes import Graph, MultiDiGraph +from networkx.generators.directed import ( + gn_graph, + gnc_graph, + gnr_graph, + random_k_out_graph, + random_uniform_k_out_graph, + scale_free_graph, +) + + +class TestGeneratorsDirected: + def test_smoke_test_random_graphs(self): + gn_graph(100) + gnr_graph(100, 0.5) + gnc_graph(100) + scale_free_graph(100) + + gn_graph(100, seed=42) + gnr_graph(100, 0.5, seed=42) + gnc_graph(100, seed=42) + scale_free_graph(100, seed=42) + + def test_create_using_keyword_arguments(self): + pytest.raises(nx.NetworkXError, gn_graph, 100, create_using=Graph()) + pytest.raises(nx.NetworkXError, gnr_graph, 100, 0.5, create_using=Graph()) + pytest.raises(nx.NetworkXError, gnc_graph, 100, create_using=Graph()) + G = gn_graph(100, seed=1) + MG = gn_graph(100, create_using=MultiDiGraph(), seed=1) + assert sorted(G.edges()) == sorted(MG.edges()) + G = gnr_graph(100, 0.5, seed=1) + MG = gnr_graph(100, 0.5, create_using=MultiDiGraph(), seed=1) + assert sorted(G.edges()) == sorted(MG.edges()) + G = gnc_graph(100, seed=1) + MG = gnc_graph(100, create_using=MultiDiGraph(), seed=1) + assert sorted(G.edges()) == sorted(MG.edges()) + + G = scale_free_graph( + 100, + alpha=0.3, + beta=0.4, + gamma=0.3, + delta_in=0.3, + delta_out=0.1, + initial_graph=nx.cycle_graph(4, create_using=MultiDiGraph), + seed=1, + ) + pytest.raises(ValueError, scale_free_graph, 100, 0.5, 0.4, 0.3) + pytest.raises(ValueError, scale_free_graph, 100, alpha=-0.3) + pytest.raises(ValueError, scale_free_graph, 100, beta=-0.3) + pytest.raises(ValueError, scale_free_graph, 100, gamma=-0.3) + + def test_parameters(self): + G = nx.DiGraph() + G.add_node(0) + + def kernel(x): + return x + + assert nx.is_isomorphic(gn_graph(1), G) + assert nx.is_isomorphic(gn_graph(1, kernel=kernel), G) + assert nx.is_isomorphic(gnc_graph(1), G) + assert nx.is_isomorphic(gnr_graph(1, 0.5), G) + + +def test_scale_free_graph_negative_delta(): + with pytest.raises(ValueError, match="delta_in must be >= 0."): + scale_free_graph(10, delta_in=-1) + with pytest.raises(ValueError, match="delta_out must be >= 0."): + scale_free_graph(10, delta_out=-1) + + +def test_non_numeric_ordering(): + G = MultiDiGraph([("a", "b"), ("b", "c"), ("c", "a")]) + s = scale_free_graph(3, initial_graph=G) + assert len(s) == 3 + assert len(s.edges) == 3 + + +@pytest.mark.parametrize("ig", (nx.Graph(), nx.DiGraph([(0, 1)]))) +def test_scale_free_graph_initial_graph_kwarg(ig): + with pytest.raises(nx.NetworkXError): + scale_free_graph(100, initial_graph=ig) + + +class TestRandomKOutGraph: + """Unit tests for the + :func:`~networkx.generators.directed.random_k_out_graph` function. + + """ + + def test_regularity(self): + """Tests that the generated graph is `k`-out-regular.""" + n = 10 + k = 3 + alpha = 1 + G = random_k_out_graph(n, k, alpha) + assert all(d == k for v, d in G.out_degree()) + G = random_k_out_graph(n, k, alpha, seed=42) + assert all(d == k for v, d in G.out_degree()) + + def test_no_self_loops(self): + """Tests for forbidding self-loops.""" + n = 10 + k = 3 + alpha = 1 + G = random_k_out_graph(n, k, alpha, self_loops=False) + assert nx.number_of_selfloops(G) == 0 + + def test_negative_alpha(self): + with pytest.raises(ValueError, match="alpha must be positive"): + random_k_out_graph(10, 3, -1) + + +class TestUniformRandomKOutGraph: + """Unit tests for the + :func:`~networkx.generators.directed.random_uniform_k_out_graph` + function. + + """ + + def test_regularity(self): + """Tests that the generated graph is `k`-out-regular.""" + n = 10 + k = 3 + G = random_uniform_k_out_graph(n, k) + assert all(d == k for v, d in G.out_degree()) + G = random_uniform_k_out_graph(n, k, seed=42) + assert all(d == k for v, d in G.out_degree()) + + def test_no_self_loops(self): + """Tests for forbidding self-loops.""" + n = 10 + k = 3 + G = random_uniform_k_out_graph(n, k, self_loops=False) + assert nx.number_of_selfloops(G) == 0 + assert all(d == k for v, d in G.out_degree()) + + def test_with_replacement(self): + n = 10 + k = 3 + G = random_uniform_k_out_graph(n, k, with_replacement=True) + assert G.is_multigraph() + assert all(d == k for v, d in G.out_degree()) + n = 10 + k = 9 + G = random_uniform_k_out_graph(n, k, with_replacement=False, self_loops=False) + assert nx.number_of_selfloops(G) == 0 + assert all(d == k for v, d in G.out_degree()) + + def test_without_replacement(self): + n = 10 + k = 3 + G = random_uniform_k_out_graph(n, k, with_replacement=False) + assert not G.is_multigraph() + assert all(d == k for v, d in G.out_degree()) diff --git a/lib/python3.10/site-packages/networkx/generators/tests/test_duplication.py b/lib/python3.10/site-packages/networkx/generators/tests/test_duplication.py new file mode 100644 index 0000000000000000000000000000000000000000..9b6100b78e59067b607e310f14d80e5a00c2b691 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/generators/tests/test_duplication.py @@ -0,0 +1,103 @@ +"""Unit tests for the :mod:`networkx.generators.duplication` module.""" + +import pytest + +import networkx as nx + + +class TestDuplicationDivergenceGraph: + """Unit tests for the + :func:`networkx.generators.duplication.duplication_divergence_graph` + function. + + """ + + def test_final_size(self): + G = nx.duplication_divergence_graph(3, p=1) + assert len(G) == 3 + G = nx.duplication_divergence_graph(3, p=1, seed=42) + assert len(G) == 3 + + def test_probability_too_large(self): + with pytest.raises(nx.NetworkXError): + nx.duplication_divergence_graph(3, p=2) + + def test_probability_too_small(self): + with pytest.raises(nx.NetworkXError): + nx.duplication_divergence_graph(3, p=-1) + + def test_non_extreme_probability_value(self): + G = nx.duplication_divergence_graph(6, p=0.3, seed=42) + assert len(G) == 6 + assert list(G.degree()) == [(0, 2), (1, 3), (2, 2), (3, 3), (4, 1), (5, 1)] + + def test_minimum_desired_nodes(self): + with pytest.raises( + nx.NetworkXError, match=".*n must be greater than or equal to 2" + ): + nx.duplication_divergence_graph(1, p=1) + + def test_create_using(self): + class DummyGraph(nx.Graph): + pass + + class DummyDiGraph(nx.DiGraph): + pass + + G = nx.duplication_divergence_graph(6, 0.3, seed=42, create_using=DummyGraph) + assert isinstance(G, DummyGraph) + with pytest.raises(nx.NetworkXError, match="create_using must not be directed"): + nx.duplication_divergence_graph(6, 0.3, seed=42, create_using=DummyDiGraph) + + +class TestPartialDuplicationGraph: + """Unit tests for the + :func:`networkx.generators.duplication.partial_duplication_graph` + function. + + """ + + def test_final_size(self): + N = 10 + n = 5 + p = 0.5 + q = 0.5 + G = nx.partial_duplication_graph(N, n, p, q) + assert len(G) == N + G = nx.partial_duplication_graph(N, n, p, q, seed=42) + assert len(G) == N + + def test_initial_clique_size(self): + N = 10 + n = 10 + p = 0.5 + q = 0.5 + G = nx.partial_duplication_graph(N, n, p, q) + assert len(G) == n + + def test_invalid_initial_size(self): + with pytest.raises(nx.NetworkXError): + N = 5 + n = 10 + p = 0.5 + q = 0.5 + G = nx.partial_duplication_graph(N, n, p, q) + + def test_invalid_probabilities(self): + N = 1 + n = 1 + for p, q in [(0.5, 2), (0.5, -1), (2, 0.5), (-1, 0.5)]: + args = (N, n, p, q) + pytest.raises(nx.NetworkXError, nx.partial_duplication_graph, *args) + + def test_create_using(self): + class DummyGraph(nx.Graph): + pass + + class DummyDiGraph(nx.DiGraph): + pass + + G = nx.partial_duplication_graph(10, 5, 0.5, 0.5, create_using=DummyGraph) + assert isinstance(G, DummyGraph) + with pytest.raises(nx.NetworkXError, match="create_using must not be directed"): + nx.partial_duplication_graph(10, 5, 0.5, 0.5, create_using=DummyDiGraph) diff --git a/lib/python3.10/site-packages/networkx/linalg/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/networkx/linalg/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a9deb8af9a055a9af09bead9dc4b8970f55445db Binary files /dev/null and b/lib/python3.10/site-packages/networkx/linalg/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/linalg/__pycache__/algebraicconnectivity.cpython-310.pyc b/lib/python3.10/site-packages/networkx/linalg/__pycache__/algebraicconnectivity.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b62c2a2d94fa95d8a3c8265c792b7b934d758ec Binary files /dev/null and b/lib/python3.10/site-packages/networkx/linalg/__pycache__/algebraicconnectivity.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/linalg/__pycache__/attrmatrix.cpython-310.pyc b/lib/python3.10/site-packages/networkx/linalg/__pycache__/attrmatrix.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9c4ffe942646f45b7e410873e470b6d891645979 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/linalg/__pycache__/attrmatrix.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/linalg/__pycache__/bethehessianmatrix.cpython-310.pyc b/lib/python3.10/site-packages/networkx/linalg/__pycache__/bethehessianmatrix.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..04f4393499a175c44eb3c6641650f87d6af6e129 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/linalg/__pycache__/bethehessianmatrix.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/linalg/__pycache__/graphmatrix.cpython-310.pyc b/lib/python3.10/site-packages/networkx/linalg/__pycache__/graphmatrix.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..74ba49f131948ac474bb73e008f9e1785276d25a Binary files /dev/null and b/lib/python3.10/site-packages/networkx/linalg/__pycache__/graphmatrix.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/linalg/__pycache__/laplacianmatrix.cpython-310.pyc b/lib/python3.10/site-packages/networkx/linalg/__pycache__/laplacianmatrix.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e930e64e7a890c25f759c8fec8f3ece4c5048c1 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/linalg/__pycache__/laplacianmatrix.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/linalg/__pycache__/modularitymatrix.cpython-310.pyc b/lib/python3.10/site-packages/networkx/linalg/__pycache__/modularitymatrix.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb615ba4a0f02ea1e410502a70e5f4abe64f705c Binary files /dev/null and b/lib/python3.10/site-packages/networkx/linalg/__pycache__/modularitymatrix.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/linalg/__pycache__/spectrum.cpython-310.pyc b/lib/python3.10/site-packages/networkx/linalg/__pycache__/spectrum.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..89c4cb88cfca2f4e195400bdeb4abe9deeb95f98 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/linalg/__pycache__/spectrum.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/linalg/tests/__init__.py b/lib/python3.10/site-packages/networkx/linalg/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/python3.10/site-packages/networkx/linalg/tests/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/networkx/linalg/tests/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..058d062aa467ae518cb9a931ae1b403bd64d2f92 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/linalg/tests/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/linalg/tests/__pycache__/test_algebraic_connectivity.cpython-310.pyc b/lib/python3.10/site-packages/networkx/linalg/tests/__pycache__/test_algebraic_connectivity.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14e8d9ad6904ff21c6b440d3cd22b1a9985a8760 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/linalg/tests/__pycache__/test_algebraic_connectivity.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/linalg/tests/__pycache__/test_attrmatrix.cpython-310.pyc b/lib/python3.10/site-packages/networkx/linalg/tests/__pycache__/test_attrmatrix.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6f04c7dc814bc95bcf2c218afdd982bd705920b Binary files /dev/null and b/lib/python3.10/site-packages/networkx/linalg/tests/__pycache__/test_attrmatrix.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/linalg/tests/__pycache__/test_bethehessian.cpython-310.pyc b/lib/python3.10/site-packages/networkx/linalg/tests/__pycache__/test_bethehessian.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b674b92c9eef5e0f3d8667ee9359207d7773818e Binary files /dev/null and b/lib/python3.10/site-packages/networkx/linalg/tests/__pycache__/test_bethehessian.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/linalg/tests/__pycache__/test_graphmatrix.cpython-310.pyc b/lib/python3.10/site-packages/networkx/linalg/tests/__pycache__/test_graphmatrix.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8728197172f8350124a3f7196d67eec4b4072015 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/linalg/tests/__pycache__/test_graphmatrix.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/linalg/tests/__pycache__/test_laplacian.cpython-310.pyc b/lib/python3.10/site-packages/networkx/linalg/tests/__pycache__/test_laplacian.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..34d79a3c4181299810f7a22fbbaa3c4a9b708666 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/linalg/tests/__pycache__/test_laplacian.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/linalg/tests/__pycache__/test_modularity.cpython-310.pyc b/lib/python3.10/site-packages/networkx/linalg/tests/__pycache__/test_modularity.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..89d72f4fb6850110ef12dd9fb351ddaaae7ed7e1 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/linalg/tests/__pycache__/test_modularity.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/linalg/tests/__pycache__/test_spectrum.cpython-310.pyc b/lib/python3.10/site-packages/networkx/linalg/tests/__pycache__/test_spectrum.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d85fd61ad2ac7e257a63f63ecc7c085cdb894a4d Binary files /dev/null and b/lib/python3.10/site-packages/networkx/linalg/tests/__pycache__/test_spectrum.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/linalg/tests/test_algebraic_connectivity.py b/lib/python3.10/site-packages/networkx/linalg/tests/test_algebraic_connectivity.py new file mode 100644 index 0000000000000000000000000000000000000000..089d917a6832e1ab211eb3ced08344b84ddb285a --- /dev/null +++ b/lib/python3.10/site-packages/networkx/linalg/tests/test_algebraic_connectivity.py @@ -0,0 +1,402 @@ +from math import sqrt + +import pytest + +np = pytest.importorskip("numpy") + + +import networkx as nx + +methods = ("tracemin_pcg", "tracemin_lu", "lanczos", "lobpcg") + + +def test_algebraic_connectivity_tracemin_chol(): + """Test that "tracemin_chol" raises an exception.""" + pytest.importorskip("scipy") + G = nx.barbell_graph(5, 4) + with pytest.raises(nx.NetworkXError): + nx.algebraic_connectivity(G, method="tracemin_chol") + + +def test_fiedler_vector_tracemin_chol(): + """Test that "tracemin_chol" raises an exception.""" + pytest.importorskip("scipy") + G = nx.barbell_graph(5, 4) + with pytest.raises(nx.NetworkXError): + nx.fiedler_vector(G, method="tracemin_chol") + + +def test_spectral_ordering_tracemin_chol(): + """Test that "tracemin_chol" raises an exception.""" + pytest.importorskip("scipy") + G = nx.barbell_graph(5, 4) + with pytest.raises(nx.NetworkXError): + nx.spectral_ordering(G, method="tracemin_chol") + + +def test_fiedler_vector_tracemin_unknown(): + """Test that "tracemin_unknown" raises an exception.""" + pytest.importorskip("scipy") + G = nx.barbell_graph(5, 4) + L = nx.laplacian_matrix(G) + X = np.asarray(np.random.normal(size=(1, L.shape[0]))).T + with pytest.raises(nx.NetworkXError, match="Unknown linear system solver"): + nx.linalg.algebraicconnectivity._tracemin_fiedler( + L, X, normalized=False, tol=1e-8, method="tracemin_unknown" + ) + + +def test_spectral_bisection(): + pytest.importorskip("scipy") + G = nx.barbell_graph(3, 0) + C = nx.spectral_bisection(G) + assert C == ({0, 1, 2}, {3, 4, 5}) + + mapping = dict(enumerate("badfec")) + G = nx.relabel_nodes(G, mapping) + C = nx.spectral_bisection(G) + assert C == ( + {mapping[0], mapping[1], mapping[2]}, + {mapping[3], mapping[4], mapping[5]}, + ) + + +def check_eigenvector(A, l, x): + nx = np.linalg.norm(x) + # Check zeroness. + assert nx != pytest.approx(0, abs=1e-07) + y = A @ x + ny = np.linalg.norm(y) + # Check collinearity. + assert x @ y == pytest.approx(nx * ny, abs=1e-7) + # Check eigenvalue. + assert ny == pytest.approx(l * nx, abs=1e-7) + + +class TestAlgebraicConnectivity: + @pytest.mark.parametrize("method", methods) + def test_directed(self, method): + G = nx.DiGraph() + pytest.raises( + nx.NetworkXNotImplemented, nx.algebraic_connectivity, G, method=method + ) + pytest.raises(nx.NetworkXNotImplemented, nx.fiedler_vector, G, method=method) + + @pytest.mark.parametrize("method", methods) + def test_null_and_singleton(self, method): + G = nx.Graph() + pytest.raises(nx.NetworkXError, nx.algebraic_connectivity, G, method=method) + pytest.raises(nx.NetworkXError, nx.fiedler_vector, G, method=method) + G.add_edge(0, 0) + pytest.raises(nx.NetworkXError, nx.algebraic_connectivity, G, method=method) + pytest.raises(nx.NetworkXError, nx.fiedler_vector, G, method=method) + + @pytest.mark.parametrize("method", methods) + def test_disconnected(self, method): + G = nx.Graph() + G.add_nodes_from(range(2)) + assert nx.algebraic_connectivity(G) == 0 + pytest.raises(nx.NetworkXError, nx.fiedler_vector, G, method=method) + G.add_edge(0, 1, weight=0) + assert nx.algebraic_connectivity(G) == 0 + pytest.raises(nx.NetworkXError, nx.fiedler_vector, G, method=method) + + def test_unrecognized_method(self): + pytest.importorskip("scipy") + G = nx.path_graph(4) + pytest.raises(nx.NetworkXError, nx.algebraic_connectivity, G, method="unknown") + pytest.raises(nx.NetworkXError, nx.fiedler_vector, G, method="unknown") + + @pytest.mark.parametrize("method", methods) + def test_two_nodes(self, method): + pytest.importorskip("scipy") + G = nx.Graph() + G.add_edge(0, 1, weight=1) + A = nx.laplacian_matrix(G) + assert nx.algebraic_connectivity(G, tol=1e-12, method=method) == pytest.approx( + 2, abs=1e-7 + ) + x = nx.fiedler_vector(G, tol=1e-12, method=method) + check_eigenvector(A, 2, x) + + @pytest.mark.parametrize("method", methods) + def test_two_nodes_multigraph(self, method): + pytest.importorskip("scipy") + G = nx.MultiGraph() + G.add_edge(0, 0, spam=1e8) + G.add_edge(0, 1, spam=1) + G.add_edge(0, 1, spam=-2) + A = -3 * nx.laplacian_matrix(G, weight="spam") + assert nx.algebraic_connectivity( + G, weight="spam", tol=1e-12, method=method + ) == pytest.approx(6, abs=1e-7) + x = nx.fiedler_vector(G, weight="spam", tol=1e-12, method=method) + check_eigenvector(A, 6, x) + + def test_abbreviation_of_method(self): + pytest.importorskip("scipy") + G = nx.path_graph(8) + A = nx.laplacian_matrix(G) + sigma = 2 - sqrt(2 + sqrt(2)) + ac = nx.algebraic_connectivity(G, tol=1e-12, method="tracemin") + assert ac == pytest.approx(sigma, abs=1e-7) + x = nx.fiedler_vector(G, tol=1e-12, method="tracemin") + check_eigenvector(A, sigma, x) + + @pytest.mark.parametrize("method", methods) + def test_path(self, method): + pytest.importorskip("scipy") + G = nx.path_graph(8) + A = nx.laplacian_matrix(G) + sigma = 2 - sqrt(2 + sqrt(2)) + ac = nx.algebraic_connectivity(G, tol=1e-12, method=method) + assert ac == pytest.approx(sigma, abs=1e-7) + x = nx.fiedler_vector(G, tol=1e-12, method=method) + check_eigenvector(A, sigma, x) + + @pytest.mark.parametrize("method", methods) + def test_problematic_graph_issue_2381(self, method): + pytest.importorskip("scipy") + G = nx.path_graph(4) + G.add_edges_from([(4, 2), (5, 1)]) + A = nx.laplacian_matrix(G) + sigma = 0.438447187191 + ac = nx.algebraic_connectivity(G, tol=1e-12, method=method) + assert ac == pytest.approx(sigma, abs=1e-7) + x = nx.fiedler_vector(G, tol=1e-12, method=method) + check_eigenvector(A, sigma, x) + + @pytest.mark.parametrize("method", methods) + def test_cycle(self, method): + pytest.importorskip("scipy") + G = nx.cycle_graph(8) + A = nx.laplacian_matrix(G) + sigma = 2 - sqrt(2) + ac = nx.algebraic_connectivity(G, tol=1e-12, method=method) + assert ac == pytest.approx(sigma, abs=1e-7) + x = nx.fiedler_vector(G, tol=1e-12, method=method) + check_eigenvector(A, sigma, x) + + @pytest.mark.parametrize("method", methods) + def test_seed_argument(self, method): + pytest.importorskip("scipy") + G = nx.cycle_graph(8) + A = nx.laplacian_matrix(G) + sigma = 2 - sqrt(2) + ac = nx.algebraic_connectivity(G, tol=1e-12, method=method, seed=1) + assert ac == pytest.approx(sigma, abs=1e-7) + x = nx.fiedler_vector(G, tol=1e-12, method=method, seed=1) + check_eigenvector(A, sigma, x) + + @pytest.mark.parametrize( + ("normalized", "sigma", "laplacian_fn"), + ( + (False, 0.2434017461399311, nx.laplacian_matrix), + (True, 0.08113391537997749, nx.normalized_laplacian_matrix), + ), + ) + @pytest.mark.parametrize("method", methods) + def test_buckminsterfullerene(self, normalized, sigma, laplacian_fn, method): + pytest.importorskip("scipy") + G = nx.Graph( + [ + (1, 10), + (1, 41), + (1, 59), + (2, 12), + (2, 42), + (2, 60), + (3, 6), + (3, 43), + (3, 57), + (4, 8), + (4, 44), + (4, 58), + (5, 13), + (5, 56), + (5, 57), + (6, 10), + (6, 31), + (7, 14), + (7, 56), + (7, 58), + (8, 12), + (8, 32), + (9, 23), + (9, 53), + (9, 59), + (10, 15), + (11, 24), + (11, 53), + (11, 60), + (12, 16), + (13, 14), + (13, 25), + (14, 26), + (15, 27), + (15, 49), + (16, 28), + (16, 50), + (17, 18), + (17, 19), + (17, 54), + (18, 20), + (18, 55), + (19, 23), + (19, 41), + (20, 24), + (20, 42), + (21, 31), + (21, 33), + (21, 57), + (22, 32), + (22, 34), + (22, 58), + (23, 24), + (25, 35), + (25, 43), + (26, 36), + (26, 44), + (27, 51), + (27, 59), + (28, 52), + (28, 60), + (29, 33), + (29, 34), + (29, 56), + (30, 51), + (30, 52), + (30, 53), + (31, 47), + (32, 48), + (33, 45), + (34, 46), + (35, 36), + (35, 37), + (36, 38), + (37, 39), + (37, 49), + (38, 40), + (38, 50), + (39, 40), + (39, 51), + (40, 52), + (41, 47), + (42, 48), + (43, 49), + (44, 50), + (45, 46), + (45, 54), + (46, 55), + (47, 54), + (48, 55), + ] + ) + A = laplacian_fn(G) + try: + assert nx.algebraic_connectivity( + G, normalized=normalized, tol=1e-12, method=method + ) == pytest.approx(sigma, abs=1e-7) + x = nx.fiedler_vector(G, normalized=normalized, tol=1e-12, method=method) + check_eigenvector(A, sigma, x) + except nx.NetworkXError as err: + if err.args not in ( + ("Cholesky solver unavailable.",), + ("LU solver unavailable.",), + ): + raise + + +class TestSpectralOrdering: + _graphs = (nx.Graph, nx.DiGraph, nx.MultiGraph, nx.MultiDiGraph) + + @pytest.mark.parametrize("graph", _graphs) + def test_nullgraph(self, graph): + G = graph() + pytest.raises(nx.NetworkXError, nx.spectral_ordering, G) + + @pytest.mark.parametrize("graph", _graphs) + def test_singleton(self, graph): + G = graph() + G.add_node("x") + assert nx.spectral_ordering(G) == ["x"] + G.add_edge("x", "x", weight=33) + G.add_edge("x", "x", weight=33) + assert nx.spectral_ordering(G) == ["x"] + + def test_unrecognized_method(self): + G = nx.path_graph(4) + pytest.raises(nx.NetworkXError, nx.spectral_ordering, G, method="unknown") + + @pytest.mark.parametrize("method", methods) + def test_three_nodes(self, method): + pytest.importorskip("scipy") + G = nx.Graph() + G.add_weighted_edges_from([(1, 2, 1), (1, 3, 2), (2, 3, 1)], weight="spam") + order = nx.spectral_ordering(G, weight="spam", method=method) + assert set(order) == set(G) + assert {1, 3} in (set(order[:-1]), set(order[1:])) + + @pytest.mark.parametrize("method", methods) + def test_three_nodes_multigraph(self, method): + pytest.importorskip("scipy") + G = nx.MultiDiGraph() + G.add_weighted_edges_from([(1, 2, 1), (1, 3, 2), (2, 3, 1), (2, 3, 2)]) + order = nx.spectral_ordering(G, method=method) + assert set(order) == set(G) + assert {2, 3} in (set(order[:-1]), set(order[1:])) + + @pytest.mark.parametrize("method", methods) + def test_path(self, method): + pytest.importorskip("scipy") + path = list(range(10)) + np.random.shuffle(path) + G = nx.Graph() + nx.add_path(G, path) + order = nx.spectral_ordering(G, method=method) + assert order in [path, list(reversed(path))] + + @pytest.mark.parametrize("method", methods) + def test_seed_argument(self, method): + pytest.importorskip("scipy") + path = list(range(10)) + np.random.shuffle(path) + G = nx.Graph() + nx.add_path(G, path) + order = nx.spectral_ordering(G, method=method, seed=1) + assert order in [path, list(reversed(path))] + + @pytest.mark.parametrize("method", methods) + def test_disconnected(self, method): + pytest.importorskip("scipy") + G = nx.Graph() + nx.add_path(G, range(0, 10, 2)) + nx.add_path(G, range(1, 10, 2)) + order = nx.spectral_ordering(G, method=method) + assert set(order) == set(G) + seqs = [ + list(range(0, 10, 2)), + list(range(8, -1, -2)), + list(range(1, 10, 2)), + list(range(9, -1, -2)), + ] + assert order[:5] in seqs + assert order[5:] in seqs + + @pytest.mark.parametrize( + ("normalized", "expected_order"), + ( + (False, [[1, 2, 0, 3, 4, 5, 6, 9, 7, 8], [8, 7, 9, 6, 5, 4, 3, 0, 2, 1]]), + (True, [[1, 2, 3, 0, 4, 5, 9, 6, 7, 8], [8, 7, 6, 9, 5, 4, 0, 3, 2, 1]]), + ), + ) + @pytest.mark.parametrize("method", methods) + def test_cycle(self, normalized, expected_order, method): + pytest.importorskip("scipy") + path = list(range(10)) + G = nx.Graph() + nx.add_path(G, path, weight=5) + G.add_edge(path[-1], path[0], weight=1) + A = nx.laplacian_matrix(G).todense() + order = nx.spectral_ordering(G, normalized=normalized, method=method) + assert order in expected_order diff --git a/lib/python3.10/site-packages/networkx/linalg/tests/test_attrmatrix.py b/lib/python3.10/site-packages/networkx/linalg/tests/test_attrmatrix.py new file mode 100644 index 0000000000000000000000000000000000000000..01574bb3b8f284edef6c7f92fe1c7e7a239e0610 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/linalg/tests/test_attrmatrix.py @@ -0,0 +1,108 @@ +import pytest + +np = pytest.importorskip("numpy") + +import networkx as nx + + +def test_attr_matrix(): + G = nx.Graph() + G.add_edge(0, 1, thickness=1, weight=3) + G.add_edge(0, 1, thickness=1, weight=3) + G.add_edge(0, 2, thickness=2) + G.add_edge(1, 2, thickness=3) + + def node_attr(u): + return G.nodes[u].get("size", 0.5) * 3 + + def edge_attr(u, v): + return G[u][v].get("thickness", 0.5) + + M = nx.attr_matrix(G, edge_attr=edge_attr, node_attr=node_attr) + np.testing.assert_equal(M[0], np.array([[6.0]])) + assert M[1] == [1.5] + + +def test_attr_matrix_directed(): + G = nx.DiGraph() + G.add_edge(0, 1, thickness=1, weight=3) + G.add_edge(0, 1, thickness=1, weight=3) + G.add_edge(0, 2, thickness=2) + G.add_edge(1, 2, thickness=3) + M = nx.attr_matrix(G, rc_order=[0, 1, 2]) + # fmt: off + data = np.array( + [[0., 1., 1.], + [0., 0., 1.], + [0., 0., 0.]] + ) + # fmt: on + np.testing.assert_equal(M, np.array(data)) + + +def test_attr_matrix_multigraph(): + G = nx.MultiGraph() + G.add_edge(0, 1, thickness=1, weight=3) + G.add_edge(0, 1, thickness=1, weight=3) + G.add_edge(0, 1, thickness=1, weight=3) + G.add_edge(0, 2, thickness=2) + G.add_edge(1, 2, thickness=3) + M = nx.attr_matrix(G, rc_order=[0, 1, 2]) + # fmt: off + data = np.array( + [[0., 3., 1.], + [3., 0., 1.], + [1., 1., 0.]] + ) + # fmt: on + np.testing.assert_equal(M, np.array(data)) + M = nx.attr_matrix(G, edge_attr="weight", rc_order=[0, 1, 2]) + # fmt: off + data = np.array( + [[0., 9., 1.], + [9., 0., 1.], + [1., 1., 0.]] + ) + # fmt: on + np.testing.assert_equal(M, np.array(data)) + M = nx.attr_matrix(G, edge_attr="thickness", rc_order=[0, 1, 2]) + # fmt: off + data = np.array( + [[0., 3., 2.], + [3., 0., 3.], + [2., 3., 0.]] + ) + # fmt: on + np.testing.assert_equal(M, np.array(data)) + + +def test_attr_sparse_matrix(): + pytest.importorskip("scipy") + G = nx.Graph() + G.add_edge(0, 1, thickness=1, weight=3) + G.add_edge(0, 2, thickness=2) + G.add_edge(1, 2, thickness=3) + M = nx.attr_sparse_matrix(G) + mtx = M[0] + data = np.ones((3, 3), float) + np.fill_diagonal(data, 0) + np.testing.assert_equal(mtx.todense(), np.array(data)) + assert M[1] == [0, 1, 2] + + +def test_attr_sparse_matrix_directed(): + pytest.importorskip("scipy") + G = nx.DiGraph() + G.add_edge(0, 1, thickness=1, weight=3) + G.add_edge(0, 1, thickness=1, weight=3) + G.add_edge(0, 2, thickness=2) + G.add_edge(1, 2, thickness=3) + M = nx.attr_sparse_matrix(G, rc_order=[0, 1, 2]) + # fmt: off + data = np.array( + [[0., 1., 1.], + [0., 0., 1.], + [0., 0., 0.]] + ) + # fmt: on + np.testing.assert_equal(M.todense(), np.array(data)) diff --git a/lib/python3.10/site-packages/networkx/linalg/tests/test_bethehessian.py b/lib/python3.10/site-packages/networkx/linalg/tests/test_bethehessian.py new file mode 100644 index 0000000000000000000000000000000000000000..339fe1be390b40083efdd61f1cae4ff62838fc93 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/linalg/tests/test_bethehessian.py @@ -0,0 +1,41 @@ +import pytest + +np = pytest.importorskip("numpy") +pytest.importorskip("scipy") + +import networkx as nx +from networkx.generators.degree_seq import havel_hakimi_graph + + +class TestBetheHessian: + @classmethod + def setup_class(cls): + deg = [3, 2, 2, 1, 0] + cls.G = havel_hakimi_graph(deg) + cls.P = nx.path_graph(3) + + def test_bethe_hessian(self): + "Bethe Hessian matrix" + # fmt: off + H = np.array([[4, -2, 0], + [-2, 5, -2], + [0, -2, 4]]) + # fmt: on + permutation = [2, 0, 1] + # Bethe Hessian gives expected form + np.testing.assert_equal(nx.bethe_hessian_matrix(self.P, r=2).todense(), H) + # nodelist is correctly implemented + np.testing.assert_equal( + nx.bethe_hessian_matrix(self.P, r=2, nodelist=permutation).todense(), + H[np.ix_(permutation, permutation)], + ) + # Equal to Laplacian matrix when r=1 + np.testing.assert_equal( + nx.bethe_hessian_matrix(self.G, r=1).todense(), + nx.laplacian_matrix(self.G).todense(), + ) + # Correct default for the regularizer r + np.testing.assert_equal( + nx.bethe_hessian_matrix(self.G).todense(), + nx.bethe_hessian_matrix(self.G, r=1.25).todense(), + ) diff --git a/lib/python3.10/site-packages/networkx/linalg/tests/test_graphmatrix.py b/lib/python3.10/site-packages/networkx/linalg/tests/test_graphmatrix.py new file mode 100644 index 0000000000000000000000000000000000000000..519198bc07b32f16c1c0ae0cd9b8bbe6b81bce62 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/linalg/tests/test_graphmatrix.py @@ -0,0 +1,276 @@ +import pytest + +np = pytest.importorskip("numpy") +pytest.importorskip("scipy") + +import networkx as nx +from networkx.exception import NetworkXError +from networkx.generators.degree_seq import havel_hakimi_graph + + +def test_incidence_matrix_simple(): + deg = [3, 2, 2, 1, 0] + G = havel_hakimi_graph(deg) + deg = [(1, 0), (1, 0), (1, 0), (2, 0), (1, 0), (2, 1), (0, 1), (0, 1)] + MG = nx.random_clustered_graph(deg, seed=42) + + I = nx.incidence_matrix(G, dtype=int).todense() + # fmt: off + expected = np.array( + [[1, 1, 1, 0], + [0, 1, 0, 1], + [1, 0, 0, 1], + [0, 0, 1, 0], + [0, 0, 0, 0]] + ) + # fmt: on + np.testing.assert_equal(I, expected) + + I = nx.incidence_matrix(MG, dtype=int).todense() + # fmt: off + expected = np.array( + [[1, 0, 0, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 1, 0], + [0, 0, 0, 0, 0, 1, 1], + [0, 0, 0, 0, 1, 0, 1]] + ) + # fmt: on + np.testing.assert_equal(I, expected) + + with pytest.raises(NetworkXError): + nx.incidence_matrix(G, nodelist=[0, 1]) + + +class TestGraphMatrix: + @classmethod + def setup_class(cls): + deg = [3, 2, 2, 1, 0] + cls.G = havel_hakimi_graph(deg) + # fmt: off + cls.OI = np.array( + [[-1, -1, -1, 0], + [1, 0, 0, -1], + [0, 1, 0, 1], + [0, 0, 1, 0], + [0, 0, 0, 0]] + ) + cls.A = np.array( + [[0, 1, 1, 1, 0], + [1, 0, 1, 0, 0], + [1, 1, 0, 0, 0], + [1, 0, 0, 0, 0], + [0, 0, 0, 0, 0]] + ) + # fmt: on + cls.WG = havel_hakimi_graph(deg) + cls.WG.add_edges_from( + (u, v, {"weight": 0.5, "other": 0.3}) for (u, v) in cls.G.edges() + ) + # fmt: off + cls.WA = np.array( + [[0, 0.5, 0.5, 0.5, 0], + [0.5, 0, 0.5, 0, 0], + [0.5, 0.5, 0, 0, 0], + [0.5, 0, 0, 0, 0], + [0, 0, 0, 0, 0]] + ) + # fmt: on + cls.MG = nx.MultiGraph(cls.G) + cls.MG2 = cls.MG.copy() + cls.MG2.add_edge(0, 1) + # fmt: off + cls.MG2A = np.array( + [[0, 2, 1, 1, 0], + [2, 0, 1, 0, 0], + [1, 1, 0, 0, 0], + [1, 0, 0, 0, 0], + [0, 0, 0, 0, 0]] + ) + cls.MGOI = np.array( + [[-1, -1, -1, -1, 0], + [1, 1, 0, 0, -1], + [0, 0, 1, 0, 1], + [0, 0, 0, 1, 0], + [0, 0, 0, 0, 0]] + ) + # fmt: on + cls.no_edges_G = nx.Graph([(1, 2), (3, 2, {"weight": 8})]) + cls.no_edges_A = np.array([[0, 0], [0, 0]]) + + def test_incidence_matrix(self): + "Conversion to incidence matrix" + I = nx.incidence_matrix( + self.G, + nodelist=sorted(self.G), + edgelist=sorted(self.G.edges()), + oriented=True, + dtype=int, + ).todense() + np.testing.assert_equal(I, self.OI) + + I = nx.incidence_matrix( + self.G, + nodelist=sorted(self.G), + edgelist=sorted(self.G.edges()), + oriented=False, + dtype=int, + ).todense() + np.testing.assert_equal(I, np.abs(self.OI)) + + I = nx.incidence_matrix( + self.MG, + nodelist=sorted(self.MG), + edgelist=sorted(self.MG.edges()), + oriented=True, + dtype=int, + ).todense() + np.testing.assert_equal(I, self.OI) + + I = nx.incidence_matrix( + self.MG, + nodelist=sorted(self.MG), + edgelist=sorted(self.MG.edges()), + oriented=False, + dtype=int, + ).todense() + np.testing.assert_equal(I, np.abs(self.OI)) + + I = nx.incidence_matrix( + self.MG2, + nodelist=sorted(self.MG2), + edgelist=sorted(self.MG2.edges()), + oriented=True, + dtype=int, + ).todense() + np.testing.assert_equal(I, self.MGOI) + + I = nx.incidence_matrix( + self.MG2, + nodelist=sorted(self.MG), + edgelist=sorted(self.MG2.edges()), + oriented=False, + dtype=int, + ).todense() + np.testing.assert_equal(I, np.abs(self.MGOI)) + + I = nx.incidence_matrix(self.G, dtype=np.uint8) + assert I.dtype == np.uint8 + + def test_weighted_incidence_matrix(self): + I = nx.incidence_matrix( + self.WG, + nodelist=sorted(self.WG), + edgelist=sorted(self.WG.edges()), + oriented=True, + dtype=int, + ).todense() + np.testing.assert_equal(I, self.OI) + + I = nx.incidence_matrix( + self.WG, + nodelist=sorted(self.WG), + edgelist=sorted(self.WG.edges()), + oriented=False, + dtype=int, + ).todense() + np.testing.assert_equal(I, np.abs(self.OI)) + + # np.testing.assert_equal(nx.incidence_matrix(self.WG,oriented=True, + # weight='weight').todense(),0.5*self.OI) + # np.testing.assert_equal(nx.incidence_matrix(self.WG,weight='weight').todense(), + # np.abs(0.5*self.OI)) + # np.testing.assert_equal(nx.incidence_matrix(self.WG,oriented=True,weight='other').todense(), + # 0.3*self.OI) + + I = nx.incidence_matrix( + self.WG, + nodelist=sorted(self.WG), + edgelist=sorted(self.WG.edges()), + oriented=True, + weight="weight", + ).todense() + np.testing.assert_equal(I, 0.5 * self.OI) + + I = nx.incidence_matrix( + self.WG, + nodelist=sorted(self.WG), + edgelist=sorted(self.WG.edges()), + oriented=False, + weight="weight", + ).todense() + np.testing.assert_equal(I, np.abs(0.5 * self.OI)) + + I = nx.incidence_matrix( + self.WG, + nodelist=sorted(self.WG), + edgelist=sorted(self.WG.edges()), + oriented=True, + weight="other", + ).todense() + np.testing.assert_equal(I, 0.3 * self.OI) + + # WMG=nx.MultiGraph(self.WG) + # WMG.add_edge(0,1,weight=0.5,other=0.3) + # np.testing.assert_equal(nx.incidence_matrix(WMG,weight='weight').todense(), + # np.abs(0.5*self.MGOI)) + # np.testing.assert_equal(nx.incidence_matrix(WMG,weight='weight',oriented=True).todense(), + # 0.5*self.MGOI) + # np.testing.assert_equal(nx.incidence_matrix(WMG,weight='other',oriented=True).todense(), + # 0.3*self.MGOI) + + WMG = nx.MultiGraph(self.WG) + WMG.add_edge(0, 1, weight=0.5, other=0.3) + + I = nx.incidence_matrix( + WMG, + nodelist=sorted(WMG), + edgelist=sorted(WMG.edges(keys=True)), + oriented=True, + weight="weight", + ).todense() + np.testing.assert_equal(I, 0.5 * self.MGOI) + + I = nx.incidence_matrix( + WMG, + nodelist=sorted(WMG), + edgelist=sorted(WMG.edges(keys=True)), + oriented=False, + weight="weight", + ).todense() + np.testing.assert_equal(I, np.abs(0.5 * self.MGOI)) + + I = nx.incidence_matrix( + WMG, + nodelist=sorted(WMG), + edgelist=sorted(WMG.edges(keys=True)), + oriented=True, + weight="other", + ).todense() + np.testing.assert_equal(I, 0.3 * self.MGOI) + + def test_adjacency_matrix(self): + "Conversion to adjacency matrix" + np.testing.assert_equal(nx.adjacency_matrix(self.G).todense(), self.A) + np.testing.assert_equal(nx.adjacency_matrix(self.MG).todense(), self.A) + np.testing.assert_equal(nx.adjacency_matrix(self.MG2).todense(), self.MG2A) + np.testing.assert_equal( + nx.adjacency_matrix(self.G, nodelist=[0, 1]).todense(), self.A[:2, :2] + ) + np.testing.assert_equal(nx.adjacency_matrix(self.WG).todense(), self.WA) + np.testing.assert_equal( + nx.adjacency_matrix(self.WG, weight=None).todense(), self.A + ) + np.testing.assert_equal( + nx.adjacency_matrix(self.MG2, weight=None).todense(), self.MG2A + ) + np.testing.assert_equal( + nx.adjacency_matrix(self.WG, weight="other").todense(), 0.6 * self.WA + ) + np.testing.assert_equal( + nx.adjacency_matrix(self.no_edges_G, nodelist=[1, 3]).todense(), + self.no_edges_A, + ) diff --git a/lib/python3.10/site-packages/networkx/linalg/tests/test_laplacian.py b/lib/python3.10/site-packages/networkx/linalg/tests/test_laplacian.py new file mode 100644 index 0000000000000000000000000000000000000000..23f1b28e19f1af4097ae3e99501a45439a6f1598 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/linalg/tests/test_laplacian.py @@ -0,0 +1,336 @@ +import pytest + +np = pytest.importorskip("numpy") +pytest.importorskip("scipy") + +import networkx as nx +from networkx.generators.degree_seq import havel_hakimi_graph +from networkx.generators.expanders import margulis_gabber_galil_graph + + +class TestLaplacian: + @classmethod + def setup_class(cls): + deg = [3, 2, 2, 1, 0] + cls.G = havel_hakimi_graph(deg) + cls.WG = nx.Graph( + (u, v, {"weight": 0.5, "other": 0.3}) for (u, v) in cls.G.edges() + ) + cls.WG.add_node(4) + cls.MG = nx.MultiGraph(cls.G) + + # Graph with clsloops + cls.Gsl = cls.G.copy() + for node in cls.Gsl.nodes(): + cls.Gsl.add_edge(node, node) + + # Graph used as an example in Sec. 4.1 of Langville and Meyer, + # "Google's PageRank and Beyond". + cls.DiG = nx.DiGraph() + cls.DiG.add_edges_from( + ( + (1, 2), + (1, 3), + (3, 1), + (3, 2), + (3, 5), + (4, 5), + (4, 6), + (5, 4), + (5, 6), + (6, 4), + ) + ) + cls.DiMG = nx.MultiDiGraph(cls.DiG) + cls.DiWG = nx.DiGraph( + (u, v, {"weight": 0.5, "other": 0.3}) for (u, v) in cls.DiG.edges() + ) + cls.DiGsl = cls.DiG.copy() + for node in cls.DiGsl.nodes(): + cls.DiGsl.add_edge(node, node) + + def test_laplacian(self): + "Graph Laplacian" + # fmt: off + NL = np.array([[ 3, -1, -1, -1, 0], + [-1, 2, -1, 0, 0], + [-1, -1, 2, 0, 0], + [-1, 0, 0, 1, 0], + [ 0, 0, 0, 0, 0]]) + # fmt: on + WL = 0.5 * NL + OL = 0.3 * NL + # fmt: off + DiNL = np.array([[ 2, -1, -1, 0, 0, 0], + [ 0, 0, 0, 0, 0, 0], + [-1, -1, 3, -1, 0, 0], + [ 0, 0, 0, 2, -1, -1], + [ 0, 0, 0, -1, 2, -1], + [ 0, 0, 0, 0, -1, 1]]) + # fmt: on + DiWL = 0.5 * DiNL + DiOL = 0.3 * DiNL + np.testing.assert_equal(nx.laplacian_matrix(self.G).todense(), NL) + np.testing.assert_equal(nx.laplacian_matrix(self.MG).todense(), NL) + np.testing.assert_equal( + nx.laplacian_matrix(self.G, nodelist=[0, 1]).todense(), + np.array([[1, -1], [-1, 1]]), + ) + np.testing.assert_equal(nx.laplacian_matrix(self.WG).todense(), WL) + np.testing.assert_equal(nx.laplacian_matrix(self.WG, weight=None).todense(), NL) + np.testing.assert_equal( + nx.laplacian_matrix(self.WG, weight="other").todense(), OL + ) + + np.testing.assert_equal(nx.laplacian_matrix(self.DiG).todense(), DiNL) + np.testing.assert_equal(nx.laplacian_matrix(self.DiMG).todense(), DiNL) + np.testing.assert_equal( + nx.laplacian_matrix(self.DiG, nodelist=[1, 2]).todense(), + np.array([[1, -1], [0, 0]]), + ) + np.testing.assert_equal(nx.laplacian_matrix(self.DiWG).todense(), DiWL) + np.testing.assert_equal( + nx.laplacian_matrix(self.DiWG, weight=None).todense(), DiNL + ) + np.testing.assert_equal( + nx.laplacian_matrix(self.DiWG, weight="other").todense(), DiOL + ) + + def test_normalized_laplacian(self): + "Generalized Graph Laplacian" + # fmt: off + G = np.array([[ 1. , -0.408, -0.408, -0.577, 0.], + [-0.408, 1. , -0.5 , 0. , 0.], + [-0.408, -0.5 , 1. , 0. , 0.], + [-0.577, 0. , 0. , 1. , 0.], + [ 0. , 0. , 0. , 0. , 0.]]) + GL = np.array([[ 1. , -0.408, -0.408, -0.577, 0. ], + [-0.408, 1. , -0.5 , 0. , 0. ], + [-0.408, -0.5 , 1. , 0. , 0. ], + [-0.577, 0. , 0. , 1. , 0. ], + [ 0. , 0. , 0. , 0. , 0. ]]) + Lsl = np.array([[ 0.75 , -0.2887, -0.2887, -0.3536, 0. ], + [-0.2887, 0.6667, -0.3333, 0. , 0. ], + [-0.2887, -0.3333, 0.6667, 0. , 0. ], + [-0.3536, 0. , 0. , 0.5 , 0. ], + [ 0. , 0. , 0. , 0. , 0. ]]) + + DiG = np.array([[ 1. , 0. , -0.4082, 0. , 0. , 0. ], + [ 0. , 0. , 0. , 0. , 0. , 0. ], + [-0.4082, 0. , 1. , 0. , -0.4082, 0. ], + [ 0. , 0. , 0. , 1. , -0.5 , -0.7071], + [ 0. , 0. , 0. , -0.5 , 1. , -0.7071], + [ 0. , 0. , 0. , -0.7071, 0. , 1. ]]) + DiGL = np.array([[ 1. , 0. , -0.4082, 0. , 0. , 0. ], + [ 0. , 0. , 0. , 0. , 0. , 0. ], + [-0.4082, 0. , 1. , -0.4082, 0. , 0. ], + [ 0. , 0. , 0. , 1. , -0.5 , -0.7071], + [ 0. , 0. , 0. , -0.5 , 1. , -0.7071], + [ 0. , 0. , 0. , 0. , -0.7071, 1. ]]) + DiLsl = np.array([[ 0.6667, -0.5774, -0.2887, 0. , 0. , 0. ], + [ 0. , 0. , 0. , 0. , 0. , 0. ], + [-0.2887, -0.5 , 0.75 , -0.2887, 0. , 0. ], + [ 0. , 0. , 0. , 0.6667, -0.3333, -0.4082], + [ 0. , 0. , 0. , -0.3333, 0.6667, -0.4082], + [ 0. , 0. , 0. , 0. , -0.4082, 0.5 ]]) + # fmt: on + + np.testing.assert_almost_equal( + nx.normalized_laplacian_matrix(self.G, nodelist=range(5)).todense(), + G, + decimal=3, + ) + np.testing.assert_almost_equal( + nx.normalized_laplacian_matrix(self.G).todense(), GL, decimal=3 + ) + np.testing.assert_almost_equal( + nx.normalized_laplacian_matrix(self.MG).todense(), GL, decimal=3 + ) + np.testing.assert_almost_equal( + nx.normalized_laplacian_matrix(self.WG).todense(), GL, decimal=3 + ) + np.testing.assert_almost_equal( + nx.normalized_laplacian_matrix(self.WG, weight="other").todense(), + GL, + decimal=3, + ) + np.testing.assert_almost_equal( + nx.normalized_laplacian_matrix(self.Gsl).todense(), Lsl, decimal=3 + ) + + np.testing.assert_almost_equal( + nx.normalized_laplacian_matrix( + self.DiG, + nodelist=range(1, 1 + 6), + ).todense(), + DiG, + decimal=3, + ) + np.testing.assert_almost_equal( + nx.normalized_laplacian_matrix(self.DiG).todense(), DiGL, decimal=3 + ) + np.testing.assert_almost_equal( + nx.normalized_laplacian_matrix(self.DiMG).todense(), DiGL, decimal=3 + ) + np.testing.assert_almost_equal( + nx.normalized_laplacian_matrix(self.DiWG).todense(), DiGL, decimal=3 + ) + np.testing.assert_almost_equal( + nx.normalized_laplacian_matrix(self.DiWG, weight="other").todense(), + DiGL, + decimal=3, + ) + np.testing.assert_almost_equal( + nx.normalized_laplacian_matrix(self.DiGsl).todense(), DiLsl, decimal=3 + ) + + +def test_directed_laplacian(): + "Directed Laplacian" + # Graph used as an example in Sec. 4.1 of Langville and Meyer, + # "Google's PageRank and Beyond". The graph contains dangling nodes, so + # the pagerank random walk is selected by directed_laplacian + G = nx.DiGraph() + G.add_edges_from( + ( + (1, 2), + (1, 3), + (3, 1), + (3, 2), + (3, 5), + (4, 5), + (4, 6), + (5, 4), + (5, 6), + (6, 4), + ) + ) + # fmt: off + GL = np.array([[ 0.9833, -0.2941, -0.3882, -0.0291, -0.0231, -0.0261], + [-0.2941, 0.8333, -0.2339, -0.0536, -0.0589, -0.0554], + [-0.3882, -0.2339, 0.9833, -0.0278, -0.0896, -0.0251], + [-0.0291, -0.0536, -0.0278, 0.9833, -0.4878, -0.6675], + [-0.0231, -0.0589, -0.0896, -0.4878, 0.9833, -0.2078], + [-0.0261, -0.0554, -0.0251, -0.6675, -0.2078, 0.9833]]) + # fmt: on + L = nx.directed_laplacian_matrix(G, alpha=0.9, nodelist=sorted(G)) + np.testing.assert_almost_equal(L, GL, decimal=3) + + # Make the graph strongly connected, so we can use a random and lazy walk + G.add_edges_from(((2, 5), (6, 1))) + # fmt: off + GL = np.array([[ 1. , -0.3062, -0.4714, 0. , 0. , -0.3227], + [-0.3062, 1. , -0.1443, 0. , -0.3162, 0. ], + [-0.4714, -0.1443, 1. , 0. , -0.0913, 0. ], + [ 0. , 0. , 0. , 1. , -0.5 , -0.5 ], + [ 0. , -0.3162, -0.0913, -0.5 , 1. , -0.25 ], + [-0.3227, 0. , 0. , -0.5 , -0.25 , 1. ]]) + # fmt: on + L = nx.directed_laplacian_matrix( + G, alpha=0.9, nodelist=sorted(G), walk_type="random" + ) + np.testing.assert_almost_equal(L, GL, decimal=3) + + # fmt: off + GL = np.array([[ 0.5 , -0.1531, -0.2357, 0. , 0. , -0.1614], + [-0.1531, 0.5 , -0.0722, 0. , -0.1581, 0. ], + [-0.2357, -0.0722, 0.5 , 0. , -0.0456, 0. ], + [ 0. , 0. , 0. , 0.5 , -0.25 , -0.25 ], + [ 0. , -0.1581, -0.0456, -0.25 , 0.5 , -0.125 ], + [-0.1614, 0. , 0. , -0.25 , -0.125 , 0.5 ]]) + # fmt: on + L = nx.directed_laplacian_matrix(G, alpha=0.9, nodelist=sorted(G), walk_type="lazy") + np.testing.assert_almost_equal(L, GL, decimal=3) + + # Make a strongly connected periodic graph + G = nx.DiGraph() + G.add_edges_from(((1, 2), (2, 4), (4, 1), (1, 3), (3, 4))) + # fmt: off + GL = np.array([[ 0.5 , -0.176, -0.176, -0.25 ], + [-0.176, 0.5 , 0. , -0.176], + [-0.176, 0. , 0.5 , -0.176], + [-0.25 , -0.176, -0.176, 0.5 ]]) + # fmt: on + L = nx.directed_laplacian_matrix(G, alpha=0.9, nodelist=sorted(G)) + np.testing.assert_almost_equal(L, GL, decimal=3) + + +def test_directed_combinatorial_laplacian(): + "Directed combinatorial Laplacian" + # Graph used as an example in Sec. 4.1 of Langville and Meyer, + # "Google's PageRank and Beyond". The graph contains dangling nodes, so + # the pagerank random walk is selected by directed_laplacian + G = nx.DiGraph() + G.add_edges_from( + ( + (1, 2), + (1, 3), + (3, 1), + (3, 2), + (3, 5), + (4, 5), + (4, 6), + (5, 4), + (5, 6), + (6, 4), + ) + ) + # fmt: off + GL = np.array([[ 0.0366, -0.0132, -0.0153, -0.0034, -0.0020, -0.0027], + [-0.0132, 0.0450, -0.0111, -0.0076, -0.0062, -0.0069], + [-0.0153, -0.0111, 0.0408, -0.0035, -0.0083, -0.0027], + [-0.0034, -0.0076, -0.0035, 0.3688, -0.1356, -0.2187], + [-0.0020, -0.0062, -0.0083, -0.1356, 0.2026, -0.0505], + [-0.0027, -0.0069, -0.0027, -0.2187, -0.0505, 0.2815]]) + # fmt: on + + L = nx.directed_combinatorial_laplacian_matrix(G, alpha=0.9, nodelist=sorted(G)) + np.testing.assert_almost_equal(L, GL, decimal=3) + + # Make the graph strongly connected, so we can use a random and lazy walk + G.add_edges_from(((2, 5), (6, 1))) + + # fmt: off + GL = np.array([[ 0.1395, -0.0349, -0.0465, 0. , 0. , -0.0581], + [-0.0349, 0.093 , -0.0116, 0. , -0.0465, 0. ], + [-0.0465, -0.0116, 0.0698, 0. , -0.0116, 0. ], + [ 0. , 0. , 0. , 0.2326, -0.1163, -0.1163], + [ 0. , -0.0465, -0.0116, -0.1163, 0.2326, -0.0581], + [-0.0581, 0. , 0. , -0.1163, -0.0581, 0.2326]]) + # fmt: on + + L = nx.directed_combinatorial_laplacian_matrix( + G, alpha=0.9, nodelist=sorted(G), walk_type="random" + ) + np.testing.assert_almost_equal(L, GL, decimal=3) + + # fmt: off + GL = np.array([[ 0.0698, -0.0174, -0.0233, 0. , 0. , -0.0291], + [-0.0174, 0.0465, -0.0058, 0. , -0.0233, 0. ], + [-0.0233, -0.0058, 0.0349, 0. , -0.0058, 0. ], + [ 0. , 0. , 0. , 0.1163, -0.0581, -0.0581], + [ 0. , -0.0233, -0.0058, -0.0581, 0.1163, -0.0291], + [-0.0291, 0. , 0. , -0.0581, -0.0291, 0.1163]]) + # fmt: on + + L = nx.directed_combinatorial_laplacian_matrix( + G, alpha=0.9, nodelist=sorted(G), walk_type="lazy" + ) + np.testing.assert_almost_equal(L, GL, decimal=3) + + E = nx.DiGraph(margulis_gabber_galil_graph(2)) + L = nx.directed_combinatorial_laplacian_matrix(E) + # fmt: off + expected = np.array( + [[ 0.16666667, -0.08333333, -0.08333333, 0. ], + [-0.08333333, 0.16666667, 0. , -0.08333333], + [-0.08333333, 0. , 0.16666667, -0.08333333], + [ 0. , -0.08333333, -0.08333333, 0.16666667]] + ) + # fmt: on + np.testing.assert_almost_equal(L, expected, decimal=6) + + with pytest.raises(nx.NetworkXError): + nx.directed_combinatorial_laplacian_matrix(G, walk_type="pagerank", alpha=100) + with pytest.raises(nx.NetworkXError): + nx.directed_combinatorial_laplacian_matrix(G, walk_type="silly") diff --git a/lib/python3.10/site-packages/networkx/linalg/tests/test_modularity.py b/lib/python3.10/site-packages/networkx/linalg/tests/test_modularity.py new file mode 100644 index 0000000000000000000000000000000000000000..9f94ff4db33a427fa2f0ef51470bc1c57c8b8682 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/linalg/tests/test_modularity.py @@ -0,0 +1,87 @@ +import pytest + +np = pytest.importorskip("numpy") +pytest.importorskip("scipy") + +import networkx as nx +from networkx.generators.degree_seq import havel_hakimi_graph + + +class TestModularity: + @classmethod + def setup_class(cls): + deg = [3, 2, 2, 1, 0] + cls.G = havel_hakimi_graph(deg) + # Graph used as an example in Sec. 4.1 of Langville and Meyer, + # "Google's PageRank and Beyond". (Used for test_directed_laplacian) + cls.DG = nx.DiGraph() + cls.DG.add_edges_from( + ( + (1, 2), + (1, 3), + (3, 1), + (3, 2), + (3, 5), + (4, 5), + (4, 6), + (5, 4), + (5, 6), + (6, 4), + ) + ) + + def test_modularity(self): + "Modularity matrix" + # fmt: off + B = np.array([[-1.125, 0.25, 0.25, 0.625, 0.], + [0.25, -0.5, 0.5, -0.25, 0.], + [0.25, 0.5, -0.5, -0.25, 0.], + [0.625, -0.25, -0.25, -0.125, 0.], + [0., 0., 0., 0., 0.]]) + # fmt: on + + permutation = [4, 0, 1, 2, 3] + np.testing.assert_equal(nx.modularity_matrix(self.G), B) + np.testing.assert_equal( + nx.modularity_matrix(self.G, nodelist=permutation), + B[np.ix_(permutation, permutation)], + ) + + def test_modularity_weight(self): + "Modularity matrix with weights" + # fmt: off + B = np.array([[-1.125, 0.25, 0.25, 0.625, 0.], + [0.25, -0.5, 0.5, -0.25, 0.], + [0.25, 0.5, -0.5, -0.25, 0.], + [0.625, -0.25, -0.25, -0.125, 0.], + [0., 0., 0., 0., 0.]]) + # fmt: on + + G_weighted = self.G.copy() + for n1, n2 in G_weighted.edges(): + G_weighted.edges[n1, n2]["weight"] = 0.5 + # The following test would fail in networkx 1.1 + np.testing.assert_equal(nx.modularity_matrix(G_weighted), B) + # The following test that the modularity matrix get rescaled accordingly + np.testing.assert_equal( + nx.modularity_matrix(G_weighted, weight="weight"), 0.5 * B + ) + + def test_directed_modularity(self): + "Directed Modularity matrix" + # fmt: off + B = np.array([[-0.2, 0.6, 0.8, -0.4, -0.4, -0.4], + [0., 0., 0., 0., 0., 0.], + [0.7, 0.4, -0.3, -0.6, 0.4, -0.6], + [-0.2, -0.4, -0.2, -0.4, 0.6, 0.6], + [-0.2, -0.4, -0.2, 0.6, -0.4, 0.6], + [-0.1, -0.2, -0.1, 0.8, -0.2, -0.2]]) + # fmt: on + node_permutation = [5, 1, 2, 3, 4, 6] + idx_permutation = [4, 0, 1, 2, 3, 5] + mm = nx.directed_modularity_matrix(self.DG, nodelist=sorted(self.DG)) + np.testing.assert_equal(mm, B) + np.testing.assert_equal( + nx.directed_modularity_matrix(self.DG, nodelist=node_permutation), + B[np.ix_(idx_permutation, idx_permutation)], + ) diff --git a/lib/python3.10/site-packages/networkx/linalg/tests/test_spectrum.py b/lib/python3.10/site-packages/networkx/linalg/tests/test_spectrum.py new file mode 100644 index 0000000000000000000000000000000000000000..e9101303cba60c56825101fa5762b56a3083e7af --- /dev/null +++ b/lib/python3.10/site-packages/networkx/linalg/tests/test_spectrum.py @@ -0,0 +1,71 @@ +import pytest + +np = pytest.importorskip("numpy") +pytest.importorskip("scipy") + +import networkx as nx +from networkx.generators.degree_seq import havel_hakimi_graph + + +class TestSpectrum: + @classmethod + def setup_class(cls): + deg = [3, 2, 2, 1, 0] + cls.G = havel_hakimi_graph(deg) + cls.P = nx.path_graph(3) + cls.WG = nx.Graph( + (u, v, {"weight": 0.5, "other": 0.3}) for (u, v) in cls.G.edges() + ) + cls.WG.add_node(4) + cls.DG = nx.DiGraph() + nx.add_path(cls.DG, [0, 1, 2]) + + def test_laplacian_spectrum(self): + "Laplacian eigenvalues" + evals = np.array([0, 0, 1, 3, 4]) + e = sorted(nx.laplacian_spectrum(self.G)) + np.testing.assert_almost_equal(e, evals) + e = sorted(nx.laplacian_spectrum(self.WG, weight=None)) + np.testing.assert_almost_equal(e, evals) + e = sorted(nx.laplacian_spectrum(self.WG)) + np.testing.assert_almost_equal(e, 0.5 * evals) + e = sorted(nx.laplacian_spectrum(self.WG, weight="other")) + np.testing.assert_almost_equal(e, 0.3 * evals) + + def test_normalized_laplacian_spectrum(self): + "Normalized Laplacian eigenvalues" + evals = np.array([0, 0, 0.7712864461218, 1.5, 1.7287135538781]) + e = sorted(nx.normalized_laplacian_spectrum(self.G)) + np.testing.assert_almost_equal(e, evals) + e = sorted(nx.normalized_laplacian_spectrum(self.WG, weight=None)) + np.testing.assert_almost_equal(e, evals) + e = sorted(nx.normalized_laplacian_spectrum(self.WG)) + np.testing.assert_almost_equal(e, evals) + e = sorted(nx.normalized_laplacian_spectrum(self.WG, weight="other")) + np.testing.assert_almost_equal(e, evals) + + def test_adjacency_spectrum(self): + "Adjacency eigenvalues" + evals = np.array([-np.sqrt(2), 0, np.sqrt(2)]) + e = sorted(nx.adjacency_spectrum(self.P)) + np.testing.assert_almost_equal(e, evals) + + def test_modularity_spectrum(self): + "Modularity eigenvalues" + evals = np.array([-1.5, 0.0, 0.0]) + e = sorted(nx.modularity_spectrum(self.P)) + np.testing.assert_almost_equal(e, evals) + # Directed modularity eigenvalues + evals = np.array([-0.5, 0.0, 0.0]) + e = sorted(nx.modularity_spectrum(self.DG)) + np.testing.assert_almost_equal(e, evals) + + def test_bethe_hessian_spectrum(self): + "Bethe Hessian eigenvalues" + evals = np.array([0.5 * (9 - np.sqrt(33)), 4, 0.5 * (9 + np.sqrt(33))]) + e = sorted(nx.bethe_hessian_spectrum(self.P, r=2)) + np.testing.assert_almost_equal(e, evals) + # Collapses back to Laplacian: + e1 = sorted(nx.bethe_hessian_spectrum(self.P, r=1)) + e2 = sorted(nx.laplacian_spectrum(self.P)) + np.testing.assert_almost_equal(e1, e2) diff --git a/lib/python3.10/site-packages/networkx/readwrite/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/networkx/readwrite/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6bd1e4bbe7043fe4888c94e851c4c1c497e95c8a Binary files /dev/null and b/lib/python3.10/site-packages/networkx/readwrite/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/readwrite/__pycache__/adjlist.cpython-310.pyc b/lib/python3.10/site-packages/networkx/readwrite/__pycache__/adjlist.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..13c21a6b10e1924bf7164008372bb0d9f11e4ef9 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/readwrite/__pycache__/adjlist.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/readwrite/__pycache__/edgelist.cpython-310.pyc b/lib/python3.10/site-packages/networkx/readwrite/__pycache__/edgelist.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b7bf3856edfbd9854cf3ceeb7959009656b53672 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/readwrite/__pycache__/edgelist.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/readwrite/__pycache__/gexf.cpython-310.pyc b/lib/python3.10/site-packages/networkx/readwrite/__pycache__/gexf.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..37bf6798bc0c8035910629fe0c6ee32a8465f4ae Binary files /dev/null and b/lib/python3.10/site-packages/networkx/readwrite/__pycache__/gexf.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/readwrite/__pycache__/gml.cpython-310.pyc b/lib/python3.10/site-packages/networkx/readwrite/__pycache__/gml.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03c58486e1baabc62fbce6d64698e2d2719122bc Binary files /dev/null and b/lib/python3.10/site-packages/networkx/readwrite/__pycache__/gml.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/readwrite/__pycache__/graph6.cpython-310.pyc b/lib/python3.10/site-packages/networkx/readwrite/__pycache__/graph6.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5cdbc592841786d31a9bc55f504da1987fef3370 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/readwrite/__pycache__/graph6.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/readwrite/__pycache__/graphml.cpython-310.pyc b/lib/python3.10/site-packages/networkx/readwrite/__pycache__/graphml.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1be7a35e4c4bcb6f2c7f194e1ae19c0ee5880ca5 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/readwrite/__pycache__/graphml.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/readwrite/__pycache__/leda.cpython-310.pyc b/lib/python3.10/site-packages/networkx/readwrite/__pycache__/leda.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d179be24ba104ea8126ce34bca11123ed10688e5 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/readwrite/__pycache__/leda.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/readwrite/__pycache__/multiline_adjlist.cpython-310.pyc b/lib/python3.10/site-packages/networkx/readwrite/__pycache__/multiline_adjlist.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..27848868ba42b25146044d3628720f85e1638cc0 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/readwrite/__pycache__/multiline_adjlist.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/readwrite/__pycache__/p2g.cpython-310.pyc b/lib/python3.10/site-packages/networkx/readwrite/__pycache__/p2g.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..09c9f524a7b86d353bcf2c426f088c1749eeb642 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/readwrite/__pycache__/p2g.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/readwrite/__pycache__/pajek.cpython-310.pyc b/lib/python3.10/site-packages/networkx/readwrite/__pycache__/pajek.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8664d9af5b6be5179ae58f6e56f0c9552dca0293 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/readwrite/__pycache__/pajek.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/readwrite/__pycache__/sparse6.cpython-310.pyc b/lib/python3.10/site-packages/networkx/readwrite/__pycache__/sparse6.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f8399a26e224772feba5c1b4164c1e9dafd1c7f7 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/readwrite/__pycache__/sparse6.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/readwrite/__pycache__/text.cpython-310.pyc b/lib/python3.10/site-packages/networkx/readwrite/__pycache__/text.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e518d8365a069b4483d02bfb12cdb06d17c367b Binary files /dev/null and b/lib/python3.10/site-packages/networkx/readwrite/__pycache__/text.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/readwrite/json_graph/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/networkx/readwrite/json_graph/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..beaf75f953fa6143696c0be11ff66da714a1c56e Binary files /dev/null and b/lib/python3.10/site-packages/networkx/readwrite/json_graph/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/readwrite/json_graph/__pycache__/adjacency.cpython-310.pyc b/lib/python3.10/site-packages/networkx/readwrite/json_graph/__pycache__/adjacency.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca87ddd69d23489f2a21fdc200a26ccb42660335 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/readwrite/json_graph/__pycache__/adjacency.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/readwrite/json_graph/__pycache__/cytoscape.cpython-310.pyc b/lib/python3.10/site-packages/networkx/readwrite/json_graph/__pycache__/cytoscape.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c9411c9ee18b11ff642aa9366f884021ac0f08d1 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/readwrite/json_graph/__pycache__/cytoscape.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/readwrite/json_graph/__pycache__/node_link.cpython-310.pyc b/lib/python3.10/site-packages/networkx/readwrite/json_graph/__pycache__/node_link.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aca84e261e15c8431b1c8fd94129d3cf9befe089 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/readwrite/json_graph/__pycache__/node_link.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/readwrite/json_graph/__pycache__/tree.cpython-310.pyc b/lib/python3.10/site-packages/networkx/readwrite/json_graph/__pycache__/tree.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9abb865b405e97827af29ba14f6724803fba8f4b Binary files /dev/null and b/lib/python3.10/site-packages/networkx/readwrite/json_graph/__pycache__/tree.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/readwrite/json_graph/tests/__init__.py b/lib/python3.10/site-packages/networkx/readwrite/json_graph/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/python3.10/site-packages/networkx/readwrite/json_graph/tests/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/networkx/readwrite/json_graph/tests/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec67eb47e6919dd4883810b4587b1f72995eb520 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/readwrite/json_graph/tests/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/readwrite/json_graph/tests/__pycache__/test_adjacency.cpython-310.pyc b/lib/python3.10/site-packages/networkx/readwrite/json_graph/tests/__pycache__/test_adjacency.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e6dbbe47e8c075a97210aa1f3fbdfad38eafa452 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/readwrite/json_graph/tests/__pycache__/test_adjacency.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/readwrite/json_graph/tests/__pycache__/test_cytoscape.cpython-310.pyc b/lib/python3.10/site-packages/networkx/readwrite/json_graph/tests/__pycache__/test_cytoscape.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6b39f232ab5b9c52958a51f9c601b23a32088b96 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/readwrite/json_graph/tests/__pycache__/test_cytoscape.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/readwrite/json_graph/tests/__pycache__/test_node_link.cpython-310.pyc b/lib/python3.10/site-packages/networkx/readwrite/json_graph/tests/__pycache__/test_node_link.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..359d464df49e13abf504950882de8d8c2d5d8bda Binary files /dev/null and b/lib/python3.10/site-packages/networkx/readwrite/json_graph/tests/__pycache__/test_node_link.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/readwrite/json_graph/tests/__pycache__/test_tree.cpython-310.pyc b/lib/python3.10/site-packages/networkx/readwrite/json_graph/tests/__pycache__/test_tree.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5144d7963e334bb1b4c5e7ab8deda95eee00d92d Binary files /dev/null and b/lib/python3.10/site-packages/networkx/readwrite/json_graph/tests/__pycache__/test_tree.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/readwrite/json_graph/tests/test_adjacency.py b/lib/python3.10/site-packages/networkx/readwrite/json_graph/tests/test_adjacency.py new file mode 100644 index 0000000000000000000000000000000000000000..37506382c55a110b26fdba32a268545d23f4474b --- /dev/null +++ b/lib/python3.10/site-packages/networkx/readwrite/json_graph/tests/test_adjacency.py @@ -0,0 +1,78 @@ +import copy +import json + +import pytest + +import networkx as nx +from networkx.readwrite.json_graph import adjacency_data, adjacency_graph +from networkx.utils import graphs_equal + + +class TestAdjacency: + def test_graph(self): + G = nx.path_graph(4) + H = adjacency_graph(adjacency_data(G)) + assert graphs_equal(G, H) + + def test_graph_attributes(self): + G = nx.path_graph(4) + G.add_node(1, color="red") + G.add_edge(1, 2, width=7) + G.graph["foo"] = "bar" + G.graph[1] = "one" + + H = adjacency_graph(adjacency_data(G)) + assert graphs_equal(G, H) + assert H.graph["foo"] == "bar" + assert H.nodes[1]["color"] == "red" + assert H[1][2]["width"] == 7 + + d = json.dumps(adjacency_data(G)) + H = adjacency_graph(json.loads(d)) + assert graphs_equal(G, H) + assert H.graph["foo"] == "bar" + assert H.graph[1] == "one" + assert H.nodes[1]["color"] == "red" + assert H[1][2]["width"] == 7 + + def test_digraph(self): + G = nx.DiGraph() + nx.add_path(G, [1, 2, 3]) + H = adjacency_graph(adjacency_data(G)) + assert H.is_directed() + assert graphs_equal(G, H) + + def test_multidigraph(self): + G = nx.MultiDiGraph() + nx.add_path(G, [1, 2, 3]) + H = adjacency_graph(adjacency_data(G)) + assert H.is_directed() + assert H.is_multigraph() + assert graphs_equal(G, H) + + def test_multigraph(self): + G = nx.MultiGraph() + G.add_edge(1, 2, key="first") + G.add_edge(1, 2, key="second", color="blue") + H = adjacency_graph(adjacency_data(G)) + assert graphs_equal(G, H) + assert H[1][2]["second"]["color"] == "blue" + + def test_input_data_is_not_modified_when_building_graph(self): + G = nx.path_graph(4) + input_data = adjacency_data(G) + orig_data = copy.deepcopy(input_data) + # Ensure input is unmodified by deserialisation + assert graphs_equal(G, adjacency_graph(input_data)) + assert input_data == orig_data + + def test_adjacency_form_json_serialisable(self): + G = nx.path_graph(4) + H = adjacency_graph(json.loads(json.dumps(adjacency_data(G)))) + assert graphs_equal(G, H) + + def test_exception(self): + with pytest.raises(nx.NetworkXError): + G = nx.MultiDiGraph() + attrs = {"id": "node", "key": "node"} + adjacency_data(G, attrs) diff --git a/lib/python3.10/site-packages/networkx/readwrite/json_graph/tests/test_cytoscape.py b/lib/python3.10/site-packages/networkx/readwrite/json_graph/tests/test_cytoscape.py new file mode 100644 index 0000000000000000000000000000000000000000..5d47f21f4217d1997165c4f19feb67d283d2dab2 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/readwrite/json_graph/tests/test_cytoscape.py @@ -0,0 +1,78 @@ +import copy +import json + +import pytest + +import networkx as nx +from networkx.readwrite.json_graph import cytoscape_data, cytoscape_graph + + +def test_graph(): + G = nx.path_graph(4) + H = cytoscape_graph(cytoscape_data(G)) + assert nx.is_isomorphic(G, H) + + +def test_input_data_is_not_modified_when_building_graph(): + G = nx.path_graph(4) + input_data = cytoscape_data(G) + orig_data = copy.deepcopy(input_data) + # Ensure input is unmodified by cytoscape_graph (gh-4173) + cytoscape_graph(input_data) + assert input_data == orig_data + + +def test_graph_attributes(): + G = nx.path_graph(4) + G.add_node(1, color="red") + G.add_edge(1, 2, width=7) + G.graph["foo"] = "bar" + G.graph[1] = "one" + G.add_node(3, name="node", id="123") + + H = cytoscape_graph(cytoscape_data(G)) + assert H.graph["foo"] == "bar" + assert H.nodes[1]["color"] == "red" + assert H[1][2]["width"] == 7 + assert H.nodes[3]["name"] == "node" + assert H.nodes[3]["id"] == "123" + + d = json.dumps(cytoscape_data(G)) + H = cytoscape_graph(json.loads(d)) + assert H.graph["foo"] == "bar" + assert H.graph[1] == "one" + assert H.nodes[1]["color"] == "red" + assert H[1][2]["width"] == 7 + assert H.nodes[3]["name"] == "node" + assert H.nodes[3]["id"] == "123" + + +def test_digraph(): + G = nx.DiGraph() + nx.add_path(G, [1, 2, 3]) + H = cytoscape_graph(cytoscape_data(G)) + assert H.is_directed() + assert nx.is_isomorphic(G, H) + + +def test_multidigraph(): + G = nx.MultiDiGraph() + nx.add_path(G, [1, 2, 3]) + H = cytoscape_graph(cytoscape_data(G)) + assert H.is_directed() + assert H.is_multigraph() + + +def test_multigraph(): + G = nx.MultiGraph() + G.add_edge(1, 2, key="first") + G.add_edge(1, 2, key="second", color="blue") + H = cytoscape_graph(cytoscape_data(G)) + assert nx.is_isomorphic(G, H) + assert H[1][2]["second"]["color"] == "blue" + + +def test_exception(): + with pytest.raises(nx.NetworkXError): + G = nx.MultiDiGraph() + cytoscape_data(G, name="foo", ident="foo") diff --git a/lib/python3.10/site-packages/networkx/readwrite/json_graph/tests/test_node_link.py b/lib/python3.10/site-packages/networkx/readwrite/json_graph/tests/test_node_link.py new file mode 100644 index 0000000000000000000000000000000000000000..f903f6066a6e7be922136275ba8fa6df0da41caa --- /dev/null +++ b/lib/python3.10/site-packages/networkx/readwrite/json_graph/tests/test_node_link.py @@ -0,0 +1,175 @@ +import json + +import pytest + +import networkx as nx +from networkx.readwrite.json_graph import node_link_data, node_link_graph + + +def test_node_link_edges_default_future_warning(): + "Test FutureWarning is raised when `edges=None` in node_link_data and node_link_graph" + G = nx.Graph([(1, 2)]) + with pytest.warns(FutureWarning, match="\nThe default value will be"): + data = nx.node_link_data(G) # edges=None, the default + with pytest.warns(FutureWarning, match="\nThe default value will be"): + H = nx.node_link_graph(data) # edges=None, the default + + +def test_node_link_deprecated_link_param(): + G = nx.Graph([(1, 2)]) + with pytest.warns(DeprecationWarning, match="Keyword argument 'link'"): + data = nx.node_link_data(G, link="links") + with pytest.warns(DeprecationWarning, match="Keyword argument 'link'"): + H = nx.node_link_graph(data, link="links") + + +class TestNodeLink: + # TODO: To be removed when signature change complete + def test_custom_attrs_dep(self): + G = nx.path_graph(4) + G.add_node(1, color="red") + G.add_edge(1, 2, width=7) + G.graph[1] = "one" + G.graph["foo"] = "bar" + + attrs = { + "source": "c_source", + "target": "c_target", + "name": "c_id", + "key": "c_key", + "link": "c_links", + } + + H = node_link_graph(node_link_data(G, **attrs), multigraph=False, **attrs) + assert nx.is_isomorphic(G, H) + assert H.graph["foo"] == "bar" + assert H.nodes[1]["color"] == "red" + assert H[1][2]["width"] == 7 + + # provide only a partial dictionary of keywords. + # This is similar to an example in the doc string + attrs = { + "link": "c_links", + "source": "c_source", + "target": "c_target", + } + H = node_link_graph(node_link_data(G, **attrs), multigraph=False, **attrs) + assert nx.is_isomorphic(G, H) + assert H.graph["foo"] == "bar" + assert H.nodes[1]["color"] == "red" + assert H[1][2]["width"] == 7 + + def test_exception_dep(self): + G = nx.MultiDiGraph() + with pytest.raises(nx.NetworkXError): + with pytest.warns(FutureWarning, match="\nThe default value will be"): + node_link_data(G, name="node", source="node", target="node", key="node") + + def test_graph(self): + G = nx.path_graph(4) + with pytest.warns(FutureWarning, match="\nThe default value will be"): + H = node_link_graph(node_link_data(G)) + assert nx.is_isomorphic(G, H) + + def test_graph_attributes(self): + G = nx.path_graph(4) + G.add_node(1, color="red") + G.add_edge(1, 2, width=7) + G.graph[1] = "one" + G.graph["foo"] = "bar" + + with pytest.warns(FutureWarning, match="\nThe default value will be"): + H = node_link_graph(node_link_data(G)) + assert H.graph["foo"] == "bar" + assert H.nodes[1]["color"] == "red" + assert H[1][2]["width"] == 7 + + with pytest.warns(FutureWarning, match="\nThe default value will be"): + d = json.dumps(node_link_data(G)) + with pytest.warns(FutureWarning, match="\nThe default value will be"): + H = node_link_graph(json.loads(d)) + assert H.graph["foo"] == "bar" + assert H.graph["1"] == "one" + assert H.nodes[1]["color"] == "red" + assert H[1][2]["width"] == 7 + + def test_digraph(self): + G = nx.DiGraph() + with pytest.warns(FutureWarning, match="\nThe default value will be"): + H = node_link_graph(node_link_data(G)) + assert H.is_directed() + + def test_multigraph(self): + G = nx.MultiGraph() + G.add_edge(1, 2, key="first") + G.add_edge(1, 2, key="second", color="blue") + with pytest.warns(FutureWarning, match="\nThe default value will be"): + H = node_link_graph(node_link_data(G)) + assert nx.is_isomorphic(G, H) + assert H[1][2]["second"]["color"] == "blue" + + def test_graph_with_tuple_nodes(self): + G = nx.Graph() + G.add_edge((0, 0), (1, 0), color=[255, 255, 0]) + with pytest.warns(FutureWarning, match="\nThe default value will be"): + d = node_link_data(G) + dumped_d = json.dumps(d) + dd = json.loads(dumped_d) + with pytest.warns(FutureWarning, match="\nThe default value will be"): + H = node_link_graph(dd) + assert H.nodes[(0, 0)] == G.nodes[(0, 0)] + assert H[(0, 0)][(1, 0)]["color"] == [255, 255, 0] + + def test_unicode_keys(self): + q = "qualité" + G = nx.Graph() + G.add_node(1, **{q: q}) + with pytest.warns(FutureWarning, match="\nThe default value will be"): + s = node_link_data(G) + output = json.dumps(s, ensure_ascii=False) + data = json.loads(output) + with pytest.warns(FutureWarning, match="\nThe default value will be"): + H = node_link_graph(data) + assert H.nodes[1][q] == q + + def test_exception(self): + G = nx.MultiDiGraph() + attrs = {"name": "node", "source": "node", "target": "node", "key": "node"} + with pytest.raises(nx.NetworkXError): + with pytest.warns(FutureWarning, match="\nThe default value will be"): + node_link_data(G, **attrs) + + def test_string_ids(self): + q = "qualité" + G = nx.DiGraph() + G.add_node("A") + G.add_node(q) + G.add_edge("A", q) + with pytest.warns(FutureWarning, match="\nThe default value will be"): + data = node_link_data(G) + assert data["links"][0]["source"] == "A" + assert data["links"][0]["target"] == q + with pytest.warns(FutureWarning, match="\nThe default value will be"): + H = node_link_graph(data) + assert nx.is_isomorphic(G, H) + + def test_custom_attrs(self): + G = nx.path_graph(4) + G.add_node(1, color="red") + G.add_edge(1, 2, width=7) + G.graph[1] = "one" + G.graph["foo"] = "bar" + + attrs = { + "source": "c_source", + "target": "c_target", + "name": "c_id", + "key": "c_key", + "link": "c_links", + } + + H = node_link_graph(node_link_data(G, **attrs), multigraph=False, **attrs) + assert nx.is_isomorphic(G, H) + assert H.graph["foo"] == "bar" + assert H.nodes[1]["color"] == "red" + assert H[1][2]["width"] == 7 diff --git a/lib/python3.10/site-packages/networkx/readwrite/json_graph/tests/test_tree.py b/lib/python3.10/site-packages/networkx/readwrite/json_graph/tests/test_tree.py new file mode 100644 index 0000000000000000000000000000000000000000..643a14d89b5211f2d97b98f2e227e68361781b97 --- /dev/null +++ b/lib/python3.10/site-packages/networkx/readwrite/json_graph/tests/test_tree.py @@ -0,0 +1,48 @@ +import json + +import pytest + +import networkx as nx +from networkx.readwrite.json_graph import tree_data, tree_graph + + +def test_graph(): + G = nx.DiGraph() + G.add_nodes_from([1, 2, 3], color="red") + G.add_edge(1, 2, foo=7) + G.add_edge(1, 3, foo=10) + G.add_edge(3, 4, foo=10) + H = tree_graph(tree_data(G, 1)) + assert nx.is_isomorphic(G, H) + + +def test_graph_attributes(): + G = nx.DiGraph() + G.add_nodes_from([1, 2, 3], color="red") + G.add_edge(1, 2, foo=7) + G.add_edge(1, 3, foo=10) + G.add_edge(3, 4, foo=10) + H = tree_graph(tree_data(G, 1)) + assert H.nodes[1]["color"] == "red" + + d = json.dumps(tree_data(G, 1)) + H = tree_graph(json.loads(d)) + assert H.nodes[1]["color"] == "red" + + +def test_exceptions(): + with pytest.raises(TypeError, match="is not a tree."): + G = nx.complete_graph(3) + tree_data(G, 0) + with pytest.raises(TypeError, match="is not directed."): + G = nx.path_graph(3) + tree_data(G, 0) + with pytest.raises(TypeError, match="is not weakly connected."): + G = nx.path_graph(3, create_using=nx.DiGraph) + G.add_edge(2, 0) + G.add_node(3) + tree_data(G, 0) + with pytest.raises(nx.NetworkXError, match="must be different."): + G = nx.MultiDiGraph() + G.add_node(0) + tree_data(G, 0, ident="node", children="node") diff --git a/lib/python3.10/site-packages/networkx/readwrite/tests/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/networkx/readwrite/tests/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..05f2bcb9fb206a75822c4784f24c957983e1a02d Binary files /dev/null and b/lib/python3.10/site-packages/networkx/readwrite/tests/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/readwrite/tests/__pycache__/test_adjlist.cpython-310.pyc b/lib/python3.10/site-packages/networkx/readwrite/tests/__pycache__/test_adjlist.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4961bc19e52d99226bc4c8c19255a1d77e8eaece Binary files /dev/null and b/lib/python3.10/site-packages/networkx/readwrite/tests/__pycache__/test_adjlist.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/readwrite/tests/__pycache__/test_edgelist.cpython-310.pyc b/lib/python3.10/site-packages/networkx/readwrite/tests/__pycache__/test_edgelist.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..036cc674037ba94c93bf264edf3957eb1dea0a74 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/readwrite/tests/__pycache__/test_edgelist.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/readwrite/tests/__pycache__/test_gexf.cpython-310.pyc b/lib/python3.10/site-packages/networkx/readwrite/tests/__pycache__/test_gexf.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e928c9a62e6ff71ab528ab1f8a8c2e9219f9cbe3 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/readwrite/tests/__pycache__/test_gexf.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/readwrite/tests/__pycache__/test_gml.cpython-310.pyc b/lib/python3.10/site-packages/networkx/readwrite/tests/__pycache__/test_gml.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1f458e4e3f41a281dd977336db2ac8b49dd87825 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/readwrite/tests/__pycache__/test_gml.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/readwrite/tests/__pycache__/test_graph6.cpython-310.pyc b/lib/python3.10/site-packages/networkx/readwrite/tests/__pycache__/test_graph6.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..32070d1ddadc50b50468c2aceb28d0d446ac14b3 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/readwrite/tests/__pycache__/test_graph6.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/readwrite/tests/__pycache__/test_graphml.cpython-310.pyc b/lib/python3.10/site-packages/networkx/readwrite/tests/__pycache__/test_graphml.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..873aefba8a90879b76b8c1d76c585e301f14eb3e Binary files /dev/null and b/lib/python3.10/site-packages/networkx/readwrite/tests/__pycache__/test_graphml.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/readwrite/tests/__pycache__/test_leda.cpython-310.pyc b/lib/python3.10/site-packages/networkx/readwrite/tests/__pycache__/test_leda.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e25b32ba549e9f97e55a3e909263304003a3ae7b Binary files /dev/null and b/lib/python3.10/site-packages/networkx/readwrite/tests/__pycache__/test_leda.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/readwrite/tests/__pycache__/test_p2g.cpython-310.pyc b/lib/python3.10/site-packages/networkx/readwrite/tests/__pycache__/test_p2g.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..872460b61db6793b4017a4b3891837a3330413f0 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/readwrite/tests/__pycache__/test_p2g.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/readwrite/tests/__pycache__/test_pajek.cpython-310.pyc b/lib/python3.10/site-packages/networkx/readwrite/tests/__pycache__/test_pajek.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6922d95506a09f9a0d93583a530fea7609094bb2 Binary files /dev/null and b/lib/python3.10/site-packages/networkx/readwrite/tests/__pycache__/test_pajek.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/readwrite/tests/__pycache__/test_sparse6.cpython-310.pyc b/lib/python3.10/site-packages/networkx/readwrite/tests/__pycache__/test_sparse6.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..27258b5823dfbec534615ac6d66074263c8f542b Binary files /dev/null and b/lib/python3.10/site-packages/networkx/readwrite/tests/__pycache__/test_sparse6.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/networkx/readwrite/tests/__pycache__/test_text.cpython-310.pyc b/lib/python3.10/site-packages/networkx/readwrite/tests/__pycache__/test_text.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1479da80c8cea2c8d19fb28f5f62d927bad91b7a Binary files /dev/null and b/lib/python3.10/site-packages/networkx/readwrite/tests/__pycache__/test_text.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/triton/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/triton/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2911e6304ec2142aac0ff20d6cd59379ec8400c0 Binary files /dev/null and b/lib/python3.10/site-packages/triton/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/triton/__pycache__/errors.cpython-310.pyc b/lib/python3.10/site-packages/triton/__pycache__/errors.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae3eb3d81f4b4dd6e22092833386719201a913f6 Binary files /dev/null and b/lib/python3.10/site-packages/triton/__pycache__/errors.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/triton/__pycache__/testing.cpython-310.pyc b/lib/python3.10/site-packages/triton/__pycache__/testing.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c0d48e06f382de623f788873a3ee59abffde500 Binary files /dev/null and b/lib/python3.10/site-packages/triton/__pycache__/testing.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/triton/backends/__init__.py b/lib/python3.10/site-packages/triton/backends/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fbf65d9e908fdb92474b69f0de86d0826fedeca5 --- /dev/null +++ b/lib/python3.10/site-packages/triton/backends/__init__.py @@ -0,0 +1,50 @@ +import os +import importlib.util +import inspect +from dataclasses import dataclass +from .driver import DriverBase +from .compiler import BaseBackend + + +def _load_module(name, path): + spec = importlib.util.spec_from_file_location(name[:-3], path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +def _find_concrete_subclasses(module, base_class): + ret = [] + for attr_name in dir(module): + attr = getattr(module, attr_name) + if isinstance(attr, type) and issubclass(attr, base_class) and not inspect.isabstract(attr): + ret.append(attr) + if len(ret) == 0: + raise RuntimeError(f"Found 0 concrete subclasses of {base_class} in {module}: {ret}") + if len(ret) > 1: + raise RuntimeError(f"Found >1 concrete subclasses of {base_class} in {module}: {ret}") + return ret[0] + + +@dataclass(frozen=True) +class Backend: + compiler: BaseBackend = None + driver: DriverBase = None + + +def _discover_backends(): + backends = dict() + root = os.path.dirname(__file__) + for name in os.listdir(root): + if not os.path.isdir(os.path.join(root, name)): + continue + if name.startswith('__'): + continue + compiler = _load_module(name, os.path.join(root, name, 'compiler.py')) + driver = _load_module(name, os.path.join(root, name, 'driver.py')) + backends[name] = Backend(_find_concrete_subclasses(compiler, BaseBackend), + _find_concrete_subclasses(driver, DriverBase)) + return backends + + +backends = _discover_backends() diff --git a/lib/python3.10/site-packages/triton/backends/compiler.py b/lib/python3.10/site-packages/triton/backends/compiler.py new file mode 100644 index 0000000000000000000000000000000000000000..990690045204b34f8f335073904436448a3e7918 --- /dev/null +++ b/lib/python3.10/site-packages/triton/backends/compiler.py @@ -0,0 +1,76 @@ +import os +import re +import subprocess + +from abc import ABCMeta, abstractmethod, abstractclassmethod +from dataclasses import dataclass +from typing import Union + + +@dataclass(frozen=True) +class GPUTarget(object): + # Target backend, e.g., cuda, hip + backend: str + # Target architecture, e.g., 90 (for cuda compute capability), gfx940 (for hip) + arch: Union[int, str] + warp_size: int + + +class BaseBackend(metaclass=ABCMeta): + + def __init__(self, target: GPUTarget) -> None: + self.target = target + assert self.supports_target(target) + + @staticmethod + def _path_to_binary(binary: str): + base_dir = os.path.join(os.path.dirname(__file__), os.pardir) + paths = [ + os.environ.get(f"TRITON_{binary.upper()}_PATH", ""), + os.path.join(base_dir, "third_party", "cuda", "bin", binary), + ] + for p in paths: + bin = p.split(" ")[0] + if os.path.exists(bin) and os.path.isfile(bin): + result = subprocess.check_output([bin, "--version"], stderr=subprocess.STDOUT) + if result is not None: + version = re.search(r".*release (\d+\.\d+).*", result.decode("utf-8"), flags=re.MULTILINE) + if version is not None: + return p, version.group(1) + raise RuntimeError(f"Cannot find {binary}") + + @abstractclassmethod + def supports_target(target: GPUTarget): + raise NotImplementedError + + @abstractmethod + def hash(self) -> str: + """Returns a unique identifier for this backend""" + raise NotImplementedError + + @abstractmethod + def parse_options(self, options: dict) -> object: + """ + Converts an `options` dictionary into an arbitrary object and returns it. + This function may contain target-specific heuristics and check the legality of the provided options + """ + raise NotImplementedError + + @abstractmethod + def add_stages(self, stages: dict, options: object) -> None: + """ + Populates `stages` dictionary with entries of the form: + ir_name [str] => Function[(src: str, metadata: dict) -> str|bytes] + The value of each entry may populate a `metadata` dictionary. + Stages will be run sequentially (in inseriton order) and can communicate using `metadata`. + All stages are expected to return a `str` object, except for the last stage which returns + a `bytes` object for execution by the launcher. + """ + raise NotImplementedError + + @abstractmethod + def load_dialects(self, context): + """ + Load additional MLIR dialects into the provided `context` + """ + raise NotImplementedError diff --git a/lib/python3.10/site-packages/triton/backends/driver.py b/lib/python3.10/site-packages/triton/backends/driver.py new file mode 100644 index 0000000000000000000000000000000000000000..e66442943b0399e398b15089c21446fdb62db5f5 --- /dev/null +++ b/lib/python3.10/site-packages/triton/backends/driver.py @@ -0,0 +1,34 @@ +from abc import ABCMeta, abstractmethod, abstractclassmethod + + +class DriverBase(metaclass=ABCMeta): + + @abstractclassmethod + def is_active(self): + pass + + @abstractmethod + def get_current_target(self): + pass + + def __init__(self) -> None: + pass + + +class GPUDriver(DriverBase): + + def __init__(self): + # TODO: support other frameworks than torch + import torch + self.get_device_capability = torch.cuda.get_device_capability + try: + from torch._C import _cuda_getCurrentRawStream + self.get_current_stream = _cuda_getCurrentRawStream + except ImportError: + self.get_current_stream = lambda idx: torch.cuda.current_stream(idx).cuda_stream + self.get_current_device = torch.cuda.current_device + self.set_current_device = torch.cuda.set_device + + # TODO: remove once TMA is cleaned up + def assemble_tensormap_to_arg(self, tensormaps_info, args): + return args diff --git a/lib/python3.10/site-packages/triton/compiler/__init__.py b/lib/python3.10/site-packages/triton/compiler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ce0cfedfcdd10622773990665d21b932b08a63d8 --- /dev/null +++ b/lib/python3.10/site-packages/triton/compiler/__init__.py @@ -0,0 +1,4 @@ +from .compiler import CompiledKernel, ASTSource, compile, AttrsDescriptor, make_backend, LazyDict +from .errors import CompilationError + +__all__ = ["compile", "make_backend", "ASTSource", "AttrsDescriptor", "CompiledKernel", "CompilationError", "LazyDict"] diff --git a/lib/python3.10/site-packages/triton/compiler/code_generator.py b/lib/python3.10/site-packages/triton/compiler/code_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..6903052ca21966fcf4edb29d8429458457c51051 --- /dev/null +++ b/lib/python3.10/site-packages/triton/compiler/code_generator.py @@ -0,0 +1,1302 @@ +import ast +import inspect +import re +import sys +import warnings +import os +import textwrap +from typing import Any, Callable, Dict, Optional, Tuple, Type, Union +from .. import language +from .._C.libtriton import ir +from ..language import constexpr, tensor, str_to_ty +from ..runtime.jit import _normalize_ty +# ideally we wouldn't need any runtime component +from ..runtime import JITFunction +from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct) +from types import ModuleType + + +def mangle_ty(ty): + if ty.is_ptr(): + return 'P' + mangle_ty(ty.element_ty) + if ty.is_int(): + SIGNED = language.dtype.SIGNEDNESS.SIGNED + prefix = 'i' if ty.int_signedness == SIGNED else 'u' + return prefix + str(ty.int_bitwidth) + if ty.is_floating(): + return str(ty) + if ty.is_block(): + elt = mangle_ty(ty.scalar) + shape = '_'.join(map(str, ty.shape)) + return f'{elt}S{shape}S' + if ty.is_void(): + return 'V' + assert False, "Unsupported type" + + +def mangle_fn(name, arg_tys, constants): + # doesn't mangle ret type, which must be a function of arg tys + mangled_arg_names = '_'.join([mangle_ty(ty) for ty in arg_tys]) + mangled_constants = '_'.join([f'{i}c{repr(constants[i])}' for i in sorted(constants)]) + mangled_constants = mangled_constants.replace('.', '_d_') + mangled_constants = mangled_constants.replace("'", '_sq_') + # [ and ] are not allowed in LLVM identifiers + mangled_constants = mangled_constants.replace('[', '_').replace(']', '_') + ret = f'{name}__{mangled_arg_names}__{mangled_constants}' + return ret + + +def _is_triton_tensor(o: Any) -> bool: + return isinstance(o, tensor) + + +def _is_constexpr(o: Any) -> bool: + return isinstance(o, constexpr) + + +def _is_triton_scalar(o: Any) -> bool: + return _is_triton_tensor(o) and (not o.type.is_block() or o.type.numel == 1) + + +def _is_list_like(o: Any) -> bool: + return isinstance(o, (list, tuple)) + + +def _unwrap_if_constexpr(o: Any): + return o.value if isinstance(o, constexpr) else o + + +def _check_fn_args(node, fn, args): + if fn.noinline: + for idx, arg in enumerate(args): + if not _is_constexpr(arg) and not _is_triton_scalar(arg): + raise UnsupportedLanguageConstruct( + fn.src, node, + f'Function {fn.__name__} is marked noinline, but was called with non-scalar argument {fn.arg_names[idx]}:{arg}' + ) + + +def _get_fn_file_line(fn): + base_fn = fn + while not isinstance(base_fn, JITFunction): + base_fn = base_fn.fn + file_name = base_fn.fn.__code__.co_filename + lines, begin_line = inspect.getsourcelines(base_fn.fn) + # Match the following pattern: + # @triton.autotune(...) <- foo.__code__.co_firstlineno + # @triton.heuristics(...) + # @triton.jit + # def foo(...): <- this line is the first line + for idx, line in enumerate(lines): + if line.strip().startswith("def "): + begin_line += idx + break + return file_name, begin_line + + +_condition_types = {bool, int, type(None)} # Python types accepted for conditionals inside kernels + + +class enter_sub_region: + + def __init__(self, generator): + self.generator = generator + + def __enter__(self): + # record lscope & local_defs in the parent scope + self.liveins = self.generator.lscope.copy() + self.prev_defs = self.generator.local_defs.copy() + self.generator.local_defs = {} + self.insert_block = self.generator.builder.get_insertion_block() + self.insert_point = self.generator.builder.get_insertion_point() + return self.liveins, self.insert_block + + def __exit__(self, *args, **kwargs): + self.generator.builder.restore_insertion_point(self.insert_point) + self.generator.lscope = self.liveins + self.generator.local_defs = self.prev_defs + + +# Check if the given syntax node has an "early" return +class ContainsReturnChecker(ast.NodeVisitor): + + def __init__(self, gscope): + self.gscope = gscope + + def _visit_stmts(self, body) -> bool: + for s in body: + if self.visit(s): + return True + return False + + def _visit_function(self, fn) -> bool: + # Currently we only support JITFunctions defined in the global scope + if isinstance(fn, JITFunction) and not fn.noinline: + fn_node = fn.parse() + return ContainsReturnChecker(self.gscope).visit(fn_node) + return False + + def generic_visit(self, node) -> bool: + ret = False + for _, value in ast.iter_fields(node): + if isinstance(value, list): + for item in value: + if isinstance(item, ast.AST): + ret = ret or self.visit(item) + elif isinstance(value, ast.AST): + ret = ret or self.visit(value) + return ret + + def visit_Attribute(self, node: ast.Attribute) -> bool: + # If the left part is a name, it's possible that + # we call triton native function or a jit function from another module. + # If the left part is not a name, it must return a tensor or a constexpr + # whose methods do not contain return statements + # e.g., (tl.load(x)).to(y) + # So we only check if the expressions within value have return or not + if isinstance(node.value, ast.Name): + if node.value.id in self.gscope: + value = self.gscope[node.value.id] + fn = getattr(value, node.attr) + return self._visit_function(fn) + return False + return self.visit(node.value) + + def visit_Name(self, node: ast.Name) -> bool: + if type(node.ctx) == ast.Store: + return False + if node.id in self.gscope: + fn = self.gscope[node.id] + return self._visit_function(fn) + return False + + def visit_Return(self, node: ast.Return) -> bool: + return True + + def visit_Assign(self, node: ast.Assign) -> bool: + # There couldn't be an early return + # x = ... + return False + + def visit_AugAssign(self, node: ast.AugAssign) -> bool: + # There couldn't be an early return + # x += ... + return False + + def visit_Module(self, node: ast.Module) -> bool: + return self._visit_stmts(node.body) + + def visit_FunctionDef(self, node: ast.FunctionDef) -> bool: + return self._visit_stmts(node.body) + + def visit_If(self, node: ast.If) -> bool: + # TODO: optimize the following case in which we actually don't have + # a return when static_cond is false: + # if dynamic_cond + # if static_cond + # func_with_return + # else + # func_without_return + ret = self._visit_stmts(node.body) + if node.orelse: + ret = ret or self._visit_stmts(node.orelse) + return ret + + def visit_IfExp(self, node: ast.IfExp) -> bool: + return self.visit(node.body) or self.visit(node.orelse) + + def visit_Call(self, node: ast.Call) -> bool: + return self.visit(node.func) + + +class CodeGenerator(ast.NodeVisitor): + + def __init__(self, context, prototype, gscope, attributes, constants, function_name, jit_fn: JITFunction, options, + codegen_fns, debug=None, module=None, is_kernel=False, function_types: Optional[Dict] = None, + noinline=False, file_name: Optional[str] = None, begin_line=0): + self.context = context + self.builder = ir.builder(context) + self.file_name = file_name + # node.lineno starts from 1, so we need to subtract 1 + self.begin_line = begin_line - 1 + self.builder.set_loc(file_name, begin_line, 0) + self.builder.options = options + # dict of functions provided by the backend. Below are the list of possible functions: + # Convert custom types not natively supported on HW. + # convert_custom_types(intput_tensor, dtype, fp_downcast_rounding=None, _builder=None) + self.builder.codegen_fns = codegen_fns + self.module = self.builder.create_module() if module is None else module + self.function_ret_types = {} if function_types is None else function_types + self.prototype = prototype + self.gscope = gscope + self.lscope = dict() + self.attributes = attributes + self.constants = constants + self.jit_fn = jit_fn + self.function_name = function_name + self.is_kernel = is_kernel + self.cur_node = None + self.debug = options.debug if debug is None else debug + self.noinline = noinline + self.scf_stack = [] + self.ret_type = None + # SSA-construction + # name => language.tensor + self.local_defs: Dict[str, tensor] = {} + self.dereference_name: Callable[[str], Any] = self._define_name_lookup() + self.fn = None + # Are we currently visiting an ast.arg's default value? These have some + # special handling. + self.visiting_arg_default_value = False + + builtin_namespace: Dict[str, Any] = {_.__name__: _ for _ in (len, list, range, float, int, isinstance, getattr)} + builtin_namespace.update(( + ('print', language.core.device_print), + ('min', language.minimum), + ('max', language.maximum), + )) + + def _unsupported(self, node, message): + return UnsupportedLanguageConstruct(self.jit_fn.src, node, message) + + def _is_constexpr_global(self, name): + absent_marker = object() + val = self.gscope.get(name, absent_marker) + if val is absent_marker: + return False + + if _is_constexpr(val): + return True + + if a := self.gscope.get("__annotations__", {}).get(name): + return _normalize_ty(a) == "constexpr" + + return False + + def _define_name_lookup(self): + + def local_lookup(name: str, absent): + # this needs to be re-fetched from `self` every time, because it gets switched occasionally + return self.lscope.get(name, absent) + + def global_lookup(name: str, absent): + val = self.gscope.get(name, absent) + # The high-level rule is that only constexpr globals are allowed. + # But actually a bunch of other things, such as module imports, are + # technically Python globals. We have to allow these too! + if (val is absent # + or name in self.builtin_namespace # + or type(val) == ModuleType # + or isinstance(val, JITFunction) # + or getattr(val, "__triton_builtin__", False) # + or getattr(val, "__module__", "").startswith("triton.language") # + or isinstance(val, language.dtype) # + or self._is_constexpr_global(name) # + # Allow accesses to globals while visiting an ast.arg + # because you should be able to do + # @triton.jit def fn(x: tl.constexpr = GLOBAL): ... + or self.visiting_arg_default_value # + or os.environ.get("TRITON_ALLOW_NON_CONSTEXPR_GLOBALS", "0") == "1"): + return val + raise NameError( + textwrap.dedent(f"""\ + Cannot access global variable {name} from within @jit'ed + function. Triton kernels can only access global variables that + are annotated as constexpr (`x: triton.language.constexpr = 42` + or `x = triton.language.constexpr(42)`). Alternatively, set the + envvar TRITON_ALLOW_NON_CONSTEXPR_GLOBALS=1, but we do not + promise to support this forever.""").replace("\n", " ")) + + absent_marker = object() + + def name_lookup(name: str) -> Any: + absent = absent_marker + for lookup_function in local_lookup, global_lookup, self.builtin_namespace.get: + value = lookup_function(name, absent) + if value is not absent: + return value + raise NameError(f'{name} is not defined') + + return name_lookup + + def set_value(self, name: str, value: Union[tensor, constexpr]) -> None: + ''' This function: + called by visit_Assign() & visit_FunctionDef() to store left value (lvalue) + 1. record local defined name (FIXME: should consider control flow) + 2. store tensor in self.lvalue + ''' + self.lscope[name] = value + self.local_defs[name] = value + + def _get_insertion_point_and_loc(self): + # XXX: this is a hack to get the location of the insertion point. + # The insertion point's location could be invalid sometimes, + # so we need to explicitly set the location + loc = self.builder.get_loc() + ip = self.builder.get_insertion_point() + return ip, loc + + def _set_insertion_point_and_loc(self, ip, loc): + self.builder.restore_insertion_point(ip) + self.builder.set_loc(loc) + + # + # AST visitor + # + def visit_compound_statement(self, stmts): + # Ensure that stmts is iterable + if not _is_list_like(stmts): + stmts = [stmts] + for stmt in stmts: + self.visit(stmt) + + # Stop parsing as soon as we hit a `return` statement; everything + # after this is dead code. + if isinstance(stmt, ast.Return): + break + + def visit_Module(self, node): + ast.NodeVisitor.generic_visit(self, node) + + def visit_List(self, node): + ctx = self.visit(node.ctx) + assert ctx is None + elts = [self.visit(elt) for elt in node.elts] + return elts + + # By design, only non-kernel functions can return + def visit_Return(self, node): + ret_value = self.visit(node.value) + # ret_block = self.builder.create_block() + # post_ret_block = self.builder.create_block() + # self.builder.create_branch(ret_block) + # self.builder.set_insertion_point_to_end(ret_block) + if ret_value is None: + self.builder.ret([]) + ret_ty = language.void + elif isinstance(ret_value, tuple): + ret_values = [language.core._to_tensor(v, self.builder) for v in ret_value] + ret_types = [v.type for v in ret_values] + self.builder.ret([v.handle for v in ret_values]) + ret_ty = tuple(ret_types) + else: + ret = language.core._to_tensor(ret_value, self.builder) + self.builder.ret([ret.handle]) + ret_ty = ret.type + # self.builder.create_branch(post_ret_block) + # self.builder.set_insertion_point_to_end(post_ret_block) + + if self.ret_type is None: + self.ret_type = ret_ty + elif self.ret_type != ret_ty: + raise TypeError(f'Inconsistent return types: {self.ret_type} and {ret_ty}') + + def visit_FunctionDef(self, node): + arg_names, kwarg_names = self.visit(node.args) + if self.fn: + raise self._unsupported(node, "nested function definition is not supported.") + # initialize defaults + for i, default_value in enumerate(node.args.defaults): + arg_node = node.args.args[-i - 1] + annotation = arg_node.annotation + name = arg_node.arg + st_target = ast.Name(id=name, ctx=ast.Store()) + if annotation is None: + init_node = ast.Assign(targets=[st_target], value=default_value) + else: + init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation) + + try: + assert not self.visiting_arg_default_value + self.visiting_arg_default_value = True + self.visit(init_node) + finally: + self.visiting_arg_default_value = False + + # initialize function + visibility = "public" if self.is_kernel else "private" + self.fn = self.builder.get_or_insert_function(self.module, self.function_name, + self.prototype.to_ir(self.builder), visibility, self.noinline) + self.module.push_back(self.fn) + entry = self.fn.add_entry_block() + arg_values = [] + idx = 0 + for i, arg_name in enumerate(arg_names): + if i in self.constants: + cst = self.constants[i] + if not _is_constexpr(cst): + cst = constexpr(self.constants[i]) + arg_values.append(cst) + continue + else: + if i in self.attributes: + for name, value in self.attributes[i]: + self.fn.set_arg_attr(idx, name, value) + arg_values.append(tensor(self.fn.args(idx), self.prototype.param_types[idx])) + idx += 1 + + insert_pt = self.builder.get_insertion_block() + for arg_name, arg_value in zip(arg_names, arg_values): + self.set_value(arg_name, arg_value) + self.builder.set_insertion_point_to_start(entry) + # visit function body + self.visit_compound_statement(node.body) + # finalize function + if self.ret_type is None or self.ret_type == language.void: + self.ret_type = language.void + self.builder.ret([]) + else: + # update return type + if isinstance(self.ret_type, tuple): + self.prototype.ret_types = list(self.ret_type) + self.fn.reset_type(self.prototype.to_ir(self.builder)) + else: + self.prototype.ret_types = [self.ret_type] + self.fn.reset_type(self.prototype.to_ir(self.builder)) + if insert_pt: + self.builder.set_insertion_point_to_end(insert_pt) + # Remove dead code + self.fn.finalize() + + def visit_arguments(self, node): + arg_names = [] + for arg in node.args: + arg_names += [self.visit(arg)] + kwarg_names = self.visit(node.kwarg) + return arg_names, kwarg_names + + def visit_arg(self, node): + ast.NodeVisitor.generic_visit(self, node) + return node.arg + + def visit_AnnAssign(self, node): + # extract attributes + annotation = self.visit(node.annotation) + target = self.visit(node.target) + value = self.visit(node.value) + # constexpr + if annotation == constexpr: + if target in self.lscope: + raise ValueError(f'{target} is already defined.' + f' constexpr cannot be reassigned.') + if not _is_constexpr(value): + value = constexpr(value) + self.lscope[target] = value + return self.lscope[target] + # default: call visit_Assign + return self.visit_Assign(node) + + def visit_Assign(self, node): + _names = [] + for target in node.targets: + _names += [self.visit(target)] + if len(_names) > 1: + raise self._unsupported(node, "simultaneous multiple assignment is not supported.") + names = _names[0] + values = self.visit(node.value) + if not _is_list_like(names): + names = [names] + if not _is_list_like(values): + values = [values] + native_nontensor_types = (language.dtype, ) + for name, value in zip(names, values): + # by default, constexpr are assigned into python variable + value = _unwrap_if_constexpr(value) + if value is not None and \ + not _is_triton_tensor(value) and \ + not isinstance(value, native_nontensor_types): + value = language.core._to_tensor(value, self.builder) + self.set_value(name, value) + + def visit_AugAssign(self, node): + name = node.target.id + lhs = ast.Name(id=name, ctx=ast.Load()) + rhs = ast.BinOp(lhs, node.op, node.value) + assign = ast.Assign(targets=[node.target], value=rhs) + self.visit(assign) + return self.dereference_name(name) + + def visit_Name(self, node): + if type(node.ctx) == ast.Store: + return node.id + return self.dereference_name(node.id) + + def visit_Store(self, node): + ast.NodeVisitor.generic_visit(self, node) + + def visit_Load(self, node): + ast.NodeVisitor.generic_visit(self, node) + + def visit_Tuple(self, node): + args = [self.visit(x) for x in node.elts] + return tuple(args) + + def _apply_binary_method(self, method_name, lhs, rhs): + # TODO: raise something meaningful if getattr fails below, esp for reverse method + if _is_triton_tensor(lhs): + return getattr(lhs, method_name)(rhs, _builder=self.builder) + if _is_triton_tensor(rhs): + reverse_method_name = re.sub(r"__(.*)__", r"__r\1__", method_name) + return getattr(rhs, reverse_method_name)(lhs, _builder=self.builder) + return getattr(lhs, method_name)(rhs) + + def visit_BinOp(self, node): + lhs = self.visit(node.left) + rhs = self.visit(node.right) + method_name = self._method_name_for_bin_op.get(type(node.op)) + if method_name is None: + raise self._unsupported(node, + "AST binary operator '{}' is not (currently) implemented.".format(node.op.__name__)) + return self._apply_binary_method(method_name, lhs, rhs) + + _method_name_for_bin_op: Dict[Type[ast.operator], str] = { + ast.Add: '__add__', + ast.Sub: '__sub__', + ast.Mult: '__mul__', + ast.Div: '__truediv__', + ast.FloorDiv: '__floordiv__', + ast.Mod: '__mod__', + ast.Pow: '__pow__', + ast.LShift: '__lshift__', + ast.RShift: '__rshift__', + ast.BitAnd: '__and__', + ast.BitOr: '__or__', + ast.BitXor: '__xor__', + } + + def visit_then_else_blocks(self, node, liveins, then_block, else_block): + # then block + self.builder.set_insertion_point_to_start(then_block) + self.visit_compound_statement(node.body) + then_block = self.builder.get_insertion_block() + then_defs = self.local_defs.copy() + # else block + else_defs = {} + if node.orelse: + self.builder.set_insertion_point_to_start(else_block) + self.lscope = liveins.copy() + self.local_defs = {} + self.visit_compound_statement(node.orelse) + else_defs = self.local_defs.copy() + else_block = self.builder.get_insertion_block() + + # update block arguments + names = [] + ret_types = [] + ir_ret_types = [] + # variables in livein whose value is updated in `if` + for name in liveins: + # check type + for defs, block_name in [(then_defs, 'then'), (else_defs, 'else')]: + if name in defs: + assert defs[name].type == liveins[name].type, \ + f'initial value for `{name}` is of type {liveins[name].type}, '\ + f'but the {block_name} block redefines it as {defs[name].type}' + if name in then_defs or name in else_defs: + names.append(name) + ret_types.append(then_defs[name].type if name in then_defs else else_defs[name].type) + ir_ret_types.append(then_defs[name].handle.get_type() if name in + then_defs else else_defs[name].handle.get_type()) + # variable defined in then but not in else + if name in then_defs and name not in else_defs: + else_defs[name] = liveins[name] + # variable defined in else but not in then + if name in else_defs and name not in then_defs: + then_defs[name] = liveins[name] + # variables that are both in then and else but not in liveins + # TODO: could probably be cleaned up + for name in then_defs.keys() & else_defs.keys(): + if name in names: + continue + then_ty = then_defs[name].type + else_ty = else_defs[name].type + assert then_ty == else_ty, \ + f'mismatched type for {name} between then block ({then_ty}) '\ + f'and else block ({else_ty})' + names.append(name) + ret_types.append(then_ty) + ir_ret_types.append(then_defs[name].handle.get_type()) + + return then_defs, else_defs, then_block, else_block, names, ret_types, ir_ret_types + + def visit_if_top_level(self, cond, node): + has_endif_block = True + with enter_sub_region(self) as sr: + liveins, ip_block = sr + then_block = self.builder.create_block() + else_block = self.builder.create_block() + # create basic-block after conditional + endif_block = self.builder.create_block() + # create branch + self.builder.set_insertion_point_to_end(ip_block) + self.builder.create_cond_branch(cond.handle, then_block, else_block) + # visit then and else blocks + then_defs, else_defs, then_block, else_block, names, ret_types, ir_ret_types = \ + self.visit_then_else_blocks(node, liveins, then_block, else_block) + # then terminator + self.builder.set_insertion_point_to_end(then_block) + if then_block.has_return() and else_block.has_return(): + has_endif_block = False + endif_block.erase() + if not then_block.has_terminator() and has_endif_block: + self.builder.create_branch(endif_block, [then_defs[n].handle for n in names]) + # else terminator + self.builder.set_insertion_point_to_end(else_block) + if not else_block.has_terminator() and has_endif_block: + self.builder.create_branch(endif_block, [else_defs[n].handle for n in names]) + if has_endif_block: + for ty in ir_ret_types: + endif_block.add_argument(ty) + if has_endif_block: + # change block + self.builder.set_insertion_point_to_start(endif_block) + # update value + for i, name in enumerate(names): + new_tensor = language.core.tensor(endif_block.arg(i), ret_types[i]) + self.set_value(name, new_tensor) + + # TODO: refactor + def visit_if_scf(self, cond, node): + with enter_sub_region(self) as sr: + liveins, _ = sr + ip, last_loc = self._get_insertion_point_and_loc() + then_block = self.builder.create_block() + else_block = self.builder.create_block() if node.orelse else None + then_defs, else_defs, then_block, else_block, names, ret_types, _ = \ + self.visit_then_else_blocks(node, liveins, then_block, else_block) + # create if op + self._set_insertion_point_and_loc(ip, last_loc) + if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, True) + then_block.merge_block_before(if_op.get_then_block()) + self.builder.set_insertion_point_to_end(if_op.get_then_block()) + if len(names) > 0: + self.builder.create_yield_op([then_defs[n].handle for n in names]) + if not node.orelse: + else_block = if_op.get_else_block() + else: + else_block.merge_block_before(if_op.get_else_block()) + self.builder.set_insertion_point_to_end(if_op.get_else_block()) + if len(names) > 0: + self.builder.create_yield_op([else_defs[n].handle for n in names]) + # update values + for i, name in enumerate(names): + new_tensor = language.core.tensor(if_op.get_result(i), ret_types[i]) + self.set_value(name, new_tensor) + + def visit_If(self, node): + cond = self.visit(node.test) + if _is_triton_tensor(cond): + cond = cond.to(language.int1, _builder=self.builder) + contains_return = ContainsReturnChecker(self.gscope).visit(node) + if self.scf_stack and contains_return: + raise self._unsupported( + node, "Cannot have `return` statements inside `while` or `for` statements in triton " + "(note that this also applies to `return` statements that are inside functions " + "transitively called from within `while`/`for` statements)") + elif self.scf_stack or not contains_return: + self.visit_if_scf(cond, node) + else: + self.visit_if_top_level(cond, node) + else: + cond = _unwrap_if_constexpr(cond) + # not isinstance - we insist the real thing, no subclasses and no ducks + if type(cond) not in _condition_types: + raise self._unsupported( + node, "`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format( + ', '.join(_.__name__ for _ in _condition_types), + type(cond).__name__)) + if cond: + self.visit_compound_statement(node.body) + else: + self.visit_compound_statement(node.orelse) + + def visit_IfExp(self, node): + cond = self.visit(node.test) + if _is_triton_tensor(cond): + cond = cond.to(language.int1, _builder=self.builder) + # TODO: Deal w/ more complicated return types (e.g tuple) + with enter_sub_region(self): + ip, last_loc = self._get_insertion_point_and_loc() + + then_block = self.builder.create_block() + self.builder.set_insertion_point_to_start(then_block) + then_val = language.core._to_tensor(self.visit(node.body), self.builder) + then_block = self.builder.get_insertion_block() + + else_block = self.builder.create_block() + self.builder.set_insertion_point_to_start(else_block) + # do not need to reset lscope since + # ternary expressions cannot define new variables + else_val = language.core._to_tensor(self.visit(node.orelse), self.builder) + else_block = self.builder.get_insertion_block() + + self._set_insertion_point_and_loc(ip, last_loc) + + assert then_val.type == else_val.type, \ + f'ternary expression with dynamic condition has inconsistent types {then_val.type} and {else_val.type}' + ret_type = then_val.type + + ret_type_ir = [ret_type.to_ir(self.builder)] if ret_type != language.void else [] + if_op = self.builder.create_if_op(ret_type_ir, cond.handle, True) + then_block.merge_block_before(if_op.get_then_block()) + if ret_type_ir: + self.builder.set_insertion_point_to_end(if_op.get_then_block()) + self.builder.create_yield_op([then_val.handle]) + + self.builder.set_insertion_point_to_end(if_op.get_then_block()) + else_block.merge_block_before(if_op.get_else_block()) + if ret_type_ir: + self.builder.set_insertion_point_to_end(if_op.get_else_block()) + self.builder.create_yield_op([else_val.handle]) + return language.core.tensor(if_op.get_result(0), ret_type) if ret_type_ir else None + else: + cond = _unwrap_if_constexpr(cond) + + # not isinstance - we insist the real thing, no subclasses and no ducks + if type(cond) not in _condition_types: + raise self._unsupported( + node, "`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format( + ', '.join(_.__name__ for _ in _condition_types), + type(cond).__name__)) + if cond: + return self.visit(node.body) + else: + return self.visit(node.orelse) + + def visit_Pass(self, node): + pass + + def visit_Compare(self, node): + if not (len(node.comparators) == 1 and len(node.ops) == 1): + raise self._unsupported(node, "simultaneous multiple comparison is not supported") + lhs = self.visit(node.left) + rhs = self.visit(node.comparators[0]) + lhs_value = _unwrap_if_constexpr(lhs) + rhs_value = _unwrap_if_constexpr(rhs) + if type(node.ops[0]) == ast.Is: + return constexpr(lhs_value is rhs_value) + if type(node.ops[0]) == ast.IsNot: + return constexpr(lhs_value is not rhs_value) + method_name = self._method_name_for_comp_op.get(type(node.ops[0])) + if method_name is None: + raise self._unsupported( + node, "AST comparison operator '{}' is not (currently) implemented.".format(node.ops[0].__name__)) + return self._apply_binary_method(method_name, lhs, rhs) + + _method_name_for_comp_op: Dict[Type[ast.cmpop], str] = { + ast.Eq: '__eq__', ast.NotEq: '__ne__', ast.Lt: '__lt__', ast.LtE: '__le__', ast.Gt: '__gt__', ast.GtE: '__ge__' + } + + def visit_UnaryOp(self, node): + operand = self.visit(node.operand) + fn = self._method_name_for_unary_op.get(type(node.op)) + if fn is None: + raise self._unsupported(node, f"AST unary operator '{node.op.__name__}' is not (currently) implemented.") + if _is_triton_tensor(operand): + return getattr(operand, fn)(_builder=self.builder) + try: + return getattr(operand, fn)() + except AttributeError: + raise self._unsupported( + node, f"AST unary operator '{fn}' is not (currently) implemented on type {type(operand).__name__}") + + _method_name_for_unary_op: Dict[Type[ast.unaryop], str] = { + ast.USub: '__neg__', ast.UAdd: '__pos__', ast.Not: '__not__', ast.Invert: '__invert__' + } + + def visit_While(self, node): + with enter_sub_region(self) as sr: + liveins, insert_block = sr + ip, last_loc = self._get_insertion_point_and_loc() + + # loop body (the after region) + # loop_block = self.builder.create_block() + dummy = self.builder.create_block() + self.builder.set_insertion_point_to_start(dummy) + self.scf_stack.append(node) + self.visit_compound_statement(node.body) + self.scf_stack.pop() + loop_defs = self.local_defs + dummy.erase() + + # collect loop-carried values + names = [] + ret_types = [] + init_args = [] + for name in loop_defs: + if name in liveins: + # We should not def new constexpr + assert _is_triton_tensor(loop_defs[name]), f'cannot reassign constxpr {name} in the loop' + assert _is_triton_tensor(liveins[name]), f'cannot reasign constexpr {name} in the loop' + assert loop_defs[name].type == liveins[name].type, \ + f'Loop-carried variable {name} has initial type {liveins[name].type} '\ + f'but is re-assigned to {loop_defs[name].type} in loop! '\ + f'Please make sure that the type stays consistent.' + + # these are loop-carried values + names.append(name) + ret_types.append(loop_defs[name].type) + init_args.append(liveins[name]) + + self._set_insertion_point_and_loc(ip, last_loc) + while_op = self.builder.create_while_op([ty.to_ir(self.builder) for ty in ret_types], + [arg.handle for arg in init_args]) + # merge the condition region + before_block = self.builder.create_block_with_parent(while_op.get_before(), + [ty.to_ir(self.builder) for ty in ret_types]) + self.builder.set_insertion_point_to_start(before_block) + for i, name in enumerate(names): + self.lscope[name] = language.core.tensor(before_block.arg(i), ret_types[i]) + self.local_defs[name] = self.lscope[name] + cond = self.visit(node.test) + self.builder.set_insertion_point_to_end(before_block) + # create ConditionOp: e.g., scf.condition(%cond) %arg0, %arg1, ... + self.builder.create_condition_op(cond.handle, [before_block.arg(i) for i in range(len(init_args))]) + # merge the loop body + after_block = self.builder.create_block_with_parent(while_op.get_after(), + [ty.to_ir(self.builder) for ty in ret_types]) + + # generate loop body + self.builder.set_insertion_point_to_start(after_block) + for i, name in enumerate(names): + self.lscope[name] = language.core.tensor(after_block.arg(i), ret_types[i]) + self.local_defs[name] = self.lscope[name] + self.scf_stack.append(node) + self.visit_compound_statement(node.body) + self.scf_stack.pop() + loop_defs = self.local_defs + yields = [] + for name in loop_defs: + if name in liveins: + yields.append(loop_defs[name]) + self.builder.create_yield_op([y.handle for y in yields]) + + # WhileOp defines new values, update the symbol table (lscope, local_defs) + for i, name in enumerate(names): + new_def = language.core.tensor(while_op.get_result(i), ret_types[i]) + self.lscope[name] = new_def + self.local_defs[name] = new_def + + for stmt in node.orelse: + assert False, "Not implemented" + ast.NodeVisitor.generic_visit(self, stmt) + + def visit_Subscript(self, node): + assert node.ctx.__class__.__name__ == "Load" + lhs = self.visit(node.value) + slices = self.visit(node.slice) + if _is_triton_tensor(lhs): + return lhs.__getitem__(slices, _builder=self.builder) + return lhs[slices] + + def visit_ExtSlice(self, node): + return [self.visit(dim) for dim in node.dims] + + def visit_For(self, node): + IteratorClass = self.visit(node.iter.func) + iter_args = [self.visit(arg) for arg in node.iter.args] + iter_kwargs = dict(self.visit(keyword) for keyword in node.iter.keywords) + if IteratorClass == language.static_range: + iterator = IteratorClass(*iter_args, **iter_kwargs) + static_range = range(iterator.start.value, iterator.end.value, iterator.step.value) + for i in static_range: + self.lscope[node.target.id] = constexpr(i) + self.visit_compound_statement(node.body) + for stmt in node.orelse: + ast.NodeVisitor.generic_visit(self, stmt) + return + num_stages = None + if IteratorClass is language.range: + iterator = IteratorClass(*iter_args, **iter_kwargs) + # visit iterator arguments + # note: only `range` iterator is supported now + # collect lower bound (lb), upper bound (ub), and step + lb = iterator.start + ub = iterator.end + step = iterator.step + num_stages = iterator.num_stages + elif IteratorClass is range: + # visit iterator arguments + # note: only `range` iterator is supported now + # collect lower bound (lb), upper bound (ub), and step + lb = iter_args[0] if len(iter_args) > 1 else self.visit(ast.Num(0)) + ub = iter_args[1] if len(iter_args) > 1 else self.visit(node.iter.args[0]) + step = iter_args[2] if len(iter_args) > 2 else self.visit(ast.Num(1)) + else: + raise RuntimeError('Only `range` and `static_range` iterators are currently supported') + # handle negative constant step (not supported by scf.for in MLIR) + negative_step = False + if _is_constexpr(step) and step.value < 0: + step = constexpr(-step.value) + negative_step = True + lb, ub = ub, lb + lb = language.core._to_tensor(lb, self.builder) + ub = language.core._to_tensor(ub, self.builder) + step = language.core._to_tensor(step, self.builder) + # induction variable type + if not lb.dtype.is_int() or not ub.dtype.is_int() or not step.dtype.is_int(): + raise TypeError(f"For loop bounds and step must all be ints, are ({lb.dtype}, {ub.dtype}, {step.dtype})") + iv_type = language.semantic.integer_promote_impl(lb.dtype, ub.dtype) + iv_type = language.semantic.integer_promote_impl(iv_type, step.dtype) + iv_ir_type = iv_type.to_ir(self.builder) + iv_is_signed = iv_type.int_signedness == language.core.dtype.SIGNEDNESS.SIGNED + # lb/ub/step might be constexpr, we need to cast them to tensor + lb = lb.handle + ub = ub.handle + step = step.handle + # ForOp can only accept IndexType as lb/ub/step. Cast integer to Index + lb = self.builder.create_int_cast(lb, iv_ir_type, iv_is_signed) + ub = self.builder.create_int_cast(ub, iv_ir_type, iv_is_signed) + step = self.builder.create_int_cast(step, iv_ir_type, iv_is_signed) + # Create placeholder for the loop induction variable + iv = self.builder.create_undef(iv_ir_type) + self.set_value(node.target.id, language.core.tensor(iv, iv_type)) + + with enter_sub_region(self) as sr: + liveins, insert_block = sr + ip, last_loc = self._get_insertion_point_and_loc() + + # create loop body block + block = self.builder.create_block() + self.builder.set_insertion_point_to_start(block) + # dry visit loop body + self.scf_stack.append(node) + self.visit_compound_statement(node.body) + self.scf_stack.pop() + block.erase() + + # If a variable (name) is defined in both its parent & itself, then it's + # a loop-carried variable. (They must be of the same type) + init_args = [] + yields = [] + names = [] + for name in self.local_defs: + if name in liveins: + assert _is_triton_tensor(self.local_defs[name]), f'{name} is not tensor' + assert _is_triton_tensor(liveins[name]) + assert self.local_defs[name].type == liveins[name].type, \ + f'Loop-carried variable {name} has initial type {liveins[name].type} '\ + f'but is re-assigned to {self.local_defs[name].type} in loop! '\ + f'Please make sure that the type stays consistent.' + + names.append(name) + init_args.append(language.core._to_tensor(liveins[name], self.builder)) + yields.append(language.core._to_tensor(self.local_defs[name], self.builder)) + + # create ForOp + self._set_insertion_point_and_loc(ip, last_loc) + for_op = self.builder.create_for_op(lb, ub, step, [arg.handle for arg in init_args]) + if num_stages is not None: + for_op.set_attr("tt.num_stages", self.builder.get_int32_attr(num_stages)) + + self.scf_stack.append(node) + self.builder.set_insertion_point_to_start(for_op.get_body(0)) + # reset local scope to not pick up local defs from the previous dry run. + self.lscope = liveins.copy() + self.local_defs = {} + for i, name in enumerate(names): + self.set_value(name, language.core.tensor(for_op.get_body(0).arg(i + 1), yields[i].type)) + self.visit_compound_statement(node.body) + self.scf_stack.pop() + yields = [] + for name in self.local_defs: + if name in liveins: + yields.append(language.core._to_tensor(self.local_defs[name], self.builder)) + + # create YieldOp + if len(yields) > 0: + self.builder.create_yield_op([y.handle for y in yields]) + for_op_region = for_op.get_body(0).get_parent() + assert for_op_region.size() == 1, "We use SCF, so the loop body should only have one block" + + # update induction variable with actual value, and replace all uses + self.builder.set_insertion_point_to_start(for_op.get_body(0)) + iv = for_op.get_induction_var() + if negative_step: + iv = self.builder.create_sub(ub, iv) + iv = self.builder.create_add(iv, lb) + self.lscope[node.target.id].handle.replace_all_uses_with(iv) + self.set_value(node.target.id, language.core.tensor(iv, iv_type)) + + # update lscope & local_defs (ForOp defines new values) + for i, name in enumerate(names): + self.set_value(name, language.core.tensor(for_op.get_result(i), yields[i].type)) + + for stmt in node.orelse: + assert False, "Don't know what to do with else after for" + ast.NodeVisitor.generic_visit(self, stmt) + + def visit_Slice(self, node): + lower = self.visit(node.lower) + upper = self.visit(node.upper) + step = self.visit(node.step) + return slice(lower, upper, step) + + def visit_Index(self, node): + return self.visit(node.value) + + def visit_keyword(self, node) -> Tuple[str, Any]: + return node.arg, self.visit(node.value) + + def visit_Assert(self, node) -> Any: + if not self.debug: + return + test = self.visit(node.test) + msg = self.visit(node.msg) if node.msg is not None else "" + # Convert assert to triton's device_assert which happens on the device + return language.core.device_assert(test, msg, _builder=self.builder) + + def call_JitFunction(self, fn: JITFunction, args, kwargs): + args = inspect.getcallargs(fn.fn, *args, **kwargs) + args = [args[name] for name in fn.arg_names] + args = [arg if _is_triton_tensor(arg) else constexpr(arg) for arg in args] + # generate function def + attributes = dict() + constexprs = [i for i, arg in enumerate(args) if _is_constexpr(arg)] + constants = {i: args[i] for i in constexprs} + # generate call + args = [None if i in constexprs else arg for i, arg in enumerate(args)] + arg_vals = [arg.handle for arg in args if arg is not None] + arg_types = [arg.type for arg in args if arg is not None] + fn_name = mangle_fn(fn.__name__, arg_types, constants) + # generate function def if necessary + if not self.module.has_function(fn_name): + prototype = language.function_type([], arg_types) + gscope = fn.__globals__ + # If the callee is not set, we use the same debug setting as the caller + file_name, begin_line = _get_fn_file_line(fn) + debug = self.debug if fn.debug is None else fn.debug + generator = CodeGenerator(self.context, prototype, gscope, attributes, constants, module=self.module, + jit_fn=fn, function_name=fn_name, function_types=self.function_ret_types, + noinline=fn.noinline, file_name=file_name, begin_line=begin_line, + options=self.builder.options, codegen_fns=self.builder.codegen_fns, debug=debug) + try: + generator.visit(fn.parse()) + except Exception as e: + # Wrap the error in the callee with the location of the call. + raise CompilationError(self.jit_fn.src, self.cur_node, None) from e + + callee_ret_type = generator.ret_type + self.function_ret_types[fn_name] = callee_ret_type + else: + callee_ret_type = self.function_ret_types[fn_name] + symbol = self.module.get_function(fn_name) + call_op = self.builder.call(symbol, arg_vals) + if call_op.get_num_results() == 0 or callee_ret_type is None: + return None + elif call_op.get_num_results() == 1: + return tensor(call_op.get_result(0), callee_ret_type) + else: + # should return a tuple of tl.tensor + results = [] + for i in range(call_op.get_num_results()): + results.append(tensor(call_op.get_result(i), callee_ret_type[i])) + return tuple(results) + + def visit_Call(self, node): + fn = _unwrap_if_constexpr(self.visit(node.func)) + static_implementation = self.statically_implemented_functions.get(fn) + if static_implementation is not None: + return static_implementation(self, node) + + kws = dict(self.visit(keyword) for keyword in node.keywords) + args = [self.visit(arg) for arg in node.args] + if fn is language.core.device_assert: # TODO: this should not be so hardcoded + if not self.debug: + return + if isinstance(fn, JITFunction): + _check_fn_args(node, fn, args) + return self.call_JitFunction(fn, args, kws) + if (hasattr(fn, '__self__') and _is_triton_tensor(fn.__self__)) or language.core.is_builtin(fn): + extra_kwargs = dict(_builder=self.builder) + sig = inspect.signature(fn) + if '_generator' in sig.parameters: + extra_kwargs['_generator'] = self + try: + return fn(*args, **extra_kwargs, **kws) + except Exception as e: + # Normally when we raise a CompilationError, we raise it as + # `from None`, because the original fileline from the exception + # is not relevant (and often points into code_generator.py + # itself). But when calling a function, we raise as `from e` to + # preserve the traceback of the original error, which may e.g. + # be in core.py. + raise CompilationError(self.jit_fn.src, node, None) from e + + if fn in self.builtin_namespace.values(): + args = map(_unwrap_if_constexpr, args) + return fn(*args, **kws) + + def visit_Constant(self, node): + return constexpr(node.value) + + def visit_BoolOp(self, node: ast.BoolOp): + if len(node.values) != 2: + raise self._unsupported( + node, "chained boolean operators (A or B or C) are not supported; use parentheses to split the chain.") + lhs = self.visit(node.values[0]) + rhs = self.visit(node.values[1]) + method_name = self._method_name_for_bool_op.get(type(node.op)) + if method_name is None: + raise self._unsupported( + node, "AST boolean operator '{}' is not (currently) implemented.".format(node.op.__name__)) + return self._apply_binary_method(method_name, lhs, rhs) + + _method_name_for_bool_op: Dict[Type[ast.boolop], str] = {ast.And: 'logical_and', ast.Or: 'logical_or'} + + if sys.version_info < (3, 8): + + def visit_NameConstant(self, node): + return constexpr(node.value) + + def visit_Num(self, node): + return constexpr(node.n) + + def visit_Str(self, node): + return constexpr(ast.literal_eval(node)) + + def visit_Attribute(self, node): + lhs = self.visit(node.value) + if _is_triton_tensor(lhs): + if node.attr == "T": + return language.semantic.permute(lhs, (1, 0), builder=self.builder) + return getattr(lhs, node.attr) + + def visit_Expr(self, node): + ast.NodeVisitor.generic_visit(self, node) + + def visit_NoneType(self, node): + return None + + def visit_JoinedStr(self, node): + values = list(node.values) + for i, value in enumerate(values): + if isinstance(value, ast.Constant): + values[i] = str(value.value) + elif isinstance(value, ast.FormattedValue): + conversion_code = value.conversion + evaluated = self.visit(value.value) + if not _is_constexpr(evaluated): + raise self._unsupported( + node, + "Cannot evaluate f-string containing non-constexpr conversion values, found conversion of type " + + str(type(evaluated))) + values[i] = ("{}" if conversion_code < 0 else "{!" + chr(conversion_code) + "}").format(evaluated.value) + else: + raise AssertionError("encountered unexpected node of type {} in a JoinedStr node".format(type(value))) + return ''.join(values) + + def visit(self, node): + if node is None: + return + with warnings.catch_warnings(): + # The ast library added visit_Constant and deprecated some other + # methods but we can't move to that without breaking Python 3.6 and 3.7. + warnings.simplefilter("ignore", DeprecationWarning) # python 3.9 + warnings.simplefilter("ignore", PendingDeprecationWarning) # python 3.8 + last_node = self.cur_node + last_loc = self.builder.get_loc() + self.cur_node = node + if hasattr(node, 'lineno') and hasattr(node, 'col_offset'): + self.builder.set_loc(self.file_name, self.begin_line + node.lineno, node.col_offset) + last_loc = self.builder.get_loc() + try: + ret = super().visit(node) + except CompilationError: + raise + except Exception as e: + # Wrap the error in a CompilationError which contains the source + # of the @jit function. + raise CompilationError(self.jit_fn.src, self.cur_node, repr(e)) from None + + # Reset the location to the last one before the visit + if last_loc: + self.cur_node = last_node + self.builder.set_loc(last_loc) + return ret + + def generic_visit(self, node): + raise self._unsupported(node, "unsupported AST node type: {}".format(type(node).__name__)) + + def execute_static_assert(self, node: ast.Call) -> None: + arg_count = len(node.args) + if not (0 < arg_count <= 2) or len(node.keywords): + raise TypeError("`static_assert` requires one or two positional arguments only") + + passed = _unwrap_if_constexpr(self.visit(node.args[0])) + if not isinstance(passed, bool): + raise NotImplementedError( + "Assertion condition could not be determined at compile-time. Make sure that it depends only on `constexpr` values" + ) + if not passed: + if arg_count == 1: + message = "" + else: + try: + message = self.visit(node.args[1]) + except Exception as e: + message = "" + + raise CompileTimeAssertionFailure(self.jit_fn.src, node, _unwrap_if_constexpr(message)) + return None + + def static_executor(python_fn): + + def ret(self, node: ast.Call): + kws = { + name: _unwrap_if_constexpr(value) + for name, value in (self.visit(keyword) for keyword in node.keywords) + } + args = [_unwrap_if_constexpr(self.visit(arg)) for arg in node.args] + return constexpr(python_fn(*args, **kws)) + + return ret + + statically_implemented_functions: Dict[object, Callable[[ast.Call], Any]] = { + language.core.static_assert: execute_static_assert, + language.core.static_print: static_executor(print), + int: static_executor(int), + len: static_executor(len), + } + + +def kernel_suffix(signature, specialization): + # suffix format: + # <'c' if equal to 1><'d' if divisible by 16><'e' if divisible by 8> + suffix = '' + for i, _ in enumerate(signature): + suffix += str(i) + if i in specialization.equal_to_1: + suffix += 'c' + if i in specialization.divisible_by_16: + suffix += 'd' + return suffix + + +def ast_to_ttir(fn, specialization, context, options, codegen_fns): + attrs = specialization.attrs + # create kernel prototype + cst_key = lambda i: fn.arg_names.index(i) if isinstance(i, str) else i + constants = {cst_key(key): value for key, value in specialization.constants.items()} + # visit kernel AST + gscope = fn.__globals__.copy() + function_name = fn.repr(specialization) + tys = list(specialization.signature.values()) + new_constants = {k: True if k in tys and tys[k] == "i1" else 1 for k in attrs.equal_to_1} + new_attrs = {k: [("tt.divisibility", 16)] for k in attrs.divisible_by_16} + + all_constants = constants.copy() + all_constants.update(new_constants) + arg_types = [str_to_ty(v) for k, v in specialization.signature.items() if k not in specialization.constants] + file_name, begin_line = _get_fn_file_line(fn) + + prototype = language.function_type([], arg_types) + generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants, function_name=function_name, + jit_fn=fn, attributes=new_attrs, is_kernel=True, file_name=file_name, + begin_line=begin_line, options=options, codegen_fns=codegen_fns) + generator.visit(fn.parse()) + + ret = generator.module + # module takes ownership of the context + ret.context = context + return ret diff --git a/lib/python3.10/site-packages/triton/compiler/compiler.py b/lib/python3.10/site-packages/triton/compiler/compiler.py new file mode 100644 index 0000000000000000000000000000000000000000..367aa1b1a325ef1d3a352c0d33f513aeb4388d31 --- /dev/null +++ b/lib/python3.10/site-packages/triton/compiler/compiler.py @@ -0,0 +1,412 @@ +from __future__ import annotations +import hashlib +import json +from .._C.libtriton import get_cache_invalidating_env_vars, ir +from ..backends import backends +from ..backends.compiler import GPUTarget +from .. import __version__ +from ..runtime.autotuner import OutOfResources +from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_manager +from ..runtime.driver import driver +# TODO: this shouldn't be here +from dataclasses import dataclass +from .code_generator import ast_to_ttir +from pathlib import Path +import re +import functools +import os + + +@dataclass +class AttrsDescriptor: + divisible_by_16: set = None + equal_to_1: set = None + + def __post_init__(self): + if self.divisible_by_16 is None: + self.divisible_by_16 = set() + if self.equal_to_1 is None: + self.equal_to_1 = set() + + def to_dict(self): + return {'divisible_by_16': list(self.divisible_by_16), 'equal_to_1': list(self.equal_to_1)} + + @staticmethod + def from_dict(data): + return AttrsDescriptor(divisible_by_16=set(data.get('divisible_by_16', [])), + equal_to_1=set(data.get('equal_to_1', []))) + + def hash(self): + key = str([sorted(x) for x in self.__dict__.values()]) + return hashlib.sha256(key.encode("utf-8")).hexdigest() + + +# - ^\s*tt\.func\s+ : match the start of the string, any leading whitespace, the keyword func, +# and any following whitespace +# - (public\s+)? : optionally match the keyword public and any following whitespace +# - (@\w+) : match an @ symbol followed by one or more word characters +# (letters, digits, or underscores), and capture it as group 1 (the function name) +# - (\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\)) : match a pair of parentheses enclosing +# zero or more arguments separated by commas, and capture it as group 2 (the argument list) +# - (attributes \{[\S\s]+\})? : optionally match attributes enclosed in braces and capture it as group 3 +mlir_prototype_pattern = r"^\s*tt\.func\s+(?:public\s+)?(@\w+)(\((?:%\w+: [\S\s]+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*(attributes \{[\S\s]+\})?\s+\{\s*$" +ptx_prototype_pattern = r"\.(?:visible|extern)\s+\.(?:entry|func)\s+(\w+)\s*\(([^)]*)\)" +prototype_pattern = { + "ttir": mlir_prototype_pattern, + "ttgir": mlir_prototype_pattern, + "ptx": ptx_prototype_pattern, +} + +mlir_arg_type_pattern = r'%\w+: ((?:[^,\s<)]+|<[^>]+>)+),?' +ptx_arg_type_pattern = r"\.param\s+\.(\w+)" +arg_type_pattern = { + "ttir": mlir_arg_type_pattern, + "ttgir": mlir_arg_type_pattern, + "ptx": ptx_arg_type_pattern, +} + + +def convert_type_repr(x): + # Currently we only capture the pointer type and assume the pointer is on global memory. + # TODO: Capture and support shared memory space + match = re.search(r'!tt\.ptr<([^,]+)', x) + if match is not None: + return '*' + convert_type_repr(match.group(1)) + return x + + +def _get_num_warps_from_ir_str(src: str): + ttgir_num_warps_pattern = r'"triton_gpu.num-warps"\s?=\s?(\d+)\s?:' + # TODO(jlebar): Using a regex to get num-warps is a hack, and will break if + # e.g. someone has an instruction (not module) attribute named "num-warps". + num_warps_matches = re.findall(ttgir_num_warps_pattern, src) + assert len(num_warps_matches) == 1, "Expected exactly one match for num_warps" + num_warps = int(num_warps_matches[0]) + return num_warps + + +class ASTSource: + + def __init__(self, fn, signature, constants=None, attrs=None) -> None: + self.fn = fn + self.ext = "ttir" + self.name = fn.__name__ + self.signature = signature + self.constants = constants + self.attrs = attrs + if isinstance(self.signature, str): + self.signature = {k: v.strip() for k, v in enumerate(self.signature.split(","))} + if self.constants is None: + self.constants = dict() + if self.attrs is None: + self.attrs = AttrsDescriptor() + + def hash(self): + sorted_sig = [v for k, v in sorted(self.signature.items())] + # Note - we stringify the keys here to allow sorting to work for cases + # where constants have mixed int/str keys. + sorted_constants = sorted((str(k), v) for k, v in self.constants.items()) + key = f"{self.fn.cache_key}-{self.attrs.hash()}-{sorted_sig}-{sorted_constants}" + return hashlib.sha256(key.encode("utf-8")).hexdigest() + + def make_ir(self, options, codegen_fns, context): + return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns) + + def parse_options(self): + return dict() + + +class IRSource: + + def __init__(self, path): + self.path = path + path = Path(path) + self.ext = path.suffix[1:] + self.src = path.read_text() + match = re.search(prototype_pattern[self.ext], self.src, re.MULTILINE) + self.name = match.group(1) + signature = match.group(2) + types = re.findall(arg_type_pattern[self.ext], signature) + self.signature = {k: convert_type_repr(ty) for k, ty in enumerate(types)} + + def hash(self): + return hashlib.sha256(self.src.encode("utf-8")).hexdigest() + + def make_ir(self, options, codegen_fns, context): + module = ir.parse_mlir_module(self.path, context) + module.context = context + return module + + def parse_options(self): + if self.ext == "ttgir": + return {'num_warps': _get_num_warps_from_ir_str(self.src)} + return dict() + + +@functools.lru_cache() +def triton_key(): + import pkgutil + TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + contents = [] + # frontend + with open(__file__, "rb") as f: + contents += [hashlib.sha256(f.read()).hexdigest()] + # compiler + path_prefixes = [ + (os.path.join(TRITON_PATH, "compiler"), "triton.compiler."), + (os.path.join(TRITON_PATH, "backends"), "triton.backends."), + ] + for path, prefix in path_prefixes: + for lib in pkgutil.walk_packages([path], prefix=prefix): + with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f: + contents += [hashlib.sha256(f.read()).hexdigest()] + + # backend + libtriton_hash = hashlib.sha256() + with open(os.path.join(TRITON_PATH, "_C/libtriton.so"), "rb") as f: + while True: + chunk = f.read(1024**2) + if not chunk: + break + libtriton_hash.update(chunk) + contents.append(libtriton_hash.hexdigest()) + # language + language_path = os.path.join(TRITON_PATH, 'language') + for lib in pkgutil.iter_modules([language_path]): + with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f: + contents += [hashlib.sha256(f.read()).hexdigest()] + return f'{__version__}' + '-'.join(contents) + + +def parse(full_name, ext, context): + if ext == "ttir" or ext == "ttgir": + module = ir.parse_mlir_module(full_name, context) + module.context = context + return module + if ext == "llir" or ext == "ptx": + return Path(full_name).read_text() + if ext == "cubin": + return Path(full_name).read_bytes() + + +def filter_traceback(e: BaseException): + """ + Removes code_generator.py and related files from tracebacks. + + These are uninteresting to the user -- "just show me *my* code!" + """ + if e.__cause__ is not None: + filter_traceback(e.__cause__) + if e.__context__ is not None: + filter_traceback(e.__context__) + + # If a user has a file that matches one of these, they're out of luck. + BAD_FILES = [ + "/triton/compiler/code_generator.py", + "/ast.py", + ] + + tb = e.__traceback__ + frames = [] + while tb is not None: + if not any(f for f in BAD_FILES if tb.tb_frame.f_code.co_filename.endswith(f)): + frames.append(tb) + tb = tb.tb_next + + for (cur_frame, next_frame) in zip(frames, frames[1:]): + cur_frame.tb_next = next_frame + + if not frames: + e.__traceback__ = None + else: + frames[-1].tb_next = None + e.__traceback__ = frames[0] + + +def compile(src, target=None, options=None): + if target is None: + target = driver.active.get_current_target() + assert isinstance(target, GPUTarget), "target must be of GPUTarget type" + backend = make_backend(target) + ir_source = not isinstance(src, ASTSource) + # create backend + if ir_source: + assert isinstance(src, str), "source must be either AST or a filepath" + src = IRSource(src) + extra_options = src.parse_options() + options = backend.parse_options(dict(options or dict(), **extra_options)) + # create cache manager + env_vars = get_cache_invalidating_env_vars() + key = f"{triton_key()}-{src.hash()}-{backend.hash()}-{options.hash()}-{str(sorted(env_vars.items()))}" + hash = hashlib.sha256(key.encode("utf-8")).hexdigest() + fn_cache_manager = get_cache_manager(hash) + # For dumping/overriding only hash the source as we want it to be independent of triton + # core changes to make it easier to track kernels by hash. + enable_override = os.environ.get("TRITON_KERNEL_OVERRIDE", "0") == "1" + enable_ir_dump = os.environ.get("TRITON_KERNEL_DUMP", "0") == "1" + fn_override_manager = get_override_manager(src.hash()) if enable_override else None + fn_dump_manager = get_dump_manager(src.hash()) if enable_ir_dump else None + metadata_filename = f"{src.name}.json" + metadata_group = fn_cache_manager.get_group(metadata_filename) or {} + metadata_path = metadata_group.get(metadata_filename) + always_compile = os.environ.get("TRITON_ALWAYS_COMPILE", "0") == "1" + if not always_compile and metadata_path is not None: + # cache hit! + metadata = json.loads(Path(metadata_path).read_text()) + return CompiledKernel(src, metadata_group, hash) + # initialize metadata + metadata = { + "hash": hash, + "target": target, + **options.__dict__, + **env_vars, + } + # run compilation pipeline and populate metadata + stages = dict() + backend.add_stages(stages, options) + first_stage = list(stages.keys()).index(src.ext) + # when the source is an IR file, don't apply the passes related to this stage. This makes it easier to write IR level tests. + if ir_source: + first_stage += 1 + context = ir.context() + ir.load_dialects(context) + backend.load_dialects(context) + codegen_fns = backend.get_codegen_implementation() + try: + module = src.make_ir(options, codegen_fns, context) + except Exception as e: + filter_traceback(e) + raise + use_ttgir_loc = os.environ.get("USE_TTGIR_LOC", "0") == "1" + for ext, compile_ir in list(stages.items())[first_stage:]: + next_module = compile_ir(module, metadata) + ir_filename = f"{src.name}.{ext}" + metadata_group[ir_filename] = fn_cache_manager.put(next_module, ir_filename) + if fn_dump_manager is not None: + fn_dump_manager.put(next_module, ir_filename) + if (fn_override_manager is not None and fn_override_manager.has_file(ir_filename)): + print(f"\nOverriding kernel with file {ir_filename}") + full_name = fn_override_manager.get_file(ir_filename) + next_module = parse(full_name, ext, context) + # use an env variable to parse ttgir from file + if use_ttgir_loc and ext == "ttgir": + ttgir_full_name = fn_cache_manager.get_file(ir_filename) + next_module.create_location_snapshot(ttgir_full_name) + print(f"Create new locations for {ttgir_full_name}") + module = next_module + # write-back metadata + metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename, + binary=False) + fn_cache_manager.put_group(metadata_filename, metadata_group) + # return handle to compiled kernel + return CompiledKernel(src, metadata_group, hash) + + +def make_backend(target): + actives = [x.compiler for x in backends.values() if x.compiler.supports_target(target)] + if len(actives) != 1: + raise RuntimeError( + f"{len(actives)} compatible backends for target ({target.backend}) ({actives}). There should only be one.") + return actives[0](target) + + +class LazyDict: + + def __init__(self, data): + self.data = data + self.extras = [] + + def get(self) -> None: + for func, args in self.extras: + self.data = self.data | func(*args) + self.extras.clear() + return self.data + + def add(self, func, args): + self.extras.append((func, args)) + + +class CompiledKernel: + + # Hooks for external tools to monitor the execution of triton kernels + # TODO: move out of this namespace since it's a runtime thing + launch_enter_hook = None + launch_exit_hook = None + + def __init__(self, src, metadata_group, hash): + from collections import namedtuple + metadata_path = next((Path(p) for c, p in metadata_group.items() if c.endswith(".json"))) + metadata = json.loads(metadata_path.read_text()) + metadata['cluster_dims'] = tuple(metadata['cluster_dims']) + # JSON serialization dumps the target as a dict. Restore it to a GPUTarget. + target = metadata['target'] + metadata['target'] = GPUTarget(target['backend'], target['arch'], target['warp_size']) + KernelMetadata = namedtuple('KernelMetadata', sorted(list(metadata.keys()))) + self.metadata = KernelMetadata(**metadata) + backend = make_backend(self.metadata.target) + self.packed_metadata = backend.pack_metadata(self.metadata) + self.src = src + self.hash = hash + self.name = self.metadata.name + # stores the text of each level of IR that was generated during compilation + asm_files = [Path(p) for c, p in metadata_group.items() if not c.endswith(".json")] + binary_ext = backend.binary_ext + self.asm = { + file.suffix[1:]: file.read_bytes() if file.suffix[1:] == binary_ext else file.read_text() + for file in asm_files + } + self.kernel = self.asm[binary_ext] + # binaries are lazily initialized + # because it involves doing runtime things + # (e.g., checking amount of shared memory on current device) + self.module = None + self.function = None + + def _init_handles(self): + if self.module is not None: + return + device = driver.active.get_current_device() + # create launcher + self.run = driver.active.launcher_cls(self.src, self.metadata) + # not enough shared memory to run the kernel + max_shared = driver.active.utils.get_device_properties(device)["max_shared_mem"] + if self.metadata.shared > max_shared: + raise OutOfResources(self.metadata.shared, max_shared, "shared memory") + # TODO: n_regs, n_spills should be metadata generated when calling `ptxas` + self.module, self.function, self.n_regs, self.n_spills = driver.active.utils.load_binary( + self.name, self.kernel, self.metadata.shared, device) + + def __getattribute__(self, name): + if name == 'run': + self._init_handles() + return super().__getattribute__(name) + + def launch_metadata(self, grid, stream, *args): + if CompiledKernel.launch_enter_hook is None: + return None + ret = LazyDict({"name": self.name, "function": self.function, "stream": stream}) + if not isinstance(self.src, ASTSource) or self.src.fn.launch_metadata is None: + return ret + arg_dict = {} + arg_idx = 0 + for i, arg_name in enumerate(self.src.fn.arg_names): + if i in self.src.fn.constexprs: + arg_dict[arg_name] = self.src.constants[arg_name] + else: + arg_dict[arg_name] = args[arg_idx] + arg_idx += 1 + ret.add(self.src.fn.launch_metadata, (grid, self.metadata, arg_dict)) + return ret + + def __getitem__(self, grid): + self._init_handles() + + def runner(*args, stream=None): + if stream is None: + device = driver.active.get_current_device() + stream = driver.active.get_current_stream(device) + launch_metadata = self.launch_metadata(grid, stream, *args) + self.run(grid[0], grid[1], grid[2], stream, self.function, self.packed_metadata, launch_metadata, + CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, *args) + + return runner diff --git a/lib/python3.10/site-packages/triton/compiler/errors.py b/lib/python3.10/site-packages/triton/compiler/errors.py new file mode 100644 index 0000000000000000000000000000000000000000..39e6c4dfb04dd2067d50ce7c79f762c5e7e2d5b8 --- /dev/null +++ b/lib/python3.10/site-packages/triton/compiler/errors.py @@ -0,0 +1,51 @@ +import ast +from typing import Optional +from ..errors import TritonError + + +class CompilationError(TritonError): + """Base class for all errors raised during compilation""" + source_line_count_max_in_message = 12 + + def _format_message(self) -> str: + node = self.node + if self.src is None: + source_excerpt = " " + else: + if hasattr(node, 'lineno'): + source_excerpt = self.src.split('\n')[:node.lineno][-self.source_line_count_max_in_message:] + if source_excerpt: + source_excerpt.append(' ' * node.col_offset + '^') + source_excerpt = '\n'.join(source_excerpt) + else: + source_excerpt = " " + else: + source_excerpt = self.src + + message = "at {}:{}:\n{}".format(node.lineno, node.col_offset, source_excerpt) if hasattr( + node, 'lineno') else source_excerpt + if self.error_message: + message += '\n' + self.error_message + return message + + def __init__(self, src: Optional[str], node: ast.AST, error_message: Optional[str] = None): + self.src = src + self.node = node + self.error_message = error_message + self.message = self._format_message() + + def __str__(self): + return self.message + + def __reduce__(self): + # this is necessary to make CompilationError picklable + return type(self), (self.src, self.node, self.error_message) + + +class CompileTimeAssertionFailure(CompilationError): + """Specific exception for failed tests in `static_assert` invocations""" + pass + + +class UnsupportedLanguageConstruct(CompilationError): + pass diff --git a/lib/python3.10/site-packages/triton/compiler/make_launcher.py b/lib/python3.10/site-packages/triton/compiler/make_launcher.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/python3.10/site-packages/triton/language/__init__.py b/lib/python3.10/site-packages/triton/language/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..168dccfea7edadf50bb8b6ac91fd85b37fb26859 --- /dev/null +++ b/lib/python3.10/site-packages/triton/language/__init__.py @@ -0,0 +1,284 @@ +"""isort:skip_file""" +# Import order is significant here. + +from . import math +from . import extra +from .standard import ( + argmax, + argmin, + cdiv, + cumprod, + cumsum, + flip, + interleave, + max, + min, + ravel, + sigmoid, + softmax, + sort, + sum, + swizzle2d, + xor_sum, + zeros, + zeros_like, +) +from .core import ( + PropagateNan, + TRITON_MAX_TENSOR_NUMEL, + _experimental_descriptor_load, + _experimental_descriptor_store, + advance, + arange, + associative_scan, + atomic_add, + atomic_and, + atomic_cas, + atomic_max, + atomic_min, + atomic_or, + atomic_xchg, + atomic_xor, + bfloat16, + block_type, + broadcast, + broadcast_to, + cat, + cast, + clamp, + const, + const_pointer_type, + constexpr, + debug_barrier, + device_assert, + device_print, + dot, + dtype, + expand_dims, + float16, + float32, + float64, + float8e4b15, + float8e4nv, + float8e4b8, + float8e5, + float8e5b16, + full, + function_type, + histogram, + inline_asm_elementwise, + int1, + int16, + int32, + int64, + int8, + join, + load, + make_block_ptr, + max_constancy, + max_contiguous, + maximum, + minimum, + multiple_of, + num_programs, + permute, + pi32_t, + pointer_type, + program_id, + range, + reduce, + reshape, + split, + static_assert, + static_print, + static_range, + store, + tensor, + trans, + uint16, + uint32, + uint64, + uint8, + view, + void, + where, +) +from .math import (umulhi, exp, exp2, fma, log, log2, cos, rsqrt, sin, sqrt, sqrt_rn, abs, fdiv, div_rn, erf, floor, + ceil) +from .random import ( + pair_uniform_to_normal, + philox, + philox_impl, + rand, + rand4x, + randint, + randint4x, + randn, + randn4x, + uint_to_uniform_float, +) + +__all__ = [ + "PropagateNan", + "TRITON_MAX_TENSOR_NUMEL", + "_experimental_descriptor_load", + "_experimental_descriptor_store", + "abs", + "advance", + "arange", + "argmax", + "argmin", + "associative_scan", + "atomic_add", + "atomic_and", + "atomic_cas", + "atomic_max", + "atomic_min", + "atomic_or", + "atomic_xchg", + "atomic_xor", + "bfloat16", + "block_type", + "broadcast", + "broadcast_to", + "builtin", + "cat", + "cast", + "cdiv", + "ceil", + "clamp", + "const", + "const_pointer_type", + "constexpr", + "cos", + "cumprod", + "cumsum", + "debug_barrier", + "device_assert", + "device_print", + "div_rn", + "dot", + "dtype", + "erf", + "exp", + "exp2", + "expand_dims", + "extra", + "fdiv", + "flip", + "float16", + "float32", + "float64", + "float8e4b15", + "float8e4nv", + "float8e4b8", + "float8e5", + "float8e5b16", + "floor", + "fma", + "full", + "function_type", + "histogram", + "inline_asm_elementwise", + "interleave", + "int1", + "int16", + "int32", + "int64", + "int8", + "ir", + "join", + "load", + "log", + "log2", + "make_block_ptr", + "math", + "max", + "max_constancy", + "max_contiguous", + "maximum", + "min", + "minimum", + "multiple_of", + "num_programs", + "pair_uniform_to_normal", + "permute", + "philox", + "philox_impl", + "pi32_t", + "pointer_type", + "program_id", + "rand", + "rand4x", + "randint", + "randint4x", + "randn", + "randn4x", + "range", + "ravel", + "reduce", + "reshape", + "rsqrt", + "sigmoid", + "sin", + "softmax", + "sort", + "split", + "sqrt", + "sqrt_rn", + "static_assert", + "static_print", + "static_range", + "store", + "sum", + "swizzle2d", + "tensor", + "trans", + "triton", + "uint16", + "uint32", + "uint64", + "uint8", + "uint_to_uniform_float", + "umulhi", + "view", + "void", + "where", + "xor_sum", + "zeros", + "zeros_like", +] + + +def str_to_ty(name): + if name[0] == "*": + name = name[1:] + if name[0] == "k": + name = name[1:] + ty = str_to_ty(name) + return const_pointer_type(ty) + ty = str_to_ty(name) + return pointer_type(ty) + tys = { + "fp8e4nv": float8e4nv, + "fp8e4b8": float8e4b8, + "fp8e5": float8e5, + "fp8e5b16": float8e5b16, + "fp8e4b15": float8e4b15, + "fp16": float16, + "bf16": bfloat16, + "fp32": float32, + "fp64": float64, + "i1": int1, + "i8": int8, + "i16": int16, + "i32": int32, + "i64": int64, + "u1": int1, + "u8": uint8, + "u16": uint16, + "u32": uint32, + "u64": uint64, + "B": int1, + } + return tys[name] diff --git a/lib/python3.10/site-packages/triton/language/core.py b/lib/python3.10/site-packages/triton/language/core.py new file mode 100644 index 0000000000000000000000000000000000000000..f2d3266e9a9f94b3071c37b20d28a559c5073f82 --- /dev/null +++ b/lib/python3.10/site-packages/triton/language/core.py @@ -0,0 +1,2621 @@ +from __future__ import annotations + +from warnings import warn +from contextlib import contextmanager +from enum import Enum +from functools import partial, wraps +import typing +from typing import Union, Callable, List, Sequence, TypeVar, Optional +import builtins +from ..runtime.jit import jit +import inspect +import os + +from .._C.libtriton import ir +from . import semantic + +T = TypeVar('T') + +TRITON_MAX_TENSOR_NUMEL = 1048576 + +TRITON_BUILTIN = "__triton_builtin__" + +PropagateNan = ir.PROPAGATE_NAN + + +def builtin(fn: T) -> T: + """Mark a function as a builtin.""" + assert callable(fn) + + @wraps(fn) + def wrapper(*args, **kwargs): + if "_builder" not in kwargs or kwargs["_builder"] is None: + raise ValueError("Did you forget to add @triton.jit ? " + "(`_builder` argument must be provided outside of JIT functions.)") + return fn(*args, **kwargs) + + setattr(wrapper, TRITON_BUILTIN, True) + + return wrapper + + +def _tensor_member_fn(fn: T) -> T: + """Decorator that adds this free function as a member fn on class tensor. + + When called as a member function on class tensor, the first argument to `fn` + is `self`, i.e. the tensor object. + + If there are multiple decorators on a function, you probably want this one + to be the highest one (i.e. furthest from the function's `def`), so it's + applied last. + + Unfortunately you still need to add a type stub to the body of class tensor + in order for pytype to know about it. + """ + assert callable(fn) + orig_sig = inspect.signature(fn) + # Does fn take args other than _builder, _generator, and the tensor itself? + has_args = len(orig_sig.parameters.keys() - {"_builder", "_generator"}) > 1 + + if not fn.__doc__: + fn.__doc__ = "" + fn.__doc__ += f""" + This function can also be called as a member function on :py:class:`tensor`, + as :code:`x.{fn.__name__}({"..." if has_args else ""})` instead of + :code:`{fn.__name__}(x{", ..." if has_args else ""})`. + """ + + def wrapper(*args, **kwargs): + return fn(*args, **kwargs) + + # Match the signature of `fn`, but change the first arg to `self` so the + # docs are a little less weird. + new_params = list(orig_sig.parameters.values()) + new_params[0] = new_params[0].replace(name='self') + new_sig = orig_sig.replace(parameters=new_params) + wrapper.__signature__ = new_sig + wrapper.__doc__ = f"Forwards to :py:func:`{fn.__name__}` free function" + # If fn is a builtin, mark the wrapper as a builtin too. + if is_builtin(fn): + setattr(wrapper, TRITON_BUILTIN, True) + + setattr(tensor, fn.__name__, wrapper) + return fn + + +def _unwrap_iterable(x): + """Returns x[0] if x has one element and x[0] is iterable.""" + if len(x) == 1: + # Determine whether x[0] is iterable. + # + # You might want to use collections.abc.Iterable instead of this + # try/except block. Unfortunately, this doesn't work with constexpr. + # + # The problem is that abc.Iterable checks for __iter__ on the *class*. + # But we want constexpr to expose an __iter__ method if and only if the + # wrapped *object* (i.e. self.value) is iterable. Therefore there's no + # right answer for whether the class constexpr defines __iter__, and + # abc.Iterable doesn't work (at least not without some metaclass magic). + try: + iter(x[0]) + return x[0] + except TypeError: + pass + + return x + + +def is_builtin(fn) -> bool: + """Is this a registered triton builtin function?""" + return getattr(fn, TRITON_BUILTIN, False) + + +@builtin +def to_tensor(x, _builder=None): + return _to_tensor(x, _builder) + + +def _to_tensor(x, builder): + if isinstance(x, bool): + return tensor(builder.get_int1(x), int1) + # Note: compile-time const integers are represented by unsigned values + elif isinstance(x, int): + if -2**31 <= x < 2**31: + return tensor(builder.get_int32(x), int32) + elif 2**31 <= x < 2**32: + return tensor(builder.get_uint32(x), uint32) + elif -2**63 <= x < 2**63: + return tensor(builder.get_int64(x), int64) + elif 2**63 <= x < 2**64: + return tensor(builder.get_uint64(x), uint64) + else: + raise RuntimeError(f'Nonrepresentable integer {x}.') + elif isinstance(x, float): + min_float32 = 2**-126 + max_float32 = (2 - 2**-23) * 2**127 + abs_x = __builtins__['abs'](x) + if abs_x == float("inf") or\ + abs_x == 0.0 or \ + x != x or \ + min_float32 <= abs_x <= max_float32: + return tensor(builder.get_fp32(x), float32) + else: + return tensor(builder.get_fp64(x), float64) + + elif isinstance(x, constexpr): + return _to_tensor(x.value, builder) + elif isinstance(x, tensor): + return x + assert False, f"cannot convert {x} of type {type(x)} to tensor" + + +class dtype: + SINT_TYPES = ['int8', 'int16', 'int32', 'int64'] + UINT_TYPES = ['int1', 'uint8', 'uint16', 'uint32', 'uint64'] + FP_TYPES = ['fp8e4b15', 'fp8e4nv', 'fp8e4b8', 'fp8e5', 'fp8e5b16', 'fp16', 'bf16', 'fp32', 'fp64'] + STANDARD_FP_TYPES = ['fp16', 'bf16', 'fp32', 'fp64'] + OTHER_TYPES = ['void'] + + class SIGNEDNESS(Enum): + SIGNED = 0 + UNSIGNED = 1 + + def __init__(self, name): + if hasattr(name, 'value'): + name = name.value + self.name = name + assert name in dtype.SINT_TYPES + dtype.UINT_TYPES + dtype.FP_TYPES + dtype.OTHER_TYPES, name + if name in dtype.SINT_TYPES: + self.int_signedness = dtype.SIGNEDNESS.SIGNED + self.int_bitwidth = int(name.split('int')[-1]) + self.primitive_bitwidth = self.int_bitwidth + elif name in dtype.UINT_TYPES: + self.int_signedness = dtype.SIGNEDNESS.UNSIGNED + self.int_bitwidth = int(name.split('int')[-1]) + self.primitive_bitwidth = self.int_bitwidth + elif name in dtype.FP_TYPES: + if name == 'fp8e4b15': + self.fp_mantissa_width = 3 + self.primitive_bitwidth = 8 + self.exponent_bias = 15 + elif name == 'fp8e4nv': + self.fp_mantissa_width = 3 + self.primitive_bitwidth = 8 + self.exponent_bias = 7 + elif name == 'fp8e4b8': + self.fp_mantissa_width = 3 + self.primitive_bitwidth = 8 + self.exponent_bias = 8 + elif name == 'fp8e5': + self.fp_mantissa_width = 2 + self.primitive_bitwidth = 8 + self.exponent_bias = 15 + elif name == 'fp8e5b16': + self.fp_mantissa_width = 2 + self.primitive_bitwidth = 8 + self.exponent_bias = 16 + elif name == 'fp16': + self.fp_mantissa_width = 10 + self.primitive_bitwidth = 16 + self.exponent_bias = 15 + elif name == 'bf16': + self.fp_mantissa_width = 7 + self.primitive_bitwidth = 16 + self.exponent_bias = 127 + elif name == 'fp32': + self.fp_mantissa_width = 23 + self.primitive_bitwidth = 32 + self.exponent_bias = 127 + elif name == 'fp64': + self.fp_mantissa_width = 53 + self.primitive_bitwidth = 64 + self.exponent_bias = 1023 + else: + raise RuntimeError(f'Unsupported floating-point type {name}') + elif name == 'void': + self.primitive_bitwidth = 0 + + def is_fp8(self): + return 'fp8' in self.name + + def is_fp8e4nv(self): + return self.name == 'fp8e4nv' + + def is_fp8e4b8(self): + return self.name == 'fp8e4b8' + + def is_fp8e4b15(self): + return self.name == 'fp8e4b15' + + def is_fp8e5(self): + return self.name == 'fp8e5' + + def is_fp8e5b16(self): + return self.name == 'fp8e5b16' + + def is_fp16(self): + return self.name == 'fp16' + + def is_bf16(self): + return self.name == 'bf16' + + def is_fp32(self): + return self.name == 'fp32' + + def is_fp64(self): + return self.name == 'fp64' + + def is_int1(self): + return self.name == 'int1' + + def is_int8(self): + return self.name == 'int8' + + def is_int16(self): + return self.name == 'int16' + + def is_int32(self): + return self.name == 'int32' + + def is_int64(self): + return self.name == 'int64' + + def is_uint8(self): + return self.name == 'uint8' + + def is_uint16(self): + return self.name == 'uint16' + + def is_uint32(self): + return self.name == 'uint32' + + def is_uint64(self): + return self.name == 'uint64' + + def is_floating(self): + return self.name in dtype.FP_TYPES + + def is_standard_floating(self): + return self.name in dtype.STANDARD_FP_TYPES + + def is_int_signed(self): + return self.name in dtype.SINT_TYPES + + def is_int_unsigned(self): + return self.name in dtype.UINT_TYPES + + def is_int(self): + return self.name in dtype.SINT_TYPES + dtype.UINT_TYPES + + def is_bool(self): + return self.is_int1() + + @staticmethod + def is_dtype(type_str): + return type_str in dtype.SINT_TYPES + dtype.UINT_TYPES + dtype.FP_TYPES + dtype.OTHER_TYPES + + @staticmethod + def is_void(): + raise RuntimeError("Not implemented") + + @staticmethod + def is_block(): + return False + + @staticmethod + def is_ptr(): + return False + + @staticmethod + def is_const(): + return False + + def __eq__(self, other: dtype): + if not isinstance(other, dtype): + return False + return self.name == other.name + + def __ne__(self, other: dtype): + return not self.__eq__(other) + + def __hash__(self): + return hash((self.name, )) + + @property + def scalar(self): + return self + + def to_ir(self, builder: ir.builder) -> ir.type: + if self.name == 'void': + return builder.get_void_ty() + elif self.name == 'int1': + return builder.get_int1_ty() + elif self.name in ('int8', 'uint8'): + return builder.get_int8_ty() + elif self.name in ('int16', 'uint16'): + return builder.get_int16_ty() + elif self.name in ('int32', 'uint32'): + return builder.get_int32_ty() + elif self.name in ('int64', 'uint64'): + return builder.get_int64_ty() + elif self.name == 'fp8e5': + return builder.get_fp8e5_ty() + elif self.name == 'fp8e5b16': + return builder.get_fp8e5b16_ty() + elif self.name == 'fp8e4nv': + return builder.get_fp8e4nv_ty() + elif self.name == 'fp8e4b8': + return builder.get_fp8e4b8_ty() + elif self.name == 'fp8e4b15': + return builder.get_fp8e4b15_ty() + elif self.name == 'fp16': + return builder.get_half_ty() + elif self.name == 'bf16': + return builder.get_bf16_ty() + elif self.name == 'fp32': + return builder.get_float_ty() + elif self.name == 'fp64': + return builder.get_double_ty() + raise ValueError(f'fail to convert {self} to ir type') + + def __str__(self): + return self.name + + def codegen_name(self): + if self.name.startswith("fp"): + return "float" + self.name[2:] + elif self.name.startswith("bf"): + return "bfloat" + self.name[2:] + else: + return self.name + + @property + def cache_key_part(self) -> str: + """See cache_key_part() in triton.cc.""" + return self.name + + def __repr__(self): + """Output of repr needs to be an evaluatable expression""" + return f'triton.language.{self.codegen_name()}' + + +# Some functions have a param named `dtype`, which shadows the `dtype` class. +# We can't change the param name because it is part of function's public API. +# Declare an alias so those functions can still reference the dtype class. +_DtypeClass = dtype + + +class pointer_type(dtype): + + def __init__(self, element_ty: dtype, address_space: int = 1): + if not isinstance(element_ty, dtype): + raise TypeError(f'element_ty is a {type(element_ty).__name__}.') + self.element_ty = element_ty + self.address_space = address_space + + self.name = f'pointer<{element_ty}>' + + def to_ir(self, builder: ir.builder) -> ir.pointer_type: + return builder.get_ptr_ty(self.element_ty.to_ir(builder), 1) + + def __str__(self): + return self.name + + def __repr__(self): + return self.__str__() + + def is_ptr(self): + return True + + def __eq__(self, other: pointer_type) -> bool: + if not isinstance(other, pointer_type): + return False + return self.element_ty == other.element_ty and self.address_space == other.address_space + + def __ne__(self, other: pointer_type) -> bool: + return not self.__eq__(other) + + @property + def scalar(self): + return self + + +class const_pointer_type(pointer_type): + + def __init__(self, element_ty: dtype, address_space: int = 1): + super().__init__(element_ty, address_space) + + def __str__(self): + return f'const_pointer<{self.element_ty}>' + + def is_const(self): + return True + + def __eq__(self, other) -> bool: + if not isinstance(other, const_pointer_type): + return False + return self.element_ty == other.element_ty and self.address_space == other.address_space + + +class block_type(dtype): + + def __init__(self, element_ty: dtype, shape: List): + self.element_ty = element_ty + + # Note that block_type's shape is a list of int + # while tensor's shape is a list of constexpr. + + # shape can be empty ([]) when an input is a 0D tensor. + if not shape: + raise TypeError('0d block_type is forbidden') + if isinstance(shape[0], constexpr): + shape = [s.value for s in shape] + + self.shape = shape + self.numel = 1 + for s in self.shape: + self.numel *= s + if self.numel > TRITON_MAX_TENSOR_NUMEL: + raise ValueError(f"numel ({self.numel}) exceeds triton maximum tensor numel ({TRITON_MAX_TENSOR_NUMEL})") + + self.name = f'<{self.shape}, {self.element_ty}>' + + def to_ir(self, builder: ir.builder) -> ir.block_type: + return builder.get_block_ty(self.element_ty.to_ir(builder), self.shape) + + def __str__(self): + return self.name + + def __repr__(self): + return self.__str__() + + def is_block(self): + return True + + def get_block_shapes(self) -> List[int]: + return self.shape + + def __eq__(self, other: block_type) -> bool: + if not isinstance(other, block_type): + return False + return self.element_ty == other.element_ty and self.shape == other.shape + + def __ne__(self, other: block_type) -> bool: + return not self.__eq__(other) + + @property + def scalar(self): + return self.element_ty + + +class function_type(dtype): + + def __init__(self, ret_types: List[dtype], param_types: List[dtype]) -> None: + self.ret_types = ret_types + self.param_types = param_types + + def __str__(self): + return f'fn ({self.param_types}) -> {self.ret_types}' + + def to_ir(self, builder: ir.builder): + ir_param_types = [ty.to_ir(builder) for ty in self.param_types] + ret_types = [ret_type.to_ir(builder) for ret_type in self.ret_types] + return builder.get_function_ty(ir_param_types, ret_types) + + +# scalar types +void = dtype('void') +int1 = dtype('int1') +int8 = dtype('int8') +int16 = dtype('int16') +int32 = dtype('int32') +int64 = dtype('int64') +uint8 = dtype('uint8') +uint16 = dtype('uint16') +uint32 = dtype('uint32') +uint64 = dtype('uint64') +float8e5 = dtype('fp8e5') +float8e5b16 = dtype('fp8e5b16') +float8e4nv = dtype('fp8e4nv') +float8e4b8 = dtype('fp8e4b8') +float8e4b15 = dtype('fp8e4b15') +float16 = dtype('fp16') +bfloat16 = dtype('bf16') +float32 = dtype('fp32') +float64 = dtype('fp64') +# pointer types +pi32_t = pointer_type(int32) + + +def get_int_dtype(bitwidth: int, signed: bool) -> dtype: + if bitwidth == 1: + return int1 + elif bitwidth == 8 and signed: + return int8 + elif bitwidth == 8 and not signed: + return uint8 + elif bitwidth == 16 and signed: + return int16 + elif bitwidth == 16 and not signed: + return uint16 + elif bitwidth == 32 and signed: + return int32 + elif bitwidth == 32 and not signed: + return uint32 + elif bitwidth == 64 and signed: + return int64 + elif bitwidth == 64 and not signed: + return uint64 + else: + raise ValueError(f'Unsupported bitwidth {bitwidth} and signedness {signed}') + + +# ----------------------- +# constexpr +# ----------------------- + + +class const: + """ + This class is used as a type annotation to mark pointers to constant data. + The `store` function cannot be called with a pointer to const. Constness + is part of the pointer type and the usual Triton type consistency rules + apply. For example you cannot have a function that returns constant pointer + in one return statement and non-constant pointer in another. + """ + pass + + +class constexpr: + """ + This class is used to store a value that is known at compile-time. + """ + + def __init__(self, value): + if isinstance(value, constexpr): + self.value = value.value + else: + self.value = value + + def __repr__(self) -> str: + return f"constexpr[{self.value}]" + + def __index__(self): + return self.value + + # In interpreter mode, constant values are not wrapped in constexpr, + # and therefore do not have a .value attribute. + # As a result, from here and below, we need to call the _constexpr_to_value + # function to obtain either constexpr.value or the value itself. + def __add__(self, other): + return constexpr(self.value + _constexpr_to_value(other)) + + def __radd__(self, other): + return constexpr(_constexpr_to_value(other) + self.value) + + def __sub__(self, other): + return constexpr(self.value - _constexpr_to_value(other)) + + def __rsub__(self, other): + return constexpr(_constexpr_to_value(other) - self.value) + + def __mul__(self, other): + return constexpr(self.value * _constexpr_to_value(other)) + + def __mod__(self, other): + return constexpr(self.value % _constexpr_to_value(other)) + + def __rmul__(self, other): + return constexpr(_constexpr_to_value(other) * self.value) + + def __truediv__(self, other): + return constexpr(self.value / _constexpr_to_value(other)) + + def __rtruediv__(self, other): + return constexpr(_constexpr_to_value(other) / self.value) + + def __floordiv__(self, other): + return constexpr(self.value // _constexpr_to_value(other)) + + def __rfloordiv__(self, other): + return constexpr(_constexpr_to_value(other) // self.value) + + def __gt__(self, other): + return constexpr(self.value > _constexpr_to_value(other)) + + def __rgt__(self, other): + return constexpr(_constexpr_to_value(other) > self.value) + + def __ge__(self, other): + return constexpr(self.value >= _constexpr_to_value(other)) + + def __rge__(self, other): + return constexpr(_constexpr_to_value(other) >= self.value) + + def __lt__(self, other): + return constexpr(self.value < _constexpr_to_value(other)) + + def __rlt__(self, other): + return constexpr(_constexpr_to_value(other) < self.value) + + def __le__(self, other): + return constexpr(self.value <= _constexpr_to_value(other)) + + def __rle__(self, other): + return constexpr(_constexpr_to_value(other) <= self.value) + + def __eq__(self, other): + return constexpr(self.value == _constexpr_to_value(other)) + + def __ne__(self, other): + return constexpr(self.value != _constexpr_to_value(other)) + + def __bool__(self): + return bool(self.value) + + def __neg__(self): + return constexpr(-self.value) + + def __and__(self, other): + return constexpr(self.value & _constexpr_to_value(other)) + + def logical_and(self, other): + return constexpr(self.value and _constexpr_to_value(other)) + + def __or__(self, other): + return constexpr(self.value | _constexpr_to_value(other)) + + def __xor__(self, other): + return constexpr(self.value ^ _constexpr_to_value(other)) + + def logical_or(self, other): + return constexpr(self.value or _constexpr_to_value(other)) + + def __pos__(self): + return constexpr(+self.value) + + def __invert__(self): + return constexpr(~self.value) + + def __pow__(self, other): + return constexpr(self.value**_constexpr_to_value(other)) + + def __rpow__(self, other): + return constexpr(_constexpr_to_value(other)**self.value) + + def __rshift__(self, other): + return constexpr(self.value >> _constexpr_to_value(other)) + + def __lshift__(self, other): + return constexpr(self.value << _constexpr_to_value(other)) + + def __not__(self): + return constexpr(not self.value) + + def __iter__(self): + return iter(self.value) + + def __call__(self, *args, **kwds): + return self.value(*args, **kwds) + + +CONSTEXPR_0 = constexpr(0) + + +def check_bit_width(value, shift_value): + if isinstance(value, tensor) and isinstance(shift_value, constexpr): + bitwidth = value.type.scalar.primitive_bitwidth + if shift_value.value >= bitwidth: + warn( + f"Value {shift_value.value} exceeds the maximum bitwidth ({bitwidth}) for type '{value.dtype}'. This may result in undefined behavior." + ) + + +class tensor: + """Represents an N-dimensional array of values or pointers. + + :code:`tensor` is the fundamental data structure in Triton programs. Most + functions in :py:mod:`triton.language` operate on and return tensors. + + Most of the named member functions here are duplicates of the free functions + in :code:`triton.language`. For example, :code:`triton.language.sqrt(x)` is + equivalent to :code:`x.sqrt()`. + + :code:`tensor` also defines most of the magic/dunder methods, so you can + write :code:`x+y`, :code:`x << 2`, etc. + + .. rubric:: Constructors + .. + For some reason Sphinx includes __init__ before printing the full table + of methods. Not what I want, but I can't figure out how to fix it. Give + it its own section so it looks intentional. :) + """ + + def __init__(self, handle, type: dtype): + """Not called by user code.""" + # IR handle + self.handle = handle + # Block shape + self.shape = type.shape if type.is_block() else () + self.numel = 1 + for s in self.shape: + self.numel *= s + self.numel = constexpr(self.numel) + self.type = type # Tensor type (can be block_type) + # Following the practice in pytorch, dtype is scalar type + self.dtype = type.scalar + self.shape = [constexpr(s) for s in self.shape] + + def __str__(self) -> str: + # ex. "float32[16, 32]" + return str(self.dtype) + '[' + ', '.join(str(s) for s in self.shape) + ']' + + @builtin + def __add__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.add(self, other, _builder) + + @builtin + def __radd__(self, other, _builder=None): + return self.__add__(other, _builder=_builder) + + @builtin + def __sub__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.sub(self, other, _builder) + + @builtin + def __rsub__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.sub(other, self, _builder) + + @builtin + def __mul__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.mul(self, other, _builder) + + @builtin + def __rmul__(self, other, _builder=None): + return self.__mul__(other, _builder=_builder) + + @builtin + def __truediv__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.truediv(self, other, _builder) + + @builtin + def __rtruediv__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.truediv(other, self, _builder) + + @builtin + def __floordiv__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.floordiv(self, other, _builder) + + @builtin + def __rfloordiv__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.floordiv(other, self, _builder) + + @builtin + def __mod__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.mod(self, other, _builder) + + @builtin + def __rmod__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.mod(other, self, _builder) + + # unary operators + @builtin + def __neg__(self, _builder=None): + return semantic.minus(self, _builder) + + @builtin + def __invert__(self, _builder=None): + return semantic.invert(self, _builder) + + # bitwise operators + + @builtin + def __and__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.and_(self, other, _builder) + + @builtin + def __rand__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.and_(other, self, _builder) + + @builtin + def __or__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.or_(self, other, _builder) + + @builtin + def __ror__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.or_(other, self, _builder) + + @builtin + def __xor__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.xor_(self, other, _builder) + + @builtin + def __rxor__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.xor_(other, self, _builder) + + @builtin + def __lshift__(self, other, _builder=None): + check_bit_width(self, other) + other = _to_tensor(other, _builder) + return semantic.shl(self, other, _builder) + + @builtin + def __rlshift__(self, other, _builder=None): + check_bit_width(other, self) + other = _to_tensor(other, _builder) + return semantic.shl(other, self, _builder) + + @builtin + def __rshift__(self, other, _builder=None): + check_bit_width(self, other) + other = _to_tensor(other, _builder) + if self.dtype.is_int_signed(): + return semantic.ashr(self, other, _builder) + else: + return semantic.lshr(self, other, _builder) + + @builtin + def __rrshift__(self, other, _builder=None): + check_bit_width(other, self) + other = _to_tensor(other, _builder) + if self.dtype.is_int_signed(): + return semantic.ashr(other, self, _builder) + else: + return semantic.lshr(other, self, _builder) + + # > + @builtin + def __gt__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.greater_than(self, other, _builder) + + @builtin + def __rgt__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.greater_than(other, self, _builder) + + # >= + @builtin + def __ge__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.greater_equal(self, other, _builder) + + @builtin + def __rge__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.greater_equal(other, self, _builder) + + # < + @builtin + def __lt__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.less_than(self, other, _builder) + + @builtin + def __rlt__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.less_than(other, self, _builder) + + # <= + @builtin + def __le__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.less_equal(self, other, _builder) + + @builtin + def __rle__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.less_equal(other, self, _builder) + + # == + @builtin + def __eq__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.equal(self, other, _builder) + + @builtin + def __req__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.equal(other, self, _builder) + + @builtin + def __ne__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.not_equal(self, other, _builder) + + @builtin + def __rne__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.not_equal(other, self, _builder) + + @builtin + def logical_and(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.logical_and(self, other, _builder) + + @builtin + def logical_or(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.logical_or(self, other, _builder) + + # note: __not__ isn't actually a magic method in python + # but it's ok because our ASTVisitor handles it + @builtin + def __not__(self, _builder=None): + return semantic.not_(self, _builder) + + @builtin + def __getitem__(self, slices, _builder=None): + if isinstance(slices, (slice, constexpr)) or slices is None: + slices = [slices] + ret = self + for dim, sl in enumerate(slices): + if sl is None or isinstance(sl, constexpr) and sl.value is None: + ret = semantic.expand_dims(ret, dim, _builder) + elif isinstance(sl, slice) and sl.start is None and sl.stop is None and sl.step is None: + pass + else: + raise ValueError(f"unsupported tensor index: {sl}") + return ret + + @property + def T(self): + """Transposes a 2D tensor.""" + assert False, "Transposition must be created by the AST Visitor" + + @builtin + def to(self, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcast: bool = False, _builder=None): + """ + Alias for :py:func:`tensor.cast`. + """ + # Triton doesn't like core functions calling other core functions, so we + # just copy-paste the implementation of cast here. It's not too bad. + if isinstance(bitcast, constexpr): + bitcast = bitcast.value + if bitcast: + return semantic.bitcast(self, dtype, _builder) + return semantic.cast(self, dtype, _builder, fp_downcast_rounding) + + # Type stubs for functions added by the _tensor_member_fn decorator. + # (Unfortunately these can't be created automatically.) + # + # We couldn't write these definitions out even if we wanted to, because some + # of these functions are defined in standard.py. + def broadcast_to(self, *shape) -> tensor: + ... + + def trans(self, *dims) -> tensor: + ... + + def permute(self, *dims) -> tensor: + ... + + def split(self) -> tuple[tensor, tensor]: + ... + + def view(self, *shape) -> tensor: + ... + + def reshape(self, *shape) -> tensor: + ... + + def expand_dims(self, axis) -> tensor: + ... + + def cast(self, dtype, fp_downcast_rounding=None, bitcast=False) -> tensor: + ... + + def store(self, value, mask=None, boundary_check=(), cache_modifier="", eviction_policy="") -> tensor: + ... + + def advance(self, offsets) -> tensor: + ... + + def atomic_cas(self, cmp, val, sem=None, scope=None) -> tensor: + ... + + def atomic_xchg(self, val, mask=None, sem=None, scope=None) -> tensor: + ... + + def atomic_add(self, val, mask=None, sem=None, scope=None) -> tensor: + ... + + def atomic_max(self, val, mask=None, sem=None, scope=None) -> tensor: + ... + + def atomic_min(self, val, mask=None, sem=None, scope=None) -> tensor: + ... + + def atomic_and(self, val, mask=None, sem=None, scope=None) -> tensor: + ... + + def atomic_or(self, val, mask=None, sem=None, scope=None) -> tensor: + ... + + def atomic_xor(self, val, mask=None, sem=None, scope=None) -> tensor: + ... + + def exp(self) -> tensor: + ... + + def log(self) -> tensor: + ... + + def cos(self) -> tensor: + ... + + def sin(self) -> tensor: + ... + + def sqrt(self) -> tensor: + ... + + def rsqrt(self) -> tensor: + ... + + def abs(self) -> tensor: + ... + + def reduce(self, axis, combine_fn, keep_dims=False) -> tensor: + ... + + def associative_scan(self, axis, combine_fn, reverse=False) -> tensor: + ... + + def histogram(self, num_bins) -> tensor: + ... + + def cdiv(self, div) -> tensor: + ... + + def sigmoid(self) -> tensor: + ... + + def softmax(self, ieee_rounding=False) -> tensor: + ... + + def ravel(self) -> tensor: + ... + + def max(self, axis=None, return_indices=False, return_indices_tie_break_left=True, keep_dims=False) -> tensor: + ... + + def argmax(self, axis, tie_break_left=True, keep_dims=False) -> tensor: + ... + + def min(self, axis=None, return_indices=False, return_indices_tie_break_left=True, keep_dims=False) -> tensor: + ... + + def argmin(self, axis, tie_break_left=True, keep_dims=False) -> tensor: + ... + + def sum(self, axis=None, keep_dims=False) -> tensor: + ... + + def xor_sum(self, axis=None, keep_dims=False) -> tensor: + ... + + def cumsum(self, axis=0, reverse=False) -> tensor: + ... + + def cumprod(self, axis=0, reverse=False) -> tensor: + ... + + def sort(self, dim: constexpr = None, descending: constexpr = CONSTEXPR_0) -> tensor: + ... + + def flip(self, dim=None) -> tensor: + ... + + +def get_bool_env_var(var_name): + v = os.getenv(var_name, "0") + return v == "1" or v == "true" or v == "on" + + +# ----------------------- +# SPMD Programming Model +# ----------------------- +def _constexpr_to_value(v): + if isinstance(v, constexpr): + return v.value + return v + + +@builtin +def program_id(axis, _builder=None): + """ + Returns the id of the current program instance along the given :code:`axis`. + + :param axis: The axis of the 3D launch grid. Must be 0, 1 or 2. + :type axis: int + """ + # if axis == -1: + # pid0 = program_id(0, _builder) + # pid1 = program_id(1, _builder) + # pid2 = program_id(2, _builder) + # npg0 = num_programs(0, _builder) + # npg1 = num_programs(0, _builder) + # return pid0 + pid1*npg0 + pid2*npg0*npg1 + axis = _constexpr_to_value(axis) + return semantic.program_id(axis, _builder) + + +@builtin +def num_programs(axis, _builder=None): + """ + Returns the number of program instances launched along the given :code:`axis`. + + :param axis: The axis of the 3D launch grid. Must be 0, 1 or 2. + :type axis: int + """ + axis = _constexpr_to_value(axis) + return semantic.num_programs(axis, _builder) + + +# ----------------------- +# Block Initialization +# ----------------------- + + +@builtin +def arange(start, end, _builder=None): + """ + Returns contiguous values within the half-open interval :code:`[start, + end)`. :code:`end - start` must be less than or equal to + :code:`TRITON_MAX_TENSOR_NUMEL = 131072` + + :param start: Start of the interval. Must be a power of two. + :type start: int32 + :param end: End of the interval. Must be a power of two greater than + :code:`start`. + :type end: int32 + """ + start = _constexpr_to_value(start) + end = _constexpr_to_value(end) + return semantic.arange(start, end, _builder) + + +def _shape_check_impl(shape): + shape = _constexpr_to_value(shape) + for i, d in enumerate(shape): + if isinstance(d, int): + d = constexpr(d) + if not isinstance(d, constexpr): + raise TypeError(f"Shape element {i} must have type `constexpr`") + if not isinstance(d.value, int): + raise TypeError(f"Shape element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]") + if d.value & (d.value - 1) != 0: + raise ValueError(f"Shape element {i} must be a power of 2") + return [_constexpr_to_value(x) for x in shape] + + +@builtin +def full(shape, value, dtype, _builder=None): + """ + Returns a tensor filled with the scalar value for the given :code:`shape` and :code:`dtype`. + + :param shape: Shape of the new array, e.g., (8, 16) or (8, ) + :value value: A scalar value to fill the array with + :type shape: tuple of ints + :param dtype: Data-type of the new array, e.g., :code:`tl.float16` + :type dtype: DType + """ + shape = _shape_check_impl(shape) + value = _constexpr_to_value(value) + dtype = _constexpr_to_value(dtype) + return semantic.full(shape, value, dtype, _builder) + + +# ----------------------- +# Shape Manipulation +# ----------------------- + + +@builtin +def broadcast(input, other, _builder=None): + """ + Tries to broadcast the two given blocks to a common compatible shape. + + :param input: The first input tensor. + :type input: Block + :param other: The second input tensor. + :type other: Block + """ + return semantic.broadcast_impl_value(input, other, _builder) + + +@_tensor_member_fn +@builtin +def broadcast_to(input, *shape, _builder=None): + """ + Tries to broadcast the given tensor to a new :code:`shape`. + + :param input: The input tensor. + :type input: Block + :param shape: The desired shape. + :type shape: + + :code:`shape` can be passed as a tuple or as individual parameters: :: + + # These are equivalent + broadcast_to(x, (32, 32)) + broadcast_to(x, 32, 32) + """ + shape = _shape_check_impl(_unwrap_iterable(shape)) + return semantic.broadcast_impl_shape(input, shape, _builder) + + +@_tensor_member_fn +@builtin +def trans(input: tensor, *dims, _builder=None): + """ + Permutes the dimensions of a tensor. + + If no permutation is specified, tries to do a (1,0) permutation, i.e. tries + to transpose a 2D tensor. + + :param input: The input tensor. + :param dims: The desired ordering of dimensions. For example, + :code:`(2, 1, 0)` reverses the order dims in a a 3D tensor. + + :code:`dims` can be passed as a tuple or as individual parameters: :: + + # These are equivalent + trans(x, (2, 1, 0)) + trans(x, 2, 1, 0) + + :py:func:`permute` is equivalent to this function, except it doesn't + have the special case when no permutation is specified. + """ + if not dims: + dims = (1, 0) + return semantic.permute(input, dims, _builder) + + +@_tensor_member_fn +@builtin +def permute(input, *dims, _builder=None): + """ + Permutes the dimensions of a tensor. + + :param input: The input tensor. + :type input: Block + :param dims: The desired ordering of dimensions. For example, + :code:`(2, 1, 0)` reverses the order dims in a a 3D tensor. + + :code:`dims` can be passed as a tuple or as individual parameters: :: + + # These are equivalent + permute(x, (2, 1, 0)) + permute(x, 2, 1, 0) + + :py:func:`trans` is equivalent to this function, except when + :code:`dims` is empty, it tries to do a (1,0) permutation. + """ + dims = _unwrap_iterable(dims) + return semantic.permute(input, dims, _builder) + + +@builtin +def cat(input, other, can_reorder=False, _builder=None): + """ + Concatenate the given blocks + + :param input: The first input tensor. + :type input: + :param other: The second input tensor. + :type other: + :param reorder: Compiler hint. If true, the compiler is + allowed to reorder elements while concatenating inputs. Only use if the + order does not matter (e.g., result is only used in reduction ops) + """ + return semantic.cat(input, other, can_reorder, _builder) + + +@builtin +def join(a, b, _builder=None): + """ + Join the given tensors in a new, minor dimension. + + For example, given two tensors of shape (4,8), produces a new tensor of + shape (4,8,2). Given two scalars, returns a tensor of shape (2). + + The two inputs are broadcasted to be the same shape. + + If you want to join more than two elements, you can use multiple calls to + this function. This reflects the constraint in Triton that tensors must + have power-of-two sizes. + + join is the inverse of split. + + :param a: The first input tensor. + :type a: Tensor + :param b: The second input tensor. + :type b: Tensor + """ + return semantic.join(a, b, _builder) + + +@jit +def _take_first(a, b): + return a + + +@_tensor_member_fn +@builtin +def split(a, _builder=None, _generator=None) -> tuple[tensor, tensor]: + """ + Split a tensor in two along its last dim, which must have size 2. + + For example, given a tensor of shape (4,8,2), produces two tensors of shape + (4,8). Given a tensor of shape (2), returns two scalars. + + If you want to split into more than two pieces, you can use multiple calls + to this function (probably plus calling reshape). This reflects the + constraint in Triton that tensors must have power-of-two sizes. + + split is the inverse of join. + + :param a: The tensor to split. + :type a: Tensor + """ + # If len(a.shape) == 1, i.e. a.shape == [2], we should return two scalars. + # But semantic.split can only handle returning tensors. Work around this by + # expanding the input to shape [1,2] and then reducing the result. + was_rank_1 = len(a.shape) == 1 + if was_rank_1: + a = semantic.expand_dims(a, 0, _builder) + + out_lhs, out_rhs = semantic.split(a, _builder) + + if was_rank_1: + # Currently `reduce` is the best way to convert a tensor of shape [1] to a scalar. + out_lhs = typing.cast(tensor, reduce(out_lhs, None, _take_first, _builder=_builder, _generator=_generator)) + out_rhs = typing.cast(tensor, reduce(out_rhs, None, _take_first, _builder=_builder, _generator=_generator)) + + return out_lhs, out_rhs + + +@_tensor_member_fn +@builtin +def view(input, *shape, _builder=None): + """ + Returns a tensor with the same elements as `input` but a different shape. + The order of the elements may not be preserved. + + :param input: The input tensor. + :type input: Block + :param shape: The desired shape. + + :code:`shape` can be passed as a tuple or as individual parameters: :: + + # These are equivalent + view(x, (32, 32)) + view(x, 32, 32) + """ + warn("view is deprecated, please use reshape with can_reorder being true.") + shape = _shape_check_impl(_unwrap_iterable(shape)) + return semantic.reshape(input, shape, can_reorder=True, builder=_builder) + + +@_tensor_member_fn +@builtin +def reshape(input, *shape, can_reorder=False, _builder=None): + """ + Returns a tensor with the same number of elements as input but with the + provided shape. + + :param input: The input tensor. + :type input: Block + :param shape: The new shape. + + :code:`shape ` can be passed as a tuple or as individual parameters: :: + + # These are equivalent + reshape(x, (32, 32)) + reshape(x, 32, 32) + """ + shape = _shape_check_impl(_unwrap_iterable(shape)) + return semantic.reshape(input, shape, can_reorder, _builder) + + +def _wrap_axis(axis, ndim): + if not (-ndim <= axis < ndim): + raise ValueError(f"invalid axis {axis}. Expected {-ndim} <= axis < {ndim}") + + return axis if axis >= 0 else axis + ndim + + +@_tensor_member_fn +@builtin +def expand_dims(input, axis, _builder=None): + """ + Expand the shape of a tensor, by inserting new length-1 dimensions. + + Axis indices are with respect to the resulting tensor, so + ``result.shape[axis]`` will be 1 for each axis. + + :param input: The input tensor. + :type input: tl.tensor + :param axis: The indices to add new axes + :type axis: int | Sequence[int] + + """ + input = _to_tensor(input, _builder) + axis = _constexpr_to_value(axis) + axes = list(axis) if isinstance(axis, Sequence) else [axis] + new_ndim = len(input.shape) + len(axes) + axes = [_wrap_axis(_constexpr_to_value(d), new_ndim) for d in axes] + + if len(set(axes)) != len(axes): + raise ValueError(f"expand_dims received duplicate axes, normalized axes = {axes}") + + ret = input + for a in sorted(axes): + ret = semantic.expand_dims(ret, a, _builder) + return ret + + +@_tensor_member_fn +@builtin +def cast(input, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcast: bool = False, _builder=None): + """ + Casts a tensor to the given :code:`dtype`. + + :param dtype: The target data type. + :param fp_downcast_rounding: The rounding mode for downcasting + floating-point values. This parameter is only used when self is a + floating-point tensor and dtype is a floating-point type with a + smaller bitwidth. Supported values are :code:`"rtne"` (round to + nearest, ties to even) and :code:`"rtz"` (round towards zero). + :param bitcast: If true, the tensor is bitcasted to the given + :code:`dtype`, instead of being numerically casted. + """ + input = _to_tensor(input, _builder) + if isinstance(bitcast, constexpr): + bitcast = bitcast.value + if bitcast: + return semantic.bitcast(input, dtype, _builder) + return semantic.cast(input, dtype, _builder, fp_downcast_rounding) + + +# ----------------------- +# Linear Algebra +# ----------------------- + + +@builtin +def dot(input, other, acc=None, input_precision=None, allow_tf32=None, max_num_imprecise_acc=None, out_dtype=float32, + _builder=None): + """ + Returns the matrix product of two blocks. + + The two blocks must be two-dimensional and have compatible inner dimensions. + + :param input: The first tensor to be multiplied. + :type input: 2D tensor of scalar-type in {:code:`int8`, :code: `float8_e5m2`, :code:`float16`, :code:`bfloat16`, :code:`float32`} + :param other: The second tensor to be multiplied. + :type other: 2D tensor of scalar-type in {:code:`int8`, :code: `float8_e5m2`, :code:`float16`, :code:`bfloat16`, :code:`float32`} + :param input_precision: How to exercise the Tensor Cores for f32 x f32. If + the device does not have Tensor Cores or the inputs are not of dtype f32, + this option is ignored. For devices that do have tensor cores, the + default precision is tf32. + :type input_precision: string. Available options for nvidia: :code:`"tf32"`, :code:`"tf32x3"`, :code:`"ieee"`. Default: :code:`"tf32"`. Avaliable options for amd: :code:`"ieee"`. + :param allow_tf32: *Deprecated.* If true, input_precision is set to "tf32". + Only one of :code:`input_precision` and :code:`allow_tf32` can be + specified (i.e. at least one must be :code:`None`). + """ + assert input_precision is None or allow_tf32 is None, "Only one of input_precision and allow_tf32 can be specified" + if input_precision is None: + supports_tf32 = _builder and "tf32" in _builder.options.allowed_dot_input_precisions + default_precision = "tf32" if (supports_tf32 and (allow_tf32 or allow_tf32 is None)) else "ieee" + input_precision = os.getenv("TRITON_F32_DEFAULT", default_precision) + + input_precision = _constexpr_to_value(input_precision) + out_dtype = _constexpr_to_value(out_dtype) + max_num_imprecise_acc = _constexpr_to_value(max_num_imprecise_acc) + return semantic.dot(input, other, acc, input_precision, max_num_imprecise_acc, out_dtype, _builder) + + +# ----------------------- +# Non-Atomic Memory Operations +# ----------------------- + + +@builtin +def load(pointer, mask=None, other=None, boundary_check=(), padding_option="", cache_modifier="", eviction_policy="", + volatile=False, _builder=None): + """ + Return a tensor of data whose values are loaded from memory at location defined by `pointer`: + + (1) If `pointer` is a single element pointer, a scalar is be loaded. In + this case: + + - `mask` and `other` must also be scalars, + - `other` is implicitly typecast to `pointer.dtype.element_ty`, and + - `boundary_check` and `padding_option` must be empty. + + (2) If `pointer` is an N-dimensional tensor of pointers, an + N-dimensional tensor is loaded. In this case: + + - `mask` and `other` are implicitly broadcast to `pointer.shape`, + - `other` is implicitly typecast to `pointer.dtype.element_ty`, and + - `boundary_check` and `padding_option` must be empty. + + (3) If `pointer` is a block pointer defined by `make_block_ptr`, a + tensor is loaded. In this case: + + - `mask` and `other` must be None, and + - `boundary_check` and `padding_option` can be specified to control + the behavior of out-of-bound access. + + :param pointer: Pointer to the data to be loaded + :type pointer: `triton.PointerType`, or block of `dtype=triton.PointerType` + :param mask: if `mask[idx]` is false, do not load the data at address `pointer[idx]` + (must be `None` with block pointers) + :type mask: Block of `triton.int1`, optional + :param other: if `mask[idx]` is false, return `other[idx]` + :type other: Block, optional + :param boundary_check: tuple of integers, indicating the dimensions which should do the boundary check + :type boundary_check: tuple of ints, optional + :param padding_option: should be one of {"", "zero", "nan"}, do padding while out of bound + :param cache_modifier: changes cache option in NVIDIA PTX + :type cache_modifier: str, optional + :param eviction_policy: changes eviction policy in NVIDIA PTX + :type eviction_policy: str, optional + :param volatile: changes volatile option in NVIDIA PTX + :type volatile: bool, optional + """ + # `mask` and `other` can be constexpr + mask = _constexpr_to_value(mask) + other = _constexpr_to_value(other) + if mask is not None: + mask = _to_tensor(mask, _builder) + if other is not None: + other = _to_tensor(other, _builder) + padding_option = _constexpr_to_value(padding_option) + cache_modifier = _constexpr_to_value(cache_modifier) + eviction_policy = _constexpr_to_value(eviction_policy) + volatile = _constexpr_to_value(volatile) + return semantic.load(pointer, mask, other, boundary_check, padding_option, cache_modifier, eviction_policy, + volatile, _builder) + + +@builtin +def _experimental_descriptor_load(desc_pointer, offsets, shape, dtype, _builder=None): + """ + Experimental feature to access TMA descriptors loads. This is an escape hatch to easily exercise TTGIR operations. + This will be removed in the future and shouldn't be used in production code. + + This loads a tensor of data based on the descriptor and offsets. + """ + type = block_type(dtype, shape) + return semantic.descriptor_load(desc_pointer, offsets, "", "", type, _builder) + + +@builtin +def _experimental_descriptor_store(desc_pointer, value, offsets, _builder=None): + """ + Experimental feature to access TMA descriptors stores. This is an escape hatch to easily exercise TTGIR operations. + This will be removed in the future and shouldn't be used in production code. + + This stores a tensor of data based on the descriptor and offsets. + """ + return semantic.descriptor_store(desc_pointer, value, offsets, _builder) + + +@_tensor_member_fn +@builtin +def store(pointer, value, mask=None, boundary_check=(), cache_modifier="", eviction_policy="", _builder=None): + """ + Store a tensor of data into memory locations defined by `pointer`. + + (1) If `pointer` is a single element pointer, a scalar is stored. In + this case: + + - `mask` must also be scalar, and + - `boundary_check` and `padding_option` must be empty. + + (2) If `pointer` is an N-dimensional tensor of pointers, an + N-dimensional block is stored. In this case: + + - `mask` is implicitly broadcast to `pointer.shape`, and + - `boundary_check` must be empty. + + (3) If `pointer` is a block pointer defined by `make_block_ptr`, a block + of data is stored. In this case: + + - `mask` must be None, and + - `boundary_check` can be specified to control the behavior of out-of-bound access. + + `value` is implicitly broadcast to `pointer.shape` and typecast to `pointer.dtype.element_ty`. + + :param pointer: The memory location where the elements of `value` are stored + :type pointer: `triton.PointerType`, or block of `dtype=triton.PointerType` + :param value: The tensor of elements to be stored + :type value: Block + :param mask: If `mask[idx]` is false, do not store `value[idx]` at `pointer[idx]` + :type mask: Block of triton.int1, optional + :param boundary_check: tuple of integers, indicating the dimensions which should do the boundary check + :type boundary_check: tuple of ints, optional + :param cache_modifier: changes cache option in NVIDIA PTX + :type cache_modifier: str, optional + :param eviction_policy: changes eviction policy in NVIDIA PTX + :type eviction_policy: str, optional + """ + # `value` can be constexpr + value = _to_tensor(value, _builder) + mask = _constexpr_to_value(mask) + if mask is not None: + mask = _to_tensor(mask, _builder) + cache_modifier = _constexpr_to_value(cache_modifier) + eviction_policy = _constexpr_to_value(eviction_policy) + return semantic.store(pointer, value, mask, boundary_check, cache_modifier, eviction_policy, _builder) + + +@builtin +def make_block_ptr(base: tensor, shape, strides, offsets, block_shape, order, _builder=None): + """ + Returns a pointer to a block in a parent tensor + + :param base: The base pointer to the parent tensor + :param shape: The shape of the parent tensor + :param strides: The strides of the parent tensor + :param offsets: The offsets to the block + :param block_shape: The shape of the block + :param order: The order of the original data format + """ + return semantic.make_block_ptr(base, shape, strides, offsets, block_shape, order, _builder) + + +@_tensor_member_fn +@builtin +def advance(base, offsets, _builder=None): + """ + Advance a block pointer + + :param base: the block pointer to advance + :param offsets: the offsets to advance, a tuple by dimension + """ + return semantic.advance(base, offsets, _builder) + + +# ----------------------- +# Atomic Memory Operations +# ----------------------- + + +def _add_atomic_docstr(name: str, has_cmp: bool = False) -> Callable[[T], T]: + + def _decorator(func: T) -> T: + docstr = f""" + Performs an atomic {name} at the memory location specified by :code:`pointer`. + + Return the data stored at :code:`pointer` before the atomic operation. + + :param pointer: The memory locations to operate on + :type pointer: Block of dtype=triton.PointerDType""" + if has_cmp: + docstr += """ + :param cmp: The values expected to be found in the atomic object + :type cmp: Block of dtype=pointer.dtype.element_ty""" + docstr += """ + :param val: The values with which to perform the atomic operation + :type val: Block of dtype=pointer.dtype.element_ty + :param sem: Memory semantics to use ("ACQUIRE_RELEASE" (default), + "ACQUIRE", "RELEASE", or "RELAXED") + :type sem: str + :param scope: Scope of threads that observe synchronizing effect of the + atomic operation ("GPU" (default), "CTA", or "SYSTEM") + :type scope: str + """ + func.__doc__ = docstr + return func + + return _decorator + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("compare-and-swap", has_cmp=True) +def atomic_cas(pointer, cmp, val, sem=None, scope=None, _builder=None): + cmp = _to_tensor(cmp, _builder) + val = _to_tensor(val, _builder) + sem = _constexpr_to_value(sem) + scope = _constexpr_to_value(scope) + return semantic.atomic_cas(pointer, cmp, val, sem, scope, _builder) + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("exchange") +def atomic_xchg(pointer, val, mask=None, sem=None, scope=None, _builder=None): + val = _to_tensor(val, _builder) + sem = _constexpr_to_value(sem) + scope = _constexpr_to_value(scope) + mask = _constexpr_to_value(mask) + return semantic.atomic_xchg(pointer, val, mask, sem, scope, _builder) + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("add") +def atomic_add(pointer, val, mask=None, sem=None, scope=None, _builder=None): + val = _to_tensor(val, _builder) + sem = _constexpr_to_value(sem) + scope = _constexpr_to_value(scope) + mask = _constexpr_to_value(mask) + return semantic.atomic_add(pointer, val, mask, sem, scope, _builder) + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("max") +def atomic_max(pointer, val, mask=None, sem=None, scope=None, _builder=None): + val = _to_tensor(val, _builder) + sem = _constexpr_to_value(sem) + scope = _constexpr_to_value(scope) + mask = _constexpr_to_value(mask) + return semantic.atomic_max(pointer, val, mask, sem, scope, _builder) + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("min") +def atomic_min(pointer, val, mask=None, sem=None, scope=None, _builder=None): + val = _to_tensor(val, _builder) + sem = _constexpr_to_value(sem) + scope = _constexpr_to_value(scope) + mask = _constexpr_to_value(mask) + return semantic.atomic_min(pointer, val, mask, sem, scope, _builder) + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("logical and") +def atomic_and(pointer, val, mask=None, sem=None, scope=None, _builder=None): + val = _to_tensor(val, _builder) + sem = _constexpr_to_value(sem) + scope = _constexpr_to_value(scope) + mask = _constexpr_to_value(mask) + return semantic.atomic_and(pointer, val, mask, sem, scope, _builder) + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("logical or") +def atomic_or(pointer, val, mask=None, sem=None, scope=None, _builder=None): + val = _to_tensor(val, _builder) + sem = _constexpr_to_value(sem) + scope = _constexpr_to_value(scope) + mask = _constexpr_to_value(mask) + return semantic.atomic_or(pointer, val, mask, sem, scope, _builder) + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("logical xor") +def atomic_xor(pointer, val, mask=None, sem=None, scope=None, _builder=None): + val = _to_tensor(val, _builder) + sem = _constexpr_to_value(sem) + scope = _constexpr_to_value(scope) + mask = _constexpr_to_value(mask) + return semantic.atomic_xor(pointer, val, mask, sem, scope, _builder) + + +# ----------------------- +# Conditioning +# ----------------------- + + +@builtin +def where(condition, x, y, _builder=None): + """ + Returns a tensor of elements from either :code:`x` or :code:`y`, depending on :code:`condition`. + + Note that :code:`x` and :code:`y` are always evaluated regardless of the value of :code:`condition`. + + If you want to avoid unintended memory operations, use the :code:`mask` arguments in `triton.load` and `triton.store` instead. + + The shape of :code:`x` and :code:`y` are both broadcast to the shape of :code:`condition`. + :code:`x` and :code:`y` must have the same data type. + + :param condition: When True (nonzero), yield x, otherwise yield y. + :type condition: Block of triton.bool + :param x: values selected at indices where condition is True. + :param y: values selected at indices where condition is False. + """ + condition = _to_tensor(condition, _builder) + x = _to_tensor(x, _builder) + y = _to_tensor(y, _builder) + return semantic.where(condition, x, y, _builder) + + +# ----------------------- +# Math +# ----------------------- + + +@builtin +def minimum(x, y, propagate_nan: constexpr = PropagateNan.NONE, _builder=None): + """ + Computes the element-wise minimum of :code:`x` and :code:`y`. + + :param x: the first input tensor + :type x: Block + :param y: the second input tensor + :type y: Block + :param propagate_nan: whether to propagate NaN values. + :type propagate_nan: tl.PropagateNan + + .. seealso:: :class:`tl.PropagateNan` + """ + x = _to_tensor(x, _builder) + y = _to_tensor(y, _builder) + x = _promote_bfloat16_to_float32(x, _builder=_builder) + y = _promote_bfloat16_to_float32(y, _builder=_builder) + propagate_nan = _constexpr_to_value(propagate_nan) + return semantic.minimum(x, y, propagate_nan, _builder) + + +@builtin +def maximum(x, y, propagate_nan: constexpr = PropagateNan.NONE, _builder=None): + """ + Computes the element-wise maximum of :code:`x` and :code:`y`. + + :param x: the first input tensor + :type x: Block + :param y: the second input tensor + :type y: Block + :param propagate_nan: whether to propagate NaN values. + :type propagate_nan: tl.PropagateNan + + .. seealso:: :class:`tl.PropagateNan` + """ + x = _to_tensor(x, _builder) + y = _to_tensor(y, _builder) + x = _promote_bfloat16_to_float32(x, _builder=_builder) + y = _promote_bfloat16_to_float32(y, _builder=_builder) + propagate_nan = _constexpr_to_value(propagate_nan) + return semantic.maximum(x, y, propagate_nan, _builder) + + +@builtin +def clamp(x, min, max, propagate_nan: constexpr = PropagateNan.NONE, _builder=None): + """ + Clamps the input tensor :code:`x` within the range [min, max]. + Behavior when :code:`min` > :code:`max` is undefined. + + :param x: the input tensor + :type x: Block + :param min: the lower bound for clamping + :type min: Block + :param max: the upper bound for clamping + :type max: Block + :param propagate_nan: whether to propagate NaN values. Applies only to the :code:`x` tensor. + If either :code:`min` or :code:`max` is NaN, the result is undefined. + :type propagate_nan: tl.PropagateNan + + .. seealso:: :class:`tl.PropagateNan` + """ + x = _to_tensor(x, _builder) + min = _to_tensor(min, _builder) + max = _to_tensor(max, _builder) + x = _promote_bfloat16_to_float32(x, _builder=_builder) + min = _promote_bfloat16_to_float32(min, _builder=_builder) + max = _promote_bfloat16_to_float32(max, _builder=_builder) + + propagate_nan = _constexpr_to_value(propagate_nan) + + return semantic.clamp(x, min, max, propagate_nan, _builder) + + +# ----------------------- +# Reductions +# ----------------------- + + +def _add_reduction_docstr(name: str, return_indices_arg: str = None, tie_break_arg: str = None) -> Callable[[T], T]: + + def _decorator(func: T) -> T: + docstr = """ + Returns the {name} of all elements in the :code:`input` tensor along the provided :code:`axis` + + :param input: the input values + :param axis: the dimension along which the reduction should be done + :param keep_dims: if true, keep the reduced dimensions with length 1""" + if return_indices_arg is not None: + docstr += f""" + :param {return_indices_arg}: if true, return index corresponding to the {name} value""" + if tie_break_arg is not None: + docstr += f""" + :param {tie_break_arg}: if true, return the left-most indices in case of ties for values that aren't NaN""" + + func.__doc__ = docstr.format(name=name) + return func + + return _decorator + + +@contextmanager +def _insertion_guard(builder): + ip = builder.get_insertion_point() + yield + builder.restore_insertion_point(ip) + + +@_tensor_member_fn +@builtin +def reduce(input, axis, combine_fn, keep_dims=False, _builder=None, _generator=None): + """Applies the combine_fn to all elements in :code:`input` tensors along the provided :code:`axis` + + :param input: the input tensor, or tuple of tensors + :param axis: the dimension along which the reduction should be done. If None, reduce all dimensions + :param combine_fn: a function to combine two groups of scalar tensors (must be marked with @triton.jit) + :param keep_dims: if true, keep the reduced dimensions with length 1 + + """ + if isinstance(input, tensor): + return reduce((input, ), axis, combine_fn, keep_dims=keep_dims, _builder=_builder, _generator=_generator)[0] + + def make_combine_region(reduce_op): + in_scalar_tys = [t.type.scalar for t in input] + prototype = function_type(in_scalar_tys, in_scalar_tys * 2) + + region = reduce_op.get_region(0) + with _insertion_guard(_builder): + param_types = [ty.to_ir(_builder) for ty in prototype.param_types] + block = _builder.create_block_with_parent(region, param_types) + args = [tensor(block.arg(i), ty) for i, ty in enumerate(prototype.param_types)] + results = _generator.call_JitFunction(combine_fn, args, kwargs={}) + if isinstance(results, tensor): + handles = [results.handle] + else: + handles = [r.handle for r in results] + _builder.create_reduce_ret(*handles) + + def expand_ndims(t, ndims): + for _ in builtins.range(ndims): + t = expand_dims(t, 0, _builder=_builder) + return t + + axis = _constexpr_to_value(axis) + keep_dims = _constexpr_to_value(keep_dims) + if axis is not None: + axis = _wrap_axis(axis, len(input[0].shape)) + ret = semantic.reduction(input, axis, make_combine_region, _builder) + if keep_dims: + if axis is not None: + ret = tuple(expand_dims(t, axis, _builder=_builder) for t in ret) + else: + ret = tuple(expand_ndims(t, len(input[0].shape)) for t in ret) + return ret + + +@builtin +def _promote_bfloat16_to_float32(t, _builder=None): + scalar_ty = t.type.scalar + + # hardware doesn't support FMAX, FMIN, CMP for bfloat16 + if scalar_ty is bfloat16: + return t.to(float32, _builder=_builder) + return t + + +@builtin +def _reduce_with_indices(input, axis, combine_fn, keep_dims=False, _builder=None, _generator=None): + axis = _constexpr_to_value(axis) + n = input.shape[axis] + index = arange(0, n, _builder=_builder) + + if len(input.shape) > 1: + # Broadcast index across the non-reduced axes + axes_to_expand = [constexpr(d) for d in builtins.range(len(input.shape))] + del axes_to_expand[axis] + index = expand_dims(index, axes_to_expand, _builder=_builder) + index = broadcast_to(index, input.shape, _builder=_builder) + + rvalue, rindices = reduce((input, index), axis, combine_fn, keep_dims=keep_dims, _builder=_builder, + _generator=_generator) + return rvalue, rindices + + +# ----------------------- +# Scans +# ----------------------- + + +def _add_scan_docstr(name: str) -> Callable[[T], T]: + + def _decorator(func: T) -> T: + docstr = """ + Returns the {name} of all elements in the :code:`input` tensor along the provided :code:`axis` + + :param input: the input values + :param axis: the dimension along which the scan should be done""" + func.__doc__ = docstr.format(name=name) + return func + + return _decorator + + +@_tensor_member_fn +@builtin +def associative_scan(input, axis, combine_fn, reverse=False, _builder=None, _generator=None): + """Applies the combine_fn to each elements with a carry in :code:`input` tensors along the provided :code:`axis` and update the carry + + :param input: the input tensor, or tuple of tensors + :param axis: the dimension along which the reduction should be done + :param combine_fn: a function to combine two groups of scalar tensors (must be marked with @triton.jit) + :param reverse: apply the associative scan in the reverse direction along axis. + + """ + if isinstance(input, tensor): + return associative_scan((input, ), axis, combine_fn, reverse, _builder=_builder, _generator=_generator)[0] + + def make_combine_region(scan_op): + in_scalar_tys = [t.type.scalar for t in input] + prototype = function_type(in_scalar_tys, in_scalar_tys * 2) + + region = scan_op.get_region(0) + with _insertion_guard(_builder): + param_types = [ty.to_ir(_builder) for ty in prototype.param_types] + block = _builder.create_block_with_parent(region, param_types) + args = [tensor(block.arg(i), ty) for i, ty in enumerate(prototype.param_types)] + results = _generator.call_JitFunction(combine_fn, args, kwargs={}) + if isinstance(results, tensor): + handles = [results.handle] + else: + handles = [r.handle for r in results] + _builder.create_scan_ret(*handles) + + axis = _constexpr_to_value(axis) + if axis is not None: + axis = _wrap_axis(axis, len(input[0].shape)) + return semantic.associative_scan(input, axis, make_combine_region, reverse, _builder) + + +@_tensor_member_fn +@builtin +def histogram(input, num_bins, _builder=None, _generator=None): + """computes an histogram based on input tensor with num_bins bins, the bins have a width of 1 and start at 0. + + :param input: the input tensor + :param num_bins: number of histogram bins + + """ + num_bins = _constexpr_to_value(num_bins) + return semantic.histogram(input, num_bins, _builder) + + +# ----------------------- +# Compiler Hint Ops +# ----------------------- + + +@builtin +def debug_barrier(_builder=None): + ''' + Insert a barrier to synchronize all threads in a block. + ''' + return semantic.debug_barrier(_builder) + + +@builtin +def multiple_of(input, values, _builder=None): + """ + Let the compiler know that the values in :code:`input` are all multiples of :code:`value`. + """ + if isinstance(values, constexpr): + values = [values] + for i, d in enumerate(values): + if not isinstance(d, constexpr): + raise TypeError(f"values element {i} must have type `constexpr`") + if not isinstance(d.value, int): + raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]") + values = [x.value for x in values] + return semantic.multiple_of(input, values) + + +@builtin +def max_contiguous(input, values, _builder=None): + """ + Let the compiler know that the `value` first values in :code:`input` are contiguous. + """ + if isinstance(values, constexpr): + values = [values] + for i, d in enumerate(values): + if not isinstance(d, constexpr): + raise TypeError(f"values element {i} must have type `constexpr`") + if not isinstance(d.value, int): + raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]") + values = [x.value for x in values] + return semantic.max_contiguous(input, values) + + +@builtin +def max_constancy(input, values, _builder=None): + """ + Let the compiler know that the `value` first values in :code:`input` are constant. + + e.g. if :code:`values` is [4], then each group of 4 values in :code:`input` should all be equal, + for example [0, 0, 0, 0, 1, 1, 1, 1]. + """ + if isinstance(values, constexpr): + values = [values] + for i, d in enumerate(values): + if not isinstance(d, constexpr): + raise TypeError(f"values element {i} must have type `constexpr`") + if not isinstance(d.value, int): + raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]") + values = [x.value for x in values] + return semantic.max_constancy(input, values) + + +# ----------------------- +# Debugging functions +# ----------------------- + + +@builtin +def static_print(*values, sep: str = " ", end: str = "\n", file=None, flush=False, _builder=None): + ''' + Print the values at compile time. The parameters are the same as the builtin :code:`print`. + + NOTE: Calling the Python builtin :code:`print` is not the same as calling this, it instead maps to :code:`device_print`, + which has special requirements for the arguments. + + .. highlight:: python + .. code-block:: python + + tl.static_print(f"{BLOCK_SIZE=}") + ''' + pass + + +@builtin +def static_assert(cond, msg="", _builder=None): + ''' + Assert the condition at compile time. Does not require that the :code:`TRITON_DEBUG` environment variable + is set. + + .. highlight:: python + .. code-block:: python + + tl.static_assert(BLOCK_SIZE == 1024) + ''' + pass + + +@builtin +def device_print(prefix, *args, hex=False, _builder=None): + ''' + Print the values at runtime from the device. String formatting does not work for runtime values, so you should + provide the values you want to print as arguments. The first value must be a string, all following values must + be scalars or tensors. + + Calling the Python builtin :code:`print` is the same as calling this function, and the requirements for the arguments will match + this function (not the normal requirements for :code:`print`). + + .. highlight:: python + .. code-block:: python + + tl.device_print("pid", pid) + print("pid", pid) + + On CUDA, printfs are streamed through a buffer of limited size (on one host, + we measured the default as 6912 KiB, but this may not be consistent across + GPUs and CUDA versions). If you notice some printfs are being dropped, you + can increase the buffer size by calling + + .. highlight:: python + .. code-block:: python + + triton.runtime.driver.active.utils.set_printf_fifo_size(size_bytes) + + CUDA may raise an error if you try to change this value after running a + kernel that uses printfs. The value set here may only affect the current + device (so if you have multiple GPUs, you'd need to call it multiple times). + + :param prefix: a prefix to print before the values. This is required to be a string literal. + :param args: the values to print. They can be any tensor or scalar. + :param hex: print all values as hex instead of decimal + ''' + import string + prefix = _constexpr_to_value(prefix) + assert isinstance(prefix, str), f"{prefix} is not string" + b_ascii = True + for ch in prefix: + if ch not in string.printable: + b_ascii = False + break + assert b_ascii, f"{prefix} is not an ascii string" + new_args = [] + for arg in args: + new_args.append(_to_tensor(arg, _builder)) + return semantic.device_print(prefix, new_args, hex, _builder) + + +@builtin +def device_assert(cond, msg="", _builder=None): + ''' + Assert the condition at runtime from the device. Requires that the environment variable :code:`TRITON_DEBUG` + is set to a value besides :code:`0` in order for this to have any effect. + + Using the Python :code:`assert` statement is the same as calling this function, except that the second argument + must be provided and must be a string, e.g. :code:`assert pid == 0, "pid != 0"`. The environment variable must + be set for this :code:`assert` statement to have any effect. + + .. highlight:: python + .. code-block:: python + + tl.device_assert(pid == 0) + assert pid == 0, f"pid != 0" + + :param cond: the condition to assert. This is required to be a boolean tensor. + :param msg: the message to print if the assertion fails. This is required to be a string literal. + ''' + msg = _constexpr_to_value(msg) + import inspect + frame = inspect.currentframe() + module = inspect.getmodule(frame) + # The triton function module doesn't have the name attribute. + # We use this trick to find the caller. + while hasattr(module, "__name__"): + frame = frame.f_back + module = inspect.getmodule(frame) + lineno = 0 + func_name = 'unknown' + file_name = 'unknown' + if frame is not None and frame.f_back is not None: + func_name = frame.f_code.co_name + file_name = frame.f_back.f_code.co_filename + # TODO: The line number currently indicates the line + # where the triton function is called but not where the + # device_assert is called. Need to enhance this. + lineno = frame.f_back.f_lineno + return semantic.device_assert(_to_tensor(cond, _builder), msg, file_name, func_name, lineno, _builder) + + +@builtin +def inline_asm_elementwise(asm: str, constraints: str, args: Sequence, dtype: Union[dtype, Sequence[dtype]], + is_pure: bool, pack: int, _builder=None): + ''' + Execute inline assembly over a tensor. Essentially, this is :code:`map` + where the function is inline assembly. + + The input tensors :code:`args` are implicitly broadcasted to the same shape. + + :code:`dtype` can be a tuple of types, in which case the output is a + tuple of tensors. + + Each invocation of the inline asm processes :code:`pack` elements at a + time. Exactly which set of inputs a block receives is unspecified. + Input elements of size less than 4 bytes are packed into 4-byte + registers. + + This op does not support empty :code:`dtype` -- the inline asm must + return at least one tensor, even if you don't need it. You can work + around this by returning a dummy tensor of arbitrary type; it shouldn't + cost you anything if you don't use it. + + Example using + [PTX](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html) + assembly: + + .. highlight:: python + .. code-block:: python + + @triton.jit + def kernel(A, B, C, D, BLOCK: tl.constexpr): + a = tl.load(A + tl.arange(0, BLOCK)) # uint8 tensor + b = tl.load(B + tl.arange(0, BLOCK)) # float32 tensor + + # For each (a,b) in zip(a,b), perform the following: + # - Let ai be `a` converted to int32. + # - Let af be `a` converted to float. + # - Let m be the max of ai and b. + # - Return ai and mi. + # Do the above 4 elements at a time. + (c, d) = tl.inline_asm_elementwise( + asm=""" + { + // Unpack `a` into `ai`. + .reg .b8 tmp<4>; + mov.b32 {tmp0, tmp1, tmp2, tmp3}, $8; + cvt.u32.u8 $0, tmp0; + cvt.u32.u8 $1, tmp1; + cvt.u32.u8 $2, tmp2; + cvt.u32.u8 $3, tmp3; + } + // Convert `ai` to float. + cvt.rn.f32.s32 $4, $0; + cvt.rn.f32.s32 $5, $1; + cvt.rn.f32.s32 $6, $2; + cvt.rn.f32.s32 $7, $3; + // Take max of `ai` and `b`. + max.f32 $4, $4, $9; + max.f32 $5, $5, $10; + max.f32 $6, $6, $11; + max.f32 $7, $7, $12; + """, + constraints=( + # 8 output registers, namely + # $0=ai0, $1=ai1, $2=ai2, $3=ai3, + # $4=m0, $5=m1, $6=m2, $7=m3. + "=r,=r,=r,=r,=r,=r,=r,=r," + # 5 input registers, namely + # $8=ai, + # $9=b0, $10=b1, $11=b2, $12=b3. + # The four elements from `a` are all packed into one register. + "r,r,r,r,r"), + args=[a, b], + dtype=(tl.int32, tl.float32), + is_pure=True, + pack=4, + ) + tl.store(C + tl.arange(0, BLOCK), c) + tl.store(D + tl.arange(0, BLOCK), d) + + :param asm: assembly to run. Must match target's assembly format. + :param constraints: asm constraints in + [LLVM format](https://llvm.org/docs/LangRef.html#inline-asm-constraint-string) + :param args: the input tensors, whose values are passed to the asm block + :param dtype: the element type(s) of the returned tensor(s) + :param is_pure: if true, the compiler assumes the asm block has no side-effects + :param pack: the number of elements to be processed by one instance of inline assembly + :param _builder: the builder + :return: one tensor or a tuple of tensors of the given dtypes + ''' + asm = _constexpr_to_value(asm) + constraints = _constexpr_to_value(constraints) + pack = _constexpr_to_value(pack) + is_pure = _constexpr_to_value(is_pure) + + # Wrap `dtype` in a tuple if it's not already. + try: + iter(dtype) # type: ignore + has_multiple_outputs = True + except TypeError: + has_multiple_outputs = False + dtype = (dtype, ) # type: ignore + + dtype = typing.cast(Sequence[_DtypeClass], dtype) + + res_tys = dtype + if dispatch_args := [_to_tensor(arg, _builder) for arg in args]: + bin_op_type_checking = partial( + semantic.binary_op_type_checking_impl, + builder=_builder, + arithmetic_check=False, + allow_lhs_ptr=True, + allow_rhs_ptr=True, + ) + broadcast_arg = dispatch_args[0] + # Get the broadcast shape over all the arguments + for item in dispatch_args: + _, broadcast_arg = bin_op_type_checking(item, broadcast_arg) + if broadcast_arg.shape: + # Change the shape of each argument based on the broadcast shape + for i, item in enumerate(dispatch_args): + dispatch_args[i], _ = bin_op_type_checking(item, broadcast_arg) + res_tys = [block_type(dt, broadcast_arg.shape) for dt in dtype] + handles = [t.handle for t in dispatch_args] + call = _builder.create_inline_asm(asm, constraints, handles, [ty.to_ir(_builder) for ty in res_tys], is_pure, pack) + + if not has_multiple_outputs: + return tensor(call.get_result(0), res_tys[0]) + return tuple(tensor(call.get_result(i), ty) for i, ty in enumerate(res_tys)) + + +# ----------------------- +# Iterators +# ----------------------- + + +class static_range: + """ + Iterator that counts upward forever. + + .. highlight:: python + .. code-block:: python + + @triton.jit + def kernel(...): + for i in tl.static_range(10): + ... + :note: This is a special iterator used to implement similar semantics to Python's :code:`range` in the context of + :code:`triton.jit` functions. In addition, it also guides the compiler to unroll the loop aggressively. + :param arg1: the start value. + :param arg2: the end value. + :param step: the step value. + """ + + def __init__(self, arg1, arg2=None, step=None): + assert isinstance(arg1, constexpr) + if step is None: + self.step = constexpr(1) + else: + assert isinstance(step, constexpr) + self.step = step + if arg2 is None: + self.start = constexpr(0) + self.end = arg1 + else: + assert isinstance(arg2, constexpr) + self.start = arg1 + self.end = arg2 + + def __iter__(self): + raise RuntimeError("static_range can only be used in @triton.jit'd functions") + + def __next__(self): + raise RuntimeError("static_range can only be used in @triton.jit'd functions") + + +class range: + """ + Iterator that counts upward forever. + + .. highlight:: python + .. code-block:: python + + @triton.jit + def kernel(...): + for i in tl.range(10, num_stages=3): + ... + :note: This is a special iterator used to implement similar semantics to Python's :code:`range` in the context of + :code:`triton.jit` functions. In addition, it allows user to pass extra attributes to the compiler. + :param arg1: the start value. + :param arg2: the end value. + :param step: the step value. + :param num_stages: pipeline the loop into this many stages (so there are + :code:`num_stages` iterations of the loop in flight at once). + + Note this is subtly different than passing :code:`num_stages` as a + kernel argument. The kernel argument only pipelines loads that feed + into :code:`dot` operations, while this attribute tries to pipeline most + (though not all) loads in this loop. + """ + + def __init__(self, arg1, arg2=None, step=None, num_stages=None): + if step is None: + self.step = constexpr(1) + else: + self.step = step + if arg2 is None: + self.start = constexpr(0) + self.end = arg1 + else: + self.start = arg1 + self.end = arg2 + self.num_stages = num_stages + + def __iter__(self): + raise RuntimeError("tl.range can only be used in @triton.jit'd functions") + + def __next__(self): + raise RuntimeError("tl.range can only be used in @triton.jit'd functions") + + +# ----------------------- +# Extern functions +# ----------------------- + + +def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, ret_shape: tuple, + is_pure: bool, _builder=None): + ''' + Dispatch a function to a library + :param func: the function to dispatch + :param lib_name: the name of the library + :param lib_path: the path of the library + :param args: the arguments of the function + :param arg_type_symbol_dict: the type of the arguments + :param ret_shape: the shape of the return value + :param _builder: the builder + :return: the return value of the function + ''' + if len(arg_type_symbol_dict) == 0: + raise ValueError("arg_type_symbol_dict is empty") + + num_args = len(list(arg_type_symbol_dict.keys())[0]) + if len(args) != num_args: + raise ValueError(f"length of input args does not match." + f"Expect {len(args)}, got {num_args}") + + arg_types = [] + arg_list = [] + for arg in args: + if isinstance(arg, tensor): + arg_types.append(arg.dtype) + arg_list.append(arg.handle) + else: + arg_types.append(type(arg)) + arg_list.append(arg) + arg_types = tuple(arg_types) + + if arg_types not in arg_type_symbol_dict: + raise ValueError(f"input arg type does not match." + f"Expect one of {arg_type_symbol_dict.keys()}, got {arg_types}") + else: + symbol = arg_type_symbol_dict[arg_types][0] + ret_type = arg_type_symbol_dict[arg_types][1] + if ret_shape: + ret_type = block_type(ret_type, ret_shape) + return tensor(func(lib_name, lib_path, symbol, arg_list, ret_type.to_ir(_builder), is_pure), ret_type) + + +@builtin +def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, is_pure: bool, + _builder=None): + ''' + Dispatch an elementwise function to a library + :param lib_name: the name of the library + :param lib_path: the path of the library + :param args: the arguments of the function + :param arg_type_symbol_dict: the type of the arguments + :param is_pure: whether the function is pure + :param _builder: the builder + :return: the return value of the function + ''' + dispatch_args = args.copy() + all_scalar = True + ret_shape = None + arg_types = [] + for i in builtins.range(len(dispatch_args)): + dispatch_args[i] = _to_tensor(dispatch_args[i], _builder) + arg_types.append(dispatch_args[i].dtype) + if dispatch_args[i].type.is_block(): + all_scalar = False + if len(arg_types) > 0: + arg_types = tuple(arg_types) + arithmetic_check = True + # If there's a type tuple that is not supported by the library, we will do arithmetic check + if arg_types in arg_type_symbol_dict: + arithmetic_check = False + broadcast_arg = dispatch_args[0] + # Get the broadcast shape over all the arguments + for item in dispatch_args: + _, broadcast_arg = semantic.binary_op_type_checking_impl(item, broadcast_arg, _builder, + arithmetic_check=arithmetic_check) + # Change the shape of each argument based on the broadcast shape + for i in builtins.range(len(dispatch_args)): + dispatch_args[i], _ = semantic.binary_op_type_checking_impl(dispatch_args[i], broadcast_arg, _builder, + arithmetic_check=arithmetic_check) + if not all_scalar: + ret_shape = broadcast_arg.shape + func = _builder.create_extern_elementwise + return dispatch(func, lib_name, lib_path, dispatch_args, arg_type_symbol_dict, ret_shape, is_pure, _builder) + + +def binary_op_type_legalization(lhs, rhs, builder): + ''' + Convert both operands to a single common type + :param lhs: the left operand + :param rhs: the right operand + :param builder: the builder + ''' + return semantic.binary_op_type_checking_impl(lhs, rhs, builder) + + +def extern(fn): + """A decorator for external functions.""" + return builtin(fn) diff --git a/lib/python3.10/site-packages/triton/language/extra/libdevice.py b/lib/python3.10/site-packages/triton/language/extra/libdevice.py new file mode 100644 index 0000000000000000000000000000000000000000..625cf3957e56c0abce80cbaa2f5fa79362079770 --- /dev/null +++ b/lib/python3.10/site-packages/triton/language/extra/libdevice.py @@ -0,0 +1,1213 @@ +from .cuda import libdevice as cuda_libdevice +from .hip import libdevice as hip_libdevice +from triton.language import core +from functools import wraps +from typing import TypeVar + +T = TypeVar('T') + + +def dispatch(fn: T) -> T: + """Dispatch a function to a correct implementation.""" + assert callable(fn) + + @wraps(fn) + def wrapper(*args, **kwargs): + _backend = kwargs["_builder"].options.backend_name + if _backend == 'cuda': + _curr_libdevice_module = cuda_libdevice + elif _backend == 'hip': + _curr_libdevice_module = hip_libdevice + else: + raise RuntimeError('unknown backend') + + try: + _impl = getattr(_curr_libdevice_module, fn.__name__) + except AttributeError: + raise RuntimeError(f'`{_backend}` does not provide support for `{fn.__name__}` extra function') + + return _impl(*args, **kwargs) + + return wrapper + + +@core.extern +@dispatch +def clz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def popc(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def byte_perm(arg0, arg1, arg2, _builder=None): + ... + + +@core.extern +@dispatch +def mulhi(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def mul24(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def brev(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def sad(arg0, arg1, arg2, _builder=None): + ... + + +@core.extern +@dispatch +def abs(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def floor(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def rcp64h(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def rsqrt(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ceil(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def trunc(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def exp2(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def saturatef(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def fma_rn(arg0, arg1, arg2, _builder=None): + ... + + +@core.extern +@dispatch +def fma_rz(arg0, arg1, arg2, _builder=None): + ... + + +@core.extern +@dispatch +def fma_rd(arg0, arg1, arg2, _builder=None): + ... + + +@core.extern +@dispatch +def fma_ru(arg0, arg1, arg2, _builder=None): + ... + + +@core.extern +@dispatch +def fast_dividef(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def div_rn(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def div_rz(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def div_rd(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def div_ru(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def rcp_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def rcp_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def rcp_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def rcp_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def sqrt_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def sqrt_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def sqrt_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def sqrt_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def sqrt(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def add_rn(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def add_rz(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def add_rd(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def add_ru(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def mul_rn(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def mul_rz(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def mul_rd(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def mul_ru(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def double2float_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2float_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2float_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2float_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2int_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2int_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2int_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2int_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2uint_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2uint_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2uint_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2uint_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def int2double_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def uint2double_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2int_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2int_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2int_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2int_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2uint_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2uint_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2uint_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2uint_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def int2float_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def int2float_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def int2float_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def int2float_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def uint2float_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def uint2float_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def uint2float_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def uint2float_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def hiloint2double(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def double2loint(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2hiint(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2ll_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2ll_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2ll_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2ll_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2ull_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2ull_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2ull_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2ull_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2ll_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2ll_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2ll_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2ll_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2ull_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2ull_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2ull_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2ull_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ll2float_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ll2float_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ll2float_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ll2float_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ull2float_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ull2float_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ull2float_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ull2float_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ll2double_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ll2double_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ll2double_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ll2double_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ull2double_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ull2double_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ull2double_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ull2double_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def int_as_float(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float_as_int(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def uint_as_float(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float_as_uint(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def longlong_as_double(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double_as_longlong(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def fast_sinf(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def fast_cosf(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def fast_log2f(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def fast_logf(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def fast_expf(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def fast_tanf(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def fast_exp10f(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def fast_log10f(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def fast_powf(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def hadd(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def rhadd(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def sub_rn(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def sub_rz(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def sub_rd(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def sub_ru(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def rsqrt_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ffs(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def rint(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def llrint(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def nearbyint(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def isnan(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def signbit(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def copysign(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def finitef(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def isinf(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def nextafter(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def sin(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def cos(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def sinpi(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def cospi(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def tan(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def log2(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def exp(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def exp10(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def cosh(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def sinh(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def tanh(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def atan2(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def atan(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def asin(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def acos(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def log(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def log10(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def log1p(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def acosh(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def asinh(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def atanh(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def expm1(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def hypot(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def rhypot(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def norm3d(arg0, arg1, arg2, _builder=None): + ... + + +@core.extern +@dispatch +def rnorm3d(arg0, arg1, arg2, _builder=None): + ... + + +@core.extern +@dispatch +def norm4d(arg0, arg1, arg2, arg3, _builder=None): + ... + + +@core.extern +@dispatch +def rnorm4d(arg0, arg1, arg2, arg3, _builder=None): + ... + + +@core.extern +@dispatch +def cbrt(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def rcbrt(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def j0(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def j1(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def y0(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def y1(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def yn(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def jn(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def cyl_bessel_i0(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def cyl_bessel_i1(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def erf(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def erfinv(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def erfc(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def erfcx(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def erfcinv(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def normcdfinv(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def normcdf(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def lgamma(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ldexp(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def scalbn(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def fmod(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def remainder(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def fma(arg0, arg1, arg2, _builder=None): + ... + + +@core.extern +@dispatch +def pow(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def tgamma(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def round(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def llround(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def fdim(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def ilogb(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def logb(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def isfinited(arg0, _builder=None): + ... diff --git a/lib/python3.10/site-packages/triton/language/math.py b/lib/python3.10/site-packages/triton/language/math.py new file mode 100644 index 0000000000000000000000000000000000000000..de5b5be6b32364111cbc997873296cf987d8349a --- /dev/null +++ b/lib/python3.10/site-packages/triton/language/math.py @@ -0,0 +1,250 @@ +from . import core +from . import semantic +from functools import wraps +from typing import List + +T = core.TypeVar('T') + + +def _check_dtype(dtypes: List[str]) -> T: + """ + We're following libdevice's convention to check accepted data types for math functions. + It is not a good practice to support all data types as accelerators/GPUs don't support + many float16 and bfloat16 math operations. + We should let the users know that they are using and invoke explicit cast to convert + the data type to the supported one. + """ + + def wrapper(fn): + + @wraps(fn) + def check(*args, **kwargs): + # concatenate args and kwargs + all_args = list(args) + list(kwargs.values()) + for arg in [a for a in all_args if isinstance(a, core.tensor)]: + if arg.type.scalar.name not in dtypes: + raise ValueError(f"Expected dtype {dtypes} but got {arg.type.scalar.name}") + return fn(*args, **kwargs) + + return check + + return wrapper + + +def _add_math_1arg_docstr(name: str) -> core.Callable[[T], T]: + + def _decorator(func: T) -> T: + docstr = """ + Computes the element-wise {name} of :code:`x`. + + :param x: the input values + :type x: Block + """ + func.__doc__ = docstr.format(name=name) + return func + + return _decorator + + +def _add_math_2arg_docstr(name: str) -> core.Callable[[T], T]: + + def _decorator(func: T) -> T: + docstr = """ + Computes the element-wise {name} of :code:`x` and :code:`y`. + + :param x: the input values + :type x: Block + :param y: the input values + :type y: Block + """ + func.__doc__ = docstr.format(name=name) + return func + + return _decorator + + +def _add_math_3arg_docstr(name: str) -> core.Callable[[T], T]: + + def _decorator(func: T) -> T: + docstr = """ + Computes the element-wise {name} of :code:`x`, :code:`y`, and :code:`z`. + + :param x: the input values + :type x: Block + :param y: the input values + :type y: Block + :param z: the input values + :type z: Block + """ + func.__doc__ = docstr.format(name=name) + return func + + return _decorator + + +@core.builtin +@_check_dtype(dtypes=["int32", "int64", "uint32", "uint64"]) +@_add_math_2arg_docstr("most significant N bits of the 2N-bit product") +def umulhi(x, y, _builder=None): + x = core._to_tensor(x, _builder) + y = core._to_tensor(y, _builder) + x, y = core.binary_op_type_legalization(x, y, _builder) + return core.tensor(_builder.create_umulhi(x.handle, y.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("exponential") +@core._tensor_member_fn +def exp(x, _builder=None): + x = core._to_tensor(x, _builder) + return core.tensor(_builder.create_exp(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("exponential (base 2)") +@core._tensor_member_fn +def exp2(x, _builder=None): + x = core._to_tensor(x, _builder) + return core.tensor(_builder.create_exp2(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("natural logarithm") +@core._tensor_member_fn +def log(x, _builder=None): + x = core._to_tensor(x, _builder) + return core.tensor(_builder.create_log(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("logarithm (base 2)") +@core._tensor_member_fn +def log2(x, _builder=None): + x = core._to_tensor(x, _builder) + return core.tensor(_builder.create_log2(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("cosine") +@core._tensor_member_fn +def cos(x, _builder=None): + x = core._to_tensor(x, _builder) + return core.tensor(_builder.create_cos(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("sine") +@core._tensor_member_fn +def sin(x, _builder=None): + x = core._to_tensor(x, _builder) + return core.tensor(_builder.create_sin(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("fast square root") +@core._tensor_member_fn +def sqrt(x, _builder=None): + x = core._to_tensor(x, _builder) + return core.tensor(_builder.create_sqrt(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32"]) +@_add_math_1arg_docstr("precise square root (rounding to nearest)") +@core._tensor_member_fn +def sqrt_rn(x, _builder=None): + x = core._to_tensor(x, _builder) + return core.tensor(_builder.create_precise_sqrt(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("inverse square root") +@core._tensor_member_fn +def rsqrt(x, _builder=None): + x = core._to_tensor(x, _builder) + return core.tensor(_builder.create_rsqrt(x.handle), x.type) + + +@core.builtin +@_add_math_1arg_docstr("absolute value") +@core._tensor_member_fn +def abs(x, _builder=None): + x = core._to_tensor(x, _builder) + dtype = x.dtype + if dtype.is_fp8e4b15(): + mask = core.full(x.shape, 0x7F, core.int8, _builder=_builder) + return core.tensor(_builder.create_and(x.handle, mask.handle), x.type) + elif dtype.is_floating(): + return core.tensor(_builder.create_fabs(x.handle), x.type) + elif dtype.is_int_signed(): + return core.tensor(_builder.create_iabs(x.handle), x.type) + elif dtype.is_int_unsigned(): + return x # no-op + else: + assert False, f"Unexpected dtype {dtype}" + + +@core.builtin +@_add_math_2arg_docstr("fast division") +def fdiv(x, y, ieee_rounding=False, _builder=None): + ieee_rounding = core._constexpr_to_value(ieee_rounding) + x = core._to_tensor(x, _builder) + y = core._to_tensor(y, _builder) + return semantic.fdiv(x, y, ieee_rounding, _builder) + + +@core.builtin +@_check_dtype(dtypes=["fp32"]) +@_add_math_2arg_docstr("precise division (rounding to nearest)") +def div_rn(x, y, _builder=None): + x = core._to_tensor(x, _builder) + y = core._to_tensor(y, _builder) + x, y = core.binary_op_type_legalization(x, y, _builder) + return core.tensor(_builder.create_precise_divf(x.handle, y.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("error function") +@core._tensor_member_fn +def erf(x, _builder=None): + x = core._to_tensor(x, _builder) + return core.tensor(_builder.create_erf(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("floor") +@core._tensor_member_fn +def floor(x, _builder=None): + x = core._to_tensor(x, _builder) + return core.tensor(_builder.create_floor(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("ceil") +@core._tensor_member_fn +def ceil(x, _builder=None): + x = core._to_tensor(x, _builder) + return core.tensor(_builder.create_ceil(x.handle), x.type) + + +@core.builtin +@_add_math_3arg_docstr("fused multiply-add") +def fma(x, y, z, _builder=None): + x = core._to_tensor(x, _builder) + y = core._to_tensor(y, _builder) + z = core._to_tensor(z, _builder) + x, y = core.binary_op_type_legalization(x, y, _builder) + z, x = core.binary_op_type_legalization(z, x, _builder) + z, y = core.binary_op_type_legalization(z, y, _builder) + return core.tensor(_builder.create_fma(x.handle, y.handle, z.handle), x.type) diff --git a/lib/python3.10/site-packages/triton/language/random.py b/lib/python3.10/site-packages/triton/language/random.py new file mode 100644 index 0000000000000000000000000000000000000000..430aeb09e2e8af2688b9b6615a3bcb967c2d87bf --- /dev/null +++ b/lib/python3.10/site-packages/triton/language/random.py @@ -0,0 +1,207 @@ +from ..runtime.jit import jit +from . import core as tl +from . import math + +N_ROUNDS_DEFAULT = 10 # Default number of rounds for philox + +# ------------------- +# randint +# ------------------- + + +@jit +def philox_impl(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + """ + Run `n_rounds` rounds of Philox for state (c0, c1, c2, c3) and key (k0, k1). + """ + if c0.dtype == tl.uint32: + PHILOX_KEY_A: tl.constexpr = 0x9E3779B9 + PHILOX_KEY_B: tl.constexpr = 0xBB67AE85 + PHILOX_ROUND_A: tl.constexpr = 0xD2511F53 + PHILOX_ROUND_B: tl.constexpr = 0xCD9E8D57 + else: + tl.static_assert(c0.dtype == tl.uint64, "dtype not supported in philox_impl") + PHILOX_KEY_A: tl.constexpr = 0x9E3779B97F4A7C15 + PHILOX_KEY_B: tl.constexpr = 0xBB67AE8584CAA73B + PHILOX_ROUND_A: tl.constexpr = 0xD2E7470EE14C6C93 + PHILOX_ROUND_B: tl.constexpr = 0xCA5A826395121157 + + for _ in tl.static_range(n_rounds): + # for _ in range(n_rounds): + # update random state + A = PHILOX_ROUND_A + B = PHILOX_ROUND_B + _c0, _c2 = c0, c2 + c0 = math.umulhi(B, _c2) ^ c1 ^ k0 + c2 = math.umulhi(A, _c0) ^ c3 ^ k1 + c1 = B * _c2 + c3 = A * _c0 + # raise key + k0 = k0 + PHILOX_KEY_A + k1 = k1 + PHILOX_KEY_B + return c0, c1, c2, c3 + + +@jit +def philox(seed, c0, c1, c2, c3, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + seed = tl.to_tensor(seed) + c0 = tl.to_tensor(c0) + c1 = tl.to_tensor(c1) + c2 = tl.to_tensor(c2) + c3 = tl.to_tensor(c3) + seed = seed.to(tl.uint64) + if tl.constexpr(c0.dtype.primitive_bitwidth) == 32: + int_dtype = tl.uint32 + seed_hi = ((seed >> 32) & 0xffffffff).to(tl.uint32) + seed_lo = (seed & 0xffffffff).to(tl.uint32) + else: + tl.static_assert(tl.constexpr(c0.dtype.primitive_bitwidth) == 64, "bitwidth not supported in philox") + int_dtype = tl.uint64 + seed_hi = tl.full((1, ), 0, dtype=int_dtype) + seed_lo = seed + c0 = c0.to(int_dtype, bitcast=True) + c1 = c1.to(int_dtype, bitcast=True) + c2 = c2.to(int_dtype, bitcast=True) + c3 = c3.to(int_dtype, bitcast=True) + return philox_impl(c0, c1, c2, c3, seed_lo, seed_hi, n_rounds) + + +@jit +def randint(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + """ + Given a :code:`seed` scalar and an :code:`offset` block, returns a single + block of random :code:`int32`. + + If you need multiple streams of random numbers, + using `randint4x` is likely to be faster than calling `randint` 4 times. + + :param seed: The seed for generating random numbers. + :param offset: The offsets to generate random numbers for. + """ + ret, _, _, _ = randint4x(seed, offset, n_rounds) + return ret + + +@jit +def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + """ + Given a :code:`seed` scalar and an :code:`offset` block, returns four + blocks of random :code:`int32`. + + This is the maximally efficient entry point + to Triton's Philox pseudo-random number generator. + + :param seed: The seed for generating random numbers. + :param offsets: The offsets to generate random numbers for. + """ + # _0 = tl.zeros(offset.shape, offset.dtype) + _0 = offset * 0 + return philox(seed, offset, _0, _0, _0, n_rounds) + + +# ------------------- +# rand +# ------------------- + +# @jit +# def uint32_to_uniform_float(x): +# """ +# Numerically stable function to convert a random uint32 into a random float uniformly sampled in [0, 1). +# """ +# two_to_the_minus_32: tl.constexpr = 2.328306e-10 +# return x * two_to_the_minus_32 + + +@jit +def uint_to_uniform_float(x): + """ + Numerically stable function to convert a random uint into a random float uniformly sampled in [0, 1). + """ + # TODO: fix frontend issues and cleanup + # conditions can be simplified + # scale is ((2**23 - 1) / 2**23) * 2**(N_BITS - 1) + if tl.constexpr(x.dtype == tl.uint32) or tl.constexpr(x.dtype == tl.int32): + # maximum value such that `MAX_INT * scale < 1.0` (with float rounding) + x = x.to(tl.int32, bitcast=True) + scale = 4.6566127342e-10 + else: + tl.static_assert(tl.constexpr(x.dtype == tl.uint64) or tl.constexpr(x.dtype == tl.int64)) + x = x.to(tl.int64, bitcast=True) + scale = 1.0842020432385337e-19 + x = tl.where(x < 0, -x - 1, x) + return x * scale + + +@jit +def rand(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + """ + Given a :code:`seed` scalar and an :code:`offset` block, + returns a block of random :code:`float32` in :math:`U(0, 1)`. + + :param seed: The seed for generating random numbers. + :param offsets: The offsets to generate random numbers for. + """ + source = randint(seed, offset, n_rounds) + return uint_to_uniform_float(source) + + +@jit +def rand4x(seed, offsets, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + """ + Given a :code:`seed` scalar and an :code:`offsets` block, + returns 4 blocks of random :code:`float32` in :math:`U(0, 1)`. + + :param seed: The seed for generating random numbers. + :param offsets: The offsets to generate random numbers for. + """ + i1, i2, i3, i4 = randint4x(seed, offsets, n_rounds) + u1 = uint_to_uniform_float(i1) + u2 = uint_to_uniform_float(i2) + u3 = uint_to_uniform_float(i3) + u4 = uint_to_uniform_float(i4) + return u1, u2, u3, u4 + + +# ------------------- +# randn +# ------------------- + + +@jit +def pair_uniform_to_normal(u1, u2): + """Box-Muller transform""" + u1 = tl.maximum(1.0e-7, u1) + th = 6.283185307179586 * u2 + r = math.sqrt(-2.0 * math.log(u1)) + return r * math.cos(th), r * math.sin(th) + + +@jit +def randn(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + """ + Given a :code:`seed` scalar and an :code:`offset` block, + returns a block of random :code:`float32` in :math:`\\mathcal{N}(0, 1)`. + + :param seed: The seed for generating random numbers. + :param offsets: The offsets to generate random numbers for. + """ + i1, i2, _, _ = randint4x(seed, offset, n_rounds) + u1 = uint_to_uniform_float(i1) + u2 = uint_to_uniform_float(i2) + n1, _ = pair_uniform_to_normal(u1, u2) + return n1 + + +@jit +def randn4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + """ + Given a :code:`seed` scalar and an :code:`offset` block, + returns 4 blocks of random :code:`float32` in :math:`\\mathcal{N}(0, 1)`. + + :param seed: The seed for generating random numbers. + :param offsets: The offsets to generate random numbers for. + """ + u1, u2, u3, u4 = rand4x(seed, offset, n_rounds) + n1, n2 = pair_uniform_to_normal(u1, u2) + n3, n4 = pair_uniform_to_normal(u3, u4) + return n1, n2, n3, n4 diff --git a/lib/python3.10/site-packages/triton/language/semantic.py b/lib/python3.10/site-packages/triton/language/semantic.py new file mode 100644 index 0000000000000000000000000000000000000000..83d4dfc8c409031ac6439b1b83d8fc48d947e7c9 --- /dev/null +++ b/lib/python3.10/site-packages/triton/language/semantic.py @@ -0,0 +1,1621 @@ +from __future__ import annotations # remove after python 3.11 + +from typing import List, Optional, Sequence, Tuple, TypeVar + +from .._C.libtriton import ir +from . import core as tl +from . import math + +T = TypeVar('T') + + +class IncompatibleTypeErrorImpl(Exception): + + def __init__(self, type_a, type_b): + self.type_a = type_a + self.type_b = type_b + self.message = "invalid operands of type " + self.type_a.__repr__() + " and " + self.type_b.__repr__() + super(IncompatibleTypeErrorImpl, self).__init__(self.message) + + +# ===----------------------------------------------------------------------===## +# Programming Model +# ===----------------------------------------------------------------------===## + + +def program_id(axis: int, builder: ir.builder) -> tl.tensor: + if axis not in (0, 1, 2): + raise ValueError(f"program_id axis must be 0, 1, or 2 but got {axis}") + return tl.tensor(builder.create_get_program_id(axis), tl.int32) + + +def num_programs(axis: int, builder: ir.builder) -> tl.tensor: + if axis not in (0, 1, 2): + raise ValueError(f"num_programs axis must be 0, 1, or 2 but got {axis}") + return tl.tensor(builder.create_get_num_programs(axis), tl.int32) + + +# ===----------------------------------------------------------------------===// +# Implicit Casting Utilities +# ===----------------------------------------------------------------------===// + + +def integer_promote_impl(a_ty: tl.dtype, b_ty: tl.dtype) -> tl.dtype: + a_rank = a_ty.int_bitwidth + b_rank = b_ty.int_bitwidth + a_sn = a_ty.int_signedness + b_sn = b_ty.int_signedness + # Rules for signedness taken from "Usual arithmetic conversions" on + # https://en.cppreference.com/w/c/language/conversion. + if a_sn == b_sn: + return a_ty if a_rank > b_rank else b_ty + elif a_sn == tl.dtype.SIGNEDNESS.UNSIGNED: + return a_ty if a_rank >= b_rank else b_ty + elif b_sn == tl.dtype.SIGNEDNESS.UNSIGNED: + return b_ty if b_rank >= a_rank else a_ty + raise TypeError(f"unexpected signedness {a_sn} and {b_sn}") + + +def computation_type_impl(a_ty: tl.dtype, b_ty: tl.dtype, div_or_mod: bool) -> tl.dtype: + # 1) if one operand is double, the other is implicitly + # converted to double + if a_ty.is_fp64() or b_ty.is_fp64(): + return tl.float64 + # 2) if one operand is float, the other is implicitly + # converted to float + if a_ty.is_fp32() or b_ty.is_fp32(): + return tl.float32 + # 3 ) if one operand is half, the other is implicitly converted to half + # unless we're doing / or %, which do not exist natively in PTX for fp16. + # Supported PTX op: add, sub, mul, fma, neg, abs, min, max, tanh, ex2, setp + if a_ty.is_fp16() or b_ty.is_fp16(): + if div_or_mod: + return tl.float32 + else: + return tl.float16 + # 4) return bf16 only if both operands are of bf16 + if a_ty.is_bf16() or b_ty.is_bf16(): + if div_or_mod: + return tl.float32 + if a_ty.is_bf16() and b_ty.is_bf16(): + return tl.bfloat16 + return tl.float32 + if not a_ty.is_int() or not b_ty.is_int(): + raise TypeError(f"unexpected type {a_ty} and {b_ty}") + # 5 ) both operands are integer and undergo + # integer promotion + if div_or_mod and a_ty.int_signedness != b_ty.int_signedness: + raise TypeError("Cannot use /, #, or % with " + a_ty.__repr__() + " and " + b_ty.__repr__() + + " because they have different signedness;" + "this is unlikely to result in a useful answer. Cast them to the same signedness.") + return integer_promote_impl(a_ty, b_ty) + + +# ===----------------------------------------------------------------------===// +# Binary Operators +# ===----------------------------------------------------------------------===// + + +def check_ptr_type_impl(type_a: tl.dtype, type_b: tl.dtype, allow_ptr_a: bool) -> None: + if type_a.is_ptr(): + if not allow_ptr_a: + raise IncompatibleTypeErrorImpl(type_a, type_b) + # T* + U* with T != U + if type_b.is_ptr() and (type_a != type_b): + raise IncompatibleTypeErrorImpl(type_a, type_b) + # T* + float + if type_b.is_floating(): + raise IncompatibleTypeErrorImpl(type_a, type_b) + + +def binary_op_type_checking_impl(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder, allow_lhs_ptr=False, + allow_rhs_ptr=False, arithmetic_check=True, + div_or_mod=False) -> Tuple[tl.tensor, tl.tensor]: + # implicit broadcasting + lhs, rhs = broadcast_impl_value(lhs, rhs, builder) + # implicit typecasting + lhs_sca_ty = lhs.type.scalar + rhs_sca_ty = rhs.type.scalar + check_ptr_type_impl(lhs_sca_ty, rhs_sca_ty, allow_lhs_ptr) + check_ptr_type_impl(rhs_sca_ty, lhs_sca_ty, allow_rhs_ptr) + if arithmetic_check and not lhs_sca_ty.is_ptr() and not rhs_sca_ty.is_ptr(): + ret_sca_ty = computation_type_impl(lhs_sca_ty, rhs_sca_ty, div_or_mod) + lhs = cast(lhs, ret_sca_ty, builder) + rhs = cast(rhs, ret_sca_ty, builder) + return lhs, rhs + + +def add(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder, True, True) + input_scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + if input_scalar_ty.is_ptr() and other_scalar_ty.is_ptr(): + raise TypeError("cannot add pointers together") + + # offset + ptr + # ptr + offset + if other_scalar_ty.is_ptr() and not input_scalar_ty.is_ptr(): + input, other = other, input + input_scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + if input_scalar_ty.is_ptr(): + return tl.tensor(builder.create_addptr(input.handle, other.handle), input.type) + # float + float + elif input_scalar_ty.is_floating(): + return tl.tensor(builder.create_fadd(input.handle, other.handle), input.type) + # int + int + elif input_scalar_ty.is_int(): + return tl.tensor(builder.create_add(input.handle, other.handle), input.type) + raise TypeError(f"unexpected type {input_scalar_ty}") + + +def sub(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder, True, False) + scalar_ty = input.type.scalar + # ptr - offset + if scalar_ty.is_ptr(): + return tl.tensor(builder.create_addptr(input.handle, minus(other, builder).handle), input.type) + # float - float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fsub(input.handle, other.handle), input.type) + # int - int + elif scalar_ty.is_int(): + return tl.tensor(builder.create_sub(input.handle, other.handle), input.type) + raise TypeError(f"unexpected type {scalar_ty}") + + +def mul(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float * float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fmul(input.handle, other.handle), input.type) + # * int + elif scalar_ty.is_int(): + return tl.tensor(builder.create_mul(input.handle, other.handle), input.type) + raise TypeError(f"unexpected type {scalar_ty}") + + +def truediv(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True) + input_scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + # float / int + if input_scalar_ty.is_floating() and other_scalar_ty.is_int(): + other = cast(other, input_scalar_ty, builder) + # int / float + elif input_scalar_ty.is_int() and other_scalar_ty.is_floating(): + input = cast(input, other_scalar_ty, builder) + # int / int (cast to tl.float32) + elif input_scalar_ty.is_int() and other_scalar_ty.is_int(): + input = cast(input, tl.float32, builder) + other = cast(other, tl.float32, builder) + # float / float (cast to the highest exponent type) + elif input_scalar_ty.is_floating() and other_scalar_ty.is_floating(): + if input_scalar_ty.fp_mantissa_width > other_scalar_ty.fp_mantissa_width: + other = cast(other, input_scalar_ty, builder) + else: + input = cast(input, other_scalar_ty, builder) + # unreachable + else: + raise TypeError(f"unexpected type {input_scalar_ty}") + return tl.tensor(builder.create_fdiv(input.handle, other.handle), input.type) + + +def floordiv(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True) + input_scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + if input_scalar_ty.is_int() and other_scalar_ty.is_int(): + ret_ty = integer_promote_impl(input_scalar_ty, other_scalar_ty) + input = cast(input, ret_ty, builder) + other = cast(other, ret_ty, builder) + if ret_ty.is_int_signed(): + return tl.tensor(builder.create_sdiv(input.handle, other.handle), input.type) + else: + return tl.tensor(builder.create_udiv(input.handle, other.handle), input.type) + raise TypeError(f"unexpected type {input_scalar_ty}") + + +def fdiv(input: tl.tensor, other: tl.tensor, ieee_rounding: bool, builder: ir.builder) -> tl.tensor: + input_scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + if not input_scalar_ty.is_floating() or not other_scalar_ty.is_floating(): + raise TypeError("both operands of fdiv must have floating scalar type") + input, other = binary_op_type_checking_impl(input, other, builder, False, False, False, True) + ret = builder.create_fdiv(input.handle, other.handle) + return tl.tensor(ret, input.type) + + +def mod(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True) + scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + # float % float + if scalar_ty.is_floating(): + # input - input.div(other, rounding_mode="floor") * other + ret = sub(input, mul(math.floor(fdiv(input, other, False, builder), _builder=builder), other, builder), builder) + return ret + # % int + elif scalar_ty.is_int(): + if scalar_ty.int_signedness != other_scalar_ty.int_signedness: + raise TypeError("Cannot mod " + scalar_ty.__repr__() + " by " + other_scalar_ty.__repr__() + " " + "because they have different signedness;" + "this is unlikely to result in a useful answer. Cast them to the same signedness.") + if scalar_ty.is_int_signed(): + return tl.tensor(builder.create_srem(input.handle, other.handle), input.type) + else: + return tl.tensor(builder.create_urem(input.handle, other.handle), input.type) + raise TypeError(f"unexpected type {scalar_ty}") + + +############## +# other arithmetic ops +############## + + +def minimum(x: tl.tensor, y: tl.tensor, propagate_nan: tl.PropagateNan, builder: ir.builder): + x, y = binary_op_type_checking_impl(x, y, builder) + dtype = x.dtype + if dtype.is_floating(): + if propagate_nan == tl.PropagateNan.ALL: + return tl.tensor(builder.create_minimumf(x.handle, y.handle), x.type) + elif propagate_nan == tl.PropagateNan.NONE: + return tl.tensor(builder.create_minnumf(x.handle, y.handle), x.type) + else: + raise ValueError(f"Unexpected propagate_nan {propagate_nan}") + elif dtype.is_int_signed(): + return tl.tensor(builder.create_minsi(x.handle, y.handle), x.type) + elif dtype.is_int_unsigned(): + return tl.tensor(builder.create_minui(x.handle, y.handle), x.type) + else: + raise TypeError(f"Unexpected dtype {dtype}") + + +def maximum(x: tl.tensor, y: tl.tensor, propagate_nan: tl.PropagateNan, builder: ir.builder): + x, y = binary_op_type_checking_impl(x, y, builder) + dtype = x.dtype + if dtype.is_floating(): + if propagate_nan == tl.PropagateNan.ALL: + return tl.tensor(builder.create_maximumf(x.handle, y.handle), x.type) + elif propagate_nan == tl.PropagateNan.NONE: + return tl.tensor(builder.create_maxnumf(x.handle, y.handle), x.type) + else: + raise ValueError(f"Unexpected propagate_nan {propagate_nan}") + elif dtype.is_int_signed(): + return tl.tensor(builder.create_maxsi(x.handle, y.handle), x.type) + elif dtype.is_int_unsigned(): + return tl.tensor(builder.create_maxui(x.handle, y.handle), x.type) + else: + raise TypeError(f"Unexpected dtype {dtype}") + + +def clamp(x: tl.tensor, min: tl.tensor, max: tl.tensor, propagate_nan: tl.PropagateNan, builder: ir.builder): + min, max = binary_op_type_checking_impl(min, max, builder) + x, min = binary_op_type_checking_impl(x, min, builder) + x, max = binary_op_type_checking_impl(x, max, builder) + + dtype = x.dtype + if dtype.is_floating(): + return tl.tensor(builder.create_clampf(x.handle, min.handle, max.handle, propagate_nan), x.type) + else: + raise TypeError(f"Unexpected dtype {dtype}. Only floating point clamp is supported") + + +############## +# bitwise ops +############## + + +def bitwise_op_type_checking_impl(input: tl.tensor, other: tl.tensor, + builder: ir.builder) -> Tuple[tl.tensor, tl.tensor]: + input, other = binary_op_type_checking_impl(input, other, builder, False, False, False) + input_sca_ty = input.type.scalar + other_sca_ty = other.type.scalar + if not input_sca_ty.is_int() or not other_sca_ty.is_int(): + raise IncompatibleTypeErrorImpl(input_sca_ty, other_sca_ty) + ret_sca_ty = integer_promote_impl(input_sca_ty, other_sca_ty) + if ret_sca_ty != input_sca_ty: + input = cast(input, ret_sca_ty, builder) + if ret_sca_ty != other_sca_ty: + other = cast(other, ret_sca_ty, builder) + return input, other + + +def and_(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = bitwise_op_type_checking_impl(input, other, builder) + return tl.tensor(builder.create_and(input.handle, other.handle), input.type) + + +def or_(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = bitwise_op_type_checking_impl(input, other, builder) + return tl.tensor(builder.create_or(input.handle, other.handle), input.type) + + +def xor_(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = bitwise_op_type_checking_impl(input, other, builder) + return tl.tensor(builder.create_xor(input.handle, other.handle), input.type) + + +def logical_and(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + if not input.type.is_int1(): + input = bitcast(input, tl.dtype("int1"), builder) + if not other.type.is_int1(): + other = bitcast(other, tl.dtype("int1"), builder) + return and_(input, other, builder) + + +def logical_or(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + if not input.type.is_int1(): + input = bitcast(input, tl.dtype("int1"), builder) + if not other.type.is_int1(): + other = bitcast(other, tl.dtype("int1"), builder) + return or_(input, other, builder) + + +def not_(input: tl.tensor, builder: ir.builder): + if not input.type.is_int1(): + input = bitcast(input, tl.dtype("int1"), builder) + return invert(input, builder) + + +def lshr(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = bitwise_op_type_checking_impl(input, other, builder) + return tl.tensor(builder.create_lshr(input.handle, other.handle), input.type) + + +def ashr(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = bitwise_op_type_checking_impl(input, other, builder) + return tl.tensor(builder.create_ashr(input.handle, other.handle), input.type) + + +def shl(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = bitwise_op_type_checking_impl(input, other, builder) + return tl.tensor(builder.create_shl(input.handle, other.handle), input.type) + + +# ===----------------------------------------------------------------------===// +# Unary Operators +# ===----------------------------------------------------------------------===// + + +def plus(input: tl.tensor) -> tl.tensor: + return input + + +def minus(input: tl.tensor, builder: ir.builder) -> tl.tensor: + input_sca_ty = input.type.scalar + if input_sca_ty.is_ptr(): + raise ValueError("wrong type argument to unary minus (" + input_sca_ty.__repr__() + ")") + _0 = tl.tensor(builder.get_null_value(input_sca_ty.to_ir(builder)), input_sca_ty) + return sub(_0, input, builder) + + +def invert(input: tl.tensor, builder: tl.tensor) -> tl.tensor: + input_sca_ty = input.type.scalar + if input_sca_ty.is_ptr() or input_sca_ty.is_floating(): + raise ValueError("wrong type argument to unary invert (" + input_sca_ty.__repr__() + ")") + _1 = tl.tensor(builder.get_all_ones_value(input_sca_ty.to_ir(builder)), input_sca_ty) + return xor_(input, _1, builder) + + +# ===----------------------------------------------------------------------===// +# Comparison Operators +# ===----------------------------------------------------------------------===// +def _bool_like(v: tl.tensor) -> tl.block_type: + if not v.type.is_block(): + return tl.int1 + shape = v.type.shape + return tl.block_type(tl.int1, shape) + + +def greater_than(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float > float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fcmpOGT(input.handle, other.handle), _bool_like(input)) + # > int + elif scalar_ty.is_int(): + if scalar_ty.is_int_signed(): + return tl.tensor(builder.create_icmpSGT(input.handle, other.handle), _bool_like(input)) + else: + return tl.tensor(builder.create_icmpUGT(input.handle, other.handle), _bool_like(input)) + raise TypeError(f"unexpected type {scalar_ty}") + + +def greater_equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float >= float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fcmpOGE(input.handle, other.handle), _bool_like(input)) + # >= int + elif scalar_ty.is_int(): + if scalar_ty.is_int_signed(): + return tl.tensor(builder.create_icmpSGE(input.handle, other.handle), _bool_like(input)) + else: + return tl.tensor(builder.create_icmpUGE(input.handle, other.handle), _bool_like(input)) + raise TypeError(f"unexpected type {scalar_ty}") + + +def less_than(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float < float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fcmpOLT(input.handle, other.handle), _bool_like(input)) + # < int + elif scalar_ty.is_int(): + if scalar_ty.is_int_signed(): + return tl.tensor(builder.create_icmpSLT(input.handle, other.handle), _bool_like(input)) + else: + return tl.tensor(builder.create_icmpULT(input.handle, other.handle), _bool_like(input)) + raise TypeError(f"unexpected type {scalar_ty}") + + +def less_equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float < float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fcmpOLE(input.handle, other.handle), _bool_like(input)) + # < int + elif scalar_ty.is_int(): + if scalar_ty.is_int_signed(): + return tl.tensor(builder.create_icmpSLE(input.handle, other.handle), _bool_like(input)) + else: + return tl.tensor(builder.create_icmpULE(input.handle, other.handle), _bool_like(input)) + raise TypeError(f"unexpected type {scalar_ty}") + + +def equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float == float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fcmpOEQ(input.handle, other.handle), _bool_like(input)) + # == int + elif scalar_ty.is_int(): + return tl.tensor(builder.create_icmpEQ(input.handle, other.handle), _bool_like(input)) + raise TypeError(f"unexpected type {scalar_ty}") + + +def not_equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float == float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fcmpUNE(input.handle, other.handle), _bool_like(input)) + # == int + elif scalar_ty.is_int(): + return tl.tensor(builder.create_icmpNE(input.handle, other.handle), _bool_like(input)) + raise TypeError(f"unexpected type {scalar_ty}") + + +# ===----------------------------------------------------------------------===// +# Block Creation +# ===----------------------------------------------------------------------===// + + +def arange(start: int, end: int, builder: ir.builder) -> tl.tensor: + if not isinstance(start, int) or not isinstance(end, int): + raise ValueError("arange's arguments must be of type tl.constexpr") + is_start_int64 = bool(start >> 32) + is_end_int64 = bool(end >> 32) + if is_start_int64 or is_end_int64: + raise ValueError("arange must fit in int32") + if end <= start: + raise ValueError("arange's end argument must be greater than the start argument") + range = end - start + if (range & (range - 1)) != 0: + raise ValueError("arange's range must be a power of 2") + shape = [range] + ret_ty = tl.block_type(tl.int32, shape) + return tl.tensor(builder.create_make_range(start, end), ret_ty) + + +def full(shape: List[int], value, dtype: tl.dtype, builder: ir.builder) -> tl.tensor: + if isinstance(value, tl.tensor): + assert value.numel.value == 1, "only accepts size-1 tensor" + value = cast(value, dtype, builder) + else: + # scalar + if dtype is None: + raise ValueError("dtype must be specified when value is not a tensor") + if value == 0: + value = builder.get_null_value(dtype.to_ir(builder)) + else: + get_value_fn = getattr(builder, f"get_{dtype.name}") + value = get_value_fn(value) + value = tl.tensor(value, dtype) + + return splat(value, shape, builder) + + +# ===----------------------------------------------------------------------===// +# Shape Manipulation +# ===----------------------------------------------------------------------===// + + +def splat(value: tl.tensor, shape: List[int], builder: ir.builder) -> tl.tensor: + assert not value.type.is_block(), "Cannot splat a block tensor" + if len(shape) == 0: + return value + ret_ty = tl.block_type(value.dtype, shape) + return tl.tensor(builder.create_splat(value.handle, shape), ret_ty) + + +def reshape(input: tl.tensor, dst_shape: List[int], can_reorder: bool, builder: ir.builder) -> tl.tensor: + numel = 1 + for s in dst_shape: + numel *= s + if input.type.numel != numel: + raise ValueError("reshape() cannot change total number of elements in tensor") + ret_ty = tl.block_type(input.type.scalar, dst_shape) + return tl.tensor(builder.create_reshape(input.handle, dst_shape, can_reorder), ret_ty) + + +def expand_dims(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: + dst_shape = [tl._constexpr_to_value(x) for x in input.shape] + dst_shape.insert(axis, 1) + + if not input.type.is_block(): + return splat(input, shape=dst_shape, builder=builder) + + ret_ty = tl.block_type(input.type.scalar, dst_shape) + return tl.tensor(builder.create_expand_dims(input.handle, axis), ret_ty) + + +def cat(lhs: tl.tensor, rhs: tl.tensor, can_reorder: bool, builder: ir.builder) -> tl.tensor: + assert can_reorder, "current implementation of `cat` always may reorder elements" + assert len(lhs.shape) == 1 + ret_type = tl.block_type(lhs.type.scalar, [lhs.shape[0] + rhs.shape[0]]) + return tl.tensor(builder.create_cat(lhs.handle, rhs.handle), ret_type) + + +def join(a: tl.tensor, b: tl.tensor, builder: ir.builder) -> tl.tensor: + a, b = broadcast_impl_value(a, b, builder) + + # The IR can't handle joining two scalars, so upcast them to 1D tensors, + # then downcast the result. + was_rank_1 = a.shape == [] + if was_rank_1: + a = expand_dims(a, 0, builder) + b = expand_dims(b, 0, builder) + + if isinstance(a.shape[-1], tl.constexpr): + two = tl.constexpr(2) + else: + two = 2 + new_shape = a.shape + [two] + + ret_type = tl.block_type(a.type.scalar, new_shape) + ret = tl.tensor(builder.create_join(a.handle, b.handle), ret_type) + + if was_rank_1: + ret = reshape(ret, [2], can_reorder=False, builder=builder) + + return ret + + +def split(a: tl.tensor, builder: ir.builder) -> Tuple[tl.tensor, tl.tensor]: + assert (len(a.shape) > 0) + assert (tl._constexpr_to_value(a.shape[-1]) == 2) + + new_shape = a.shape[:-1] + ret_type = tl.block_type(a.type.scalar, new_shape) + outLHS, outRHS = builder.create_split(a.handle) + return ( + tl.tensor(outLHS, ret_type), + tl.tensor(outRHS, ret_type), + ) + + +def permute(input: tl.tensor, dims: Tuple[int], builder: ir.builder) -> tl.tensor: + if len(input.shape) != len(dims): + raise ValueError("permute dims must have the same length as input shape") + if sorted(tl._constexpr_to_value(d) for d in dims) != list(range(len(dims))): + raise ValueError(f"permute dims must be a permutation of 0, 1, ..., n-1, but were {dims}") + + ret_type = tl.block_type(input.type.scalar, [input.shape[d] for d in dims]) + return tl.tensor(builder.create_trans(input.handle, dims), ret_type) + + +def broadcast_impl_shape(input: tl.tensor, shape: List[int], builder: ir.builder) -> tl.tensor: + if not input.type.is_block(): + ret_ty = tl.block_type(input.type, shape) + return tl.tensor(builder.create_splat(input.handle, shape), ret_ty) + src_shape = input.type.get_block_shapes() + if len(src_shape) != len(shape): + raise ValueError(f"Cannot broadcast, rank mismatch: {src_shape}, {shape}") + if shape == src_shape: + return input + for i, item in enumerate(src_shape): + if shape[i] != item and item != 1: + raise ValueError(f"Cannot broadcast, the expanded size of the tensor ({shape[i]})" + f" must match the existing size ({item}) at non-singleton dimension" + f" {i}: {src_shape}, {shape}") + ret_ty = tl.block_type(input.type.scalar, shape) + return tl.tensor(builder.create_broadcast(input.handle, shape), ret_ty) + + +def broadcast_impl_value(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder) -> tl.tensor: + lhs_ty = lhs.type + rhs_ty = rhs.type + + # make_shape_compatible(block, scalar) + if lhs_ty.is_block() and not rhs_ty.is_block(): + rhs_ty = tl.block_type(rhs_ty.scalar, lhs_ty.shape) + rhs = tl.tensor(builder.create_splat(rhs.handle, lhs_ty.get_block_shapes()), rhs_ty) + # make_shape_compatible(scalar, block) + elif not lhs_ty.is_block() and rhs_ty.is_block(): + lhs_ty = tl.block_type(lhs_ty.scalar, rhs_ty.shape) + lhs = tl.tensor(builder.create_splat(lhs.handle, rhs_ty.get_block_shapes()), lhs_ty) + # make_shape_compatible(block, block) + elif lhs_ty.is_block() and rhs_ty.is_block(): + lhs_shape = lhs_ty.get_block_shapes() + rhs_shape = rhs_ty.get_block_shapes() + + if len(lhs_shape) < len(rhs_shape): + # Add new axes to lhs + for _ in range(len(lhs_shape), len(rhs_shape)): + lhs = tl.tensor(builder.create_expand_dims(lhs.handle, 0), + tl.block_type(lhs_ty.scalar, [1] + lhs_shape)) + lhs_ty = lhs.type + lhs_shape = lhs_ty.get_block_shapes() + elif len(rhs_shape) < len(lhs_shape): + # Add new axes to rhs + for _ in range(len(rhs_shape), len(lhs_shape)): + rhs = tl.tensor(builder.create_expand_dims(rhs.handle, 0), + tl.block_type(rhs_ty.scalar, [1] + rhs_shape)) + rhs_ty = rhs.type + rhs_shape = rhs_ty.get_block_shapes() + assert len(rhs_shape) == len(lhs_shape) + + ret_shape = [] + for i, left in enumerate(lhs_shape): + right = rhs_shape[i] + if left == 1: + ret_shape.append(right) + elif (right == 1) or (right == left): + ret_shape.append(left) + else: + raise ValueError("Cannot make_shape_compatible: incompatible dimensions " + "at index " + str(i) + ": " + str(left) + " and " + str(right)) + if lhs_shape != ret_shape: + ret_ty = tl.block_type(lhs_ty.scalar, ret_shape) + lhs = tl.tensor(builder.create_broadcast(lhs.handle, ret_shape), ret_ty) + if rhs_shape != ret_shape: + ret_ty = tl.block_type(rhs_ty.scalar, ret_shape) + rhs = tl.tensor(builder.create_broadcast(rhs.handle, ret_shape), ret_ty) + # (scalar, scalar) => returns original blocks + return lhs, rhs + + +####### +# cast +####### + + +def _str_to_rounding_mode(rounding_mode: Optional[str]): + if rounding_mode is None: + return None + if rounding_mode == 'rtne': + return ir.ROUNDING_MODE.RTNE + if rounding_mode == 'rtz': + return ir.ROUNDING_MODE.RTZ + raise ValueError(f"Invalid rounding mode: {rounding_mode}. Supported rounding modes are 'rtne' and 'rtz'.") + + +def bitcast(input: tl.tensor, dst_ty: tl.dtype, builder: ir.builder) -> tl.tensor: + src_ty = input.type + if src_ty.is_block(): + dst_ty = tl.block_type(dst_ty.scalar, input.type.get_block_shapes()) + if src_ty == dst_ty: + return input + src_sca_ty = src_ty.scalar + dst_sca_ty = dst_ty.scalar + if src_sca_ty.is_ptr() or dst_sca_ty.is_ptr(): + return cast(input, dst_ty, builder) + # Bitcast + src_bits = src_sca_ty.primitive_bitwidth + dst_bits = dst_sca_ty.primitive_bitwidth + if src_bits != dst_bits: + raise ValueError("Cannot bitcast data-type of size " + str(src_bits) + " to " + "data-type of size " + str(dst_bits)) + return tl.tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)), dst_ty) + + +def cast(input: tl.tensor, dst_ty: tl.dtype, builder: ir.builder, + fp_downcast_rounding: Optional[str] = None) -> tl.tensor: + src_ty = input.type + if isinstance(dst_ty, tl.constexpr): + dst_ty = dst_ty.value + if isinstance(fp_downcast_rounding, tl.constexpr): + fp_downcast_rounding = fp_downcast_rounding.value + if src_ty.is_block(): + dst_ty = tl.block_type(dst_ty.scalar, input.type.get_block_shapes()) + if src_ty == dst_ty: + return input + + src_sca_ty = src_ty.scalar + dst_sca_ty = dst_ty.scalar + + # For fp downcasting default rounding mode should be RTNE, for all other conversions it should + # not be set + fp_downcast_rounding = _str_to_rounding_mode(fp_downcast_rounding) + use_custom_rounding = False + if dst_sca_ty.is_floating() and src_sca_ty.is_floating( + ) and dst_sca_ty.primitive_bitwidth < src_sca_ty.primitive_bitwidth: + if fp_downcast_rounding is None: fp_downcast_rounding = ir.ROUNDING_MODE.RTNE + elif fp_downcast_rounding != ir.ROUNDING_MODE.RTNE: use_custom_rounding = True + else: + if fp_downcast_rounding is not None: + raise ValueError("fp_downcast_rounding should be set only for truncating fp conversions. " + "Source scalar type is " + str(src_sca_ty) + " and destination type is " + str(dst_sca_ty)) + + if (src_sca_ty.is_fp8e4nv() or dst_sca_ty.is_fp8e4nv()): + assert builder.options.allow_fp8e4nv, "fp8e4nv data type is not supported on CUDA arch < 89" + + if (src_sca_ty.is_fp8e4b15() or dst_sca_ty.is_fp8e4b15()): + assert builder.codegen_fns.get( + "convert_custom_types") is not None, "target doesn't provide conversion for this type." + return builder.codegen_fns["convert_custom_types"](input, dst_ty, fp_downcast_rounding, _builder=builder) + # Casting with customized floating types involved: fp8 <=> bf16, fp16, fp32, fp64 + # and non-default rounding modes for downcasting + if (src_sca_ty.is_fp8() and dst_sca_ty.is_floating()) or \ + (src_sca_ty.is_floating() and dst_sca_ty.is_fp8()) or \ + use_custom_rounding: + return tl.tensor(builder.create_fp_to_fp(input.handle, dst_ty.to_ir(builder), fp_downcast_rounding), dst_ty) + + # bf16 <=> (not fp32) + if (src_sca_ty.is_fp16() and not dst_sca_ty.is_fp32()) or \ + (src_sca_ty.is_bf16() and not dst_sca_ty.is_fp32()): + return cast(cast(input, tl.float32, builder), dst_sca_ty, builder) + + # Standard floating types' casting: truncation + # fp64 => fp32, fp16, bf16 + # fp32 => fp16, bf16 + truncate_fp = src_sca_ty.is_floating() and \ + dst_sca_ty.is_floating() and \ + src_sca_ty.primitive_bitwidth > dst_sca_ty.primitive_bitwidth + if truncate_fp: + return tl.tensor(builder.create_fp_trunc(input.handle, dst_ty.to_ir(builder)), dst_ty) + + # Standard floating types' casting: extension + # fp32 => fp64 + # fp16 => fp32, fp64 + # bf16 => fp32, fp64 + ext_fp = src_sca_ty.is_floating() and \ + dst_sca_ty.is_floating() and \ + src_sca_ty.primitive_bitwidth < dst_sca_ty.primitive_bitwidth + if ext_fp: + return tl.tensor(builder.create_fp_ext(input.handle, dst_ty.to_ir(builder)), dst_ty) + + # Casting between integer types + if src_sca_ty.is_int() and dst_sca_ty.is_int() and \ + (src_sca_ty.int_bitwidth != dst_sca_ty.int_bitwidth or src_sca_ty.int_signedness != dst_sca_ty.int_signedness): + sign_extend = src_sca_ty.is_int_signed() and not src_sca_ty.is_bool() + if dst_sca_ty.is_bool(): + ty = input.dtype.to_ir(builder) + _0 = tl.tensor(builder.get_null_value(ty), input.dtype) + return not_equal(input, _0, builder) + else: + return tl.tensor(builder.create_int_cast(input.handle, dst_ty.to_ir(builder), sign_extend), dst_ty) + + # Casting standard floating types to integer types + if src_sca_ty.is_standard_floating() and dst_sca_ty.is_int(): + if dst_sca_ty.is_bool(): + ty = input.dtype.to_ir(builder) + _0 = tl.tensor(builder.get_null_value(ty), input.dtype) + return not_equal(input, _0, builder) + elif dst_sca_ty.is_int_signed(): + return tl.tensor(builder.create_fp_to_si(input.handle, dst_ty.to_ir(builder)), dst_ty) + else: + return tl.tensor(builder.create_fp_to_ui(input.handle, dst_ty.to_ir(builder)), dst_ty) + + # Casting integer types to standard floating types + if src_sca_ty.is_int() and dst_sca_ty.is_standard_floating(): + if src_sca_ty.is_bool() or not src_sca_ty.is_int_signed(): + return tl.tensor(builder.create_ui_to_fp(input.handle, dst_ty.to_ir(builder)), dst_ty) + else: + return tl.tensor(builder.create_si_to_fp(input.handle, dst_ty.to_ir(builder)), dst_ty) + + # Casting pointer types to integer types + if src_sca_ty.is_ptr() and dst_sca_ty.is_int(): + bitwidth = dst_sca_ty.int_bitwidth + if bitwidth == 64: + return tl.tensor(builder.create_ptr_to_int(input.handle, dst_ty.to_ir(builder)), dst_ty) + if bitwidth == 1: + return not_equal(cast(input, tl.int64, builder), tl.tensor(builder.get_int64(0), tl.int64), builder) + + # Casting integer types to pointer types + if src_sca_ty.is_int() and dst_sca_ty.is_ptr(): + return tl.tensor(builder.create_int_to_ptr(input.handle, dst_ty.to_ir(builder)), dst_ty) + + # Casting pointer types to pointer types + if src_sca_ty.is_ptr() and dst_sca_ty.is_ptr(): + return tl.tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)), dst_ty) + + assert False, f'cannot cast {input} to {dst_ty}' + + +# ===----------------------------------------------------------------------===// +# Memory Operators +# ===----------------------------------------------------------------------===// + + +def _str_to_load_cache_modifier(cache_modifier): + cache = ir.CACHE_MODIFIER.NONE # default + if cache_modifier: + if cache_modifier == ".ca": + cache = ir.CACHE_MODIFIER.CA + elif cache_modifier == ".cg": + cache = ir.CACHE_MODIFIER.CG + else: + raise ValueError(f"Cache modifier {cache_modifier} not supported") + return cache + + +def _str_to_store_cache_modifier(cache_modifier): + cache = ir.CACHE_MODIFIER.NONE # default + if cache_modifier: + if cache_modifier == ".wb": + cache = ir.CACHE_MODIFIER.WB + elif cache_modifier == ".cg": + cache = ir.CACHE_MODIFIER.CG + elif cache_modifier == ".cs": + cache = ir.CACHE_MODIFIER.CS + elif cache_modifier == ".wt": + cache = ir.CACHE_MODIFIER.WT + else: + raise ValueError(f"Cache modifier {cache_modifier} not supported") + return cache + + +def _str_to_eviction_policy(eviction_policy): + eviction = ir.EVICTION_POLICY.NORMAL # default + if eviction_policy: + if eviction_policy == "evict_last": + eviction = ir.EVICTION_POLICY.EVICT_LAST + elif eviction_policy == "evict_first": + eviction = ir.EVICTION_POLICY.EVICT_FIRST + else: + raise ValueError(f"Eviction policy {eviction_policy} not supported") + return eviction + + +def _str_to_padding_option(padding_option): + padding = None # default + if padding_option: + if padding_option == "zero": + padding = ir.PADDING_OPTION.PAD_ZERO + elif padding_option == "nan": + padding = ir.PADDING_OPTION.PAD_NAN + else: + raise ValueError(f"Padding option {padding_option} not supported") + return padding + + +def _str_to_sem(sem_option): + sem = ir.MEM_SEMANTIC.ACQUIRE_RELEASE + if sem_option: + if sem_option == "acquire": + sem = ir.MEM_SEMANTIC.ACQUIRE + elif sem_option == "release": + sem = ir.MEM_SEMANTIC.RELEASE + elif sem_option == "acq_rel": + sem = ir.MEM_SEMANTIC.ACQUIRE_RELEASE + elif sem_option == "relaxed": + sem = ir.MEM_SEMANTIC.RELAXED + else: + raise ValueError(f"Memory semantic {sem_option} not supported") + return sem + + +def _str_to_scope(scope_option): + scope = ir.MEM_SYNC_SCOPE.GPU + if scope_option: + if scope_option == "gpu": + scope = ir.MEM_SYNC_SCOPE.GPU + elif scope_option == "cta": + scope = ir.MEM_SYNC_SCOPE.CTA + elif scope_option == "sys": + scope = ir.MEM_SYNC_SCOPE.SYSTEM + else: + raise ValueError(f"Memory semantic {scope_option} not supported") + return scope + + +def _canonicalize_boundary_check(boundary_check, block_shape): + if boundary_check: + if not hasattr(boundary_check, "__iter__"): + boundary_check = [boundary_check] + boundary_check = [elem.value if isinstance(elem, tl.constexpr) else elem for elem in boundary_check] + for dim in boundary_check: + assert isinstance(dim, int) and 0 <= dim < len(block_shape) + assert len(boundary_check) > 0 + assert len(boundary_check) == len(set(boundary_check)), "Duplicate dimension in `boundary_check`" + return sorted(boundary_check) + return () + + +def _load_block_pointer(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder): + # Load by a block pointer: `pointer_type>` + # Block pointer can not have `mask` and `other` arguments + if mask is not None or other is not None: + raise ValueError("`mask` and `other` arguments cannot be specified for loading block pointers") + + elt_ty = ptr.type.element_ty.element_ty + assert elt_ty != tl.int1, "`tl.int1` should be rewrited in `tl.make_block_ptr`" + if elt_ty.is_int() and padding == ir.PADDING_OPTION.PAD_NAN: + raise ValueError("Padding option `nan` is not supported for integer block pointers") + + # `dst_ty` is de-referenced type of the pointer type + dst_ty = ptr.type.element_ty + + # Check `boundary_check` argument + boundary_check = _canonicalize_boundary_check(boundary_check, dst_ty.get_block_shapes()) + + # Build IR + return tl.tensor( + builder.create_tensor_pointer_load(ptr.handle, boundary_check, padding, cache, eviction, is_volatile), dst_ty) + + +def _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder): + # Load by a tensor of pointers or a pointer of scalar: `block_type>` or `pointer_type<>` + if not ptr.type.scalar.is_ptr(): + raise ValueError(f"Unsupported ptr type {ptr.type.__repr__()} in `tl.load`") + + # Check `mask`, `other`, `boundary_check`, and `padding` arguments + if mask is None and other is not None: + raise ValueError("`other` cannot be provided without `mask`") + if padding or boundary_check: + raise ValueError("`padding_option` or `boundary_check` argument is not supported for loading a tensor of" + "pointers or loading a scalar. Because the compiler does not know the boundary; please " + "use block pointers (defined by `make_block_ptr`) instead") + + # For a pointer of scalar, check the type of `mask` and `other` + if not ptr.type.is_block(): + if mask and mask.type.is_block(): + raise ValueError("Mask argument cannot be block type if pointer argument is not a block") + if other and other.type.is_block(): + raise ValueError("Other argument cannot be block type if pointer argument is not a block") + + # Make `mask` and `other` into the same shape as `ptr` + if ptr.type.is_block(): + if mask is not None: + mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder) + if other is not None: + other = broadcast_impl_shape(other, ptr.type.get_block_shapes(), builder) + + # Get `pointer_type` and `elt_ty` + ptr_ty = ptr.type.scalar + elt_ty = ptr_ty.element_ty + + # Treat `pointer_type` as `pointer_type` + if elt_ty == tl.int1: + elt_ty = tl.int8 + ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space) + ptr = cast(ptr, ptr_ty, builder) + + # Cast `other` into `ele_ty` type + if other is not None: + other = cast(other, elt_ty, builder) + + # Create loaded result type `dst_ty` + if ptr.type.is_block(): + shape = ptr.type.get_block_shapes() + dst_ty = tl.block_type(elt_ty, shape) + else: + # Load by de-referencing the pointer of scalar + dst_ty = elt_ty + + # Build IR + if mask is None: + return tl.tensor(builder.create_load(ptr.handle, cache, eviction, is_volatile), dst_ty) + else: + return tl.tensor( + builder.create_masked_load(ptr.handle, mask.handle, other.handle if other else None, cache, eviction, + is_volatile), dst_ty) + + +def load(ptr: tl.tensor, mask: Optional[tl.tensor], other: Optional[tl.tensor], boundary_check: Tuple, + padding_option: str, cache_modifier: str, eviction_policy: str, is_volatile: bool, + builder: ir.builder) -> tl.tensor: + # Cache, eviction and padding options + cache = _str_to_load_cache_modifier(cache_modifier) + eviction = _str_to_eviction_policy(eviction_policy) + padding = _str_to_padding_option(padding_option) + + if ptr.type.is_ptr() and ptr.type.element_ty.is_block(): + # Load by a block pointer: `pointer_type>` + return _load_block_pointer(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder) + else: + # Load by a tensor of pointers or a pointer of scalar: `block_type>` or `pointer_type<>` + return _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder) + + +def descriptor_load(desc_ptr: tl.tensor, offsets, cache_modifier: str, eviction_policy: str, type, + builder: ir.builder) -> tl.tensor: + offsets = _convert_to_ir_values(builder, offsets, require_i64=False) + x = builder.create_descriptor_load(desc_ptr.handle, offsets, type.to_ir(builder), + _str_to_load_cache_modifier(cache_modifier), + _str_to_eviction_policy(eviction_policy)) + return tl.tensor(x, type) + + +def descriptor_store(desc_ptr: tl.tensor, value: tl.tensor, offsets, builder: ir.builder) -> tl.tensor: + offsets = _convert_to_ir_values(builder, offsets, require_i64=False) + return tl.tensor(builder.create_descriptor_store(desc_ptr.handle, value.handle, offsets), tl.void) + + +def _store_block_pointer(ptr, val, mask, boundary_check, cache, eviction, builder): + # Store by a block pointer: `pointer_type>` + # Block pointers can not have the `mask` argument + if mask is not None: + raise ValueError("`mask` and `other` arguments cannot be specified for loading block pointers") + + # Check same shape and element type + block_shape = ptr.type.element_ty.get_block_shapes() + if not val.type.is_block(): + val = broadcast_impl_shape(val, block_shape, builder) + assert val.type.is_block(), "Value argument must be block type or a scalar" + assert block_shape == val.type.get_block_shapes( + ), f"Block shape({block_shape}) and value shape({val.type.get_block_shapes()}) mismatch" + assert ptr.type.element_ty.element_ty == val.type.element_ty, f"Block element type({ptr.type.element_ty.element_ty}) and value element type({val.type.element_ty}) mismatch" + + elt_ty = ptr.type.element_ty.element_ty + assert elt_ty != tl.int1, "`tl.int1` should be rewrited in `tl.make_block_ptr`" + + # Check `boundary_check` argument + boundary_check = _canonicalize_boundary_check(boundary_check, block_shape) + + # Cast to target data type + val = cast(val, elt_ty, builder) + + # Build IR + return tl.tensor(builder.create_tensor_pointer_store(ptr.handle, val.handle, boundary_check, cache, eviction), + tl.void) + + +def _store_legacy(ptr, val, mask, boundary_check, cache, eviction, builder): + # Store by a tensor of pointers or a pointer of scalar: `block_type>` or `pointer_type<>` + if not ptr.type.scalar.is_ptr(): + raise ValueError(f"Unsupported ptr type {ptr.type.__repr__()} in `tl.store`") + + # Check `boundary_check` argument + if boundary_check: + raise ValueError("`boundary_check` argument is not supported for storing a tensor of pointers or storing a " + "scalar. Because the compiler does not know the boundary; please use block pointers " + "(defined by `make_block_ptr`) instead") + + # For a pointer of scalar, check the type of `val` and `mask` + if not ptr.type.is_block(): + if val.type.is_block(): + raise ValueError("Value argument cannot be block type if pointer argument is not a block") + if mask and mask.type.is_block(): + raise ValueError("Mask argument cannot be block type if pointer argument is not a block") + + # Make `mask` and `val` into the same shape as `ptr` + if ptr.type.is_block(): + val = broadcast_impl_shape(val, ptr.type.get_block_shapes(), builder) + if mask is not None: + mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder) + + ptr_ty = ptr.type.scalar + elt_ty = ptr_ty.element_ty + + # Treat `pointer_type` as `pointer_type` + if elt_ty == tl.int1: + elt_ty = tl.int8 + ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space) + ptr = cast(ptr, ptr_ty, builder) + + # Cast to target data type + val = cast(val, elt_ty, builder) + + # Build IR + if not mask: + return tl.tensor(builder.create_store(ptr.handle, val.handle, cache, eviction), tl.void) + if not mask.type.scalar.is_bool(): + raise ValueError("Mask must have boolean scalar type") + return tl.tensor(builder.create_masked_store(ptr.handle, val.handle, mask.handle, cache, eviction), tl.void) + + +def store(ptr: tl.tensor, val: tl.tensor, mask: Optional[tl.tensor], boundary_check, cache_modifier: str, + eviction_policy: str, builder: ir.builder) -> tl.tensor: + # Cache and eviction options + cache = _str_to_store_cache_modifier(cache_modifier) + eviction = _str_to_eviction_policy(eviction_policy) + + if ptr.type.is_const() or ptr.type.scalar.is_const(): + raise ValueError("Cannot store to a constant pointer") + + if ptr.type.is_ptr() and ptr.type.element_ty.is_block(): + # Store by a block pointer: `pointer_type>` + return _store_block_pointer(ptr, val, mask, boundary_check, cache, eviction, builder) + else: + # Store by a tensor of pointers or a pointer of scalar: `block_type>` or `pointer_type<>` + return _store_legacy(ptr, val, mask, boundary_check, cache, eviction, builder) + + +######### +# atomic +######### + + +def atomic_cas(ptr: tl.tensor, cmp: tl.tensor, val: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: + sem = _str_to_sem(sem) + scope = _str_to_scope(scope) + element_ty = ptr.type.scalar.element_ty + if element_ty.primitive_bitwidth not in [16, 32, 64]: + raise ValueError("atomic_cas only supports elements with width {16, 32, 64}") + return tl.tensor(builder.create_atomic_cas(ptr.handle, cmp.handle, val.handle, sem, scope), val.type) + + +def atom_red_typechecking_impl(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, op: str, + builder: ir.builder) -> Tuple[tl.tensor, tl.tensor, tl.tensor]: + if not ptr.type.scalar.is_ptr(): + raise ValueError("Pointer argument of store instruction is " + ptr.type.__repr__()) + if ptr.type.is_const() or ptr.type.element_ty.is_const(): + raise ValueError("Cannot store to a constant pointer") + element_ty = ptr.type.scalar.element_ty + if element_ty is tl.float16 and op != 'add': + raise ValueError("atomic_" + op + " does not support fp16") + if element_ty in [tl.int1, tl.int8, tl.int16, tl.bfloat16]: + raise ValueError("atomic_" + op + " does not support " + str(element_ty)) + if ptr.type.is_block(): + if mask is not None: + mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder) + if val is not None: + val = broadcast_impl_shape(val, ptr.type.get_block_shapes(), builder) + val = cast(val, ptr.type.scalar.element_ty, builder) + if not mask: + mask_ir = builder.get_int1(True) + mask_ty = tl.int1 + if ptr.type.is_block(): + mask_ir = builder.create_splat(mask_ir, ptr.type.get_block_shapes()) + mask_ty = tl.block_type(tl.int1, ptr.type.get_block_shapes()) + mask = tl.tensor(mask_ir, mask_ty) + return ptr, val, mask + + +def atomic_max(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'max', builder) + sem = _str_to_sem(sem) + scope = _str_to_scope(scope) + sca_ty = val.type.scalar + # direct call to atomic_max for integers + if sca_ty.is_int(): + if sca_ty.is_int_signed(): + return tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, ptr.handle, val.handle, mask.handle, sem, scope), val.type) + else: + return tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX, ptr.handle, val.handle, mask.handle, sem, scope), val.type) + # for float + # return atomic_smax(i_ptr, i_val) if val >= 0 + # return atomic_umin(i_ptr, i_val) if val < 0 + if sca_ty not in {tl.float32, tl.float64}: + raise TypeError(f"atomic_max not supported for dtype {sca_ty}") + + zero = full([], 0.0, sca_ty, builder) + + i_type = tl.int32 if sca_ty == tl.float32 else tl.int64 + i_val = bitcast(val, i_type, builder) + i_ptr = bitcast(ptr, tl.pointer_type(i_type, 1), builder) + ui_type = tl.uint32 if sca_ty == tl.float32 else tl.uint64 + ui_val = bitcast(val, ui_type, builder) + ui_ptr = bitcast(ptr, tl.pointer_type(ui_type, 1), builder) + pos = greater_equal(val, zero, builder) + neg = less_than(val, zero, builder) + pos_ret = tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, i_ptr.handle, i_val.handle, + and_(mask, pos, builder).handle, sem, scope), i_val.type) + neg_ret = tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, ui_ptr.handle, ui_val.handle, + and_(mask, neg, builder).handle, sem, scope), ui_val.type) + ret = where(pos, pos_ret, neg_ret, builder) + return bitcast(ret, sca_ty, builder) + + +def atomic_min(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'min', builder) + sem = _str_to_sem(sem) + scope = _str_to_scope(scope) + sca_ty = val.type.scalar + # direct call to atomic_min for integers + if sca_ty.is_int(): + if sca_ty.is_int_signed(): + return tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, ptr.handle, val.handle, mask.handle, sem, scope), val.type) + else: + return tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, ptr.handle, val.handle, mask.handle, sem, scope), val.type) + # for float + # return atomic_smin(i_ptr, i_val) if val >= 0 + # return atomic_umax(i_ptr, i_val) if val < 0 + if sca_ty not in {tl.float32, tl.float64}: + raise TypeError(f"atomic_min not supported for dtype {sca_ty}") + + zero = full([], 0.0, sca_ty, builder) + + i_type = tl.int32 if sca_ty == tl.float32 else tl.int64 + i_val = bitcast(val, i_type, builder) + i_ptr = bitcast(ptr, tl.pointer_type(i_type, 1), builder) + ui_type = tl.uint32 if sca_ty == tl.float32 else tl.uint64 + ui_val = bitcast(val, ui_type, builder) + ui_ptr = bitcast(ptr, tl.pointer_type(ui_type, 1), builder) + pos = greater_equal(val, zero, builder) + neg = less_than(val, zero, builder) + pos_ret = tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, i_ptr.handle, i_val.handle, + and_(mask, pos, builder).handle, sem, scope), i_val.type) + neg_ret = tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX, ui_ptr.handle, ui_val.handle, + and_(mask, neg, builder).handle, sem, scope), ui_ptr.type) + ret = where(pos, pos_ret, neg_ret, builder) + return bitcast(ret, sca_ty, builder) + + +def atomic_add(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'add', builder) + sem = _str_to_sem(sem) + scope = _str_to_scope(scope) + sca_ty = val.type.scalar + op = ir.ATOMIC_OP.FADD if sca_ty.is_floating() else ir.ATOMIC_OP.ADD + return tl.tensor(builder.create_atomic_rmw(op, ptr.handle, val.handle, mask.handle, sem, scope), val.type) + + +def atomic_and(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'and', builder) + sem = _str_to_sem(sem) + scope = _str_to_scope(scope) + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.AND, ptr.handle, val.handle, mask.handle, sem, scope), + val.type) + + +def atomic_or(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'or', builder) + sem = _str_to_sem(sem) + scope = _str_to_scope(scope) + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.OR, ptr.handle, val.handle, mask.handle, sem, scope), + val.type) + + +def atomic_xor(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'xor', builder) + sem = _str_to_sem(sem) + scope = _str_to_scope(scope) + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XOR, ptr.handle, val.handle, mask.handle, sem, scope), + val.type) + + +def atomic_xchg(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, + builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'xchg', builder) + sem = _str_to_sem(sem) + scope = _str_to_scope(scope) + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XCHG, ptr.handle, val.handle, mask.handle, sem, scope), + val.type) + + +# ===----------------------------------------------------------------------===// +# Linear Algebra +# ===----------------------------------------------------------------------===// + + +def _str_to_dot_input_precision(input_precision, builder): + assert input_precision.lower() in builder.options.allowed_dot_input_precisions, \ + f"input_precision must be one of {builder.options.allowed_dot_input_precisions}. Got {input_precision}" + input_precision = input_precision.upper() + if input_precision == "TF32X3": + input_precision = "TF32x3" + return getattr(ir.INPUT_PRECISION, input_precision) + + +def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, input_precision: Optional[str], max_num_imprecise_acc: int, + out_dtype: tl.dtype, builder: ir.builder) -> tl.tensor: + + def assert_dtypes_valid(lhs_dtype, rhs_dtype, options): + if not options.allow_fp8e4nv: + assert not lhs_dtype.is_fp8e4nv() and not rhs_dtype.is_fp8e4nv( + ), "Dot op does not support fp8e4nv on CUDA arch < 90" + if lhs_dtype.is_fp8() and rhs_dtype.is_fp8(): + return + assert lhs_dtype == rhs_dtype, f"First input ({lhs_dtype}) and second input ({rhs_dtype}) must have the same dtype!" + else: + if lhs_dtype.is_int() or rhs_dtype.is_int(): + assert lhs_dtype == rhs_dtype, f"Both operands must be same type. First operand ({lhs_dtype}) and second operand ({rhs_dtype})" + assert lhs_dtype.is_int8() or lhs_dtype.is_uint8( + ), f"Both operands must be either int8 or uint8. Operand type ({lhs_dtype})" + elif lhs_dtype.is_fp8() or rhs_dtype.is_fp8(): + if options.allow_fp8e4b15: + allowed_types = ['fp8e4nv', 'fp8e5', 'fp8e4b15'] + else: + allowed_types = ['fp8e4nv', 'fp8e5'] + + def _validate_dtype(dtype, allowed_types, operand_name): + if not any(getattr(dtype, f'is_{dtype_name}')() for dtype_name in allowed_types): + supported_types = ', '.join(allowed_types) + raise AssertionError(f"Only supports {supported_types}. {operand_name} ({dtype})") + + _validate_dtype(lhs_dtype, allowed_types, "First operand") + _validate_dtype(rhs_dtype, allowed_types, "Second operand") + else: + assert lhs_dtype.is_fp16() or lhs_dtype.is_bf16() or lhs_dtype.is_fp32() or lhs_dtype.is_int1( + ), f"Unsupported dtype {lhs_dtype}" + assert rhs_dtype.is_fp16() or rhs_dtype.is_bf16() or rhs_dtype.is_fp32() or rhs_dtype.is_int1( + ), f"Unsupported dtype {rhs_dtype}" + assert lhs_dtype == rhs_dtype, f"First input ({lhs_dtype}) and second input ({rhs_dtype}) must have the same dtype!" + + assert lhs.type.is_block() and rhs.type.is_block() + assert_dtypes_valid(lhs.dtype, rhs.dtype, builder.options) + if lhs.dtype.is_fp8e4b15() or rhs.dtype.is_fp8e4b15(): + lhs = cast(lhs, tl.float16, builder) + rhs = cast(rhs, tl.float16, builder) + + if input_precision is None: + input_precision = builder.options.default_dot_input_precision + + input_precision = _str_to_dot_input_precision(input_precision, builder) + + lhs_rank = len(lhs.shape) + rhs_rank = len(rhs.shape) + assert lhs_rank == rhs_rank == 2 or lhs_rank == rhs_rank == 3, f"Both inputs must be either 2D or 3D; (lhs: {lhs.shape} vs rhs: {rhs.shape})" + assert lhs.shape[-1].value == rhs.shape[ + -2].value, f"First input shape ({lhs.shape}) and second input shape {rhs.shape} are not compatible for matmul (second index of first shape ({lhs.shape[-1].value}) must be equal to first index of second shape ({rhs.shape[-2].value})" + assert lhs.shape[-2].value >= 16 and lhs.shape[-1].value >= 16 \ + and rhs.shape[-1].value >= 16, \ + f"All non-batch values in both first input shape ({lhs.shape}) and second input shape ({rhs.shape}) must be >= 16!" + if lhs.type.scalar.is_int(): + assert lhs.type.scalar == tl.int8, "only int8 supported!" + # TODO: This is CUDA specific, check if ROCm has the same limitation + assert lhs.shape[1].value >= 32, "small blocks not supported!" + _0 = builder.get_int32(0) + ret_scalar_ty = tl.int32 + elif out_dtype.is_bf16(): + raise ValueError( + "out_dtype=bfloat16 is unsupported. Please use out_dtype=float32/float16 and cast with `.to(tl.bfloat16)`") + elif lhs.type.scalar.is_fp32() or lhs.type.scalar.is_bf16(): + _0 = builder.get_fp32(0) + ret_scalar_ty = tl.float32 + else: + _0 = builder.get_fp16(0) if out_dtype.is_fp16() else builder.get_fp32(0) + ret_scalar_ty = out_dtype + + M = lhs.type.shape[-2] + N = rhs.type.shape[-1] + B = lhs.type.shape[0] if lhs_rank == 3 else None + ret_ty = tl.block_type(ret_scalar_ty, [B, M, N] if B else [M, N]) + if acc is None: + acc_handle = builder.create_splat(_0, [B, M, N] if B else [M, N]) + else: + acc_handle = acc.handle + assert acc.type == ret_ty + + # max_num_imprecise_acc only applies to fp8 -> fp32 dot on sm_90 + if max_num_imprecise_acc is None: + if lhs.dtype.is_fp8() and rhs.dtype.is_fp8(): + max_num_imprecise_acc = builder.options.max_num_imprecise_acc_default + else: + max_num_imprecise_acc = 0 + + return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, acc_handle, input_precision, max_num_imprecise_acc), + ret_ty) + + +# ===----------------------------------------------------------------------===// +# Indexing +# ===----------------------------------------------------------------------===// + + +def where(condition: tl.tensor, x: tl.tensor, y: tl.tensor, builder: ir.builder) -> tl.tensor: + condition = cast(condition, tl.int1, builder) + if condition.type.is_block(): + condition, x = broadcast_impl_value(condition, x, builder) + x, y = broadcast_impl_value(x, y, builder) + condition, x = broadcast_impl_value(condition, x, builder) + + x, y = binary_op_type_checking_impl(x, y, builder, True, True) + if not condition.type.is_block(): + condition, _ = broadcast_impl_value(condition, x, builder) + ret_ty = x.type + return tl.tensor(builder.create_select(condition.handle, x.handle, y.handle), ret_ty) + + +# ===----------------------------------------------------------------------===// +# Reduction +# ===----------------------------------------------------------------------=== + + +def wrap_tensor(x, scalar_ty, ret_shape): + if ret_shape: + res_ty = tl.block_type(scalar_ty, ret_shape) + else: + # 0d-tensor -> scalar + res_ty = scalar_ty + return tl.tensor(x, res_ty) + + +def reduction(inputs: Sequence[tl.tensor], axis: int, region_builder_fn, builder: ir.builder) -> Tuple[tl.tensor, ...]: + if axis is None: + inputs = tuple(reshape(t, [t.numel.value], can_reorder=True, builder=builder) for t in inputs) + axis = 0 + # get result shape + shape = inputs[0].type.shape + rank = len(shape) + assert axis < rank, f"reduction axis must be < inputs rank ({rank})" + ret_shape = [s for i, s in enumerate(shape) if i != axis] + assert all(t.type.shape == shape for t in inputs), "all reduction inputs must have the same shape" + + reduce_op = builder.create_reduce([t.handle for t in inputs], axis) + region_builder_fn(reduce_op) + reduce_op.verify() + + return tuple(wrap_tensor(reduce_op.get_result(i), inputs[i].type.scalar, ret_shape) for i in range(len(inputs))) + + +# ===----------------------------------------------------------------------=== +# Associative Scan +# ===----------------------------------------------------------------------=== + + +def associative_scan(inputs: Sequence[tl.tensor], axis: int, region_builder_fn, reverse: bool, + builder: ir.builder) -> Tuple[tl.tensor, ...]: + shape = inputs[0].type.shape + rank = len(shape) + + assert -rank <= axis < rank, f"scan axis {axis} must be < inputs rank ({rank})" + + if axis < 0: + axis += rank + + for t in inputs: + assert t.type.shape == shape, "all scan inputs must have the same shape" + + scan_op = builder.create_scan([t.handle for t in inputs], axis, reverse) + region_builder_fn(scan_op) + scan_op.verify() + + return tuple(wrap_tensor(scan_op.get_result(i), inputs[i].type.scalar, shape) for i in range(len(inputs))) + + +# ===----------------------------------------------------------------------=== +# Histogram +# ===----------------------------------------------------------------------=== + + +def histogram(input: tl.tensor, num_bins: int, builder: ir.builder) -> tl.tensor: + assert len(input.shape) == 1, "histogram only supports 1D input" + assert input.dtype.is_int(), "histogram only supports integer input" + return tl.tensor(builder.create_histogram(input.handle, num_bins), tl.block_type(tl.int32, (num_bins, ))) + + +## + + +def multiple_of(x: tl.tensor, values: List[int]) -> tl.tensor: + if max(1, len(x.shape)) != len(values): + raise ValueError("Shape of input to multiple_of does not match the length of values") + x.handle.set_attr("tt.divisibility", ir.make_attr(values, x.handle.get_context())) + return x + + +def max_contiguous(x: tl.tensor, values: List[int]) -> tl.tensor: + if len(x.shape) != len(values): + raise ValueError("Shape of input to max_contiguous does not match the length of values") + x.handle.set_attr("tt.contiguity", ir.make_attr(values, x.handle.get_context())) + return x + + +def max_constancy(x: tl.tensor, values: List[int]) -> tl.tensor: + if len(x.shape) != len(values): + raise ValueError("Shape of input to max_constancy does not match the length of values") + x.handle.set_attr("tt.constancy", ir.make_attr(values, x.handle.get_context())) + return x + + +def debug_barrier(builder: ir.builder) -> tl.tensor: + return tl.tensor(builder.create_barrier(), tl.void) + + +def device_print(prefix: str, args: List[tl.tensor], hex: bool, builder: ir.builder) -> tl.tensor: + # It makes sense visually for prefix to end in ": "; make it so. Also, + # non-empty prefixes should start with " ". + if not prefix.endswith(" ") and args: + prefix += " " + if not prefix.endswith(": ") and args: + prefix = prefix[:-1] + ": " + if len(prefix) > 2 and not prefix.startswith(" "): + prefix = " " + prefix + + new_args = [arg.handle for arg in args] + return tl.tensor(builder.create_print(prefix, hex, new_args), tl.void) + + +def device_assert(cond: tl.tensor, msg: str, file_name: str, func_name, lineno: int, builder: ir.builder) -> tl.tensor: + cond_ty = cond.type + if not cond_ty.is_block(): + cond_ty = tl.block_type(cond_ty.scalar, (1, )) + cond = tl.tensor(builder.create_splat(cond.handle, (1, )), cond_ty) + return tl.tensor(builder.create_assert(cond.handle, msg, file_name, func_name, lineno), tl.void) + + +def _convert_elem_to_ir_value(builder, elem, require_i64): + if isinstance(elem, int): + elem = tl.constexpr(elem) + if isinstance(elem, tl.constexpr): + if require_i64: + assert -2**63 <= elem.value < 2**63, f"Block pointers only support 64 bit `shape/strides`, " \ + f"got a value {elem.value} which is out of the range" + return builder.get_int64(elem.value) + else: + assert -2**31 <= elem.value < 2**31, f"Block pointers only support 32 bit `offsets/block_shape`, " \ + f"got a value {elem.value} which is out of the range" + return builder.get_int32(elem.value) + elif isinstance(elem, tl.tensor): + assert elem.numel.value == 1, "Expected a scalar in shape/strides/offsets" + assert elem.dtype.is_int(), "Expected an integer scalar type in shape/strides/offsets" + if elem.dtype != tl.int64 and require_i64: + return builder.create_int_cast(elem.handle, builder.get_int64_ty(), elem.dtype.is_int_signed()) + elif elem.dtype != tl.int32 and not require_i64: + assert False, "Block pointers only support 32 bit `offsets/block_shape`, " \ + "add a `.to(tl.int32)` or use regular indexing for 64 bit support" + return elem.handle + assert False, f"Unsupported element type in shape/strides/offsets: {type(elem)}" + + +def _convert_to_ir_values(builder, list_like, require_i64=True): + if hasattr(list_like, "__iter__"): + return [_convert_elem_to_ir_value(builder, elem, require_i64) for elem in list_like] + return [_convert_elem_to_ir_value(builder, list_like, require_i64)] + + +def make_block_ptr(base: tl.tensor, shape, strides, offsets, block_shape, order, builder: ir.builder) -> tl.tensor: + # Convert dynamic arguments to IR values + # NOTES(Chenggang): current `shape/strides` are `int64_t`, while `offsets/block_shape` are `int32_t` + shape = _convert_to_ir_values(builder, shape) + strides = _convert_to_ir_values(builder, strides) + offsets = _convert_to_ir_values(builder, offsets, require_i64=False) + + # Check `base` type + if not base.type.is_ptr() or base.type.element_ty.is_block(): + raise ValueError("Expected `base` to be a pointer type (but not a block pointer type or others)") + + # Treat `pointer_type` as `pointer_type` + if base.type.element_ty == tl.int1: + base = cast(base, tl.pointer_type(tl.int8, base.type.address_space), builder) + + # Check whether `block_shape` is static + if not hasattr(block_shape, "__iter__"): + block_shape = [block_shape] + block_shape = [elem.value if isinstance(elem, tl.constexpr) else elem for elem in block_shape] + assert all(isinstance(elem, int) and -2**31 <= elem < 2**31 for elem in block_shape), \ + "Expected a list of constant integers (`int32_t` range) in `block_shape`" + + # Check `order` + if not hasattr(order, "__iter__"): + order = [order] + order = [elem.value if isinstance(elem, tl.constexpr) else elem for elem in order] + assert sorted(order) == list(range(len(order))), "Expected a permutation of (0, 1, ..., len(order)-1) in order" + + # Must have same length + assert all(len(block_shape) == len(list_like) for list_like in [shape, strides, offsets, order]), \ + "Expected shape/strides/offsets/block_shape to have the same length" + + # Build value, the type is: + # `pointer_type>` in Python + # `tt.ptr>` in MLIR + handle = builder.create_make_block_ptr(base.handle, shape, strides, offsets, block_shape, order) + return tl.tensor(handle, tl.pointer_type(tl.block_type(base.type.element_ty, block_shape))) + + +def advance(base: tl.tensor, offsets, builder: ir.builder) -> tl.tensor: + # Convert dynamic offsets to IR values + offsets = _convert_to_ir_values(builder, offsets, require_i64=False) + + # Advanced block pointer type is the same as before + return tl.tensor(builder.create_advance(base.handle, offsets), base.type) diff --git a/lib/python3.10/site-packages/triton/language/standard.py b/lib/python3.10/site-packages/triton/language/standard.py new file mode 100644 index 0000000000000000000000000000000000000000..de30cf260bfa4a49d4b9ca0dcc1529f79992522d --- /dev/null +++ b/lib/python3.10/site-packages/triton/language/standard.py @@ -0,0 +1,441 @@ +from __future__ import annotations + +from ..runtime.jit import jit +from . import core +from . import math + +# constexpr utilities (triton metaprogramming sucks) + + +def _unwrap_if_constexpr(o): + return o.value if isinstance(o, core.constexpr) else o + + +def _log2(i: core.constexpr): + log2 = 0 + n = i.value + while n > 1: + n >>= 1 + log2 += 1 + return core.constexpr(log2) + + +def _is_power_of_two(i: core.constexpr): + n = i.value + return core.constexpr((n & (n - 1)) == 0 and n != 0) + + +# ----------------------- +# Standard library +# ----------------------- + + +@core._tensor_member_fn +@jit +def cdiv(x, div): + """ + Computes the ceiling division of :code:`x` by :code:`div` + + :param x: the input number + :type x: Block + :param div: the divisor + :param div: Block + """ + return (x + div - 1) // div + + +@core._tensor_member_fn +@jit +@math._add_math_1arg_docstr("sigmoid") +def sigmoid(x): + return 1 / (1 + math.exp(-x)) + + +@core._tensor_member_fn +@jit +@math._add_math_1arg_docstr("softmax") +def softmax(x, ieee_rounding=False): + z = x - max(x, 0) + num = math.exp(z) + den = sum(num, 0) + return math.fdiv(num, den, ieee_rounding) + + +@core._tensor_member_fn +@jit +def ravel(x): + """ + Returns a contiguous flattened view of :code:`x`. + + :param x: the input tensor + :type x: Block + """ + return core.reshape(x, [x.numel], can_reorder=True) + + +@jit +def swizzle2d(i, j, size_i, size_j, size_g): + """ + Transforms indices of a row-major :code:`size_i * size_j` matrix into those + of one where the indices are col-major for each group of :code:`size_g` + rows. + + For example, for :code:`size_i = size_j = 4` and :code:`size_g = 2`, it will + transform :: + + [[0 , 1 , 2 , 3 ], + [4 , 5 , 6 , 7 ], + [8 , 9 , 10, 11], + [12, 13, 14, 15]] + + into :: + + [[0, 2, 4 , 6 ], + [1, 3, 5 , 7 ], + [8, 10, 12, 14], + [9, 11, 13, 15]] + """ + # "unrolled index in array" + ij = i * size_j + j + # number of elements in `size_g` groups + # of `size_j` columns + size_gj = size_g * size_j + # index of the group in which (i,j) is + group_id = ij // size_gj + # row-index of the first element of this group + off_i = group_id * size_g + # last group may have fewer rows + size_g = core.minimum(size_i - off_i, size_g) + # new row and column indices + new_i = off_i + (ij % size_g) + new_j = (ij % size_gj) // size_g + return new_i, new_j + + +@jit +def zeros(shape, dtype): + """ + Returns a tensor filled with the scalar value 0 for the given :code:`shape` and :code:`dtype`. + + :param shape: Shape of the new array, e.g., (8, 16) or (8, ) + :type shape: tuple of ints + :param dtype: Data-type of the new array, e.g., :code:`tl.float16` + :type dtype: DType + """ + return core.full(shape, 0, dtype) + + +@jit +def zeros_like(input): + """ + Creates a tensor of zeros with the same shape and type as a given tensor. + """ + return zeros(input.shape, input.dtype) + + +# max and argmax + + +@jit +def _argmax_combine(value1, index1, value2, index2, tie_break_left): + if tie_break_left: + tie = value1 == value2 and index1 < index2 + else: + tie = False + gt = value1 > value2 or tie + v_ret = core.where(gt, value1, value2) + i_ret = core.where(gt, index1, index2) + return v_ret, i_ret + + +@jit +def _argmax_combine_tie_break_left(value1, index1, value2, index2): + return _argmax_combine(value1, index1, value2, index2, True) + + +@jit +def _argmax_combine_tie_break_fast(value1, index1, value2, index2): + return _argmax_combine(value1, index1, value2, index2, False) + + +@jit +def _elementwise_max(a, b): + return core.maximum(a, b) + + +@core._tensor_member_fn +@jit +@core._add_reduction_docstr("maximum", return_indices_arg="return_indices", + tie_break_arg="return_indices_tie_break_left") +def max(input, axis=None, return_indices=False, return_indices_tie_break_left=True, keep_dims=False): + input = core._promote_bfloat16_to_float32(input) + if return_indices: + if return_indices_tie_break_left: + return core._reduce_with_indices(input, axis, _argmax_combine_tie_break_left, keep_dims=keep_dims) + else: + return core._reduce_with_indices(input, axis, _argmax_combine_tie_break_fast, keep_dims=keep_dims) + else: + if core.constexpr(input.dtype.primitive_bitwidth) < core.constexpr(32): + if core.constexpr(input.dtype.is_floating()): + input = input.to(core.float32) + else: + assert input.dtype.is_int(), "Expecting input to be integer type" + input = input.to(core.int32) + return core.reduce(input, axis, _elementwise_max, keep_dims=keep_dims) + + +@core._tensor_member_fn +@jit +@core._add_reduction_docstr("maximum index", tie_break_arg="tie_break_left") +def argmax(input, axis, tie_break_left=True, keep_dims=False): + (_, ret) = max(input, axis, return_indices=True, return_indices_tie_break_left=tie_break_left, keep_dims=keep_dims) + return ret + + +# min and argmin + + +@jit +def _argmin_combine(value1, index1, value2, index2, tie_break_left): + if tie_break_left: + tie = value1 == value2 and index1 < index2 + else: + tie = False + lt = value1 < value2 or tie + value_ret = core.where(lt, value1, value2) + index_ret = core.where(lt, index1, index2) + return value_ret, index_ret + + +@jit +def _argmin_combine_tie_break_left(value1, index1, value2, index2): + return _argmin_combine(value1, index1, value2, index2, True) + + +@jit +def _argmin_combine_tie_break_fast(value1, index1, value2, index2): + return _argmin_combine(value1, index1, value2, index2, False) + + +@jit +def _elementwise_min(a, b): + return core.minimum(a, b) + + +@core._tensor_member_fn +@jit +@core._add_reduction_docstr("minimum", return_indices_arg="return_indices", + tie_break_arg="return_indices_tie_break_left") +def min(input, axis=None, return_indices=False, return_indices_tie_break_left=True, keep_dims=False): + input = core._promote_bfloat16_to_float32(input) + if return_indices: + if return_indices_tie_break_left: + return core._reduce_with_indices(input, axis, _argmin_combine_tie_break_left, keep_dims=keep_dims) + else: + return core._reduce_with_indices(input, axis, _argmin_combine_tie_break_fast, keep_dims=keep_dims) + else: + if core.constexpr(input.dtype.primitive_bitwidth) < 32: + if core.constexpr(input.dtype.is_floating()): + input = input.to(core.float32) + else: + assert input.dtype.is_int(), "Expecting input to be integer type" + input = input.to(core.int32) + return core.reduce(input, axis, _elementwise_min, keep_dims=keep_dims) + + +@core._tensor_member_fn +@jit +@core._add_reduction_docstr("minimum index", tie_break_arg="tie_break_left") +def argmin(input, axis, tie_break_left=True, keep_dims=False): + _, ret = min(input, axis, return_indices=True, return_indices_tie_break_left=tie_break_left, keep_dims=keep_dims) + return ret + + +@jit +def _sum_combine(a, b): + return a + b + + +# sum + + +@core._tensor_member_fn +@jit +@core._add_reduction_docstr("sum") +def sum(input, axis=None, keep_dims=False): + input = core._promote_bfloat16_to_float32(input) + return core.reduce(input, axis, _sum_combine, keep_dims=keep_dims) + + +@jit +def _xor_combine(a, b): + return a ^ b + + +# xor sum + + +@core._tensor_member_fn +@core.builtin +@core._add_reduction_docstr("xor sum") +def xor_sum(input, axis=None, keep_dims=False, _builder=None, _generator=None): + scalar_ty = input.type.scalar + if not scalar_ty.is_int(): + raise ValueError("xor_sum only supported for integers") + + input = core._promote_bfloat16_to_float32(input, _builder=_builder) + return core.reduce(input, axis, _xor_combine, keep_dims=keep_dims, _builder=_builder, _generator=_generator) + + +# cumsum + + +@core._tensor_member_fn +@jit +@core._add_scan_docstr("cumsum") +def cumsum(input, axis=0, reverse=False): + # todo rename this to a generic function name + input = core._promote_bfloat16_to_float32(input) + return core.associative_scan(input, axis, _sum_combine, reverse) + + +# cumprod + + +@jit +def _prod_combine(a, b): + return a * b + + +@core._tensor_member_fn +@jit +@core._add_scan_docstr("cumprod") +def cumprod(input, axis=0, reverse=False): + # todo rename this to a generic function name + input = core._promote_bfloat16_to_float32(input) + return core.associative_scan(input, axis, _prod_combine, reverse) + + +# sort + + +@jit +def _compare_and_swap(x, flip, i: core.constexpr, n_dims: core.constexpr): + n_outer: core.constexpr = x.numel >> n_dims + shape: core.constexpr = [n_outer * 2**i, 2, 2**(n_dims - i - 1)] + y = core.reshape(x, shape) + # slice left/right with 'stride' 2**(n_dims - i - 1) + mask = core.arange(0, 2)[None, :, None] + left = core.broadcast_to(sum(y * (1 - mask), 1)[:, None, :], shape) + right = core.broadcast_to(sum(y * mask, 1)[:, None, :], shape) + left = core.reshape(left, x.shape) + right = core.reshape(right, x.shape) + # actual compare-and-swap + idtype = core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True) + ileft = left.to(idtype, bitcast=True) + iright = right.to(idtype, bitcast=True) + ix = x.to(idtype, bitcast=True) + ret = ix ^ core.where((left > right) ^ flip, ileft ^ iright, zeros_like(ix)) + return ret.to(x.dtype, bitcast=True) + + +@jit +def _bitonic_merge(x, stage: core.constexpr, order: core.constexpr, n_dims: core.constexpr): + ''' + order_type 0 == ascending + order_type 1 == descending + order_type 2 == alternating + ''' + n_outer: core.constexpr = x.numel >> n_dims + core.static_assert(stage <= n_dims) + # flip denotes whether to re-arrange sub-sequences of elements in ascending or + # descending order. + # if flip = 00000000... then all elements will be re-arranged ascendingly at this stage + # if flip = 00110011... then all the elements will be re-arranged alternatingly (with + # a stride of 2) at this stage + if order == 2: + shape: core.constexpr = [n_outer * 2**(n_dims - 1 - stage), 2, 2**stage] + flip = core.reshape(core.broadcast_to(core.arange(0, 2)[None, :, None], shape), x.shape) + else: + flip = order + # perform `stage` rounds of `compare-and-swap` + for i in core.static_range(stage): + x = _compare_and_swap(x, flip, i + (n_dims - stage), n_dims) + return x + + +@core._tensor_member_fn +@jit +def sort(x, dim: core.constexpr = None, descending: core.constexpr = core.CONSTEXPR_0): + # handle default dimension or check that it is the most minor dim + _dim: core.constexpr = len(x.shape) - 1 if dim is None else dim + core.static_assert(_dim == len(x.shape) - 1, "only minor dimension is currently supported") + # iteratively run bitonic merge-sort steps + n_dims: core.constexpr = _log2(x.shape[_dim]) + for i in core.static_range(1, n_dims + 1): + x = _bitonic_merge(x, i, 2 if i < n_dims else descending, n_dims) + return x + + +# flip + + +def _get_flip_dim(dim, shape): + dim = _unwrap_if_constexpr(dim) + shape = _unwrap_if_constexpr(shape) + if dim is None: + dim = len(shape) - 1 + assert dim == len(shape) - 1, "Currently only support flipping the last dimension" + return core.constexpr(dim) + + +@core._tensor_member_fn +@jit +def flip(x, dim=None): + """ + Flips a tensor `x` along the dimension `dim`. + + :param x: the first input tensor + :type x: Block + :param dim: the dimension to flip along (currently only final dimension supported) + :type dim: int + """ + core.static_assert(_is_power_of_two(x.shape[_get_flip_dim(dim, x.shape)])) + core.static_assert(_is_power_of_two(x.numel)) + # # reshape the tensor to have all dimensions be 2. + # # TODO: We shouldn't have to change the dimensions not sorted. + steps: core.constexpr = _log2(x.numel) + start: core.constexpr = _log2(x.numel) - _log2(x.shape[_get_flip_dim(dim, x.shape)]) + y = core.reshape(x, [2] * steps) + y = core.expand_dims(y, start) + flip = (core.arange(0, 2)[:, None] == 1 - core.arange(0, 2)) + for i in core.static_range(start, steps): + flip2 = flip + for j in core.static_range(0, steps + 1): + if j != i and j != i + 1: + flip2 = core.expand_dims(flip2, j) + y = sum(y * flip2, i + 1, keep_dims=True) + x = core.reshape(y, x.shape) + return x + + +@jit +def interleave(a, b): + """ + Interleaves the values of two tensors along their last dimension. + + The two tensors must have the same shape. + + Equivalent to `tl.join(a, b).reshape(a.shape[-1:] + [2 * a.shape[-1]])` + """ + c = core.join(a, b) + + assert isinstance(c.shape, list) + if len(c.shape) == 1: + # We must have interleaved two scalars. + return c + else: + # This `else` is necessary because Triton's AST parser doesn't + # understand that if we take the `if` above we definitely don't run this + # `else`. + return core.reshape(c, c.shape[:-2] + [2 * c.shape[-2]]) diff --git a/lib/python3.10/site-packages/triton/ops/__init__.py b/lib/python3.10/site-packages/triton/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..18f1d782d2a023b3cf1bc5643e6fbb45e6f8b4d9 --- /dev/null +++ b/lib/python3.10/site-packages/triton/ops/__init__.py @@ -0,0 +1,7 @@ +# from .conv import _conv, conv +from . import blocksparse +from .cross_entropy import _cross_entropy, cross_entropy +from .flash_attention import attention +from .matmul import _matmul, get_higher_dtype, matmul + +__all__ = ["blocksparse", "_cross_entropy", "cross_entropy", "_matmul", "matmul", "attention", "get_higher_dtype"] diff --git a/lib/python3.10/site-packages/triton/ops/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/triton/ops/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a65e63ce7d0dc0a0c6ba153da2eaaecc145cab7 Binary files /dev/null and b/lib/python3.10/site-packages/triton/ops/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/triton/ops/__pycache__/cross_entropy.cpython-310.pyc b/lib/python3.10/site-packages/triton/ops/__pycache__/cross_entropy.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..368a2dad85448420eaa1430de7aab435f87f69de Binary files /dev/null and b/lib/python3.10/site-packages/triton/ops/__pycache__/cross_entropy.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/triton/ops/__pycache__/flash_attention.cpython-310.pyc b/lib/python3.10/site-packages/triton/ops/__pycache__/flash_attention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a4c1c9419c685dcb19bfb4eed9f4b93a3dc2fdc7 Binary files /dev/null and b/lib/python3.10/site-packages/triton/ops/__pycache__/flash_attention.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/triton/ops/__pycache__/matmul.cpython-310.pyc b/lib/python3.10/site-packages/triton/ops/__pycache__/matmul.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62b87085b32d3a05b58739f4706ba97df33bf31d Binary files /dev/null and b/lib/python3.10/site-packages/triton/ops/__pycache__/matmul.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/triton/ops/__pycache__/matmul_perf_model.cpython-310.pyc b/lib/python3.10/site-packages/triton/ops/__pycache__/matmul_perf_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..84f0def7ea7c1ba5121063677cbc9f87f3e8ed32 Binary files /dev/null and b/lib/python3.10/site-packages/triton/ops/__pycache__/matmul_perf_model.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/triton/ops/blocksparse/__init__.py b/lib/python3.10/site-packages/triton/ops/blocksparse/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6b24b5377fab564ac951b827e398c744524d711c --- /dev/null +++ b/lib/python3.10/site-packages/triton/ops/blocksparse/__init__.py @@ -0,0 +1,7 @@ +from .matmul import matmul +from .softmax import softmax + +__all__ = [ + "matmul", + "softmax", +] diff --git a/lib/python3.10/site-packages/triton/ops/blocksparse/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/triton/ops/blocksparse/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8874bf1155847e57f0e764b3eacdc30209f51db9 Binary files /dev/null and b/lib/python3.10/site-packages/triton/ops/blocksparse/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/triton/ops/blocksparse/__pycache__/matmul.cpython-310.pyc b/lib/python3.10/site-packages/triton/ops/blocksparse/__pycache__/matmul.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6315461989f89a9c21ce43896022d4caeaed1806 Binary files /dev/null and b/lib/python3.10/site-packages/triton/ops/blocksparse/__pycache__/matmul.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/triton/ops/blocksparse/__pycache__/softmax.cpython-310.pyc b/lib/python3.10/site-packages/triton/ops/blocksparse/__pycache__/softmax.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd0d9ad0230fcb116b1bca84bb7a4eab8e8d78dc Binary files /dev/null and b/lib/python3.10/site-packages/triton/ops/blocksparse/__pycache__/softmax.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/triton/ops/blocksparse/matmul.py b/lib/python3.10/site-packages/triton/ops/blocksparse/matmul.py new file mode 100644 index 0000000000000000000000000000000000000000..098e1543809e9dcea3a77c3a38538bcd34153eb8 --- /dev/null +++ b/lib/python3.10/site-packages/triton/ops/blocksparse/matmul.py @@ -0,0 +1,432 @@ +import torch + +from ... import cdiv, heuristics, jit +from ... import language as tl + +# ******************************************************** +# -------------------------------------------------------- +# Sparse = Dense x Dense (SDD) +# This operation uses super-blocking to make sure that +# it's done efficiently when small blocks can be grouped +# together +# -------------------------------------------------------- +# ******************************************************** + + +@heuristics({ + 'EVEN_K': lambda nargs: nargs['K'] % nargs['TILE_K'] == 0, +}) +@jit +def _sdd_kernel(A, B, C, # + stride_za, stride_ha, stride_ma, stride_ak, # + stride_zb, stride_hb, stride_bk, stride_nb, # + stride_zc, stride_hc, stride_mc, stride_nc, # + K, grid_offset, lut, # + TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr, # + BLOCK: tl.constexpr, EVEN_K: tl.constexpr # + ): + # ------------ # + # - Prologue - # + # ------------ # + block_id = tl.program_id(0) + grid_offset + lut += block_id * 3 + # offsets + off_z = tl.program_id(2) # batch + off_h = tl.load(lut + 0) # head + + # initialize pointers to A + start_am = tl.load(lut + 1) + offs_am = start_am * BLOCK + (tl.arange(0, TILE_M) % BLOCK) + offs_ak = tl.arange(0, TILE_K) + a_ptrs = A \ + + off_z * stride_za \ + + off_h * stride_ha \ + + offs_am[:, None] * stride_ma \ + + offs_ak[None, :] * stride_ak + # initialize pointers to B + start_bn = tl.load(lut + 2) + offs_bn = start_bn * BLOCK + (tl.arange(0, TILE_N) % BLOCK) + offs_bk = tl.arange(0, TILE_K) + b_ptrs = B \ + + off_z * stride_zb \ + + off_h * stride_hb \ + + offs_bn[None, :] * stride_nb \ + + offs_bk[:, None] * stride_bk + # ---------------- # + # Inner Loop # + # ---------------- # + acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32) + for k in range(K, 0, -TILE_K): + if EVEN_K: + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) + else: + a = tl.load(a_ptrs, mask=offs_ak[None, :] < k, other=0.) + b = tl.load(b_ptrs, mask=offs_bk[:, None] < k, other=0.) + acc += tl.dot(a, b, out_dtype=tl.float32) + a_ptrs += TILE_K * stride_ak + b_ptrs += TILE_K * stride_bk + c = acc.to(C.dtype.element_ty) + # ---------------- # + # Epilogue # + # ---------------- # + offs_cm = tl.arange(0, TILE_M) % BLOCK + offs_cn = tl.arange(0, TILE_N) % BLOCK + pc = C \ + + off_z * stride_zc \ + + block_id * stride_hc \ + + offs_cm[:, None] * stride_mc \ + + offs_cn[None, :] * stride_nc + tl.store(pc, c, mask=True) + + +def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, widths, out=None): + if a.stride(2) != 1 and a.stride(3) != 1: + a = a.contiguous() + if b.stride(2) != 1 and b.stride(3) != 1: + b = b.contiguous() + # (A * B)^T = B^T * A^T + if trans_c: + a, b = b, a + trans_a, trans_b = not trans_b, not trans_a + # shape constraints + a_dim = -2 if trans_a else -1 + b_dim = -1 if trans_b else -2 + Ka, Kb = a.shape[a_dim], b.shape[b_dim] + if Ka != Kb: + raise ValueError(f"Inner dimension mismatch (A: {Ka} vs B: {Kb})") + # allocate output + if out is None: + c = torch.empty((a.shape[0], lut.shape[0], block, block), dtype=a.dtype, device=a.device) + else: + assert out.shape == (a.shape[0], lut.shape[0], block, block) + c = out + grid = [c.shape[1], 1, c.shape[0]] + _sdd_kernel[grid]( + a, b, c, # + a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3), # + b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3), # + c.stride(0), c.stride(1), c.stride(2), c.stride(3), # + Ka, 0, lut, # + TILE_M=block, TILE_N=block, TILE_K=32, BLOCK=block, num_stages=4, # + num_warps=4 # + ) + return c + + +def sdd_lut(layout, block, device): + lut = layout.nonzero(as_tuple=False).to(device).int() + lut = lut.contiguous() + return lut, None + + +# ----------------------------- +# Dense = Sparse x Dense (DSD) +# This operation uses a look-up table that contains pre-computed pointer increments +# in order to minimize computations in the inner loop of the matmul kernel. +# ----------------------------- + + +@jit +def _dsd_kernel(A, B, C, # + stride_az, stride_ha, stride_am, stride_ak, # + stride_zb, stride_hb, stride_bk, stride_bn, # + stride_zc, stride_hc, stride_cm, stride_cn, # + DS0, DS1, lut, # + TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr, # + GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr # + ): + # ------------ # + # - Prologue - # + # ------------ # + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + num_pid_m = tl.num_programs(0) + num_pid_n = tl.num_programs(1) + pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_SIZE_M) + pidz = tl.program_id(2) + header = lut + pid_n * 4 + offset = tl.load(header + 0) + K = tl.load(header + 1) + column = tl.load(header + 2) + off_h = tl.load(header + 3) + pinc = lut + offset + # initialize pointers to A (sparse) + block_id = tl.load(pinc + 1) + block_id = tl.multiple_of(block_id, 8) # compiler hint + offs_am = tl.arange(0, TILE_M) + offs_ak = tl.arange(0, TILE_K) + pa = A + pidz * stride_az \ + + block_id * stride_ha \ + + offs_am[:, None] * stride_am \ + + offs_ak[None, :] * stride_ak + # initialize pointers to B (dense) + offs_bn = pid_m * TILE_N + tl.arange(0, TILE_N) + offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn % DS0, TILE_N), TILE_N) + start_bk = tl.load(pinc) + start_bk = tl.multiple_of(start_bk, 8) # compiler hint + offs_bk = start_bk + tl.arange(0, TILE_K) + pb = B + pidz * stride_zb \ + + off_h * stride_hb \ + + offs_bn[None, :] * stride_bn \ + + offs_bk[:, None] * stride_bk + # ---------------- # + # Inner Loop # + # ---------------- # + acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32) + pinc += 2 + inc_a = tl.load(pinc + 1) + inc_a = tl.multiple_of(inc_a, 8) + inc_b = tl.load(pinc) + inc_b = tl.multiple_of(inc_b, 8) + for k in range(K, 0, -TILE_K): + a = tl.load(pa) + b = tl.load(pb) + acc += tl.dot(a, b, out_dtype=tl.float32) + pa += inc_a + pb += inc_b * stride_bk + pinc += 2 + inc_a = tl.load(pinc + 1) + inc_a = tl.multiple_of(inc_a, 8) + inc_b = tl.load(pinc) + inc_b = tl.multiple_of(inc_b, 8) + c = acc.to(C.dtype.element_ty) + # initialize pointers to C + offs_cm = column * TILE_M + tl.arange(0, TILE_M) + offs_cn = pid_m * TILE_N + tl.arange(0, TILE_N) + pc = C \ + + off_h * stride_hc \ + + pidz * stride_zc \ + + offs_cm[:, None] * stride_cm \ + + offs_cn[None, :] * stride_cn + tl.store(pc, c, mask=offs_cn[None, :] < DS0) + + +def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=None): + if a.stride(2) != 1 and a.stride(3) != 1: + a = a.contiguous() + if b.stride(2) != 1 and b.stride(3) != 1: + b = b.contiguous() + # shapes / dtypes + AS1 = block * spdims[2 if trans_a else 1] + BS0 = b.size(0) + BS1 = b.size(1) + BS3 = b.size(2 if trans_b else 3) + dtype = a.dtype + # allocate output + CS0 = BS0 + CS1 = BS1 + CS2 = BS3 if trans_c else AS1 + CS3 = AS1 if trans_c else BS3 + if out is None: + c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device) + else: + assert out.shape == (CS0, CS1, CS2, CS3) + c = out + # meta-parameter heuristics + TILE_N = 128 + # compute output + grid = lambda meta: [cdiv(BS3, meta['TILE_N']), width, BS0] + _dsd_kernel[grid]( + a, b, c, # + a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3), # + b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3), # + c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3), # + BS3, AS1, lut, # + TILE_M=block, TILE_N=TILE_N, TILE_K=min(block, 32), BLOCK=block, num_stages=4, # + num_warps=4, GROUP_SIZE_M=4 # + ) + # exit() + return c + + +def dsd_lut(layout, block, step, trans, device): + """ + Generates the look-up table for incrementing pointers in the DSD/DDS matmul. + Example (BLOCK=32, STEP=16) + [[1, 0, 0, 1, 0], + [0, 1, 1, 0, 1], + [1, 0, 1, 0, 0]] + + Then the offsets for A are + [0 , 16, 32, 48] <- row 0 + \\----/ \\----/ + col=0 col=3 + [64, 80, 96, 112, 128, 144] <- row 1 + \\----/ \\----/ \\------/ + col=1 col=2 col=3 + [160, 176, 192, 208] + which leads to increments table + [0, 16, 16, 16, || 64, 16, 16, 16, 16, 16, || 160, 16, 16, 16] + + Because B is dense, the offsets are + [0, 16, 96, 112] <- row 0 + [32, 48, 64, 80] <- row 1 + [0, 16, 64, 80] <- row 2 + """ + sizes = torch.sum(layout, 2 if trans else 1) + head_id, col_id = torch.ones_like(sizes).nonzero(as_tuple=True) + sizes = sizes.flatten() + segments = sizes * step + # pointer increments + if trans: + nnz = layout.nonzero(as_tuple=False) + else: + nnz = layout.transpose(1, 2).nonzero(as_tuple=False) + num_blocks = nnz.size(0) + offsets = torch.zeros_like(sizes) + offsets[1:] = torch.cumsum(sizes[:-1], dim=0) + offsets = torch.min(offsets, (num_blocks - 1) * torch.ones_like(offsets)) + # ------------------------------- + # dense input pointer increments + # ------------------------------- + # Note that the inner loop matmul kernel may have a fixed step size (e.g., TILE_K) + # that is smaller than the block size, so we need to do a bit of extra work + # to handle this case + B_idx = nnz[:, 2] * block + B_incs = B_idx.clone() + B_incs[1:] -= B_idx[:-1] + div = block // step + B_incs = B_incs.view(-1, 1).repeat(1, div) + B_incs[:, 1:] = step + B_incs[:, 0] -= (div - 1) * step + # first increment for each reduction is actually the offset + B_incs[offsets[segments > 0], 0] = B_idx[offsets[segments > 0]] + B_incs = B_incs.view(-1) + # ------------------------------- + # sparse input pointer increments + # ------------------------------- + # same as above, except that the increments are in the sparse memory layout + if trans: + A_idx = torch.arange(num_blocks, device=layout.device) + else: + A_idx = torch.tensor([], dtype=torch.int64, device=layout.device) + current_offset = 0 + for z in range(layout.size(0)): + layoutw = layout[z, :, :].clone().long() + msum = layoutw.sum() + layoutw[layoutw > 0] = 1 + torch.arange(msum, device=layout.device) + A_idx = torch.cat((A_idx, current_offset + layoutw.T[layoutw.T > 0] - 1)) + current_offset += msum + A_incs = A_idx * block * block + A_incs[1:] -= A_idx[:-1] * block * block + A_incs = A_incs.view(-1, 1).repeat(1, div) + if trans: + A_incs[:, 1:] = step + A_incs[:, 0] -= (div - 1) * step + else: + A_incs[:, 1:] = step * block + A_incs[:, 0] -= (div - 1) * step * block + A_incs[offsets[segments > 0], 0] = A_idx[offsets[segments > 0]] + A_incs = A_incs.view(-1) + # create header + width = col_id.size(0) + offsets = offsets * 2 * div + 4 * width + segments = segments * div + header = torch.stack((offsets, segments, col_id, head_id), dim=1).view(-1).contiguous() + # create increments + incs = torch.stack((B_incs, A_incs), dim=1).view(-1).contiguous() + # pad by a factor 2*MAX_NUM_STAGES + # to accommodate pre-fetching inside the kernel + pad = torch.zeros(20, device=incs.device, dtype=incs.dtype) + incs = torch.cat((incs, pad)) + # create lut + lut = torch.cat((header, incs)) + lut = lut.type(torch.int32).to(device) + # create locks + return lut, width + + +# ----------------------------- +# Dense = Dense x Sparse (DDS) +# ----------------------------- +# AB = (B^T A^T)^T + + +def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=None): + return dsd_matmul(b, a, not trans_b, not trans_a, not trans_c, spdims, block, lut, width, out=out) + + +############## +# MAIN API # +############## + + +class _matmul(torch.autograd.Function): + + fn = {'sdd': sdd_matmul, 'dsd': dsd_matmul, 'dds': dds_matmul} + + @staticmethod + def forward(ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block, c_lut, c_width, da_lut, da_width, db_lut, + db_width, out): + c = _matmul.fn[mode](a, b, trans_a, trans_b, trans_c, spdims, block, c_lut, c_width, out=out) + # save for backward + ctx.save_for_backward(a, b) + ctx.da_lut = da_lut + ctx.da_width = da_width + ctx.db_lut = db_lut + ctx.db_width = db_width + ctx.mode = mode + ctx.spdims = spdims + ctx.block = block + ctx.trans_a = trans_a + ctx.trans_b = trans_b + ctx.trans_c = trans_c + ctx.has_out = out is not None + return c + + @staticmethod + def backward(ctx, dc): + # saved for backward + a, b = ctx.saved_tensors + da, db = None, None + mode = ctx.mode + # gradients w.r.t. a + if ctx.needs_input_grad[0]: + mode_da = mode[1] + mode[0] + mode[2] + da = _matmul.fn[mode_da](dc, b, ctx.trans_c, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block, + ctx.da_lut, ctx.da_width) + # gradients w.r.t. b + if ctx.needs_input_grad[1]: + mode_db = mode[2] + mode[1] + mode[0] + db = _matmul.fn[mode_db](a, dc, not ctx.trans_a, ctx.trans_c, ctx.trans_b, ctx.spdims, ctx.block, + ctx.db_lut, ctx.db_width) + dout = dc if ctx.has_out else None + return da, db, None, None, None, \ + None, None, None, None, \ + None, None, None, None, None, dout + + +class matmul: + + def __init__(self, layout, block, mode, device, trans_a=False, trans_b=False, trans_c=False): + if mode not in ['sdd', 'dsd', 'dds']: + raise NotImplementedError('Supported modes are: sdd, dsd, dds') + self.block = block + self.mode = mode + self.trans_a = trans_a + self.trans_b = trans_b + self.trans_c = trans_c + self.layout = layout + self.spdims = layout.shape + step = min(block, 32) + if self.mode == 'sdd': + self.c_lut, self.c_width = sdd_lut(layout, block, device) + self.da_lut, self.da_width = dsd_lut(layout, block, step, True, device) + self.db_lut, self.db_width = dsd_lut(layout, block, step, False, device) + if self.mode == 'dsd': + self.c_lut, self.c_width = dsd_lut(layout, block, step, not self.trans_a, device) + self.da_lut, self.da_width = sdd_lut(layout, block, device) + self.db_lut, self.db_width = dsd_lut(layout, block, step, self.trans_a, device) + if self.mode == 'dds': + self.c_lut, self.c_width = dsd_lut(layout, block, step, self.trans_b, device) + self.da_lut, self.da_width = dsd_lut(layout, block, step, not self.trans_b, device) + self.db_lut, self.db_width = sdd_lut(layout, block, device) + + def __call__(self, a, b, out=None): + c = _matmul.apply(a, b, self.trans_a, self.trans_b, self.trans_c, self.mode, self.spdims, self.block, # + self.c_lut, self.c_width, # + self.da_lut, self.da_width, # + self.db_lut, self.db_width, # + out) + return c diff --git a/lib/python3.10/site-packages/triton/ops/blocksparse/softmax.py b/lib/python3.10/site-packages/triton/ops/blocksparse/softmax.py new file mode 100644 index 0000000000000000000000000000000000000000..bcffff26bb515aa65f9e199d9e7c3b36bc6fe1c3 --- /dev/null +++ b/lib/python3.10/site-packages/triton/ops/blocksparse/softmax.py @@ -0,0 +1,228 @@ +import torch + +from ... import jit +from ... import language as tl +from ... import next_power_of_2 + + +def num_warps(n): + if n <= 128: + return 1 + if n <= 256: + return 2 + if n <= 512: + return 4 + if n <= 4096: + return 8 + return 16 + + +@jit +def _blocksparse_softmax_fwd(Out, A, stride_xz, LUT, # + R, extent, stride_zr, stride_hr, # relative attention + scale, is_causal, # + ROW_SIZE: tl.constexpr, # + BLOCK_SIZE: tl.constexpr, # + IS_DENSE: tl.constexpr # + ): + h = tl.program_id(0) + m = tl.program_id(1) + z = tl.program_id(2) + # create index ranges + hm = h * tl.num_programs(1) + m + lane_n = tl.arange(0, ROW_SIZE) % BLOCK_SIZE + block_n = tl.arange(0, ROW_SIZE) // BLOCK_SIZE + # extract information from LUT + header = LUT + (hm // BLOCK_SIZE) * 2 + size = tl.load(header + 0) + offset = tl.load(header + 1) + # pointer offset + off_a = z * stride_xz + off_a += (offset + block_n) * BLOCK_SIZE * BLOCK_SIZE # block indx + off_a += (m % BLOCK_SIZE) * BLOCK_SIZE # row indx + # do not need to read column indices in the dense case + if IS_DENSE: + ns = tl.arange(0, ROW_SIZE) + else: + off_lut = offset + 2 * tl.num_programs(0) * tl.num_programs(1) // BLOCK_SIZE + start_n = tl.load(LUT + off_lut + block_n, mask=block_n < size, other=0) + ns = start_n * BLOCK_SIZE + lane_n + # load X + mask = block_n < size + a = tl.load(A + off_a + lane_n, mask=mask, other=-float("inf")) + a = a.to(tl.float32) + # compute + out = a + out *= scale + # apply relative attention + if R is not None: + R += z * stride_zr + R += h * stride_hr + off_lo = (extent - m - 1) + ns + mask_lo = (off_lo >= 0) & (off_lo < extent) + rel_logits = tl.load(R + m * extent + off_lo, mask=mask_lo, other=0.0) + out += rel_logits + out = out.to(tl.float32) + # apply causal mask + out = tl.where((ns > m) & is_causal, -float("inf"), out) + # computation + out = tl.softmax(out) + # write-back + tl.store(Out + off_a + lane_n, out, mask=mask) + + +@jit +def _blocksparse_softmax_bwd(DA, stride_zdx, # + DOut, stride_zdout, # + Out, stride_zout, # + scale, # + LUT, # + DR, extent, stride_zr, stride_hr, stride_er, # + is_causal, # + ROW_SIZE: tl.constexpr, # + BLOCK_SIZE: tl.constexpr, # + IS_DENSE: tl.constexpr): + h = tl.program_id(0) + m = tl.program_id(1) + z = tl.program_id(2) + # create index ranges + hm = h * tl.num_programs(1) + m + lane_n = tl.arange(0, ROW_SIZE) % BLOCK_SIZE + block_n = tl.arange(0, ROW_SIZE) // BLOCK_SIZE + # extract information from LUT + header = LUT + (hm // BLOCK_SIZE) * 2 + size = tl.load(header + 0) + offset = tl.load(header + 1) + # row-col offset + off_mn = (offset + block_n) * BLOCK_SIZE * BLOCK_SIZE + off_mn += (m % BLOCK_SIZE) * BLOCK_SIZE + mask = block_n < size + # pointers + As = Out + z * stride_zout + off_mn + DOuts = DOut + z * stride_zdout + off_mn + # do not need to read column indices in the dense case + if IS_DENSE: + ns = tl.arange(0, ROW_SIZE) + else: + off_lut = offset + 2 * tl.num_programs(0) * tl.num_programs(1) // BLOCK_SIZE + start_n = tl.load(LUT + off_lut + block_n, mask=mask, other=0) + ns = start_n * BLOCK_SIZE + lane_n + # load data + a = tl.load(As + lane_n, mask=mask, other=0.0) + a = a.to(tl.float32) + dout = tl.load(DOuts + lane_n, mask=mask, other=0.0) + dout = dout.to(tl.float32) + # compute + a = tl.where((ns > m) & is_causal & (a == a), 0., a) + da = a * (dout - tl.sum(a * dout, 0)) + # apply relative attention + if DR is not None: + DR += z * stride_zr + DR += h * stride_hr + off_lo = (extent - m - 1) + ns + mask_lo = (off_lo >= 0) & (off_lo < extent) & mask + tl.store(DR + m * extent + off_lo, da, mask=mask_lo) + da = da * scale + # convert da + # write-back + DAs = DA + z * stride_zdx + off_mn + tl.store(DAs + lane_n, da, mask=mask) + + +class _softmax(torch.autograd.Function): + + @staticmethod + def make_lut(layout, block, device): + _empty = torch.tensor([], dtype=torch.int64, device=layout.device) + sizes = _empty.clone() + # sizes along rows + for h in range(layout.shape[0]): + sizes = torch.cat((sizes, layout[h, :, :].sum(-1))) + total_sizes = sizes * block + # offsets in block format + offsets = torch.zeros_like(sizes) + offsets[1:] = torch.cumsum(sizes[:-1], dim=0) + # block indices + columns = layout.nonzero(as_tuple=False)[:, 2] + header = torch.stack((sizes, offsets), dim=1).view(-1) + lut = torch.cat((header, columns)).type(torch.int32).to(device) + return lut, int(total_sizes.max()) + + @staticmethod + def forward(ctx, a, scale, rel_logits, is_causal, spdims, block, lut, maxlut, is_dense): + if scale is not None and isinstance(scale, torch.Tensor): + assert scale.device.type == "cpu" + scale = scale.item() + M = a.shape[0] + grid = [spdims[0], spdims[1] * block, M] + rel_shape = (1, 1, 1, 1) if rel_logits is None else rel_logits.shape + rel_strides = (1, 1, 1, 1) if rel_logits is None else rel_logits.stride() + # enqueue kernel + out = torch.empty_like(a) + _blocksparse_softmax_fwd[grid]( + out, a, a.stride(0), lut, # + rel_logits, rel_shape[-1], rel_strides[0], rel_strides[1], # relative attn# + scale, # + is_causal, # + BLOCK_SIZE=block, # + ROW_SIZE=next_power_of_2(maxlut), # + IS_DENSE=is_dense, # + num_warps=num_warps(maxlut) # + ) + # save to context + # ctx.mark_dirty(x) + ctx.save_for_backward(out, lut) + ctx.spdims = spdims + ctx.block = block + ctx.maxlut = maxlut + ctx.scale = scale + ctx.rel_shape = rel_shape + ctx.rel_strides = rel_strides + ctx.rel_dtype = a.dtype + ctx.is_dense = is_dense + ctx.is_causal = is_causal + return out + + @staticmethod + def backward(ctx, dout): + # retrieve from context + out, lut = ctx.saved_tensors + # relative logits gradients + dr = None + if ctx.needs_input_grad[3]: + dr = torch.zeros(ctx.rel_shape, dtype=ctx.rel_dtype, device=out.device) + # run kernel + M = out.shape[0] + grid = (ctx.spdims[0], ctx.spdims[1] * ctx.block, M) + da = torch.empty_like(dout) + _blocksparse_softmax_bwd[grid]( + da, da.stride(0), # + dout, dout.stride(0), # + out, out.stride(0), # + ctx.scale, # + lut, # + dr, ctx.rel_shape[-1], ctx.rel_strides[0], ctx.rel_strides[1], ctx.rel_strides[2], # + ctx.is_causal, # + BLOCK_SIZE=ctx.block, # + ROW_SIZE=next_power_of_2(ctx.maxlut), # + IS_DENSE=ctx.is_dense, # + num_warps=num_warps(ctx.maxlut) # + ) + return (da, None, None, dr, None, None, None, None, None, None, None, None, None, None, None, None, None, None) + + +class softmax: + + def __init__(self, layout, block, device, is_dense=False): + self.spdims = layout.shape + self.layout = layout + self.block = block + self.lut, self.maxlut = _softmax.make_lut(self.layout, self.block, device) + self.is_dense = is_dense + + def __call__(self, a, *, scale=1.0, rel_logits=None, is_causal=False): + if rel_logits is not None and rel_logits.dtype != a.dtype: + raise ValueError(f"relative position embedding must be {a.dtype}") + a = _softmax.apply(a, scale, rel_logits, is_causal, self.spdims, self.block, self.lut, self.maxlut, + self.is_dense) + return a diff --git a/lib/python3.10/site-packages/triton/ops/cross_entropy.py b/lib/python3.10/site-packages/triton/ops/cross_entropy.py new file mode 100644 index 0000000000000000000000000000000000000000..88e8dae50db09ca17c1bb8464afee2975b3d77eb --- /dev/null +++ b/lib/python3.10/site-packages/triton/ops/cross_entropy.py @@ -0,0 +1,96 @@ +import torch + +from .. import heuristics, jit +from .. import language as tl +from .. import next_power_of_2 + + +def num_warps(N): + if N < 2048: + return 4 + elif N < 8192: + return 8 + return 16 + + +@heuristics({'num_warps': lambda nargs: num_warps(nargs['N'])}) +@heuristics({'BLOCK': lambda nargs: next_power_of_2(nargs['N'])}) +@jit +def _forward(LOGITS, PROBS, IDX, LOSS, N, BLOCK: tl.constexpr): + row = tl.program_id(0) + cols = tl.arange(0, BLOCK) + idx = tl.load(IDX + row) + # pointers to logit and probs + LOGITS = LOGITS + row * N + cols + WRIT_PROBS = PROBS + row * N + cols + READ_PROBS = PROBS + row * N + idx + # write-back negative log-probs + logits = tl.load(LOGITS, mask=cols < N, other=-float('inf')) + logits = logits.to(tl.float32) + logits = logits - tl.max(logits, 0) + probs = tl.log(tl.sum(tl.exp(logits), 0)) - logits + tl.store(WRIT_PROBS, probs, mask=cols < N) + # There is a bug in the compiler, which fails to insert a barrier here. + # We add it explicitly for now. Will be fixed soon. + tl.debug_barrier() + # write-back loss + probs = tl.load(READ_PROBS) + tl.store(LOSS + row, probs) + + +@heuristics({'num_warps': lambda nargs: num_warps(nargs['N'])}) +@heuristics({'BLOCK': lambda nargs: next_power_of_2(nargs['N'])}) +@jit +def _backward(PROBS, IDX, DPROBS, N, BLOCK: tl.constexpr): + row = tl.program_id(0) + cols = tl.arange(0, BLOCK) + idx = tl.load(IDX + row) + # pointers to probs + PROBS = PROBS + row * N + cols + # We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k] + # and we have -log(p[k]) stored in PROBS, so this is easy + probs = -tl.load(PROBS, mask=cols < N, other=float('inf')) + probs = tl.exp(probs.to(tl.float32)) + delta = cols == idx + # write result in-place in PROBS + dout = tl.load(DPROBS + row) + din = (probs - delta) * dout + tl.store(PROBS, din.to(PROBS.dtype.element_ty), mask=cols < N) + + +class _cross_entropy(torch.autograd.Function): + + @classmethod + def forward(cls, ctx, logits, indices): + # make sure we can use triton + assert (indices.dtype == torch.int64), "Indices are expected to be of type long." + # make kernel + device, dtype = logits.device, logits.dtype + n_cols = logits.shape[-1] + # run the kernel + result = torch.empty_like(indices, dtype=dtype, device=device) + neg_logprobs = torch.empty_like(logits, dtype=dtype, device=device) + grid = lambda opt: (logits.numel() // n_cols, ) + _forward[grid](logits, neg_logprobs, indices, result, n_cols) + # save for backward + ctx.save_for_backward(neg_logprobs, indices) + return result + + @classmethod + def backward(cls, ctx, dneg_logprobs): + """We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k] + so we initialize the gradient as neg_logprobs, so we can just exponentiate + to get p[k], which is most of what we need... neg_logprobs will be + modified in place to become the gradient we want + """ + # load saved tensors + neg_logprobs, indices = ctx.saved_tensors + # run the kernel + # neg_logprobs will be modified in place to become our gradient: + n_cols = neg_logprobs.shape[-1] + grid = lambda opt: (neg_logprobs.numel() // n_cols, ) + _backward[grid](neg_logprobs, indices, dneg_logprobs, n_cols) + return neg_logprobs, None + + +cross_entropy = _cross_entropy.apply diff --git a/lib/python3.10/site-packages/triton/ops/flash_attention.py b/lib/python3.10/site-packages/triton/ops/flash_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..0825ef26c41c078b4d4b7b6b198d46746317307e --- /dev/null +++ b/lib/python3.10/site-packages/triton/ops/flash_attention.py @@ -0,0 +1,466 @@ +""" +Fused Attention +=============== +This is a Triton implementation of the Flash Attention algorithm +(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf) + +Sequence Parallel implementation inspired by HazyResearch +(see https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attn_triton.py) +""" + +import torch +import triton + +from .. import cdiv, jit +from .. import language as tl + + +def is_hip(): + return triton.runtime.driver.active.get_current_target().backend == "hip" + + +@jit +def _fwd_kernel(Q, K, V, sm_scale, # + L, # + Out, # + stride_qz, stride_qh, stride_qm, stride_qk, # + stride_kz, stride_kh, stride_kn, stride_kk, # + stride_vz, stride_vh, stride_vn, stride_vk, # + stride_oz, stride_oh, stride_om, stride_on, # + Z, H, N_CTX, # + Z_H_N_CTX, # + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, # + BLOCK_N: tl.constexpr, # + IS_CAUSAL: tl.constexpr # + ): + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + qvk_offset = off_hz * stride_qh + vk_offset = qvk_offset // stride_qm + + K_block_ptr = tl.make_block_ptr( + base=K, + shape=(BLOCK_DMODEL, Z_H_N_CTX), + strides=(stride_kk, stride_kn), + offsets=(0, vk_offset), + block_shape=(BLOCK_DMODEL, BLOCK_N), + order=(0, 1), + ) + V_block_ptr = tl.make_block_ptr( + base=V, + shape=(Z_H_N_CTX, BLOCK_DMODEL), + strides=(stride_vn, stride_vk), + offsets=(vk_offset, 0), + block_shape=(BLOCK_N, BLOCK_DMODEL), + order=(1, 0), + ) + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # credits to: Adam P. Goucher (https://github.com/apgoucher): + # scale sm_scale by 1/log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + qk_scale = sm_scale * 1.44269504 + # load q: it will stay in SRAM throughout + + offs_k = tl.arange(0, BLOCK_DMODEL) + Q_ptrs = Q + qvk_offset + offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk + q = tl.load(Q_ptrs) + + q = (q * qk_scale).to(K.dtype.element_ty) + lo = 0 + hi = (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX + for start_n in range(lo, hi, BLOCK_N): + # -- load k, v -- + k = tl.load(K_block_ptr) + v = tl.load(V_block_ptr) + # -- compute qk --- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + if IS_CAUSAL: + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) + qk += tl.dot(q, k) + # -- compute scaling constant --- + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) + alpha = tl.math.exp2(m_i - m_i_new) + p = tl.math.exp2(qk - m_i_new[:, None]) + # -- scale and update acc -- + acc *= alpha[:, None] + acc += tl.dot(p.to(V.dtype.element_ty), v) + # -- update m_i and l_i -- + l_i = l_i * alpha + tl.sum(p, 1) + m_i = m_i_new + # update pointers + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + # write back l and m + acc = acc / l_i[:, None] + l_ptrs = L + off_hz * N_CTX + offs_m + tl.store(l_ptrs, m_i + tl.math.log2(l_i)) + # write back O + O_block_ptr = tl.make_block_ptr( + base=Out, + shape=(Z_H_N_CTX, BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(vk_offset + start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + # O_ptrs = Out + qvk_offset + offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk + tl.store(O_block_ptr, acc.to(K.dtype.element_ty)) + + +@jit +def _bwd_preprocess( + Out, + DO, + Delta, + BLOCK_M: tl.constexpr, + D_HEAD: tl.constexpr, +): + off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) + off_n = tl.arange(0, D_HEAD) + # load + o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) + do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) + # compute + delta = tl.sum(o * do, axis=1) + # write-back + tl.store(Delta + off_m, delta) + + +@jit +def _bwd_kernel_one_col_block(Q, K, V, sm_scale, qk_scale, # + Out, DO, # + DQ, DK, DV, # + L, # + D, # + Q_block_ptr, K_block_ptr, V_block_ptr, # + DO_block_ptr, DQ_block_ptr, DK_block_ptr, DV_block_ptr, # + stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, # + stride_kz, stride_kh, stride_kn, stride_kk, # + stride_vz, stride_vh, stride_vn, stride_vk, # + Z, H, N_CTX, # + off_h, off_z, off_hz, start_n, num_block, # + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, # + BLOCK_N: tl.constexpr, # + SEQUENCE_PARALLEL: tl.constexpr, # + CAUSAL: tl.constexpr, # + MMA_V3: tl.constexpr # + ): + if CAUSAL: + lo = start_n * BLOCK_M + else: + lo = 0 + + Q_offset = (off_z * stride_qz + off_h * stride_qh) // stride_qm + DQ_offset = off_z * stride_qz + off_h * stride_qh + K_offset = (off_z * stride_kz + off_h * stride_kh) // stride_kn + V_offset = (off_z * stride_vz + off_h * stride_vh) // stride_vn + if SEQUENCE_PARALLEL: + DQ_offset += stride_dqa * start_n + DQ_offset = DQ_offset // stride_qm + + Q_block_ptr = tl.advance(Q_block_ptr, (lo + Q_offset, 0)) + K_block_ptr = tl.advance(K_block_ptr, (start_n * BLOCK_M + K_offset, 0)) + V_block_ptr = tl.advance(V_block_ptr, (start_n * BLOCK_M + V_offset, 0)) + DO_block_ptr = tl.advance(DO_block_ptr, (lo + Q_offset, 0)) + DQ_block_ptr = tl.advance(DQ_block_ptr, (lo + DQ_offset, 0)) + DK_block_ptr = tl.advance(DK_block_ptr, (start_n * BLOCK_M + K_offset, 0)) + DV_block_ptr = tl.advance(DV_block_ptr, (start_n * BLOCK_M + V_offset, 0)) + + # initialize row/col offsets + offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M) + offs_m = tl.arange(0, BLOCK_N) + # pointer to row-wise quantities in value-like data + D_ptrs = D + off_hz * N_CTX + l_ptrs = L + off_hz * N_CTX + # initialize dv amd dk + dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # k and v stay in SRAM throughout + k = tl.load(K_block_ptr) + v = tl.load(V_block_ptr) + # loop over rows + for start_m in range(lo, num_block * BLOCK_M, BLOCK_M): + offs_m_curr = start_m + offs_m + # load q, k, v, do on-chip + q = tl.load(Q_block_ptr) + # recompute p = softmax(qk, dim=-1).T + # NOTE: `do` is pre-divided by `l`; no normalization here + if CAUSAL: + qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), float(0.0), float("-inf")) + else: + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, tl.trans(k)) + qk *= qk_scale + l_i = tl.load(l_ptrs + offs_m_curr) + p = tl.math.exp2(qk - l_i[:, None]) + # compute dv + do = tl.load(DO_block_ptr) + dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do) + # compute dp = dot(v, do) + Di = tl.load(D_ptrs + offs_m_curr) + # dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None] + dp = tl.dot(do, tl.trans(v)) + # compute ds = p * (dp - delta[:, None]) + ds = (p * (dp - Di[:, None]) * sm_scale).to(Q.dtype.element_ty) + # compute dk = dot(ds.T, q) + dk += tl.dot(tl.trans(ds), q) + # compute dq + if not SEQUENCE_PARALLEL: + dq = tl.load(DQ_block_ptr) + dq += tl.dot(ds, k) + tl.store(DQ_block_ptr, dq.to(Q.dtype.element_ty)) + elif SEQUENCE_PARALLEL: + if MMA_V3: + dq = tl.dot(ds, k) + else: + # not work with mma v3, because M % 64 != 0 + dq = tl.trans(tl.dot(tl.trans(k), tl.trans(ds))) + tl.store(DQ_block_ptr, dq.to(Q.dtype.element_ty)) + + # increment pointers + DQ_block_ptr = tl.advance(DQ_block_ptr, (BLOCK_M, 0)) + Q_block_ptr = tl.advance(Q_block_ptr, (BLOCK_M, 0)) + DO_block_ptr = tl.advance(DO_block_ptr, (BLOCK_M, 0)) + # write-back + tl.store(DV_block_ptr, dv.to(V.dtype.element_ty)) + tl.store(DK_block_ptr, dk.to(K.dtype.element_ty)) + + +@jit +def _bwd_kernel(Q, K, V, sm_scale, # + Out, DO, # + DQ, DK, DV, # + L, # + D, # + stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, # + stride_kz, stride_kh, stride_kn, stride_kk, # + stride_vz, stride_vh, stride_vn, stride_vk, # + Z, H, N_CTX, # + Z_H_N_CTX, # + SQ_Z_H_N_CTX, # + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, # + BLOCK_N: tl.constexpr, # + SEQUENCE_PARALLEL: tl.constexpr, # + CAUSAL: tl.constexpr, # + MMA_V3: tl.constexpr # + ): + qk_scale = sm_scale * 1.44269504 + off_hz = tl.program_id(0) + off_z = off_hz // H + off_h = off_hz % H + + Q_block_ptr = tl.make_block_ptr( + base=Q, + shape=(Z_H_N_CTX, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + K_block_ptr = tl.make_block_ptr( + base=K, + shape=(Z_H_N_CTX, BLOCK_DMODEL), + strides=(stride_kn, stride_kk), + offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + V_block_ptr = tl.make_block_ptr( + base=V, + shape=(Z_H_N_CTX, BLOCK_DMODEL), + strides=(stride_vn, stride_vk), + offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + DO_block_ptr = tl.make_block_ptr( + base=DO, + shape=(Z_H_N_CTX, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + if SEQUENCE_PARALLEL: + DQ_block_ptr = tl.make_block_ptr( + base=DQ, + shape=(SQ_Z_H_N_CTX, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + else: + DQ_block_ptr = tl.make_block_ptr( + base=DQ, + shape=(Z_H_N_CTX, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + + DK_block_ptr = tl.make_block_ptr( + base=DK, + shape=(Z_H_N_CTX, BLOCK_DMODEL), + strides=(stride_kn, stride_kk), + offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + DV_block_ptr = tl.make_block_ptr( + base=DV, + shape=(Z_H_N_CTX, BLOCK_DMODEL), + strides=(stride_vn, stride_vk), + offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + + num_block_n = tl.cdiv(N_CTX, BLOCK_N) + if not SEQUENCE_PARALLEL: + for start_n in range(0, num_block_n): + _bwd_kernel_one_col_block(Q, K, V, sm_scale, qk_scale, Out, DO, # + DQ, DK, DV, # + L, # + D, # + Q_block_ptr, K_block_ptr, V_block_ptr, # + DO_block_ptr, DQ_block_ptr, DK_block_ptr, DV_block_ptr, # + stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, # + stride_kz, stride_kh, stride_kn, stride_kk, # + stride_vz, stride_vh, stride_vn, stride_vk, # + Z, H, N_CTX, # + off_h, off_z, off_hz, start_n, num_block_n, # + BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL, # + BLOCK_N=BLOCK_N, # + SEQUENCE_PARALLEL=SEQUENCE_PARALLEL, # + CAUSAL=CAUSAL, # + MMA_V3=MMA_V3 # + ) + else: + start_n = tl.program_id(1) + _bwd_kernel_one_col_block(Q, K, V, sm_scale, qk_scale, Out, DO, # + DQ, DK, DV, # + L, # + D, # + Q_block_ptr, K_block_ptr, V_block_ptr, # + DO_block_ptr, DQ_block_ptr, DK_block_ptr, DV_block_ptr, # + stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, # + stride_kz, stride_kh, stride_kn, stride_kk, # + stride_vz, stride_vh, stride_vn, stride_vk, # + Z, H, N_CTX, # + off_h, off_z, off_hz, start_n, num_block_n, # + BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL, # + BLOCK_N=BLOCK_N, # + SEQUENCE_PARALLEL=SEQUENCE_PARALLEL, # + CAUSAL=CAUSAL, # + MMA_V3=MMA_V3 # + ) + + +class _attention(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, k, v, causal, sm_scale, sequence_parallel=False): + # only support for Ampere now + capability = torch.cuda.get_device_capability() + if capability[0] < 8: + raise RuntimeError("Flash attention currently only supported for compute capability >= 80") + BLOCK_M = 128 + BLOCK_N = 64 + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + assert Lk in {16, 32, 64, 128} + o = torch.empty_like(q) + grid = (cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1) + L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + num_warps = 4 if Lk <= 64 else 8 + _fwd_kernel[grid]( + q, k, v, sm_scale, # + L, # + o, # + q.stride(0), q.stride(1), q.stride(2), q.stride(3), # + k.stride(0), k.stride(1), k.stride(2), k.stride(3), # + v.stride(0), v.stride(1), v.stride(2), v.stride(3), # + o.stride(0), o.stride(1), o.stride(2), o.stride(3), # + q.shape[0], q.shape[1], q.shape[2], # + q.shape[0] * q.shape[1] * q.shape[2], # + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk, # + IS_CAUSAL=causal, # + num_warps=num_warps, # + num_stages=4 # + ) + + ctx.save_for_backward(q, k, v, o, L) + ctx.grid = grid + ctx.sm_scale = sm_scale + ctx.BLOCK_DMODEL = Lk + ctx.causal = causal + ctx.sequence_parallel = sequence_parallel + return o + + @staticmethod + def backward(ctx, do): + capability = torch.cuda.get_device_capability() + MMA_V3 = capability[0] >= 9 + BLOCK = 128 + + if is_hip(): + # Bwd pass runs out of shared memory on HIP with larger block size. + BLOCK = 64 + + q, k, v, o, L = ctx.saved_tensors + sequence_parallel = ctx.sequence_parallel + seq_len_kv = k.shape[2] + do = do.contiguous() + if sequence_parallel: + replicas = cdiv(seq_len_kv, BLOCK) + new_dq_shape = (replicas, ) + q.shape + dq = torch.zeros(new_dq_shape, device=q.device, dtype=q.dtype) + else: + dq = torch.zeros_like(q, dtype=q.dtype) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + delta = torch.empty_like(L) + _bwd_preprocess[(cdiv(q.shape[2], BLOCK) * ctx.grid[1], )]( + o, + do, + delta, + BLOCK_M=BLOCK, + D_HEAD=ctx.BLOCK_DMODEL, + ) + _bwd_kernel[(ctx.grid[1], cdiv(seq_len_kv, BLOCK) if sequence_parallel else 1)]( + q, k, v, ctx.sm_scale, # + o, do, # + dq, dk, dv, # + L, # + delta, # + o.numel(), q.stride(0), q.stride(1), q.stride(2), q.stride(3), # + k.stride(0), k.stride(1), k.stride(2), k.stride(3), # + v.stride(0), v.stride(1), v.stride(2), v.stride(3), # + q.shape[0], q.shape[1], q.shape[2], # + q.shape[0] * q.shape[1] * q.shape[2], # + cdiv(seq_len_kv, BLOCK) * q.shape[0] * q.shape[1] * q.shape[2], # + BLOCK_M=BLOCK, BLOCK_N=BLOCK, # + BLOCK_DMODEL=ctx.BLOCK_DMODEL, # + SEQUENCE_PARALLEL=sequence_parallel, # + CAUSAL=ctx.causal, # + MMA_V3=MMA_V3, # + num_warps=8, # + num_stages=1 # + ) + + if len(dq.shape) == 5: + dq = dq.sum(dim=0) + return dq, dk, dv, None, None, None + + +attention = _attention.apply diff --git a/lib/python3.10/site-packages/triton/ops/matmul.py b/lib/python3.10/site-packages/triton/ops/matmul.py new file mode 100644 index 0000000000000000000000000000000000000000..f7f577a1b80bd85998d52218d93adf0ab3cc5a60 --- /dev/null +++ b/lib/python3.10/site-packages/triton/ops/matmul.py @@ -0,0 +1,219 @@ +import torch + +from .. import Config, autotune, cdiv, heuristics, jit +from .. import language as tl +from .matmul_perf_model import early_config_prune, estimate_matmul_time + +_ordered_datatypes = [torch.int8, torch.float16, torch.bfloat16, torch.float32] + + +def upcast_if_fp8(a): + if "fp8" in str(a): + return torch.float16 + return a + + +def get_higher_dtype(a, b): + a = upcast_if_fp8(a) + b = upcast_if_fp8(b) + if a is b: + return a + + assert a in _ordered_datatypes + assert b in _ordered_datatypes + + for d in _ordered_datatypes: + if a is d: + return b + if b is d: + return a + + +def init_to_zero(name): + return lambda nargs: nargs[name].zero_() + + +def get_configs_io_bound(): + configs = [] + for num_stages in [2, 3, 4, 5, 6]: + for block_m in [16, 32]: + for block_k in [32, 64]: + for block_n in [32, 64, 128, 256]: + num_warps = 2 if block_n <= 64 else 4 + configs.append( + Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1}, + num_stages=num_stages, num_warps=num_warps)) + # split_k + for split_k in [2, 4, 8, 16]: + configs.append( + Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k}, + num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C'))) + return configs + + +@autotune( + configs=[ + # basic configs for compute-bound matmuls + Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2), + # good for int8 + Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2), + ] + get_configs_io_bound(), + key=['M', 'N', 'K'], + prune_configs_by={ + 'early_config_prune': early_config_prune, + 'perf_model': estimate_matmul_time, + 'top_k': 10, + }, +) +@heuristics({ + 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0, +}) +@jit +def _kernel(A, B, C, M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + acc_dtype: tl.constexpr, # + input_precision: tl.constexpr, # + fp8_fast_accum: tl.constexpr, # + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, # + GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, AB_DTYPE: tl.constexpr # + ): + # matrix multiplication + pid = tl.program_id(0) + pid_z = tl.program_id(1) + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + # do matrix multiplication + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K) + # pointers + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=acc_dtype) + for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + k_remaining = K - k * (BLOCK_K * SPLIT_K) + _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty) + a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0) + b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0) + if AB_DTYPE is not None: + a = a.to(AB_DTYPE) + b = b.to(AB_DTYPE) + if fp8_fast_accum: + acc = tl.dot(a, b, acc, out_dtype=acc_dtype, input_precision=input_precision) + else: + acc += tl.dot(a, b, out_dtype=acc_dtype, input_precision=input_precision) + A += BLOCK_K * SPLIT_K * stride_ak + B += BLOCK_K * SPLIT_K * stride_bk + acc = acc.to(C.dtype.element_ty) + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) + mask = (rm < M)[:, None] & (rn < N)[None, :] + # handles write-back with reduction-splitting + if SPLIT_K == 1: + tl.store(C, acc, mask=mask) + else: + tl.atomic_add(C, acc, mask=mask) + + +class _matmul(torch.autograd.Function): + kernel = _kernel + + _locks = {} + + @staticmethod + def _call(a, b, acc_dtype, input_precision, fp8_fast_accum, output_dtype): + device = a.device + # handle non-contiguous inputs if necessary + if a.stride(0) > 1 and a.stride(1) > 1: + a = a.contiguous() + if b.stride(0) > 1 and b.stride(1) > 1: + b = b.contiguous() + # checks constraints + assert a.shape[1] == b.shape[0], "incompatible dimensions" + M, K = a.shape + _, N = b.shape + + # common type between a and b + ab_dtype = get_higher_dtype(a.dtype, b.dtype) + + # allocates output + if (output_dtype is None): + output_dtype = ab_dtype + + c = torch.empty((M, N), device=device, dtype=output_dtype) + + # Allowed types for acc_type given the types of a and b. + supported_acc_dtypes = { + torch.float16: (torch.float32, torch.float16), torch.bfloat16: (torch.float32, torch.bfloat16), + torch.float32: (torch.float32, ), torch.int8: (torch.int32, ) + } + + if acc_dtype is None: + acc_dtype = supported_acc_dtypes[ab_dtype][0] + else: + assert isinstance(acc_dtype, torch.dtype), "acc_dtype must be a torch.dtype" + assert acc_dtype in supported_acc_dtypes[a.dtype], "acc_dtype not compatible with the type of a" + assert acc_dtype in supported_acc_dtypes[b.dtype], "acc_dtype not compatible with the type of b" + + def to_tl_type(ty): + return getattr(tl, str(ty).split(".")[-1]) + + acc_dtype = to_tl_type(acc_dtype) + ab_dtype = to_tl_type(ab_dtype) + output_dtype = to_tl_type(output_dtype) + + # Tensor cores support input with mixed float8 types. + if a.dtype in [tl.float8e4nv, tl.float8e5] and b.dtype in [tl.float8e4nv, tl.float8e5]: + ab_dtype = None + # launch kernel + grid = lambda META: (cdiv(M, META['BLOCK_M']) * cdiv(N, META['BLOCK_N']), META['SPLIT_K']) + _kernel[grid]( + a, b, c, M, N, K, # + a.stride(0), a.stride(1), # + b.stride(0), b.stride(1), # + c.stride(0), c.stride(1), # + acc_dtype=acc_dtype, # + input_precision=input_precision, # + fp8_fast_accum=fp8_fast_accum, # + GROUP_M=8, AB_DTYPE=ab_dtype) + return c + + @staticmethod + def forward(ctx, a, b, acc_dtype=None, input_precision=None, fp8_fast_accum=True, output_dtype=None): + return _matmul._call(a, b, acc_dtype=acc_dtype, input_precision=input_precision, fp8_fast_accum=fp8_fast_accum, + output_dtype=output_dtype) + + +matmul = _matmul.apply diff --git a/lib/python3.10/site-packages/triton/ops/matmul_perf_model.py b/lib/python3.10/site-packages/triton/ops/matmul_perf_model.py new file mode 100644 index 0000000000000000000000000000000000000000..b60b74540b47830574b10cd4f38c5c7024f616c3 --- /dev/null +++ b/lib/python3.10/site-packages/triton/ops/matmul_perf_model.py @@ -0,0 +1,171 @@ +import functools +import heapq + +import torch + +from .. import cdiv +from ..runtime import driver +from ..testing import (get_dram_gbps, get_max_simd_tflops, get_max_tensorcore_tflops, nvsmi) + + +@functools.lru_cache() +def get_clock_rate_in_khz(): + try: + return nvsmi(['clocks.max.sm'])[0] * 1e3 + except FileNotFoundError: + import pynvml + + pynvml.nvmlInit() + handle = pynvml.nvmlDeviceGetHandleByIndex(0) + return pynvml.nvmlDeviceGetMaxClockInfo(handle, pynvml.NVML_CLOCK_SM) * 1e3 + + +def get_tensorcore_tflops(device, num_ctas, num_warps, dtype): + ''' return compute throughput in TOPS ''' + total_warps = num_ctas * min(num_warps, 4) + num_subcores = driver.active.utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs + tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops( + dtype, get_clock_rate_in_khz(), device) + return tflops + + +def get_simd_tflops(device, num_ctas, num_warps, dtype): + ''' return compute throughput in TOPS ''' + total_warps = num_ctas * min(num_warps, 4) + num_subcores = driver.active.utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs + tflops = min(num_subcores, total_warps) / num_subcores * get_max_simd_tflops(dtype, get_clock_rate_in_khz(), device) + return tflops + + +def get_tflops(device, num_ctas, num_warps, dtype): + capability = torch.cuda.get_device_capability(device) + if capability[0] < 8 and dtype == torch.float32: + return get_simd_tflops(device, num_ctas, num_warps, dtype) + return get_tensorcore_tflops(device, num_ctas, num_warps, dtype) + + +def estimate_matmul_time( + # backend, device, + num_warps, num_stages, # + A, B, C, # + M, N, K, # + BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, # + debug=False, **kwargs # +): + ''' return estimated running time in ms + = max(compute, loading) + store ''' + device = torch.cuda.current_device() + dtype = A.dtype + dtsize = A.element_size() + + num_cta_m = cdiv(M, BLOCK_M) + num_cta_n = cdiv(N, BLOCK_N) + num_cta_k = SPLIT_K + num_ctas = num_cta_m * num_cta_n * num_cta_k + + # If the input is smaller than the block size + M, N = max(M, BLOCK_M), max(N, BLOCK_N) + + # time to compute + total_ops = 2 * M * N * K / (1024 * 1024 * 1024) # GOPS + tput = get_tflops(device, num_ctas, num_warps, dtype) + compute_ms = total_ops / tput + + # time to load data + num_sm = driver.active.utils.get_device_properties(device)["multiprocessor_count"] + active_cta_ratio = min(1, num_ctas / num_sm) + active_cta_ratio_bw1 = min(1, num_ctas / 32) # 32 active ctas are enough to saturate + active_cta_ratio_bw2 = max(min(1, (num_ctas - 32) / (108 - 32)), 0) # 32-108, remaining 5% + dram_bw = get_dram_gbps(device) * (active_cta_ratio_bw1 * 0.95 + active_cta_ratio_bw2 * 0.05) # in GB/s + l2_bw = dram_bw * 4 # rough estimation (should be 4.7 for A100?) + # assume 80% of (following) loads are in L2 cache + load_a_dram = M * K * dtsize * (1 + 0.2 * (num_cta_n - 1)) + load_a_l2 = M * K * dtsize * 0.8 * (num_cta_n - 1) + load_b_dram = N * K * dtsize * (1 + 0.2 * (num_cta_m - 1)) + load_b_l2 = N * K * dtsize * 0.8 * (num_cta_m - 1) + # total + total_dram = (load_a_dram + load_b_dram) / (1024 * 1024) # MB + total_l2 = (load_a_l2 + load_b_l2) / (1024 * 1024) + # loading time in ms + load_ms = total_dram / dram_bw + total_l2 / l2_bw + + # estimate storing time + store_bw = dram_bw * 0.6 # :o + store_c_dram = M * N * dtsize * SPLIT_K / (1024 * 1024) # MB + if SPLIT_K == 1: + store_ms = store_c_dram / store_bw + else: + reduce_bw = store_bw + store_ms = store_c_dram / reduce_bw + # c.zero_() + zero_ms = M * N * 2 / (1024 * 1024) / store_bw + store_ms += zero_ms + + total_time_ms = max(compute_ms, load_ms) + store_ms + if debug: + print(f'Total time: {total_time_ms}ms, compute time: {compute_ms}ms, ' + f'loading time: {load_ms}ms, store time: {store_ms}ms, ' + f'Activate CTAs: {active_cta_ratio*100}%') + return total_time_ms + + +def early_config_prune(configs, named_args, **kwargs): + device = torch.cuda.current_device() + capability = torch.cuda.get_device_capability() + # BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages + dtsize = named_args['A'].element_size() + dtype = named_args['A'].dtype + + # 1. make sure we have enough smem + pruned_configs = [] + for config in configs: + kw = config.kwargs + BLOCK_M, BLOCK_N, BLOCK_K, num_stages = \ + kw['BLOCK_M'], kw['BLOCK_N'], kw['BLOCK_K'], config.num_stages + + max_shared_memory = driver.active.utils.get_device_properties(device)["max_shared_mem"] + required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize + if required_shared_memory <= max_shared_memory: + pruned_configs.append(config) + configs = pruned_configs + + # Some dtypes do not allow atomic_add + if dtype not in [torch.float16, torch.float32]: + configs = [config for config in configs if config.kwargs['SPLIT_K'] == 1] + + # group configs by (BLOCK_M,_N,_K, SPLIT_K, num_warps) + configs_map = {} + for config in configs: + kw = config.kwargs + BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages = \ + kw['BLOCK_M'], kw['BLOCK_N'], kw['BLOCK_K'], kw['SPLIT_K'], config.num_warps, config.num_stages + + key = (BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps) + if key in configs_map: + configs_map[key].append((config, num_stages)) + else: + configs_map[key] = [(config, num_stages)] + + pruned_configs = [] + for k, v in configs_map.items(): + BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps = k + if capability[0] >= 8: + # compute cycles (only works for ampere GPUs) + mmas = BLOCK_M * BLOCK_N * BLOCK_K / (16 * 8 * 16) + mma_cycles = mmas / min(4, num_warps) * 8 + + ldgsts_latency = 300 # Does this matter? + optimal_num_stages = ldgsts_latency / mma_cycles + + # nearest stages, prefer large #stages + nearest = heapq.nsmallest( + 2, v, key=lambda x: 10 + abs(x[1] - optimal_num_stages) + if (x[1] - optimal_num_stages) < 0 else x[1] - optimal_num_stages) + + for n in nearest: + pruned_configs.append(n[0]) + else: # Volta & Turing only supports num_stages <= 2 + random_config = v[0][0] + random_config.num_stages = 2 + pruned_configs.append(random_config) + return pruned_configs diff --git a/lib/python3.10/site-packages/triton/profiler/__init__.py b/lib/python3.10/site-packages/triton/profiler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0add689155c3dc28595e6ce1aeead48e133276dc --- /dev/null +++ b/lib/python3.10/site-packages/triton/profiler/__init__.py @@ -0,0 +1,10 @@ +# flake8: noqa +from .scope import scope, enter_scope, exit_scope +from .profile import ( + start, + activate, + deactivate, + finalize, + profile, + DEFAULT_PROFILE_NAME, +) diff --git a/lib/python3.10/site-packages/triton/profiler/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/triton/profiler/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9157584adac2b343e3d790c2e07079093775206d Binary files /dev/null and b/lib/python3.10/site-packages/triton/profiler/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/triton/profiler/__pycache__/flags.cpython-310.pyc b/lib/python3.10/site-packages/triton/profiler/__pycache__/flags.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03db92d0bd9be94482a03f6da20d0e3651fd81ff Binary files /dev/null and b/lib/python3.10/site-packages/triton/profiler/__pycache__/flags.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/triton/profiler/__pycache__/hook.cpython-310.pyc b/lib/python3.10/site-packages/triton/profiler/__pycache__/hook.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e84988708b2710e5794a3bdc87706eff763973e3 Binary files /dev/null and b/lib/python3.10/site-packages/triton/profiler/__pycache__/hook.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/triton/profiler/__pycache__/profile.cpython-310.pyc b/lib/python3.10/site-packages/triton/profiler/__pycache__/profile.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a97412cb316c085302b779527e098c5522d2d4ea Binary files /dev/null and b/lib/python3.10/site-packages/triton/profiler/__pycache__/profile.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/triton/profiler/__pycache__/proton.cpython-310.pyc b/lib/python3.10/site-packages/triton/profiler/__pycache__/proton.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f0303f0cbcaa2b25bc8f9351db9f19763bad5d1e Binary files /dev/null and b/lib/python3.10/site-packages/triton/profiler/__pycache__/proton.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/triton/profiler/__pycache__/scope.cpython-310.pyc b/lib/python3.10/site-packages/triton/profiler/__pycache__/scope.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1c12cb305dc85e95eade592f325680f3200b20d Binary files /dev/null and b/lib/python3.10/site-packages/triton/profiler/__pycache__/scope.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/triton/profiler/__pycache__/viewer.cpython-310.pyc b/lib/python3.10/site-packages/triton/profiler/__pycache__/viewer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c3846126d3fb91656a6eacff3ea6995b998a494 Binary files /dev/null and b/lib/python3.10/site-packages/triton/profiler/__pycache__/viewer.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/triton/profiler/flags.py b/lib/python3.10/site-packages/triton/profiler/flags.py new file mode 100644 index 0000000000000000000000000000000000000000..37c75b2433efb08f1c3d8a60a661da0e65fee763 --- /dev/null +++ b/lib/python3.10/site-packages/triton/profiler/flags.py @@ -0,0 +1,33 @@ +""" +This file contains the global flags used in the proton package. +""" + +# Whether to enable profiling. Default is False. +profiling_on = False +# Whether the script is run from the command line. Default is False. +command_line = False + + +def set_profiling_on(): + global profiling_on + profiling_on = True + + +def set_profiling_off(): + global profiling_on + profiling_on = False + + +def get_profiling_on(): + global profiling_on + return profiling_on + + +def set_command_line(): + global command_line + command_line = True + + +def is_command_line(): + global command_line + return command_line diff --git a/lib/python3.10/site-packages/triton/profiler/hook.py b/lib/python3.10/site-packages/triton/profiler/hook.py new file mode 100644 index 0000000000000000000000000000000000000000..a9ec5f36b0f7ea50bdf04d5ef52993e69f995da3 --- /dev/null +++ b/lib/python3.10/site-packages/triton/profiler/hook.py @@ -0,0 +1,33 @@ +from .scope import enter_scope, exit_scope +from triton.compiler import CompiledKernel, LazyDict + +COMPUTE_METADATA_SCOPE_NAME = "__proton_launch_metadata" + + +class TritonHook: + flops_width = [8, 16, 32, 64] + metrics = [f"flops{width}" for width in flops_width] + ["bytes"] + + @staticmethod + def enter(lazy_dict: LazyDict) -> None: + enter_scope(COMPUTE_METADATA_SCOPE_NAME) + metadata = lazy_dict.get() + exit_scope() + fn_metrics = {k: metadata[k] for k in TritonHook.metrics if k in metadata} + enter_scope(metadata["name"], triton_op=True, metrics=fn_metrics) + + @staticmethod + def exit(lazy_dict: LazyDict) -> None: + exit_scope(triton_op=True) + + +def register_triton_hook() -> None: + if CompiledKernel.launch_enter_hook is None: + CompiledKernel.launch_enter_hook = TritonHook.enter + CompiledKernel.launch_exit_hook = TritonHook.exit + + +def unregister_triton_hook() -> None: + if CompiledKernel.launch_enter_hook == TritonHook.enter: + CompiledKernel.launch_enter_hook = None + CompiledKernel.launch_exit_hook = None diff --git a/lib/python3.10/site-packages/triton/profiler/profile.py b/lib/python3.10/site-packages/triton/profiler/profile.py new file mode 100644 index 0000000000000000000000000000000000000000..2bf7938a5967c740eedaf987271c23e4f5fb7c06 --- /dev/null +++ b/lib/python3.10/site-packages/triton/profiler/profile.py @@ -0,0 +1,192 @@ +import functools +import triton + +from triton._C.libproton import proton as libproton +from .hook import register_triton_hook, unregister_triton_hook +from .flags import set_profiling_off, set_profiling_on, is_command_line +from typing import Optional + +DEFAULT_PROFILE_NAME = "proton" + + +def _select_backend() -> str: + backend = triton.runtime.driver.active.get_current_target().backend + if backend == "cuda": + return "cupti" + elif backend == "hip": + return "roctracer" + else: + raise ValueError("No backend is available for the current target.") + + +def start( + name: Optional[str] = None, + *, + context: Optional[str] = "shadow", + data: Optional[str] = "tree", + backend: Optional[str] = None, + hook: Optional[str] = None, +): + """ + Start profiling with the given name and backend. + + Usage: + + ```python + proton.start("my_profile") + # do something + proton.finalize() + ``` + + Args: + name (str, optional): The name (with path) of the profiling session. + If not provided, the default name is "~/proton.hatchet". + backend (str, optional): The backend to use for profiling. + Available options are ["cupti"]. + Defaults to None, which automatically selects the backend matching the current active runtime. + context (str, optional): The context to use for profiling. + Available options are ["shadow", "python"]. + Defaults to "shadow". + data (str, optional): The data structure to use for profiling. + Available options are ["tree"]. + Defaults to "tree". + hook (str, optional): The hook to use for profiling. + Available options are [None, "triton"]. + Defaults to None. + Returns: + session (int): The session ID of the profiling session. + """ + if is_command_line(): + # Ignore the start() call if the script is run from the command line. + return + + if name is None: + name = DEFAULT_PROFILE_NAME + + if backend is None: + backend = _select_backend() + + set_profiling_on() + if hook and hook == "triton": + register_triton_hook() + return libproton.start(name, context, data, backend) + + +def activate(session: Optional[int] = 0) -> None: + """ + Activate the specified session. + The profiling session will be active and data will be recorded. + + Args: + session (int): The session ID of the profiling session. Defaults to 0 (the first session started.) + + Returns: + None + """ + if is_command_line() and session != 0: + raise ValueError("Only one session can be activated when running from the command line.") + libproton.activate(session) + + +def deactivate(session: Optional[int] = 0) -> None: + """ + Stop the specified session. + The profiling session's data will still be in the memory, but no more data will be recorded. + + Args: + session (int): The session ID of the profiling session. Defaults to 0 (the first session started.) + + Returns: + None + """ + if is_command_line() and session != 0: + raise ValueError("Only one session can be deactivated when running from the command line.") + libproton.deactivate(session) + + +def finalize(session: Optional[int] = None, output_format: str = "hatchet") -> None: + """ + Finalizes a profiling session. + Flush and write the profiling data to the file specified by the session name. + + Args: + session (int, optional): The session ID to finalize. If None, all sessions are finalized. Defaults to None. + output_format (str, optional): The output format for the profiling results. + Aavailable options are ["hatchet"]. + + Returns: + None + """ + if session is None: + set_profiling_off() + libproton.finalize_all(output_format) + unregister_triton_hook() + else: + if is_command_line() and session != 0: + raise ValueError("Only one session can be finalized when running from the command line.") + libproton.finalize(session, output_format) + + +def _profiling( + func, + name: Optional[str] = None, + context: Optional[str] = "shadow", + data: Optional[str] = "tree", + backend: Optional[str] = None, + hook: Optional[str] = None, +): + """ + Context manager for profiling. Internally use only. + + Args: + See start() for the arguments. + + Returns: + wrapper (function): The wrapped function. + """ + + @functools.wraps(func) + def wrapper(*args, **kwargs): + session = start(name, context=context, data=data, backend=backend, hook=hook) + ret = func(*args, **kwargs) + deactivate(session) + return ret + + return wrapper + + +def profile( + func=None, + *, + name: Optional[str] = None, + context: Optional[str] = "shadow", + data: Optional[str] = "tree", + backend: Optional[str] = None, + hook: Optional[str] = None, +): + """ + Decorator for profiling. + + Usage: + + ```python + @proton.profile + def foo(): + pass + ``` + + Args: + See start() for the arguments. + + Returns: + decorator (function): The decorator function. + """ + if func is None: + # It's being used with parentheses, so return a decorator + def decorator(f): + return _profiling(f, name=name, context=context, data=data, backend=backend, hook=hook) + + return decorator + else: + # It's being used without parentheses, so apply the decorator directly + return _profiling(func, name=name, context=context, data=data, backend=backend, hook=hook) diff --git a/lib/python3.10/site-packages/triton/profiler/proton.py b/lib/python3.10/site-packages/triton/profiler/proton.py new file mode 100644 index 0000000000000000000000000000000000000000..21267f97da39d0ae971848aa1c96f25958235979 --- /dev/null +++ b/lib/python3.10/site-packages/triton/profiler/proton.py @@ -0,0 +1,78 @@ +import argparse +import sys +import os +from .profile import start, finalize, _select_backend +from .flags import set_command_line + + +def parse_arguments(): + parser = argparse.ArgumentParser( + description="The proton command utility for profiling scripts and pytest tests.", usage=""" + proton [options] script.py [script_args] [script_options] + proton [options] pytest [pytest_args] [script_options] + python -m triton.profiler.proton [options] script.py [script_args] [script_options] +""", formatter_class=argparse.RawTextHelpFormatter) + parser.add_argument("-n", "--name", type=str, help="Name of the profiling session") + parser.add_argument("-b", "--backend", type=str, help="Profiling backend", default=None, choices=["cupti"]) + parser.add_argument("-c", "--context", type=str, help="Profiling context", default="shadow", + choices=["shadow", "python"]) + parser.add_argument("-d", "--data", type=str, help="Profiling data", default="tree", choices=["tree"]) + parser.add_argument("-k", "--hook", type=str, help="Profiling hook", default=None, choices=[None, "triton"]) + args, target_args = parser.parse_known_args() + return args, target_args + + +def is_pytest(script): + return os.path.basename(script) == 'pytest' + + +def execute_as_main(script, args): + script_path = os.path.abspath(script) + # Prepare a clean global environment + clean_globals = { + "__name__": "__main__", + "__file__": script_path, + "__builtins__": __builtins__, + sys.__name__: sys, + } + + original_argv = sys.argv + sys.argv = [script] + args + + # Execute in the isolated environment + try: + with open(script_path, 'rb') as file: + code = compile(file.read(), script_path, 'exec') + exec(code, clean_globals) + except Exception as e: + print(f"An error occurred while executing the script: {e}") + finally: + sys.argv = original_argv + + +def run_profiling(args, target_args): + backend = args.backend if args.backend else _select_backend() + + start(args.name, context=args.context, data=args.data, backend=backend, hook=args.hook) + + # Set the command line mode to avoid any `start` calls in the script. + set_command_line() + + script = target_args[0] + script_args = target_args[1:] if len(target_args) > 1 else [] + if is_pytest(script): + import pytest + pytest.main(script_args) + else: + execute_as_main(script, script_args) + + finalize() + + +def main(): + args, target_args = parse_arguments() + run_profiling(args, target_args) + + +if __name__ == "__main__": + main() diff --git a/lib/python3.10/site-packages/triton/profiler/scope.py b/lib/python3.10/site-packages/triton/profiler/scope.py new file mode 100644 index 0000000000000000000000000000000000000000..5695b8807500a887b6351e52e4dc9ae60b2bbd0d --- /dev/null +++ b/lib/python3.10/site-packages/triton/profiler/scope.py @@ -0,0 +1,105 @@ +import threading +from functools import wraps +from typing import Optional, Union + +from .flags import get_profiling_on +from triton._C.libproton import proton as libproton + +_local = threading.local() + +MetricValueType = Union[float, int] +PropertyValueType = Union[float, int, str] + + +class scope: + """ + A context manager and decorator for entering and exiting a scope. + + Usage: + context manager: + ```python + with proton.scope("test0", {metric_name: metric_value}): + foo[1,](x, y) + ``` + + decoarator: + ```python + @proton.scope("test0", {metric_name: metric_value}) + def foo(x, y): + ... + ``` + + Args: + name (str): The name of the scope. + metrics (dict[str, float], optional): The metrics of the scope. Default is None. + """ + + def __init__(self, name: str, metrics: Optional[dict[str, MetricValueType]] = None, + properties: Optional[dict[str, PropertyValueType]] = None) -> None: + self._name = name + self._metrics = metrics + self._properties = properties + + def __enter__(self): + if not get_profiling_on(): + return self + self._id = libproton.record_scope() + libproton.enter_scope(self._id, self._name) + if self._metrics: + libproton.add_metrics(self._id, self._metrics) + if self._properties: + libproton.set_properties(self._id, self._properties) + return self + + def __exit__(self, exc_type, exc_value, traceback) -> None: + if not get_profiling_on(): + return + libproton.exit_scope(self._id, self._name) + + def __call__(self, func): + + @wraps(func) + def wrapper(*args, **kwargs): + if get_profiling_on(): + id = libproton.record_scope() + libproton.enter_scope(id, self._name) + if self._metrics: + libproton.add_metrics(id, self._metrics) + if self._properties: + libproton.set_properties(id, self._properties) + ret = func(*args, **kwargs) + if get_profiling_on(): + libproton.exit_scope(id, self._name) + return ret + + return wrapper + + +def enter_scope(name: str, *, triton_op: bool = False, metrics: Optional[dict[str, MetricValueType]] = None, + properties: Optional[dict[str, PropertyValueType]] = None) -> int: + if not get_profiling_on(): + return -1 + id = libproton.record_scope() + if not hasattr(_local, "scopes"): + _local.scopes = [] + _local.scopes.append((id, name)) + if triton_op: + libproton.enter_op(id, name) + else: + libproton.enter_scope(id, name) + if metrics: + libproton.add_metrics(id, metrics) + if properties: + libproton.set_properties(id, properties) + return id + + +def exit_scope(triton_op: bool = False) -> int: + if not get_profiling_on(): + return -1 + id, name = _local.scopes.pop() + if triton_op: + libproton.exit_op(id, name) + else: + libproton.exit_scope(id, name) + return id diff --git a/lib/python3.10/site-packages/triton/profiler/viewer.py b/lib/python3.10/site-packages/triton/profiler/viewer.py new file mode 100644 index 0000000000000000000000000000000000000000..3ef3a4c93ab4b0035346eb11249f0cba872bfe20 --- /dev/null +++ b/lib/python3.10/site-packages/triton/profiler/viewer.py @@ -0,0 +1,236 @@ +import argparse +from collections import namedtuple +import json +import pandas as pd + +import hatchet as ht +from triton.profiler.hook import COMPUTE_METADATA_SCOPE_NAME, TritonHook + + +def match_available_metrics(metrics, raw_metrics): + ret = [] + if metrics: + for metric in metrics: + metric = metric.lower() + for raw_metric in raw_metrics: + raw_metric_no_unit = raw_metric.split("(")[0].strip().lower() + if metric in (raw_metric, raw_metric_no_unit): + ret.append(raw_metric + " (inc)") + break + else: + ret = [raw_metrics[0]] + " (inc)" + return ret + + +def get_raw_metrics(file): + database = json.load(file) + device_info = database.pop(1) + gf = ht.GraphFrame.from_literal(database) + return gf, gf.show_metric_columns(), device_info + + +def get_min_time_flops(df, device_info): + min_time_flops = pd.DataFrame(0.0, index=df.index, columns=["min_time"]) + for device_type in device_info: + for device_index in device_info[device_type]: + arch = device_info[device_type][device_index]["arch"] + num_sms = device_info[device_type][device_index]["num_sms"] + clock_rate = device_info[device_type][device_index]["clock_rate"] + for width in TritonHook.flops_width: + idx = df["DeviceId"] == device_index + device_frames = df[idx] + if f"flops{width}" not in device_frames.columns: + continue + max_flops = 0 + if device_type == "CUDA": + if arch == "80": + max_flops = 624e12 / (width / 8) + elif arch == "89": + # TODO(Keren): Implement fp16 acc-> 660.6 fp8 + max_flops = (330.3 * 1e12) / (width / 8) + elif arch == "90": + # 114 sms and 1755mhz is the base number of sms and clock rate of H100 pcie + max_flops = ((num_sms / 114 * clock_rate / (1755 * 1e3) * 1513) * 1e12) / (width / 8) + elif device_type == "HIP": + if arch == "gfx90a": + max_flops = 383e12 / (width / 8) + elif arch == "gfx941" or arch == "gfx942": + max_flops = 2614.9e12 / (width / 8) + else: + raise ValueError(f"Unsupported device type: {device_type}") + min_time_flops.loc[idx, "min_time"] += device_frames[f"flops{width}"].fillna(0) / max_flops + return min_time_flops + + +def get_min_time_bytes(df, device_info): + min_time_bytes = pd.DataFrame(0.0, index=df.index, columns=["min_time"]) + for device_type in device_info: + for device_index in device_info[device_type]: + idx = df["DeviceId"] == device_index + device_frames = df[idx] + memory_clock_rate = device_info[device_type][device_index]["memory_clock_rate"] # in khz + bus_width = device_info[device_type][device_index]["bus_width"] # in bits + peak_bandwidth = 2 * bus_width * memory_clock_rate * 1e3 / 8 + min_time_bytes.loc[idx, "min_time"] += device_frames["bytes"] / peak_bandwidth + return min_time_bytes + + +FactorDict = namedtuple("FactorDict", ["name", "factor"]) +time_factor_dict = FactorDict("time", {"time/s": 1, "time/ms": 1e-3, "time/us": 1e-6, "time/ns": 1e-9}) +flops_factor_dict = FactorDict("flops", {"flop/s": 1, "gflop/s": 1e9, "tflop/s": 1e12}) +bytes_factor_dict = FactorDict("bytes", {"byte/s": 1, "gbyte/s": 1e9, "tbyte/s": 1e12}) + +derivable_metrics = { + **{key: flops_factor_dict + for key in flops_factor_dict.factor.keys()}, + **{key: bytes_factor_dict + for key in bytes_factor_dict.factor.keys()}, +} + + +def derive_metrics(gf, metrics, raw_metrics, device_info): + derived_metrics = [] + original_metrics = [] + time_metric_name = match_available_metrics([time_factor_dict.name], raw_metrics)[0] + time_unit = (time_factor_dict.name + "/" + time_metric_name.split("(")[1].split(")")[0]) + for metric in metrics: + if metric == "util": # Tensor core only + min_time_bytes = get_min_time_bytes(gf.dataframe, device_info) + min_time_flops = get_min_time_flops(gf.dataframe, device_info) + time_sec = gf.dataframe[time_metric_name] * (time_factor_dict.factor[time_unit] / + time_factor_dict.factor["time/s"]) + gf.dataframe["util (inc)"] = min_time_flops["min_time"].combine(min_time_bytes["min_time"], max) / time_sec + derived_metrics.append("util (inc)") + elif metric in derivable_metrics: + deriveable_metric = derivable_metrics[metric] + metric_name = deriveable_metric.name + metric_factor_dict = deriveable_metric.factor + matched_metric_name = match_available_metrics([metric_name], raw_metrics)[0] + gf.dataframe[f"{metric} (inc)"] = (gf.dataframe[matched_metric_name] / + (gf.dataframe[time_metric_name] * time_factor_dict.factor[time_unit]) / + metric_factor_dict[metric]) + derived_metrics.append(f"{metric} (inc)") + elif metric in time_factor_dict.factor: + metric_time_unit = time_factor_dict.name + "/" + metric.split("/")[1] + gf.dataframe[f"{metric} (inc)"] = gf.dataframe[time_metric_name] * ( + time_factor_dict.factor[time_unit] / time_factor_dict.factor[metric_time_unit]) + derived_metrics.append(f"{metric} (inc)") + else: + original_metrics.append(metric) + + if original_metrics: + original_metrics = match_available_metrics(original_metrics, raw_metrics) + return derived_metrics + original_metrics + + +def parse(metrics, filename, include, exclude, threshold, depth): + with open(filename, "r") as f: + gf, raw_metrics, device_info = get_raw_metrics(f) + assert len(raw_metrics) > 0, "No metrics found in the input file" + gf.update_inclusive_columns() + metrics = derive_metrics(gf, metrics, raw_metrics, device_info) + if include or exclude: + # make regex do negative match + name_filter = f"^(?!{exclude}).*" if exclude else include + query = ["*", {"name": name_filter}] + gf = gf.filter(query, squash=True) + # filter out metadata computation + query = [{"name": f"^(?!{COMPUTE_METADATA_SCOPE_NAME}).*"}] + gf = gf.filter(query, squash=True) + if threshold: + # TODO: generalize to support multiple metrics + query = ["*", {metrics[0]: f">= {threshold}"}] + gf = gf.filter(query, squash=True) + print(gf.tree(metric_column=metrics, expand_name=True, depth=depth, render_header=False)) + + +def show_metrics(file_name): + with open(file_name, "r") as f: + _, raw_metrics, _ = get_raw_metrics(f) + print("Available metrics:") + if raw_metrics: + for raw_metric in raw_metrics: + raw_metric_no_unit = raw_metric.split("(")[0].strip().lower() + print(f"- {raw_metric_no_unit}") + return + + +def main(): + argparser = argparse.ArgumentParser( + description="Performance data viewer for proton profiles.", + formatter_class=argparse.RawTextHelpFormatter, + ) + argparser.add_argument( + "-l", + "--list", + action="store_true", + help="""List available metrics. Metric names are case insensitive and ignore units. +Derived metrics can be created when source metrics are available. +- time/s, time/ms, time/us, time/ns: time +- flop/s, gflop/s, tflop/s: flops / time +- byte/s, gbyte/s, tbyte/s: bytes / time +- util: max(sum(flops) / peak_flops_time, bytes / peak_bandwidth_time)) +""", + ) + argparser.add_argument( + "-m", + "--metrics", + type=str, + default=None, + help="""At maximum two metrics can be specified, separated by comma. +There are two modes: +1) Choose the output metric to display. It's case insensitive and ignore units. +2) Derive a new metric from existing metrics. +""", + ) + argparser.add_argument( + "-i", + "--include", + type=str, + default=None, + help="Include frames(kernels) that match the given regular expression", + ) + argparser.add_argument( + "-e", + "--exclude", + type=str, + default=None, + help="Exclude frames(kernels) that match the given regular expression", + ) + + argparser.add_argument( + "-t", + "--threshold", + type=float, + default=None, + help= + "Exclude frames(kernels) whose metrics are below the given threshold. This filter only applies on the first metric.", + ) + + argparser.add_argument( + "-d", + "--depth", + type=int, + default=100, + help="The depth of the tree to display", + ) + + args, target_args = argparser.parse_known_args() + assert len(target_args) == 1, "Must specify a file to read" + + file_name = target_args[0] + metrics = args.metrics.split(",") if args.metrics else None + include = args.include + exclude = args.exclude + threshold = args.threshold + depth = args.depth + if include and exclude: + raise ValueError("Cannot specify both include and exclude") + if args.list: + show_metrics(file_name) + elif metrics: + parse(metrics, file_name, include, exclude, threshold, depth) + + +if __name__ == "__main__": + main() diff --git a/lib/python3.10/site-packages/triton/runtime/__init__.py b/lib/python3.10/site-packages/triton/runtime/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0b3979d28d9a4359be5ceb3d2dea08f4d56899e6 --- /dev/null +++ b/lib/python3.10/site-packages/triton/runtime/__init__.py @@ -0,0 +1,23 @@ +from .autotuner import (Autotuner, Config, Heuristics, autotune, heuristics) +from .cache import RedisRemoteCacheBackend, RemoteCacheBackend +from .driver import driver +from .jit import JITFunction, KernelInterface, MockTensor, TensorWrapper, reinterpret +from .errors import OutOfResources, InterpreterError + +__all__ = [ + "autotune", + "Autotuner", + "Config", + "driver", + "Heuristics", + "heuristics", + "InterpreterError", + "JITFunction", + "KernelInterface", + "MockTensor", + "OutOfResources", + "RedisRemoteCacheBackend", + "reinterpret", + "RemoteCacheBackend", + "TensorWrapper", +] diff --git a/lib/python3.10/site-packages/triton/runtime/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/triton/runtime/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a9d1194369c6e1417fd1d181fc703678b5ec8e3 Binary files /dev/null and b/lib/python3.10/site-packages/triton/runtime/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/triton/runtime/__pycache__/autotuner.cpython-310.pyc b/lib/python3.10/site-packages/triton/runtime/__pycache__/autotuner.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..92f021a18688d95218eb5e97ebfecdae4804eba0 Binary files /dev/null and b/lib/python3.10/site-packages/triton/runtime/__pycache__/autotuner.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/triton/runtime/__pycache__/build.cpython-310.pyc b/lib/python3.10/site-packages/triton/runtime/__pycache__/build.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..497bad16731959d98d4b4bba0e2f53fd72905e09 Binary files /dev/null and b/lib/python3.10/site-packages/triton/runtime/__pycache__/build.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/triton/runtime/__pycache__/cache.cpython-310.pyc b/lib/python3.10/site-packages/triton/runtime/__pycache__/cache.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..543c5e348cf26d6f0dd0c56ad20a07c4ee9e6917 Binary files /dev/null and b/lib/python3.10/site-packages/triton/runtime/__pycache__/cache.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/triton/runtime/__pycache__/driver.cpython-310.pyc b/lib/python3.10/site-packages/triton/runtime/__pycache__/driver.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..687bec037e13dd67d632d041c90ca1000381cbfa Binary files /dev/null and b/lib/python3.10/site-packages/triton/runtime/__pycache__/driver.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/triton/runtime/__pycache__/errors.cpython-310.pyc b/lib/python3.10/site-packages/triton/runtime/__pycache__/errors.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1732fdad2b943065e1d469c3a8be64d717bf5411 Binary files /dev/null and b/lib/python3.10/site-packages/triton/runtime/__pycache__/errors.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/triton/runtime/__pycache__/interpreter.cpython-310.pyc b/lib/python3.10/site-packages/triton/runtime/__pycache__/interpreter.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aeb109ddbbdf397e554085c4e098a0a08e7aba93 Binary files /dev/null and b/lib/python3.10/site-packages/triton/runtime/__pycache__/interpreter.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/triton/runtime/__pycache__/jit.cpython-310.pyc b/lib/python3.10/site-packages/triton/runtime/__pycache__/jit.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4980a09f82a17760d5d119180efcd437a84b8763 Binary files /dev/null and b/lib/python3.10/site-packages/triton/runtime/__pycache__/jit.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/triton/runtime/autotuner.py b/lib/python3.10/site-packages/triton/runtime/autotuner.py new file mode 100644 index 0000000000000000000000000000000000000000..73e61866273c4cb9fec7688e2aecbc1e459ce12e --- /dev/null +++ b/lib/python3.10/site-packages/triton/runtime/autotuner.py @@ -0,0 +1,361 @@ +from __future__ import annotations + +import builtins +import os +import time +import inspect +from typing import Dict + +from ..testing import do_bench, do_bench_cudagraph +from .jit import KernelInterface +from .errors import OutOfResources + + +class Autotuner(KernelInterface): + + def __init__( + self, + fn, + arg_names, + configs, + key, + reset_to_zero, + restore_value, + pre_hook=None, + post_hook=None, + prune_configs_by: Dict = None, + warmup=25, + rep=100, + use_cuda_graph=False, + ): + """ + :param prune_configs_by: a dict of functions that are used to prune configs, fields: + 'perf_model': performance model used to predicate running time with different configs, returns running time + 'top_k': number of configs to bench + 'prune_num_stages_by'(optional): a function used to prune num_stages. It takes configs:List[Config] as its input, and returns pruned configs. + """ + if not configs: + self.configs = [Config({}, num_warps=4, num_stages=2, num_ctas=1)] + else: + self.configs = configs + self.key_idx = [arg_names.index(k) for k in key] + self.cache = {} + self.arg_names = arg_names + + # Reset to zero or restore values + self.reset_idx = [] + if reset_to_zero is not None: + self.reset_idx = [arg_names.index(k) for k in reset_to_zero] + self.restore_idx = [] + if restore_value is not None: + self.restore_idx = [arg_names.index(k) for k in restore_value] + + # Hook to reset or restore for required tensors + self.pre_hook = lambda args, reset_only=False: 0 + self.post_hook = lambda args, exception: 0 + if pre_hook: + self.pre_hook = pre_hook + elif (len(self.reset_idx) > 0 or len(self.restore_idx) > 0): + + def _pre_hook(args, reset_only=False): + for i in self.reset_idx: + args[i].zero_() + if not reset_only: + self.restore_copies = [args[i].clone() for i in self.restore_idx] + + self.pre_hook = _pre_hook + + if post_hook: + self.post_hook = post_hook + elif len(self.restore_idx) > 0: + + def _post_hook(args, exception): + for i, j in enumerate(self.restore_idx): + args[j].copy_(self.restore_copies[i]) + self.restore_copies = [] + + self.post_hook = _post_hook + + self.perf_model = None + self.configs_top_k = 1.0 + self.early_config_prune = None + if prune_configs_by: + self.perf_model = prune_configs_by.get("perf_model", self.perf_model) + self.configs_top_k = prune_configs_by.get("top_k", self.configs_top_k) + self.early_config_prune = prune_configs_by.get("early_config_prune", self.early_config_prune) + + self.fn = fn + self.base_fn = fn + while not inspect.isfunction(self.base_fn): + self.base_fn = self.base_fn.fn + self.num_warmups = warmup + self.num_reps = rep + import torch + self.use_cuda_graph = use_cuda_graph and torch.cuda.is_available() + + def _bench(self, *args, config, **meta): + from ..compiler.errors import CompileTimeAssertionFailure + + # check for conflicts, i.e. meta-parameters both provided + # as kwargs and by the autotuner + conflicts = meta.keys() & config.kwargs.keys() + if conflicts: + raise ValueError(f"Conflicting meta-parameters: {', '.join(conflicts)}." + " Make sure that you don't re-define auto-tuned symbols.") + # augment meta-parameters with tunable ones + current = dict(meta, **config.all_kwargs()) + full_nargs = {**self.nargs, **current} + + def kernel_call(): + if config.pre_hook: + config.pre_hook(full_nargs) + self.pre_hook(args) + try: + self.fn.run( + *args, + **current, + ) + except Exception as e: + try: + self.post_hook(args, exception=e) + finally: + # Throw exception raised by `self.fn.run` + raise + + self.post_hook(args, exception=None) + + try: + if self.use_cuda_graph: + import torch + with torch.cuda.stream(torch.cuda.Stream()): + bench_res = do_bench_cudagraph(kernel_call, rep=self.num_reps, return_mode="median") + return bench_res + return do_bench(kernel_call, warmup=self.num_warmups, rep=self.num_reps, quantiles=(0.5, 0.2, 0.8)) + except (OutOfResources, CompileTimeAssertionFailure): + return float("inf") if self.use_cuda_graph else [float("inf"), float("inf"), float("inf")] + + def run(self, *args, **kwargs): + self.nargs = dict(zip(self.arg_names, args)) + used_cached_result = True + if len(self.configs) > 1: + all_args = {**self.nargs, **kwargs} + _args = [] + for name in self.arg_names: + if name in all_args: + _args.append(all_args[name]) + key = [_args[i] for i in self.key_idx] + for arg in _args: + if hasattr(arg, "dtype"): + key.append(str(arg.dtype)) + key = tuple(key) + if key not in self.cache: + # prune configs + used_cached_result = False + pruned_configs = self.prune_configs(kwargs) + bench_start = time.time() + timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs} + bench_end = time.time() + self.bench_time = bench_end - bench_start + self.cache[key] = builtins.min(timings, key=timings.get) + self.pre_hook(args, reset_only=True) + self.configs_timings = timings + config = self.cache[key] + else: + config = self.configs[0] + self.best_config = config + if os.getenv("TRITON_PRINT_AUTOTUNING", None) == "1" and not used_cached_result: + print(f"Triton autotuning for function {self.base_fn.__name__} finished after " + f"{self.bench_time:.2f}s; best config selected: {self.best_config};") + if config.pre_hook is not None: + config.pre_hook({**self.nargs, **kwargs, **config.all_kwargs()}) + ret = self.fn.run( + *args, + **kwargs, + **config.all_kwargs(), + ) + self.nargs = None + return ret + + def prune_configs(self, kwargs): + pruned_configs = self.configs + if self.early_config_prune: + pruned_configs = self.early_config_prune(self.configs, self.nargs, **kwargs) + if self.perf_model: + top_k = self.configs_top_k + if isinstance(top_k, float) and top_k <= 1.0: + top_k = int(len(self.configs) * top_k) + if len(pruned_configs) > top_k: + est_timing = { + config: self.perf_model( + **self.nargs, + **kwargs, + **config.all_kwargs(), + ) + for config in pruned_configs + } + pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k] + return pruned_configs + + def warmup(self, *args, **kwargs): + self.nargs = dict(zip(self.arg_names, args)) + ret = [] + for config in self.prune_configs(kwargs): + ret.append(self.fn.warmup( + *args, + **kwargs, + **config.all_kwargs(), + )) + self.nargs = None + return ret + + +class Config: + """ + An object that represents a possible kernel configuration for the auto-tuner to try. + + :ivar kwargs: a dictionary of meta-parameters to pass to the kernel as keyword arguments. + :type kwargs: dict[Str, Any] + :ivar num_warps: the number of warps to use for the kernel when compiled for GPUs. For example, if + `num_warps=8`, then each kernel instance will be automatically parallelized to + cooperatively execute using `8 * 32 = 256` threads. + :type num_warps: int + :ivar num_stages: the number of stages that the compiler should use when software-pipelining loops. + Mostly useful for matrix multiplication workloads on SM80+ GPUs. + :type num_ctas: int + :ivar num_ctas: number of blocks in a block cluster. SM90+ only. + :type maxnreg: Optional[int] + :ivar maxnreg: maximum number of registers one thread can use. Corresponds + to ptx .maxnreg directive. Not supported on all platforms. + :ivar pre_hook: a function that will be called before the kernel is called. Parameters of this + function are args. + """ + + def __init__(self, kwargs, num_warps=4, num_stages=2, num_ctas=1, maxnreg=None, pre_hook=None): + self.kwargs = kwargs + self.num_warps = num_warps + self.num_ctas = num_ctas + self.num_stages = num_stages + self.maxnreg = maxnreg + self.pre_hook = pre_hook + + def all_kwargs(self): + return { + **self.kwargs, **{ + k: v + for (k, v) in ( + ("num_warps", self.num_warps), + ("num_ctas", self.num_ctas), + ("num_stages", self.num_stages), + ("maxnreg", self.maxnreg), + ) if v is not None + } + } + + def __str__(self): + res = [] + for k, v in self.kwargs.items(): + res.append(f"{k}: {v}") + res.append(f"num_warps: {self.num_warps}") + res.append(f"num_ctas: {self.num_ctas}") + res.append(f"num_stages: {self.num_stages}") + res.append(f"maxnreg: {self.maxnreg}") + return ", ".join(res) + + +def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, restore_value=None, pre_hook=None, post_hook=None, + warmup=25, rep=100, use_cuda_graph=False): + """ + Decorator for auto-tuning a :code:`triton.jit`'d function. + + .. highlight:: python + .. code-block:: python + + @triton.autotune(configs=[ + triton.Config(kwargs={'BLOCK_SIZE': 128}, num_warps=4), + triton.Config(kwargs={'BLOCK_SIZE': 1024}, num_warps=8), + ], + key=['x_size'] # the two above configs will be evaluated anytime + # the value of x_size changes + ) + @triton.jit + def kernel(x_ptr, x_size, **META): + BLOCK_SIZE = META['BLOCK_SIZE'] + :note: When all the configurations are evaluated, the kernel will run multiple times. + This means that whatever value the kernel updates will be updated multiple times. + To avoid this undesired behavior, you can use the `reset_to_zero` argument, which + resets the value of the provided tensor to `zero` before running any configuration. + + If the environment variable :code:`TRITON_PRINT_AUTOTUNING` is set to + :code:`"1"`, Triton will print a message to stdout after autotuning each + kernel, including the time spent autotuning and the best configuration. + + :param configs: a list of :code:`triton.Config` objects + :type configs: list[triton.Config] + :param key: a list of argument names whose change in value will trigger the evaluation of all provided configs. + :type key: list[str] + :param prune_configs_by: a dict of functions that are used to prune configs, fields: + 'perf_model': performance model used to predicate running time with different configs, returns running time + 'top_k': number of configs to bench + 'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It takes configs:List[Config] as its input, and returns pruned configs. + :param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs. + :type reset_to_zero: list[str] + :param restore_value: a list of argument names whose value will be restored after evaluating any configs. + :type restore_value: list[str] + :param pre_hook: a function that will be called before the kernel is called. + This overrides the default pre_hook used for 'reset_to_zero' and 'restore_value'. + 'args': a list of arguments passed to the kernel. + 'reset_only': a boolean indicating whether the pre_hook is called to reset the values only, without a corresponding post_hook. + :type pre_hook: lambda args, reset_only + :param post_hook: a function that will be called after the kernel is called. + This overrides the default post_hook used for 'restore_value'. + 'args': a list of arguments passed to the kernel. + 'exception': the exception raised by the kernel in case of a compilation or runtime error. + :type post_hook: lambda args, exception + :param warmup: Warmup time (in ms) to pass to benchmarking, defaults to 25. + :type warmup: int + :param rep: Repetition time (in ms) to pass to benchmarking, defaults to 100. + :type rep: int + """ + + def decorator(fn): + return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, restore_value, pre_hook=pre_hook, + post_hook=post_hook, prune_configs_by=prune_configs_by, warmup=warmup, rep=rep, + use_cuda_graph=use_cuda_graph) + + return decorator + + +class Heuristics(KernelInterface): + + def __init__(self, fn, arg_names, values) -> None: + self.fn = fn + self.values = values + self.arg_names = arg_names + + def run(self, *args, **kwargs): + for v, heur in self.values.items(): + kwargs[v] = heur({**dict(zip(self.arg_names, args)), **kwargs}) + return self.fn.run(*args, **kwargs) + + +def heuristics(values): + """ + Decorator for specifying how the values of certain meta-parameters may be computed. + This is useful for cases where auto-tuning is prohibitevely expensive, or just not applicable. + + .. highlight:: python + .. code-block:: python + + @triton.heuristics(values={'BLOCK_SIZE': lambda args: 2 ** int(math.ceil(math.log2(args[1])))}) + @triton.jit + def kernel(x_ptr, x_size, **META): + BLOCK_SIZE = META['BLOCK_SIZE'] # smallest power-of-two >= x_size + :param values: a dictionary of meta-parameter names and functions that compute the value of the meta-parameter. + each such function takes a list of positional arguments as input. + :type values: dict[str, Callable[[list[Any]], Any]] + """ + + def decorator(fn): + return Heuristics(fn, fn.arg_names, values) + + return decorator diff --git a/lib/python3.10/site-packages/triton/runtime/build.py b/lib/python3.10/site-packages/triton/runtime/build.py new file mode 100644 index 0000000000000000000000000000000000000000..d7baeb2868b0a6c62f852f5a30251795a56ad381 --- /dev/null +++ b/lib/python3.10/site-packages/triton/runtime/build.py @@ -0,0 +1,78 @@ +import contextlib +import sys +import io +import sysconfig +import os +import shutil +import subprocess +import setuptools + + +@contextlib.contextmanager +def quiet(): + old_stdout, old_stderr = sys.stdout, sys.stderr + sys.stdout, sys.stderr = io.StringIO(), io.StringIO() + try: + yield + finally: + sys.stdout, sys.stderr = old_stdout, old_stderr + + +def _build(name, src, srcdir, library_dirs, include_dirs, libraries): + suffix = sysconfig.get_config_var('EXT_SUFFIX') + so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix)) + # try to avoid setuptools if possible + cc = os.environ.get("CC") + if cc is None: + # TODO: support more things here. + clang = shutil.which("clang") + gcc = shutil.which("gcc") + cc = gcc if gcc is not None else clang + if cc is None: + raise RuntimeError("Failed to find C compiler. Please specify via CC environment variable.") + # This function was renamed and made public in Python 3.10 + if hasattr(sysconfig, 'get_default_scheme'): + scheme = sysconfig.get_default_scheme() + else: + scheme = sysconfig._get_default_scheme() + # 'posix_local' is a custom scheme on Debian. However, starting Python 3.10, the default install + # path changes to include 'local'. This change is required to use triton with system-wide python. + if scheme == 'posix_local': + scheme = 'posix_prefix' + py_include_dir = sysconfig.get_paths(scheme=scheme)["include"] + include_dirs = include_dirs + [srcdir, py_include_dir] + cc_cmd = [cc, src, "-O3", "-shared", "-fPIC", "-o", so] + cc_cmd += [f'-l{lib}' for lib in libraries] + cc_cmd += [f"-L{dir}" for dir in library_dirs] + cc_cmd += [f"-I{dir}" for dir in include_dirs] + ret = subprocess.check_call(cc_cmd) + if ret == 0: + return so + # fallback on setuptools + extra_compile_args = [] + # extra arguments + extra_link_args = [] + # create extension module + ext = setuptools.Extension( + name=name, + language='c', + sources=[src], + include_dirs=include_dirs, + extra_compile_args=extra_compile_args + ['-O3'], + extra_link_args=extra_link_args, + library_dirs=library_dirs, + libraries=libraries, + ) + # build extension module + args = ['build_ext'] + args.append('--build-temp=' + srcdir) + args.append('--build-lib=' + srcdir) + args.append('-q') + args = dict( + name=name, + ext_modules=[ext], + script_args=args, + ) + with quiet(): + setuptools.setup(**args) + return so diff --git a/lib/python3.10/site-packages/triton/runtime/cache.py b/lib/python3.10/site-packages/triton/runtime/cache.py new file mode 100644 index 0000000000000000000000000000000000000000..bd3c29b9964cac408e7bf8aa3a47635d9c7abf54 --- /dev/null +++ b/lib/python3.10/site-packages/triton/runtime/cache.py @@ -0,0 +1,281 @@ +import importlib +import json +import os +import uuid +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Dict, List, Optional +import hashlib + + +def default_cache_dir(): + return os.path.join(Path.home(), ".triton", "cache") + + +def default_override_dir(): + return os.path.join(Path.home(), ".triton", "override") + + +def default_dump_dir(): + return os.path.join(Path.home(), ".triton", "dump") + + +class CacheManager(ABC): + + def __init__(self, key): + pass + + @abstractmethod + def get_file(self, filename) -> Optional[str]: + pass + + @abstractmethod + def put(self, data, filename, binary=True) -> str: + pass + + @abstractmethod + def get_group(self, filename: str) -> Optional[Dict[str, str]]: + pass + + @abstractmethod + def put_group(self, filename: str, group: Dict[str, str]): + pass + + +class FileCacheManager(CacheManager): + + def __init__(self, key, override=False, dump=False): + self.key = key + self.lock_path = None + if dump: + self.cache_dir = default_dump_dir() + self.cache_dir = os.path.join(self.cache_dir, self.key) + self.lock_path = os.path.join(self.cache_dir, "lock") + os.makedirs(self.cache_dir, exist_ok=True) + elif override: + self.cache_dir = default_override_dir() + self.cache_dir = os.path.join(self.cache_dir, self.key) + else: + # create cache directory if it doesn't exist + self.cache_dir = os.getenv("TRITON_CACHE_DIR", "").strip() or default_cache_dir() + if self.cache_dir: + self.cache_dir = os.path.join(self.cache_dir, self.key) + self.lock_path = os.path.join(self.cache_dir, "lock") + os.makedirs(self.cache_dir, exist_ok=True) + else: + raise RuntimeError("Could not create or locate cache dir") + + def _make_path(self, filename) -> str: + return os.path.join(self.cache_dir, filename) + + def has_file(self, filename) -> bool: + if not self.cache_dir: + raise RuntimeError("Could not create or locate cache dir") + return os.path.exists(self._make_path(filename)) + + def get_file(self, filename) -> Optional[str]: + if self.has_file(filename): + return self._make_path(filename) + else: + return None + + def get_group(self, filename: str) -> Optional[Dict[str, str]]: + grp_filename = f"__grp__{filename}" + if not self.has_file(grp_filename): + return None + grp_filepath = self._make_path(grp_filename) + with open(grp_filepath) as f: + grp_data = json.load(f) + child_paths = grp_data.get("child_paths", None) + # Invalid group data. + if child_paths is None: + return None + result = {} + for c, p in child_paths.items(): + if os.path.exists(p): + result[c] = p + return result + + # Note a group of pushed files as being part of a group + def put_group(self, filename: str, group: Dict[str, str]) -> str: + if not self.cache_dir: + raise RuntimeError("Could not create or locate cache dir") + grp_contents = json.dumps({"child_paths": group}) + grp_filename = f"__grp__{filename}" + return self.put(grp_contents, grp_filename, binary=False) + + def put(self, data, filename, binary=True) -> str: + if not self.cache_dir: + raise RuntimeError("Could not create or locate cache dir") + binary = isinstance(data, bytes) + if not binary: + data = str(data) + assert self.lock_path is not None + filepath = self._make_path(filename) + # Random ID to avoid any collisions + rnd_id = str(uuid.uuid4()) + # we use the PID in case a bunch of these around so we can see what PID made it + pid = os.getpid() + # use tempfile to be robust against program interruptions + temp_path = f"{filepath}.tmp.pid_{pid}_{rnd_id}" + mode = "wb" if binary else "w" + with open(temp_path, mode) as f: + f.write(data) + # Replace is guaranteed to be atomic on POSIX systems if it succeeds + # so filepath cannot see a partial write + os.replace(temp_path, filepath) + return filepath + + +class RemoteCacheBackend: + """ + A backend implementation for accessing a remote/distributed cache. + """ + + def __init__(self, key: str): + pass + + @abstractmethod + def get(self, filenames: List[str]) -> Dict[str, bytes]: + pass + + @abstractmethod + def put(self, filename: str, data: bytes): + pass + + +class RedisRemoteCacheBackend(RemoteCacheBackend): + + def __init__(self, key): + import redis + self._key = key + self._key_fmt = os.environ.get("TRITON_REDIS_KEY_FORMAT", "triton:{key}:{filename}") + self._redis = redis.Redis( + host=os.environ.get("TRITON_REDIS_HOST", "localhost"), + port=int(os.environ.get("TRITON_REDIS_PORT", 6379)), + ) + + def _get_key(self, filename: str) -> str: + return self._key_fmt.format(key=self._key, filename=filename) + + def get(self, filenames: List[str]) -> Dict[str, str]: + results = self._redis.mget([self._get_key(f) for f in filenames]) + return {filename: result for filename, result in zip(filenames, results) if result is not None} + + def put(self, filename: str, data: bytes) -> Dict[str, bytes]: + self._redis.set(self._get_key(filename), data) + + +class RemoteCacheManager(CacheManager): + + def __init__(self, key, override=False, dump=False): + # Setup backend pointed too by `TRITON_REMOTE_CACHE_BACKEND`. + remote_cache_manager = os.environ["TRITON_REMOTE_CACHE_BACKEND"] + module_path, clz_nme = remote_cache_manager.split(":") + module = importlib.import_module(module_path) + remote_cache_cls = getattr(module, clz_nme) + self._backend = remote_cache_cls(key) + + self._override = override + self._dump = dump + + # Use a `FileCacheManager` to materialize remote cache paths locally. + self._file_cache_manager = FileCacheManager(key, override=override, dump=dump) + + def _materialize(self, filename: str, data: bytes): + # We use a backing `FileCacheManager` to provide the materialized data. + return self._file_cache_manager.put(data, filename, binary=True) + + def get_file(self, filename: str) -> Optional[str]: + # We don't handle the dump/override cases. + if self._dump or self._override: + return self._file_cache_manager.get_file(filename) + + # We always check the remote cache backend -- even if our internal file- + # based cache has the item -- to make sure LRU accounting works as + # expected. + results = self._backend.get([filename]) + if len(results) == 0: + return None + (_, data), = results.items() + return self._materialize(filename, data) + + def put(self, data, filename: str, binary=True) -> str: + # We don't handle the dump/override cases. + if self._dump or self._override: + return self._file_cache_manager.put(data, filename, binary=binary) + + if not isinstance(data, bytes): + data = str(data).encode("utf-8") + self._backend.put(filename, data) + return self._materialize(filename, data) + + def get_group(self, filename: str) -> Optional[Dict[str, str]]: + # We don't handle the dump/override cases. + if self._dump or self._override: + return self._file_cache_manager.get_group(filename) + + grp_filename = f"__grp__{filename}" + grp_filepath = self.get_file(grp_filename) + if grp_filepath is None: + return None + with open(grp_filepath) as f: + grp_data = json.load(f) + child_paths = grp_data.get("child_paths", None) + + result = None + + # Found group data. + if child_paths is not None: + result = {} + for child_path, data in self._backend.get(child_paths).items(): + result[child_path] = self._materialize(child_path, data) + + return result + + def put_group(self, filename: str, group: Dict[str, str]): + # We don't handle the dump/override cases. + if self._dump or self._override: + return self._file_cache_manager.put_group(filename, group) + + grp_contents = json.dumps({"child_paths": sorted(list(group.keys()))}) + grp_filename = f"__grp__{filename}" + return self.put(grp_contents, grp_filename) + + +__cache_cls = FileCacheManager +__cache_cls_nme = "DEFAULT" + + +def get_cache_manager(key) -> CacheManager: + import os + + user_cache_manager = os.environ.get("TRITON_CACHE_MANAGER", None) + global __cache_cls + global __cache_cls_nme + + if user_cache_manager is not None and user_cache_manager != __cache_cls_nme: + module_path, clz_nme = user_cache_manager.split(":") + module = importlib.import_module(module_path) + __cache_cls = getattr(module, clz_nme) + __cache_cls_nme = user_cache_manager + + return __cache_cls(key) + + +def get_override_manager(key) -> CacheManager: + return __cache_cls(key, override=True) + + +def get_dump_manager(key) -> CacheManager: + return __cache_cls(key, dump=True) + + +def make_so_cache_key(version_hash, signature, constants, ids, **kwargs): + # Get unique key for the compiled code + signature = {k: 'ptr' if v[0] == '*' else v for k, v in signature.items()} + key = f"{version_hash}-{''.join(signature.values())}-{constants}-{ids}" + for kw in kwargs: + key = f"{key}-{kwargs.get(kw)}" + key = hashlib.sha256(key.encode("utf-8")).hexdigest() + return key diff --git a/lib/python3.10/site-packages/triton/runtime/driver.py b/lib/python3.10/site-packages/triton/runtime/driver.py new file mode 100644 index 0000000000000000000000000000000000000000..c3b97a764145490b8c955cc37fe366bb755176a4 --- /dev/null +++ b/lib/python3.10/site-packages/triton/runtime/driver.py @@ -0,0 +1,60 @@ +from ..backends import backends +from ..backends import DriverBase + + +def _create_driver(): + actives = [x.driver for x in backends.values() if x.driver.is_active()] + if len(actives) != 1: + raise RuntimeError(f"{len(actives)} active drivers ({actives}). There should only be one.") + return actives[0]() + + +class LazyProxy: + + def __init__(self, init_fn): + self._init_fn = init_fn + self._obj = None + + def _initialize_obj(self): + if self._obj is None: + self._obj = self._init_fn() + + def __getattr__(self, name): + self._initialize_obj() + return getattr(self._obj, name) + + def __setattr__(self, name, value): + if name in ["_init_fn", "_obj"]: + super().__setattr__(name, value) + else: + self._initialize_obj() + setattr(self._obj, name, value) + + def __delattr__(self, name): + self._initialize_obj() + delattr(self._obj, name) + + def __repr__(self): + if self._obj is None: + return f"<{self.__class__.__name__} for {self._init_fn} not yet initialized>" + return repr(self._obj) + + def __str__(self): + self._initialize_obj() + return str(self._obj) + + +class DriverConfig: + + def __init__(self): + self.default = LazyProxy(_create_driver) + self.active = self.default + + def set_active(self, driver: DriverBase): + self.active = driver + + def reset_active(self): + self.active = self.default + + +driver = DriverConfig() diff --git a/lib/python3.10/site-packages/triton/runtime/errors.py b/lib/python3.10/site-packages/triton/runtime/errors.py new file mode 100644 index 0000000000000000000000000000000000000000..4dce9176709a91675921af217b55f530ff57ddc9 --- /dev/null +++ b/lib/python3.10/site-packages/triton/runtime/errors.py @@ -0,0 +1,26 @@ +from ..errors import TritonError +from typing import Optional + + +class InterpreterError(TritonError): + + def __init__(self, error_message: Optional[str] = None): + self.error_message = error_message + + def __str__(self) -> str: + return self.error_message or "" + + +class OutOfResources(TritonError): + + def __init__(self, required, limit, name): + self.required = required + self.limit = limit + self.name = name + + def __str__(self) -> str: + return f"out of resource: {self.name}, Required: {self.required}, Hardware limit: {self.limit}. Reducing block sizes or `num_stages` may help." + + def __reduce__(self): + # this is necessary to make CompilationError picklable + return (type(self), (self.required, self.limit, self.name)) diff --git a/lib/python3.10/site-packages/triton/runtime/interpreter.py b/lib/python3.10/site-packages/triton/runtime/interpreter.py new file mode 100644 index 0000000000000000000000000000000000000000..a82832ecf9821d2517514a6ed47babd9dbdb5f71 --- /dev/null +++ b/lib/python3.10/site-packages/triton/runtime/interpreter.py @@ -0,0 +1,1127 @@ +import inspect +from typing import Tuple + +import math +import numpy as np + +import triton +import triton.language as tl +from dataclasses import dataclass +from .errors import InterpreterError +from functools import partial +from .._C.libtriton import interpreter as _interpreter +from .._C.libtriton import ir as _ir + + +class TensorHandle: + + def __init__(self, data, dtype): + ''' + data: numpy array + dtype: triton type, either pointer_type or scalar_type. + we don't store block_type here because the shape information is already availale in the data field + attr: a dictionary of attributes + ''' + self.data = data + self.dtype = dtype + self.attr = {} + + def __bool__(self): + return bool(self.data.all()) + + def get_element_ty(self): + dtype = self.dtype + while hasattr(dtype, "element_ty"): + dtype = dtype.element_ty + return dtype + + def clone(self): + return TensorHandle(self.data.copy(), self.dtype) + + def set_attr(self, key, value): + self.attr[key] = value + + +class BlockPointerHandle: + + def __init__(self, base, shape, strides, offsets, tensor_shape, order): + self.base = base + self.shape = shape + self.strides = strides + self.offsets = offsets + self.tensor_shape = tensor_shape + self.order = order + + def materialize_pointers(self, boundary_check): + dtype_tt = self.base.get_element_ty() + n_bytes = dtype_tt.primitive_bitwidth // 8 + tensor_shape = self.tensor_shape + ptrs = np.broadcast_to(self.base.data, self.tensor_shape) + masks = np.ones(self.tensor_shape, dtype=bool) + for dim in range(len(tensor_shape)): + bcast_dims = [1] * len(tensor_shape) + bcast_dims[dim] = tensor_shape[dim] + off = (self.offsets[dim].data + np.arange(tensor_shape[dim])).reshape(bcast_dims) + ptrs = ptrs + (n_bytes * off * self.strides[dim].data).astype(np.uint64) + if dim in boundary_check: + masks = np.logical_and(masks, off < self.shape[dim].data) + ptrs = TensorHandle(ptrs, self.base.dtype.scalar) + return ptrs, masks + + +@dataclass(frozen=True) +class InterpreterOptions: + extern_libs: dict = None + debug: bool = False + arch: str = None + allow_fp8e4nv: bool = True + allow_fp8e4b15: bool = True + default_dot_input_precision: str = "tf32" + allowed_dot_input_precisions: Tuple[str] = ("tf32", "tf32x3", "ieee") + max_num_imprecise_acc_default: int = 0 + + +def _get_signed_np_dtype(dtype): + if dtype == np.uint8: + return np.int8 + if dtype == np.uint16: + return np.int16 + if dtype == np.uint32: + return np.int32 + if dtype == np.uint64: + return np.int64 + return dtype + + +def _get_np_dtype(tt_dtype): + if isinstance(tt_dtype, tl.pointer_type): + return np.dtype(np.uint64) + np_types = { + tl.int1: np.dtype(bool), + tl.float16: np.dtype(np.float16), + tl.float32: np.dtype(np.float32), + tl.float64: np.dtype(np.float64), + tl.int8: np.dtype(np.int8), + tl.uint8: np.dtype(np.uint8), + tl.int16: np.dtype(np.int16), + tl.uint16: np.dtype(np.uint16), + tl.int32: np.dtype(np.int32), + tl.uint32: np.dtype(np.uint32), + tl.int64: np.dtype(np.int64), + tl.uint64: np.dtype(np.uint64), + # bfloat16 types are stored as uint16 + tl.bfloat16: np.dtype(np.uint16), + # float8 types are stored as uint8 + tl.float8e5: np.dtype(np.uint8), + tl.float8e5b16: np.dtype(np.uint8), + tl.float8e4nv: np.dtype(np.uint8), + tl.float8e4b8: np.dtype(np.uint8), + tl.float8e4b15: np.dtype(np.uint8), + } + if isinstance(tt_dtype, tl.block_type): + if isinstance(tt_dtype.element_ty, tl.pointer_type): + return np.dtype(np.uint64) + return np_types[tt_dtype.element_ty] + return np_types[tt_dtype] + + +def _convert_float(input, input_dtype, output_dtype, rounding_mode): + input_uint_dtype = getattr(np, f"uint{input_dtype.primitive_bitwidth}") + output_unint_dtype = getattr(np, f"uint{output_dtype.primitive_bitwidth}") + input_bin = np.frombuffer(input.tobytes(), dtype=input_uint_dtype) + sign = (input_bin >> (input_dtype.primitive_bitwidth - 1)) & 0x01 + input_exponent_width = input_dtype.primitive_bitwidth - input_dtype.fp_mantissa_width - 1 + output_exponent_width = output_dtype.primitive_bitwidth - output_dtype.fp_mantissa_width - 1 + significand = input_bin & ((1 << input_dtype.fp_mantissa_width) - 1) + bias_input = input_dtype.exponent_bias + bias_output = output_dtype.exponent_bias + exponent = ((input_bin >> input_dtype.fp_mantissa_width) & ((1 << input_exponent_width) - 1)).astype(np.int32) + subnormal_index = exponent == 0 + if np.any(subnormal_index): + # Credit to Phil: phil@openai.com + # subnormal repr: ((-1.0)**sign) * (2.0**(1 - exp_bias)) * (2^(m0) + 2^(m1) + ... + 2^(mn)) + # where m0, m1, ..., mn are the 1-bit of the mantissa + # convert it to normal repr: ((-1.0)**sign) * (2.0**(1 + m0 - exp_bias)) * (1 + 2^(m1 - m0) + ... + 2^(mn - m0)) + bit_pos = np.zeros_like(input_bin, dtype=np.int32) + # Find the most significant bit of the mantissa in the significand + for i in range(input_dtype.fp_mantissa_width): + bit_index = ((significand >> i) & 0x01) + # pos should be >= 1 + bit_pos[bit_index == 1] = input_dtype.fp_mantissa_width - i + zero_significand_index = significand == 0 + exponent[subnormal_index] = 1 - bit_pos[subnormal_index] + # 0 significand and subnormal should be treated as 0 + exponent[zero_significand_index & subnormal_index] = bias_input - bias_output + significand[subnormal_index] = (significand[subnormal_index] << bit_pos[subnormal_index]) & ( + (1 << input_dtype.fp_mantissa_width) - 1) + # Prevent overflow and underflow + exponent_output = np.maximum(0, np.minimum((exponent - bias_input + bias_output), (1 << output_exponent_width) - 1)) + exponent_output = exponent_output.astype(output_unint_dtype) + sign_output = sign.astype(output_unint_dtype) + if input_dtype.primitive_bitwidth > output_dtype.primitive_bitwidth: # Downcast + significand_output = (significand >> (input_dtype.fp_mantissa_width - output_dtype.fp_mantissa_width)) & ( + (1 << output_dtype.fp_mantissa_width) - 1) + if rounding_mode == _ir.ROUNDING_MODE.RTNE: # Round to nearst even + # find the cut-off bit + cut_off = significand & (1 << (input_dtype.fp_mantissa_width - output_dtype.fp_mantissa_width - 1)) + significand_output = significand_output + (cut_off > 0) + significand_output = significand_output.astype(output_unint_dtype) + else: # Upcast + significand_output = (significand.astype(output_unint_dtype) << + (output_dtype.fp_mantissa_width - input_dtype.fp_mantissa_width)) & ( + (1 << output_dtype.fp_mantissa_width) - 1) + subnormal_index = exponent_output == 0 + if np.any(subnormal_index): # underflow + # normal repr: ((-1.0)**sign) * (2.0**(exp - exp_bias_input)) * (1 + 2^(m0) + 2^(m1) + ... + 2^(mn)) + # where m0, m1, ..., mn are the 1-bit of the mantissa + # shift = (1 - exp_bias_output) - (exp - exp_bias_input) + # convert it to subnormal repr: ((-1.0)**sign) * (2.0**(1 - exp_bias_output)) * (2^(-shift) + 2^(m0 - shift) + 2^(m1 - shift) + ... + 2^(mn - shift)) + exponent = ((input_bin >> input_dtype.fp_mantissa_width) & ((1 << input_exponent_width) - 1)).astype(np.int32) + non_zero_exponent_index = exponent != 0 + # If the original exponent is not zero, we still need to shift the significand and consider the 1.0 part in mantissa + subnormal_index = subnormal_index & non_zero_exponent_index + shift = np.zeros_like(input_bin, dtype=np.int32) + shift[subnormal_index] = (1 - bias_output) - (exponent[subnormal_index] - bias_input) + significand_output[subnormal_index] = (significand_output[subnormal_index] >> shift[subnormal_index]) | ( + 1 << (output_dtype.fp_mantissa_width - shift[subnormal_index])) + output = (sign_output << (output_dtype.primitive_bitwidth - 1)) | ( + exponent_output << output_dtype.fp_mantissa_width) | significand_output + return output.reshape(input.shape) + + +def _erf(x): + # Numpy does not support erf + return math.erf(x) + + +def _umulhi_64(a, b): + # Numpy does not support 128-bit multiplication + # So we have to implement it manually + return (int(a) * int(b)) >> 64 + + +np_erf_fp32 = np.vectorize(_erf, otypes=[np.float32]) +np_erf_fp64 = np.vectorize(_erf, otypes=[np.float64]) +np_umulhi_u64 = np.vectorize(_umulhi_64, otypes=[np.uint64]) + + +class ExtraFunctions: + + @staticmethod + def _convert_custom_types(input, dst_ty, fp_downcast_rounding, _builder): + return tl.tensor(_builder.create_fp_to_fp(input.handle, dst_ty, fp_downcast_rounding), dst_ty) + + +class InterpreterBuilder: + ir_sem_to_interpreter_sem = { + _ir.MEM_SEMANTIC.ACQUIRE: _interpreter.MEM_SEMANTIC.ACQUIRE, + _ir.MEM_SEMANTIC.RELEASE: _interpreter.MEM_SEMANTIC.RELEASE, + _ir.MEM_SEMANTIC.RELAXED: _interpreter.MEM_SEMANTIC.RELAXED, + _ir.MEM_SEMANTIC.ACQUIRE_RELEASE: _interpreter.MEM_SEMANTIC.ACQUIRE_RELEASE, + } + + ir_rmw_op_to_interpreter_rmw_op = { + _ir.ATOMIC_OP.ADD: _interpreter.RMW_OP.ADD, + _ir.ATOMIC_OP.FADD: _interpreter.RMW_OP.FADD, + _ir.ATOMIC_OP.MIN: _interpreter.RMW_OP.MIN, + _ir.ATOMIC_OP.UMIN: _interpreter.RMW_OP.UMIN, + _ir.ATOMIC_OP.MAX: _interpreter.RMW_OP.MAX, + _ir.ATOMIC_OP.UMAX: _interpreter.RMW_OP.UMAX, + _ir.ATOMIC_OP.AND: _interpreter.RMW_OP.AND, + _ir.ATOMIC_OP.OR: _interpreter.RMW_OP.OR, + _ir.ATOMIC_OP.XOR: _interpreter.RMW_OP.XOR, + _ir.ATOMIC_OP.XCHG: _interpreter.RMW_OP.XCHG, + } + + def __init__(self) -> None: + self.arch = None + self.options = InterpreterOptions() + self.codegen_fns = {} + self.codegen_fns["convert_custom_types"] = ExtraFunctions._convert_custom_types + + def set_grid_idx(self, x, y, z): + if not x < self.grid_dim[0]: + raise ValueError("x >= grid_dim[0]") + if not y < self.grid_dim[1]: + raise ValueError("y >= grid_dim[1]") + if not z < self.grid_dim[2]: + raise ValueError("z >= grid_dim[2]") + self.grid_idx = (x, y, z) + + def set_grid_dim(self, nx, ny, nz): + self.grid_dim = (nx, ny, nz) + + # constants + + def get_half_ty(self): + return tl.float16 + + def get_bf16_ty(self): + return tl.bfloat16 + + def get_float_ty(self): + return tl.float32 + + def get_double_ty(self): + return tl.float64 + + def get_int8_ty(self): + return tl.int8 + + def get_uint8_ty(self): + return tl.uint8 + + def get_int16_ty(self): + return tl.int16 + + def get_uint16_ty(self): + return tl.uint16 + + def get_int32_ty(self): + return tl.int32 + + def get_uint32_ty(self): + return tl.uint32 + + def get_int64_ty(self): + return tl.int64 + + def get_uint64_ty(self): + return tl.uint64 + + def get_fp8e4nv_ty(self): + return tl.float8e4nv + + def get_fp8e4b15_ty(self): + return tl.float8e4b15 + + def get_fp8e4b8_ty(self): + return tl.float8e4b8 + + def get_fp8e5_ty(self): + return tl.float8e5 + + def get_fp8e5b16_ty(self): + return tl.float8e5b16 + + def get_ptr_ty(self, elt_ty, addr_space): + return tl.pointer_type(elt_ty, addr_space) + + def get_block_ty(self, dtype, shape): + return tl.block_type(dtype, shape) + + def get_int1(self, value): + return TensorHandle(np.array([value], dtype=np.bool_), tl.int1) + + def get_uint8(self, value): + return TensorHandle(np.array([value], dtype=np.uint8), tl.uint8) + + def get_int8(self, value): + return TensorHandle(np.array([value], dtype=np.int8), tl.int8) + + def get_uint16(self, value): + return TensorHandle(np.array([value], dtype=np.uint16), tl.uint16) + + def get_int16(self, value): + return TensorHandle(np.array([value], dtype=np.int16), tl.int16) + + def get_uint32(self, value): + return TensorHandle(np.array([value], dtype=np.uint32), tl.uint32) + + def get_int32(self, value): + return TensorHandle(np.array([value], dtype=np.int32), tl.int32) + + def get_uint64(self, value): + return TensorHandle(np.array([value], dtype=np.uint64), tl.uint64) + + def get_int64(self, value): + return TensorHandle(np.array([value], dtype=np.int64), tl.int64) + + def get_fp16(self, value): + return TensorHandle(np.array([value], dtype=np.float16), tl.float16) + + def get_fp32(self, value): + return TensorHandle(np.array([value], dtype=np.float32), tl.float32) + + def get_fp64(self, value): + return TensorHandle(np.array([value], dtype=np.float64), tl.float64) + + def get_null_value(self, type): + return TensorHandle(np.array([0], dtype=_get_np_dtype(type)), type) + + # programming model + def create_get_program_id(self, axis): + if self.grid_idx is None: + raise ValueError("grid_idx is None") + return TensorHandle(np.array([self.grid_idx[axis]], dtype=np.int32), tl.int32) + + def create_get_num_programs(self, axis): + return TensorHandle(np.array([self.grid_dim[axis]], dtype=np.int32), tl.int32) + + # memory ops + def create_load(self, ptr, _0, _1, is_volatile): + mask = TensorHandle(np.ones_like(ptr.data, dtype=bool), tl.int1) + other = None + return self.create_masked_load(ptr, mask, other, _0, _1, is_volatile) + + def create_store(self, ptr, val, _0, _1): + mask = TensorHandle(np.ones_like(ptr.data, dtype=bool), tl.int1) + return self.create_masked_store(ptr, val, mask, None, None) + + def create_masked_load(self, ptrs, mask, other, cache_modifier, eviction_policy, is_volatile): + dtype_tt = ptrs.get_element_ty() + dtype_np = _get_np_dtype(dtype_tt) + if other is None: + other = TensorHandle(np.zeros_like(ptrs.data, dtype=dtype_np), dtype_tt) + ret = _interpreter.load(ptrs.data, mask.data, other.data, dtype_np) + return TensorHandle(ret, dtype_tt) + + def create_masked_store(self, ptrs, value, mask, cache_modifier, eviction_policy): + return _interpreter.store(ptrs.data, value.data, mask.data) + + # casting ops + def cast_impl(self, src, dst_type): + src_element_type = src.dtype.scalar + dst_element_type = dst_type.scalar + if (src_element_type == tl.bfloat16 and dst_element_type == tl.float32) or \ + (src_element_type == tl.float32 and dst_element_type == tl.bfloat16): + data = _convert_float(src.data, src_element_type, dst_element_type, None).view(_get_np_dtype(dst_type)) + return TensorHandle(data, dst_type.scalar) + else: + return TensorHandle(src.data.astype(_get_np_dtype(dst_type)), dst_type.scalar) + + create_si_to_fp = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_ui_to_fp = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_fp_to_si = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_fp_to_ui = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_fp_ext = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_fp_trunc = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_int_cast = lambda self, src, dst_type, is_signed: self.cast_impl(src, dst_type) + + def create_fp_to_fp(self, src, dst_type, rounding_mode): + src_element_type = src.dtype.scalar + dst_element_type = dst_type.scalar + data = _convert_float(src.data, src_element_type, dst_element_type, rounding_mode).view(_get_np_dtype(dst_type)) + return TensorHandle(data, dst_type.scalar) + + def create_bitcast(self, src, dst_type): + return TensorHandle(src.data.view(_get_np_dtype(dst_type)), dst_type.scalar) + + # binary operators + def binary_op(self, lhs, rhs, op): + return TensorHandle(op(lhs.data, rhs.data), lhs.dtype.scalar) + + create_fadd = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.add) + create_fmul = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.multiply) + create_fdiv = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.divide) + create_frem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.remainder) + create_fsub = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.subtract) + create_mul = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.multiply) + create_precise_divf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.divide) + create_sdiv = lambda self, lhs, rhs: self.create_idiv(lhs, rhs) + create_udiv = lambda self, lhs, rhs: self.create_idiv(lhs, rhs) + # LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders. + create_srem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.fmod) + create_urem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.fmod) + create_add = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.add) + create_sub = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.subtract) + create_shl = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.left_shift) + create_lshr = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.right_shift) + create_minsi = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum) + create_minui = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum) + create_minimumf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum) + create_minnumf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum) + create_maxsi = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum) + create_maxui = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum) + create_maximumf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum) + create_maxnumf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum) + create_icmpSLE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal) + create_icmpSLT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less) + create_icmpSGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal) + create_icmpSGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater) + create_icmpULE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal) + create_icmpULT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less) + create_icmpUGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal) + create_icmpUGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater) + create_icmpEQ = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.equal) + create_icmpNE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.not_equal) + create_fcmpOLT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less) + create_fcmpOGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater) + create_fcmpOLE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal) + create_fcmpOGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal) + create_fcmpOEQ = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.equal) + create_fcmpONE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.not_equal) + create_fcmpULT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less) + create_fcmpUGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater) + create_fcmpULE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal) + create_fcmpUGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal) + create_fcmpUEQ = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.equal) + create_fcmpUNE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.not_equal) + create_and = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_and) + create_xor = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_xor) + create_or = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_or) + + def create_idiv(self, lhs, rhs): + # Triton has IEEE, not numpy/torch, semantics for %, and those carry + # through to //, so we have to use a nonstandard expression to get a + # reference result for //. + return TensorHandle((lhs.data - np.fmod(lhs.data, rhs.data)) // rhs.data, lhs.dtype.scalar) + + def create_ashr(self, lhs, rhs): + # Triton's rshift operator depends on the signedness of the left operand + lhs_dtype = _get_signed_np_dtype(lhs.data.dtype) + rhs_dtype = _get_signed_np_dtype(rhs.data.dtype) + lhs.data = lhs.data.astype(lhs_dtype) + rhs.data = rhs.data.astype(rhs_dtype) + return self.binary_op(lhs, rhs, np.right_shift) + + def create_umulhi(self, lhs, rhs): + dtype = lhs.data.dtype + if dtype == np.int64 or dtype == np.uint64: + return TensorHandle(np_umulhi_u64(lhs.data, rhs.data), lhs.dtype.scalar) + else: + compute_dtype = getattr(np, f"uint{dtype.itemsize * 8 * 2}") + lhs_data = lhs.data.astype(compute_dtype) + rhs_data = rhs.data.astype(compute_dtype) + ret_data = np.multiply(lhs_data, rhs_data) >> (dtype.itemsize * 8) + return TensorHandle(ret_data.astype(dtype), lhs.dtype.scalar) + + # ternary functions + def ternary_op(self, lhs, rhs, other, op): + return TensorHandle(op(lhs.data, rhs.data, other.data), other.dtype.scalar) + + create_clampf = lambda self, arg, lo, hi, propagate_nans: self.ternary_op(arg, lo, hi, np.clip) + create_select = lambda self, cond, lhs, rhs: self.ternary_op(cond, lhs, rhs, np.where) + + def create_fma(self, x, y, z): + return TensorHandle(x.data * y.data + z.data, z.dtype.scalar) + + # unary functions + def unary_op(self, arg, op): + return TensorHandle(op(arg.data), arg.dtype.scalar) + + def create_fabs(self, arg): + # Mask out the sign bit based on the primitive length + dtype_tt = arg.dtype + mask_bitwidth = dtype_tt.primitive_bitwidth - 1 + np_uint_dtype = getattr(np, f"uint{dtype_tt.primitive_bitwidth}") + data = arg.data.view(np_uint_dtype) + mask = (1 << mask_bitwidth) - 1 + ret = (data & mask).view(_get_np_dtype(dtype_tt)) + return TensorHandle(ret, arg.dtype.scalar) + + create_cos = lambda self, arg: self.unary_op(arg, np.cos) + create_exp = lambda self, arg: self.unary_op(arg, np.exp) + create_exp2 = lambda self, arg: self.unary_op(arg, np.exp2) + create_iabs = lambda self, arg: self.unary_op(arg, np.abs) + create_floor = lambda self, arg: self.unary_op(arg, np.floor) + create_ceil = lambda self, arg: self.unary_op(arg, np.ceil) + create_log = lambda self, arg: self.unary_op(arg, np.log) + create_log2 = lambda self, arg: self.unary_op(arg, np.log2) + create_precise_sqrt = lambda self, arg: self.unary_op(arg, np.sqrt) + create_sqrt = lambda self, arg: self.unary_op(arg, np.sqrt) + create_sin = lambda self, arg: self.unary_op(arg, np.sin) + + def create_erf(self, arg): + ret = np_erf_fp32(arg.data) if arg.data.dtype == np.float32 else np_erf_fp64(arg.data) + return TensorHandle(ret, arg.dtype.scalar) + + def create_rsqrt(self, arg): + return TensorHandle(1 / np.sqrt(arg.data), arg.dtype.scalar) + + # tensor operators + create_reshape = lambda self, arg, shape, allow_reorder: TensorHandle(arg.data.reshape(shape), arg.dtype.scalar) + + def create_trans(self, arg, perm): + return TensorHandle(np.transpose(arg.data, perm), arg.dtype.scalar) + + def create_dot(self, a, b, d, input_precision, max_num_imprecise_acc): + a_data = a.data + b_data = b.data + if (a.dtype.primitive_bitwidth == 8 and a.dtype.is_floating()) or \ + (b.dtype.primitive_bitwidth == 8 and b.dtype.is_floating()): + a_data = _convert_float(a_data, a.dtype, tl.float16, None).view(np.float16) + b_data = _convert_float(b_data, b.dtype, tl.float16, None).view(np.float16) + return TensorHandle(np.matmul(a_data, b_data, dtype=d.data.dtype) + d.data, d.dtype.scalar) + + def create_make_range(self, start, stop): + return TensorHandle(np.arange(start, stop, dtype=np.int32), tl.int32) + + def create_histogram(self, data, bins): + return TensorHandle(np.histogram(data.data, bins=bins, range=(0, bins))[0], tl.int32) + + # pointer arithmetic + + def create_addptr(self, ptr, offset): + dtype_tt = ptr.get_element_ty() + element_bitwidth = dtype_tt.primitive_bitwidth + # int1's bitwidth is 1, but we need to use 8 for pointer arithmetic + element_bytewidth = max(1, element_bitwidth // 8) + return TensorHandle(ptr.data + element_bytewidth * offset.data.astype(np.uint64), ptr.dtype) + + def create_tensor_pointer_load(self, ptr, boundary_check, padding_option, cache_modifier, eviction_policy, + is_volatile): + ptrs, masks = ptr.materialize_pointers(boundary_check) + dtype_tt = ptrs.get_element_ty() + dtype_np = _get_np_dtype(dtype_tt) + if padding_option is None: + other = None + elif padding_option == _ir.PADDING_OPTION.PAD_ZERO: + other = TensorHandle(np.zeros_like(ptrs.data, dtype=dtype_np), dtype_tt) + elif padding_option == _ir.PADDING_OPTION.PAD_NAN: + other = TensorHandle(np.full_like(ptrs.data, float('nan'), dtype=dtype_np), dtype_tt) + else: + raise ValueError(f"unsupported padding option {padding_option}") + return self.create_masked_load(ptrs, masks, other, cache_modifier, eviction_policy, is_volatile) + + def create_tensor_pointer_store(self, ptr, value, boundary_check, cache_modifier, eviction_policy): + ptrs, masks = ptr.materialize_pointers(boundary_check) + return self.create_masked_store(ptrs, value, masks, cache_modifier, eviction_policy) + + def create_expand_dims(self, arg, axis): + return TensorHandle(np.expand_dims(arg.data, axis), arg.dtype.scalar) + + def create_broadcast(self, arg, shape): + return TensorHandle(np.broadcast_to(arg.data, shape), arg.dtype.scalar) + + def create_int_to_ptr(self, val, dst_ty): + return TensorHandle(val.data.astype(np.uint64), dst_ty.scalar) + + def create_ptr_to_int(self, val, dst_ty): + return TensorHandle(val.data.astype(np.uint64), dst_ty.scalar) + + def create_cat(self, lhs, rhs): + return TensorHandle(np.concatenate([lhs.data, rhs.data]), lhs.dtype.scalar) + + def create_join(self, lhs, rhs): + # Triton only supports joining two original tensors into a new one along the last axis + return TensorHandle(np.stack([lhs.data, rhs.data], axis=-1), lhs.dtype.scalar) + + def create_split(self, val): + # Triton only supports splitting the original tensor into two along the last axis + return (TensorHandle(val.data[..., 0], val.dtype.scalar), TensorHandle(val.data[..., 1], val.dtype.scalar)) + + def create_splat(self, arg, shape): + if isinstance(arg.dtype, tl.block_type): + return TensorHandle(np.full(shape, arg.data[0], dtype=_get_np_dtype(arg.dtype)), arg.dtype.scalar) + else: # scalar + return TensorHandle(np.full(shape, arg.data, dtype=_get_np_dtype(arg.dtype)), arg.dtype.scalar) + + def create_atomic_cas(self, ptr, cmp, val, sem, scope): + if sem not in self.ir_sem_to_interpreter_sem: + raise ValueError(f"unsupported semantic {sem}") + sem = self.ir_sem_to_interpreter_sem[sem] + return TensorHandle(_interpreter.atomic_cas(ptr.data, cmp.data, val.data, sem), cmp.dtype.scalar) + + def create_atomic_rmw(self, rmwOp, ptr, val, mask, sem, scope): + if rmwOp not in self.ir_rmw_op_to_interpreter_rmw_op: + raise ValueError(f"unsupported rmwOp {rmwOp}") + if sem not in self.ir_sem_to_interpreter_sem: + raise ValueError(f"unsupported semantic {sem}") + rmwOp = self.ir_rmw_op_to_interpreter_rmw_op[rmwOp] + sem = self.ir_sem_to_interpreter_sem[sem] + return TensorHandle(_interpreter.atomic_rmw(rmwOp, ptr.data, val.data, mask.data, sem), val.dtype.scalar) + + def create_extern_elementwise(self, libName, libPath, symbol, argList, retType, isPure): + raise NotImplementedError("extern_elementwise not supported in interpreter mode") + + def create_inline_asm(self, inlineAsm, constraints, values, type, isPure, pack): + raise NotImplementedError("inline_asm not supported in interpreter mode") + + def create_print(self, prefix, hex, values): + # Interpreter's device_print function has a different format than Triton's device_print + msg = f"({self.grid_idx[0]}, {self.grid_idx[1]}, {self.grid_idx[2]})" + if prefix: + msg += f" {prefix}" + if hex: + np.set_printoptions(formatter={'all': lambda x: f"0x{x:02x}"}) + for value in values: + print(msg + f" {value.data}") + if hex: + np.set_printoptions(formatter=None) + + def create_assert(self, condition, message, fileName, funcName, lineNo): + # Interpreter's device_assert function has a different format than Triton's device_assert + assert condition, f"{message} in {fileName}:{funcName}:{lineNo}" + + def create_barrier(self): + # Triton's barrier applies to each program in a grid, so it's a no-op in the interpreter + pass + + def create_make_block_ptr(self, base, shape, strides, offsets, tensor_shape, order): + # Create new offsets to avoid modifying the original + new_offsets = [offset.clone() for offset in offsets] + return BlockPointerHandle(base, shape, strides, new_offsets, tensor_shape, order) + + def create_advance(self, ptr, offsets): + if len(ptr.offsets) != len(offsets): + raise ValueError("len(ptr.offsets) != len(offsets)") + # Create new offsets to avoid modifying the original + new_offsets = [offset.clone() for offset in ptr.offsets] + ret = BlockPointerHandle(ptr.base, ptr.shape, ptr.strides, new_offsets, ptr.tensor_shape, ptr.order) + for i in range(len(offsets)): + ret.offsets[i].data += offsets[i].data + return ret + + def get_all_ones_value(self, type): + np_type = _get_np_dtype(type) + if "int" in np_type.name: + return TensorHandle(np.full(1, -1, dtype=np_type), type.scalar) + else: + raise TypeError(f"unsupported type {type}") + + +def _patch_attr(obj, name, member, builder): + new_member = lambda *args, member=member, **kwargs: (member(*args, ** + {k: v + for k, v in kwargs.items() + if k != "_builder"}, _builder=builder)) + setattr(obj, name, new_member) + + +def _patch_builtin(pkg, builder): + for name, member in inspect.getmembers(pkg): + if tl.core.is_builtin(member): + _patch_attr(pkg, name, member, builder) + + +def _patch_lang_tensor(tensor): + + def _get_bool(self): + data = self.handle.data + # in triton, only scalars can be converted to booleans + # here we need this hack because all scalars are tensors + return bool(data) if data.size == 1 else True + + def _get_transpose(self): + return tl.core.tensor(TensorHandle(np.transpose(self.handle.data), self.handle.dtype), self.dtype.scalar) + + tensor.__index__ = lambda self: int(self.handle.data) + tensor.__bool__ = lambda self: _get_bool(self) + tensor.__repr__ = lambda self: repr(self.handle.data) + tensor.__str__ = lambda self: str(self.handle.data) + tensor.T = property(_get_transpose) + + +class ReduceScanOpIneterface: + + def __init__(self, axis, combine_fn): + self.axis = axis + self.combine_fn = combine_fn + + def check_axis(self, shape, axis): + if axis is not None and axis >= len(shape): + raise ValueError(f"axis {axis} out of bounds for shape {shape}") + + def check_tensor(self, input): + for arg in input: + if not isinstance(arg, tl.core.tensor): + raise ValueError(f"input must be a tensor, got {type(arg)}") + self.check_axis(arg.shape, self.axis) + + def to_tensor(self, ret, dtype): + if hasattr(ret, "shape") and ret.shape: + ret_type = tl.block_type(dtype, ret.shape) + else: + ret = np.array([ret], dtype=_get_np_dtype(dtype)) + ret_type = dtype + return tl.core.tensor(TensorHandle(ret, dtype.scalar), ret_type) + + def apply(self, input): + if not isinstance(input, tuple): + input = (input, ) + self.check_tensor(input) + return self.apply_impl(input) + + def apply_impl(self, input): + raise NotImplementedError("apply_impl not implemented") + + +class ReduceOps(ReduceScanOpIneterface): + + def __init__(self, axis, combine_fn, keep_dims): + super().__init__(axis, combine_fn) + self.keep_dims = keep_dims + + def unravel(self, input, axis): + ret = [] + for data in input: + if axis is not None: + ret.append(data) + else: + axis = 0 + ret.append(self.to_tensor(data.handle.data.flatten(), data.dtype)) + return tuple(ret), axis + + def generic_reduce(self, input): + original_axis = self.axis + input, axis = self.unravel(input, self.axis) + input_data = [] + output_data = [] + input_shape = input[0].handle.data.shape + output_shape = input_shape[0:axis] + input_shape[axis + 1:] + for arg in input: + input_data.append(arg.handle.data) + output_data.append(np.zeros(output_shape, dtype=arg.handle.data.dtype)) + # Reduce on axis + for i in range(input_data[0].size): + # Recover input_index from i using input_shape + input_index = np.unravel_index(i, input_shape) + output_index = input_index[0:axis] + input_index[axis + 1:] + input_tuple = tuple(self.to_tensor(d[input_index], input[ii].dtype) for ii, d in enumerate(input_data)) + if input_index[axis] == 0: + # First element + for j in range(len(output_data)): + output_data[j][output_index] = input_tuple[j].handle.data.item() + else: + acc_tuple = tuple(self.to_tensor(o[output_index], input[oi].dtype) for oi, o in enumerate(output_data)) + combine_fn_ret = self.combine_fn.fn(*acc_tuple, *input_tuple) + acc_tuple = (combine_fn_ret, ) if not isinstance(combine_fn_ret, tuple) else combine_fn_ret + for j in range(len(output_data)): + output_data[j][output_index] = acc_tuple[j].handle.data.item() if isinstance( + acc_tuple[j], tl.core.tensor) else acc_tuple[j] + # Pack output + ret = [] + for i, data in enumerate(output_data): + if self.keep_dims: + if original_axis is not None: + data = np.expand_dims(data, axis) + else: + for _ in range(len(input_shape)): + data = np.expand_dims(data, 0) + + elif original_axis is None: + # Take a scalar + data = data.item() + ret.append(self.to_tensor(data, input[i].dtype)) + return ret[0] if len(ret) == 1 else tuple(ret) + + def min_max(self, input, val_reduce_op, idx_reduce_op=None): + # If input is a tuple, it must be (val, index), and we only take val + input = input[0] if isinstance(input, tuple) else input + val = None + idx = None + if val_reduce_op: + val = self.to_tensor(val_reduce_op(input.handle.data, axis=self.axis, keepdims=self.keep_dims), input.dtype) + if idx_reduce_op: + idx = self.to_tensor(idx_reduce_op(input.handle.data, axis=self.axis, keepdims=self.keep_dims), tl.int32) + if val is not None and idx is not None: + return val, idx + elif val is not None: + return val + elif idx is not None: + return idx + else: + raise ValueError("val_reduce_op and idx_reduce_op are both None") + + def sum(self, input): + return self.to_tensor(np.sum(input.handle.data, axis=self.axis, keepdims=self.keep_dims), input.dtype) + + def apply_impl(self, input): + if self.combine_fn == tl.standard._argmin_combine_tie_break_left: + return self.min_max(input[0], val_reduce_op=np.min, idx_reduce_op=np.argmin) + elif self.combine_fn == tl.standard._argmax_combine_tie_break_left: + return self.min_max(input[0], val_reduce_op=np.max, idx_reduce_op=np.argmax) + elif self.combine_fn == tl.standard._elementwise_max: + return self.min_max(input[0], val_reduce_op=np.max, idx_reduce_op=None) + elif self.combine_fn == tl.standard._elementwise_min: + return self.min_max(input[0], val_reduce_op=np.min, idx_reduce_op=None) + elif self.combine_fn == tl.standard._sum_combine: + return self.sum(input[0]) + else: + # Fall back to the slow mode + return self.generic_reduce(input) + + +class ScanOps(ReduceScanOpIneterface): + + def __init__(self, axis, combine_fn, reverse): + super().__init__(axis, combine_fn) + self.reverse = reverse + + def cumsum(self, input): + return [self.to_tensor(np.cumsum(input.handle.data, axis=self.axis), dtype=input.dtype)] + + def cumprod(self, input): + return [self.to_tensor(np.cumprod(input.handle.data, axis=self.axis), dtype=input.dtype)] + + def generic_scan(self, input): + input_data = [] + output_data = [] + shape = input[0].handle.data.shape + for arg in input: + input_data.append(arg.handle.data) + output_data.append(np.zeros(shape, dtype=arg.handle.data.dtype)) + # Scan on axis + for i in range(input_data[0].size): + # Recover index from i using shape + index = np.unravel_index(i, shape) + data = tuple(self.to_tensor(d[index], input[ii].dtype) for ii, d in enumerate(input_data)) + if index[self.axis] == 0: + # First element + for j in range(len(output_data)): + output_data[j][index] = data[j].handle.data.item() + else: + prev_index = tuple(index[i] - 1 if i == self.axis else index[i] for i in range(len(index))) + acc_tuple = tuple(self.to_tensor(o[prev_index], input[oi].dtype) for oi, o in enumerate(output_data)) + combine_fn_ret = self.combine_fn.fn(*acc_tuple, *data) + acc_tuple = (combine_fn_ret, ) if not isinstance(combine_fn_ret, tuple) else combine_fn_ret + for j in range(len(output_data)): + output_data[j][index] = acc_tuple[j].handle.data.item() if isinstance( + acc_tuple[j], tl.core.tensor) else acc_tuple[j] + # Pack output + ret = [] + for i, data in enumerate(output_data): + ret.append(self.to_tensor(data, input[i].dtype)) + return ret + + def apply_impl(self, input): + new_input = [] + if self.reverse: + for arg in input: + new_input.append(self.to_tensor(np.flip(arg.handle.data, axis=self.axis), arg.dtype)) + else: + new_input = input + if self.combine_fn == tl.standard._sum_combine: + ret = self.cumsum(new_input[0]) + elif self.combine_fn == tl.standard._prod_combine: + ret = self.cumprod(new_input[0]) + else: + # Fall back to the slow mode + ret = self.generic_scan(new_input) + if self.reverse: + for arg in ret: + arg.handle.data = np.flip(arg.handle.data, axis=self.axis) + return len(ret) == 1 and ret[0] or tuple(ret) + + +def _patch_reduce_scan(): + # Because interpreter doesn't support region_builder_fn, we cannot patch the builder + # to use the new reduce and scan functions. + # Instead, we need to patch reduce and reduce functions in tl and tl.core + def _new_reduce(input, axis, combine_fn, keep_dims=False, **kwargs): + return ReduceOps(axis, combine_fn, keep_dims).apply(input) + + def _new_scan(input, axis, combine_fn, reverse=False, **kwargs): + return ScanOps(axis, combine_fn, reverse).apply(input) + + tl.reduce = _new_reduce + tl.associative_scan = _new_scan + tl.core.reduce = _new_reduce + tl.core.associative_scan = _new_scan + + +def _patch_lang_core(lang): + + def _new_to_ir(self, builder): + # We need to specify signedness for integer types in the numpy mode + if self.name == 'void': + return builder.get_void_ty() + elif self.name == 'int1': + return builder.get_int1_ty() + elif self.name == 'int8': + return builder.get_int8_ty() + elif self.name == 'uint8': + return builder.get_uint8_ty() + elif self.name == 'int16': + return builder.get_int16_ty() + elif self.name == 'uint16': + return builder.get_uint16_ty() + elif self.name == 'int32': + return builder.get_int32_ty() + elif self.name == 'uint32': + return builder.get_uint32_ty() + elif self.name == 'int64': + return builder.get_int64_ty() + elif self.name == 'uint64': + return builder.get_uint64_ty() + elif self.name == 'fp8e5': + return builder.get_fp8e5_ty() + elif self.name == 'fp8e4nv': + return builder.get_fp8e4nv_ty() + elif self.name == 'fp8e4b15': + return builder.get_fp8e4b15_ty() + elif self.name == 'fp16': + return builder.get_half_ty() + elif self.name == 'bf16': + return builder.get_bf16_ty() + elif self.name == 'fp32': + return builder.get_float_ty() + elif self.name == 'fp64': + return builder.get_double_ty() + raise ValueError(f'fail to convert {self} to ir type') + + # can't just map lang.static_range to `range`, because `tl.static_range` + # can get `step` passed by keyword + def _new_range(arg1, arg2=None, step=None, **kwargs): + if step is None: + step = 1 + if arg2 is None: + start, end = 0, arg1 + else: + start, end = arg1, arg2 + return range(start, end, step) + + def _new_static_assert(cond, msg=""): + assert cond, msg + + def _set_attr(input, values, name): + # skip non tensor types. This may happen for induction variables. + if not isinstance(input, tl.tensor): + return input + # Unwrap constexpr + values = [values] if not isinstance(values, (list, tuple)) else values + values = [v.value if isinstance(v, tl.constexpr) else v for v in values] + if len(values) != max(1, len(input.shape)): + raise ValueError(f"len(values) != len(input.shape) for {name}") + input.handle.set_attr(name, values) + return input + + lang.range = _new_range + lang.static_range = _new_range + lang.static_assert = _new_static_assert + lang.static_print = print + lang.dtype.to_ir = _new_to_ir + lang.multiple_of = partial(_set_attr, name="tt.divisiblity") + lang.max_contiguous = partial(_set_attr, name="tt.contiguity") + lang.max_constancy = partial(_set_attr, name="tt.constancy") + + _patch_reduce_scan() + + +def _patch_lang(fn): + lang = [value for _, value in fn.__globals__.items() if value in [tl, tl.core]] + assert len(lang) == 1, "triton.language must be visible from within jit'd function" + _patch_builtin(lang[0], interpreter_builder) + _patch_builtin(lang[0].tensor, interpreter_builder) + if lang[0] == tl: + _patch_builtin(lang[0].math, interpreter_builder) + _patch_lang_tensor(lang[0].tensor) + _patch_lang_core(lang[0]) + + +# TODO: wrap everything in triton tensors +def _implicit_cvt(arg): + if isinstance(arg, int): + ty = tl.str_to_ty(triton.runtime.jit.JITFunction._type_of(triton.runtime.jit.JITFunction._key_of(arg))) + dtype = np.int32 + if -2**31 <= arg < 2**31: + dtype = np.int32 + elif 2**31 <= arg < 2**32: + dtype = np.uint32 + elif -2**63 <= arg < 2**63: + dtype = np.int64 + elif 2**63 <= arg < 2**64: + dtype = np.uint64 + else: + raise ValueError(f"Unsupported integer value {arg}") + handle = TensorHandle(np.array([arg], dtype=dtype), ty) + return tl.tensor(handle, ty) + if hasattr(arg, "data_ptr"): + ty = tl.str_to_ty(triton.runtime.jit.JITFunction._type_of(triton.runtime.jit.JITFunction._key_of(arg))) + handle = TensorHandle(np.array([arg.data_ptr()], dtype=np.uint64), ty) + return tl.tensor(handle, ty) + return arg + + +interpreter_builder = InterpreterBuilder() + +# These keywords are not supported by the interpreter +RESERVED_KWS = ["num_warps", "num_stages", "num_ctas", "enable_fp_fusion", "grid", "maxnreg"] + + +class GridExecutor: + + def __init__(self, fn, arg_names, grid): + from .jit import _normalize_ty # TODO: modularize + + self.fn = fn + self.arg_names = arg_names + self.grid = grid + __annotations__ = {name: _normalize_ty(ty) for name, ty in fn.__annotations__.items()} + self.constexprs = [name for name in arg_names if __annotations__.get(name) == "constexpr"] + + def _init_args_hst(self, args_dev, kwargs): + args_hst = [] + for arg in args_dev: + if hasattr(arg, "data_ptr"): + args_hst.append(arg.cpu()) + else: + args_hst.append(arg) + # Process keyword arguments + kwargs_hst = {} + for key, value in kwargs.items(): + if hasattr(value, "data_ptr"): + kwargs_hst[key] = value.cpu() + else: + kwargs_hst[key] = value + return args_hst, kwargs_hst + + def _restore_args_dev(self, args_dev, args_hst, kwargs, kwargs_hst): + for arg_dev, arg_hst in zip(args_dev, args_hst): + if hasattr(arg_dev, "data_ptr"): + arg_dev.data.copy_(arg_hst.to(arg_dev.device).data) + + # Restore keyword arguments + for key, kwarg_dev in kwargs.items(): + kwarg_hst = kwargs_hst[key] + if hasattr(kwarg_dev, "data_ptr"): + kwarg_dev.data.copy_(kwarg_hst.to(kwarg_dev.device).data) + + def __call__(self, *args_dev, **kwargs): + # removes reserved keywords from kwargs + kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS} + if kwargs.pop("warmup", False): + return + # copy arguments to the host + args_hst, kwargs_hst = self._init_args_hst(args_dev, kwargs) + # remaps core language functions to interpreted ones + _patch_lang(self.fn) + # we need to copy arguments to the host for the interpreter + # implicitly convert tensor arguments to their base pointers + args = inspect.getcallargs(self.fn, *args_hst, **kwargs_hst) + args = {name: arg if name in self.constexprs else _implicit_cvt(arg) for name, arg in args.items()} + # iterate through grid + grid = self.grid(args) if callable(self.grid) else self.grid + assert len(grid) <= 3, "grid must have at most 3 dimensions" + grid = grid + (1, ) * (3 - len(grid)) + interpreter_builder.set_grid_dim(*grid) + try: + for x in range(grid[0]): + for y in range(grid[1]): + for z in range(grid[2]): + interpreter_builder.set_grid_idx(x, y, z) + self.fn(**args) + except Exception as e: + raise InterpreterError(repr(e)) from e + # copy arguments back to propagate side-effects + self._restore_args_dev(args_dev, args_hst, kwargs, kwargs_hst) + + +class InterpretedFunction: + + def __init__(self, fn) -> None: + self.fn = fn + + def run(*args, **kwargs): + grid = kwargs["grid"] + return GridExecutor(self.fn, self.arg_names, grid)(*args, **kwargs) + + self.run = run + signature = inspect.signature(fn) + self.arg_names = [v.name for v in signature.parameters.values()] + + @property + def __name__(self): + return self.fn.__name__ + + def __getitem__(self, grid): + return GridExecutor(self.fn, self.arg_names, grid) + + def __call__(self, *args, **kwargs): + # This is a device function call + _patch_lang(self.fn) + try: + return self.fn(*args, **kwargs) + except Exception as e: + raise InterpreterError(repr(e)) from e diff --git a/lib/python3.10/site-packages/triton/runtime/jit.py b/lib/python3.10/site-packages/triton/runtime/jit.py new file mode 100644 index 0000000000000000000000000000000000000000..a12b1d235b7cdb8eb2b9aaeb40ff7a1d60fdf570 --- /dev/null +++ b/lib/python3.10/site-packages/triton/runtime/jit.py @@ -0,0 +1,956 @@ +from __future__ import annotations, division +import ast +import hashlib +import inspect +import itertools +import os +import re +import textwrap +from collections import defaultdict +from functools import cached_property +from typing import Callable, Generic, Iterable, Optional, TypeVar, Union, overload, Dict, Any, Tuple +from ..runtime.driver import driver +from types import ModuleType + +TRITON_MODULE = __name__[:-len(".runtime.jit")] + +T = TypeVar("T") + +# ----------------------------------------------------------------------------- +# Dependencies Finder +# ----------------------------------------------------------------------------- + + +class DependenciesFinder(ast.NodeVisitor): + """ + This AST visitor is used to find dependencies of a JITFunction. This can + be used to invalidate a JITFunction's hash when its source code -- or + that of its dependencies -- changes. + + This visitor also keeps track of the global variables touched by the + JITFunction. When we launch the kernel, we check that these have the same + values as they did when we ran this visitor. If not, we raise an error (or + otherwise we could recompile). + """ + + def __init__(self, name, globals, src) -> None: + super().__init__() + self.name = name + self.hasher = hashlib.sha256(src.encode("utf-8")) + + # This function's __globals__ dict. + self.globals = globals + + # Python builtins that can be accessed from Triton kernels. + self.supported_python_builtins = { + 'float', + 'getattr', + 'int', + 'isinstance', + 'len', + 'list', + 'max', + 'min', + 'print', + 'range', + } + + # used_global_vals tells us which global variables are used by this + # function and all those it transitively calls, plus the values of those + # variables when each function was initially run. (That is, if A calls + # C, and B calls C, then the values for C in used_global_vals will be + # from the first time C was run, either by A or B.) + # + # Each function may have a different __globals__ dict, so the global + # variable `foo` may actually have a different value in the different + # functions. Thus this map is actually + # (var_name, id(__globals__)) -> (var_value, __globals__). + self.used_global_vals: Dict[Tuple[str, int], Tuple[Any, Dict[str, Any]]] = {} + + self.visiting_arg_default_value = False + + @property + def ret(self): + return self.hasher.hexdigest() + + def visit_Name(self, node): + if type(node.ctx) == ast.Store: + return node.id + + if node.id in self.local_names: + # The global name is hidden by the local name. + return None + + val = self.globals.get(node.id, None) + + # Only keep track of "interesting" global variables, that non-evil users + # might change. Don't consider functions, modules, builtins, etc. This + # helps keep the list of vars we have to check small. + if (val is not None # + # Python default arguments are resolved only once, when the + # function is defined. So if you do `foo(a=A)` and the value of + # A changes, foo will still use the old value of A. + and not self.visiting_arg_default_value + # It would be pretty evil if someone did `import x` and then + # `x = blah`. + and type(val) != ModuleType + # It would be pretty evil if we used function `foo` inside of + # `bar` and then someone did `foo = baz`. + and not isinstance(val, JITFunction) and not getattr(val, "__triton_builtin__", False) # + and node.id not in self.supported_python_builtins # + ): + self.used_global_vals[(node.id, id(self.globals))] = (val, self.globals) + + return val + + def visit_Tuple(self, node): + # We need to explicitly return the tuple values so that visit_Assign can + # access them in the case of `a, b = ...`. + return [self.visit(elt) for elt in node.elts] + + def visit_Attribute(self, node): + lhs = self.visit(node.value) + while isinstance(lhs, ast.Attribute): + lhs = self.visit(lhs.value) + if lhs is None or (getattr(lhs, "__name__", "") == TRITON_MODULE): + return None + return getattr(lhs, node.attr) + + def visit_Call(self, node): + + def is_triton_builtin(func): + if inspect.isbuiltin(node.func): + return True + module = getattr(func, "__module__", "") + return module.startswith(TRITON_MODULE) + + func = self.visit(node.func) + assert func is None or is_triton_builtin(func) or isinstance( + func, JITFunction + ), f'Function "{func.__name__}" is being called from a Triton function but is not a Triton function itself. Decorate it with @triton.jit to fix this' + + # Traverse arguments as well as node.func so we can find JITFunctions + # passed to tl.reduce or tl.associative_scan as the combine_fn + for obj in itertools.chain( + (func, ), + map(self.visit, node.args), + (self.visit(kw.value) for kw in node.keywords), + ): + if not isinstance(obj, JITFunction): + continue + if is_triton_builtin(obj): + continue + + func_cache_key = obj.cache_key + + # Merge our used_global_vals with those of the called function, + # after checking that all overlapping values are consistent. + for k in self.used_global_vals.keys() & obj.used_global_vals.keys(): + var_name, _ = k + v1, _ = self.used_global_vals[k] + v2, _ = obj.used_global_vals[k] + if v1 != v2: + raise RuntimeError( + f"Global variable {var_name} has value {v1} when compiling {self.name}, but inner kernel {func.__name__} has conflicting value {v2} from when it was first compiled. This is not allowed." + ) + + self.used_global_vals.update(obj.used_global_vals) + + noinline = str(getattr(obj, "noinline", False)) + + key = func_cache_key + noinline + self.hasher.update(key.encode("utf-8")) + + def visit_FunctionDef(self, node): + # Save the local name, which may hide the global name. + self.local_names = {arg.arg for arg in node.args.args} + self.generic_visit(node) + + def visit_arguments(self, node): + # The purpose of this function is to visit everything in `arguments` + # just like `generic_visit`, except when we're visiting default values + # (i.e. the `foo` part of `def fn(x = foo)`), we set + # self.visiting_arg_default_value = True. This allows visit_Name to be + # aware that we're inside function default values, which have special + # semantics. + + # According to the AST docs, the arguments node has the following structure. + # + # arguments = (arg* posonlyargs, arg* args, arg? vararg, arg* kwonlyargs, + # expr* kw_defaults, arg? kwarg, expr* defaults) + def visit_defaults(defaults): + try: + assert not self.visiting_arg_default_value + self.visiting_arg_default_value = True + for expr in defaults: + if expr is not None: + self.visit(expr) + finally: + self.visiting_arg_default_value = False + + for arg in itertools.chain(node.posonlyargs, node.args, [node.vararg] if node.vararg else [], node.kwonlyargs): + self.visit(arg) + + visit_defaults(node.kw_defaults) + + if node.kwarg is not None: + self.visit(node.kwarg) + + visit_defaults(node.defaults) + + def visitAssnTarget(self, node): + # Target is either a single string, or a list of strings (if the assn + # target is a tuple). + target = self.visit(node) + if isinstance(target, list): + self.local_names |= set(target) + else: + self.local_names.add(target) + + def visit_Assign(self, node): + if len(node.targets) != 1: + # TODO(jlebar): I don't actually know how to hit this. You don't + # get it from `a, b = ...` -- in that case, node.targets is a single + # Tuple, and in fact we *do* need to handle that case if we want + # existing code to work. + raise TypeError("Simultaneous multiple assignment is not supported.") + + self.visitAssnTarget(node.targets[0]) + + # This will re-visit the target, but that's OK. + self.generic_visit(node) + + def visit_AnnAssign(self, node): + self.visitAssnTarget(node.target) + + # This will re-visit the target, but that's OK. + self.generic_visit(node) + + def visit_For(self, node): + self.visitAssnTarget(node.target) + + # This will re-visit the target, but that's fine. + self.generic_visit(node) + + +# ----------------------------------------------------------------------------- +# JITFunction +# ----------------------------------------------------------------------------- + + +def _normalize_ty(ty) -> str: + if isinstance(ty, type): + return ty.__name__ + elif isinstance(ty, str): + return ty + return repr(ty) + + +class KernelParam: + """Represents a parameter (name plus metadata) to a @jit'ed function.""" + + def __init__(self, num: int, param: inspect.Parameter, do_not_specialize: bool): + self.num = num + self._param = param + self.do_not_specialize = do_not_specialize + + @cached_property + def name(self): + return self._param.name + + @cached_property + def annotation(self): + if not self._param.annotation or self._param.annotation == inspect.Parameter.empty: + return "" + return _normalize_ty(self._param.annotation) + + @cached_property + def annotation_type(self): + annotation = self.annotation + for ty1, ty2 in [("uint", 'u'), ("int", 'i')]: + width = annotation[annotation.find(ty1) + len(ty1):] + if width and ty1 in annotation: + return f"{ty2}{width}" + if annotation == "bool": + return "u1" + return "" + + @cached_property + def is_constexpr(self): + return "constexpr" in self.annotation + + @cached_property + def is_const(self): + return "const" in self.annotation and not self.is_constexpr + + @property + def default(self): + return self._param.default + + @property + def has_default(self): + return self._param.default != inspect.Parameter.empty + + +def compute_spec_key(v): + + if hasattr(v, "data_ptr") and (v.data_ptr() % 16 == 0): + return "D" + elif isinstance(v, int): + # bool is a subclass of int, so we don't check explicitly above. + if (v % 16 == 0): + return "D" + elif v == 1: + return "1" + return "N" + + +dtype2str = {} + + +def mangle_type(arg, is_const=False): + + if arg is None: + return "none" + elif isinstance(arg, bool): + return "i1" + elif isinstance(arg, int): + if -(2**31) <= arg and arg <= 2**31 - 1: + return "i32" + elif 2**63 <= arg and arg <= 2**64 - 1: + return "u64" + else: + return "i64" + elif isinstance(arg, float): + return "fp32" + else: + # dtypes are hashable so we can memoize this mapping: + dsk = (arg.dtype, is_const) + res = dtype2str.get(dsk, None) + if res is None: + res = ("*k" if dsk[1] else "*") + type_canonicalisation_dict[str(dsk[0]).split('.')[-1]] + dtype2str[dsk] = res + return res + + +class KernelInterface(Generic[T]): + run: T + + def __getitem__(self, grid) -> T: + """ + A JIT function is launched with: fn[grid](*args, **kwargs). + Hence JITFunction.__getitem__ returns a callable proxy that + memorizes the grid. + """ + return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs) + # return cast(T, functools.partial(cast(Callable, self.run), grid=grid)) + + +def serialize_specialization_data(name, signature, constants, attrs, options, key): + constants = {key: str(value) if value.__class__.__name__ == "dtype" else value for key, value in constants.items()} + import json + obj = { + 'name': name, 'signature': signature, 'constants': constants, 'attrs': attrs.to_dict(), 'options': + options.__dict__, 'key': key + } + serialized_obj = json.dumps(obj) + return serialized_obj + + +def create_function_from_signature(sig, kparams): + """ + Equivalent to sig.bind followed by apply_defaults. This generates a + native Python function (using exec) which can be memoized on a per-kernel + basis to avoid having to run these expensive functions -- which constitute + much of the kernel launch overhead -- every time we run the kernel. + """ + + assert len(sig.parameters) == len(kparams) + + # Create the function argument list and the dict entries for the return statement + func_args = [] + dict_entries = [] + constexpr_vals = [] + non_constexpr_vals = [] + signature_types = [] + specialisations = [] + + for ((name, sp), kp) in zip(sig.parameters.items(), kparams): + if sp.default is inspect.Parameter.empty: + func_args.append(name) + dict_entries.append(f"'{name}': {name}") + else: + func_args.append(f"{name}=default_{name}") + dict_entries.append(f"'{name}': {name}") + if kp.is_constexpr: + constexpr_vals.append(name) + else: + non_constexpr_vals.append(name) + if not kp.do_not_specialize: + specialisations.append('compute_spec_key(%s)' % name) + if kp.annotation_type: + signature_types.append('"%s"' % kp.annotation_type) + else: + signature_types.append('mangle_type(%s, %s)' % (name, 'True' if kp.is_const else 'False')) + + cache_key = ''.join([x + ', ' for x in signature_types + specialisations]) + constexpr_vals = ''.join([x + ', ' for x in constexpr_vals]) + non_constexpr_vals = ''.join([x + ', ' for x in non_constexpr_vals]) + + func_args.append('**excess_kwargs') + + # Join all arguments into a function definition string + args_str = ', '.join(func_args) + dict_str = ', '.join(dict_entries) + func_body = "def dynamic_func(%s):\n return {%s}, (%s), (%s), (%s), excess_kwargs" % ( + args_str, dict_str, cache_key, constexpr_vals, non_constexpr_vals) + + # Prepare defaults to be inserted into function namespace + func_namespace = { + f"default_{name}": param.default + for name, param in sig.parameters.items() + if param.default is not inspect.Parameter.empty + } + + func_namespace['mangle_type'] = mangle_type + func_namespace['compute_spec_key'] = compute_spec_key + + # Execute the function string in func_namespace to create the function + exec(func_body, func_namespace) + + # Extract the newly created function from the namespace + return func_namespace['dynamic_func'] + + +type_canonicalisation_dict = { + "bool": "i1", + "float8e4nv": "fp8e4nv", + "float8e5": "fp8e5", + "float8e4b15": "fp8e4b15", + "float8_e4m3fn": "fp8e4nv", + "float8e4b8": "fp8e4b8", + "float8_e4m3fnuz": "fp8e4b8", + "float8_e5m2": "fp8e5", + "float8e5b16": "fp8e5b16", + "float8_e5m2fnuz": "fp8e5b16", + "float16": "fp16", + "bfloat16": "bf16", + "float32": "fp32", + "float64": "fp64", + "int8": "i8", + "int16": "i16", + "int32": "i32", + "int64": "i64", + "uint8": "u8", + "uint16": "u16", + "uint32": "u32", + "uint64": "u64", +} + +for v in list(type_canonicalisation_dict.values()): + type_canonicalisation_dict[v] = v + + +class JITFunction(KernelInterface[T]): + # Hook for inspecting compiled functions and modules + cache_hook = None + divisibility = 16 + + @staticmethod + def _key_of(arg): + if hasattr(arg, "dtype"): + return arg.dtype + elif isinstance(arg, bool): + return "i1" + elif isinstance(arg, int): + if -(2**31) <= arg and arg <= 2**31 - 1: + return "i32" + elif 2**63 <= arg and arg <= 2**64 - 1: + return "u64" + else: + return "i64" + elif isinstance(arg, float): + return "fp32" + elif arg is None: + return None + else: + raise TypeError(f"Unsupported type {type(arg)} for {arg}") + + @staticmethod + def _spec_of(arg): + if hasattr(arg, "data_ptr"): + return arg.data_ptr() % JITFunction.divisibility == 0 + elif isinstance(arg, int): + return (arg % 16 == 0, arg == 1) + return (arg is None, ) + + def _get_config(self, *args): + from ..compiler import AttrsDescriptor + + def is_divisible_by_16(x): + if hasattr(x, "data_ptr"): + return x.data_ptr() % JITFunction.divisibility == 0 + elif isinstance(x, int): + return x % JITFunction.divisibility == 0 + if x is None: + return True + return False + + divisible_by_16 = { + param.num + for param, arg in zip(self.params, args) + if is_divisible_by_16(arg) and not param.do_not_specialize + } + equal_to_1 = { + param.num + for param, arg in zip(self.params, args) + if isinstance(arg, int) and not isinstance(arg, bool) and arg == 1 and not param.do_not_specialize + } + # folded equal_to_1 and None + # TODO: method to collect all folded args + return AttrsDescriptor(tuple(divisible_by_16), tuple(equal_to_1)) + # return _triton.code_gen.instance_descriptor(divisible_by_16, + # equal_to_1) + + @staticmethod + def _type_of(key, is_const=False): + # `None` is nullptr. Implicitly convert to *i8. + if key is None: + return "*i8" + elif isinstance(key, str): + return key + + dtype_str = str(key).split(".")[-1] + dtype_str = type_canonicalisation_dict[dtype_str] + const_str = "*k" if is_const else "*" + return const_str + dtype_str + + def _make_constants(self, constexpr_key): + constants = dict(zip(self.constexprs, constexpr_key)) + return constants + + def _call_hook( + self, + key, + signature, + device, + constants, + options, + configs, + ): + if JITFunction.cache_hook is None: + return False + + name = self.fn.__name__ + module = self.fn.__module__ + arg_reprs = ", ".join([f"{param.name}: {ty}" for param, ty in zip(self.params, key[1])]) + repr = f"{name}[num_warps={options.num_warps}, num_ctas={options.num_ctas}, num_stages={options.num_stages}, enable_fp_fusion={options.enable_fp_fusion}]({arg_reprs})" + + class JitFunctionInfo: + + def __init__(self, module, name, jit_function): + self.module = module + self.name = name + self.jit_function = jit_function + pass + + specialization_data = serialize_specialization_data(name, signature, constants, configs[0], options, key) + + kwargs = { + 'signature': signature, + 'device': device, + 'constants': constants, + 'num_warps': options.num_warps, + 'num_ctas': options.num_ctas, + 'num_stages': options.num_stages, + 'enable_fp_fusion': options.enable_fp_fusion, + 'extern_libs': options.extern_libs, + 'configs': configs, + 'specialization_data': specialization_data, + } + + return JITFunction.cache_hook( + key=key, + repr=repr, + fn=JitFunctionInfo(module, name, self), + compile={"key": key, **kwargs}, + is_manual_warmup=False, + already_compiled=False, + ) + + def add_pre_run_hook(self, hook): + ''' + Add a hook that will be executed prior to the execution of run + function with args and kwargs passed into the kernel + ''' + assert callable(hook) + self.pre_run_hooks.append(hook) + + def create_binder(self): + """ + Precompute as much as possible. + """ + from ..compiler import CompiledKernel, compile, ASTSource, make_backend + self.CompiledKernel = CompiledKernel + self.compile = compile + self.ASTSource = ASTSource + self.make_backend = make_backend + self.binder = create_function_from_signature(self.signature, self.params) + self.constexpr_indices = [i for (i, p) in enumerate(self.params) if p.is_constexpr] + self.non_constexpr_indices = [i for (i, p) in enumerate(self.params) if not p.is_constexpr] + self.specialised_indices = [ + i for (i, p) in enumerate(self.params) if (not p.do_not_specialize) and (not p.is_constexpr) + ] + + def run(self, *args, grid, warmup, **kwargs): + # parse options + device = driver.active.get_current_device() + stream = driver.active.get_current_stream(device) + kwargs["debug"] = self.debug + + # Execute pre run hooks with args and kwargs + for hook in self.pre_run_hooks: + hook(*args, **kwargs) + + if self.binder is None: + self.create_binder() + + bound_args, sig_and_spec, constexpr_vals, non_constexpr_vals, excess_kwargs = self.binder(*args, **kwargs) + + # compute cache key + key = ''.join(sig_and_spec) + str((constexpr_vals, excess_kwargs)) + kernel = self.cache[device].get(key, None) + + if kernel is None: + # Kernel is not cached; we have to compile. + target = driver.active.get_current_target() + backend = self.make_backend(target) + options = backend.parse_options(kwargs) + + # deprecated arguments + assert "device_type" not in kwargs, "device_type option is deprecated; current target will be used" + assert "device" not in kwargs, "device option is deprecated; current device will be used" + assert "stream" not in kwargs, "stream option is deprecated; current stream will be used" + for k in excess_kwargs: + if k not in options.__dict__: + raise KeyError("Keyword argument %s was specified but unrecognised" % k) + + bound_vals = tuple(bound_args.values()) + + # `None` is nullptr. Implicitly convert to *i8. This needs to be + # done here rather than when we build the signature as otherwise + # the kernel cache key could not distinguish between byte pointers + # and None arguments, resulting in a downstream mismatch: + sigkeys = [self.params[i].name for i in self.non_constexpr_indices] + sigvals = sig_and_spec[:len(sigkeys)] + signature = {k: ('*i8' if (v == 'none') else v) for (k, v) in zip(sigkeys, sigvals)} + + configs = (self._get_config(*bound_vals), ) + constants = { + p.name: v + for (v, p) in zip(bound_vals, self.params) + if p.is_constexpr or p.num in configs[0].equal_to_1 or v is None + } + for i, arg in constants.items(): + if callable(arg): + raise TypeError(f"Callable constexpr at index {i} is not supported") + + if self._call_hook(key, signature, device, constants, options, configs): + return None + # compile the kernel + src = self.ASTSource(self, signature, constants, configs[0]) + kernel = self.compile( + src, + target=target, + options=options.__dict__, + ) + self.cache[device][key] = kernel + + # Check that used global values have not changed. + not_present = object() + for (name, globals_dict_id), (val, globals_dict) in self.used_global_vals.items(): + if (newVal := globals_dict.get(name, not_present)) != val: + raise RuntimeError( + f"Global variable {name} has changed since we compiled this kernel, from {val} to {newVal}") + + if not warmup: + # canonicalize grid + assert grid is not None + if callable(grid): + # Arguments are passed as a dict to `grid`, by contract. + # TODO(jlebar): In the new launch API, pass the compiler flags as a + # second parameter to `grid`. + grid = grid(bound_args) + grid_size = len(grid) + grid_0 = grid[0] + grid_1 = grid[1] if grid_size > 1 else 1 + grid_2 = grid[2] if grid_size > 2 else 1 + + # launch kernel + launch_metadata = kernel.launch_metadata(grid, stream, *non_constexpr_vals) + kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata, + self.CompiledKernel.launch_enter_hook, self.CompiledKernel.launch_exit_hook, *non_constexpr_vals) + return kernel + + def __init__(self, fn, version=None, do_not_specialize=None, debug=None, noinline=None, repr=None, + launch_metadata=None): + do_not_specialize = do_not_specialize if do_not_specialize else [] + + self.fn = fn + self.module = fn.__module__ + self.version = version + self.signature = inspect.signature(fn) + self.do_not_specialize = do_not_specialize + self.starting_line_number = inspect.getsourcelines(fn)[1] + self.repr = lambda _: fn.__name__ if repr is None else repr(_) + self.launch_metadata = launch_metadata + + self.binder = None + + self.params = [] + for i, param in enumerate(self.signature.parameters.values()): + dns = do_not_specialize and (i in do_not_specialize or param.name in do_not_specialize) + self.params.append(KernelParam(i, param, dns)) + + # function source code (without decorators) + self.src = textwrap.dedent(inspect.getsource(fn)) + self.src = self.src[re.search(r"^def\s+\w+\s*\(", self.src, re.MULTILINE).start():] + # cache of just-in-time compiled kernels + self.cache = defaultdict(dict) + self.hash = None + + # Map of global variables used by the function and any functions it + # transitively calls, plus their values. The values are collected when + # the function is first compiled. Then every time we run the function, + # we check that the values of the globals match what's expected, + # otherwise we raise an error. + # + # Different functions can have different __globals__ maps, so the map + # key is actually (var name, id(__globals__)), and the map value is + # (value, __globals__). + self.used_global_vals: Dict[Tuple[str, int], Tuple[Any, Dict[str, Any]]] = {} + + # JITFunction can be instantiated as kernel + # when called with a grid using __getitem__ + self.kernel = None + self.debug = True if os.environ.get("TRITON_DEBUG", "0") == "1" else debug + self.noinline = noinline + + # TODO(jlebar): Remove uses of these fields outside this file, then + # remove the fields here. + self.arg_names = [p.name for p in self.params] + self.constexprs = [p.num for p in self.params if p.is_constexpr] + + # Hooks that will be called prior to executing "run" + self.pre_run_hooks = [] + + # reuse docs of wrapped function + self.__doc__ = fn.__doc__ + self.__name__ = fn.__name__ + self.__globals__ = fn.__globals__ + self.__module__ = fn.__module__ + + @property + def cache_key(self): + # TODO : hash should be attribute of `self` + if self.hash is None: + dependencies_finder = DependenciesFinder(name=self.__name__, globals=self.__globals__, src=self.src) + dependencies_finder.visit(self.parse()) + self.hash = dependencies_finder.ret + str(self.starting_line_number) + self.used_global_vals = dict(sorted(dependencies_finder.used_global_vals.items())) + return self.hash + + def warmup(self, *args, grid, **kwargs): + return self.run(grid=grid, warmup=True, *map(MockTensor.wrap_dtype, args), **kwargs) + + def preload(self, specialization_data): + from ..compiler import AttrsDescriptor, compile, ASTSource + import json + import triton.language as tl + device = driver.active.get_current_device() + deserialized_obj = json.loads(specialization_data) + if deserialized_obj['name'] != self.fn.__name__: + raise RuntimeError( + f"Specialization data is for {deserialized_obj['name']} but trying to preload for {self.fn.__name__}") + constants = { + key: tl.dtype(value) if tl.dtype.is_dtype(value) else value + for key, value in deserialized_obj['constants'].items() + } + signature = dict(deserialized_obj['signature'].items()) + src = ASTSource(self, signature, constants, AttrsDescriptor.from_dict(deserialized_obj['attrs'])) + options = { + key: tuple(value) if isinstance(value, list) else value + for key, value in deserialized_obj['options'].items() + } + key = deserialized_obj['key'] + kernel = compile(src, None, options) + self.cache[device][key] = kernel + return kernel + + # we do not parse `src` in the constructor because + # the user might want to monkey-patch self.src dynamically. + # Our unit tests do this, for example. + def parse(self): + tree = ast.parse(self.src) + assert isinstance(tree, ast.Module) + assert len(tree.body) == 1 + assert isinstance(tree.body[0], ast.FunctionDef) + return tree + + def __call__(self, *args, **kwargs): + raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel") + + def __setattr__(self, name, value): + super(JITFunction, self).__setattr__(name, value) + # - when `.src` attribute is set, cache path needs + # to be reinitialized + if name == "src": + self.hash = None + + def __repr__(self): + return f"JITFunction({self.module}:{self.fn.__name__})" + + +# ----------------------------------------------------------------------------- +# `jit` decorator +# ----------------------------------------------------------------------------- + + +@overload +def jit(fn: T) -> JITFunction[T]: + ... + + +@overload +def jit( + *, + version=None, + repr: Optional[Callable] = None, + launch_metadata: Optional[Callable] = None, + do_not_specialize: Optional[Iterable[int]] = None, + debug: Optional[bool] = None, + noinline: Optional[bool] = None, +) -> Callable[[T], JITFunction[T]]: + ... + + +def jit( + fn: Optional[T] = None, + *, + version=None, + repr: Optional[Callable] = None, + launch_metadata: Optional[Callable] = None, + do_not_specialize: Optional[Iterable[int]] = None, + debug: Optional[bool] = None, + noinline: Optional[bool] = None, +) -> Union[JITFunction[T], Callable[[T], JITFunction[T]]]: + """ + Decorator for JIT-compiling a function using the Triton compiler. + + :note: When a jit'd function is called, arguments are + implicitly converted to pointers if they have a :code:`.data_ptr()` method + and a `.dtype` attribute. + + :note: This function will be compiled and run on the GPU. It will only have access to: + + * python primitives, + * builtins within the triton package, + * arguments to this function, + * other jit'd functions + + :param fn: the function to be jit-compiled + :type fn: Callable + """ + + def decorator(fn: T) -> JITFunction[T]: + assert callable(fn) + if os.getenv("TRITON_INTERPRET", "0") == "1": + from .interpreter import InterpretedFunction + return InterpretedFunction(fn) + else: + return JITFunction( + fn, + version=version, + do_not_specialize=do_not_specialize, + debug=debug, + noinline=noinline, + repr=repr, + launch_metadata=launch_metadata, + ) + + if fn is not None: + return decorator(fn) + + else: + return decorator + + +# ----------------------------------------------------------------------------- +# Utilities for mocking tensors +# ----------------------------------------------------------------------------- + + +class MockTensor: + """ + Can be used in place of real tensors when calling: + kernel.warmup(MockTensor(torch.float32), ...) + """ + + @staticmethod + def wrap_dtype(arg): + if arg.__class__.__name__ == "dtype" and arg.__module__ == "torch": + return MockTensor(arg) + return arg + + def __init__(self, dtype): + self.dtype = dtype + + @staticmethod + def data_ptr(): + return 0 # optimistically assumes multiple of 16 + + +class TensorWrapper: + + def __init__(self, base, dtype): + self.dtype = dtype + self.base = base + self.data = base.data + self.device = base.device + self.shape = self.base.shape + + def data_ptr(self): + return self.base.data_ptr() + + def stride(self, i): + return self.base.stride(i) + + def __str__(self) -> str: + return f"TensorWrapper[{self.dtype}]({self.base})" + + def element_size(self): + return self.base.element_size() + + def cpu(self): + return TensorWrapper(self.base.cpu(), self.dtype) + + def copy_(self, other): + self.base.copy_(other.base) + + def to(self, device): + return TensorWrapper(self.base.to(device), self.dtype) + + +def reinterpret(tensor, dtype): + if isinstance(tensor, TensorWrapper): + if dtype == tensor.base.dtype: + # Reinterpreting to the original interpretation; return the base. + return tensor.base + else: + # Reinterpreting a wrapped tensor to a different type. + return TensorWrapper(tensor.base, dtype) + elif hasattr(tensor, "data_ptr"): + # A new wrapper is needed around an unwrapped tensor. + return TensorWrapper(tensor, dtype) + else: + raise TypeError(f"Cannot reinterpret a {type(tensor)}.") diff --git a/lib/python3.10/site-packages/triton/tools/__init__.py b/lib/python3.10/site-packages/triton/tools/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/python3.10/site-packages/triton/tools/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/triton/tools/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..537a548c7e4e46418381829a2c91b93b3592e24f Binary files /dev/null and b/lib/python3.10/site-packages/triton/tools/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/triton/tools/__pycache__/build_extern.cpython-310.pyc b/lib/python3.10/site-packages/triton/tools/__pycache__/build_extern.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd3b11fde5f3a0396e985cd9c899d2bb2a3b3ee0 Binary files /dev/null and b/lib/python3.10/site-packages/triton/tools/__pycache__/build_extern.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/triton/tools/__pycache__/compile.cpython-310.pyc b/lib/python3.10/site-packages/triton/tools/__pycache__/compile.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8b249d39a5fd3a86f97233986690d4de9630ec2d Binary files /dev/null and b/lib/python3.10/site-packages/triton/tools/__pycache__/compile.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/triton/tools/__pycache__/disasm.cpython-310.pyc b/lib/python3.10/site-packages/triton/tools/__pycache__/disasm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fc7f0a5c0d04ef48b038977a1b82046d563d7afc Binary files /dev/null and b/lib/python3.10/site-packages/triton/tools/__pycache__/disasm.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/triton/tools/__pycache__/link.cpython-310.pyc b/lib/python3.10/site-packages/triton/tools/__pycache__/link.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..adb2af53f1aeece2361ae4b03ce4886271535020 Binary files /dev/null and b/lib/python3.10/site-packages/triton/tools/__pycache__/link.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/triton/tools/build_extern.py b/lib/python3.10/site-packages/triton/tools/build_extern.py new file mode 100644 index 0000000000000000000000000000000000000000..8f0168d59d7af045bde68a508a000654f4893bb1 --- /dev/null +++ b/lib/python3.10/site-packages/triton/tools/build_extern.py @@ -0,0 +1,365 @@ +import argparse +import subprocess +from abc import ABC, abstractmethod +from typing import Dict, List, Optional + + +class Symbol: + _name: str + _op_name: str + _ret_type: str + _arg_names: List[str] + _arg_types: List[str] + + def __init__( + self, + name: str, + op_name: str, + ret_type: str, + arg_names: List[str], + arg_types: List[str], + ) -> None: + ''' + A symbol is a function declaration. + :param name: name of the symbol + :param op_name: name of the operation + :param ret_type: return type of the operation + :param arg_names: names of the arguments + :param arg_types: types of the arguments + ''' + self._name = name + self._op_name = op_name + self._ret_type = ret_type + self._arg_names = list(arg_names) + self._arg_types = list(arg_types) + + @property + def name(self) -> str: + return self._name + + @property + def op_name(self) -> str: + return self._op_name + + @property + def ret_type(self) -> str: + return self._ret_type + + @property + def arg_names(self) -> List[str]: + return self._arg_names + + @property + def arg_types(self) -> List[str]: + return self._arg_types + + +def convert_type(type_str) -> Optional[str]: + if type_str == "i32": + return "int32" + elif type_str == "u32": + return "uint32" + elif type_str == "i64": + return "int64" + elif type_str == "u64": + return "uint64" + elif type_str == "float": + return "fp32" + elif type_str == "double": + return "fp64" + else: + # ignore other types, such as pointer types + return None + + +def to_unsigned(type_str) -> str: + if type_str == "int32": + return "uint32" + elif type_str == "int64": + return "uint64" + else: + return type_str + + +class ExternLibrary(ABC): + _name: str + _path: str + _symbols: Dict[str, Symbol] + _format: bool + _grouping: bool + + def __init__( + self, + name: str, + path: str, + format: bool = True, + grouping: bool = True, + ) -> None: + ''' + Abstract class for extern library. + :param name: name of the library + :param path: path of the library + :param format: whether to format the generated stub file + ''' + self._name = name + self._path = path + self._symbols = {} + self._format = format + self._grouping = grouping + + @property + def name(self) -> str: + return self._name + + @property + def path(self) -> str: + return self._path + + @property + def symbols(self) -> Dict[str, Symbol]: + return self._symbols + + @property + def grouping(self) -> bool: + return self._grouping + + @abstractmethod + def parse_symbols(self, input_file) -> None: + pass + + @abstractmethod + def _output_stubs(self) -> str: + pass + + def generate_stub_file(self, output_dir) -> None: + file_str = self._output_stubs() + if file_str is None or len(file_str) == 0: + raise Exception("file_str is empty") + + output_file = f"{output_dir}/{self._name}.py" + with open(output_file, "w") as f: + f.write(file_str) + f.close() + if self._format: + subprocess.Popen(["autopep8", "-a", "-r", "-i", output_file], stdout=subprocess.PIPE).communicate() + subprocess.Popen(["isort", output_file], stdout=subprocess.PIPE).communicate() + + +class Libdevice(ExternLibrary): + _symbol_groups: Dict[str, List[Symbol]] + + def __init__(self, path) -> None: + ''' + Constructor for Libdevice. + :param path: path of the libdevice library + ''' + super().__init__("libdevice", path) + self._symbol_groups = {} + self.is_pure = True + + @staticmethod + def _extract_symbol(line) -> Optional[Symbol]: + # Extract symbols from line in the following format: + # "define [internal] @(,)" + entries = line.split("@") + ret_str = entries[0] + func_str = entries[1] + # Get ret_type, skip internal symbols + ret_strs = ret_str.split() + if ret_strs[1] == "internal": + return None + ret_type = convert_type(ret_strs[1]) + if ret_type is None: + return None + # Get function name + func_strs = func_str.split("(") + func_name = func_strs[0].replace("@", "") + op_name = func_name.replace("__nv_", "") + if 'ieee' in op_name: + return None + # Get arg_types + arg_strs = func_strs[1].split(",") + arg_types = [] + arg_names = [] + for i, arg_str in enumerate(arg_strs): + arg_type = convert_type(arg_str.split()[0]) + if arg_type is None: + return None + arg_name = 'arg' + str(i) + arg_types.append(arg_type) + arg_names.append(arg_name) + if op_name == "sad": + # Special case for sad, where the last argument is an unsigned int + arg_types[-1] = to_unsigned(arg_types[-1]) + elif op_name.startswith("u"): + # LLVM does not differentiate between signed and unsigned integer type. + # We have to convert the types to unsigned + ret_type = to_unsigned(ret_type) + for i, arg_type in enumerate(arg_types): + arg_types[i] = to_unsigned(arg_type) + return Symbol(func_name, op_name, ret_type, arg_names, arg_types) + + def _group_symbols(self) -> None: + symbol_set = {} + for symbol in self._symbols.values(): + op_name = symbol.op_name + symbol_set[op_name] = symbol + + # Group functions together by renaming. + renaming = { + 'llabs': 'abs', 'acosf': 'acos', 'acoshf': 'acosh', 'dadd_rd': 'add_rd', 'fadd_rd': 'add_rd', 'dadd_rn': + 'add_rn', 'fadd_rn': 'add_rn', 'dadd_ru': 'add_ru', 'fadd_ru': 'add_ru', 'dadd_rz': 'add_rz', 'fadd_rz': + 'add_rz', 'asinf': 'asin', 'asinhf': 'asinh', 'atanf': 'atan', 'atan2f': 'atan2', 'atanhf': 'atanh', + 'brevll': 'brev', 'cbrtf': 'cbrt', 'ceilf': 'ceil', 'clzll': 'clz', 'copysignf': 'copysign', 'cosf': 'cos', + 'coshf': 'cosh', 'cospif': 'cospi', 'cyl_bessel_i0f': 'cyl_bessel_i0', 'cyl_bessel_i1f': 'cyl_bessel_i1', + 'fdiv_rd': 'div_rd', 'ddiv_rd': 'div_rd', 'fdiv_rn': 'div_rn', 'ddiv_rn': 'div_rn', 'fdiv_ru': 'div_ru', + 'ddiv_ru': 'div_ru', 'fdiv_rz': 'div_rz', 'ddiv_rz': 'div_rz', 'erff': 'erf', 'erfcf': 'erfc', 'erfcinvf': + 'erfcinv', 'erfcxf': 'erfcx', 'erfinvf': 'erfinv', 'expf': 'exp', 'exp10f': 'exp10', 'exp2f': 'exp2', + 'expm1f': 'expm1', 'fabsf': 'abs', 'fabs': 'abs', 'fast_fdividef': 'fast_dividef', 'fdimf': 'fdim', 'ffsll': + 'ffs', 'floorf': 'floor', 'fmaf': 'fma', 'fmaf_rd': 'fma_rd', 'fmaf_rn': 'fma_rn', 'fmaf_ru': 'fma_ru', + 'fmaf_rz': 'fma_rz', 'fmodf': 'fmod', 'uhadd': 'hadd', 'hypotf': 'hypot', 'ilogbf': 'ilogb', 'isinff': + 'isinf', 'isinfd': 'isinf', 'isnanf': 'isnan', 'isnand': 'isnan', 'j0f': 'j0', 'j1f': 'j1', 'jnf': 'jn', + 'ldexpf': 'ldexp', 'lgammaf': 'lgamma', 'llrintf': 'llrint', 'llroundf': 'llround', 'logf': 'log', 'log10f': + 'log10', 'log1pf': 'log1p', 'log2f': 'log2', 'logbf': 'logb', 'umax': 'max', 'llmax': 'max', 'ullmax': + 'max', 'fmaxf': 'max', 'fmax': 'max', 'umin': 'min', 'llmin': 'min', 'ullmin': 'min', 'fminf': 'min', + 'fmin': 'min', 'dmul_rd': 'mul_rd', 'fmul_rd': 'mul_rd', 'dmul_rn': 'mul_rn', 'fmul_rn': 'mul_rn', + 'dmul_ru': 'mul_ru', 'fmul_ru': 'mul_ru', 'dmul_rz': 'mul_rz', 'fmul_rz': 'mul_rz', 'umul24': 'mul24', + 'umulhi': 'mulhi', 'mul64hi': 'mulhi', 'umul64hi': 'mulhi', 'nearbyintf': 'nearbyint', 'nextafterf': + 'nextafter', 'norm3df': 'norm3d', 'norm4df': 'norm4d', 'normcdff': 'normcdf', 'normcdfinvf': 'normcdfinv', + 'popcll': 'popc', 'powif': 'pow', 'powi': 'pow', 'powf': 'pow', 'rcbrtf': 'rcbrt', 'frcp_rd': 'rcp_rd', + 'drcp_rd': 'rcp_rd', 'frcp_rn': 'rcp_rn', 'drcp_rn': 'rcp_rn', 'frcp_ru': 'rcp_ru', 'drcp_ru': 'rcp_ru', + 'frcp_rz': 'rcp_rz', 'drcp_rz': 'rcp_rz', 'remainderf': 'remainder', 'urhadd': 'rhadd', 'rhypotf': 'rhypot', + 'rintf': 'rint', 'rnorm3df': 'rnorm3d', 'rnorm4df': 'rnorm4d', 'roundf': 'round', 'rsqrtf': 'rsqrt', + 'frsqrt_rn': 'rsqrt_rn', 'usad': 'sad', 'scalbnf': 'scalbn', 'signbitf': 'signbit', 'signbitd': 'signbit', + 'sinf': 'sin', 'sinhf': 'sinh', 'sinpif': 'sinpi', 'sqrtf': 'sqrt', 'fsqrt_rd': 'sqrt_rd', 'dsqrt_rd': + 'sqrt_rd', 'fsqrt_rn': 'sqrt_rn', 'dsqrt_rn': 'sqrt_rn', 'fsqrt_ru': 'sqrt_ru', 'dsqrt_ru': 'sqrt_ru', + 'fsqrt_rz': 'sqrt_rz', 'dsqrt_rz': 'sqrt_rz', 'fsub_rd': 'sub_rd', 'dsub_rd': 'sub_rd', 'fsub_rn': 'sub_rn', + 'dsub_rn': 'sub_rn', 'fsub_ru': 'sub_ru', 'dsub_ru': 'sub_ru', 'fsub_rz': 'sub_rz', 'dsub_rz': 'sub_rz', + 'tanf': 'tan', 'tanhf': 'tanh', 'tgammaf': 'tgamma', 'truncf': 'trunc', 'y0f': 'y0', 'y1f': 'y1', 'ynf': + 'yn' + } + + for symbol in self._symbols.values(): + op_name = symbol.op_name + if op_name in renaming: + op_name = renaming[op_name] + symbol._op_name = op_name + if op_name in self._symbol_groups: + self._symbol_groups[op_name].append(symbol) + else: + self._symbol_groups[op_name] = [symbol] + + def parse_symbols(self, input_file) -> None: + if len(self.symbols) > 0: + return + output = subprocess.check_output(["grep", "define", input_file]).decode().splitlines() + for line in output: + symbol = self._extract_symbol(line) + if symbol is None: + continue + self._symbols[symbol.name] = symbol + + self._group_symbols() + + def _output_stubs(self) -> str: + # Generate python functions in the following format: + # @extern.extern + # def (, _builder=None): + # arg_type_symbol_dict = {[arg_type]: {(symbol, ret_type)}} + # return core.extern_elementwise("libdevice", , , , _builder) + import_str = "from . import core\n" + + header_str = "" + func_str = "" + for symbols in self._symbol_groups.values(): + func_str += "@core.extern\n" + func_name_str = f"def {symbols[0].op_name}(" + for arg_name in symbols[0].arg_names: + func_name_str += f"{arg_name}, " + func_name_str += "_builder=None):\n" + + return_str = f"\treturn core.extern_elementwise(\"{self._name}\", libdevice_path(), [" + for arg_name in symbols[0].arg_names: + return_str += f"{arg_name}, " + return_str += "], \n" + + arg_type_symbol_dict_str = "{" + for symbol in symbols: + arg_type_symbol_dict_str += "(" + for arg_type in symbol.arg_types: + arg_type_symbol_dict_str += f'core.dtype("{arg_type}"),' + ret_type = f'core.dtype("{symbol.ret_type}")' + arg_type_symbol_dict_str += "): (\"" + symbol.name + "\", " + ret_type + "),\n" + arg_type_symbol_dict_str += "}" + + return_str += arg_type_symbol_dict_str + return_str += f", is_pure={self.is_pure}" + return_str += ", _builder=_builder)\n" + + func_str += func_name_str + return_str + "\n" + file_str = import_str + header_str + func_str + + return file_str + + +class LLVMDisassembler: + _path: str + _ll_file: str + + def __init__(self, path) -> None: + ''' + Invoke llvm-dis to disassemble the given file. + :param path: path to llvm-dis + ''' + self._path = path + self._ll_file = "/tmp/extern_lib.ll" + + def disasm(self, lib_path: str) -> None: + subprocess.Popen([self._path, lib_path, "-o", self.ll_file], stdout=subprocess.PIPE).communicate() + + @property + def ll_file(self) -> str: + return self._ll_file + + @property + def path(self) -> str: + return self._path + + +extern_libs = ["libdevice"] + + +def build( + llvm_dis_path: str, + lib_path: str, + lib_name: str, + output_dir: str, +) -> None: + ''' + Interface function to build the library file. + :param llvm_dis_path: path to the llvm-dis binary + :param lib_path: path to the external library file + :param lib_name: name of the library + :param output_dir: path to the output directory + ''' + if lib_name == "libdevice": + extern_lib = Libdevice(lib_path) + else: + raise Exception(f"Unknown extern library: {lib_name}") + + llvm_disassembler = LLVMDisassembler(llvm_dis_path) + llvm_disassembler.disasm(lib_path) + + extern_lib.parse_symbols(llvm_disassembler.ll_file) + extern_lib.generate_stub_file(output_dir) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--llvm-dis", dest="llvm_dis_path", help="Path to llvm-dis", default="llvm-dis") + parser.add_argument("--lib-path", dest="lib_path", help="Path to the extern library") + parser.add_argument("--lib-name", dest="lib_name", help="Name of the extern library") + parser.add_argument("--output", dest="output_dir", help="Output file path", default="/tmp/") + args = parser.parse_args() + + build(args.llvm_dis_path, args.lib_path, args.lib_name, args.output_dir) diff --git a/lib/python3.10/site-packages/triton/tools/compile.c b/lib/python3.10/site-packages/triton/tools/compile.c new file mode 100644 index 0000000000000000000000000000000000000000..971bf61912a752465c83a0b17ae131bee6f5cc74 --- /dev/null +++ b/lib/python3.10/site-packages/triton/tools/compile.c @@ -0,0 +1,67 @@ +/* clang-format off */ +#include +#include +#include +#include +#include + + +// helpers to check for cuda errors +#define CUDA_CHECK(ans) {{\ + gpuAssert((ans), __FILE__, __LINE__);\ + }}\ + +static inline void gpuAssert(CUresult code, const char *file, int line) {{ + if (code != CUDA_SUCCESS) {{ + const char *prefix = "Triton Error [CUDA]: "; + const char *str; + cuGetErrorString(code, &str); + char err[1024] = {{0}}; + strcat(err, prefix); + strcat(err, str); + printf("%s\\n", err); + exit(code); + }} +}} + +// globals +#define CUBIN_NAME {kernel_name}_cubin +CUmodule {kernel_name}_mod = NULL; +CUfunction {kernel_name}_func = NULL; +unsigned char CUBIN_NAME[{bin_size}] = {{ {bin_data} }}; + + +void unload_{kernel_name}(void) {{ + CUDA_CHECK(cuModuleUnload({kernel_name}_mod)); +}} + +// TODO: some code duplication with `runtime/backend/cuda.c` +void load_{kernel_name}() {{ + int dev = 0; + void *bin = (void *)&CUBIN_NAME; + int shared = {shared}; + CUDA_CHECK(cuModuleLoadData(&{kernel_name}_mod, bin)); + CUDA_CHECK(cuModuleGetFunction(&{kernel_name}_func, {kernel_name}_mod, "{triton_kernel_name}")); + // set dynamic shared memory if necessary + int shared_optin; + CUDA_CHECK(cuDeviceGetAttribute(&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, dev)); + if (shared > 49152 && shared_optin > 49152) {{ + CUDA_CHECK(cuFuncSetCacheConfig({kernel_name}_func, CU_FUNC_CACHE_PREFER_SHARED)); + CUDA_CHECK(cuFuncSetAttribute({kernel_name}_func, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin)) + }} +}} + +/* +{kernel_docstring} +*/ +CUresult {kernel_name}(CUstream stream, {signature}) {{ + if ({kernel_name}_func == NULL) + load_{kernel_name}(); + unsigned int gX = {gridX}; + unsigned int gY = {gridY}; + unsigned int gZ = {gridZ}; + void *args[{num_args}] = {{ {arg_pointers} }}; + // TODO: shared memory + if(gX * gY * gZ > 0) + return cuLaunchKernel({kernel_name}_func, gX, gY, gZ, {num_warps} * 32, 1, 1, {shared}, stream, args, NULL); +}} diff --git a/lib/python3.10/site-packages/triton/tools/compile.h b/lib/python3.10/site-packages/triton/tools/compile.h new file mode 100644 index 0000000000000000000000000000000000000000..d98b7063b6ae6292b65b61abf5a30c58b7d28e95 --- /dev/null +++ b/lib/python3.10/site-packages/triton/tools/compile.h @@ -0,0 +1,14 @@ +#ifndef TT_KERNEL_INCLUDES +#define TT_KERNEL_INCLUDES + +#include +#include +#include +#include + +#endif + +void unload_{kernel_name}(void); +void load_{kernel_name}(void); +// tt-linker: {kernel_name}:{full_signature}:{algo_info} +CUresult{_placeholder} {kernel_name}(CUstream stream, {signature}); diff --git a/lib/python3.10/site-packages/triton/tools/compile.py b/lib/python3.10/site-packages/triton/tools/compile.py new file mode 100644 index 0000000000000000000000000000000000000000..872332b03d99a963d6139d42155aa6e8acd45fa4 --- /dev/null +++ b/lib/python3.10/site-packages/triton/tools/compile.py @@ -0,0 +1,145 @@ +import binascii +import hashlib +import importlib.util +import sys +from argparse import ArgumentParser +from pathlib import Path +from typing import List + +import triton +from triton.compiler.code_generator import kernel_suffix +from triton.backends.nvidia.driver import ty_to_cpp + +desc = """ +Triton ahead-of-time compiler: + +This program compiles the kernel with name `kernel-name` in the file at the +provided `path` into self-contained C source-code that embeds the `cubin` +data along with utilities to load, unload and launch the kernel. + +signature is provided as a list of (optionally divisibility-hinted) types +or constexpr values, e.g. + +`compile.py --kernel-name kernel --signature "*fp32:16, i32:16, 1024, i32" --out-name kernel /path/to/kernel.py` + +will compile triton.JITFunction of name `kernel` inside the file `/path/to/kernel.py`. +Said kernel will be specialized such that argument 0, 1 are assumed to be multiple of 16, +and argument 2 is assumed to be a compile-time constant of value 1024, i.e. it won't be part of the generated prototype. + +The resulting entry point will have signature + +CUresult kernel_{specialization_suffix}(CUstream stream, unsigned gX, unsigned gY, unsigned gZ, float* arg0, int32_t arg1, int32_t arg2) + +Different such specialized entry points can be combined using the `linker.py` script. + +NOTE: when resolving the scope of /path/to/kernel.py, the file will be executed from within its parent directory with the python interpreter +used to run this `compile.py` script +""" + +if __name__ == "__main__": + + # command-line arguments + parser = ArgumentParser(description=desc) + parser.add_argument("path", + help="Path to Python source containing desired kernel in its scope. File will be executed.") + parser.add_argument("--kernel-name", "-n", type=str, default="", help="Name of the kernel to compile", + required=True) + parser.add_argument("--num-warps", "-w", type=int, default=1, help="Number of warps to launch the kernel") + parser.add_argument("--num-stages", "-ns", type=int, default=3, + help="Number of stages (meta-parameter of the kernel)") + parser.add_argument("--out-name", "-on", type=str, default=None, help="Out name for the compiled kernel") + parser.add_argument("--out-path", "-o", type=Path, default=None, help="Out filename") + parser.add_argument("--signature", "-s", type=str, help="Signature of the kernel", required=True) + parser.add_argument("--grid", "-g", type=str, help="Launch grid of the kernel", required=True) + args = parser.parse_args() + + out_name = args.out_name if args.out_name else args.kernel_name + out_path = args.out_path if args.out_path else Path(out_name) + + # execute python sources and extract functions wrapped in JITFunction + arg_path = Path(args.path) + sys.path.insert(0, str(arg_path.parent)) + spec = importlib.util.spec_from_file_location(arg_path.stem, arg_path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + kernel = getattr(mod, args.kernel_name) + grid = args.grid.split(",") + assert len(grid) == 3 + + # validate and parse signature + signature = list(map(lambda s: s.strip(" "), args.signature.split(","))) + + def hash_signature(signature: List[str]): + m = hashlib.sha256() + m.update(" ".join(signature).encode()) + return m.hexdigest()[:8] + + meta_sig = f"warps{args.num_warps}xstages{args.num_stages}" + sig_hash = hash_signature(signature + [meta_sig]) + + def constexpr(s): + try: + ret = int(s) + return ret + except ValueError: + pass + try: + ret = float(s) + return ret + except ValueError: + pass + return None + + hints = {i: constexpr(s.split(":")[1]) for i, s in enumerate(signature) if ":" in s} + hints = {k: v for k, v in hints.items() if v is not None} + constants = {i: constexpr(s) for i, s in enumerate(signature)} + constants = {k: v for k, v in constants.items() if v is not None} + signature = {i: s.split(":")[0] for i, s in enumerate(signature) if i not in constants} + const_sig = 'x'.join([str(v) for v in constants.values()]) + doc_string = [f"{kernel.arg_names[i]}={constants[i]}" for i in constants.keys()] + doc_string += [f"num_warps={args.num_warps}", f"num_stages={args.num_stages}"] + + # compile ast into cubin + for h in hints.values(): + assert h in [1, 16], f"Only 1 and 16 are valid hints, got {h}" + divisible_by_16 = [i for i, h in hints.items() if h == 16] + equal_to_1 = [i for i, h in hints.items() if h == 1] + attrs = triton.compiler.AttrsDescriptor(divisible_by_16=divisible_by_16, equal_to_1=equal_to_1) + for i in equal_to_1: + constants.update({i: 1}) + src = triton.compiler.ASTSource(fn=kernel, constants=constants, signature=signature, attrs=attrs) + opts = {"num_warps": args.num_warps, "num_stages": args.num_stages} + ccinfo = triton.compile(src, options=opts) + arg_names = [] + arg_types = [] + for i in signature.keys(): + if i not in equal_to_1: + arg_names += [kernel.arg_names[i]] + arg_types += [signature[i]] + + # dump C stub code + suffix = kernel_suffix(signature.values(), attrs) + func_name = '_'.join([out_name, sig_hash, suffix]) + hex_ = str(binascii.hexlify(ccinfo.asm["cubin"]))[2:-1] + params = { + "kernel_name": func_name, + "triton_kernel_name": args.kernel_name, + "bin_size": len(hex_), + "bin_data": ", ".join([f"0x{x}{y}" for x, y in zip(hex_[::2], hex_[1::2])]), + "signature": ", ".join([f"{ty_to_cpp(ty)} {name}" for name, ty in zip(arg_names, arg_types)]), + "full_signature": ", ".join([f"{ty_to_cpp(signature[i])} {kernel.arg_names[i]}" for i in signature.keys()]), + "arg_pointers": ", ".join([f"&{arg}" for arg in arg_names]), + "num_args": len(arg_names), + "kernel_docstring": doc_string, + "shared": ccinfo.metadata.shared, + "num_warps": args.num_warps, + "algo_info": '_'.join([const_sig, meta_sig]), + "gridX": grid[0], + "gridY": grid[1], + "gridZ": grid[2], + "_placeholder": "", + } + for ext in ['h', 'c']: + template_path = Path(__file__).parent / f"compile.{ext}" + with out_path.with_suffix(f".{sig_hash}_{suffix}.{ext}").open("w") as fp: + fp.write(Path(template_path).read_text().format(**params)) diff --git a/lib/python3.10/site-packages/triton/tools/disasm.py b/lib/python3.10/site-packages/triton/tools/disasm.py new file mode 100644 index 0000000000000000000000000000000000000000..1e309a2e4940ed56a432870f19b99d47aeea79f2 --- /dev/null +++ b/lib/python3.10/site-packages/triton/tools/disasm.py @@ -0,0 +1,142 @@ +# MIT License + +# Copyright (c) 2020 Da Yan @ HKUST + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import functools +import os +import re +import subprocess +import tempfile + +from ..common.backend import path_to_cuobjdump, path_to_nvdisasm + +FLINE_RE = re.compile(r'\s*/\*\w{4}\*/\s*([^;]*;)\s*/\* 0x(\w{16}) \*/\s*') +SLINE_RE = re.compile(r'\s*/\* 0x(\w{16}) \*/\s*') +FNAME_RE = re.compile(r'\s*Function : (\w+)\s*') +BRA_RE = re.compile(r'(.*BRA(?:\.U)? )(0x\w+);') + + +def parseCtrl(sline): + enc = int(SLINE_RE.match(sline).group(1), 16) + stall = (enc >> 41) & 0xf + yld = (enc >> 45) & 0x1 + wrtdb = (enc >> 46) & 0x7 + readb = (enc >> 49) & 0x7 + watdb = (enc >> 52) & 0x3f + + yld_str = 'Y' if yld == 0 else '-' + wrtdb_str = '-' if wrtdb == 7 else str(wrtdb) + readb_str = '-' if readb == 7 else str(readb) + watdb_str = '--' if watdb == 0 else f'{watdb:02d}' + return f'{watdb_str}:{readb_str}:{wrtdb_str}:{yld_str}:{stall:x}' + + +def processSassLines(fline, sline, labels): + asm = FLINE_RE.match(fline).group(1) + # Remove tailing space + if asm.endswith(" ;"): + asm = asm[:-2] + ";" + ctrl = parseCtrl(sline) + # BRA target address + if BRA_RE.match(asm) is not None: + target = int(BRA_RE.match(asm).group(2), 16) + if target in labels: + pass + else: + labels[target] = len(labels) + return (f'{ctrl}', f'{asm}') + + +@functools.lru_cache() +def get_sass(cubin_asm, fun=None): + fd, path = tempfile.mkstemp() + try: + with open(fd, 'wb') as cubin: + cubin.write(cubin_asm) + sass = extract(path, fun) + finally: + os.remove(path) + return sass + + +def extract(file_path, fun): + cuobjdump, _ = path_to_cuobjdump() + nvdisasm, _ = path_to_nvdisasm() + os.environ["NVDISASM_PATH"] = nvdisasm + if fun is None: + sass_str = subprocess.check_output([cuobjdump, "-sass", file_path]) + else: + sass_str = subprocess.check_output([cuobjdump, "-fun", fun, "-sass", file_path]) + sass_lines = sass_str.splitlines() + line_idx = 0 + while line_idx < len(sass_lines): + line = sass_lines[line_idx].decode() + # format: + # function : + # .headerflags: ... + # /*0000*/ asmstr /*0x...*/ + # /*0x...*/ + + # Looking for new function header (function: ) + while FNAME_RE.match(line) is None: + line_idx += 1 + if line_idx < len(sass_lines): + line = sass_lines[line_idx].decode() + else: + return + + fname = FNAME_RE.match(line).group(1) + ret = '' + ret += f'Function:{fname}\n' + line_idx += 2 # bypass .headerflags + line = sass_lines[line_idx].decode() + # Remapping address to label + labels = {} # address -> label_idx + # store sass asm in buffer and them print them (for labels) + # (ctrl, asm) + asm_buffer = [] + while FLINE_RE.match(line) is not None: + # First line (Offset ASM Encoding) + fline = sass_lines[line_idx].decode() + line_idx += 1 + # Second line (Encoding) + sline = sass_lines[line_idx].decode() + line_idx += 1 + asm_buffer.append(processSassLines(fline, sline, labels)) + # peek the next line + line = sass_lines[line_idx].decode() + # Print sass + # label naming convention: LBB#i + for idx, (ctrl, asm) in enumerate(asm_buffer): + # Print label if this is BRA target + offset = idx * 16 + if offset in labels: + label_name = f'LBB{labels[offset]}' + ret += f'{label_name}:\n' + ret += ctrl + '\t' + # if this is BRA, remap offset to label + if BRA_RE.match(asm): + target = int(BRA_RE.match(asm).group(2), 16) + target_name = f'LBB{labels[target]}' + asm = BRA_RE.sub(rf'\1{target_name};', asm) + ret += asm + '\n' + ret += '\n' + return ret diff --git a/lib/python3.10/site-packages/triton/tools/link.py b/lib/python3.10/site-packages/triton/tools/link.py new file mode 100644 index 0000000000000000000000000000000000000000..75a1157a52f92bbd5d2eae640af97ea360da2ef3 --- /dev/null +++ b/lib/python3.10/site-packages/triton/tools/link.py @@ -0,0 +1,322 @@ +from collections import defaultdict +from pathlib import Path +from typing import Sequence, Union + +from dataclasses import dataclass + + +def _exists(x): + return x is not None + + +class LinkerError(Exception): + pass + + +@dataclass +class KernelLinkerMeta: + orig_kernel_name: str + arg_names: Sequence[str] + arg_ctypes: Sequence[str] + sizes: Sequence[Union[int, None]] + sig_hash: str + triton_suffix: str + suffix: str + num_specs: int + """ number of specialized arguments """ + + +class HeaderParser: + + def __init__(self) -> None: + import re + + # [kernel_name, c signature] + self.linker_directives = re.compile("//[\\s]*tt-linker:[\\s]*([\\w]+):(.+):(.+)") + # [name, hash, suffix] + self.kernel_name = re.compile("^([\\w]+)_([\\w]+)_([\\w]+)$") + # [(type, name)] + self.c_sig = re.compile("[\\s]*(\\w+)\\s(\\w+)[,]?") + # [d|c] + self.arg_suffix = re.compile("[c,d]") + + self.kernels = defaultdict(list) + + def extract_linker_meta(self, header: str): + for ln in header.splitlines(): + if ln.startswith("//"): + m = self.linker_directives.match(ln) + if _exists(m): + ker_name, c_sig, algo_info = m.group(1), m.group(2), m.group(3) + name, sig_hash, suffix = self._match_name(ker_name) + c_types, arg_names = self._match_c_sig(c_sig) + num_specs, sizes = self._match_suffix(suffix, c_sig) + self._add_kernel( + "_".join([name, algo_info]), + KernelLinkerMeta( + orig_kernel_name=name, + arg_names=arg_names, + arg_ctypes=c_types, + sizes=sizes, + sig_hash=sig_hash, + triton_suffix=suffix, + suffix=suffix, + num_specs=num_specs, + ), + ) + + def _match_name(self, ker_name: str): + m = self.kernel_name.match(ker_name) + if _exists(m): + name, sig_hash, suffix = m.group(1), m.group(2), m.group(3) + return name, sig_hash, suffix + raise LinkerError(f"{ker_name} is not a valid kernel name") + + def _match_c_sig(self, c_sig: str): + m = self.c_sig.findall(c_sig) + if len(m): + tys, args = [], [] + for ty, arg_name in m: + tys.append(ty) + args.append(arg_name) + return tys, args + + raise LinkerError(f"{c_sig} is not a valid argument signature") + + def _match_suffix(self, suffix: str, c_sig: str): + args = c_sig.split(",") + s2i = {"c": 1, "d": 16} + num_specs = 0 + sizes = [] + # scan through suffix, first find the index, + # then see if it is followed by d or c + for i in range(len(args)): + pos = suffix.find(str(i)) + if pos == -1: + raise LinkerError(f"{suffix} is not a valid kernel suffix") + pos += len(str(i)) + if self.arg_suffix.match(suffix, pos): + num_specs += 1 + sizes.extend([None] * (i - len(sizes))) + sizes.append(s2i[suffix[pos]]) + pos += 1 + if i < len(args) - 1: + suffix = suffix[pos:] + else: + sizes.extend([None] * (len(args) - len(sizes))) + return num_specs, sizes + + def _add_kernel(self, name: str, ker: KernelLinkerMeta): + if name in self.kernels: + last: KernelLinkerMeta = self.kernels[name][-1] + + for cur, new_ in zip(last.arg_ctypes, ker.arg_ctypes): + if cur != new_: + raise LinkerError( + f"Mismatched signature for kernel {name}: \n\texisting sig is: {','.join(last.arg_ctypes)}\n\tcurrent is: {','.join(ker.arg_ctypes)}" + ) + + self.kernels[name].append(ker) + + +def gen_signature_with_full_args(m): + return ", ".join([f"{ty} {arg}" for ty, arg in zip(m.arg_ctypes, m.arg_names)]) + + +def gen_signature(m): + arg_types = [ty for ty, hint in zip(m.arg_ctypes, m.sizes) if hint != 1] + arg_names = [arg for arg, hint in zip(m.arg_names, m.sizes) if hint != 1] + sig = ", ".join([f"{ty} {arg}" for ty, arg in zip(arg_types, arg_names)]) + return sig + + +# generate declarations of kernels with meta-parameter and constant values +def make_algo_decls(name: str, metas: Sequence[KernelLinkerMeta]) -> str: + return f""" +CUresult {name}(CUstream stream, {gen_signature_with_full_args(metas[-1])}); +void load_{name}(); +void unload_{name}(); + """ + + +# generate declarations of kernels with meta-parameter and constant values +def make_global_decl(meta: KernelLinkerMeta) -> str: + return f""" +CUresult {meta.orig_kernel_name}_default(CUstream stream, {gen_signature_with_full_args(meta)}); +CUresult {meta.orig_kernel_name}(CUstream stream, {gen_signature_with_full_args(meta)}, int algo_id); +void load_{meta.orig_kernel_name}(); +void unload_{meta.orig_kernel_name}(); + """ + + +# generate dispatcher function for kernels with different meta-parameter and constant values +def make_default_algo_kernel(meta: KernelLinkerMeta) -> str: + src = f"CUresult {meta.orig_kernel_name}_default(CUstream stream, {gen_signature_with_full_args(meta)}){{\n" + src += (f" return {meta.orig_kernel_name}(stream, {', '.join(meta.arg_names)}, 0);\n") + src += "}\n" + return src + + +# generate dispatcher function for kernels with different integer value hints +def make_kernel_hints_dispatcher(name: str, metas: Sequence[KernelLinkerMeta]) -> str: + src = f"// launcher for: {name}\n" + for meta in sorted(metas, key=lambda m: -m.num_specs): + src += f"CUresult {meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}(CUstream stream, {gen_signature(meta)});\n" + src += "\n" + + src += (f"CUresult {name}(CUstream stream, {gen_signature_with_full_args(metas[-1])}){{") + src += "\n" + for meta in sorted(metas, key=lambda m: -m.num_specs): + cond_fn = ( # + lambda val, hint: f"({val} % {hint} == 0)" # + if hint == 16 # + else f"({val} == {hint})" # + if hint == 1 # + else None) + conds = " && ".join([ # + cond_fn(val, hint) # + for val, hint in zip(meta.arg_names, meta.sizes) # + if hint is not None + ]) + src += (f" if ({conds})\n" if any(meta.sizes) else "if (1)\n" + ) # Edge case where no specializations hence no dispatching required + arg_names = [arg for arg, hint in zip(meta.arg_names, meta.sizes) if hint != 1] + src += f" return {meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}(stream, {', '.join(arg_names)});\n" + src += "\n" + src += " return CUDA_ERROR_INVALID_VALUE;\n" + src += "}\n" + + for mode in ["load", "unload"]: + src += f"\n// {mode} for: {name}\n" + for meta in sorted(metas, key=lambda m: -m.num_specs): + src += f"void {mode}_{meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}();\n" + src += f"void {mode}_{name}() {{" + src += "\n" + for meta in sorted(metas, key=lambda m: -m.num_specs): + src += (f" {mode}_{meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}();\n") + src += "}\n" + return src + + +# generate dispatcher function for kernels with different meta-parameter and constant values +def make_kernel_meta_const_dispatcher(meta: KernelLinkerMeta) -> str: + src = f"CUresult {meta.orig_kernel_name}(CUstream stream, {gen_signature_with_full_args(meta)}, int algo_id){{\n" + src += f" assert (algo_id < (int)sizeof({meta.orig_kernel_name}_kernels));\n" + src += f" return {meta.orig_kernel_name}_kernels[algo_id](stream, {', '.join(meta.arg_names)});\n" + src += "}\n" + return src + + +# generate definition of function pointers of kernel dispatchers based on meta-parameter and constant values +def make_func_pointers(names: str, meta: KernelLinkerMeta) -> str: + # the table of hint dispatchers + src = f"typedef CUresult (*kernel_func_t)(CUstream stream, {gen_signature_with_full_args(meta)});\n" + src += f"kernel_func_t {meta.orig_kernel_name}_kernels[] = {{\n" + for name in names: + src += f" {name},\n" + src += "};\n" + return src + + +# generate definition for load/unload functions for kernels with different meta-parameter and constant values +def make_kernel_load_def(names: str, meta: KernelLinkerMeta) -> str: + src = "" + for mode in ["load", "unload"]: + src += f"void {mode}_{meta.orig_kernel_name}(void){{\n" + for name in names: + src += f" {mode}_{name}();\n" + src += "}\n\n" + return src + + +def make_get_num_algos_decl(meta: KernelLinkerMeta) -> str: + src = f"int {meta.orig_kernel_name}_get_num_algos(void);" + return src + + +def make_get_num_algos_def(meta: KernelLinkerMeta) -> str: + src = f"int {meta.orig_kernel_name}_get_num_algos(void){{\n" + src += f" return (int)(sizeof({meta.orig_kernel_name}_kernels) / sizeof({meta.orig_kernel_name}_kernels[0]));\n" + src += "}\n" + return src + + +desc = """ +Triton ahead-of-time linker: + +This program takes in header files generated by compile.py, and generates a +single entry-point responsible for dispatching the user's input to the right +kernel given the specializations that were compiled. + +Example usage: +python link.py /path/to/headers/*.h -o kernel_name +""" + +if __name__ == "__main__": + from argparse import ArgumentParser + + parser = ArgumentParser(description=desc) + parser.add_argument( + "headers", + nargs="+", + help="Paths to header files to link. Must include linker directive annotations (autogenerated by ttc)", + ) + parser.add_argument("--out", "-o", type=Path, help="Out filename") + parser.add_argument( + "--prefix", + type=str, + default="", + help="String to prefix kernel dispatcher names", + ) + args = parser.parse_args() + + # metadata + parser = HeaderParser() + includes = [] + for header in args.headers: + h_path = Path(header) + h_str = h_path.read_text() + includes.append(h_path.name) + parser.extract_linker_meta(h_str) + + # generate headers + algo_decls = [make_algo_decls(name, meta) for name, meta in parser.kernels.items()] + meta_lists = [meta for name, meta in parser.kernels.items()] + meta = meta_lists[0][0] + get_num_algos_decl = make_get_num_algos_decl(meta) + global_decl = make_global_decl(meta) + with args.out.with_suffix(".h").open("w") as fp: + out = "#include \n" + out += "\n".join(algo_decls) + out += "\n" + out += get_num_algos_decl + out += "\n" + out += global_decl + fp.write(out) + + # generate source + defs = [make_kernel_hints_dispatcher(name, meta) for name, meta in parser.kernels.items()] + names = [name for name in parser.kernels.keys()] + func_pointers_def = make_func_pointers(names, meta) + meta_const_def = make_kernel_meta_const_dispatcher(meta) + load_unload_def = make_kernel_load_def(names, meta) + get_num_algos_def = make_get_num_algos_def(meta) + default_algo_kernel = make_default_algo_kernel(meta) + with args.out.with_suffix(".c").open("w") as fp: + out = "" + out += "#include \n" + out += "#include \n" + out += "#include \n" + out += "\n" + out += "\n".join(defs) + out += "\n" + out += func_pointers_def + out += "\n" + out += get_num_algos_def + out += "\n" + out += meta_const_def + out += "\n" + out += load_unload_def + out += "\n" + out += default_algo_kernel + fp.write(out)